Use Mixins#

In previous section, we introduce the most straight forward way to define a GS model. In this section, we will show how to use pre-defined mixins to compose a GS Model.

Tip

The code in the section is simplified for illustration purpose only and is not executable. Full implementation of 3DGS has been included in Model Zoo.

In many studies focusing on 3DGS, researchers often modify specific modules of the original GS while leaving others unchanged. This implies that 3DGS can be functionally decomposed into distinct modules to enable flexible feature composition. Leveraging Python’s multiple inheritance/mixin mechanism, we have decomposed GS into multiple components.

When defining new GS models, researchers can now:

  1. Freely combine different modules based on task requirements

  2. Maintain full model customizability by arbitrarily overriding methods from mixin classes

Currently, we primarily categorize the system into three core functional modules:

  1. Rendering Module RenderMxin

  2. Loss Function Module LossMixin

  3. Densification Module DensificationAndPruneMixin

Let’s start from the last code snippet in previous section.

 1import torch
 2from splatwizard.modules.gaussian_model import BaseGaussianModel
 3from splatwizard.config import OptimizationParams, PipelineParams
 4from splatwizard.modules.dataclass import RenderResult, LossPack
 5from splatwizard.scheduler import Scheduler, task
 6
 7
 8class GSModel(BaseGaussianModel):
 9
10    def __init__(self):
11        ...  # init operations
12
13    def setup_functions(self):
14        ...  # setup operations
15
16    def training_setup(self, opt: OptimizationParams):
17        ...  # define optimizer
18
19    @task
20    def update_learning_rate(self, step: int):
21        ...  # define lr update operation
22
23    def register_pre_task(
24            self, scheduler: Scheduler, ppl: PipelineParams, opt: OptimizationParams
25    ):
26        ...  # register tasks
27
28    def register_post_task(
29            self, scheduler: Scheduler, ppl: PipelineParams, opt: OptimizationParams
30    ):
31        ...  # register tasks
32
33    def replace_tensor_to_optimizer(self, tensor, name):
34        ...
35
36    def _prune_optimizer(self, mask):
37        ...
38
39    def prune_points(self, mask):
40        ...
41
42    def cat_tensors_to_optimizer(self, tensors_dict):
43        ...
44
45    def densification_postfix(
46            self, new_xyz, new_features_dc, new_features_rest,
47            new_opacities, new_scaling, new_rotation
48    ):
49        ...
50
51    def densify_and_split(self, grads, grad_threshold, scene_extent, N=2):
52        ...
53
54    def densify_and_clone(self, grads, grad_threshold, scene_extent):
55        ...
56
57    def densify_and_prune(self, max_grad, min_opacity, extent, max_screen_size):
58        ...
59
60    @task
61    def add_densification_stats(self, render_result: RenderResult):
62        ...
63
64    @task
65    def densify_and_prune_task(self, opt: OptimizationParams, step: int):
66        ...
67
68    def render(self, viewpoint_camera, bg_color, pipe, opt=None, step=0, scaling_modifier=1.0, override_color=None):
69        ...  # render function
70
71    def loss_func(self, viewpoint_cam, render_result: RenderResult, opt) -> (torch.Tensor, LossPack):
72        ...  # loss function

When implementing with the mixin mechanism, the code structure transforms into the following form:

 1from splatwizard.modules.gaussian_model import BaseGaussianModel
 2from splatwizard.config import OptimizationParams, PipelineParams
 3from splatwizard.scheduler import Scheduler, task
 4from splatwizard.modules.loss_mixin import LossMixin
 5from splatwizard.modules.render_mixin import RenderMixin
 6from splatwizard.modules.dp_mixin import DensificationAndPruneMixin
 7
 8
 9class GSModel(RenderMixin, LossMixin, DensificationAndPruneMixin, BaseGaussianModel):
10
11    def __init__(self):
12        ...  # init operations
13
14    def setup_functions(self):
15        ...  # setup operations
16
17    def training_setup(self, opt: OptimizationParams):
18        ...  # define optimizer
19
20    @task
21    def update_learning_rate(self, step: int):
22        ...  # define lr update operation
23
24    def register_pre_task(
25            self, scheduler: Scheduler, ppl: PipelineParams, opt: OptimizationParams
26    ):
27        ...  # register tasks
28
29    def register_post_task(
30            self, scheduler: Scheduler, ppl: PipelineParams, opt: OptimizationParams
31    ):
32        ...  # register tasks
33
34    @task
35    def densify_and_prune_task(self, opt: OptimizationParams, step: int):
36        ...
37

Thus, our model implementation becomes significantly simplified.