safetensors [1] 号称提供一种更安全的存数据方式,支持多种框架,见 [2]。不过在处理玄数据(metadata)时:
- 只支持 Dict[str, str] 的形式,即值必须是字符串,而不能是 int、float 或嵌套 dict,而这些在 PyTorch 原先的 torch.save、torch.load 是支持的。考虑用
json.dumps
将 dict 转写成字符串,读时则用json.loads
恢复回 dict。 - 没有专门从 checkpoint 文件读出 metadata 的方法。考虑采用 [3] 中 Ok_Storage_1799 的回答所讲利用
safetensors.safe_open
的方法读 metadata。
下面是存、取 PyTorch 模型参数、metadata 的简例:
import time, json, pprint
import torch
from safetensors import safe_open # to read metadata
from safetensors.torch import save_model, load_modelprint("建模型")
model = torch.nn.Linear(2, 3)
# 初始参数值
for pn, p in model.named_parameters():print(pn, p)print("存模型、metadata")
# 将模型参数置零 (模拟 training)
for p in model.parameters():p.data.zero_()
# 存模型
save_model(model,"ckpt.safetensors",# metadata 用 json 转写成 str{"metadata": json.dumps({"time": time.asctime(),"epoch": 57,"acc": 0.56,"args": {"debug": False,"dataset": "MNIST","decay_steps": [10, 20]}})}
)print("读模型")
load_model(model, "ckpt.safetensors")
# 验证更新(置零)后参数值
for pn, p in model.named_parameters():print(pn, p)print("读 metadata")
with safe_open("ckpt.safetensors", framework="pt") as f:print(type(f), dir(f))print(list(f.keys())) # 模型参数的名字print(type(f.metadata())) # dictfor k, v in f.metadata().items():print(k, v)# 用 json 恢复 metadata 成 dictif "metadata" == k:metadata = json.loads(v)pprint.pprint(metadata)
References
- huggingface/safetensors
- Python documentation
- How to get metadata from a safetensor file?