mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[Inductor] Added aten.normal_ decomp (#91207)
Fixes #91085 Pull Request resolved: https://github.com/pytorch/pytorch/pull/91207 Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/lezcano
This commit is contained in:
parent
092e28f17f
commit
b6df987671
|
|
@ -396,6 +396,7 @@ dtensor_fails = {
|
|||
xfail("norm", "nuc"),
|
||||
xfail("normal"),
|
||||
xfail("normal", "number_mean"),
|
||||
xfail("normal", "in_place"),
|
||||
xfail("ormqr"),
|
||||
xfail("ones"),
|
||||
xfail("pca_lowrank"),
|
||||
|
|
|
|||
|
|
@ -957,16 +957,6 @@ aten::nll_loss2d_forward
|
|||
aten::nll_loss2d_forward.output
|
||||
aten::nonzero
|
||||
aten::nonzero.out
|
||||
aten::normal.Tensor_Tensor
|
||||
aten::normal.Tensor_Tensor_out
|
||||
aten::normal.Tensor_float
|
||||
aten::normal.Tensor_float_out
|
||||
aten::normal.float_Tensor
|
||||
aten::normal.float_Tensor_out
|
||||
aten::normal.float_float
|
||||
aten::normal.float_float_out
|
||||
aten::normal.out
|
||||
aten::normal_
|
||||
aten::normal_functional
|
||||
aten::ones.names
|
||||
aten::ones.names_out
|
||||
|
|
|
|||
|
|
@ -2345,6 +2345,7 @@ aot_autograd_failures = {
|
|||
xfail('cov'),
|
||||
xfail('chalf'), # RuntimeError: "sum_cpu" not implemented for 'ComplexHalf'
|
||||
xfail('sparse.sampled_addmm'),
|
||||
xfail('normal', 'number_mean'), # TypeError: randn_like(): argument 'input' (position 1) must be Tensor, not float
|
||||
xfail('sparse.mm', 'reduce'),
|
||||
skip('nn.functional.binary_cross_entropy_with_logits'), # seems to fail sometimes?
|
||||
skip('nn.functional.margin_ranking_loss'), # seems flaky
|
||||
|
|
@ -2491,7 +2492,6 @@ symbolic_aot_autograd_failures = {
|
|||
xfail('nn.functional.smooth_l1_loss', ''), # could not find kernel
|
||||
xfail('nn.functional.unfold', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('norm', 'nuc'), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('normal', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('normal', 'number_mean'), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('ormqr', ''), # aten.ormqr.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('pca_lowrank', ''), # could not find kernel
|
||||
|
|
|
|||
|
|
@ -262,6 +262,7 @@ inductor_expected_failures_single_sample["cpu"] = {
|
|||
"exponential": {f16},
|
||||
"geometric": {f16},
|
||||
"log_normal": {f16},
|
||||
"normal.in_place": {f16, f32, f64},
|
||||
"uniform": {f16},
|
||||
"unique": {b8, f32, f64, i32, i64},
|
||||
"unique_consecutive": {b8, f32, f64, i32, i64},
|
||||
|
|
@ -336,6 +337,7 @@ inductor_expected_failures_single_sample["cuda"] = {
|
|||
"cauchy": {f16, f32, f64},
|
||||
"exponential": {f16, f32, f64},
|
||||
"geometric": {f16, f32, f64, i32, i64},
|
||||
"normal.in_place": {f16, f32, f64},
|
||||
"log_normal": {f16, f32, f64},
|
||||
"uniform": {f16, f32, f64},
|
||||
"unique": {b8, f16, f32, f64, i32, i64},
|
||||
|
|
|
|||
|
|
@ -1349,7 +1349,6 @@ symbolic_tensor_failures = {
|
|||
xfail('nn.functional.pixel_unshuffle', ''), # aten.pixel_unshuffle.default - couldn't find symbolic meta function/deco...
|
||||
xfail('nn.functional.smooth_l1_loss', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('nonzero', ''), # aten.nonzero.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('normal', ''), # aten.normal.Tensor_Tensor - couldn't find symbolic meta function/decomposition
|
||||
xfail('normal', 'number_mean'), # aten.normal.float_Tensor - couldn't find symbolic meta function/decomposition
|
||||
xfail('ormqr', ''), # aten.ormqr.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('pca_lowrank', ''), # aten.mm.default - couldn't find symbolic meta function/decomposition
|
||||
|
|
|
|||
|
|
@ -387,6 +387,9 @@ extra_random_decomps = get_decompositions(
|
|||
aten.exponential_,
|
||||
aten.geometric,
|
||||
aten.geometric_,
|
||||
aten.normal,
|
||||
aten.normal_,
|
||||
aten.normal_functional,
|
||||
aten.log_normal,
|
||||
aten.log_normal_,
|
||||
aten.uniform_,
|
||||
|
|
|
|||
|
|
@ -5330,6 +5330,19 @@ def log_normal(self, mean=1, std=2, generator=None):
|
|||
return torch.exp(std * torch.randn_like(self) + mean)
|
||||
|
||||
|
||||
# TODO: add support for functionalization aten.normal_functional
|
||||
@register_decomposition(aten.normal)
|
||||
@out_wrapper()
|
||||
@elementwise_type_promotion_wrapper(
|
||||
type_promoting_args=("self",),
|
||||
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
||||
)
|
||||
def normal(self, mean=0, std=1, generator=None):
|
||||
assert generator is None
|
||||
utils.check(std >= 0, lambda: f"normal expects std >= 0.0, but found std {std}")
|
||||
return std * torch.randn_like(self) + mean
|
||||
|
||||
|
||||
# inplace
|
||||
abs_ = _make_inplace(abs)
|
||||
acos_ = _make_inplace(acos)
|
||||
|
|
@ -5421,6 +5434,7 @@ xlogy_ = _make_inplace(xlogy)
|
|||
cauchy_ = _make_inplace(cauchy)
|
||||
exponential_ = _make_inplace(exponential)
|
||||
geometric_ = _make_inplace(geometric)
|
||||
normal_ = _make_inplace(normal)
|
||||
log_normal_ = _make_inplace(log_normal)
|
||||
zero_ = _make_inplace(zero)
|
||||
|
||||
|
|
|
|||
|
|
@ -796,6 +796,24 @@ def sample_inputs_randn(op, device, dtype, requires_grad, **kwargs):
|
|||
for shape in shapes:
|
||||
yield SampleInput(input=shape, kwargs=dict(dtype=dtype, device=device, requires_grad=requires_grad))
|
||||
|
||||
def sample_inputs_normal(op, device, dtype, requires_grad, **kwargs):
|
||||
|
||||
make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=False)
|
||||
samples = (
|
||||
((S, S), 0, 5),
|
||||
((S, S, S), -2, 0.5),
|
||||
)
|
||||
for shape, mean, std in samples:
|
||||
yield SampleInput(make_arg(shape), args=(mean, std))
|
||||
|
||||
def error_inputs_normal(op, device, **kwargs):
|
||||
t = torch.zeros([10], device=device)
|
||||
invalid_std = -1
|
||||
yield ErrorInput(
|
||||
SampleInput(t, args=(0, invalid_std)),
|
||||
error_type=RuntimeError,
|
||||
error_regex=r"normal expects std >= 0.0, but found std {}".format(invalid_std),
|
||||
)
|
||||
|
||||
def sample_inputs_cauchy(op, device, dtype, requires_grad, **kwargs):
|
||||
make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=False)
|
||||
|
|
@ -9068,6 +9086,36 @@ op_db: List[OpInfo] = [
|
|||
DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick'),
|
||||
)),
|
||||
OpInfo('normal',
|
||||
variant_test_name='in_place',
|
||||
op=lambda inp, *args, **kwargs: wrapper_set_seed(torch.Tensor.normal_, inp, *args, **kwargs),
|
||||
inplace_variant=torch.Tensor.normal_,
|
||||
dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
|
||||
supports_out=False,
|
||||
supports_autograd=False,
|
||||
sample_inputs_func=sample_inputs_normal,
|
||||
error_inputs_func=error_inputs_normal,
|
||||
skips=(
|
||||
# Tests that assume input is a tensor or sequence of tensors
|
||||
DecorateInfo(unittest.skip("Test expects tensor input"), "TestCommon", "test_noncontiguous_samples"),
|
||||
|
||||
# Tests that assume input tensor has a meaningful effect on output tensor
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick'),
|
||||
# AssertionError: JIT Test does not execute any logic
|
||||
DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
|
||||
# AssertionError: Tensor-likes are not close!
|
||||
DecorateInfo(unittest.expectedFailure, 'TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive_inplace'),
|
||||
DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
|
||||
# FX failed to normalize op - add the op to the op_skip list.
|
||||
DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
|
||||
# vmap: calling random operator not supported
|
||||
DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"),
|
||||
DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"),
|
||||
)),
|
||||
OpInfo('uniform',
|
||||
op=lambda inp, *args, **kwargs: wrapper_set_seed(torch.Tensor.uniform_, inp, *args, **kwargs),
|
||||
method_variant=None,
|
||||
|
|
@ -15710,7 +15758,13 @@ op_db: List[OpInfo] = [
|
|||
# Computed gradient is incorrect -- would be an exfail but gradgrad somehow passes
|
||||
DecorateInfo(unittest.skip("Gradients are incorrect!"), 'TestFwdGradients'),
|
||||
DecorateInfo(unittest.skip("Gradients are incorrect!"), 'TestBwdGradients'),
|
||||
DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'))),
|
||||
DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
|
||||
# RuntimeError: Difference from {dtype} is larger with decomposition
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestDecomp', 'test_comprehensive'),
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestDecomp', 'test_quick'),
|
||||
# The inplace variant (Tensor.normal_) is different from torch.normal
|
||||
# inplace varaint Tensor.normal_ is decomposed using randn_like()
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_symbolic_meta_outplace_all_strides'))),
|
||||
OpInfo('normal',
|
||||
# This has its own variant b/c OpInfos assume the first arg is a Tensor but it is not here
|
||||
variant_test_name='number_mean',
|
||||
|
|
@ -15731,7 +15785,21 @@ op_db: List[OpInfo] = [
|
|||
# Computed gradient is incorrect -- would be an exfail but gradgrad somehow passes
|
||||
DecorateInfo(unittest.skip("Gradients are incorrect!"), 'TestFwdGradients'),
|
||||
DecorateInfo(unittest.skip("Gradients are incorrect!"), 'TestBwdGradients'),
|
||||
DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'))),
|
||||
DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
|
||||
# The inplace variant (Tensor.normal_) is different from torch.normal
|
||||
# inplace varaint Tensor.normal_ is decomposed using randn_like()
|
||||
# TypeError: randn_like(): argument 'input' (position 1) must be Tensor, not float
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestFakeTensor', 'test_fake_autocast'),
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestFakeTensor', 'test_fake'),
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_symbolic_meta_outplace'),
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_symbolic_meta_outplace_all_strides'),
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_meta_outplace'),
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_meta_outplace'),
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestDecomp', 'test_comprehensive'),
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestDecomp', 'test_quick'),
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestProxyTensorOpInfo', 'test_make_fx_fake_exhaustive'),
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestFakeTensor', 'test_fake_crossref_backward_amp'),
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestFakeTensor', 'test_fake_crossref_backward_no_amp'))),
|
||||
OpInfo('bernoulli',
|
||||
op=lambda inp, *args, **kwargs:
|
||||
wrapper_set_seed(torch.bernoulli, inp, *args, **kwargs),
|
||||
|
|
@ -17986,6 +18054,35 @@ python_ref_db = [
|
|||
DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
|
||||
)
|
||||
),
|
||||
PythonRefInfo(
|
||||
"_refs.normal",
|
||||
torch_opinfo_name="normal",
|
||||
torch_opinfo_variant_name="in_place",
|
||||
supports_out=True,
|
||||
decorators=(
|
||||
# TODO: RuntimeError: no _refs support for torch.rand_like
|
||||
DecorateInfo(unittest.skip("TODO: RuntimeError: no _refs support for torch.rand_like"),
|
||||
'TestCommon',
|
||||
'test_python_ref'),
|
||||
|
||||
# AssertionError: Tensor-likes are not close!
|
||||
DecorateInfo(unittest.skip("Expected: normal is not comparable"),
|
||||
'TestCommon',
|
||||
'test_out'),
|
||||
DecorateInfo(unittest.skip("Expected: normal is not comparable"),
|
||||
'TestCommon',
|
||||
'test_out_warning'),
|
||||
DecorateInfo(unittest.skip("Expected: normal is not comparable"),
|
||||
'TestCommon',
|
||||
'test_python_ref_torch_fallback'),
|
||||
DecorateInfo(unittest.skip("Expected: normal is not comparable"), 'TestDecomp', 'test_comprehensive'),
|
||||
DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
|
||||
DecorateInfo(unittest.skip("make_traced() doesn't set seed properly!"), 'TestCommon', 'test_python_ref_executor'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
|
||||
)
|
||||
),
|
||||
PythonRefInfo(
|
||||
"_refs.arange",
|
||||
torch_opinfo_name="arange",
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user