mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Generally wildcard imports are bad for the reasons described here: https://www.flake8rules.com/rules/F403.html This PR replaces wildcard imports with an explicit list of imported items where possible, and adds a `# noqa: F403` comment in the other cases (mostly re-exports in `__init__.py` files). This is a prerequisite for https://github.com/pytorch/pytorch/issues/55816, because currently [`tools/codegen/dest/register_dispatch_key.py` simply fails if you sort its imports](https://github.com/pytorch/pytorch/actions/runs/742505908). Pull Request resolved: https://github.com/pytorch/pytorch/pull/55838 Test Plan: CI. You can also run `flake8` locally. Reviewed By: jbschlosser Differential Revision: D27724232 Pulled By: samestep fbshipit-source-id: 269fb09cb4168f8a51fd65bfaacc6cda7fb87c34
76 lines
2.4 KiB
Python
76 lines
2.4 KiB
Python
from typing import List, Union, Set, Any
|
|
|
|
from tools.codegen.context import with_native_function
|
|
from tools.codegen.utils import concatMap
|
|
from tools.codegen.model import (NativeFunction, NativeFunctionsGroup,
|
|
is_structured_dispatch_key)
|
|
import tools.codegen.api.meta as meta
|
|
import tools.codegen.api.native as native
|
|
import tools.codegen.api.structured as structured
|
|
|
|
@with_native_function
|
|
def gen_unstructured(f: NativeFunction) -> List[str]:
|
|
ns = list(f.dispatch.values())
|
|
|
|
rs = []
|
|
# Sometimes a function name shows up multiple times; only generate
|
|
# it once!
|
|
seen = set()
|
|
for n in ns:
|
|
if n in seen:
|
|
continue
|
|
if "legacy::" in n:
|
|
continue
|
|
seen.add(n)
|
|
returns_type = native.returns_type(f.func.returns)
|
|
args = native.arguments(f.func)
|
|
rs.append(f"TORCH_API {returns_type} {n}({', '.join(a.decl() for a in args)});")
|
|
|
|
return rs
|
|
|
|
@with_native_function
|
|
def gen_structured(g: NativeFunctionsGroup) -> List[str]:
|
|
# only out has dispatch
|
|
meta_name = meta.name(g)
|
|
rs = []
|
|
seen: Set[Any] = set()
|
|
out_args = structured.impl_arguments(g)
|
|
for k, n in g.out.dispatch.items():
|
|
if n in seen:
|
|
continue
|
|
if not is_structured_dispatch_key(k):
|
|
continue
|
|
seen.add(n)
|
|
rs.append(f"""\
|
|
struct TORCH_API structured_{n} : public at::meta::{meta_name} {{
|
|
void impl({', '.join(a.decl() for a in out_args)});
|
|
}};
|
|
""")
|
|
|
|
seen = set()
|
|
for f in g.functions():
|
|
returns_type = native.returns_type(f.func.returns)
|
|
args = native.arguments(f.func)
|
|
for k, n in f.dispatch.items():
|
|
if n in seen:
|
|
continue
|
|
if is_structured_dispatch_key(k):
|
|
continue
|
|
seen.add(n)
|
|
args_str = ', '.join(a.decl() for a in args)
|
|
rs.append(f"TORCH_API {returns_type} {n}({args_str});")
|
|
|
|
return rs
|
|
|
|
# Generates NativeFunctions.h, a list of forward declarations of all
|
|
# actual kernel definitions we keep in aten/src/ATen/native/
|
|
@with_native_function
|
|
def compute_native_function_declaration(g: Union[NativeFunctionsGroup, NativeFunction]) -> List[str]:
|
|
if isinstance(g, NativeFunctionsGroup):
|
|
if g.structured:
|
|
return gen_structured(g)
|
|
else:
|
|
return list(concatMap(gen_unstructured, g.functions()))
|
|
else:
|
|
return gen_unstructured(g)
|