diff --git a/hanlp/common/torch_component.py b/hanlp/common/torch_component.py index 6d8e07f6a..f808fd02d 100644 --- a/hanlp/common/torch_component.py +++ b/hanlp/common/torch_component.py @@ -3,6 +3,7 @@ # Date: 2020-05-08 21:20 import logging import os +import pickle import re import time from abc import ABC, abstractmethod @@ -97,7 +98,10 @@ def load_weights(self, save_dir, filename='model.pt', **kwargs): save_dir = get_resource(save_dir) filename = os.path.join(save_dir, filename) # flash(f'Loading model: {filename} [blink]...[/blink][/yellow]') - self.model_.load_state_dict(torch.load(filename, map_location='cpu'), strict=False) + try: + self.model_.load_state_dict(torch.load(filename, map_location='cpu', weights_only=True), strict=False) + except pickle.UnpicklingError: + self.model_.load_state_dict(torch.load(filename, map_location='cpu', weights_only=False), strict=False) # flash('') def save_config(self, save_dir, filename='config.json'):