[HigherOrderOp] change signature of map_impl (#117161)

Summary:
X-link: https://github.com/pytorch/executorch/pull/1580

This PR changes the schema of map_impl from map_impl(f, num_mapped, *operands) to map_impl(f, mapped_args: Tuple, moperands: Tuple). This is to prepare for turning on dynamo for eager mode map, where we want to get rid of the num_mapped scalar.

Test Plan: Existing tests.

Differential Revision: D52495413

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117161
Approved by: https://github.com/angelayi, https://github.com/tugsbayasgalan
This commit is contained in:
Yidi Wu 2024-01-13 02:50:46 +00:00 committed by PyTorch MergeBot
parent f2f47c6848
commit 2bc7da1ab7
6 changed files with 66 additions and 59 deletions

View File

@ -1121,7 +1121,7 @@ def forward(self, L_xs_ : torch.Tensor, L_y_ : torch.Tensor):
l_xs_ = L_xs_ l_xs_ = L_xs_
l_y_ = L_y_ l_y_ = L_y_
map_body_1 = self.map_body_1 map_body_1 = self.map_body_1
map_impl = torch.ops.higher_order.map_impl(map_body_1, 1, l_xs_, l_y_); map_body_1 = l_xs_ = l_y_ = None map_impl = torch.ops.higher_order.map_impl(map_body_1, [l_xs_], [l_y_]); map_body_1 = l_xs_ = l_y_ = None
getitem_1 = map_impl[0]; map_impl = None getitem_1 = map_impl[0]; map_impl = None
return (getitem_1,)""", return (getitem_1,)""",
) )
@ -1131,7 +1131,7 @@ def forward(self, L_xs_ : torch.Tensor, L_y_ : torch.Tensor):
def forward(self, getitem, l_y_): def forward(self, getitem, l_y_):
getitem_1 = getitem[0] getitem_1 = getitem[0]
map_body_0 = self.map_body_0 map_body_0 = self.map_body_0
map_impl = torch.ops.higher_order.map_impl(map_body_0, 1, getitem, l_y_); map_body_0 = getitem = l_y_ = None map_impl = torch.ops.higher_order.map_impl(map_body_0, [getitem], [l_y_]); map_body_0 = getitem = l_y_ = None
getitem_2 = map_impl[0]; map_impl = None getitem_2 = map_impl[0]; map_impl = None
return (getitem_2,)""", return (getitem_2,)""",
) )
@ -1152,7 +1152,7 @@ def forward(self, getitem, l_y_):
def forward(self, L_x_ : torch.Tensor): def forward(self, L_x_ : torch.Tensor):
l_x_ = L_x_ l_x_ = L_x_
map_body_0 = self.map_body_0 map_body_0 = self.map_body_0
map_impl = torch.ops.higher_order.map_impl(map_body_0, 1, l_x_); map_body_0 = l_x_ = None map_impl = torch.ops.higher_order.map_impl(map_body_0, [l_x_], []); map_body_0 = l_x_ = None
getitem_1 = map_impl[0] getitem_1 = map_impl[0]
getitem_2 = map_impl[1]; map_impl = None getitem_2 = map_impl[1]; map_impl = None
return (getitem_1, getitem_2)""", return (getitem_1, getitem_2)""",
@ -1188,7 +1188,7 @@ def forward(self, getitem):
def forward(self, L_x_ : torch.Tensor): def forward(self, L_x_ : torch.Tensor):
l_x_ = L_x_ l_x_ = L_x_
map_body_0 = self.map_body_0 map_body_0 = self.map_body_0
map_impl = torch.ops.higher_order.map_impl(map_body_0, 1, l_x_); map_body_0 = l_x_ = None map_impl = torch.ops.higher_order.map_impl(map_body_0, [l_x_], []); map_body_0 = l_x_ = None
getitem_1 = map_impl[0] getitem_1 = map_impl[0]
getitem_2 = map_impl[1] getitem_2 = map_impl[1]
getitem_3 = map_impl[2] getitem_3 = map_impl[2]
@ -1237,7 +1237,7 @@ def forward(self, getitem):
def forward(self, L_x_ : torch.Tensor): def forward(self, L_x_ : torch.Tensor):
l_x_ = L_x_ l_x_ = L_x_
map_body_0 = self.map_body_0 map_body_0 = self.map_body_0
map_impl = torch.ops.higher_order.map_impl(map_body_0, 1, l_x_, 3); map_body_0 = l_x_ = None map_impl = torch.ops.higher_order.map_impl(map_body_0, [l_x_], [3]); map_body_0 = l_x_ = None
getitem_1 = map_impl[0]; map_impl = None getitem_1 = map_impl[0]; map_impl = None
return (getitem_1,)""", return (getitem_1,)""",
) )
@ -1271,7 +1271,7 @@ def forward(self, getitem, const):
def forward(self, L_x_ : torch.Tensor): def forward(self, L_x_ : torch.Tensor):
l_x_ = L_x_ l_x_ = L_x_
map_body_0 = self.map_body_0 map_body_0 = self.map_body_0
map_impl = torch.ops.higher_order.map_impl(map_body_0, 1, l_x_, 3); map_body_0 = l_x_ = None map_impl = torch.ops.higher_order.map_impl(map_body_0, [l_x_], [3]); map_body_0 = l_x_ = None
getitem_1 = map_impl[0]; map_impl = None getitem_1 = map_impl[0]; map_impl = None
return (getitem_1,)""", return (getitem_1,)""",
) )

