Commit Graph

1058 Commits

Author SHA1 Message Date
albanD
6ffc94fa62 Fix cpp node instance check (#125875)
Mostly visible when calling multi_grad_hook and thus using this to test it.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125875
Approved by: https://github.com/jackiexu1992, https://github.com/ezyang
2024-05-11 21:31:12 +00:00
soulitzer
fab5bd5359 [checkpoint] Improve error message when use_reentrant=True is used with .grad() (#125155)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125155
Approved by: https://github.com/albanD
2024-04-29 18:57:35 +00:00
Chen, Zejun
b1984237a0 [Profiler] Unify the device(CUDA, XPU, PrivateUse1) in torch profiler post processing (#123247)
This PR unifies the CUDA, XPU and PrivateUse1 in the torch profiler. Now CUDA, XPU and PrivateUse1 can together use string object `use_device` to distinguish each other and share one device path for calculating kineto time durations and memory statistics for post processing.

#suppress-api-compatibility-check

Co-authored-by: Aaron Enye Shi <enye.shi@gmail.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123247
Approved by: https://github.com/aaronenyeshi
2024-04-22 01:26:55 +00:00
Aaron Gokaslan
c5fafe9f48 [BE]: TRY002 - Ban raising vanilla exceptions (#124570)
Adds a ruff lint rule to ban raising raw exceptions. Most of these should at the very least be runtime exception, value errors, type errors or some other errors. There are hundreds of instance of these bad exception types already in the codebase, so I have noqa'd most of them. Hopefully this error code will get commiters to rethink what exception type they should raise when they submit a PR.

I also encourage people to gradually go and fix all the existing noqas that have been added so they can be removed overtime and our exception typing can be improved.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124570
Approved by: https://github.com/ezyang
2024-04-21 22:26:40 +00:00
Aaron Gokaslan
5a1216bb2e [BE]: Update ruff to 0.4.1 (#124549)
Update ruff to 0.4.1 .
This version fixes a lot false negatives/false positives, is 20-40% faster, and has various other bug fixes.

Below is a before and after table showing the execution time of ruff lint and ruff format in milliseconds courtesy of https://astral.sh/blog/ruff-v0.4.0

| Repository                                         | Linter (v0.3) | Linter (v0.4) | Formatter (v0.3) | Formatter (v0.4) |
|----------------------------------------------------|---------------|---------------|------------------|------------------|
| [pytorch/pytorch](https://github.com/pytorch/pytorch) | 328.7         | 251.8         | 351.1            | 274.9            |

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124549
Approved by: https://github.com/ezyang
2024-04-21 14:06:23 +00:00
PyTorch MergeBot
520bc1080e Revert "[Profiler] Unify the device(CUDA, XPU, PrivateUse1) in torch profiler post processing (#123247)"
This reverts commit 768ce2cdda.

Reverted https://github.com/pytorch/pytorch/pull/123247 on behalf of https://github.com/DanilBaibak due to Broken trunk ([comment](https://github.com/pytorch/pytorch/pull/123247#issuecomment-2066152611))
2024-04-19 09:09:03 +00:00
Chen, Zejun
768ce2cdda [Profiler] Unify the device(CUDA, XPU, PrivateUse1) in torch profiler post processing (#123247)
This PR unifies the CUDA, XPU and PrivateUse1 in the torch profiler. Now CUDA, XPU and PrivateUse1 can together use string object `use_device` to distinguish each other and share one device path for calculating kineto time durations and memory statistics for post processing.

#suppress-api-compatibility-check

Co-authored-by: Aaron Enye Shi <enye.shi@gmail.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123247
Approved by: https://github.com/aaronenyeshi, https://github.com/gujinghui
2024-04-19 03:31:13 +00:00
Yuanhao Ji
21f7cbdc1c Enable UFMT on test/test_autograd.py (#124141)
Part of: #123062

Ran lintrunner on:

- `test/test_autograd.py`

Detail:

```bash
$ lintrunner -a --take UFMT --all-files
ok No lint issues.
Successfully applied all patches.
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124141
Approved by: https://github.com/soulitzer
2024-04-18 00:16:23 +00:00
rzou
3d2d7ba19d Delete torch.autograd.function.traceable APIs (#122817)
We deprecated them in 2.3 with plans to delete in 2.4. Very few OSS
repos use this flag at all and it also does nothing.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122817
Approved by: https://github.com/albanD
2024-03-28 18:24:15 +00:00
Joel Schlosser
cd6bfc7965 Proper view support for jagged layout NestedTensor (#113279)
This PR:
* Introduces an ATen op for creating true jagged views from a dense values buffer
    * `_nested_view_from_jagged(values, offsets, lengths, ragged_idx, dummy)`
    * This ops is implemented on the Python side using torch.library so we can return a subclass instance
    * `jagged_from_list()` now uses this instead of the old autograd.Function `NestedViewFromBuffer`
    * The latter op is used for non-contiguous JTs returned via `torch.nested.narrow()`
    * `dummy` is an awful hack to ensure that `NestedTensor.__torch_dispatch__()` is invoked for our view
* Introduces an ATen op for accessing the `values` component of an NT via a view
    * `_nested_get_values(nt)`
* **Removes** the autograd.Functions `ViewNestedFromBuffer` and `ViewBufferFromNested` in favor of `nested_from_values_offsets()` / `nested_from_values_offsets_lengths()` and `nt.values()`, respectively.
* Changes test code to prefer `as_nested_tensor()` over `jagged_from_list()` directly
    * Similarly, avoid `buffer_from_jagged()`, preferring `values()`
* Depends on general subclass view fake-ification on the PT2 side (handled solely in previous PRs in the stack)

With these changes, the semantics of jagged layout NTs are such that they are considered a true view of the underlying `values` buffer. This means views of jagged NTs are views of the underlying buffer as well, simplifying some handling.

Differential Revision: [D54269922](https://our.internmc.facebook.com/intern/diff/D54269922)
Co-authored-by: voznesenskym <voznesenskym@gmail.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113279
Approved by: https://github.com/ezyang
2024-03-22 02:12:36 +00:00
PyTorch MergeBot
224beecee6 Revert "Proper view support for jagged layout NestedTensor (#113279)"
This reverts commit 5855c490f0.

Reverted https://github.com/pytorch/pytorch/pull/113279 on behalf of https://github.com/jbschlosser due to Need to fix BC thing ([comment](https://github.com/pytorch/pytorch/pull/113279#issuecomment-2013899762))
2024-03-21 22:03:01 +00:00
Joel Schlosser
5855c490f0 Proper view support for jagged layout NestedTensor (#113279)
This PR:
* Introduces an ATen op for creating true jagged views from a dense values buffer
    * `_nested_view_from_jagged(values, offsets, lengths, ragged_idx, dummy)`
    * This ops is implemented on the Python side using torch.library so we can return a subclass instance
    * `jagged_from_list()` now uses this instead of the old autograd.Function `NestedViewFromBuffer`
    * The latter op is used for non-contiguous JTs returned via `torch.nested.narrow()`
    * `dummy` is an awful hack to ensure that `NestedTensor.__torch_dispatch__()` is invoked for our view
* Introduces an ATen op for accessing the `values` component of an NT via a view
    * `_nested_get_values(nt)`
* **Removes** the autograd.Functions `ViewNestedFromBuffer` and `ViewBufferFromNested` in favor of `nested_from_values_offsets()` / `nested_from_values_offsets_lengths()` and `nt.values()`, respectively.
* Changes test code to prefer `as_nested_tensor()` over `jagged_from_list()` directly
    * Similarly, avoid `buffer_from_jagged()`, preferring `values()`
* Depends on general subclass view fake-ification on the PT2 side (handled solely in previous PRs in the stack)

With these changes, the semantics of jagged layout NTs are such that they are considered a true view of the underlying `values` buffer. This means views of jagged NTs are views of the underlying buffer as well, simplifying some handling.

Differential Revision: [D54269922](https://our.internmc.facebook.com/intern/diff/D54269922)
Co-authored-by: voznesenskym <voznesenskym@gmail.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113279
Approved by: https://github.com/ezyang
2024-03-20 23:45:34 +00:00
albanD
6791b0c09e Change default torch_function behavior to be disabled when torch_dispatch is defined (take 2) (#120632)
This does not introduce a new test but is tested by checking that all the classes we already have still behave as before now that they don't explicitly disable torch_function.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120632
Approved by: https://github.com/ezyang
2024-03-09 01:08:37 +00:00
rzou
b52e0bf131 Deprecate torch.autograd.function.traceable, is_traceable (#121413)
- There are no usages of this internally.
- There are very few usages of this in OSS (most of these are forks of old
repositories).
- This flag doesn't do anything.

We're deprecating it to prevent confusion. I will delete it immediately
after the branch cut.

Test Plan:
- new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121413
Approved by: https://github.com/albanD, https://github.com/soulitzer
2024-03-08 18:41:07 +00:00
Yu, Guangye
c2b2e57032 Intel GPU Runtime Upstreaming for Guard (#118523)
# Motivation
According to [[RFC] Intel GPU Runtime Upstreaming](https://github.com/pytorch/pytorch/issues/114842), the 5th runtime component we would like to upstream is `Guard`. We will cover device guard and stream guard in this PR.

# Design
Device guard is used mainly for op dispatcher in PyTorch. Currently, PyTorch already has a device guard abstraction `c10::impl::DeviceGuardImplInterface`. In our design, we will introduce an `XPUGuardImpl` class inherits from `c10::impl::DeviceGuardImplInterface`. Register `XPUGuardImpl` to PyTorch after we implement the device switch management mechanism in `XPUGuardImpl`. Besides, we will introduce `XPUGuard`, `OptionalXPUGuard`, `XPUStreamGuard`, and `OptionalXPUStreamGuard`. They are all following the design of CUDA's counterpart. The corresponding C++ file should be placed in c10/xpu/ folder.

# Additional Context
It is unnecessary to add `Guard` code to PyTorch frontend.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118523
Approved by: https://github.com/EikanWang, https://github.com/gujinghui, https://github.com/jgong5, https://github.com/malfet
ghstack dependencies: #120315
2024-02-22 14:07:21 +00:00
soulitzer
450339ab2d Test for fatal signal in test_pynode_destruction_deadlock (#120279)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120279
Approved by: https://github.com/albanD
2024-02-21 21:53:51 +00:00
Joel Schlosser
9ec8dd2467 Reify view_func() closures as ViewFuncs (#118404)
Replaces `view_func()` closures with a reified `ViewFunc` data structure. Codegen generates a `ViewFunc` subclass for each view op (e.g. `NarrowViewFunc`) containing state needed to reconstruct the view. The `ViewFunc` API allows for querying and hot-swapping any `SymInt`s or `Tensors` in the state through `get_symints()` / `get_tensors()` / `clone_and_set()`, which will be essential for fake-ification later on.

```cpp
/// Base class for view functions, providing reapplication of a view on a new base.
/// Each view op should get a codegenerated subclass of this class containing
/// any state needed to reconstruct the view. The class also provides convenience
/// accessors for saved SymInts / tensor state. This is useful for e.g. fake-ification,
/// where we want to use symbolic values or fake tensors instead.
struct TORCH_API ViewFunc {
  virtual ~ViewFunc() {}
  /// Returns any SymInts in the saved state.
  virtual std::vector<c10::SymInt> get_symints() const { return {}; }
  /// Returns the number of SymInts in the saved state.
  virtual size_t num_symints() const { return 0; }
  /// Returns any tensors in the saved state.
  virtual std::vector<at::Tensor> get_tensors() const { return {}; }
  /// Returns the number of tensors in the saved state.
  virtual size_t num_tensors() const { return 0; }
  /// Reapplies the view on the given base using the saved state.
  virtual at::Tensor operator()(const at::Tensor&) const = 0;
  /// Returns a clone of this ViewFunc, optionally with the specified saved state.
  virtual std::unique_ptr<ViewFunc> clone_and_set(
      std::optional<std::vector<c10::SymInt>> = c10::nullopt,
      std::optional<std::vector<at::Tensor>> = c10::nullopt) const = 0;

protected:
  /// Sets the values of any SymInts in the saved state. The input vector size must
  /// match the number of SymInts in the saved state (i.e. the size of the list
  /// returned by get_symints()).
  virtual void set_symints(std::vector<c10::SymInt>) {}
  /// Sets the values of any Tensors in the saved state. The input vector size must
  /// match the number of Tensors in the saved state (i.e. the size of the list
  /// returned by get_tensors()).
  virtual void set_tensors(std::vector<at::Tensor>) {}
};
```

New codegen files:
* `torch/csrc/autograd/generated/ViewFunc.h`
* `torch/csrc/autograd/generated/ViewFuncs.cpp`

The templates for these also contains impls for `ChainedViewFunc` and `ErroringViewFunc` which are used in a few places within autograd.

Example codegen for `slice.Tensor`:
```cpp
// torch/csrc/autograd/generated/ViewFuncs.h
#define SLICE_TENSOR_VIEW_FUNC_AVAILABLE
struct SliceTensorViewFunc : public torch::autograd::ViewFunc {
  SliceTensorViewFunc(int64_t dim, c10::optional<c10::SymInt> start, c10::optional<c10::SymInt> end, c10::SymInt step) : dim(dim), start(start), end(end), step(step)
  {};
  virtual ~SliceTensorViewFunc() override {};
  virtual std::vector<c10::SymInt> get_symints() const override;
  virtual size_t num_symints() const override;
  virtual std::vector<at::Tensor> get_tensors() const override;
  virtual size_t num_tensors() const override;
  virtual at::Tensor operator()(const at::Tensor&) const override;
  virtual std::unique_ptr<ViewFunc> clone_and_set(
      std::optional<std::vector<c10::SymInt>> = c10::nullopt,
      std::optional<std::vector<at::Tensor>> = c10::nullopt) const override;

protected:
  virtual void set_symints(std::vector<c10::SymInt>) override;
  virtual void set_tensors(std::vector<at::Tensor>) override;

private:
  int64_t dim;
  c10::optional<c10::SymInt> start;
  c10::optional<c10::SymInt> end;
  c10::SymInt step;
};
...

// torch/csrc/autograd/generated/ViewFuncs.cpp
std::vector<c10::SymInt> SliceTensorViewFunc::get_symints() const {
  ::std::vector<c10::SymInt> symints;
  symints.reserve((start.has_value() ? 1 : 0) + (end.has_value() ? 1 : 0) + 1);
  if(start.has_value()) symints.insert(symints.end(), *(start));
  if(end.has_value()) symints.insert(symints.end(), *(end));
  symints.push_back(step);
  return symints;
}

size_t SliceTensorViewFunc::num_symints() const {
  return static_cast<size_t>((start.has_value() ? 1 : 0) + (end.has_value() ? 1 : 0) + 1);
}

void SliceTensorViewFunc::set_symints(std::vector<c10::SymInt> symints) {
  TORCH_INTERNAL_ASSERT(symints.size() == num_symints());
  auto i = 0;
  if(start.has_value()) start = symints[i];
  i += (start.has_value() ? 1 : 0);
  if(end.has_value()) end = symints[i];
  i += (end.has_value() ? 1 : 0);
  step = symints[i];
}

std::vector<at::Tensor> SliceTensorViewFunc::get_tensors() const {
  ::std::vector<at::Tensor> tensors;
  return tensors;
}

size_t SliceTensorViewFunc::num_tensors() const {
  return static_cast<size_t>(0);
}

void SliceTensorViewFunc::set_tensors(std::vector<at::Tensor> tensors) {
  TORCH_INTERNAL_ASSERT(tensors.size() == num_tensors());

}

at::Tensor SliceTensorViewFunc::operator()(const at::Tensor& input_base) const {
  return at::_ops::slice_Tensor::call(input_base, dim, start, end, step);
}

std::unique_ptr<ViewFunc> SliceTensorViewFunc::clone_and_set(
    std::optional<std::vector<c10::SymInt>> symints,
    std::optional<std::vector<at::Tensor>> tensors) const {
  auto output = std::make_unique<SliceTensorViewFunc>(dim, start, end, step);
  if (symints.has_value()) {
    output->set_symints(std::move(*(symints)));
  }
  if (tensors.has_value()) {
    output->set_tensors(std::move(*(tensors)));
  }
  return output;
}
```

The `_view_func()` / `_view_func_unsafe()` methods now accept two additional (optional) args for `symint_visitor_fn` / `tensor_visitor_fn`. If these are defined, they are expected to be python callables that operate on a single SymInt / tensor and return a new one. This allows for the hot-swapping needed during fake-ification.

For testing, there are extensive pre-existing tests, and I added a test to ensure that hot-swapping functions correctly.
```sh
python test/test_autograd.py -k test_view_func_replay
python test/test_ops.py -k test_view_replay
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118404
Approved by: https://github.com/ezyang
2024-02-14 22:00:43 +00:00
PyTorch MergeBot
24bdd03d23 Revert "Reify view_func() closures as ViewFuncs (#118404)"
This reverts commit d5a6762263.

Reverted https://github.com/pytorch/pytorch/pull/118404 on behalf of https://github.com/DanilBaibak due to Broken trunk ([comment](https://github.com/pytorch/pytorch/pull/118404#issuecomment-1938600260))
2024-02-12 12:38:51 +00:00
Joel Schlosser
d5a6762263 Reify view_func() closures as ViewFuncs (#118404)
Replaces `view_func()` closures with a reified `ViewFunc` data structure. Codegen generates a `ViewFunc` subclass for each view op (e.g. `NarrowViewFunc`) containing state needed to reconstruct the view. The `ViewFunc` API allows for querying and hot-swapping any `SymInt`s or `Tensors` in the state through `get_symints()` / `get_tensors()` / `clone_and_set()`, which will be essential for fake-ification later on.

```cpp
/// Base class for view functions, providing reapplication of a view on a new base.
/// Each view op should get a codegenerated subclass of this class containing
/// any state needed to reconstruct the view. The class also provides convenience
/// accessors for saved SymInts / tensor state. This is useful for e.g. fake-ification,
/// where we want to use symbolic values or fake tensors instead.
struct TORCH_API ViewFunc {
  virtual ~ViewFunc() {}
  /// Returns any SymInts in the saved state.
  virtual std::vector<c10::SymInt> get_symints() const { return {}; }
  /// Returns the number of SymInts in the saved state.
  virtual size_t num_symints() const { return 0; }
  /// Returns any tensors in the saved state.
  virtual std::vector<at::Tensor> get_tensors() const { return {}; }
  /// Returns the number of tensors in the saved state.
  virtual size_t num_tensors() const { return 0; }
  /// Reapplies the view on the given base using the saved state.
  virtual at::Tensor operator()(const at::Tensor&) const = 0;
  /// Returns a clone of this ViewFunc, optionally with the specified saved state.
  virtual std::unique_ptr<ViewFunc> clone_and_set(
      std::optional<std::vector<c10::SymInt>> = c10::nullopt,
      std::optional<std::vector<at::Tensor>> = c10::nullopt) const = 0;

protected:
  /// Sets the values of any SymInts in the saved state. The input vector size must
  /// match the number of SymInts in the saved state (i.e. the size of the list
  /// returned by get_symints()).
  virtual void set_symints(std::vector<c10::SymInt>) {}
  /// Sets the values of any Tensors in the saved state. The input vector size must
  /// match the number of Tensors in the saved state (i.e. the size of the list
  /// returned by get_tensors()).
  virtual void set_tensors(std::vector<at::Tensor>) {}
};
```

New codegen files:
* `torch/csrc/autograd/generated/ViewFunc.h`
* `torch/csrc/autograd/generated/ViewFuncs.cpp`

The templates for these also contains impls for `ChainedViewFunc` and `ErroringViewFunc` which are used in a few places within autograd.

Example codegen for `slice.Tensor`:
```cpp
// torch/csrc/autograd/generated/ViewFuncs.h
#define SLICE_TENSOR_VIEW_FUNC_AVAILABLE
struct SliceTensorViewFunc : public torch::autograd::ViewFunc {
  SliceTensorViewFunc(int64_t dim, c10::optional<c10::SymInt> start, c10::optional<c10::SymInt> end, c10::SymInt step) : dim(dim), start(start), end(end), step(step)
  {};
  virtual ~SliceTensorViewFunc() override {};
  virtual std::vector<c10::SymInt> get_symints() const override;
  virtual size_t num_symints() const override;
  virtual std::vector<at::Tensor> get_tensors() const override;
  virtual size_t num_tensors() const override;
  virtual at::Tensor operator()(const at::Tensor&) const override;
  virtual std::unique_ptr<ViewFunc> clone_and_set(
      std::optional<std::vector<c10::SymInt>> = c10::nullopt,
      std::optional<std::vector<at::Tensor>> = c10::nullopt) const override;

protected:
  virtual void set_symints(std::vector<c10::SymInt>) override;
  virtual void set_tensors(std::vector<at::Tensor>) override;

private:
  int64_t dim;
  c10::optional<c10::SymInt> start;
  c10::optional<c10::SymInt> end;
  c10::SymInt step;
};
...

// torch/csrc/autograd/generated/ViewFuncs.cpp
std::vector<c10::SymInt> SliceTensorViewFunc::get_symints() const {
  ::std::vector<c10::SymInt> symints;
  symints.reserve((start.has_value() ? 1 : 0) + (end.has_value() ? 1 : 0) + 1);
  if(start.has_value()) symints.insert(symints.end(), *(start));
  if(end.has_value()) symints.insert(symints.end(), *(end));
  symints.push_back(step);
  return symints;
}

size_t SliceTensorViewFunc::num_symints() const {
  return static_cast<size_t>((start.has_value() ? 1 : 0) + (end.has_value() ? 1 : 0) + 1);
}

void SliceTensorViewFunc::set_symints(std::vector<c10::SymInt> symints) {
  TORCH_INTERNAL_ASSERT(symints.size() == num_symints());
  auto i = 0;
  if(start.has_value()) start = symints[i];
  i += (start.has_value() ? 1 : 0);
  if(end.has_value()) end = symints[i];
  i += (end.has_value() ? 1 : 0);
  step = symints[i];
}

std::vector<at::Tensor> SliceTensorViewFunc::get_tensors() const {
  ::std::vector<at::Tensor> tensors;
  return tensors;
}

size_t SliceTensorViewFunc::num_tensors() const {
  return static_cast<size_t>(0);
}

void SliceTensorViewFunc::set_tensors(std::vector<at::Tensor> tensors) {
  TORCH_INTERNAL_ASSERT(tensors.size() == num_tensors());

}

at::Tensor SliceTensorViewFunc::operator()(const at::Tensor& input_base) const {
  return at::_ops::slice_Tensor::call(input_base, dim, start, end, step);
}

std::unique_ptr<ViewFunc> SliceTensorViewFunc::clone_and_set(
    std::optional<std::vector<c10::SymInt>> symints,
    std::optional<std::vector<at::Tensor>> tensors) const {
  auto output = std::make_unique<SliceTensorViewFunc>(dim, start, end, step);
  if (symints.has_value()) {
    output->set_symints(std::move(*(symints)));
  }
  if (tensors.has_value()) {
    output->set_tensors(std::move(*(tensors)));
  }
  return output;
}
```

The `_view_func()` / `_view_func_unsafe()` methods now accept two additional (optional) args for `symint_visitor_fn` / `tensor_visitor_fn`. If these are defined, they are expected to be python callables that operate on a single SymInt / tensor and return a new one. This allows for the hot-swapping needed during fake-ification.

For testing, there are extensive pre-existing tests, and I added a test to ensure that hot-swapping functions correctly.
```sh
python test/test_autograd.py -k test_view_func_replay
python test/test_ops.py -k test_view_replay
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118404
Approved by: https://github.com/ezyang
2024-02-09 18:51:36 +00:00
Andrew Gu
d6b556bd98 Added "any" mode to register_multi_grad_hook (#117984)
This is a re-open of https://github.com/pytorch/pytorch/pull/115628/. This PR adds an `"any"` option to `register_multi_grad_hook` that runs the hook when the gradient of _any_ of the input tensors is computed. The existing functionality is folded under the default `"all"` mode.

The multi-threaded test case is based on the existing one for `register_multi_grad_hook`. I would appreciate a closer look on that. ~~I am not sure about the hook signature (i.e. why we see two gradients in the hook that runs instead of just one, as [`register_hook`](https://pytorch.org/docs/stable/generated/torch.Tensor.register_hook.html) docs suggest).~~ It was because I was iterating over the 2 elements in the single tensor 😢 .

I did not update the `notes/autograd.rst`, which currently has a [blurb](https://pytorch.org/docs/stable/notes/autograd.html#special-hooks) on `register_multi_grad_hook`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117984
Approved by: https://github.com/soulitzer
ghstack dependencies: #117994, #118186
2024-01-25 16:25:52 +00:00
soulitzer
67300a11cb Support custom autograd Function forward AD return non-Tensor in forward (#118234)
Fixes https://github.com/pytorch/pytorch/issues/117491

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118234
Approved by: https://github.com/albanD
ghstack dependencies: #117552
2024-01-25 03:24:29 +00:00
soulitzer
5b819d9ef0 Properly move retains_grad hook on in-place over view for base (#117552)
Fixes https://github.com/pytorch/pytorch/issues/117366
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117552
Approved by: https://github.com/albanD
2024-01-25 00:27:13 +00:00
rzou
db1a6eda9e [codemod] markDynamoStrictTest batch 22 (#117729)
[codemod] markDynamoStrictTest test_autograd
[codemod] markDynamoStrictTest test_ao_sparsity
[codemod] markDynamoStrictTest test_jit
[codemod] markDynamoStrictTest test_quantization
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117729
Approved by: https://github.com/bdhirsh
2024-01-18 16:59:26 +00:00
soulitzer
5866284d4a Make not passing use_reentrant back to warning instead of erroring and clarify docs (#116710)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116710
Approved by: https://github.com/albanD
ghstack dependencies: #116523
2024-01-09 20:58:49 +00:00
Joel Schlosser
3c21264c9b Introduce reverse view_funcs (#115894)
Part 2 of implementation for general [subclass view fake-ification](https://docs.google.com/document/d/1C5taWiplmX7nKiURXDOAZG2W5VNJ2iV0fQFq92H0Cxw).

Details:
* Codegen `rev_view_func()` alongside `view_func()`
    * Reverse view_func gives you a "base" from a "view": `rev_view_func(new_view) -> new_base` AKA it plays the original view backwards
* Utilizes the functional inverses defined in `FunctionalInverses.cpp`, passing `InverseReturnMode::AlwaysView`
* Manually implements functional inverses for `narrow()` and `chunk()`
* **NB: Multi-output views now set view_func() / rev_view_func() for each of the output views!**
    * Due to this, the `as_view()` overload that operates on a list of views is scrapped in favor of iteration via codegen

Example codegen in `ADInplaceOrViewTypeN.cpp`:
```cpp
at::Tensor narrow(c10::DispatchKeySet ks, const at::Tensor & self, int64_t dim, c10::SymInt start, c10::SymInt length) {
  auto _tmp = ([&]() {
    at::AutoDispatchBelowADInplaceOrView guard;
    return at::_ops::narrow::redispatch(ks & c10::after_ADInplaceOrView_keyset, self, dim, start, length);
  })();
  std::function<at::Tensor(const at::Tensor&)> func=nullptr;
  std::function<at::Tensor(const at::Tensor&)> rev_func=nullptr;
  if (false || !self.unsafeGetTensorImpl()->support_as_strided() ||
      c10::AutogradState::get_tls_state().get_view_replay_enabled()) {
    func = [=](const at::Tensor& input_base) {
      return at::_ops::narrow::call(input_base, dim, start, length);
    };
    rev_func = [=](const at::Tensor& input_view) {
      // NB: args from narrow() signature are passed along to the inverse
      return at::functionalization::FunctionalInverses::narrow_copy_inverse(self, input_view, at::functionalization::InverseReturnMode::AlwaysView, dim, start, length);
    };
  }
  auto result = as_view(/* base */ self, /* output */ _tmp, /* is_bw_differentiable */ true, /* is_fw_differentiable */ true, /* view_func */ func, /* rev_view_func */ rev_func, /* creation_meta */ InferenceMode::is_enabled() ? CreationMeta::INFERENCE_MODE : (at::GradMode::is_enabled() ? CreationMeta::DEFAULT : CreationMeta::NO_GRAD_MODE));
  return result;
}
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115894
Approved by: https://github.com/soulitzer
2024-01-05 16:48:12 +00:00
Aaron Gokaslan
3fe437b24b [BE]: Update flake8 to v6.1.0 and fix lints (#116591)
Updates flake8 to v6.1.0 and fixes a few lints using sed and some ruff tooling.
- Replace `assert(0)` with `raise AssertionError()`
- Remove extraneous parenthesis i.e.
  - `assert(a == b)` -> `assert a == b`
  - `if(x > y or y < z):`->`if x > y or y < z:`
  - And `return('...')` -> `return '...'`

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116591
Approved by: https://github.com/albanD, https://github.com/malfet
2024-01-03 06:04:44 +00:00
soulitzer
4d6a1ad400 Activation checkpoint and checkpoint_sequential errors if use_reentrant not passed explicitly (#115868)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115868
Approved by: https://github.com/albanD
ghstack dependencies: #115438
2023-12-20 15:23:44 +00:00
soulitzer
cfb3cd11c1 Add basic autograd TORCH_LOGS support (#115438)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115438
Approved by: https://github.com/albanD
2023-12-20 15:23:44 +00:00
Aaron Gokaslan
794545c11f [BE]: Enable RUF015 codebase wide (#115507)
Constant time access of first value in collection. This is a constant time operation instead of converting the item to a list to get the first item which is linear. The rule is turned on which automatically autofixes and enforces this.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115507
Approved by: https://github.com/malfet
2023-12-11 15:51:01 +00:00
Xuehai Pan
55064a4ef9 [BE] add parentheses to kwargs unpacking func(*args, **(kwargs or {})) (#115026)
This PR adds parentheses to kwargs unpacking `func(*args, **(kwargs or {}))` for better code readability.

With/without the parentheses are semantic equivalent because they produce the same bytecode.

```console
$ echo "func(*args, **kwargs or {})" | python3 -m dis -
  0           0 RESUME                   0

  1           2 PUSH_NULL
              4 LOAD_NAME                0 (func)
              6 LOAD_NAME                1 (args)
              8 BUILD_MAP                0
             10 LOAD_NAME                2 (kwargs)
             12 JUMP_IF_TRUE_OR_POP      1 (to 16)
             14 BUILD_MAP                0
        >>   16 DICT_MERGE               1
             18 CALL_FUNCTION_EX         1
             20 POP_TOP
             22 LOAD_CONST               0 (None)
             24 RETURN_VALUE

$ echo "func(*args, **(kwargs or {}))" | python3 -m dis -
  0           0 RESUME                   0

  1           2 PUSH_NULL
              4 LOAD_NAME                0 (func)
              6 LOAD_NAME                1 (args)
              8 BUILD_MAP                0
             10 LOAD_NAME                2 (kwargs)
             12 JUMP_IF_TRUE_OR_POP      1 (to 16)
             14 BUILD_MAP                0
        >>   16 DICT_MERGE               1
             18 CALL_FUNCTION_EX         1
             20 POP_TOP
             22 LOAD_CONST               0 (None)
             24 RETURN_VALUE
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115026
Approved by: https://github.com/Skylion007
2023-12-03 20:03:26 +00:00
Tobias Ringwald
b6df841460 Fixed an issue where a user-specified default device clashed with the… (#114560)
… device placement of the RNG. This PR now ignores the user-specified default device, allocates the tensor on the CPU and then moves the tensor to the device of the input tensor. This was more or less already the standard procedure in case the default device wasn't set.

Fixes #114536.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114560
Approved by: https://github.com/soulitzer
2023-11-29 17:45:49 +00:00
Ying Liu
85b97605ab Enable set sequence nr (#114120)
Summary:
In some cases (especially those involving collective calls) - we would want to always kick off a collective call first before running going down another path.

For  example:

```
tbe lookup -> a2a ->
                     overarch
dense ------------->
```

if the forward code is written as
a2a_out = a2a
dense = dense_net
out = overarch(a2a_out, dense)
out.backward()

The current default is running backwards in the opposite order the forward is called. However, there is no data dependency between a2a and dense, so in reality either of them could be run first. We would like the a2a to run first because it provides optimal (on average) overlap.

Changing the seq_nr of a2a_out to something large enough would allow autograd engine to kick it off first.

Test Plan: Tests incoming

Differential Revision: D51445261

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114120
Approved by: https://github.com/ezyang, https://github.com/albanD
2023-11-21 19:47:28 +00:00
soulitzer
c1d9d4a2b5 checkpoint_sequential warns if use_reentrant not passed explicitly (#114158)
Use warning text for deprecation message.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114158
Approved by: https://github.com/albanD
2023-11-20 23:08:44 +00:00
soulitzer
c435b8c10a Fix autograd engine callback error propagation from device thread (#113702)
The existing try-catch doesn't work because it doesn't call err.persist(). This is in contrast to the try-catch for evaluate_function which does work because it calls into python_engine's thread_on_exception which calls persist.

Calling persist on a python_error stashes the PyErr state from the thread-local PyThreadState onto the python_error object, so that when this error object is stored onto the future and passed back to the calling cpu thread, python_engine's execute try-catch can then err.restore() the error state. Finally, the python_engine's execute would re-raise so that this is re-caught by the HANDLE_TH_ERRORS macro.

Fixes https://github.com/pytorch/pytorch/issues/75750

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113702
Approved by: https://github.com/albanD
2023-11-17 20:17:02 +00:00
soulitzer
3e3c6cc05e Do not error when printing view created in no-grad modified in-place in no-grad (#113716)
Fixes https://github.com/pytorch/pytorch/issues/99968

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113716
Approved by: https://github.com/albanD
2023-11-16 18:57:56 +00:00
Jon Chuang
5ccd22502f [contextlib] Wrapping a function with set_grad_enabled will consume its global mutation (#113359)
Fixes https://github.com/pytorch/pytorch/issues/113298

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113359
Approved by: https://github.com/soulitzer, https://github.com/jansel
2023-11-09 19:16:20 +00:00
PyTorch MergeBot
b0087b4cf7 Revert "record_function: remove legacy internal operators (#72303)"
This reverts commit 0be84bb41e.

Reverted https://github.com/pytorch/pytorch/pull/72303 on behalf of https://github.com/izaitsevfb due to Apparently _record_function_enter is still used internally at Meta in several places and in lots of internal tests. ([comment](https://github.com/pytorch/pytorch/pull/72303#issuecomment-1777942975))
2023-10-24 20:01:14 +00:00
Peter Bell
0be84bb41e record_function: remove legacy internal operators (#72303)
These operators have not been used since #76420 but were preserved for TorchScript backward compatibility

Pull Request resolved: https://github.com/pytorch/pytorch/pull/72303
Approved by: https://github.com/albanD
ghstack dependencies: #104535
2023-10-23 22:55:05 +00:00
albanD
5e8be63e99 Allow specifiying inputs as GradientEdge in autograd APIs (#110867)
This can be useful for advanced users (like AOTAutograd) who don't want to keep the corresponding Tensor alive (for memory reasons for example) or when inplace op will change the Tensor's grad_fn (but gradients wrt to the original value is needed).

I went minimal API change but open to suggestions.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110867
Approved by: https://github.com/soulitzer
2023-10-12 04:08:44 +00:00
soulitzer
73f4c1a406 [reland2] Update custom Function preserve torch function when inputs … (#110895)
…returned as-is

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110895
Approved by: https://github.com/albanD
2023-10-11 21:37:19 +00:00
soulitzer
c9eb8d8d90 Add set_checkpoint_debug_enabled that overrides local setting (#110728)
People access activation checkpoint through many layers of config and it is not always guaranteed that all the layers of wrapping around checkpoint properly propagate all the kwargs, e.g. debug mode. This context manager offers an alternative way to enable debug mode that bypasses the need for all layers to propagate kwargs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110728
Approved by: https://github.com/albanD
ghstack dependencies: #110673, #110674, #110675, #110676
2023-10-11 02:12:31 +00:00
PyTorch MergeBot
d1c157c598 Revert "[reland] Update custom Function preserve torch function when inputs r… (#110679)"
This reverts commit 563728f61c.

Reverted https://github.com/pytorch/pytorch/pull/110679 on behalf of https://github.com/kit1980 due to The diff has Meta-internal changes, please land from Phabricator ([comment](https://github.com/pytorch/pytorch/pull/110679#issuecomment-1753523182))
2023-10-09 19:09:01 +00:00
soulitzer
563728f61c [reland] Update custom Function preserve torch function when inputs r… (#110679)
…eturned as-is

reland of https://github.com/pytorch/pytorch/pull/109825#issuecomment-1749803837

Opening this without ghstack to do codev. In our PR, we changed the signature of `_wrap_outputs`. There is some internal code that calls `_wrap_outputs` directly, so we also need to update that callsite.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110679
Approved by: https://github.com/albanD
2023-10-07 00:27:45 +00:00
PyTorch MergeBot
236afe73a2 Revert "Update custom Function preserve torch function when inputs returned as-is (#109825)"
This reverts commit 4e73eee93f.

Reverted https://github.com/pytorch/pytorch/pull/109825 on behalf of https://github.com/PaliC due to causing a plethora of internal failures ([comment](https://github.com/pytorch/pytorch/pull/109825#issuecomment-1749802739))
2023-10-05 23:49:41 +00:00
soulitzer
4e73eee93f Update custom Function preserve torch function when inputs returned as-is (#109825)
Fixes https://github.com/pytorch/pytorch/issues/109805
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109825
Approved by: https://github.com/albanD
2023-10-04 22:45:11 +00:00
FFFrog
70f2adaec3 Setup_context does not contain default values of forward() (#108561)
Fixes #108529

As the title shown.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108561
Approved by: https://github.com/soulitzer
2023-09-19 16:23:52 +00:00
soulitzer
3efc1882e8 Update CopySlices to not internal assert when grad_output is undefined (#108353)
Fixes https://github.com/pytorch/pytorch/issues/107928

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108353
Approved by: https://github.com/albanD
ghstack dependencies: #107296, #107349
2023-09-11 16:26:05 +00:00
Jane Xu
6e71ad0509 Add tensor post accumulate grad hook API (#107063)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107063
Approved by: https://github.com/albanD, https://github.com/soulitzer
2023-08-24 00:19:35 +00:00
PyTorch MergeBot
432fce4e0d Revert "Add tensor post accumulate grad hook API (#107063)"
This reverts commit 3f655277d4.

Reverted https://github.com/pytorch/pytorch/pull/107063 on behalf of https://github.com/ZainRizvi due to Diff train weirdness. Need to temporarily revert this PR and will right land it soon afterwards ([comment](https://github.com/pytorch/pytorch/pull/107063#issuecomment-1690799057))
2023-08-24 00:12:34 +00:00
Jane Xu
3f655277d4 Add tensor post accumulate grad hook API (#107063)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107063
Approved by: https://github.com/albanD, https://github.com/soulitzer
2023-08-22 15:15:57 +00:00