mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[HigherOrderOp] change signature of map_impl (#117161)
Summary: X-link: https://github.com/pytorch/executorch/pull/1580 This PR changes the schema of map_impl from map_impl(f, num_mapped, *operands) to map_impl(f, mapped_args: Tuple, moperands: Tuple). This is to prepare for turning on dynamo for eager mode map, where we want to get rid of the num_mapped scalar. Test Plan: Existing tests. Differential Revision: D52495413 Pull Request resolved: https://github.com/pytorch/pytorch/pull/117161 Approved by: https://github.com/angelayi, https://github.com/tugsbayasgalan
This commit is contained in:
parent
f2f47c6848
commit
2bc7da1ab7
|
|
@ -1121,7 +1121,7 @@ def forward(self, L_xs_ : torch.Tensor, L_y_ : torch.Tensor):
|
|||
l_xs_ = L_xs_
|
||||
l_y_ = L_y_
|
||||
map_body_1 = self.map_body_1
|
||||
map_impl = torch.ops.higher_order.map_impl(map_body_1, 1, l_xs_, l_y_); map_body_1 = l_xs_ = l_y_ = None
|
||||
map_impl = torch.ops.higher_order.map_impl(map_body_1, [l_xs_], [l_y_]); map_body_1 = l_xs_ = l_y_ = None
|
||||
getitem_1 = map_impl[0]; map_impl = None
|
||||
return (getitem_1,)""",
|
||||
)
|
||||
|
|
@ -1131,7 +1131,7 @@ def forward(self, L_xs_ : torch.Tensor, L_y_ : torch.Tensor):
|
|||
def forward(self, getitem, l_y_):
|
||||
getitem_1 = getitem[0]
|
||||
map_body_0 = self.map_body_0
|
||||
map_impl = torch.ops.higher_order.map_impl(map_body_0, 1, getitem, l_y_); map_body_0 = getitem = l_y_ = None
|
||||
map_impl = torch.ops.higher_order.map_impl(map_body_0, [getitem], [l_y_]); map_body_0 = getitem = l_y_ = None
|
||||
getitem_2 = map_impl[0]; map_impl = None
|
||||
return (getitem_2,)""",
|
||||
)
|
||||
|
|
@ -1152,7 +1152,7 @@ def forward(self, getitem, l_y_):
|
|||
def forward(self, L_x_ : torch.Tensor):
|
||||
l_x_ = L_x_
|
||||
map_body_0 = self.map_body_0
|
||||
map_impl = torch.ops.higher_order.map_impl(map_body_0, 1, l_x_); map_body_0 = l_x_ = None
|
||||
map_impl = torch.ops.higher_order.map_impl(map_body_0, [l_x_], []); map_body_0 = l_x_ = None
|
||||
getitem_1 = map_impl[0]
|
||||
getitem_2 = map_impl[1]; map_impl = None
|
||||
return (getitem_1, getitem_2)""",
|
||||
|
|
@ -1188,7 +1188,7 @@ def forward(self, getitem):
|
|||
def forward(self, L_x_ : torch.Tensor):
|
||||
l_x_ = L_x_
|
||||
map_body_0 = self.map_body_0
|
||||
map_impl = torch.ops.higher_order.map_impl(map_body_0, 1, l_x_); map_body_0 = l_x_ = None
|
||||
map_impl = torch.ops.higher_order.map_impl(map_body_0, [l_x_], []); map_body_0 = l_x_ = None
|
||||
getitem_1 = map_impl[0]
|
||||
getitem_2 = map_impl[1]
|
||||
getitem_3 = map_impl[2]
|
||||
|
|
@ -1237,7 +1237,7 @@ def forward(self, getitem):
|
|||
def forward(self, L_x_ : torch.Tensor):
|
||||
l_x_ = L_x_
|
||||
map_body_0 = self.map_body_0
|
||||
map_impl = torch.ops.higher_order.map_impl(map_body_0, 1, l_x_, 3); map_body_0 = l_x_ = None
|
||||
map_impl = torch.ops.higher_order.map_impl(map_body_0, [l_x_], [3]); map_body_0 = l_x_ = None
|
||||
getitem_1 = map_impl[0]; map_impl = None
|
||||
return (getitem_1,)""",
|
||||
)
|
||||
|
|
@ -1271,7 +1271,7 @@ def forward(self, getitem, const):
|
|||
def forward(self, L_x_ : torch.Tensor):
|
||||
l_x_ = L_x_
|
||||
map_body_0 = self.map_body_0
|
||||
map_impl = torch.ops.higher_order.map_impl(map_body_0, 1, l_x_, 3); map_body_0 = l_x_ = None
|
||||
map_impl = torch.ops.higher_order.map_impl(map_body_0, [l_x_], [3]); map_body_0 = l_x_ = None
|
||||
getitem_1 = map_impl[0]; map_impl = None
|
||||
return (getitem_1,)""",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1451,8 +1451,7 @@ def forward(self, arg0_1, arg1_1, arg2_1):
|
|||
self.assertExpectedInline(gm.code.strip(), """\
|
||||
def forward(self, pred_1, x_1):
|
||||
body_graph_0 = self.body_graph_0
|
||||
map_impl = torch.ops.higher_order.map_impl(body_graph_0, 1, x_1, pred_1);\
|
||||
body_graph_0 = x_1 = pred_1 = None
|
||||
map_impl = torch.ops.higher_order.map_impl(body_graph_0, [x_1], [pred_1]); body_graph_0 = x_1 = pred_1 = None
|
||||
getitem = map_impl[0]; map_impl = None
|
||||
return getitem""")
|
||||
self.assertExpectedInline(gm.body_graph_0.code.strip(), """\
|
||||
|
|
|
|||
|
|
@ -792,8 +792,8 @@ class MapHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
|||
|
||||
p_args = (
|
||||
body_node,
|
||||
1, # right now we only supports num_mapped = 1
|
||||
*([arg.as_proxy() for arg in args[1:]] + list(body_lifted_freevars.keys())),
|
||||
[args[1].as_proxy()],
|
||||
[arg.as_proxy() for arg in args[2:]] + list(body_lifted_freevars.keys()),
|
||||
)
|
||||
return _call_function_and_unflatten_output(
|
||||
tx, torch.ops.higher_order.map_impl, p_args, {}, body_r, body_spec
|
||||
|
|
|
|||
|
|
@ -200,8 +200,8 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
|
|||
pred, true_fn, false_fn, inputs = args
|
||||
return self.callback.call_cond(pred, true_fn, false_fn, inputs, meta)
|
||||
elif target == torch.ops.higher_order.map_impl:
|
||||
f, num_args, *rest = args # type: ignore[assignment]
|
||||
return self.callback.call_map(f, num_args, list(rest), meta)
|
||||
f, mapped_args, operands = args # type: ignore[assignment]
|
||||
return self.callback.call_map(f, mapped_args, operands, meta)
|
||||
# For other unregistered HigherOrderOps, just interpret them blindly
|
||||
elif isinstance(target, torch._ops.HigherOrderOperator):
|
||||
return self.callback._fx(
|
||||
|
|
@ -357,18 +357,17 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
|
|||
def call_map(
|
||||
self,
|
||||
f: torch.fx.GraphModule,
|
||||
num_args: int,
|
||||
args: List[ProxyValue],
|
||||
mapped_args: List[ProxyValue],
|
||||
operands: List[ProxyValue],
|
||||
meta: NodeMetadata,
|
||||
) -> ProxyValue:
|
||||
xs = _unstack_pytree([arg.data for arg in args[:num_args]])[0]
|
||||
pos_args = args[num_args:]
|
||||
f_branch = self.call_submodule(f, tuple(xs + [arg.data for arg in pos_args]))
|
||||
xs = _unstack_pytree([arg.data for arg in mapped_args])[0]
|
||||
f_branch = self.call_submodule(f, tuple(xs + [arg.data for arg in operands]))
|
||||
assert f_branch is not None
|
||||
return self._fx(
|
||||
"call_function",
|
||||
torch.ops.higher_order.map_impl,
|
||||
(f_branch.graph_module, num_args, *args),
|
||||
(f_branch.graph_module, mapped_args, operands),
|
||||
{},
|
||||
meta,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -194,7 +194,7 @@ def map_wrapper(f, xs, *args):
|
|||
return flat_out
|
||||
|
||||
return pytree.tree_unflatten(
|
||||
map_impl(flat_fn, num_mapped_args, *flat_xs, *args), out_spec # type: ignore[arg-type]
|
||||
map_impl(flat_fn, flat_xs, args), out_spec # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -205,7 +205,11 @@ class MapAutogradOp(torch.autograd.Function):
|
|||
ctx._joint_graph = joint_graph
|
||||
ctx._num_mapped_args = num_mapped_args
|
||||
with torch._C._AutoDispatchBelowAutograd():
|
||||
return (*map_impl(fw_graph, num_mapped_args, *flat_args),)
|
||||
return (
|
||||
*map_impl(
|
||||
fw_graph, flat_args[:num_mapped_args], flat_args[num_mapped_args:]
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *flat_grads):
|
||||
|
|
@ -215,17 +219,13 @@ class MapAutogradOp(torch.autograd.Function):
|
|||
|
||||
grads = map_impl(
|
||||
ctx._joint_graph,
|
||||
ctx._num_mapped_args + len(flat_grads),
|
||||
*fw_mapped_args,
|
||||
*flat_grads,
|
||||
*pos_args,
|
||||
fw_mapped_args + flat_grads,
|
||||
pos_args,
|
||||
)
|
||||
return None, None, None, *grads
|
||||
|
||||
|
||||
def trace_map(proxy_mode, func_overload, f, num_mapped, *args):
|
||||
xs = list(args[:num_mapped])
|
||||
pos_args = list(args[num_mapped:])
|
||||
def trace_map(proxy_mode, func_overload, f, xs, pos_args):
|
||||
leading_dim_size = xs[0].shape[0]
|
||||
|
||||
example_input = _unstack_pytree(xs)[0]
|
||||
|
|
@ -254,7 +254,7 @@ def trace_map(proxy_mode, func_overload, f, num_mapped, *args):
|
|||
|
||||
expanded_outs = pytree.tree_map(expand_tensor, example_outs)
|
||||
|
||||
node_args = (body_graph, num_mapped, *args)
|
||||
node_args = (body_graph, list(xs), list(pos_args))
|
||||
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args)
|
||||
out_proxy = proxy_mode.tracer.create_proxy(
|
||||
"call_function", func_overload, proxy_args, {}, name="map_impl"
|
||||
|
|
@ -311,9 +311,7 @@ def _stack_pytree(pytrees):
|
|||
|
||||
|
||||
@map_impl.py_impl(DispatchKey.CompositeExplicitAutograd)
|
||||
def map_dense(f, num_mapped_args, *args):
|
||||
xs = args[:num_mapped_args]
|
||||
pos_args = args[num_mapped_args:]
|
||||
def map_dense(f, xs, pos_args):
|
||||
pytrees = []
|
||||
for inp in _unstack_pytree(xs):
|
||||
pytrees.append(f(*inp, *pos_args))
|
||||
|
|
@ -321,30 +319,29 @@ def map_dense(f, num_mapped_args, *args):
|
|||
|
||||
|
||||
@map_impl.py_impl(DispatchKey.Autograd)
|
||||
def map_autograd(f, num_mapped_args, *args):
|
||||
fw_graph, bw_graph = create_fw_bw_graph(f, num_mapped_args, *args)
|
||||
flat_out = MapAutogradOp.apply(fw_graph, bw_graph, num_mapped_args, *args)
|
||||
def map_autograd(f, xs, pos_args):
|
||||
num_mapped_args = len(xs)
|
||||
fw_graph, bw_graph = create_fw_bw_graph(f, num_mapped_args, *xs, *pos_args)
|
||||
flat_out = MapAutogradOp.apply(fw_graph, bw_graph, num_mapped_args, *xs, *pos_args)
|
||||
return flat_out
|
||||
|
||||
|
||||
@map_impl.py_impl(ProxyTorchDispatchMode)
|
||||
def map_proxy_torch_dispatch_mode(mode, f, num_mapped, *args):
|
||||
def map_proxy_torch_dispatch_mode(mode, f, xs, args):
|
||||
if mode.enable_tracing:
|
||||
return trace_map(mode, map_impl, f, num_mapped, *args)
|
||||
return trace_map(mode, map_impl, f, xs, args)
|
||||
else:
|
||||
return map_impl(f, num_mapped, *args)
|
||||
return map_impl(f, xs, args)
|
||||
|
||||
|
||||
@map_impl.py_impl(FakeTensorMode)
|
||||
def map_fake_tensor_mode(mode, f, num_mapped, *args):
|
||||
def map_fake_tensor_mode(mode, f, xs, args):
|
||||
with mode:
|
||||
return map_dense(f, num_mapped, *args)
|
||||
return map_dense(f, xs, args)
|
||||
|
||||
|
||||
@map_impl.py_functionalize_impl
|
||||
def map_functionalize(ctx, f, num_mapped, *args):
|
||||
xs = args[:num_mapped]
|
||||
pos_args = args[num_mapped:]
|
||||
def map_functionalize(ctx, f, xs, pos_args):
|
||||
unwrapped_xs = ctx.unwrap_tensors(xs)
|
||||
unwrapped_args = ctx.unwrap_tensors(pos_args)
|
||||
wrapped_fn = ctx.functionalize(f)
|
||||
|
|
@ -358,5 +355,5 @@ def map_functionalize(ctx, f, num_mapped, *args):
|
|||
if _has_potential_branch_input_alias(f, example_inputs):
|
||||
raise UnsupportedAliasMutationException("torch.map is aliasing the input!")
|
||||
|
||||
map_return = map_impl(wrapped_fn, num_mapped, *unwrapped_xs, *unwrapped_args)
|
||||
map_return = map_impl(wrapped_fn, unwrapped_xs, unwrapped_args)
|
||||
return ctx.wrap_tensors(map_return)
|
||||
|
|
|
|||
|
|
@ -137,29 +137,41 @@ def _unlift(
|
|||
buffers_to_mutate,
|
||||
)
|
||||
if node.op == "call_function" and node.target.__name__ == "map_impl":
|
||||
body_graph, num_mapped, *operands = node.args
|
||||
body_graph, mapped_args, operands = node.args
|
||||
num_mapped = len(mapped_args)
|
||||
body_gm = getattr(gm, body_graph.name)
|
||||
inp_pos_to_buffer_name_for_submod = {}
|
||||
real_operands = []
|
||||
# TODO Fix situation here to replace dot with underscore...
|
||||
state_dict_for_lookup = {
|
||||
key.replace(".", "_"): value for key, value in state_dict.items()
|
||||
}
|
||||
for ix, operand in enumerate(operands):
|
||||
if operand.target in inp_pos_to_param_buffer_name.values():
|
||||
inp_pos_to_buffer_name_for_submod[ix] = operand.target
|
||||
if operand.target in state_dict_for_lookup:
|
||||
value = state_dict_for_lookup[operand.target]
|
||||
elif operand.target in tensor_constants:
|
||||
value = tensor_constants[operand.target]
|
||||
else:
|
||||
raise RuntimeError(f"Unable to find value for {operand.target}")
|
||||
body_gm.register_buffer(operand.target, value)
|
||||
else:
|
||||
real_operands.append(operand)
|
||||
node.args = (body_graph, num_mapped, *real_operands)
|
||||
|
||||
_, in_spec = pytree.tree_flatten(real_operands)
|
||||
def _find_real_operands(operands, start_ix):
|
||||
real_operands = []
|
||||
for ix, operand in enumerate(operands):
|
||||
if operand.target in inp_pos_to_param_buffer_name.values():
|
||||
inp_pos_to_buffer_name_for_submod[
|
||||
ix + start_ix
|
||||
] = operand.target
|
||||
if operand.target in state_dict_for_lookup:
|
||||
value = state_dict_for_lookup[operand.target]
|
||||
elif operand.target in tensor_constants:
|
||||
value = tensor_constants[operand.target]
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Unable to find value for {operand.target}"
|
||||
)
|
||||
|
||||
body_gm.register_buffer(operand.target, value)
|
||||
else:
|
||||
real_operands.append(operand)
|
||||
return real_operands
|
||||
|
||||
real_mapped_args = _find_real_operands(mapped_args, 0)
|
||||
real_mapped_operands = _find_real_operands(operands, num_mapped)
|
||||
node.args = (body_graph, real_mapped_args, real_mapped_operands)
|
||||
|
||||
_, in_spec = pytree.tree_flatten(real_mapped_args + real_mapped_operands)
|
||||
|
||||
_unlift(
|
||||
body_gm,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user