PyTorch 最佳实践:模型保存和加载

时间:2022-07-25
本文章向大家介绍PyTorch 最佳实践:模型保存和加载,主要内容包括其使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。

PyTorch模型保存和加载有两种方法,官方最佳实践指南推荐其中一种,但似乎效果没啥区别。最近做模型量化,遇到一个意外的错误,才理解了最佳实践背后的原理,以及不遵循它可能会遇到什么问题。

作者:Lernapparat 编译:McGL

我们研究了一些最佳实践,同时尝试阐明其背后的基本原理。

你是中级 PyTorch 程序员吗?你是否遵循官方文档的最佳实践指南?你对哪些应该坚持,哪些可以放弃而不会搞出问题有自己的经验和看法吗?

我承认有时候很难遵循最佳实践,因为他们反对的方法似乎也能工作,而我并不完全理解他们的基本原理。这是发生在我身上的一件小事。

一个我做量化 (Quantization)的故事

在Raspberry Pi 上搭建 PyTorch 之后,我一直期待着用它做一些有趣的项目。当然,我找到了一个模型,我想在Pi上适配并跑起来。我很快就让它跑起来了,但是它没有我想象的那么快。所以我开始着手量化它。

量化使得任何操作都是有状态的 / 暂时的(stateful / temporarily)

如果你把 PyTorch 计算看作是一组由操作链接起来的值(张量),量化包括对每个操作进行量化,并形成一个意见(opinion),即通过一个仿射变换对量化元素类型进行整数范围近似,张量值输出的范围应该是多少。如果这听起来很复杂,不要担心,重点是现在每个操作都需要与“一个意见”相关联,或者更准确的说,是一个观察者,记录模型的一些典型应用中所看到的最小值和最大值。但是现在这意味着在量化期间,所有操作都是有状态的。更准确的说,在准备量化和进行量化之前,它们都是有状态的。

我经常提到这一点,我主张不要声明一次激活函数,然后多次重用。这是因为在使用函数的计算中的各个点上,观察者通常会看到不同的值,所以现在它们的工作方式不同了。

这种新的有状态特性也适用于简单的事情,比如张量相加,通常表示为 a + b。为此, PyTorch 提供了 torch.nn.quantized.FloatFunctional模块。这是一个常见的 Module ,但是做了修改,在计算中不使用 forward ,而是有几种方法对应基本的操作,如我们这里的.add

所以我使用了残差(residual)模块,它看起来大概像这样(注意它是如何分开独立声明激活的,这是一件好事!):

class ResBlock(torch.nn.Module):
  def __init__(self, ...):
     self.conv1 = ...
     self.act1 = ...
     self.conv2 = ...
     self.act2 = ...
  def forward(self, x):
     return self.act2(x + self.conv2(self.act1(self.conv1(x))))

我还添加了 self.add = torch.nn.quantized.FloatFunctional() 到 __init__ 并把 x + ... 替换为 self.add.add(x, ...)。搞定!

根据准备好的模型,我可以添加量化本身,依据PyTorch 教程执行很简单。在评估脚本的最后,模型全部加载、设置为 eval 等之后,我添加了以下内容并重新启动了正在使用的 notebook kernel,然后运行了所有这些。

#config
model.qconfig = torch.quantization.get_default_qconfig('qnnpack')
torch.backends.quantized.engine = 'qnnpack'
# wrap in quantization of inputs / de-quantization of output)
model = torch.quantization.QuantWrapper(model)
# insert observers
torch.quantization.prepare(model, inplace=True)

因此稍后(在运行模型以获得观察结果之后) ,我会调用

torch.quantization.convert(model, inplace=True)

来得到一个模型。很简单!

一个意外的错误

现在我只需要运行几个批次的输入。

preds = model(inp)

但是发生了什么呢?

ModuleAttributeError: 'ResBlock' object has no attribute 'add'

糟糕!

出什么问题了? 是不是我在 ResBlock 中有拼写错误?

在 Jupyter中你可以非常容易地使用 ?? model.resblock1来检查。但是这没问题,没有拼写错误。

这就是 PyTorch 最佳实践的用武之地。

序列化(Serialization)最佳实践

