检查点存储流程

保存checkpoint日志解析

  • 保存检查点开始
    1
    2024-11-21 10:18:54,402 >> Saving model checkpoint to /home/dell/sdb/saves/Qwen2-0___5B-Instruct/freeze/sft/checkpoint-25
    PGSQL
    模型检查点将被保存到指定路径 /home/dell/sdb/saves/Qwen2-0___5B-Instruct/freeze/sft/checkpoint-25,代表训练中的第25次迭代或某个训练进度标志。
  • 保存配置文件
    1
    2
    [INFO|configuration_utils.py:472] 2024-11-21 10:18:54,406 >> Configuration saved in /home/dell/sdb/saves/Qwen2-0___5B-Instruct/freeze/sft/checkpoint-25/config.json
    [INFO|configuration_utils.py:807] 2024-11-21 10:18:54,406 >> Configuration saved in /home/dell/sdb/saves/Qwen2-0___5B-Instruct/freeze/sft/checkpoint-25/generation_config.json
    PGSQL
    config.json 和 generation_config.json 分别保存模型的基础配置和生成参数配置,确保模型加载时可以正确重现训练环境。
  • 保存模型权重文件
1
[INFO|modeling_utils.py:2766] 2024-11-21 10:19:00,213 >> Model weights saved in /home/dell/sdb/saves/Qwen2-0___5B-Instruct/freeze/sft/checkpoint-25/model.safetensors
PGSQL

权重文件保存为 model.safetensors 格式。

  • 保存分词器文件
1
2
[INFO|tokenization_utils_base.py:2702] 2024-11-21 10:19:00,214 >> tokenizer config file saved in /home/dell/sdb/saves/Qwen2-0___5B-Instruct/freeze/sft/checkpoint-25/tokenizer_config.json
[INFO|tokenization_utils_base.py:2711] 2024-11-21 10:19:00,214 >> Special tokens file saved in /home/dell/sdb/saves/Qwen2-0___5B-Instruct/freeze/sft/checkpoint-25/special_tokens_map.json
PGSQL

tokenizer_config.json:分词器的配置信息。
special_tokens_map.json:保存特殊 token(如 pad_token、cls_token 等)的映射关系。

保存 DeepSpeed 检查点

  • 记录全局步骤的检查点信息
1
2
[2024-11-21 10:20:31,627] [INFO] [logging.py:96:log_dist] [Rank 0] [Torch] Checkpoint global_step25 is about to be saved!

PGSQL

global_step25 代表训练到第 25 个全局步骤的状态。

  • 保存零冗余优化器的检查点文件
    1
    [2024-11-21 10:22:33,574] [INFO] [logging.py:96:log_dist] [Rank 0] Saving model checkpoint: /home/dell/sdb/saves/Qwen2-0___5B-Instruct/freeze/sft/checkpoint-25/global_step25/zero_pp_rank_0_mp_rank_00_model_states.pt
    PGSQL
    zero_pp_rank_0_mp_rank_00_model_states.pt:保存当前模型状态,用于零冗余优化(ZeRO)的分布式训练。
1
2
[2024-11-21 10:22:33,575] [INFO] [torch_checkpoint_engine.py:21:save] [Torch] Saving ...
[2024-11-21 10:26:56,161] [INFO] [torch_checkpoint_engine.py:23:save] [Torch] Saved ...
PROLOG
1
[2024-11-21 10:27:11,351] [INFO] [torch_checkpoint_engine.py:23:save] [Torch] Saved .../zero_pp_rank_0_mp_rank_00_optim_states.pt.
PROLOG

zero_pp_rank_0_mp_rank_00_optim_states.pt 保存优化器状态,保证恢复训练时可以正确加载优化器的参数。

1
2
3
[2024-11-21 10:27:11,369] [INFO] [engine.py:3589:_save_zero_checkpoint] zero checkpoint saved ...
[2024-11-21 10:27:11,369] [INFO] [torch_checkpoint_engine.py:33:commit] [Torch] Checkpoint global_step25 is ready now!

