You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Monkey patching like below could fix the problem similarly to repr
# Monkeypatch .numpy() to fetch underlying tensor and call .numpy()_old_numpy=torch.Tensor.numpy@functools.wraps(_old_numpy)def_numpy(tensor):
level=_C.maybe_get_level(tensor)
iflevel==-1:
return_old_numpy(tensor)
if_C.is_functionaltensor(tensor):
# Since we're unwrapping the FunctionalTensorWrapper, we need to make sure# that it's up to date firsttorch._sync(tensor)
value=_C.get_unwrapped(tensor)
dl_enabled=_C.tls_set_is_included()
try:
# Disable temporarily kDynamicLayerFrontModeKey/kDynamicLayerBackModeKey as included dispatch keysif (dl_enabled):
_C._set_dynamic_layer_keys_included(False)
returnvalue.numpy()
finally:
# Reenable kDynamicLayerFrontModeKey/kDynamicLayerBackModeKey as included dispatch keysif (dl_enabled):
_C._set_dynamic_layer_keys_included(True)
setattr(torch.Tensor, 'numpy', _numpy)
In case of vmap, obtained ndarray is batched and not a slice without batch dimension:
Calling
.numpy()
on wrapped tensors, e.g.GradTrackingTensor
,BatchedTensor
How to reproduce
Context: discovered when benchmarking functorch transforms on detr: https://github.com/pytorch/pytorch/blob/58f78ff4e08a6d6a1fc0844dd19bb92fb139bbac/benchmarks/functional_autograd_benchmark/torchvision_models.py#L802-L803
EDIT:
Monkey patching like below could fix the problem similarly to
repr
In case of
vmap
, obtained ndarray is batched and not a slice without batch dimension:The text was updated successfully, but these errors were encountered: