Revert "functionalization: add support for zero_()"

This reverts commit 7d44b3675b.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76375

Approved by: https://github.com/datumbox, https://github.com/albanD
This commit is contained in:
Brian Hirsh 2022-04-26 09:57:46 -07:00 committed by PyTorch MergeBot
parent 368430036e
commit 40d96f0afd
11 changed files with 22 additions and 61 deletions

View File

@ -132,12 +132,6 @@ Tensor& zero_cpu_(Tensor &self, int64_t nelements) {
return self;
}
// Needed for functionalization: we need to convert zero_() directly to zero()
// when removing mutations.
Tensor zero(const Tensor& self) {
return at::zeros_like(self);
}
Tensor& zero_(Tensor &self) {
int64_t nelements = c10::multiply_integers(self.sizes());
if (self.device() == at::kCPU &&

View File

@ -4922,13 +4922,6 @@
- func: zeros.out(int[] size, *, Tensor(a!) out) -> Tensor(a!)
- func: zeros_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
dispatch:
# Why give zeros_like() a derivative formula?
# For the functionalization pass.
# zeros_like() runs *underneath* functionalization (we transform zero_ -> zero -> zeros_like),
# But the decomposite for zeros_like() calls into more inplace ops!
# Instead, we want zeros_like() to be a primitive w.r.t tracing (which we can do by adding an autograd formula).
CompositeExplicitAutograd: zeros_like
- func: _standard_gamma_grad(Tensor self, Tensor output) -> Tensor
variants: function
@ -5148,9 +5141,6 @@
SparseCPU, SparseCUDA: resize_as_sparse_
SparseCsrCPU, SparseCsrCUDA: resize_as_sparse_csr_
- func: zero(Tensor self) -> Tensor
variants: function
- func: zero_(Tensor(a!) self) -> Tensor(a!)
device_check: NoCheck # TensorIterator
variants: method, function

View File

@ -116,6 +116,7 @@ ALLOW_LIST = [
("prim::infer_squeeze_size.dim", datetime.date(9999, 1, 1)),
("prim::infer_squeeze_size", datetime.date(9999, 1, 1)),
("aten::_cat", datetime.date(2022, 5, 15)),
("aten::zero", datetime.date(2022, 5, 15)),
]
ALLOW_LIST_COMPILED = [

View File

@ -446,23 +446,6 @@ $1 = torch._ops.aten._to_copy.default($0, dtype=6, layout=0, device=device(type=
$2 = torch._ops.aten.expand_copy.default($1, [2])
$3 = torch._ops.aten.add.Tensor($2, $0)""")
# zero_ gets its own test because of the newly added at::zero operator.
def test_zero_(self):
def f(x):
y = x + x
z = y.diagonal()
z.zero_()
return y
self.assert_functionalization(f, torch.ones(2, 2))
logs = self.get_logs(f, torch.ones(2, 2))
# zero() should decompose into zeros_like(), which will show up in the trace
self.assertExpectedInline('\n'.join(logs), """\
$0 = input('input')
$1 = torch._ops.aten.add.Tensor($0, $0)
$2 = torch._ops.aten.diagonal_copy.default($1)
$3 = torch._ops.aten.zeros_like.default($2)""")
def test_fill_(self):
def f(x):
y = x + x

View File

@ -1694,10 +1694,6 @@
self: zeros_like(grad)
result: auto_linear
- name: zeros_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
self: zeros_like(grad)
result: auto_linear
- name: sparse_mask(Tensor self, Tensor mask) -> Tensor
self: grad.to_dense().sparse_mask(mask).to_dense()
mask: non_differentiable

View File

@ -649,14 +649,11 @@ def process_function(info: DifferentiabilityInfo, template: CodeTemplate) -> str
copy_ranges: List[str] = []
for i, n in enumerate(var_names):
copy_ranges.append(DERIVATIVE_MULTI_COPY_RANGE.substitute(name=n, i=i))
return (
False,
DERIVATIVE_MULTI.substitute(
idx_ranges=idx_ranges,
copy_ranges=copy_ranges,
derivative=formula,
grad_input_mask=grad_input_mask,
),
return False, DERIVATIVE_MULTI.substitute(
idx_ranges=idx_ranges,
copy_ranges=copy_ranges,
derivative=formula,
grad_input_mask=grad_input_mask,
)
body.extend(unpack)

View File

@ -151,7 +151,6 @@ _SKIP_PYTHON_BINDINGS = [
"_has_same_storage_numel", # used for forward AD internals
"_reshape_alias",
"replace_", # only used by the functionalization pass, doesn't need to be exposed to python
"zero", # only used by the functionalization pass, doesn't need to be exposed to python
"copy", # only used by the functionalization pass
"fill.Tensor", # only used by the functionalization pass
"fill.Scalar", # only used by the functionalization pass

View File

@ -97,6 +97,7 @@ from typing import Callable, List, Optional, Sequence, Tuple, Union, Dict
DONT_REQUIRE_DERIVATIVE = {
# These only depend on the input Tensor's shape and device, not the data
"ones_like",
"zeros_like",
"rand_like",
"randn_like",
# These are only implemented on integral types
@ -170,7 +171,6 @@ GRADIENT_IMPLEMENTED_FOR_COMPLEX = {
"triu",
"chunk",
"zero_",
"zeros_like",
"eq_",
"ne_",
"add",

View File

@ -15,6 +15,7 @@ from torchgen.api.autograd import (
)
from torchgen.api.types import (
Binding,
CppSignatureGroup,
NamedCType,
BaseCType,
VectorCType,
@ -31,7 +32,6 @@ from torchgen.api.types import (
stringT,
)
from torchgen.api import cpp
from torchgen.api import dispatcher
from torchgen.gen import parse_native_yaml, get_grouped_by_view_native_functions
from torchgen.context import with_native_function
from torchgen.model import (
@ -130,8 +130,8 @@ def load_derivatives(
@with_native_function
def dispatcher_arguments(f: NativeFunction) -> Sequence[Binding]:
return dispatcher.arguments(f.func)
def cpp_arguments(f: NativeFunction) -> Sequence[Binding]:
return CppSignatureGroup.from_native_function(f, method=False).signature.arguments()
def create_derivative(
@ -142,7 +142,7 @@ def create_derivative(
) -> Derivative:
original_formula = formula
arguments: List[NamedCType] = [
a.nctype.remove_const_ref() for a in dispatcher_arguments(f)
a.nctype.remove_const_ref() for a in cpp_arguments(f)
]
return_names = tuple(n if n != "self" else "result" for n in cpp.return_names(f))
@ -470,7 +470,7 @@ def create_differentiability_info(
non_differentiable_arg_names: List[str] = []
args_with_derivatives_set: Set[str] = set()
all_arg_names = [a.name for a in dispatcher_arguments(f)]
all_arg_names = [a.name for a in cpp_arguments(f)]
all_ret_names = [
r.name for r in f.func.returns
] # only used for the assert below
@ -529,7 +529,7 @@ def create_differentiability_info(
# TODO: do we need eagerly calculate and save it here? Can it be derived
# from NativeFunction and `derivatives` on callsites instead?
args_with_derivatives = [
a for a in dispatcher_arguments(f) if a.name in args_with_derivatives_set
a for a in cpp_arguments(f) if a.name in args_with_derivatives_set
]
# Postprocess forward derivatives definitions now that we know the differentiable arguments
@ -603,14 +603,14 @@ def create_differentiability_info(
)
canonical = canonical_function(functions, defn_name)
if "grad_input_mask" in (a.name for a in dispatcher_arguments(canonical)):
if "grad_input_mask" in (a.name for a in cpp_arguments(canonical)):
raise RuntimeError(
f"Schema for {defn_name} has an argument named grad_input_mask, "
"but this name would be shadowed by our codegen. "
"Please use a different name in native_functions.yaml."
)
if "result" in (a.name for a in dispatcher_arguments(canonical)):
if "result" in (a.name for a in cpp_arguments(canonical)):
raise RuntimeError(
f"Schema for {defn_name} has an argument named result, "
"but this is only allowed for outputs."

View File

@ -15041,6 +15041,10 @@ op_db: List[OpInfo] = [
# JIT has issue when op is passed as lambda
# AssertionError: JIT Test does not execute any logic
DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
# Fails due to a limitation of gradgradcheck
# https://github.com/pytorch/pytorch/issues/59137
DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_fn_gradgrad'),
DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_inplace_gradgrad'),
DecorateInfo(unittest.skip('Allowed exemption'), 'TestCompositeCompliance', 'test_operator'),
DecorateInfo(unittest.skip('Allowed exemption'), 'TestCompositeCompliance', 'test_backward'),
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad'),

View File

@ -161,12 +161,9 @@ def argumenttype_ivalue_convert(
def _gen_code_base_type(
arg_name: str, out_name: str, ctype: CType
) -> Tuple[List[str], List[str]]:
return (
[
f"{ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.to<{ctype.cpp_type(strip_ref=True)}>();"
],
[],
)
return [
f"{ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.to<{ctype.cpp_type(strip_ref=True)}>();"
], []
def _gen_code_optional_type(