diff --git a/README.md b/README.md index 8dd80d34a..7f30c42cd 100644 --- a/README.md +++ b/README.md @@ -1,72 +1,9 @@ # EasyOCR - -[![PyPI Status](https://badge.fury.io/py/easyocr.svg)](https://badge.fury.io/py/easyocr) -[![license](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/JaidedAI/EasyOCR/blob/master/LICENSE) -[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.to/easyocr) -[![Tweet](https://img.shields.io/twitter/url/https/github.com/JaidedAI/EasyOCR.svg?style=social)](https://twitter.com/intent/tweet?text=Check%20out%20this%20awesome%20library:%20EasyOCR%20https://github.com/JaidedAI/EasyOCR) -[![Twitter](https://img.shields.io/badge/twitter-@JaidedAI-blue.svg?style=flat)](https://twitter.com/JaidedAI) - -Ready-to-use OCR with 80+ [supported languages](https://www.jaided.ai/easyocr) and all popular writing scripts including: Latin, Chinese, Arabic, Devanagari, Cyrillic, etc. - -[Try Demo on our website](https://www.jaided.ai/easyocr) - -Integrated into [Huggingface Spaces 🤗](https://huggingface.co/spaces) using [Gradio](https://github.com/gradio-app/gradio). Try out the Web Demo: [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/tomofi/EasyOCR) - - +- This repository is forked from JaidedAI/EasyOCR in order to use it for custom training purposes ## What's new -- 25 May 2023 - Version 1.7.0 - - Add Apple Silicon support (thanks[@rayeesoft](https://github.com/rayeesoft) and [@ArtemBernatskyy](https://github.com/ArtemBernatskyy), see [PR](https://github.com/JaidedAI/EasyOCR/pull/1004)) - - Fix several compatibilities -- 15 September 2022 - Version 1.6.2 - - Add CPU support for DBnet - - DBnet will only be compiled when users initialize DBnet detector. -- 1 September 2022 - Version 1.6.1 - - Fix DBnet path bug for Windows - - Add new built-in model `cyrillic_g2`. This model is a new default for Cyrillic script. -- 24 August 2022 - Version 1.6.0 - - Restructure code to support alternative text detectors. - - Add detector `DBnet`, see [paper](https://arxiv.org/abs/2202.10304v1). It can be used by initializing like this `reader = easyocr.Reader(['en'], detect_network = 'dbnet18')`. -- 2 June 2022 - Version 1.5.0 - - Add trainer for CRAFT detection model (thanks[@gmuffiness](https://github.com/gmuffiness), see [PR](https://github.com/JaidedAI/EasyOCR/pull/739)) -- 9 April 2022 - Version 1.4.2 - - Update dependencies (opencv and pillow issues) -- 11 September 2021 - Version 1.4.1 - - Add trainer folder - - Add `readtextlang` method (thanks[@arkya-art](https://github.com/arkya-art), see [PR](https://github.com/JaidedAI/EasyOCR/pull/525)) - - Extend `rotation_info` argument to support all possible angles (thanks[abde0103](https://github.com/abde0103), see [PR](https://github.com/JaidedAI/EasyOCR/pull/515)) -- 29 June 2021 - Version 1.4 - - [Instructions](https://github.com/JaidedAI/EasyOCR/blob/master/custom_model.md) on training/using custom recognition models - - Example [dataset](https://www.jaided.ai/easyocr/modelhub) for model training - - Batched image inference for GPUs (thanks [@SamSamhuns](https://github.com/SamSamhuns), see [PR](https://github.com/JaidedAI/EasyOCR/pull/458)) - - Vertical text support (thanks [@interactivetech](https://github.com/interactivetech)). This is for rotated text, not to be confused with vertical Chinese or Japanese text. (see [PR](https://github.com/JaidedAI/EasyOCR/pull/450)) - - Output in dictionary format (thanks [@A2va](https://github.com/A2va), see [PR](https://github.com/JaidedAI/EasyOCR/pull/441)) -- 30 May 2021 - Version 1.3.2 - - Faster greedy decoder (thanks [@samayala22](https://github.com/samayala22)) - - Fix bug when a text box's aspect ratio is disproportional (thanks [iQuartic](https://iquartic.com/) for bug report) -- 20 April 2021 - Version 1.3.1 - - Add support for PIL image (thanks [@prays](https://github.com/prays)) - - Add Tajik language (tjk) - - Update argument setting for command line - - Add `x_ths` and `y_ths` to control merging behavior when `paragraph=True` -- 21 March 2021 - Version 1.3 - - Second-generation models: multiple times smaller size, multiple times faster inference, additional characters and comparable accuracy to the first generation models. - EasyOCR will choose the latest model by default but you can also specify which model to use by passing `recog_network` argument when creating a `Reader` instance. - For example, `reader = easyocr.Reader(['en','fr'], recog_network='latin_g1')` will use the 1st generation Latin model - - List of all models: [Model hub](https://www.jaided.ai/easyocr/modelhub) - -- [Read all release notes](https://github.com/JaidedAI/EasyOCR/blob/master/releasenotes.md) - -## What's coming next -- Handwritten text support - -## Examples - -![example](examples/example.png) - -![example2](examples/example2.png) - -![example3](examples/example3.png) - +- trainer.py and trainer.ipynb created/modified for custom training +- there will be a detailed description .txt file for dataset usage +- preprocess and usage of EasyOCR scripts has been added to the scripts folder ## Installation @@ -81,12 +18,12 @@ pip install easyocr For the latest development release: ``` bash -pip install git+https://github.com/JaidedAI/EasyOCR.git +pip install git+https://github.com/yilmazmusa08/EasyOCR.git ``` Note 1: For Windows, please install torch and torchvision first by following the official instructions here https://pytorch.org. On the pytorch website, be sure to select the right CUDA version you have. If you intend to run on CPU mode only, select `CUDA = None`. -Note 2: We also provide a Dockerfile [here](https://github.com/JaidedAI/EasyOCR/blob/master/Dockerfile). +Note 2: We also provide a Dockerfile [here](https://github.com/yilmazmusa08/EasyOCR/blob/master/Dockerfile). ## Usage @@ -143,11 +80,6 @@ For more information, read the [tutorial](https://www.jaided.ai/easyocr/tutorial $ easyocr -l ch_sim en -f chinese.jpg --detail=1 --gpu=True ``` -## Train/use your own model - -For recognition model, [Read here](https://github.com/JaidedAI/EasyOCR/blob/master/custom_model.md). - -For detection model (CRAFT), [Read here](https://github.com/JaidedAI/EasyOCR/blob/master/trainer/craft/README.md). ## Implementation Roadmap diff --git a/scripts/cv2_fix.py b/scripts/cv2_fix.py new file mode 100644 index 000000000..35c4d23c6 --- /dev/null +++ b/scripts/cv2_fix.py @@ -0,0 +1,82 @@ +import cv2 +from PIL import Image +import os +import numpy as np +from typing import Tuple, Union +import math + +angle_zero = 0.16 +target_size = (400, 72) + +# Path to the directory containing image files +dir_path = os.path.join("..", "tests/inputs") +out_dir = os.path.join("..", "tests/outputs") + +# Create the directories if they do not exist +if not os.path.exists(dir_path): + os.makedirs(dir_path) + +if not os.path.exists(out_dir): + os.makedirs(out_dir) + +def rotate( + image: np.ndarray, angle: float, background: Union[int, Tuple[int, int, int]] +) -> np.ndarray: + old_width, old_height = image.shape[:2] + angle_radian = math.radians(angle) + width = abs(np.sin(angle_radian) * old_height) + abs(np.cos(angle_radian) * old_width) + height = abs(np.sin(angle_radian) * old_width) + abs(np.cos(angle_radian) * old_height) + + image_center = tuple(np.array(image.shape[1::-1]) / 2) + rot_mat = cv2.getRotationMatrix2D(image_center, angle, 1.0) + rot_mat[1, 2] += (width - old_width) / 2 + rot_mat[0, 2] += (height - old_height) / 2 + return cv2.warpAffine(image, rot_mat, (int(round(height)), int(round(width))), borderValue=background) + +# Loop through all image files in the directory +for filename in os.listdir(dir_path): + # Load the image and convert it to a NumPy array + image_path = os.path.join(dir_path, filename) + image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE) + w, h = image.shape + + # Crop the left 12.5% of the image + crop_width = image.shape[1] // 8 + image_cropped = image[:, :crop_width] + + # Split the cropped image into top and bottom halves + height, width = image_cropped.shape + top_half = image_cropped[:height//2, :] + bottom_half = image_cropped[height//2:, :] + + # Compute the mean intensity values of the top and bottom halves + top_whiteness = np.mean(255 - top_half) + bottom_whiteness = np.mean(255 - bottom_half) + + # Print the aspect ratio, top whiteness, and bottom whiteness + aspect_ratio = round(float(w) / float(h), 2) + print(f"Aspect Ratio: {aspect_ratio}, Top Whiteness: {top_whiteness}, Bottom Whiteness: {bottom_whiteness}") + + + image_np = np.array(image) + image_np = np.rot90(image_np, k=1) + + if aspect_ratio >= angle_zero: + angle = math.degrees(math.atan(aspect_ratio - angle_zero)) - 90 + image_rot1 = rotate(image_np, angle, (0, 0, 0)) + image_rot2 = rotate(image_np, 180-angle, (0, 0, 0)) + + # Compare the mean intensity values of the top and bottom halves + threshold = 0 + if top_whiteness > bottom_whiteness * (1 + threshold): + image_np = image_rot2 + print(f"{filename}: top half is brighter than bottom half") + elif bottom_whiteness > top_whiteness * (1 + threshold): + print(f"{filename}: bottom half is brighter than top half") + image_np = image_rot1 + else: + print(f"{filename}: top and bottom halves have similar brightness") + # Save the resized image to a file + out_path = os.path.join(out_dir, f'{filename}_processed.jpg') + Image.fromarray(image_np).save(out_path) + diff --git a/scripts/final_predict.py b/scripts/final_predict.py new file mode 100644 index 000000000..b80fd0777 --- /dev/null +++ b/scripts/final_predict.py @@ -0,0 +1,121 @@ +import cv2 +import easyocr +from PIL import Image +import os +import numpy as np +from typing import Tuple, Union +import math +import matplotlib +matplotlib.use('TkAgg') +import matplotlib.pyplot as plt + +angle_zero = 0.16 +target_size = (400, 72) + +# Initialize the Reader object with the languages you want to recognize and the desired parameters +reader = easyocr.Reader(['en', 'tr'], gpu=True) + +reader.model_path = './EasyOCR-master/trainer/saved_models/custom_model/best_accuracy.pth' + +# Path to the directory containing image files +dir_path = os.path.join("..", "tests/inputs") +out_dir = os.path.join("..", "tests/outputs") + +# Create the directories if they do not exist +if not os.path.exists(dir_path): + os.makedirs(dir_path) + +if not os.path.exists(out_dir): + os.makedirs(out_dir) + +def rotate( + image: np.ndarray, angle: float, background: Union[int, Tuple[int, int, int]] +) -> np.ndarray: + old_width, old_height = image.shape[:2] + angle_radian = math.radians(angle) + width = abs(np.sin(angle_radian) * old_height) + abs(np.cos(angle_radian) * old_width) + height = abs(np.sin(angle_radian) * old_width) + abs(np.cos(angle_radian) * old_height) + + image_center = tuple(np.array(image.shape[1::-1]) / 2) + rot_mat = cv2.getRotationMatrix2D(image_center, angle, 1.0) + rot_mat[1, 2] += (width - old_width) / 2 + rot_mat[0, 2] += (height - old_height) / 2 + return cv2.warpAffine(image, rot_mat, (int(round(height)), int(round(width))), borderValue=background) + +# Loop through all image files in the directory +for filename in os.listdir(dir_path): + # Load the image and convert it to a NumPy array + image_path = os.path.join(dir_path, filename) + image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE) + w, h = image.shape + + # Crop the left 12.5% of the image + crop_width = image.shape[1] // 8 + image_cropped = image[:, :crop_width] + + # Split the cropped image into top and bottom halves + height, width = image_cropped.shape + top_half = image_cropped[:height//2, :] + bottom_half = image_cropped[height//2:, :] + + # Compute the mean intensity values of the top and bottom halves + top_whiteness = np.mean(255 - top_half) + bottom_whiteness = np.mean(255 - bottom_half) + + # Print the aspect ratio, top whiteness, and bottom whiteness + aspect_ratio = round(float(w) / float(h), 2) + print(f"Aspect Ratio: {aspect_ratio}, Top Whiteness: {top_whiteness}, Bottom Whiteness: {bottom_whiteness}") + + + image_np = np.array(image) + image_np = np.rot90(image_np, k=1) + + if aspect_ratio >= angle_zero: + angle = math.degrees(math.atan(aspect_ratio - angle_zero)) - 90 + image_rot1 = rotate(image_np, angle, (0, 0, 0)) + image_rot2 = rotate(image_np, 180-angle, (0, 0, 0)) + + # Compare the mean intensity values of the top and bottom halves + threshold = 0 + if top_whiteness > bottom_whiteness * (1 + threshold): + image_np = image_rot2 + print(f"{filename}: top half is brighter than bottom half") + elif bottom_whiteness > top_whiteness * (1 + threshold): + print(f"{filename}: bottom half is brighter than top half") + image_np = image_rot1 + else: + print(f"{filename}: top and bottom halves have similar brightness") + + + # Perform OCR on the image + result = reader.readtext(image_np, allowlist='0123456789:', + contrast_ths=0.1, adjust_contrast=0.5, + text_threshold=0.7, low_text=0.4, link_threshold=0.4, canvas_size=2560, + mag_ratio=1.0, slope_ths=0.1, ycenter_ths=0.5, height_ths=0.5, + width_ths=0.5, add_margin=0.1, x_ths=1.0, y_ths=0.5, + decoder='greedy', beamWidth=50, batch_size=2) + + # Plot the image with bounding boxes and recognized text + fig, ax = plt.subplots(figsize=(10,10)) + ax.imshow(image_np, cmap='gray') + + + for r in result: + bbox = r[0] + text = r[1] + confidence = r[2] + x1, y1 = bbox[0] + x2, y2 = bbox[1] + x3, y3 = bbox[2] + x4, y4 = bbox[3] + poly = plt.Polygon(bbox, facecolor=None, edgecolor='green', linewidth=2, fill=False) + ax.add_patch(poly) + ax.text(x1, y1-10, f'{text} ({confidence:.2f})', fontsize=12, color='red', bbox=dict(facecolor='white', alpha=0.8, edgecolor='none')) + + + # Save the plotted image to a file + plt.savefig(os.path.join(out_dir, f'{filename}')) + + # Close the figure to release memory resources + plt.close(fig) + diff --git a/trainer/GPU_clear.py b/trainer/GPU_clear.py new file mode 100644 index 000000000..ed8ac5b89 --- /dev/null +++ b/trainer/GPU_clear.py @@ -0,0 +1,3 @@ +import torch + +torch.cuda.empty_cache() \ No newline at end of file diff --git a/trainer/all_data/folder.txt b/trainer/all_data/folder.txt index 8e455ad77..f9006aab6 100644 --- a/trainer/all_data/folder.txt +++ b/trainer/all_data/folder.txt @@ -1 +1 @@ -place dataset folder here +there should be a "train" and "val" folder that should both contain images and "labels.csv". "labels.csv" should have "filename" as image file names and "labels" as corresponding label. diff --git a/trainer/config_files/model.ini b/trainer/config_files/model.ini new file mode 100644 index 000000000..1f53556ea --- /dev/null +++ b/trainer/config_files/model.ini @@ -0,0 +1,21 @@ + Transformation: Bu parametre, modelin resimlere uygulayacağı ön işleme işlevini belirler. Bu işlevler, resimleri boyutlandırma, kesme, renk düzenleme vb. işlemler olabilir. Örnek olarak, resimlere boyutlandırma işlevini eklemek isteyebilirsiniz. + + FeatureExtraction: Bu parametre, resim özelliklerini çıkarmak için kullanılan ağ mimarisini belirler. Bu ağ mimarisi, özellikle Convolutional Neural Networks (CNN) gibi evrişimli sinir ağlarından oluşabilir. Örnek olarak, ResNet, EfficientNet, Inception, VGG gibi farklı mimariler deneyebilirsiniz. + + SequenceModeling: Bu parametre, karakter seçimini ve modelin karakterleri tanıma yöntemini belirler. Bu aşamada, genellikle LSTM gibi recurrent neural networks kullanılır. Örnek olarak, GRU veya Convolutional Sequence to Sequence Learning gibi farklı modelleri deneyebilirsiniz. + + Prediction: Bu parametre, modelin karakterlerin belirlenmesinde kullanacağı yöntemi belirler. Connectionist Temporal Classification (CTC) veya Attn gibi farklı yöntemler kullanılabilir. + + num_fiducial: Bu parametre, resimlerin geometrik dönüşümünü hesaplamak için kullanılan sabit sayıdır. + + input_channel: Bu parametre, girdi resminin kaç kanaldan oluştuğunu belirler. Genellikle, tek kanallı siyah-beyaz resimler için 1 ve renkli resimler için 3 olarak ayarlanır. + + output_channel: Bu parametre, FeatureExtraction'dan çıkan çıktı kanal sayısını belirler. + + hidden_size: Bu parametre, SequenceModeling'de kullanılan gizli katman boyutunu belirler. + + decode: Bu parametre, Prediction aşamasında kullanılan karakter tahmin yöntemini belirler. + + freeze_FeatureFxtraction: Bu parametre, eğitim sırasında FeatureExtraction ağındaki ağırlıkların dondurulup dondurulmayacağını belirler. Örnek olarak, eğitim süresini azaltmak için önceden eğitilmiş bir VGG ağının ağırlıklarını kullanabilir ve dondurabilirsiniz. + + freeze_SequenceModeling: Bu parametre, eğitim sırasında SequenceModeling ağındaki ağırlıkların dondurulup dondurulmayacağını belirler. Örnek olarak, daha hızlı eğitim için önceden eğitilmiş bir LSTM ağının ağırlıklarını kullanabilir ve dondurabilirsiniz. \ No newline at end of file diff --git a/trainer/config_files/train.yaml b/trainer/config_files/train.yaml new file mode 100644 index 000000000..8c5938a9f --- /dev/null +++ b/trainer/config_files/train.yaml @@ -0,0 +1,45 @@ +number: '0123456789' +symbol: "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~ €" +lang_char: 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' +experiment_name: 'custom_model' +train_data: './EasyOCR-master/trainer/all_data' +valid_data: './EasyOCR-master/trainer/all_data/val' +manualSeed: 1111 +workers: 10 +batch_size: 32 +num_iter: 200000 +valInterval: 1000 +saved_model: '.EasyOCR-master/trainer/saved_models/best_accuracy.pth' +FT: False +optim: False # default is Adadelta +lr: 0.1 +beta1: 0.9 +rho: 0.95 +eps: 0.00000001 +grad_clip: 5 +# Data processing +select_data: '.EasyOCR-master/trainer/all_data/train' # veri kümesi adı +batch_ratio: '876-443' +total_data_usage_ratio: 0.15 +batch_max_length: 6 # sayıların toplam hanesi + nokta + virgül +imgH: 72 +imgW: 400 +rgb: False +sensitive: True +PAD: True +contrast_adjust: 0.0 +data_filtering_off: False +# Model Architecture +Transformation: 'None' +FeatureExtraction: 'VGG' +SequenceModeling: 'BiLSTM' +Prediction: 'CTC' +num_fiducial: 16 +input_channel: 1 +output_channel: 256 +hidden_size: 256 +decode: 'greedy' +new_prediction: False +freeze_FeatureFxtraction: False +freeze_SequenceModeling: False + diff --git a/trainer/saved_models/best_accuracy.pth b/trainer/saved_models/best_accuracy.pth new file mode 100644 index 000000000..141a7997d Binary files /dev/null and b/trainer/saved_models/best_accuracy.pth differ diff --git a/trainer/saved_models/best_model.pth b/trainer/saved_models/best_model.pth new file mode 100644 index 000000000..141a7997d Binary files /dev/null and b/trainer/saved_models/best_model.pth differ diff --git a/trainer/train.py b/trainer/train.py index e0066f3d0..3bc6a1772 100644 --- a/trainer/train.py +++ b/trainer/train.py @@ -271,8 +271,8 @@ def train(opt, show_number = 2, amp=False): log.write(predicted_result_log + '\n') print('validation time: ', time.time()-t1) t1=time.time() - # save model per 1e+4 iter. - if (i + 1) % 1e+4 == 0: + # save model per 1e+3 iter. + if (i + 1) % 1e+3 == 0: torch.save( model.state_dict(), f'./saved_models/{opt.experiment_name}/iter_{i+1}.pth') diff --git a/trainer/trainer.ipynb b/trainer/trainer.ipynb index 712bf8334..08fa6bd21 100644 --- a/trainer/trainer.ipynb +++ b/trainer/trainer.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": { "ExecuteTime": { "end_time": "2021-07-23T04:19:23.488642Z", @@ -21,7 +21,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": { "ExecuteTime": { "end_time": "2021-07-23T04:19:23.885144Z", @@ -37,7 +37,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": { "ExecuteTime": { "end_time": "2021-07-23T04:19:24.119144Z", @@ -56,7 +56,7 @@ " if opt.lang_char == 'None':\n", " characters = ''\n", " for data in opt['select_data'].split('-'):\n", - " csv_path = os.path.join(opt['train_data'], data, 'labels.csv')\n", + " csv_path = os.path.join(opt['train'], data, 'labels.csv')\n", " df = pd.read_csv(csv_path, sep='^([^,]+),', engine='python', usecols=['filename', 'words'], keep_default_na=False)\n", " all_char = ''.join(df['words'])\n", " characters += ''.join(set(all_char))\n", @@ -68,19 +68,330 @@ " return opt" ] }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "opt = get_config(\"config_files/train.yaml\")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "\n", + "# Create a device object to represent the GPU\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n" + ] + }, { "cell_type": "code", "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if torch.cuda.is_available():\n", + " \n", + " print(\"CUDA is available\")\n", + "else:\n", + " print(\"CUDA is not available\")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, "metadata": { "ExecuteTime": { "end_time": "2021-07-23T04:49:07.045060Z", "start_time": "2021-07-23T04:20:15.050992Z" } }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Filtering the images containing characters which are not in opt.character\n", + "Filtering the images whose label is longer than opt.batch_max_length\n", + "--------------------------------------------------------------------------------\n", + "dataset_root: /home/musa/Desktop/Work/Mina/Project3/EasyOCR-master/trainer/all_data\n", + "opt.select_data: ['/home/musa/Desktop/Work/Mina/Project3/EasyOCR', 'master/trainer/all_data/train']\n", + "opt.batch_ratio: ['10162', '2591']\n", + "--------------------------------------------------------------------------------\n", + "dataset_root: /home/musa/Desktop/Work/Mina/Project3/EasyOCR-master/trainer/all_data\t dataset: /home/musa/Desktop/Work/Mina/Project3/EasyOCR\n", + "/home/musa/Desktop/Work/Mina/Project3/EasyOCR-master/trainer/all_data/val\n", + " filename words\n", + "0 10445.jpg 1813\n", + "1 10714.jpg 1801\n", + "2 10819.jpg 22\n", + "3 11635.jpg 1806\n", + "4 11720.jpg 1808\n", + "... ... ...\n", + "2627 11010.jpg 1819\n", + "2628 11949.jpg 1827\n", + "2629 11498.jpg 1823\n", + "2630 11557.jpg 1801\n", + "2631 10504.jpg 1801\n", + "\n", + "[2632 rows x 2 columns]\n", + "sub-directory:\t/val\t num samples: 2607\n", + "/home/musa/Desktop/Work/Mina/Project3/EasyOCR-master/trainer/all_data/train\n", + " filename words\n", + "0 4046.jpg 1816\n", + "1 6996.jpg 1804\n", + "2 2601.jpg 1832\n", + "3 232.jpg 1824\n", + "4 6493.jpg 1841\n", + "... ... ...\n", + "10224 4189.jpg 1837\n", + "10225 8945.jpg 1826\n", + "10226 img005-055.png 4\n", + "10227 67.jpg 1837\n", + "10228 1326.jpg 1837\n", + "\n", + "[10229 rows x 2 columns]\n", + "sub-directory:\t/train\t num samples: 10166\n", + "num total samples of /home/musa/Desktop/Work/Mina/Project3/EasyOCR: 12773 x 1.0 (total_data_usage_ratio) = 12773\n", + "num samples of /home/musa/Desktop/Work/Mina/Project3/EasyOCR per batch: 32 x 10162.0 (batch_ratio) = 325184\n", + "--------------------------------------------------------------------------------\n", + "dataset_root: /home/musa/Desktop/Work/Mina/Project3/EasyOCR-master/trainer/all_data\t dataset: master/trainer/all_data/train\n", + "/home/musa/Desktop/Work/Mina/Project3/EasyOCR-master/trainer/all_data/train\n", + " filename words\n", + "0 4046.jpg 1816\n", + "1 6996.jpg 1804\n", + "2 2601.jpg 1832\n", + "3 232.jpg 1824\n", + "4 6493.jpg 1841\n", + "... ... ...\n", + "10224 4189.jpg 1837\n", + "10225 8945.jpg 1826\n", + "10226 img005-055.png 4\n", + "10227 67.jpg 1837\n", + "10228 1326.jpg 1837\n", + "\n", + "[10229 rows x 2 columns]\n", + "sub-directory:\t/train\t num samples: 10166\n", + "num total samples of master/trainer/all_data/train: 10166 x 1.0 (total_data_usage_ratio) = 10166\n", + "num samples of master/trainer/all_data/train per batch: 32 x 2591.0 (batch_ratio) = 82912\n", + "--------------------------------------------------------------------------------\n", + "Total_batch_size: 325184+82912 = 408096\n", + "--------------------------------------------------------------------------------\n", + "dataset_root: /home/musa/Desktop/Work/Mina/Project3/EasyOCR-master/trainer/all_data/val\t dataset: /\n", + "/home/musa/Desktop/Work/Mina/Project3/EasyOCR-master/trainer/all_data/val/\n", + " filename words\n", + "0 10445.jpg 1813\n", + "1 10714.jpg 1801\n", + "2 10819.jpg 22\n", + "3 11635.jpg 1806\n", + "4 11720.jpg 1808\n", + "... ... ...\n", + "2627 11010.jpg 1819\n", + "2628 11949.jpg 1827\n", + "2629 11498.jpg 1823\n", + "2630 11557.jpg 1801\n", + "2631 10504.jpg 1801\n", + "\n", + "[2632 rows x 2 columns]\n", + "sub-directory:\t/.\t num samples: 2607\n", + "--------------------------------------------------------------------------------\n", + "No Transformation module specified\n", + "model input parameters 64 600 20 1 256 256 12 5 None VGG BiLSTM CTC\n", + "Model:\n", + "DataParallel(\n", + " (module): Model(\n", + " (FeatureExtraction): VGG_FeatureExtractor(\n", + " (ConvNet): Sequential(\n", + " (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (1): ReLU(inplace=True)\n", + " (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", + " (3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (4): ReLU(inplace=True)\n", + " (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", + " (6): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (7): ReLU(inplace=True)\n", + " (8): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (9): ReLU(inplace=True)\n", + " (10): MaxPool2d(kernel_size=(2, 1), stride=(2, 1), padding=0, dilation=1, ceil_mode=False)\n", + " (11): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (12): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (13): ReLU(inplace=True)\n", + " (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (15): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (16): ReLU(inplace=True)\n", + " (17): MaxPool2d(kernel_size=(2, 1), stride=(2, 1), padding=0, dilation=1, ceil_mode=False)\n", + " (18): Conv2d(256, 256, kernel_size=(2, 2), stride=(1, 1))\n", + " (19): ReLU(inplace=True)\n", + " )\n", + " )\n", + " (AdaptiveAvgPool): AdaptiveAvgPool2d(output_size=(None, 1))\n", + " (SequenceModeling): Sequential(\n", + " (0): BidirectionalLSTM(\n", + " (rnn): LSTM(256, 256, batch_first=True, bidirectional=True)\n", + " (linear): Linear(in_features=512, out_features=256, bias=True)\n", + " )\n", + " (1): BidirectionalLSTM(\n", + " (rnn): LSTM(256, 256, batch_first=True, bidirectional=True)\n", + " (linear): Linear(in_features=512, out_features=256, bias=True)\n", + " )\n", + " )\n", + " (Prediction): Linear(in_features=256, out_features=12, bias=True)\n", + " )\n", + ")\n", + "Modules, Parameters\n", + "module.FeatureExtraction.ConvNet.0.weight 288\n", + "module.FeatureExtraction.ConvNet.0.bias 32\n", + "module.FeatureExtraction.ConvNet.3.weight 18432\n", + "module.FeatureExtraction.ConvNet.3.bias 64\n", + "module.FeatureExtraction.ConvNet.6.weight 73728\n", + "module.FeatureExtraction.ConvNet.6.bias 128\n", + "module.FeatureExtraction.ConvNet.8.weight 147456\n", + "module.FeatureExtraction.ConvNet.8.bias 128\n", + "module.FeatureExtraction.ConvNet.11.weight 294912\n", + "module.FeatureExtraction.ConvNet.12.weight 256\n", + "module.FeatureExtraction.ConvNet.12.bias 256\n", + "module.FeatureExtraction.ConvNet.14.weight 589824\n", + "module.FeatureExtraction.ConvNet.15.weight 256\n", + "module.FeatureExtraction.ConvNet.15.bias 256\n", + "module.FeatureExtraction.ConvNet.18.weight 262144\n", + "module.FeatureExtraction.ConvNet.18.bias 256\n", + "module.SequenceModeling.0.rnn.weight_ih_l0 262144\n", + "module.SequenceModeling.0.rnn.weight_hh_l0 262144\n", + "module.SequenceModeling.0.rnn.bias_ih_l0 1024\n", + "module.SequenceModeling.0.rnn.bias_hh_l0 1024\n", + "module.SequenceModeling.0.rnn.weight_ih_l0_reverse 262144\n", + "module.SequenceModeling.0.rnn.weight_hh_l0_reverse 262144\n", + "module.SequenceModeling.0.rnn.bias_ih_l0_reverse 1024\n", + "module.SequenceModeling.0.rnn.bias_hh_l0_reverse 1024\n", + "module.SequenceModeling.0.linear.weight 131072\n", + "module.SequenceModeling.0.linear.bias 256\n", + "module.SequenceModeling.1.rnn.weight_ih_l0 262144\n", + "module.SequenceModeling.1.rnn.weight_hh_l0 262144\n", + "module.SequenceModeling.1.rnn.bias_ih_l0 1024\n", + "module.SequenceModeling.1.rnn.bias_hh_l0 1024\n", + "module.SequenceModeling.1.rnn.weight_ih_l0_reverse 262144\n", + "module.SequenceModeling.1.rnn.weight_hh_l0_reverse 262144\n", + "module.SequenceModeling.1.rnn.bias_ih_l0_reverse 1024\n", + "module.SequenceModeling.1.rnn.bias_hh_l0_reverse 1024\n", + "module.SequenceModeling.1.linear.weight 131072\n", + "module.SequenceModeling.1.linear.bias 256\n", + "module.Prediction.weight 3072\n", + "module.Prediction.bias 12\n", + "Total Trainable Params: 3759500\n", + "Trainable params num : 3759500\n", + "Optimizer:\n", + "Adadelta (\n", + "Parameter Group 0\n", + " differentiable: False\n", + " eps: 1e-08\n", + " foreach: None\n", + " lr: 1.0\n", + " maximize: False\n", + " rho: 0.95\n", + " weight_decay: 0\n", + ")\n", + "------------ Options -------------\n", + "number: 0123456789\n", + "symbol: .\n", + "lang_char: \n", + "experiment_name: custom_model\n", + "train_data: /home/musa/Desktop/Work/Mina/Project3/EasyOCR-master/trainer/all_data\n", + "valid_data: /home/musa/Desktop/Work/Mina/Project3/EasyOCR-master/trainer/all_data/val\n", + "manualSeed: 1111\n", + "workers: 6\n", + "batch_size: 408096\n", + "num_iter: 300000\n", + "valInterval: 20000\n", + "saved_model: \n", + "FT: False\n", + "optim: False\n", + "lr: 1.0\n", + "beta1: 0.9\n", + "rho: 0.95\n", + "eps: 1e-08\n", + "grad_clip: 5\n", + "select_data: ['/home/musa/Desktop/Work/Mina/Project3/EasyOCR', 'master/trainer/all_data/train']\n", + "batch_ratio: ['10162', '2591']\n", + "total_data_usage_ratio: 1.0\n", + "batch_max_length: 5\n", + "imgH: 64\n", + "imgW: 600\n", + "rgb: False\n", + "sensitive: True\n", + "PAD: True\n", + "contrast_adjust: 0.0\n", + "data_filtering_off: False\n", + "Transformation: None\n", + "FeatureExtraction: VGG\n", + "SequenceModeling: BiLSTM\n", + "Prediction: CTC\n", + "num_fiducial: 20\n", + "input_channel: 1\n", + "output_channel: 256\n", + "hidden_size: 256\n", + "decode: greedy\n", + "new_prediction: False\n", + "freeze_FeatureFxtraction: False\n", + "freeze_SequenceModeling: False\n", + "character: 0123456789.\n", + "num_class: 12\n", + "---------------------------------------\n", + "\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/musa/anaconda3/lib/python3.10/site-packages/torch/cuda/amp/grad_scaler.py:120: UserWarning: torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling.\n", + " warnings.warn(\"torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling.\")\n" + ] + }, + { + "ename": "FileNotFoundError", + "evalue": "Caught FileNotFoundError in DataLoader worker process 0.\nOriginal Traceback (most recent call last):\n File \"/home/musa/anaconda3/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py\", line 308, in _worker_loop\n data = fetcher.fetch(index)\n File \"/home/musa/anaconda3/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py\", line 51, in fetch\n data = [self.dataset[idx] for idx in possibly_batched_index]\n File \"/home/musa/anaconda3/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py\", line 51, in \n data = [self.dataset[idx] for idx in possibly_batched_index]\n File \"/home/musa/anaconda3/lib/python3.10/site-packages/torch/utils/data/dataset.py\", line 298, in __getitem__\n return self.dataset[self.indices[idx]]\n File \"/home/musa/anaconda3/lib/python3.10/site-packages/torch/utils/data/dataset.py\", line 243, in __getitem__\n return self.datasets[dataset_idx][sample_idx]\n File \"/home/musa/Desktop/Work/Mina/Project3/EasyOCR-master/trainer/dataset.py\", line 182, in __getitem__\n img = Image.open(img_fpath).convert('L')\n File \"/home/musa/anaconda3/lib/python3.10/site-packages/PIL/Image.py\", line 3227, in open\n fp = builtins.open(filename, \"rb\")\nFileNotFoundError: [Errno 2] No such file or directory: '/home/musa/Desktop/Work/Mina/Project3/EasyOCR-master/trainer/all_data/train/57.28.png'\n", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[7], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43mopt\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mamp\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mto(device)\n", + "File \u001b[0;32m~/Desktop/Work/Mina/Project3/EasyOCR-master/trainer/train.py:203\u001b[0m, in \u001b[0;36mtrain\u001b[0;34m(opt, show_number, amp)\u001b[0m\n\u001b[1;32m 201\u001b[0m scaler\u001b[38;5;241m.\u001b[39mupdate()\n\u001b[1;32m 202\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 203\u001b[0m image_tensors, labels \u001b[38;5;241m=\u001b[39m \u001b[43mtrain_dataset\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_batch\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 204\u001b[0m image \u001b[38;5;241m=\u001b[39m image_tensors\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[1;32m 205\u001b[0m text, length \u001b[38;5;241m=\u001b[39m converter\u001b[38;5;241m.\u001b[39mencode(labels, batch_max_length\u001b[38;5;241m=\u001b[39mopt\u001b[38;5;241m.\u001b[39mbatch_max_length)\n", + "File \u001b[0;32m~/Desktop/Work/Mina/Project3/EasyOCR-master/trainer/dataset.py:101\u001b[0m, in \u001b[0;36mBatch_Balanced_Dataset.get_batch\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 99\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i, data_loader_iter \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataloader_iter_list):\n\u001b[1;32m 100\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 101\u001b[0m image, text \u001b[38;5;241m=\u001b[39m \u001b[43mdata_loader_iter\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__next__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 102\u001b[0m balanced_batch_images\u001b[38;5;241m.\u001b[39mappend(image)\n\u001b[1;32m 103\u001b[0m balanced_batch_texts \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m text\n", + "File \u001b[0;32m~/anaconda3/lib/python3.10/site-packages/torch/utils/data/dataloader.py:634\u001b[0m, in \u001b[0;36m_BaseDataLoaderIter.__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 631\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_sampler_iter \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 632\u001b[0m \u001b[38;5;66;03m# TODO(https://github.com/pytorch/pytorch/issues/76750)\u001b[39;00m\n\u001b[1;32m 633\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_reset() \u001b[38;5;66;03m# type: ignore[call-arg]\u001b[39;00m\n\u001b[0;32m--> 634\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_next_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 635\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 636\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_dataset_kind \u001b[38;5;241m==\u001b[39m _DatasetKind\u001b[38;5;241m.\u001b[39mIterable \u001b[38;5;129;01mand\u001b[39;00m \\\n\u001b[1;32m 637\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \\\n\u001b[1;32m 638\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m>\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called:\n", + "File \u001b[0;32m~/anaconda3/lib/python3.10/site-packages/torch/utils/data/dataloader.py:1346\u001b[0m, in \u001b[0;36m_MultiProcessingDataLoaderIter._next_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1344\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 1345\u001b[0m \u001b[38;5;28;01mdel\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_task_info[idx]\n\u001b[0;32m-> 1346\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_process_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/anaconda3/lib/python3.10/site-packages/torch/utils/data/dataloader.py:1372\u001b[0m, in \u001b[0;36m_MultiProcessingDataLoaderIter._process_data\u001b[0;34m(self, data)\u001b[0m\n\u001b[1;32m 1370\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_try_put_index()\n\u001b[1;32m 1371\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(data, ExceptionWrapper):\n\u001b[0;32m-> 1372\u001b[0m \u001b[43mdata\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mreraise\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1373\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m data\n", + "File \u001b[0;32m~/anaconda3/lib/python3.10/site-packages/torch/_utils.py:644\u001b[0m, in \u001b[0;36mExceptionWrapper.reraise\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 640\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m:\n\u001b[1;32m 641\u001b[0m \u001b[38;5;66;03m# If the exception takes multiple arguments, don't try to\u001b[39;00m\n\u001b[1;32m 642\u001b[0m \u001b[38;5;66;03m# instantiate since we don't know how to\u001b[39;00m\n\u001b[1;32m 643\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(msg) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;28mNone\u001b[39m\n\u001b[0;32m--> 644\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m exception\n", + "\u001b[0;31mFileNotFoundError\u001b[0m: Caught FileNotFoundError in DataLoader worker process 0.\nOriginal Traceback (most recent call last):\n File \"/home/musa/anaconda3/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py\", line 308, in _worker_loop\n data = fetcher.fetch(index)\n File \"/home/musa/anaconda3/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py\", line 51, in fetch\n data = [self.dataset[idx] for idx in possibly_batched_index]\n File \"/home/musa/anaconda3/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py\", line 51, in \n data = [self.dataset[idx] for idx in possibly_batched_index]\n File \"/home/musa/anaconda3/lib/python3.10/site-packages/torch/utils/data/dataset.py\", line 298, in __getitem__\n return self.dataset[self.indices[idx]]\n File \"/home/musa/anaconda3/lib/python3.10/site-packages/torch/utils/data/dataset.py\", line 243, in __getitem__\n return self.datasets[dataset_idx][sample_idx]\n File \"/home/musa/Desktop/Work/Mina/Project3/EasyOCR-master/trainer/dataset.py\", line 182, in __getitem__\n img = Image.open(img_fpath).convert('L')\n File \"/home/musa/anaconda3/lib/python3.10/site-packages/PIL/Image.py\", line 3227, in open\n fp = builtins.open(filename, \"rb\")\nFileNotFoundError: [Errno 2] No such file or directory: '/home/musa/Desktop/Work/Mina/Project3/EasyOCR-master/trainer/all_data/train/57.28.png'\n" + ] + } + ], + "source": [ + "train(opt, amp=False).to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, "outputs": [], "source": [ - "opt = get_config(\"config_files/en_filtered_config.yaml\")\n", - "train(opt, amp=False)" + "\"\"\"# Move the model to the GPU\n", + "model = train(opt, amp=False).to(device)\n", + "\n", + "# Move the data to the GPU\n", + "# Here you should replace this with your own data loading and preprocessing code\n", + "data = torch.randn(32, 3, 224, 224).to(device)\n", + "\n", + "# Train the model on the GPU\n", + "model.train()\n", + "for i in range(100):\n", + " output = model(data)\n", + " loss = output.mean()\n", + " loss.backward()\"\"\"" ] }, { @@ -88,12 +399,14 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [] + "source": [ + "print(\"finish\")" + ] } ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -107,9 +420,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.11" + "version": "3.10.9" } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 4 } diff --git a/trainer/trainer.py b/trainer/trainer.py new file mode 100644 index 000000000..d76b74e83 --- /dev/null +++ b/trainer/trainer.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python +# coding: utf-8 + +# In[2]: + + +import os +import torch.backends.cudnn as cudnn +import yaml +from train import train +from utils import AttrDict +import pandas as pd + + +# In[3]: + + +cudnn.benchmark = True +cudnn.deterministic = False + + +# In[4]: + + +def get_config(file_path): + with open(file_path, 'r', encoding="utf8") as stream: + opt = yaml.safe_load(stream) + opt = AttrDict(opt) + if opt.lang_char == 'None': + characters = '' + for data in opt['select_data'].split('-'): + csv_path = os.path.join(opt['train'], data, 'labels.csv') + df = pd.read_csv(csv_path, sep='^([^,]+),', engine='python', usecols=['filename', 'words'], keep_default_na=False) + all_char = ''.join(df['words']) + characters += ''.join(set(all_char)) + characters = sorted(set(characters)) + opt.character= ''.join(characters) + else: + opt.character = opt.number + opt.symbol + opt.lang_char + os.makedirs(f'./saved_models/{opt.experiment_name}', exist_ok=True) + return opt + + +# In[5]: + + +opt = get_config("config_files/train.yaml") + + +import torch + +# Create a device object to represent the GPU +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + + +if torch.cuda.is_available(): + + print("CUDA is available") +else: + print("CUDA is not available") + + +train(opt, amp=True).to(device) + + +print("finish") +