forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
unpickler.h
203 lines (176 loc) · 7.42 KB
/
unpickler.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
#pragma once
#include <ATen/core/ivalue.h>
#include <c10/util/ArrayRef.h>
#include <caffe2/serialize/inline_container.h>
#include <torch/csrc/Export.h>
#include <torch/csrc/jit/frontend/script_type_parser.h>
#include <torch/csrc/jit/serialization/pickler.h>
namespace torch {
namespace jit {
using TypeResolver =
std::function<c10::StrongTypePtr(const c10::QualifiedName&)>;
using ObjLoader = std::function<
c10::intrusive_ptr<c10::ivalue::Object>(const at::StrongTypePtr&, IValue)>;
class DeserializationStorageContext;
// [unpickler refactor] there is some cruft around PickleOpCode::BUILD,
// PickleOpCode::NEWOBJ, and the last_opcode_ member below that should be
// deleted at some point, the Pickler doesn't produce it and it's only around to
// support models saved before 1.1
class TORCH_API Unpickler {
AT_DISALLOW_COPY_AND_ASSIGN(Unpickler);
using TypeParserT = c10::TypePtr (*)(const std::string&);
public:
// tensors inside the pickle are references to the tensor_table.
// class_resolver is to resolve strong class type, type_resolver_ is
// to resolve any JIT type. class_resolver and type_resolver are not merged
// here because some use cases need to get strong class type that
// type_resolver_ can not return.
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
Unpickler(
std::function<size_t(char*, size_t)> reader,
TypeResolver type_resolver,
c10::ArrayRef<at::Tensor> tensor_table,
TypeParserT type_parser = defaultTypeParser)
: reader_(std::move(reader)),
tensor_table_(tensor_table),
type_resolver_(std::move(type_resolver)),
use_storage_device_(false),
type_parser_(type_parser),
version_(caffe2::serialize::kProducedFileFormatVersion) {}
Unpickler(
std::function<size_t(char*, size_t)> reader,
TypeResolver type_resolver,
c10::ArrayRef<at::Tensor> tensor_table,
ObjLoader obj_loader,
TypeParserT type_parser = defaultTypeParser)
: reader_(std::move(reader)),
tensor_table_(tensor_table),
type_resolver_(std::move(type_resolver)),
obj_loader_(std::move(obj_loader)),
use_storage_device_(false),
type_parser_(type_parser),
version_(caffe2::serialize::kProducedFileFormatVersion) {}
// tensors inside the pickle contain meta-data, the raw tensor
// dead is retrieved by calling `read_record`.
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
Unpickler(
std::function<size_t(char*, size_t)> reader,
TypeResolver type_resolver,
ObjLoader obj_loader,
std::function<at::DataPtr(const std::string&)> read_record,
c10::optional<at::Device> device,
bool use_storage_device = false,
TypeParserT type_parser = defaultTypeParser,
std::shared_ptr<DeserializationStorageContext> storage_context = nullptr)
: reader_(std::move(reader)),
tensor_table_(),
type_resolver_(std::move(type_resolver)),
obj_loader_(std::move(obj_loader)),
read_record_(std::move(read_record)),
// NOLINTNEXTLINE(performance-move-const-arg)
device_(std::move(device)),
use_storage_device_(use_storage_device),
type_parser_(type_parser),
storage_context_(std::move(storage_context)),
version_(caffe2::serialize::kProducedFileFormatVersion) {}
// consume the pickle stream, producing an IValue from the contents.
// Type Tags: the pickler will restore the type tags on
// List and Dict objects when possible IValue is an Object.
// Otherwise, Dict and List objects will end up with Any as their tag.
// If you know the type of the ivalue, tags can be restored with
// restoreAccurateTypeTags
IValue parse_ivalue();
// [type tag serialization]
// This is used to determine whether to restore type tags be recursively
// descending into the returned stack object (if version_number <= 2), or
// if version_number >= 3, to use the type strings included in the pickle
// archive for container types. By default this is set to
// `kProducedFileFormatVersion` so unless you're loading a pickle file
// from alongside a corresponding `version` file, you don't need to set
// the version manually.
void set_version(uint64_t version_number) {
version_ = version_number;
}
static c10::TypePtr defaultTypeParser(const std::string& str) {
ScriptTypeParser parser;
return parser.parseType(str);
}
private:
// No arguments ensures that a template argument must be specified
// so that the number of bytes read / type read is explicit
template <typename T>
T read() {
T item;
if (sizeof(T) <= buffer_remaining_) {
// Fast path: entirely from buffer.
memcpy(&item, buffer_.data() + buffer_pos_, sizeof(T));
buffer_remaining_ -= sizeof(T);
buffer_pos_ += sizeof(T);
} else {
// Don't over-template the slow path, to avoid code size bloat.
readSlowWithBuffer(reinterpret_cast<char*>(&item), sizeof(T));
}
return item;
}
void readSlowWithBuffer(char* dest, size_t sz);
std::string readBytes(size_t num_bytes);
double readFloat();
void readGlobal(
const std::string& module_name,
const std::string& class_name);
void rebuildTensor(bool quantized);
void rebuildTensorFromTypeV2();
void rebuildSparseTensor();
#ifdef USE_DISTRIBUTED
void rebuildRRef();
#endif
PickleOpCode readInstruction();
PickleOpCode readOpCode() {
return static_cast<PickleOpCode>(read<uint8_t>());
}
std::string readString();
void readList(IValue list_ivalue);
void readListElements(IValue list_ivalue, size_t start);
void setInput(size_t memo_id);
void run();
// Returns the number of bytes read. This should statefully
// remember the position. Don't call reader_ directly.
std::function<size_t(char*, size_t)> reader_;
// Small buffer to avoid calling reader_ on a per-byte basis.
std::array<char, 256> buffer_;
size_t buffer_pos_{0};
size_t buffer_remaining_{0};
std::vector<IValue> stack_;
// globals are represented on the stack as IValue integer indices
// into this list
std::vector<std::function<void(void)>> globals_;
std::vector<IValue> memo_table_;
std::vector<size_t> marks_;
c10::ArrayRef<at::Tensor> tensor_table_;
// When deserializing types on lists and dicts, cache the type here
// so we don't have to parse the same type multiple times. Strings
// are already de-duplicated and replaced with BINGETs in the
// pickler, so we can just use the actual data pointer of each string.
std::unordered_map<std::string, c10::TypePtr> type_cache_;
// optionally nullptr, needs to be present for creating classes
TypeResolver type_resolver_;
ObjLoader obj_loader_;
IValue empty_tuple_;
std::function<at::DataPtr(const std::string&)> read_record_;
c10::optional<at::Device> device_;
// When set to true, Unpickler will ignore the pickled device and use the
// device of the DataPtr returned by the read_record_ function. The default
// value of this flag is false.
const bool use_storage_device_;
TypeParserT type_parser_{defaultTypeParser};
// Used for torch.package to enable sharing of storages across
// ScriptModules and eager modules
std::shared_ptr<DeserializationStorageContext> storage_context_;
// See [type tag serialization]
uint64_t version_;
// See [NOTE] skip_next_read_global
uint8_t skip_next_read_global = 0;
};
void restoreAccurateTypeTags(const IValue& root, const c10::TypePtr& type_tag);
} // namespace jit
} // namespace torch