Use Mixins#
In previous section, we introduce the most straight forward way to define a GS model. In this section, we will show how to use pre-defined mixins to compose a GS Model.
Tip
The code in the section is simplified for illustration purpose only and is not executable. Full implementation of 3DGS has been included in Model Zoo.
In many studies focusing on 3DGS, researchers often modify specific modules of the original GS while leaving others unchanged. This implies that 3DGS can be functionally decomposed into distinct modules to enable flexible feature composition. Leveraging Python’s multiple inheritance/mixin mechanism, we have decomposed GS into multiple components.
When defining new GS models, researchers can now:
Freely combine different modules based on task requirements
Maintain full model customizability by arbitrarily overriding methods from mixin classes
Currently, we primarily categorize the system into three core functional modules:
Rendering Module
RenderMxinLoss Function Module
LossMixinDensification Module
DensificationAndPruneMixin
Let’s start from the last code snippet in previous section.
1import torch
2from splatwizard.modules.gaussian_model import BaseGaussianModel
3from splatwizard.config import OptimizationParams, PipelineParams
4from splatwizard.modules.dataclass import RenderResult, LossPack
5from splatwizard.scheduler import Scheduler, task
6
7
8class GSModel(BaseGaussianModel):
9
10 def __init__(self):
11 ... # init operations
12
13 def setup_functions(self):
14 ... # setup operations
15
16 def training_setup(self, opt: OptimizationParams):
17 ... # define optimizer
18
19 @task
20 def update_learning_rate(self, step: int):
21 ... # define lr update operation
22
23 def register_pre_task(
24 self, scheduler: Scheduler, ppl: PipelineParams, opt: OptimizationParams
25 ):
26 ... # register tasks
27
28 def register_post_task(
29 self, scheduler: Scheduler, ppl: PipelineParams, opt: OptimizationParams
30 ):
31 ... # register tasks
32
33 def replace_tensor_to_optimizer(self, tensor, name):
34 ...
35
36 def _prune_optimizer(self, mask):
37 ...
38
39 def prune_points(self, mask):
40 ...
41
42 def cat_tensors_to_optimizer(self, tensors_dict):
43 ...
44
45 def densification_postfix(
46 self, new_xyz, new_features_dc, new_features_rest,
47 new_opacities, new_scaling, new_rotation
48 ):
49 ...
50
51 def densify_and_split(self, grads, grad_threshold, scene_extent, N=2):
52 ...
53
54 def densify_and_clone(self, grads, grad_threshold, scene_extent):
55 ...
56
57 def densify_and_prune(self, max_grad, min_opacity, extent, max_screen_size):
58 ...
59
60 @task
61 def add_densification_stats(self, render_result: RenderResult):
62 ...
63
64 @task
65 def densify_and_prune_task(self, opt: OptimizationParams, step: int):
66 ...
67
68 def render(self, viewpoint_camera, bg_color, pipe, opt=None, step=0, scaling_modifier=1.0, override_color=None):
69 ... # render function
70
71 def loss_func(self, viewpoint_cam, render_result: RenderResult, opt) -> (torch.Tensor, LossPack):
72 ... # loss function
When implementing with the mixin mechanism, the code structure transforms into the following form:
1from splatwizard.modules.gaussian_model import BaseGaussianModel
2from splatwizard.config import OptimizationParams, PipelineParams
3from splatwizard.scheduler import Scheduler, task
4from splatwizard.modules.loss_mixin import LossMixin
5from splatwizard.modules.render_mixin import RenderMixin
6from splatwizard.modules.dp_mixin import DensificationAndPruneMixin
7
8
9class GSModel(RenderMixin, LossMixin, DensificationAndPruneMixin, BaseGaussianModel):
10
11 def __init__(self):
12 ... # init operations
13
14 def setup_functions(self):
15 ... # setup operations
16
17 def training_setup(self, opt: OptimizationParams):
18 ... # define optimizer
19
20 @task
21 def update_learning_rate(self, step: int):
22 ... # define lr update operation
23
24 def register_pre_task(
25 self, scheduler: Scheduler, ppl: PipelineParams, opt: OptimizationParams
26 ):
27 ... # register tasks
28
29 def register_post_task(
30 self, scheduler: Scheduler, ppl: PipelineParams, opt: OptimizationParams
31 ):
32 ... # register tasks
33
34 @task
35 def densify_and_prune_task(self, opt: OptimizationParams, step: int):
36 ...
37
Thus, our model implementation becomes significantly simplified.