Define a GS Model#

To facilitate researchers’ usage and reduce cognitive load, we adopted a design philosophy in the model architecture similar to the original 3DGS, where a single Gaussian Splatting class simultaneously manages trainable parameters, optimizers, densification operations, and other components. This design simplifies the integration process with existing research. However, to enhance implementation flexibility, we introduced several modifications. Below, we will reimplement the original 3DGS model step-by-step as an example to explain the key concept in our framework.

Tip

The code in the section is simplified for illustration purpose only and is not executable. Full implementation of 3DGS has been included in Model Zoo.

Step 1: Define GS Model Class and Configurations#

Similar to original 3DGS, all trainable parameters are defined in __init__ method:

 1import torch
 2from splatwizard.modules.gaussian_model import BaseGaussianModel
 3
 4
 5class GSModel(BaseGaussianModel):
 6
 7    def __init__(self, model_params: GSModelParams):
 8        super().__init__()
 9
10        self._xyz = torch.empty(0)
11        ...  # define other trainable parameters
12
13        self.setup_functions()
14
15    def setup_functions(self):
16        ...  # setup operations
17

The parameters of the model and the training parameters also need to be properly defined.

from dataclasses import dataclass
from splatwizard.config import ModelParams, OptimizationParams


@dataclass
class GSModelParams(ModelParams):
    ...


@dataclass
class GSOptimizationParams(OptimizationParams):
    ...

Step 2: Setup Optimizer#

The process of defining the optimizer also remains consistent with the original implementation.

 1from splatwizard.modules.gaussian_model import BaseGaussianModel
 2from splatwizard.config import OptimizationParams
 3
 4
 5class GSModel(BaseGaussianModel):
 6
 7    def __init__(self):
 8        ...  # init operations
 9
10    def setup_functions(self):
11        ...  # setup operations
12
13    def training_setup(self, opt: OptimizationParams):
14        ...  # define optimizer
15
16    def update_learning_rate(self, step: int):
17        ...  # define lr update operation 
18

Step 3: Define Densification and Prune Method#

Densification and pruning remain consistent with the original version. There is a little difference that we use a special dataclass RenderResult as the input of statistic method. By adopting a dedicated dataclass for intermediate data transfer, we avoid functions with excessive parameters, minimizing potential coding mistakes.

 1from splatwizard.modules.gaussian_model import BaseGaussianModel
 2from splatwizard.config import OptimizationParams
 3from splatwizard.modules.dataclass import RenderResult
 4
 5
 6class GSModel(BaseGaussianModel):
 7
 8    def __init__(self):
 9        ...  # init operations
10
11    def setup_functions(self):
12        ...  # setup operations
13
14    def training_setup(self, opt: OptimizationParams):
15        ...  # define optimzer
16
17    def update_learning_rate(self, step: int):
18        ...  # define lr update operation   
19
20    def replace_tensor_to_optimizer(self, tensor, name):
21        ...
22
23    def _prune_optimizer(self, mask):
24        ...
25
26    def prune_points(self, mask):
27        ...
28
29    def cat_tensors_to_optimizer(self, tensors_dict):
30        ...
31
32    def densification_postfix(
33            self, new_xyz, new_features_dc, new_features_rest,
34            new_opacities, new_scaling, new_rotation
35    ):
36        ...
37
38    def densify_and_split(self, grads, grad_threshold, scene_extent, N=2):
39        ...
40
41    def densify_and_clone(self, grads, grad_threshold, scene_extent):
42        ...
43
44    def densify_and_prune(self, max_grad, min_opacity, extent, max_screen_size):
45        ...
46
47    def add_densification_stats(self, render_result: RenderResult):
48        ... 

Step 4: Define Render Function and Loss Function#

Unlike the original implementation, we integrate both the rendering function and the loss function into the model itself. This design aims to standardize the training pipeline across different models as much as possible, thereby simplifying the implementation of the evaluation module.

 1import torch
 2from splatwizard.modules.gaussian_model import BaseGaussianModel
 3from splatwizard.config import OptimizationParams
 4from splatwizard.modules.dataclass import RenderResult, LossPack
 5
 6
 7class GSModel(BaseGaussianModel):
 8
 9    def __init__(self):
10        ...  # init operations
11
12    def setup_functions(self):
13        ...  # setup operations
14
15    def training_setup(self, opt: OptimizationParams):
16        ...  # define optimizer
17
18    def update_learning_rate(self, step: int):
19        ...  # define lr update operation
20
21    def replace_tensor_to_optimizer(self, tensor, name):
22        ...
23
24    def _prune_optimizer(self, mask):
25        ...
26
27    def prune_points(self, mask):
28        ...
29
30    def cat_tensors_to_optimizer(self, tensors_dict):
31        ...
32
33    def densification_postfix(
34            self, new_xyz, new_features_dc, new_features_rest,
35            new_opacities, new_scaling, new_rotation
36    ):
37        ...
38
39    def densify_and_split(self, grads, grad_threshold, scene_extent, N=2):
40        ...
41
42    def densify_and_clone(self, grads, grad_threshold, scene_extent):
43        ...
44
45    def densify_and_prune(self, max_grad, min_opacity, extent, max_screen_size):
46        ...
47
48    def add_densification_stats(self, render_result: RenderResult):
49        ...
50
51    def render(self, viewpoint_camera, bg_color, pipe, opt=None, step=0, scaling_modifier=1.0, override_color=None):
52        ...  # render function
53
54    def loss_func(self, viewpoint_cam, render_result: RenderResult, opt) -> (torch.Tensor, LossPack):
55        ...  # loss function

