mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Correctly specify size of sparse_csr tensors in maskedtensor binary ops (#134335)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134335 Approved by: https://github.com/amjames, https://github.com/cpuhrsch
This commit is contained in:
parent
08db735629
commit
f911361de1
|
|
@ -1,9 +1,7 @@
|
|||
# Owner(s): ["module: masked operators"]
|
||||
|
||||
import torch
|
||||
import unittest
|
||||
from torch.testing._internal.common_utils import (
|
||||
decorateIf,
|
||||
TestCase,
|
||||
run_tests,
|
||||
make_tensor,
|
||||
|
|
@ -957,37 +955,6 @@ class TestOperators(TestCase):
|
|||
|
||||
@ops(mt_binary_ufuncs, allowed_dtypes=MASKEDTENSOR_FLOAT_TYPES) # type: ignore[arg-type]
|
||||
@parametrize("layout", [torch.strided, torch.sparse_coo, torch.sparse_csr])
|
||||
# FIXME:
|
||||
# Result is just wrong; production logic should be fixed
|
||||
@decorateIf(
|
||||
unittest.expectedFailure,
|
||||
lambda params: (
|
||||
params["op"].name == "add" and
|
||||
params["dtype"] in [torch.float16, torch.float32] and
|
||||
params["device"] == "cpu" and
|
||||
params["layout"] == torch.sparse_csr
|
||||
)
|
||||
)
|
||||
# Result is just wrong; production logic should be fixed
|
||||
@decorateIf(
|
||||
unittest.expectedFailure,
|
||||
lambda params: (
|
||||
params["op"].name == "sub" and
|
||||
params["dtype"] in [torch.float16, torch.float32] and
|
||||
params["device"] == "cpu" and
|
||||
params["layout"] == torch.sparse_csr
|
||||
)
|
||||
)
|
||||
# Result is just wrong; production logic should be fixed
|
||||
@decorateIf(
|
||||
unittest.expectedFailure,
|
||||
lambda params: (
|
||||
params["op"].name == "eq" and
|
||||
params["dtype"] == torch.float64 and
|
||||
params["device"] == "cpu" and
|
||||
params["layout"] == torch.sparse_csr
|
||||
)
|
||||
)
|
||||
def test_binary_core(self, device, dtype, op, layout):
|
||||
self._test_unary_binary_equality(device, dtype, op, layout)
|
||||
|
||||
|
|
|
|||
|
|
@ -139,9 +139,10 @@ def _binary_helper(fn, args, kwargs, inplace):
|
|||
|
||||
crow = data_args[0].crow_indices()
|
||||
col = data_args[0].col_indices()
|
||||
size = data_args[0].size()
|
||||
data_args[0] = data_args[0].values()
|
||||
v = fn(*data_args)
|
||||
result_data = torch.sparse_csr_tensor(crow, col, v)
|
||||
result_data = torch.sparse_csr_tensor(crow, col, v, size)
|
||||
|
||||
else:
|
||||
result_data = fn(*data_args)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user