diff --git a/test/test_fx_experimental.py b/test/test_fx_experimental.py index 6b0dc67ec51..78feaf11eea 100644 --- a/test/test_fx_experimental.py +++ b/test/test_fx_experimental.py @@ -702,13 +702,21 @@ class TestFXExperimental(JitTestCase): torch.testing.assert_close(loaded(x), mttm(x)) def test_proxy_tensor(self): - def f(x): + def f_grad(x): val = x.cos().cos().sum() return torch.autograd.grad(val, x) - traced_graph = make_fx(f)(torch.randn(3, requires_grad=True)) - inp = torch.randn(3, requires_grad=True) - torch.testing.assert_close(traced_graph(inp), f(inp)) + def f_backward(x): + val = x.cos().cos().sum() + val.backward() + return x.grad + + for f in [f_grad, f_backward]: + traced_graph = make_fx(f)(torch.randn(3, requires_grad=True)) + inp = torch.randn(3, requires_grad=True) + traced_graph_out = traced_graph(inp) + assert inp.grad is None + torch.testing.assert_close(traced_graph_out, f(inp)) def test_mode_tracing_factory_function(self): def f(x): diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index 9f144d932ad..d45584d2fd9 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -15,7 +15,7 @@ from contextlib import contextmanager from torch.utils._python_dispatch import push_torch_dispatch_mode, TorchDispatchMode -__all__ = ["ProxyTensor", "PythonKeyTracer", "dispatch_trace", "make_fx"] +__all__ = ["ProxyTensor", "PythonKeyTracer", "dispatch_trace", "make_fx", "enable_strict"] aten = torch.ops.aten CURRENT_DECOMPOSITION_TABLE: Dict[torch._ops.OpOverload, Callable] = {} @@ -40,6 +40,11 @@ def decompose(decomposition_table): finally: CURRENT_DECOMPOSITION_TABLE = old_decomposition_table +# Checks whether we try to convert the tensor into a scalar +IS_STRICT = True +def enable_strict(val): + global IS_STRICT + IS_STRICT = val def wrap_output(real_out, proxy_out): def wrap_with_proxy(e, proxy): @@ -68,7 +73,8 @@ def proxy_call(func_overload, args, kwargs=None): return CURRENT_DECOMPOSITION_TABLE[func_overload](*args, **kwargs) if func_overload == aten._local_scalar_dense.default: raise RuntimeError("It appears that you're trying to get value out of a tracing tensor - erroring out! " - "It's likely that this is caused by data-dependent control flow or similar.") + "It's likely that this is caused by data-dependent control flow or similar." + "Try torch.fx.experimental.proxy_tensor.enable_strict(False) to disable this check") def unwrap_proxy(e): return e.proxy if isinstance(e, ProxyTensor) else e @@ -92,18 +98,24 @@ class ProxyTensor(torch.Tensor): proxy: fx.Proxy @staticmethod - def __new__(cls, elem, proxy): + def __new__(cls, elem, proxy, *, requires_grad=None): # Hack to deal with super().__new__ not working for sparse tensors - if elem.is_sparse: - proxy.node.meta['tensor_meta'] = {} - r = torch.Tensor._make_subclass(cls, elem, elem.requires_grad) + if elem.is_sparse or requires_grad is not None: + r = torch.Tensor._make_subclass(cls, elem, requires_grad) else: r = super().__new__(cls, elem) # type: ignore[call-arg] + + if elem.is_sparse: + proxy.node.meta['tensor_meta'] = {} + else: proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(r) r.proxy = proxy # type: ignore[attr-defined] return r + def __deepcopy__(self, memo): + return self.clone() + def __repr__(self): with no_dispatch(): return f"ProxyTensor({self.as_subclass(torch.Tensor)}, proxy={self.proxy})" # type: ignore[arg-type] @@ -172,7 +184,7 @@ def wrap_key(f, inps): for idx, arg in enumerate(flat_args): if isinstance(flat_inps[idx], torch.Tensor): with no_dispatch(): - flat_args[idx] = ProxyTensor(flat_inps[idx], arg) + flat_args[idx] = ProxyTensor(flat_inps[idx], arg, requires_grad=flat_inps[idx].is_leaf) else: flat_args[idx] = flat_inps[idx]