This PR updates OpInfo-based tests for NJTs:
* Adds extensive coverage across non-contiguous NJTs (both non-contiguous transposed and non-contiguous with holes)
* The `_sample_njts()` helper that `sample_input_func`s utilize now produces non-contig NJTs as well
* Utilizes a `SampleInput`-based xfail system for granular classification of bugs. For example, it's possible to indicate that a class of ops is expected to fail only on non-contig with holes NJT inputs.
* I decided on adding `SampleInput`s and utilizing this system over using test parametrization for two reasons:
* Test perf - adding `SampleInput`s is faster than generating entire new tests
* Avoiding the possibility of `sample_input_func`s not respecting the non-contig test parameter - this would result in silently incorrect passing of these tests. Keeping the responsibility for `SampleInput` generation firmly within each `OpInfo`'s `sample_input_func` means weirdness like this isn't possible
* Improves `SampleInput` naming for a bunch of `sample_input_func`s. This makes it easier to xfail them as needed. For example, binary / unary / other ops now use the new `_describe_njt()` helper to get a string repr that uniquely defines the type of NJT being passed to the op
* Adds appropriate `XFailRule`s to get tests passing for forward / backward / forward compile / backward compile. In general, each xfail corresponds to some bug that needs to be fixed
```python
# Represents a rule indicating how to xfail a particular test. It allows granularity
# at the device, dtype, op, and individual sample levels. This flexibility allows entire
# bugs to be represented by a single rule, even if this corresponds with multiple conceptual
# test cases across multiple ops.
@dataclass
class XFailRule:
# expected error type
error_type: TypeVar = Exception
# expected error message
error_msg: str = ".*"
# function to indicate whether the rule applies; return True if so
match_fn: Callable[[torch.device, torch.dtype, OpInfo, SampleInput], bool] = None
# optional name for identifying the rule
name: str = ""
def match(self, device, dtype, op, sample) -> bool:
return self.match_fn(device, dtype, op, sample)
```
Example:
```python
# Bug when broadcasting a binary op with non-contiguous with holes NJT + dense
# tensor with 1 in ragged dim.
XFailRule(
error_type=RuntimeError,
error_msg="cannot call binary pointwise function .* with inputs of shapes",
match_fn=lambda device, dtype, op, sample: (
isinstance(op, BinaryUfuncInfo)
and "noncontig_holes" in sample.name
and "broadcasting 1 over ragged" in sample.name
),
name="binary_noncontig_holes_broadcasting_1_over_ragged",
),
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138370
Approved by: https://github.com/cpuhrsch, https://github.com/soulitzer
ghstack dependencies: #140160
This PR updates the binding for `stream_write_value32` to be consistent with `memset32` which IMO makes more sense for this type of utilities:
- Changed the API to take a uint32 tensor as argument, instead of a device pointer
- Changed the Python binding to be a static method of `_SymmetricMemory`, instead of a object method
- Use the dispatcher for device dispatching, as opposed to `SymmetricMemory` backends
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139934
Approved by: https://github.com/weifengpy
ghstack dependencies: #139227
This PR adds "PrimHOPBase", which is intended to be a base class that
one can extend to create new HOPs that match some criteria:
- they take one subgraph as input, and their semantics are running the
subgraph on some operands
- the HOP stays alive until Inductor
The motivation is that we are seeing a lot more HOPs (invoke_subgraph,
invoke_quant) that have this property and there can be a lot of shared
code between them.
Future:
- Migrate invoke_subgraph to use this
- There are some TODOs in the code
Test Plan:
- new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139898
Approved by: https://github.com/anijain2305, https://github.com/ydwu4
In old triton versions, you take the hash of the triton kernel and use it in the filepath for the cached kernel. In Triton 3.2 (after https://github.com/triton-lang/triton/pull/4553), the filepath will use the base-64-encoded representation of the hash in the path.
This PR checks whether the `_base64` function exists in triton, and if so, uses the base-64-encoded represenatation in the path.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140190
Approved by: https://github.com/ezyang
Here's the overview:
There's a new contextmanager singleton called MetricsContext. Entering the MetricsContext is how we demarcate the boundary on which we'll create a single CompilationMetrics object, and therefore, a single dynamo_compile log entry. While we're inside the MetricsContext, we can update/set many different metrics. Most importantly: `dynamo_timed` can also update the in-progress MetricsContext. In the proposal here, we tell `dynamo_timed` that we want it to do so by providing the name of the MetricsContext field to increment. There can be many `dynamo_timed` calls in different parts of the code updating different fields. Then when the MetricsContext exits, that's when the logging of everything gathered finally happens. One potential footgun is trying to use `dynamo_timed` when we haven't entered the MetricsContext, but we assert on that problem. Another problem is that we re-enter the context recursively, but we watch for that and do the logging only when the outermost exits.
Some specifics:
* Introduce MetricsContext - a context manager that on exit, records the CompilationMetrics (which also logs to dynamo_compile).
* Completely remove the concept of frame_phase_timing. Instead, update the MetricsContext during compilation, either directly or via dynamo_timed.
* Remove some globals we previously used to accumulate counters to later populate a CompilationMetrics. We use CompilationMetrics set/update/increment APIs instead.
* `record_compilation_metrics` is now called on exit from MetricsContext.
* Populate legacy CompilationMetrics fields right before logging, inside `record_compilation_metrics`.
* Remove the one-off `add_remote_cache_time_saved` helper; capture that timing directly into the MetricsContext.
And specifically, several changes to dynamo_timed:
* "Modernize" the parameters and update all callsites accordingly.
* Move the backwards logging of the CompilationMetrics to the backwards compile location.
* Add a parameter for which CompilationMetrics field to update
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139849
Approved by: https://github.com/ezyang
ghstack dependencies: #140094
Summary:
I was looking into why the non-standard bool value will fail for msort - it makes sense for argsort and sort to fail, because we're randomly generating uint8 so the order will be different (and thus the indices will be different). But msort should work.
After some digging, it's interesting that even though scalar_t is bool, when the actual value is a uint8_t, the comparison will treat them as signed. I tried lhs=255 and rhs=0: lhs < rhs is equivalent to -1 < 0 which is true (but it's supposed to be False)
Therefore we add an explicit type cast.
Test Plan: Remove the test skip
Differential Revision: D65472170
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139870
Approved by: https://github.com/Skylion007, https://github.com/davidberard98
This PR updates the binding for `stream_write_value32` to be consistent with `memset32` which IMO makes more sense for this type of utilities:
- Changed the API to take a uint32 tensor as argument, instead of a device pointer
- Changed the Python binding to be a static method of `_SymmetricMemory`, instead of a object method
- Use the dispatcher for device dispatching, as opposed to `SymmetricMemory` backends
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139934
Approved by: https://github.com/weifengpy
ghstack dependencies: #139227
[AOTI] Introduce an extensibility mechanism for the c shim codegen to make it easy to produce c shims for out-of-tree OP kernels as well. Add c shim for XPU.
### Motivation
Since the current c shim codegen will only produce C wrappers for Op's registered in `aten/src/ATen/native/native_functions.yaml`, for the same backend, when a portion of out-of-tree OP's are not registered in that file, but are registered externally. For example, `third_party/torch-xpu-ops/yaml/native_functions.yaml` , in this case, the existing codegen can't fulfill the need to do extensions for the c shims from the out-of-tree OPs for the in-tree that has already been produced.
### Design
To extend the c shim with more OP for a backend from out-of-tree.
The PR provided a bool option `--aoti-extend` to indicate the codegen is to extend c shim from out-of-tree.
The generated c shim is stored in the `extend` subdirectory , for example:
```
torch/include/torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.h
torch/include/torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.cpp
torch/include/torch/csrc/inductor/aoti_torch/generated/extend/c_shim_xpu.h
torch/include/torch/csrc/inductor/aoti_torch/generated/extend/c_shim_xpu.cpp
```
example usage:
`python -m torchgen.gen --source-path third_party/torch-xpu-ops/yaml/ --xpu --aoti-extend --update-aoti-c-shim `
`--xpu`: generate c shim for XPU
`--aoti-extend `: this is an out-of-tree OPs(defined in `third_party/torch-xpu-ops/yaml/native_functions.yaml`) extend for in-tree ops(defined in `aten/src/ATen/native/native_functions.yaml`)
`--update-aoti-c-shim`: always generate c_shim_xpu.h for the extend c_shim.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136742
Approved by: https://github.com/EikanWang, https://github.com/desertfire
ghstack dependencies: #139025
Summary: Add time log to cudagraph, including `create deferred_cudagraphify wrapper`, `warmup`, `record`, and `checkpoint`.
Test Plan:
1. buck2 run fbcode//mode/opt //pytorch/benchmark:run -- resnet50 -d cuda -t train --inductor --pt2-triton-cudagraph
2. Found the result in [scuba table](https://fburl.com/scuba/pt2_compile_events/0oik8nu9).
{F1954034920}
Differential Revision: D65505659
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139818
Approved by: https://github.com/eellison
Summary:
AMD lowering duration is 1.55x longer than H100. Profiling shows hipification related functions took 22% of overall lowering time.
This diff cuts that time by safely memoize the trie to regex logic. The trick is to incrementally build a state of the trie during the trie construction. The state is the hash of all the words added to the trie.
Differential Revision: D65659445
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140156
Approved by: https://github.com/ColinPeppler
Co-authored-by: Kefei Lu <kefeilu@meta.com>
Fixes#126268
I've basically followed @ezyang suggestion (I think) to use `func.decompose(...)`. Since `__torch_dispatch__` won't be called a second time for the same op, I've added a second `TorchDispatchMode` (`_DecomposedCounterMode`) that simpy dispatches to the parent flop counter. Using `self` as the inner context manager is not possible, since the second call to `__enter__` would re-initialize the counter's tracking state.
Let me know if there's something wrong with this implementation, since I'm quite unsure how the decomposition thing actually works :D
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138508
Approved by: https://github.com/ezyang
This fix was a bit more involved:
1) It fixes a item_memo loss place.
2) It updates a test to be eager instead of aot_eager since it reveals a very obscure bug related to replacements that's not worth solving since in practice inductor will regenerate the runtime asserts anyways
3) It updates tensorify to specialize more places now that the aforementioned bug is fixed.
Fixes `PYTORCH_OPINFO_SAMPLE_INPUT_INDEX=6 python test/inductor/test_torchinductor_opinfo.py TestInductorOpInfoCPU.test_comprehensive_linalg_norm_cpu_float16` when `specialize_float=False`
while ensuring `python test/dynamo/test_dynamic_shapes.py DynamicShapesMiscTests.test_runtime_assert_replacement_dynamic_shapes` doesn't regress
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139587
Approved by: https://github.com/ezyang
ghstack dependencies: #139569, #139457, #139568, #139572, #139846, #139454, #139896, #139935
Currently, we get all partition id by iterating assignment whose size is same as the number of nodes in graph. But we can reach same results by iterating partitions_by_id whose size is much smaller than the nodes number. Assume the number of nodes is N, the number of partitions is P, the time complexity decrease from O(N * N) to O(N * P) after this patch.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136598
Approved by: https://github.com/ezyang
Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
This PR contains several fixes related to non-contiguous NJTs:
1. Propagates `lengths` through op calls appropriately (see desc of #138098)
* SDPA now calls `nested_view_from_values_offsets_lengths()` instead of `nested_view_from_values_offsets()`
2. Allows non-contig NJTs in unsqueeze / transpose / select
3. Expands padded dense -> NJT conversion to support non-contig NJTs
4. (unrelated sorry) Updates `split` / `split_with_sizes` to allow for optional `dim`, matching the ATen signature
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140160
Approved by: https://github.com/cpuhrsch
Fixes#136559
As we upgrade to NumPy 2, torch falsely filtered out `numpy.random` as unsupported in dynamo tracking.
This PR changes the filtering rules to include them while keeping behavior with numpy 1 unchanged.
Before this PR, the following tests failed:
```
PYTORCH_TEST_WITH_ASAN=1 PYTORCH_TEST_WITH_UBSAN=1 python test/dynamo/test_functions.py -k FunctionTests.test_numpy_random
PYTORCH_TEST_WITH_ASAN=1 PYTORCH_TEST_WITH_UBSAN=1 python test/dynamo/test_unspec.py -k UnspecTests.test_to_tensor
PYTORCH_TEST_WITH_ASAN=1 PYTORCH_TEST_WITH_UBSAN=1 python test/test_fake_tensor.py -k FakeTensorTest.test_export_numpy
PYTORCH_TEST_WITH_ASAN=1 PYTORCH_TEST_WITH_UBSAN=1 python test/test_fake_tensor.py -k PropagateRealTensorsFakeTensorTest.test_export_numpy_propagate_real_tensors
```
With this PR, the supported/unsupported ops in NumPy 1 are not changed.
For NumPy 2, only the `numpy.random` ops that are already supported with NumPy 1 are added to the supported list.
I used the following scripts to check the differences before and after the change for both NumPy 1 & 2.
The output is empty for NumPy 1 since there is no change.
The output is a list of `numpy.random` that considered supported for NumPy 2.
```py
from torch._dynamo import trace_rules
import numpy as np
def new_numpy_function_ids():
unsupported_funcs = {"seed", "ranf", "get_bit_generator", "RandomState", "set_bit_generator", "sample"}
def is_supported(k, v, mod):
if not callable(v):
return False
if not getattr(v, "__module__", None):
return True
if v.__module__ == mod.__name__:
return True
if v.__module__ == "numpy.random.mtrand" and mod.__name__== "numpy.random" and k not in unsupported_funcs:
return True
return False
rv = {}
for mod in trace_rules.NP_SUPPORTED_MODULES:
for k, v in mod.__dict__.items():
if is_supported(k, v, mod):
rv[id(v)] = f"{mod.__name__}.{k}"
return rv
def old_numpy_function_ids():
rv = {}
for mod in trace_rules.NP_SUPPORTED_MODULES:
rv.update(
{
id(v): f"{mod.__name__}.{k}"
for k, v in mod.__dict__.items()
if callable(v)
and (getattr(v, "__module__", None) or mod.__name__) == mod.__name__
}
)
return rv
rv1 = set(old_numpy_function_ids().values())
rv2 = set(new_numpy_function_ids().values())
for v in (rv1 - rv2):
print(v)
print("****")
for v in (rv2 - rv1):
print(v)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138686
Approved by: https://github.com/williamwen42
This PR introduces the following:
### torch.ops.symm_mem._async_input_mm
`_async_input_mm(Tensor a, Tensor b, Tensor a_chunk_signals, int a_chunk_pivot) -> Tensor`
An mm impl that supports consuming asynchronous input. It guarantees the following rasterization order, and that the corresponding signal arrives before an input chunk is consumed.
```
num_chunks = a_chunks_signals.numel()
for chunk_idx in range(a_chunk_pivot, num_chunks + a_chunk_pivot):
chunk_idx = chunk_idx % num_chunks
wait_signal(a_chunk_signals, chunk_idx)
# Compute output tiles that consumes the input chunk
```
### PersistentAsyncInputScheduler
This is a forked version of PersistentScheduler that supports consuming asynchronous input. This tile scheduler introduces the following arguments:
- `tiles_per_chunk_m` – Specifies the size of an M chunk. Chunks are the granularity at which the asynchronous input becomes ready. It must be an interger multiple of the size of an M tile.
- `chunk_signals` – `chunk_signals[i] == 1` indicates that chunk i is ready. Before returning a work tile, get_current_work() waits for the signal to ensure that the corresponding chunk is ready.
- `tile_idx_pivot_m` – After applying swizzling, apply `pivot(m) => (m + tile_idx_pivot_m) % tiles_m` to `m`. In a distributed setting, this allows different ranks to process different m indices at the same time, thus avoiding communication hotspots.
Note that this scheduler currently only supports the `KernelTmaWarpSpecializedCooperative` kernel schedule. This is enforced via the template argument `KernelSchedule`.
Usage:
```
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>,
CollectiveMainloop,
CollectiveEpilogue,
cutlass::gemm::PersistentAsyncInputScheduler<KernelSchedule>>;
```
### _fused_all_gather_matmul_native
An ag-mm impl that combines `torch.ops.symm_mem._async_input_mm` and progress-aware all-gather. This is not yet enabled via the async-tp passes. We will use it as a backend to optimize the current decomposition-based async-tp impl.
## Benchmarks
### 4096x3584x8192
- cublas + nccl: 539us
- decomp-based async-tp w/o cuda graph: 694us
- decomp-based async-tp w/ cuda graph: 478us
- new cutlass kernel: 408us
<img width="478" alt="image" src="https://github.com/user-attachments/assets/39f316ab-36c5-4b41-af77-07854a385dfc">
### 2048x3584x8192
- cublas + nccl: 301us
- decomp-based async-tp w/o cuda graph: 687us
- decomp-based async-tp w/ cuda graph: 356us
- new cutlass kernel: 276us
<img width="441" alt="image" src="https://github.com/user-attachments/assets/9e23ce21-863b-43dd-a562-fb05d3a5a144">
## Next Steps
- Add tuning logic
- Use `_fused_all_gather_matmul_native` as a backend for the decomp-based async-tp impl
Differential temp Revision: [D65623152](https://our.internmc.facebook.com/intern/diff/D65623152)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139227
Approved by: https://github.com/weifengpy, https://github.com/Chillee