PyTorch 官方文档有个关于序列化的说明,其中包含一个最佳实践部分。它这样开头

序列化和还原模型主要有两种方法。第一个(推荐)是只保存和加载模型参数:

然后展示了如何用 state_dict() 和 load_state_dict() 方法来运作. 第二种方法是保存和加载模型。

该说明提供了优先只使用序列化参数的理由如下:

然而,在[保存模型的情况]下,序列化的数据绑定到特定的类和所使用的确切目录结构,因此在其他项目中使用时,或在一些重度的重构之后,它可能会以各种方式中断。

事实证明,这是一个相当轻描淡写的说法,甚至在我们非常温和的修改中——几乎算不上重大的修改——也遇到了它所提到的问题。

什么出了问题?

为了找到问题的核心,我们必须思考 Python 中的对象是什么。在一个粗略的过度简化中,它完全由其 __dict__属性定义, 该属性包含所有("data")成员,其__class__ 属性指向它的类型( 例如,对于 Module 实例,是Module, 而对于 Module 本身 (一个类) ,是 type) 。当我们调用一个方法时,它通常不在 __dict__ 中(其实也可以,但改动会比较复杂)。但是 Python 会自动查询 __class__ 来寻找方法 (或者其他在 __dict__中找不到的东西)。

当反序列化模型时(我使用的模型的作者没有遵循最佳实践建议) ,Python 将通过查找 __class__ 的类型并将其与反序列化__dict__组合来构造一个对象。但是它(正确地)没有做的是调用 __init__ 来设置类(它不应该这样做,尤其是担心在 __init__ 和序列化之间可能已经修改了内容,或者它可能有我们不希望的副作用)。这意味着,当我们调用模块时,我们使用了新的forward 但是得到了原作者的__init__ 准备的__dict__ 和后续的训练,而没有我们修改过的 __init__ 添加的新属性add。

所以简而言之,这就是为什么在 Python 中序列化 PyTorch 模块或通常意义上的对象是危险的: 你很容易就会得到数据属性和代码不同步的结果。

保持兼容性

这里有一个显而易见的问题——也可以说是一个缺点——那就是除了状态字典(state dict)之外,我们还需要跟踪 setup 的配置。但是如果你愿意的话,你可以轻松地序列化所有参数以及状态字典——只需将它们粘贴到一个联合字典中。

但是不序列化模块本身还有其他优点:

显而易见的是,我们可以使用状态字典。可以无需模块加载状态字典,如果我们改变了一些重要的东西,可以检查和修改状态字典。

不太明显的是,实现者或用户还可以自定义模块处理状态字典。这有两个方面:

  • 对于用户来说,有钩子(hooks)。好吧,它们不是非常官方,但是有_register_load_state_dict_pre_hook ,你可以用它来注册钩子,在更新模型之前处理状态字典,还有_register_state_dict_hook 来注册钩子,这些钩子在状态字典被收集之后和从 state_dict()返回之前被调用。
  • 更重要的是,实现者可以覆写 _load_from_state_dict 。当类具有属性 _version时,这将在状态字典中保存为 version 元数据(metadata). 有了这个,你可以添加来自旧状态字典的转换。BatchNorm提供了一个怎么做到这点的例子,大致看起来像这样:
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
                              missing_keys, unexpected_keys, error_msgs):
        version = local_metadata.get('version', None)

        if (version is None or version < 2) and self.have_new_thing:
            new_key = prefix + 'new_thing_param'
            if new_key not in state_dict:
                state_dict[new_key] = ... # some default here

        super()._load_from_state_dict(
            state_dict, prefix, local_metadata, strict,
            missing_keys, unexpected_keys, error_msgs)

因此在这里我们检查版本是否是旧的,并且需要一个新的key,先添加它,然后再交给超类 (通常是 torch.nn.Module)常规处理。

总结

当保存整个模型而不是按照最佳实践只保存参数时,我们已经看到了什么出错了的非常详细的描述。 我个人的看法是,保存模型的陷阱是相当大的,很容易掉坑里,所以我们真的应该注意只保存模型参数,而不是 Module 类。

希望你喜欢这个深入 PyTorch 最佳实践的小插曲。

原文:http://lernapparat.de/pytorch-best-practices/