1、保存模型
- # 定义模型
- model = BPNetModel(n_feature=n_feature,n_hidden=n_hidden,n_output=n_output) #调用网络
-
- # 保存模型
- torch.save(model, 'BPNetModel0.pth')
2、加载模型
- import torch
-
- ## 读取模型
- model = torch.load('BPNetModel0.pth')
3、保存模型参数
- #调用网络
- model = BPNetModel(n_feature=n_feature,n_hidden=n_hidden,n_output=n_output)
-
- # 保存模型
- torch.save({'model': model.state_dict()}, 'BPNetModel0.pth')
4、加载参数
- # 读取模型
- state_dict = torch.load('model_name.pth')
- model.load_state_dict(state_dict['model'])