Using Philox4 as PRNG
Test plan (other that CI)
Run
```python
mport torch
from torch._inductor.utils import run_and_get_code
from contextlib import nullcontext
def foo(x):
return x * torch.randn_like(x)
foo_c = torch.compile(foo)
x = torch.ones(100, 100, device="mps")
y = foo_c(x)
print(y.mean().item(), y.std().item())
for i in range(25):
print(y[i].mean(), y[i].std())
```
And observe that printed values are close to 0 and 1
TODO: Better `randint` algorithm for large ranges
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145705
Approved by: https://github.com/dcci, https://github.com/jansel
May be to be later reused from eager op as well
Also, didn't know that Metal already have type_traits
And use `metal::isunorderder(a, b)` instead of `metal::isnan(a + b)` is it is defined as function that is equivalent `a != a || b != b`, but I suspect it might have a best native implementation for the specific architecture
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145157
Approved by: https://github.com/dcci
`metal::isnan` is only defined for floats, so provide a generic wrapper
that is false for integral types
TODO: Figure out why type propagantion is not working (or should it?)
Fixes #ISSUE_NUMBER
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144665
Approved by: https://github.com/dcci
Now error message looks as follows:
```
% python ../test/inductor/test_torchinductor.py -v -k test_cat_unbacked_2d_mps
test_cat_unbacked_2d_mps (__main__.GPUTests) ... inline_call []
stats [('calls_captured', 6)]
inductor [('extern_calls', 2), ('fxgraph_cache_miss', 1)]
aot_autograd [('total', 1), ('autograd_cache_bypass', 1), ('not_ok', 1)]
ERROR
======================================================================
ERROR: test_cat_unbacked_2d_mps (__main__.GPUTests)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/Users/malfet/git/pytorch/pytorch/torch/testing/_internal/common_utils.py", line 3126, in wrapper
method(*args, **kwargs)
File "/Users/malfet/git/pytorch/pytorch/build/../test/inductor/test_torchinductor.py", line 12254, in new_test
return value(self)
File "/Users/malfet/miniconda3/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/Users/malfet/git/pytorch/pytorch/build/../test/inductor/test_torchinductor.py", line 5885, in test_cat_unbacked_2d
self.common(
File "/Users/malfet/miniconda3/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/Users/malfet/git/pytorch/pytorch/build/../test/inductor/test_torchinductor.py", line 620, in check_model_gpu
check_model(
File "/Users/malfet/git/pytorch/pytorch/build/../test/inductor/test_torchinductor.py", line 461, in check_model
actual = run(*example_inputs, **kwargs)
File "/Users/malfet/git/pytorch/pytorch/torch/_dynamo/eval_frame.py", line 580, in _fn
raise e.remove_dynamo_frames() from None # see TORCHDYNAMO_VERBOSE=1
File "/Users/malfet/git/pytorch/pytorch/torch/_inductor/compile_fx.py", line 704, in _compile_fx_inner
raise InductorError(e, currentframe()).with_traceback(
File "/Users/malfet/git/pytorch/pytorch/torch/_inductor/compile_fx.py", line 689, in _compile_fx_inner
mb_compiled_graph = fx_codegen_and_compile(
File "/Users/malfet/git/pytorch/pytorch/torch/_inductor/compile_fx.py", line 1149, in fx_codegen_and_compile
return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
File "/Users/malfet/git/pytorch/pytorch/torch/_inductor/compile_fx.py", line 1064, in codegen_and_compile
compiled_fn = graph.compile_to_module().call
File "/Users/malfet/git/pytorch/pytorch/torch/_inductor/graph.py", line 1977, in compile_to_module
return self._compile_to_module()
File "/Users/malfet/git/pytorch/pytorch/torch/_inductor/graph.py", line 2018, in _compile_to_module
mod = PyCodeCache.load_by_key_path(
File "/Users/malfet/git/pytorch/pytorch/torch/_inductor/codecache.py", line 2768, in load_by_key_path
mod = _reload_python_module(key, path)
File "/Users/malfet/git/pytorch/pytorch/torch/_inductor/runtime/compile_tasks.py", line 51, in _reload_python_module
exec(code, mod.__dict__, mod.__dict__)
File "/var/folders/sc/2thx6_x95h7_h9qs8s48yh140000gn/T/tmpmyfz2ju8/lt/cltm34ognlgcc6oxoe6bexvtbwcdtdfgnkjj5miz7vhkemitacp7.py", line 40, in <module>
File "/var/folders/sc/2thx6_x95h7_h9qs8s48yh140000gn/T/tmpmyfz2ju8/lt/cltm34ognlgcc6oxoe6bexvtbwcdtdfgnkjj5miz7vhkemitacp7.py", line 32, in _compile_mps_shader
torch._inductor.exc.InductorError: SyntaxError: failed to compile
kernel void generated_kernel(
device float* out_ptr0,
constant float* in_ptr0,
uint xindex [[thread_position_in_grid]]
) {
long x1 = (xindex) / (3);
auto tmp0 = x1;
auto tmp1 = static_cast<long>(tmp0);
auto tmp2 = 0;
auto tmp3 = tmp1 >= tmp2;
auto tmp4 = 2;
auto tmp5 = tmp1 < tmp4;
long x0 = (xindex) % (3);
auto tmp6 = in_ptr0[x0 + 3*(x1)];
auto tmp7 = tmp5 ? tmp6 : 0.0;
auto tmp8 = tmp1 >= tmp4;
auto tmp9 = 2 + ks0;
auto tmp10 = static_cast<long>(tmp9);
auto tmp11 = tmp1 < tmp10;
auto tmp12 = 1.0;
auto tmp13 = tmp8 ? tmp12 : 0.0;
auto tmp14 = tmp5 ? tmp7 : tmp13;
long x2 = xindex;
out_ptr0[x2] = static_cast<float>(tmp14);
}
with program_source:18:25: error: use of undeclared identifier 'ks0'
auto tmp9 = 2 + ks0;
^
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
To execute this test, run the following from the base repo dir:
python test/inductor/test_torchinductor.py GPUTests.test_cat_unbacked_2d_mps
This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0
----------------------------------------------------------------------
Ran 1 test in 0.472s
FAILED (errors=1)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144649
Approved by: https://github.com/Skylion007, https://github.com/jansel, https://github.com/dcci
ghstack dependencies: #144647, #144648
Just pass them as kernel arguments
After this change `pytest test/inductor/test_torchinduct.py -v -k _mps` reports 330 failed, 429 passed after and 335 failed, 424 passed before
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144662
Approved by: https://github.com/jansel
Move constant to value logic to `value_to_metal` function (similar to `value_to_cpp`)
Call it from `constant` as well as `where` ops (which is in turn being called from `masked` op
Fixes #ISSUE_NUMBER
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144500
Approved by: https://github.com/dcci
397 -> 395 tests failing. `static_cast<>` is because there are several overloads of `fmod()` that's otherwise ambiguous. I wonder if we should take in account NaN propagation (maybe it's not tested).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144449
Approved by: https://github.com/malfet
Not sure why CI did not complain about it, but it my local runs it clearly says
```
Advice (FLAKE8) E226
missing whitespace around arithmetic operator
See https://www.flake8rules.com/rules/E226.html
268 | with code.indent():
269 | if len(idx_var_names) > 1:
270 | for idx, name in enumerate(idx_var_names):
>>> 271 | code.writeline(f"auto {name} = thread_pos.{chr(120+idx)};")
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144282
Approved by: https://github.com/Skylion007
Alas, PythonPrinter would not work here, not would CppPrinter, so start building MetalPrinter.
`pytest test/inductor/test_torchinductor.py -k _mps` score is 474 failed, 277 passed, 32 skipped
Before this change:
`pytest test/inductor/test_torchinductor.py -k _mps` reported 506 failed, 245 passed, 32 skipped
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143973
Approved by: https://github.com/jansel
ghstack dependencies: #143948, #143949