在PyTorch中,权重衰减(weight decay)是一种正则化技术,用于防止模型过拟合。它通过在损失函数中添加一个惩罚项来实现,这个惩罚项是模型参数的平方和乘以权重衰减参数。
要在PyTorch中设置权重衰减,可以使用torch.optim.SGD
优化器的weight_decay
参数。以下是一个示例:
import torch import torch.nn as nn import torch.optim as optim # 定义一个简单的模型 model = nn.Linear(10, 1) # 定义损失函数和优化器 criterion = nn.MSELoss() optimizer = optim.SGD(model.parameters(), lr=0.01, weight_decay=0.001) # 训练模型 for epoch in range(100): # 前向传播 output = model(torch.randn(1, 10)) loss = criterion(output, torch.randn(1, 1)) # 反向传播和优化 optimizer.zero_grad() loss.backward() optimizer.step()
在这个例子中,我们使用optim.SGD
优化器训练一个简单的线性模型。我们将权重衰减参数设置为0.001
。你可以根据需要调整这个参数来控制正则化的强度。
另外,你还可以使用torch.optim.Adam
优化器,它也支持权重衰减。只需将weight_decay
参数传递给Adam
优化器即可:
optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=0.001)