Extend torch function support to ALL arguments, not just scalar type (but not insides of list) (#145089)

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145089
Approved by: https://github.com/albanD, https://github.com/zou3519
This commit is contained in:
Edward Z. Yang 2025-08-07 13:20:47 -07:00 committed by PyTorch MergeBot
parent 195b5c2e27
commit 03b254e49f
3 changed files with 28 additions and 19 deletions

View File

@ -4660,7 +4660,6 @@ class TestFunctionalTracing(JitTestCase):
"linear": BUILT_IN_FUNC,
"logsigmoid": BUILT_IN_FUNC,
"one_hot": BUILT_IN_FUNC,
"pad": ARG_TYPE_MISMATCH,
"pairwise_distance": BUILT_IN_FUNC,
"pdist": BUILT_IN_FUNC,
"pixel_shuffle": BUILT_IN_FUNC,
@ -4693,12 +4692,6 @@ class TestFunctionalTracing(JitTestCase):
"max_unpool3d": PROXY_ITERATED,
"fold": PROXY_ITERATED,
"unfold": PROXY_ITERATED,
"adaptive_max_pool1d_with_indices": ARG_TYPE_MISMATCH,
"fractional_max_pool2d_with_indices": ARG_TYPE_MISMATCH,
"fractional_max_pool3d_with_indices": ARG_TYPE_MISMATCH,
"layer_norm": ARG_TYPE_MISMATCH,
"rms_norm": ARG_TYPE_MISMATCH,
"lp_pool1d": ARG_TYPE_MISMATCH,
"affine_grid": CONTROL_FLOW,
"alpha_dropout": CONTROL_FLOW,
"batch_norm": CONTROL_FLOW,
@ -4732,9 +4725,6 @@ class TestFunctionalTracing(JitTestCase):
"leaky_relu": CONTROL_FLOW,
"local_response_norm": CONTROL_FLOW,
"margin_ranking_loss": CONTROL_FLOW,
"max_pool1d_with_indices": ARG_TYPE_MISMATCH,
"max_pool2d_with_indices": ARG_TYPE_MISMATCH,
"max_pool3d_with_indices": ARG_TYPE_MISMATCH,
"mse_loss": CONTROL_FLOW,
"multi_head_attention_forward": CONTROL_FLOW,
"multi_margin_loss": CONTROL_FLOW,

View File

@ -938,6 +938,27 @@ auto FunctionParameter::check(
std::vector<PyObject*>& overloaded_args,
int argnum,
int64_t* failed_idx) -> bool {
if (_check(obj, overloaded_args, argnum, failed_idx)) {
return true;
}
// NB: This will not detect torch function inside elements of a list. So
// you still have to handle that manually
// NB: torch function on Tensor subclasses NOT eligible here, you handled
// that internally
if (check_has_torch_function(obj, /*ignore_mode*/ true) &&
!THPVariable_Check(obj)) {
// unrelated objects with __torch_function__
append_overloaded_arg(&overloaded_args, obj, /*obj_is_type*/ false);
return true;
}
return false;
}
auto FunctionParameter::_check(
PyObject* obj,
std::vector<PyObject*>& overloaded_args,
int argnum,
int64_t* failed_idx) -> bool {
switch (type_) {
case ParameterType::TENSOR: {
if (is_tensor_and_append_overloaded(obj, &overloaded_args)) {
@ -1013,15 +1034,7 @@ auto FunctionParameter::check(
case ParameterType::PYOBJECT:
return true;
case ParameterType::SCALARTYPE:
if (THPDtype_Check(obj) || THPPythonScalarType_Check(obj)) {
return true;
}
if (check_has_torch_function(obj, /*ignore_mode*/ true)) {
// tensor subclasses and unrelated objects with __torch_function__
append_overloaded_arg(&overloaded_args, obj, /*obj_is_type*/ false);
return true;
}
return false;
return THPDtype_Check(obj) || THPPythonScalarType_Check(obj);
case ParameterType::LAYOUT:
return THPLayout_Check(obj);
case ParameterType::MEMORY_FORMAT:

View File

@ -322,6 +322,12 @@ struct FunctionParameter {
int argnum,
int64_t* failed_idx = nullptr);
bool _check(
PyObject* obj,
std::vector<PyObject*>& overloaded_args,
int argnum,
int64_t* failed_idx = nullptr);
void set_default_str(const std::string& str);
TORCH_PYTHON_API std::string type_name() const;