其实整个转换的流程也并不是特别复杂.假设我们有一个 torch.bfloat16 的张量,我们想把它 dump 成二进制文件

1
reference = torch.randn(..., dtype=torch.bfloat16)

NumPy 里有一个函数 .tobytes() 可以很方便地将张量转换为二进制 bytes 以写入文件.如果此时 referencedtype 是 NumPy 支持的类型,例如 torch.float16,那么此时直接将 PyTorch 转换为 NumPy 即可~这个工具 .numpy() 是 PyTorch 自带的.

1
2
3
4
5
6
7
8
9
reference_bytes = reference
.cpu() # 先转回 CPU,因为 NumPy 只能在 CPU 端
.numpy() # 直接转换
.astype(np.float16) # 直接用 NumPy 原生表示 float16
.tobytes()

# 直接打开文件写入
with open('some.file', "wb") as f:
f.write(reference_bytes)

但是 NumPy 却并不支持 torch.bfloat16,所以 .astype() 这一步会直接报错.

所以这里的 workaround 是,将 refenrence 张量在 PyTorch 侧转化为 torch.int16,于是在 NumPy 一侧也是 np.int16.因为我们其实并不需要在 NumPy 一侧对张量进行处理,所以可以直接 dump np.int16 的张量:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
ref_bytes = reference
.to(torch.bfloat16)
.view(torch.int16)
# 这里用 .view() 是因为 bfloat16, int16 都是 16 bit 的
# 所以直接 reinterpret 即可,无需拷贝
#
# 而且不能用 .to(torch.int16),
# 因为 .to() 会对浮点数进行 rounding,但这不是我们想要的
.cpu()
.numpy()
.tobytes()

with open('some.file', "wb") as f:
f.write(ref_bytes)