配置loss函数¶
单个loss函数¶
loss需要被包装在LossContainer中,比如交叉熵损失:
from rainbowneko.train.loss import LossContainer
loss=LossContainer(loss=CrossEntropyLoss())
注解
放在LossContainer中的loos最好继承自nn.Module
多个loss函数¶
通过LossGroup可以组合多个loss函数,并为每个loss设置权重:
from rainbowneko.train.loss import LossContainer, LossGroup
loss=LossGroup([
LossContainer(loss=CrossEntropyLoss()),
LossContainer(loss=MSELoss(), weight=0.2),
])
数据流控制¶
我们通过LossContainer的key_map可以指定用哪些变量去计算loss,比如对于半监督学习场景:
LossContainer(CrossEntropyLoss(), key_map=('pred.pred_student -> 0', 'inputs.label -> 1'))
其中pred是模型输出结果,inputs是输入的所有数据。
把模型预测输出中的pred_student作为loss的第0个输入,把输入数据中的label作为第1个输入。
LossContainer默认的key_map是('pred.pred -> 0', 'inputs.label -> 1'),即默认情况下,loss的第0个输入是模型的输出,第1个输入是输入数据中的label。