在Torch中定义一个神经网络模型通常需要使用nn.Module类。下面是一个示例代码,展示了如何定义一个简单的全连接神经网络模型:
import torch import torch.nn as nn class SimpleNN(nn.Module): def __init__(self): super(SimpleNN, self).__init__() self.fc1 = nn.Linear(784, 128) self.relu = nn.ReLU() self.fc2 = nn.Linear(128, 10) def forward(self, x): x = self.fc1(x) x = self.relu(x) x = self.fc2(x) return x model = SimpleNN()
在上面的代码中,我们定义了一个名为SimpleNN的神经网络模型,它包含两个全连接层和一个ReLU激活函数。在__init__
方法中,我们定义了模型的各个层,然后在forward
方法中定义了数据在模型中的流动路径。
需要注意的是,在定义神经网络模型时,通常需要继承nn.Module类,并实现__init__
和forward
方法。__init__
方法用于初始化模型的结构,forward
方法用于定义数据在模型中的传播路径。