diff --git a/functorch/_src/partitioners.py b/functorch/_src/partitioners.py index 0fa99cfb843..aec79f3db33 100644 --- a/functorch/_src/partitioners.py +++ b/functorch/_src/partitioners.py @@ -231,7 +231,7 @@ def _count_ops(graph): def min_cut_rematerialization_partition( - joint_module: fx.GraphModule, _joint_inputs, compiler="nvfuser" + joint_module: fx.GraphModule, _joint_inputs, compiler="nvfuser", recomputable_ops=None, ) -> Tuple[fx.GraphModule, fx.GraphModule]: """ Partitions the joint graph such that the backward recomputes the forward. @@ -247,6 +247,12 @@ def min_cut_rematerialization_partition( Args: joint_module(fx.GraphModule): The joint forward and backward graph. This is the result of AOT Autograd tracing. + _joint_inputs: The inputs to the joint graph. This is unused. + compiler: This option determines the default set of recomputable ops. + Currently, there are two options: ``nvfuser`` and ``inductor``. + recomputable_ops: This is an optional set of recomputable ops. If this + is not None, then this set of ops will be used instead of the + default set of ops. Returns: Returns the generated forward and backward Fx graph modules. @@ -299,13 +305,14 @@ def min_cut_rematerialization_partition( aten = torch.ops.aten prims = torch.ops.prims - recomputable_ops = [aten.add, aten.sub, aten.div, aten.atan2, aten.mul, aten.max, aten.min, aten.pow, aten.remainder, aten.fmod, aten.__and__, aten.__or__, aten.__xor__, aten.__lshift__, aten.__rshift__, aten.eq, aten.ne, aten.ge, aten.gt, aten.le, aten.lt, aten.abs, aten.bitwise_not, aten.ceil, aten.floor, aten.frac, aten.neg, aten.relu, aten.round, aten.silu, aten.trunc, aten.log, aten.log10, aten.log1p, aten.log2, aten.lgamma, aten.exp, aten.expm1, aten.erf, aten.erfc, aten.cos, aten.acos, aten.cosh, aten.sin, aten.asin, aten.sinh, aten.tan, aten.atan, aten.tanh, aten.atanh, aten.sqrt, aten.rsqrt, aten.reciprocal, aten.sigmoid, aten.softplus, aten.threshold, aten.threshold_backward, aten.clamp, aten.where, aten.lerp, aten.addcmul, aten.gelu, aten.gelu_backward, aten.alias, aten.sum, aten.mean, aten._grad_sum_to_size, aten.sum_to_size, aten.amax, aten.to, aten.type_as, operator.getitem, aten.squeeze, aten.unsqueeze, aten.rsub, aten._to_copy] # noqa: E501 + # compiler == "nvfuser" is the default set of recomputable ops + default_recomputable_ops = [aten.add, aten.sub, aten.div, aten.atan2, aten.mul, aten.max, aten.min, aten.pow, aten.remainder, aten.fmod, aten.__and__, aten.__or__, aten.__xor__, aten.__lshift__, aten.__rshift__, aten.eq, aten.ne, aten.ge, aten.gt, aten.le, aten.lt, aten.abs, aten.bitwise_not, aten.ceil, aten.floor, aten.frac, aten.neg, aten.relu, aten.round, aten.silu, aten.trunc, aten.log, aten.log10, aten.log1p, aten.log2, aten.lgamma, aten.exp, aten.expm1, aten.erf, aten.erfc, aten.cos, aten.acos, aten.cosh, aten.sin, aten.asin, aten.sinh, aten.tan, aten.atan, aten.tanh, aten.atanh, aten.sqrt, aten.rsqrt, aten.reciprocal, aten.sigmoid, aten.softplus, aten.threshold, aten.threshold_backward, aten.clamp, aten.where, aten.lerp, aten.addcmul, aten.gelu, aten.gelu_backward, aten.alias, aten.sum, aten.mean, aten._grad_sum_to_size, aten.sum_to_size, aten.amax, aten.to, aten.type_as, operator.getitem, aten.squeeze, aten.unsqueeze, aten.rsub, aten._to_copy] # noqa: E501 if compiler == "inductor": - recomputable_ops += [prims.div, prims.convert_element_type, aten.sign, aten.clone, aten._to_copy, aten.full_like, prims.var, prims.sum, aten.var, aten.std, prims.broadcast_in_dim, aten.select, aten.permute, aten._unsafe_view, aten.view, aten.expand, aten.slice, aten.reshape, aten.broadcast_tensors, aten.scalar_tensor, aten.ones, aten.new_zeros, aten.lift_fresh_copy, aten.minimum, aten.arange, aten.bitwise_and, aten.triu, aten.var_mean, aten.isinf, aten.any, aten.isnan, aten.full, aten.as_strided, aten.zeros, aten.argmax, aten.maximum, aten.bitwise_or, aten.logical_and, aten.logical_or] # noqa: E501 + default_recomputable_ops += [prims.div, prims.convert_element_type, aten.sign, aten.clone, aten._to_copy, aten.full_like, prims.var, prims.sum, aten.var, aten.std, prims.broadcast_in_dim, aten.select, aten.permute, aten._unsafe_view, aten.view, aten.expand, aten.slice, aten.reshape, aten.broadcast_tensors, aten.scalar_tensor, aten.ones, aten.new_zeros, aten.lift_fresh_copy, aten.minimum, aten.arange, aten.bitwise_and, aten.triu, aten.var_mean, aten.isinf, aten.any, aten.isnan, aten.full, aten.as_strided, aten.zeros, aten.argmax, aten.maximum, aten.bitwise_or, aten.logical_and, aten.logical_or] # noqa: E501 # Natalia said that we should allow recomputing indexing :) - recomputable_ops += [aten.index] + default_recomputable_ops += [aten.index] - recomputable_ops = set(recomputable_ops) + recomputable_ops = set(recomputable_ops) if recomputable_ops is not None else set(default_recomputable_ops) random_ops = [aten.native_dropout, aten.rand_like, aten.randn_like] compute_intensive_ops = [aten.mm, aten.convolution, aten.convolution_backward, aten.bmm, aten.addmm, aten.upsample_bilinear2d, aten._softmax, aten._softmax_backward_data, aten.native_layer_norm, aten.native_layer_norm_backward, aten.native_batch_norm, aten.native_batch_norm_backward] # noqa: E501 diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index df25f90e55f..822e7b351d9 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -757,6 +757,62 @@ class TestPartitioning(AOTTestCase): ins, outs = get_ins_outs(fw_graph) self.assertEqual(outs[1].target, torch.ops.aten.mm.default) + @unittest.skipIf(not USE_NETWORKX, "networkx not available") + def test_min_cut_partitioner_recomputable_ops(self): + def f(x): + return x * x * x + + recomputable_ops = [] + partition_fn = partial(min_cut_rematerialization_partition, recomputable_ops=recomputable_ops) + + fw_graph, bw_graph = get_fw_bw_graph(f, [torch.randn(3, requires_grad=True)], partition_fn) + # Expected forward graph: + # opcode name target args kwargs + # ------------- --------- --------------- -------------------------- -------- + # placeholder primals_1 primals_1 () {} + # call_function mul aten.mul.Tensor (primals_1, primals_1) {} + # call_function mul_1 aten.mul.Tensor (mul, primals_1) {} + # output output output ([mul_1, primals_1, mul],) {} + self.assertEqual(get_num_ins_outs(fw_graph), (1, 3)) + # Expected backward graph: + # opcode name target args kwargs + # ------------- ---------- --------------- ----------------------- -------- + # placeholder primals_1 primals_1 () {} + # placeholder mul mul () {} + # placeholder tangents_1 tangents_1 () {} + # call_function mul_2 aten.mul.Tensor (tangents_1, mul) {} + # call_function mul_3 aten.mul.Tensor (tangents_1, primals_1) {} + # call_function mul_4 aten.mul.Tensor (mul_3, primals_1) {} + # call_function add aten.add.Tensor (mul_2, mul_4) {} + # call_function add_1 aten.add.Tensor (add, mul_4) {} + # output output output ([add_1],) {} + self.assertEqual(get_num_ins_outs(bw_graph), (3, 1)) + + recomputable_ops = [torch.ops.aten.mul] + partition_fn = partial(min_cut_rematerialization_partition, recomputable_ops=recomputable_ops) + fw_graph, bw_graph = get_fw_bw_graph(f, [torch.randn(3, requires_grad=True)], partition_fn) + # Expected forward graph: + # opcode name target args kwargs + # ------------- --------- --------------- ---------------------- -------- + # placeholder primals_1 primals_1 () {} + # call_function mul aten.mul.Tensor (primals_1, primals_1) {} + # call_function mul_1 aten.mul.Tensor (mul, primals_1) {} + # output output output ([mul_1, primals_1],) {} + self.assertEqual(get_num_ins_outs(fw_graph), (1, 2)) + # Expected backward graph: + # opcode name target args kwargs + # ------------- ---------- --------------- ----------------------- -------- + # placeholder primals_1 primals_1 () {} + # placeholder tangents_1 tangents_1 () {} + # call_function mul aten.mul.Tensor (primals_1, primals_1) {} # RECOMPUTED + # call_function mul_2 aten.mul.Tensor (tangents_1, mul) {} + # call_function mul_3 aten.mul.Tensor (tangents_1, primals_1) {} + # call_function mul_4 aten.mul.Tensor (mul_3, primals_1) {} + # call_function add aten.add.Tensor (mul_2, mul_4) {} + # call_function add_1 aten.add.Tensor (add, mul_4) {} + # output output output ([add_1],) {} + self.assertEqual(get_num_ins_outs(bw_graph), (2, 1)) + def test_contiguous(self): # The test simulates the condition where transpose followed by view # happens in the backward pass.