[HigherOrderOp] change signature of map_impl (#117161)

Summary:
X-link: https://github.com/pytorch/executorch/pull/1580

This PR changes the schema of map_impl from map_impl(f, num_mapped, *operands) to map_impl(f, mapped_args: Tuple, moperands: Tuple). This is to prepare for turning on dynamo for eager mode map, where we want to get rid of the num_mapped scalar.

Test Plan: Existing tests.

Differential Revision: D52495413

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117161
Approved by: https://github.com/angelayi, https://github.com/tugsbayasgalan
This commit is contained in:
Yidi Wu 2024-01-13 02:50:46 +00:00 committed by PyTorch MergeBot
parent f2f47c6848
commit 2bc7da1ab7
6 changed files with 66 additions and 59 deletions

View File

@ -1121,7 +1121,7 @@ def forward(self, L_xs_ : torch.Tensor, L_y_ : torch.Tensor):
l_xs_ = L_xs_
l_y_ = L_y_
map_body_1 = self.map_body_1
map_impl = torch.ops.higher_order.map_impl(map_body_1, 1, l_xs_, l_y_); map_body_1 = l_xs_ = l_y_ = None
map_impl = torch.ops.higher_order.map_impl(map_body_1, [l_xs_], [l_y_]); map_body_1 = l_xs_ = l_y_ = None
getitem_1 = map_impl[0]; map_impl = None
return (getitem_1,)""",
)
@ -1131,7 +1131,7 @@ def forward(self, L_xs_ : torch.Tensor, L_y_ : torch.Tensor):
def forward(self, getitem, l_y_):
getitem_1 = getitem[0]
map_body_0 = self.map_body_0
map_impl = torch.ops.higher_order.map_impl(map_body_0, 1, getitem, l_y_); map_body_0 = getitem = l_y_ = None
map_impl = torch.ops.higher_order.map_impl(map_body_0, [getitem], [l_y_]); map_body_0 = getitem = l_y_ = None
getitem_2 = map_impl[0]; map_impl = None
return (getitem_2,)""",
)
@ -1152,7 +1152,7 @@ def forward(self, getitem, l_y_):
def forward(self, L_x_ : torch.Tensor):
l_x_ = L_x_
map_body_0 = self.map_body_0
map_impl = torch.ops.higher_order.map_impl(map_body_0, 1, l_x_); map_body_0 = l_x_ = None
map_impl = torch.ops.higher_order.map_impl(map_body_0, [l_x_], []); map_body_0 = l_x_ = None
getitem_1 = map_impl[0]
getitem_2 = map_impl[1]; map_impl = None
return (getitem_1, getitem_2)""",
@ -1188,7 +1188,7 @@ def forward(self, getitem):
def forward(self, L_x_ : torch.Tensor):
l_x_ = L_x_
map_body_0 = self.map_body_0
map_impl = torch.ops.higher_order.map_impl(map_body_0, 1, l_x_); map_body_0 = l_x_ = None
map_impl = torch.ops.higher_order.map_impl(map_body_0, [l_x_], []); map_body_0 = l_x_ = None
getitem_1 = map_impl[0]
getitem_2 = map_impl[1]
getitem_3 = map_impl[2]
@ -1237,7 +1237,7 @@ def forward(self, getitem):
def forward(self, L_x_ : torch.Tensor):
l_x_ = L_x_
map_body_0 = self.map_body_0
map_impl = torch.ops.higher_order.map_impl(map_body_0, 1, l_x_, 3); map_body_0 = l_x_ = None
map_impl = torch.ops.higher_order.map_impl(map_body_0, [l_x_], [3]); map_body_0 = l_x_ = None
getitem_1 = map_impl[0]; map_impl = None
return (getitem_1,)""",
)
@ -1271,7 +1271,7 @@ def forward(self, getitem, const):
def forward(self, L_x_ : torch.Tensor):
l_x_ = L_x_
map_body_0 = self.map_body_0
map_impl = torch.ops.higher_order.map_impl(map_body_0, 1, l_x_, 3); map_body_0 = l_x_ = None
map_impl = torch.ops.higher_order.map_impl(map_body_0, [l_x_], [3]); map_body_0 = l_x_ = None
getitem_1 = map_impl[0]; map_impl = None
return (getitem_1,)""",
)

View File

@ -1451,8 +1451,7 @@ def forward(self, arg0_1, arg1_1, arg2_1):
self.assertExpectedInline(gm.code.strip(), """\
def forward(self, pred_1, x_1):
body_graph_0 = self.body_graph_0
map_impl = torch.ops.higher_order.map_impl(body_graph_0, 1, x_1, pred_1);\
body_graph_0 = x_1 = pred_1 = None
map_impl = torch.ops.higher_order.map_impl(body_graph_0, [x_1], [pred_1]); body_graph_0 = x_1 = pred_1 = None
getitem = map_impl[0]; map_impl = None
return getitem""")
self.assertExpectedInline(gm.body_graph_0.code.strip(), """\

