diff --git a/src/neuroconv/datainterfaces/__init__.py b/src/neuroconv/datainterfaces/__init__.py index 9028c1d42..f45a78639 100644 --- a/src/neuroconv/datainterfaces/__init__.py +++ b/src/neuroconv/datainterfaces/__init__.py @@ -94,6 +94,7 @@ # Text from .text.csv.csvtimeintervalsinterface import CsvTimeIntervalsInterface from .text.excel.exceltimeintervalsinterface import ExcelTimeIntervalsInterface +from .text.timeintervalsinterface import TimeIntervalsInterface interface_list = [ # Ecephys @@ -162,6 +163,7 @@ # Text CsvTimeIntervalsInterface, ExcelTimeIntervalsInterface, + TimeIntervalsInterface, ] interfaces_by_category = dict( diff --git a/src/neuroconv/datainterfaces/text/csv/csvtimeintervalsinterface.py b/src/neuroconv/datainterfaces/text/csv/csvtimeintervalsinterface.py index 4904e1c19..b1ce92fcd 100644 --- a/src/neuroconv/datainterfaces/text/csv/csvtimeintervalsinterface.py +++ b/src/neuroconv/datainterfaces/text/csv/csvtimeintervalsinterface.py @@ -1,10 +1,10 @@ import pandas as pd -from ..timeintervalsinterface import TimeIntervalsInterface +from ..timeintervalsinterface import BaseTimeIntervalsInterface from ....utils.types import FilePathType -class CsvTimeIntervalsInterface(TimeIntervalsInterface): +class CsvTimeIntervalsInterface(BaseTimeIntervalsInterface): """Interface for adding data from a .csv file as a TimeIntervals object.""" display_name = "CSV time interval table" diff --git a/src/neuroconv/datainterfaces/text/excel/exceltimeintervalsinterface.py b/src/neuroconv/datainterfaces/text/excel/exceltimeintervalsinterface.py index 96a375d50..5851af99e 100644 --- a/src/neuroconv/datainterfaces/text/excel/exceltimeintervalsinterface.py +++ b/src/neuroconv/datainterfaces/text/excel/exceltimeintervalsinterface.py @@ -2,11 +2,11 @@ import pandas as pd -from ..timeintervalsinterface import TimeIntervalsInterface +from ..timeintervalsinterface import BaseTimeIntervalsInterface from ....utils.types import FilePathType -class ExcelTimeIntervalsInterface(TimeIntervalsInterface): +class ExcelTimeIntervalsInterface(BaseTimeIntervalsInterface): """Interface for adding data from an Excel file to NWB as a TimeIntervals object.""" display_name = "Excel time interval table" diff --git a/src/neuroconv/datainterfaces/text/time_intervals.schema.json b/src/neuroconv/datainterfaces/text/time_intervals.schema.json new file mode 100644 index 000000000..32c3a891a --- /dev/null +++ b/src/neuroconv/datainterfaces/text/time_intervals.schema.json @@ -0,0 +1,50 @@ +{ + "type": "object", + "patternProperties": { + "^[a-zA-Z0-9_]+$": { + "type": "object", + "properties": { + "description": { + "type": "string" + }, + "columns": { + "type": "array", + "items": { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "description": { + "type": "string" + } + }, + "required": [ + "name" + ], + "additionalProperties": false + } + }, + "data": { + "type": "array", + "items": { + "type": "object", + "properties": { + "start_time": { + "type": "number" + }, + "stop_time": { + "type": "number" + } + }, + "required": [ + "start_time", + "stop_time" + ], + "additionalProperties": true + } + } + } + } + } +} diff --git a/src/neuroconv/datainterfaces/text/timeintervalsinterface.py b/src/neuroconv/datainterfaces/text/timeintervalsinterface.py index dd7f491f8..a68957e10 100644 --- a/src/neuroconv/datainterfaces/text/timeintervalsinterface.py +++ b/src/neuroconv/datainterfaces/text/timeintervalsinterface.py @@ -4,14 +4,15 @@ import numpy as np from pynwb import NWBFile +from pynwb.epoch import TimeIntervals from ...basedatainterface import BaseDataInterface from ...tools.text import convert_df_to_time_intervals -from ...utils.dict import load_dict_from_file +from ...utils.dict import DeepDict, load_dict_from_file from ...utils.types import FilePathType -class TimeIntervalsInterface(BaseDataInterface): +class BaseTimeIntervalsInterface(BaseDataInterface): """Abstract Interface for time intervals.""" keywords = ("table", "trials", "epochs", "time intervals") @@ -154,3 +155,58 @@ def add_to_nwbfile( @abstractmethod def _read_file(self, file_path: FilePathType, **read_kwargs): pass + + +class TimeIntervalsInterface(BaseDataInterface): + + display_name = "Time Intervals" + keywords = ("table", "trials", "epochs", "time intervals") + associated_suffixes = () + info = "Interface for time intervals data added to metadata." + + @classmethod + def add_to_nwbfile(cls, nwbfile: NWBFile, metadata: DeepDict) -> None: + + for table_name, table_metadata in metadata["TimeIntervals"].items(): + time_intervals = TimeIntervals( + name=table_name, description=table_metadata.get("description", "no description") + ) + + # add custom columns + for column_name, column_metadata in table_metadata["columns"].items(): + if column_name not in ("start_time", "stop_time"): + time_intervals.add_column( + name=column_name, description=column_metadata.get("description", "no description") + ) + # handle any custom columns that were not defined in columns object + for column_name in table_metadata["data"][0]: + if column_name not in time_intervals.colnames + ("start_time", "stop_time"): + time_intervals.add_column(name=key, description="no description") + # add data + for row in table_metadata["data"]: + time_intervals.add_interval(**row) + + # add to nwbfile + if table_name == "trials": + nwbfile.trials = time_intervals + else: + nwbfile.add_time_intervals(time_intervals) + + def get_metadata_schema(self) -> dict: + metadata_schema = super().get_metadata_schema() + + time_intervals_schema = load_dict_from_file(Path(__file__).parent / "time_intervals.schema.json") + metadata_schema["properties"]["TimeIntervals"] = time_intervals_schema + return metadata_schema + + def get_metadata(self) -> dict: + metadata = super().get_metadata() + metadata["TimeIntervals"]["trials"] = dict( + columns=dict( + start_time=dict(description="start time of the trial"), + stop_time=dict(description="stop time of the trial"), + ), + data=[], + ) + + return metadata diff --git a/tests/test_minimal/test_tools/test_importing.py b/tests/test_minimal/test_tools/test_importing.py index 9f800f0e7..1fcff6c45 100644 --- a/tests/test_minimal/test_tools/test_importing.py +++ b/tests/test_minimal/test_tools/test_importing.py @@ -27,5 +27,3 @@ def test_guide_attributes(): assert isinstance( value, tuple ), f"{name} incorrectly specified GUIDE-related attribute 'associated_suffixes' (must be tuple)." - if isinstance(value, tuple): - assert len(value) > 0, f"{name} is missing entries in GUIDE related attribute {key}." diff --git a/tests/test_text/test_time_intervals.py b/tests/test_text/test_time_intervals.py index 761879bd8..698dde9a1 100644 --- a/tests/test_text/test_time_intervals.py +++ b/tests/test_text/test_time_intervals.py @@ -8,6 +8,7 @@ from neuroconv.datainterfaces import ( CsvTimeIntervalsInterface, ExcelTimeIntervalsInterface, + TimeIntervalsInterface, ) from neuroconv.tools.nwb_helpers import make_nwbfile_from_metadata from neuroconv.tools.text import convert_df_to_time_intervals @@ -96,3 +97,40 @@ def test_csv_round_trip_rename(tmp_path): def test_get_metadata_schema(): interface = CsvTimeIntervalsInterface(trials_csv_path) interface.get_metadata_schema() + + +def test_trials(): + metadata = dict( + NWBFile=dict( + session_start_time=datetime.now().astimezone(), + ), + TimeIntervals=dict( + trials=dict( + columns=dict( + start_time=dict(description="start time of the trial"), + stop_time=dict(description="stop time of the trial"), + correct=dict(description="correct or not"), + ), + data=[ + dict(start_time=0.0, stop_time=1.0, correct=True), + dict(start_time=1.0, stop_time=2.0, correct=False), + ], + ), + new_table=dict( + columns=dict( + stim_id=dict(description="stimulus ID"), + ), + data=[ + dict(start_time=0.0, stop_time=1.0, stim_id=0), + dict(start_time=1.0, stop_time=2.0, stim_id=1), + ], + ), + ), + ) + + nwbfile = TimeIntervalsInterface().create_nwbfile(metadata) + assert nwbfile.trials.correct.description == "correct or not" + assert_array_equal(nwbfile.trials.correct[:], [True, False]) + assert_array_equal(nwbfile.trials.start_time[:], [0.0, 1.0]) + + assert_array_equal(nwbfile.intervals["new_table"].stim_id[:], [0, 1])