PyTorch 对模型 save/load 的支持主要由 torch.serialization._save() 提供..pth 文件其实并不是单纯的 tensor data,而是模型的逻辑结构,这是因为 PyTorch 支持 tensor weight 共享,而共享本质是使用相同的 tensor data,所以两个共享权重的 module 在保存时,只需要保存一份 tensor data 即可.所以 PyTorch 在保存模型时,选择将模型拓扑结构与底层数据分开存的方式.前者用 pickle 存,后者用 storage.

Zipfile-Based Approach

从 PyTorch 1.6 开始,torch.save() 使用的是一种新的、基于 zipfile 的保存方式.并且默认启用.文件的结构是这样的:

1
2
3
4
5
6
7
8
9
checkpoint.pth
├── data.pkl
├── byteorder # added in PyTorch 2.1.0
├── data/
│ ├── 0
│ ├── 1
│ ├── 2
│ └── …
└── version
  • data.pkl 是 pickled object except contained torch.Storage
  • byteorder 指示是 little-endian 还是 big-endian
  • data/ 包含 object 中的 storage,每一个 storage 都是独立的文件
  • version 记录版本信息,load() 的时候会用到