Correctly specify size of sparse_csr tensors in maskedtensor binary ops (#134335)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134335
Approved by: https://github.com/amjames, https://github.com/cpuhrsch
This commit is contained in:
Benjamin Glass 2024-12-02 22:39:45 +00:00 committed by PyTorch MergeBot
parent 08db735629
commit f911361de1
2 changed files with 2 additions and 34 deletions

View File

@ -1,9 +1,7 @@
# Owner(s): ["module: masked operators"]
import torch
import unittest
from torch.testing._internal.common_utils import (
decorateIf,
TestCase,
run_tests,
make_tensor,
@ -957,37 +955,6 @@ class TestOperators(TestCase):
@ops(mt_binary_ufuncs, allowed_dtypes=MASKEDTENSOR_FLOAT_TYPES) # type: ignore[arg-type]
@parametrize("layout", [torch.strided, torch.sparse_coo, torch.sparse_csr])
# FIXME:
# Result is just wrong; production logic should be fixed
@decorateIf(
unittest.expectedFailure,
lambda params: (
params["op"].name == "add" and
params["dtype"] in [torch.float16, torch.float32] and
params["device"] == "cpu" and
params["layout"] == torch.sparse_csr
)
)
# Result is just wrong; production logic should be fixed
@decorateIf(
unittest.expectedFailure,
lambda params: (
params["op"].name == "sub" and
params["dtype"] in [torch.float16, torch.float32] and
params["device"] == "cpu" and
params["layout"] == torch.sparse_csr
)
)
# Result is just wrong; production logic should be fixed
@decorateIf(
unittest.expectedFailure,
lambda params: (
params["op"].name == "eq" and
params["dtype"] == torch.float64 and
params["device"] == "cpu" and
params["layout"] == torch.sparse_csr
)
)
def test_binary_core(self, device, dtype, op, layout):
self._test_unary_binary_equality(device, dtype, op, layout)

View File

@ -139,9 +139,10 @@ def _binary_helper(fn, args, kwargs, inplace):
crow = data_args[0].crow_indices()
col = data_args[0].col_indices()
size = data_args[0].size()
data_args[0] = data_args[0].values()
v = fn(*data_args)
result_data = torch.sparse_csr_tensor(crow, col, v)
result_data = torch.sparse_csr_tensor(crow, col, v, size)
else:
result_data = fn(*data_args)