CrossEntropyLoss
¶
Cross Entropy 的公式为
其中 \(p\) 为真实标签值,\(q\) 为网络预测值。因为 \(p\) 经常是 one-hot label,而 \(q\) 通常是通过 softmax
函数得到的 logits
,所以上式可化为
Given an input \(x\) of size \((N, C)\), where \(N\) is the batch size and \(C\) is the number of classes, and a target class index \(y\) of size \((N,)\), the CrossEntropyLoss
is calculated as:
Here, \(x_{i,y_i}\) represents the logit value of the target class \(y_i\) for the \(i\)-th sample in the batch, and \(x_{i,j}\) represents the logit value for class \(j\) for the \(i\)-th sample. The sum in the denominator is taken over all classes. The logarithm and exponential operations ensure that the values are processed in a logarithmic scale and that the probabilities are properly normalized.
The negative sign at the beginning of the formula indicates that the loss is minimized, which means the model is encouraged to assign high probabilities to the correct class and lower probabilities to the incorrect classes.
The CrossEntropyLoss
function is often used in conjunction with a softmax activation function applied to the logits to obtain class probabilities before computing the loss. This formulation ensures that the loss takes into account the probabilities and the true class index, making it suitable for training models in multi-class classification tasks.
The torch.nn.CrossEntropyLoss
function in PyTorch is equivalent to the combination of torch.nn.LogSoftmax
and torch.nn.NLLLoss
. Let's break down each component individually:
torch.nn.LogSoftmax
: This function applies the logarithm operation to the softmax function output. The softmax function normalizes the input values into a probability distribution, while the logarithm operation transforms the resulting probabilities into logarithmic scale. The formula forLogSoftmax
can be expressed as:
$$
\text{{LogSoftmax}}(x_i) = \log\left(\frac{{\exp(x_i)}}{{\sum\limits_{j=1}^{\text{{num_classes}}} \exp(x_j)}}\right)
$$
Here, \(x_i\) represents the input logits for class \(i\), and \(\exp\) denotes the exponential function.
torch.nn.NLLLoss
(Negative Log Likelihood Loss): This loss function measures the negative log likelihood of the predicted class probability distribution compared to the true distribution. It expects the input to be in logarithmic scale, which is why we useLogSoftmax
before applyingNLLLoss
. The formula forNLLLoss
can be expressed as:
$$
\text{{NLLLoss}}(x, y) = -x[y]
$$
Here, \(x\) represents the input logits, and \(y\) represents the target class index.
Now, let's derive how the combination of LogSoftmax
and NLLLoss
is equivalent to CrossEntropyLoss
:
Given an input \(x\) and its corresponding target class index \(y\), the CrossEntropyLoss
function can be defined as follows:
Substituting the definition of NLLLoss
and LogSoftmax
, we have:
The final expression above is equivalent to the CrossEntropyLoss
formula. The first term, \(\log\left(\sum\limits_{j=1}^{\text{{num\_classes}}} \exp(x[j])\right)\), is the logarithm of the sum of exponential values of all logits, which acts as a normalization factor. The second term, \(x[y]\), represents the logit value corresponding to the target class. Thus, minimizing CrossEntropyLoss
encourages the model to assign a high probability to the correct class and lower probabilities to the incorrect classes.
In summary, by combining LogSoftmax
and NLLLoss
, the CrossEntropyLoss
function provides a convenient and efficient way to compute the loss for multi-class classification problems while ensuring proper handling of logarithmic scaling and probability distributions.
torch
的 CrossEntropyLoss
还有一个可选的 label_smoothing
参数,记为 \(\varepsilon\)。这是因为 one-hot label 可能让网络过于自信,影响了预测。
其中 \(K\) 是种类数。