Commit Graph

165 Commits

Author SHA1 Message Date
IvanKobzarev
a37afd23fa [custom_ops][perf] Move expensive pytree traversals of tensors to C++ (#148555)
(benchmark for 1 call)

Before:
```
└─ $ python ~/task_custom_ops_perf/test_custom_ops_perf_repro.py
DO_BENCH mutate: 77.72445678710938 us PROFILE:/home/ivankobzarev/task_custom_ops_perf/mutate.json
DO_BENCH no_mutate: 64.61143493652344 us PROFILE:/home/ivankobzarev/task_custom_ops_perf/no_mutate.json
DO_BENCH direct_mutate: 11.682510375976562 us PROFILE:/home/ivankobzarev/task_custom_ops_perf/direct_mutate.json
DO_BENCH direct_no_mutate: 18.596649169921875 us PROFILE:/home/ivankobzarev/task_custom_ops_perf/direct_no_mutate.json
```

After:
```
└─ $ python ~/task_custom_ops_perf/test_custom_ops_perf_repro.py
DO_BENCH mutate: 47.6837158203125 us PROFILE:/home/ivankobzarev/task_custom_ops_perf/mutate.json
DO_BENCH no_mutate: 31.709671020507812 us PROFILE:/home/ivankobzarev/task_custom_ops_perf/no_mutate.json
DO_BENCH direct_mutate: 10.967254638671875 us PROFILE:/home/ivankobzarev/task_custom_ops_perf/direct_mutate.json
DO_BENCH direct_no_mutate: 10.728836059570312 us PROFILE:/home/ivankobzarev/task_custom_ops_perf/direct_no_mutate.json
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148555
Approved by: https://github.com/zou3519
2025-04-01 18:45:48 +00:00
PyTorch MergeBot
d256b2dcb2 Revert "[custom_ops][perf] Move expensive pytree traversals of tensors to C++ (#148555)"
This reverts commit d686d04c2f.

Reverted https://github.com/pytorch/pytorch/pull/148555 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/148555#issuecomment-2753283221))
2025-03-26 05:27:52 +00:00
IvanKobzarev
d686d04c2f [custom_ops][perf] Move expensive pytree traversals of tensors to C++ (#148555)
(benchmark for 1 call)

Before:
```
└─ $ python ~/task_custom_ops_perf/test_custom_ops_perf_repro.py
DO_BENCH mutate: 77.72445678710938 us PROFILE:/home/ivankobzarev/task_custom_ops_perf/mutate.json
DO_BENCH no_mutate: 64.61143493652344 us PROFILE:/home/ivankobzarev/task_custom_ops_perf/no_mutate.json
DO_BENCH direct_mutate: 11.682510375976562 us PROFILE:/home/ivankobzarev/task_custom_ops_perf/direct_mutate.json
DO_BENCH direct_no_mutate: 18.596649169921875 us PROFILE:/home/ivankobzarev/task_custom_ops_perf/direct_no_mutate.json
```

After:
```
└─ $ python ~/task_custom_ops_perf/test_custom_ops_perf_repro.py
DO_BENCH mutate: 47.6837158203125 us PROFILE:/home/ivankobzarev/task_custom_ops_perf/mutate.json
DO_BENCH no_mutate: 31.709671020507812 us PROFILE:/home/ivankobzarev/task_custom_ops_perf/no_mutate.json
DO_BENCH direct_mutate: 10.967254638671875 us PROFILE:/home/ivankobzarev/task_custom_ops_perf/direct_mutate.json
DO_BENCH direct_no_mutate: 10.728836059570312 us PROFILE:/home/ivankobzarev/task_custom_ops_perf/direct_no_mutate.json
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148555
Approved by: https://github.com/zou3519
2025-03-19 17:16:57 +00:00
William Wen
16e202a38e [dynamo] improved graph break messages for some common graph break sites [1/N] (#146525)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146525
Approved by: https://github.com/jansel
2025-02-20 00:08:13 +00:00
rzou
98b5d455fd [opcheck] Improve error reporting; allow atol/rtol overrides (#146488)
This PR improves opcheck to:
1. directly use torch.testing.assert_close (without a msg override).
   This allows it to print the absolute and relative differences and the
   number of mismatched elements.
2. take in an atol/rtol tolerance (for if someone just wants to use
   opcheck in their testing).

Test Plan:
- tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146488
Approved by: https://github.com/williamwen42
2025-02-05 21:25:06 +00:00
Aaron Orenstein
2f24f2eb46 Make sure to evaluate annotation strings in the context of where the prototype was created (#145667)
This was incorrectly evaluating the annotation in the context of infer_schema - make sure to evaluate annotation strings in the context of where the prototype was created instead.

Fixes #145481

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145667
Approved by: https://github.com/zou3519
2025-01-29 00:14:45 +00:00
Yanbo Liang
ec91b7720f [Custom Ops] Add a new API to allow users to register an autocast for the custom op (#145588)
Fixes #137033

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145588
Approved by: https://github.com/zou3519
2025-01-27 19:22:43 +00:00
Aaron Orenstein
a79100ab11 PEP585 update - torch/_dynamo (#145105)
See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145105
Approved by: https://github.com/bobrenjc93
2025-01-18 20:47:11 +00:00
Edward Z. Yang
fd8b217fcd Pass allow_rhs_unbacked to the stride test in metadata test too (#143040)
Fixes https://github.com/pytorch/pytorch/issues/142410

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143040
Approved by: https://github.com/bobrenjc93
2024-12-19 09:37:50 +00:00
Tom Ritchford
d8c8ba2440 Fix unused Python variables in test/[e-z]* (#136964)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136964
Approved by: https://github.com/justinchuby, https://github.com/albanD
2024-12-18 23:02:30 +00:00
zeshengzong
cb71bcc542 Replace clone.detach with detach.clone (#140264)
Fixes #64532

As state in issue, replace `clone.detach` by `detach.clone`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140264
Approved by: https://github.com/soulitzer
2024-11-13 07:01:02 +00:00
rzou
85c3c4132d no-op torch.library.custom_op APIs on torch.deploy (#139509)
We forgot this case in the previous PR. Fixes
https://github.com/pytorch/pytorch/issues/137536

Test Plan:
- better tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139509
Approved by: https://github.com/williamwen42
2024-11-04 18:01:08 +00:00
Simon Fan
99608ceed6 Scoped extension building for C++ backed custom ops tests (#136695)
FIXES #125579 #131103 #133197 #133283 #134738 #135369 #135685

Tests that create C++ extensions can cause flakiness in CI due to library namespace conflict and test ordering. We can build them in temp dirs to ensure isolation.

An alternative is to build these as part of the build process and have build time errors.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136695
Approved by: https://github.com/zou3519
2024-10-26 07:41:00 +00:00
rzou
f500cb43bb Fix torch.library.register_vmap (#137306)
We didn't support multiple levels of vmap. The main problem is, during
the batching rule, we need to exclude the vmap dispatch key
(FuncTorchBatched) like how our C++ batching rules do it.

Test Plan:
- new test
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137306
Approved by: https://github.com/Chillee
2024-10-04 03:46:35 +00:00
rzou
e4d32d2194 Improve data-dependent-output meta kernel error message (#136671)
Test Plan:
- code reading
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136671
Approved by: https://github.com/williamwen42
2024-09-26 03:46:04 +00:00
rzou
d0456b4274 noop on torch.library APIs under torch::deploy (multipy) (#136645)
Fixes https://github.com/pytorch/pytorch/issues/136177

The motivation is that torch::deploy doesn't handle this well. The
workaround for users is to use C++ custom ops.

All torch.library APIs ultimately go through the torch.library.Library
object, so we add checks to noop for torch::deploy there.

Test Plan:
- new test
- going to test this internally and hope nothing breaks.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136645
Approved by: https://github.com/ezyang
2024-09-26 02:34:34 +00:00
Aaron Orenstein
8c356ce3da Fix lint errors in fbcode (#135614)
Summary: Fixed a bunch of fbcode imports that happened to work but confused autodeps.  After this autodeps still suggests "improvements" to TARGETS (which breaks our builds) but at least it can find all the imports.

Test Plan:
```
fbpython fbcode/tools/build/buck/linters/lint_autoformat.py --linter=autodeps --default-exec-timeout=1800 -- fbcode/caffe2/TARGETS fbcode/caffe2/test/TARGETS
```
Before:
```
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "test_export" (from caffe2/test/export/testing.py:229) when processing rule "test_export". Please make sure it's listed in the srcs parameter of another rule. See https://fbur$
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "testing" (from caffe2/test/export/test_export.py:87) when processing rule "test_export". Please make sure it's listed in the srcs parameter of another rule. See https://fburl$
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "test_export" (from caffe2/test/export/test_serdes.py:9) when processing rule "test_export". Please make sure it's listed in the srcs parameter of another rule. See https://fb$
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "testing" (from caffe2/test/export/test_serdes.py:10) when processing rule "test_export". Please make sure it's listed in the srcs parameter of another rule. See https://fburl$
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "testing" (from caffe2/test/export/test_retraceability.py:7) when processing rule "test_export". Please make sure it's listed in the srcs parameter of another rule. See https:$
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "test_export" (from caffe2/test/export/test_retraceability.py:6) when processing rule "test_export". Please make sure it's listed in the srcs parameter of another rule. See ht$
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "testing" (from caffe2/test/export/test_export_nonstrict.py:7) when processing rule "test_export". Please make sure it's listed in the srcs parameter of another rule. See http$
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "test_export" (from caffe2/test/export/test_export_nonstrict.py:6) when processing rule "test_export". Please make sure it's listed in the srcs parameter of another rule. See $
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "test_export" (from caffe2/test/export/test_export_training_ir_to_run_decomp.py:8) when processing rule "test_export". Please make sure it's listed in the srcs parameter of an$
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "testing" (from caffe2/test/export/test_export_training_ir_to_run_decomp.py:10) when processing rule "test_export". Please make sure it's listed in the srcs parameter of anoth$
ERROR while processing caffe2/test/TARGETS: Found "//python/typeshed_internal:typeshed_internal_library" owner for "cv2" but it is protected by visibility rules: [] (from caffe2/test/test_bundled_images.py:7) when processing rule "test_bundled_$
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "caffe2.test.profiler_test_cpp_thread_lib" (from caffe2/test/profiler/test_cpp_thread.py:29) when processing rule "profiler_test_cpp_thread". Please make sure it's listed in t$
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "torch._utils_internal.get_file_path_2" (from caffe2/test/test_custom_ops.py:23) when processing rule "custom_ops". Please make sure it's listed in the srcs parameter of anoth$
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "torch._utils_internal.get_file_path_2" (from caffe2/test/test_public_bindings.py:13) when processing rule "public_bindings". Please make sure it's listed in the srcs paramete$
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "torch._C._profiler.symbolize_tracebacks" (from caffe2/test/test_cuda.py:3348) when processing rule "test_cuda". Please make sure it's listed in the srcs parameter of another $
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "torch._C._profiler.gather_traceback" (from caffe2/test/test_cuda.py:3348) when processing rule "test_cuda". Please make sure it's listed in the srcs parameter of another rule$
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for include <torch/csrc/autograd/profiler_kineto.h> (from caffe2/test/profiler/test_cpp_thread.cpp:2) when processing profiler_test_cpp_thread_lib.  Some things to try:
```

Differential Revision: D62049222

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135614
Approved by: https://github.com/oulgen, https://github.com/laithsakka
2024-09-13 02:04:34 +00:00
rzou
f65a564fa2 [inductor] Flip custom_op_default_layout_constraint (#135239)
By default, Inductor should respect the stride order of input Tensors to
custom operators.

Test Plan:
- new tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135239
Approved by: https://github.com/albanD
ghstack dependencies: #135391
2024-09-10 14:27:43 +00:00
rzou
ad29a2c0dc Add Inductor config for default stride behavior (#135238)
By default, Inductor is allowed to manipulate the layout
(strides+storage offset) of input tensors to custom operators.

We want to change it so that the default is that Inductor should respect
the stride order of input tensors to custom operators.

This PR adds a config to toggle the behavior, in the next PR up we'll
change the default. We also make the following changes:
- We add a new operator Tag (flexible_layout), which means that
inductor is allowed to manipulate the layout. When we flip the default,
users can specify they want the old behavior by using this tag.

This is a reland of https://github.com/pytorch/pytorch/pull/126986,
which was previously reverted due to silent incorrectness. We've since
fixed the silent incorrectness
(https://github.com/pytorch/pytorch/pull/133639)

Test Plan:
- new test

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135238
Approved by: https://github.com/albanD
2024-09-06 14:48:24 +00:00
Xuehai Pan
4226ed1585 [BE] Format uncategorized Python files with ruff format (#132576)
Remove patterns `**`, `test/**`, and `torch/**` in `tools/linter/adapters/pyfmt_linter.py` and run `lintrunner`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132576
Approved by: https://github.com/ezyang, https://github.com/Skylion007
ghstack dependencies: #132574
2024-08-04 17:13:31 +00:00
Oguz Ulgen
221350e3a4 Add None return type to init -- tests (#132352)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132352
Approved by: https://github.com/ezyang
ghstack dependencies: #132335, #132351
2024-08-01 15:44:51 +00:00
rzou
e393c7fa05 Tighten torch.library.infer_schema input types (#130705)
Made the following changes:
- mutates_args is now keyword-only and mandatory. This is to align with
  torch.library.custom_op (which makes it mandatory because it's easy to
  miss)
- op_name is now keyword-only. This helps the readability of the API
- updated all usages of infer_schema

This change is not BC-breaking because we introduced
torch.library.infer_schema a couple of days ago.

Test Plan:
- tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130705
Approved by: https://github.com/yushangdi
ghstack dependencies: #131777
2024-07-29 16:01:19 +00:00
Shangdi Yu
68c725a094 [custom ops] Add register_vmap for custom ops (#130589)
Fixes #130284
Fixes #130653

- Add `torch.library.register_vmap` to custom ops
- Add `register_vmap` for operators in ops in custom_op_db.
- Make `torch.autograd.Function` support kwarg-only kwargs for vmap
- test operators in op_db with `tests/test_vmap`.
- change `test_vmap` to allow custom `out_dim` and allow "None" in `out_dim` when testing.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130589
Approved by: https://github.com/zou3519
2024-07-23 17:48:38 +00:00
PyTorch MergeBot
b435d84261 Revert "[custom ops] Add register_vmap for custom ops (#130589)"
This reverts commit 074b420641.

Reverted https://github.com/pytorch/pytorch/pull/130589 on behalf of https://github.com/atalman due to Please fix lint and reland ([comment](https://github.com/pytorch/pytorch/pull/130589#issuecomment-2244092174))
2024-07-23 01:44:44 +00:00
Shangdi Yu
074b420641 [custom ops] Add register_vmap for custom ops (#130589)
Fixes #130284
Fixes #130653

- Add `torch.library.register_vmap` to custom ops
- Add `register_vmap` for operators in ops in custom_op_db.
- Make `torch.autograd.Function` support kwarg-only kwargs for vmap
- test operators in op_db with `tests/test_vmap`.
- change `test_vmap` to allow custom `out_dim` and allow "None" in `out_dim` when testing.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130589
Approved by: https://github.com/zou3519
2024-07-23 00:54:52 +00:00
Xuehai Pan
ba48cf6535 [BE][Easy][6/19] enforce style for empty lines in import segments in test/ (#129757)
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter.

You can review these PRs via:

```bash
git diff --ignore-all-space --ignore-blank-lines HEAD~1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129757
Approved by: https://github.com/ezyang
2024-07-17 06:42:37 +00:00
PyTorch MergeBot
68a4f2a3df Revert "Tighten torch.library.infer_schema input types (#130705)"
This reverts commit ca2d424c6e.

Reverted https://github.com/pytorch/pytorch/pull/130705 on behalf of https://github.com/atalman due to Failing internal CI ([comment](https://github.com/pytorch/pytorch/pull/130705#issuecomment-2230821876))
2024-07-16 12:57:11 +00:00
rzou
ca2d424c6e Tighten torch.library.infer_schema input types (#130705)
Made the following changes:
- mutates_args is now keyword-only and mandatory. This is to align with
  torch.library.custom_op (which makes it mandatory because it's easy to
  miss)
- op_name is now keyword-only. This helps the readability of the API
- updated all usages of infer_schema

This change is not BC-breaking because we introduced
torch.library.infer_schema a couple of days ago.

Test Plan:
- tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130705
Approved by: https://github.com/yushangdi
2024-07-15 16:43:57 +00:00
rzou
9c69684af8 [custom_ops] expose torch.library.register_torch_dispatch (#130261)
This is the API for defining the interaction between a torch_dispatch
class and a custom op. Taking API bikeshedding.

Test Plan:
- new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130261
Approved by: https://github.com/albanD
ghstack dependencies: #130064
2024-07-12 14:13:01 +00:00
rzou
ba941769b5 Add API for open registration between operators and subclasses (and modes) (#130064)
We add torch.library.Library._register_torch_dispatch_rule. Here, a user
can provide us a specific rule to run for a specific
(torch_dispatch_class, operator) pair. The motivation is that a user
might want to extend a subclass/mode but may not have access to the
source code of the subclass/mode.

I'll make this public in a follow-up PR if we think the approach and API
is good.

Keep in mind that many subclasses will likely deliver their own open
registration solution (DTensor has register_sharding_prop_rule and NJT
has register_jagged_op); _register_torch_dispatch_rule is meant as a
catch-all open registration mechanism for when the subclass hasn't
provided anything more specific.

Test Plan:
- new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130064
Approved by: https://github.com/albanD
2024-07-12 14:13:01 +00:00
Shangdi Yu
a4576dad34 [reland][custom ops] infer schema (#130079)
Fixes #129617

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130079
Approved by: https://github.com/zou3519
2024-07-11 03:39:07 +00:00
PyTorch MergeBot
ce499eee0c Revert "Add API for open registration between operators and subclasses (and modes) (#130064)"
This reverts commit c23d103afa.

Reverted https://github.com/pytorch/pytorch/pull/130064 on behalf of https://github.com/izaitsevfb due to fails internal builds, see [D59553526](https://www.internalfb.com/diff/D59553526) ([comment](https://github.com/pytorch/pytorch/pull/130064#issuecomment-2221587575))
2024-07-10 21:50:32 +00:00
PyTorch MergeBot
86bca69c5f Revert "[custom_ops] expose torch.library.register_torch_dispatch (#130261)"
This reverts commit bb9a73f767.

Reverted https://github.com/pytorch/pytorch/pull/130261 on behalf of https://github.com/izaitsevfb due to depends on #130064 which needs to be reverted ([comment](https://github.com/pytorch/pytorch/pull/130261#issuecomment-2221569707))
2024-07-10 21:43:28 +00:00
PyTorch MergeBot
e14a0f45ed Revert "[reland][custom ops] infer schema (#130079)"
This reverts commit bef085bdfa.

Reverted https://github.com/pytorch/pytorch/pull/130079 on behalf of https://github.com/izaitsevfb due to depends on #130064 which needs to be reverted ([comment](https://github.com/pytorch/pytorch/pull/130079#issuecomment-2221561483))
2024-07-10 21:40:16 +00:00
Shangdi Yu
bef085bdfa [reland][custom ops] infer schema (#130079)
Fixes #129617

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130079
Approved by: https://github.com/zou3519
2024-07-10 16:18:36 +00:00
rzou
bb9a73f767 [custom_ops] expose torch.library.register_torch_dispatch (#130261)
This is the API for defining the interaction between a torch_dispatch
class and a custom op. Taking API bikeshedding.

Test Plan:
- new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130261
Approved by: https://github.com/albanD
ghstack dependencies: #130064
2024-07-09 21:11:27 +00:00
rzou
c23d103afa Add API for open registration between operators and subclasses (and modes) (#130064)
We add torch.library.Library._register_torch_dispatch_rule. Here, a user
can provide us a specific rule to run for a specific
(torch_dispatch_class, operator) pair. The motivation is that a user
might want to extend a subclass/mode but may not have access to the
source code of the subclass/mode.

I'll make this public in a follow-up PR if we think the approach and API
is good.

Keep in mind that many subclasses will likely deliver their own open
registration solution (DTensor has register_sharding_prop_rule and NJT
has register_jagged_op); _register_torch_dispatch_rule is meant as a
catch-all open registration mechanism for when the subclass hasn't
provided anything more specific.

Test Plan:
- new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130064
Approved by: https://github.com/albanD
2024-07-09 21:11:27 +00:00
Shangdi Yu
cab90b0049 [custom ops] disable kernel temporarily (#130190)
Fixes #128621

Sometimes we want to disable the backend implementation for testing/benchmarking purposes.

For example:

```python
@custom_op("mylib::f", mutates_args=())
def f(x: Tensor) -> Tensor:
    return torch.zeros(1)

print(f(torch.randn(1))) # tensor([0.])

@f.register_kernel("cpu")
def _(x):
    return torch.ones(1)

print(f(torch.randn(1))). # tensor([1.])

with f.set_kernel_enabled("cpu", enabled = False):
    print(f(0)) # tensor([0.])
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130190
Approved by: https://github.com/williamwen42, https://github.com/zou3519
2024-07-09 16:13:50 +00:00
PyTorch MergeBot
d44c30e2f9 Revert "Add API for open registration between operators and subclasses (and modes) (#130064)"
This reverts commit 922d2737d5.

Reverted https://github.com/pytorch/pytorch/pull/130064 on behalf of https://github.com/huydhn due to Sorry for reverting your change but test_profiler_tree is failing in trunk after this lands 922d2737d5, maybe a landrace ([comment](https://github.com/pytorch/pytorch/pull/130064#issuecomment-2216135497))
2024-07-09 01:48:38 +00:00
rzou
922d2737d5 Add API for open registration between operators and subclasses (and modes) (#130064)
We add torch.library.Library._register_torch_dispatch_rule. Here, a user
can provide us a specific rule to run for a specific
(torch_dispatch_class, operator) pair. The motivation is that a user
might want to extend a subclass/mode but may not have access to the
source code of the subclass/mode.

I'll make this public in a follow-up PR if we think the approach and API
is good.

Keep in mind that many subclasses will likely deliver their own open
registration solution (DTensor has register_sharding_prop_rule and NJT
has register_jagged_op); _register_torch_dispatch_rule is meant as a
catch-all open registration mechanism for when the subclass hasn't
provided anything more specific.

Test Plan:
- new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130064
Approved by: https://github.com/albanD
2024-07-08 22:13:05 +00:00
PyTorch MergeBot
44a773c121 Revert "[custom ops] infer schema (#130079)"
This reverts commit 3fe324ffb6.

Reverted https://github.com/pytorch/pytorch/pull/130079 on behalf of https://github.com/huydhn due to The test_public_bindings failure looks legit 3fe324ffb6 ([comment](https://github.com/pytorch/pytorch/pull/130079#issuecomment-2215420957))
2024-07-08 22:02:29 +00:00
Shangdi Yu
3fe324ffb6 [custom ops] infer schema (#130079)
Fixes #129617

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130079
Approved by: https://github.com/zou3519
2024-07-08 20:46:23 +00:00
Shangdi Yu
2fe7c1fe04 [custom ops] Support factory function (#129978)
Fixes #129389

If a user registers a device-specific implementation for an operator that accepts no Tensors, then we require the operator to have a "device: torch.device argument"

We switch on the device argument to select the correct backend to dispatch to.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129978
Approved by: https://github.com/zou3519
2024-07-04 00:10:52 +00:00
rzou
872d972e41 [custom_op] better error message on no returns (#129896)
I run into this a lot. I can imagine that it would look opaque to users,
so made it more friendly

Old error message: "ValueError: infer_schema(func): Return has unsupported type <class 'inspect._empty'>."

Test Plan:
- new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129896
Approved by: https://github.com/yushangdi
2024-07-02 23:34:23 +00:00
Shangdi Yu
aa0352ca38 [custom ops] add default value support for device types (#129792)
Fixes #129371

I think the first case in Issue #129371 is already supported in the current code? Since it takes care of string default values. This PR adds support for device type default values.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129792
Approved by: https://github.com/zou3519
2024-07-02 23:31:29 +00:00
Shangdi Yu
9fb2dec7a6 [custom ops] Add unknown arg (#129614)
Fixes #129372

Add a mutated_args="unknown" that pessimistically assumes that all inputs to the operator are being mutates.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129614
Approved by: https://github.com/zou3519
2024-07-02 16:10:14 +00:00
Shangdi Yu
deaab33f3f [custom op] add error message (#129417)
Fixes [#129370](https://github.com/pytorch/pytorch/issues/129370)

Suggest correct a List type annotation when input is in Tuple type. To avoid confusion, we only suggest a type if the type is supported.

Example:
Tuple[int, int] -> List[int]
Tuple[Tensor, Tensor, Optional[Tensor]] -> List[Optional[Tensor]]
Tuple[int, ...] -> List[int]

ValueError: infer_schema(func): Parameter y has unsupported type typing.Tuple[torch.Tensor, torch.Tensor, typing.Optional[torch.Tensor]]. Tuple type annotation is not supported. Please try to use a List instead. For example, typing.List[typing.Optional[torch.Tensor]].
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129417
Approved by: https://github.com/zou3519
2024-06-28 01:03:14 +00:00
rzou
856541c701 [custom_op] support default dtype values (#129189)
This PR:
- moves some of the dtype-string utilities into ScalarType.{h, cpp}
- adds a new utility to get a mapping from dtype name to the C++ dtype
- the perser now checks if the string is a dtype name; if it is then it
  pulls the c++ dtype from the mapping.

Test Plan:
- new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129189
Approved by: https://github.com/albanD
ghstack dependencies: #129177, #129178, #129179
2024-06-23 00:13:23 +00:00
rzou
5d8e23b49c [custom_op] Support string default values in schema (#129179)
Test Plan:
- new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129179
Approved by: https://github.com/albanD
ghstack dependencies: #129177, #129178
2024-06-21 13:31:40 +00:00
rzou
9972e5f447 Rename impl_abstract to register_fake, part 2/2 (#123938)
This PR renames the implementation details of register_fake to align
more with the new name. It is in its own PR because this is risky
(torch.package sometimes depends on private library functions and
implementation details).

Test Plan:
- tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123938
Approved by: https://github.com/williamwen42
2024-06-14 14:37:24 +00:00