INFORM7

所有文件均保存完成,检查点 global_step25 可以正常使用。

总结

DeepSpeed 和 PyTorch 结合使用,完整的检查点保存流程

  1. 模型参数(权重)
  2. 优化器状态
  3. 配置文件
  4. 特殊辅助文件(如冻结参数形状和分词器信息)

transformers/trainer.py

核心逻辑

在分布式训练中,多个进程通常会并行运行,为了避免重复保存检查点(例如由每个进程都保存一次),通常只有主进程(通常称为 rank 0 或“进程 0”)负责保存检查点。
如果该函数运行到了这一步,说明调用者已经确保此代码只会在进程 0 上执行。因此,无需在此处再次检查是否为主进程。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def _save(self, output_dir: Optional[str] = None, state_dict=None):
output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
logger.info(f"Saving model checkpoint to {output_dir}")
# 确定支持的模型类型
supported_classes = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel)

# 模型不是支持的类型,state_dict:获取模型参数的状态字典,用于保存权重。
if not isinstance(self.model, supported_classes):
if state_dict is None:
state_dict = self.model.state_dict()
# 如果通过 self.accelerator.unwrap_model 解包后的模型属于支持的类:调用 save_pretrained() 方法保存模型到 output_dir,并传入 state_dict 和 safe_serialization 参数。
if isinstance(self.accelerator.unwrap_model(self.model), supported_classes):
self.accelerator.unwrap_model(self.model).save_pretrained(
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
)
# 如果解包后的模型仍然不属于支持的类:记录日志:只保存模型state_dict
else:
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
if self.args.save_safetensors:
safetensors.torch.save_file(
state_dict, os.path.join(output_dir, SAFE_WEIGHTS_NAME), metadata={"format": "pt"}
)
else:
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
# 如果 self.model 类型属于支持的类,直接调用 save_pretrained() 保存模型,包括 state_dict 和 safe_serialization 参数。
else:
self.model.save_pretrained(
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
)
#save_pretrained 是主函数,用于保存分词器的完整状态
if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir)

# 训练参数
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
PYTHON

总结

这段代码的功能是保存训练过程中生成的检查点,包括

  • 模型权重state_dict:获取模型参数的状态字典,用于保存权重。
  • 模型及其配置保存到指定的目录:self.model.save_pretrained(只有主进程才会保存模型)
  • 分词器配置:self.tokenizer.save_pretrained(output_dir),用于保存模型的函数,默认为 torch.save
  • 训练参数:torch.save

save_pretrained核心逻辑

save_pretrained 函数的核心逻辑是将训练好的模型及其配置保存到指定的目录,方便后续通过 from_pretrained 方法重新加载。

1
2
3
4
5
# 如果 self.model 类型属于支持的类,直接调用 save_pretrained() 保存模型,包括 state_dict 和 safe_serialization 参数。
else:
self.model.save_pretrained(
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
)
ROUTEROS

输入参数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
is_main_process: bool = True,
state_dict: Optional[dict] = None,
# 用于保存模型的函数,默认为 torch.save。
save_function: Callable = torch.save,
push_to_hub: bool = False,
# 每个模型检查点的最大大小,超过该大小时将进行分片保存。
max_shard_size: Union[int, str] = "5GB",
# 使用 safetensors进行序列化
safe_serialization: bool = True,
variant: Optional[str] = None,
token: Optional[Union[str, bool]] = None,
# 是否将适配器权重(如果存在)以兼容 PEFT 库的格式保存
save_peft_format: bool = True,
**kwargs,
):
PYTHON

offload 的模块进行特殊处理

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# for offloaded modules
module_map = {}

# Save the model
if state_dict is None:
# if any model parameters are offloaded, make module map
if (
# 检查模型是否有设备映射属性
hasattr(self, "hf_device_map")
# 确认设备映射到多个设备
and len(set(self.hf_device_map.values())) > 1
# 判断是否有部分模型参数被 offload(存储到 CPU 或磁盘)。
and ("cpu" in self.hf_device_map.values() or "disk" in self.hf_device_map.values())
):
# 发出警告,提示用户 CPU 的空闲内存需要大于 shard_size(默认值为 5GB)。
warnings.warn(
"Attempting to save a model with offloaded modules. Ensure that unallocated cpu memory exceeds the `shard_size` (5GB default)"
)

