mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[MPS] Fix dot/mm for conj_tensors (#150157)
- Distinguish between conjugated/non_conjugated inputs by appending conjugation to the operator key - For matmul or dot, add `conjugateWithTensor:name:` calls before running the op - Enable testing for conjugated ops by passing `include_conjugated_inputs` to opinfo - Filter `include_conjugated_inputs` argument from `sample_inputs_window` (probably should have landed as separate PR) - Preserve conj property when gathering the views, that fixes `cov` operator Fixes https://github.com/pytorch/pytorch/issues/148156 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150157 Approved by: https://github.com/dcci
This commit is contained in:
parent
9092dd2e82
commit
7c65911b11
|
|
@ -325,13 +325,15 @@ std::string getTensorsStringKey(const TensorList& tensors, bool short_dtype, boo
|
|||
str += "Scalar";
|
||||
} else {
|
||||
if (exclude_shape) {
|
||||
str += "[-1]";
|
||||
str += "-1";
|
||||
} else {
|
||||
str +=
|
||||
std::string([[getMPSShape(tensor) valueForKey:@"description"] componentsJoinedByString:@","].UTF8String);
|
||||
}
|
||||
}
|
||||
str += "]";
|
||||
if (tensor.is_conj())
|
||||
str += "_conj";
|
||||
} else {
|
||||
str += "Undefined";
|
||||
}
|
||||
|
|
@ -543,7 +545,12 @@ Placeholder::Placeholder(MPSGraphTensor* mpsGraphTensor,
|
|||
if ((!src.is_contiguous() || src.storage_offset()) && gatherTensorData) {
|
||||
Tensor emptyShell = Tensor();
|
||||
// use "_tensor" from Placeholder to retain view's output during its usage in other ops
|
||||
_tensor = gatherViewTensor(src, emptyShell);
|
||||
// And preserve conjugated property here
|
||||
if (!src.is_conj()) {
|
||||
_tensor = gatherViewTensor(src, emptyShell);
|
||||
} else {
|
||||
_tensor = gatherViewTensor(src.conj(), emptyShell).conj();
|
||||
}
|
||||
if (!_tensor.has_storage()) {
|
||||
// if we cannot gather, we make the tensor contiguous implicitly, and keep
|
||||
// it in placeholder to be able to retrieve it when we return from constructor
|
||||
|
|
|
|||
|
|
@ -81,6 +81,12 @@ Tensor dot_mps(const Tensor& self, const Tensor& other) {
|
|||
castSelf = selfTensor;
|
||||
castOther = otherTensor;
|
||||
}
|
||||
if (self.is_conj()) {
|
||||
castSelf = [mpsGraph conjugateWithTensor:selfTensor name:nil];
|
||||
}
|
||||
if (other.is_conj()) {
|
||||
castOther = [mpsGraph conjugateWithTensor:otherTensor name:nil];
|
||||
}
|
||||
|
||||
MPSGraphTensor* dot = [mpsGraph multiplicationWithPrimaryTensor:castSelf
|
||||
secondaryTensor:castOther
|
||||
|
|
|
|||
|
|
@ -118,10 +118,12 @@ std::tuple<MPSGraphTensor*, MPSGraphTensor*, MPSGraphTensor*> do_mm(MPSGraph* gr
|
|||
dataType:getMPSDataType(self)];
|
||||
return {nil, nil, output};
|
||||
}
|
||||
auto selfTensor = mpsGraphRankedPlaceHolder(graph, self);
|
||||
auto otherTensor = mpsGraphRankedPlaceHolder(graph, other);
|
||||
auto selfTensor_ = mpsGraphRankedPlaceHolder(graph, self);
|
||||
auto otherTensor_ = mpsGraphRankedPlaceHolder(graph, other);
|
||||
auto selfTensor = self.is_conj() ? [graph conjugateWithTensor:selfTensor_ name:nil] : selfTensor_;
|
||||
auto otherTensor = other.is_conj() ? [graph conjugateWithTensor:otherTensor_ name:nil] : otherTensor_;
|
||||
auto output = [graph matrixMultiplicationWithPrimaryTensor:selfTensor secondaryTensor:otherTensor name:nil];
|
||||
return {selfTensor, otherTensor, output};
|
||||
return {selfTensor_, otherTensor_, output};
|
||||
}
|
||||
|
||||
bool use_metal_mm(const Tensor& self, const Tensor& other, const Tensor& output) {
|
||||
|
|
|
|||
|
|
@ -405,6 +405,7 @@ def mps_ops_modifier(ops):
|
|||
'constant_pad_nd',
|
||||
'cos',
|
||||
'cosh',
|
||||
'cov',
|
||||
'count_nonzero',
|
||||
'diff',
|
||||
'div',
|
||||
|
|
@ -12455,8 +12456,16 @@ MPS_GRAD_DTYPES = [torch.float32, torch.float16]
|
|||
|
||||
def transform_opinfo_sample_to_mps(sample):
|
||||
"""Transforms opinfo.core.SampleInput from CPU to MPS"""
|
||||
mps_sample = sample.transform(
|
||||
lambda x: x.detach().to("mps").requires_grad_(x.requires_grad) if isinstance(x, torch.Tensor) else x)
|
||||
def transform_sample(x):
|
||||
if not isinstance(x, torch.Tensor):
|
||||
return x
|
||||
requires_grad = x.requires_grad
|
||||
conjugated = x.is_conj()
|
||||
rc = x.detach()
|
||||
rc = rc.to("mps") if not conjugated else x.conj().to("mps").conj()
|
||||
return rc.requires_grad_(x.requires_grad)
|
||||
|
||||
mps_sample = sample.transform(transform_sample)
|
||||
|
||||
# Transform kwargs `device="cpu"` to `device="mps"`
|
||||
if mps_sample.kwargs.get("device", "") == "cpu":
|
||||
|
|
@ -12575,12 +12584,14 @@ class TestConsistency(TestCaseMPS):
|
|||
@ops(mps_ops_modifier(test_consistency_op_db), allowed_dtypes=MPS_DTYPES)
|
||||
def test_output_match(self, device, dtype, op):
|
||||
self.assertEqual(device, "cpu")
|
||||
include_conjugated_inputs = dtype.is_complex and op.test_conjugated_samples
|
||||
|
||||
def get_samples():
|
||||
return op.sample_inputs(
|
||||
device,
|
||||
dtype,
|
||||
requires_grad=(dtype.is_floating_point or dtype.is_complex),
|
||||
include_conjugated_inputs=include_conjugated_inputs,
|
||||
# TODO: Enable per-sample seed setting and tweak tolerances / fix xfails
|
||||
set_seed=False,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -29,6 +29,8 @@ def sample_inputs_window(op_info, device, dtype, requires_grad, *args, **kwargs)
|
|||
additional keyword arguments.
|
||||
"""
|
||||
|
||||
# Remove include_conjugated_inputs from kwargs
|
||||
kwargs.pop("include_conjugated_inputs", None)
|
||||
# Tests window sizes up to 5 samples.
|
||||
for size, sym in product(range(6), (True, False)):
|
||||
yield SampleInput(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user