Shortcuts

配置loss函数

单个loss函数

loss需要被包装在LossContainer中,比如交叉熵损失:

from rainbowneko.train.loss import LossContainer

loss=LossContainer(loss=CrossEntropyLoss())

注解

放在LossContainer中的loos最好继承自nn.Module

多个loss函数

通过LossGroup可以组合多个loss函数,并为每个loss设置权重:

from rainbowneko.train.loss import LossContainer, LossGroup

loss=LossGroup([
    LossContainer(loss=CrossEntropyLoss()),
    LossContainer(loss=MSELoss(), weight=0.2),
])

数据流控制

我们通过LossContainerkey_map可以指定用哪些变量去计算loss,比如对于半监督学习场景:

LossContainer(CrossEntropyLoss(), key_map=('pred.pred_student -> 0', 'inputs.label -> 1'))

其中pred是模型输出结果,inputs是输入的所有数据。 把模型预测输出中的pred_student作为loss的第0个输入,把输入数据中的label作为第1个输入。

LossContainer默认的key_map('pred.pred -> 0', 'inputs.label -> 1'),即默认情况下,loss的第0个输入是模型的输出,第1个输入是输入数据中的label