[dynamo] Replace unimplemented with unimplemented_v2 in torch/_dynamo/variables/torch.py (#157344)

Fixes part of #147913

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157344
Approved by: https://github.com/williamwen42

Co-authored-by: William Wen <william.wen42@gmail.com>
This commit is contained in:
zeshengzong 2025-07-08 00:46:56 +00:00 committed by PyTorch MergeBot
parent 85111cd165
commit df72078fe1
4 changed files with 525 additions and 88 deletions

View File

@ -671,13 +671,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
fn(p) fn(p)
self.assertFalse(True) # must raise error before this self.assertFalse(True) # must raise error before this
except torch._dynamo.exc.Unsupported as e: except torch._dynamo.exc.Unsupported as e:
msg = """ self.assertIn("Invalid input type for nonstrict_trace-ed function", str(e))
For `nonstrict_trace`-ed function, the only allowed input types are basic types (e.g., torch.Tensor, int, float) or pytree containers of those. Here you are calling the function with arguments that contain a value of type <DecoratorTests.test_nonstrict_trace_custom_class_error.<locals>.Point>, please use one of the following to register the type with pytree:
* `torch.utils._pytree.register_constant`
* `torch.utils._pytree.register_dataclass`
* `torch.utils._pytree.register_pytree_node`
""" # NOQA: B950
self.assertIn(msg, str(e))
def test_nonstrict_trace_nested_custom_class_error(self): def test_nonstrict_trace_nested_custom_class_error(self):
class Point: class Point:
@ -723,13 +717,7 @@ For `nonstrict_trace`-ed function, the only allowed input types are basic types
fn(torch.ones(10), torch.ones(1)) fn(torch.ones(10), torch.ones(1))
self.assertFalse(True) # must raise error before this self.assertFalse(True) # must raise error before this
except torch._dynamo.exc.Unsupported as e: except torch._dynamo.exc.Unsupported as e:
msg = """ self.assertIn("Invalid input type for nonstrict_trace-ed function", str(e))
For `nonstrict_trace`-ed function, the only allowed input types are basic types (e.g., torch.Tensor, int, float) or pytree containers of those. Here you are calling the function with arguments that contain a value of type <DecoratorTests.test_nonstrict_trace_nested_custom_class_error.<locals>.Point>, please use one of the following to register the type with pytree:
* `torch.utils._pytree.register_constant`
* `torch.utils._pytree.register_dataclass`
* `torch.utils._pytree.register_pytree_node`
""" # NOQA: B950
self.assertIn(msg, str(e))
def test_nonstrict_newly_constructed_trace_register_constant_type_error(self): def test_nonstrict_newly_constructed_trace_register_constant_type_error(self):
class State: class State:
@ -766,12 +754,10 @@ For `nonstrict_trace`-ed function, the only allowed input types are basic types
fn(x) fn(x)
self.assertFalse(True) # must raise error before this self.assertFalse(True) # must raise error before this
except torch._dynamo.exc.Unsupported as e: except torch._dynamo.exc.Unsupported as e:
msg = """ self.assertIn(
You are calling a `nonstrict_trace`-ed function with an input that contains an object of type <DecoratorTests.test_nonstrict_newly_constructed_trace_register_constant_type_error.<locals>.State>, which was marked with `pytree.register_constant`. However, the object was constructed _inside_ the `torch.compile` region. "Input marked with `pytree.register_constant` constructed in the `torch.compile` region",
str(e),
Please construct the object _outside_ the `torch.compile` region, or submit an issue to GitHub. )
""" # NOQA: B950
self.assertIn(msg, str(e))
def test_nonstrict_trace_object_in_context_error(self): def test_nonstrict_trace_object_in_context_error(self):
class Point: class Point:
@ -814,17 +800,9 @@ Please construct the object _outside_ the `torch.compile` region, or submit an i
fn(x, y) fn(x, y)
self.assertFalse(True) # must raise error before this self.assertFalse(True) # must raise error before this
except torch._dynamo.exc.Unsupported as e: except torch._dynamo.exc.Unsupported as e:
msg = """ self.assertIn(
You are calling a `nonstrict_trace`-ed function where one one of the inputs has been registered with a `pytree_flatten` that puts an object of type <DecoratorTests.test_nonstrict_trace_object_in_context_error.<locals>.Point> into the context. "Invalid use of pytree_flatten with nonstrict_trace-ed function", str(e)
)
Please consider modifying that `pytree_flatten` to avoid putting the object into context, and apply one of the following to <DecoratorTests.test_nonstrict_trace_object_in_context_error.<locals>.Point>
* `torch.utils._pytree.register_constant`
* `torch.utils._pytree.register_dataclass`
* `torch.utils._pytree.register_pytree_node`
If the above doesn't work, please subtmit an issue to GitHub.
""" # NOQA: B950
self.assertIn(msg, str(e))
def test_graph_break(self): def test_graph_break(self):
cnts = torch._dynamo.testing.CompileCounter() cnts = torch._dynamo.testing.CompileCounter()

View File

@ -232,7 +232,7 @@ class TorchFunctionModeTests(torch._dynamo.test_case.TestCase):
self.assertRaisesRegex( self.assertRaisesRegex(
torch._dynamo.exc.Unsupported, torch._dynamo.exc.Unsupported,
"Popping from an empty torch function mode stack", "Attempted to pop from empty torch function mode stack",
lambda: fn(torch.ones(2, 2)), lambda: fn(torch.ones(2, 2)),
) )

View File

@ -2172,9 +2172,6 @@
"Hints": [ "Hints": [
"Don't mutate `.data` on this tensor, or move ", "Don't mutate `.data` on this tensor, or move ",
"the mutation out of `torch.compile` region" "the mutation out of `torch.compile` region"
],
"Additional_Info": [
"INFO"
] ]
} }
], ],
@ -2199,5 +2196,278 @@
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues."
] ]
} }
],
"GB0223": [
{
"Gb_type": "torch.compile call with > 1 args",
"Context": "args={args}, kwargs={kwargs}",
"Explanation": "Attempted to call `torch.compile` with > 1 args. Dynamo does not support this.",
"Hints": [
"Remove the torch.compile call or its additional args.",
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues."
]
}
],
"GB0224": [
{
"Gb_type": "Attempted to call torch in-graph function on only torch.SymInt arguments",
"Context": "fn={self.value}, args={args}, kwargs={kwargs}",
"Explanation": "Attempted to call {str(self.value)} (that should be put in the FX graph) on only torch.SymInt arguments. Dynamo does not support this.",
"Hints": [
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues."
]
}
],
"GB0225": [
{
"Gb_type": "Attempted to use tensor creation function with requires_grad=True",
"Context": "fn={self.value}, args={args}, kwargs={kwargs}",
"Explanation": "Dynamo does not support this.",
"Hints": [
"Create the tensor outside the compiled region.",
"Do not set `requires_grad=True`.",
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues."
]
}
],
"GB0226": [
{
"Gb_type": "`torch.nn.Parameter()` with unsupported data type",
"Context": "data={data}",
"Explanation": "Called `torch.nn.Parameter()` with non-Tensor argument.",
"Hints": [
"Ensure the argument to `torch.nn.Parameter()` is a `torch.Tensor`.",
"Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled."
]
}
],
"GB0227": [
{
"Gb_type": "Attempted to use torch.nn.Parameter constructor with tensor subclass",
"Context": "str(data)",
"Explanation": "Dynamo does not support this.",
"Hints": [
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues."
]
}
],
"GB0228": [
{
"Gb_type": "`torch.nn.Parameter`: cannot convert to traceable tracable",
"Context": "",
"Explanation": "convert_tracable_parameter is set to False.",
"Hints": [
"Check usage of context manager: do_not_convert_to_tracable_parameter",
"This graph break may be difficult to debug. Please report an issue to PyTorch for assistance."
]
}
],
"GB0229": [
{
"Gb_type": "Unexpected type of data placeholder op for parameter construction",
"Context": "data_node.op={data_node.op}",
"Explanation": "Data node op should be placeholder or get_attr.",
"Hints": [
"This graph break may be difficult to debug. Please report an issue to PyTorch for assistance."
]
}
],
"GB0230": [
{
"Gb_type": "Attempted to use torch.use_deterministic_algorithms(warn_only=True)",
"Context": "mode={mode}, warn_only={warn_only}",
"Explanation": "Dynamo does not support this.",
"Hints": [
"Remove param warn_only in function call torch.use_deterministic_algorithms.",
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues."
]
}
],
"GB0231": [
{
"Gb_type": "call `torch.from_numpy` with `torch._dynamo.config.trace_numpy=False`",
"Context": "trace_numpy={config.trace_numpy}",
"Explanation": "Attempted to call `torch.from_numpy` with config `torch._dynamo.config.trace_numpy` set to `False`.",
"Hints": [
"Change `torch._dynamo.config.trace_numpy` to `True`."
]
}
],
"GB0232": [
{
"Gb_type": "`torch.from_numpy` with NumPy unavailable",
"Context": "",
"Explanation": "Attempted to call `torch.numpy` but NumPy could not be imported.",
"Hints": [
"Check NumPy version and installation in your environment.",
"Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled."
]
}
],
"GB0233": [
{
"Gb_type": "Attempted to use strided NestedTensor",
"Context": "layout={layout}",
"Explanation": "Dynamo does not support this.",
"Hints": [
"Change layout=torch.jagged.",
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues."
]
}
],
"GB0234": [
{
"Gb_type": "Attempted to pop from empty torch function mode stack",
"Context": "",
"Explanation": "Called `torch._C._pop_torch_function_stack` when torch function mode stack is empty.",
"Hints": [
"Do not pop from empty torch function mode stack.",
"Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled."
]
}
],
"GB0235": [
{
"Gb_type": "`torch.nn.Parameter` with non-constant Tensor attributes",
"Context": "data={data}",
"Explanation": "Dynamo does not support this.",
"Hints": [
"Ensure the Tensor argument's shape, dtype, and device are correct.",
"Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled."
]
}
],
"GB0236": [
{
"Gb_type": "Invalid input type for nonstrict_trace-ed function",
"Context": "Encountered input of type <{type_name}>.",
"Explanation": "For `nonstrict_trace`-ed functions, only basic types (e.g., torch.Tensor, int, float) or pytree containers of those are allowed as inputs. The provided argument contains an unsupported type.",
"Hints": [
"Use one of the following to register the type with pytree:\n",
"* `torch.utils._pytree.register_constant`\n",
"* `torch.utils._pytree.register_dataclass`\n",
"* `torch.utils._pytree.register_pytree_node`"
]
}
],
"GB0237": [
{
"Gb_type": "non-constant `requires_grad` argument to `torch.nn.Parameter`",
"Context": "requires_grad={requires_grad}",
"Explanation": "Dynamo does not support this.",
"Hints": [
"Change `requires_grad` to be a bool.",
"Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled."
]
}
],
"GB0238": [
{
"Gb_type": "Input marked with `pytree.register_constant` constructed in the `torch.compile` region",
"Context": "Input={input_spec_vt}, offending type <{type_name}>.",
"Explanation": "Calling a `nonstrict_trace`-ed function with an input that contains an object of type <{type_name}>, which was marked with `pytree.register_constant`. However, the object was constructed _inside_ the `torch.compile` region. This is not supported.",
"Hints": [
"Construct the object _outside_ the `torch.compile` region, or submit an issue to GitHub.",
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues."
]
}
],
"GB0239": [
{
"Gb_type": "Invalid use of pytree_flatten with nonstrict_trace-ed function",
"Context": "Input={input_spec_vt}, offending type <{type_name}>.",
"Explanation": "Calling a `nonstrict_trace`-ed function where one of the inputs has been registered with a `pytree_flatten` that places an object of type <{type_name}> into the context.",
"Hints": [
"Modifying the `pytree_flatten` to avoid placing the object into the context.",
"Apply one of the following to <{type_name}>:\n",
"* `torch.utils._pytree.register_constant`\n",
"* `torch.utils._pytree.register_dataclass`\n",
"* `torch.utils._pytree.register_pytree_node`",
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues."
]
}
],
"GB0240": [
{
"Gb_type": "Shape mismatch with out= list of tensor variants",
"Context": "fn={self.value}, args={args}, kwargs={kwargs}",
"Explanation": "Shape mismatch when calling {self.value} with `out=`. Provided `out=` shape: {saved_out_shape}. Actual shape: {fake_out.shape}.",
"Hints": [
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues."
]
}
],
"GB0241": [
{
"Gb_type": "Attempted to call op with non-contiguous `out=` list of tensors",
"Context": "self.value={self.value}, args={args}, kwargs={kwargs}",
"Explanation": "Dynamo does not support this.",
"Hints": [
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues."
]
}
],
"GB0242": [
{
"Gb_type": "Attempted to call op with non-contiguous `out=` tensor",
"Context": "self.value={self.value}, args={args}, kwargs={kwargs}",
"Explanation": "Dynamo does not support this.",
"Hints": [
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues."
]
}
],
"GB0243": [
{
"Gb_type": "Attempted to use `torch.nn.modules.utils._ntuple` with unsupported argument type",
"Context": "value={value}",
"Explanation": "Dynamo does not support this.",
"Hints": [
"Change use of _ntuple with argument as constant or tensor."
]
}
],
"GB0244": [
{
"Gb_type": "Attempted to use `torch.nn.Parameter()` with export",
"Context": "",
"Explanation": "Dynamo does not support this.",
"Hints": [
"Do not use `torch.nn.Parameter()` with export.",
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues."
]
}
],
"GB0245": [
{
"Gb_type": "Attempted to use `nested_tensor` with non-list input",
"Context": "tensor_list={tensor_list}",
"Explanation": "Dynamo does not support this.",
"Hints": [
"Change `nested_tensor` with list input.",
"Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled."
]
}
],
"GB0246": [
{
"Gb_type": "Attempted to use `torch.nn.functional.one_hot` with data-dependent output shape",
"Context": "args={args}, kwargs={kwargs}",
"Explanation": "Dynamo does not support this.",
"Hints": [
"Explicitly set the `num_classes` param of the function call ",
"`torch.nn.functional.one_hot` to something other than -1."
]
}
],
"GB0247": [
{
"Gb_type": "Shape mismatch with out= tensor variant",
"Context": "fn={self.value}, args={args}, kwargs={kwargs}",
"Explanation": "Shape mismatch when calling {self.value} with `out=`. Provided `out=` shape: {saved_out_shapes}. Actual shape: {fake_out.shape}.",
"Hints": [
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues."
]
}
] ]
} }

