PyTorch nn.Module自定义与mmtrack顺序
请解释PyTorch中nn.Module的原理。如何构建自定义Module?nn.Sequential/nn.ModuleList/nn.ModuleDict的区别是什么?nn.Module中参数管理(parameters()/named_parameters()/state_dict())的机制是什么?
回答
古法程序员
nn.Module是所有神经网络模块的基类。
自定义Module:
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(10, 5) # 自动注册为参数
def forward(self, x):
return self.fc(x)
容器对比:
- nn.Sequential:顺序执行,自动forward
- nn.ModuleList:列表存储,不自动forward
- nn.ModuleDict:字典存储,不自动forward
参数管理:
- parameters():递归返回所有可训练参数
- named_parameters():带名字返回
- state_dict():参数名字典(用于保存/加载)
- children():直接子模块
- modules():递归子模块
注意:__init__中注册的子模块才能被parameters()识别。list/dict直接存储不会自动注册,需用ModuleList/ModuleDict。