优雅地导入预训练模型¶
动机¶
使用MODELS注册器生成模型后,可能需要进行预训练参数的导入。使用pytorch
自带的model.load(state_dict)函数能满足大多数的需求。然而在以下几种场景下,
使用自带的导入函数将不满足需求:
单机的模型希望导入分布式环境下训练的模型
DistributeParallel,将因为module前缀导致导入失败;新的分类模型的
head模块中调整了新的分类个数,将因为fc层参数大小不匹配导致导入失败;新的分类模型的
backbone模块只希望导入自监督模型的backbone模块,一般情况下需要手动操作;读取存放于不同文件系统下的预训练参数,一般需要手动下载;
在这篇文章中,将介绍一种优雅地导入预训练模型的机制,来解决上述问题。
解决方法¶
在定义模型结构时,通过pretrain参数来指明如何导入预训练模型即可。
import torch
from torch.nn.parallel import DataParallel
from essmc2.models import MODELS, BACKBONES
# Mock一个模型,并保存相关参数
model = MODELS.build(dict(
type="Classifier",
backbone=dict(type='ResNet'),
neck=dict(type='GlobalAveragePooling'),
head=dict(type='ClassifierHead', dim=2048, num_classes=100),
))
model = DataParallel(model)
torch.save({
"state_dict": model.state_dict()
}, "pretrain.pth")
# 创建一个分类数不一样的新模型,并load参数
new_classifier = MODELS.build(dict(
type="Classifier",
backbone=dict(type="ResNet"),
neck=dict(type='GlobalAveragePooling'),
head=dict(type='ClassifierHead', dim=2048, num_classes=200),
pretrain="pretrain.pth"
))
# 2021-12-02 17:20:12,821 - essmc2 - INFO - Load pretrained model [Classifier] from pretrain.pth
# 2021-12-02 17:20:12,925 - essmc2 - WARNING - ignore keys from source:
# head.fc.weight: invalid shape, dst torch.Size([200, 2048]) vs. src torch.Size([100, 2048])
# head.fc.bias: invalid shape, dst torch.Size([200]) vs. src torch.Size([100])
# 2021-12-02 17:20:12,974 - essmc2 - WARNING - missing key in source state_dict: head.fc.weight, head.fc.bias
# 创建一个backbone,并load参数
new_backbone = BACKBONES.build(dict(
type="ResNet",
pretrain=dict(path='pretrain.pth', sub_level="backbone")
))
# 2021-12-02 17:21:58,616 - essmc2 - INFO - Load pretrained model [ResNet] from pretrain.pth
# 创建一个从url load参数的backbone
new_backbone_2 = BACKBONES.build(dict(
type="ResNet",
pretrain="https://download.pytorch.org/models/resnet50-0676ba61.pth"
))
# 2021-12-02 17:23:31,761 - essmc2 - INFO - Load pretrained model [ResNet] from https://download.pytorch.org/models/resnet50-0676ba61.pth
# Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to .cache/torch/hub/checkpoints/resnet50-0676ba61.pth
# 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 97.8M/97.8M [00:17<00:00, 5.97MB/s]
# 2021-12-02 17:23:53,279 - essmc2 - WARNING - unexpected key in source state_dict: fc.weight, fc.bias