diff --git a/test/expect/TestJit.test_broadcast_fusion_cuda.expect b/test/expect/TestJit.test_broadcast_fusion_cuda.expect index 47147b7a1aa..27719ce9a2e 100644 --- a/test/expect/TestJit.test_broadcast_fusion_cuda.expect +++ b/test/expect/TestJit.test_broadcast_fusion_cuda.expect @@ -4,11 +4,11 @@ graph(%0 : Float(*, *) %3 : Float(*, *) = prim::FusionGroup_0[device=0](%2, %0, %1) return (%3); } -with prim::FusionGroup_0 = graph(%1 : Float(*) - %4 : Float(*, *) - %5 : Float(*)) { - %6 : Float(*, *) = aten::mul(%4, %5) - %2 : int = prim::Constant[value=1]() - %3 : Float(*, *) = aten::add(%6, %1, %2) - return (%3); +with prim::FusionGroup_0 = graph(%0 : Float(*) + %1 : Float(*, *) + %2 : Float(*)) { + %3 : Float(*, *) = aten::mul(%1, %2) + %4 : int = prim::Constant[value=1]() + %5 : Float(*, *) = aten::add(%3, %0, %4) + return (%5); } diff --git a/test/expect/TestJit.test_concat_fusion.expect b/test/expect/TestJit.test_concat_fusion.expect index 9cc1b2dae65..7d0f36fcbca 100644 --- a/test/expect/TestJit.test_concat_fusion.expect +++ b/test/expect/TestJit.test_concat_fusion.expect @@ -3,11 +3,11 @@ graph(%0 : Float(*, *) %2 : Float(*, *) = prim::FusionGroup_0[device=0](%0, %1) return (%2); } -with prim::FusionGroup_0 = graph(%3 : Float(*, *) - %4 : Float(*, *)) { - %6 : int = prim::Constant[value=1]() - %7 : Float(*, *) = aten::add(%3, %4, %6) - %5 : Float(*, *) = aten::mul(%3, %4) - %2 : Float(*, *) = prim::FusedConcat[dim=0](%7, %5) - return (%2); +with prim::FusionGroup_0 = graph(%0 : Float(*, *) + %1 : Float(*, *)) { + %2 : int = prim::Constant[value=1]() + %3 : Float(*, *) = aten::add(%0, %1, %2) + %4 : Float(*, *) = aten::mul(%0, %1) + %5 : Float(*, *) = prim::FusedConcat[dim=0](%3, %4) + return (%5); } diff --git a/test/expect/TestJit.test_concat_fusion_invariant_cuda.expect b/test/expect/TestJit.test_concat_fusion_invariant_cuda.expect index a362483f239..748f7cd227b 100644 --- a/test/expect/TestJit.test_concat_fusion_invariant_cuda.expect +++ b/test/expect/TestJit.test_concat_fusion_invariant_cuda.expect @@ -6,12 +6,12 @@ graph(%0 : Float(*, *) %5 : Float(*, *) = aten::add(%4, %2, %3) return (%5); } -with prim::FusionGroup_0 = graph(%3 : Float(*, *) - %4 : Float(*, *)) { - %7 : int = prim::Constant[value=1]() - %8 : Float(*, *) = aten::add(%3, %4, %7) - %5 : int = prim::Constant[value=1]() - %6 : Float(*, *) = aten::sub(%3, %4, %5) - %2 : Float(*, *) = prim::FusedConcat[dim=0](%8, %6) - return (%2); +with prim::FusionGroup_0 = graph(%0 : Float(*, *) + %1 : Float(*, *)) { + %2 : int = prim::Constant[value=1]() + %3 : Float(*, *) = aten::add(%0, %1, %2) + %4 : int = prim::Constant[value=1]() + %5 : Float(*, *) = aten::sub(%0, %1, %4) + %6 : Float(*, *) = prim::FusedConcat[dim=0](%3, %5) + return (%6); } diff --git a/test/expect/TestJit.test_cpp.expect b/test/expect/TestJit.test_cpp.expect index 08119f5277b..54a3c16d459 100644 --- a/test/expect/TestJit.test_cpp.expect +++ b/test/expect/TestJit.test_cpp.expect @@ -64,10 +64,10 @@ graph(%0 : Dynamic %2 : Dynamic %3 : Dynamic %4 : Dynamic) { - %23 : Dynamic, %24 : Dynamic = prim::GraphExecutor_0(%0, %3, %1, %4, %2) + %23 : Dynamic, %24 : Dynamic = prim::DifferentiableGraph_0(%0, %3, %1, %4, %2) return (%24, %23); } -with prim::GraphExecutor_0 = graph(%1 : Dynamic +with prim::DifferentiableGraph_0 = graph(%1 : Dynamic %2 : Dynamic %4 : Dynamic %5 : Dynamic diff --git a/test/expect/TestJit.test_fuse_last_device.expect b/test/expect/TestJit.test_fuse_last_device.expect index 9cd143c4481..e59c112dd5c 100644 --- a/test/expect/TestJit.test_fuse_last_device.expect +++ b/test/expect/TestJit.test_fuse_last_device.expect @@ -3,14 +3,14 @@ graph(%0 : Float(*) %2 : Float(*) = prim::FusionGroup_0[device=1](%0, %1) return (%2); } -with prim::FusionGroup_0 = graph(%5 : Float(*) - %10 : Float(*)) { - %11 : int = prim::Constant[value=1]() - %12 : Float(*) = aten::add(%5, %10, %11) - %9 : Float(*) = aten::mul(%5, %12) - %6 : int = prim::Constant[value=1]() - %7 : Float(*) = aten::add(%9, %5, %6) - %3 : Float(*) = aten::tanh(%7) - %1 : Float(*) = aten::sigmoid(%3) - return (%1); +with prim::FusionGroup_0 = graph(%0 : Float(*) + %1 : Float(*)) { + %2 : int = prim::Constant[value=1]() + %3 : Float(*) = aten::add(%0, %1, %2) + %4 : Float(*) = aten::mul(%0, %3) + %5 : int = prim::Constant[value=1]() + %6 : Float(*) = aten::add(%4, %0, %5) + %7 : Float(*) = aten::tanh(%6) + %8 : Float(*) = aten::sigmoid(%7) + return (%8); } diff --git a/test/expect/TestJit.test_fusion_distribute.expect b/test/expect/TestJit.test_fusion_distribute.expect index 1240535dc09..4cd475eab9d 100644 --- a/test/expect/TestJit.test_fusion_distribute.expect +++ b/test/expect/TestJit.test_fusion_distribute.expect @@ -6,14 +6,14 @@ graph(%0 : Float(*, *) %6 : Float(*, *) = prim::FusionGroup_0[device=0](%4, %5) return (%6); } -with prim::FusionGroup_0 = graph(%11 : Dynamic - %14 : Dynamic) { - %15 : Float(*, *), %16 : Float(*, *) = prim::ConstantChunk[chunks=2, dim=1](%14) - %12 : Float(*, *), %13 : Float(*, *) = prim::ConstantChunk[chunks=2, dim=1](%11) - %9 : int = prim::Constant[value=1]() - %10 : Float(*, *) = aten::add(%13, %16, %9) - %5 : int = prim::Constant[value=1]() - %6 : Float(*, *) = aten::add(%12, %15, %5) - %2 : Float(*, *) = aten::mul(%6, %10) - return (%2); +with prim::FusionGroup_0 = graph(%0 : Dynamic + %1 : Dynamic) { + %2 : Float(*, *), %3 : Float(*, *) = prim::ConstantChunk[chunks=2, dim=1](%1) + %4 : Float(*, *), %5 : Float(*, *) = prim::ConstantChunk[chunks=2, dim=1](%0) + %6 : int = prim::Constant[value=1]() + %7 : Float(*, *) = aten::add(%5, %3, %6) + %8 : int = prim::Constant[value=1]() + %9 : Float(*, *) = aten::add(%4, %2, %8) + %10 : Float(*, *) = aten::mul(%9, %7) + return (%10); } diff --git a/test/expect/TestJit.test_lstm_fusion_concat.expect b/test/expect/TestJit.test_lstm_fusion_concat.expect index 9262c9563c1..8e349095e05 100644 --- a/test/expect/TestJit.test_lstm_fusion_concat.expect +++ b/test/expect/TestJit.test_lstm_fusion_concat.expect @@ -16,29 +16,29 @@ graph(%0 : Float(*, *) %16 : Float(*, *) = prim::FusionGroup_0[device=0](%2, %14, %15) return (%16); } -with prim::FusionGroup_0 = graph(%15 : Float(*, *) - %41 : Dynamic - %46 : Dynamic) { - %47 : Float(*, *), %48 : Float(*, *), %49 : Float(*, *), %50 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%46) - %42 : Float(*, *), %43 : Float(*, *), %44 : Float(*, *), %45 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%41) - %39 : int = prim::Constant[value=1]() - %40 : Float(*, *) = aten::add(%42, %47, %39) - %35 : int = prim::Constant[value=1]() - %36 : Float(*, *) = aten::add(%43, %48, %35) - %31 : int = prim::Constant[value=1]() - %32 : Float(*, *) = aten::add(%44, %49, %31) - %27 : int = prim::Constant[value=1]() - %28 : Float(*, *) = aten::add(%45, %50, %27) - %24 : Float(*, *) = aten::sigmoid(%40) - %22 : Float(*, *) = aten::sigmoid(%36) - %20 : Float(*, *) = aten::tanh(%32) - %18 : Float(*, *) = aten::sigmoid(%28) - %16 : Float(*, *) = aten::mul(%22, %15) - %13 : Float(*, *) = aten::mul(%24, %20) - %9 : int = prim::Constant[value=1]() - %10 : Float(*, *) = aten::add(%16, %13, %9) - %6 : Float(*, *) = aten::tanh(%10) - %5 : Float(*, *) = aten::mul(%18, %6) - %2 : Float(*, *) = prim::FusedConcat[dim=0](%5, %10) - return (%2); +with prim::FusionGroup_0 = graph(%0 : Float(*, *) + %1 : Dynamic + %2 : Dynamic) { + %3 : Float(*, *), %4 : Float(*, *), %5 : Float(*, *), %6 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%2) + %7 : Float(*, *), %8 : Float(*, *), %9 : Float(*, *), %10 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%1) + %11 : int = prim::Constant[value=1]() + %12 : Float(*, *) = aten::add(%7, %3, %11) + %13 : int = prim::Constant[value=1]() + %14 : Float(*, *) = aten::add(%8, %4, %13) + %15 : int = prim::Constant[value=1]() + %16 : Float(*, *) = aten::add(%9, %5, %15) + %17 : int = prim::Constant[value=1]() + %18 : Float(*, *) = aten::add(%10, %6, %17) + %19 : Float(*, *) = aten::sigmoid(%12) + %20 : Float(*, *) = aten::sigmoid(%14) + %21 : Float(*, *) = aten::tanh(%16) + %22 : Float(*, *) = aten::sigmoid(%18) + %23 : Float(*, *) = aten::mul(%20, %0) + %24 : Float(*, *) = aten::mul(%19, %21) + %25 : int = prim::Constant[value=1]() + %26 : Float(*, *) = aten::add(%23, %24, %25) + %27 : Float(*, *) = aten::tanh(%26) + %28 : Float(*, *) = aten::mul(%22, %27) + %29 : Float(*, *) = prim::FusedConcat[dim=0](%28, %26) + return (%29); } diff --git a/test/expect/TestJit.test_lstm_fusion_cuda.expect b/test/expect/TestJit.test_lstm_fusion_cuda.expect index b9d885ade1f..f9539d6f241 100644 --- a/test/expect/TestJit.test_lstm_fusion_cuda.expect +++ b/test/expect/TestJit.test_lstm_fusion_cuda.expect @@ -16,28 +16,28 @@ graph(%0 : Float(*, *) %16 : Float(*, *), %17 : Float(*, *) = prim::FusionGroup_0[device=0](%2, %14, %15) return (%16, %17); } -with prim::FusionGroup_0 = graph(%13 : Float(*, *) - %39 : Dynamic - %44 : Dynamic) { - %45 : Float(*, *), %46 : Float(*, *), %47 : Float(*, *), %48 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%44) - %40 : Float(*, *), %41 : Float(*, *), %42 : Float(*, *), %43 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%39) - %37 : int = prim::Constant[value=1]() - %38 : Float(*, *) = aten::add(%40, %45, %37) - %33 : int = prim::Constant[value=1]() - %34 : Float(*, *) = aten::add(%41, %46, %33) - %29 : int = prim::Constant[value=1]() - %30 : Float(*, *) = aten::add(%42, %47, %29) +with prim::FusionGroup_0 = graph(%0 : Float(*, *) + %1 : Dynamic + %2 : Dynamic) { + %3 : Float(*, *), %4 : Float(*, *), %5 : Float(*, *), %6 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%2) + %7 : Float(*, *), %8 : Float(*, *), %9 : Float(*, *), %10 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%1) + %11 : int = prim::Constant[value=1]() + %12 : Float(*, *) = aten::add(%7, %3, %11) + %13 : int = prim::Constant[value=1]() + %14 : Float(*, *) = aten::add(%8, %4, %13) + %15 : int = prim::Constant[value=1]() + %16 : Float(*, *) = aten::add(%9, %5, %15) + %17 : int = prim::Constant[value=1]() + %18 : Float(*, *) = aten::add(%10, %6, %17) + %19 : Float(*, *) = aten::sigmoid(%12) + %20 : Float(*, *) = aten::sigmoid(%14) + %21 : Float(*, *) = aten::tanh(%16) + %22 : Float(*, *) = aten::sigmoid(%18) + %23 : Float(*, *) = aten::mul(%20, %0) + %24 : Float(*, *) = aten::mul(%19, %21) %25 : int = prim::Constant[value=1]() - %26 : Float(*, *) = aten::add(%43, %48, %25) - %22 : Float(*, *) = aten::sigmoid(%38) - %20 : Float(*, *) = aten::sigmoid(%34) - %18 : Float(*, *) = aten::tanh(%30) - %16 : Float(*, *) = aten::sigmoid(%26) - %14 : Float(*, *) = aten::mul(%20, %13) - %11 : Float(*, *) = aten::mul(%22, %18) - %7 : int = prim::Constant[value=1]() - %8 : Float(*, *) = aten::add(%14, %11, %7) - %4 : Float(*, *) = aten::tanh(%8) - %2 : Float(*, *) = aten::mul(%16, %4) - return (%2, %8); + %26 : Float(*, *) = aten::add(%23, %24, %25) + %27 : Float(*, *) = aten::tanh(%26) + %28 : Float(*, *) = aten::mul(%22, %27) + return (%28, %26); } diff --git a/test/expect/TestScript.test_call_traced_fn_from_traced_module.expect b/test/expect/TestScript.test_call_traced_fn_from_traced_module.expect index 0168dc99453..c7e0e11b92d 100644 --- a/test/expect/TestScript.test_call_traced_fn_from_traced_module.expect +++ b/test/expect/TestScript.test_call_traced_fn_from_traced_module.expect @@ -1,6 +1,6 @@ graph(%0 : Double(3, 4) %1 : Double(4, 5)) { %2 : Double(3, 5) = aten::mm(%0, %1), scope: TracedModule - %3 : Double(3, 4) = aten::neg(%2), scope: TracedModule/traced_fn + %3 : Double(*, *) = aten::neg(%2), scope: TracedModule/traced_fn return (%3); } diff --git a/test/expect/TestScript.test_call_traced_fn_from_tracing_fn.expect b/test/expect/TestScript.test_call_traced_fn_from_tracing_fn.expect index ef682b2a031..684f1d2523c 100644 --- a/test/expect/TestScript.test_call_traced_fn_from_tracing_fn.expect +++ b/test/expect/TestScript.test_call_traced_fn_from_tracing_fn.expect @@ -1,5 +1,5 @@ graph(%0 : Double(3, 4)) { - %1 : Double(3, 4) = aten::neg(%0), scope: traced_fn1 + %1 : Double(*, *) = aten::neg(%0), scope: traced_fn1 %2 : Long() = prim::Constant[value={1}]() %3 : int = prim::Constant[value=1]() %4 : Double(3, 4) = aten::add(%1, %2, %3) diff --git a/test/expect/TestScript.test_call_traced_mod_from_tracing_fn.expect b/test/expect/TestScript.test_call_traced_mod_from_tracing_fn.expect index 7235862571c..4be8b877307 100644 --- a/test/expect/TestScript.test_call_traced_mod_from_tracing_fn.expect +++ b/test/expect/TestScript.test_call_traced_mod_from_tracing_fn.expect @@ -1,6 +1,6 @@ graph(%0 : Double(3, 4)) { %1 : Double(4, 3) = prim::Constant[value=](), scope: TracedModule[TracedModule] - %2 : Double(3, 3) = aten::mm(%0, %1), scope: TracedModule[TracedModule] + %2 : Double(*, *) = aten::mm(%0, %1), scope: TracedModule[TracedModule] %3 : Double() = prim::Constant[value={1}]() %4 : int = prim::Constant[value=1]() %5 : Double(3, 3) = aten::add(%2, %3, %4) diff --git a/test/expect/TestScript.test_call_traced_module_from_traced_module.expect b/test/expect/TestScript.test_call_traced_module_from_traced_module.expect index 8658dd73d0c..33895f4d0cb 100644 --- a/test/expect/TestScript.test_call_traced_module_from_traced_module.expect +++ b/test/expect/TestScript.test_call_traced_module_from_traced_module.expect @@ -2,7 +2,7 @@ graph(%0 : Double(3, 4) %1 : Double(4, 5) %2 : Double(5, 7)) { %3 : Double(3, 5) = aten::mm(%0, %1), scope: TracedModule - %4 : Double(3, 7) = aten::mm(%3, %2), scope: TracedModule/TracedModule[TracedModule1][mod] + %4 : Double(*, *) = aten::mm(%3, %2), scope: TracedModule/TracedModule[TracedModule1][mod] %5 : Double() = prim::Constant[value={1}](), scope: TracedModule %6 : int = prim::Constant[value=1](), scope: TracedModule %7 : Double(3, 7) = aten::add(%4, %5, %6), scope: TracedModule diff --git a/test/expect/TestScript.test_chunk_fusion_cuda.expect b/test/expect/TestScript.test_chunk_fusion_cuda.expect index 706fd5e44c3..c4bff15b156 100644 --- a/test/expect/TestScript.test_chunk_fusion_cuda.expect +++ b/test/expect/TestScript.test_chunk_fusion_cuda.expect @@ -2,10 +2,10 @@ graph(%x : Float(*, *)) { %1 : Float(*, *) = prim::FusionGroup_0[device=0](%x) return (%1); } -with prim::FusionGroup_0 = graph(%7 : Float(*, *)) { - %8 : Float(*, *), %9 : Float(*, *), %10 : Float(*, *) = prim::ConstantChunk[chunks=3, dim=1](%7) - %6 : Float(*, *) = aten::mul(%8, %9) - %2 : int = prim::Constant[value=1]() - %3 : Float(*, *) = aten::add(%6, %10, %2) - return (%3); +with prim::FusionGroup_0 = graph(%0 : Float(*, *)) { + %1 : Float(*, *), %2 : Float(*, *), %3 : Float(*, *) = prim::ConstantChunk[chunks=3, dim=1](%0) + %4 : Float(*, *) = aten::mul(%1, %2) + %5 : int = prim::Constant[value=1]() + %6 : Float(*, *) = aten::add(%4, %3, %5) + return (%6); } diff --git a/test/expect/TestScript.test_chunk_multiple_fusion_cuda.expect b/test/expect/TestScript.test_chunk_multiple_fusion_cuda.expect index 84ba0050e04..56fe6d02ed5 100644 --- a/test/expect/TestScript.test_chunk_multiple_fusion_cuda.expect +++ b/test/expect/TestScript.test_chunk_multiple_fusion_cuda.expect @@ -5,26 +5,26 @@ graph(%s : Float(*, *, *) %4 : Float(*, *, *) = prim::FusionGroup_0[device=0](%s, %y, %x, %z) return (%4); } -with prim::FusionGroup_0 = graph(%24 : Float(*, *, *) - %28 : Float(*, *, *) - %31 : Float(*, *, *) - %35 : Float(*, *, *)) { - %36 : Float(*, *, *), %37 : Float(*, *, *) = prim::ConstantChunk[chunks=2, dim=2](%35) - %32 : Float(*, *, *), %33 : Float(*, *, *), %34 : Float(*, *, *) = prim::ConstantChunk[chunks=3, dim=1](%31) - %29 : Float(*, *, *), %30 : Float(*, *, *) = prim::ConstantChunk[chunks=2, dim=0](%28) - %26 : int = prim::Constant[value=1]() - %27 : Float(*, *, *) = aten::add(%24, %32, %26) - %22 : int = prim::Constant[value=1]() - %23 : Float(*, *, *) = aten::add(%27, %33, %22) - %18 : int = prim::Constant[value=1]() - %19 : Float(*, *, *) = aten::add(%23, %34, %18) - %14 : int = prim::Constant[value=1]() - %15 : Float(*, *, *) = aten::add(%19, %29, %14) - %10 : int = prim::Constant[value=1]() - %11 : Float(*, *, *) = aten::add(%15, %30, %10) - %6 : int = prim::Constant[value=1]() - %7 : Float(*, *, *) = aten::add(%11, %36, %6) - %2 : int = prim::Constant[value=1]() - %3 : Float(*, *, *) = aten::add(%7, %37, %2) - return (%3); +with prim::FusionGroup_0 = graph(%0 : Float(*, *, *) + %1 : Float(*, *, *) + %2 : Float(*, *, *) + %3 : Float(*, *, *)) { + %4 : Float(*, *, *), %5 : Float(*, *, *) = prim::ConstantChunk[chunks=2, dim=2](%3) + %6 : Float(*, *, *), %7 : Float(*, *, *), %8 : Float(*, *, *) = prim::ConstantChunk[chunks=3, dim=1](%2) + %9 : Float(*, *, *), %10 : Float(*, *, *) = prim::ConstantChunk[chunks=2, dim=0](%1) + %11 : int = prim::Constant[value=1]() + %12 : Float(*, *, *) = aten::add(%0, %6, %11) + %13 : int = prim::Constant[value=1]() + %14 : Float(*, *, *) = aten::add(%12, %7, %13) + %15 : int = prim::Constant[value=1]() + %16 : Float(*, *, *) = aten::add(%14, %8, %15) + %17 : int = prim::Constant[value=1]() + %18 : Float(*, *, *) = aten::add(%16, %9, %17) + %19 : int = prim::Constant[value=1]() + %20 : Float(*, *, *) = aten::add(%18, %10, %19) + %21 : int = prim::Constant[value=1]() + %22 : Float(*, *, *) = aten::add(%20, %4, %21) + %23 : int = prim::Constant[value=1]() + %24 : Float(*, *, *) = aten::add(%22, %5, %23) + return (%24); } diff --git a/test/expect/TestScript.test_lstm_fusion_cuda-backward.expect b/test/expect/TestScript.test_lstm_fusion_cuda-backward.expect index f3b9696c654..50d7d0799c8 100644 --- a/test/expect/TestScript.test_lstm_fusion_cuda-backward.expect +++ b/test/expect/TestScript.test_lstm_fusion_cuda-backward.expect @@ -7,9 +7,9 @@ graph(%0 : Float(*, *) %6 : Dynamic %7 : Dynamic %8 : Dynamic - %x : Float(*, *) - %hx : Float(*, *) - %cx : Float(*, *) + %9 : Float(*, *) + %10 : Float(*, *) + %11 : Float(*, *) %12 : Float(*, *) %13 : Float(*, *) %ingate : Float(*, *) @@ -18,58 +18,67 @@ graph(%0 : Float(*, *) %outgate : Float(*, *) %18 : Float(*, *)) { %19 : int = prim::Constant[value=1]() - %20 : Float(*, *) = prim::FusionGroup_0[device=0](%ingate, %forgetgate, %cellgate, %outgate, %cx, %1, %18, %0) - %21 : Float(*, *) = aten::mul(%20, %19) - %22 : Float(*, *) = aten::t(%hx) - %23 : Float(*, *) = aten::mm(%22, %21) - %24 : Float(*, *) = aten::t(%23) - %25 : Float(*, *) = aten::t(%x) - %26 : Float(*, *) = aten::mm(%25, %20) - %27 : Float(*, *) = aten::t(%26) - return (%27, %24, %21, %21); + %20 : Float(*, *), %21 : Float(*, *) = prim::FusionGroup_0[device=0](%ingate, %forgetgate, %cellgate, %outgate, %11, %1, %18, %0) + %22 : Float(*, *) = aten::mul(%20, %19) + %23 : Float(*, *) = aten::t(%12) + %24 : Float(*, *) = aten::mm(%20, %23) + %25 : Float(*, *) = aten::mul(%24, %19) + %26 : Float(*, *) = aten::t(%10) + %27 : Float(*, *) = aten::mm(%26, %20) + %28 : Float(*, *) = aten::mul(%27, %19) + %29 : Float(*, *) = aten::t(%13) + %30 : Float(*, *) = aten::mm(%22, %29) + %31 : Float(*, *) = aten::t(%9) + %32 : Float(*, *) = aten::mm(%31, %22) + %33 : Float(*, *) = aten::t(%32) + %34 : Float(*, *) = aten::t(%28) + return (%34, %33, %30, %25, %22, %22, %21); } -with prim::FusionGroup_0 = graph(%9 : Float(*, *) - %19 : Float(*, *) - %33 : Float(*, *) - %39 : Float(*, *) - %46 : Float(*, *) - %53 : Float(*, *) - %65 : Float(*, *) - %67 : Float(*, *)) { - %69 : Float(*, *) = aten::mul(%67, %65) - %68 : Float(*, *) = aten::mul(%67, %39) - %66 : Float(*, *) = aten::mul(%65, %65) - %64 : Float(*, *) = aten::neg(%66) - %61 : int = prim::Constant[value=1]() - %62 : Float(*, *) = aten::add(%64, %61, %61) - %59 : Float(*, *) = aten::mul(%68, %62) - %55 : int = prim::Constant[value=1]() - %56 : Float(*, *) = aten::add(%53, %59, %55) - %51 : int = prim::Constant[value=1]() - %52 : Float(*, *) = aten::mul(%56, %51) - %50 : Float(*, *) = aten::mul(%52, %33) - %49 : Float(*, *) = aten::mul(%52, %9) - %47 : Float(*, *) = aten::mul(%56, %46) - %44 : Float(*, *) = aten::neg(%39) - %42 : int = prim::Constant[value=1]() - %43 : Float(*, *) = aten::add(%44, %42, %42) - %40 : Float(*, *) = aten::mul(%69, %39) - %37 : Float(*, *) = aten::mul(%40, %43) - %34 : Float(*, *) = aten::mul(%33, %33) - %32 : Float(*, *) = aten::neg(%34) - %29 : int = prim::Constant[value=1]() - %30 : Float(*, *) = aten::add(%32, %29, %29) - %27 : Float(*, *) = aten::mul(%49, %30) - %24 : Float(*, *) = aten::neg(%19) - %22 : int = prim::Constant[value=1]() - %23 : Float(*, *) = aten::add(%24, %22, %22) - %20 : Float(*, *) = aten::mul(%47, %19) - %17 : Float(*, *) = aten::mul(%20, %23) - %14 : Float(*, *) = aten::neg(%9) +with prim::FusionGroup_0 = graph(%0 : Float(*, *) + %1 : Float(*, *) + %2 : Float(*, *) + %3 : Float(*, *) + %4 : Float(*, *) + %5 : Float(*, *) + %6 : Float(*, *) + %7 : Float(*, *)) { + %8 : Float(*, *) = aten::mul(%5, %3) + %9 : Float(*, *) = aten::mul(%6, %6) + %10 : Float(*, *) = aten::neg(%9) + %11 : int = prim::Constant[value=1]() %12 : int = prim::Constant[value=1]() - %13 : Float(*, *) = aten::add(%14, %12, %12) - %10 : Float(*, *) = aten::mul(%50, %9) - %7 : Float(*, *) = aten::mul(%10, %13) - %4 : Float(*, *) = prim::FusedConcat[dim=1](%7, %17, %27, %37) - return (%4); + %13 : Float(*, *) = aten::add(%10, %12, %12) + %14 : Float(*, *) = aten::mul(%8, %13) + %15 : int = prim::Constant[value=1]() + %16 : int = prim::Constant[value=1]() + %17 : Float(*, *) = aten::add(%7, %14, %16) + %18 : Float(*, *) = aten::mul(%17, %1) + %19 : Float(*, *) = aten::mul(%5, %6) + %20 : int = prim::Constant[value=1]() + %21 : Float(*, *) = aten::mul(%17, %20) + %22 : Float(*, *) = aten::mul(%21, %2) + %23 : Float(*, *) = aten::mul(%21, %0) + %24 : Float(*, *) = aten::mul(%17, %4) + %25 : Float(*, *) = aten::neg(%3) + %26 : int = prim::Constant[value=1]() + %27 : Float(*, *) = aten::add(%25, %26, %26) + %28 : Float(*, *) = aten::mul(%19, %3) + %29 : Float(*, *) = aten::mul(%28, %27) + %30 : Float(*, *) = aten::mul(%2, %2) + %31 : Float(*, *) = aten::neg(%30) + %32 : int = prim::Constant[value=1]() + %33 : Float(*, *) = aten::add(%31, %32, %32) + %34 : Float(*, *) = aten::mul(%23, %33) + %35 : Float(*, *) = aten::neg(%1) + %36 : int = prim::Constant[value=1]() + %37 : Float(*, *) = aten::add(%35, %36, %36) + %38 : Float(*, *) = aten::mul(%24, %1) + %39 : Float(*, *) = aten::mul(%38, %37) + %40 : Float(*, *) = aten::neg(%0) + %41 : int = prim::Constant[value=1]() + %42 : Float(*, *) = aten::add(%40, %41, %41) + %43 : Float(*, *) = aten::mul(%22, %0) + %44 : Float(*, *) = aten::mul(%43, %42) + %45 : Float(*, *) = prim::FusedConcat[dim=1](%44, %39, %34, %29) + return (%45, %18); } diff --git a/test/expect/TestScript.test_lstm_fusion_cuda-forward.expect b/test/expect/TestScript.test_lstm_fusion_cuda-forward.expect index 6f0d19b5016..84905033f5a 100644 --- a/test/expect/TestScript.test_lstm_fusion_cuda-forward.expect +++ b/test/expect/TestScript.test_lstm_fusion_cuda-forward.expect @@ -1,44 +1,54 @@ -graph(%x.1 : Float(*, *) - %hx.1 : Float(*, *) - %cx.1 : Float(*, *) +graph(%x : Float(*, *) + %hx : Float(*, *) + %cx : Float(*, *) %w_ih : Float(*, *) %w_hh : Float(*, *) %b_ih : Float(*) %b_hh : Float(*)) { - %7 : Float(*, *) = aten::t(%w_ih) - %8 : Float(*, *) = aten::t(%w_hh) - %9 : Float(*, *) = aten::mm(%hx.1, %8) + %7 : Float(*, *), %8 : Float(*, *) = prim::DifferentiableGraph_0(%w_ih, %w_hh, %hx, %x, %b_ih, %b_hh, %cx) + return (%8, %7); +} +with prim::DifferentiableGraph_0 = graph(%0 : Float(*, *) + %1 : Float(*, *) + %2 : Float(*, *) + %3 : Float(*, *) + %4 : Float(*) + %5 : Float(*) + %6 : Float(*, *)) { + %7 : Float(*, *) = aten::t(%0) + %8 : Float(*, *) = aten::t(%1) + %9 : Float(*, *) = aten::mm(%2, %8) %10 : int = prim::Constant[value=1]() - %11 : Float(*, *) = aten::addmm(%9, %x.1, %7, %10, %10) - %12 : Float(*, *) = aten::add(%11, %b_ih, %10) - %13 : Dynamic[] = prim::ListConstruct(%12, %b_hh) + %11 : Float(*, *) = aten::addmm(%9, %3, %7, %10, %10) + %12 : Float(*, *) = aten::add(%11, %4, %10) + %13 : Dynamic[] = prim::ListConstruct(%12, %5) %14 : Dynamic[] = aten::broadcast_tensors(%13) %15 : Dynamic, %16 : Dynamic = prim::ListUnpack(%14) - %hy : Float(*, *), %18 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_0[device=0](%cx.1, %15, %16) - return (%hy, %cy, %7, %8, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %18); + %hy : Float(*, *), %18 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_0[device=0](%6, %15, %16) + return (%cy, %hy, %7, %8, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %18); } -with prim::FusionGroup_0 = graph(%13 : Float(*, *) - %39 : Dynamic - %44 : Dynamic) { - %45 : Dynamic, %46 : Dynamic, %47 : Dynamic, %48 : Dynamic = prim::FusedChunk[chunks=4, dim=1](%44) - %40 : Dynamic, %41 : Dynamic, %42 : Dynamic, %43 : Dynamic = prim::FusedChunk[chunks=4, dim=1](%39) - %37 : int = prim::Constant[value=1]() - %38 : Float(*, *) = aten::add(%40, %45, %37) - %33 : int = prim::Constant[value=1]() - %34 : Float(*, *) = aten::add(%41, %46, %33) - %29 : int = prim::Constant[value=1]() - %30 : Float(*, *) = aten::add(%42, %47, %29) +with prim::FusionGroup_0 = graph(%0 : Float(*, *) + %1 : Dynamic + %2 : Dynamic) { + %3 : Float(*, *), %4 : Float(*, *), %5 : Float(*, *), %6 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%2) + %7 : Float(*, *), %8 : Float(*, *), %9 : Float(*, *), %10 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%1) + %11 : int = prim::Constant[value=1]() + %12 : Float(*, *) = aten::add(%7, %3, %11) + %13 : int = prim::Constant[value=1]() + %14 : Float(*, *) = aten::add(%8, %4, %13) + %15 : int = prim::Constant[value=1]() + %16 : Float(*, *) = aten::add(%9, %5, %15) + %17 : int = prim::Constant[value=1]() + %18 : Float(*, *) = aten::add(%10, %6, %17) + %ingate.1 : Float(*, *) = aten::sigmoid(%12) + %forgetgate.1 : Float(*, *) = aten::sigmoid(%14) + %cellgate.1 : Float(*, *) = aten::tanh(%16) + %outgate.1 : Float(*, *) = aten::sigmoid(%18) + %23 : Float(*, *) = aten::mul(%forgetgate.1, %0) + %24 : Float(*, *) = aten::mul(%ingate.1, %cellgate.1) %25 : int = prim::Constant[value=1]() - %26 : Float(*, *) = aten::add(%43, %48, %25) - %ingate.1 : Float(*, *) = aten::sigmoid(%38) - %forgetgate.1 : Float(*, *) = aten::sigmoid(%34) - %cellgate.1 : Float(*, *) = aten::tanh(%30) - %outgate.1 : Float(*, *) = aten::sigmoid(%26) - %14 : Float(*, *) = aten::mul(%forgetgate.1, %13) - %11 : Float(*, *) = aten::mul(%ingate.1, %cellgate.1) - %7 : int = prim::Constant[value=1]() - %cy : Float(*, *) = aten::add(%14, %11, %7) - %4 : Float(*, *) = aten::tanh(%cy) - %hy : Float(*, *) = aten::mul(%outgate.1, %4) - return (%hy, %4, %cy, %outgate.1, %cellgate.1, %forgetgate.1, %ingate.1); + %cy : Float(*, *) = aten::add(%23, %24, %25) + %27 : Float(*, *) = aten::tanh(%cy) + %hy : Float(*, *) = aten::mul(%outgate.1, %27) + return (%hy, %27, %cy, %outgate.1, %cellgate.1, %forgetgate.1, %ingate.1); } diff --git a/test/expect/TestScript.test_milstm_fusion_cuda-backward.expect b/test/expect/TestScript.test_milstm_fusion_cuda-backward.expect index ac8953c2dd9..5a099e3b3e6 100644 --- a/test/expect/TestScript.test_milstm_fusion_cuda-backward.expect +++ b/test/expect/TestScript.test_milstm_fusion_cuda-backward.expect @@ -10,12 +10,12 @@ graph(%0 : Float(*, *) %9 : Dynamic %10 : Dynamic %11 : Dynamic - %x : Float(*, *) - %hx : Float(*, *) - %cx : Float(*, *) - %alpha : Float(*) - %beta_i : Float(*) - %beta_h : Float(*) + %12 : Float(*, *) + %13 : Float(*, *) + %14 : Float(*) + %15 : Float(*) + %16 : Float(*) + %17 : Float(*, *) %18 : Float(*, *) %Wx : Float(*, *) %20 : Float(*, *) @@ -26,85 +26,92 @@ graph(%0 : Float(*, *) %cellgate : Float(*, *) %outgate : Float(*, *) %27 : Float(*, *)) { - %28 : Float(*, *) = prim::FusionGroup_0[device=0](%ingate, %forgetgate, %cellgate, %outgate, %cx, %1, %27, %0) - %29 : Float(*, *), %30 : Float(*, *), %31 : Float(*, *), %32 : Float(*, *), %33 : Float(*, *), %34 : Float(*, *) = prim::FusionGroup_1[device=0](%alpha, %beta_i, %Wx, %28, %Uz, %22, %beta_h) - %35 : Float(*, *) = aten::t(%hx) - %36 : Float(*, *) = aten::mm(%35, %31) - %37 : Float(*, *) = aten::t(%36) - %38 : Float(*, *) = aten::t(%x) - %39 : Float(*, *) = aten::mm(%38, %29) + %28 : Float(*, *), %29 : Float(*, *) = prim::FusionGroup_0[device=0](%ingate, %forgetgate, %cellgate, %outgate, %17, %1, %27, %0) + %30 : Float(*, *), %31 : Float(*, *), %32 : Float(*, *), %33 : Float(*, *), %34 : Float(*, *), %35 : Float(*, *) = prim::FusionGroup_1[device=0](%14, %15, %Wx, %28, %Uz, %22, %16) + %36 : Float(*, *) = aten::t(%20) + %37 : Float(*, *) = aten::mm(%32, %36) + %38 : Float(*, *) = aten::t(%13) + %39 : Float(*, *) = aten::mm(%38, %32) %40 : Float(*, *) = aten::t(%39) - return (%40, %37, %30, %32, %33, %34); + %41 : Float(*, *) = aten::t(%18) + %42 : Float(*, *) = aten::mm(%30, %41) + %43 : Float(*, *) = aten::t(%12) + %44 : Float(*, *) = aten::mm(%43, %30) + %45 : Float(*, *) = aten::t(%44) + return (%45, %42, %40, %37, %31, %33, %34, %35, %29); } -with prim::FusionGroup_0 = graph(%9 : Float(*, *) - %19 : Float(*, *) - %33 : Float(*, *) - %39 : Float(*, *) - %46 : Float(*, *) - %53 : Float(*, *) - %65 : Float(*, *) - %67 : Float(*, *)) { - %69 : Float(*, *) = aten::mul(%67, %65) - %68 : Float(*, *) = aten::mul(%67, %39) - %66 : Float(*, *) = aten::mul(%65, %65) - %64 : Float(*, *) = aten::neg(%66) - %61 : int = prim::Constant[value=1]() - %62 : Float(*, *) = aten::add(%64, %61, %61) - %59 : Float(*, *) = aten::mul(%68, %62) - %55 : int = prim::Constant[value=1]() - %56 : Float(*, *) = aten::add(%53, %59, %55) - %51 : int = prim::Constant[value=1]() - %52 : Float(*, *) = aten::mul(%56, %51) - %50 : Float(*, *) = aten::mul(%52, %33) - %49 : Float(*, *) = aten::mul(%52, %9) - %47 : Float(*, *) = aten::mul(%56, %46) - %44 : Float(*, *) = aten::neg(%39) - %42 : int = prim::Constant[value=1]() - %43 : Float(*, *) = aten::add(%44, %42, %42) - %40 : Float(*, *) = aten::mul(%69, %39) - %37 : Float(*, *) = aten::mul(%40, %43) - %34 : Float(*, *) = aten::mul(%33, %33) - %32 : Float(*, *) = aten::neg(%34) - %29 : int = prim::Constant[value=1]() - %30 : Float(*, *) = aten::add(%32, %29, %29) - %27 : Float(*, *) = aten::mul(%49, %30) - %24 : Float(*, *) = aten::neg(%19) - %22 : int = prim::Constant[value=1]() - %23 : Float(*, *) = aten::add(%24, %22, %22) - %20 : Float(*, *) = aten::mul(%47, %19) - %17 : Float(*, *) = aten::mul(%20, %23) - %14 : Float(*, *) = aten::neg(%9) +with prim::FusionGroup_0 = graph(%0 : Float(*, *) + %1 : Float(*, *) + %2 : Float(*, *) + %3 : Float(*, *) + %4 : Float(*, *) + %5 : Float(*, *) + %6 : Float(*, *) + %7 : Float(*, *)) { + %8 : Float(*, *) = aten::mul(%5, %3) + %9 : Float(*, *) = aten::mul(%6, %6) + %10 : Float(*, *) = aten::neg(%9) + %11 : int = prim::Constant[value=1]() %12 : int = prim::Constant[value=1]() - %13 : Float(*, *) = aten::add(%14, %12, %12) - %10 : Float(*, *) = aten::mul(%50, %9) - %7 : Float(*, *) = aten::mul(%10, %13) - %4 : Float(*, *) = prim::FusedConcat[dim=1](%7, %17, %27, %37) - return (%4); -} -with prim::FusionGroup_1 = graph(%5 : Float(*) - %8 : Float(*) - %10 : Float(*, *) - %12 : Float(*, *) - %13 : Float(*, *) - %20 : Float(*, *) - %22 : Float(*)) { - %30 : int = prim::Constant[value=1]() - %29 : int = prim::Constant[value=1]() - %28 : int = prim::Constant[value=1]() + %13 : Float(*, *) = aten::add(%10, %12, %12) + %14 : Float(*, *) = aten::mul(%8, %13) + %15 : int = prim::Constant[value=1]() + %16 : int = prim::Constant[value=1]() + %17 : Float(*, *) = aten::add(%7, %14, %16) + %18 : Float(*, *) = aten::mul(%17, %1) + %19 : Float(*, *) = aten::mul(%5, %6) + %20 : int = prim::Constant[value=1]() + %21 : Float(*, *) = aten::mul(%17, %20) + %22 : Float(*, *) = aten::mul(%21, %2) + %23 : Float(*, *) = aten::mul(%21, %0) + %24 : Float(*, *) = aten::mul(%17, %4) + %25 : Float(*, *) = aten::neg(%3) %26 : int = prim::Constant[value=1]() - %27 : Float(*, *) = aten::mul(%12, %26) - %25 : Float(*, *) = aten::mul(%27, %13) - %24 : Float(*, *) = aten::mul(%27, %10) - %23 : Float(*, *) = aten::mul(%27, %22) - %21 : Float(*, *) = aten::mul(%12, %20) - %19 : int = prim::Constant[value=1]() - %17 : int = prim::Constant[value=1]() - %18 : Float(*, *) = aten::add(%23, %21, %17) - %14 : Float(*, *) = aten::mul(%12, %13) - %11 : Float(*, *) = aten::mul(%14, %10) - %9 : Float(*, *) = aten::mul(%27, %8) - %6 : Float(*, *) = aten::mul(%14, %5) - %2 : int = prim::Constant[value=1]() - %3 : Float(*, *) = aten::add(%9, %6, %2) - return (%3, %11, %18, %24, %25, %27); + %27 : Float(*, *) = aten::add(%25, %26, %26) + %28 : Float(*, *) = aten::mul(%19, %3) + %29 : Float(*, *) = aten::mul(%28, %27) + %30 : Float(*, *) = aten::mul(%2, %2) + %31 : Float(*, *) = aten::neg(%30) + %32 : int = prim::Constant[value=1]() + %33 : Float(*, *) = aten::add(%31, %32, %32) + %34 : Float(*, *) = aten::mul(%23, %33) + %35 : Float(*, *) = aten::neg(%1) + %36 : int = prim::Constant[value=1]() + %37 : Float(*, *) = aten::add(%35, %36, %36) + %38 : Float(*, *) = aten::mul(%24, %1) + %39 : Float(*, *) = aten::mul(%38, %37) + %40 : Float(*, *) = aten::neg(%0) + %41 : int = prim::Constant[value=1]() + %42 : Float(*, *) = aten::add(%40, %41, %41) + %43 : Float(*, *) = aten::mul(%22, %0) + %44 : Float(*, *) = aten::mul(%43, %42) + %45 : Float(*, *) = prim::FusedConcat[dim=1](%44, %39, %34, %29) + return (%45, %18); +} +with prim::FusionGroup_1 = graph(%0 : Float(*) + %1 : Float(*) + %2 : Float(*, *) + %3 : Float(*, *) + %4 : Float(*, *) + %5 : Float(*, *) + %6 : Float(*)) { + %7 : int = prim::Constant[value=1]() + %8 : int = prim::Constant[value=1]() + %9 : int = prim::Constant[value=1]() + %10 : int = prim::Constant[value=1]() + %11 : Float(*, *) = aten::mul(%3, %10) + %12 : Float(*, *) = aten::mul(%11, %4) + %13 : Float(*, *) = aten::mul(%11, %2) + %14 : Float(*, *) = aten::mul(%11, %6) + %15 : Float(*, *) = aten::mul(%3, %5) + %16 : int = prim::Constant[value=1]() + %17 : int = prim::Constant[value=1]() + %18 : Float(*, *) = aten::add(%14, %15, %17) + %19 : Float(*, *) = aten::mul(%3, %4) + %20 : Float(*, *) = aten::mul(%19, %2) + %21 : Float(*, *) = aten::mul(%11, %1) + %22 : Float(*, *) = aten::mul(%19, %0) + %23 : int = prim::Constant[value=1]() + %24 : Float(*, *) = aten::add(%21, %22, %23) + return (%24, %20, %18, %13, %12, %11); } diff --git a/test/expect/TestScript.test_milstm_fusion_cuda-forward.expect b/test/expect/TestScript.test_milstm_fusion_cuda-forward.expect index 6520c205334..b9ab84ced12 100644 --- a/test/expect/TestScript.test_milstm_fusion_cuda-forward.expect +++ b/test/expect/TestScript.test_milstm_fusion_cuda-forward.expect @@ -1,60 +1,73 @@ -graph(%x.1 : Float(*, *) - %hx.1 : Float(*, *) - %cx.1 : Float(*, *) +graph(%x : Float(*, *) + %hx : Float(*, *) + %cx : Float(*, *) %w_ih : Float(*, *) %w_hh : Float(*, *) - %alpha.1 : Float(*) - %beta_i.1 : Float(*) - %beta_h.1 : Float(*) + %alpha : Float(*) + %beta_i : Float(*) + %beta_h : Float(*) %bias : Float(*)) { - %9 : Float(*, *) = aten::t(%w_ih) - %Wx.1 : Float(*, *) = aten::mm(%x.1, %9) - %11 : Float(*, *) = aten::t(%w_hh) - %Uz.1 : Float(*, *) = aten::mm(%hx.1, %11) - %13 : Float(*, *), %14 : Float(*, *) = prim::FusionGroup_0[device=0](%beta_h.1, %Uz.1, %beta_i.1, %Wx.1, %alpha.1) - %15 : Dynamic[] = prim::ListConstruct(%13, %bias) - %16 : Dynamic[] = aten::broadcast_tensors(%15) - %17 : Dynamic, %18 : Dynamic = prim::ListUnpack(%16) - %hy : Float(*, *), %20 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_1[device=0](%cx.1, %17, %18) - return (%hy, %cy, %9, %Wx.1, %11, %Uz.1, %14, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %20); + %9 : Float(*, *), %10 : Float(*, *) = prim::DifferentiableGraph_0(%w_ih, %x, %w_hh, %hx, %alpha, %beta_i, %beta_h, %bias, %cx) + return (%10, %9); } -with prim::FusionGroup_0 = graph(%4 : Float(*) - %5 : Float(*, *) - %11 : Float(*) - %12 : Float(*, *) - %16 : Float(*)) { - %17 : Float(*, *) = aten::mul(%16, %12) - %15 : Float(*, *) = aten::mul(%17, %5) - %13 : Float(*, *) = aten::mul(%11, %12) - %9 : int = prim::Constant[value=1]() - %10 : Float(*, *) = aten::add(%15, %13, %9) - %6 : Float(*, *) = aten::mul(%4, %5) - %2 : int = prim::Constant[value=1]() - %3 : Float(*, *) = aten::add(%10, %6, %2) - return (%3, %17); +with prim::DifferentiableGraph_0 = graph(%0 : Float(*, *) + %1 : Float(*, *) + %2 : Float(*, *) + %3 : Float(*, *) + %4 : Float(*) + %5 : Float(*) + %6 : Float(*) + %7 : Float(*) + %8 : Float(*, *)) { + %9 : Float(*, *) = aten::t(%0) + %Wx.1 : Float(*, *) = aten::mm(%1, %9) + %11 : Float(*, *) = aten::t(%2) + %Uz.1 : Float(*, *) = aten::mm(%3, %11) + %13 : int = prim::Constant[value=1]() + %14 : Float(*, *), %15 : Float(*, *) = prim::FusionGroup_0[device=0](%6, %Uz.1, %5, %Wx.1, %4) + %16 : Dynamic[] = prim::ListConstruct(%14, %7) + %17 : Dynamic[] = aten::broadcast_tensors(%16) + %18 : Dynamic, %19 : Dynamic = prim::ListUnpack(%17) + %hy : Float(*, *), %21 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_1[device=0](%8, %18, %19) + return (%cy, %hy, %9, %Wx.1, %11, %Uz.1, %15, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %21); } -with prim::FusionGroup_1 = graph(%13 : Float(*, *) - %39 : Dynamic - %44 : Dynamic) { - %45 : Dynamic, %46 : Dynamic, %47 : Dynamic, %48 : Dynamic = prim::FusedChunk[chunks=4, dim=1](%44) - %40 : Dynamic, %41 : Dynamic, %42 : Dynamic, %43 : Dynamic = prim::FusedChunk[chunks=4, dim=1](%39) - %37 : int = prim::Constant[value=1]() - %38 : Float(*, *) = aten::add(%40, %45, %37) - %33 : int = prim::Constant[value=1]() - %34 : Float(*, *) = aten::add(%41, %46, %33) - %29 : int = prim::Constant[value=1]() - %30 : Float(*, *) = aten::add(%42, %47, %29) +with prim::FusionGroup_0 = graph(%0 : Float(*) + %1 : Float(*, *) + %2 : Float(*) + %3 : Float(*, *) + %4 : Float(*)) { + %5 : Float(*, *) = aten::mul(%4, %3) + %6 : Float(*, *) = aten::mul(%5, %1) + %7 : Float(*, *) = aten::mul(%2, %3) + %8 : int = prim::Constant[value=1]() + %9 : Float(*, *) = aten::add(%6, %7, %8) + %10 : Float(*, *) = aten::mul(%0, %1) + %11 : int = prim::Constant[value=1]() + %12 : Float(*, *) = aten::add(%9, %10, %11) + return (%12, %5); +} +with prim::FusionGroup_1 = graph(%0 : Float(*, *) + %1 : Dynamic + %2 : Dynamic) { + %3 : Float(*, *), %4 : Float(*, *), %5 : Float(*, *), %6 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%2) + %7 : Float(*, *), %8 : Float(*, *), %9 : Float(*, *), %10 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%1) + %11 : int = prim::Constant[value=1]() + %12 : Float(*, *) = aten::add(%7, %3, %11) + %13 : int = prim::Constant[value=1]() + %14 : Float(*, *) = aten::add(%8, %4, %13) + %15 : int = prim::Constant[value=1]() + %16 : Float(*, *) = aten::add(%9, %5, %15) + %17 : int = prim::Constant[value=1]() + %18 : Float(*, *) = aten::add(%10, %6, %17) + %ingate.1 : Float(*, *) = aten::sigmoid(%12) + %forgetgate.1 : Float(*, *) = aten::sigmoid(%14) + %cellgate.1 : Float(*, *) = aten::tanh(%16) + %outgate.1 : Float(*, *) = aten::sigmoid(%18) + %23 : Float(*, *) = aten::mul(%forgetgate.1, %0) + %24 : Float(*, *) = aten::mul(%ingate.1, %cellgate.1) %25 : int = prim::Constant[value=1]() - %26 : Float(*, *) = aten::add(%43, %48, %25) - %ingate.1 : Float(*, *) = aten::sigmoid(%38) - %forgetgate.1 : Float(*, *) = aten::sigmoid(%34) - %cellgate.1 : Float(*, *) = aten::tanh(%30) - %outgate.1 : Float(*, *) = aten::sigmoid(%26) - %14 : Float(*, *) = aten::mul(%forgetgate.1, %13) - %11 : Float(*, *) = aten::mul(%ingate.1, %cellgate.1) - %7 : int = prim::Constant[value=1]() - %cy : Float(*, *) = aten::add(%14, %11, %7) - %4 : Float(*, *) = aten::tanh(%cy) - %hy : Float(*, *) = aten::mul(%outgate.1, %4) - return (%hy, %4, %cy, %outgate.1, %cellgate.1, %forgetgate.1, %ingate.1); + %cy : Float(*, *) = aten::add(%23, %24, %25) + %27 : Float(*, *) = aten::tanh(%cy) + %hy : Float(*, *) = aten::mul(%outgate.1, %27) + return (%hy, %27, %cy, %outgate.1, %cellgate.1, %forgetgate.1, %ingate.1); } diff --git a/test/expect/TestScript.test_tensor_scalar_fusion_cuda-1.expect b/test/expect/TestScript.test_tensor_scalar_fusion_cuda-1.expect index 8e1bfb179d1..0bbd1a7025a 100644 --- a/test/expect/TestScript.test_tensor_scalar_fusion_cuda-1.expect +++ b/test/expect/TestScript.test_tensor_scalar_fusion_cuda-1.expect @@ -4,8 +4,8 @@ graph(%x : Float(*, *)) { } with prim::FusionGroup_0 = graph(%0 : Float(*, *)) { %z : float = prim::Constant[value=3]() - %4 : int = prim::Constant[value=1]() - %y : Float(*, *) = aten::add(%0, %z, %4) - %2 : Float(*, *) = aten::mul(%0, %y) - return (%2); + %2 : int = prim::Constant[value=1]() + %y : Float(*, *) = aten::add(%0, %z, %2) + %4 : Float(*, *) = aten::mul(%0, %y) + return (%4); } diff --git a/test/test_jit.py b/test/test_jit.py index 899bdcb973d..e8b85a0b6a2 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -152,12 +152,19 @@ def get_fn(file_name, script_path): def get_execution_plan(graph_executor_state): - execution_plans = graph_executor_state.execution_plans.values() + execution_plans = list(graph_executor_state.execution_plans.values()) num_plans = len(execution_plans) if num_plans != 1: raise RuntimeError('This test assumes this GraphExecutor should ' 'only have one execution plan, got: {}'.format(num_plans)) - return list(execution_plans)[0] + return execution_plans[0] + + +def get_grad_executor(plan_state): + if len(list(plan_state.graph.nodes())) != 1: + raise RuntimeError("Can't get a grad_executor for a non-differentiable graph") + grad_executors = list(plan_state.code.grad_executors()) + return grad_executors[0] def backward_graph(script_module): @@ -165,10 +172,8 @@ def backward_graph(script_module): raise RuntimeError('Expected ScriptModule') ge_state = script_module.get_debug_state() fwd_plan = get_execution_plan(ge_state) - if fwd_plan.grad_executor is None: - raise RuntimeError('Error: tried to get grad_executor of function ' - 'that hasn\'t run backward yet.') - bwd_plan = get_execution_plan(fwd_plan.grad_executor) + grad_executor = get_grad_executor(fwd_plan) + bwd_plan = get_execution_plan(grad_executor.get_debug_state()) # Running JIT passes requires that we own the graph (with a shared_ptr). # The debug state struct does not own its graph so we make a copy of it. return bwd_plan.graph.copy() @@ -2807,7 +2812,6 @@ a") @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows") @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") @skipIfRocm - @unittest.skip("Temporarily broken") def test_lstm_fusion_cuda(self): inputs = get_lstm_inputs('cuda', training=True) module = self.checkScript(LSTMCellS, inputs) @@ -2820,7 +2824,6 @@ a") @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows") @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") @skipIfRocm - @unittest.skip("Temporarily broken") def test_milstm_fusion_cuda(self): inputs = get_milstm_inputs('cuda', training=True) module = self.checkScript(MiLSTMCell, inputs) diff --git a/torch/csrc/autograd/input_metadata.h b/torch/csrc/autograd/input_metadata.h index e421441f872..7dfa951439c 100644 --- a/torch/csrc/autograd/input_metadata.h +++ b/torch/csrc/autograd/input_metadata.h @@ -15,7 +15,7 @@ struct InputMetadata { InputMetadata(const at::Type& type, at::IntList shape, const int64_t device) : type_{&type} , shape_{shape}, device_{device} { } - InputMetadata(const at::Tensor& t) + InputMetadata(const at::Tensor& t) : InputMetadata(t.type(), t.sizes(), t.is_cuda() ? t.get_device() : - 1) { } bool is_valid() const { @@ -35,6 +35,10 @@ struct InputMetadata { return device_; } + at::Tensor zeros_like() const { + return at::zeros(shape_, at::TensorOptions(*type_, static_cast(device_))); + } + private: const at::Type* type_ = nullptr; at::DimVector shape_; diff --git a/torch/csrc/jit/argument_spec.h b/torch/csrc/jit/argument_spec.h index e1d168901e5..5f02c8ae4f0 100644 --- a/torch/csrc/jit/argument_spec.h +++ b/torch/csrc/jit/argument_spec.h @@ -24,6 +24,7 @@ struct ArgumentInfo { int device() const { return device_; } + // XXX: It is guaranteed that this will return false when called on non-tensor arguments bool requires_grad() const { return requires_grad_; } @@ -42,7 +43,8 @@ struct ArgumentInfo { private: unsigned is_tensor_ : 1; unsigned defined_ : 1; - unsigned requires_grad_ : 6; + unsigned requires_grad_ : 1; + unsigned : 5; unsigned dim_ : 8; int device_ : 8; // NOTE: this needs to be signed because we use -1 to represent CPU unsigned type_ : 8; @@ -59,6 +61,9 @@ struct ArgumentSpec { int32_t num_inputs = inputs.size(); for (int32_t i = 0; i < num_inputs; ++i) { auto & arg = args[i]; + // Initialize all fields to 0. This is convenient, because e.g. + // requires_grad() can be checked even on tensors. + std::memset(&arg, 0, sizeof(ArgumentInfo)); arg.is_tensor_ = static_cast(inputs[i].isTensor()); if (arg.is_tensor_) { at::Tensor t = inputs[i].toTensor(); diff --git a/torch/csrc/jit/attributes.h b/torch/csrc/jit/attributes.h index 53b87af9ef9..7199e0ae578 100644 --- a/torch/csrc/jit/attributes.h +++ b/torch/csrc/jit/attributes.h @@ -176,6 +176,12 @@ struct Attributes { #undef CREATE_ACCESSOR + // Our Graphs are not very const-correct, so we need to allow returning + // non-const references too + GraphAttr::ValueType& g(Symbol name) { + return get(name); + } + // does not use CREATE_ACCESSOR because we need additional asserts Derived* t_(Symbol name, TensorAttr::ConstructorType v) { JIT_ASSERT(!v.defined() || !v.is_variable()); diff --git a/torch/csrc/jit/autodiff.cpp b/torch/csrc/jit/autodiff.cpp index 09b1ff8548a..3d53ad3967c 100644 --- a/torch/csrc/jit/autodiff.cpp +++ b/torch/csrc/jit/autodiff.cpp @@ -35,6 +35,7 @@ bool isDifferentiable(Node * n) { "aten::neg(Tensor self) -> Tensor", "aten::type_as(Tensor self, Tensor other) -> Tensor", "aten::unsqueeze(Tensor self, int dim) -> Tensor", + "aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta, Scalar alpha) -> Tensor", "aten::mm(Tensor self, Tensor mat2) -> Tensor", "aten::lt(Tensor self, Tensor other) -> Tensor", "aten::le(Tensor self, Tensor other) -> Tensor", @@ -97,6 +98,9 @@ static std::vector gradientForNode(Node* node, ArrayRef grad_val if (node->matches("aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor")) { return {grads.at(0), grads.at(0) * node->namedInput(attr::alpha), nullptr}; + } else if (node->kind() == prim::AutogradAdd) { + return {grads.at(0), grads.at(0)}; + } else if (node->matches("aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor")) { return {grads.at(0), -grads.at(0) * node->namedInput(attr::alpha), nullptr}; @@ -136,29 +140,14 @@ static std::vector gradientForNode(Node* node, ArrayRef grad_val } else if (node->matches("aten::unsqueeze(Tensor self, int dim) -> Tensor")) { return {grads.at(0).squeeze(node->namedInput(attr::dim)), nullptr}; + } else if (node->matches("aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta, Scalar alpha) -> Tensor")) { + return {grads.at(0) * node->namedInput(attr::beta), + grads.at(0).mm(inputs.at(2).t()) * node->namedInput(attr::alpha), + inputs.at(1).t().mm(grads.at(0)) * node->namedInput(attr::alpha), + nullptr, nullptr}; + } else if (node->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor")) { - SymbolicVariable dmat1, dmat2; - if (auto type = inputs.at(0).value()->type()->cast()) { - auto sizes = type->sizes(), strides = type->strides(); - if (strides.at(0) == 1 && strides.at(1) == sizes.at(0)) { - dmat1 = inputs.at(1).mm(grads.at(0).t()).t(); - } else { - dmat1 = grads.at(0).mm(inputs.at(1).t()); - } - } else { - dmat1 = grads.at(0).mm(inputs.at(1).t()); - } - if (auto type = inputs.at(1).value()->type()->cast()) { - auto sizes = type->sizes(), strides = type->strides(); - if (strides.at(0) == 1 && strides.at(1) == sizes.at(0)) { - dmat2 = grads.at(0).t().mm(inputs.at(0)).t(); - } else { - dmat2 = inputs.at(0).t().mm(grads.at(0)); - } - } else { - dmat2 = inputs.at(0).t().mm(grads.at(0)); - } - return {dmat1, dmat2}; + return {grads.at(0).mm(inputs.at(1).t()), inputs.at(0).t().mm(grads.at(0))}; } else if (node->matches("aten::expand(Tensor self, int[] size, *, int implicit) -> Tensor")) { const auto& input_sizes = inputs.at(0).sizes(); @@ -364,7 +353,9 @@ static ReverseDetails addReverseInline(Gradient& grad_desc, JIT_ASSERT(grad_inputs.size() == node->inputs().size()); for (size_t i = 0, num_inputs = grad_inputs.size(); i < num_inputs; ++i) { if (!requires_grad(inputs[i])) continue; - JIT_ASSERT(grad_inputs[i]); + // NB: Not returning a gradient w.r.t. a value that requires grad is normal if the + // input is non-differentiable. This happens e.g. in the aten::type_as case. + if (!grad_inputs[i]) continue; set_grad(inputs[i], grad_inputs[i]); } } @@ -374,6 +365,11 @@ static ReverseDetails addReverseInline(Gradient& grad_desc, Value * input = inputs[i]; if (!requires_grad(input)) continue; + // NB: Not having a gradient defined w.r.t. an input to the graph which requires grad + // can happen and is not an error. It might have been used only in non-differentiable + // contexts (e.g. as second input to aten::type_as). In that case we simply ignore it + // as an output, because it won't ever produce any meaningful values. + if (grad_map.count(input) == 0) continue; reverse_block->registerOutput(get_grad(input)); grad_desc.df_output_vjps.push_back(i); } @@ -557,7 +553,7 @@ Gradient differentiate(std::shared_ptr& graph, const std::vector& r // addReverseInline has to call gradientForNode if *any* of the outputs // require grad, but it will emit vjps for *all* outputs. Use DCE to remove // unnecessary nodes. - EliminateDeadCode(grad_desc.f); + EliminateDeadCode(rev_info.reverse_block); // Fills in f, df, f_real_outputs, df_input_captures, // modifies df_input_vjps (new vjps are added for temporaries) lambdaLiftReverse(grad_desc, rev_info); diff --git a/torch/csrc/jit/graph_executor.cpp b/torch/csrc/jit/graph_executor.cpp index a448733fdca..6c603789a26 100644 --- a/torch/csrc/jit/graph_executor.cpp +++ b/torch/csrc/jit/graph_executor.cpp @@ -24,6 +24,7 @@ #include "torch/csrc/jit/passes/constant_propagation.h" #include "torch/csrc/jit/symbolic_variable.h" #include "torch/csrc/jit/ivalue.h" +#include "torch/csrc/jit/custom_operator.h" #include "torch/csrc/autograd/edge.h" #include "torch/csrc/autograd/function.h" @@ -45,14 +46,34 @@ using tensor_list = std::vector; using Variable = autograd::Variable; using autograd::variable_list; -// this type is in ExecutionPlan to run its Gradient if it is -// specified. It has a list of inputs captured by ExecutionPlan that -// it concats with inputs to form the full set of inputs to graph. -// see struct Gradient for a description of how the derivative graph -// is constructed and what variables are captured. -struct ExecutionPlanAutogradFunction : public autograd::Function { - ExecutionPlanAutogradFunction(GraphExecutor graph, size_t capture_size) - : graph(std::move(graph)) { +struct ExecutionPlan { + ExecutionPlan() = default; + ExecutionPlan(std::shared_ptr graph) + : code(graph) + , graph(std::move(graph)) {} + + void run(Stack& stack) const { + return InterpreterState(code).runOneStage(stack); + } + + operator bool() const { + return static_cast(graph); + } + + ExecutionPlanState getDebugState() { + ExecutionPlanState state; + state.code = &code; + state.graph = graph.get(); + return state; + } + + Code code; + std::shared_ptr graph; +}; + +struct DifferentiableGraphBackward : public autograd::Function { + DifferentiableGraphBackward(GraphExecutor executor, size_t capture_size) + : executor(std::move(executor)) { is_var_capture.reserve(capture_size); var_captures.reserve(capture_size); ivalue_captures.reserve(capture_size); @@ -74,10 +95,28 @@ struct ExecutionPlanAutogradFunction : public autograd::Function { ++ivalue_capture_it; } } - graph.run(stack); - return fmap(stack, [](IValue & val) { - return autograd::Variable(std::move(val).toTensor()); - }); + + executor.run(stack); + JIT_ASSERT(stack.size() == num_outputs()); + + variable_list outputs; + outputs.reserve(num_outputs()); + for (size_t i = 0; i < num_outputs(); ++i) { + if (should_compute_output(i)) { + auto output = std::move(stack[i]).toTensor(); + const auto & edge = next_edge(i); + if (output.defined()) { + outputs.push_back(std::move(output)); + } else if (edge.is_valid()) { + outputs.push_back(edge.function->input_metadata(edge.input_nr).zeros_like()); + } else { + outputs.emplace_back(); + } + } else { + outputs.emplace_back(); + } + } + return outputs; } void capture(const IValue & val, bool is_output) { @@ -91,7 +130,7 @@ struct ExecutionPlanAutogradFunction : public autograd::Function { } private: friend struct ExecutionPlan; - GraphExecutor graph; + GraphExecutor executor; // INVARIANT: is_var_capture.size() == var_captures.size() + ivalue_captures.size() std::vector is_var_capture; @@ -104,74 +143,17 @@ private: // This will unwrap Variables, run the plan, and re-wrap them. // It can optionally also have a gradient which is hooked up // to the output Variables if present. -struct ExecutionPlan { - ExecutionPlan(std::shared_ptr& graph) - : f(graph), - graph(graph), - num_inputs(graph->inputs().size()), - num_outputs(graph->outputs().size()) {} - ExecutionPlan(std::shared_ptr& graph, Gradient grad) - : f(graph), - graph(graph), +struct DifferentiableGraphOp { + DifferentiableGraphOp(Gradient grad) + : f(grad.f), grad(std::move(grad)), grad_executor(this->grad.df), - num_inputs(graph->inputs().size()), - num_outputs(graph->outputs().size()) {} - - void run(Stack & stack) const { - if (grad) { - return runWithGrad(stack); - } - InterpreterState(f).runOneStage(stack); - } - - std::shared_ptr get_graph() const { - return graph; - } - - ExecutionPlanState getDebugState() { - ExecutionPlanState state; - state.f = &f; - state.graph = graph.get(); - if (grad) { - state.grad = &grad; - state.grad_executor = std::unique_ptr( - new GraphExecutorState(grad_executor.getDebugState())); - } else { - state.grad = nullptr; - state.grad_executor.reset(); - } - return state; - } - -private: - void detachVariables(Stack & stack) const { - // It would be nice to use an ArrayRef here, but unfortunately those can only - // return const references, so we need to do a bunch of indexing ourselves. - const int64_t stack_size = stack.size(); - const int64_t stack_offset = stack_size - num_inputs; - for (int64_t i = stack_offset; i < stack_size; ++i) { - auto & v = stack[i]; - if (!v.isTensor()) continue; - auto t = std::move(v).toTensor(); - v = IValue{t.defined() ? autograd::as_variable_ref(t).detach() : std::move(t)}; - } - } - // Capture (save) inputs that would be required to subsequently run backwards - void captureInputs(ExecutionPlanAutogradFunction & grad_fn, at::ArrayRef inputs) const { - for (size_t offset : grad.df_input_captured_inputs) { - grad_fn.capture(inputs[offset], /*is_output*/false); - } - } - void captureOutputs(ExecutionPlanAutogradFunction & grad_fn, at::ArrayRef outputs) const { - for (size_t offset : grad.df_input_captured_outputs) { - grad_fn.capture(outputs[offset], /*is_output*/true); - } - } + num_inputs(this->grad.f->inputs().size()), + num_outputs(this->grad.f->outputs().size()) {} // XXX: keep in mind that stack can be larger than the inputs we need! - void runWithGrad(Stack & stack) const { - auto grad_fn = std::make_shared(grad_executor, + int operator()(Stack & stack) const { + auto grad_fn = std::make_shared(grad_executor, grad.df_input_captured_inputs.size() + grad.df_input_captured_outputs.size()); { @@ -179,7 +161,7 @@ private: // hook up the outputs of df to the gradient functions of the inputs that require gradients for(auto idx : grad.df_output_vjps) { auto v = Variable(inputs[idx].toTensor()); - grad_fn->add_next_edge(v.gradient_edge()); + grad_fn->add_next_edge(v.defined() ? v.gradient_edge() : autograd::Edge{}); } captureInputs(*grad_fn, inputs); } @@ -201,8 +183,15 @@ private: // reallocate variables that were already created in wrapTensors. We // should add an API for this. Variable output = outputs[idx].toTensor(); - autograd::create_gradient_edge(output, grad_fn); - output.set_requires_grad(true); + // NB: since our requires_grad setting is only a heuristic we might end up + // wanting to differentiate through integral tensors, which is generally a + // hard error in autograd. + if (at::isFloatingType(output.type().scalarType())) { + autograd::create_gradient_edge(output, grad_fn); + output.set_requires_grad(true); + } else { + grad_fn->add_input_metadata(autograd::Function::undefined_input{}); + } } captureOutputs(*grad_fn, outputs); // drop the temporary outputs so that we return the same number of @@ -210,22 +199,89 @@ private: const size_t num_temporary_outputs = num_outputs - grad.f_real_outputs; stack.erase(stack.end() - num_temporary_outputs, stack.end()); } + return 0; + } + +private: + friend GraphExecutor* detail::getGradExecutor(Operation& op); + + void detachVariables(Stack & stack) const { + // It would be nice to use an ArrayRef here, but unfortunately those can only + // return const references, so we need to do a bunch of indexing ourselves. + const int64_t stack_size = stack.size(); + const int64_t stack_offset = stack_size - num_inputs; + for (int64_t i = stack_offset; i < stack_size; ++i) { + auto & v = stack[i]; + if (!v.isTensor()) continue; + auto t = std::move(v).toTensor(); + v = IValue{t.defined() ? autograd::as_variable_ref(t).detach() : std::move(t)}; + } + } + // Capture (save) inputs that would be required to subsequently run backwards + void captureInputs(DifferentiableGraphBackward & grad_fn, at::ArrayRef inputs) const { + for (size_t offset : grad.df_input_captured_inputs) { + grad_fn.capture(inputs[offset], /*is_output*/false); + } + } + void captureOutputs(DifferentiableGraphBackward & grad_fn, at::ArrayRef outputs) const { + for (size_t offset : grad.df_input_captured_outputs) { + grad_fn.capture(outputs[offset], /*is_output*/true); + } } Code f; - // optimized graph for debugging and testing - std::shared_ptr graph; - // description of gradient as a graph - Gradient grad; // if(grad) is false when this is unused - // executor for df, including code caches + Gradient grad; GraphExecutor grad_executor; const size_t num_inputs; const size_t num_outputs; }; +void packGradient(Gradient gradient, Node *dnode) { + JIT_ASSERT(dnode->kind() == prim::DifferentiableGraph); + dnode->g_(attr::Subgraph, gradient.f) + ->g_(attr::ReverseSubgraph, gradient.df) + ->i_(attr::f_real_outputs, gradient.f_real_outputs) + ->is_(attr::df_input_vjps, fmap(gradient.df_input_vjps)) + ->is_(attr::df_input_captured_inputs, fmap(gradient.df_input_captured_inputs)) + ->is_(attr::df_input_captured_outputs, fmap(gradient.df_input_captured_outputs)) + ->is_(attr::df_output_vjps, fmap(gradient.df_output_vjps)); +} + +Gradient getGradient(Node *n) { + JIT_ASSERT(n->kind() == prim::DifferentiableGraph); + Gradient grad; + grad.f = n->g(attr::Subgraph); + grad.df = n->g(attr::ReverseSubgraph); + grad.f_real_outputs = n->i(attr::f_real_outputs); + grad.df_input_vjps = fmap(n->is(attr::df_input_vjps)); + grad.df_input_captured_inputs = fmap(n->is(attr::df_input_captured_inputs)); + grad.df_input_captured_outputs = fmap(n->is(attr::df_input_captured_outputs)); + grad.df_output_vjps = fmap(n->is(attr::df_output_vjps)); + return grad; +} + } // anonymous namespace +RegisterOperators reg_graph_executor_ops({ + Operator( + prim::DifferentiableGraph, + [](Node *n) -> Operation { + return DifferentiableGraphOp(getGradient(n)); + }) +}); + +namespace detail { + +GraphExecutor* getGradExecutor(Operation& op) { + if (auto diff_op = op.target()) { + return &diff_op->grad_executor; + } + return nullptr; +} + +} // namespace detail + // a Graph can be created via tracing, or via a language-based frontend // GraphExecutor runs it. It can run the same graph on many different sizes // and different requires_grad states, and handles specializations for each situation. @@ -233,68 +289,49 @@ private: // tracing concerns separated. struct GraphExecutorImpl { - GraphExecutorImpl(std::shared_ptr graph, bool optimize, bool symbolically_differentiable) - : graph(std::move(graph)) - , optimize(optimize) - , num_inputs(this->graph->inputs().size()) - , num_outputs(this->graph->outputs().size()) - , symbolically_differentiable(symbolically_differentiable) - , may_introduce_gradient(calcMayIntroduceGradient(this->graph->block())) {} + static std::shared_ptr prepareGraph(std::shared_ptr graph) { + auto copy = graph->copy(); + EraseShapeInformation(*copy); + return copy; + } + GraphExecutorImpl(std::shared_ptr graph, bool optimize) - : GraphExecutorImpl(graph, optimize, isDifferentiable(*graph)) {} + : graph(prepareGraph(graph)) + , optimize(optimize) + , num_inputs(this->graph->inputs().size()) + , num_outputs(this->graph->outputs().size()) {} // entry point where execution begins void run(Stack & stack) { - if(stack.size() < num_inputs) { - std::stringstream ss; - ss << "expected " << num_inputs << " inputs but got " << stack.size() << " inputs"; - throw std::runtime_error(ss.str()); - } - auto inputs = last(stack, num_inputs); + AT_CHECK(stack.size() >= num_inputs, "expected ", num_inputs, " inputs, but got only ", stack.size()); - // the tracer has called a graph executor - // there is no need to optimize, but we do need to splice the graph of - // this excutor into the trace. Otherwise we might unroll control-flow - // operations. if(tracer::isTracing()) { return runTraced(stack); } - // this is the fallback pathway, when we cannot differentiate - if(!optimize || (!symbolically_differentiable && needsGradient(inputs))) { - return runFallback(stack); - } - - // either we can symbolically differentiate, or we do not need a gradient. - // go down the route where we treat the inputs as tensors - // and fully optimize - auto & implementation = getOrCompile(inputs); - return implementation.run(stack); + auto & execution_plan = optimize ? getOrCompile(stack) : getOrCompileFallback(); + return execution_plan.run(stack); } std::shared_ptr graphFor(const Stack& stack) const { auto inputs = last(stack, num_inputs); ArgumentSpec spec(autograd::GradMode::is_enabled(), inputs); - if (!optimize || (!symbolically_differentiable && needsGradient(inputs))) { - JIT_ASSERTM(autograd_fallback_graph, "No graph found for given inputs"); - return autograd_fallback_graph; + if (!optimize) { + AT_CHECK(fallback, "No graph found for given inputs"); + return fallback.graph; } auto it = plan_cache.find(spec); - JIT_ASSERTM(it != plan_cache.end(), "No graph found for given inputs"); - return it->second.get_graph(); + AT_CHECK(it != plan_cache.end(), "No graph found for given inputs"); + return it->second.graph; } GraphExecutorState getDebugState() { GraphExecutorState state; state.graph = graph.get(); - if (autograd_fallback) { - state.autograd_fallback = &autograd_fallback; - state.autograd_fallback_graph = autograd_fallback_graph.get(); - } else { - state.autograd_fallback = nullptr; - state.autograd_fallback_graph = nullptr; + if (fallback) { + state.fallback = fallback.getDebugState(); } for (auto & entry : plan_cache) { state.execution_plans.emplace(entry.first, entry.second.getDebugState()); @@ -305,6 +342,121 @@ struct GraphExecutorImpl { private: friend struct GraphExecutor; + const ExecutionPlan & getOrCompileFallback() { + std::lock_guard lock(compile_mutex); + if(!fallback) { + auto graph_ = graph->copy(); + runRequiredPasses(graph_); + fallback = ExecutionPlan(graph_); + } + return fallback; + } + + const ExecutionPlan & getOrCompile(const Stack& stack) { + // outside lock guard, to minimize the time holding the lock on the fast path + // ArgumentSpec even computes its hashCode here. + ArgumentSpec spec(autograd::GradMode::is_enabled(), last(stack, num_inputs)); + { + std::lock_guard lock(compile_mutex); + auto it = plan_cache.find(spec); + if (it != plan_cache.end()) + return it->second; + auto plan = compileSpec(spec); + auto r = plan_cache.emplace(std::move(spec), std::move(plan)); + return r.first->second; + } + } + + ExecutionPlan compileSpec(const ArgumentSpec & spec) { + auto opt_graph = graph->copy(); + + // Phase 1. Specialize to input definedness (this is very important for + // gradient graphs), and run required passes to bring the graph + // to an executable form. + specializeGrad(opt_graph, spec); + runRequiredPasses(opt_graph); + + // Phase 2. Propagate detailed information about the spec through the + // graph (enabled more specializations in later passes). + PropagateInputShapes(*opt_graph, spec); + + // Phase 3. Run differentiable optimizations (i.e. simple graph rewrites that + // we can still execute using autograd). + runOptimization(opt_graph, spec); + + // Phase 4. If this graph will be differentiated, we need to slice out the + // symbolically differentiable subgraphs for further optimizations. + // Phase 5. Apply non-differentiable optimizations to the graphs we've found + // (or the whole grpah if we know we won't need its derivative). + if (needsGradient(opt_graph, spec)) { + auto diff_nodes = CreateAutodiffSubgraphs(*opt_graph); + for (Node * dnode : diff_nodes) { + // XXX: we don't have requires_grad information on the intermediate values, + // so we conservatively assume it's always true (on tensor inputs). + auto diff_graph = std::move(dnode->g(attr::Subgraph)); + auto requires_grads = fmap(diff_graph->inputs(), [](Value* v) { + // NB: only floating-point inputs can have requires_grad=True. If we + // don't have type information, we have to assume that it's true. + if (auto tensor_type = v->type()->cast()) { + return at::isFloatingType(tensor_type->scalarType()); + } + return v->type()->isSubtypeOf(DynamicType::get()); + }); + Gradient gradient = differentiate(diff_graph, requires_grads); + runNondiffOptimization(gradient.f); + packGradient(gradient, dnode); + } + } else { + runNondiffOptimization(opt_graph); + } + return ExecutionPlan(opt_graph); + } + + void specializeGrad(std::shared_ptr& graph, const ArgumentSpec& spec) { + std::vector defined; + for (size_t i = 0; i < spec.size(); ++i) + defined.push_back(spec.at(i).defined()); + specializeUndef(*graph, defined); + } + + void runOptimization(std::shared_ptr& graph, const ArgumentSpec& spec) { + EliminateDeadCode(graph); + EliminateCommonSubexpression(graph); + UnrollLoops(graph); + ConstantPropagation(graph); + PeepholeOptimize(graph); + CheckInplace(graph); + BatchMM(graph); + } + + void runNondiffOptimization(std::shared_ptr& graph) { + FuseGraph(graph); + } + + static bool needsGradient(const std::shared_ptr& graph, const ArgumentSpec& spec) { + if (!autograd::GradMode::is_enabled()) + return false; + if (mayIntroduceGradient(graph->block())) + return true; + for(size_t i = 0; i < spec.size(); ++i) { + if(spec.at(i).requires_grad()) + return true; + } + return false; + } + + static bool mayIntroduceGradient(const Block* b) { + for (const Node* n : b->nodes()) { + if (n->kind() == prim::PythonOp) + return true; + for (const Block* bb : n->blocks()) { + if (mayIntroduceGradient(bb)) + return true; + } + } + return false; + } + void runTraced(Stack & stack) { auto state = tracer::getTracingState(); auto inputs = last(stack, num_inputs); @@ -313,25 +465,18 @@ private: }); ArgumentSpec spec(autograd::GradMode::is_enabled(), inputs); - runFallback(stack); + // NB: we could just run the fallback in here and call it a day, but that would loose all + // the control flow information we have in the graph. Thus, we run the fallback to + // get the correct output values, but we will override the tracing states later. + getOrCompileFallback().run(stack); - auto all_dynamic = [](const at::ArrayRef xs) { - for(Value* x : xs) { - if(x->type()->kind() != TypeKind::DynamicType) - return false; - } - return true; - }; // Traces always have types propagated through them, so we make sure to // also propagate types through the graph we are inserting here. // However, this->graph itself may already have been generated with // tracing and so we only do the type propgation if no concrete types have // been set. - auto local_graph = this->graph; - if(all_dynamic(local_graph->inputs()) && all_dynamic(local_graph->outputs())) { - local_graph = this->graph->copy(); - PropagateInputShapes(*local_graph, spec); - } + auto local_graph = this->graph->copy(); + PropagateInputShapes(*local_graph, spec); auto output_values = script::inlineCallTo(*state->graph, *local_graph, input_values); auto outputs = last(stack, num_outputs); @@ -343,147 +488,32 @@ private: } } - void runFallback(Stack & stack) { - auto & fb = getOrCreateAutogradFallback(); - InterpreterState(fb).runOneStage(stack); - } - - static bool calcMayIntroduceGradient(Block* b) { - for(Node* n : b->nodes()) { - if(n->kind() == prim::PythonOp) - return true; - for(Block* bb : n->blocks()) { - if(calcMayIntroduceGradient(bb)) - return true; - } - } - return false; - } - bool needsGradient(at::ArrayRef inputs) const { - if (!autograd::GradMode::is_enabled()) { - return false; - } - if (may_introduce_gradient) - return true; - for (const IValue & value : inputs) { - if (!value.isTensor()) continue; - auto t = value.toTensor(); - if (t.defined() && autograd::as_variable_ref(t).requires_grad()) - return true; - } - return false; - } - - const Code & getOrCreateAutogradFallback() { - std::lock_guard lock(compile_mutex); - if(autograd_fallback) { - return autograd_fallback; - } - auto graph_ = graph->copy(); - runRequiredPasses(graph_); - if(optimize) { - runOptimization(graph_, /*graphMustSupportVariables=*/true); - if(!isDifferentiable(*graph_)) { - EraseShapeInformation(*graph_); - CreateAutodiffSubgraphs(*graph_); - } - } - autograd_fallback_graph = graph_; - autograd_fallback = Code(graph_); - return autograd_fallback; - } - const ExecutionPlan & getOrCompile(at::ArrayRef inputs) { - // outside lock guard, to minimize the time holding the lock on the fast path - // ArgumentSpec even computes its hashCode here. - ArgumentSpec spec(autograd::GradMode::is_enabled(), inputs); - { - std::lock_guard lock(compile_mutex); - auto it = plan_cache.find(spec); - if(it != plan_cache.end()) - return it->second; - auto plan = compileSpec(spec); - auto r = plan_cache.emplace(std::move(spec), std::move(plan)); - return r.first->second; - } - } - - bool argumentSpecRequiresGradient(const ArgumentSpec & spec) { - for(size_t i = 0; i < spec.size(); ++i) { - if(spec.at(i).requires_grad()) - return true; - } - return false; - } - - ExecutionPlan compileSpec(const ArgumentSpec & spec) { - auto graph_ = graph->copy(); - - specializeToSpec(graph_, spec); - - if(!argumentSpecRequiresGradient(spec)) { - runOptimization(graph_, /*graphMustSupportVariables=*/false); - return ExecutionPlan(graph_); - } - JIT_ASSERT(symbolically_differentiable); - - std::vector requires_grads; - requires_grads.reserve(spec.size()); - for(size_t i = 0; i < spec.size(); i++) - requires_grads.push_back(spec.at(i).requires_grad()); - - Gradient gradient = differentiate(graph_, requires_grads); - graph_ = gradient.f; - runOptimization(graph_, /*graphMustSupportVariables=*/false); - return ExecutionPlan(graph_, std::move(gradient)); - } - // the unoptimized starting graph - // this is never mutated + // The unoptimized starting graph. This field is effectively const, but we can't make it so + // because Graph::copy() is not const (and making it const is not that easy at this point). std::shared_ptr graph; - // true - do everything we can to make this graph run fast - // false - do not modifiy the graph at all and just use the interpreter - // to run the graph. Useful for debugging correctness issues in the implementation + // If false, we'll run the graph as we get it, without any optimizations. Useful + // for debugging. const bool optimize; const size_t num_inputs; const size_t num_outputs; - // GraphExecutor optimizes more aggresively when we _know_ the graph will be - // symbolically differentiable. - bool symbolically_differentiable; + // Populated only when optimize is false (and in that case plan_cache will be unused). + // The compiled version of graph. + ExecutionPlan fallback; - // some ops, including python operations, can intorduce requires_grad=True - // variables even though no inputs to this graph are availiable, if - // the graph includes those operators then needGradient must be true - // regardles of input state. - bool may_introduce_gradient; - - // when this graph has some parts that are not symbolically_differentable, - // but some input does require a derivative, we create and use autograd_fallback, - // which wraps up the fully differentiable subgraphs, and then runs the outer - // graph through autograd. - // Since we can't optimize black box functions anyway, there is only one fallback path, - // and it must work on all sizes (so no optimizations that inspect sizes can run on it) - std::shared_ptr autograd_fallback_graph; - Code autograd_fallback; - - // optimizable code paths, used when we can differentiate or when no derivative is needed - // Spec describes input conditions, Plan describes how to execute them. + // Mapping from argument configurations to optimized versions of the graph that are + // specialized to the spec. std::unordered_map plan_cache; - // GraphExecutor can be accessed from multiple thread so - // anytime we are checking or updating the autograd_fallback or - // plan_cache, we must hold the compile mutex. - // along the fast path (no compilation) code should - // hold this for as little time as possible. + // GraphExecutors can be accessed from multiple threads, so this thread needs to be + // held every time we access the fallback or plan_cache. std::mutex compile_mutex; }; GraphExecutor::GraphExecutor(std::shared_ptr graph, bool optimize) : pImpl(new GraphExecutorImpl(std::move(graph), optimize)) {} -GraphExecutor::GraphExecutor(std::shared_ptr graph, bool optimize, bool symbolically_differentiable) -: pImpl(new GraphExecutorImpl(std::move(graph), optimize, symbolically_differentiable)) {} - void GraphExecutor::run(Stack & inputs) { return pImpl->run(inputs); } @@ -509,49 +539,7 @@ void runRequiredPasses(const std::shared_ptr& g) { // add valid expand nodes when the shapes are stable RemoveExpands(g); CanonicalizeOps(g); -} - -void specializeToSpec(const std::shared_ptr& graph, const ArgumentSpec& spec) { - // clean up GradOf and AutogradAdd nodes - // this must be first because later passes do not know what GradOfs are - std::vector defined; - for(size_t i = 0; i < spec.size(); ++i) { - defined.push_back(spec.at(i).defined()); - } - specializeUndef(*graph, defined); - - // required passes shared with autograd fallback - runRequiredPasses(graph); - - // clean up dead constants from specialization - EliminateDeadCode(graph); - // calculate all input shapes - PropagateInputShapes(*graph, spec); -} - -void runOptimization(std::shared_ptr & graph, bool graphMustSupportVariables) { - - // these optimizations must run in the presence of variables - // and when shape information is not statically known. - EliminateDeadCode(graph); - CheckInplace(graph); - EliminateCommonSubexpression(graph); - - if (!graphMustSupportVariables) { - // These optimizations can introduce operators like FusionGroup that - // do not work on variables - - // They also may assume that concrete sizes/strides are availiable - UnrollLoops(graph); - ConstantPropagation(graph); - //TODO: create peephole optimizations that are safe to run - // when we are using variables, and when we do not know sizes. - PeepholeOptimize(graph); - // TODO: remove mandatory size checking in BatchMM, otherwise - // it works fine on variables. - BatchMM(graph); - FuseGraph(graph); - } + EliminateDeadCode(g); } }} diff --git a/torch/csrc/jit/graph_executor.h b/torch/csrc/jit/graph_executor.h index 2693af50af1..7e644273a5b 100644 --- a/torch/csrc/jit/graph_executor.h +++ b/torch/csrc/jit/graph_executor.h @@ -15,29 +15,20 @@ struct GraphExecutorState; // They is only valid only right after you call getDebugState() and should never // be used again once another GraphExecutor function is called. struct ExecutionPlanState { - Code* f; - Graph* graph; - - // Those two fields are optional - Gradient* grad; - std::shared_ptr grad_executor; // shared_ptr to break the cycle... + Code* code = nullptr; + const Graph* graph = nullptr; }; struct GraphExecutorState { - Graph* graph; + const Graph* graph; + ExecutionPlanState fallback; // XXX: members of this field are optional std::unordered_map execution_plans; - - // Those two fields are optional - Code* autograd_fallback; - Graph* autograd_fallback_graph; }; struct GraphExecutorImpl; struct TORCH_API GraphExecutor { GraphExecutor() = default; GraphExecutor(std::shared_ptr graph, bool optimize = true); - // note: if not specified, symbolically_differentiable is computed from the graph. - GraphExecutor(std::shared_ptr graph, bool optimize, bool symbolically_differentiable); void run(Stack & inputs); explicit operator bool() const { return pImpl != nullptr; @@ -53,15 +44,11 @@ private: // regardless of whether sizes have been specialized or not. TORCH_API void runRequiredPasses(const std::shared_ptr& g); -// specialize 'graph' to the types, sizes, and other properties described in spec -// this prepares the graph for execution, including running runRequiredPasses, -// but the execution only remains valid for tensors whose properties match spec -// otherwise running the graph will have undefined results. -TORCH_API void specializeToSpec(const std::shared_ptr& graph, const ArgumentSpec& spec); +namespace detail { + +GraphExecutor* getGradExecutor(Operation& op); + +} // namespace detail -// apply standard optimizations. if graphMustSupportVariables=false then -// then the passes are allowed to modify the graph in ways that make it no longer -// work with tensors that have requires_grad=True -TORCH_API void runOptimization(std::shared_ptr & graph, bool graphMustSupportVariables); }} diff --git a/torch/csrc/jit/init.cpp b/torch/csrc/jit/init.cpp index 54af121e339..e816961d86d 100644 --- a/torch/csrc/jit/init.cpp +++ b/torch/csrc/jit/init.cpp @@ -129,8 +129,8 @@ void initJITBindings(PyObject *module) { }); py::class_(m, "ArgumentSpec"); py::class_(m, "Code") - .def("executors", [](Code& c) { - return py::make_iterator(c.executors().begin(), c.executors().end()); + .def("grad_executors", [](Code& c) { + return py::make_iterator(c.grad_executors().begin(), c.grad_executors().end()); }); py::class_(m, "ExecutionPlanState") @@ -138,10 +138,7 @@ void initJITBindings(PyObject *module) { return s.graph; }) .def_property_readonly("code", [](ExecutionPlanState& s) { - return s.f; - }) - .def_property_readonly("grad_executor", [](ExecutionPlanState& s) { - return s.grad_executor.get(); + return s.code; }); py::class_(m, "Gradient") @@ -174,11 +171,8 @@ void initJITBindings(PyObject *module) { .def_property_readonly("execution_plans", [](GraphExecutorState& s) { return s.execution_plans; }) - .def_property_readonly("autograd_fallback", [](GraphExecutorState& s) { - return s.autograd_fallback; - }) - .def_property_readonly("autograd_fallback_graph", [](GraphExecutorState& s) { - return s.autograd_fallback_graph; + .def_property_readonly("fallback", [](GraphExecutorState& s) { + return s.fallback; }); py::class_(m, "GraphExecutor", py::dynamic_attr()) @@ -204,7 +198,7 @@ void initJITBindings(PyObject *module) { .def_property_readonly("graph", [](GraphExecutor& ge) { return ge.graph(); }) - .def("get_debug_state", [](GraphExecutor& ge) { + .def("get_debug_state", [](GraphExecutor& ge) { return ge.getDebugState(); }) .def("__call__", [](GraphExecutor& ge, py::args args) -> py::object { diff --git a/torch/csrc/jit/interned_strings.h b/torch/csrc/jit/interned_strings.h index 5fc97e1c67e..899c3c80636 100644 --- a/torch/csrc/jit/interned_strings.h +++ b/torch/csrc/jit/interned_strings.h @@ -24,7 +24,7 @@ namespace torch { namespace jit { _(prim, Eval) \ _(prim, Expand) /* onnx */ \ _(prim, FusionGroup) \ - _(prim, GraphExecutor) \ + _(prim, DifferentiableGraph) \ _(prim, If) \ _(prim, Jump) /* debug */ \ _(prim, JumpNZ) /* debug */ \ @@ -87,6 +87,12 @@ namespace torch { namespace jit { _(onnx, Not) \ FORALL_ATTR_BASE_SYMBOLS(_) \ _(attr, Subgraph) \ + _(attr, ReverseSubgraph) \ + _(attr, f_real_outputs) \ + _(attr, df_input_vjps) \ + _(attr, df_input_captured_inputs) \ + _(attr, df_input_captured_outputs) \ + _(attr, df_output_vjps) \ _(attr, axes) \ _(attr, axis) \ _(attr, broadcast) \ diff --git a/torch/csrc/jit/interpreter.cpp b/torch/csrc/jit/interpreter.cpp index 80acb34dd5f..1132b5cd3f5 100644 --- a/torch/csrc/jit/interpreter.cpp +++ b/torch/csrc/jit/interpreter.cpp @@ -515,7 +515,7 @@ struct CodeImpl { size_t insertInstruction(Node * n) { auto inst = insertInstruction(n->kind(), n->getSourceLocation(), n->inputs(), moveFlags(n) , n->outputs()); - instructions[inst].callback = getInterpreterOperation(n); + instructions[inst].callback = getOperation(n); return inst; } size_t insertInstruction(Symbol sym, @@ -603,27 +603,16 @@ struct CodeImpl { return r; } - // Returns a function implementing functionality of a given node, - // or nullptr if it's a no-op for autograd. - Operation getInterpreterOperation(jit::Node* node) { - if(node->kind() != prim::GraphExecutor) { - return getOperation(node); + const std::vector& grad_executors() { + if (!grad_executors_) { + grad_executors_.emplace(); + for (Instruction & instr : instructions) { + if (auto executor = detail::getGradExecutor(instr.callback)) { + grad_executors_->push_back(executor); + } + } } - // recursive graph executors cannot be Operators because they - // have to register themselves with the interpreter so that - // we can provide useful debugging information - - auto executor = std::make_shared(node->g(attr::Subgraph)); - graph_executors.emplace_back(executor.get()); - return [=](Stack& stack) mutable { - autograd::profiler::RecordFunction record("GraphExecutor"); - executor->run(stack); - return 0; - }; - } - - const std::vector& executors() { - return graph_executors; + return *grad_executors_; } void dumpInstruction(std::ostream & out, size_t pc) const { @@ -664,7 +653,7 @@ struct CodeImpl { // It is also very useful for debugging interpreter problems to // keep this around. std::shared_ptr graph; - std::vector graph_executors; // for debugging + at::optional> grad_executors_; PreprocessGraph preprocess; std::unordered_map unique_to_reg; // map from unique of nodes to register in register table @@ -771,8 +760,8 @@ Code::Code(std::shared_ptr& graph) : pImpl(new CodeImpl(graph)) {} Code::~Code() = default; -const std::vector& Code::executors() { - return pImpl->executors(); +const std::vector& Code::grad_executors() { + return pImpl->grad_executors(); } InterpreterState::InterpreterState(const Code & code) diff --git a/torch/csrc/jit/interpreter.h b/torch/csrc/jit/interpreter.h index 8c4832e7ec6..151a980d76a 100644 --- a/torch/csrc/jit/interpreter.h +++ b/torch/csrc/jit/interpreter.h @@ -29,8 +29,7 @@ struct TORCH_API Code { Code(std::shared_ptr& graph); ~Code(); - // Returns pointers to GraphExecutors created to run GraphExecutor nodes in the given graph. - const std::vector& executors(); + const std::vector& grad_executors(); explicit operator bool() const { return pImpl != nullptr; diff --git a/torch/csrc/jit/ir.cpp b/torch/csrc/jit/ir.cpp index 3f5bef3812d..9053b45e2e4 100644 --- a/torch/csrc/jit/ir.cpp +++ b/torch/csrc/jit/ir.cpp @@ -186,7 +186,7 @@ std::ostream& printNode(std::ostream & out, size_t level, const Node * n, std::v IR_ELSE() if(n->hasAttribute(attr::Subgraph) && groups) { out << n->kind().toQualString() << "_" << groups->size(); - if (n->numAttributes() > 1) { + if (n->numAttributes() > 1 && n->kind() != prim::DifferentiableGraph) { printAttributes(out, n, /*ignore_subgraph=*/true); } groups->push_back(n); diff --git a/torch/csrc/jit/passes/canonicalize.cpp b/torch/csrc/jit/passes/canonicalize.cpp index 12e94393540..e5dda3ec4c5 100644 --- a/torch/csrc/jit/passes/canonicalize.cpp +++ b/torch/csrc/jit/passes/canonicalize.cpp @@ -24,6 +24,9 @@ std::shared_ptr Canonicalize(const std::shared_ptr& graph) { r_outputs.at(i)->setStage(outputs.at(i)->stage()); rn_env[outputs.at(i)] = r_outputs.at(i); } + if (node->hasAttribute(attr::Subgraph)) { + r_node->g_(attr::Subgraph, Canonicalize(node->g(attr::Subgraph))); + } } for (auto* output : graph->outputs()) { r->registerOutput(rn_fn(output)); diff --git a/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp b/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp index d37ff6dfea5..3554c22ddc7 100644 --- a/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp +++ b/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp @@ -17,7 +17,7 @@ namespace { // right before nodes[0] (i.e. it will not create cycles and all uses of // new node will be after this position). // prereq: nodes are in topological order -void mergeNodes(Block * block, Symbol group_node_kind, ArrayRef nodes) { +Node* mergeNodes(Block * block, Symbol group_node_kind, ArrayRef nodes) { JIT_ASSERT(nodes.size() > 0); std::unordered_map value_map; Graph * graph = block->owningGraph(); @@ -66,11 +66,12 @@ void mergeNodes(Block * block, Symbol group_node_kind, ArrayRef nodes) { nodes[i - 1]->destroy(); } JIT_ASSERT(isDifferentiable(*new_graph)); + return group_node; } } -void CreateAutodiffSubgraphs(Block * block, size_t threshold) { +void CreateAutodiffSubgraphs(Block * block, size_t threshold, std::vector& diff_graphs) { // This implementation is not optimal, but it is simple. // It just scans through the list in order looking for runs of // differentiable ops, and then grouping them together when @@ -93,21 +94,23 @@ void CreateAutodiffSubgraphs(Block * block, size_t threshold) { groupable.push_back(node); } else { if(groupable.size() >= threshold) { - mergeNodes(block, prim::GraphExecutor, groupable); + diff_graphs.push_back(mergeNodes(block, prim::DifferentiableGraph, groupable)); } groupable.clear(); for (Block * sub_block : node->blocks()) { - CreateAutodiffSubgraphs(sub_block, threshold); + CreateAutodiffSubgraphs(sub_block, threshold, diff_graphs); } } } if(groupable.size() >= threshold) { - mergeNodes(block, prim::GraphExecutor, groupable); + diff_graphs.push_back(mergeNodes(block, prim::DifferentiableGraph, groupable)); } } -void CreateAutodiffSubgraphs(Graph & graph, size_t threshold) { - CreateAutodiffSubgraphs(graph.block(), threshold); +std::vector CreateAutodiffSubgraphs(Graph & graph, size_t threshold) { + std::vector diff_nodes; + CreateAutodiffSubgraphs(graph.block(), threshold, diff_nodes); + return diff_nodes; } diff --git a/torch/csrc/jit/passes/create_autodiff_subgraphs.h b/torch/csrc/jit/passes/create_autodiff_subgraphs.h index b76ee82e529..44a6683dc4c 100644 --- a/torch/csrc/jit/passes/create_autodiff_subgraphs.h +++ b/torch/csrc/jit/passes/create_autodiff_subgraphs.h @@ -8,6 +8,7 @@ struct Graph; // insert GraphExecutor nodes that group together // subgraphs that are differentiable by the jit's autodiff passes // threshold - minimum number of nodes that will appear in a block -TORCH_API void CreateAutodiffSubgraphs(Graph & graph, size_t threshold = 2); +// returns all differentiable blocks that have been found +TORCH_API std::vector CreateAutodiffSubgraphs(Graph & graph, size_t threshold = 2); }}