mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[DTensor] Fix local_map with multi-threading (#149070)
Using `nonlocal device_mesh` is not safe with multi-threading Pull Request resolved: https://github.com/pytorch/pytorch/pull/149070 Approved by: https://github.com/wanchaol
This commit is contained in:
parent
df60500ab8
commit
a8b1767ae5
|
|
@ -1,5 +1,6 @@
|
|||
# mypy: allow-untyped-defs
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
import functools
|
||||
from collections.abc import Sequence
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
|
|
@ -126,7 +127,7 @@ def local_map(
|
|||
.. note:: This API is currently experimental and subject to change
|
||||
"""
|
||||
|
||||
def wrapped(*args, **kwargs):
|
||||
def wrapped(device_mesh: Optional[DeviceMesh], *args, **kwargs):
|
||||
# process input args
|
||||
flat_args, args_spec = pytree.tree_flatten(args)
|
||||
if in_placements is not None:
|
||||
|
|
@ -137,7 +138,6 @@ def local_map(
|
|||
|
||||
# we assume every DTensor object is placed on the same device mesh
|
||||
flat_local_args = []
|
||||
nonlocal device_mesh # access var device_mesh from the outer scope
|
||||
seen_dtensor_arg = False
|
||||
for idx, arg in enumerate(flat_args):
|
||||
if isinstance(arg, DTensor):
|
||||
|
|
@ -232,4 +232,4 @@ def local_map(
|
|||
else:
|
||||
return out
|
||||
|
||||
return wrapped
|
||||
return functools.partial(wrapped, device_mesh)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user