在PyTorch中,正则化是一种常用的技术,用于防止模型过拟合。常见的正则化方法包括L1正则化和L2正则化。
在PyTorch中,可以使用nn.Module
的add_weight()
方法为模型参数添加正则化项。例如,以下代码为模型的权重添加了L2正则化项:
import torch.nn as nn class MyModel(nn.Module): def __init__(self): super(MyModel, self).__init__() self.conv1 = nn.Conv2d(3, 64, kernel_size=3) self.conv2 = nn.Conv2d(64, 128, kernel_size=3) self.fc1 = nn.Linear(128 * 25 * 25, 1024) self.fc2 = nn.Linear(1024, 512) self.fc3 = nn.Linear(512, 10) # 添加L2正则化项 for param in self.parameters(): param.requires_grad = True param.register_hook(lambda x: x * (1 - 0.001)) def forward(self, x): x = F.relu(self.conv1(x)) x = F.max_pool2d(x, 2) x = F.relu(self.conv2(x)) x = F.max_pool2d(x, 2) x = x.view(-1, 128 * 25 * 25) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x
在上面的代码中,我们使用了register_hook()
方法为每个参数添加了一个钩子函数,该函数将参数乘以一个因子(在这里是1 - 0.001
),从而实现了L2正则化。
除了L2正则化外,还可以使用其他正则化方法,例如L1正则化和Dropout。在PyTorch中,这些方法也可以很容易地实现。