Skip to content

Commit

Permalink
fix pre-commit?
Browse files Browse the repository at this point in the history
  • Loading branch information
federicazanca committed Jul 31, 2024
1 parent 10338e9 commit 1bfa751
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions aiida_mlip/workflows/training_workgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
from pathlib import Path

from aiida_quantumespresso.workflows.pw.relax import PwRelaxWorkChain
from aiida_workgraph.workgraph import WorkGraph, task
from aiida_workgraph import WorkGraph, task
from ase.io import read
from sklearn.model_selection import train_test_split

from aiida.orm import Dict, SinglefileData, load_code
from aiida.plugins import CalculationFactory, WorkflowFactory, entry_point
from aiida.orm import SinglefileData
from aiida.plugins import CalculationFactory, WorkflowFactory

from aiida_mlip.data.config import JanusConfigfile
from aiida_mlip.helpers.help_load import load_structure
Expand Down Expand Up @@ -67,12 +67,12 @@ def create_input(**inputs: dict) -> SinglefileData:
"""

input_data = []
for name, structure in inputs.items():
for _, structure in inputs.items():
ase_structure = structure.to_ase()
extxyz_str = ase_structure.write(format="extxyz")
input_data.append(extxyz_str)
temp_file_path = "tmp.extxyz"
with open(temp_file_path, "w") as temp_file:
with open(temp_file_path, "w", encoding="utf8") as temp_file:
temp_file.write("\n".join(input_data))

file_data = SinglefileData(file=temp_file_path)
Expand Down Expand Up @@ -113,11 +113,11 @@ def split_xyz_file(xyz_file: SinglefileData) -> dict:
test_path = "test.extxyz"
validation_path = "validation.extxyz"

with open(train_path, "w") as f:
with open(train_path, "w", encoding="utf8") as f:
f.write("\n".join(train_data))
with open(test_path, "w") as f:
with open(test_path, "w", encoding="utf8") as f:
f.write("\n".join(test_data))
with open(validation_path, "w") as f:
with open(validation_path, "w", encoding="utf8") as f:
f.write("\n".join(validation_data))

return {
Expand Down Expand Up @@ -152,12 +152,13 @@ def update_janusconfigfile(janusconfigfile: JanusConfigfile) -> JanusConfigfile:

new_config_path = "./config.yml"

with open(new_config_path, "w") as file:
with open(new_config_path, "w", encoding="utf8") as file:
file.write(content)

return JanusConfigfile(file=new_config_path)


# pylint: disable=unused-variable
def TrainWorkGraph(
folder_path: Path, inputs: dict, janusconfigfile: JanusConfigfile
) -> WorkGraph:
Expand Down Expand Up @@ -202,6 +203,7 @@ def TrainWorkGraph(
training_calc = CalculationFactory("mlip.train")
train_inputs = {}
train_inputs["config_file"] = update_config_task.outputs.result

train_task = wg.add_task(
training_calc, name="training", mlip_config=update_config_task.outputs.result
)
Expand Down

0 comments on commit 1bfa751

Please sign in to comment.