Shortcuts

配置文件详解

基本格式

小技巧

RainbowNeko Engine的配置文件有pythonyaml两种格式。推荐使用python格式,灵活度更高,简单易用,可读性也更强。

python格式

python格式的配置文件支持完整的python语法,可以在配置中进行函数和类的调用。比如下面的例子:

from functools import partial

from cfgs.py.train.classify import multi_class
from rainbowneko.data import BaseDataset
from rainbowneko.data.handler import MixUPHandler, HandlerChain
from rainbowneko.train.loss import LossContainer, SoftCELoss

num_classes = 10
multi_class.num_classes = num_classes


def make_cfg():
    return dict(
        _base_=[multi_class],

        train=dict(
            loss=LossContainer(loss=SoftCELoss()),
            metrics=None,
        ),

        data_train=dict(
            dataset1=BaseDataset(
                batch_handler=HandlerChain(
                    mixup=MixUPHandler(num_classes=num_classes)
                )
            )
        ),
    )

配置需要定义在make_cfg函数中,定义一个dict。在配置中支持完整的python语法,包括函数调用,运算等。

注解

配置函数不会直接执行,会由解析器使用ast(抽象句法树)进行解析,将所有的call(调用)操作都转换成dictlist。在解析之后,会由框架在需要的地方进行实例化。 比如:

dict(
    layer=Linear(4, 4, bias=False)
)

在解析阶段,会自动翻译成:

dict(
    layer=dict(_target_=Linear, _args_=[4,4], bias=False)
)

注解

+-*/等运算操作左右的call节点,不会被解析器转换成dictlist,会直接执行。

partial使用

在配置中,有一些模块需要使用partial定义,因为在使用该模块时需要传入额外的参数。通过以下两种方式都可以实现:

optimizer=partial(torch.optim.AdamW, weight_decay=5e-4)
# 由解析器自动转换
optimizer=torch.optim.AdamW(_partial_=True, weight_decay=5e-4)

配置函数

yaml格式

在yaml格式的配置文件中,当需要引用某个类或函数时,需要写完整的路径。比如:

_base_:
  - cfgs/yaml/train/classify/multi_class.yaml

num_classes: 10

train:
  loss:
    _target_: rainbowneko.train.loss.LossContainer
    loss:
      _target_: rainbowneko.train.loss.SoftCELoss
    metrics: null
    
data_train:
  dataset1:
    _target_: rainbowneko.train.data.BaseDataset
    batch_handler:
      _target_: rainbowneko.train.data.handler.HandlerChain
      mixup:
        _target_: rainbowneko.train.data.handler.MixUPHandler
        num_classes: ${num_classes} # 引用配置

继承

配置文件可以继承,比如在python配置文件中,通过import导入配置文件,在_base_中配置,就可以继承另一个配置文件的配置。

from cfgs.py.train.classify import multi_class

def make_cfg():
    return dict(
        _base_=[multi_class],
        ...
    )

比如这里继承multi_class这个配置文件,可以自动获取multi_class的内容并继承。

在配置中定义的参数,会覆盖父配置文件的内容。对于嵌套的配置,只会替换内层定义的参数,不会把dict或者调用整个替换。

比如父配置文件的data_train是下面的配置:

dict(
    dataset1=partial(BaseDataset, batch_size=128, loss_weight=1.0,
        source=dict(
            data_source1=IndexSource(
                data=torchvision.datasets.cifar.CIFAR10(root=r'D:\others\dataset\cifar', train=True, download=True)
            ),
        ),
        handler=HandlerChain(
            load=LoadImageHandler(),
            bucket=FixedBucket.handler, # bucket 会自带一些处理模块
            image=ImageHandler(transform=T.Compose([
                    T.RandomCrop(size=32, padding=4),
                    T.RandomHorizontalFlip(),
                    T.ToTensor(),
                    T.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]),
                ]),
            )
        ),
        bucket=FixedBucket(target_size=32),
    )
)

如果要修改数据集路径,则可以在子配置文件中这样定义:

dict(
    dataset1=partial(BaseDataset
        source=dict(
            data_source1=IndexSource(
                data=torchvision.datasets.cifar.CIFAR10(root='data path')
            ),
        ),
    )
)

这样只会修改CIFAR10root参数,其他参数不会被修改。dataset1中的handlerbucket参数不会被修改。

小技巧

由于解析器会把调用转换成dict,所以在继承时,可以这样修改参数:

dict(
    dataset1=dict(
        source=dict(
            data_source1=dict(
                data=dict(root='data path')
            ),
        ),
    )
)

IndexSource()等价于dict(_target_=IndexSource)

全部替换

如果要全部替换父配置文件某个节点的内容,不保留父配置文件的部分,可以添加_replace_:

dataset1=partial(BaseDataset
    _replace_=True,
    ...
)

删除

如果要删除父配置文件的一个节点,可以使用---:

dict(
    dataset1='---', # 删除dataset1节点
    dataset_new=...
)

引用其他配置

在一个节点中可以引用其他节点的配置,比如:

train=dict(
    train_epochs=100,
)
epochs='${train.train_epochs}' # 引用train节点的train_epochs

也可以使用相对路径进行引用:

model=dict(
    wrapper=DistillationWrapper(_partial_=True, _replace_=True,
        model_teacher=load_resnet(torchvision.models.resnet18()),
        model_student='${.model_teacher}', # 引用同级节点的model_teacher
        ...
    )
),