[MPSInductor] Fix silent correctness in bitcast (#151272)

By using Metal `as_type` which according to documentation does exactly
that:
> Metal adds an as_type<type-id> operator to allow any scalar or vector data type (that is not
a pointer) to be reinterpreted as another scalar or vector data type of the same size. The bits in
the operand are returned directly without modification as the new type. The usual type
promotion for function arguments is not performed.

Using `reinterpret_cast` created a potential silent correctness error when dtypes of different sizes were bitcast to each other
Add expicit cast to src_type to avoid errors due to type promotion (i.e.
soemthing like (x+1).view(dtype=torch.float16) would work correctly in
eager mode for int16 dtype, but would fail in compile, as arithmetic
operations will promote int16 to int32

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151272
Approved by: https://github.com/dcci
ghstack dependencies: #151224, #151246
This commit is contained in:
Nikita Shulga 2025-04-14 16:05:06 -07:00 committed by PyTorch MergeBot
parent 508b882513
commit 070357b61a
3 changed files with 5 additions and 2 deletions

View File

@ -162,7 +162,6 @@ class MPSBasicTests(TestCase):
# Copy tests
for test_name in [
"test_min_max_reduction",
"test_add_complex4",
"test_add_const_int",
"test_add_inplace_permuted",
@ -177,6 +176,7 @@ for test_name in [
"test_avg_pool2d8",
"test_batch_norm_2d_2",
"test_bernoulli1",
"test_bfloat16_to_int16",
"test_builtins_round",
"test_builtins_round_float_ndigits_neg",
"test_cat_empty",
@ -210,6 +210,7 @@ for test_name in [
"test_max_pool2d2",
"test_multilayer_prime_size",
"test_multilayer_var_lowp",
"test_min_max_reduction",
"test_min_max_reduction_nan",
"test_nan_to_num",
"test_neg_max_uint8",

View File

@ -12064,6 +12064,8 @@ class CommonTemplate:
x_view = x.view(dtype=torch.int16)
return x_view.mul(2) + x_view.bitwise_and(2)
if not self.is_dtype_supported(torch.bfloat16):
raise unittest.SkipTest("bfloat16 is not supported on {self.device}")
a = torch.ones(4, dtype=torch.bfloat16, device=self.device)
b = torch.ones(4, dtype=torch.bfloat16, device=self.device)
ref = fn(a, b)

View File

@ -134,7 +134,7 @@ class MetalOverrides(OpOverrides):
def to_dtype_bitcast(
x: CSEVariable, dtype: torch.dtype, src_dtype: torch.dtype
) -> str:
return f"*reinterpret_cast<thread {DTYPE_TO_METAL[dtype]}*>(&{x})"
return f"as_type<{DTYPE_TO_METAL[dtype]}>(static_cast<{DTYPE_TO_METAL[src_dtype]}>({x}))"
@staticmethod
def constant(val: Union[bool, float, int], dtype: torch.dtype) -> str: