其实整个转换的流程也并不是特别复杂.假设我们有一个 torch.bfloat16 的张量,我们想把它 dump 成二进制文件
1 | reference = torch.randn(..., dtype=torch.bfloat16) |
NumPy 里有一个函数 .tobytes() 可以很方便地将张量转换为二进制 bytes 以写入文件.如果此时 reference 的 dtype 是 NumPy 支持的类型,例如 torch.float16,那么此时直接将 PyTorch 转换为 NumPy 即可~这个工具 .numpy() 是 PyTorch 自带的.
1 | reference_bytes = reference |
但是 NumPy 却并不支持 torch.bfloat16,所以 .astype() 这一步会直接报错.
所以这里的 workaround 是,将 refenrence 张量在 PyTorch 侧转化为 torch.int16,于是在 NumPy 一侧也是 np.int16.因为我们其实并不需要在 NumPy 一侧对张量进行处理,所以可以直接 dump np.int16 的张量:
1 | ref_bytes = reference |