mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
The `usort` config in `pyproject.toml` has no effect due to a typo. Fixing the typo make `usort` do more and generate the changes in the PR. Except `pyproject.toml`, all changes are generated by `lintrunner -a --take UFMT --all-files`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127124 Approved by: https://github.com/Skylion007 ghstack dependencies: #127122, #127123
207 lines
6.7 KiB
Python
207 lines
6.7 KiB
Python
# Owner(s): ["module: codegen"]
|
|
|
|
import textwrap
|
|
import unittest
|
|
from typing import cast
|
|
|
|
import expecttest
|
|
import yaml
|
|
|
|
import torchgen.dest as dest
|
|
import torchgen.gen as gen
|
|
from torchgen.gen import LineLoader, parse_native_yaml_struct
|
|
from torchgen.model import (
|
|
Annotation,
|
|
CustomClassType,
|
|
DispatchKey,
|
|
NativeFunctionsGroup,
|
|
Type,
|
|
)
|
|
|
|
|
|
class TestCodegenModel(expecttest.TestCase):
|
|
def assertParseErrorInline(self, yaml_str: str, expect: str) -> None:
|
|
es = yaml.load(yaml_str, Loader=LineLoader)
|
|
try:
|
|
parse_native_yaml_struct(es, set())
|
|
except AssertionError as e:
|
|
# hack to strip out the context
|
|
msg, _ = str(e).split(" in ", 2)
|
|
self.assertExpectedInline("\n".join(textwrap.wrap(msg)), expect, skip=1)
|
|
return
|
|
self.fail(msg="Did not raise when expected to")
|
|
|
|
def assertUfuncErrorInline(self, yaml_str: str, expect: str) -> None:
|
|
# parse a single structured group out of the yaml to g
|
|
es = yaml.load(yaml_str, Loader=LineLoader)
|
|
parsed_yaml = parse_native_yaml_struct(es, set())
|
|
native_functions, backend_indices = (
|
|
parsed_yaml.native_functions,
|
|
parsed_yaml.backend_indices,
|
|
)
|
|
grouped_native_functions = gen.get_grouped_native_functions(native_functions)
|
|
assert len(grouped_native_functions) == 1
|
|
g = grouped_native_functions[0]
|
|
assert isinstance(g, NativeFunctionsGroup)
|
|
assert g.out.ufunc_inner_loop
|
|
# this is not ufunc codegen per se, but it does some basic sanity tests for
|
|
# ufunc generation
|
|
gen.compute_meta_function_declaration(g)
|
|
dest.compute_native_function_declaration(g, backend_indices[DispatchKey.CPU])
|
|
dest.compute_native_function_declaration(g, backend_indices[DispatchKey.CUDA])
|
|
try:
|
|
# the real kahuna
|
|
dest.compute_ufunc_cpu(g)
|
|
dest.compute_ufunc_cpu_kernel(g)
|
|
dest.compute_ufunc_cuda(g)
|
|
except AssertionError as e:
|
|
# hack to strip out the context
|
|
msg, _ = str(e).split(" in ", 2)
|
|
self.assertExpectedInline("\n".join(textwrap.wrap(msg)), expect, skip=1)
|
|
return
|
|
self.fail(msg="Did not raise when expected to")
|
|
|
|
# NB: indent is hardcoded to be two here, so format your yaml accordingly
|
|
binop_out = (
|
|
"func: binop.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)"
|
|
)
|
|
ti_binop_out = f"""{binop_out}
|
|
structured: True
|
|
structured_inherits: TensorIteratorBase"""
|
|
ti_binop = """func: binop(Tensor self, Tensor other) -> Tensor
|
|
structured_delegate: binop.out
|
|
"""
|
|
|
|
ti_unop_out = """func: unop.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
|
|
structured: True
|
|
structured_inherits: TensorIteratorBase"""
|
|
ti_unop = """func: unop(Tensor self) -> Tensor
|
|
structured_delegate: unop.out
|
|
"""
|
|
|
|
def test_nonstructured_ufunc(self) -> None:
|
|
yaml_str = f"""\
|
|
- {self.binop_out}
|
|
ufunc_inner_loop:
|
|
Generic: binop (Bool)
|
|
"""
|
|
self.assertParseErrorInline(
|
|
yaml_str,
|
|
"""\
|
|
ufunc must be structured""",
|
|
)
|
|
|
|
def test_overlapping_ufunc_and_dispatch(self) -> None:
|
|
yaml_str = f"""\
|
|
- {self.ti_binop_out}
|
|
ufunc_inner_loop:
|
|
Generic: binop (Bool)
|
|
dispatch:
|
|
CPU: binop_cpu
|
|
"""
|
|
self.assertParseErrorInline(
|
|
yaml_str,
|
|
"""\
|
|
ufunc should not have explicit dispatch entry for CPU""",
|
|
)
|
|
|
|
# See https://github.com/pytorch/pytorch/pull/65851#discussion_r810238456
|
|
@unittest.expectedFailure
|
|
def test_scalaronly_shadowed(self) -> None:
|
|
yaml_str = f"""\
|
|
- {self.ti_binop_out}
|
|
ufunc_inner_loop:
|
|
Generic: binop (Bool)
|
|
ScalarOnly: binop (Bool)
|
|
"""
|
|
self.assertParseErrorInline(
|
|
yaml_str,
|
|
"""\
|
|
""",
|
|
)
|
|
|
|
def test_conflicting_ufunc(self) -> None:
|
|
yaml_str = f"""\
|
|
- {self.ti_binop_out}
|
|
ufunc_inner_loop:
|
|
Generic: binop (Bool)
|
|
ScalarOnly: binop_scalar (Bool)
|
|
- {self.ti_binop}
|
|
"""
|
|
self.assertUfuncErrorInline(
|
|
yaml_str,
|
|
"""\
|
|
ScalarOnly and Generic must have same ufunc name""",
|
|
)
|
|
|
|
def test_invalid_cudafunctoronself_for_binary_op(self) -> None:
|
|
yaml_str = f"""\
|
|
- {self.ti_unop_out}
|
|
ufunc_inner_loop:
|
|
Generic: unop (All)
|
|
CUDAFunctorOnSelf: unop_self_cuda (All)
|
|
- {self.ti_unop}
|
|
"""
|
|
self.assertUfuncErrorInline(
|
|
yaml_str,
|
|
"""\
|
|
cannot use CUDAFunctorOnSelf on non-binary function""",
|
|
)
|
|
|
|
def test_parse_custom_class_type(self) -> None:
|
|
custom_class_name = "namespace_foo.class_bar"
|
|
custom_class_name_with_prefix = f"__torch__.torch.classes.{custom_class_name}"
|
|
custom_class_type = cast(
|
|
CustomClassType, Type.parse(custom_class_name_with_prefix)
|
|
)
|
|
self.assertTrue(isinstance(custom_class_type, CustomClassType))
|
|
self.assertEqual(custom_class_name, custom_class_type.class_name)
|
|
self.assertEqual(custom_class_name_with_prefix, str(custom_class_type))
|
|
|
|
|
|
class TestAnnotation(expecttest.TestCase):
|
|
def test_single_alias_no_write(self) -> None:
|
|
a = Annotation.parse("a")
|
|
self.assertEqual(a.alias_set, tuple("a"))
|
|
self.assertFalse(a.is_write)
|
|
self.assertEqual(a.alias_set_after, tuple())
|
|
|
|
def test_single_alias_is_write(self) -> None:
|
|
a = Annotation.parse("a!")
|
|
self.assertEqual(a.alias_set, tuple("a"))
|
|
self.assertTrue(a.is_write)
|
|
self.assertEqual(a.alias_set_after, tuple())
|
|
|
|
def test_single_alias_is_write_to_wildcard(self) -> None:
|
|
a = Annotation.parse("a! -> *")
|
|
self.assertEqual(a.alias_set, tuple("a"))
|
|
self.assertTrue(a.is_write)
|
|
self.assertEqual(a.alias_set_after, tuple("*"))
|
|
|
|
def test_alias_set(self) -> None:
|
|
a = Annotation.parse("a|b")
|
|
self.assertEqual(a.alias_set, ("a", "b"))
|
|
|
|
def test_alias_set_is_write_raises_exception(self) -> None:
|
|
with self.assertRaisesRegex(
|
|
AssertionError, r"alias set larger than 1 is not mutable"
|
|
):
|
|
Annotation.parse("a|b!")
|
|
|
|
def test_single_alias_is_write_to_alias_set(self) -> None:
|
|
a = Annotation.parse("a! -> a|b")
|
|
self.assertEqual(a.alias_set, tuple("a"))
|
|
self.assertTrue(a.is_write)
|
|
self.assertEqual(a.alias_set_after, ("a", "b"))
|
|
|
|
def test_before_and_after_alias_set_larger_than_1_raises_exception(self) -> None:
|
|
with self.assertRaisesRegex(
|
|
AssertionError,
|
|
r"before alias set and after alias set cannot be larger than 1 at the same time",
|
|
):
|
|
Annotation.parse("a|b -> c|d")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|