理解Logits、Softmax和softmax_cross_entropy_with_logits的区别

理解Logits、Softmax和softmax_cross_entropy_with_logits的区别

技术背景

在机器学习尤其是深度学习中,分类问题是一个常见的任务。在解决分类问题时,我们需要将模型的输出转换为概率分布,以便确定每个类别的可能性。同时,我们需要一个损失函数来衡量模型预测结果与真实标签之间的差异,从而进行模型的训练和优化。在TensorFlow中,logitssoftmaxsoftmax_cross_entropy_with_logits是与这些任务相关的重要概念和函数。

实现步骤

1. 理解Logits

Logits是神经网络模型的未归一化输出,通常是线性层的输出。这些值没有经过任何归一化处理,因此它们的和不一定为1,也不代表概率。例如,在一个简单的神经网络中,通过y = W * x + b计算得到的结果就是logits。

2. 使用Softmax函数

Softmax函数是一种常用的归一化函数,它将logits转换为概率分布。Softmax函数的作用是将输入的logits“压缩”到[0, 1]的范围内,并且使得所有输出值的和为1。具体来说,对于输入的logits向量 $z$,Softmax函数的输出为:
$$\sigma(z)j = \frac{e^{z_j}}{\sum{k=1}^{K} e^{z_k}}$$
其中,$K$ 是类别的数量,$z_j$ 是第 $j$ 个类别的logit值。

3. 计算交叉熵损失

交叉熵是一种常用的损失函数,用于衡量两个概率分布之间的差异。在分类问题中,我们通常使用交叉熵来衡量模型预测的概率分布与真实标签的概率分布之间的差异。对于单个样本,交叉熵损失的计算公式为:
$$H(p, q) = - \sum_{i=1}^{K} p_i \log(q_i)$$
其中,$p$ 是真实标签的概率分布,$q$ 是模型预测的概率分布。

4. 使用softmax_cross_entropy_with_logits函数

softmax_cross_entropy_with_logits函数将Softmax函数和交叉熵损失的计算合并在一起,并且在数值计算上进行了优化,避免了直接计算可能出现的数值不稳定问题。该函数的输入是logits和真实标签,输出是每个样本的交叉熵损失。

核心代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import tensorflow as tf
import numpy as np

# 创建示例logits
y_hat = tf.convert_to_tensor(np.array([[0.5, 1.5, 0.1], [2.2, 1.3, 1.7]]))

# 使用softmax函数
y_hat_softmax = tf.nn.softmax(y_hat)

# 创建示例真实标签(one-hot编码)
y_true = tf.convert_to_tensor(np.array([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]))

# 手动计算交叉熵损失
loss_per_instance_1 = -tf.reduce_sum(y_true * tf.log(y_hat_softmax), reduction_indices=[1])
total_loss_1 = tf.reduce_mean(loss_per_instance_1)

# 使用softmax_cross_entropy_with_logits函数计算交叉熵损失
loss_per_instance_2 = tf.nn.softmax_cross_entropy_with_logits(y_hat, y_true)
total_loss_2 = tf.reduce_mean(loss_per_instance_2)

# 创建会话并运行计算
sess = tf.Session()
print("Softmax输出:")
print(sess.run(y_hat_softmax))
print("手动计算的每个样本的损失:")
print(sess.run(loss_per_instance_1))
print("手动计算的总损失:")
print(sess.run(total_loss_1))
print("使用softmax_cross_entropy_with_logits计算的每个样本的损失:")
print(sess.run(loss_per_instance_2))
print("使用softmax_cross_entropy_with_logits计算的总损失:")
print(sess.run(total_loss_2))
sess.close()

最佳实践

  • 在模型评估阶段,使用softmax函数将logits转换为概率分布,以便进行分类预测。
  • 在模型训练阶段,使用softmax_cross_entropy_with_logits函数计算交叉熵损失,避免手动计算可能出现的数值不稳定问题。
  • 如果标签是稀疏的(即每个样本只有一个真实类别),可以使用sparse_softmax_cross_entropy_with_logits函数,避免将标签转换为one-hot编码。

常见问题

1. softmax_cross_entropy_with_logits函数的输入应该是什么?

该函数的输入应该是未经过Softmax处理的logits和真实标签。如果输入的是经过Softmax处理的概率分布,会导致结果错误。

2. 为什么要使用softmax_cross_entropy_with_logits函数而不是手动计算?

手动计算Softmax和交叉熵损失可能会出现数值不稳定的问题,例如当logits值很大或很小时,指数运算可能会导致溢出或下溢。softmax_cross_entropy_with_logits函数在数值计算上进行了优化,避免了这些问题。

3. 在TensorFlow 2.0中如何使用这些函数?

在TensorFlow 2.0中,可以使用tf.compat.v2.nn.softmaxtf.compat.v2.nn.softmax_cross_entropy_with_logitstf.compat.v2.nn.sparse_softmax_cross_entropy_with_logits来替代TensorFlow 1.x中的相应函数。


理解Logits、Softmax和softmax_cross_entropy_with_logits的区别
https://119291.xyz/posts/2025-04-22.understanding-logits-softmax-and-softmax-cross-entropy-with-logits/
作者
ww
发布于
2025年4月22日
许可协议