From 47ccb2cc6ba582c6affeb4f5c06582a5666eca81 Mon Sep 17 00:00:00 2001 From: janezlapajne Date: Wed, 28 Aug 2024 17:23:50 +0200 Subject: [PATCH] feat: Add load_from_parquet and save_to_parquet methods to Signals class --- siapy/entities/signatures.py | 9 +++++++++ tests/entities/test_entities_signatures.py | 16 ++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/siapy/entities/signatures.py b/siapy/entities/signatures.py index 7cfb88f..04c282e 100644 --- a/siapy/entities/signatures.py +++ b/siapy/entities/signatures.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from pathlib import Path from typing import Any import numpy as np @@ -13,6 +14,11 @@ class Signals: _data: pd.DataFrame + @classmethod + def load_from_parquet(cls, filepath: str | Path) -> "Signals": + df = pd.read_parquet(filepath) + return cls(df) + @property def df(self) -> pd.DataFrame: return self._data @@ -23,6 +29,9 @@ def to_numpy(self) -> np.ndarray: def mean(self) -> np.ndarray: return np.nanmean(self.to_numpy(), axis=0) + def save_to_parquet(self, filepath: str | Path) -> None: + self.df.to_parquet(filepath, index=True) + @dataclass class SignaturesFilter: diff --git a/tests/entities/test_entities_signatures.py b/tests/entities/test_entities_signatures.py index 71aa5ee..364e849 100644 --- a/tests/entities/test_entities_signatures.py +++ b/tests/entities/test_entities_signatures.py @@ -1,3 +1,7 @@ +import os +from pathlib import Path +from tempfile import TemporaryDirectory + import numpy as np import pandas as pd import pytest @@ -22,6 +26,18 @@ def test_signals_mean(): assert signals_mean.shape == (4,) +def test_signals_save_and_load_to_parquet(): + signals_df = pd.DataFrame([[1, 2, 4, 6], [3, 4, 3, 5]]) + signals = Signals(signals_df) + with TemporaryDirectory() as tmpdir: + parquet_file = Path(tmpdir, "test_signals.parquet") + signals.save_to_parquet(parquet_file) + assert os.path.exists(parquet_file) + loaded_signals = Signals.load_from_parquet(parquet_file) + assert isinstance(loaded_signals, Signals) + assert loaded_signals.df.equals(signals.df) + + def test_signatures_filter_create(): pixels_df = pd.DataFrame({"u": [0, 1], "v": [0, 1]}) signals_df = pd.DataFrame([[1, 2], [3, 4]])