mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Extend torch function support to ALL arguments, not just scalar type (but not insides of list) (#145089)
Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/145089 Approved by: https://github.com/albanD, https://github.com/zou3519
This commit is contained in:
parent
195b5c2e27
commit
03b254e49f
|
|
@ -4660,7 +4660,6 @@ class TestFunctionalTracing(JitTestCase):
|
|||
"linear": BUILT_IN_FUNC,
|
||||
"logsigmoid": BUILT_IN_FUNC,
|
||||
"one_hot": BUILT_IN_FUNC,
|
||||
"pad": ARG_TYPE_MISMATCH,
|
||||
"pairwise_distance": BUILT_IN_FUNC,
|
||||
"pdist": BUILT_IN_FUNC,
|
||||
"pixel_shuffle": BUILT_IN_FUNC,
|
||||
|
|
@ -4693,12 +4692,6 @@ class TestFunctionalTracing(JitTestCase):
|
|||
"max_unpool3d": PROXY_ITERATED,
|
||||
"fold": PROXY_ITERATED,
|
||||
"unfold": PROXY_ITERATED,
|
||||
"adaptive_max_pool1d_with_indices": ARG_TYPE_MISMATCH,
|
||||
"fractional_max_pool2d_with_indices": ARG_TYPE_MISMATCH,
|
||||
"fractional_max_pool3d_with_indices": ARG_TYPE_MISMATCH,
|
||||
"layer_norm": ARG_TYPE_MISMATCH,
|
||||
"rms_norm": ARG_TYPE_MISMATCH,
|
||||
"lp_pool1d": ARG_TYPE_MISMATCH,
|
||||
"affine_grid": CONTROL_FLOW,
|
||||
"alpha_dropout": CONTROL_FLOW,
|
||||
"batch_norm": CONTROL_FLOW,
|
||||
|
|
@ -4732,9 +4725,6 @@ class TestFunctionalTracing(JitTestCase):
|
|||
"leaky_relu": CONTROL_FLOW,
|
||||
"local_response_norm": CONTROL_FLOW,
|
||||
"margin_ranking_loss": CONTROL_FLOW,
|
||||
"max_pool1d_with_indices": ARG_TYPE_MISMATCH,
|
||||
"max_pool2d_with_indices": ARG_TYPE_MISMATCH,
|
||||
"max_pool3d_with_indices": ARG_TYPE_MISMATCH,
|
||||
"mse_loss": CONTROL_FLOW,
|
||||
"multi_head_attention_forward": CONTROL_FLOW,
|
||||
"multi_margin_loss": CONTROL_FLOW,
|
||||
|
|
|
|||
|
|
@ -938,6 +938,27 @@ auto FunctionParameter::check(
|
|||
std::vector<PyObject*>& overloaded_args,
|
||||
int argnum,
|
||||
int64_t* failed_idx) -> bool {
|
||||
if (_check(obj, overloaded_args, argnum, failed_idx)) {
|
||||
return true;
|
||||
}
|
||||
// NB: This will not detect torch function inside elements of a list. So
|
||||
// you still have to handle that manually
|
||||
// NB: torch function on Tensor subclasses NOT eligible here, you handled
|
||||
// that internally
|
||||
if (check_has_torch_function(obj, /*ignore_mode*/ true) &&
|
||||
!THPVariable_Check(obj)) {
|
||||
// unrelated objects with __torch_function__
|
||||
append_overloaded_arg(&overloaded_args, obj, /*obj_is_type*/ false);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
auto FunctionParameter::_check(
|
||||
PyObject* obj,
|
||||
std::vector<PyObject*>& overloaded_args,
|
||||
int argnum,
|
||||
int64_t* failed_idx) -> bool {
|
||||
switch (type_) {
|
||||
case ParameterType::TENSOR: {
|
||||
if (is_tensor_and_append_overloaded(obj, &overloaded_args)) {
|
||||
|
|
@ -1013,15 +1034,7 @@ auto FunctionParameter::check(
|
|||
case ParameterType::PYOBJECT:
|
||||
return true;
|
||||
case ParameterType::SCALARTYPE:
|
||||
if (THPDtype_Check(obj) || THPPythonScalarType_Check(obj)) {
|
||||
return true;
|
||||
}
|
||||
if (check_has_torch_function(obj, /*ignore_mode*/ true)) {
|
||||
// tensor subclasses and unrelated objects with __torch_function__
|
||||
append_overloaded_arg(&overloaded_args, obj, /*obj_is_type*/ false);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
return THPDtype_Check(obj) || THPPythonScalarType_Check(obj);
|
||||
case ParameterType::LAYOUT:
|
||||
return THPLayout_Check(obj);
|
||||
case ParameterType::MEMORY_FORMAT:
|
||||
|
|
|
|||
|
|
@ -322,6 +322,12 @@ struct FunctionParameter {
|
|||
int argnum,
|
||||
int64_t* failed_idx = nullptr);
|
||||
|
||||
bool _check(
|
||||
PyObject* obj,
|
||||
std::vector<PyObject*>& overloaded_args,
|
||||
int argnum,
|
||||
int64_t* failed_idx = nullptr);
|
||||
|
||||
void set_default_str(const std::string& str);
|
||||
TORCH_PYTHON_API std::string type_name() const;
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user