# 遍历 model_to_save 的所有子模块。named_modules() 方法返回所有模块的名字和模块本身。准备为每个模块提取参数字典(state_dict)。
for name, module in model_to_save.named_modules():
if name == "":
continue
# 获取当前模块的参数字典(state_dict)。
module_state_dict = module.state_dict()
# 将模块名和参数名组合,记录到 module_map 中。为每个参数生成唯一的键名:模块名.参数名。在 module_map 中记录键名和模块的对应关系,用于后续加载这些参数。
for key in module_state_dict:
module_map[name + f".{key}"] = module

# 调用 model_to_save 的 state_dict() 方法生成完整的模型参数字典。
state_dict = model_to_save.state_dict()
PYTHON

这部分代码的核心功能是为保存模型时支持 offload 的模块进行特殊处理

  1. 检查模型是否包含 offload 到 CPU 或磁盘的参数。
  2. 如果存在 offload 参数,创建 module_map,记录参数与模块的对应关系。
  3. 最后生成完整的 state_dict,用于保存模型。

分片保存

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# Save the model
filename_to_tensors = state_dict_split.filename_to_tensors.items()
if module_map:
filename_to_tensors = logging.tqdm(filename_to_tensors, desc="Saving checkpoint shards")
for shard_file, tensors in filename_to_tensors:
shard = {tensor: state_dict[tensor] for tensor in tensors}
# remake shard with onloaded parameters if necessary
if module_map:
if accelerate_version < version.parse("0.31"):
raise ImportError(
f"You need accelerate version to be greater or equal than 0.31 to save models with offloaded parameters. Detected version {accelerate_version}. "
f"Please upgrade accelerate with `pip install -U accelerate`"
)
# init state_dict for this shard
shard_state_dict = {name: "" for name in shard}
for module_name in shard:
module = module_map[module_name]
# update state dict with onloaded parameters
shard_state_dict = get_state_dict_from_offload(module, module_name, shard_state_dict)

# assign shard to be the completed state dict
shard = shard_state_dict
del shard_state_dict
gc.collect()

if safe_serialization:
# At some point we will need to deal better with save_function (used for TPU and other distributed
# joyfulness), but for now this enough.
safe_save_file(shard, os.path.join(save_directory, shard_file), metadata={"format": "pt"})
else:
save_function(shard, os.path.join(save_directory, shard_file))

MIPSASM

代码核心逻辑概述:

  • 分片处理:将大模型的参数按需分片保存,减少内存占用。
  • 设备卸载管理:处理可能被分配到不同设备(如 CPU 或磁盘)的参数,从设备或磁盘加载卸载的参数到内存中,确保能够正确重新加载。
  • 依赖库校验:确保环境中的库版本满足保存的要求(如 Accelerate 库版本)。

保存配置

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# Save the config
if is_main_process:
if not _hf_peft_config_loaded:
model_to_save.config.save_pretrained(save_directory)
if self.can_generate():
# generation config built from the model config + the model config holds generation kwargs -> generate
# may revert to legacy behavior if the two don't match
if (
model_to_save.generation_config._from_model_config
and model_to_save.config._has_non_default_generation_parameters()
):
new_generation_config = GenerationConfig.from_model_config(model_to_save.config)
if new_generation_config != model_to_save.generation_config:
logger.warning(
"Your generation config was originally created from the model config, but the model "
"config has changed since then. Unless you pass the `generation_config` argument to this "
"model's `generate` calls, they will revert to the legacy behavior where the base "
"`generate` parameterization is loaded from the model config instead. "
"To avoid this behavior and this warning, we recommend you to overwrite the generation "
"config model attribute before calling the model's `save_pretrained`, preferably also "
"removing any generation kwargs from the model config. This warning will be raised to an "
"exception in v4.41."
)
model_to_save.generation_config.save_pretrained(save_directory)

