在PyTorch中打印模型摘要的方法

在PyTorch中打印模型摘要的方法

技术背景

在深度学习模型开发过程中,了解模型的结构和参数情况是非常重要的。在Keras中,可以使用model.summary()方法方便地打印出模型的详细摘要信息。然而,PyTorch并没有直接提供类似的功能。不过,有多种方法可以在PyTorch中实现类似的模型摘要打印效果。

实现步骤

1. 直接打印模型

这是最简单的方法,直接使用print(model)语句即可。

1
2
3
from torchvision import models
model = models.vgg16()
print(model)

这种方法会输出模型的各个层及其基本参数设置,但不会给出每层的输出形状和参数数量的详细统计。

2. 使用torchsummary(已不推荐)

早期可以使用torchsummary库来获取类似Keras的模型摘要信息。首先需要安装该库:

1
pip install torchsummary

然后使用以下代码:

1
2
3
4
5
from torchvision import models
from torchsummary import summary

vgg = models.vgg16()
summary(vgg, (3, 224, 224))

不过,torchsummary现在已经不再维护,建议使用torchinfo

3. 使用torchinfo

torchinfo是目前推荐的获取模型摘要信息的库。首先安装该库:

1
pip install torchinfo

然后使用以下代码:

1
2
3
4
5
6
from torchinfo import summary
import torchvision.models as models

model = models.alexnet()
batch_size = 16
summary(model, input_size=(batch_size, 3, 224, 224))

torchinfo会输出每层的详细信息,包括层类型、输出形状、参数数量等,还会给出总的参数数量、可训练参数数量、非可训练参数数量等统计信息。

4. 自定义函数获取模型信息

可以自定义一个函数来获取模型的参数和权重信息:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from torch.nn.modules.module import _addindent
import torch
import numpy as np

def torch_summarize(model, show_weights=True, show_parameters=True):
"""Summarizes torch model by showing trainable parameters and weights."""
tmpstr = model.__class__.__name__ + ' (\n'
for key, module in model._modules.items():
# if it contains layers let call it recursively to get params and weights
if type(module) in [
torch.nn.modules.container.Container,
torch.nn.modules.container.Sequential
]:
modstr = torch_summarize(module)
else:
modstr = module.__repr__()
modstr = _addindent(modstr, 2)

params = sum([np.prod(p.size()) for p in module.parameters()])
weights = tuple([tuple(p.size()) for p in module.parameters()])

tmpstr += ' (' + key + '): ' + modstr
if show_weights:
tmpstr += ', weights={}'.format(weights)
if show_parameters:
tmpstr += ', parameters={}'.format(params)
tmpstr += '\n'

tmpstr = tmpstr + ')'
return tmpstr

# Test
import torchvision.models as models
model = models.alexnet()
print(torch_summarize(model))

核心代码

以下是使用torchinfo打印模型摘要的核心代码:

1
2
3
4
5
6
7
8
9
10
11
from torchinfo import summary
import torchvision.models as models

# 加载模型
model = models.resnet18()

# 定义输入尺寸
input_size = (1, 3, 224, 224)

# 打印模型摘要
summary(model, input_size=input_size)

最佳实践

  • 对于简单的模型查看,可以直接使用print(model)
  • 如果需要详细的模型信息,包括每层的输出形状和参数数量,推荐使用torchinfo库。
  • 在使用torchinfo时,确保正确指定输入尺寸,因为PyTorch的动态计算图特性,模型的输出形状依赖于输入。

常见问题

1. torchsummarytorchinfo有什么区别?

torchsummary是早期用于获取模型摘要信息的库,但现在已经不再维护。torchinfo是其替代库,功能更稳定,推荐使用。

2. 如何确定模型的输入尺寸?

输入尺寸取决于模型的设计和任务需求。对于图像分类任务,常见的输入尺寸如(3, 224, 224)表示RGB图像的通道数为3,高度和宽度均为224像素。如果是自定义模型,需要根据模型的输入层设置来确定。

3. 使用torchinfo时出现错误怎么办?

确保已经正确安装torchinfo库,并且输入尺寸的格式正确。如果问题仍然存在,可以查看错误信息,检查模型定义是否有问题。


在PyTorch中打印模型摘要的方法
https://119291.xyz/posts/2025-04-22.print-model-summary-in-pytorch/
作者
ww
发布于
2025年4月22日
许可协议