Commit Graph

132 Commits

Author SHA1 Message Date
Gufan Yin
5d963474aa Replace enforce_dtype with dtype in ShardedTensor.gather (#110561)
Summary:
Sometimes local_shards are empty on some ranks, and out.dtype is float16, which will cause error if enforce_dtype is True because `data` will be float32.

Callers know best what dtype they want, so we can just let callers decide.

Temporarily keep enforce_dtype for backward compatibility

Test Plan: Run local and MAST job

Reviewed By: uciyc123

Differential Revision: D46886551

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110561
Approved by: https://github.com/wanchaol, https://github.com/malfet
2023-10-05 23:16:23 +00:00
Fabrice Pont
053367b1ed fix: flake8-bugbear code B024 (#107265)
See #106571 item B024

This fix concerns the addition of `abstractmethod` to methods declared inside abstract classes.

Should I also include PEP8 compliant reformatting on the files I had to modify ?
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107265
Approved by: https://github.com/kit1980
2023-10-04 23:52:52 +00:00
Brian
e20c35a53b Allow public access for imports (#108914)
Fixes #108776

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108914
Approved by: https://github.com/wanchaol
2023-09-28 06:05:59 +00:00
Brian
806c52b4c9 Update chunk_sharding_spec.py (#108915)
Fixes #108869

Implements the first solution proposed in the issue.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108915
Approved by: https://github.com/wanchaol, https://github.com/wz337
2023-09-15 21:43:15 +00:00
Aaron Gokaslan
6d43c89f37 [BE]: Update Ruff to 0.0.280 (#105724)
Removes unusued loop values in python dictionary iteration. Automated fix from Ruff master

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105724
Approved by: https://github.com/ezyang, https://github.com/janeyx99
2023-07-22 23:03:34 +00:00
Justin Chu
232b96b6e2 [BE] Enable ruff's UP rules and autoformat distributed/ (#105433)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105433
Approved by: https://github.com/albanD
2023-07-19 14:27:11 +00:00
Nikita Shulga
5837e95d30 [Reland] Update mypy to 1.4.1 (#105227)
This PR re-lands
- [Typing] Fix PEP 484 Violation (#105022)
- Update mypy to 1.4.1 (#91983)

That were reverted due to the conflict with internal source repo.

Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
  - Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
  - Add missing return statement to `torch._export. deserialize_graph`
  - Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
  - Add assert it `torch/optim/optimizer.py` that Optional list is not None
TODO (in followup PR):
  - Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`

Unrelated, to bypass CI failures due to the gcc9 dependency update in Ubuntu-18.04:
- Add hack to squash older libstdc++ from conda environment in favor one from OS to `.ci/docker/install_conda.sh`
- Update bazel cuda builds to focal, as with libstdc++-6.0.32 bazel builds loose the ability to catch exceptions (probably because they link with cupti statically, but I could not found where it is done)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105227
Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007
2023-07-15 20:30:20 +00:00
PyTorch MergeBot
15fd1ea118 Revert "[Reland] Update mypy to 1.4.1 (#105227)"
This reverts commit c9c4f8efc3.

Reverted https://github.com/pytorch/pytorch/pull/105227 on behalf of https://github.com/atalman due to trying to mitigate ci sev #105248 ([comment](https://github.com/pytorch/pytorch/pull/105227#issuecomment-1636510935))
2023-07-14 22:28:35 +00:00
Nikita Shulga
c9c4f8efc3 [Reland] Update mypy to 1.4.1 (#105227)
This PR re-lands
- [Typing] Fix PEP 484 Violation (#105022)
- Update mypy to 1.4.1 (#91983)

That were reverted due to the conflict with internal source repo.

Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
  - Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
  - Add missing return statement to `torch._export. deserialize_graph`
  - Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
  - Add assert it `torch/optim/optimizer.py` that Optional list is not None
TODO (in followup PR):
  - Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105227
Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007
2023-07-14 20:45:12 +00:00
PyTorch MergeBot
3c5a494d7a Revert "Update mypy to 1.4.1 (#91983)"
This reverts commit 634659e262.

Reverted https://github.com/pytorch/pytorch/pull/91983 on behalf of https://github.com/malfet due to It's dependent change was reverted, so reverting this one as well, to keep CI clean ([comment](https://github.com/pytorch/pytorch/pull/91983#issuecomment-1636059709))
2023-07-14 15:59:16 +00:00
Nikita Shulga
634659e262 Update mypy to 1.4.1 (#91983)
Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
  - Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
  - Add missing return statement to `torch._export. deserialize_graph`
  - Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
  -
TODO (in followup PR):
  - Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91983
Approved by: https://github.com/kit1980, https://github.com/ZainRizvi, https://github.com/huydhn, https://github.com/thiagocrepaldi, https://github.com/aaronenyeshi
2023-07-13 16:30:36 +00:00
Chao Yang
367b0ad062 enforce dtype (reland) (#102996)
Summary: The original diff didn't break the test.

Test Plan: N/A

Differential Revision: D46448488

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102996
Approved by: https://github.com/malfet, https://github.com/wanchaol
2023-06-06 00:35:04 +00:00
PyTorch MergeBot
ecb191683e Revert "enforece dtype (#102802)"
This reverts commit 8e2a86c2a5.

Reverted https://github.com/pytorch/pytorch/pull/102802 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/102802#issuecomment-1577099676))
2023-06-05 16:21:28 +00:00
Chao Yang
8e2a86c2a5 enforece dtype (#102802)
Summary: Add a flag to enforce the gather data dtype. In case backward compatibility, make the default as False

Test Plan: local and mast

Reviewed By: zyan0, strisunshinewentingwang

Differential Revision: D46295190

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102802
Approved by: https://github.com/mrshenli
2023-06-05 02:04:09 +00:00
Thibaut Durand
01da732691 Fix type annotation of torch.split (#100655)
The type annotation indicates `list` but the returned type is `tuple`
```python
>>> import torch
>>> type(torch.arange(10).split(4))
<class 'tuple'>
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100655
Approved by: https://github.com/kit1980
2023-05-16 21:35:41 +00:00
Ivan Kobzarev
4582ceb2c4 [distributed][sharded_tensor] Move local_shards check from ShardedTensorBase to ShardedTensor (#100197)
Differential Revision: [D45369211](https://our.internmc.facebook.com/intern/diff/D45369211)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100197
Approved by: https://github.com/fduwjj
2023-05-02 12:42:24 +00:00
Aaron Gokaslan
e2a3817dfd [BE] Enable C419 rule for any all shortcircuiting (#99890)
Apparently https://github.com/pytorch/pytorch/pull/78142 made torch.JIT allow for simple generator expressions which allows us to enable rules that replace unnecessary list comprehensions with generators in any/all. This was originally part of #99280 but I split it off into this PR so that it can be easily reverted should anything break.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99890
Approved by: https://github.com/justinchuby, https://github.com/kit1980, https://github.com/malfet
2023-04-25 15:02:13 +00:00
Iris
ca8625f456 [BE][1/N]Add sharding spec logger for ShardedTensor (#99748)
Set up a nullHandler() on the OSS side.
Next step is to set up the counterpart in internal.

This is part of the effort for ShardedTensor deprecation. We want to log internal use cases for different sharding spec.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99748
Approved by: https://github.com/H-Huang, https://github.com/fegin
2023-04-22 04:05:21 +00:00
Kazuaki Ishizaki
6514d71add Fix typos under torch/distributed directory (#98225)
This PR fixes typos in comments and messages of `.py` files under `torch/distributed` directory

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98225
Approved by: https://github.com/soulitzer, https://github.com/kit1980
2023-04-05 00:21:33 +00:00
Kazuaki Ishizaki
35fd5c548e Fix typos under torch/distributed directory (#95638)
This PR fixes typos in comments and messages of `.py` files under torch/distributed directory

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95638
Approved by: https://github.com/usamah1, https://github.com/H-Huang, https://github.com/kit1980
2023-03-27 21:13:44 +00:00
fduwjj
b11ce4bbca Bring back tensor_has_compatible_shallow_copy_type (#97455)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97455
Approved by: https://github.com/clee2000
2023-03-24 06:43:20 +00:00
fduwjj
5cc2e4d7c9 [10/N] Remove ST init ops (#96985)
Differential Revision: [D44158326](https://our.internmc.facebook.com/intern/diff/D44158326)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96985
Approved by: https://github.com/wz337, https://github.com/wanchaol
2023-03-22 20:26:18 +00:00
fduwjj
546835c45a [9/N] Remove ST multiple ops (#96989)
Differential Revision: [D44158327](https://our.internmc.facebook.com/intern/diff/D44158327)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96989
Approved by: https://github.com/wz337, https://github.com/wanchaol
2023-03-22 20:02:58 +00:00
Xuehai Pan
80e8e41ca7 Fix type hint for torch.Tensor.grad_fn (#96804)
Fix type hint for `torch.Tensor.grad_fn`, which can be a `torch.autograd.graph.Node` or `None`.

This is a regression in `torch` 2.0. It makes `mypy` failure in downstream projects.

Ref:

- https://github.com/pytorch/pytorch/issues/94937#issuecomment-1469344993
- metaopt/torchopt#149
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96804
Approved by: https://github.com/Skylion007
2023-03-15 17:14:05 +00:00
fduwjj
28aa2efd14 [7/N][BE] Remove Partial Tensor and its dependency (#95949)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95949
Approved by: https://github.com/wanchaol
2023-03-06 19:57:46 +00:00
fduwjj
6dddc0d689 [6/N][BE] Remove Sharded Linear Op for ShardedTensor (#95948)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95948
Approved by: https://github.com/wanchaol
2023-03-06 19:57:19 +00:00
fduwjj
4e396a54e8 [5/N][BE] Remove Replicated Tensor class (#95947)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95947
Approved by: https://github.com/wanchaol
2023-03-06 19:50:17 +00:00
Jane Xu
e5b9d98752 Rephrase zero_grad docs (#95643)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95643
Approved by: https://github.com/albanD
2023-02-28 22:04:23 +00:00
fduwjj
38fdd28db4 [4/N][Deprecate ST][BE] Move warnings of Partial Tensor to functions (#95631)
To solve https://github.com/pytorch/pytorch/issues/95623
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95631
Approved by: https://github.com/wanchaol
2023-02-27 22:28:04 +00:00
fduwjj
fa7f17799a [3/N][BE][ST Deprecate] Remove Replicated Tensor (#95453)
Please use distributed tensor instead. We are deprecating ShardedTensor.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95453
Approved by: https://github.com/wanchaol
2023-02-26 06:18:31 +00:00
fduwjj
b4c8186774 [BE][1/N] Add deprecate msg to Sharded Partial and Replicate Tensor (#94928)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94928
Approved by: https://github.com/wanchaol
2023-02-16 03:23:53 +00:00
Aaron Gokaslan
67d9790985 [BE] Apply almost all remaining flake8-comprehension checks (#94676)
Applies the remaining flake8-comprehension fixes and checks. This changes replace all remaining unnecessary generator expressions with list/dict/set comprehensions which are more succinct, performant, and better supported by our torch.jit compiler. It also removes useless generators such as 'set(a for a in b)`, resolving it into just the set call.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94676
Approved by: https://github.com/ezyang
2023-02-12 01:01:25 +00:00
Xuehai Pan
5b1cedacde [BE] [2/3] Rewrite super() calls in functorch and torch (#94588)
Rewrite Python built-in class `super()` calls. Only non-semantic changes should be applied.

- #94587
- #94588
- #94592

Also, methods with only a `super()` call are removed:

```diff
class MyModule(nn.Module):
-   def __init__(self):
-       super().__init__()
-
    def forward(self, ...):
        ...
```

Some cases that change the semantics should be kept unchanged. E.g.:

f152a79be9/caffe2/python/net_printer.py (L184-L190)

f152a79be9/test/test_jit_fuser_te.py (L2628-L2635)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94588
Approved by: https://github.com/ezyang, https://github.com/albanD
2023-02-10 21:16:33 +00:00
Aaron Gokaslan
1e2d82b8e4 [BE] Merge isinstance calls together (#94419)
Simplify and speeds up isinstance calls by checking for multiple types at the same time.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94419
Approved by: https://github.com/ezyang
2023-02-09 00:47:26 +00:00
Aaron Gokaslan
8fce9a09cd [BE]: pyupgrade Python to 3.8 - imports and object inheritance only (#94308)
Apply parts of pyupgrade to torch (starting with the safest changes).
This PR only does two things: removes the need to inherit from object and removes unused future imports.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94308
Approved by: https://github.com/ezyang, https://github.com/albanD
2023-02-07 21:10:56 +00:00
Jane Xu
b90496eef5 [nn] zero_grad() set_to_none default True (#92731)
Attempts to fix #92656

BC-breaking! This changes the default of zero_grad in optim and in nn to default set grads to None instead of zero tensors. We are changing the default because there are proven perf wins and existing code has typically not regressed due to this change. (will probably have to flesh out this note more).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/92731
Approved by: https://github.com/ngimel
2023-01-26 01:04:28 +00:00
Connor Baker
2066523508 Fix ShardedTensorMetadata.tensor_properties for Python 3.11 (#91795)
The `tensor_properties` field of the `ShardedTensorMetadata` dataclass is a reference to a `TensorProperties` object. However, the field is set to `field(default=TensorProperties())` instead of `field(default_factory=TensorProperties)`. This causes an error when using Python 3.11 or later:

```python
ValueError: mutable default <class 'torch.distributed._shard.sharded_tensor.metadata.TensorProperties'> for field tensor_properties is not allowed: use default_factory
```

This change in dataclass behavior was introduced in [bpo-44674: Use unhashability as a proxy for mutability for default dataclass __init__ arguments](https://github.com/python/cpython/pull/29867).

The current use of `default` instead of `default_factory` also means that all `ShardedTensorMetadata` objects created without specifying `tensor_properties` will share the same `TensorProperties` object.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91795
Approved by: https://github.com/fduwjj
2023-01-19 04:21:05 +00:00
Samantha Andow
a7749ae177 [reland] rename DisableTorchFunction to DisableTorchFunctionSubclass (#88218) (#89221)
Summary: First half of #87990. This doesn't change any of the behavior and is just a rename

#88218 got reverted for internal breakages. This is the reland of started from internal

Differential Revision:
D41268423

LaMa Project: L1098534

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89221
Approved by: https://github.com/meliy-meyada, https://github.com/zou3519
2023-01-04 18:32:49 +00:00
joncrall
ad782ff7df Enable xdoctest runner in CI for real this time (#83816)
Builds on #83317 and enables running the doctests. Just need to figure out what is causing the failures.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83816
Approved by: https://github.com/ezyang, https://github.com/malfet
2022-12-29 05:32:42 +00:00
Sergii Dymchenko
365071c73c Fix non-existing parameters in docstrings in torch/distributed (#91116)
This is a continuation of https://github.com/pytorch/pytorch/pull/90505
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91116
Approved by: https://github.com/huydhn
2022-12-22 02:37:31 +00:00
Sergii Dymchenko
f51f6aa387 Fix non-existing parameters in docstrings (#90505)
Continuation after https://github.com/pytorch/pytorch/pull/90163.

Here is a script I used to find all the non-existing arguments in the docstrings (the script can give false positives in presence of *args/**kwargs or decorators):

_Edit:_
I've realized that the indentation is wrong for the last `break` in the script, so the script only gives output for a function if the first docstring argument is wrong. I'll create a separate PR if I find more issues with corrected script.

``` python
import ast
import os
import docstring_parser

for root, dirs, files in os.walk('.'):
    for name in files:
        if root.startswith("./.git/") or root.startswith("./third_party/"):
            continue
        if name.endswith(".py"):
            full_name = os.path.join(root, name)
            with open(full_name, "r") as source:
                tree = ast.parse(source.read())
                for node in ast.walk(tree):
                    if isinstance(node, ast.FunctionDef):
                        all_node_args = node.args.args
                        if node.args.vararg is not None:
                            all_node_args.append(node.args.vararg)
                        if node.args.kwarg is not None:
                            all_node_args.append(node.args.kwarg)
                        if node.args.posonlyargs is not None:
                            all_node_args.extend(node.args.posonlyargs)
                        if node.args.kwonlyargs is not None:
                            all_node_args.extend(node.args.kwonlyargs)
                        args = [a.arg for a in all_node_args]
                        docstring = docstring_parser.parse(ast.get_docstring(node))
                        doc_args = [a.arg_name for a in docstring.params]
                        clean_doc_args = []
                        for a in doc_args:
                            clean_a = ""
                            for c in a.split()[0]:
                                if c.isalnum() or c == '_':
                                    clean_a += c
                            if clean_a:
                                clean_doc_args.append(clean_a)
                        doc_args = clean_doc_args
                        for a in doc_args:
                            if a not in args:
                                print(full_name, node.lineno, args, doc_args)
                            break

```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90505
Approved by: https://github.com/malfet, https://github.com/ZainRizvi
2022-12-09 21:43:09 +00:00
PyTorch MergeBot
cba96366a2 Revert "remove torch.equal usages (#89527)"
This reverts commit 4095ef8b80.

Reverted https://github.com/pytorch/pytorch/pull/89527 on behalf of https://github.com/clee2000 due to broke periodic multigpu tests 4095ef8b80 https://github.com/pytorch/pytorch/actions/runs/3592806602/jobs/6049368502
2022-12-02 21:36:13 +00:00
Philip Meier
4095ef8b80 remove torch.equal usages (#89527)
Preparation for the next PR in this stack: #89559.

I replaced

- `self.assertTrue(torch.equal(...))` with `self.assertEqual(..., rtol=0, atol=0, exact_device=True)`,
- the same for `self.assertFalse(...)` with `self.assertNotEqual(...)`, and
- `assert torch.equal(...)` with `torch.testing.assert_close(..., rtol=0, atol=0)` (note that we don't need to set `check_device=True` here since that is the default).

There were a few instances where the result of `torch.equal` is used directly. In that cases I've replaced with `(... == ...).all().item()` while sometimes also dropping the `.item()` depending on the context.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89527
Approved by: https://github.com/mruberry
2022-12-01 11:22:52 +00:00
Colin Taylor
24b9890f03 [torchrec] [composable] update ShardedEmbeddingBagCollection to be use registered EBCs with shardedTensors as registered modules (#758) (#88026)
Summary:
X-link: https://github.com/pytorch/torchrec/pull/758

This PR fixes a bug in FSDP/DDP, where ShardedTensors are not supported even if passed in as params to ignore.
this is important for composability because TorchRec named_parameters() will return FQN of shardedTensors (as defined in goals)
It defines device of ShardedTensor to be None when local_tensor() does not exist on rank

update ShardedEmbeddingBagCollection to be composable according to https://docs.google.com/document/d/1TBJSd5zgEg6cRcXv3Okuj7bBkqQwGS2IPh4TLWNNzFI/edit

Differential Revision: D40458625

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88026
Approved by: https://github.com/wanchaol, https://github.com/rohan-varma
2022-11-17 04:26:13 +00:00
Iris
aee96bbf5a [PT-D][Checkpointing] Move distributed checkpointing from torch.distributed._shard.checkpoint to torch.distributed.checkpoint (#88698)
Context in RFC: https://github.com/pytorch/pytorch/issues/86620

.rst file will be finalized in subsequent PRs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88698
Approved by: https://github.com/wanchaol
2022-11-16 21:06:38 +00:00
PyTorch MergeBot
ba4d5aae06 Revert "rename DisableTorchFunction to DisableTorchFunctionSubclass (#88218)"
This reverts commit 7f28be10e5.

Reverted https://github.com/pytorch/pytorch/pull/88218 on behalf of https://github.com/izaitsevfb due to BC-breaking change, D41211901
2022-11-11 19:13:05 +00:00
samdow
7f28be10e5 rename DisableTorchFunction to DisableTorchFunctionSubclass (#88218)
First half of #87990. This doesn't change any of the behavior and is just a rename

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88218
Approved by: https://github.com/ezyang, https://github.com/zou3519
2022-11-10 14:51:13 +00:00
Kurt Mohler
ee28b865ee Deprecate TypedStorage, its derived classes, and all of their public methods (#85303)
Part of #85302

Pull Request resolved: https://github.com/pytorch/pytorch/pull/85303
Approved by: https://github.com/ezyang
2022-11-08 18:11:01 +00:00
Chien-Chin Huang
bd1e95ce30 Improve the performance of validate_non_overlapping_shards_metadata (#85639)
`validate_non_overlapping_shards_metadata()` uses a quadratic algorithm to verify the overlapping. However, in some cases (only one dimension is sharded), we a O(nlogn) algorithm can easily be implemented. This PR changes the implementation of `validate_non_overlapping_shards_metadata()`.

Differential Revision: [D39681725](https://our.internmc.facebook.com/intern/diff/D39681725/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85639
Approved by: https://github.com/wanchaol
2022-10-20 23:51:48 +00:00
Andrew Gu
56c0c0af5b [ShardedTensor] Add is_floating_point (#85483)
This adds `is_floating_point()` support to `ShardedTensor`. This is needed for `ShardedTensor` + FSDP.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85483
Approved by: https://github.com/wanchaol
2022-09-23 04:48:03 +00:00