PyTorch 中实现早停(Early Stopping)的方法有很多种,其中一种常见的方法是使用验证集(validation set)来监控模型的性能,当验证集上的性能不再提高时,停止训练。下面是一个简单的实现方法:
- 定义一个验证集和数据加载器(DataLoader):
from torch.utils.data import DataLoader, random_split # 假设 train_dataset 是你的训练数据集,val_dataset 是你的验证数据集 train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
- 定义一个函数来计算模型在验证集上的性能指标(例如准确率):
def compute_metrics(model, val_loader): model.eval() correct = 0 total = 0 with torch.no_grad(): for images, labels in val_loader: outputs = model(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() accuracy = 100 * correct / total return {'accuracy': accuracy}
- 定义一个早停类,用于监控验证集上的性能指标,并在性能不再提高时停止训练:
class EarlyStopping:
def __init__(self, patience=10, verbose=False):
self.patience = patience
self.verbose = verbose
self.best_score = float('-inf')
self.wait = 0
def step(self, model, val_loader):
score = compute_metrics(model, val_loader)['accuracy']
if score > self.best_score:
self.best_score = score
self.wait = 0
else:
self.wait += 1
if self.wait >= self.patience:
if self.verbose:
print(f'Early stopping at epoch {epoch}, best accuracy: {self.best_score:.2f}')
return True
return False
- 在训练循环中使用早停类:
early_stopping = EarlyStopping(patience=10, verbose=True) for epoch in range(num_epochs): model.train() # 训练代码 model.eval() should_stop = early_stopping.step(model, val_loader) if should_stop: break
在这个例子中,我们定义了一个 EarlyStopping
类,它会在验证集上的性能指标不再提高时停止训练。我们在训练循环中使用这个类,并在每个 epoch 结束时调用它的 step
方法。如果 step
方法返回 True
,则表示应该停止训练。