Skip to content

Commit

Permalink
feat: Add load_from_parquet and save_to_parquet methods to Signals class
Browse files Browse the repository at this point in the history
  • Loading branch information
janezlapajne committed Aug 28, 2024
1 parent 36981a5 commit 47ccb2c
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 0 deletions.
9 changes: 9 additions & 0 deletions siapy/entities/signatures.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import dataclass
from pathlib import Path
from typing import Any

import numpy as np
Expand All @@ -13,6 +14,11 @@
class Signals:
_data: pd.DataFrame

@classmethod
def load_from_parquet(cls, filepath: str | Path) -> "Signals":
df = pd.read_parquet(filepath)
return cls(df)

@property
def df(self) -> pd.DataFrame:
return self._data
Expand All @@ -23,6 +29,9 @@ def to_numpy(self) -> np.ndarray:
def mean(self) -> np.ndarray:
return np.nanmean(self.to_numpy(), axis=0)

def save_to_parquet(self, filepath: str | Path) -> None:
self.df.to_parquet(filepath, index=True)


@dataclass
class SignaturesFilter:
Expand Down
16 changes: 16 additions & 0 deletions tests/entities/test_entities_signatures.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import os
from pathlib import Path
from tempfile import TemporaryDirectory

import numpy as np
import pandas as pd
import pytest
Expand All @@ -22,6 +26,18 @@ def test_signals_mean():
assert signals_mean.shape == (4,)


def test_signals_save_and_load_to_parquet():
signals_df = pd.DataFrame([[1, 2, 4, 6], [3, 4, 3, 5]])
signals = Signals(signals_df)
with TemporaryDirectory() as tmpdir:
parquet_file = Path(tmpdir, "test_signals.parquet")
signals.save_to_parquet(parquet_file)
assert os.path.exists(parquet_file)
loaded_signals = Signals.load_from_parquet(parquet_file)
assert isinstance(loaded_signals, Signals)
assert loaded_signals.df.equals(signals.df)


def test_signatures_filter_create():
pixels_df = pd.DataFrame({"u": [0, 1], "v": [0, 1]})
signals_df = pd.DataFrame([[1, 2], [3, 4]])
Expand Down

0 comments on commit 47ccb2c

Please sign in to comment.