mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/48249 Introduced autograd related data models at tools.codegen.api.autograd. Migrated load_derivatives.py to produce the new data models from derivatives.yaml. It has clean mypy-strict result. Changed both gen_autograd_functions.py and gen_variable_type.py to consume the new data model. Added type annotations to gen_autograd_functions.py - it has clean mypy-strict result except for the .gen_autograd import (so haven't added it to the strict config in this PR). To limit the scope of the PR, gen_variable_type.py is not refactored, and the main structure of load_derivatives.py / gen_autograd_functions.py is kept. We only make necessary changes to make it work. Confirmed byte-for-byte compatible with the old codegen: ``` Run it before and after this PR: .jenkins/pytorch/codegen-test.sh <baseline_output_dir> .jenkins/pytorch/codegen-test.sh <test_output_dir> Then run diff to compare the generated files: diff -Naur <baseline_output_dir> <test_output_dir> ``` Test Plan: Imported from OSS Reviewed By: ezyang Differential Revision: D25086561 Pulled By: ljk53 fbshipit-source-id: 1f43ab0931d9814c24683b9a48ca497c5fc3d729
90 lines
3.3 KiB
Python
90 lines
3.3 KiB
Python
from dataclasses import dataclass
|
|
from typing import Optional, Sequence, List, Tuple
|
|
|
|
from tools.codegen.api.types import *
|
|
from tools.codegen.model import *
|
|
|
|
# Represents a saved attribute involved in backward calculation.
|
|
# Note that it can be a derived property of an input argument, e.g.:
|
|
# we could save `other.scalar_type()` instead of the entire `other` tensor.
|
|
@dataclass(frozen=True)
|
|
class SavedAttribute:
|
|
# Name of the saved attribute.
|
|
# Suffix is appended if it's derived property, e.g.: `other_scalar_type`
|
|
name: str
|
|
|
|
# The cpp type string.
|
|
# TODO: change from raw string to model.Type
|
|
type: str
|
|
|
|
# The expression to read the derived property at save time, e.g.:
|
|
# `other.scalar_type()`.
|
|
expr: str
|
|
|
|
# Represents a backward formula that calculates derivatives for one
|
|
# or more tensors.
|
|
@dataclass(frozen=True)
|
|
class Derivative:
|
|
# The formula string (legit C++ expression).
|
|
# Note that expressions against input arguments have been replaced with the
|
|
# corresponding saved attributes.
|
|
# E.g.:
|
|
# raw formula: `mul_tensor_backward(grad, self, other.scalar_type())`
|
|
# here: `mul_tensor_backward(grad, self, other_scalar_type)`
|
|
formula: str
|
|
|
|
# Names of the arguments for which this formula calculates derivatives.
|
|
var_names: Tuple[str, ...]
|
|
|
|
# Saved inputs that are referenced by the formula.
|
|
saved_inputs: Tuple[SavedAttribute, ...]
|
|
|
|
# Saved outputs that are referenced by the formula.
|
|
saved_outputs: Tuple[SavedAttribute, ...]
|
|
|
|
# Represents differentiability info for a NativeFunction.
|
|
@dataclass(frozen=True)
|
|
class DifferentiabilityInfo:
|
|
# The base name read from derivatives.yaml.
|
|
name: str
|
|
|
|
# The matching native function.
|
|
#
|
|
# There can be multiple NativeFunction having the same base name:
|
|
# - different overloads with different types of input arguments;
|
|
# - in-place/out/functional variants of the same function;
|
|
#
|
|
# We first use the schema string (under the 'name' key) in derivatives.yaml
|
|
# to find the NativeFunction having the same schema string.
|
|
# Then we find the in-place/out/functional variants of the matching function.
|
|
# Among these variants, we choose the one having the same name as the
|
|
# derivatives.yaml entry. If there is no exact match, then we choose the
|
|
# in-place variant.
|
|
# TODO: maybe the logic to search for all variants is no longer necessary?
|
|
func: NativeFunction
|
|
|
|
# The name of the generated autograd function.
|
|
# It's set only if we will calculate a derivative, i.e.
|
|
# 'args_with_derivatives' is not empty.
|
|
op: Optional[str]
|
|
|
|
# The derivatives formulae for this function.
|
|
derivatives: Sequence[Derivative]
|
|
|
|
# The union of 'saved_inputs' of all 'derivatives'.
|
|
all_saved_inputs: Sequence[SavedAttribute]
|
|
|
|
# The union of 'saved_outputs' of all 'derivatives'.
|
|
all_saved_outputs: Sequence[SavedAttribute]
|
|
|
|
# The function's input arguments for which it calculates derivatives.
|
|
# It's the union of 'var_names' of all 'derivatives', sorted by the
|
|
# argument order in the function schema.
|
|
args_with_derivatives: Sequence[CppArgument]
|
|
|
|
# Names of arguments whose derivative formula is 'non_differentiable'.
|
|
non_differentiable_arg_names: Sequence[str]
|
|
|
|
# Raw data read from derivatives.yaml.
|
|
output_differentiability: Optional[List[bool]]
|