Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add trials metadata #999

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/neuroconv/datainterfaces/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
# Text
from .text.csv.csvtimeintervalsinterface import CsvTimeIntervalsInterface
from .text.excel.exceltimeintervalsinterface import ExcelTimeIntervalsInterface
from .text.timeintervalsinterface import TimeIntervalsInterface

interface_list = [
# Ecephys
Expand Down Expand Up @@ -162,6 +163,7 @@
# Text
CsvTimeIntervalsInterface,
ExcelTimeIntervalsInterface,
TimeIntervalsInterface,
]

interfaces_by_category = dict(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import pandas as pd

from ..timeintervalsinterface import TimeIntervalsInterface
from ..timeintervalsinterface import BaseTimeIntervalsInterface
from ....utils.types import FilePathType


class CsvTimeIntervalsInterface(TimeIntervalsInterface):
class CsvTimeIntervalsInterface(BaseTimeIntervalsInterface):
"""Interface for adding data from a .csv file as a TimeIntervals object."""

display_name = "CSV time interval table"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

import pandas as pd

from ..timeintervalsinterface import TimeIntervalsInterface
from ..timeintervalsinterface import BaseTimeIntervalsInterface
from ....utils.types import FilePathType


class ExcelTimeIntervalsInterface(TimeIntervalsInterface):
class ExcelTimeIntervalsInterface(BaseTimeIntervalsInterface):
"""Interface for adding data from an Excel file to NWB as a TimeIntervals object."""

display_name = "Excel time interval table"
Expand Down
50 changes: 50 additions & 0 deletions src/neuroconv/datainterfaces/text/time_intervals.schema.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
{
"type": "object",
"patternProperties": {
"^[a-zA-Z0-9_]+$": {
"type": "object",
"properties": {
"description": {
"type": "string"
},
"columns": {
"type": "array",
"items": {
"type": "object",
"properties": {
"name": {
"type": "string"
},
"description": {
"type": "string"
}
},
"required": [
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't description be required as well?]

from pynwb.testing.mock.file import mock_NWBFile

nwbfile = mock_NWBFile()
nwbfile.add_trial_column(name='correct')

Throws an error.

"name"
],
"additionalProperties": false
}
},
"data": {
"type": "array",
"items": {
"type": "object",
"properties": {
"start_time": {
"type": "number"
},
"stop_time": {
"type": "number"
}
},
"required": [
"start_time",
"stop_time"
],
"additionalProperties": true
}
}
}
}
}
}
60 changes: 58 additions & 2 deletions src/neuroconv/datainterfaces/text/timeintervalsinterface.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@

import numpy as np
from pynwb import NWBFile
from pynwb.epoch import TimeIntervals

from ...basedatainterface import BaseDataInterface
from ...tools.text import convert_df_to_time_intervals
from ...utils.dict import load_dict_from_file
from ...utils.dict import DeepDict, load_dict_from_file
from ...utils.types import FilePathType


class TimeIntervalsInterface(BaseDataInterface):
class BaseTimeIntervalsInterface(BaseDataInterface):
"""Abstract Interface for time intervals."""

keywords = ("table", "trials", "epochs", "time intervals")
Expand Down Expand Up @@ -154,3 +155,58 @@ def add_to_nwbfile(
@abstractmethod
def _read_file(self, file_path: FilePathType, **read_kwargs):
pass


class TimeIntervalsInterface(BaseDataInterface):

display_name = "Time Intervals"
keywords = ("table", "trials", "epochs", "time intervals")
associated_suffixes = ()
info = "Interface for time intervals data added to metadata."

@classmethod
def add_to_nwbfile(cls, nwbfile: NWBFile, metadata: DeepDict) -> None:

for table_name, table_metadata in metadata["TimeIntervals"].items():
time_intervals = TimeIntervals(
name=table_name, description=table_metadata.get("description", "no description")
)

# add custom columns
for column_name, column_metadata in table_metadata["columns"].items():
if column_name not in ("start_time", "stop_time"):
time_intervals.add_column(
name=column_name, description=column_metadata.get("description", "no description")
)
# handle any custom columns that were not defined in columns object
for column_name in table_metadata["data"][0]:
if column_name not in time_intervals.colnames + ("start_time", "stop_time"):
time_intervals.add_column(name=key, description="no description")
# add data
for row in table_metadata["data"]:
time_intervals.add_interval(**row)

# add to nwbfile
if table_name == "trials":
nwbfile.trials = time_intervals
else:
nwbfile.add_time_intervals(time_intervals)

def get_metadata_schema(self) -> dict:
metadata_schema = super().get_metadata_schema()

time_intervals_schema = load_dict_from_file(Path(__file__).parent / "time_intervals.schema.json")
metadata_schema["properties"]["TimeIntervals"] = time_intervals_schema
return metadata_schema

def get_metadata(self) -> dict:
metadata = super().get_metadata()
metadata["TimeIntervals"]["trials"] = dict(
columns=dict(
start_time=dict(description="start time of the trial"),
stop_time=dict(description="stop time of the trial"),
),
data=[],
)

return metadata
2 changes: 0 additions & 2 deletions tests/test_minimal/test_tools/test_importing.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,3 @@ def test_guide_attributes():
assert isinstance(
value, tuple
), f"{name} incorrectly specified GUIDE-related attribute 'associated_suffixes' (must be tuple)."
if isinstance(value, tuple):
assert len(value) > 0, f"{name} is missing entries in GUIDE related attribute {key}."
38 changes: 38 additions & 0 deletions tests/test_text/test_time_intervals.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from neuroconv.datainterfaces import (
CsvTimeIntervalsInterface,
ExcelTimeIntervalsInterface,
TimeIntervalsInterface,
)
from neuroconv.tools.nwb_helpers import make_nwbfile_from_metadata
from neuroconv.tools.text import convert_df_to_time_intervals
Expand Down Expand Up @@ -96,3 +97,40 @@ def test_csv_round_trip_rename(tmp_path):
def test_get_metadata_schema():
interface = CsvTimeIntervalsInterface(trials_csv_path)
interface.get_metadata_schema()


def test_trials():
metadata = dict(
NWBFile=dict(
session_start_time=datetime.now().astimezone(),
),
TimeIntervals=dict(
trials=dict(
columns=dict(
start_time=dict(description="start time of the trial"),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this overwritten by the schema by default? the description of start_time and stop_time?

stop_time=dict(description="stop time of the trial"),
correct=dict(description="correct or not"),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was expecting that you had to pass the name inside the dictionary as well from reading the schema, but here, the name here is only passed as the dictionary key.

),
data=[
dict(start_time=0.0, stop_time=1.0, correct=True),
dict(start_time=1.0, stop_time=2.0, correct=False),
],
),
new_table=dict(
columns=dict(
stim_id=dict(description="stimulus ID"),
),
data=[
dict(start_time=0.0, stop_time=1.0, stim_id=0),
dict(start_time=1.0, stop_time=2.0, stim_id=1),
],
),
),
)

nwbfile = TimeIntervalsInterface().create_nwbfile(metadata)
assert nwbfile.trials.correct.description == "correct or not"
assert_array_equal(nwbfile.trials.correct[:], [True, False])
assert_array_equal(nwbfile.trials.start_time[:], [0.0, 1.0])

assert_array_equal(nwbfile.intervals["new_table"].stim_id[:], [0, 1])
Loading