mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: I'm cramming through clang tidy emitted warnings. This PR addresses the `hi-cpp-override` check which warns that `virtual` + `override` is redundant, since `override` already signifies that a function is overriding and thus virtual. Where there was `virtual` + `override` I removed the `virtual`, where there was `virtual` and no `override` I removed `virtual` and added `override`. ezyang apaszke Pull Request resolved: https://github.com/pytorch/pytorch/pull/9335 Differential Revision: D8807082 Pulled By: goldsborough fbshipit-source-id: e0a261053f6540a22cc56ec160a24aa285af6319
217 lines
7.1 KiB
Python
217 lines
7.1 KiB
Python
# Generates C++ autograd functions for the derivatives of ATen operations
|
|
#
|
|
# This writes two files:
|
|
# Functions.h/cpp: subclasses of autograd::Function
|
|
# python_functions.h/cpp: Python bindings for the above classes
|
|
#
|
|
import re
|
|
from .utils import nested_dict, CodeTemplate, write
|
|
from .gen_autograd import VIEW_FUNCTIONS
|
|
from .utils import IDENT_REGEX
|
|
|
|
FUNCTION_DECLARATION = CodeTemplate("""\
|
|
struct ${op} : public ${superclass} {
|
|
using ${superclass}::${superclass};
|
|
variable_list apply(variable_list&& grads) override;
|
|
std::string name() const override { return "${op}"; }
|
|
void release_variables() override {
|
|
${release_variables}
|
|
}
|
|
${will_release_variables}
|
|
${saved_variables}
|
|
${saved_list_sizes}
|
|
};
|
|
""")
|
|
|
|
WILL_RELEASE_VARIABLES = CodeTemplate("""\
|
|
bool retain_variables = true;
|
|
void will_release_variables() override {
|
|
retain_variables = false;
|
|
}
|
|
""")
|
|
|
|
FUNCTION_DEFINITION = CodeTemplate("""\
|
|
variable_list ${op}::apply(variable_list&& grads) {
|
|
IndexRangeGenerator gen;
|
|
${compute_index_ranges}
|
|
variable_list grad_inputs(gen.size());
|
|
${body}
|
|
return grad_inputs;
|
|
}
|
|
""")
|
|
|
|
PY_FUNCTION_DEFINITION = CodeTemplate("""\
|
|
static PyTypeObject ${op}Class;
|
|
addClass<${op}>(${op}Class, "${op}");
|
|
""")
|
|
|
|
GRAD_INPUT_MASK = CodeTemplate("""\
|
|
auto grad_input_mask = std::array<bool, ${n}>{
|
|
${masks}
|
|
};\
|
|
""")
|
|
|
|
DERIVATIVE_SINGLE = CodeTemplate("""\
|
|
if (should_compute_output({ ${name}_ix })) {
|
|
auto grad_result = ${derivative};
|
|
copy_range(grad_inputs, ${name}_ix, grad_result);
|
|
}
|
|
""")
|
|
|
|
DERIVATIVE_MULTI_COPY_RANGE = CodeTemplate("""\
|
|
if (should_compute_output({ ${name}_ix })) {
|
|
copy_range(grad_inputs, ${name}_ix, std::get<${i}>(grad_result));
|
|
}
|
|
""")
|
|
|
|
DERIVATIVE_MULTI = CodeTemplate("""\
|
|
if (should_compute_output({ ${idx_ranges} })) {
|
|
${grad_input_mask}
|
|
auto grad_result = ${derivative};
|
|
${copy_ranges}
|
|
}
|
|
""")
|
|
|
|
# These functions have backwards which cannot be traced, and so must have
|
|
# their backward functions traced opaquely.
|
|
# VIEW_FUNCTIONS are not traceable because they use as_strided, which
|
|
# has an untraceable backwards, see
|
|
# https://github.com/pytorch/pytorch/issues/4250
|
|
# TODO: This is probably not exhaustive, but it's a start
|
|
UNTRACEABLE_FUNCTIONS = VIEW_FUNCTIONS
|
|
|
|
|
|
def gen_autograd_functions(out, autograd_functions, template_path):
|
|
"""Functions.h and Functions.cpp body
|
|
|
|
These contain the auto-generated subclasses of torch::autograd::Function
|
|
for each every differentiable torch function.
|
|
"""
|
|
|
|
FUNCTIONS_H = CodeTemplate.from_file(template_path + '/Functions.h')
|
|
FUNCTIONS_CPP = CodeTemplate.from_file(template_path + '/Functions.cpp')
|
|
PY_FUNCTIONS_H = CodeTemplate.from_file(template_path + '/python_functions.h')
|
|
PY_FUNCTIONS_CPP = CodeTemplate.from_file(template_path + '/python_functions.cpp')
|
|
|
|
function_definitions = []
|
|
function_declarations = []
|
|
py_function_initializers = []
|
|
|
|
for func in autograd_functions:
|
|
env = process_function(func)
|
|
|
|
function_declarations.append(FUNCTION_DECLARATION.substitute(env))
|
|
function_definitions.append(FUNCTION_DEFINITION.substitute(env))
|
|
py_function_initializers.append(PY_FUNCTION_DEFINITION.substitute(env))
|
|
|
|
top_env = {
|
|
'autograd_function_definitions': function_definitions,
|
|
'autograd_function_declarations': function_declarations,
|
|
'py_function_initializers': py_function_initializers,
|
|
}
|
|
|
|
write(out, 'Functions.h', FUNCTIONS_H, top_env)
|
|
write(out, 'Functions.cpp', FUNCTIONS_CPP, top_env)
|
|
write(out, 'python_functions.h', PY_FUNCTIONS_H, top_env)
|
|
write(out, 'python_functions.cpp', PY_FUNCTIONS_CPP, top_env)
|
|
|
|
|
|
def process_function(func):
|
|
env = {}
|
|
saved_variables = []
|
|
release_variables = []
|
|
saved_list_sizes = []
|
|
unpack = []
|
|
|
|
env['compute_index_ranges'] = []
|
|
for arg in func['args_with_gradients']:
|
|
if arg['type'] == 'TensorList':
|
|
size = '{}_size_'.format(arg['name'])
|
|
saved_list_sizes.append('size_t {}_size_;'.format(arg['name']))
|
|
else:
|
|
size = '1'
|
|
env['compute_index_ranges'].append('auto {}_ix = gen.range({});'.format(arg['name'], size))
|
|
|
|
def save_arg(arg, is_output):
|
|
name = arg['name']
|
|
if arg['type'] == 'Tensor' or (arg['type'] == 'Scalar' and is_output):
|
|
saved_variables.append('SavedVariable {}_;'.format(name))
|
|
release_variables.append('{}_.reset_data();'.format(name))
|
|
ptr = 'shared_from_this()' if is_output else ''
|
|
unpack.append('auto {} = {}_.unpack({});'.format(name, name, ptr))
|
|
elif arg['type'] == 'TensorList':
|
|
saved_variables.append('std::vector<SavedVariable> {}_;'.format(name))
|
|
release_variables.append('{}_.clear();'.format(name))
|
|
unpack.append('auto {} = unpack_list({}_);'.format(name, name))
|
|
elif arg['type'] == 'IntList':
|
|
saved_variables.append('std::vector<int64_t> {};'.format(name))
|
|
else:
|
|
saved_variables.append('{} {};'.format(arg['type'], name))
|
|
|
|
for arg in func['saved_inputs']:
|
|
save_arg(arg, is_output=False)
|
|
for arg in func['saved_outputs']:
|
|
save_arg(arg, is_output=True)
|
|
env['saved_variables'] = saved_variables
|
|
env['release_variables'] = release_variables
|
|
env['saved_list_sizes'] = saved_list_sizes
|
|
|
|
if uses_retain_variables(func):
|
|
env['will_release_variables'] = WILL_RELEASE_VARIABLES.substitute()
|
|
else:
|
|
env['will_release_variables'] = ''
|
|
|
|
body = []
|
|
|
|
if uses_single_grad(func):
|
|
body.append('auto& grad = grads[0];')
|
|
|
|
def emit_derivative(derivative):
|
|
formula = derivative['formula']
|
|
var_names = derivative['var_names']
|
|
if len(var_names) == 1:
|
|
return DERIVATIVE_SINGLE.substitute(name=var_names[0], derivative=formula)
|
|
else:
|
|
if 'grad_input_mask' in formula:
|
|
masks = ['should_compute_output({{ {}_ix }}),'.format(n) for n in var_names]
|
|
grad_input_mask = GRAD_INPUT_MASK.substitute(masks=masks, n=len(var_names))
|
|
else:
|
|
grad_input_mask = ''
|
|
idx_ranges = ', '.join("{}_ix".format(n) for n in var_names)
|
|
copy_ranges = []
|
|
for i, n in enumerate(var_names):
|
|
copy_ranges.append(DERIVATIVE_MULTI_COPY_RANGE.substitute(name=n, i=i))
|
|
return DERIVATIVE_MULTI.substitute(
|
|
idx_ranges=idx_ranges, copy_ranges=copy_ranges,
|
|
derivative=formula,
|
|
grad_input_mask=grad_input_mask)
|
|
|
|
body.extend(unpack)
|
|
for derivative in func['derivatives']:
|
|
body.append(emit_derivative(derivative))
|
|
|
|
env['body'] = body
|
|
if func['name'] in UNTRACEABLE_FUNCTIONS:
|
|
env['superclass'] = 'Function'
|
|
else:
|
|
env['superclass'] = 'TraceableFunction'
|
|
return nested_dict(env, func)
|
|
|
|
|
|
def uses_ident(func, ident):
|
|
if func is None:
|
|
return False
|
|
for derivative in func['derivatives']:
|
|
formula = derivative['formula']
|
|
if re.search(IDENT_REGEX.format(ident), formula):
|
|
return True
|
|
return False
|
|
|
|
|
|
def uses_retain_variables(func):
|
|
return uses_ident(func, 'retain_variables')
|
|
|
|
|
|
def uses_single_grad(func):
|
|
return uses_ident(func, 'grad')
|