pytorch/torch/_dynamo/variables
ydwu4 fc5cde7579 [dynamo] constant fold torch.cuda.get_device_properties to avoid graph break (#118422)
Before the PR, we have a graph break for code like this,
```python
    def test_get_device_properties_tensor_device(a):
        x = a.to("cuda")
        prop = torch.cuda.get_device_properties(x.device)
        if prop.major == 8:
            return x + prop.multi_processor_count
        return x + prop.max_threads_per_multi_processor
```
This PR constant folds the torch.cuda.get_device_properties and we'll get a following dynamo graph:
```python
[2024-01-26 13:28:13,253] [0/0] torch._dynamo.output_graph.__graph: [DEBUG]  <eval_with_key>.0 class GraphModule(torch.nn.Module):
[2024-01-26 13:28:13,253] [0/0] torch._dynamo.output_graph.__graph: [DEBUG]     def forward(self, L_a_ : torch.Tensor):
[2024-01-26 13:28:13,253] [0/0] torch._dynamo.output_graph.__graph: [DEBUG]         l_a_ = L_a_
[2024-01-26 13:28:13,253] [0/0] torch._dynamo.output_graph.__graph: [DEBUG]
[2024-01-26 13:28:13,253] [0/0] torch._dynamo.output_graph.__graph: [DEBUG]         # File: /home/yidi/local/pytorch/test/dynamo/test_functions.py:544 in test_get_device_properties_tensor_device, code: x = a.to("cuda")
[2024-01-26 13:28:13,253] [0/0] torch._dynamo.output_graph.__graph: [DEBUG]         x = l_a_.to('cuda');  l_a_ = None
[2024-01-26 13:28:13,253] [0/0] torch._dynamo.output_graph.__graph: [DEBUG]
[2024-01-26 13:28:13,253] [0/0] torch._dynamo.output_graph.__graph: [DEBUG]         # File: /home/yidi/local/pytorch/test/dynamo/test_functions.py:547 in test_get_device_properties_tensor_device, code: return x + prop.multi_processor_count
[2024-01-26 13:28:13,253] [0/0] torch._dynamo.output_graph.__graph: [DEBUG]         add = x + 108;  x = None
[2024-01-26 13:28:13,253] [0/0] torch._dynamo.output_graph.__graph: [DEBUG]         return (add,)
[2024-01-26 13:28:13,253] [0/0] torch._dynamo.output_graph.__graph: [DEBUG]
```

The signature of get_device_properties is:
```python
def get_device_properties(device: _device_t) -> _CudaDeviceProperties:
```
I think it's safe to constant fold get_device_properties():
1. torch.cuda.get_device_properties(tensor.device). In this case, tensor.device.index is guarded in _check_tensor
2. torch.cuda.get_device_properties(device_int_id). We don't expect the GPU properties for a particular index changes during a torch.compile run and it make sense to specialize the properties for a concrete device_int_id.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118422
Approved by: https://github.com/yanboliang, https://github.com/jansel
2024-01-29 20:26:40 +00:00
..
__init__.py Unify MYPYINDUCTOR and MYPY (#118432) 2024-01-27 17:23:20 +00:00
base.py Unify MYPYINDUCTOR and MYPY (#118432) 2024-01-27 17:23:20 +00:00
builder.py [14/N][Dynamo] Make trace_rules.lookup only handle function + callable type (#118366) 2024-01-27 23:02:44 +00:00
builtin.py [14/N][Dynamo] Make trace_rules.lookup only handle function + callable type (#118366) 2024-01-27 23:02:44 +00:00
constant.py Unify MYPYINDUCTOR and MYPY (#118432) 2024-01-27 17:23:20 +00:00
ctx_manager.py Unify MYPYINDUCTOR and MYPY (#118432) 2024-01-27 17:23:20 +00:00
dicts.py Unify MYPYINDUCTOR and MYPY (#118432) 2024-01-27 17:23:20 +00:00
distributed.py Unify MYPYINDUCTOR and MYPY (#118432) 2024-01-27 17:23:20 +00:00
functions.py Unify MYPYINDUCTOR and MYPY (#118432) 2024-01-27 17:23:20 +00:00
higher_order_ops.py Unify MYPYINDUCTOR and MYPY (#118432) 2024-01-27 17:23:20 +00:00
iter.py Unify MYPYINDUCTOR and MYPY (#118432) 2024-01-27 17:23:20 +00:00
lazy.py Unify MYPYINDUCTOR and MYPY (#118432) 2024-01-27 17:23:20 +00:00
lists.py Unify MYPYINDUCTOR and MYPY (#118432) 2024-01-27 17:23:20 +00:00
misc.py Unify MYPYINDUCTOR and MYPY (#118432) 2024-01-27 17:23:20 +00:00
nn_module.py Unify MYPYINDUCTOR and MYPY (#118432) 2024-01-27 17:23:20 +00:00
optimizer.py Unify MYPYINDUCTOR and MYPY (#118432) 2024-01-27 17:23:20 +00:00
sdpa.py Unify MYPYINDUCTOR and MYPY (#118432) 2024-01-27 17:23:20 +00:00
tensor.py Unify MYPYINDUCTOR and MYPY (#118432) 2024-01-27 17:23:20 +00:00
torch_function.py Unify MYPYINDUCTOR and MYPY (#118432) 2024-01-27 17:23:20 +00:00
torch.py [dynamo] constant fold torch.cuda.get_device_properties to avoid graph break (#118422) 2024-01-29 20:26:40 +00:00
user_defined.py Unify MYPYINDUCTOR and MYPY (#118432) 2024-01-27 17:23:20 +00:00