From a8b1767ae5b37f4473f09fe575f9216a025a3e60 Mon Sep 17 00:00:00 2001 From: Andrew Gu Date: Thu, 13 Mar 2025 10:58:59 +0000 Subject: [PATCH] [DTensor] Fix `local_map` with multi-threading (#149070) Using `nonlocal device_mesh` is not safe with multi-threading Pull Request resolved: https://github.com/pytorch/pytorch/pull/149070 Approved by: https://github.com/wanchaol --- torch/distributed/tensor/experimental/_func_map.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torch/distributed/tensor/experimental/_func_map.py b/torch/distributed/tensor/experimental/_func_map.py index 51861141af5..d0f94d36cae 100644 --- a/torch/distributed/tensor/experimental/_func_map.py +++ b/torch/distributed/tensor/experimental/_func_map.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates +import functools from collections.abc import Sequence from typing import Callable, Optional, Union @@ -126,7 +127,7 @@ def local_map( .. note:: This API is currently experimental and subject to change """ - def wrapped(*args, **kwargs): + def wrapped(device_mesh: Optional[DeviceMesh], *args, **kwargs): # process input args flat_args, args_spec = pytree.tree_flatten(args) if in_placements is not None: @@ -137,7 +138,6 @@ def local_map( # we assume every DTensor object is placed on the same device mesh flat_local_args = [] - nonlocal device_mesh # access var device_mesh from the outer scope seen_dtensor_arg = False for idx, arg in enumerate(flat_args): if isinstance(arg, DTensor): @@ -232,4 +232,4 @@ def local_map( else: return out - return wrapped + return functools.partial(wrapped, device_mesh)