diff --git a/functorch/_src/monkey_patching.py b/functorch/_src/monkey_patching.py index 1b507d908..95ae3a87f 100644 --- a/functorch/_src/monkey_patching.py +++ b/functorch/_src/monkey_patching.py @@ -98,3 +98,34 @@ def _backward(*args, **kwargs): setattr(torch.Tensor, 'backward', _backward) + + +# Monkeypatch .numpy() to fetch underlying tensor and call .numpy() +_old_numpy = torch.Tensor.numpy + + +@functools.wraps(_old_numpy) +def _numpy(tensor): + level = _C.maybe_get_level(tensor) + if level == -1: + return _old_numpy(tensor) + + if _C.is_functionaltensor(tensor): + # Since we're unwrapping the FunctionalTensorWrapper, we need to make sure + # that it's up to date first + torch._sync(tensor) + + value = _C.get_unwrapped(tensor) + dl_enabled = _C.tls_set_is_included() + try: + # Disable temporarily kDynamicLayerFrontModeKey/kDynamicLayerBackModeKey as included dispatch keys + if (dl_enabled): + _C._set_dynamic_layer_keys_included(False) + return value.numpy() + finally: + # Reenable kDynamicLayerFrontModeKey/kDynamicLayerBackModeKey as included dispatch keys + if (dl_enabled): + _C._set_dynamic_layer_keys_included(True) + + +setattr(torch.Tensor, 'numpy', _numpy)