mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Change capture_scalar_outputs to use SymInt/SymFloat rather than Tensor to model scalars (#93150)
Previously, Dynamo faked support for item() when `capture_scalar_outputs` was True by representing it internally as a Tensor. With dynamic shapes, this is no longer necessary; we can represent it directly as a SymInt/SymFloat. Do so. Doing this requires you to use dynamic shapes; in principle we could support scalar outputs WITHOUT dynamic shapes but I won't do this unless someone hollers for it. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Differential Revision: [D42885775](https://our.internmc.facebook.com/intern/diff/D42885775) Pull Request resolved: https://github.com/pytorch/pytorch/pull/93150 Approved by: https://github.com/voznesenskym
This commit is contained in:
parent
76b683b008
commit
902b4dba75
|
|
@ -320,6 +320,7 @@ class ExportTests(torch._dynamo.test_case.TestCase):
|
|||
|
||||
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
|
||||
|
||||
@patch.object(torch._dynamo.config, "dynamic_shapes", True)
|
||||
@patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
|
||||
def test_dupes_and_bypass_with_non_tensor_output(self):
|
||||
inp = torch.tensor([0.1, 0.1])
|
||||
|
|
@ -366,6 +367,7 @@ class ExportTests(torch._dynamo.test_case.TestCase):
|
|||
|
||||
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
|
||||
|
||||
@patch.object(torch._dynamo.config, "dynamic_shapes", True)
|
||||
@patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
|
||||
def test_zeroes_in_new_shape_scalar_out(self):
|
||||
inp = torch.zeros(10)
|
||||
|
|
@ -390,6 +392,7 @@ class ExportTests(torch._dynamo.test_case.TestCase):
|
|||
|
||||
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
|
||||
|
||||
@patch.object(torch._dynamo.config, "dynamic_shapes", True)
|
||||
@patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
|
||||
def test_zeroes_in_new_shape_scalar_out_permute(self):
|
||||
inp = torch.zeros(10)
|
||||
|
|
@ -414,6 +417,7 @@ class ExportTests(torch._dynamo.test_case.TestCase):
|
|||
|
||||
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
|
||||
|
||||
@patch.object(torch._dynamo.config, "dynamic_shapes", True)
|
||||
@patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
|
||||
def test_zeroes_in_new_shape_scalar_out_permute_dupe_and_bypass(self):
|
||||
inp = torch.zeros(10)
|
||||
|
|
@ -771,6 +775,7 @@ class ExportTests(torch._dynamo.test_case.TestCase):
|
|||
|
||||
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
|
||||
|
||||
@patch.object(torch._dynamo.config, "dynamic_shapes", True)
|
||||
@patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
|
||||
def test_dupes_and_bypass_with_non_tensor_output_with_aten_graph(self):
|
||||
inp = torch.tensor([0.1, 0.1])
|
||||
|
|
@ -1421,6 +1426,7 @@ class ExportTests(torch._dynamo.test_case.TestCase):
|
|||
f, (torch.randn(5)), aten_graph=False, tracing_mode="symbolic"
|
||||
)
|
||||
|
||||
@patch.object(torch._dynamo.config, "dynamic_shapes", True)
|
||||
@patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
|
||||
def test_export_with_module_layer(self):
|
||||
from functorch.experimental.control_flow import cond
|
||||
|
|
@ -1634,6 +1640,7 @@ class ExportTests(torch._dynamo.test_case.TestCase):
|
|||
)
|
||||
|
||||
@patch.object(torch._dynamo.config, "dynamic_shapes", True)
|
||||
@patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
|
||||
def test_dynamic_slicing_simple(self):
|
||||
def f(x):
|
||||
return x[slice(None, None, None)]
|
||||
|
|
@ -1645,6 +1652,8 @@ class ExportTests(torch._dynamo.test_case.TestCase):
|
|||
inp = torch.randn(6, 7)
|
||||
self.assertEqual(gm(inp), f(inp))
|
||||
|
||||
@patch.object(torch._dynamo.config, "dynamic_shapes", True)
|
||||
@patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
|
||||
def test_export_cond_in_aten_symbolic(self):
|
||||
class ConditionOp(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
|
|
|||
|
|
@ -448,6 +448,7 @@ class MiscTests(torch._dynamo.test_case.TestCase):
|
|||
self, fn=fn, nargs=1, expected_ops=5, expected_ops_dynamic=8
|
||||
)
|
||||
|
||||
@patch.object(torch._dynamo.config, "dynamic_shapes", True)
|
||||
@patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
|
||||
def test_tensor_item_capture(self):
|
||||
def fn(a, b):
|
||||
|
|
@ -462,6 +463,7 @@ class MiscTests(torch._dynamo.test_case.TestCase):
|
|||
self.assertEqual(cnts.frame_count, 1)
|
||||
self.assertEqual(cnts.op_count, 3)
|
||||
|
||||
@patch.object(torch._dynamo.config, "dynamic_shapes", True)
|
||||
@patch.object(torch._dynamo.config, "capture_scalar_outputs", False)
|
||||
def test_tensor_item_no_capture(self):
|
||||
def fn(a, b):
|
||||
|
|
@ -2035,6 +2037,7 @@ class MiscTests(torch._dynamo.test_case.TestCase):
|
|||
opt_f(x, n)
|
||||
self.assertEqual(cnts.frame_count, 1)
|
||||
|
||||
@patch.object(torch._dynamo.config, "dynamic_shapes", True)
|
||||
@patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
|
||||
def test_item(self):
|
||||
class MyMod(torch.nn.Module):
|
||||
|
|
@ -2048,6 +2051,7 @@ class MiscTests(torch._dynamo.test_case.TestCase):
|
|||
|
||||
self.assertEqual(y, 11)
|
||||
|
||||
@patch.object(torch._dynamo.config, "dynamic_shapes", True)
|
||||
@patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
|
||||
def test_item_changes(self):
|
||||
class MyMod(torch.nn.Module):
|
||||
|
|
@ -2064,6 +2068,7 @@ class MiscTests(torch._dynamo.test_case.TestCase):
|
|||
self.assertEqual(y, 11)
|
||||
self.assertEqual(z, 61)
|
||||
|
||||
@patch.object(torch._dynamo.config, "dynamic_shapes", True)
|
||||
@patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
|
||||
def test_item_changes_new_shape(self):
|
||||
class MyMod(torch.nn.Module):
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ except ImportError:
|
|||
from torch import nn
|
||||
from torch._dynamo.debug_utils import same_two_models
|
||||
from torch._dynamo.testing import rand_strided, requires_static_shapes, same
|
||||
from torch._dynamo.utils import ifdyn
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
|
|
@ -42,13 +43,6 @@ def is_fx_tracing_test() -> bool:
|
|||
return torch.nn.Module.__call__ is not _orig_module_call
|
||||
|
||||
|
||||
def ifdyn(count1, count2):
|
||||
if torch._dynamo.config.dynamic_shapes:
|
||||
return count1
|
||||
else:
|
||||
return count2
|
||||
|
||||
|
||||
def has_detectron2():
|
||||
try:
|
||||
from detectron2.layers.mask_ops import _paste_masks_tensor_shape
|
||||
|
|
@ -948,6 +942,7 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
|||
# uncomment/adjust the assertEqual below
|
||||
@unittest.expectedFailure
|
||||
@patch.object(torch._dynamo.config, "fake_tensor_propagation", True)
|
||||
@patch.object(torch._dynamo.config, "dynamic_shapes", True)
|
||||
@patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
|
||||
def test_maml_item_capture(self):
|
||||
a = torch.randn(5, 1, 28, 28)
|
||||
|
|
@ -966,6 +961,7 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
|||
self.assertIn(cnt.op_count, (36, 35, 34, 29, 28, 27))
|
||||
|
||||
# see: https://github.com/pytorch/pytorch/issues/80067
|
||||
@patch.object(torch._dynamo.config, "dynamic_shapes", True)
|
||||
@patch.object(torch._dynamo.config, "capture_scalar_outputs", False)
|
||||
def test_maml_no_item_capture(self):
|
||||
a = torch.randn(5, 1, 28, 28)
|
||||
|
|
@ -979,7 +975,7 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
|||
for _ in range(10):
|
||||
self.assertTrue(same(opt_model(a, b, c, d), correct))
|
||||
|
||||
self.assertEqual(cnt.frame_count, ifdyn(5, 4))
|
||||
self.assertEqual(cnt.frame_count, 5)
|
||||
# TODO(jansel): figure out why op count depends on imports
|
||||
self.assertIn(cnt.op_count, (31, 36, 35, 34, 29, 28))
|
||||
|
||||
|
|
|
|||
|
|
@ -439,6 +439,7 @@ class SubGraphTests(torch._dynamo.test_case.TestCase):
|
|||
self.assertEqual(opt_fn(x), fn(x))
|
||||
self.assertEqual(cnt_dynamic.frame_count, 2)
|
||||
|
||||
@patch.object(torch._dynamo.config, "dynamic_shapes", True)
|
||||
@patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
|
||||
def test_no_graph_break_on_item(self):
|
||||
def fn(a, b):
|
||||
|
|
@ -450,6 +451,7 @@ class SubGraphTests(torch._dynamo.test_case.TestCase):
|
|||
|
||||
self._common(fn, 1, 6)
|
||||
|
||||
@patch.object(torch._dynamo.config, "dynamic_shapes", True)
|
||||
@patch.object(torch._dynamo.config, "capture_scalar_outputs", False)
|
||||
def test_graph_break_on_item(self):
|
||||
def fn(a, b):
|
||||
|
|
|
|||
|
|
@ -141,6 +141,7 @@ repro_level = int(os.environ.get("TORCHDYNAMO_REPRO_LEVEL", 2))
|
|||
|
||||
# Not all backends support scalars. Some calls on torch.Tensor (like .item()) return a scalar type.
|
||||
# When this flag is set to False, we introduce a graph break instead of capturing.
|
||||
# This requires dynamic_shapes to be True.
|
||||
capture_scalar_outputs = False
|
||||
|
||||
# Should almost always be true in prod. This relaxes the requirement that cond's true_fn and
|
||||
|
|
|
|||
|
|
@ -1278,3 +1278,10 @@ def fqn(obj: Any):
|
|||
Returns the fully qualified name of the object.
|
||||
"""
|
||||
return f"{obj.__module__}.{obj.__qualname__}"
|
||||
|
||||
|
||||
def ifdyn(count1, count2):
|
||||
if torch._dynamo.config.dynamic_shapes:
|
||||
return count1
|
||||
else:
|
||||
return count2
|
||||
|
|
|
|||
|
|
@ -3,8 +3,6 @@ import dataclasses
|
|||
import enum
|
||||
import functools
|
||||
import inspect
|
||||
import math
|
||||
import numbers
|
||||
import operator
|
||||
import re
|
||||
import types
|
||||
|
|
@ -90,7 +88,6 @@ from .misc import (
|
|||
from .nn_module import UnspecializedNNModuleVariable
|
||||
from .tensor import (
|
||||
DynamicShapeVariable,
|
||||
FakeItemVariable,
|
||||
TensorVariable,
|
||||
TensorWithTFOverrideVariable,
|
||||
UnspecializedPythonVariable,
|
||||
|
|
@ -930,19 +927,6 @@ def wrap_fx_proxy_cls(
|
|||
):
|
||||
proxy.node.meta["example_value"] = example_value
|
||||
return ConstantVariable(example_value, **options)
|
||||
elif (
|
||||
isinstance(example_value, numbers.Number)
|
||||
and (proxy.node.target == "item" or proxy.node.target in {math.sqrt, math.pow})
|
||||
and config.capture_scalar_outputs
|
||||
):
|
||||
# item raw value should not be accessed
|
||||
return wrap_fx_proxy_cls(
|
||||
FakeItemVariable,
|
||||
tx=tx,
|
||||
proxy=proxy,
|
||||
example_value=torch.tensor(example_value),
|
||||
**options,
|
||||
)
|
||||
elif isinstance(example_value, (torch.SymInt, torch.SymFloat)):
|
||||
proxy.node.meta["example_value"] = example_value
|
||||
return DynamicShapeVariable(proxy, example_value, **options)
|
||||
|
|
|
|||
|
|
@ -319,22 +319,16 @@ class TensorVariable(VariableTracker):
|
|||
unimplemented(f"Tensor.{name}")
|
||||
elif name == "nonzero" and not config.dynamic_shapes:
|
||||
unimplemented(f"Tensor.{name}")
|
||||
elif name == "item":
|
||||
if config.capture_scalar_outputs:
|
||||
example_value = get_fake_value(self.proxy.node, tx)
|
||||
return wrap_fx_proxy(
|
||||
tx,
|
||||
tx.output.create_proxy(
|
||||
"call_method",
|
||||
"item",
|
||||
(self.as_proxy(),),
|
||||
{},
|
||||
),
|
||||
example_value=example_value,
|
||||
**options,
|
||||
)
|
||||
else:
|
||||
elif name == "item" and not config.capture_scalar_outputs:
|
||||
unimplemented(f"Tensor.{name}")
|
||||
elif (
|
||||
name == "item"
|
||||
and config.capture_scalar_outputs
|
||||
and not config.dynamic_shapes
|
||||
):
|
||||
raise AssertionError(
|
||||
"To capture_scalar_outputs, you must also set dynamic_shapes = True"
|
||||
)
|
||||
elif name == "__len__":
|
||||
return self.call_method(tx, "size", [ConstantVariable(0, **options)], {})
|
||||
elif name == "__setitem__":
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user