mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[HigherOrderOp] change signature of map_impl (#117161)
Summary: X-link: https://github.com/pytorch/executorch/pull/1580 This PR changes the schema of map_impl from map_impl(f, num_mapped, *operands) to map_impl(f, mapped_args: Tuple, moperands: Tuple). This is to prepare for turning on dynamo for eager mode map, where we want to get rid of the num_mapped scalar. Test Plan: Existing tests. Differential Revision: D52495413 Pull Request resolved: https://github.com/pytorch/pytorch/pull/117161 Approved by: https://github.com/angelayi, https://github.com/tugsbayasgalan
This commit is contained in:
parent
f2f47c6848
commit
2bc7da1ab7
|
|
@ -1121,7 +1121,7 @@ def forward(self, L_xs_ : torch.Tensor, L_y_ : torch.Tensor):
|
||||||
l_xs_ = L_xs_
|
l_xs_ = L_xs_
|
||||||
l_y_ = L_y_
|
l_y_ = L_y_
|
||||||
map_body_1 = self.map_body_1
|
map_body_1 = self.map_body_1
|
||||||
map_impl = torch.ops.higher_order.map_impl(map_body_1, 1, l_xs_, l_y_); map_body_1 = l_xs_ = l_y_ = None
|
map_impl = torch.ops.higher_order.map_impl(map_body_1, [l_xs_], [l_y_]); map_body_1 = l_xs_ = l_y_ = None
|
||||||
getitem_1 = map_impl[0]; map_impl = None
|
getitem_1 = map_impl[0]; map_impl = None
|
||||||
return (getitem_1,)""",
|
return (getitem_1,)""",
|
||||||
)
|
)
|
||||||
|
|
@ -1131,7 +1131,7 @@ def forward(self, L_xs_ : torch.Tensor, L_y_ : torch.Tensor):
|
||||||
def forward(self, getitem, l_y_):
|
def forward(self, getitem, l_y_):
|
||||||
getitem_1 = getitem[0]
|
getitem_1 = getitem[0]
|
||||||
map_body_0 = self.map_body_0
|
map_body_0 = self.map_body_0
|
||||||
map_impl = torch.ops.higher_order.map_impl(map_body_0, 1, getitem, l_y_); map_body_0 = getitem = l_y_ = None
|
map_impl = torch.ops.higher_order.map_impl(map_body_0, [getitem], [l_y_]); map_body_0 = getitem = l_y_ = None
|
||||||
getitem_2 = map_impl[0]; map_impl = None
|
getitem_2 = map_impl[0]; map_impl = None
|
||||||
return (getitem_2,)""",
|
return (getitem_2,)""",
|
||||||
)
|
)
|
||||||
|
|
@ -1152,7 +1152,7 @@ def forward(self, getitem, l_y_):
|
||||||
def forward(self, L_x_ : torch.Tensor):
|
def forward(self, L_x_ : torch.Tensor):
|
||||||
l_x_ = L_x_
|
l_x_ = L_x_
|
||||||
map_body_0 = self.map_body_0
|
map_body_0 = self.map_body_0
|
||||||
map_impl = torch.ops.higher_order.map_impl(map_body_0, 1, l_x_); map_body_0 = l_x_ = None
|
map_impl = torch.ops.higher_order.map_impl(map_body_0, [l_x_], []); map_body_0 = l_x_ = None
|
||||||
getitem_1 = map_impl[0]
|
getitem_1 = map_impl[0]
|
||||||
getitem_2 = map_impl[1]; map_impl = None
|
getitem_2 = map_impl[1]; map_impl = None
|
||||||
return (getitem_1, getitem_2)""",
|
return (getitem_1, getitem_2)""",
|
||||||
|
|
@ -1188,7 +1188,7 @@ def forward(self, getitem):
|
||||||
def forward(self, L_x_ : torch.Tensor):
|
def forward(self, L_x_ : torch.Tensor):
|
||||||
l_x_ = L_x_
|
l_x_ = L_x_
|
||||||
map_body_0 = self.map_body_0
|
map_body_0 = self.map_body_0
|
||||||
map_impl = torch.ops.higher_order.map_impl(map_body_0, 1, l_x_); map_body_0 = l_x_ = None
|
map_impl = torch.ops.higher_order.map_impl(map_body_0, [l_x_], []); map_body_0 = l_x_ = None
|
||||||
getitem_1 = map_impl[0]
|
getitem_1 = map_impl[0]
|
||||||
getitem_2 = map_impl[1]
|
getitem_2 = map_impl[1]
|
||||||
getitem_3 = map_impl[2]
|
getitem_3 = map_impl[2]
|
||||||
|
|
@ -1237,7 +1237,7 @@ def forward(self, getitem):
|
||||||
def forward(self, L_x_ : torch.Tensor):
|
def forward(self, L_x_ : torch.Tensor):
|
||||||
l_x_ = L_x_
|
l_x_ = L_x_
|
||||||
map_body_0 = self.map_body_0
|
map_body_0 = self.map_body_0
|
||||||
map_impl = torch.ops.higher_order.map_impl(map_body_0, 1, l_x_, 3); map_body_0 = l_x_ = None
|
map_impl = torch.ops.higher_order.map_impl(map_body_0, [l_x_], [3]); map_body_0 = l_x_ = None
|
||||||
getitem_1 = map_impl[0]; map_impl = None
|
getitem_1 = map_impl[0]; map_impl = None
|
||||||
return (getitem_1,)""",
|
return (getitem_1,)""",
|
||||||
)
|
)
|
||||||
|
|
@ -1271,7 +1271,7 @@ def forward(self, getitem, const):
|
||||||
def forward(self, L_x_ : torch.Tensor):
|
def forward(self, L_x_ : torch.Tensor):
|
||||||
l_x_ = L_x_
|
l_x_ = L_x_
|
||||||
map_body_0 = self.map_body_0
|
map_body_0 = self.map_body_0
|
||||||
map_impl = torch.ops.higher_order.map_impl(map_body_0, 1, l_x_, 3); map_body_0 = l_x_ = None
|
map_impl = torch.ops.higher_order.map_impl(map_body_0, [l_x_], [3]); map_body_0 = l_x_ = None
|
||||||
getitem_1 = map_impl[0]; map_impl = None
|
getitem_1 = map_impl[0]; map_impl = None
|
||||||
return (getitem_1,)""",
|
return (getitem_1,)""",
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1451,8 +1451,7 @@ def forward(self, arg0_1, arg1_1, arg2_1):
|
||||||
self.assertExpectedInline(gm.code.strip(), """\
|
self.assertExpectedInline(gm.code.strip(), """\
|
||||||
def forward(self, pred_1, x_1):
|
def forward(self, pred_1, x_1):
|
||||||
body_graph_0 = self.body_graph_0
|
body_graph_0 = self.body_graph_0
|
||||||
map_impl = torch.ops.higher_order.map_impl(body_graph_0, 1, x_1, pred_1);\
|
map_impl = torch.ops.higher_order.map_impl(body_graph_0, [x_1], [pred_1]); body_graph_0 = x_1 = pred_1 = None
|
||||||
body_graph_0 = x_1 = pred_1 = None
|
|
||||||
getitem = map_impl[0]; map_impl = None
|
getitem = map_impl[0]; map_impl = None
|
||||||
return getitem""")
|
return getitem""")
|
||||||
self.assertExpectedInline(gm.body_graph_0.code.strip(), """\
|
self.assertExpectedInline(gm.body_graph_0.code.strip(), """\
|
||||||
|
|
|
||||||
|
|
@ -792,8 +792,8 @@ class MapHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||||
|
|
||||||
p_args = (
|
p_args = (
|
||||||
body_node,
|
body_node,
|
||||||
1, # right now we only supports num_mapped = 1
|
[args[1].as_proxy()],
|
||||||
*([arg.as_proxy() for arg in args[1:]] + list(body_lifted_freevars.keys())),
|
[arg.as_proxy() for arg in args[2:]] + list(body_lifted_freevars.keys()),
|
||||||
)
|
)
|
||||||
return _call_function_and_unflatten_output(
|
return _call_function_and_unflatten_output(
|
||||||
tx, torch.ops.higher_order.map_impl, p_args, {}, body_r, body_spec
|
tx, torch.ops.higher_order.map_impl, p_args, {}, body_r, body_spec
|
||||||
|
|
|
||||||
|
|
@ -200,8 +200,8 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
|
||||||
pred, true_fn, false_fn, inputs = args
|
pred, true_fn, false_fn, inputs = args
|
||||||
return self.callback.call_cond(pred, true_fn, false_fn, inputs, meta)
|
return self.callback.call_cond(pred, true_fn, false_fn, inputs, meta)
|
||||||
elif target == torch.ops.higher_order.map_impl:
|
elif target == torch.ops.higher_order.map_impl:
|
||||||
f, num_args, *rest = args # type: ignore[assignment]
|
f, mapped_args, operands = args # type: ignore[assignment]
|
||||||
return self.callback.call_map(f, num_args, list(rest), meta)
|
return self.callback.call_map(f, mapped_args, operands, meta)
|
||||||
# For other unregistered HigherOrderOps, just interpret them blindly
|
# For other unregistered HigherOrderOps, just interpret them blindly
|
||||||
elif isinstance(target, torch._ops.HigherOrderOperator):
|
elif isinstance(target, torch._ops.HigherOrderOperator):
|
||||||
return self.callback._fx(
|
return self.callback._fx(
|
||||||
|
|
@ -357,18 +357,17 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
|
||||||
def call_map(
|
def call_map(
|
||||||
self,
|
self,
|
||||||
f: torch.fx.GraphModule,
|
f: torch.fx.GraphModule,
|
||||||
num_args: int,
|
mapped_args: List[ProxyValue],
|
||||||
args: List[ProxyValue],
|
operands: List[ProxyValue],
|
||||||
meta: NodeMetadata,
|
meta: NodeMetadata,
|
||||||
) -> ProxyValue:
|
) -> ProxyValue:
|
||||||
xs = _unstack_pytree([arg.data for arg in args[:num_args]])[0]
|
xs = _unstack_pytree([arg.data for arg in mapped_args])[0]
|
||||||
pos_args = args[num_args:]
|
f_branch = self.call_submodule(f, tuple(xs + [arg.data for arg in operands]))
|
||||||
f_branch = self.call_submodule(f, tuple(xs + [arg.data for arg in pos_args]))
|
|
||||||
assert f_branch is not None
|
assert f_branch is not None
|
||||||
return self._fx(
|
return self._fx(
|
||||||
"call_function",
|
"call_function",
|
||||||
torch.ops.higher_order.map_impl,
|
torch.ops.higher_order.map_impl,
|
||||||
(f_branch.graph_module, num_args, *args),
|
(f_branch.graph_module, mapped_args, operands),
|
||||||
{},
|
{},
|
||||||
meta,
|
meta,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -194,7 +194,7 @@ def map_wrapper(f, xs, *args):
|
||||||
return flat_out
|
return flat_out
|
||||||
|
|
||||||
return pytree.tree_unflatten(
|
return pytree.tree_unflatten(
|
||||||
map_impl(flat_fn, num_mapped_args, *flat_xs, *args), out_spec # type: ignore[arg-type]
|
map_impl(flat_fn, flat_xs, args), out_spec # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -205,7 +205,11 @@ class MapAutogradOp(torch.autograd.Function):
|
||||||
ctx._joint_graph = joint_graph
|
ctx._joint_graph = joint_graph
|
||||||
ctx._num_mapped_args = num_mapped_args
|
ctx._num_mapped_args = num_mapped_args
|
||||||
with torch._C._AutoDispatchBelowAutograd():
|
with torch._C._AutoDispatchBelowAutograd():
|
||||||
return (*map_impl(fw_graph, num_mapped_args, *flat_args),)
|
return (
|
||||||
|
*map_impl(
|
||||||
|
fw_graph, flat_args[:num_mapped_args], flat_args[num_mapped_args:]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, *flat_grads):
|
def backward(ctx, *flat_grads):
|
||||||
|
|
@ -215,17 +219,13 @@ class MapAutogradOp(torch.autograd.Function):
|
||||||
|
|
||||||
grads = map_impl(
|
grads = map_impl(
|
||||||
ctx._joint_graph,
|
ctx._joint_graph,
|
||||||
ctx._num_mapped_args + len(flat_grads),
|
fw_mapped_args + flat_grads,
|
||||||
*fw_mapped_args,
|
pos_args,
|
||||||
*flat_grads,
|
|
||||||
*pos_args,
|
|
||||||
)
|
)
|
||||||
return None, None, None, *grads
|
return None, None, None, *grads
|
||||||
|
|
||||||
|
|
||||||
def trace_map(proxy_mode, func_overload, f, num_mapped, *args):
|
def trace_map(proxy_mode, func_overload, f, xs, pos_args):
|
||||||
xs = list(args[:num_mapped])
|
|
||||||
pos_args = list(args[num_mapped:])
|
|
||||||
leading_dim_size = xs[0].shape[0]
|
leading_dim_size = xs[0].shape[0]
|
||||||
|
|
||||||
example_input = _unstack_pytree(xs)[0]
|
example_input = _unstack_pytree(xs)[0]
|
||||||
|
|
@ -254,7 +254,7 @@ def trace_map(proxy_mode, func_overload, f, num_mapped, *args):
|
||||||
|
|
||||||
expanded_outs = pytree.tree_map(expand_tensor, example_outs)
|
expanded_outs = pytree.tree_map(expand_tensor, example_outs)
|
||||||
|
|
||||||
node_args = (body_graph, num_mapped, *args)
|
node_args = (body_graph, list(xs), list(pos_args))
|
||||||
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args)
|
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args)
|
||||||
out_proxy = proxy_mode.tracer.create_proxy(
|
out_proxy = proxy_mode.tracer.create_proxy(
|
||||||
"call_function", func_overload, proxy_args, {}, name="map_impl"
|
"call_function", func_overload, proxy_args, {}, name="map_impl"
|
||||||
|
|
@ -311,9 +311,7 @@ def _stack_pytree(pytrees):
|
||||||
|
|
||||||
|
|
||||||
@map_impl.py_impl(DispatchKey.CompositeExplicitAutograd)
|
@map_impl.py_impl(DispatchKey.CompositeExplicitAutograd)
|
||||||
def map_dense(f, num_mapped_args, *args):
|
def map_dense(f, xs, pos_args):
|
||||||
xs = args[:num_mapped_args]
|
|
||||||
pos_args = args[num_mapped_args:]
|
|
||||||
pytrees = []
|
pytrees = []
|
||||||
for inp in _unstack_pytree(xs):
|
for inp in _unstack_pytree(xs):
|
||||||
pytrees.append(f(*inp, *pos_args))
|
pytrees.append(f(*inp, *pos_args))
|
||||||
|
|
@ -321,30 +319,29 @@ def map_dense(f, num_mapped_args, *args):
|
||||||
|
|
||||||
|
|
||||||
@map_impl.py_impl(DispatchKey.Autograd)
|
@map_impl.py_impl(DispatchKey.Autograd)
|
||||||
def map_autograd(f, num_mapped_args, *args):
|
def map_autograd(f, xs, pos_args):
|
||||||
fw_graph, bw_graph = create_fw_bw_graph(f, num_mapped_args, *args)
|
num_mapped_args = len(xs)
|
||||||
flat_out = MapAutogradOp.apply(fw_graph, bw_graph, num_mapped_args, *args)
|
fw_graph, bw_graph = create_fw_bw_graph(f, num_mapped_args, *xs, *pos_args)
|
||||||
|
flat_out = MapAutogradOp.apply(fw_graph, bw_graph, num_mapped_args, *xs, *pos_args)
|
||||||
return flat_out
|
return flat_out
|
||||||
|
|
||||||
|
|
||||||
@map_impl.py_impl(ProxyTorchDispatchMode)
|
@map_impl.py_impl(ProxyTorchDispatchMode)
|
||||||
def map_proxy_torch_dispatch_mode(mode, f, num_mapped, *args):
|
def map_proxy_torch_dispatch_mode(mode, f, xs, args):
|
||||||
if mode.enable_tracing:
|
if mode.enable_tracing:
|
||||||
return trace_map(mode, map_impl, f, num_mapped, *args)
|
return trace_map(mode, map_impl, f, xs, args)
|
||||||
else:
|
else:
|
||||||
return map_impl(f, num_mapped, *args)
|
return map_impl(f, xs, args)
|
||||||
|
|
||||||
|
|
||||||
@map_impl.py_impl(FakeTensorMode)
|
@map_impl.py_impl(FakeTensorMode)
|
||||||
def map_fake_tensor_mode(mode, f, num_mapped, *args):
|
def map_fake_tensor_mode(mode, f, xs, args):
|
||||||
with mode:
|
with mode:
|
||||||
return map_dense(f, num_mapped, *args)
|
return map_dense(f, xs, args)
|
||||||
|
|
||||||
|
|
||||||
@map_impl.py_functionalize_impl
|
@map_impl.py_functionalize_impl
|
||||||
def map_functionalize(ctx, f, num_mapped, *args):
|
def map_functionalize(ctx, f, xs, pos_args):
|
||||||
xs = args[:num_mapped]
|
|
||||||
pos_args = args[num_mapped:]
|
|
||||||
unwrapped_xs = ctx.unwrap_tensors(xs)
|
unwrapped_xs = ctx.unwrap_tensors(xs)
|
||||||
unwrapped_args = ctx.unwrap_tensors(pos_args)
|
unwrapped_args = ctx.unwrap_tensors(pos_args)
|
||||||
wrapped_fn = ctx.functionalize(f)
|
wrapped_fn = ctx.functionalize(f)
|
||||||
|
|
@ -358,5 +355,5 @@ def map_functionalize(ctx, f, num_mapped, *args):
|
||||||
if _has_potential_branch_input_alias(f, example_inputs):
|
if _has_potential_branch_input_alias(f, example_inputs):
|
||||||
raise UnsupportedAliasMutationException("torch.map is aliasing the input!")
|
raise UnsupportedAliasMutationException("torch.map is aliasing the input!")
|
||||||
|
|
||||||
map_return = map_impl(wrapped_fn, num_mapped, *unwrapped_xs, *unwrapped_args)
|
map_return = map_impl(wrapped_fn, unwrapped_xs, unwrapped_args)
|
||||||
return ctx.wrap_tensors(map_return)
|
return ctx.wrap_tensors(map_return)
|
||||||
|
|
|
||||||
|
|
@ -137,29 +137,41 @@ def _unlift(
|
||||||
buffers_to_mutate,
|
buffers_to_mutate,
|
||||||
)
|
)
|
||||||
if node.op == "call_function" and node.target.__name__ == "map_impl":
|
if node.op == "call_function" and node.target.__name__ == "map_impl":
|
||||||
body_graph, num_mapped, *operands = node.args
|
body_graph, mapped_args, operands = node.args
|
||||||
|
num_mapped = len(mapped_args)
|
||||||
body_gm = getattr(gm, body_graph.name)
|
body_gm = getattr(gm, body_graph.name)
|
||||||
inp_pos_to_buffer_name_for_submod = {}
|
inp_pos_to_buffer_name_for_submod = {}
|
||||||
real_operands = []
|
|
||||||
# TODO Fix situation here to replace dot with underscore...
|
# TODO Fix situation here to replace dot with underscore...
|
||||||
state_dict_for_lookup = {
|
state_dict_for_lookup = {
|
||||||
key.replace(".", "_"): value for key, value in state_dict.items()
|
key.replace(".", "_"): value for key, value in state_dict.items()
|
||||||
}
|
}
|
||||||
for ix, operand in enumerate(operands):
|
|
||||||
if operand.target in inp_pos_to_param_buffer_name.values():
|
|
||||||
inp_pos_to_buffer_name_for_submod[ix] = operand.target
|
|
||||||
if operand.target in state_dict_for_lookup:
|
|
||||||
value = state_dict_for_lookup[operand.target]
|
|
||||||
elif operand.target in tensor_constants:
|
|
||||||
value = tensor_constants[operand.target]
|
|
||||||
else:
|
|
||||||
raise RuntimeError(f"Unable to find value for {operand.target}")
|
|
||||||
body_gm.register_buffer(operand.target, value)
|
|
||||||
else:
|
|
||||||
real_operands.append(operand)
|
|
||||||
node.args = (body_graph, num_mapped, *real_operands)
|
|
||||||
|
|
||||||
_, in_spec = pytree.tree_flatten(real_operands)
|
def _find_real_operands(operands, start_ix):
|
||||||
|
real_operands = []
|
||||||
|
for ix, operand in enumerate(operands):
|
||||||
|
if operand.target in inp_pos_to_param_buffer_name.values():
|
||||||
|
inp_pos_to_buffer_name_for_submod[
|
||||||
|
ix + start_ix
|
||||||
|
] = operand.target
|
||||||
|
if operand.target in state_dict_for_lookup:
|
||||||
|
value = state_dict_for_lookup[operand.target]
|
||||||
|
elif operand.target in tensor_constants:
|
||||||
|
value = tensor_constants[operand.target]
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Unable to find value for {operand.target}"
|
||||||
|
)
|
||||||
|
|
||||||
|
body_gm.register_buffer(operand.target, value)
|
||||||
|
else:
|
||||||
|
real_operands.append(operand)
|
||||||
|
return real_operands
|
||||||
|
|
||||||
|
real_mapped_args = _find_real_operands(mapped_args, 0)
|
||||||
|
real_mapped_operands = _find_real_operands(operands, num_mapped)
|
||||||
|
node.args = (body_graph, real_mapped_args, real_mapped_operands)
|
||||||
|
|
||||||
|
_, in_spec = pytree.tree_flatten(real_mapped_args + real_mapped_operands)
|
||||||
|
|
||||||
_unlift(
|
_unlift(
|
||||||
body_gm,
|
body_gm,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user