定义ModuleList
我们可以将我们需要的层放入到一个集合中,然后将这个集合作为参数传入nn.ModuleList中,但是这个子类并不可以直接使用,因为这个子类并没有实现forward函数,所以要使用还需要放在继承了nn.Module的模型中进行使用。文章来源地址https://www.toymoban.com/news/detail-726020.html
model_list = nn.ModuleList([nn.Conv2d(1, 5, 2), nn.Linear(10, 2), nn.Sigmoid()])
x = torch.randn(32, 3, 24, 24)
for model in model_list:
model_list(x)
使用ModuleList定义网络
class Net(nn.Module):
def __init__(self):
super().__init__()
self.model_list = nn.ModuleList([nn.Conv2d(1, 5, 2), nn.Linear(10, 2), nn.Sigmoid()])
def forward(self, x):
return self.model_list(x)
打印网络层结构
model = Net()
print(model)
Net(
(model_list): ModuleList(
(0): Conv2d(1, 5, kernel_size=(2, 2), stride=(1, 1))
(1): Linear(in_features=10, out_features=2, bias=True)
(2): Sigmoid()
)
)
文章来源:https://www.toymoban.com/news/detail-726020.html
到了这里,关于pytorch中nn.ModuleList()使用方法的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!