mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "functionalization: add support for zero_()"
This reverts commit 7d44b3675b.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76375
Approved by: https://github.com/datumbox, https://github.com/albanD
This commit is contained in:
parent
368430036e
commit
40d96f0afd
|
|
@ -132,12 +132,6 @@ Tensor& zero_cpu_(Tensor &self, int64_t nelements) {
|
|||
return self;
|
||||
}
|
||||
|
||||
// Needed for functionalization: we need to convert zero_() directly to zero()
|
||||
// when removing mutations.
|
||||
Tensor zero(const Tensor& self) {
|
||||
return at::zeros_like(self);
|
||||
}
|
||||
|
||||
Tensor& zero_(Tensor &self) {
|
||||
int64_t nelements = c10::multiply_integers(self.sizes());
|
||||
if (self.device() == at::kCPU &&
|
||||
|
|
|
|||
|
|
@ -4922,13 +4922,6 @@
|
|||
- func: zeros.out(int[] size, *, Tensor(a!) out) -> Tensor(a!)
|
||||
|
||||
- func: zeros_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
|
||||
dispatch:
|
||||
# Why give zeros_like() a derivative formula?
|
||||
# For the functionalization pass.
|
||||
# zeros_like() runs *underneath* functionalization (we transform zero_ -> zero -> zeros_like),
|
||||
# But the decomposite for zeros_like() calls into more inplace ops!
|
||||
# Instead, we want zeros_like() to be a primitive w.r.t tracing (which we can do by adding an autograd formula).
|
||||
CompositeExplicitAutograd: zeros_like
|
||||
|
||||
- func: _standard_gamma_grad(Tensor self, Tensor output) -> Tensor
|
||||
variants: function
|
||||
|
|
@ -5148,9 +5141,6 @@
|
|||
SparseCPU, SparseCUDA: resize_as_sparse_
|
||||
SparseCsrCPU, SparseCsrCUDA: resize_as_sparse_csr_
|
||||
|
||||
- func: zero(Tensor self) -> Tensor
|
||||
variants: function
|
||||
|
||||
- func: zero_(Tensor(a!) self) -> Tensor(a!)
|
||||
device_check: NoCheck # TensorIterator
|
||||
variants: method, function
|
||||
|
|
|
|||
|
|
@ -116,6 +116,7 @@ ALLOW_LIST = [
|
|||
("prim::infer_squeeze_size.dim", datetime.date(9999, 1, 1)),
|
||||
("prim::infer_squeeze_size", datetime.date(9999, 1, 1)),
|
||||
("aten::_cat", datetime.date(2022, 5, 15)),
|
||||
("aten::zero", datetime.date(2022, 5, 15)),
|
||||
]
|
||||
|
||||
ALLOW_LIST_COMPILED = [
|
||||
|
|
|
|||
|
|
@ -446,23 +446,6 @@ $1 = torch._ops.aten._to_copy.default($0, dtype=6, layout=0, device=device(type=
|
|||
$2 = torch._ops.aten.expand_copy.default($1, [2])
|
||||
$3 = torch._ops.aten.add.Tensor($2, $0)""")
|
||||
|
||||
# zero_ gets its own test because of the newly added at::zero operator.
|
||||
def test_zero_(self):
|
||||
def f(x):
|
||||
y = x + x
|
||||
z = y.diagonal()
|
||||
z.zero_()
|
||||
return y
|
||||
|
||||
self.assert_functionalization(f, torch.ones(2, 2))
|
||||
logs = self.get_logs(f, torch.ones(2, 2))
|
||||
# zero() should decompose into zeros_like(), which will show up in the trace
|
||||
self.assertExpectedInline('\n'.join(logs), """\
|
||||
$0 = input('input')
|
||||
$1 = torch._ops.aten.add.Tensor($0, $0)
|
||||
$2 = torch._ops.aten.diagonal_copy.default($1)
|
||||
$3 = torch._ops.aten.zeros_like.default($2)""")
|
||||
|
||||
def test_fill_(self):
|
||||
def f(x):
|
||||
y = x + x
|
||||
|
|
|
|||
|
|
@ -1694,10 +1694,6 @@
|
|||
self: zeros_like(grad)
|
||||
result: auto_linear
|
||||
|
||||
- name: zeros_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
|
||||
self: zeros_like(grad)
|
||||
result: auto_linear
|
||||
|
||||
- name: sparse_mask(Tensor self, Tensor mask) -> Tensor
|
||||
self: grad.to_dense().sparse_mask(mask).to_dense()
|
||||
mask: non_differentiable
|
||||
|
|
|
|||
|
|
@ -649,14 +649,11 @@ def process_function(info: DifferentiabilityInfo, template: CodeTemplate) -> str
|
|||
copy_ranges: List[str] = []
|
||||
for i, n in enumerate(var_names):
|
||||
copy_ranges.append(DERIVATIVE_MULTI_COPY_RANGE.substitute(name=n, i=i))
|
||||
return (
|
||||
False,
|
||||
DERIVATIVE_MULTI.substitute(
|
||||
idx_ranges=idx_ranges,
|
||||
copy_ranges=copy_ranges,
|
||||
derivative=formula,
|
||||
grad_input_mask=grad_input_mask,
|
||||
),
|
||||
return False, DERIVATIVE_MULTI.substitute(
|
||||
idx_ranges=idx_ranges,
|
||||
copy_ranges=copy_ranges,
|
||||
derivative=formula,
|
||||
grad_input_mask=grad_input_mask,
|
||||
)
|
||||
|
||||
body.extend(unpack)
|
||||
|
|
|
|||
|
|
@ -151,7 +151,6 @@ _SKIP_PYTHON_BINDINGS = [
|
|||
"_has_same_storage_numel", # used for forward AD internals
|
||||
"_reshape_alias",
|
||||
"replace_", # only used by the functionalization pass, doesn't need to be exposed to python
|
||||
"zero", # only used by the functionalization pass, doesn't need to be exposed to python
|
||||
"copy", # only used by the functionalization pass
|
||||
"fill.Tensor", # only used by the functionalization pass
|
||||
"fill.Scalar", # only used by the functionalization pass
|
||||
|
|
|
|||
|
|
@ -97,6 +97,7 @@ from typing import Callable, List, Optional, Sequence, Tuple, Union, Dict
|
|||
DONT_REQUIRE_DERIVATIVE = {
|
||||
# These only depend on the input Tensor's shape and device, not the data
|
||||
"ones_like",
|
||||
"zeros_like",
|
||||
"rand_like",
|
||||
"randn_like",
|
||||
# These are only implemented on integral types
|
||||
|
|
@ -170,7 +171,6 @@ GRADIENT_IMPLEMENTED_FOR_COMPLEX = {
|
|||
"triu",
|
||||
"chunk",
|
||||
"zero_",
|
||||
"zeros_like",
|
||||
"eq_",
|
||||
"ne_",
|
||||
"add",
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from torchgen.api.autograd import (
|
|||
)
|
||||
from torchgen.api.types import (
|
||||
Binding,
|
||||
CppSignatureGroup,
|
||||
NamedCType,
|
||||
BaseCType,
|
||||
VectorCType,
|
||||
|
|
@ -31,7 +32,6 @@ from torchgen.api.types import (
|
|||
stringT,
|
||||
)
|
||||
from torchgen.api import cpp
|
||||
from torchgen.api import dispatcher
|
||||
from torchgen.gen import parse_native_yaml, get_grouped_by_view_native_functions
|
||||
from torchgen.context import with_native_function
|
||||
from torchgen.model import (
|
||||
|
|
@ -130,8 +130,8 @@ def load_derivatives(
|
|||
|
||||
|
||||
@with_native_function
|
||||
def dispatcher_arguments(f: NativeFunction) -> Sequence[Binding]:
|
||||
return dispatcher.arguments(f.func)
|
||||
def cpp_arguments(f: NativeFunction) -> Sequence[Binding]:
|
||||
return CppSignatureGroup.from_native_function(f, method=False).signature.arguments()
|
||||
|
||||
|
||||
def create_derivative(
|
||||
|
|
@ -142,7 +142,7 @@ def create_derivative(
|
|||
) -> Derivative:
|
||||
original_formula = formula
|
||||
arguments: List[NamedCType] = [
|
||||
a.nctype.remove_const_ref() for a in dispatcher_arguments(f)
|
||||
a.nctype.remove_const_ref() for a in cpp_arguments(f)
|
||||
]
|
||||
|
||||
return_names = tuple(n if n != "self" else "result" for n in cpp.return_names(f))
|
||||
|
|
@ -470,7 +470,7 @@ def create_differentiability_info(
|
|||
non_differentiable_arg_names: List[str] = []
|
||||
args_with_derivatives_set: Set[str] = set()
|
||||
|
||||
all_arg_names = [a.name for a in dispatcher_arguments(f)]
|
||||
all_arg_names = [a.name for a in cpp_arguments(f)]
|
||||
all_ret_names = [
|
||||
r.name for r in f.func.returns
|
||||
] # only used for the assert below
|
||||
|
|
@ -529,7 +529,7 @@ def create_differentiability_info(
|
|||
# TODO: do we need eagerly calculate and save it here? Can it be derived
|
||||
# from NativeFunction and `derivatives` on callsites instead?
|
||||
args_with_derivatives = [
|
||||
a for a in dispatcher_arguments(f) if a.name in args_with_derivatives_set
|
||||
a for a in cpp_arguments(f) if a.name in args_with_derivatives_set
|
||||
]
|
||||
|
||||
# Postprocess forward derivatives definitions now that we know the differentiable arguments
|
||||
|
|
@ -603,14 +603,14 @@ def create_differentiability_info(
|
|||
)
|
||||
|
||||
canonical = canonical_function(functions, defn_name)
|
||||
if "grad_input_mask" in (a.name for a in dispatcher_arguments(canonical)):
|
||||
if "grad_input_mask" in (a.name for a in cpp_arguments(canonical)):
|
||||
raise RuntimeError(
|
||||
f"Schema for {defn_name} has an argument named grad_input_mask, "
|
||||
"but this name would be shadowed by our codegen. "
|
||||
"Please use a different name in native_functions.yaml."
|
||||
)
|
||||
|
||||
if "result" in (a.name for a in dispatcher_arguments(canonical)):
|
||||
if "result" in (a.name for a in cpp_arguments(canonical)):
|
||||
raise RuntimeError(
|
||||
f"Schema for {defn_name} has an argument named result, "
|
||||
"but this is only allowed for outputs."
|
||||
|
|
|
|||
|
|
@ -15041,6 +15041,10 @@ op_db: List[OpInfo] = [
|
|||
# JIT has issue when op is passed as lambda
|
||||
# AssertionError: JIT Test does not execute any logic
|
||||
DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
|
||||
# Fails due to a limitation of gradgradcheck
|
||||
# https://github.com/pytorch/pytorch/issues/59137
|
||||
DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_fn_gradgrad'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_inplace_gradgrad'),
|
||||
DecorateInfo(unittest.skip('Allowed exemption'), 'TestCompositeCompliance', 'test_operator'),
|
||||
DecorateInfo(unittest.skip('Allowed exemption'), 'TestCompositeCompliance', 'test_backward'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad'),
|
||||
|
|
|
|||
|
|
@ -161,12 +161,9 @@ def argumenttype_ivalue_convert(
|
|||
def _gen_code_base_type(
|
||||
arg_name: str, out_name: str, ctype: CType
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
return (
|
||||
[
|
||||
f"{ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.to<{ctype.cpp_type(strip_ref=True)}>();"
|
||||
],
|
||||
[],
|
||||
)
|
||||
return [
|
||||
f"{ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.to<{ctype.cpp_type(strip_ref=True)}>();"
|
||||
], []
|
||||
|
||||
|
||||
def _gen_code_optional_type(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user