1.模型权重保存 torch.save
model_name = args.model
if model_name == "ResNet18" or model_name == "ResNet34":
from models.ResNet1 import BasicBlock
from models.ResNet1 import ResNet as PATCHMODEL
if model_name == "ResNet18":
net = PATCHMODEL(BasicBlock, [2, 2, 2, 2], num_classes=num_classes).cuda()
torch.save(net.state_dict(), weights_dir + '/' + model_name + '_train_loss_min_numCls{}.pth'.format(num_classes))
2.模型权重上传 load_state_dict
model_name = args.model
if model_name == "ResNet18" or model_name == "ResNet34":
from models.ResNet1 import BasicBlock
from models.ResNet1 import ResNet as PATCHMODEL
if model_name == "ResNet18":
model = PATCHMODEL(BasicBlock, [2, 2, 2, 2], num_classes=num_classes).cuda()
model.load_state_dict(torch.load(model_path), strict=False)
本站资源均来自互联网,仅供研究学习,禁止违法使用和商用,产生法律纠纷本站概不负责!如果侵犯了您的权益请与我们联系!
转载请注明出处: 免费源码网-免费的源码资源网站 » 【Pytorch】模型权重保存与上传
发表评论 取消回复