Shortcuts

Configuration File Explanation

Basic Format

Tip

The configuration files for the RainbowNeko Engine support both Python and YAML formats. It is recommended to use the Python format due to its higher flexibility, simplicity, ease of use, and better readability.

Python Format

Configuration files in Python format support full Python syntax, allowing for function and class calls within the configuration. For example:

from functools import partial

from cfgs.py.train.classify import multi_class
from rainbowneko.data import BaseDataset
from rainbowneko.data.handler import MixUPHandler, HandlerChain
from rainbowneko.train.loss import LossContainer, SoftCELoss

num_classes = 10
multi_class.num_classes = num_classes


def make_cfg():
    return dict(
        _base_=[multi_class],

        train=dict(
            loss=LossContainer(loss=SoftCELoss()),
            metrics=None,
        ),

        data_train=dict(
            dataset1=BaseDataset(
                batch_handler=HandlerChain(
                    mixup=MixUPHandler(num_classes=num_classes)
                )
            )
        ),
    )

The configuration should be defined within a make_cfg function that returns a dict. Full Python syntax is supported in the configuration, including function calls and operations.

Note

The configuration function is not executed directly. Instead, it is parsed by an interpreter using AST (Abstract Syntax Tree), which converts all call operations into dict and list. After parsing, the framework instantiates them where necessary.

For example:

dict(
    layer=Linear(4, 4, bias=False)
)

During parsing, it will be automatically translated into:

dict(
    layer=dict(_target_=Linear, _args_=[4, 4], bias=False)
)

Note

Operations such as +-*/ on both sides of a call node will not be converted into dict or list by the parser; they will be executed directly.

Using partial

Some modules in the configuration may require additional parameters during use. These can be defined using partial, which can be implemented in two ways:

optimizer = partial(torch.optim.AdamW, weight_decay=5e-4)
# Automatically converted by the parser
optimizer = torch.optim.AdamW(_partial_=True, weight_decay=5e-4)

Configuration Function

YAML Format

In YAML format configuration files, when referencing a class or function, you must provide its full path. For example:

_base_:
  - cfgs/yaml/train/classify/multi_class.yaml

num_classes: 10

train:
  loss:
    _target_: rainbowneko.train.loss.LossContainer
    loss:
      _target_: rainbowneko.train.loss.SoftCELoss
    metrics: null
    
data_train:
  dataset1:
    _target_: rainbowneko.train.data.BaseDataset
    batch_handler:
      _target_: rainbowneko.train.data.handler.HandlerChain
      mixup:
        _target_: rainbowneko.train.data.handler.MixUPHandler
        num_classes: ${num_classes} # Reference to another configuration parameter

Inheritance

Configuration files can inherit from others. For example, in Python configuration files, you can inherit another file’s settings by importing it and specifying it in _base_:

from cfgs.py.train.classify import multi_class

dict(
    _base_=[multi_class],
    ...
)

Here, inheriting the multi_class configuration file automatically includes its content.

Parameters defined in the current configuration override those from the parent file. For nested configurations, only inner parameters are replaced; the entire dict or call is not replaced.

For instance, if the parent file’s data_train has this structure:

dict(
    dataset1=partial(BaseDataset, batch_size=128, loss_weight=1.0,
        source=dict(
            data_source1=IndexSource(
                data=torchvision.datasets.cifar.CIFAR10(root=r'D:\others\dataset\cifar', train=True, download=True)
            ),
        ),
        handler=HandlerChain(
            load=LoadImageHandler(),
            bucket=FixedBucket.handler,
            image=ImageHandler(transform=T.Compose([
                    T.RandomCrop(size=32, padding=4),
                    T.RandomHorizontalFlip(),
                    T.ToTensor(),
                    T.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]),
                ]),
            )
        ),
        bucket=FixedBucket(target_size=32),
    )
)

You can modify just the dataset path in a child file like this:

dict(
    dataset1=partial(BaseDataset,
        source=dict(
            data_source1=IndexSource(
                data=torchvision.datasets.cifar.CIFAR10(root='data path')
            ),
        ),
    )
)

This only modifies the root parameter of CIFAR10, leaving other parameters unchanged. The handler and bucket parameters within dataset1 remain unaltered.

Tip

Since the parser converts calls into dictionaries during inheritance, you can modify parameters like this:

dict(
    dataset1=dict(
        source=dict(
            data_source1=dict(
                data=dict(root='data path')
            ),
        ),
    )
)

Here, calling IndexSource() is equivalent to writing dict(_target_=IndexSource).

Complete Replacement

To completely replace a parent file’s node without retaining any part of it:

dataset1 = partial(BaseDataset,
    _replace_=True,
    ...
)

Deletion

To delete a node from a parent file:

dict(
    dataset1='---', # Deletes the dataset1 node
    dataset_new=...
)

Referencing Other Configurations

A node can reference another node’s parameter. For example:

train=dict(
    train_epochs=100,
)
epochs='${train.train_epochs}' # Reference to train's train_epochs parameter

You can also use relative paths for references:

model=dict(
    wrapper=DistillationWrapper(_partial_=True, _replace_=True,
        model_teacher=load_resnet(torchvision.models.resnet18()),
        model_student='${.model_teacher}', # Reference to sibling node model_teacher
        ...
    )
),