diff --git a/textgen/language_modeling/songnet_model.py b/textgen/language_modeling/songnet_model.py index 5f2b0bc..99b83cf 100644 --- a/textgen/language_modeling/songnet_model.py +++ b/textgen/language_modeling/songnet_model.py @@ -31,7 +31,11 @@ from textgen.language_modeling.songnet_utils import ( ZHCharTokenizer, s2t, s2xy, s2xy_polish, SongNetDataLoader, - BOS, EOS, + BOS, + EOS, + PRETRAINED_MODELS, + LOCAL_DIR, + http_get, ) has_cuda = torch.cuda.is_available() @@ -594,7 +598,7 @@ class SongNetModel: def __init__( self, model_type='songnet', - model_name='shibing624/songnet-base-chinese-couplet', + model_name='songnet-base-chinese', args=None, use_cuda=has_cuda, cuda_device=-1, @@ -644,6 +648,17 @@ def __init__( self.results = {} if model_name: + bin_path = os.path.join(model_name, 'pytorch_model.bin') + if not os.path.exists(bin_path): + if model_name in PRETRAINED_MODELS: + local_model_dir = os.path.join(LOCAL_DIR, model_name) + local_bin_path = os.path.join(local_model_dir, 'pytorch_model.bin') + if not os.path.exists(bin_path): + url = PRETRAINED_MODELS[model_name] + http_get(url, local_model_dir) + else: + logger.warning(f'Model {bin_path} not exists, use local model {local_bin_path}') + model_name = local_model_dir self.tokenizer = ZHCharTokenizer.from_pretrained(model_name, **kwargs) self.model = SongNet( self.tokenizer, @@ -655,7 +670,7 @@ def __init__( num_layers=self.args.num_layers, smoothing_factor=self.args.smoothing_factor, ) - self.model.load_state_dict(torch.load(os.path.join(model_name, 'pytorch_model.bin'))) + self.model.load_state_dict(torch.load(bin_path)) self.args.model_type = model_type if model_name is None: diff --git a/textgen/language_modeling/songnet_utils.py b/textgen/language_modeling/songnet_utils.py index 8492df2..a4b54ab 100644 --- a/textgen/language_modeling/songnet_utils.py +++ b/textgen/language_modeling/songnet_utils.py @@ -4,11 +4,18 @@ @description: """ import os +import sys import random import numpy as np import torch -from loguru import logger +import shutil +import tarfile +import zipfile +import six +import requests + +from tqdm.autonotebook import tqdm PAD, UNK, BOS, EOS = '', '', '', '' BOC, EOC = '', '' @@ -18,7 +25,14 @@ PS = [''] + ['' for i in range(512)] # position TS = [''] + ['' for i in range(32)] # other types PUNCS = {",", ".", "?", "!", ":", ",", "。", "?", "!", ":"} - +PRETRAINED_MODELS = { + 'songnet-base-chinese': + 'https://github.com/shibing624/pycorrector/releases/download/0.4.5/convseq2seq_correction.tar.gz', + 'songnet-base-chinese-couplet': '', + 'songnet-base-chinese-poem': '', + 'songnet-base-chinese-songci': '', +} +LOCAL_DIR = os.path.expanduser('~/.cache/torch/shibing624/') class ZHCharTokenizer(object): def __init__(self, vocab_file, specials=None): @@ -68,7 +82,7 @@ def token2idx(self, x): return self._token2idx.get(x, self.unk_idx) def __repr__(self): - return f"ZHCharTokenizer<_token2idx size:{len(self._token2idx)}>" + return f"ZHCharTokenizer" @classmethod def from_pretrained(cls, model_dir, *init_inputs, **kwargs): @@ -377,3 +391,78 @@ def preprocess_data(line, max_length, min_length): if len(ys) < min_length: return None return xs_tpl, xs_seg, xs_pos, ys, ys_tpl, ys_seg, ys_pos + + +def http_get(url, path, extract: bool = True): + """ + Downloads a URL to a given path on disc + """ + if os.path.dirname(path) != '': + os.makedirs(os.path.dirname(path), exist_ok=True) + + req = requests.get(url, stream=True) + if req.status_code != 200: + print("Exception when trying to download {}. Response {}".format(url, req.status_code), file=sys.stderr) + req.raise_for_status() + return + + download_filepath = path + "_part" + with open(download_filepath, "wb") as file_binary: + content_length = req.headers.get('Content-Length') + total = int(content_length) if content_length is not None else None + progress = tqdm(unit="B", total=total, unit_scale=True) + for chunk in req.iter_content(chunk_size=1024): + if chunk: # filter out keep-alive new chunks + progress.update(len(chunk)) + file_binary.write(chunk) + + os.rename(download_filepath, path) + progress.close() + + if extract: + data_dir = os.path.dirname(os.path.abspath(path)) + _extract_archive(path, data_dir, 'auto') + + +def _extract_archive(file_path, path='.', archive_format='auto'): + """ + Extracts an archive if it matches tar, tar.gz, tar.bz, or zip formats. + :param file_path: path to the archive file + :param path: path to extract the archive file + :param archive_format: Archive format to try for extracting the file. + Options are 'auto', 'tar', 'zip', and None. + 'tar' includes tar, tar.gz, and tar.bz files. + The default 'auto' is ['tar', 'zip']. + None or an empty list will return no matches found. + :return: True if a match was found and an archive extraction was completed, + False otherwise. + """ + if archive_format is None: + return False + if archive_format == 'auto': + archive_format = ['tar', 'zip'] + if isinstance(archive_format, six.string_types): + archive_format = [archive_format] + + for archive_type in archive_format: + if archive_type == 'tar': + open_fn = tarfile.open + is_match_fn = tarfile.is_tarfile + if archive_type == 'zip': + open_fn = zipfile.ZipFile + is_match_fn = zipfile.is_zipfile + + if is_match_fn(file_path): + with open_fn(file_path) as archive: + try: + archive.extractall(path) + except (tarfile.TarError, RuntimeError, + KeyboardInterrupt): + if os.path.exists(path): + if os.path.isfile(path): + os.remove(path) + else: + shutil.rmtree(path) + raise + return True + return False