View File

@ -1451,8 +1451,7 @@ def forward(self, arg0_1, arg1_1, arg2_1):
self.assertExpectedInline(gm.code.strip(), """\ self.assertExpectedInline(gm.code.strip(), """\
def forward(self, pred_1, x_1): def forward(self, pred_1, x_1):
body_graph_0 = self.body_graph_0 body_graph_0 = self.body_graph_0
map_impl = torch.ops.higher_order.map_impl(body_graph_0, 1, x_1, pred_1);\ map_impl = torch.ops.higher_order.map_impl(body_graph_0, [x_1], [pred_1]); body_graph_0 = x_1 = pred_1 = None
body_graph_0 = x_1 = pred_1 = None
getitem = map_impl[0]; map_impl = None getitem = map_impl[0]; map_impl = None
return getitem""") return getitem""")
self.assertExpectedInline(gm.body_graph_0.code.strip(), """\ self.assertExpectedInline(gm.body_graph_0.code.strip(), """\

View File

@ -792,8 +792,8 @@ class MapHigherOrderVariable(TorchHigherOrderOperatorVariable):
p_args = ( p_args = (
body_node, body_node,
1, # right now we only supports num_mapped = 1 [args[1].as_proxy()],
*([arg.as_proxy() for arg in args[1:]] + list(body_lifted_freevars.keys())), [arg.as_proxy() for arg in args[2:]] + list(body_lifted_freevars.keys()),
) )
return _call_function_and_unflatten_output( return _call_function_and_unflatten_output(
tx, torch.ops.higher_order.map_impl, p_args, {}, body_r, body_spec tx, torch.ops.higher_order.map_impl, p_args, {}, body_r, body_spec

View File

@ -200,8 +200,8 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
pred, true_fn, false_fn, inputs = args pred, true_fn, false_fn, inputs = args
return self.callback.call_cond(pred, true_fn, false_fn, inputs, meta) return self.callback.call_cond(pred, true_fn, false_fn, inputs, meta)
elif target == torch.ops.higher_order.map_impl: elif target == torch.ops.higher_order.map_impl:
f, num_args, *rest = args # type: ignore[assignment] f, mapped_args, operands = args # type: ignore[assignment]
return self.callback.call_map(f, num_args, list(rest), meta) return self.callback.call_map(f, mapped_args, operands, meta)
# For other unregistered HigherOrderOps, just interpret them blindly # For other unregistered HigherOrderOps, just interpret them blindly
elif isinstance(target, torch._ops.HigherOrderOperator): elif isinstance(target, torch._ops.HigherOrderOperator):
return self.callback._fx( return self.callback._fx(
@ -357,18 +357,17 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
def call_map( def call_map(
self, self,
f: torch.fx.GraphModule, f: torch.fx.GraphModule,
num_args: int, mapped_args: List[ProxyValue],
args: List[ProxyValue], operands: List[ProxyValue],
meta: NodeMetadata, meta: NodeMetadata,
) -> ProxyValue: ) -> ProxyValue:
xs = _unstack_pytree([arg.data for arg in args[:num_args]])[0] xs = _unstack_pytree([arg.data for arg in mapped_args])[0]
pos_args = args[num_args:] f_branch = self.call_submodule(f, tuple(xs + [arg.data for arg in operands]))
f_branch = self.call_submodule(f, tuple(xs + [arg.data for arg in pos_args]))
assert f_branch is not None assert f_branch is not None
return self._fx( return self._fx(
"call_function", "call_function",
torch.ops.higher_order.map_impl, torch.ops.higher_order.map_impl,
(f_branch.graph_module, num_args, *args), (f_branch.graph_module, mapped_args, operands),
{}, {},
meta, meta,
) )

View File

@ -194,7 +194,7 @@ def map_wrapper(f, xs, *args):
return flat_out return flat_out
return pytree.tree_unflatten( return pytree.tree_unflatten(
map_impl(flat_fn, num_mapped_args, *flat_xs, *args), out_spec # type: ignore[arg-type] map_impl(flat_fn, flat_xs, args), out_spec # type: ignore[arg-type]
) )
@ -205,7 +205,11 @@ class MapAutogradOp(torch.autograd.Function):
ctx._joint_graph = joint_graph ctx._joint_graph = joint_graph
ctx._num_mapped_args = num_mapped_args ctx._num_mapped_args = num_mapped_args
with torch._C._AutoDispatchBelowAutograd(): with torch._C._AutoDispatchBelowAutograd():
return (*map_impl(fw_graph, num_mapped_args, *flat_args),) return (
*map_impl(
fw_graph, flat_args[:num_mapped_args], flat_args[num_mapped_args:]
),
)
@staticmethod @staticmethod
def backward(ctx, *flat_grads): def backward(ctx, *flat_grads):
@ -215,17 +219,13 @@ class MapAutogradOp(torch.autograd.Function):
grads = map_impl( grads = map_impl(
ctx._joint_graph, ctx._joint_graph,
ctx._num_mapped_args + len(flat_grads), fw_mapped_args + flat_grads,
*fw_mapped_args, pos_args,
*flat_grads,
*pos_args,
) )
return None, None, None, *grads return None, None, None, *grads
def trace_map(proxy_mode, func_overload, f, num_mapped, *args): def trace_map(proxy_mode, func_overload, f, xs, pos_args):
xs = list(args[:num_mapped])
pos_args = list(args[num_mapped:])
leading_dim_size = xs[0].shape[0] leading_dim_size = xs[0].shape[0]
example_input = _unstack_pytree(xs)[0] example_input = _unstack_pytree(xs)[0]
@ -254,7 +254,7 @@ def trace_map(proxy_mode, func_overload, f, num_mapped, *args):
expanded_outs = pytree.tree_map(expand_tensor, example_outs) expanded_outs = pytree.tree_map(expand_tensor, example_outs)
node_args = (body_graph, num_mapped, *args) node_args = (body_graph, list(xs), list(pos_args))
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args) proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args)
out_proxy = proxy_mode.tracer.create_proxy( out_proxy = proxy_mode.tracer.create_proxy(
"call_function", func_overload, proxy_args, {}, name="map_impl" "call_function", func_overload, proxy_args, {}, name="map_impl"
@ -311,9 +311,7 @@ def _stack_pytree(pytrees):
@map_impl.py_impl(DispatchKey.CompositeExplicitAutograd) @map_impl.py_impl(DispatchKey.CompositeExplicitAutograd)
def map_dense(f, num_mapped_args, *args): def map_dense(f, xs, pos_args):
xs = args[:num_mapped_args]
pos_args = args[num_mapped_args:]
pytrees = [] pytrees = []
for inp in _unstack_pytree(xs): for inp in _unstack_pytree(xs):
pytrees.append(f(*inp, *pos_args)) pytrees.append(f(*inp, *pos_args))
@ -321,30 +319,29 @@ def map_dense(f, num_mapped_args, *args):
@map_impl.py_impl(DispatchKey.Autograd) @map_impl.py_impl(DispatchKey.Autograd)
def map_autograd(f, num_mapped_args, *args): def map_autograd(f, xs, pos_args):
fw_graph, bw_graph = create_fw_bw_graph(f, num_mapped_args, *args) num_mapped_args = len(xs)
flat_out = MapAutogradOp.apply(fw_graph, bw_graph, num_mapped_args, *args) fw_graph, bw_graph = create_fw_bw_graph(f, num_mapped_args, *xs, *pos_args)
flat_out = MapAutogradOp.apply(fw_graph, bw_graph, num_mapped_args, *xs, *pos_args)
return flat_out return flat_out
@map_impl.py_impl(ProxyTorchDispatchMode) @map_impl.py_impl(ProxyTorchDispatchMode)
def map_proxy_torch_dispatch_mode(mode, f, num_mapped, *args): def map_proxy_torch_dispatch_mode(mode, f, xs, args):
if mode.enable_tracing: if mode.enable_tracing:
return trace_map(mode, map_impl, f, num_mapped, *args) return trace_map(mode, map_impl, f, xs, args)
else: else:
return map_impl(f, num_mapped, *args) return map_impl(f, xs, args)
@map_impl.py_impl(FakeTensorMode) @map_impl.py_impl(FakeTensorMode)
def map_fake_tensor_mode(mode, f, num_mapped, *args): def map_fake_tensor_mode(mode, f, xs, args):
with mode: with mode:
return map_dense(f, num_mapped, *args) return map_dense(f, xs, args)
@map_impl.py_functionalize_impl @map_impl.py_functionalize_impl
def map_functionalize(ctx, f, num_mapped, *args): def map_functionalize(ctx, f, xs, pos_args):
xs = args[:num_mapped]
pos_args = args[num_mapped:]
unwrapped_xs = ctx.unwrap_tensors(xs) unwrapped_xs = ctx.unwrap_tensors(xs)
unwrapped_args = ctx.unwrap_tensors(pos_args) unwrapped_args = ctx.unwrap_tensors(pos_args)
wrapped_fn = ctx.functionalize(f) wrapped_fn = ctx.functionalize(f)
@ -358,5 +355,5 @@ def map_functionalize(ctx, f, num_mapped, *args):
if _has_potential_branch_input_alias(f, example_inputs): if _has_potential_branch_input_alias(f, example_inputs):
raise UnsupportedAliasMutationException("torch.map is aliasing the input!") raise UnsupportedAliasMutationException("torch.map is aliasing the input!")
map_return = map_impl(wrapped_fn, num_mapped, *unwrapped_xs, *unwrapped_args) map_return = map_impl(wrapped_fn, unwrapped_xs, unwrapped_args)
return ctx.wrap_tensors(map_return) return ctx.wrap_tensors(map_return)

View File

@ -137,29 +137,41 @@ def _unlift(
buffers_to_mutate, buffers_to_mutate,
) )
if node.op == "call_function" and node.target.__name__ == "map_impl": if node.op == "call_function" and node.target.__name__ == "map_impl":
body_graph, num_mapped, *operands = node.args body_graph, mapped_args, operands = node.args
num_mapped = len(mapped_args)
body_gm = getattr(gm, body_graph.name) body_gm = getattr(gm, body_graph.name)
inp_pos_to_buffer_name_for_submod = {} inp_pos_to_buffer_name_for_submod = {}
real_operands = []
# TODO Fix situation here to replace dot with underscore... # TODO Fix situation here to replace dot with underscore...
state_dict_for_lookup = { state_dict_for_lookup = {
key.replace(".", "_"): value for key, value in state_dict.items() key.replace(".", "_"): value for key, value in state_dict.items()
} }
for ix, operand in enumerate(operands):
if operand.target in inp_pos_to_param_buffer_name.values():
inp_pos_to_buffer_name_for_submod[ix] = operand.target
if operand.target in state_dict_for_lookup:
value = state_dict_for_lookup[operand.target]
elif operand.target in tensor_constants:
value = tensor_constants[operand.target]
else:
raise RuntimeError(f"Unable to find value for {operand.target}")
body_gm.register_buffer(operand.target, value)
else:
real_operands.append(operand)
node.args = (body_graph, num_mapped, *real_operands)
_, in_spec = pytree.tree_flatten(real_operands) def _find_real_operands(operands, start_ix):
real_operands = []
for ix, operand in enumerate(operands):
if operand.target in inp_pos_to_param_buffer_name.values():
inp_pos_to_buffer_name_for_submod[
ix + start_ix
] = operand.target
if operand.target in state_dict_for_lookup:
value = state_dict_for_lookup[operand.target]
elif operand.target in tensor_constants:
value = tensor_constants[operand.target]
else:
raise RuntimeError(
f"Unable to find value for {operand.target}"
)
body_gm.register_buffer(operand.target, value)
else:
real_operands.append(operand)
return real_operands
real_mapped_args = _find_real_operands(mapped_args, 0)
real_mapped_operands = _find_real_operands(operands, num_mapped)
node.args = (body_graph, real_mapped_args, real_mapped_operands)
_, in_spec = pytree.tree_flatten(real_mapped_args + real_mapped_operands)
_unlift( _unlift(
body_gm, body_gm,