from torch.nn.modules.module import _addindent import torch import numpy as np
deftorch_summarize(model, show_weights=True, show_parameters=True): """Summarizes torch model by showing trainable parameters and weights.""" tmpstr = model.__class__.__name__ + ' (\n' for key, module in model._modules.items(): # if it contains layers let call it recursively to get params and weights iftype(module) in [ torch.nn.modules.container.Container, torch.nn.modules.container.Sequential ]: modstr = torch_summarize(module) else: modstr = module.__repr__() modstr = _addindent(modstr, 2)
params = sum([np.prod(p.size()) for p in module.parameters()]) weights = tuple([tuple(p.size()) for p in module.parameters()])