From ce30e676c05b3232c7603f88482d7d306c46c17b Mon Sep 17 00:00:00 2001 From: janezlapajne Date: Wed, 28 Aug 2024 17:34:53 +0200 Subject: [PATCH] feat: Add load_from_parquet and save_to_parquet methods to Signatures class --- siapy/entities/signatures.py | 8 ++++++++ tests/entities/test_entities_signatures.py | 12 ++++++++++++ 2 files changed, 20 insertions(+) diff --git a/siapy/entities/signatures.py b/siapy/entities/signatures.py index 04c282e..ccbe606 100644 --- a/siapy/entities/signatures.py +++ b/siapy/entities/signatures.py @@ -93,6 +93,11 @@ def from_dataframe(cls, dataframe: pd.DataFrame) -> "Signatures": signals = Signals(dataframe.drop(columns=[Pixels.coords.U, Pixels.coords.V])) return cls._create(pixels, signals) + @classmethod + def load_from_parquet(cls, filepath: str | Path) -> "Signatures": + df = pd.read_parquet(filepath) + return cls.from_dataframe(df) + @property def pixels(self) -> Pixels: return self._pixels @@ -109,3 +114,6 @@ def to_numpy(self) -> np.ndarray: def filter(self) -> SignaturesFilter: return SignaturesFilter(self.pixels, self.signals) + + def save_to_parquet(self, filepath: str | Path) -> None: + self.to_dataframe().to_parquet(filepath, index=True) diff --git a/tests/entities/test_entities_signatures.py b/tests/entities/test_entities_signatures.py index 364e849..9a9c4e2 100644 --- a/tests/entities/test_entities_signatures.py +++ b/tests/entities/test_entities_signatures.py @@ -156,3 +156,15 @@ def test_signatures_from_dataframe(): with pytest.raises(ValueError): Signatures.from_dataframe(df_missing_v) + + +def test_signatures_save_and_load_to_parquet(): + df = pd.DataFrame({"u": [0, 1], "v": [0, 1], "0": [1, 2], "1": [3, 4]}) + signatures = Signatures.from_dataframe(df) + with TemporaryDirectory() as tmpdir: + parquet_file = Path(tmpdir, "test_signatures.parquet") + signatures.save_to_parquet(parquet_file) + assert os.path.exists(parquet_file) + loaded_signatures = Signatures.load_from_parquet(parquet_file) + assert isinstance(loaded_signatures, Signatures) + assert loaded_signatures.to_dataframe().equals(signatures.to_dataframe())