Revert "[inductor] Handle the case where kwargs contains tensor (#88215)"

This reverts commit 983c0e7f31.

Reverted https://github.com/pytorch/pytorch/pull/88215 on behalf of https://github.com/huydhn due to Sorry for reverting your PR but I think it breaks trunk https://github.com/pytorch/pytorch/actions/runs/3380662072/jobs/5613987333 with a failure in test_torchinductor_opinfo.py
This commit is contained in:
PyTorch MergeBot 2022-11-02 23:33:15 +00:00
parent 7354368fd5
commit a8561c4571
3 changed files with 16 additions and 32 deletions

View File

@ -4083,20 +4083,6 @@ class CommonTemplate:
else:
self.assertEqual(len(inps), 0)
@unittest.skipIf(HAS_CUDA, "histogramdd only supports cpu")
def test_kwargs(self):
def fn(x, y):
return torch.histogramdd(
x,
bins=[3, 3],
weight=y,
)
self.common(
fn,
[torch.randn((4, 2)), torch.randn((4))],
)
if HAS_CPU:

View File

@ -8,7 +8,6 @@ import textwrap
from collections import OrderedDict
from enum import Enum
from functools import partial
from inspect import signature
from typing import Any, Callable, ClassVar, Dict, List, Optional, Set, Tuple, Union
from unittest.mock import patch
@ -2237,8 +2236,7 @@ class ExternKernel(InputsKernel):
@classmethod
def process_kernel(cls, kernel, *args, **kwargs):
binded_args = signature(kernel).bind(*args, **kwargs).arguments
args_flat, args_spec = pytree.tree_flatten(binded_args)
args_flat, args_spec = pytree.tree_flatten(args)
is_arg_tensor = []
tensor_args = []
@ -2251,16 +2249,15 @@ class ExternKernel(InputsKernel):
non_tensor_args.append(arg)
def unflatten_args(new_tensor_args, new_non_tensor_args):
result = []
new_args = []
it_tensors = iter(new_tensor_args)
it_non_tensors = iter(new_non_tensor_args)
for is_tensor in is_arg_tensor:
if is_tensor:
result.append(next(it_tensors))
new_args.append(next(it_tensors))
else:
result.append(next(it_non_tensors))
result = pytree.tree_unflatten(result, args_spec)
return result.get("args", []), result.get("kwargs", {})
new_args.append(next(it_non_tensors))
return pytree.tree_unflatten(new_args, args_spec)
tensor_args = [cls.realize_input(x) for x in tensor_args]
@ -2286,8 +2283,9 @@ class ExternKernel(InputsKernel):
).zero_()
example_args.append(arg)
new_args, new_kwargs = unflatten_args(example_args, non_tensor_args)
example_output = kernel(*new_args, **new_kwargs)
example_output = kernel(
*unflatten_args(example_args, non_tensor_args), **kwargs
)
return example_output, tensor_args, non_tensor_args, unflatten_args
@ -2880,13 +2878,15 @@ class FallbackKernel(ExternKernelAlloc):
def __repr__(self):
return self.ref
tensor_args = [Shim(x.codegen_reference()) for x in self.inputs]
constant_args = [Shim(repr(x)) for x in self.constant_args]
def gen_kwarg(k, v):
return f"{k}={repr(v)}"
tensor_args = [Shim(x.codegen_reference()) for x in self.inputs]
constant_args = [Shim(repr(x)) for x in self.constant_args]
args, kwargs = self.unflatten_args(tensor_args, constant_args)
return list(map(repr, args)) + list(gen_kwarg(k, v) for k, v in kwargs.items())
kwargs = list(gen_kwarg(k, v) for k, v in self.kwargs.items())
return list(map(repr, self.unflatten_args(tensor_args, constant_args))) + kwargs
@classmethod
def create(cls, kernel, *args, **kwargs):

View File

@ -164,10 +164,8 @@ def _register_lowering(
args = args[0]
# Only look at args that are Tensors
indices = [i for i, x in enumerate(args) if isinstance(x, TensorBox)]
# kwargs tensors not supported yet unless it's a fallback op
assert not any(isinstance(x, TensorBox) for x in kwargs.values()) or all(
fn in fallbacks for fn in aten_fn
)
# kwargs tensors not supported yet
assert not any(isinstance(x, TensorBox) for x in kwargs.values())
if (type_promotion_kind or convert_input_to_bool) and indices:
if convert_input_to_bool: