Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring/remove is empty batch #1144

Merged
merged 4 commits into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ instead of just `nn.Module`. #1032
Can be considered a bugfix. #1063
- The methods `to_numpy` and `to_torch` in are not in-place anymore
(use `to_numpy_` or `to_torch_` instead). #1098, #1117
- The method `Batch.is_empty` has been removed. Instead, the user can simply check for emptiness of Batch by using `len` on dicts. #1144
- Logging:
- `BaseLogger.prepare_dict_for_logging` is now abstract. #1074
- Removed deprecated and unused `BasicLogger` (only affects users who subclassed it). #1074
Expand Down
18 changes: 9 additions & 9 deletions docs/01_tutorials/03_batch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -324,35 +324,35 @@ Still, we can use a tree (in the right) to show the structure of ``Batch`` objec

Reserved keys mean that in the future there will eventually be values attached to them. The values can be scalars, tensors, or even **Batch** objects. Understanding this is critical to understand the behavior of ``Batch`` when dealing with heterogeneous Batches.

The introduction of reserved keys gives rise to the need to check if a key is reserved. Tianshou provides ``Batch.is_empty`` to achieve this.
The introduction of reserved keys gives rise to the need to check if a key is reserved.

.. raw:: html

<details>
<summary>Examples of Batch.is_empty</summary>
<summary>Examples of checking whether Batch is empty</summary>

.. code-block:: python

>>> Batch().is_empty()
>>> len(Batch().get_keys()) == 0
True
>>> Batch(a=Batch(), b=Batch(c=Batch())).is_empty()
>>> len(Batch(a=Batch(), b=Batch(c=Batch())).get_keys()) == 0
False
>>> Batch(a=Batch(), b=Batch(c=Batch())).is_empty(recurse=True)
>>> len(Batch(a=Batch(), b=Batch(c=Batch()))) == 0
True
>>> Batch(d=1).is_empty()
>>> len(Batch(d=1).get_keys()) == 0
False
>>> Batch(a=np.float64(1.0)).is_empty()
>>> len(Batch(a=np.float64(1.0)).get_keys()) == 0
False

.. raw:: html

</details><br>

The ``Batch.is_empty`` function has an option to decide whether to identify direct emptiness (just a ``Batch()``) or to identify recursive emptiness (a ``Batch`` object without any scalar/tensor leaf nodes).
To check whether a Batch is empty, simply use ``len(Batch.get_keys()) == 0`` to decide whether to identify direct emptiness (just a ``Batch()``) or ``len(Batch) == 0`` to identify recursive emptiness (a ``Batch`` object without any scalar/tensor leaf nodes).

.. note::

Do not get confused with ``Batch.is_empty`` and ``Batch.empty``. ``Batch.empty`` and its in-place variant ``Batch.empty_`` are used to set some values to zeros or None. Check the API documentation for further details.
Do not get confused with ``Batch.empty``. ``Batch.empty`` and its in-place variant ``Batch.empty_`` are used to set some values to zeros or None. Check the API documentation for further details.


Length and Shape
Expand Down
36 changes: 18 additions & 18 deletions test/base/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,23 @@

def test_batch() -> None:
assert list(Batch()) == []
assert Batch().is_empty()
assert not Batch(b={"c": {}}).is_empty()
assert Batch(b={"c": {}}).is_empty(recurse=True)
assert not Batch(a=Batch(), b=Batch(c=Batch())).is_empty()
assert Batch(a=Batch(), b=Batch(c=Batch())).is_empty(recurse=True)
assert not Batch(d=1).is_empty()
assert not Batch(a=np.float64(1.0)).is_empty()
assert len(Batch().get_keys()) == 0
assert len(Batch(b={"c": {}}).get_keys()) != 0
assert len(Batch(b={"c": {}})) == 0
assert len(Batch(a=Batch(), b=Batch(c=Batch())).get_keys()) != 0
assert len(Batch(a=Batch(), b=Batch(c=Batch()))) == 0
assert len(Batch(d=1).get_keys()) != 0
assert len(Batch(a=np.float64(1.0)).get_keys()) != 0
assert len(Batch(a=[1, 2, 3], b={"c": {}})) == 3
assert not Batch(a=[1, 2, 3]).is_empty()
assert len(Batch(a=[1, 2, 3]).get_keys()) != 0
b = Batch({"a": [4, 4], "b": [5, 5]}, c=[None, None])
assert b.c.dtype == object
b = Batch(d=[None], e=[starmap], f=Batch)
assert b.d.dtype == b.e.dtype == object
assert b.f == Batch
b = Batch()
b.update()
assert b.is_empty()
assert len(b.get_keys()) == 0
b.update(c=[3, 5])
assert np.allclose(b.c, [3, 5])
# mimic the behavior of dict.update, where kwargs can overwrite keys
Expand Down Expand Up @@ -141,7 +141,7 @@ def test_batch() -> None:
assert batch2_sum.a.b == (batch2.a.b + 1.0) * 2
assert batch2_sum.a.c == (batch2.a.c + 1.0) * 2
assert batch2_sum.a.d.e == (batch2.a.d.e + 1.0) * 2
assert batch2_sum.a.d.f.is_empty()
assert len(batch2_sum.a.d.f.get_keys()) == 0
with pytest.raises(TypeError):
batch2 += [1] # type: ignore # error is raised explicitly
batch3 = Batch(a={"c": np.zeros(1), "d": Batch(e=np.array([0.0]), f=np.array([3.0]))})
Expand Down Expand Up @@ -255,7 +255,7 @@ def test_batch_cat_and_stack() -> None:
ans = Batch.cat([a, b, a])
assert np.allclose(ans.a.a, np.concatenate([a.a.a, np.zeros((3, 4)), a.a.a]))
assert np.allclose(ans.b, np.concatenate([a.b, b.b, a.b]))
assert ans.a.t.is_empty()
assert len(ans.a.t.get_keys()) == 0

