mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Dynamo] Handle extracted unbound tensor methods (#137227)
fixes2 Pull Request resolved: https://github.com/pytorch/pytorch/pull/137227 Approved by: https://github.com/williamwen42, https://github.com/anijain2305 ghstack dependencies: #137114, #137115, #137116, #137117, #137120
This commit is contained in:
parent
b3f30c9bc3
commit
0a304d9048
|
|
@ -938,6 +938,16 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
|
||||||
else:
|
else:
|
||||||
return x - 1
|
return x - 1
|
||||||
|
|
||||||
|
@make_test
|
||||||
|
def test_tensor_size(x):
|
||||||
|
fn = torch.Tensor.size
|
||||||
|
return fn(x + 1)
|
||||||
|
|
||||||
|
@make_test
|
||||||
|
def test_tensor_dim(x):
|
||||||
|
fn = torch.Tensor.dim
|
||||||
|
return fn(x + 1)
|
||||||
|
|
||||||
@make_test
|
@make_test
|
||||||
def test_tensor_is_inference(x):
|
def test_tensor_is_inference(x):
|
||||||
if x.is_inference():
|
if x.is_inference():
|
||||||
|
|
|
||||||
|
|
@ -772,7 +772,7 @@ class TensorVariable(VariableTracker):
|
||||||
self._warn_capture_scalar_outputs()
|
self._warn_capture_scalar_outputs()
|
||||||
unimplemented("Tensor.item")
|
unimplemented("Tensor.item")
|
||||||
|
|
||||||
def method_getitem(self, *args, **kwargs):
|
def method___getitem__(self, *args, **kwargs):
|
||||||
from ..symbolic_convert import InstructionTranslator
|
from ..symbolic_convert import InstructionTranslator
|
||||||
from .builder import wrap_fx_proxy
|
from .builder import wrap_fx_proxy
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -871,10 +871,6 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
||||||
|
|
||||||
return ConstantVariable.create(None)
|
return ConstantVariable.create(None)
|
||||||
|
|
||||||
@register(torch._C.TensorBase.__getitem__)
|
|
||||||
def handle_getitem(self, tx: "InstructionTranslator", *args, **kwargs):
|
|
||||||
return args[0].call_method(tx, "getitem", args[1:], kwargs)
|
|
||||||
|
|
||||||
return handlers
|
return handlers
|
||||||
|
|
||||||
def call_function(
|
def call_function(
|
||||||
|
|
@ -904,6 +900,9 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.is_tensor_method():
|
||||||
|
return self.call_tensor_method(tx, args, kwargs)
|
||||||
|
|
||||||
special_handler = self._get_handlers().get(self.value)
|
special_handler = self._get_handlers().get(self.value)
|
||||||
if special_handler:
|
if special_handler:
|
||||||
result = special_handler(self, tx, *args, **kwargs)
|
result = special_handler(self, tx, *args, **kwargs)
|
||||||
|
|
@ -1176,6 +1175,16 @@ Either create the tensor outside the compiled region, or do not set the tensor t
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def call_tensor_method(self, tx, args, kwargs):
|
||||||
|
return args[0].call_method(tx, self.get_function().__name__, args[1:], kwargs)
|
||||||
|
|
||||||
|
def is_tensor_method(self):
|
||||||
|
return (
|
||||||
|
inspect.ismethoddescriptor(self.get_function())
|
||||||
|
and hasattr(self.get_function(), "__objclass__")
|
||||||
|
and self.get_function().__objclass__ == torch._C.TensorBase
|
||||||
|
)
|
||||||
|
|
||||||
def torch_function_override_enabled(self, tx, args, kwargs):
|
def torch_function_override_enabled(self, tx, args, kwargs):
|
||||||
return (
|
return (
|
||||||
self.get_function() in get_overridable_functions()
|
self.get_function() in get_overridable_functions()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user