pytorch/tools/autograd/gen_view_funcs.py
Joel Schlosser d5a6762263 Reify view_func() closures as ViewFuncs (#118404)
Replaces `view_func()` closures with a reified `ViewFunc` data structure. Codegen generates a `ViewFunc` subclass for each view op (e.g. `NarrowViewFunc`) containing state needed to reconstruct the view. The `ViewFunc` API allows for querying and hot-swapping any `SymInt`s or `Tensors` in the state through `get_symints()` / `get_tensors()` / `clone_and_set()`, which will be essential for fake-ification later on.

```cpp
/// Base class for view functions, providing reapplication of a view on a new base.
/// Each view op should get a codegenerated subclass of this class containing
/// any state needed to reconstruct the view. The class also provides convenience
/// accessors for saved SymInts / tensor state. This is useful for e.g. fake-ification,
/// where we want to use symbolic values or fake tensors instead.
struct TORCH_API ViewFunc {
  virtual ~ViewFunc() {}
  /// Returns any SymInts in the saved state.
  virtual std::vector<c10::SymInt> get_symints() const { return {}; }
  /// Returns the number of SymInts in the saved state.
  virtual size_t num_symints() const { return 0; }
  /// Returns any tensors in the saved state.
  virtual std::vector<at::Tensor> get_tensors() const { return {}; }
  /// Returns the number of tensors in the saved state.
  virtual size_t num_tensors() const { return 0; }
  /// Reapplies the view on the given base using the saved state.
  virtual at::Tensor operator()(const at::Tensor&) const = 0;
  /// Returns a clone of this ViewFunc, optionally with the specified saved state.
  virtual std::unique_ptr<ViewFunc> clone_and_set(
      std::optional<std::vector<c10::SymInt>> = c10::nullopt,
      std::optional<std::vector<at::Tensor>> = c10::nullopt) const = 0;

protected:
  /// Sets the values of any SymInts in the saved state. The input vector size must
  /// match the number of SymInts in the saved state (i.e. the size of the list
  /// returned by get_symints()).
  virtual void set_symints(std::vector<c10::SymInt>) {}
  /// Sets the values of any Tensors in the saved state. The input vector size must
  /// match the number of Tensors in the saved state (i.e. the size of the list
  /// returned by get_tensors()).
  virtual void set_tensors(std::vector<at::Tensor>) {}
};
```

New codegen files:
* `torch/csrc/autograd/generated/ViewFunc.h`
* `torch/csrc/autograd/generated/ViewFuncs.cpp`

The templates for these also contains impls for `ChainedViewFunc` and `ErroringViewFunc` which are used in a few places within autograd.

Example codegen for `slice.Tensor`:
```cpp
// torch/csrc/autograd/generated/ViewFuncs.h
#define SLICE_TENSOR_VIEW_FUNC_AVAILABLE
struct SliceTensorViewFunc : public torch::autograd::ViewFunc {
  SliceTensorViewFunc(int64_t dim, c10::optional<c10::SymInt> start, c10::optional<c10::SymInt> end, c10::SymInt step) : dim(dim), start(start), end(end), step(step)
  {};
  virtual ~SliceTensorViewFunc() override {};
  virtual std::vector<c10::SymInt> get_symints() const override;
  virtual size_t num_symints() const override;
  virtual std::vector<at::Tensor> get_tensors() const override;
  virtual size_t num_tensors() const override;
  virtual at::Tensor operator()(const at::Tensor&) const override;
  virtual std::unique_ptr<ViewFunc> clone_and_set(
      std::optional<std::vector<c10::SymInt>> = c10::nullopt,
      std::optional<std::vector<at::Tensor>> = c10::nullopt) const override;

protected:
  virtual void set_symints(std::vector<c10::SymInt>) override;
  virtual void set_tensors(std::vector<at::Tensor>) override;

private:
  int64_t dim;
  c10::optional<c10::SymInt> start;
  c10::optional<c10::SymInt> end;
  c10::SymInt step;
};
...

// torch/csrc/autograd/generated/ViewFuncs.cpp
std::vector<c10::SymInt> SliceTensorViewFunc::get_symints() const {
  ::std::vector<c10::SymInt> symints;
  symints.reserve((start.has_value() ? 1 : 0) + (end.has_value() ? 1 : 0) + 1);
  if(start.has_value()) symints.insert(symints.end(), *(start));
  if(end.has_value()) symints.insert(symints.end(), *(end));
  symints.push_back(step);
  return symints;
}

size_t SliceTensorViewFunc::num_symints() const {
  return static_cast<size_t>((start.has_value() ? 1 : 0) + (end.has_value() ? 1 : 0) + 1);
}

void SliceTensorViewFunc::set_symints(std::vector<c10::SymInt> symints) {
  TORCH_INTERNAL_ASSERT(symints.size() == num_symints());
  auto i = 0;
  if(start.has_value()) start = symints[i];
  i += (start.has_value() ? 1 : 0);
  if(end.has_value()) end = symints[i];
  i += (end.has_value() ? 1 : 0);
  step = symints[i];
}

std::vector<at::Tensor> SliceTensorViewFunc::get_tensors() const {
  ::std::vector<at::Tensor> tensors;
  return tensors;
}

size_t SliceTensorViewFunc::num_tensors() const {
  return static_cast<size_t>(0);
}

void SliceTensorViewFunc::set_tensors(std::vector<at::Tensor> tensors) {
  TORCH_INTERNAL_ASSERT(tensors.size() == num_tensors());

}

at::Tensor SliceTensorViewFunc::operator()(const at::Tensor& input_base) const {
  return at::_ops::slice_Tensor::call(input_base, dim, start, end, step);
}

std::unique_ptr<ViewFunc> SliceTensorViewFunc::clone_and_set(
    std::optional<std::vector<c10::SymInt>> symints,
    std::optional<std::vector<at::Tensor>> tensors) const {
  auto output = std::make_unique<SliceTensorViewFunc>(dim, start, end, step);
  if (symints.has_value()) {
    output->set_symints(std::move(*(symints)));
  }
  if (tensors.has_value()) {
    output->set_tensors(std::move(*(tensors)));
  }
  return output;
}
```

The `_view_func()` / `_view_func_unsafe()` methods now accept two additional (optional) args for `symint_visitor_fn` / `tensor_visitor_fn`. If these are defined, they are expected to be python callables that operate on a single SymInt / tensor and return a new one. This allows for the hot-swapping needed during fake-ification.

For testing, there are extensive pre-existing tests, and I added a test to ensure that hot-swapping functions correctly.
```sh
python test/test_autograd.py -k test_view_func_replay
python test/test_ops.py -k test_view_replay
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118404
Approved by: https://github.com/ezyang
2024-02-09 18:51:36 +00:00

335 lines
11 KiB
Python

# Generates ViewFuncs.h/cpp
#
# NOTE: If any changes are being made to the ViewFunc codegen please also check
# if updates are needed in torch/csrc/autograd/autograd_not_implemented_fallback.cpp
# The fallback is expected to mimic this codegen, so we should keep the two in sync.
from typing import List, Tuple
import torchgen.api.dispatcher as dispatcher
from torchgen.api.autograd import NativeFunctionWithDifferentiabilityInfo
from torchgen.api.translate import translate
from torchgen.api.types import (
BaseCType,
Binding,
NamedCType,
SymIntT,
tensorT,
VectorCType,
)
from torchgen.code_template import CodeTemplate
from torchgen.model import Argument, NativeFunction, OptionalType
from torchgen.utils import FileManager
from .gen_inplace_or_view_type import (
CALL_DISPATCH,
extract_bindings,
get_view_info,
modifies_arguments,
use_derived,
)
FUNCTION_DECLARATION = CodeTemplate(
"""\
#define ${uppercase_op}_AVAILABLE
struct ${op} : public ${superclass} {
${op}(${constructor_args}) ${initializer_list}
{};
virtual ~${op}() override {};
virtual std::vector<c10::SymInt> get_symints() const override;
virtual size_t num_symints() const override;
virtual std::vector<at::Tensor> get_tensors() const override;
virtual size_t num_tensors() const override;
virtual at::Tensor operator()(const at::Tensor&) const override;
virtual std::unique_ptr<ViewFunc> clone_and_set(
std::optional<std::vector<c10::SymInt>> = c10::nullopt,
std::optional<std::vector<at::Tensor>> = c10::nullopt) const override;
protected:
virtual void set_symints(std::vector<c10::SymInt>) override;
virtual void set_tensors(std::vector<at::Tensor>) override;
private:
${state}
};
"""
)
FUNCTION_DEFINITION = CodeTemplate(
"""\
std::vector<c10::SymInt> ${op}::get_symints() const {
${get_symints}
}
size_t ${op}::num_symints() const {
return static_cast<size_t>(${num_symints});
}
void ${op}::set_symints(std::vector<c10::SymInt> ${symints_vec}) {
TORCH_INTERNAL_ASSERT(${symints_vec}.size() == num_symints());
${set_symints}
}
std::vector<at::Tensor> ${op}::get_tensors() const {
${get_tensors}
}
size_t ${op}::num_tensors() const {
return static_cast<size_t>(${num_tensors});
}
void ${op}::set_tensors(std::vector<at::Tensor> ${tensors_vec}) {
TORCH_INTERNAL_ASSERT(${tensors_vec}.size() == num_tensors());
${set_tensors}
}
at::Tensor ${op}::operator()(const at::Tensor& ${call_input_name}) const {
return ${op_call};
}
std::unique_ptr<ViewFunc> ${op}::clone_and_set(
std::optional<std::vector<c10::SymInt>> ${symints_vec},
std::optional<std::vector<at::Tensor>> ${tensors_vec}) const {
auto output = std::make_unique<${op}>(${clone_args});
if (${symints_vec}.has_value()) {
output->set_symints(std::move(*(${symints_vec})));
}
if (${tensors_vec}.has_value()) {
output->set_tensors(std::move(*(${tensors_vec})));
}
return output;
}
"""
)
# e.g. as_strided -> AsStridedViewFunc for camel case or
# as_strided_view_func otherwise
def view_func_name(
f: NativeFunction, include_namespace: bool = False, camel_case: bool = True
) -> str:
name = f.func.name.unambiguous_name()
view_func_name = f"{name.replace('.', '_')}_view_func"
if camel_case:
is_private = view_func_name.startswith("_")
view_func_name = "".join(
[p.title() for p in view_func_name.replace(".", "_").split("_")]
)
if is_private:
# put the leading underscore back in
view_func_name = f"_{view_func_name}"
namespace = "torch::autograd::generated::" if include_namespace else ""
return f"{namespace}{view_func_name}"
def is_symint_or_tensor(arg: Argument) -> bool:
return arg.type.is_tensor_like() or arg.type.is_symint_like()
def remove_const_ref(binding: Binding) -> Binding:
return Binding(
name=binding.name,
nctype=binding.nctype.remove_const_ref(),
argument=binding.argument,
default=binding.default,
)
def returns_multi_tensor(fn: NativeFunction) -> bool:
returns = fn.func.returns
assert len(returns) == 1
returns_list_like = returns[0].type.is_list_like() is not None
returns_tensor_like = returns[0].type.is_tensor_like()
return returns_list_like and returns_tensor_like
# Generates strings with logic for getting / setting state of a particular type.
#
# Args:
# bindings (list): List of state bindings of interest (may be empty)
# state_vec_type (NamedCType): Type of vector to either return or copy from
#
# Returns:
# tuple: (list of getter logic strings, list of setter logic strings, string
# with num items expression)
def generate_state_getter_setter(
bindings: List[Binding],
state_vec_type: NamedCType,
) -> Tuple[List[str], List[str], str]:
getter_logic = []
setter_logic = []
state_vec = state_vec_type.name
getter_logic.append(f"{state_vec_type.cpp_type()} {state_vec};")
if len(bindings) > 0:
setter_logic.append("auto i = 0;")
num_exprs = []
for i, b in enumerate(bindings):
assert isinstance(b.argument, Argument)
if b.argument.type.is_list_like():
# Handle list-likes.
num_expr = f"{b.name}.size()"
num_exprs.append(num_expr)
getter = f"{state_vec}.insert({state_vec}.end(), {b.name}.begin(), {b.name}.end());"
setter = f"std::copy({state_vec}.begin() + i, {state_vec}.begin() + i + {b.name}.size(), {b.name}.begin());"
elif isinstance(b.argument.type, OptionalType):
# Handle optionals.
num_expr = f"({b.name}.has_value() ? 1 : 0)"
num_exprs.append(num_expr)
conditional = f"if({b.name}.has_value())"
getter = (
f"{conditional} {state_vec}.insert({state_vec}.end(), *({b.name}));"
)
setter = f"{conditional} {b.name} = {state_vec}[i];"
else:
num_expr = "1"
num_exprs.append(num_expr)
getter = f"{state_vec}.push_back({b.name});"
setter = f"{b.name} = {state_vec}[i];"
getter_logic.append(getter)
setter_logic.append(setter)
if i < len(bindings) - 1:
setter_logic.append(f"i += {num_expr};")
# Reserve / assert based on the total number of items expression.
num_items = "0" if len(num_exprs) == 0 else " + ".join(num_exprs)
if len(bindings) > 0:
getter_logic.insert(1, f"{state_vec}.reserve({num_items});")
getter_logic.append(f"return {state_vec};")
return getter_logic, setter_logic, num_items
def process_function(fn: NativeFunction, template: CodeTemplate) -> str:
bindings = extract_bindings(fn)
non_self_bindings = [b for b in bindings if b.name != "self"]
non_self_args = fn.func.arguments.flat_all[1:]
non_self_value_bindings = [
dispatcher.argument(a, remove_non_owning_ref_types=True) for a in non_self_args
]
# Generate constructor / clone args for the generated struct.
constructor_args = [b.defn() for b in non_self_bindings]
clone_args = [b.name for b in non_self_bindings]
# Generate state variable declarations for the generated struct.
state_variables = [
f"{remove_const_ref(b).defn()};" for b in non_self_value_bindings
]
# Generate initializer list expressions for the generated struct.
# allow_expensive_conversions=True because we need to store e.g. SymIntArrayRefs as
# vector<SymInt>s.
init_exprs = translate(
non_self_bindings, non_self_value_bindings, allow_expensive_conversions=True
)
initializers = []
for b, init_expr in zip(non_self_bindings, init_exprs):
name = b.nctype.name
assert isinstance(name, str)
initializers.append(f"{name}({init_expr.expr})")
# Generate call to underlying view op
call_input_name = "input_base"
op_call_args = [call_input_name, *(b.name for b in non_self_bindings)]
op_call = CALL_DISPATCH.substitute(
unambiguous_name=fn.func.name.unambiguous_name(),
unpacked_args=op_call_args,
)
# Multi-output views additionally require a view_idx for disambiguation.
if returns_multi_tensor(fn):
view_idx_name = "view_idx"
view_idx_typename = "int64_t"
view_idx_decl = f"{view_idx_typename} {view_idx_name}"
constructor_args.append(view_idx_decl)
clone_args.append(view_idx_name)
state_variables.append(f"{view_idx_decl};")
initializers.append(f"{view_idx_name}({view_idx_name})")
op_call += f"[{view_idx_name}]"
# Generate initializer list for the generated struct.
initializer_list = f": {', '.join(initializers)}" if len(initializers) > 0 else ""
# Generate getter / setter logic for any symints.
symint_bindings = [
b
for b in non_self_bindings
if isinstance(b.argument, Argument) and b.argument.type.is_symint_like()
]
symints_vec_type = NamedCType("symints", VectorCType(BaseCType(SymIntT)))
get_symints, set_symints, num_symints = generate_state_getter_setter(
symint_bindings, symints_vec_type
)
# Generate getter / setter logic for any tensors.
tensor_bindings = [
b
for b in non_self_bindings
if isinstance(b.argument, Argument) and b.argument.type.is_tensor_like()
]
tensors_vec_type = NamedCType("tensors", VectorCType(BaseCType(tensorT)))
get_tensors, set_tensors, num_tensors = generate_state_getter_setter(
tensor_bindings, tensors_vec_type
)
return template.substitute(
op=view_func_name(fn),
uppercase_op=view_func_name(fn, camel_case=False).upper(),
superclass="torch::autograd::ViewFunc",
initializer_list=initializer_list,
state=state_variables,
constructor_args=constructor_args,
clone_args=clone_args,
symints_vec=symints_vec_type.name,
get_symints=get_symints,
set_symints=set_symints,
num_symints=num_symints,
tensors_vec=tensors_vec_type.name,
get_tensors=get_tensors,
set_tensors=set_tensors,
num_tensors=num_tensors,
call_input_name=call_input_name,
op_call=op_call,
)
def gen_view_funcs(
out: str,
fns_with_infos: List[NativeFunctionWithDifferentiabilityInfo],
template_path: str,
) -> None:
# don't need the info parts, just the function
fns = [fn.func for fn in fns_with_infos if use_derived(fn)]
# only want out-of-place views
view_fns = [
fn for fn in fns if get_view_info(fn) is not None and not modifies_arguments(fn)
]
declarations = [process_function(fn, FUNCTION_DECLARATION) for fn in view_fns]
definitions = [process_function(fn, FUNCTION_DEFINITION) for fn in view_fns]
ops_headers = [f"#include <ATen/ops/{fn.root_name}_ops.h>" for fn in view_fns]
file_basename = "ViewFuncs"
fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
for suffix in [".h", ".cpp"]:
fname = file_basename + suffix
fm.write_with_template(
fname,
fname,
lambda: {
"generated_comment": "@"
+ f"generated from {fm.template_dir_for_comments()}/"
+ fname,
"view_func_declarations": declarations,
"view_func_definitions": definitions,
"ops_headers": ops_headers,
},
)