Skip to content

Commit

Permalink
Merge pull request #16985 from velconia/local_rel_1_4_fix_save_load
Browse files Browse the repository at this point in the history
fix dygraph save load
  • Loading branch information
velconia authored Apr 19, 2019
2 parents eca2be5 + d72db90 commit 3df4cbf
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 46 deletions.
67 changes: 29 additions & 38 deletions python/paddle/fluid/dygraph/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def save_persistables(vardict, dirname, filename=None):
_save_var_to_file(vardict, dirname, filename)


def load_persistables(vardict, dirname, filename=None):
def load_persistables(dirname):
"""
This function trys to load persistable variables from the folder
`dirname` or the file `filename`.
Expand All @@ -86,11 +86,7 @@ def load_persistables(vardict, dirname, filename=None):
the file name.
Args:
vardict(dict of Parameters): The parameters will be loaded.
dirname(str): The directory path.
filename(str|None): The file which saved all variables, this file path should be end with '.npz'. If variables were
saved in differnet files, set it to None.
Default: None
Returns:
dict: The parameter-dict resumed from file
Expand All @@ -104,10 +100,7 @@ def load_persistables(vardict, dirname, filename=None):
param_1 = param_dict['PtbModel_0.w_1']
"""
if isinstance(vardict, collections.OrderedDict):
return _load_var_from_file(vardict, dirname, filename)

return {}
return _load_var_from_file(dirname)


def _save_var_to_file(stat_dict, file_dir, file_name):
Expand Down Expand Up @@ -139,41 +132,39 @@ def _save_var_to_file(stat_dict, file_dir, file_name):
})


def _load_var_from_file(stat_dict, file_dir, file_name):
def _load_var_from_file(file_dir):
def walk_filename(file_dir):
base_path = os.path.join(file_dir)
var_name_list = []
if os.path.exists(base_path):
for dirpath, dirnames, filenames in os.walk(base_path):
pt = dirpath.replace(base_path, "", 1)
if pt.startswith("/") or pt.startswith("\\"):
pt = pt[1:]
for fth_name in filenames:
if fth_name[0] != '.':
name_path = os.path.join(pt, fth_name)
if "\\" in name_path:
name_path = name_path.replace("\\", "/")
var_name_list.append(name_path)

return var_name_list

load_block = default_main_program().global_block()
load_var_map = {}

for var_key, each_var in stat_dict.items():
assert isinstance(each_var, Variable)
if each_var.type == core.VarDesc.VarType.RAW:
continue
new_var = _clone_var_in_block_(load_block, each_var)
if file_name is None:
load_block.append_op(
type='load',
inputs={},
outputs={'Out': [new_var]},
attrs={
'file_path': os.path.join(file_dir,
os.path.normpath(each_var.name))
})

load_var_map[new_var.name] = new_var

if file_name is not None:
load_var_list = []
for name in sorted(load_var_map.keys()):
load_var_list.append(load_var_map[name])

file_var_list = walk_filename(file_dir)
for var_name in file_var_list:
new_var = Variable(block=load_block, name=var_name)
load_block.append_op(
type='load_combine',
type='load',
inputs={},
outputs={"Out": load_var_list},
outputs={'Out': [new_var]},
attrs={
'file_path': os.path.join(file_dir, os.path.normpath(file_name))
'file_path': os.path.join(file_dir,
os.path.normpath(new_var.name))
})
for res_var in load_var_list:
load_var_map[res_var.name] = res_var

load_var_map[new_var.name] = new_var

return load_var_map

Expand Down
10 changes: 9 additions & 1 deletion python/paddle/fluid/dygraph/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __init__(self, name_scope, dtype=core.VarDesc.VarType.FP32):
self._dtype = dtype
self._parameters = collections.OrderedDict()
self._sub_layers = collections.OrderedDict()
self._loaddict_holder = collections.OrderedDict()

self._helper = LayerObjectHelper(self._full_name)

Expand Down Expand Up @@ -193,6 +194,9 @@ def add_parameter(self, name, parameter):
"""
assert isinstance(parameter, framework.Parameter)
self._parameters[name] = parameter
if parameter.name in self._loaddict_holder:
self._parameters[name] = self._loaddict_holder[parameter.name]
parameter = self._loaddict_holder[parameter.name]
return parameter

def __getattr__(self, name):
Expand All @@ -207,7 +211,10 @@ def __setattr__(self, name, value):
if params is None:
raise ValueError(
"super(YourLayer, self).__init__() should be called first")
params[name] = value
if value.name in self._loaddict_holder:
params[name] = self._loaddict_holder[value.name]
else:
params[name] = value
elif isinstance(value, core.Layer):
layers = self.__dict__.get('_sub_layers', None)
if layers is None:
Expand Down Expand Up @@ -244,6 +251,7 @@ def state_dict(self, destination=None, prefix='', include_sublayers=True):
return destination

def load_dict(self, stat_dict, include_sublayers=True):
self._loaddict_holder = stat_dict
for name, item in self.__dict__.get('_parameters', None).items():
if item.name in stat_dict:
var = item._ivar.value()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,14 +142,11 @@ def test_save_load_persistables(self):
for param in mnist.parameters():
dy_param_init_value[param.name] = param.numpy()

mnist.load_dict(
fluid.dygraph.load_persistables(mnist.state_dict(),
"save_dir"))

restore = mnist.parameters()
restore = fluid.dygraph.load_persistables("save_dir")
mnist.load_dict(restore)

self.assertEqual(len(dy_param_init_value), len(restore))
for value in restore:
for ky, value in restore.items():
self.assertTrue(
np.allclose(value.numpy(), dy_param_init_value[
value.name]))
Expand All @@ -158,7 +155,7 @@ def test_save_load_persistables(self):

step += 1

if step > 20:
if step > 10:
break


Expand Down

0 comments on commit 3df4cbf

Please sign in to comment.