要在PyTorch中实现分布式训练,可以使用torch.distributed包提供的工具和函数。下面是一个简单的示例代码,演示如何在PyTorch中设置并运行分布式训练:
import torch import torch.distributed as dist import torch.multiprocessing as mp from torch.nn.parallel import DistributedDataParallel as DDP def setup(rank, world_size): os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '12355' # 初始化进程组 dist.init_process_group("gloo", rank=rank, world_size=world_size) def cleanup(): dist.destroy_process_group() def train(rank, world_size): setup(rank, world_size) # 创建模型和优化器 model = MyModel() model = DDP(model) optimizer = torch.optim.SGD(model.parameters(), lr=0.01) # 加载数据 train_dataset = MyDataset() train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, sampler=train_sampler) # 训练 for epoch in range(10): for data, target in train_loader: optimizer.zero_grad() output = model(data) loss = F.nll_loss(output, target) loss.backward() optimizer.step() cleanup() if __name__ == '__main__': world_size = 4 mp.spawn(train, args=(world_size,), nprocs=world_size)
在这个示例中,我们首先设置了进程组,然后创建了模型、优化器和数据加载器。然后在train
函数中,我们使用torch.multiprocessing.spawn
函数来启动多个进程,每个进程运行train
函数。在train
函数中,我们将模型包装成DistributedDataParallel
对象来实现分布式训练,同时使用torch.utils.data.distributed.DistributedSampler
来分配数据。最后,我们在训练循环中进行模型训练,并在训练结束后清理进程组。