[DTensor] Fix local_map with multi-threading (#149070)

Using `nonlocal device_mesh` is not safe with multi-threading

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149070
Approved by: https://github.com/wanchaol
This commit is contained in:
Andrew Gu 2025-03-13 10:58:59 +00:00 committed by PyTorch MergeBot
parent df60500ab8
commit a8b1767ae5

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
import functools
from collections.abc import Sequence
from typing import Callable, Optional, Union
@ -126,7 +127,7 @@ def local_map(
.. note:: This API is currently experimental and subject to change
"""
def wrapped(*args, **kwargs):
def wrapped(device_mesh: Optional[DeviceMesh], *args, **kwargs):
# process input args
flat_args, args_spec = pytree.tree_flatten(args)
if in_placements is not None:
@ -137,7 +138,6 @@ def local_map(
# we assume every DTensor object is placed on the same device mesh
flat_local_args = []
nonlocal device_mesh # access var device_mesh from the outer scope
seen_dtensor_arg = False
for idx, arg in enumerate(flat_args):
if isinstance(arg, DTensor):
@ -232,4 +232,4 @@ def local_map(
else:
return out
return wrapped
return functools.partial(wrapped, device_mesh)