1.模型权重保存 torch.save

model_name = args.model
if model_name == "ResNet18" or model_name == "ResNet34":
    from models.ResNet1 import BasicBlock
    from models.ResNet1 import ResNet as PATCHMODEL
if model_name == "ResNet18":
	net = PATCHMODEL(BasicBlock, [2, 2, 2, 2], num_classes=num_classes).cuda()
	
torch.save(net.state_dict(), weights_dir + '/' + model_name + '_train_loss_min_numCls{}.pth'.format(num_classes))

2.模型权重上传 load_state_dict

model_name = args.model
if model_name == "ResNet18" or model_name == "ResNet34":
    from models.ResNet1 import BasicBlock
    from models.ResNet1 import ResNet as PATCHMODEL
if model_name == "ResNet18":
    model = PATCHMODEL(BasicBlock, [2, 2, 2, 2], num_classes=num_classes).cuda()
    
model.load_state_dict(torch.load(model_path), strict=False)

点赞(0) 打赏

评论列表 共有 0 条评论

暂无评论

微信公众账号

微信扫一扫加关注

发表
评论
返回
顶部