mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49735 This is the final wave of autograd codegen data model migration. After this PR: - autograd codegen no longer depends on Declarations.yaml; - autograd codegen sources are fully type annotated and pass mypy-strict check; To avoid potential merge conflicts with other pending PRs, some structural changes are intentionally avoided, e.g. didn't move inner methods out, didn't change all inner methods to avoid reading outer function's variables, and etc. Confirmed byte-for-byte compatible with the old codegen: ``` Run it before and after this PR: .jenkins/pytorch/codegen-test.sh <baseline_output_dir> .jenkins/pytorch/codegen-test.sh <test_output_dir> Then run diff to compare the generated files: diff -Naur <baseline_output_dir> <test_output_dir> ``` Confirmed clean mypy-strict run: ``` mypy --config mypy-strict.ini ``` Test Plan: Imported from OSS Reviewed By: ezyang, bhosmer Differential Revision: D25678879 Pulled By: ljk53 fbshipit-source-id: ba6e2eb6b9fb744208f7f79a922d933fcc3bde9f
123 lines
4.6 KiB
Python
123 lines
4.6 KiB
Python
from dataclasses import dataclass
|
|
from typing import Optional, Sequence, List, Tuple
|
|
|
|
from tools.codegen.api.types import *
|
|
from tools.codegen.model import *
|
|
|
|
# Represents a saved attribute involved in backward calculation.
|
|
# Note that it can be a derived property of an input argument, e.g.:
|
|
# we could save `other.scalar_type()` instead of the entire `other` tensor.
|
|
@dataclass(frozen=True)
|
|
class SavedAttribute:
|
|
# Name of the saved attribute.
|
|
# Suffix is appended if it's derived property, e.g.: `other_scalar_type`
|
|
name: str
|
|
|
|
# The cpp type string.
|
|
# TODO: change from raw string to model.Type
|
|
type: str
|
|
|
|
# The expression to read the derived property at save time, e.g.:
|
|
# `other.scalar_type()`.
|
|
expr: str
|
|
|
|
# Represents a backward formula that calculates derivatives for one
|
|
# or more tensors.
|
|
@dataclass(frozen=True)
|
|
class Derivative:
|
|
# The formula string (legit C++ expression).
|
|
# Note that expressions against input arguments have been replaced with the
|
|
# corresponding saved attributes.
|
|
# E.g.:
|
|
# raw formula: `mul_tensor_backward(grad, self, other.scalar_type())`
|
|
# here: `mul_tensor_backward(grad, self, other_scalar_type)`
|
|
formula: str
|
|
|
|
# Names of the arguments for which this formula calculates derivatives.
|
|
var_names: Tuple[str, ...]
|
|
|
|
# Saved inputs that are referenced by the formula.
|
|
saved_inputs: Tuple[SavedAttribute, ...]
|
|
|
|
# Saved outputs that are referenced by the formula.
|
|
saved_outputs: Tuple[SavedAttribute, ...]
|
|
|
|
# Represents differentiability info for a NativeFunction.
|
|
@dataclass(frozen=True)
|
|
class DifferentiabilityInfo:
|
|
# The base name read from derivatives.yaml.
|
|
name: str
|
|
|
|
# The matching native function.
|
|
#
|
|
# There can be multiple NativeFunction having the same base name:
|
|
# - different overloads with different types of input arguments;
|
|
# - in-place/out/functional variants of the same function;
|
|
#
|
|
# We first use the schema string (under the 'name' key) in derivatives.yaml
|
|
# to find the NativeFunction having the same schema string.
|
|
# Then we find the in-place/out/functional variants of the matching function.
|
|
# Among these variants, we choose the one having the same name as the
|
|
# derivatives.yaml entry. If there is no exact match, then we choose the
|
|
# in-place variant.
|
|
# TODO: maybe the logic to search for all variants is no longer necessary?
|
|
func: NativeFunction
|
|
|
|
# The name of the generated autograd function.
|
|
# It's set only if we will calculate a derivative, i.e.
|
|
# 'args_with_derivatives' is not empty.
|
|
op: Optional[str]
|
|
|
|
# The derivatives formulae for this function.
|
|
derivatives: Sequence[Derivative]
|
|
|
|
# The union of 'saved_inputs' of all 'derivatives'.
|
|
all_saved_inputs: Sequence[SavedAttribute]
|
|
|
|
# The union of 'saved_outputs' of all 'derivatives'.
|
|
all_saved_outputs: Sequence[SavedAttribute]
|
|
|
|
# The function's input arguments for which it calculates derivatives.
|
|
# It's the union of 'var_names' of all 'derivatives', sorted by the
|
|
# argument order in the function schema.
|
|
args_with_derivatives: Sequence[Binding]
|
|
|
|
# Names of arguments whose derivative formula is 'non_differentiable'.
|
|
non_differentiable_arg_names: Sequence[str]
|
|
|
|
# Raw data read from derivatives.yaml.
|
|
output_differentiability: Optional[List[bool]]
|
|
|
|
@property
|
|
def has_derivatives(self) -> bool:
|
|
return len(self.args_with_derivatives) > 0
|
|
|
|
# Represents a differentiable `Argument`.
|
|
# How is it different from the `Argument` type?
|
|
# - It's processed Arguments which are differentiable and only used in the
|
|
# context of the autograd codegen;
|
|
# - It can represent SelfArgument or regular Argument but not TensorOptionsArgument;
|
|
@dataclass(frozen=True)
|
|
class DifferentiableInput:
|
|
name: str
|
|
type: Type
|
|
|
|
# TODO: only to keep it byte-for-byte compatible with the old codegen, should remove.
|
|
cpp_type: str
|
|
|
|
# Represents a differentiable `Return`.
|
|
# How it it different from the `Return` type?
|
|
# - The name in `Return` is optional. Here it is always populated using the same
|
|
# `cpp.return_names()` method.
|
|
# TODO: some cpp naming logic (e.g. resolving name conflict) might be irrelevant?
|
|
# - It's processed Returns which are differentiable, in compliance with the
|
|
# `output_differentiability` field defined in derivatives.yaml (if specified),
|
|
# and are only used in the context of the autograd codegen;
|
|
@dataclass(frozen=True)
|
|
class DifferentiableOutput:
|
|
name: str
|
|
type: Type
|
|
|
|
# TODO: only to keep it byte-for-byte compatible with the old codegen, should remove.
|
|
cpp_type: str
|