diff --git a/test/dynamo/test_subclasses.py b/test/dynamo/test_subclasses.py index 387b6bf59b1..8005d6e3a28 100644 --- a/test/dynamo/test_subclasses.py +++ b/test/dynamo/test_subclasses.py @@ -1361,6 +1361,14 @@ class TestNestedTensor(torch._dynamo.test_case.TestCase): self._check_recompiles(fn, (nt,), (nt2,), False) self._check_recompiles(fn, (nt,), (nt3,), True) + def test_inline_nested_tensor_from_jagged(self): + nt, _ = self._get_jagged_tensor(((2, 3, 4), 5), None) + + def fn(x): + return torch.nested.nested_tensor_from_jagged(x.values() * 2, x.offsets()) + + torch.compile(fn, fullgraph=True, backend="aot_eager")(nt) + def _get_views(self): # Test all cases with both an NT base and a dense base # Subclass -> Subclass diff --git a/test/profiler/test_profiler.py b/test/profiler/test_profiler.py index 8e4e31718d9..e149ea379b8 100644 --- a/test/profiler/test_profiler.py +++ b/test/profiler/test_profiler.py @@ -634,6 +634,7 @@ class TestExecutionTrace(TestCase): found_root_node = True assert found_root_node + @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/124500") def test_execution_trace_nested_tensor(self): fp = tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False) fp.close() diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 763f6482cb9..daeb8626c10 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -173,6 +173,7 @@ manual_torch_name_rule_map = { "torch.nn.Parameter": TorchInGraphFunctionVariable, "torch._nested_tensor_from_mask": SkipFunctionVariable, "torch._nested_from_padded": SkipFunctionVariable, + "torch.nested.nested_tensor_from_jagged": UserFunctionVariable, # symbol operators implemented in Python "torch.sym_not": TorchInGraphFunctionVariable, "torch.sym_float": TorchInGraphFunctionVariable,