mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[MPSInductor] Fix silent correctness in bitcast (#151272)
By using Metal `as_type` which according to documentation does exactly that: > Metal adds an as_type<type-id> operator to allow any scalar or vector data type (that is not a pointer) to be reinterpreted as another scalar or vector data type of the same size. The bits in the operand are returned directly without modification as the new type. The usual type promotion for function arguments is not performed. Using `reinterpret_cast` created a potential silent correctness error when dtypes of different sizes were bitcast to each other Add expicit cast to src_type to avoid errors due to type promotion (i.e. soemthing like (x+1).view(dtype=torch.float16) would work correctly in eager mode for int16 dtype, but would fail in compile, as arithmetic operations will promote int16 to int32 Pull Request resolved: https://github.com/pytorch/pytorch/pull/151272 Approved by: https://github.com/dcci ghstack dependencies: #151224, #151246
This commit is contained in:
parent
508b882513
commit
070357b61a
|
|
@ -162,7 +162,6 @@ class MPSBasicTests(TestCase):
|
|||
|
||||
# Copy tests
|
||||
for test_name in [
|
||||
"test_min_max_reduction",
|
||||
"test_add_complex4",
|
||||
"test_add_const_int",
|
||||
"test_add_inplace_permuted",
|
||||
|
|
@ -177,6 +176,7 @@ for test_name in [
|
|||
"test_avg_pool2d8",
|
||||
"test_batch_norm_2d_2",
|
||||
"test_bernoulli1",
|
||||
"test_bfloat16_to_int16",
|
||||
"test_builtins_round",
|
||||
"test_builtins_round_float_ndigits_neg",
|
||||
"test_cat_empty",
|
||||
|
|
@ -210,6 +210,7 @@ for test_name in [
|
|||
"test_max_pool2d2",
|
||||
"test_multilayer_prime_size",
|
||||
"test_multilayer_var_lowp",
|
||||
"test_min_max_reduction",
|
||||
"test_min_max_reduction_nan",
|
||||
"test_nan_to_num",
|
||||
"test_neg_max_uint8",
|
||||
|
|
|
|||
|
|
@ -12064,6 +12064,8 @@ class CommonTemplate:
|
|||
x_view = x.view(dtype=torch.int16)
|
||||
return x_view.mul(2) + x_view.bitwise_and(2)
|
||||
|
||||
if not self.is_dtype_supported(torch.bfloat16):
|
||||
raise unittest.SkipTest("bfloat16 is not supported on {self.device}")
|
||||
a = torch.ones(4, dtype=torch.bfloat16, device=self.device)
|
||||
b = torch.ones(4, dtype=torch.bfloat16, device=self.device)
|
||||
ref = fn(a, b)
|
||||
|
|
|
|||
|
|
@ -134,7 +134,7 @@ class MetalOverrides(OpOverrides):
|
|||
def to_dtype_bitcast(
|
||||
x: CSEVariable, dtype: torch.dtype, src_dtype: torch.dtype
|
||||
) -> str:
|
||||
return f"*reinterpret_cast<thread {DTYPE_TO_METAL[dtype]}*>(&{x})"
|
||||
return f"as_type<{DTYPE_TO_METAL[dtype]}>(static_cast<{DTYPE_TO_METAL[src_dtype]}>({x}))"
|
||||
|
||||
@staticmethod
|
||||
def constant(val: Union[bool, float, int], dtype: torch.dtype) -> str:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user