RUBY

总结:

  • 保存模型配置:如果没有加载 PEFT 配置文件,模型的配置文件会被保存。
  • 生成配置检查与保存:如果模型支持生成,并且生成配置与模型配置不一致,会生成一个新的生成配置,并保存该配置。如果生成配置发生变化,会发出警告,提醒用户注意生成配置的行为。
  • 最终,保存模型的配置文件和生成配置文件config.json 和 generation_config.json ,以确保模型的完整性和生成能力能够在后续使用中被正确恢复。

torch.save源码分析

在大模型训练中checkpoint的存储流程在/home/dell/anaconda3/envs/llama_factory/lib/python3.11/site-packages/deepspeed/runtime/checkpoint_engine/torch_checkpoint_engine.py中,调用torch.save(state_dict, path)函数。

1
2
3
4
5
def save(self, state_dict, path: str):
logger.info(f"[Torch] Saving {path}...")
torch.save(state_dict, path)
logger.info(f"[Torch] Saved {path}.")
return None
PYTHON
  • Pytorch 保存和加载模型后缀:**.pt 和.pth**
    作用:保存一个序列化(serialized)的目标到磁盘。
    函数使用了Python的pickle程序用于序列化。模型(models),张量
    (tensors)和文件夹(dictionaries)都是可以用这个函数保存的目标类型

在保存用于推理或者继续训练的常规检查点的时候,除了模型的state_dict之外,还必须保存其他参数。保存优化器的state_dict也非常重要,因为它包含了模型在训练时候优化器的缓存和参数。

save 函数的参数列表

1
2
3
4
5
6
7
8
def save(
obj: object,
f: FILE_LIKE,
pickle_module: Any = pickle,
pickle_protocol: int = DEFAULT_PROTOCOL,
_use_new_zipfile_serialization: bool = True,
_disable_byteorder_record: bool = False
) -> None:
PHP
  • obj: object: 需要保存的对象(如张量、模型等)。
  • f: FILE_LIKE: 目标文件,必须是一个文件类对象,能够实现 write 和 flush 方法,或者是一个包含文件路径的字符串或 os.PathLike 对象。
  • pickle_module: Any = pickle: 序列化时使用的模块,默认是 Python 的 pickle 模块,可以传入其他模块来控制序列化过程。
  • pickle_protocol: int = DEFAULT_PROTOCOL: 指定使用的 pickle 协议版本,默认使用 PyTorch 默认的协议。
  • _use_new_zipfile_serialization: bool = True: 是否使用新的基于 ZIP 的文件格式进行序列化。默认为 True,表示使用新格式。需要使用旧格式保存,可以将 _use_new_zipfile_serialization=False
  • _disable_byteorder_record: bool = False:用于控制是否记录字节顺序。

核心逻辑

1
2
3
4
torch._C._log_api_usage_once("torch.save")
_check_dill_version(pickle_module)
_check_save_filelike(f)

ISBL
  • torch._C._log_api_usage_once(“torch.save”)调用内部 C++ 函数 _log_api_usage_once 记录 torch.save 的 API 使用信息。PyTorch 内部会收集 API 调用的统计数据。
  • _check_dill_version(pickle_module)验证它的版本是否支持当前 PyTorch 的序列化需求。目的是避免由于模块版本问题导致序列化失败。
  • _check_save_filelike(f) 验证 f 是否是有效的文件类对象或路径。

检查是否使用新 ZIP 序列化格式

1
2
3
4
5
6
7
8

