[MPS] Fix dot/mm for conj_tensors (#150157)

- Distinguish between conjugated/non_conjugated inputs by appending conjugation to the operator key
- For matmul or dot, add `conjugateWithTensor:name:` calls before running the op
- Enable testing for conjugated ops by passing `include_conjugated_inputs` to opinfo
- Filter  `include_conjugated_inputs` argument from `sample_inputs_window` (probably should have landed as separate PR)
- Preserve conj property when gathering the views, that fixes `cov` operator

Fixes https://github.com/pytorch/pytorch/issues/148156
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150157
Approved by: https://github.com/dcci
This commit is contained in:
Nikita Shulga 2025-03-28 12:14:11 -07:00 committed by PyTorch MergeBot
parent 9092dd2e82
commit 7c65911b11
5 changed files with 35 additions and 7 deletions

View File

@ -325,13 +325,15 @@ std::string getTensorsStringKey(const TensorList& tensors, bool short_dtype, boo
str += "Scalar";
} else {
if (exclude_shape) {
str += "[-1]";
str += "-1";
} else {
str +=
std::string([[getMPSShape(tensor) valueForKey:@"description"] componentsJoinedByString:@","].UTF8String);
}
}
str += "]";
if (tensor.is_conj())
str += "_conj";
} else {
str += "Undefined";
}
@ -543,7 +545,12 @@ Placeholder::Placeholder(MPSGraphTensor* mpsGraphTensor,
if ((!src.is_contiguous() || src.storage_offset()) && gatherTensorData) {
Tensor emptyShell = Tensor();
// use "_tensor" from Placeholder to retain view's output during its usage in other ops
_tensor = gatherViewTensor(src, emptyShell);
// And preserve conjugated property here
if (!src.is_conj()) {
_tensor = gatherViewTensor(src, emptyShell);
} else {
_tensor = gatherViewTensor(src.conj(), emptyShell).conj();
}
if (!_tensor.has_storage()) {
// if we cannot gather, we make the tensor contiguous implicitly, and keep
// it in placeholder to be able to retrieve it when we return from constructor

View File

@ -81,6 +81,12 @@ Tensor dot_mps(const Tensor& self, const Tensor& other) {
castSelf = selfTensor;
castOther = otherTensor;
}
if (self.is_conj()) {
castSelf = [mpsGraph conjugateWithTensor:selfTensor name:nil];
}
if (other.is_conj()) {
castOther = [mpsGraph conjugateWithTensor:otherTensor name:nil];
}
MPSGraphTensor* dot = [mpsGraph multiplicationWithPrimaryTensor:castSelf
secondaryTensor:castOther

View File

@ -118,10 +118,12 @@ std::tuple<MPSGraphTensor*, MPSGraphTensor*, MPSGraphTensor*> do_mm(MPSGraph* gr
dataType:getMPSDataType(self)];
return {nil, nil, output};
}
auto selfTensor = mpsGraphRankedPlaceHolder(graph, self);
auto otherTensor = mpsGraphRankedPlaceHolder(graph, other);
auto selfTensor_ = mpsGraphRankedPlaceHolder(graph, self);
auto otherTensor_ = mpsGraphRankedPlaceHolder(graph, other);
auto selfTensor = self.is_conj() ? [graph conjugateWithTensor:selfTensor_ name:nil] : selfTensor_;
auto otherTensor = other.is_conj() ? [graph conjugateWithTensor:otherTensor_ name:nil] : otherTensor_;
auto output = [graph matrixMultiplicationWithPrimaryTensor:selfTensor secondaryTensor:otherTensor name:nil];
return {selfTensor, otherTensor, output};
return {selfTensor_, otherTensor_, output};
}
bool use_metal_mm(const Tensor& self, const Tensor& other, const Tensor& output) {

View File

@ -405,6 +405,7 @@ def mps_ops_modifier(ops):
'constant_pad_nd',
'cos',
'cosh',
'cov',
'count_nonzero',
'diff',
'div',
@ -12455,8 +12456,16 @@ MPS_GRAD_DTYPES = [torch.float32, torch.float16]
def transform_opinfo_sample_to_mps(sample):
"""Transforms opinfo.core.SampleInput from CPU to MPS"""
mps_sample = sample.transform(
lambda x: x.detach().to("mps").requires_grad_(x.requires_grad) if isinstance(x, torch.Tensor) else x)
def transform_sample(x):
if not isinstance(x, torch.Tensor):
return x
requires_grad = x.requires_grad
conjugated = x.is_conj()
rc = x.detach()
rc = rc.to("mps") if not conjugated else x.conj().to("mps").conj()
return rc.requires_grad_(x.requires_grad)
mps_sample = sample.transform(transform_sample)
# Transform kwargs `device="cpu"` to `device="mps"`
if mps_sample.kwargs.get("device", "") == "cpu":
@ -12575,12 +12584,14 @@ class TestConsistency(TestCaseMPS):
@ops(mps_ops_modifier(test_consistency_op_db), allowed_dtypes=MPS_DTYPES)
def test_output_match(self, device, dtype, op):
self.assertEqual(device, "cpu")
include_conjugated_inputs = dtype.is_complex and op.test_conjugated_samples
def get_samples():
return op.sample_inputs(
device,
dtype,
requires_grad=(dtype.is_floating_point or dtype.is_complex),
include_conjugated_inputs=include_conjugated_inputs,
# TODO: Enable per-sample seed setting and tweak tolerances / fix xfails
set_seed=False,
)

View File

@ -29,6 +29,8 @@ def sample_inputs_window(op_info, device, dtype, requires_grad, *args, **kwargs)
additional keyword arguments.
"""
# Remove include_conjugated_inputs from kwargs
kwargs.pop("include_conjugated_inputs", None)
# Tests window sizes up to 5 samples.
for size, sym in product(range(6), (True, False)):
yield SampleInput(