mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/55334 The goal of this PR is to clean up some of the autograd codegen to compare C++ types using `CType` objects instead of raw strings. My last PR in the stack made that string comparison a little more fragile, since the raw C++ strings needed to be namespace-aware. I confirmed byte-for-byte no codegen changes vs. the last PR (which added namespaces to the codegen) by running `diff -qr ../pytorch-common_test/torch/csrc/autograd/generated/ ../pytorch-callgrind_test_after2/torch/csrc/autograd/generated/` and `diff -qr ../pytorch-common_test/build/aten/src/ATen/ ../pytorch-callgrind_test_after2/build/aten/src/ATen/` Note that a better end-state for the autograd codegen would be to do all of its type pattern matching directly off of JIT types, instead of off of CType’s (which are really just generated from JIT types, incorporating C++ specific semantics). That looks like it’ll require a pretty substantial change though, so I’m not doing it in this PR. As part of this change (and after talking with ezyang), I split off the `CType` data class into a separate `NamedCType` class, which holds a name and a `CType`. This way, `CType` only knows about actual C++ types, making it easier to compare CType’s to each other in the codegen when we only care about the type. The core change is in `types.py`, but it required a bunch of downstream changes to update all of the places where we create `CType`s to create `NamedCType`s instead. The main change in the autograd codegen was that I updated `SavedAttribute` to store a `NamedCType`. The other autograd changes all pretty much came from that change. Test Plan: Imported from OSS Reviewed By: bhosmer Differential Revision: D27708347 Pulled By: bdhirsh fbshipit-source-id: 3e07c80569c7b229c638f389e76e319bff6315f9
111 lines
4.6 KiB
Python
111 lines
4.6 KiB
Python
from tools.codegen.model import (Argument, FunctionSchema, Return,
|
|
SelfArgument, TensorOptionsArguments, Type,
|
|
assert_never)
|
|
|
|
from tools.codegen.api.types import (ArgName, BaseCType, Binding,
|
|
ConstRefCType, NamedCType, CType, MutRefCType, ListCType,
|
|
OptionalCType, tensorT, scalarT, layoutT,
|
|
deviceT, boolT, scalarTypeT)
|
|
from tools.codegen.api import cpp
|
|
|
|
from typing import Union, Sequence, List, Optional
|
|
|
|
# This file describes the translation of JIT schema to the native functions API.
|
|
# This looks a lot like the C++ API (which makes historical sense, because the
|
|
# idea was you wrote native functions to implement functions in the C++ API),
|
|
# but over time we have evolved the C++ API without actually changing our
|
|
# native:: kernels. The intention is to make native API and dispatcher API
|
|
# line up as closely as possible, since this results in the least overhead
|
|
# (no translation is needed from dispatcher API to native API).
|
|
|
|
def name(func: FunctionSchema) -> str:
|
|
name = str(func.name.name)
|
|
# TODO: delete this!
|
|
if func.is_out_fn():
|
|
name += '_out'
|
|
if func.name.overload_name:
|
|
name += f'_{func.name.overload_name}'
|
|
return name
|
|
|
|
def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> NamedCType:
|
|
if str(t) == 'Tensor?':
|
|
tensor_type: OptionalCType = OptionalCType(BaseCType(tensorT))
|
|
if mutable:
|
|
return NamedCType(binds, MutRefCType(tensor_type))
|
|
else:
|
|
return NamedCType(binds, ConstRefCType(tensor_type))
|
|
elif str(t) == 'Tensor?[]':
|
|
return NamedCType(binds, ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT)))))
|
|
elif str(t) == 'Scalar':
|
|
return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
|
|
elif str(t) == 'Scalar?':
|
|
return NamedCType(binds, ConstRefCType(OptionalCType(BaseCType(scalarT))))
|
|
return cpp.argumenttype_type(t, mutable=mutable, binds=binds)
|
|
|
|
def returns_type(rs: Sequence[Return]) -> CType:
|
|
return cpp.returns_type(rs)
|
|
|
|
def argument_type(a: Argument, *, binds: ArgName) -> NamedCType:
|
|
return argumenttype_type(a.type, mutable=a.is_write, binds=binds)
|
|
|
|
def argument(a: Union[Argument, SelfArgument, TensorOptionsArguments], *, is_out: bool) -> List[Binding]:
|
|
# Ideally, we NEVER default native functions. However, there are a number
|
|
# of functions that call native:: directly and rely on the defaulting
|
|
# existing. So for BC, we generate defaults for non-out variants (but not
|
|
# for out variants, where it is impossible to generate an appropriate
|
|
# default)
|
|
should_default = not is_out
|
|
if isinstance(a, Argument):
|
|
default: Optional[str] = None
|
|
if should_default and a.default is not None:
|
|
default = cpp.default_expr(a.default, a.type)
|
|
return [Binding(
|
|
nctype=argument_type(a, binds=a.name),
|
|
name=a.name,
|
|
default=default,
|
|
argument=a,
|
|
)]
|
|
elif isinstance(a, SelfArgument):
|
|
# Erase SelfArgument from the distinction
|
|
return argument(a.argument, is_out=is_out)
|
|
elif isinstance(a, TensorOptionsArguments):
|
|
default = None
|
|
if should_default:
|
|
default = '{}'
|
|
# TODO: Not sure why the arguments assigned here are for
|
|
# TensorOptionsArguments and not the constituent pieces. It seems
|
|
# to matter
|
|
return [
|
|
Binding(
|
|
nctype=NamedCType('dtype', OptionalCType(BaseCType(scalarTypeT))),
|
|
name='dtype',
|
|
default=default,
|
|
argument=a,
|
|
),
|
|
Binding(
|
|
nctype=NamedCType('layout', OptionalCType(BaseCType(layoutT))),
|
|
name='layout',
|
|
default=default,
|
|
argument=a,
|
|
),
|
|
Binding(
|
|
nctype=NamedCType('device', OptionalCType(BaseCType(deviceT))),
|
|
name='device',
|
|
default=default,
|
|
argument=a,
|
|
),
|
|
Binding(
|
|
nctype=NamedCType('pin_memory', OptionalCType(BaseCType(boolT))),
|
|
name='pin_memory',
|
|
default=default,
|
|
argument=a,
|
|
)]
|
|
else:
|
|
assert_never(a)
|
|
|
|
def arguments(func: FunctionSchema) -> List[Binding]:
|
|
args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
|
|
args.extend(func.arguments.non_out)
|
|
args.extend(func.arguments.out)
|
|
return [r for arg in args for r in argument(arg, is_out=func.is_out_fn())]
|