mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Dynamo] Handle extracted unbound tensor methods (#137227)
fixes2 Pull Request resolved: https://github.com/pytorch/pytorch/pull/137227 Approved by: https://github.com/williamwen42, https://github.com/anijain2305 ghstack dependencies: #137114, #137115, #137116, #137117, #137120
This commit is contained in:
parent
b3f30c9bc3
commit
0a304d9048
|
|
@ -938,6 +938,16 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
|
|||
else:
|
||||
return x - 1
|
||||
|
||||
@make_test
|
||||
def test_tensor_size(x):
|
||||
fn = torch.Tensor.size
|
||||
return fn(x + 1)
|
||||
|
||||
@make_test
|
||||
def test_tensor_dim(x):
|
||||
fn = torch.Tensor.dim
|
||||
return fn(x + 1)
|
||||
|
||||
@make_test
|
||||
def test_tensor_is_inference(x):
|
||||
if x.is_inference():
|
||||
|
|
|
|||
|
|
@ -772,7 +772,7 @@ class TensorVariable(VariableTracker):
|
|||
self._warn_capture_scalar_outputs()
|
||||
unimplemented("Tensor.item")
|
||||
|
||||
def method_getitem(self, *args, **kwargs):
|
||||
def method___getitem__(self, *args, **kwargs):
|
||||
from ..symbolic_convert import InstructionTranslator
|
||||
from .builder import wrap_fx_proxy
|
||||
|
||||
|
|
|
|||
|
|
@ -871,10 +871,6 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
|||
|
||||
return ConstantVariable.create(None)
|
||||
|
||||
@register(torch._C.TensorBase.__getitem__)
|
||||
def handle_getitem(self, tx: "InstructionTranslator", *args, **kwargs):
|
||||
return args[0].call_method(tx, "getitem", args[1:], kwargs)
|
||||
|
||||
return handlers
|
||||
|
||||
def call_function(
|
||||
|
|
@ -904,6 +900,9 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
|||
),
|
||||
)
|
||||
|
||||
if self.is_tensor_method():
|
||||
return self.call_tensor_method(tx, args, kwargs)
|
||||
|
||||
special_handler = self._get_handlers().get(self.value)
|
||||
if special_handler:
|
||||
result = special_handler(self, tx, *args, **kwargs)
|
||||
|
|
@ -1176,6 +1175,16 @@ Either create the tensor outside the compiled region, or do not set the tensor t
|
|||
)
|
||||
return result
|
||||
|
||||
def call_tensor_method(self, tx, args, kwargs):
|
||||
return args[0].call_method(tx, self.get_function().__name__, args[1:], kwargs)
|
||||
|
||||
def is_tensor_method(self):
|
||||
return (
|
||||
inspect.ismethoddescriptor(self.get_function())
|
||||
and hasattr(self.get_function(), "__objclass__")
|
||||
and self.get_function().__objclass__ == torch._C.TensorBase
|
||||
)
|
||||
|
||||
def torch_function_override_enabled(self, tx, args, kwargs):
|
||||
return (
|
||||
self.get_function() in get_overridable_functions()
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user