Skip to content

Commit

Permalink
Fixing tensor.numpy on wrapped tensors
Browse files Browse the repository at this point in the history
Fixes pytorch#626

Description:
- Fixing tensor.numpy on wrapped tensors
  • Loading branch information
vfdev-5 committed Mar 29, 2022
1 parent 9d6ee76 commit a878536
Showing 1 changed file with 31 additions and 0 deletions.
31 changes: 31 additions & 0 deletions functorch/_src/monkey_patching.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,34 @@ def _backward(*args, **kwargs):


setattr(torch.Tensor, 'backward', _backward)


# Monkeypatch .numpy() to fetch underlying tensor and call .numpy()
_old_numpy = torch.Tensor.numpy


@functools.wraps(_old_numpy)
def _numpy(tensor):
level = _C.maybe_get_level(tensor)
if level == -1:
return _old_numpy(tensor)

if _C.is_functionaltensor(tensor):
# Since we're unwrapping the FunctionalTensorWrapper, we need to make sure
# that it's up to date first
torch._sync(tensor)

value = _C.get_unwrapped(tensor)
dl_enabled = _C.tls_set_is_included()
try:
# Disable temporarily kDynamicLayerFrontModeKey/kDynamicLayerBackModeKey as included dispatch keys
if (dl_enabled):
_C._set_dynamic_layer_keys_included(False)
return value.numpy()
finally:
# Reenable kDynamicLayerFrontModeKey/kDynamicLayerBackModeKey as included dispatch keys
if (dl_enabled):
_C._set_dynamic_layer_keys_included(True)


setattr(torch.Tensor, 'numpy', _numpy)

0 comments on commit a878536

Please sign in to comment.