PyTorch 对模型 save/load 的支持主要由 torch.serialization._save() 提供..pth 文件其实并不是单纯的 tensor data,而是模型的逻辑结构,这是因为 PyTorch 支持 tensor weight 共享,而共享本质是使用相同的 tensor data,所以两个共享权重的 module 在保存时,只需要保存一份 tensor data 即可.所以 PyTorch 在保存模型时,选择将模型拓扑结构与底层数据分开存的方式.前者用 pickle 存,后者用 storage.
Zipfile-Based Approach
从 PyTorch 1.6 开始,torch.save() 使用的是一种新的、基于 zipfile 的保存方式.并且默认启用.文件的结构是这样的:
1 | checkpoint.pth |
data.pkl是 pickled object except containedtorch.Storagebyteorder指示是 little-endian 还是 big-endiandata/包含 object 中的 storage,每一个 storage 都是独立的文件version记录版本信息,load()的时候会用到