Pytorch中Cross Entropy的用法
本文假设读者已经对熵(Entropy)的概念有所了解。
1. 交叉熵
本文只对交叉熵做简要介绍。详细步骤请参阅各大机器学习书籍、博客等。
交叉熵(Cross
Entropy)是深度学习中常用的损失函数。交叉熵是KL散度在特定情况下的变种。KL散度公式如下,
深度学习中,原概率分布
2. Pytorch中的nn.CrossEntropyLoss()
Pytorch中的nn.CrossEntropyLoss()一般输入两个参数:output和ground truth,output一般为模型的输出,而ground truth有多种表现形式。下面举几个例子。
2.1 one-hot编码型
交叉熵一般用于分类任务,对于多分类任务,可以进行one-hot编码。我们直接来看代码示例。
1 | import torch |
输出为
1 | tensor(1.2914, grad_fn=<DivBackward1>) |
这里直接将one-hot编码之后的
2.2 索引输入型
除了直接输入one-hot编码形式的GT,我们还可以直接输入one-hot编码后
1 | import torch |
同样,输出结果为 1
tensor(1.2914, grad_fn=<DivBackward1>)
对比两段代码,数组
3. 应用
写这个博客的原因是学习MoCo时发现计算CrossEntropyLoss时直接将所有label置零,感到有些不解,于是查阅资料发现Pytorch中的CrossEntropyLoss有多种输入方法,于此记录。
MoCo中可以直接将label置零的原因是在伪代码中有这样一行
1 | logits = cat([l_pos, 1_neg], dim=1) |
也就是说MoCo始终将正样本放在第一个位置,因此one-hot编码时值为1的地方都是首位,直接将label全部置零,用索引的形式输入CrossEntropyLoss即可。