mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Generally wildcard imports are bad for the reasons described here: https://www.flake8rules.com/rules/F403.html This PR replaces wildcard imports with an explicit list of imported items where possible, and adds a `# noqa: F403` comment in the other cases (mostly re-exports in `__init__.py` files). This is a prerequisite for https://github.com/pytorch/pytorch/issues/55816, because currently [`tools/codegen/dest/register_dispatch_key.py` simply fails if you sort its imports](https://github.com/pytorch/pytorch/actions/runs/742505908). Pull Request resolved: https://github.com/pytorch/pytorch/pull/55838 Test Plan: CI. You can also run `flake8` locally. Reviewed By: jbschlosser Differential Revision: D27724232 Pulled By: samestep fbshipit-source-id: 269fb09cb4168f8a51fd65bfaacc6cda7fb87c34
257 lines
11 KiB
Python
257 lines
11 KiB
Python
from dataclasses import dataclass
|
|
import re
|
|
from typing import Optional, Sequence, List, Tuple
|
|
|
|
from tools.codegen.api import cpp
|
|
from tools.codegen.api.types import Binding
|
|
from tools.codegen.model import NativeFunction, Type, SchemaKind
|
|
from tools.codegen.utils import IDENT_REGEX
|
|
|
|
# Represents a saved attribute involved in backward calculation.
|
|
# Note that it can be a derived property of an input argument, e.g.:
|
|
# we could save `other.scalar_type()` instead of the entire `other` tensor.
|
|
@dataclass(frozen=True)
|
|
class SavedAttribute:
|
|
# Name of the saved attribute.
|
|
# Suffix is appended if it's derived property, e.g.: `other_scalar_type`
|
|
name: str
|
|
|
|
# The cpp type string.
|
|
# TODO: change from raw string to model.Type
|
|
type: str
|
|
|
|
# The expression to read the derived property at save time, e.g.:
|
|
# `other.scalar_type()`.
|
|
expr: str
|
|
|
|
# Represents a backward formula that calculates derivatives for one
|
|
# or more tensors.
|
|
@dataclass(frozen=True)
|
|
class Derivative:
|
|
# The formula string (legit C++ expression).
|
|
# Note that expressions against input arguments have been replaced with the
|
|
# corresponding saved attributes.
|
|
# E.g.:
|
|
# raw formula: `mul_tensor_backward(grad, self, other.scalar_type())`
|
|
# here: `mul_tensor_backward(grad, self, other_scalar_type)`
|
|
formula: str
|
|
|
|
# Names of the arguments for which this formula calculates derivatives.
|
|
var_names: Tuple[str, ...]
|
|
|
|
# Saved inputs that are referenced by the formula.
|
|
saved_inputs: Tuple[SavedAttribute, ...]
|
|
|
|
# Saved outputs that are referenced by the formula.
|
|
saved_outputs: Tuple[SavedAttribute, ...]
|
|
|
|
# Represents differentiability info for a NativeFunction.
|
|
@dataclass(frozen=True)
|
|
class DifferentiabilityInfo:
|
|
# The base name read from derivatives.yaml.
|
|
name: str
|
|
|
|
# The matching native function.
|
|
#
|
|
# There can be multiple NativeFunction having the same base name:
|
|
# - different overloads with different types of input arguments;
|
|
# - in-place/out/functional variants of the same function;
|
|
#
|
|
# We first use the schema string (under the 'name' key) in derivatives.yaml
|
|
# to find the NativeFunction having the same schema string.
|
|
# Then we find the in-place/out/functional variants of the matching function.
|
|
# Among these variants, we choose the one having the same name as the
|
|
# derivatives.yaml entry. If there is no exact match, then we choose the
|
|
# in-place variant.
|
|
# TODO: maybe the logic to search for all variants is no longer necessary?
|
|
func: NativeFunction
|
|
|
|
# The name of the generated autograd function.
|
|
# It's set only if we will calculate a derivative, i.e.
|
|
# 'args_with_derivatives' is not empty.
|
|
op: Optional[str]
|
|
|
|
# The derivatives formulae for this function.
|
|
derivatives: Sequence[Derivative]
|
|
|
|
# The union of 'saved_inputs' of all 'derivatives'.
|
|
all_saved_inputs: Sequence[SavedAttribute]
|
|
|
|
# The union of 'saved_outputs' of all 'derivatives'.
|
|
all_saved_outputs: Sequence[SavedAttribute]
|
|
|
|
# The function's input arguments for which it calculates derivatives.
|
|
# It's the union of 'var_names' of all 'derivatives', sorted by the
|
|
# argument order in the function schema.
|
|
args_with_derivatives: Sequence[Binding]
|
|
|
|
# Names of arguments whose derivative formula is 'non_differentiable'.
|
|
non_differentiable_arg_names: Sequence[str]
|
|
|
|
# Raw data read from derivatives.yaml.
|
|
output_differentiability: Optional[List[bool]]
|
|
|
|
@property
|
|
def has_derivatives(self) -> bool:
|
|
return len(self.args_with_derivatives) > 0
|
|
|
|
def uses_ident(info: Optional[DifferentiabilityInfo], ident: str) -> bool:
|
|
if info is None:
|
|
return False
|
|
for derivative in info.derivatives:
|
|
formula = derivative.formula
|
|
if re.search(IDENT_REGEX.format(ident), formula):
|
|
return True
|
|
return False
|
|
|
|
def uses_retain_variables(info: Optional[DifferentiabilityInfo]) -> bool:
|
|
return uses_ident(info, 'retain_variables')
|
|
|
|
def uses_single_grad(info: Optional[DifferentiabilityInfo]) -> bool:
|
|
return uses_ident(info, 'grad')
|
|
|
|
# Represents a differentiable `Argument`.
|
|
# How is it different from the `Argument` type?
|
|
# - It's processed Arguments which are differentiable and only used in the
|
|
# context of the autograd codegen;
|
|
# - It can represent SelfArgument or regular Argument but not TensorOptionsArgument;
|
|
@dataclass(frozen=True)
|
|
class DifferentiableInput:
|
|
name: str
|
|
type: Type
|
|
|
|
# TODO: only to keep it byte-for-byte compatible with the old codegen, should remove.
|
|
cpp_type: str
|
|
|
|
# Represents a differentiable `Return`.
|
|
# How it it different from the `Return` type?
|
|
# - The name in `Return` is optional. Here it is always populated using the same
|
|
# `cpp.return_names()` method.
|
|
# TODO: some cpp naming logic (e.g. resolving name conflict) might be irrelevant?
|
|
# - It's processed Returns which are differentiable, in compliance with the
|
|
# `output_differentiability` field defined in derivatives.yaml (if specified),
|
|
# and are only used in the context of the autograd codegen;
|
|
@dataclass(frozen=True)
|
|
class DifferentiableOutput:
|
|
name: str
|
|
type: Type
|
|
|
|
# TODO: only to keep it byte-for-byte compatible with the old codegen, should remove.
|
|
cpp_type: str
|
|
|
|
@dataclass(frozen=True)
|
|
class NativeFunctionWithDifferentiabilityInfo:
|
|
func: NativeFunction
|
|
info: Optional[DifferentiabilityInfo]
|
|
|
|
# TODO: Update comment below since it is out of date.
|
|
def dispatch_strategy(fn: NativeFunctionWithDifferentiabilityInfo) -> str:
|
|
"""How are we going to call the underlying implementation of a
|
|
declaration? There are two strategies:
|
|
- use_derived: we want to call the implementation on CPUDoubleType
|
|
(or a similar, derived Type instance). Because these derived
|
|
instances deal in Tensors, not Variables (it's a completely different
|
|
object, so it doesn't dispatch back to VariableType), code on
|
|
this dispatch path needs to wrap/unwrap tensors. If the
|
|
derived implementation takes and returns tensors, the
|
|
implementation is usually differentiable (although we also use
|
|
the derived dispatch path for non-differentiable functions
|
|
that we still want to dispatch on the derived Type instance;
|
|
e.g., size())
|
|
- use_type: we want to call the implementation on Type, because
|
|
it is implemented concretely, and the functions it invokes will
|
|
get dispatched back to VariableType (which will ensure that they
|
|
are differentiable.)
|
|
"""
|
|
if fn.func.is_abstract or (fn.info is not None and fn.info.has_derivatives):
|
|
# If the function is abstract (not implemented on at::Type), we must
|
|
# call the implementation on the derived type with unpacked tensors.
|
|
|
|
# If the function has a derivative specified and is concrete, we could
|
|
# call either implementation. We prefer the calling the derived
|
|
# type's implementation with unpacked tensors because it is more
|
|
# performant in some cases: any internal calls to other ATen functions
|
|
# won't have the history tracked.
|
|
|
|
# If the function has a type dispatched argument (i.e. is a factory),
|
|
# we prefer calling the derived type's implementation both because it is
|
|
# more performant and to ensure factory functions return tensors with _version
|
|
# of 0 (probably not strictly necessary, but nice to have to keeps versions simple
|
|
# to understand.
|
|
|
|
return 'use_derived'
|
|
else:
|
|
# If the function is concrete (we don't have to override it) and we
|
|
# didn't declare it in derivatives.yaml, we'll assume that it is
|
|
# actually implemented out of differentiable functions. (This
|
|
# assumption might not hold, but then you'll see gradcheck fail.)
|
|
return 'use_type'
|
|
|
|
def match_differentiability_info(
|
|
native_functions: List[NativeFunction],
|
|
differentiability_infos: Sequence[DifferentiabilityInfo],
|
|
) -> List[NativeFunctionWithDifferentiabilityInfo]:
|
|
"""Sets the "derivative" key on declarations to matching autograd function
|
|
In-place functions will use the out-of-place derivative definition if there
|
|
is no in-place specific derivative.
|
|
"""
|
|
|
|
info_by_schema = {info.func.func: info for info in differentiability_infos}
|
|
functional_info_by_signature = {
|
|
info.func.func.signature(strip_default=True): info
|
|
for info in differentiability_infos
|
|
if info.func.func.kind() == SchemaKind.functional}
|
|
|
|
def find_info(f: NativeFunction) -> Tuple[Optional[DifferentiabilityInfo], bool]:
|
|
if f.func in info_by_schema:
|
|
return info_by_schema[f.func], True
|
|
|
|
# if there is no exact match look for the out-of-place signature.
|
|
# i.e mul() for mul_() or mul_out()
|
|
return functional_info_by_signature.get(f.func.signature(strip_default=True)), False
|
|
|
|
result: List[NativeFunctionWithDifferentiabilityInfo] = []
|
|
for f in native_functions:
|
|
info, is_exact_match = find_info(f)
|
|
|
|
# Currently, the '.strides()' to 'strides_or_error' replacement does not support
|
|
# 'self' derivatives of an inplace function, so we must check for this case.
|
|
if f.func.kind() == SchemaKind.inplace and (info is not None):
|
|
for derivative in info.derivatives:
|
|
if 'self' in derivative.var_names:
|
|
for saved_input in derivative.saved_inputs:
|
|
assert 'strides_or_error' not in saved_input.expr, (
|
|
"Calling '.strides()' in the 'self' derivative formula of an "
|
|
f"in-place function is not supported: {f.func}")
|
|
|
|
result.append(NativeFunctionWithDifferentiabilityInfo(
|
|
func=f,
|
|
info=info,
|
|
))
|
|
|
|
return result
|
|
|
|
def is_differentiable(name: str, type: Type, info: Optional[DifferentiabilityInfo]) -> bool:
|
|
return type.is_tensor_like() and (info is None or name not in info.non_differentiable_arg_names)
|
|
|
|
def gen_differentiable_outputs(fn: NativeFunctionWithDifferentiabilityInfo) -> List[DifferentiableOutput]:
|
|
f = fn.func
|
|
info = fn.info
|
|
outputs: List[DifferentiableOutput] = [
|
|
DifferentiableOutput(name=name, type=ret.type, cpp_type=cpp.return_type(ret))
|
|
for name, ret in zip(cpp.return_names(f), f.func.returns)]
|
|
output_differentiability = info.output_differentiability if info else None
|
|
if output_differentiability is not None:
|
|
differentiable_outputs: List[DifferentiableOutput] = []
|
|
if False in output_differentiability and f.func.kind() == SchemaKind.inplace:
|
|
raise RuntimeError("output_differentiability=False for inplace operation (version_counter won't get updated)")
|
|
for differentiable, output in zip(output_differentiability, outputs):
|
|
if differentiable:
|
|
differentiable_outputs.append(output)
|
|
return differentiable_outputs
|
|
candidate_differentiable_outputs = list(filter(lambda r: is_differentiable(r.name, r.type, info), outputs))
|
|
if uses_single_grad(info):
|
|
return candidate_differentiable_outputs[:1]
|
|
else:
|
|
return candidate_differentiable_outputs
|