mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[MPS] Fix torch.std/torch.var default/correction handling (#91203)
If `torch.std`, `torch.var` are invoked without any arguments, it should be assumed that `unbiased` is `True`. Also, if `correction` parameter is specified it should be use in correction computation. Test by adding `std` and `var` to consistency tests Fixes https://github.com/pytorch/pytorch/issues/91198 Pull Request resolved: https://github.com/pytorch/pytorch/pull/91203 Approved by: https://github.com/kit1980
This commit is contained in:
parent
e670c261c5
commit
dd735b96df
|
|
@ -602,8 +602,7 @@ Tensor std_var_common_impl_mps(
|
|||
}
|
||||
}
|
||||
|
||||
bool use_correction = correction.has_value();
|
||||
const auto correction_value = use_correction ? correction.value() : false;
|
||||
const auto correction_value = correction.has_value() ? correction.value() : 1;
|
||||
int64_t correction_n = 1;
|
||||
|
||||
native_mps::MPSGraphCache* cache_ = native_mps::MPSGraphCache::getInstance();
|
||||
|
|
@ -742,14 +741,14 @@ Tensor std_var_common_impl_mps(
|
|||
return output_t;
|
||||
}
|
||||
|
||||
double bessel_correction = ((double) correction_n) / ((double) (correction_n-1));
|
||||
double bessel_correction = static_cast<double>(correction_n) / static_cast<double>(correction_n - correction_value);
|
||||
|
||||
auto stream = at::mps::getCurrentMPSStream();
|
||||
|
||||
@autoreleasepool {
|
||||
string op_key = (stdVarType == STANDARD_DEVIATION) ? "std_mps" : "var_mps";
|
||||
NSString* ns_key = [[axes valueForKey:@"description"] componentsJoinedByString:@","];
|
||||
string bessel_corrected = (use_correction && correction_value) ? "unbiased " : "biased ";
|
||||
string bessel_corrected = "correction_value=" + to_string(correction_value);
|
||||
string use_dim_info = (use_dim) ? "use_dim=1:" + to_string(dim_value.size()) : "use_dim=0";
|
||||
string keepdim_info = (keepdim) ? "keepdim=1" : "keepdim=0";
|
||||
string key = op_key + use_dim_info + ":" + keepdim_info + ":" + string([ns_key UTF8String]) + ":" + native_mps::getTensorsStringKey(input_t) + ":" + bessel_corrected;
|
||||
|
|
@ -771,7 +770,7 @@ Tensor std_var_common_impl_mps(
|
|||
name:nil];
|
||||
MPSGraphTensor *outputTensor;
|
||||
|
||||
if (use_correction && correction_value)
|
||||
if (correction_value)
|
||||
{
|
||||
MPSGraphTensor *besselTensor= [mpsGraph constantWithScalar:bessel_correction
|
||||
dataType:MPSDataTypeFloat32];
|
||||
|
|
|
|||
|
|
@ -7583,6 +7583,7 @@ class TestConsistency(TestCase):
|
|||
'square': ['f16', 'f32'],
|
||||
'squeeze': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
|
||||
'stack': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
|
||||
'std': ['f32'],
|
||||
'sub': ['f32', 'i16', 'i32', 'i64'],
|
||||
'sum_to_size': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
|
||||
'svd': ['f32'],
|
||||
|
|
@ -7602,6 +7603,7 @@ class TestConsistency(TestCase):
|
|||
'unbind': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
|
||||
'unflatten': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
|
||||
'unsqueeze': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
|
||||
'var': ['f32'],
|
||||
'view': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
|
||||
'view_as': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
|
||||
'vsplit': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
|
||||
|
|
@ -7613,7 +7615,8 @@ class TestConsistency(TestCase):
|
|||
'logical_and': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
|
||||
'logical_or': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
|
||||
'logical_xor': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
|
||||
'where': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8']}
|
||||
'where': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8']
|
||||
}
|
||||
|
||||
|
||||
ALLOWLIST_OP_GRAD = {
|
||||
|
|
@ -7783,7 +7786,8 @@ class TestConsistency(TestCase):
|
|||
'view_as': ['f16', 'f32'],
|
||||
'vsplit': ['f16', 'f32'],
|
||||
'vstack': ['f16', 'f32'],
|
||||
'zero_': ['f16', 'f32']}
|
||||
'zero_': ['f16', 'f32']
|
||||
}
|
||||
|
||||
# These ops that are problematic. So never run them even when
|
||||
# generating the new allowlist.
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user