diff --git a/nbs/lag_transforms.ipynb b/nbs/lag_transforms.ipynb index 3ff6ac59..2eb96f58 100644 --- a/nbs/lag_transforms.ipynb +++ b/nbs/lag_transforms.ipynb @@ -541,7 +541,9 @@ "outputs": [], "source": [ "#| hide\n", - "import operator" + "import operator\n", + "\n", + "from mlforecast.grouped_array import GroupedArray as MLGroupedArray" ] }, { @@ -593,8 +595,14 @@ " tfm._set_core_tfm(1)\n", " tfm._get_name(1)\n", " tfm.transform(ga)\n", - " tfm.update(ga)\n", - " tfm.update_samples" + " updates = tfm.update(ga)\n", + " upd_samples = tfm.update_samples\n", + " if upd_samples > -1:\n", + " sliced_ga = MLGroupedArray(ga.data, ga.indptr).take_from_groups(slice(-upd_samples, None))\n", + " ga2 = CoreGroupedArray(sliced_ga.data, sliced_ga.indptr)\n", + " tfm.transform(ga) # to reset state\n", + " updates2 = tfm.update(ga2)\n", + " np.testing.assert_allclose(updates, updates2)" ] } ],