View File

@ -52,7 +52,7 @@ from ..create_parameter_op import (
tracable_create_parameter, tracable_create_parameter,
) )
from ..device_interface import get_registered_device_interfaces from ..device_interface import get_registered_device_interfaces
from ..exc import unimplemented, unimplemented_v2 from ..exc import unimplemented_v2
from ..guards import GuardBuilder, install_guard from ..guards import GuardBuilder, install_guard
from ..source import CallFunctionNoArgsSource, SyntheticLocalSource from ..source import CallFunctionNoArgsSource, SyntheticLocalSource
from ..utils import ( from ..utils import (
@ -588,7 +588,15 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
# torch.compile is a no-op in dynamo # torch.compile is a no-op in dynamo
return args[0] return args[0]
unimplemented("torch.compile is used as a decorator in the compiled frame") unimplemented_v2(
gb_type="torch.compile call with > 1 args",
context=f"args={args}, kwargs={kwargs}",
explanation="Attempted to call `torch.compile` with > 1 args. Dynamo does not support this.",
hints=[
"Remove the torch.compile call or its additional args.",
*graph_break_hints.SUPPORTABLE,
],
)
@register(*REWRITE_OPS_TO_TENSOR_SIZE_METHOD) @register(*REWRITE_OPS_TO_TENSOR_SIZE_METHOD)
def handle_tensor_size_rewrites(self, tx: "InstructionTranslator", input): def handle_tensor_size_rewrites(self, tx: "InstructionTranslator", input):
@ -615,7 +623,15 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
self, tx: "InstructionTranslator", mode, warn_only=False self, tx: "InstructionTranslator", mode, warn_only=False
): ):
if warn_only and warn_only.as_python_constant(): if warn_only and warn_only.as_python_constant():
unimplemented("torch.use_deterministic_algorithms(warn_only=True)") unimplemented_v2(
gb_type="Attempted to use torch.use_deterministic_algorithms(warn_only=True)",
context=f"mode={mode}, warn_only={warn_only}",
explanation="Dynamo does not support this.",
hints=[
"Remove param warn_only in function call torch.use_deterministic_algorithms.",
*graph_break_hints.SUPPORTABLE,
],
)
return DeterministicAlgorithmsVariable.create(tx, mode.as_python_constant()) return DeterministicAlgorithmsVariable.create(tx, mode.as_python_constant())
@register(torch.are_deterministic_algorithms_enabled) @register(torch.are_deterministic_algorithms_enabled)
@ -666,9 +682,27 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
@register(torch.from_numpy) @register(torch.from_numpy)
def handle_from_numpy(self, tx: "InstructionTranslator", *args): def handle_from_numpy(self, tx: "InstructionTranslator", *args):
if not config.trace_numpy: if not config.trace_numpy:
unimplemented("torch.from_numpy. config.trace_numpy is False") unimplemented_v2(
gb_type="call `torch.from_numpy` with `torch._dynamo.config.trace_numpy=False`",
context=f"trace_numpy={config.trace_numpy}",
explanation=(
"Attempted to call `torch.from_numpy` with config "
"`torch._dynamo.config.trace_numpy` set to `False`."
),
hints=[
"Change `torch._dynamo.config.trace_numpy` to `True`.",
],
)
if not np: if not np:
unimplemented("torch.from_numpy. NumPy is not available") unimplemented_v2(
gb_type="`torch.from_numpy` with NumPy unavailable",
context="",
explanation="Attempted to call `torch.numpy` but NumPy could not be imported.",
hints=[
"Check NumPy version and installation in your environment.",
*graph_break_hints.USER_ERROR,
],
)
return wrap_fx_proxy_cls( return wrap_fx_proxy_cls(
target_cls=TensorVariable, target_cls=TensorVariable,
tx=tx, tx=tx,
@ -880,9 +914,25 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
from .lists import BaseListVariable from .lists import BaseListVariable
if layout and layout.as_python_constant() == torch.strided: if layout and layout.as_python_constant() == torch.strided:
unimplemented("torch.compile does not support strided NestedTensor") unimplemented_v2(
gb_type="Attempted to use strided NestedTensor",
context=f"layout={layout}",
explanation="Dynamo does not support this.",
hints=[
"Change layout=torch.jagged.",
*graph_break_hints.SUPPORTABLE,
],
)
if not isinstance(tensor_list, BaseListVariable): if not isinstance(tensor_list, BaseListVariable):
unimplemented("nested_tensor with non-list input") unimplemented_v2(
gb_type="Attempted to use `nested_tensor` with non-list input",
context=f"tensor_list={tensor_list}",
explanation="Dynamo does not support this.",
hints=[
"Change `nested_tensor` with list input.",
*graph_break_hints.USER_ERROR,
],
)
@register(torch.nn.functional.one_hot) @register(torch.nn.functional.one_hot)
def handle_one_hot(self, tx: "InstructionTranslator", *args, **kwargs): def handle_one_hot(self, tx: "InstructionTranslator", *args, **kwargs):
@ -891,8 +941,14 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
and args[1].is_python_constant() and args[1].is_python_constant()
and args[1].as_python_constant() == -1 and args[1].as_python_constant() == -1
): ):
unimplemented( unimplemented_v2(
"torch.nn.functional.one_hot with data-dependent output shape" gb_type="Attempted to use `torch.nn.functional.one_hot` with data-dependent output shape",
context=f"args={args}, kwargs={kwargs}",
explanation="Dynamo does not support this.",
hints=[
"Explicitly set the `num_classes` param of the function call "
"`torch.nn.functional.one_hot` to something other than -1.",
],
) )
@register(torch.fx.experimental.symbolic_shapes.guard_size_oblivious) @register(torch.fx.experimental.symbolic_shapes.guard_size_oblivious)
@ -1061,7 +1117,15 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
): ):
assert not args and not kwargs assert not args and not kwargs
if not tx.symbolic_torch_function_state.mode_stack: if not tx.symbolic_torch_function_state.mode_stack:
raise unimplemented("Popping from an empty torch function mode stack") unimplemented_v2(
gb_type="Attempted to pop from empty torch function mode stack",
context="",
explanation="Called `torch._C._pop_torch_function_stack` when torch function mode stack is empty.",
hints=[
"Do not pop from empty torch function mode stack.",
*graph_break_hints.USER_ERROR,
],
)
TorchFunctionModeStackVariable.register_mutation(tx) TorchFunctionModeStackVariable.register_mutation(tx)
return tx.symbolic_torch_function_state.pop_torch_function_mode() return tx.symbolic_torch_function_state.pop_torch_function_mode()
@ -1152,13 +1216,20 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
arg_type = flat_arg_vt.python_type() arg_type = flat_arg_vt.python_type()
if not is_graphable_type(arg_type): if not is_graphable_type(arg_type):
type_name = flat_arg_vt.python_type().__qualname__ type_name = flat_arg_vt.python_type().__qualname__
unimplemented( unimplemented_v2(
f""" gb_type="Invalid input type for nonstrict_trace-ed function",
For `nonstrict_trace`-ed function, the only allowed input types are basic types (e.g., torch.Tensor, int, float) or pytree containers of those. Here you are calling the function with arguments that contain a value of type <{type_name}>, please use one of the following to register the type with pytree: context=f"Encountered input of type <{type_name}>.",
* `torch.utils._pytree.register_constant` explanation=(
* `torch.utils._pytree.register_dataclass` "For `nonstrict_trace`-ed functions, only basic types (e.g., torch.Tensor, int, float) "
* `torch.utils._pytree.register_pytree_node` "or pytree containers of those are allowed as inputs. The provided argument contains "
""" # NOQA: B950 "an unsupported type."
),
hints=[
"Use one of the following to register the type with pytree:\n"
"* `torch.utils._pytree.register_constant`\n"
"* `torch.utils._pytree.register_dataclass`\n"
"* `torch.utils._pytree.register_pytree_node`",
],
) )
# Since we checked with `is_graphable` above, `as_proxy` on the # Since we checked with `is_graphable` above, `as_proxy` on the
@ -1179,25 +1250,37 @@ For `nonstrict_trace`-ed function, the only allowed input types are basic types
import torch.utils._pytree as pytree import torch.utils._pytree as pytree
if pytree.is_constant_class(typ): if pytree.is_constant_class(typ):
unimplemented( unimplemented_v2(
f""" gb_type="Input marked with `pytree.register_constant` constructed in the `torch.compile` region",
You are calling a `nonstrict_trace`-ed function with an input that contains an object of type <{type_name}>, which was marked with `pytree.register_constant`. However, the object was constructed _inside_ the `torch.compile` region. context=f"Input={input_spec_vt}, offending type <{type_name}>.",
explanation=(
Please construct the object _outside_ the `torch.compile` region, or submit an issue to GitHub. "Calling a `nonstrict_trace`-ed function with an input that contains an object "
""" # NOQA: B950 f"of type <{type_name}>, which was marked with `pytree.register_constant`. However, the object "
"was constructed _inside_ the `torch.compile` region. This is not supported."
),
hints=[
"Construct the object _outside_ the `torch.compile` region, or submit an issue to GitHub.",
*graph_break_hints.SUPPORTABLE,
],
from_exc=e,
) )
else: else:
unimplemented( unimplemented_v2(
f""" gb_type="Invalid use of pytree_flatten with nonstrict_trace-ed function",
You are calling a `nonstrict_trace`-ed function where one one of the inputs has been registered with a `pytree_flatten` that puts an object of type <{type_name}> into the context. context=f"Input={input_spec_vt}, offending type <{type_name}>.",
explanation=(
Please consider modifying that `pytree_flatten` to avoid putting the object into context, and apply one of the following to <{type_name}> "Calling a `nonstrict_trace`-ed function where one of the inputs has been registered "
* `torch.utils._pytree.register_constant` f"with a `pytree_flatten` that places an object of type <{type_name}> into the context."
* `torch.utils._pytree.register_dataclass` ),
* `torch.utils._pytree.register_pytree_node` hints=[
"Modifying the `pytree_flatten` to avoid placing the object into the context.",
If the above doesn't work, please subtmit an issue to GitHub. f"Apply one of the following to <{type_name}>:\n"
""" # NOQA: B950 "* `torch.utils._pytree.register_constant`\n"
"* `torch.utils._pytree.register_dataclass`\n"
"* `torch.utils._pytree.register_pytree_node`",
*graph_break_hints.SUPPORTABLE,
],
from_exc=e,
) )
fn = self.value fn = self.value
@ -1308,7 +1391,17 @@ To support this behavior, we need to allow const-propping tensors that store sym
For now, dynamo will explicitly graph break when it encounters user code with this behavior. For now, dynamo will explicitly graph break when it encounters user code with this behavior.
""" """
log.warning(msg) log.warning(msg)
unimplemented(msg) unimplemented_v2(
gb_type="Attempted to call torch in-graph function on only torch.SymInt arguments",
context=f"fn={self.value}, args={args}, kwargs={kwargs}",
explanation=(
f"Attempted to call {str(self.value)} (that should be put in the FX graph) on only torch.SymInt arguments. "
"Dynamo does not support this."
),
hints=[
*graph_break_hints.SUPPORTABLE,
],
)
# TODO(voz): Replace w/ dynamic shape rewrite table. # TODO(voz): Replace w/ dynamic shape rewrite table.
# Ideally, we would be able to do this at ctor time, but alas we need a combination # Ideally, we would be able to do this at ctor time, but alas we need a combination
@ -1366,9 +1459,15 @@ For now, dynamo will explicitly graph break when it encounters user code with th
and "requires_grad" in kwargs and "requires_grad" in kwargs
and kwargs["requires_grad"].as_python_constant() and kwargs["requires_grad"].as_python_constant()
): ):
unimplemented( unimplemented_v2(
"""factory functions that return tensors that require grad are not supported. gb_type="Attempted to use tensor creation function with requires_grad=True",
Either create the tensor outside the compiled region, or do not set the tensor to require_grad""" context=f"fn={self.value}, args={args}, kwargs={kwargs}",
explanation="Dynamo does not support this.",
hints=[
"Create the tensor outside the compiled region.",
"Do not set `requires_grad=True`.",
*graph_break_hints.SUPPORTABLE,
],
) )
# Handle e.g., `torch.add(a, b, out=result)` # Handle e.g., `torch.add(a, b, out=result)`
@ -1400,12 +1499,27 @@ Either create the tensor outside the compiled region, or do not set the tensor t
if saved_out_shape != fake_out.shape: if saved_out_shape != fake_out.shape:
# It's hard to get out variants with resizing on graph inputs work # It's hard to get out variants with resizing on graph inputs work
# properly across dynamo/aot/inductor, just fall back. # properly across dynamo/aot/inductor, just fall back.
unimplemented("out variants with resizing on graph inputs") unimplemented_v2(
gb_type="Shape mismatch with out= list of tensor variants",
context=f"fn={self.value}, args={args}, kwargs={kwargs}",
explanation=(
f"Shape mismatch when calling {self.value} with `out=`. "
f"Provided `out=` shape: {saved_out_shape}. Actual shape: {fake_out.shape}."
),
hints=[
*graph_break_hints.SUPPORTABLE,
],
)
if not torch._prims_common.is_contiguous(fake_out): if not torch._prims_common.is_contiguous(fake_out):
# It's difficult to handle strides correctly in functionalization # It's difficult to handle strides correctly in functionalization
# when calling an out= op with a non-contiguous out argument # when calling an out= op with a non-contiguous out argument
unimplemented( unimplemented_v2(
"out= op was called where output tensor was non-contiguous" gb_type="Attempted to call op with non-contiguous `out=` list of tensors",
context=f"self.value={self.value}, args={args}, kwargs={kwargs}",
explanation="Dynamo does not support this.",
hints=[
*graph_break_hints.SUPPORTABLE,
],
) )
else: else:
assert isinstance(out_kwarg_vt, TensorVariable) assert isinstance(out_kwarg_vt, TensorVariable)
@ -1414,12 +1528,27 @@ Either create the tensor outside the compiled region, or do not set the tensor t
if saved_out_shapes != fake_out.shape: if saved_out_shapes != fake_out.shape:
# It's hard to get out variants with resizing on graph inputs work # It's hard to get out variants with resizing on graph inputs work
# properly across dynamo/aot/inductor, just fall back. # properly across dynamo/aot/inductor, just fall back.
unimplemented("out variants with resizing on graph inputs") unimplemented_v2(
gb_type="Shape mismatch with out= tensor variant",
context=f"fn={self.value}, args={args}, kwargs={kwargs}",
explanation=(
f"Shape mismatch when calling {self.value} with `out=`. "
f"Provided `out=` shape: {saved_out_shapes}. Actual shape: {fake_out.shape}."
),
hints=[
*graph_break_hints.SUPPORTABLE,
],
)
if not torch._prims_common.is_contiguous(fake_out): if not torch._prims_common.is_contiguous(fake_out):
# It's difficult to handle strides correctly in functionalization # It's difficult to handle strides correctly in functionalization
# when calling an out= op with a non-contiguous out argument # when calling an out= op with a non-contiguous out argument
unimplemented( unimplemented_v2(
"out= op was called where output tensor was non-contiguous" gb_type="Attempted to call op with non-contiguous `out=` tensor",
context=f"self.value={self.value}, args={args}, kwargs={kwargs}",
explanation="Dynamo does not support this.",
hints=[
*graph_break_hints.SUPPORTABLE,
],
) )
return tensor_variable return tensor_variable
@ -1444,7 +1573,14 @@ Either create the tensor outside the compiled region, or do not set the tensor t
torch.nn.modules.utils._ntuple(count)(value.as_python_constant()), torch.nn.modules.utils._ntuple(count)(value.as_python_constant()),
) )
else: else:
unimplemented(f"torch.nn.modules.utils._ntuple({value})") unimplemented_v2(
gb_type="Attempted to use `torch.nn.modules.utils._ntuple` with unsupported argument type",
context=f"value={value}",
explanation="Dynamo does not support this.",
hints=[
"Change use of _ntuple with argument as constant or tensor.",
],
)
if self.value is torch.nn.modules.utils._ntuple: if self.value is torch.nn.modules.utils._ntuple:
return variables.LambdaVariable(handle_ntuple) return variables.LambdaVariable(handle_ntuple)
@ -1455,16 +1591,40 @@ Either create the tensor outside the compiled region, or do not set the tensor t
def call_nn_parameter(cls, tx, data=None, requires_grad=True): def call_nn_parameter(cls, tx, data=None, requires_grad=True):
"""A call to torch.nn.Parameter() gets lifted to before the graph""" """A call to torch.nn.Parameter() gets lifted to before the graph"""
if tx.export: if tx.export:
unimplemented("nn parameter construction not supported with export") unimplemented_v2(
gb_type="Attempted to use `torch.nn.Parameter()` with export",
context="",
explanation="Dynamo does not support this.",
hints=[
"Do not use `torch.nn.Parameter()` with export.",
*graph_break_hints.SUPPORTABLE,
],
)
if isinstance(requires_grad, variables.VariableTracker): if isinstance(requires_grad, variables.VariableTracker):
try: try:
requires_grad = requires_grad.as_python_constant() requires_grad = requires_grad.as_python_constant()
except NotImplementedError: except NotImplementedError:
unimplemented("Parameter(requires_grad=...) not constant") unimplemented_v2(
gb_type="non-constant `requires_grad` argument to `torch.nn.Parameter`",
context=f"requires_grad={requires_grad}",
explanation="Dynamo does not support this.",
hints=[
"Change `requires_grad` to be a bool.",
*graph_break_hints.USER_ERROR,
],
)
if not isinstance(data, variables.TensorVariable): if not isinstance(data, variables.TensorVariable):
unimplemented(f"Parameter(data={data}) not implemented") unimplemented_v2(
gb_type="`torch.nn.Parameter()` with unsupported data type",
context=f"data={data}",
explanation="Called `torch.nn.Parameter()` with non-Tensor argument.",
hints=[
"Ensure the argument to `torch.nn.Parameter()` is a `torch.Tensor`.",
*graph_break_hints.USER_ERROR,
],
)
# this results in cleaner graphs, but only works for inputs # this results in cleaner graphs, but only works for inputs
if data.source: if data.source:
@ -1473,17 +1633,41 @@ Either create the tensor outside the compiled region, or do not set the tensor t
if isinstance( if isinstance(
data, TensorWithTFOverrideVariable data, TensorWithTFOverrideVariable
) or is_traceable_wrapper_subclass_type(data.class_type): ) or is_traceable_wrapper_subclass_type(data.class_type):
unimplemented("Parameter constructor with tensor subclass NYI") unimplemented_v2(
gb_type="Attempted to use torch.nn.Parameter constructor with tensor subclass",
context=str(data),
explanation="Dynamo does not support this.",
hints=[
*graph_break_hints.SUPPORTABLE,
],
)
if not can_convert_to_tracable_parameter(): if not can_convert_to_tracable_parameter():
unimplemented("Workaround for issues with nn_parameter construction") unimplemented_v2(
gb_type="`torch.nn.Parameter`: cannot convert to traceable tracable",
context="",
explanation="convert_tracable_parameter is set to False.",
hints=[
"Check usage of context manager: do_not_convert_to_tracable_parameter",
*graph_break_hints.DIFFICULT,
],
)
try: try:
shape = tuple(data.var_getattr(tx, "shape").as_python_constant()) shape = tuple(data.var_getattr(tx, "shape").as_python_constant())
dtype = data.var_getattr(tx, "dtype").as_python_constant() dtype = data.var_getattr(tx, "dtype").as_python_constant()
device = data.var_getattr(tx, "device").as_python_constant() device = data.var_getattr(tx, "device").as_python_constant()
except NotImplementedError as e: except NotImplementedError as e:
unimplemented(f"Parameter not python_constant: {e}") unimplemented_v2(
gb_type="`torch.nn.Parameter` with non-constant Tensor attributes",
context=f"data={data}",
explanation="Dynamo does not support this.",
hints=[
"Ensure the Tensor argument's shape, dtype, and device are correct.",
*graph_break_hints.USER_ERROR,
],
from_exc=e,
)
placeholder = tx.output.synthetic_graph_input( placeholder = tx.output.synthetic_graph_input(
new_parameter_placeholder, [shape, dtype, device, requires_grad] new_parameter_placeholder, [shape, dtype, device, requires_grad]
@ -1535,8 +1719,13 @@ Either create the tensor outside the compiled region, or do not set the tensor t
data_node = data.as_proxy().node data_node = data.as_proxy().node
if data_node.op not in ("placeholder", "get_attr"): if data_node.op not in ("placeholder", "get_attr"):
unimplemented( unimplemented_v2(
"Unexpected type of data placeholder op for parameter construction" gb_type="Unexpected type of data placeholder op for parameter construction",
context=f"data_node.op={data_node.op}",
explanation="Data node op should be placeholder or get_attr.",
hints=[
*graph_break_hints.DIFFICULT,
],
) )
# add the newly constructed nn.Parameter as a graph input # add the newly constructed nn.Parameter as a graph input