View File

@ -792,8 +792,8 @@ class MapHigherOrderVariable(TorchHigherOrderOperatorVariable):
p_args = (
body_node,
1, # right now we only supports num_mapped = 1
*([arg.as_proxy() for arg in args[1:]] + list(body_lifted_freevars.keys())),
[args[1].as_proxy()],
[arg.as_proxy() for arg in args[2:]] + list(body_lifted_freevars.keys()),
)
return _call_function_and_unflatten_output(
tx, torch.ops.higher_order.map_impl, p_args, {}, body_r, body_spec

View File

@ -200,8 +200,8 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
pred, true_fn, false_fn, inputs = args
return self.callback.call_cond(pred, true_fn, false_fn, inputs, meta)
elif target == torch.ops.higher_order.map_impl:
f, num_args, *rest = args # type: ignore[assignment]
return self.callback.call_map(f, num_args, list(rest), meta)
f, mapped_args, operands = args # type: ignore[assignment]
return self.callback.call_map(f, mapped_args, operands, meta)
# For other unregistered HigherOrderOps, just interpret them blindly
elif isinstance(target, torch._ops.HigherOrderOperator):
return self.callback._fx(
@ -357,18 +357,17 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
def call_map(
self,
f: torch.fx.GraphModule,
num_args: int,
args: List[ProxyValue],
mapped_args: List[ProxyValue],
operands: List[ProxyValue],
meta: NodeMetadata,
) -> ProxyValue:
xs = _unstack_pytree([arg.data for arg in args[:num_args]])[0]
pos_args = args[num_args:]
f_branch = self.call_submodule(f, tuple(xs + [arg.data for arg in pos_args]))
xs = _unstack_pytree([arg.data for arg in mapped_args])[0]
f_branch = self.call_submodule(f, tuple(xs + [arg.data for arg in operands]))
assert f_branch is not None
return self._fx(
"call_function",
torch.ops.higher_order.map_impl,
(f_branch.graph_module, num_args, *args),
(f_branch.graph_module, mapped_args, operands),
{},
meta,
)

View File

@ -194,7 +194,7 @@ def map_wrapper(f, xs, *args):
return flat_out
return pytree.tree_unflatten(
map_impl(flat_fn, num_mapped_args, *flat_xs, *args), out_spec # type: ignore[arg-type]
map_impl(flat_fn, flat_xs, args), out_spec # type: ignore[arg-type]
)
@ -205,7 +205,11 @@ class MapAutogradOp(torch.autograd.Function):
ctx._joint_graph = joint_graph
ctx._num_mapped_args = num_mapped_args
with torch._C._AutoDispatchBelowAutograd():
return (*map_impl(fw_graph, num_mapped_args, *flat_args),)
return (
*map_impl(
fw_graph, flat_args[:num_mapped_args], flat_args[num_mapped_args:]
),
)
@staticmethod
def backward(ctx, *flat_grads):
@ -215,17 +219,13 @@ class MapAutogradOp(torch.autograd.Function):
grads = map_impl(
ctx._joint_graph,
ctx._num_mapped_args + len(flat_grads),
*fw_mapped_args,
*flat_grads,
*pos_args,
fw_mapped_args + flat_grads,
pos_args,
)
return None, None, None, *grads
def trace_map(proxy_mode, func_overload, f, num_mapped, *args):
xs = list(args[:num_mapped])
pos_args = list(args[num_mapped:])
def trace_map(proxy_mode, func_overload, f, xs, pos_args):
leading_dim_size = xs[0].shape[0]
example_input = _unstack_pytree(xs)[0]
@ -254,7 +254,7 @@ def trace_map(proxy_mode, func_overload, f, num_mapped, *args):
expanded_outs = pytree.tree_map(expand_tensor, example_outs)
node_args = (body_graph, num_mapped, *args)
node_args = (body_graph, list(xs), list(pos_args))
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args)
out_proxy = proxy_mode.tracer.create_proxy(
"call_function", func_overload, proxy_args, {}, name="map_impl"
@ -311,9 +311,7 @@ def _stack_pytree(pytrees):
@map_impl.py_impl(DispatchKey.CompositeExplicitAutograd)
def map_dense(f, num_mapped_args, *args):
xs = args[:num_mapped_args]
pos_args = args[num_mapped_args:]
def map_dense(f, xs, pos_args):
pytrees = []
for inp in _unstack_pytree(xs):
pytrees.append(f(*inp, *pos_args))
@ -321,30 +319,29 @@ def map_dense(f, num_mapped_args, *args):
@map_impl.py_impl(DispatchKey.Autograd)
def map_autograd(f, num_mapped_args, *args):
fw_graph, bw_graph = create_fw_bw_graph(f, num_mapped_args, *args)
flat_out = MapAutogradOp.apply(fw_graph, bw_graph, num_mapped_args, *args)
def map_autograd(f, xs, pos_args):
num_mapped_args = len(xs)
fw_graph, bw_graph = create_fw_bw_graph(f, num_mapped_args, *xs, *pos_args)
flat_out = MapAutogradOp.apply(fw_graph, bw_graph, num_mapped_args, *xs, *pos_args)
return flat_out
@map_impl.py_impl(ProxyTorchDispatchMode)
def map_proxy_torch_dispatch_mode(mode, f, num_mapped, *args):
def map_proxy_torch_dispatch_mode(mode, f, xs, args):
if mode.enable_tracing:
return trace_map(mode, map_impl, f, num_mapped, *args)
return trace_map(mode, map_impl, f, xs, args)
else:
return map_impl(f, num_mapped, *args)
return map_impl(f, xs, args)
@map_impl.py_impl(FakeTensorMode)
def map_fake_tensor_mode(mode, f, num_mapped, *args):
def map_fake_tensor_mode(mode, f, xs, args):
with mode:
return map_dense(f, num_mapped, *args)
return map_dense(f, xs, args)
@map_impl.py_functionalize_impl
def map_functionalize(ctx, f, num_mapped, *args):
xs = args[:num_mapped]
pos_args = args[num_mapped:]
def map_functionalize(ctx, f, xs, pos_args):
unwrapped_xs = ctx.unwrap_tensors(xs)
unwrapped_args = ctx.unwrap_tensors(pos_args)
wrapped_fn = ctx.functionalize(f)
@ -358,5 +355,5 @@ def map_functionalize(ctx, f, num_mapped, *args):
if _has_potential_branch_input_alias(f, example_inputs):
raise UnsupportedAliasMutationException("torch.map is aliasing the input!")
map_return = map_impl(wrapped_fn, num_mapped, *unwrapped_xs, *unwrapped_args)
map_return = map_impl(wrapped_fn, unwrapped_xs, unwrapped_args)
return ctx.wrap_tensors(map_return)

View File

@ -137,29 +137,41 @@ def _unlift(
buffers_to_mutate,
)
if node.op == "call_function" and node.target.__name__ == "map_impl":
body_graph, num_mapped, *operands = node.args
body_graph, mapped_args, operands = node.args
num_mapped = len(mapped_args)
body_gm = getattr(gm, body_graph.name)
inp_pos_to_buffer_name_for_submod = {}
real_operands = []
# TODO Fix situation here to replace dot with underscore...
state_dict_for_lookup = {
key.replace(".", "_"): value for key, value in state_dict.items()
}
for ix, operand in enumerate(operands):
if operand.target in inp_pos_to_param_buffer_name.values():
inp_pos_to_buffer_name_for_submod[ix] = operand.target
if operand.target in state_dict_for_lookup:
value = state_dict_for_lookup[operand.target]
elif operand.target in tensor_constants:
value = tensor_constants[operand.target]
else:
raise RuntimeError(f"Unable to find value for {operand.target}")
body_gm.register_buffer(operand.target, value)
else:
real_operands.append(operand)
node.args = (body_graph, num_mapped, *real_operands)
_, in_spec = pytree.tree_flatten(real_operands)
def _find_real_operands(operands, start_ix):
real_operands = []
for ix, operand in enumerate(operands):
if operand.target in inp_pos_to_param_buffer_name.values():
inp_pos_to_buffer_name_for_submod[
ix + start_ix
] = operand.target
if operand.target in state_dict_for_lookup:
value = state_dict_for_lookup[operand.target]
elif operand.target in tensor_constants:
value = tensor_constants[operand.target]
else:
raise RuntimeError(
f"Unable to find value for {operand.target}"
)
body_gm.register_buffer(operand.target, value)
else:
real_operands.append(operand)
return real_operands
real_mapped_args = _find_real_operands(mapped_args, 0)
real_mapped_operands = _find_real_operands(operands, num_mapped)
node.args = (body_graph, real_mapped_args, real_mapped_operands)
_, in_spec = pytree.tree_flatten(real_mapped_args + real_mapped_operands)
_unlift(
body_gm,