if _use_new_zipfile_serialization:
with _open_zipfile_writer(f) as opened_zipfile:
_save(obj, opened_zipfile, pickle_module, pickle_protocol, _disable_byteorder_record)
return
else:
with _open_file_like(f, 'wb') as opened_file:
_legacy_save(obj, opened_file, pickle_module, pickle_protocol)
SQF
  • _open_zipfile_writer(f): 打开一个 ZIP 文件写入器。 f 是文件路径,则创建对应文件。
  • _save(obj, opened_zipfile, pickle_module, pickle_protocol, _disable_byteorder_record):实际保存逻辑。将对象 obj 序列化并写入到 ZIP 文件中,使用指定的 pickle_module 和 pickle_protocol。
    _disable_byteorder_record: 控制是否记录字节顺序,通常与硬件架构有关(如小端序/大端序)。默认false。

def _save实际保存逻辑分析

初始化数据结构

1
2
3
4
def _save(obj, zip_file, pickle_module, pickle_protocol, _disable_byteorder_record):
serialized_storages = {}
id_map: Dict[int, str] = {}
storage_dtypes: Dict[int, torch.dtype] = {}
PYTHON
  • serialized_storages = {}存储所有被序列化的 PyTorch 存储对象(Storage)及其对应的键值对。
  • id_map: Dict[int, str] = {}维护存储对象的唯一标识映射。
    键:存储对象的 _cdata 属性(内存地址)。
    值:存储对象的唯一字符串标识符(str(len(id_map)) 自动生成)。
    避免多次序列化同一个存储对象。
  • storage_dtypes: Dict[int, torch.dtype] = {}记录存储对象的 data_ptr(内存指针)和其数据类型。保证同一内存地址的数据类型一致。

这部分代码的目的是初始化保存过程中需要追踪的信息结构:

  1. serialized_storages:记录被序列化的存储对象,便于后续写入 ZIP 文件
  2. id_map:确保存储对象的唯一标识,用于多次引用时避免重复保存。
  3. storage_dtypes:确保数据类型一致性,防止跨张量操作引发错误。

生成存储对象的唯一标识符和元信息流程

  • 判断对象是否为 PyTorch 存储类型(TypedStorage 或旧版 Storage)。
  • 为特定对象生成持久化 ID,在序列化过程中标记特殊类型的数据对象。
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    def persistent_id(obj):
    if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj):

    if isinstance(obj, torch.storage.TypedStorage):
    # TODO: Once we decide to break serialization FC, this case
    # can be deleted
    storage = obj._untyped_storage
    storage_dtype = obj.dtype
    storage_type_str = obj._pickle_storage_type()
    storage_type = getattr(torch, storage_type_str)
    storage_numel = obj._size()

    else:
    storage = obj
    storage_dtype = torch.uint8
    storage_type = normalize_storage_type(type(obj))
    storage_numel = storage.nbytes()

    ELM
  • 处理 TypedStorage 类型(具备类型信息的存储):
  1. storage = obj._untyped_storage 获取 TypedStorage 的底层非类型化存储(UntypedStorage),TypedStorage 是 Storage 的高级抽象,增加了数据类型信息。
  2. storage_dtype = obj.dtype提取存储对象的数据类型(如 torch.float16)。
  3. storage_type_str = obj._pickle_storage_type() 获取存储对象的类型名称(如 FloatStorage)。通过 getattr 获取对应的类型对象。
  4. storage_numel = obj._size()获取存储对象的元素数量(numel)。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
if storage.data_ptr() != 0:
if storage.data_ptr() in storage_dtypes:
if storage_dtype != storage_dtypes[storage.data_ptr()]:
raise RuntimeError(
'Cannot save multiple tensors or storages that '
'view the same data as different types')
else:
storage_dtypes[storage.data_ptr()] = storage_dtype

storage_key = id_map.setdefault(storage._cdata, str(len(id_map)))
location = location_tag(storage)
serialized_storages[storage_key] = storage

return ('storage',
storage_type,
storage_key,
location,
storage_numel)

