Commit Graph

43731 Commits

Author SHA1 Message Date
PyTorch MergeBot
dbb55b448b Revert "[7/N] Fix Wextra-semi warning (#140225)"
This reverts commit ffb979032d.

Reverted https://github.com/pytorch/pytorch/pull/140225 on behalf of https://github.com/kit1980 due to breaking internal builds ([comment](https://github.com/pytorch/pytorch/pull/140225#issuecomment-2469312229))
2024-11-12 00:02:06 +00:00
Tugsbayasgalan Manlaibaatar
0af38b1034 Remove temp table to post autograd IR (#140085)
This table is not needed

Differential Revision: [D64553397](https://our.internmc.facebook.com/intern/diff/D64553397/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140085
Approved by: https://github.com/justinchuby, https://github.com/bdhirsh
2024-11-11 23:59:09 +00:00
Felix Zimmermann
c223e0642c Tighten type hints for tensor arithmetic (#135392)
Fixes #124015

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135392
Approved by: https://github.com/ezyang
2024-11-11 23:55:27 +00:00
PyTorch MergeBot
222175b3d5 Revert "[Partitioner] Enumerate partitions by iterating partition ids (#136598)"
This reverts commit 2ede4c9a38.

Reverted https://github.com/pytorch/pytorch/pull/136598 on behalf of https://github.com/kit1980 due to breaking internal ExecuTorch tests ([comment](https://github.com/pytorch/pytorch/pull/136598#issuecomment-2469294995))
2024-11-11 23:42:51 +00:00
PyTorch MergeBot
412df50454 Revert "[dynamo] Remove dead code path for capturing __class__ in UserFunctionVariable (#140034)"
This reverts commit de40a23f6c.

Reverted https://github.com/pytorch/pytorch/pull/140034 on behalf of https://github.com/kit1980 due to breaking internal tests, see D65755044 ([comment](https://github.com/pytorch/pytorch/pull/140034#issuecomment-2469290205))
2024-11-11 23:38:00 +00:00
Animesh Jain
5eb1ccadc2 [dynamo][user-defined] Walk __mro__ to get the member descriptor source (#140300)
Fixes https://github.com/pytorch/pytorch/issues/140266

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140300
Approved by: https://github.com/williamwen42
2024-11-11 23:16:48 +00:00
Animesh Jain
5f7ea7ca6a [invoke_subgraph] Support symint/int as inputs (#140058)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140058
Approved by: https://github.com/ydwu4, https://github.com/eellison
ghstack dependencies: #139162
2024-11-11 22:26:43 +00:00
Xuan Zhang
d4cdc09881 ILP for auto FSDP wrapping (#140298)
This PR presents a mixed integer linear programming (MILP) formulation that can be utilized to determine, under a memory budget, which modules to wrap as FSDP units. Similar to the auto SAC MILP introduced in https://github.com/pytorch/pytorch/pull/137908, the MILP uses information collected from MemTracker, Runtime Estimator, and SAC Estimator, introduced in these PRs:
* https://github.com/pytorch/pytorch/pull/124688
* https://github.com/pytorch/pytorch/pull/134243
* https://github.com/pytorch/pytorch/pull/135208

End-to-end example and its sample output:

```
import copy
from typing import Tuple

import torch
from torch._subclasses.fake_tensor import FakeTensorMode

from torch.distributed._tools.ilp_utils import (
    aggregate_stats,
    get_peak_memory_runtime_baseline,
    parse_module_info,
)
from torch.distributed._tools.mem_tracker import _ModState, MemTracker
from torch.distributed._tools.runtime_estimator import RuntimeEstimator
from torch.distributed._tools.sac_estimator import SACEstimator
from torch.distributed._tools.fsdp_ilp import fsdp_milp, CommType, CommParams
from torch.testing._internal.distributed._tensor.common_dtensor import (
    ModelArgs,
    Transformer,
)

def _init_model_input_optimizer() -> (
    Tuple[torch.nn.Module, torch.optim.Optimizer, torch.Tensor]
):
    bsz = 2
    model_args = ModelArgs(
        n_layers=6,
        n_heads=12,
        vocab_size=8192,
        max_seq_len=1024,
        dim=6144,
        dropout_p=0.1,
    )
    with torch.device(torch.cuda.current_device()):
        model = Transformer(model_args)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=True)
    inp = torch.randint(
        0,
        model_args.vocab_size,
        (bsz, model_args.max_seq_len),
        device=torch.cuda.current_device(),
    )
    return (model, optimizer, inp)

def _run_and_get_mem_tracker(
    model: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    inp: torch.Tensor,
) -> MemTracker:
    mem_tracker = MemTracker()
    mem_tracker.track_external(model, optimizer)
    with mem_tracker as mt:
        for iter_idx in range(2):  # running twice to initialize optimizer
            output = model(inp)
            output.sum().backward()
            if iter_idx == 1:
                last_snapshot = mt.get_tracker_snapshot("current")
            optimizer.step()
            optimizer.zero_grad()
            if iter_idx == 0:
                mt.reset_mod_stats()
    assert last_snapshot is not None
    for mod_stats in mem_tracker.memory_tracking.values():
        if _ModState.POST_BW not in mod_stats.snapshots.keys():
            mod_stats.snapshots.setdefault(_ModState.POST_BW, []).append(
                copy.deepcopy(last_snapshot)
            )
    return mem_tracker

def _run_and_get_runtime_estimator(
    model: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    inp: torch.Tensor,
) -> RuntimeEstimator:
    def _run_one_step() -> None:
        output = model(inp)
        output.sum().backward()
        optimizer.step()
        optimizer.zero_grad()

    # Initializing optimizer states and warm-up
    _run_one_step()

    runtime_estimator = RuntimeEstimator()
    with runtime_estimator(estimate_mode_type="operator-level-cost-model"):
        _run_one_step()  # We use only one iteration for estimation
    return runtime_estimator

def _run_and_get_sac_estimator(
    model: torch.nn.Module,
    inp: torch.Tensor,
) -> SACEstimator:
    sac_estimator = SACEstimator()
    with sac_estimator(estimate_mode_type="operator-level-cost-model"):
        loss = model(inp).sum()
    loss.backward()
    return sac_estimator

def main():
    with FakeTensorMode():
        model, optimizer, inp = _init_model_input_optimizer()
        mem_tracker = _run_and_get_mem_tracker(model, optimizer, inp)
        runtime_estimator = _run_and_get_runtime_estimator(model, optimizer, inp)
        sac_estimator = _run_and_get_sac_estimator(model, inp)
        mod_info = aggregate_stats(
            model,
            mem_tracker,
            runtime_estimator,
            sac_estimator,
            torch.device(torch.cuda.current_device()),
        )
        g = parse_module_info(mod_info)

        peak_mem, compute_time = get_peak_memory_runtime_baseline(g)
        print("=== WITHOUT FSDP ===")
        print(f"peak_mem: {round(peak_mem / 2**30, 2)} GiB")
        print(f"compute_time: {round(compute_time, 2)} ms")

        fsdp_decisions, exposed_comm_time, peak_mem = fsdp_milp(
            g,
            world_size=8,
            memory_budget=15,
            comm_params={
                CommType.ALL_GATHER: CommParams(latency=0.01, bandwidth=2 * 1e8),
                CommType.REDUCE_SCATTER: CommParams(latency=0.01, bandwidth=2 * 1e8),
            },
        )
        print("=== WITH FSDP on 8 ranks ===")
        print(f"fsdp units: {sorted(fsdp_decisions)}")
        print(f"peak_mem: {round(peak_mem / 2**30, 2)} GiB")
        print(f"exposed communication time: {round(exposed_comm_time, 2)} ms")

if __name__ == "__main__":
    main()
```

```
=== WITHOUT FSDP ===
peak_mem: 20.92 GiB
compute_time: 1375.49 ms
=== WITH FSDP on 8 ranks ===
fsdp units: ['Transformer', 'Transformer.layers.0.attention.wk', 'Transformer.layers.0.attention.wo', 'Transformer.layers.0.attention.wq', 'Transformer.layers.0.attention.wv', 'Transformer.layers.0.feed_forward.w1', 'Transformer.layers.0.feed_forward.w2', 'Transformer.layers.1', 'Transformer.layers.2', 'Transformer.layers.3', 'Transformer.layers.4', 'Transformer.layers.5', 'Transformer.output', 'Transformer.pos_embeddings']
peak_mem: 13.63 GiB
exposed communication time: 1.02 ms
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140298
Approved by: https://github.com/weifengpy
2024-11-11 22:02:39 +00:00
Bin Bao
2c77352fe2 [AOTI][refactor] Clean up call chain in wrapper codegen (#136531)
Summary: For cpp wrapper, generate_kernel_call and define_kernel need to handle both cpu and gpu kernels. Refactor the code to remove nested super() calls.

Differential Revision: [D65639095](https://our.internmc.facebook.com/intern/diff/D65639095)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136531
Approved by: https://github.com/frank-wei
2024-11-11 22:00:42 +00:00
Justin Chu
780b28f67e [ONNX] Update docstring typo in building (#140281)
The oprecorder docstring mistakenly referred to torchscript when it should say ONNX IR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140281
Approved by: https://github.com/titaiwangms
2024-11-11 21:01:27 +00:00
Rachel Guo
10e40dd5ca [aoti][tooling] Add support to debug printing for all AOTI model run input args (#140064)
Summary:
Add debug printing around: `void AOTInductorModel::run_impl()`

Example:
```
void AOTInductorModel::run_impl(
    AtenTensorHandle*
        input_handles, // array of input AtenTensorHandle; handles
                        // are stolen; the array itself is borrowed
    AtenTensorHandle*
        output_handles, // array for writing output AtenTensorHandle; handles
                        // will be stolen by the caller; the array itself is
                        // borrowed
    DeviceStreamType stream,
    AOTIProxyExecutorHandle proxy_executor
) {

    auto inputs = steal_from_raw_handles_to_raii_handles(input_handles, 3);
    auto arg0_1 = std::move(inputs[0]);
    auto arg1_1 = std::move(inputs[1]);
    auto arg2_1 = std::move(inputs[2]);
    aoti_torch_print_tensor_handle(arg0_1, "aoti_model_inputs - arg0_1");
    aoti_torch_print_tensor_handle(arg1_1, "aoti_model_inputs - arg1_1");
    aoti_torch_print_tensor_handle(arg2_1, "aoti_model_inputs - arg2_1");
```

Differential Revision: D65616590

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140064
Approved by: https://github.com/chenyang78
2024-11-11 20:10:35 +00:00
Joel Schlosser
e7ec294c10 NJT OpInfo tests v2 (#138370)
This PR updates OpInfo-based tests for NJTs:
* Adds extensive coverage across non-contiguous NJTs (both non-contiguous transposed and non-contiguous with holes)
    * The `_sample_njts()` helper that `sample_input_func`s utilize now produces non-contig NJTs as well
* Utilizes a `SampleInput`-based xfail system for granular classification of bugs. For example, it's possible to indicate that a class of ops is expected to fail only on non-contig with holes NJT inputs.
    * I decided on adding `SampleInput`s and utilizing this system over using test parametrization for two reasons:
        * Test perf - adding `SampleInput`s is faster than generating entire new tests
        * Avoiding the possibility of `sample_input_func`s not respecting the non-contig test parameter - this would result in silently incorrect passing of these tests. Keeping the responsibility for `SampleInput` generation firmly within each `OpInfo`'s `sample_input_func` means weirdness like this isn't possible
* Improves `SampleInput` naming for a bunch of `sample_input_func`s. This makes it easier to xfail them as needed. For example, binary / unary / other ops now use the new `_describe_njt()` helper to get a string repr that uniquely defines the type of NJT being passed to the op
* Adds appropriate `XFailRule`s to get tests passing for forward / backward / forward compile / backward compile. In general, each xfail corresponds to some bug that needs to be fixed

```python
# Represents a rule indicating how to xfail a particular test. It allows granularity
# at the device, dtype, op, and individual sample levels. This flexibility allows entire
# bugs to be represented by a single rule, even if this corresponds with multiple conceptual
# test cases across multiple ops.
@dataclass
class XFailRule:
    # expected error type
    error_type: TypeVar = Exception
    # expected error message
    error_msg: str = ".*"
    # function to indicate whether the rule applies; return True if so
    match_fn: Callable[[torch.device, torch.dtype, OpInfo, SampleInput], bool] = None
    # optional name for identifying the rule
    name: str = ""

    def match(self, device, dtype, op, sample) -> bool:
        return self.match_fn(device, dtype, op, sample)
```

Example:
```python
    # Bug when broadcasting a binary op with non-contiguous with holes NJT + dense
    # tensor with 1 in ragged dim.
    XFailRule(
        error_type=RuntimeError,
        error_msg="cannot call binary pointwise function .* with inputs of shapes",
        match_fn=lambda device, dtype, op, sample: (
            isinstance(op, BinaryUfuncInfo)
            and "noncontig_holes" in sample.name
            and "broadcasting 1 over ragged" in sample.name
        ),
        name="binary_noncontig_holes_broadcasting_1_over_ragged",
    ),
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138370
Approved by: https://github.com/cpuhrsch, https://github.com/soulitzer
ghstack dependencies: #140160
2024-11-11 19:35:24 +00:00
Yifu Wang
0a0915fb5e [SymmetricMemory] improve the API for stream_write_value32 (#139934)
This PR updates the binding for `stream_write_value32` to be consistent with `memset32` which IMO makes more sense for this type of utilities:
- Changed the API to take a uint32 tensor as argument, instead of a device pointer
- Changed the Python binding to be a static method of `_SymmetricMemory`, instead of a object method
- Use the dispatcher for device dispatching, as opposed to `SymmetricMemory` backends

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139934
Approved by: https://github.com/weifengpy
ghstack dependencies: #139227
2024-11-11 18:49:22 +00:00
PyTorch MergeBot
5f4a21dc58 Revert "[SymmetricMemory] improve the API for stream_write_value32 (#139934)"
This reverts commit 2f3a5a15ef.

Reverted https://github.com/pytorch/pytorch/pull/139934 on behalf of https://github.com/malfet due to Broke distributed tests, see https://github.com/pytorch/pytorch/actions/runs/11770673088/job/32784210441 ([comment](https://github.com/pytorch/pytorch/pull/139934#issuecomment-2468641512))
2024-11-11 17:02:07 +00:00
Richard Zou
04b5b4a94e Add base class for single-subgraph inductor HOPs (#139898)
This PR adds "PrimHOPBase", which is intended to be a base class that
one can extend to create new HOPs that match some criteria:
- they take one subgraph as input, and their semantics are running the
  subgraph on some operands
- the HOP stays alive until Inductor

The motivation is that we are seeing a lot more HOPs (invoke_subgraph,
invoke_quant) that have this property and there can be a lot of shared
code between them.

Future:
- Migrate invoke_subgraph to use this
- There are some TODOs in the code

Test Plan:
- new tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139898
Approved by: https://github.com/anijain2305, https://github.com/ydwu4
2024-11-11 16:12:35 +00:00
David Berard
d4b8857e51 [codecache][triton 3.2] hash -> base64 conversion for triton 3.2 (#140190)
In old triton versions, you take the hash of the triton kernel and use it in the filepath for the cached kernel. In Triton 3.2 (after https://github.com/triton-lang/triton/pull/4553), the filepath will use the base-64-encoded representation of the hash in the path.

This PR checks whether the `_base64` function exists in triton, and if so, uses the base-64-encoded represenatation in the path.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140190
Approved by: https://github.com/ezyang
2024-11-11 15:32:28 +00:00
Sam Larsen
cb15c15157 [logging] Overhaul dynamo_timed and CompilationMetrics logging. (#139849)
Here's the overview:

There's a new contextmanager singleton called MetricsContext. Entering the MetricsContext is how we demarcate the boundary on which we'll create a single CompilationMetrics object, and therefore, a single dynamo_compile log entry. While we're inside the MetricsContext, we can update/set many different metrics. Most importantly: `dynamo_timed` can also update the in-progress MetricsContext. In the proposal here, we tell `dynamo_timed` that we want it to do so by providing the name of the MetricsContext field to increment. There can be many `dynamo_timed` calls in different parts of the code updating different fields. Then when the MetricsContext exits, that's when the logging of everything gathered finally happens. One potential footgun is trying to use `dynamo_timed` when we haven't entered the MetricsContext, but we assert on that problem. Another problem is that we re-enter the context recursively, but we watch for that and do the logging only when the outermost exits.

Some specifics:
* Introduce MetricsContext - a context manager that on exit, records the CompilationMetrics (which also logs to dynamo_compile).
* Completely remove the concept of frame_phase_timing. Instead, update the MetricsContext during compilation, either directly or via dynamo_timed.
* Remove some globals we previously used to accumulate counters to later populate a CompilationMetrics. We use CompilationMetrics set/update/increment APIs instead.
* `record_compilation_metrics` is now called on exit from MetricsContext.
* Populate legacy CompilationMetrics fields right before logging, inside `record_compilation_metrics`.
* Remove the one-off `add_remote_cache_time_saved` helper; capture that timing directly into the MetricsContext.

And specifically, several changes to dynamo_timed:
* "Modernize" the parameters and update all callsites accordingly.
* Move the backwards logging of the CompilationMetrics to the backwards compile location.
* Add a parameter for which CompilationMetrics field to update

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139849
Approved by: https://github.com/ezyang
ghstack dependencies: #140094
2024-11-11 14:24:23 +00:00
Xiaodong Wang
565a7942ee Recover non-standard bool test for msort (#139870)
Summary:
I was looking into why the non-standard bool value will fail for msort - it makes sense for argsort and sort to fail, because we're randomly generating uint8 so the order will be different (and thus the indices will be different). But msort should work.

After some digging, it's interesting that even though scalar_t is bool, when the actual value is a uint8_t, the comparison will treat them as signed. I tried lhs=255 and rhs=0: lhs < rhs is equivalent to -1 < 0 which is true (but it's supposed to be False)

Therefore we add an explicit type cast.

Test Plan: Remove the test skip

Differential Revision: D65472170

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139870
Approved by: https://github.com/Skylion007, https://github.com/davidberard98
2024-11-11 02:00:34 +00:00
Yifu Wang
2f3a5a15ef [SymmetricMemory] improve the API for stream_write_value32 (#139934)
This PR updates the binding for `stream_write_value32` to be consistent with `memset32` which IMO makes more sense for this type of utilities:
- Changed the API to take a uint32 tensor as argument, instead of a device pointer
- Changed the Python binding to be a static method of `_SymmetricMemory`, instead of a object method
- Use the dispatcher for device dispatching, as opposed to `SymmetricMemory` backends

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139934
Approved by: https://github.com/weifengpy
ghstack dependencies: #139227
2024-11-11 01:54:35 +00:00
cyy
ffb979032d [7/N] Fix Wextra-semi warning (#140225)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140225
Approved by: https://github.com/ezyang
2024-11-10 14:28:10 +00:00
CaoE
94c9bb73c0 [Inductor] [CPP] Update BRGEMM parameters for Half cpp gemm template (#140116)
Update BRGEMM parameters for Half cpp gemm template as BRGEMM api is changed https://github.com/pytorch/pytorch/pull/138184.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140116
Approved by: https://github.com/jansel
2024-11-10 06:37:10 +00:00
cyy
7d4f5f7508 [Environment Variable][6/N] Use thread-safe getenv functions (#140200)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140200
Approved by: https://github.com/ezyang
2024-11-09 15:05:51 +00:00
fduwjj
5107d244ee [c10d][Logging] Remove args and kwargs from c10d logging (#140169)
This PR is trying to reland https://github.com/pytorch/pytorch/pull/139804

We now don't want to log args and kwargs directly because if they contain tensor or tensor subclass it would take lots of time in conversion to string or even not supported.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140169
Approved by: https://github.com/wz337, https://github.com/kwen2501
2024-11-09 13:57:32 +00:00
Yu, Guangye
052b67e2b4 Add torch.version.xpu (#139466)
# Motivation
We add a new attribute `torch.version.xpu` to facilitate the problem diagnosing and version control.

# Additional Context
It is aligned with `torch.version.cuda` and `torch.version.hip`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139466
Approved by: https://github.com/EikanWang, https://github.com/ezyang, https://github.com/atalman, https://github.com/malfet
ghstack dependencies: #139258
2024-11-09 13:31:21 +00:00
xinan.lin
191971e01d [AOTI] Introduce an extensibility mechanism for the c shim codegen to make it easy to produce c shims for out-of-tree OP kernels as well. Add c_shim for XPU. (#136742)
[AOTI] Introduce an extensibility mechanism for the c shim codegen to make it easy to produce c shims for out-of-tree OP kernels as well. Add c shim for XPU.

### Motivation
Since the current c shim codegen will only produce C wrappers for Op's registered in `aten/src/ATen/native/native_functions.yaml`, for the same backend, when a portion of out-of-tree OP's are not registered in that file, but are registered externally. For example, `third_party/torch-xpu-ops/yaml/native_functions.yaml` , in this case, the existing codegen can't fulfill the need to do extensions for the c shims from the out-of-tree OPs for the in-tree that has already been produced.

### Design
To extend the c shim with more OP for a backend from out-of-tree.
The PR provided a bool option `--aoti-extend` to indicate the codegen is to extend c shim from out-of-tree.
The generated c shim is stored in the `extend` subdirectory , for example:
```
torch/include/torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.h
torch/include/torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.cpp
torch/include/torch/csrc/inductor/aoti_torch/generated/extend/c_shim_xpu.h
torch/include/torch/csrc/inductor/aoti_torch/generated/extend/c_shim_xpu.cpp
```
example usage:
`python -m torchgen.gen --source-path third_party/torch-xpu-ops/yaml/ --xpu --aoti-extend --update-aoti-c-shim  `
`--xpu`:  generate c shim for XPU
`--aoti-extend `: this is an out-of-tree OPs(defined in `third_party/torch-xpu-ops/yaml/native_functions.yaml`)  extend for in-tree ops(defined in `aten/src/ATen/native/native_functions.yaml`)
`--update-aoti-c-shim`: always generate c_shim_xpu.h for the extend c_shim.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136742
Approved by: https://github.com/EikanWang, https://github.com/desertfire
ghstack dependencies: #139025
2024-11-09 13:19:52 +00:00
Boyuan Feng
e2e425b4f3 [CUDAGraph] Add dynamo timer to checkpoint, warmup, and record (#139818)
Summary: Add time log to cudagraph, including `create deferred_cudagraphify wrapper`, `warmup`,	`record`, and `checkpoint`.

Test Plan:
1. buck2 run fbcode//mode/opt //pytorch/benchmark:run -- resnet50 -d cuda -t train --inductor --pt2-triton-cudagraph

2. Found the result in [scuba table](https://fburl.com/scuba/pt2_compile_events/0oik8nu9).

 {F1954034920}

Differential Revision: D65505659

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139818
Approved by: https://github.com/eellison
2024-11-09 05:27:11 +00:00
cyy
ab55a99283 Use TORCH_DECLARE_XXX (#139952)
Because those files use TORCH_API

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139952
Approved by: https://github.com/ezyang
2024-11-09 04:56:28 +00:00
Kefei Lu
d2d1258b1b Speed up AMD AOT Inductor lowering by memoizing hipify trie to regex logic (#140156)
Summary:
AMD lowering duration is 1.55x longer than H100. Profiling shows hipification related functions took 22% of overall lowering time.

This diff cuts that time by safely memoize the trie to regex logic. The trick is to incrementally build a state of the trie during the trie construction. The state is the hash of all the words added to the trie.

Differential Revision: D65659445

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140156
Approved by: https://github.com/ColinPeppler

Co-authored-by: Kefei Lu <kefeilu@meta.com>
2024-11-09 04:28:58 +00:00
Michael Lazos
8b2e3855a9 Make size a property with an assertion (#139794)
Fixes https://github.com/pytorch/pytorch/issues/120568

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139794
Approved by: https://github.com/williamwen42
2024-11-09 03:39:41 +00:00
cyy
032135f8a2 [2/N] Turn inline static functions into static (#140068)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140068
Approved by: https://github.com/ezyang
2024-11-09 03:31:24 +00:00
Bob Ren
3b8470c461 add special case for __round__ constant variables (#139583)
Fixes `PYTORCH_TEST_WITH_INDUCTOR=1 tlp python test/test_torch.py TestTorchDeviceTypeCUDA.test_cauchy_cuda_float64` when specialize_float=False

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139583
Approved by: https://github.com/ezyang
ghstack dependencies: #139569, #139457, #139568, #139572, #139846, #139454, #139896, #139935, #139587
2024-11-09 03:25:53 +00:00
Florian (Feuermagier)
f915409c26 FlopCounterMode: Decompose ops for inference mode (#138508)
Fixes #126268

I've basically followed @ezyang suggestion (I think) to use `func.decompose(...)`. Since `__torch_dispatch__` won't be called a second time for the same op, I've added a second `TorchDispatchMode` (`_DecomposedCounterMode`) that simpy dispatches to the parent flop counter. Using `self` as the inner context manager is not possible, since the second call to `__enter__` would re-initialize the counter's tracking state.

Let me know if there's something wrong with this implementation, since I'm quite unsure how the decomposition thing actually works :D

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138508
Approved by: https://github.com/ezyang
2024-11-09 03:13:53 +00:00
Bob Ren
4488e23763 Fix another item memo loss location + bool specialization bug (#139587)
This fix was a bit more involved:
1) It fixes a item_memo loss place.
2) It updates a test to be eager instead of aot_eager since it reveals a very obscure bug related to replacements that's not worth solving since in practice inductor will regenerate the runtime asserts anyways
3) It updates tensorify to specialize more places now that the aforementioned bug is fixed.

Fixes `PYTORCH_OPINFO_SAMPLE_INPUT_INDEX=6 python test/inductor/test_torchinductor_opinfo.py TestInductorOpInfoCPU.test_comprehensive_linalg_norm_cpu_float16` when `specialize_float=False`

while ensuring `python test/dynamo/test_dynamic_shapes.py DynamicShapesMiscTests.test_runtime_assert_replacement_dynamic_shapes` doesn't regress

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139587
Approved by: https://github.com/ezyang
ghstack dependencies: #139569, #139457, #139568, #139572, #139846, #139454, #139896, #139935
2024-11-09 03:11:19 +00:00
Zhou, Lingzhi
2ede4c9a38 [Partitioner] Enumerate partitions by iterating partition ids (#136598)
Currently, we get all partition id by iterating assignment whose size is same as the number of nodes in graph. But we can reach same results by iterating partitions_by_id whose size is much smaller than the nodes number. Assume the number of nodes is N, the number of partitions is P, the time complexity decrease from O(N * N) to O(N * P) after this patch.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136598
Approved by: https://github.com/ezyang

Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
2024-11-09 01:31:46 +00:00
Joel Schlosser
9c678af9f9 Misc. non-contig NJT fixes (#140160)
This PR contains several fixes related to non-contiguous NJTs:
1. Propagates `lengths` through op calls appropriately (see desc of #138098)
    * SDPA now calls `nested_view_from_values_offsets_lengths()` instead of `nested_view_from_values_offsets()`
2. Allows non-contig NJTs in unsqueeze / transpose / select
3. Expands padded dense -> NJT conversion to support non-contig NJTs
4. (unrelated sorry) Updates `split` / `split_with_sizes` to allow for optional `dim`, matching the ATen signature
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140160
Approved by: https://github.com/cpuhrsch
2024-11-09 01:18:26 +00:00
Ryan Guo
de40a23f6c [dynamo] Remove dead code path for capturing __class__ in UserFunctionVariable (#140034)
This was introduced in https://github.com/pytorch/torchdynamo/commit/d0c10341
as limited support for pre-existing cells, since we know `__class__` wouldn't be modified
in most cases. It's no longer needed now that we have much more support for these cells.

Example:
```python
class Foo():
    def __init__(self):
        super().__init__()

print(Foo.__init__.__code__.co_freevars) # ('__class__',)
print(Foo.__init__.__closure__)          # (<cell at 0x1011fb310: type object at 0x10fe185b0>,)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140034
Approved by: https://github.com/williamwen42, https://github.com/anijain2305, https://github.com/jansel
ghstack dependencies: #140033
2024-11-09 01:03:24 +00:00
Ryan Guo
0b8652a999 [dynamo] Remove NestedUserFunctionVariable.closure_scope (#140033)
This was no longer needed after https://github.com/pytorch/torchdynamo/commit/663e4d92,
which removed the uses of `closure_scope` but not the field itself.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140033
Approved by: https://github.com/williamwen42, https://github.com/anijain2305, https://github.com/jansel
2024-11-09 01:03:24 +00:00
Peter Steinbach
090b778b8a Clarify meaning of rate parameter in Gamma distribution (#134847)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134847
Approved by: https://github.com/fritzo
2024-11-09 00:22:13 +00:00
PyTorch MergeBot
7eb66173e2 Revert "Fix split decomp returning self (#140065)"
This reverts commit 9d99dceb53.

Reverted https://github.com/pytorch/pytorch/pull/140065 on behalf of https://github.com/ZainRizvi due to Diff been imported internally, but merged externally. And the internal diff has been updated so the diff and PR are now mismatched.  Reverting this PR to get things back into a consistent state. See D65635070 ([comment](https://github.com/pytorch/pytorch/pull/140065#issuecomment-2465928027))
2024-11-09 00:16:26 +00:00
PyTorch MergeBot
1400fedf76 Revert "add supports_coalescing property in c10d::Backend to determine whether backend supports coalescing (#135338)"
This reverts commit e5574445b0.

Reverted https://github.com/pytorch/pytorch/pull/135338 on behalf of https://github.com/ZainRizvi due to Sorry but this is failing internally. Please see D65663382 for more details ([comment](https://github.com/pytorch/pytorch/pull/135338#issuecomment-2465911854))
2024-11-08 23:52:49 +00:00
Michael Lazos
ea0f60ecfa [Dynamo] allow dynamic callables on tensor variables (#137940)
Fixes https://github.com/pytorch/pytorch/issues/134844

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137940
Approved by: https://github.com/williamwen42
2024-11-08 23:49:34 +00:00
PyTorch MergeBot
beae7725be Revert "Tighten type hints for tensor arithmetic (#135392)"
This reverts commit d378819068.

Reverted https://github.com/pytorch/pytorch/pull/135392 on behalf of https://github.com/ZainRizvi due to Sorry but this is breaking internally. See D65641103 for more details ([comment](https://github.com/pytorch/pytorch/pull/135392#issuecomment-2465906839))
2024-11-08 23:44:41 +00:00
Haifeng Jin
2af5172774 fix dynamo tracking numpy 2 ops (#138686)
Fixes #136559
As we upgrade to NumPy 2, torch falsely filtered out `numpy.random` as unsupported in dynamo tracking.
This PR changes the filtering rules to include them while keeping behavior with numpy 1 unchanged.

Before this PR, the following tests failed:

```
PYTORCH_TEST_WITH_ASAN=1 PYTORCH_TEST_WITH_UBSAN=1 python test/dynamo/test_functions.py -k FunctionTests.test_numpy_random
PYTORCH_TEST_WITH_ASAN=1 PYTORCH_TEST_WITH_UBSAN=1 python test/dynamo/test_unspec.py -k UnspecTests.test_to_tensor
PYTORCH_TEST_WITH_ASAN=1 PYTORCH_TEST_WITH_UBSAN=1 python test/test_fake_tensor.py -k FakeTensorTest.test_export_numpy
PYTORCH_TEST_WITH_ASAN=1 PYTORCH_TEST_WITH_UBSAN=1 python test/test_fake_tensor.py -k PropagateRealTensorsFakeTensorTest.test_export_numpy_propagate_real_tensors
```

With this PR, the supported/unsupported ops in NumPy 1 are not changed.
For NumPy 2, only the `numpy.random` ops that are already supported with NumPy 1 are added to the supported list.

I used the following scripts to check the differences before and after the change for both NumPy 1 & 2.
The output is empty for NumPy 1 since there is no change.
The output is a list of `numpy.random` that considered supported for NumPy 2.

```py
from torch._dynamo import trace_rules
import numpy as np

def new_numpy_function_ids():
    unsupported_funcs = {"seed", "ranf", "get_bit_generator", "RandomState", "set_bit_generator", "sample"}

    def is_supported(k, v, mod):
        if not callable(v):
            return False
        if not getattr(v, "__module__", None):
            return True
        if v.__module__ == mod.__name__:
            return True
        if v.__module__ == "numpy.random.mtrand" and mod.__name__== "numpy.random" and k not in unsupported_funcs:
            return True
        return False
    rv = {}
    for mod in trace_rules.NP_SUPPORTED_MODULES:
        for k, v in mod.__dict__.items():
            if is_supported(k, v, mod):
                rv[id(v)] = f"{mod.__name__}.{k}"
    return rv

def old_numpy_function_ids():
    rv = {}
    for mod in trace_rules.NP_SUPPORTED_MODULES:
        rv.update(
            {
                id(v): f"{mod.__name__}.{k}"
                for k, v in mod.__dict__.items()
                if callable(v)
                and (getattr(v, "__module__", None) or mod.__name__) == mod.__name__
            }
        )
    return rv

rv1 = set(old_numpy_function_ids().values())
rv2 = set(new_numpy_function_ids().values())

for v in (rv1 - rv2):
    print(v)
print("****")
for v in (rv2 - rv1):
    print(v)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138686
Approved by: https://github.com/williamwen42
2024-11-08 23:38:53 +00:00
Yifu Wang
1659e241c8 [experimental] async-tp impl with cutlass-based, progress aware kernel (#139227)
This PR introduces the following:

### torch.ops.symm_mem._async_input_mm

`_async_input_mm(Tensor a, Tensor b, Tensor a_chunk_signals, int a_chunk_pivot) -> Tensor`

An mm impl that supports consuming asynchronous input. It guarantees the following rasterization order, and that the corresponding signal arrives before an input chunk is consumed.
```
num_chunks = a_chunks_signals.numel()
for chunk_idx in range(a_chunk_pivot, num_chunks + a_chunk_pivot):
    chunk_idx = chunk_idx % num_chunks
    wait_signal(a_chunk_signals, chunk_idx)
    # Compute output tiles that consumes the input chunk
```

### PersistentAsyncInputScheduler

This is a forked version of PersistentScheduler that supports consuming asynchronous input. This tile scheduler introduces the following arguments:

- `tiles_per_chunk_m` – Specifies the size of an M chunk. Chunks are the granularity at which the asynchronous input becomes ready. It must be an interger multiple of the size of an M tile.
- `chunk_signals` – `chunk_signals[i] == 1` indicates that chunk i is ready. Before returning a work tile, get_current_work() waits for the signal to ensure that the corresponding chunk is ready.
- `tile_idx_pivot_m` – After applying swizzling, apply `pivot(m) => (m + tile_idx_pivot_m) % tiles_m` to `m`. In a distributed setting, this allows different ranks to process different m indices at the same time, thus avoiding communication hotspots.

Note that this scheduler currently only supports the `KernelTmaWarpSpecializedCooperative` kernel schedule. This is enforced via the template argument `KernelSchedule`.

Usage:
```
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
   Shape<int, int, int, int>,
   CollectiveMainloop,
   CollectiveEpilogue,
   cutlass::gemm::PersistentAsyncInputScheduler<KernelSchedule>>;
```

### _fused_all_gather_matmul_native
An ag-mm impl that combines `torch.ops.symm_mem._async_input_mm` and progress-aware all-gather. This is not yet enabled via the async-tp passes. We will use it as a backend to optimize the current decomposition-based async-tp impl.

## Benchmarks

### 4096x3584x8192
- cublas + nccl: 539us
- decomp-based async-tp w/o cuda graph: 694us
- decomp-based async-tp w/ cuda graph: 478us
- new cutlass kernel: 408us

<img width="478" alt="image" src="https://github.com/user-attachments/assets/39f316ab-36c5-4b41-af77-07854a385dfc">

### 2048x3584x8192
- cublas + nccl: 301us
- decomp-based async-tp w/o cuda graph: 687us
- decomp-based async-tp w/ cuda graph: 356us
- new cutlass kernel: 276us

<img width="441" alt="image" src="https://github.com/user-attachments/assets/9e23ce21-863b-43dd-a562-fb05d3a5a144">

## Next Steps
- Add tuning logic
- Use `_fused_all_gather_matmul_native` as a backend for the decomp-based async-tp impl

Differential temp Revision: [D65623152](https://our.internmc.facebook.com/intern/diff/D65623152)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139227
Approved by: https://github.com/weifengpy, https://github.com/Chillee
2024-11-08 23:28:25 +00:00
Natalia Gimelshein
1cdaf1d85f correctly keep track of processed tensors for foreach reductions (#140103)
Fixes #140066

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140103
Approved by: https://github.com/janeyx99

Co-authored-by: Jane Xu <janeyx@meta.com>
2024-11-08 23:04:53 +00:00
Gabriel Ferns
95198f8299 Remove uses of deleted operations (#139447)
resolves: https://github.com/pytorch/pytorch/issues/138721

Summary:

Delete the uses of deleted nodes. The double for-loop is icky here, but N should
be pretty small and removing it requires refactoring the datastructures
involved, which is a bigger endeavor.

Test Plan:

Normal test coverage should be sufficient. There were a couple of spots in the
scheduler code that didn't check users being deleted, so I'll run a perf test to see
what impact that has, and to make sure N^2 doesn't affect compile times.

Perf:
https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Tue%2C%2029%20Oct%202024%2017%3A41%3A36%20GMT&stopTime=Tue%2C%2005%20Nov%202024%2018%3A41%3A36%20GMT&granularity=hour&suite=torchbench&mode=inference&dtype=bfloat16&deviceName=cuda%20(a100)&lBranch=exclamaforte/prune-deleted-users&lCommit=5cb1aa6f7d8a52acdae0c7cf36b8c2d536d7f0d1&rBranch=main&rCommit=f4ee5a243dbb31e6310e5632b1c87898b299df2c
off of nov4 nightly

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139447
Approved by: https://github.com/eellison
2024-11-08 22:21:53 +00:00
PyTorch MergeBot
347f96061f Revert "[cpu] Modify inductor opt flag --- ftree-loop-vectorize (#136827)"
This reverts commit cf0bb6c435.

Reverted https://github.com/pytorch/pytorch/pull/136827 on behalf of https://github.com/ZainRizvi due to Sorry but this breaks internally. See D65605094 for more details ([comment](https://github.com/pytorch/pytorch/pull/136827#issuecomment-2465805271))
2024-11-08 21:52:33 +00:00
PyTorch MergeBot
a7724518c0 Revert "[Inductor][CPU] Fuse SmoothQuant int8 linear pattern (#139595)"
This reverts commit d72a308e77.

Reverted https://github.com/pytorch/pytorch/pull/139595 on behalf of https://github.com/ZainRizvi due to Sorry but the newly added tests in test_mkldnn_pattern_matcher.py fail internally. See D65661038 for more details ([comment](https://github.com/pytorch/pytorch/pull/139595#issuecomment-2465797016))
2024-11-08 21:45:52 +00:00
PyTorch MergeBot
80d0356b11 Revert "Make Context to be Device-agnostic Step by Step (2/N) (#136526)"
This reverts commit c03324de2d.

Reverted https://github.com/pytorch/pytorch/pull/136526 on behalf of https://github.com/ZainRizvi due to This fails to build internally. See D65604944 for more details ([comment](https://github.com/pytorch/pytorch/pull/136526#issuecomment-2465790157))
2024-11-08 21:40:10 +00:00
Zain Rizvi
411203e7c1 Revert D65490202 (#140142)
Summary:
This diff reverts D65490202
This is causing tests to fail on open source. See distributed/test_c10d_logger.py::C10dErrorLoggerTest::test_exception_logger [GH job link](https://github.com/pytorch/pytorch/actions/runs/11736922614/job/32697709457) [HUD commit link](ba9645f6e5)

Test Plan: NA

Differential Revision: D65663063

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140142
Approved by: https://github.com/malfet, https://github.com/huydhn
2024-11-08 21:22:32 +00:00