Checkpoint 其实挺恶心的,不同格式之间保存的方式不同.而且当需要为一种框架适配另一种框架的输出张量时,这种情况就很突出了(没错,就是在 InfiniTrain 上适配 PyTorch,还有什么 LLMC 的东西).那么接下来先来讲一讲 InfiniTrain 里写的 Checkpoint 框架.

General Checkpoint Design

以 InfiniTrain 项目里 Checkpoint Design 来简单谈谈 General Checkpoint Design

总而言之的话,Checkpoint 框架可以分成两层:

  1. 第一层是通用状态训练层,包含了参数 state_dict、优化器和 trainer_state
  2. 第二层则专注于模型格式适配,即 LLMC、model.bin 通用二进制、和 PyTorch .pth

为了能拿到所有参数,先在 module.hoptimizer.hdistributed_optimizer.h 里添加 StateDict()LoadStateDict() API.

接下来,定义一些数据结构:

  • TrainerState 包含一些全局的数据,例如 global_step, best_loss, checkpoint_format 和 parallelism size 等等
  • CheckpointOptions 则包含 format 和一个匿名函数,负责将 model parameter 写入文件
  • CheckpointLoadOptions 则相反,负责读取 model parameter 并写入 nn::Module* 指针指向的模型里

之所以这里在 checkpoint IO 里用回调函数,是因为在后续 training loop 里可以使用异步写入文件,不影响 main loop 训练的速度.

同时,适配新格式的时候也更加方便.

接着是 Checkpoint 类,最最重要的是 Save(), Load() 两个 API

1
2
3
4
5
6
7
8
9
10
11
12
13
static void Save(
const std::filesystem::path &checkpoint_dir,
const nn::Module &model,
const Optimizer &optimizer,
const TrainerState &state,
const CheckpointOptions &options = {});

static void Load(
const std::filesystem::path &checkpoint_dir,
nn::Module *model,
Optimizer *optimizer,
TrainerState *state,
const CheckpointLoadOptions &options = {});

Save() 路径

  • 如果 format=bin 且提供了 model_bin_writer,优先走模型专有写出(LLMC 兼容),否则走通用 SaveStateDictBinary(...)
  • optimizer 状态按 save_optimizer_state 决定是否落盘
  • trainer_state.json 始终写,记录 step/best_loss/lr/并行拓扑。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
// 创建文件夹
std::filesystem::create_directories(checkpoint_dir);

// 处理模型参数
const auto model_path = checkpoint_dir / (options.format == "pth" ? "model.pth" : "model.bin");
if (options.format == "bin" && options.model_bin_writer) {
options.model_bin_writer(model, model_path);
} else {
SaveStateDictBinary(model_path, model.StateDict());
}

// 处理优化器
if (options.save_optimizer_state) {
auto opt_state = optimizer.StateDict();
if (!opt_state.empty()) {
const auto opt_path = checkpoint_dir / (options.format == "pth" ? "optimizer.pth" : "optimizer.bin");
SaveStateDictBinary(opt_path, opt_state);
}
}

// 处理 trainer_state.json
SaveTrainerState(checkpoint_dir / "trainer_state.json", state);

最重要的,要在 main.cc 里写入 model_bin_writer 的逻辑,即该如何保存参数:

1
2
3
4
// 这里是直接用了原来库里提供好的 SaveAsLLMC
options.model_bin_writer = [&](const nn::Module &, const std::filesystem::path &model_path) {
llmc_model->SaveAsLLMC(model_path.string());
};

Load() 路径

也是类似的,分别处理模型参数、优化器参数和全局训练参数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
const std::string format = InferFormat(checkpoint_dir);
const auto model_path = checkpoint_dir / (format == "pth" ? "model.pth" : "model.bin");

// 加载模型参数
if (format == "bin" && options.model_bin_loader) {
const uint32_t magic = PeekMagic(model_path);
if (magic == kCkptMagic) {
model->LoadStateDict(LoadStateDictBinary(model_path));
} else {
options.model_bin_loader(model, model_path);
}
} else {
model->LoadStateDict(LoadStateDictBinary(model_path));
}

// 加载优化器
if (optimizer != nullptr && options.load_optimizer_state) {
const auto opt_path = checkpoint_dir / (format == "pth" ? "optimizer.pth" : "optimizer.bin");
if (std::filesystem::exists(opt_path)) {
optimizer->LoadStateDict(LoadStateDictBinary(opt_path));
}
}

// 加载全局训练参数
*state = LoadTrainerState(checkpoint_dir / "trainer_state.json");

最重要的,要在 main.cc 里写入 model_bin_loader 的逻辑,即该如何读取参数:

1
2
3
4
5
// 这里是直接用了原来库里提供好的 FromLLMC
load_options.model_bin_loader = [](nn::Module *target_model, const std::filesystem::path &model_path) {
auto loaded_model = GPT2::FromLLMC(model_path.string());
target_model->LoadStateDict(loaded_model->StateDict());
};