b1.stack_([b2])
assert isinstance(b1.a.d.e, np.ndarray)
Expand Down Expand Up @@ -296,7 +296,7 @@ def test_batch_cat_and_stack() -> None:
b=torch.cat([torch.zeros(3, 3), b2.b]),
common=Batch(c=np.concatenate([b1.common.c, b2.common.c])),
)
assert ans.a.is_empty()
assert len(ans.a.get_keys()) == 0
assert torch.allclose(test.b, ans.b)
assert np.allclose(test.common.c, ans.common.c)

Expand Down Expand Up @@ -325,7 +325,7 @@ def test_batch_cat_and_stack() -> None:
assert np.allclose(d.d, [0, 6, 9])

# test stack with empty Batch()
assert Batch.stack([Batch(), Batch(), Batch()]).is_empty()
assert len(Batch.stack([Batch(), Batch(), Batch()]).get_keys()) == 0
a = Batch(a=1, b=2, c=3, d=Batch(), e=Batch())
b = Batch(a=4, b=5, d=6, e=Batch())
c = Batch(c=7, b=6, d=9, e=Batch())
Expand All @@ -334,12 +334,12 @@ def test_batch_cat_and_stack() -> None:
assert np.allclose(d.b, [2, 5, 6])
assert np.allclose(d.c, [3, 0, 7])
assert np.allclose(d.d, [0, 6, 9])
assert d.e.is_empty()
assert len(d.e.get_keys()) == 0
b1 = Batch(a=Batch(), common=Batch(c=np.random.rand(4, 5)))
b2 = Batch(b=Batch(), common=Batch(c=np.random.rand(4, 5)))
test = Batch.stack([b1, b2], axis=-1)
assert test.a.is_empty()
assert test.b.is_empty()
assert len(test.a.get_keys()) == 0
assert len(test.b.get_keys()) == 0
assert np.allclose(test.common.c, np.stack([b1.common.c, b2.common.c], axis=-1))

b1 = Batch(a=np.random.rand(4, 4), common=Batch(c=np.random.rand(4, 5)))
Expand All @@ -362,9 +362,9 @@ def test_batch_cat_and_stack() -> None:

# exceptions
batch_cat: Batch = Batch.cat([])
assert batch_cat.is_empty()
assert len(batch_cat.get_keys()) == 0
batch_stack: Batch = Batch.stack([])
assert batch_stack.is_empty()
assert len(batch_stack.get_keys()) == 0
b1 = Batch(e=[4, 5], d=6)
b2 = Batch(e=[4, 6])
with pytest.raises(ValueError):
Expand Down
4 changes: 2 additions & 2 deletions test/base/test_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1378,5 +1378,5 @@ def test_custom_key() -> None:
sampled_batch.__dict__[key],
Batch,
):
assert batch.__dict__[key].is_empty()
assert sampled_batch.__dict__[key].is_empty()
assert len(batch.__dict__[key].get_keys()) == 0
assert len(sampled_batch.__dict__[key].get_keys()) == 0
78 changes: 19 additions & 59 deletions tianshou/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def alloc_by_keys_diff(
if key in meta.get_keys():
if isinstance(meta[key], Batch) and isinstance(batch[key], Batch):
alloc_by_keys_diff(meta[key], batch[key], size, stack)
elif isinstance(meta[key], Batch) and meta[key].is_empty():
elif isinstance(meta[key], Batch) and len(meta[key].get_keys()) == 0:
meta[key] = create_value(batch[key], size, stack)
else:
meta[key] = create_value(batch[key], size, stack)
Expand Down Expand Up @@ -393,9 +393,6 @@ def update(self, batch: dict | Self | None = None, **kwargs: Any) -> None:
def __len__(self) -> int:
...

def is_empty(self, recurse: bool = False) -> bool:
...

def split(
self,
size: int,
Expand Down Expand Up @@ -514,7 +511,7 @@ def __getitem__(self, index: str | IndexType) -> Any:
if len(batch_items) > 0:
new_batch = Batch()
for batch_key, obj in batch_items:
if isinstance(obj, Batch) and obj.is_empty():
if isinstance(obj, Batch) and len(obj.get_keys()) == 0:
new_batch.__dict__[batch_key] = Batch()
else:
new_batch.__dict__[batch_key] = obj[index]
Expand Down Expand Up @@ -574,13 +571,13 @@ def __iadd__(self, other: Self | Number | np.number) -> Self:
other.__dict__.values(),
strict=True,
): # TODO are keys consistent?
if isinstance(obj, Batch) and obj.is_empty():
if isinstance(obj, Batch) and len(obj.get_keys()) == 0:
continue
self.__dict__[batch_key] += value
return self
if _is_number(other):
for batch_key, obj in self.items():
if isinstance(obj, Batch) and obj.is_empty():
if isinstance(obj, Batch) and len(obj.get_keys()) == 0:
continue
self.__dict__[batch_key] += other
return self
Expand All @@ -594,7 +591,7 @@ def __imul__(self, value: Number | np.number) -> Self:
"""Algebraic multiplication with a scalar value in-place."""
assert _is_number(value), "Only multiplication by a number is supported."
for batch_key, obj in self.__dict__.items():
if isinstance(obj, Batch) and obj.is_empty():
if isinstance(obj, Batch) and len(obj.get_keys()) == 0:
continue
self.__dict__[batch_key] *= value
return self
Expand All @@ -607,7 +604,7 @@ def __itruediv__(self, value: Number | np.number) -> Self:
"""Algebraic division with a scalar value in-place."""
assert _is_number(value), "Only division by a number is supported."
for batch_key, obj in self.__dict__.items():
if isinstance(obj, Batch) and obj.is_empty():
if isinstance(obj, Batch) and len(obj.get_keys()) == 0:
continue
self.__dict__[batch_key] /= value
return self
Expand Down Expand Up @@ -722,7 +719,7 @@ def __cat(self, batches: Sequence[dict | Self], lens: list[int]) -> None:
{
batch_key
for batch_key, obj in batch.items()
if not (isinstance(obj, Batch) and obj.is_empty())
if not (isinstance(obj, Batch) and len(obj.get_keys()) == 0)
}
for batch in batches
]
Expand Down Expand Up @@ -753,7 +750,7 @@ def __cat(self, batches: Sequence[dict | Self], lens: list[int]) -> None:
if key not in batch.__dict__:
continue
value = batch.get(key)
if isinstance(value, Batch) and value.is_empty():
if isinstance(value, Batch) and len(value.get_keys()) == 0:
continue
try:
self.__dict__[key][sum_lens[i] : sum_lens[i + 1]] = value
Expand All @@ -771,28 +768,27 @@ def cat_(self, batches: BatchProtocol | Sequence[dict | BatchProtocol]) -> None:
if len(batch) > 0:
batch_list.append(Batch(batch))
elif isinstance(batch, Batch):
# x.is_empty() means that x is Batch() and should be ignored
if not batch.is_empty():
if len(batch.get_keys()) != 0:
batch_list.append(batch)
else:
raise ValueError(f"Cannot concatenate {type(batch)} in Batch.cat_")
if len(batch_list) == 0:
return
batches = batch_list
try:
# x.is_empty(recurse=True) here means x is a nested empty batch
# len(batch) here means batch is a nested empty batch
# like Batch(a=Batch), and we have to treat it as length zero and
# keep it.
lens = [0 if batch.is_empty(recurse=True) else len(batch) for batch in batches]
lens = [0 if len(batch) == 0 else len(batch) for batch in batches]
except TypeError as exception:
raise ValueError(
"Batch.cat_ meets an exception. Maybe because there is any "
f"scalar in {batches} but Batch.cat_ does not support the "
"concatenation of scalar.",
) from exception
if not self.is_empty():
if len(self.get_keys()) != 0:
batches = [self, *list(batches)]
lens = [0 if self.is_empty(recurse=True) else len(self), *lens]
lens = [0 if len(self) == 0 else len(self), *lens]
self.__cat(batches, lens)