Step 5: Register Tasks#

From this point onward, our design philosophy diverges significantly from the original implementation. 3DGS requires execution of diverse operations at different training stages, such as collecting statistics, updating learning rates, densification, etc. For compression-related models, these operations become even more varied with increasingly complex scheduling requirements. Manually managing these tasks within training scripts proves both cumbersome and error-prone.

To address this, we designed a dedicated Scheduler that consolidates all task operations within a unified framework. In practical implementation, GS model are required to define two core methods:

  1. register_pre_task: Registers operations executed before rendering (e.g., learning rate updates)

  2. register_post_task: Registers operations executed after rendering (e.g., densification)

This architecture organizes the originally dispersed operations. Taking learning rate adjustment as an example, we register it in register_pre_task, while operations like densification are handled through register_post_task, establishing clear execution boundaries and logical flow.

The scheduler ensures operations execute in their designated phases while maintaining compatibility with existing optimization steps. This proves particularly valuable when extending the framework to support novel compression techniques requiring additional processing stages.

 1import torch
 2from splatwizard.modules.gaussian_model import BaseGaussianModel
 3from splatwizard.config import OptimizationParams, PipelineParams
 4from splatwizard.modules.dataclass import RenderResult, LossPack
 5from splatwizard.scheduler import Scheduler, task
 6
 7
 8class GSModel(BaseGaussianModel):
 9
10    def __init__(self):
11        ...  # init operations
12
13    def setup_functions(self):
14        ...  # setup operations
15
16    def training_setup(self, opt: OptimizationParams):
17        ...  # define optimizer
18
19    @task
20    def update_learning_rate(self, step: int):
21        ...  # define lr update operation
22
23    def register_pre_task(
24            self, scheduler: Scheduler, ppl: PipelineParams, opt: OptimizationParams
25    ):
26        scheduler.register_task(
27            range(opt.iterations),
28            task=self.update_learning_rate
29        )
30        ...  # other tasks
31
32    def register_post_task(
33            self, scheduler: Scheduler, ppl: PipelineParams, opt: OptimizationParams
34    ):
35        scheduler.register_task(
36            range(opt.densify_until_iter),
37            task=self.add_densification_stats)
38
39        scheduler.register_task(
40            range(opt.densify_from_iter, opt.densify_until_iter, opt.densification_interval),
41            task=self.densify_and_prune_task
42        )
43        ...  # other tasks
44
45    def replace_tensor_to_optimizer(self, tensor, name):
46        ...
47
48    def _prune_optimizer(self, mask):
49        ...
50
51    def prune_points(self, mask):
52        ...
53
54    def cat_tensors_to_optimizer(self, tensors_dict):
55        ...
56
57    def densification_postfix(
58            self, new_xyz, new_features_dc, new_features_rest,
59            new_opacities, new_scaling, new_rotation
60    ):
61        ...
62
63    def densify_and_split(self, grads, grad_threshold, scene_extent, N=2):
64        ...
65
66    def densify_and_clone(self, grads, grad_threshold, scene_extent):
67        ...
68
69    def densify_and_prune(self, max_grad, min_opacity, extent, max_screen_size):
70        ...
71
72    @task
73    def add_densification_stats(self, render_result: RenderResult):
74        ...
75
76    @task
77    def densify_and_prune_task(self, opt: OptimizationParams, step: int):
78        size_threshold = 20 if step > opt.opacity_reset_interval else None
79        self.densify_and_prune(
80            opt.densify_grad_threshold, 0.005,
81            self.spatial_lr_scale, size_threshold
82        )
83
84    def render(self, viewpoint_camera, bg_color, pipe, opt=None, step=0, scaling_modifier=1.0, override_color=None):
85        ...  # render function
86
87    def loss_func(self, viewpoint_cam, render_result: RenderResult, opt) -> (torch.Tensor, LossPack):
88        ...  # loss function

The scheduler provides task functions with three fixed parameters: render results, optimization parameters, and the current training step. To simplify task function development, we introduce the @task decorator, which automatically supplies appropriate arguments during scheduling based on the function’s parameter type annotations.

Tip

Task function parameters must be exclusively selected from:

  • RenderResult or its derivative class

  • OptimizationParams or its derivative class (here is GSOptimizationParams defined in Step 1)

  • int (representing the training step count)

See Line 20, Line 74 and Line 78 for actual use cases.

Summary#

Congratulations! You have finished your first GS model in SplatWizard. Next, we will simplify the process by using more tricky way