在PyTorch中,可以通过定义一个函数来初始化模型的权重。以下是一个示例代码:
import torch import torch.nn as nn def init_weights(m): if type(m) == nn.Linear or type(m) == nn.Conv2d: nn.init.xavier_uniform_(m.weight) nn.init.zeros_(m.bias) # 定义模型 model = nn.Sequential( nn.Conv2d(3, 64, kernel_size=3), nn.ReLU(), nn.Linear(64*28*28, 10) ) # 初始化模型权重 model.apply(init_weights)
在上面的代码中,定义了一个init_weights
函数,该函数根据模型的类型对权重进行初始化。然后通过调用model.apply(init_weights)
来初始化模型的权重。这样就可以保证模型的权重被正确地初始化。