return None
STYLUS
  • if storage.data_ptr() != 0:检查存储对象的指针(data_ptr)是否非零。data_ptr() 返回存储的底层内存地址。
  • storage.data_ptr() in storage_dtypes:检查当前存储对象的内存地址是否已存在于 storage_dtypes 中。如果该存储的内存地址尚未记录,存储其对应的数据类型。
  • storage_key = id_map.setdefault(storage._cdata, str(len(id_map)))为存储对象分配唯一的键(storage_key)。
  • location = location_tag(storage)获取存储的设备位置信息(如 cpu 或 cuda:0)。
  • serialized_storages[storage_key] = storage将存储对象按其键(storage_key)保存到 serialized_storages 字典中。
  • 生成并返回该存储对象的持久化 ID(元组)。
  • 如果对象 obj 不是存储对象,则返回 None,表示无需特殊处理。

总结:

  • 首先判断对象类型(TypedStorage 和普通 Storage),提取存储元信息:数据类型、存储类型、大小等元信息。
  • 然后检查存储是否已分配内存storage.data_ptr() != 0,未分配无需进行类型检查。否则检查存储的 data_ptr() 是否已经存在于 storage_dtypes,if storage_dtype != storage_dtypes[storage.data_ptr()],如果当前存储的 dtype 与之前记录的不一致,抛出一个运行时错误。这确保了同一块内存地址上的存储不能在序列化过程中以不同的数据类型保存。如果该 data_ptr() 尚未记录,将当前存储的 dtype 添加到 storage_dtypes 字典中,方便后续检查。storage_dtypes[storage.data_ptr()] = storage_dtype确保引用相同内存地址的多个存储对象在序列化过程中具有一致的数据类型。
  • 为当前存储生成或获取一个唯一的 storage_keystorage_key = id_map.setdefault(storage._cdata, str(len(id_map)))
  • 最后返回持久化标识符,用于后续保存和加载过程中的关联

对象序列化为二进制数据

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Write the pickle data for `obj`
data_buf = io.BytesIO()
pickler = pickle_module.Pickler(data_buf, protocol=pickle_protocol)
pickler.persistent_id = persistent_id
pickler.dump(obj)
data_value = data_buf.getvalue()
zip_file.write_record('data.pkl', data_value, len(data_value))

# Write byte order marker
if not _disable_byteorder_record:
if sys.byteorder not in ['little', 'big']:
raise ValueError('Unknown endianness type: ' + sys.byteorder)

zip_file.write_record('byteorder', sys.byteorder, len(sys.byteorder))
PYTHON
  • data_buf = io.BytesIO() 创建一个内存中的二进制数据缓冲区(BytesIO 对象),存在于内存中,用于存储后续序列化生成的数据。
  • 创建一个 Pickler 对象,用于将 obj 序列化为二进制数据。pickle_protocol: 序列化使用的协议版本。
  • pickler.persistent_id = persistent_id将 persistent_id 函数分配给 pickler 对象,用于处理序列化中的持久性对象。
  • pickler.dump(obj)使用 Pickler 对象将 obj 序列化为二进制数据,并写入 data_buf 中。dump 方法是 Pickler 提供的序列化操作。
  • data_value = data_buf.getvalue()从 data_buf 中提取已经序列化的二进制数据,存储到变量 data_value 中。
  • zip_file.write_record(‘data.pkl’, data_value, len(data_value))将序列化的二进制数据 data_value 写入到 zip_file 中,文件名为 ‘data.pkl’。
  • 检查是否需要禁用字节序标记写入功能,sys.byteorder: 系统的字节序,可以是 ‘little’(小端序)或 ‘big’(大端序)。如果检测到未知字节序,则抛出异常。
  • zip_file.write_record(‘byteorder’, sys.byteorder, len(sys.byteorder))将系统的字节序(sys.byteorder)作为记录写入 zip_file,文件名为 ‘byteorder’。数据内容是 ‘little’ 或 ‘big’。

张量存储器的保存

