在PyTorch中获取预训练模型有多种方法,以下是一些常用的途径:
使用PyTorch内置的预训练模型
PyTorch的torchvision.models
模块提供了多种预训练模型,这些模型已经在大型数据集(如ImageNet)上进行了训练,可以直接使用。例如,要加载一个预训练的ResNet-50模型,可以使用以下代码:
import torchvision.models as models model = models.resnet50(pretrained=True)
在这个代码中,pretrained=True
参数指示PyTorch下载并加载预训练权重。一旦加载,你就可以直接使用这个模型进行预测,或者在特定数据集上进行微调。
使用PyTorch Hub
PyTorch Hub是一个存储和共享预训练模型、预处理代码和数据集的仓库。你可以通过PyTorch Hub轻松地访问和使用这些资源。例如,要使用预训练的ResNet-18模型,可以这样做:
import torch import torchvision.models as models model = models.resnet18(pretrained=True)
PyTorch Hub还允许你通过指定GitHub存储库的名称和分支来加载模型,提供了更多的灵活性和自定义选项。
通过上述方法,你可以轻松地获取和使用PyTorch的预训练模型,无论是通过PyTorch内置的模型库还是通过PyTorch Hub。