mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo] Replace unimplemented with unimplemented_v2 in torch/_dynamo/variables/torch.py (#157344)
Fixes part of #147913 Pull Request resolved: https://github.com/pytorch/pytorch/pull/157344 Approved by: https://github.com/williamwen42 Co-authored-by: William Wen <william.wen42@gmail.com>
This commit is contained in:
parent
85111cd165
commit
df72078fe1
|
|
@ -671,13 +671,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
|||
fn(p)
|
||||
self.assertFalse(True) # must raise error before this
|
||||
except torch._dynamo.exc.Unsupported as e:
|
||||
msg = """
|
||||
For `nonstrict_trace`-ed function, the only allowed input types are basic types (e.g., torch.Tensor, int, float) or pytree containers of those. Here you are calling the function with arguments that contain a value of type <DecoratorTests.test_nonstrict_trace_custom_class_error.<locals>.Point>, please use one of the following to register the type with pytree:
|
||||
* `torch.utils._pytree.register_constant`
|
||||
* `torch.utils._pytree.register_dataclass`
|
||||
* `torch.utils._pytree.register_pytree_node`
|
||||
""" # NOQA: B950
|
||||
self.assertIn(msg, str(e))
|
||||
self.assertIn("Invalid input type for nonstrict_trace-ed function", str(e))
|
||||
|
||||
def test_nonstrict_trace_nested_custom_class_error(self):
|
||||
class Point:
|
||||
|
|
@ -723,13 +717,7 @@ For `nonstrict_trace`-ed function, the only allowed input types are basic types
|
|||
fn(torch.ones(10), torch.ones(1))
|
||||
self.assertFalse(True) # must raise error before this
|
||||
except torch._dynamo.exc.Unsupported as e:
|
||||
msg = """
|
||||
For `nonstrict_trace`-ed function, the only allowed input types are basic types (e.g., torch.Tensor, int, float) or pytree containers of those. Here you are calling the function with arguments that contain a value of type <DecoratorTests.test_nonstrict_trace_nested_custom_class_error.<locals>.Point>, please use one of the following to register the type with pytree:
|
||||
* `torch.utils._pytree.register_constant`
|
||||
* `torch.utils._pytree.register_dataclass`
|
||||
* `torch.utils._pytree.register_pytree_node`
|
||||
""" # NOQA: B950
|
||||
self.assertIn(msg, str(e))
|
||||
self.assertIn("Invalid input type for nonstrict_trace-ed function", str(e))
|
||||
|
||||
def test_nonstrict_newly_constructed_trace_register_constant_type_error(self):
|
||||
class State:
|
||||
|
|
@ -766,12 +754,10 @@ For `nonstrict_trace`-ed function, the only allowed input types are basic types
|
|||
fn(x)
|
||||
self.assertFalse(True) # must raise error before this
|
||||
except torch._dynamo.exc.Unsupported as e:
|
||||
msg = """
|
||||
You are calling a `nonstrict_trace`-ed function with an input that contains an object of type <DecoratorTests.test_nonstrict_newly_constructed_trace_register_constant_type_error.<locals>.State>, which was marked with `pytree.register_constant`. However, the object was constructed _inside_ the `torch.compile` region.
|
||||
|
||||
Please construct the object _outside_ the `torch.compile` region, or submit an issue to GitHub.
|
||||
""" # NOQA: B950
|
||||
self.assertIn(msg, str(e))
|
||||
self.assertIn(
|
||||
"Input marked with `pytree.register_constant` constructed in the `torch.compile` region",
|
||||
str(e),
|
||||
)
|
||||
|
||||
def test_nonstrict_trace_object_in_context_error(self):
|
||||
class Point:
|
||||
|
|
@ -814,17 +800,9 @@ Please construct the object _outside_ the `torch.compile` region, or submit an i
|
|||
fn(x, y)
|
||||
self.assertFalse(True) # must raise error before this
|
||||
except torch._dynamo.exc.Unsupported as e:
|
||||
msg = """
|
||||
You are calling a `nonstrict_trace`-ed function where one one of the inputs has been registered with a `pytree_flatten` that puts an object of type <DecoratorTests.test_nonstrict_trace_object_in_context_error.<locals>.Point> into the context.
|
||||
|
||||
Please consider modifying that `pytree_flatten` to avoid putting the object into context, and apply one of the following to <DecoratorTests.test_nonstrict_trace_object_in_context_error.<locals>.Point>
|
||||
* `torch.utils._pytree.register_constant`
|
||||
* `torch.utils._pytree.register_dataclass`
|
||||
* `torch.utils._pytree.register_pytree_node`
|
||||
|
||||
If the above doesn't work, please subtmit an issue to GitHub.
|
||||
""" # NOQA: B950
|
||||
self.assertIn(msg, str(e))
|
||||
self.assertIn(
|
||||
"Invalid use of pytree_flatten with nonstrict_trace-ed function", str(e)
|
||||
)
|
||||
|
||||
def test_graph_break(self):
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
|
|
|
|||
|
|
@ -232,7 +232,7 @@ class TorchFunctionModeTests(torch._dynamo.test_case.TestCase):
|
|||
|
||||
self.assertRaisesRegex(
|
||||
torch._dynamo.exc.Unsupported,
|
||||
"Popping from an empty torch function mode stack",
|
||||
"Attempted to pop from empty torch function mode stack",
|
||||
lambda: fn(torch.ones(2, 2)),
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -2172,9 +2172,6 @@
|
|||
"Hints": [
|
||||
"Don't mutate `.data` on this tensor, or move ",
|
||||
"the mutation out of `torch.compile` region"
|
||||
],
|
||||
"Additional_Info": [
|
||||
"INFO"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
|
@ -2199,5 +2196,278 @@
|
|||
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues."
|
||||
]
|
||||
}
|
||||
],
|
||||
"GB0223": [
|
||||
{
|
||||
"Gb_type": "torch.compile call with > 1 args",
|
||||
"Context": "args={args}, kwargs={kwargs}",
|
||||
"Explanation": "Attempted to call `torch.compile` with > 1 args. Dynamo does not support this.",
|
||||
"Hints": [
|
||||
"Remove the torch.compile call or its additional args.",
|
||||
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues."
|
||||
]
|
||||
}
|
||||
],
|
||||
"GB0224": [
|
||||
{
|
||||
"Gb_type": "Attempted to call torch in-graph function on only torch.SymInt arguments",
|
||||
"Context": "fn={self.value}, args={args}, kwargs={kwargs}",
|
||||
"Explanation": "Attempted to call {str(self.value)} (that should be put in the FX graph) on only torch.SymInt arguments. Dynamo does not support this.",
|
||||
"Hints": [
|
||||
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues."
|
||||
]
|
||||
}
|
||||
],
|
||||
"GB0225": [
|
||||
{
|
||||
"Gb_type": "Attempted to use tensor creation function with requires_grad=True",
|
||||
"Context": "fn={self.value}, args={args}, kwargs={kwargs}",
|
||||
"Explanation": "Dynamo does not support this.",
|
||||
"Hints": [
|
||||
"Create the tensor outside the compiled region.",
|
||||
"Do not set `requires_grad=True`.",
|
||||
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues."
|
||||
]
|
||||
}
|
||||
],
|
||||
"GB0226": [
|
||||
{
|
||||
"Gb_type": "`torch.nn.Parameter()` with unsupported data type",
|
||||
"Context": "data={data}",
|
||||
"Explanation": "Called `torch.nn.Parameter()` with non-Tensor argument.",
|
||||
"Hints": [
|
||||
"Ensure the argument to `torch.nn.Parameter()` is a `torch.Tensor`.",
|
||||
"Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled."
|
||||
]
|
||||
}
|
||||
],
|
||||
"GB0227": [
|
||||
{
|
||||
"Gb_type": "Attempted to use torch.nn.Parameter constructor with tensor subclass",
|
||||
"Context": "str(data)",
|
||||
"Explanation": "Dynamo does not support this.",
|
||||
"Hints": [
|
||||
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues."
|
||||
]
|
||||
}
|
||||
],
|
||||
"GB0228": [
|
||||
{
|
||||
"Gb_type": "`torch.nn.Parameter`: cannot convert to traceable tracable",
|
||||
"Context": "",
|
||||
"Explanation": "convert_tracable_parameter is set to False.",
|
||||
"Hints": [
|
||||
"Check usage of context manager: do_not_convert_to_tracable_parameter",
|
||||
"This graph break may be difficult to debug. Please report an issue to PyTorch for assistance."
|
||||
]
|
||||
}
|
||||
],
|
||||
"GB0229": [
|
||||
{
|
||||
"Gb_type": "Unexpected type of data placeholder op for parameter construction",
|
||||
"Context": "data_node.op={data_node.op}",
|
||||
"Explanation": "Data node op should be placeholder or get_attr.",
|
||||
"Hints": [
|
||||
"This graph break may be difficult to debug. Please report an issue to PyTorch for assistance."
|
||||
]
|
||||
}
|
||||
],
|
||||
"GB0230": [
|
||||
{
|
||||
"Gb_type": "Attempted to use torch.use_deterministic_algorithms(warn_only=True)",
|
||||
"Context": "mode={mode}, warn_only={warn_only}",
|
||||
"Explanation": "Dynamo does not support this.",
|
||||
"Hints": [
|
||||
"Remove param warn_only in function call torch.use_deterministic_algorithms.",
|
||||
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues."
|
||||
]
|
||||
}
|
||||
],
|
||||
"GB0231": [
|
||||
{
|
||||
"Gb_type": "call `torch.from_numpy` with `torch._dynamo.config.trace_numpy=False`",
|
||||
"Context": "trace_numpy={config.trace_numpy}",
|
||||
"Explanation": "Attempted to call `torch.from_numpy` with config `torch._dynamo.config.trace_numpy` set to `False`.",
|
||||
"Hints": [
|
||||
"Change `torch._dynamo.config.trace_numpy` to `True`."
|
||||
]
|
||||
}
|
||||
],
|
||||
"GB0232": [
|
||||
{
|
||||
"Gb_type": "`torch.from_numpy` with NumPy unavailable",
|
||||
"Context": "",
|
||||
"Explanation": "Attempted to call `torch.numpy` but NumPy could not be imported.",
|
||||
"Hints": [
|
||||
"Check NumPy version and installation in your environment.",
|
||||
"Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled."
|
||||
]
|
||||
}
|
||||
],
|
||||
"GB0233": [
|
||||
{
|
||||
"Gb_type": "Attempted to use strided NestedTensor",
|
||||
"Context": "layout={layout}",
|
||||
"Explanation": "Dynamo does not support this.",
|
||||
"Hints": [
|
||||
"Change layout=torch.jagged.",
|
||||
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues."
|
||||
]
|
||||
}
|
||||
],
|
||||
"GB0234": [
|
||||
{
|
||||
"Gb_type": "Attempted to pop from empty torch function mode stack",
|
||||
"Context": "",
|
||||
"Explanation": "Called `torch._C._pop_torch_function_stack` when torch function mode stack is empty.",
|
||||
"Hints": [
|
||||
"Do not pop from empty torch function mode stack.",
|
||||
"Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled."
|
||||
]
|
||||
}
|
||||
],
|
||||
"GB0235": [
|
||||
{
|
||||
"Gb_type": "`torch.nn.Parameter` with non-constant Tensor attributes",
|
||||
"Context": "data={data}",
|
||||
"Explanation": "Dynamo does not support this.",
|
||||
"Hints": [
|
||||
"Ensure the Tensor argument's shape, dtype, and device are correct.",
|
||||
"Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled."
|
||||
]
|
||||
}
|
||||
],
|
||||
"GB0236": [
|
||||
{
|
||||
"Gb_type": "Invalid input type for nonstrict_trace-ed function",
|
||||
"Context": "Encountered input of type <{type_name}>.",
|
||||
"Explanation": "For `nonstrict_trace`-ed functions, only basic types (e.g., torch.Tensor, int, float) or pytree containers of those are allowed as inputs. The provided argument contains an unsupported type.",
|
||||
"Hints": [
|
||||
"Use one of the following to register the type with pytree:\n",
|
||||
"* `torch.utils._pytree.register_constant`\n",
|
||||
"* `torch.utils._pytree.register_dataclass`\n",
|
||||
"* `torch.utils._pytree.register_pytree_node`"
|
||||
]
|
||||
}
|
||||
],
|
||||
"GB0237": [
|
||||
{
|
||||
"Gb_type": "non-constant `requires_grad` argument to `torch.nn.Parameter`",
|
||||
"Context": "requires_grad={requires_grad}",
|
||||
"Explanation": "Dynamo does not support this.",
|
||||
"Hints": [
|
||||
"Change `requires_grad` to be a bool.",
|
||||
"Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled."
|
||||
]
|
||||
}
|
||||
],
|
||||
"GB0238": [
|
||||
{
|
||||
"Gb_type": "Input marked with `pytree.register_constant` constructed in the `torch.compile` region",
|
||||
"Context": "Input={input_spec_vt}, offending type <{type_name}>.",
|
||||
"Explanation": "Calling a `nonstrict_trace`-ed function with an input that contains an object of type <{type_name}>, which was marked with `pytree.register_constant`. However, the object was constructed _inside_ the `torch.compile` region. This is not supported.",
|
||||
"Hints": [
|
||||
"Construct the object _outside_ the `torch.compile` region, or submit an issue to GitHub.",
|
||||
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues."
|
||||
]
|
||||
}
|
||||
],
|
||||
"GB0239": [
|
||||
{
|
||||
"Gb_type": "Invalid use of pytree_flatten with nonstrict_trace-ed function",
|
||||
"Context": "Input={input_spec_vt}, offending type <{type_name}>.",
|
||||
"Explanation": "Calling a `nonstrict_trace`-ed function where one of the inputs has been registered with a `pytree_flatten` that places an object of type <{type_name}> into the context.",
|
||||
"Hints": [
|
||||
"Modifying the `pytree_flatten` to avoid placing the object into the context.",
|
||||
"Apply one of the following to <{type_name}>:\n",
|
||||
"* `torch.utils._pytree.register_constant`\n",
|
||||
"* `torch.utils._pytree.register_dataclass`\n",
|
||||
"* `torch.utils._pytree.register_pytree_node`",
|
||||
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues."
|
||||
]
|
||||
}
|
||||
],
|
||||
"GB0240": [
|
||||
{
|
||||
"Gb_type": "Shape mismatch with out= list of tensor variants",
|
||||
"Context": "fn={self.value}, args={args}, kwargs={kwargs}",
|
||||
"Explanation": "Shape mismatch when calling {self.value} with `out=`. Provided `out=` shape: {saved_out_shape}. Actual shape: {fake_out.shape}.",
|
||||
"Hints": [
|
||||
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues."
|
||||
]
|
||||
}
|
||||
],
|
||||
"GB0241": [
|
||||
{
|
||||
"Gb_type": "Attempted to call op with non-contiguous `out=` list of tensors",
|
||||
"Context": "self.value={self.value}, args={args}, kwargs={kwargs}",
|
||||
"Explanation": "Dynamo does not support this.",
|
||||
"Hints": [
|
||||
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues."
|
||||
]
|
||||
}
|
||||
],
|
||||
"GB0242": [
|
||||
{
|
||||
"Gb_type": "Attempted to call op with non-contiguous `out=` tensor",
|
||||
"Context": "self.value={self.value}, args={args}, kwargs={kwargs}",
|
||||
"Explanation": "Dynamo does not support this.",
|
||||
"Hints": [
|
||||
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues."
|
||||
]
|
||||
}
|
||||
],
|
||||
"GB0243": [
|
||||
{
|
||||
"Gb_type": "Attempted to use `torch.nn.modules.utils._ntuple` with unsupported argument type",
|
||||
"Context": "value={value}",
|
||||
"Explanation": "Dynamo does not support this.",
|
||||
"Hints": [
|
||||
"Change use of _ntuple with argument as constant or tensor."
|
||||
]
|
||||
}
|
||||
],
|
||||
"GB0244": [
|
||||
{
|
||||
"Gb_type": "Attempted to use `torch.nn.Parameter()` with export",
|
||||
"Context": "",
|
||||
"Explanation": "Dynamo does not support this.",
|
||||
"Hints": [
|
||||
"Do not use `torch.nn.Parameter()` with export.",
|
||||
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues."
|
||||
]
|
||||
}
|
||||
],
|
||||
"GB0245": [
|
||||
{
|
||||
"Gb_type": "Attempted to use `nested_tensor` with non-list input",
|
||||
"Context": "tensor_list={tensor_list}",
|
||||
"Explanation": "Dynamo does not support this.",
|
||||
"Hints": [
|
||||
"Change `nested_tensor` with list input.",
|
||||
"Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled."
|
||||
]
|
||||
}
|
||||
],
|
||||
"GB0246": [
|
||||
{
|
||||
"Gb_type": "Attempted to use `torch.nn.functional.one_hot` with data-dependent output shape",
|
||||
"Context": "args={args}, kwargs={kwargs}",
|
||||
"Explanation": "Dynamo does not support this.",
|
||||
"Hints": [
|
||||
"Explicitly set the `num_classes` param of the function call ",
|
||||
"`torch.nn.functional.one_hot` to something other than -1."
|
||||
]
|
||||
}
|
||||
],
|
||||
"GB0247": [
|
||||
{
|
||||
"Gb_type": "Shape mismatch with out= tensor variant",
|
||||
"Context": "fn={self.value}, args={args}, kwargs={kwargs}",
|
||||
"Explanation": "Shape mismatch when calling {self.value} with `out=`. Provided `out=` shape: {saved_out_shapes}. Actual shape: {fake_out.shape}.",
|
||||
"Hints": [
|
||||
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues."
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
|
@ -52,7 +52,7 @@ from ..create_parameter_op import (
|
|||
tracable_create_parameter,
|
||||
)
|
||||
from ..device_interface import get_registered_device_interfaces
|
||||
from ..exc import unimplemented, unimplemented_v2
|
||||
from ..exc import unimplemented_v2
|
||||
from ..guards import GuardBuilder, install_guard
|
||||
from ..source import CallFunctionNoArgsSource, SyntheticLocalSource
|
||||
from ..utils import (
|
||||
|
|
@ -588,7 +588,15 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
|||
# torch.compile is a no-op in dynamo
|
||||
return args[0]
|
||||
|
||||
unimplemented("torch.compile is used as a decorator in the compiled frame")
|
||||
unimplemented_v2(
|
||||
gb_type="torch.compile call with > 1 args",
|
||||
context=f"args={args}, kwargs={kwargs}",
|
||||
explanation="Attempted to call `torch.compile` with > 1 args. Dynamo does not support this.",
|
||||
hints=[
|
||||
"Remove the torch.compile call or its additional args.",
|
||||
*graph_break_hints.SUPPORTABLE,
|
||||
],
|
||||
)
|
||||
|
||||
@register(*REWRITE_OPS_TO_TENSOR_SIZE_METHOD)
|
||||
def handle_tensor_size_rewrites(self, tx: "InstructionTranslator", input):
|
||||
|
|
@ -615,7 +623,15 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
|||
self, tx: "InstructionTranslator", mode, warn_only=False
|
||||
):
|
||||
if warn_only and warn_only.as_python_constant():
|
||||
unimplemented("torch.use_deterministic_algorithms(warn_only=True)")
|
||||
unimplemented_v2(
|
||||
gb_type="Attempted to use torch.use_deterministic_algorithms(warn_only=True)",
|
||||
context=f"mode={mode}, warn_only={warn_only}",
|
||||
explanation="Dynamo does not support this.",
|
||||
hints=[
|
||||
"Remove param warn_only in function call torch.use_deterministic_algorithms.",
|
||||
*graph_break_hints.SUPPORTABLE,
|
||||
],
|
||||
)
|
||||
return DeterministicAlgorithmsVariable.create(tx, mode.as_python_constant())
|
||||
|
||||
@register(torch.are_deterministic_algorithms_enabled)
|
||||
|
|
@ -666,9 +682,27 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
|||
@register(torch.from_numpy)
|
||||
def handle_from_numpy(self, tx: "InstructionTranslator", *args):
|
||||
if not config.trace_numpy:
|
||||
unimplemented("torch.from_numpy. config.trace_numpy is False")
|
||||
unimplemented_v2(
|
||||
gb_type="call `torch.from_numpy` with `torch._dynamo.config.trace_numpy=False`",
|
||||
context=f"trace_numpy={config.trace_numpy}",
|
||||
explanation=(
|
||||
"Attempted to call `torch.from_numpy` with config "
|
||||
"`torch._dynamo.config.trace_numpy` set to `False`."
|
||||
),
|
||||
hints=[
|
||||
"Change `torch._dynamo.config.trace_numpy` to `True`.",
|
||||
],
|
||||
)
|
||||
if not np:
|
||||
unimplemented("torch.from_numpy. NumPy is not available")
|
||||
unimplemented_v2(
|
||||
gb_type="`torch.from_numpy` with NumPy unavailable",
|
||||
context="",
|
||||
explanation="Attempted to call `torch.numpy` but NumPy could not be imported.",
|
||||
hints=[
|
||||
"Check NumPy version and installation in your environment.",
|
||||
*graph_break_hints.USER_ERROR,
|
||||
],
|
||||
)
|
||||
return wrap_fx_proxy_cls(
|
||||
target_cls=TensorVariable,
|
||||
tx=tx,
|
||||
|
|
@ -880,9 +914,25 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
|||
from .lists import BaseListVariable
|
||||
|
||||
if layout and layout.as_python_constant() == torch.strided:
|
||||
unimplemented("torch.compile does not support strided NestedTensor")
|
||||
unimplemented_v2(
|
||||
gb_type="Attempted to use strided NestedTensor",
|
||||
context=f"layout={layout}",
|
||||
explanation="Dynamo does not support this.",
|
||||
hints=[
|
||||
"Change layout=torch.jagged.",
|
||||
*graph_break_hints.SUPPORTABLE,
|
||||
],
|
||||
)
|
||||
if not isinstance(tensor_list, BaseListVariable):
|
||||
unimplemented("nested_tensor with non-list input")
|
||||
unimplemented_v2(
|
||||
gb_type="Attempted to use `nested_tensor` with non-list input",
|
||||
context=f"tensor_list={tensor_list}",
|
||||
explanation="Dynamo does not support this.",
|
||||
hints=[
|
||||
"Change `nested_tensor` with list input.",
|
||||
*graph_break_hints.USER_ERROR,
|
||||
],
|
||||
)
|
||||
|
||||
@register(torch.nn.functional.one_hot)
|
||||
def handle_one_hot(self, tx: "InstructionTranslator", *args, **kwargs):
|
||||
|
|
@ -891,8 +941,14 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
|||
and args[1].is_python_constant()
|
||||
and args[1].as_python_constant() == -1
|
||||
):
|
||||
unimplemented(
|
||||
"torch.nn.functional.one_hot with data-dependent output shape"
|
||||
unimplemented_v2(
|
||||
gb_type="Attempted to use `torch.nn.functional.one_hot` with data-dependent output shape",
|
||||
context=f"args={args}, kwargs={kwargs}",
|
||||
explanation="Dynamo does not support this.",
|
||||
hints=[
|
||||
"Explicitly set the `num_classes` param of the function call "
|
||||
"`torch.nn.functional.one_hot` to something other than -1.",
|
||||
],
|
||||
)
|
||||
|
||||
@register(torch.fx.experimental.symbolic_shapes.guard_size_oblivious)
|
||||
|
|
@ -1061,7 +1117,15 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
|||
):
|
||||
assert not args and not kwargs
|
||||
if not tx.symbolic_torch_function_state.mode_stack:
|
||||
raise unimplemented("Popping from an empty torch function mode stack")
|
||||
unimplemented_v2(
|
||||
gb_type="Attempted to pop from empty torch function mode stack",
|
||||
context="",
|
||||
explanation="Called `torch._C._pop_torch_function_stack` when torch function mode stack is empty.",
|
||||
hints=[
|
||||
"Do not pop from empty torch function mode stack.",
|
||||
*graph_break_hints.USER_ERROR,
|
||||
],
|
||||
)
|
||||
TorchFunctionModeStackVariable.register_mutation(tx)
|
||||
return tx.symbolic_torch_function_state.pop_torch_function_mode()
|
||||
|
||||
|
|
@ -1152,13 +1216,20 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
|||
arg_type = flat_arg_vt.python_type()
|
||||
if not is_graphable_type(arg_type):
|
||||
type_name = flat_arg_vt.python_type().__qualname__
|
||||
unimplemented(
|
||||
f"""
|
||||
For `nonstrict_trace`-ed function, the only allowed input types are basic types (e.g., torch.Tensor, int, float) or pytree containers of those. Here you are calling the function with arguments that contain a value of type <{type_name}>, please use one of the following to register the type with pytree:
|
||||
* `torch.utils._pytree.register_constant`
|
||||
* `torch.utils._pytree.register_dataclass`
|
||||
* `torch.utils._pytree.register_pytree_node`
|
||||
""" # NOQA: B950
|
||||
unimplemented_v2(
|
||||
gb_type="Invalid input type for nonstrict_trace-ed function",
|
||||
context=f"Encountered input of type <{type_name}>.",
|
||||
explanation=(
|
||||
"For `nonstrict_trace`-ed functions, only basic types (e.g., torch.Tensor, int, float) "
|
||||
"or pytree containers of those are allowed as inputs. The provided argument contains "
|
||||
"an unsupported type."
|
||||
),
|
||||
hints=[
|
||||
"Use one of the following to register the type with pytree:\n"
|
||||
"* `torch.utils._pytree.register_constant`\n"
|
||||
"* `torch.utils._pytree.register_dataclass`\n"
|
||||
"* `torch.utils._pytree.register_pytree_node`",
|
||||
],
|
||||
)
|
||||
|
||||
# Since we checked with `is_graphable` above, `as_proxy` on the
|
||||
|
|
@ -1179,25 +1250,37 @@ For `nonstrict_trace`-ed function, the only allowed input types are basic types
|
|||
import torch.utils._pytree as pytree
|
||||
|
||||
if pytree.is_constant_class(typ):
|
||||
unimplemented(
|
||||
f"""
|
||||
You are calling a `nonstrict_trace`-ed function with an input that contains an object of type <{type_name}>, which was marked with `pytree.register_constant`. However, the object was constructed _inside_ the `torch.compile` region.
|
||||
|
||||
Please construct the object _outside_ the `torch.compile` region, or submit an issue to GitHub.
|
||||
""" # NOQA: B950
|
||||
unimplemented_v2(
|
||||
gb_type="Input marked with `pytree.register_constant` constructed in the `torch.compile` region",
|
||||
context=f"Input={input_spec_vt}, offending type <{type_name}>.",
|
||||
explanation=(
|
||||
"Calling a `nonstrict_trace`-ed function with an input that contains an object "
|
||||
f"of type <{type_name}>, which was marked with `pytree.register_constant`. However, the object "
|
||||
"was constructed _inside_ the `torch.compile` region. This is not supported."
|
||||
),
|
||||
hints=[
|
||||
"Construct the object _outside_ the `torch.compile` region, or submit an issue to GitHub.",
|
||||
*graph_break_hints.SUPPORTABLE,
|
||||
],
|
||||
from_exc=e,
|
||||
)
|
||||
else:
|
||||
unimplemented(
|
||||
f"""
|
||||
You are calling a `nonstrict_trace`-ed function where one one of the inputs has been registered with a `pytree_flatten` that puts an object of type <{type_name}> into the context.
|
||||
|
||||
Please consider modifying that `pytree_flatten` to avoid putting the object into context, and apply one of the following to <{type_name}>
|
||||
* `torch.utils._pytree.register_constant`
|
||||
* `torch.utils._pytree.register_dataclass`
|
||||
* `torch.utils._pytree.register_pytree_node`
|
||||
|
||||
If the above doesn't work, please subtmit an issue to GitHub.
|
||||
""" # NOQA: B950
|
||||
unimplemented_v2(
|
||||
gb_type="Invalid use of pytree_flatten with nonstrict_trace-ed function",
|
||||
context=f"Input={input_spec_vt}, offending type <{type_name}>.",
|
||||
explanation=(
|
||||
"Calling a `nonstrict_trace`-ed function where one of the inputs has been registered "
|
||||
f"with a `pytree_flatten` that places an object of type <{type_name}> into the context."
|
||||
),
|
||||
hints=[
|
||||
"Modifying the `pytree_flatten` to avoid placing the object into the context.",
|
||||
f"Apply one of the following to <{type_name}>:\n"
|
||||
"* `torch.utils._pytree.register_constant`\n"
|
||||
"* `torch.utils._pytree.register_dataclass`\n"
|
||||
"* `torch.utils._pytree.register_pytree_node`",
|
||||
*graph_break_hints.SUPPORTABLE,
|
||||
],
|
||||
from_exc=e,
|
||||
)
|
||||
|
||||
fn = self.value
|
||||
|
|
@ -1308,7 +1391,17 @@ To support this behavior, we need to allow const-propping tensors that store sym
|
|||
For now, dynamo will explicitly graph break when it encounters user code with this behavior.
|
||||
"""
|
||||
log.warning(msg)
|
||||
unimplemented(msg)
|
||||
unimplemented_v2(
|
||||
gb_type="Attempted to call torch in-graph function on only torch.SymInt arguments",
|
||||
context=f"fn={self.value}, args={args}, kwargs={kwargs}",
|
||||
explanation=(
|
||||
f"Attempted to call {str(self.value)} (that should be put in the FX graph) on only torch.SymInt arguments. "
|
||||
"Dynamo does not support this."
|
||||
),
|
||||
hints=[
|
||||
*graph_break_hints.SUPPORTABLE,
|
||||
],
|
||||
)
|
||||
|
||||
# TODO(voz): Replace w/ dynamic shape rewrite table.
|
||||
# Ideally, we would be able to do this at ctor time, but alas we need a combination
|
||||
|
|
@ -1366,9 +1459,15 @@ For now, dynamo will explicitly graph break when it encounters user code with th
|
|||
and "requires_grad" in kwargs
|
||||
and kwargs["requires_grad"].as_python_constant()
|
||||
):
|
||||
unimplemented(
|
||||
"""factory functions that return tensors that require grad are not supported.
|
||||
Either create the tensor outside the compiled region, or do not set the tensor to require_grad"""
|
||||
unimplemented_v2(
|
||||
gb_type="Attempted to use tensor creation function with requires_grad=True",
|
||||
context=f"fn={self.value}, args={args}, kwargs={kwargs}",
|
||||
explanation="Dynamo does not support this.",
|
||||
hints=[
|
||||
"Create the tensor outside the compiled region.",
|
||||
"Do not set `requires_grad=True`.",
|
||||
*graph_break_hints.SUPPORTABLE,
|
||||
],
|
||||
)
|
||||
|
||||
# Handle e.g., `torch.add(a, b, out=result)`
|
||||
|
|
@ -1400,12 +1499,27 @@ Either create the tensor outside the compiled region, or do not set the tensor t
|
|||
if saved_out_shape != fake_out.shape:
|
||||
# It's hard to get out variants with resizing on graph inputs work
|
||||
# properly across dynamo/aot/inductor, just fall back.
|
||||
unimplemented("out variants with resizing on graph inputs")
|
||||
unimplemented_v2(
|
||||
gb_type="Shape mismatch with out= list of tensor variants",
|
||||
context=f"fn={self.value}, args={args}, kwargs={kwargs}",
|
||||
explanation=(
|
||||
f"Shape mismatch when calling {self.value} with `out=`. "
|
||||
f"Provided `out=` shape: {saved_out_shape}. Actual shape: {fake_out.shape}."
|
||||
),
|
||||
hints=[
|
||||
*graph_break_hints.SUPPORTABLE,
|
||||
],
|
||||
)
|
||||
if not torch._prims_common.is_contiguous(fake_out):
|
||||
# It's difficult to handle strides correctly in functionalization
|
||||
# when calling an out= op with a non-contiguous out argument
|
||||
unimplemented(
|
||||
"out= op was called where output tensor was non-contiguous"
|
||||
unimplemented_v2(
|
||||
gb_type="Attempted to call op with non-contiguous `out=` list of tensors",
|
||||
context=f"self.value={self.value}, args={args}, kwargs={kwargs}",
|
||||
explanation="Dynamo does not support this.",
|
||||
hints=[
|
||||
*graph_break_hints.SUPPORTABLE,
|
||||
],
|
||||
)
|
||||
else:
|
||||
assert isinstance(out_kwarg_vt, TensorVariable)
|
||||
|
|
@ -1414,12 +1528,27 @@ Either create the tensor outside the compiled region, or do not set the tensor t
|
|||
if saved_out_shapes != fake_out.shape:
|
||||
# It's hard to get out variants with resizing on graph inputs work
|
||||
# properly across dynamo/aot/inductor, just fall back.
|
||||
unimplemented("out variants with resizing on graph inputs")
|
||||
unimplemented_v2(
|
||||
gb_type="Shape mismatch with out= tensor variant",
|
||||
context=f"fn={self.value}, args={args}, kwargs={kwargs}",
|
||||
explanation=(
|
||||
f"Shape mismatch when calling {self.value} with `out=`. "
|
||||
f"Provided `out=` shape: {saved_out_shapes}. Actual shape: {fake_out.shape}."
|
||||
),
|
||||
hints=[
|
||||
*graph_break_hints.SUPPORTABLE,
|
||||
],
|
||||
)
|
||||
if not torch._prims_common.is_contiguous(fake_out):
|
||||
# It's difficult to handle strides correctly in functionalization
|
||||
# when calling an out= op with a non-contiguous out argument
|
||||
unimplemented(
|
||||
"out= op was called where output tensor was non-contiguous"
|
||||
unimplemented_v2(
|
||||
gb_type="Attempted to call op with non-contiguous `out=` tensor",
|
||||
context=f"self.value={self.value}, args={args}, kwargs={kwargs}",
|
||||
explanation="Dynamo does not support this.",
|
||||
hints=[
|
||||
*graph_break_hints.SUPPORTABLE,
|
||||
],
|
||||
)
|
||||
|
||||
return tensor_variable
|
||||
|
|
@ -1444,7 +1573,14 @@ Either create the tensor outside the compiled region, or do not set the tensor t
|
|||
torch.nn.modules.utils._ntuple(count)(value.as_python_constant()),
|
||||
)
|
||||
else:
|
||||
unimplemented(f"torch.nn.modules.utils._ntuple({value})")
|
||||
unimplemented_v2(
|
||||
gb_type="Attempted to use `torch.nn.modules.utils._ntuple` with unsupported argument type",
|
||||
context=f"value={value}",
|
||||
explanation="Dynamo does not support this.",
|
||||
hints=[
|
||||
"Change use of _ntuple with argument as constant or tensor.",
|
||||
],
|
||||
)
|
||||
|
||||
if self.value is torch.nn.modules.utils._ntuple:
|
||||
return variables.LambdaVariable(handle_ntuple)
|
||||
|
|
@ -1455,16 +1591,40 @@ Either create the tensor outside the compiled region, or do not set the tensor t
|
|||
def call_nn_parameter(cls, tx, data=None, requires_grad=True):
|
||||
"""A call to torch.nn.Parameter() gets lifted to before the graph"""
|
||||
if tx.export:
|
||||
unimplemented("nn parameter construction not supported with export")
|
||||
unimplemented_v2(
|
||||
gb_type="Attempted to use `torch.nn.Parameter()` with export",
|
||||
context="",
|
||||
explanation="Dynamo does not support this.",
|
||||
hints=[
|
||||
"Do not use `torch.nn.Parameter()` with export.",
|
||||
*graph_break_hints.SUPPORTABLE,
|
||||
],
|
||||
)
|
||||
|
||||
if isinstance(requires_grad, variables.VariableTracker):
|
||||
try:
|
||||
requires_grad = requires_grad.as_python_constant()
|
||||
except NotImplementedError:
|
||||
unimplemented("Parameter(requires_grad=...) not constant")
|
||||
unimplemented_v2(
|
||||
gb_type="non-constant `requires_grad` argument to `torch.nn.Parameter`",
|
||||
context=f"requires_grad={requires_grad}",
|
||||
explanation="Dynamo does not support this.",
|
||||
hints=[
|
||||
"Change `requires_grad` to be a bool.",
|
||||
*graph_break_hints.USER_ERROR,
|
||||
],
|
||||
)
|
||||
|
||||
if not isinstance(data, variables.TensorVariable):
|
||||
unimplemented(f"Parameter(data={data}) not implemented")
|
||||
unimplemented_v2(
|
||||
gb_type="`torch.nn.Parameter()` with unsupported data type",
|
||||
context=f"data={data}",
|
||||
explanation="Called `torch.nn.Parameter()` with non-Tensor argument.",
|
||||
hints=[
|
||||
"Ensure the argument to `torch.nn.Parameter()` is a `torch.Tensor`.",
|
||||
*graph_break_hints.USER_ERROR,
|
||||
],
|
||||
)
|
||||
|
||||
# this results in cleaner graphs, but only works for inputs
|
||||
if data.source:
|
||||
|
|
@ -1473,17 +1633,41 @@ Either create the tensor outside the compiled region, or do not set the tensor t
|
|||
if isinstance(
|
||||
data, TensorWithTFOverrideVariable
|
||||
) or is_traceable_wrapper_subclass_type(data.class_type):
|
||||
unimplemented("Parameter constructor with tensor subclass NYI")
|
||||
unimplemented_v2(
|
||||
gb_type="Attempted to use torch.nn.Parameter constructor with tensor subclass",
|
||||
context=str(data),
|
||||
explanation="Dynamo does not support this.",
|
||||
hints=[
|
||||
*graph_break_hints.SUPPORTABLE,
|
||||
],
|
||||
)
|
||||
|
||||
if not can_convert_to_tracable_parameter():
|
||||
unimplemented("Workaround for issues with nn_parameter construction")
|
||||
unimplemented_v2(
|
||||
gb_type="`torch.nn.Parameter`: cannot convert to traceable tracable",
|
||||
context="",
|
||||
explanation="convert_tracable_parameter is set to False.",
|
||||
hints=[
|
||||
"Check usage of context manager: do_not_convert_to_tracable_parameter",
|
||||
*graph_break_hints.DIFFICULT,
|
||||
],
|
||||
)
|
||||
|
||||
try:
|
||||
shape = tuple(data.var_getattr(tx, "shape").as_python_constant())
|
||||
dtype = data.var_getattr(tx, "dtype").as_python_constant()
|
||||
device = data.var_getattr(tx, "device").as_python_constant()
|
||||
except NotImplementedError as e:
|
||||
unimplemented(f"Parameter not python_constant: {e}")
|
||||
unimplemented_v2(
|
||||
gb_type="`torch.nn.Parameter` with non-constant Tensor attributes",
|
||||
context=f"data={data}",
|
||||
explanation="Dynamo does not support this.",
|
||||
hints=[
|
||||
"Ensure the Tensor argument's shape, dtype, and device are correct.",
|
||||
*graph_break_hints.USER_ERROR,
|
||||
],
|
||||
from_exc=e,
|
||||
)
|
||||
|
||||
placeholder = tx.output.synthetic_graph_input(
|
||||
new_parameter_placeholder, [shape, dtype, device, requires_grad]
|
||||
|
|
@ -1535,8 +1719,13 @@ Either create the tensor outside the compiled region, or do not set the tensor t
|
|||
|
||||
data_node = data.as_proxy().node
|
||||
if data_node.op not in ("placeholder", "get_attr"):
|
||||
unimplemented(
|
||||
"Unexpected type of data placeholder op for parameter construction"
|
||||
unimplemented_v2(
|
||||
gb_type="Unexpected type of data placeholder op for parameter construction",
|
||||
context=f"data_node.op={data_node.op}",
|
||||
explanation="Data node op should be placeholder or get_attr.",
|
||||
hints=[
|
||||
*graph_break_hints.DIFFICULT,
|
||||
],
|
||||
)
|
||||
|
||||
# add the newly constructed nn.Parameter as a graph input
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user