From 15a7a7c8ea06a9d6e1a6db7c5dc2be8ed5e6b314 Mon Sep 17 00:00:00 2001 From: Ievgen Popovych Date: Thu, 12 Jan 2023 19:57:20 +0200 Subject: [PATCH] Refactor model download logic This eases on repetition a bit and makes it easier to use utils module to download (by offloading more tasks to it). Signed-off-by: Ievgen Popovych --- easyocr/easyocr.py | 41 ++++++++++++----------------------------- easyocr/utils.py | 14 +++++++++++--- 2 files changed, 23 insertions(+), 32 deletions(-) diff --git a/easyocr/easyocr.py b/easyocr/easyocr.py index 4ef943401..cbb619eac 100644 --- a/easyocr/easyocr.py +++ b/easyocr/easyocr.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- from .recognition import get_recognizer, get_text -from .utils import group_text_box, get_image_list, calculate_md5, get_paragraph,\ +from .utils import group_text_box, get_image_list, file_md5_matches, get_paragraph,\ download_and_unzip, printProgressBar, diff, reformat_input,\ make_rotated_img_list, set_result_with_confidence,\ reformat_input_batched @@ -167,24 +167,14 @@ def __init__(self, lang_list, gpu=True, model_storage_directory=None, model_path = os.path.join(self.model_storage_directory, model['filename']) # check recognition model file if recognizer: - if os.path.isfile(model_path) == False: + if not os.path.isfile(model_path) or not file_md5_matches(model_path, model['md5sum']): if not self.download_enabled: - raise FileNotFoundError("Missing %s and downloads disabled" % model_path) + raise FileNotFoundError("Missing or corrupted {} and downloading is disabled".format(model_path)) LOGGER.warning('Downloading recognition model, please wait. ' 'This may take several minutes depending upon your network connection.') - download_and_unzip(model['url'], model['filename'], self.model_storage_directory, verbose) - assert calculate_md5(model_path) == model['md5sum'], corrupt_msg + download_and_unzip(model['url'], model['filename'], model['md5sum'], + self.model_storage_directory, verbose) LOGGER.info('Download complete.') - elif calculate_md5(model_path) != model['md5sum']: - if not self.download_enabled: - raise FileNotFoundError("MD5 mismatch for %s and downloads disabled" % model_path) - LOGGER.warning(corrupt_msg) - os.remove(model_path) - LOGGER.warning('Re-downloading the recognition model, please wait. ' - 'This may take several minutes depending upon your network connection.') - download_and_unzip(model['url'], model['filename'], self.model_storage_directory, verbose) - assert calculate_md5(model_path) == model['md5sum'], corrupt_msg - LOGGER.info('Download complete') self.setLanguageList(lang_list, model) else: # user-defined model @@ -240,25 +230,18 @@ def getDetectorPath(self, detect_network): raise RuntimeError("Unsupport detector network. Support networks are craft and dbnet18.") self.get_textbox = get_textbox self.get_detector = get_detector - corrupt_msg = 'MD5 hash mismatch, possible file corruption' detector_path = os.path.join(self.model_storage_directory, self.detection_models[self.detect_network]['filename']) - if os.path.isfile(detector_path) == False: + if not os.path.isfile(detector_path) or \ + not file_md5_matches(detector_path, self.detection_models[self.detect_network]['md5sum']): if not self.download_enabled: - raise FileNotFoundError("Missing %s and downloads disabled" % detector_path) + raise FileNotFoundError("Missing or corrupted {} and downloading is disabled".format(detector_path)) LOGGER.warning('Downloading detection model, please wait. ' 'This may take several minutes depending upon your network connection.') - download_and_unzip(self.detection_models[self.detect_network]['url'], self.detection_models[self.detect_network]['filename'], self.model_storage_directory, self.verbose) - assert calculate_md5(detector_path) == self.detection_models[self.detect_network]['md5sum'], corrupt_msg + download_and_unzip(self.detection_models[self.detect_network]['url'], + self.detection_models[self.detect_network]['filename'], + self.detection_models[self.detect_network]['md5sum'], + self.model_storage_directory, self.verbose) LOGGER.info('Download complete') - elif calculate_md5(detector_path) != self.detection_models[self.detect_network]['md5sum']: - if not self.download_enabled: - raise FileNotFoundError("MD5 mismatch for %s and downloads disabled" % detector_path) - LOGGER.warning(corrupt_msg) - os.remove(detector_path) - LOGGER.warning('Re-downloading the detection model, please wait. ' - 'This may take several minutes depending upon your network connection.') - download_and_unzip(self.detection_models[self.detect_network]['url'], self.detection_models[self.detect_network]['filename'], self.model_storage_directory, self.verbose) - assert calculate_md5(detector_path) == self.detection_models[self.detect_network]['md5sum'], corrupt_msg else: raise RuntimeError("Unsupport detector network. Support networks are {}.".format(', '.join(self.support_detection_network))) diff --git a/easyocr/utils.py b/easyocr/utils.py index 64435cfdb..4dbf3f2af 100644 --- a/easyocr/utils.py +++ b/easyocr/utils.py @@ -580,20 +580,28 @@ def get_image_list(horizontal_list, free_list, img, model_height = 64, sort_outp image_list = sorted(image_list, key=lambda item: item[0][0][1]) # sort by vertical position return image_list, max_width -def download_and_unzip(url, filename, model_storage_directory, verbose=True): +def download_and_unzip(url, filename, md5, model_storage_directory, verbose=True): + desired_file = os.path.join(model_storage_directory, filename) + if os.path.isfile(desired_file): + if file_md5_matches(desired_file, md5): + return + else: + os.remove(desired_file) zip_path = os.path.join(model_storage_directory, 'temp.zip') reporthook = printProgressBar(prefix='Progress:', suffix='Complete', length=50) if verbose else None urlretrieve(url, zip_path, reporthook=reporthook) with ZipFile(zip_path, 'r') as zipObj: zipObj.extract(filename, model_storage_directory) os.remove(zip_path) + if not file_md5_matches(desired_file, md5): + raise Exception('MD5 hash mismatch after download, possible file corruption') -def calculate_md5(fname): +def file_md5_matches(fname, expected_md5): hash_md5 = hashlib.md5() with open(fname, "rb") as f: for chunk in iter(lambda: f.read(4096), b""): hash_md5.update(chunk) - return hash_md5.hexdigest() + return hash_md5.hexdigest() == expected_md5 def diff(input_list): return max(input_list)-min(input_list)