检查点序列化过程

torch.save()

def _save入口

  • /home/dell/sdb/pytorch/torch/serialization.py(813)_save()
    1
    2
    3
    4
    5
    def _save(obj, zip_file, pickle_module, pickle_protocol, _disable_byteorder_record):
    pdb.set_trace()
    serialized_storages = {}
    id_map: Dict[int, str] = {}
    storage_dtypes: Dict[int, torch.dtype] = {}

序列化obj到data_buf

1
2
3
4
5
6
7
8
# Write the pickle data for `obj`
data_buf = io.BytesIO()
pickler = pickle_module.Pickler(data_buf, protocol=pickle_protocol)
pickler.persistent_id = persistent_id
pdb.set_trace()
pickler.dump(obj)
data_value = data_buf.getvalue()
zip_file.write_record('data.pkl', data_value, len(data_value))
  • 将对象 obj 使用 pickle 模块序列化,并将序列化后的数据存储在一个内存缓冲区中,然后将缓冲区中的数据写入到一个压缩文件中。

pickle.py

  • /home/dell/anaconda3/envs/torch_new_env/lib/python3.11/pickle.py

def dump

1
2
3
4
5
6
7
8
9
10
11
12
13
def dump(self, obj):
if not hasattr(self, "_file_write"):
raise PicklingError("Pickler.__init__() was not called by "
"%s.__init__()" % (self.__class__.__name__,))
if self.proto >= 2:
self.write(PROTO + pack("<B", self.proto))
# PROTO = b'\x80' pack("<B", self.proto):\x02
# <B 表示小端字节顺序的 1 字节数据,将协议版本(self.proto)打包成一个字节流。
if self.proto >= 4:
self.framer.start_framing()
self.save(obj)
self.write(STOP)
self.framer.end_framing()
  • write(PROTO + pack(“<B”, self.proto)) data : b’\x80\x02’
  • save(obj)
  • write(STOP) b’.’

def write

1
2
3
4
5
def write(self, data):
if self.current_frame:
return self.current_frame.write(data)
else:
return self.file_write(data)
  • pack 是 struct 模块中的一个函数,用于将数据打包为二进制格式。”<B” 是一个格式字符串,表示将 self.proto打包成一个字节(B)并以小端字节序(<)进行编码。
  • “B”:表示一个无符号字节(unsigned char),即一个 8 位的整数。
  • “<”:表示小端字节序(Least Endian)
  • PROTO:b’\x80 标识序列化协议,常量
  • pack(“<B”, self.proto) 将 self.proto 的值打包成一个字节流。
  • self.proto = 2,被打包为 b’\x02’
  • self.file_write(data)写入data

self.save(obj)

obj

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
(Pdb) p obj
{'model_state_dict': OrderedDict([('embedding.weight', tensor([[-0.6647],
[ 0.9021],
[-0.8437],
...,
[-1.1032],
[ 0.5453],
[ 0.5881]])), ('rnn.weight_ih_l0', tensor([[-0.5051],
[ 0.1924]])), ('rnn.weight_hh_l0', tensor([[-0.6455, -0.5423],
[-0.0997, 0.0933]])), ('rnn.bias_ih_l0', tensor([ 0.0648, -0.3999])), ('rnn.bias_hh_l0', tensor([0.0913, 0.1312])), ('fc.weight', tensor([[ 0.1756, 0.3320],
[ 0.0310, -0.0897],
[ 0.1182, 0.0745],
...,
[ 0.4240, -0.1028],
[ 0.1716, -0.0779],
[-0.0436, 0.3810]])), ('fc.bias', tensor([ 0.4243, -0.6157, -0.2059, ..., -0.5051, 0.5370, -0.2196]))])}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
def save(self, obj, save_persistent_id=True):
self.framer.commit_frame()

# Check for persistent id (defined by a subclass)
pid = self.persistent_id(obj)
if pid is not None and save_persistent_id:
self.save_pers(pid)
return

# Check the memo
x = self.memo.get(id(obj))
if x is not None:
self.write(self.get(x[0]))
return

rv = NotImplemented
reduce = getattr(self, "reducer_override", None)
if reduce is not None:
rv = reduce(obj)

if rv is NotImplemented:
# Check the type dispatch table
t = type(obj)
f = self.dispatch.get(t)
if f is not None:
f(self, obj) # Call unbound method with explicit self
return

# Check private dispatch table if any, or else
# copyreg.dispatch_table
reduce = getattr(self, 'dispatch_table', dispatch_table).get(t)
if reduce is not None:
rv = reduce(obj)
else:
# Check for a class with a custom metaclass; treat as regular
# class
if issubclass(t, type):
self.save_global(obj)
return

