【动手学深度学习笔记】之读取和存储

时间:2022-07-23
本文章向大家介绍【动手学深度学习笔记】之读取和存储,主要内容包括其使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。

1. 读取和存储

在训练好模型后,有时需要把训练好的模型参数存储起来以供后续使用。

1.1 读写Tensor

存储和读取Tensor可以分别使用save函数和load函数实现。save函数的操作对象包括模型、张量和字典等。

首先创建两个Tensor

import torch

x = torch.ones(3)
y = torch.zeros(4)

读取和存储单个Tensor实例:

torch.save(x,'x.pt')
x2 = torch.load('x.pt')
print(x2)

Out[1]:

tensor([1., 1., 1.])

读取和存储一个Tensor列表实例:

torch.save([x,y],'xy.pt')
xy_list = torch.load('xy.pt')
print(xy_list)

Out[1]:

[tensor([1., 1., 1.]), tensor([0., 0., 0., 0.])]

读取和存储一个Tensor字典实例:

torch.save({'x': x, 'y': y}, 'xy_dict.pt')
xy = torch.load('xy_dict.pt')
print(xy)

Out[1]:

{'x': tensor([1., 1., 1.]), 'y': tensor([0., 0., 0., 0.])}

1.2 读写模型

优化器和具有可学习参数的层的参数名称和参数被存储在state_dict

下面以实例调用state_dict来显示模型参数和名称。

class MLP(nn.Module):
    def __init__(self):
        super(MLP,self).__init__()
        self.hidden = nn.Linear(3,2)
        self.act = nn.ReLU()
        self.output = nn.Linear(2,1)
        
    def forward(self,x):
        a = self.act(self.hidden(x))
        return self.output(a)
    
net = MLP()
net.state_dict()

Out[1]:
    
OrderedDict([('hidden.weight',
              tensor([[-0.3303, -0.2529, -0.4268],
                      [ 0.4672, -0.2530, -0.0974]])),
             ('hidden.bias', tensor([-0.1994, -0.2971])),
             ('output.weight', tensor([[-0.3032, -0.0526]])),
             ('output.bias', tensor([0.5046]))])

下面以实例调用state_dict来显示优化器状态和超参数。

optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
optimizer.state_dict()

Out[1]:

{'state': {},
 'param_groups': [{'lr': 0.001,
   'momentum': 0.9,
   'dampening': 0,
   'weight_decay': 0,
   'nesterov': False,
   'params': [1967586368392, 1967586368472, 1967586368632, 196758368712]}]}

可以通过保存和加载模型参数(state_dict)来实现保存和加载模型。

同样通过实例来显示整个过程。

#保存:
torch.save(net.state_dict(),'K:sd.pt')

#读取:
net1 = MLP()
net1.load_state_dict(torch.load('K:sd.pt'))

Out[1]:
    <All keys matched successfully>

也可以直接存储和读取整个模型。

#存储
torch.save(net,'K:sd1.pt')

#读取
net2 = torch.load('K:sd1.pt')

通过这两种方法保存和读取的模型具有相同的模型参数,因此他们的正向传播结果是相同的。