1
2
3
4
5
6
7
8
9
10
11
12
# Write each tensor to a file named tensor/the_tensor_key in the zip archive
for key in sorted(serialized_storages.keys()):
name = f'data/{key}'
storage = serialized_storages[key]
# given that we copy things around anyway, we might use storage.cpu()
# this means to that to get tensors serialized, you need to implement
# .cpu() on the underlying Storage
if storage.device.type != 'cpu':
storage = storage.cpu()
# Now that it is on the CPU we can directly copy it into the zip file
num_bytes = storage.nbytes()
zip_file.write_record(name, storage, num_bytes)
PGSQL
  • 遍历 serialized_storages 的所有键,并按字母顺序排序,serialized_storages: 一个字典,其中键是张量的名称或标识符,值是对应的存储对象(通常为张量的底层数据)。
  • **name = f’data/{key}’**,生成该键对应的文件路径,存储为变量 name,文件路径以 ‘data/‘ 为前缀,再加上键值 key。
  • storage = serialized_storages[key] serialized_storages 字典中获取当前键 key 对应的存储对象,赋值给变量 storage。
  • if storage.device.type != ‘cpu’:检查当前存储对象是否位于非 CPU 设备上(如 GPU)。 storage = storage.cpu()如果设备类型不是 ‘cpu’,则将存储对象移至 CPU。
  • num_bytes = storage.nbytes()获取当前存储对象占用的字节数,赋值给变量 num_bytes。
  • **zip_file.write_record(name, storage, num_bytes)**将存储对象 storage 写入压缩文件中,文件名为 name。
    name: 文件路径(如 ‘data/weight’)。
    storage: 需要写入的数据内容。
    num_bytes: 数据的字节大小,用于告知写入函数内假设是一个自定义方法,负责将数据写入到压缩文件中。

总结

  1. 对象序列化(Pickle)并存储:使用 pickle 序列化模块将指定的对象 obj 转换为二进制数据。序列化后的数据被存储在内存缓冲区中,并写入压缩文件 zip_file,文件名为 ‘data.pkl’。
  2. 记录字节序(Byte Order Marker):检查系统的字节序(sys.byteorder),确保它是有效的(’little’ 或 ‘big’)。将字节序信息写入压缩文件中,文件名为 ‘byteorder’。
  3. 张量数据的存储:遍历所有张量存储(serialized_storages 字典中的键值对)。为每个张量生成文件路径,存储为 ‘data/‘。
    如果张量数据不在 CPU 上(如位于 GPU 上),将其移动到 CPU,确保数据可以被序列化并存储。获取张量数据的字节大小,并将其内容写入压缩文件中。

def_save函数核心逻辑

这段代码提供了一个通用的机制,用于将复杂对象和张量序列化并存储在压缩文件中

  1. 初始化数据结构
  • 定义 serialized_storages 用于存储序列化后的张量存储器。
  • 使用 id_map 记录每个存储器的唯一标识符,避免重复序列化相同数据。
  • 通过 storage_dtypes 确保多个存储器(storages)共享同一底层数据时,它们的数据类型(dtype)一致
  1. 生成存储对象的唯一标识符和元信息
  • 该方法处理张量的存储器序列化逻辑,并返回一个包含存储器元数据的标识符(tuple)。
  • 记录存储器的类型、唯一键、位置(设备)、元素数量等信息。
  • 检查共享数据指针的存储器是否具有相同的数据类型,确保数据一致性。
    将存储器添加到 serialized_storages 中,并通过 id_map 分配唯一键名。
  1. 对象序列化和保存:
  • 使用 pickle 模块将对象(obj)序列化为二进制数据。
  • 通过 pickle_module.Pickler 实现自定义序列化行为,并将persistent_id 用于处理张量存储器的特殊序列化需求。
  • 将序列化后的对象数据写入压缩文件,文件名为 data.pkl。
    4.存储字节序(Byte Order Marker):
  • 检查系统的字节序(sys.byteorder),确保数据可以跨平台读取。将字节序信息写入压缩文件,文件名为 byteorder。
  1. 张量存储器的保存:
  • 遍历 serialized_storages 中的所有张量存储器。如果存储器不在 CPU 上,则将其移动到 CPU。
  • 获取存储器的字节大小并将其内容写入压缩文件,路径为 data/

检查点存储流程
http://sjx.com/2024/11/27/检查点存储流程/
作者
sjx
发布于
2024年11月27日
许可协议