Checkpoint 其实挺恶心的,不同格式之间保存的方式不同.而且当需要为一种框架适配另一种框架的输出张量时,这种情况就很突出了(没错,就是在 InfiniTrain 上适配 PyTorch,还有什么 LLMC 的东西).那么接下来先来讲一讲 InfiniTrain 里写的 Checkpoint 框架.
General Checkpoint Design
以 InfiniTrain 项目里 Checkpoint Design 来简单谈谈 General Checkpoint Design
总而言之的话,Checkpoint 框架可以分成两层:
- 第一层是通用状态训练层,包含了参数
state_dict、优化器和 trainer_state
- 第二层则专注于模型格式适配,即 LLMC、
model.bin 通用二进制、和 PyTorch .pth
为了能拿到所有参数,先在 module.h、optimizer.h、distributed_optimizer.h 里添加 StateDict() 和 LoadStateDict() API.
接下来,定义一些数据结构:
TrainerState 包含一些全局的数据,例如 global_step, best_loss, checkpoint_format 和 parallelism size 等等
CheckpointOptions 则包含 format 和一个匿名函数,负责将 model parameter 写入文件
CheckpointLoadOptions 则相反,负责读取 model parameter 并写入 nn::Module* 指针指向的模型里
之所以这里在 checkpoint IO 里用回调函数,是因为在后续 training loop 里可以使用异步写入文件,不影响 main loop 训练的速度.
同时,适配新格式的时候也更加方便.
接着是 Checkpoint 类,最最重要的是 Save(), Load() 两个 API
1 2 3 4 5 6 7 8 9 10 11 12 13
| static void Save( const std::filesystem::path &checkpoint_dir, const nn::Module &model, const Optimizer &optimizer, const TrainerState &state, const CheckpointOptions &options = {});
static void Load( const std::filesystem::path &checkpoint_dir, nn::Module *model, Optimizer *optimizer, TrainerState *state, const CheckpointLoadOptions &options = {});
|
Save() 路径
- 如果
format=bin 且提供了 model_bin_writer,优先走模型专有写出(LLMC 兼容),否则走通用 SaveStateDictBinary(...)
- optimizer 状态按
save_optimizer_state 决定是否落盘
trainer_state.json 始终写,记录 step/best_loss/lr/并行拓扑。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
| std::filesystem::create_directories(checkpoint_dir);
const auto model_path = checkpoint_dir / (options.format == "pth" ? "model.pth" : "model.bin"); if (options.format == "bin" && options.model_bin_writer) { options.model_bin_writer(model, model_path); } else { SaveStateDictBinary(model_path, model.StateDict()); }
if (options.save_optimizer_state) { auto opt_state = optimizer.StateDict(); if (!opt_state.empty()) { const auto opt_path = checkpoint_dir / (options.format == "pth" ? "optimizer.pth" : "optimizer.bin"); SaveStateDictBinary(opt_path, opt_state); } }
SaveTrainerState(checkpoint_dir / "trainer_state.json", state);
|
最重要的,要在 main.cc 里写入 model_bin_writer 的逻辑,即该如何保存参数:
1 2 3 4
| options.model_bin_writer = [&](const nn::Module &, const std::filesystem::path &model_path) { llmc_model->SaveAsLLMC(model_path.string()); };
|
Load() 路径
也是类似的,分别处理模型参数、优化器参数和全局训练参数
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
| const std::string format = InferFormat(checkpoint_dir); const auto model_path = checkpoint_dir / (format == "pth" ? "model.pth" : "model.bin");
if (format == "bin" && options.model_bin_loader) { const uint32_t magic = PeekMagic(model_path); if (magic == kCkptMagic) { model->LoadStateDict(LoadStateDictBinary(model_path)); } else { options.model_bin_loader(model, model_path); } } else { model->LoadStateDict(LoadStateDictBinary(model_path)); }
if (optimizer != nullptr && options.load_optimizer_state) { const auto opt_path = checkpoint_dir / (format == "pth" ? "optimizer.pth" : "optimizer.bin"); if (std::filesystem::exists(opt_path)) { optimizer->LoadStateDict(LoadStateDictBinary(opt_path)); } }
*state = LoadTrainerState(checkpoint_dir / "trainer_state.json");
|
最重要的,要在 main.cc 里写入 model_bin_loader 的逻辑,即该如何读取参数:
1 2 3 4 5
| load_options.model_bin_loader = [](nn::Module *target_model, const std::filesystem::path &model_path) { auto loaded_model = GPT2::FromLLMC(model_path.string()); target_model->LoadStateDict(loaded_model->StateDict()); };
|