mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Revert "[inductor] Handle the case where kwargs contains tensor (#88215)"
This reverts commit 983c0e7f31.
Reverted https://github.com/pytorch/pytorch/pull/88215 on behalf of https://github.com/huydhn due to Sorry for reverting your PR but I think it breaks trunk https://github.com/pytorch/pytorch/actions/runs/3380662072/jobs/5613987333 with a failure in test_torchinductor_opinfo.py
This commit is contained in:
parent
7354368fd5
commit
a8561c4571
|
|
@ -4083,20 +4083,6 @@ class CommonTemplate:
|
|||
else:
|
||||
self.assertEqual(len(inps), 0)
|
||||
|
||||
@unittest.skipIf(HAS_CUDA, "histogramdd only supports cpu")
|
||||
def test_kwargs(self):
|
||||
def fn(x, y):
|
||||
return torch.histogramdd(
|
||||
x,
|
||||
bins=[3, 3],
|
||||
weight=y,
|
||||
)
|
||||
|
||||
self.common(
|
||||
fn,
|
||||
[torch.randn((4, 2)), torch.randn((4))],
|
||||
)
|
||||
|
||||
|
||||
if HAS_CPU:
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ import textwrap
|
|||
from collections import OrderedDict
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from inspect import signature
|
||||
from typing import Any, Callable, ClassVar, Dict, List, Optional, Set, Tuple, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
|
|
@ -2237,8 +2236,7 @@ class ExternKernel(InputsKernel):
|
|||
|
||||
@classmethod
|
||||
def process_kernel(cls, kernel, *args, **kwargs):
|
||||
binded_args = signature(kernel).bind(*args, **kwargs).arguments
|
||||
args_flat, args_spec = pytree.tree_flatten(binded_args)
|
||||
args_flat, args_spec = pytree.tree_flatten(args)
|
||||
|
||||
is_arg_tensor = []
|
||||
tensor_args = []
|
||||
|
|
@ -2251,16 +2249,15 @@ class ExternKernel(InputsKernel):
|
|||
non_tensor_args.append(arg)
|
||||
|
||||
def unflatten_args(new_tensor_args, new_non_tensor_args):
|
||||
result = []
|
||||
new_args = []
|
||||
it_tensors = iter(new_tensor_args)
|
||||
it_non_tensors = iter(new_non_tensor_args)
|
||||
for is_tensor in is_arg_tensor:
|
||||
if is_tensor:
|
||||
result.append(next(it_tensors))
|
||||
new_args.append(next(it_tensors))
|
||||
else:
|
||||
result.append(next(it_non_tensors))
|
||||
result = pytree.tree_unflatten(result, args_spec)
|
||||
return result.get("args", []), result.get("kwargs", {})
|
||||
new_args.append(next(it_non_tensors))
|
||||
return pytree.tree_unflatten(new_args, args_spec)
|
||||
|
||||
tensor_args = [cls.realize_input(x) for x in tensor_args]
|
||||
|
||||
|
|
@ -2286,8 +2283,9 @@ class ExternKernel(InputsKernel):
|
|||
).zero_()
|
||||
example_args.append(arg)
|
||||
|
||||
new_args, new_kwargs = unflatten_args(example_args, non_tensor_args)
|
||||
example_output = kernel(*new_args, **new_kwargs)
|
||||
example_output = kernel(
|
||||
*unflatten_args(example_args, non_tensor_args), **kwargs
|
||||
)
|
||||
|
||||
return example_output, tensor_args, non_tensor_args, unflatten_args
|
||||
|
||||
|
|
@ -2880,13 +2878,15 @@ class FallbackKernel(ExternKernelAlloc):
|
|||
def __repr__(self):
|
||||
return self.ref
|
||||
|
||||
tensor_args = [Shim(x.codegen_reference()) for x in self.inputs]
|
||||
constant_args = [Shim(repr(x)) for x in self.constant_args]
|
||||
|
||||
def gen_kwarg(k, v):
|
||||
return f"{k}={repr(v)}"
|
||||
|
||||
tensor_args = [Shim(x.codegen_reference()) for x in self.inputs]
|
||||
constant_args = [Shim(repr(x)) for x in self.constant_args]
|
||||
args, kwargs = self.unflatten_args(tensor_args, constant_args)
|
||||
return list(map(repr, args)) + list(gen_kwarg(k, v) for k, v in kwargs.items())
|
||||
kwargs = list(gen_kwarg(k, v) for k, v in self.kwargs.items())
|
||||
|
||||
return list(map(repr, self.unflatten_args(tensor_args, constant_args))) + kwargs
|
||||
|
||||
@classmethod
|
||||
def create(cls, kernel, *args, **kwargs):
|
||||
|
|
|
|||
|
|
@ -164,10 +164,8 @@ def _register_lowering(
|
|||
args = args[0]
|
||||
# Only look at args that are Tensors
|
||||
indices = [i for i, x in enumerate(args) if isinstance(x, TensorBox)]
|
||||
# kwargs tensors not supported yet unless it's a fallback op
|
||||
assert not any(isinstance(x, TensorBox) for x in kwargs.values()) or all(
|
||||
fn in fallbacks for fn in aten_fn
|
||||
)
|
||||
# kwargs tensors not supported yet
|
||||
assert not any(isinstance(x, TensorBox) for x in kwargs.values())
|
||||
|
||||
if (type_promotion_kind or convert_input_to_bool) and indices:
|
||||
if convert_input_to_bool:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user