# Check for a __reduce_ex__ method, fall back to __reduce__
reduce = getattr(obj, "__reduce_ex__", None)
if reduce is not None:
rv = reduce(self.proto)
else:
reduce = getattr(obj, "__reduce__", None)
if reduce is not None:
rv = reduce()
else:
raise PicklingError("Can't pickle %r object: %r" %
(t.__name__, obj))

# Check for string returned by reduce(), meaning "save as global"
if isinstance(rv, str):
self.save_global(obj, rv)
return

# Assert that reduce() returned a tuple
if not isinstance(rv, tuple):
raise PicklingError("%s must return string or tuple" % reduce)

# Assert that it returned an appropriately sized tuple
l = len(rv)
if not (2 <= l <= 6):
raise PicklingError("Tuple returned by %s must have "
"two to six elements" % reduce)

# Save the reduce() output and finally memoize the object
self.save_reduce(obj=obj, *rv)

检查pid

Check the memo

1
2
3
4
x = self.memo.get(id(obj))
if x is not None:
self.write(self.get(x[0]))
return
  • 这里x是None,不执行
  • id(obj):128170083121344

执行序列化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
rv = NotImplemented
reduce = getattr(self, "reducer_override", None)
if reduce is not None:
rv = reduce(obj)

if rv is NotImplemented:
# Check the type dispatch table
# 获取对象 obj 的类型
t = type(obj)
# 查找该类型 t 对应的处理方法
f = self.dispatch.get(t)
if f is not None:
f(self, obj) # Call unbound method with explicit self
return

1.save_dict(self, obj)

  • t:dict
  • f:save_dict
  • 字典序列化
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    def save_dict(self, obj):

    if self.bin:
    self.write(EMPTY_DICT) b'}'
    else: # proto 0 -- can't use EMPTY_DICT
    self.write(MARK + DICT)

    self.memoize(obj)
    self._batch_setitems(obj.items())

    dispatch[dict] = save_dict
    if PyStringMap is not None:
    dispatch[PyStringMap] = save_dict

将obj存储到memo

  • def memoize
1
2
3
4
5
6
7
def memoize(self, obj):
if self.fast:
return
assert id(obj) not in self.memo
idx = len(self.memo)
self.write(self.put(idx))
self.memo[id(obj)] = idx, obj
  • memo 是用来缓存已经序列化过的对象,避免重复序列化相同的对象。
  • idx:0 当前长度,memo 中下一个可用的位置
  • self.write(self.put(idx)) 将 idx 存储到输出流中。
  • self.memo[id(obj)] = idx, obj
  1. memo[127373469135808]=0,dict

def _batch_setitems(self, items)

  • 将字典 obj 中的所有键值对通过 _batch_setitems 方法批量处理,通常是进行序列化操作。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
def _batch_setitems(self, items):
# Helper to batch up SETITEMS sequences; proto >= 1 only
save = self.save
write = self.write
# self.bin是true
if not self.bin:
for k, v in items:
save(k)
save(v)
write(SETITEM)
return

it = iter(items)
while True:
tmp = list(islice(it, self._BATCHSIZE))
n = len(tmp)
if n > 1:
write(MARK)
for k, v in tmp:
save(k)
save(v)
write(SETITEMS)
elif n:
pdb.set_trace()
k, v = tmp[0]
save(k)
save(v)
write(SETITEM)
# else tmp is empty, and we're done
if n < self._BATCHSIZE:
return
  • 键值对保存到 tmp 列表中,最多取1000个键值对
  • 在一个无限循环处理完所有的键值对
  • tmp
(Pdb) p items
dict_items([('model_state_dict', OrderedDict([('embedding.weight', tensor([[ 0.7320],
        [-1.1317],
        [ 1.5068],
        ...,
        [ 0.6015],
        [ 0.3011],
        [-1.1067]])), ('rnn.weight_ih_l0', tensor([[-0.3000],
        [-0.5426]])), ('rnn.weight_hh_l0', tensor([[-0.3055, -0.6061],
        [ 0.3116, -0.4966]])), ('rnn.bias_ih_l0', tensor([0.1714, 0.3571])), ('rnn.bias_hh_l0', tensor([-0.6939, -0.2296])), ('fc.weight', tensor([[-0.1484,  0.5944],
        [-0.4861, -0.2301],
        [ 0.5660, -0.4483],
        ...,
        [-0.2943, -0.1674],
        [-0.4656,  0.3088],
        [-0.3656,  0.0259]])), ('fc.bias', tensor([ 0.6055, -0.0741, -0.1650,  ..., -0.1479,  0.0097,  0.2772]))]))])

总结

图片


检查点序列化过程
http://sjx.com/2024/12/10/检查点序列化过程/
作者
sjx
发布于
2024年12月10日
许可协议