From 1f2389e2197db6400ec9f588044d79f499c6b5f8 Mon Sep 17 00:00:00 2001 From: federica Date: Wed, 31 Jul 2024 15:58:13 +0100 Subject: [PATCH] fix pre-commit? --- aiida_mlip/workflows/training_workgraph.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/aiida_mlip/workflows/training_workgraph.py b/aiida_mlip/workflows/training_workgraph.py index 19b7dc0..3eef6ac 100644 --- a/aiida_mlip/workflows/training_workgraph.py +++ b/aiida_mlip/workflows/training_workgraph.py @@ -3,12 +3,12 @@ from pathlib import Path from aiida_quantumespresso.workflows.pw.relax import PwRelaxWorkChain -from aiida_workgraph.workgraph import WorkGraph, task +from aiida_workgraph import WorkGraph, task from ase.io import read from sklearn.model_selection import train_test_split -from aiida.orm import Dict, SinglefileData, load_code -from aiida.plugins import CalculationFactory, WorkflowFactory, entry_point +from aiida.orm import SinglefileData +from aiida.plugins import CalculationFactory, WorkflowFactory from aiida_mlip.data.config import JanusConfigfile from aiida_mlip.helpers.help_load import load_structure @@ -67,12 +67,12 @@ def create_input(**inputs: dict) -> SinglefileData: """ input_data = [] - for name, structure in inputs.items(): + for _, structure in inputs.items(): ase_structure = structure.to_ase() extxyz_str = ase_structure.write(format="extxyz") input_data.append(extxyz_str) temp_file_path = "tmp.extxyz" - with open(temp_file_path, "w") as temp_file: + with open(temp_file_path, "w", encoding="utf8") as temp_file: temp_file.write("\n".join(input_data)) file_data = SinglefileData(file=temp_file_path) @@ -113,11 +113,11 @@ def split_xyz_file(xyz_file: SinglefileData) -> dict: test_path = "test.extxyz" validation_path = "validation.extxyz" - with open(train_path, "w") as f: + with open(train_path, "w", encoding="utf8") as f: f.write("\n".join(train_data)) - with open(test_path, "w") as f: + with open(test_path, "w", encoding="utf8") as f: f.write("\n".join(test_data)) - with open(validation_path, "w") as f: + with open(validation_path, "w", encoding="utf8") as f: f.write("\n".join(validation_data)) return { @@ -152,12 +152,13 @@ def update_janusconfigfile(janusconfigfile: JanusConfigfile) -> JanusConfigfile: new_config_path = "./config.yml" - with open(new_config_path, "w") as file: + with open(new_config_path, "w", encoding="utf8") as file: file.write(content) return JanusConfigfile(file=new_config_path) +# pylint: disable=unused-variable def TrainWorkGraph( folder_path: Path, inputs: dict, janusconfigfile: JanusConfigfile ) -> WorkGraph: @@ -202,6 +203,7 @@ def TrainWorkGraph( training_calc = CalculationFactory("mlip.train") train_inputs = {} train_inputs["config_file"] = update_config_task.outputs.result + train_task = wg.add_task( training_calc, name="training", mlip_config=update_config_task.outputs.result )