Skip to content

Commit

Permalink
Forgot to remove use of is_empty
Browse files Browse the repository at this point in the history
  • Loading branch information
dantp-ai authored and MischaPanch committed Jul 18, 2024
1 parent d1e2b33 commit 330f5cc
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions tianshou/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def alloc_by_keys_diff(
if key in meta.get_keys():
if isinstance(meta[key], Batch) and isinstance(batch[key], Batch):
alloc_by_keys_diff(meta[key], batch[key], size, stack)
elif isinstance(meta[key], Batch) and meta[key].is_empty():
elif isinstance(meta[key], Batch) and len(meta[key].get_keys()) == 0:
meta[key] = create_value(batch[key], size, stack)
else:
meta[key] = create_value(batch[key], size, stack)
Expand Down Expand Up @@ -768,7 +768,6 @@ def cat_(self, batches: BatchProtocol | Sequence[dict | BatchProtocol]) -> None:
if len(batch) > 0:
batch_list.append(Batch(batch))
elif isinstance(batch, Batch):
# x.is_empty() means that x is Batch() and should be ignored
if len(batch.get_keys()) != 0:
batch_list.append(batch)
else:
Expand All @@ -777,7 +776,7 @@ def cat_(self, batches: BatchProtocol | Sequence[dict | BatchProtocol]) -> None:
return
batches = batch_list
try:
# x.is_empty(recurse=True) here means x is a nested empty batch
# len(batch) here means batch is a nested empty batch
# like Batch(a=Batch), and we have to treat it as length zero and
# keep it.
lens = [0 if len(batch) == 0 else len(batch) for batch in batches]
Expand Down Expand Up @@ -806,7 +805,6 @@ def stack_(self, batches: Sequence[dict | BatchProtocol], axis: int = 0) -> None
if len(batch) > 0:
batch_list.append(Batch(batch))
elif isinstance(batch, Batch):
# x.is_empty() means that x is Batch() and should be ignored
if len(batch.get_keys()) != 0:
batch_list.append(batch)
else:
Expand All @@ -821,7 +819,7 @@ def stack_(self, batches: Sequence[dict | BatchProtocol], axis: int = 0) -> None
{
batch_key
for batch_key, obj in batch.items()
if not (isinstance(obj, BatchProtocol) and obj.is_empty())
if not (isinstance(obj, BatchProtocol) and len(obj.get_keys()) == 0)
}
for batch in batches
]
Expand Down Expand Up @@ -867,7 +865,7 @@ def stack_(self, batches: Sequence[dict | BatchProtocol], axis: int = 0) -> None
# TODO: fix code/annotations s.t. the ignores can be removed
if (
isinstance(value, BatchProtocol) # type: ignore
and value.is_empty() # type: ignore
and len(value.get_keys()) == 0 # type: ignore
):
continue # type: ignore
try:
Expand Down

0 comments on commit 330f5cc

Please sign in to comment.