From 328c7b520488ce4dda49441fde3f0aa3fc8b98b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Morales?= Date: Mon, 11 Nov 2024 12:31:48 -0600 Subject: [PATCH] one more test --- nbs/lag_transforms.ipynb | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/nbs/lag_transforms.ipynb b/nbs/lag_transforms.ipynb index 3ff6ac59..2eb96f58 100644 --- a/nbs/lag_transforms.ipynb +++ b/nbs/lag_transforms.ipynb @@ -541,7 +541,9 @@ "outputs": [], "source": [ "#| hide\n", - "import operator" + "import operator\n", + "\n", + "from mlforecast.grouped_array import GroupedArray as MLGroupedArray" ] }, { @@ -593,8 +595,14 @@ " tfm._set_core_tfm(1)\n", " tfm._get_name(1)\n", " tfm.transform(ga)\n", - " tfm.update(ga)\n", - " tfm.update_samples" + " updates = tfm.update(ga)\n", + " upd_samples = tfm.update_samples\n", + " if upd_samples > -1:\n", + " sliced_ga = MLGroupedArray(ga.data, ga.indptr).take_from_groups(slice(-upd_samples, None))\n", + " ga2 = CoreGroupedArray(sliced_ga.data, sliced_ga.indptr)\n", + " tfm.transform(ga) # to reset state\n", + " updates2 = tfm.update(ga2)\n", + " np.testing.assert_allclose(updates, updates2)" ] } ],