[Dynamo] Handle extracted unbound tensor methods (#137227)

fixes2

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137227
Approved by: https://github.com/williamwen42, https://github.com/anijain2305
ghstack dependencies: #137114, #137115, #137116, #137117, #137120
This commit is contained in:
Michael Lazos 2024-10-08 14:11:04 -07:00 committed by PyTorch MergeBot
parent b3f30c9bc3
commit 0a304d9048
15 changed files with 24 additions and 5 deletions

View File

@ -938,6 +938,16 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
else:
return x - 1
@make_test
def test_tensor_size(x):
fn = torch.Tensor.size
return fn(x + 1)
@make_test
def test_tensor_dim(x):
fn = torch.Tensor.dim
return fn(x + 1)
@make_test
def test_tensor_is_inference(x):
if x.is_inference():

View File

@ -772,7 +772,7 @@ class TensorVariable(VariableTracker):
self._warn_capture_scalar_outputs()
unimplemented("Tensor.item")
def method_getitem(self, *args, **kwargs):
def method___getitem__(self, *args, **kwargs):
from ..symbolic_convert import InstructionTranslator
from .builder import wrap_fx_proxy

View File

@ -871,10 +871,6 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
return ConstantVariable.create(None)
@register(torch._C.TensorBase.__getitem__)
def handle_getitem(self, tx: "InstructionTranslator", *args, **kwargs):
return args[0].call_method(tx, "getitem", args[1:], kwargs)
return handlers
def call_function(
@ -904,6 +900,9 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
),
)
if self.is_tensor_method():
return self.call_tensor_method(tx, args, kwargs)
special_handler = self._get_handlers().get(self.value)
if special_handler:
result = special_handler(self, tx, *args, **kwargs)
@ -1176,6 +1175,16 @@ Either create the tensor outside the compiled region, or do not set the tensor t
)
return result
def call_tensor_method(self, tx, args, kwargs):
return args[0].call_method(tx, self.get_function().__name__, args[1:], kwargs)
def is_tensor_method(self):
return (
inspect.ismethoddescriptor(self.get_function())
and hasattr(self.get_function(), "__objclass__")
and self.get_function().__objclass__ == torch._C.TensorBase
)
def torch_function_override_enabled(self, tx, args, kwargs):
return (
self.get_function() in get_overridable_functions()