@staticmethod
Expand All @@ -809,22 +805,21 @@ def stack_(self, batches: Sequence[dict | BatchProtocol], axis: int = 0) -> None
if len(batch) > 0:
batch_list.append(Batch(batch))
elif isinstance(batch, Batch):
# x.is_empty() means that x is Batch() and should be ignored
if not batch.is_empty():
if len(batch.get_keys()) != 0:
batch_list.append(batch)
else:
raise ValueError(f"Cannot concatenate {type(batch)} in Batch.stack_")
if len(batch_list) == 0:
return
batches = batch_list
if not self.is_empty():
if len(self.get_keys()) != 0:
batches = [self, *batches]
# collect non-empty keys
keys_map = [
{
batch_key
for batch_key, obj in batch.items()
if not (isinstance(obj, BatchProtocol) and obj.is_empty())
if not (isinstance(obj, BatchProtocol) and len(obj.get_keys()) == 0)
}
for batch in batches
]
Expand Down Expand Up @@ -870,7 +865,7 @@ def stack_(self, batches: Sequence[dict | BatchProtocol], axis: int = 0) -> None
# TODO: fix code/annotations s.t. the ignores can be removed
if (
isinstance(value, BatchProtocol) # type: ignore
and value.is_empty() # type: ignore
and len(value.get_keys()) == 0 # type: ignore
):
continue # type: ignore
try:
Expand Down Expand Up @@ -930,7 +925,7 @@ def __len__(self) -> int:
# TODO: causes inconsistent behavior to batch with empty batches
# and batch with empty sequences of other type. Remove, but only after
# Buffer and Collectors have been improved to no longer rely on this
if isinstance(obj, Batch) and obj.is_empty(recurse=True):
if isinstance(obj, Batch) and len(obj) == 0:
continue
if hasattr(obj, "__len__") and (isinstance(obj, Batch) or obj.ndim > 0):
lens.append(len(obj))
Expand All @@ -940,45 +935,10 @@ def __len__(self) -> int:
return 0
return min(lens)

def is_empty(self, recurse: bool = False) -> bool:
"""Test if a Batch is empty.

If ``recurse=True``, it further tests the values of the object; else
it only tests the existence of any key.

``b.is_empty(recurse=True)`` is mainly used to distinguish
``Batch(a=Batch(a=Batch()))`` and ``Batch(a=1)``. They both raise
exceptions when applied to ``len()``, but the former can be used in
``cat``, while the latter is a scalar and cannot be used in ``cat``.

Another usage is in ``__len__``, where we have to skip checking the
length of recursively empty Batch.
::

>>> Batch().is_empty()
True
>>> Batch(a=Batch(), b=Batch(c=Batch())).is_empty()
False
>>> Batch(a=Batch(), b=Batch(c=Batch())).is_empty(recurse=True)
True
>>> Batch(d=1).is_empty()
False
>>> Batch(a=np.float64(1.0)).is_empty()
False
"""
if len(self.__dict__) == 0:
return True
if not recurse:
return False
return all(
False if not isinstance(obj, Batch) else obj.is_empty(recurse=True)
for obj in self.values()
)

@property
def shape(self) -> list[int]:
"""Return self.shape."""
if self.is_empty():
if len(self.get_keys()) == 0:
return []
data_shape = []
for obj in self.__dict__.values():
Expand Down
6 changes: 3 additions & 3 deletions tianshou/data/buffer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def update(self, buffer: "ReplayBuffer") -> np.ndarray:
self._index = (self._index + 1) % self.maxsize
self._size = min(self._size + 1, self.maxsize)
to_indices = np.array(to_indices)
if self._meta.is_empty():
if len(self._meta.get_keys()) == 0:
self._meta = create_value(buffer._meta, self.maxsize, stack=False) # type: ignore
self._meta[to_indices] = buffer._meta[from_indices]
return to_indices
Expand Down Expand Up @@ -284,7 +284,7 @@ def add(
batch.done = batch.done.astype(bool)
batch.terminated = batch.terminated.astype(bool)
batch.truncated = batch.truncated.astype(bool)
if self._meta.is_empty():
if len(self._meta.get_keys()) == 0:
self._meta = create_value(batch, self.maxsize, stack) # type: ignore
else: # dynamic key pops up in batch
alloc_by_keys_diff(self._meta, batch, self.maxsize, stack)
Expand Down Expand Up @@ -377,7 +377,7 @@ def get(
return np.stack(stack, axis=indices.ndim)

except IndexError as exception:
if not (isinstance(val, Batch) and val.is_empty()):
if not (isinstance(val, Batch) and len(val.get_keys()) == 0):
raise exception # val != Batch()
return Batch()

Expand Down
Loading
Loading