[MPS] Fix torch.std/torch.var default/correction handling (#91203)

If `torch.std`, `torch.var` are invoked without any arguments, it should be assumed that `unbiased` is `True`.

Also, if `correction` parameter is specified it should be use in correction computation.

Test by adding `std` and `var` to consistency tests

Fixes https://github.com/pytorch/pytorch/issues/91198

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91203
Approved by: https://github.com/kit1980
This commit is contained in:
Nikita Shulga 2022-12-21 02:23:50 +00:00 committed by PyTorch MergeBot
parent e670c261c5
commit dd735b96df
2 changed files with 10 additions and 7 deletions

View File

@ -602,8 +602,7 @@ Tensor std_var_common_impl_mps(
}
}
bool use_correction = correction.has_value();
const auto correction_value = use_correction ? correction.value() : false;
const auto correction_value = correction.has_value() ? correction.value() : 1;
int64_t correction_n = 1;
native_mps::MPSGraphCache* cache_ = native_mps::MPSGraphCache::getInstance();
@ -742,14 +741,14 @@ Tensor std_var_common_impl_mps(
return output_t;
}
double bessel_correction = ((double) correction_n) / ((double) (correction_n-1));
double bessel_correction = static_cast<double>(correction_n) / static_cast<double>(correction_n - correction_value);
auto stream = at::mps::getCurrentMPSStream();
@autoreleasepool {
string op_key = (stdVarType == STANDARD_DEVIATION) ? "std_mps" : "var_mps";
NSString* ns_key = [[axes valueForKey:@"description"] componentsJoinedByString:@","];
string bessel_corrected = (use_correction && correction_value) ? "unbiased " : "biased ";
string bessel_corrected = "correction_value=" + to_string(correction_value);
string use_dim_info = (use_dim) ? "use_dim=1:" + to_string(dim_value.size()) : "use_dim=0";
string keepdim_info = (keepdim) ? "keepdim=1" : "keepdim=0";
string key = op_key + use_dim_info + ":" + keepdim_info + ":" + string([ns_key UTF8String]) + ":" + native_mps::getTensorsStringKey(input_t) + ":" + bessel_corrected;
@ -771,7 +770,7 @@ Tensor std_var_common_impl_mps(
name:nil];
MPSGraphTensor *outputTensor;
if (use_correction && correction_value)
if (correction_value)
{
MPSGraphTensor *besselTensor= [mpsGraph constantWithScalar:bessel_correction
dataType:MPSDataTypeFloat32];

View File

@ -7583,6 +7583,7 @@ class TestConsistency(TestCase):
'square': ['f16', 'f32'],
'squeeze': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'stack': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'std': ['f32'],
'sub': ['f32', 'i16', 'i32', 'i64'],
'sum_to_size': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'svd': ['f32'],
@ -7602,6 +7603,7 @@ class TestConsistency(TestCase):
'unbind': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'unflatten': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'unsqueeze': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'var': ['f32'],
'view': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'view_as': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'vsplit': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
@ -7613,7 +7615,8 @@ class TestConsistency(TestCase):
'logical_and': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'logical_or': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'logical_xor': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'where': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8']}
'where': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8']
}
ALLOWLIST_OP_GRAD = {
@ -7783,7 +7786,8 @@ class TestConsistency(TestCase):
'view_as': ['f16', 'f32'],
'vsplit': ['f16', 'f32'],
'vstack': ['f16', 'f32'],
'zero_': ['f16', 'f32']}
'zero_': ['f16', 'f32']
}
# These ops that are problematic. So never run them even when
# generating the new allowlist.