Commit Graph

74 Commits

Author SHA1 Message Date
Oguz Ulgen
1df14f1bf8 Move has_triton to top level triton utils so that dynamo can also access (#109832)
it without creating cyclic dependencies

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109832
Approved by: https://github.com/zou3519
2023-09-22 19:33:41 +00:00
Wanchao Liang
2fa063e1e0 [device_mesh][BE] remove allgather from DM (#105614)
For the reason similar to https://github.com/pytorch/pytorch/pull/105605
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105614
Approved by: https://github.com/rohan-varma, https://github.com/wz337, https://github.com/fduwjj
2023-07-27 01:33:05 +00:00
Wanchao Liang
4a49f1f46e [device mesh][BE] remove allreduce from DM (#105605)
This PR removes allreduce from DM and use functional collective instead,
the rationle is that we don't want to maintain yet another set of
collective apis, and since the DM's collective is now a thin wrapper to functional collective so we
don't really need these collective to live in DM
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105605
Approved by: https://github.com/kumpera, https://github.com/wz337, https://github.com/fduwjj
2023-07-27 01:33:02 +00:00
Justin Chu
232b96b6e2 [BE] Enable ruff's UP rules and autoformat distributed/ (#105433)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105433
Approved by: https://github.com/albanD
2023-07-19 14:27:11 +00:00
Chien-Chin Huang
2f04aab140 [SPMD] Disable all SPMD tests (#104784)
SPMD is not actively developed and is out-of-sync with the PyTorch compiler code.  Disable the tests for now.

Differential Revision: [D47296840](https://our.internmc.facebook.com/intern/diff/D47296840/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104784
Approved by: https://github.com/fduwjj
2023-07-07 23:31:54 +00:00
Yeonju Ro
06f656c5d1 [distributed] implemented find_all_descendants (#102138)
Fixes #100397

Implemented find_all_descendants function that identifies the list of nodes that need to be moved. Added unit test.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102138
Approved by: https://github.com/fegin
2023-05-24 21:47:59 +00:00
Wanchao Liang
d316a2dd5c [spmd] Enable data parallel to work with non 0 batch dim (#100073)
This PR enables data parallel to work with non 0 batch dim, the only
thing we need to do is to expose the input_batch_dim to DataParallelMode
and the data parallel expansion automatically works as we have done
things correctly in batch dim analysis.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100073
Approved by: https://github.com/mrshenli
2023-05-24 17:55:10 +00:00
Wanchao Liang
dd1f295201 [spmd] Improve activation handling, factory ops and batch dim reduction (#100853)
This PR improves the activation handling logic of data parallel, to
support the cases where there're tensor factory ops that does not depend
on any input node, it would still produce activation, with either
sharded act (i.e. if output shape have batch size) or replcate act

It also significantly simplify the full reduction logic, now we don't
need the full reduction detection, we only need to ensure that when
compute the batch dim, we detected full reduction and mark it as sharded
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100853
Approved by: https://github.com/mrshenli
2023-05-24 17:55:09 +00:00
Wanchao Liang
4d55ea8548 [spmd] enhance batch dim analysis of data parallel (#100852)
This PR enhances batch dim analysis of data parallel to understand
more on the cases where batch dim get flattened or split, using
dtensor's view ops, we could be able to track the batch dim that got
transformed in non-trival ways.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100852
Approved by: https://github.com/mrshenli
2023-05-24 17:55:07 +00:00
Wanchao Liang
b2eaba6b62 [spmd] by default average gradients for nccl backend (#99964)
This PR by default average gradient for NCCL backend, this allows
SPMD's data parallel match with DDP/FSDP results.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99964
Approved by: https://github.com/mrshenli
2023-05-24 17:55:06 +00:00
Edward Z. Yang
3318a832b3 Tighten FakeTensor reentrancy asserts, add debugging (#102091)
When investigating failures in https://github.com/pytorch/pytorch/pull/100017 I realized that we were reentering FakeTensorMode even though there was already one on the stack. Although we have attempted assert for these cases in the past, e.g., as in https://github.com/pytorch/pytorch/pull/97186 it seems that the existing protections were insufficient.

In this particular case, the reapplication of FakeTensorMode was due to an interaction with NotImplemented multiple dispatch handling. If proxy tensor mode detects an unrecognized tensor type (this includes FakeTensor, if it is not tracked with a proxy), it will return NotImplemented to give this tensor a chance to unpack itself into proxyable operation. However, this is never the right thing for FakeTensor, where no unpacking is possible. However, today, FakeTensor attempts to reapply the FakeTensorMode, resulting in FakeTensorMode being twice on the stack.

This PR does a number of things:

* It adds an assert in `FakeTensorMode.__torch_dispatch__` that you must not already have this mode on the stack, this is ALWAYS an error
* It modifies `FakeTensor.__torch_dispatch__` to return `NotImplemented` if the mode is already active. This prevents us from readding the mode on the stack
* It adds a new logging artifact `not_implemented` which you can use to get debug logs about all of the times a `__torch_dispatch__` handler returned NotImplemented and why it did so. Your subclass has to manually opt into this logging, but I inserted the necessary logs for ProxyTensorMode and FakeTensor(Mode)
* `with fake_mode` now no-ops if the fake mode is already on the stack, which is what users want anyway
* I am BREAKING pre-autograd tracing, because it is currently doing something weird with the original C++ mode stack. Brian is going to follow up with a fix next week.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102091
Approved by: https://github.com/thiagocrepaldi, https://github.com/eellison, https://github.com/wanchaol, https://github.com/bdhirsh
2023-05-24 05:37:51 +00:00
Jack Taylor
187eb7ca88 Enable default workflow PyT 2.0 UTs on ROCm stack (#100981)
PR to enable default workflow PyTorch 2.0 unit tests for the ROCm stack.

- Enables all the dynamo unit test suites
- Enables some of the inductor unit test suites
       - `test_config`
       - `test_cpp_wrapper` (cpu only)
       - `test_minifier`
       - `test_standalone_compile`
       - `test_torchinductor_dynamic_shapes`
       - `test_torchinductor_opinfo`
       - `test_torchinductor`
       - `test_triton_wrapper`
- Introduces TEST_WITH_ROCM conditions for unit test skip/fail dictionaries in test_torchinductor_dynamic_shapes.py and test_torchinductor_opinfo.py

Note this PR follows on from the discussions for the previous UT enablement PR https://github.com/pytorch/pytorch/pull/97988, we have opted to only enable a few inductor suites at the moment to ease the upstreaming effort as these files are changing very quickly.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100981
Approved by: https://github.com/jithunnair-amd, https://github.com/malfet
2023-05-15 23:45:04 +00:00
Shen Li
af841f38bd [SPMD] Allow Override.replacement to have a global view (#101427)
It's easier for users to implement one Override that takes care of
all target submodules of different types, instead of specifying one
mapping pair for each FQN/type. For example, when calculating
sharding for sparse layers, the decision needs to be make globally.
In this, case it's helpful to allow user Override to get access to
all submodules and make replacement decisions accordingly.

Differential Revision: [D45879732](https://our.internmc.facebook.com/intern/diff/D45879732)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101427
Approved by: https://github.com/fegin
2023-05-15 21:27:41 +00:00
Chien-Chin Huang
49c8a0cad0 [SPMD][BE] Remove the legacy tracing code (#100858)
Remove the legacy tracing code as it cause several test and benchmark issues.

Differential Revision: [D45649123](https://our.internmc.facebook.com/intern/diff/D45649123/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100858
Approved by: https://github.com/wanchaol
2023-05-11 23:08:27 +00:00
Shen Li
2ebb48ff28 [SPMD] add FQN argument to Override.replacement (#100473)
Differential Revision: [D45486089](https://our.internmc.facebook.com/intern/diff/D45486089)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100473
Approved by: https://github.com/wanchaol
2023-05-03 14:20:01 +00:00
Shen Li
9439cb0e11 Avoid using einsum for torch.cat DTensor propogation (#100251)
DTensor was reusing `einop_rule` to propagate sharding for torch.cat.
However, einsum only supports up to 52 subscripts (i.e., input tensors).
We have encountered use cases where one cat operator has more than 60
input tensors. Therefore, this commit reimplements sharding prop
rule for cat without using einsum.

Differential Revision: [D45435232](https://our.internmc.facebook.com/intern/diff/D45435232)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100251
Approved by: https://github.com/wanchaol
2023-05-03 01:56:18 +00:00
Chien-Chin Huang
e0a2b49f0b [SPMD] Introduce prerequisites to graph_optimization_pass (#99970)
Some optimizations require prerequisite passes. It is hard to debug why a optimization pass because of the prerequisites condition does not match. Adding this check makes it easier to discover the error.

Differential Revision: [D45255377](https://our.internmc.facebook.com/intern/diff/D45255377/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99970
Approved by: https://github.com/lessw2020
2023-04-28 18:38:01 +00:00
Chien-Chin Huang
b94a0ba5bb [SPMD] Add embedding dense backward prop rule for postional embedding (#100038)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100038
Approved by: https://github.com/mrshenli
2023-04-27 16:31:51 +00:00
Wanchao Liang
fc6f2f6e4e [spmd] simplify data parallel tests (#99901)
As titled
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99901
Approved by: https://github.com/awgu, https://github.com/mrshenli
2023-04-25 19:31:00 +00:00
Wanchao Liang
c6949db481 [spmd] enable fully_shard fused_adam test (#99898)
This PR enables fully_shard fused adam tests with some additional tweaks
about how to handle scalar tensor. Now we treat scalar tensors as if
it's just a scalar value, we don't distribute it as there's no need to
shard a scalar tensor
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99898
Approved by: https://github.com/mrshenli
2023-04-25 19:30:55 +00:00
Wanchao Liang
ad882c5210 [spmd] Use TupleStrategy and enable replicate fused_adam (#99374)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99374
Approved by: https://github.com/mrshenli
2023-04-25 19:30:53 +00:00
Wanchao Liang
9db6920635 [spmd] Add list handling to data parallel and add foreach tests (#99373)
This PR adds list handling logic to the new DataParallel expansion and
add foreach optimizer tests, currently current testing sgd optimizers
in foreach mode, for both replicate and fully shard

Next step:

Add fused optim tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99373
Approved by: https://github.com/mrshenli
2023-04-22 05:39:20 +00:00
Wanchao Liang
e9bf94149e [spmd] Introduce Compile Mode FSDP with DTensor (#99062)
This PR introduces compile mode Data Parallel (FSDP/DDP) using DTensor sharding.

Along with the algorithm, it also introduces a new DataParallelMode so that `compile` API can take it
and apply data parallel. This PR trys to preserve the DTensorExpand
approach first to avoid BC, we shall discuss steps to remove
DTensorExpand.

The data parallel mode uses heuristics to determine node types in the
graphs and assign the corresponding sharding. The detailed algorithm
described in the design doc.

The benefits of this approach:
- Model parameters and optimizer states are all DTensors after  `spmd.compile`, which is necessary for FSDP, and also makes it super easier for checkpointing
- As model parameter/optim states are sharding in a per-parameter approach, it would be able to compose with sophisticated second order optimizer (i.e. Shampoo) in a easier way.
- We leverage the model parameter/grads information to derive data parallel pattern. In this way we don't need to worry about DTensor op coverage anymore! As data parallel is just a special case of DTensor operation.
- Use dtensor_expand might work for DDP but aren't going to work for FSDP as dtensor might choose to allgather activation, which might violate native fsdp algorithm.
- The approach is general enough to support both DDP/FSDP and a mixed mode

Follow ups:
- Add the "default" data parallel mode which supports mixing of
replicate/fully shard
- Test more e2e models with more different types of optimizers, etc
- migrate the existing stack from the DTensorExpand mode
- build optimizations on top of this prototype

Differential Revision: [D45174400](https://our.internmc.facebook.com/intern/diff/D45174400)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99062
Approved by: https://github.com/mrshenli
2023-04-22 03:13:05 +00:00
Edward Z. Yang
abdd1f4a38 Reuse tracing context and fake tensors from backwards in forwards (#99619)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99619
Approved by: https://github.com/wanchaol
2023-04-20 22:39:48 +00:00
Chien-Chin Huang
88c45a1954 [SPMD] Allow users to dynamically pass the last_iter to IterGraphModule (#99575)
The current design of IterGraphModule requires users to specify the concrete iteration count which is not always possible and not very precise. This PR introduce `last_iter` to IterGraphModule.forward() which allows users to dynamically specify the last iteration.

Differential Revision: [D45129585](https://our.internmc.facebook.com/intern/diff/D45129585/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99575
Approved by: https://github.com/lessw2020
2023-04-20 16:49:34 +00:00
Shen Li
e605b5df74 [SPMD] Add sym_stride to DSymInt (#99504)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99504
Approved by: https://github.com/fegin
2023-04-19 14:55:40 +00:00
Shen Li
2cb8a8d4cc [SPMD] Support DSymInt for slice_backward in SPMD expansion (#99501)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99501
Approved by: https://github.com/fegin
2023-04-19 14:55:40 +00:00
Shen Li
292296141a [SPMD] Support SymInt with non-op call_function nodes (#99420)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99420
Approved by: https://github.com/fegin
2023-04-19 14:55:37 +00:00
Shen Li
7c0c663a4c [SPMD] Add aten.stack and aten.select to DTensor prop (#99417)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99417
Approved by: https://github.com/fegin
2023-04-19 14:55:34 +00:00
Chien-Chin Huang
41d7969590 [SPMD] Upstream iter_move_grads_and_optimizers (#98785)
This PR upstreams `iter_move_grads_and_optimizer` which delay some of the gradients and the corresponding optimizer to the next iteration. D44512863(credit to @lessw2020 ) is the internal implementation, which is only good for the old _SPMD expansion.  This PR changes the implmentation to use the new APIs.

Differential Revision: [D44836486](https://our.internmc.facebook.com/intern/diff/D44836486/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98785
Approved by: https://github.com/mrshenli
2023-04-19 06:40:33 +00:00
Rodrigo Kumpera
38e964056b Reland python ops (#99170)
Waiting for the revert to land.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99170
Approved by: https://github.com/albanD
2023-04-18 15:15:46 +00:00
PyTorch MergeBot
1c042a2137 Revert "Reland python ops (#99170)"
This reverts commit d4de64ae8d.

Reverted https://github.com/pytorch/pytorch/pull/99170 on behalf of https://github.com/DanilBaibak due to Break internal build
2023-04-18 11:37:43 +00:00
Rodrigo Kumpera
d4de64ae8d Reland python ops (#99170)
Waiting for the revert to land.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99170
Approved by: https://github.com/albanD
2023-04-17 21:53:41 +00:00
Chien-Chin Huang
148d49260a [SPMD] Implement split_fused_optimizer to split one fused_optimizer node to two (#98784)
Several optimization passes requires the ability to split the fused_optimizer.  This PR adds the API to support the use cases.

Differential Revision: [D44806450](https://our.internmc.facebook.com/intern/diff/D44806450/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98784
Approved by: https://github.com/mrshenli
2023-04-17 10:02:07 +00:00
Shen Li
c69d54885a [SPMD][BE] Generalize factory ops support in SPMD expansion (#99233)
Differential Revision: [](https://our.internmc.facebook.com/intern/diff/)

Differential Revision: [](https://our.internmc.facebook.com/intern/diff/)

Differential Revision: [](https://our.internmc.facebook.com/intern/diff/)

Differential Revision: [](https://our.internmc.facebook.com/intern/diff/)

Differential Revision: [D45028740](https://our.internmc.facebook.com/intern/diff/D45028740)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99233
Approved by: https://github.com/yifuwang
2023-04-16 00:07:27 +00:00
Shen Li
6bb20822f5 [SPMD][BE] Remove deprecated aten.sym_numel branch (#99232)
Differential Revision: [](https://our.internmc.facebook.com/intern/diff/)

Differential Revision: [](https://our.internmc.facebook.com/intern/diff/)

Differential Revision: [](https://our.internmc.facebook.com/intern/diff/)

Differential Revision: [D45028732](https://our.internmc.facebook.com/intern/diff/D45028732)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99232
Approved by: https://github.com/yifuwang
2023-04-16 00:07:27 +00:00
Shen Li
544cd8e134 [SPMD] Refactor DSize to DSymInt to enable sym_numel (#99206)
This commit uses `aten.arange.default` and `aten.arange.start` to
test `aten.sym_numel`.

Differential Revision: [](https://our.internmc.facebook.com/intern/diff/)

Differential Revision: [D45028715](https://our.internmc.facebook.com/intern/diff/D45028715)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99206
Approved by: https://github.com/yifuwang
2023-04-16 00:07:21 +00:00
Shen Li
bafb984022 [SPMD] Enable aten.full.default with SymInt on sharded dims (#99190)
Differential Revision: [D45028686](https://our.internmc.facebook.com/intern/diff/D45028686)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99190
Approved by: https://github.com/yifuwang
2023-04-16 00:07:18 +00:00
Rodrigo Kumpera
a910045add [PATCH] Back out "Move functional collectives implementation to python. (#98595) (#99168)
Summary:
Original commit changeset: ba36f8751adc

Original Phabricator Diff: D44788697

Test Plan: model loading is fine after reverting the diff

Reviewed By: zyan0, sayitmemory

Differential Revision: D44921259
---

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99168
Approved by: https://github.com/izaitsevfb
2023-04-14 23:48:19 +00:00
Shen Li
40aaacd4fa Respect sharded dimensions when aten expaned/view consumes SymInt values (#99058)
Currently, aten.expand always expands to the global dimension. Then, it
introduces additional slice and clone ops before running compute on
the expanded tensor with a local tensor.

In this commit, if we detect the op consumes a SymInt size, it respects
both local size and the dimension placements from where the SymInt was
extracted.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99058
Approved by: https://github.com/wanchaol
2023-04-14 13:54:05 +00:00
Chien-Chin Huang
bce7308881 [SPMD] Upstream partial_lower (#99069)
Several ops cannot be lowered to the Inductor. This PR copies the internal implementation of partial_lower (credit to @yifuwang ) to torch.distributed._spmd to unblock the OSS usage. The internal version will be kept until it is mature and will replace this version.

Differential Revision: [D44970278](https://our.internmc.facebook.com/intern/diff/D44970278/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99069
Approved by: https://github.com/mrshenli, https://github.com/lessw2020
2023-04-14 08:32:05 +00:00
Shen Li
75f55ca63b Support FQN as SPMD module override key (#98966)
Differential Revision: [D44940232](https://our.internmc.facebook.com/intern/diff/D44940232)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98966
Approved by: https://github.com/wanchaol, https://github.com/fegin
2023-04-13 00:45:48 +00:00
Chien-Chin Huang
07a1378f52 [SPMD] Introduce schedule_comm_wait (#98578)
`schedule_comm_wait` delays the wait_tensor ops as late as possible. Note that this optimization currently does not reorder the computation ops. For `foreach` based optimizer, we observe that reordering the computation ops is required to achieve a good performance.

Differential Revision: [D44761487](https://our.internmc.facebook.com/intern/diff/D44761487/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98578
Approved by: https://github.com/mrshenli
2023-04-12 00:51:19 +00:00
Chien-Chin Huang
dd3e2ddc0a [SPMD] Introduce graph_optimization_pass and comm_fusion_with_cat (#98285)
This PR add `graph_optimization_pass` decorator which should be wrapped by all graph optimization passes. This PR also introduces the first graph optimization, `comm_fusion_with_cat`, as the first use case of `graph_optimization_pass`.

Differential Revision: [D44661608](https://our.internmc.facebook.com/intern/diff/D44661608/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98285
Approved by: https://github.com/yifuwang
2023-04-12 00:51:16 +00:00
Chien-Chin Huang
2de67eaaee [SPMD] Add a dump_graphs_to_files utils to facilitate graph transformation debug (#98284)
Throughout the compilation, there are multiple graphs that will be generated.  This PR add an utils to dump the result graphs to a folder.

Differential Revision: [D44661599](https://our.internmc.facebook.com/intern/diff/D44661599/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98284
Approved by: https://github.com/mrshenli
2023-04-11 23:14:12 +00:00
Chien-Chin Huang
06c206cea3 [SPMD] Add the default graph module transformation that is applied after tracing and expansion (#98182)
This PR adds the GraphModuleTransformation class that can be used as the
default transformation after the `train_step()` is traced and expand. The
current implementation includes:
1. Wrap the input graph module with IterGraphModule. This will enable the futher graph optimizations which are all implemented based on IterGraphModule.
2. Ability to lower the graph module to the Inductor. To achieve this goal, `lower_to_inductor()` is implemented.

TODO:
1. The `override` and `gm_transofmation` have overlapping functions -- `override.transform` can be used to achieve the same function as `gm_transformation`. However, the current semantics of `override` is to override and transform partial graphs while `gm_transformation` is to transform the entire expaned GM. The final UX of `compile()` needs some discussion.

2. The current `lower_to_inductor()` assumes that the entire graph can be lowered to Inductor. This assumption is okay for integration of graph optimizations but is too restrictive for many models. We should upstream `partial_lowering()`.

Differential Revision: [D44616783](https://our.internmc.facebook.com/intern/diff/D44616783/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98182
Approved by: https://github.com/mrshenli
2023-04-11 21:12:49 +00:00
Shen Li
3fcc5ff0d6 Avoid passing buffers to optimizers during spmd rematerialization (#98714)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98714
Approved by: https://github.com/fegin
2023-04-10 17:09:15 +00:00
Shen Li
54b168484d Support LayerNorm without weight or bias parameters (#98687)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98687
Approved by: https://github.com/yifuwang
2023-04-09 02:13:10 +00:00
Shen Li
1be3549a27 Enable replicated embedding in SPMD for NLP models (#98686)
For models like NanoGPT, embeddings are replicated and input ids
are sharded. In this case, output lookups should be sharded to
match ids.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98686
Approved by: https://github.com/yifuwang
2023-04-09 02:13:10 +00:00
Shen Li
d255c8e1ad Add NLLLoss to DTensor prop rule (#98512)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98512
Approved by: https://github.com/wanchaol
2023-04-08 01:22:36 +00:00