Change specialization rules in GraphExecutors (#10977)

Summary:
**Review last commit only.** Stacked on top of #10949.

This commit fixes a number of issues connected to caching
differentiability status of graphs inside graph executors,
and changes the rules for optimization of differentiable subgraphs.
Previously every one of those was instantiated as a separate graph
executor, but now they are simply heavier-optimized graph regions,
and graph executors are only instantiated for their backward.

zdevito
Pull Request resolved: https://github.com/pytorch/pytorch/pull/10977

Differential Revision: D9600626

Pulled By: apaszke

fbshipit-source-id: dad09a0f586e396afbd5406319c1cd54fbb8a3d3
This commit is contained in:
Adam Paszke 2018-08-30 22:06:19 -07:00 committed by Facebook Github Bot
parent a320e5cbd3
commit 00df09b65d
34 changed files with 786 additions and 763 deletions

View File

@ -4,11 +4,11 @@ graph(%0 : Float(*, *)
%3 : Float(*, *) = prim::FusionGroup_0[device=0](%2, %0, %1) %3 : Float(*, *) = prim::FusionGroup_0[device=0](%2, %0, %1)
return (%3); return (%3);
} }
with prim::FusionGroup_0 = graph(%1 : Float(*) with prim::FusionGroup_0 = graph(%0 : Float(*)
%4 : Float(*, *) %1 : Float(*, *)
%5 : Float(*)) { %2 : Float(*)) {
%6 : Float(*, *) = aten::mul(%4, %5) %3 : Float(*, *) = aten::mul(%1, %2)
%2 : int = prim::Constant[value=1]() %4 : int = prim::Constant[value=1]()
%3 : Float(*, *) = aten::add(%6, %1, %2) %5 : Float(*, *) = aten::add(%3, %0, %4)
return (%3); return (%5);
} }

View File

@ -3,11 +3,11 @@ graph(%0 : Float(*, *)
%2 : Float(*, *) = prim::FusionGroup_0[device=0](%0, %1) %2 : Float(*, *) = prim::FusionGroup_0[device=0](%0, %1)
return (%2); return (%2);
} }
with prim::FusionGroup_0 = graph(%3 : Float(*, *) with prim::FusionGroup_0 = graph(%0 : Float(*, *)
%4 : Float(*, *)) { %1 : Float(*, *)) {
%6 : int = prim::Constant[value=1]() %2 : int = prim::Constant[value=1]()
%7 : Float(*, *) = aten::add(%3, %4, %6) %3 : Float(*, *) = aten::add(%0, %1, %2)
%5 : Float(*, *) = aten::mul(%3, %4) %4 : Float(*, *) = aten::mul(%0, %1)
%2 : Float(*, *) = prim::FusedConcat[dim=0](%7, %5) %5 : Float(*, *) = prim::FusedConcat[dim=0](%3, %4)
return (%2); return (%5);
} }

View File

@ -6,12 +6,12 @@ graph(%0 : Float(*, *)
%5 : Float(*, *) = aten::add(%4, %2, %3) %5 : Float(*, *) = aten::add(%4, %2, %3)
return (%5); return (%5);
} }
with prim::FusionGroup_0 = graph(%3 : Float(*, *) with prim::FusionGroup_0 = graph(%0 : Float(*, *)
%4 : Float(*, *)) { %1 : Float(*, *)) {
%7 : int = prim::Constant[value=1]() %2 : int = prim::Constant[value=1]()
%8 : Float(*, *) = aten::add(%3, %4, %7) %3 : Float(*, *) = aten::add(%0, %1, %2)
%5 : int = prim::Constant[value=1]() %4 : int = prim::Constant[value=1]()
%6 : Float(*, *) = aten::sub(%3, %4, %5) %5 : Float(*, *) = aten::sub(%0, %1, %4)
%2 : Float(*, *) = prim::FusedConcat[dim=0](%8, %6) %6 : Float(*, *) = prim::FusedConcat[dim=0](%3, %5)
return (%2); return (%6);
} }

View File

@ -64,10 +64,10 @@ graph(%0 : Dynamic
%2 : Dynamic %2 : Dynamic
%3 : Dynamic %3 : Dynamic
%4 : Dynamic) { %4 : Dynamic) {
%23 : Dynamic, %24 : Dynamic = prim::GraphExecutor_0(%0, %3, %1, %4, %2) %23 : Dynamic, %24 : Dynamic = prim::DifferentiableGraph_0(%0, %3, %1, %4, %2)
return (%24, %23); return (%24, %23);
} }
with prim::GraphExecutor_0 = graph(%1 : Dynamic with prim::DifferentiableGraph_0 = graph(%1 : Dynamic
%2 : Dynamic %2 : Dynamic
%4 : Dynamic %4 : Dynamic
%5 : Dynamic %5 : Dynamic

View File

@ -3,14 +3,14 @@ graph(%0 : Float(*)
%2 : Float(*) = prim::FusionGroup_0[device=1](%0, %1) %2 : Float(*) = prim::FusionGroup_0[device=1](%0, %1)
return (%2); return (%2);
} }
with prim::FusionGroup_0 = graph(%5 : Float(*) with prim::FusionGroup_0 = graph(%0 : Float(*)
%10 : Float(*)) { %1 : Float(*)) {
%11 : int = prim::Constant[value=1]() %2 : int = prim::Constant[value=1]()
%12 : Float(*) = aten::add(%5, %10, %11) %3 : Float(*) = aten::add(%0, %1, %2)
%9 : Float(*) = aten::mul(%5, %12) %4 : Float(*) = aten::mul(%0, %3)
%6 : int = prim::Constant[value=1]() %5 : int = prim::Constant[value=1]()
%7 : Float(*) = aten::add(%9, %5, %6) %6 : Float(*) = aten::add(%4, %0, %5)
%3 : Float(*) = aten::tanh(%7) %7 : Float(*) = aten::tanh(%6)
%1 : Float(*) = aten::sigmoid(%3) %8 : Float(*) = aten::sigmoid(%7)
return (%1); return (%8);
} }

View File

@ -6,14 +6,14 @@ graph(%0 : Float(*, *)
%6 : Float(*, *) = prim::FusionGroup_0[device=0](%4, %5) %6 : Float(*, *) = prim::FusionGroup_0[device=0](%4, %5)
return (%6); return (%6);
} }
with prim::FusionGroup_0 = graph(%11 : Dynamic with prim::FusionGroup_0 = graph(%0 : Dynamic
%14 : Dynamic) { %1 : Dynamic) {
%15 : Float(*, *), %16 : Float(*, *) = prim::ConstantChunk[chunks=2, dim=1](%14) %2 : Float(*, *), %3 : Float(*, *) = prim::ConstantChunk[chunks=2, dim=1](%1)
%12 : Float(*, *), %13 : Float(*, *) = prim::ConstantChunk[chunks=2, dim=1](%11) %4 : Float(*, *), %5 : Float(*, *) = prim::ConstantChunk[chunks=2, dim=1](%0)
%9 : int = prim::Constant[value=1]() %6 : int = prim::Constant[value=1]()
%10 : Float(*, *) = aten::add(%13, %16, %9) %7 : Float(*, *) = aten::add(%5, %3, %6)
%5 : int = prim::Constant[value=1]() %8 : int = prim::Constant[value=1]()
%6 : Float(*, *) = aten::add(%12, %15, %5) %9 : Float(*, *) = aten::add(%4, %2, %8)
%2 : Float(*, *) = aten::mul(%6, %10) %10 : Float(*, *) = aten::mul(%9, %7)
return (%2); return (%10);
} }

View File

@ -16,29 +16,29 @@ graph(%0 : Float(*, *)
%16 : Float(*, *) = prim::FusionGroup_0[device=0](%2, %14, %15) %16 : Float(*, *) = prim::FusionGroup_0[device=0](%2, %14, %15)
return (%16); return (%16);
} }
with prim::FusionGroup_0 = graph(%15 : Float(*, *) with prim::FusionGroup_0 = graph(%0 : Float(*, *)
%41 : Dynamic %1 : Dynamic
%46 : Dynamic) { %2 : Dynamic) {
%47 : Float(*, *), %48 : Float(*, *), %49 : Float(*, *), %50 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%46) %3 : Float(*, *), %4 : Float(*, *), %5 : Float(*, *), %6 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%2)
%42 : Float(*, *), %43 : Float(*, *), %44 : Float(*, *), %45 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%41) %7 : Float(*, *), %8 : Float(*, *), %9 : Float(*, *), %10 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%1)
%39 : int = prim::Constant[value=1]() %11 : int = prim::Constant[value=1]()
%40 : Float(*, *) = aten::add(%42, %47, %39) %12 : Float(*, *) = aten::add(%7, %3, %11)
%35 : int = prim::Constant[value=1]() %13 : int = prim::Constant[value=1]()
%36 : Float(*, *) = aten::add(%43, %48, %35) %14 : Float(*, *) = aten::add(%8, %4, %13)
%31 : int = prim::Constant[value=1]() %15 : int = prim::Constant[value=1]()
%32 : Float(*, *) = aten::add(%44, %49, %31) %16 : Float(*, *) = aten::add(%9, %5, %15)
%27 : int = prim::Constant[value=1]() %17 : int = prim::Constant[value=1]()
%28 : Float(*, *) = aten::add(%45, %50, %27) %18 : Float(*, *) = aten::add(%10, %6, %17)
%24 : Float(*, *) = aten::sigmoid(%40) %19 : Float(*, *) = aten::sigmoid(%12)
%22 : Float(*, *) = aten::sigmoid(%36) %20 : Float(*, *) = aten::sigmoid(%14)
%20 : Float(*, *) = aten::tanh(%32) %21 : Float(*, *) = aten::tanh(%16)
%18 : Float(*, *) = aten::sigmoid(%28) %22 : Float(*, *) = aten::sigmoid(%18)
%16 : Float(*, *) = aten::mul(%22, %15) %23 : Float(*, *) = aten::mul(%20, %0)
%13 : Float(*, *) = aten::mul(%24, %20) %24 : Float(*, *) = aten::mul(%19, %21)
%9 : int = prim::Constant[value=1]() %25 : int = prim::Constant[value=1]()
%10 : Float(*, *) = aten::add(%16, %13, %9) %26 : Float(*, *) = aten::add(%23, %24, %25)
%6 : Float(*, *) = aten::tanh(%10) %27 : Float(*, *) = aten::tanh(%26)
%5 : Float(*, *) = aten::mul(%18, %6) %28 : Float(*, *) = aten::mul(%22, %27)
%2 : Float(*, *) = prim::FusedConcat[dim=0](%5, %10) %29 : Float(*, *) = prim::FusedConcat[dim=0](%28, %26)
return (%2); return (%29);
} }

View File

@ -16,28 +16,28 @@ graph(%0 : Float(*, *)
%16 : Float(*, *), %17 : Float(*, *) = prim::FusionGroup_0[device=0](%2, %14, %15) %16 : Float(*, *), %17 : Float(*, *) = prim::FusionGroup_0[device=0](%2, %14, %15)
return (%16, %17); return (%16, %17);
} }
with prim::FusionGroup_0 = graph(%13 : Float(*, *) with prim::FusionGroup_0 = graph(%0 : Float(*, *)
%39 : Dynamic %1 : Dynamic
%44 : Dynamic) { %2 : Dynamic) {
%45 : Float(*, *), %46 : Float(*, *), %47 : Float(*, *), %48 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%44) %3 : Float(*, *), %4 : Float(*, *), %5 : Float(*, *), %6 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%2)
%40 : Float(*, *), %41 : Float(*, *), %42 : Float(*, *), %43 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%39) %7 : Float(*, *), %8 : Float(*, *), %9 : Float(*, *), %10 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%1)
%37 : int = prim::Constant[value=1]() %11 : int = prim::Constant[value=1]()
%38 : Float(*, *) = aten::add(%40, %45, %37) %12 : Float(*, *) = aten::add(%7, %3, %11)
%33 : int = prim::Constant[value=1]() %13 : int = prim::Constant[value=1]()
%34 : Float(*, *) = aten::add(%41, %46, %33) %14 : Float(*, *) = aten::add(%8, %4, %13)
%29 : int = prim::Constant[value=1]() %15 : int = prim::Constant[value=1]()
%30 : Float(*, *) = aten::add(%42, %47, %29) %16 : Float(*, *) = aten::add(%9, %5, %15)
%17 : int = prim::Constant[value=1]()
%18 : Float(*, *) = aten::add(%10, %6, %17)
%19 : Float(*, *) = aten::sigmoid(%12)
%20 : Float(*, *) = aten::sigmoid(%14)
%21 : Float(*, *) = aten::tanh(%16)
%22 : Float(*, *) = aten::sigmoid(%18)
%23 : Float(*, *) = aten::mul(%20, %0)
%24 : Float(*, *) = aten::mul(%19, %21)
%25 : int = prim::Constant[value=1]() %25 : int = prim::Constant[value=1]()
%26 : Float(*, *) = aten::add(%43, %48, %25) %26 : Float(*, *) = aten::add(%23, %24, %25)
%22 : Float(*, *) = aten::sigmoid(%38) %27 : Float(*, *) = aten::tanh(%26)
%20 : Float(*, *) = aten::sigmoid(%34) %28 : Float(*, *) = aten::mul(%22, %27)
%18 : Float(*, *) = aten::tanh(%30) return (%28, %26);
%16 : Float(*, *) = aten::sigmoid(%26)
%14 : Float(*, *) = aten::mul(%20, %13)
%11 : Float(*, *) = aten::mul(%22, %18)
%7 : int = prim::Constant[value=1]()
%8 : Float(*, *) = aten::add(%14, %11, %7)
%4 : Float(*, *) = aten::tanh(%8)
%2 : Float(*, *) = aten::mul(%16, %4)
return (%2, %8);
} }

View File

@ -1,6 +1,6 @@
graph(%0 : Double(3, 4) graph(%0 : Double(3, 4)
%1 : Double(4, 5)) { %1 : Double(4, 5)) {
%2 : Double(3, 5) = aten::mm(%0, %1), scope: TracedModule %2 : Double(3, 5) = aten::mm(%0, %1), scope: TracedModule
%3 : Double(3, 4) = aten::neg(%2), scope: TracedModule/traced_fn %3 : Double(*, *) = aten::neg(%2), scope: TracedModule/traced_fn
return (%3); return (%3);
} }

View File

@ -1,5 +1,5 @@
graph(%0 : Double(3, 4)) { graph(%0 : Double(3, 4)) {
%1 : Double(3, 4) = aten::neg(%0), scope: traced_fn1 %1 : Double(*, *) = aten::neg(%0), scope: traced_fn1
%2 : Long() = prim::Constant[value={1}]() %2 : Long() = prim::Constant[value={1}]()
%3 : int = prim::Constant[value=1]() %3 : int = prim::Constant[value=1]()
%4 : Double(3, 4) = aten::add(%1, %2, %3) %4 : Double(3, 4) = aten::add(%1, %2, %3)

View File

@ -1,6 +1,6 @@
graph(%0 : Double(3, 4)) { graph(%0 : Double(3, 4)) {
%1 : Double(4, 3) = prim::Constant[value=<Tensor>](), scope: TracedModule[TracedModule] %1 : Double(4, 3) = prim::Constant[value=<Tensor>](), scope: TracedModule[TracedModule]
%2 : Double(3, 3) = aten::mm(%0, %1), scope: TracedModule[TracedModule] %2 : Double(*, *) = aten::mm(%0, %1), scope: TracedModule[TracedModule]
%3 : Double() = prim::Constant[value={1}]() %3 : Double() = prim::Constant[value={1}]()
%4 : int = prim::Constant[value=1]() %4 : int = prim::Constant[value=1]()
%5 : Double(3, 3) = aten::add(%2, %3, %4) %5 : Double(3, 3) = aten::add(%2, %3, %4)

View File

@ -2,7 +2,7 @@ graph(%0 : Double(3, 4)
%1 : Double(4, 5) %1 : Double(4, 5)
%2 : Double(5, 7)) { %2 : Double(5, 7)) {
%3 : Double(3, 5) = aten::mm(%0, %1), scope: TracedModule %3 : Double(3, 5) = aten::mm(%0, %1), scope: TracedModule
%4 : Double(3, 7) = aten::mm(%3, %2), scope: TracedModule/TracedModule[TracedModule1][mod] %4 : Double(*, *) = aten::mm(%3, %2), scope: TracedModule/TracedModule[TracedModule1][mod]
%5 : Double() = prim::Constant[value={1}](), scope: TracedModule %5 : Double() = prim::Constant[value={1}](), scope: TracedModule
%6 : int = prim::Constant[value=1](), scope: TracedModule %6 : int = prim::Constant[value=1](), scope: TracedModule
%7 : Double(3, 7) = aten::add(%4, %5, %6), scope: TracedModule %7 : Double(3, 7) = aten::add(%4, %5, %6), scope: TracedModule

View File

@ -2,10 +2,10 @@ graph(%x : Float(*, *)) {
%1 : Float(*, *) = prim::FusionGroup_0[device=0](%x) %1 : Float(*, *) = prim::FusionGroup_0[device=0](%x)
return (%1); return (%1);
} }
with prim::FusionGroup_0 = graph(%7 : Float(*, *)) { with prim::FusionGroup_0 = graph(%0 : Float(*, *)) {
%8 : Float(*, *), %9 : Float(*, *), %10 : Float(*, *) = prim::ConstantChunk[chunks=3, dim=1](%7) %1 : Float(*, *), %2 : Float(*, *), %3 : Float(*, *) = prim::ConstantChunk[chunks=3, dim=1](%0)
%6 : Float(*, *) = aten::mul(%8, %9) %4 : Float(*, *) = aten::mul(%1, %2)
%2 : int = prim::Constant[value=1]() %5 : int = prim::Constant[value=1]()
%3 : Float(*, *) = aten::add(%6, %10, %2) %6 : Float(*, *) = aten::add(%4, %3, %5)
return (%3); return (%6);
} }

View File

@ -5,26 +5,26 @@ graph(%s : Float(*, *, *)
%4 : Float(*, *, *) = prim::FusionGroup_0[device=0](%s, %y, %x, %z) %4 : Float(*, *, *) = prim::FusionGroup_0[device=0](%s, %y, %x, %z)
return (%4); return (%4);
} }
with prim::FusionGroup_0 = graph(%24 : Float(*, *, *) with prim::FusionGroup_0 = graph(%0 : Float(*, *, *)
%28 : Float(*, *, *) %1 : Float(*, *, *)
%31 : Float(*, *, *) %2 : Float(*, *, *)
%35 : Float(*, *, *)) { %3 : Float(*, *, *)) {
%36 : Float(*, *, *), %37 : Float(*, *, *) = prim::ConstantChunk[chunks=2, dim=2](%35) %4 : Float(*, *, *), %5 : Float(*, *, *) = prim::ConstantChunk[chunks=2, dim=2](%3)
%32 : Float(*, *, *), %33 : Float(*, *, *), %34 : Float(*, *, *) = prim::ConstantChunk[chunks=3, dim=1](%31) %6 : Float(*, *, *), %7 : Float(*, *, *), %8 : Float(*, *, *) = prim::ConstantChunk[chunks=3, dim=1](%2)
%29 : Float(*, *, *), %30 : Float(*, *, *) = prim::ConstantChunk[chunks=2, dim=0](%28) %9 : Float(*, *, *), %10 : Float(*, *, *) = prim::ConstantChunk[chunks=2, dim=0](%1)
%26 : int = prim::Constant[value=1]() %11 : int = prim::Constant[value=1]()
%27 : Float(*, *, *) = aten::add(%24, %32, %26) %12 : Float(*, *, *) = aten::add(%0, %6, %11)
%22 : int = prim::Constant[value=1]() %13 : int = prim::Constant[value=1]()
%23 : Float(*, *, *) = aten::add(%27, %33, %22) %14 : Float(*, *, *) = aten::add(%12, %7, %13)
%18 : int = prim::Constant[value=1]() %15 : int = prim::Constant[value=1]()
%19 : Float(*, *, *) = aten::add(%23, %34, %18) %16 : Float(*, *, *) = aten::add(%14, %8, %15)
%14 : int = prim::Constant[value=1]() %17 : int = prim::Constant[value=1]()
%15 : Float(*, *, *) = aten::add(%19, %29, %14) %18 : Float(*, *, *) = aten::add(%16, %9, %17)
%10 : int = prim::Constant[value=1]() %19 : int = prim::Constant[value=1]()
%11 : Float(*, *, *) = aten::add(%15, %30, %10) %20 : Float(*, *, *) = aten::add(%18, %10, %19)
%6 : int = prim::Constant[value=1]() %21 : int = prim::Constant[value=1]()
%7 : Float(*, *, *) = aten::add(%11, %36, %6) %22 : Float(*, *, *) = aten::add(%20, %4, %21)
%2 : int = prim::Constant[value=1]() %23 : int = prim::Constant[value=1]()
%3 : Float(*, *, *) = aten::add(%7, %37, %2) %24 : Float(*, *, *) = aten::add(%22, %5, %23)
return (%3); return (%24);
} }

View File

@ -7,9 +7,9 @@ graph(%0 : Float(*, *)
%6 : Dynamic %6 : Dynamic
%7 : Dynamic %7 : Dynamic
%8 : Dynamic %8 : Dynamic
%x : Float(*, *) %9 : Float(*, *)
%hx : Float(*, *) %10 : Float(*, *)
%cx : Float(*, *) %11 : Float(*, *)
%12 : Float(*, *) %12 : Float(*, *)
%13 : Float(*, *) %13 : Float(*, *)
%ingate : Float(*, *) %ingate : Float(*, *)
@ -18,58 +18,67 @@ graph(%0 : Float(*, *)
%outgate : Float(*, *) %outgate : Float(*, *)
%18 : Float(*, *)) { %18 : Float(*, *)) {
%19 : int = prim::Constant[value=1]() %19 : int = prim::Constant[value=1]()
%20 : Float(*, *) = prim::FusionGroup_0[device=0](%ingate, %forgetgate, %cellgate, %outgate, %cx, %1, %18, %0) %20 : Float(*, *), %21 : Float(*, *) = prim::FusionGroup_0[device=0](%ingate, %forgetgate, %cellgate, %outgate, %11, %1, %18, %0)
%21 : Float(*, *) = aten::mul(%20, %19) %22 : Float(*, *) = aten::mul(%20, %19)
%22 : Float(*, *) = aten::t(%hx) %23 : Float(*, *) = aten::t(%12)
%23 : Float(*, *) = aten::mm(%22, %21) %24 : Float(*, *) = aten::mm(%20, %23)
%24 : Float(*, *) = aten::t(%23) %25 : Float(*, *) = aten::mul(%24, %19)
%25 : Float(*, *) = aten::t(%x) %26 : Float(*, *) = aten::t(%10)
%26 : Float(*, *) = aten::mm(%25, %20) %27 : Float(*, *) = aten::mm(%26, %20)
%27 : Float(*, *) = aten::t(%26) %28 : Float(*, *) = aten::mul(%27, %19)
return (%27, %24, %21, %21); %29 : Float(*, *) = aten::t(%13)
%30 : Float(*, *) = aten::mm(%22, %29)
%31 : Float(*, *) = aten::t(%9)
%32 : Float(*, *) = aten::mm(%31, %22)
%33 : Float(*, *) = aten::t(%32)
%34 : Float(*, *) = aten::t(%28)
return (%34, %33, %30, %25, %22, %22, %21);
} }
with prim::FusionGroup_0 = graph(%9 : Float(*, *) with prim::FusionGroup_0 = graph(%0 : Float(*, *)
%19 : Float(*, *) %1 : Float(*, *)
%33 : Float(*, *) %2 : Float(*, *)
%39 : Float(*, *) %3 : Float(*, *)
%46 : Float(*, *) %4 : Float(*, *)
%53 : Float(*, *) %5 : Float(*, *)
%65 : Float(*, *) %6 : Float(*, *)
%67 : Float(*, *)) { %7 : Float(*, *)) {
%69 : Float(*, *) = aten::mul(%67, %65) %8 : Float(*, *) = aten::mul(%5, %3)
%68 : Float(*, *) = aten::mul(%67, %39) %9 : Float(*, *) = aten::mul(%6, %6)
%66 : Float(*, *) = aten::mul(%65, %65) %10 : Float(*, *) = aten::neg(%9)
%64 : Float(*, *) = aten::neg(%66) %11 : int = prim::Constant[value=1]()
%61 : int = prim::Constant[value=1]()
%62 : Float(*, *) = aten::add(%64, %61, %61)
%59 : Float(*, *) = aten::mul(%68, %62)
%55 : int = prim::Constant[value=1]()
%56 : Float(*, *) = aten::add(%53, %59, %55)
%51 : int = prim::Constant[value=1]()
%52 : Float(*, *) = aten::mul(%56, %51)
%50 : Float(*, *) = aten::mul(%52, %33)
%49 : Float(*, *) = aten::mul(%52, %9)
%47 : Float(*, *) = aten::mul(%56, %46)
%44 : Float(*, *) = aten::neg(%39)
%42 : int = prim::Constant[value=1]()
%43 : Float(*, *) = aten::add(%44, %42, %42)
%40 : Float(*, *) = aten::mul(%69, %39)
%37 : Float(*, *) = aten::mul(%40, %43)
%34 : Float(*, *) = aten::mul(%33, %33)
%32 : Float(*, *) = aten::neg(%34)
%29 : int = prim::Constant[value=1]()
%30 : Float(*, *) = aten::add(%32, %29, %29)
%27 : Float(*, *) = aten::mul(%49, %30)
%24 : Float(*, *) = aten::neg(%19)
%22 : int = prim::Constant[value=1]()
%23 : Float(*, *) = aten::add(%24, %22, %22)
%20 : Float(*, *) = aten::mul(%47, %19)
%17 : Float(*, *) = aten::mul(%20, %23)
%14 : Float(*, *) = aten::neg(%9)
%12 : int = prim::Constant[value=1]() %12 : int = prim::Constant[value=1]()
%13 : Float(*, *) = aten::add(%14, %12, %12) %13 : Float(*, *) = aten::add(%10, %12, %12)
%10 : Float(*, *) = aten::mul(%50, %9) %14 : Float(*, *) = aten::mul(%8, %13)
%7 : Float(*, *) = aten::mul(%10, %13) %15 : int = prim::Constant[value=1]()
%4 : Float(*, *) = prim::FusedConcat[dim=1](%7, %17, %27, %37) %16 : int = prim::Constant[value=1]()
return (%4); %17 : Float(*, *) = aten::add(%7, %14, %16)
%18 : Float(*, *) = aten::mul(%17, %1)
%19 : Float(*, *) = aten::mul(%5, %6)
%20 : int = prim::Constant[value=1]()
%21 : Float(*, *) = aten::mul(%17, %20)
%22 : Float(*, *) = aten::mul(%21, %2)
%23 : Float(*, *) = aten::mul(%21, %0)
%24 : Float(*, *) = aten::mul(%17, %4)
%25 : Float(*, *) = aten::neg(%3)
%26 : int = prim::Constant[value=1]()
%27 : Float(*, *) = aten::add(%25, %26, %26)
%28 : Float(*, *) = aten::mul(%19, %3)
%29 : Float(*, *) = aten::mul(%28, %27)
%30 : Float(*, *) = aten::mul(%2, %2)
%31 : Float(*, *) = aten::neg(%30)
%32 : int = prim::Constant[value=1]()
%33 : Float(*, *) = aten::add(%31, %32, %32)
%34 : Float(*, *) = aten::mul(%23, %33)
%35 : Float(*, *) = aten::neg(%1)
%36 : int = prim::Constant[value=1]()
%37 : Float(*, *) = aten::add(%35, %36, %36)
%38 : Float(*, *) = aten::mul(%24, %1)
%39 : Float(*, *) = aten::mul(%38, %37)
%40 : Float(*, *) = aten::neg(%0)
%41 : int = prim::Constant[value=1]()
%42 : Float(*, *) = aten::add(%40, %41, %41)
%43 : Float(*, *) = aten::mul(%22, %0)
%44 : Float(*, *) = aten::mul(%43, %42)
%45 : Float(*, *) = prim::FusedConcat[dim=1](%44, %39, %34, %29)
return (%45, %18);
} }

View File

@ -1,44 +1,54 @@
graph(%x.1 : Float(*, *) graph(%x : Float(*, *)
%hx.1 : Float(*, *) %hx : Float(*, *)
%cx.1 : Float(*, *) %cx : Float(*, *)
%w_ih : Float(*, *) %w_ih : Float(*, *)
%w_hh : Float(*, *) %w_hh : Float(*, *)
%b_ih : Float(*) %b_ih : Float(*)
%b_hh : Float(*)) { %b_hh : Float(*)) {
%7 : Float(*, *) = aten::t(%w_ih) %7 : Float(*, *), %8 : Float(*, *) = prim::DifferentiableGraph_0(%w_ih, %w_hh, %hx, %x, %b_ih, %b_hh, %cx)
%8 : Float(*, *) = aten::t(%w_hh) return (%8, %7);
%9 : Float(*, *) = aten::mm(%hx.1, %8) }
with prim::DifferentiableGraph_0 = graph(%0 : Float(*, *)
%1 : Float(*, *)
%2 : Float(*, *)
%3 : Float(*, *)
%4 : Float(*)
%5 : Float(*)
%6 : Float(*, *)) {
%7 : Float(*, *) = aten::t(%0)
%8 : Float(*, *) = aten::t(%1)
%9 : Float(*, *) = aten::mm(%2, %8)
%10 : int = prim::Constant[value=1]() %10 : int = prim::Constant[value=1]()
%11 : Float(*, *) = aten::addmm(%9, %x.1, %7, %10, %10) %11 : Float(*, *) = aten::addmm(%9, %3, %7, %10, %10)
%12 : Float(*, *) = aten::add(%11, %b_ih, %10) %12 : Float(*, *) = aten::add(%11, %4, %10)
%13 : Dynamic[] = prim::ListConstruct(%12, %b_hh) %13 : Dynamic[] = prim::ListConstruct(%12, %5)
%14 : Dynamic[] = aten::broadcast_tensors(%13) %14 : Dynamic[] = aten::broadcast_tensors(%13)
%15 : Dynamic, %16 : Dynamic = prim::ListUnpack(%14) %15 : Dynamic, %16 : Dynamic = prim::ListUnpack(%14)
%hy : Float(*, *), %18 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_0[device=0](%cx.1, %15, %16) %hy : Float(*, *), %18 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_0[device=0](%6, %15, %16)
return (%hy, %cy, %7, %8, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %18); return (%cy, %hy, %7, %8, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %18);
} }
with prim::FusionGroup_0 = graph(%13 : Float(*, *) with prim::FusionGroup_0 = graph(%0 : Float(*, *)
%39 : Dynamic %1 : Dynamic
%44 : Dynamic) { %2 : Dynamic) {
%45 : Dynamic, %46 : Dynamic, %47 : Dynamic, %48 : Dynamic = prim::FusedChunk[chunks=4, dim=1](%44) %3 : Float(*, *), %4 : Float(*, *), %5 : Float(*, *), %6 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%2)
%40 : Dynamic, %41 : Dynamic, %42 : Dynamic, %43 : Dynamic = prim::FusedChunk[chunks=4, dim=1](%39) %7 : Float(*, *), %8 : Float(*, *), %9 : Float(*, *), %10 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%1)
%37 : int = prim::Constant[value=1]() %11 : int = prim::Constant[value=1]()
%38 : Float(*, *) = aten::add(%40, %45, %37) %12 : Float(*, *) = aten::add(%7, %3, %11)
%33 : int = prim::Constant[value=1]() %13 : int = prim::Constant[value=1]()
%34 : Float(*, *) = aten::add(%41, %46, %33) %14 : Float(*, *) = aten::add(%8, %4, %13)
%29 : int = prim::Constant[value=1]() %15 : int = prim::Constant[value=1]()
%30 : Float(*, *) = aten::add(%42, %47, %29) %16 : Float(*, *) = aten::add(%9, %5, %15)
%17 : int = prim::Constant[value=1]()
%18 : Float(*, *) = aten::add(%10, %6, %17)
%ingate.1 : Float(*, *) = aten::sigmoid(%12)
%forgetgate.1 : Float(*, *) = aten::sigmoid(%14)
%cellgate.1 : Float(*, *) = aten::tanh(%16)
%outgate.1 : Float(*, *) = aten::sigmoid(%18)
%23 : Float(*, *) = aten::mul(%forgetgate.1, %0)
%24 : Float(*, *) = aten::mul(%ingate.1, %cellgate.1)
%25 : int = prim::Constant[value=1]() %25 : int = prim::Constant[value=1]()
%26 : Float(*, *) = aten::add(%43, %48, %25) %cy : Float(*, *) = aten::add(%23, %24, %25)
%ingate.1 : Float(*, *) = aten::sigmoid(%38) %27 : Float(*, *) = aten::tanh(%cy)
%forgetgate.1 : Float(*, *) = aten::sigmoid(%34) %hy : Float(*, *) = aten::mul(%outgate.1, %27)
%cellgate.1 : Float(*, *) = aten::tanh(%30) return (%hy, %27, %cy, %outgate.1, %cellgate.1, %forgetgate.1, %ingate.1);
%outgate.1 : Float(*, *) = aten::sigmoid(%26)
%14 : Float(*, *) = aten::mul(%forgetgate.1, %13)
%11 : Float(*, *) = aten::mul(%ingate.1, %cellgate.1)
%7 : int = prim::Constant[value=1]()
%cy : Float(*, *) = aten::add(%14, %11, %7)
%4 : Float(*, *) = aten::tanh(%cy)
%hy : Float(*, *) = aten::mul(%outgate.1, %4)
return (%hy, %4, %cy, %outgate.1, %cellgate.1, %forgetgate.1, %ingate.1);
} }

View File

@ -10,12 +10,12 @@ graph(%0 : Float(*, *)
%9 : Dynamic %9 : Dynamic
%10 : Dynamic %10 : Dynamic
%11 : Dynamic %11 : Dynamic
%x : Float(*, *) %12 : Float(*, *)
%hx : Float(*, *) %13 : Float(*, *)
%cx : Float(*, *) %14 : Float(*)
%alpha : Float(*) %15 : Float(*)
%beta_i : Float(*) %16 : Float(*)
%beta_h : Float(*) %17 : Float(*, *)
%18 : Float(*, *) %18 : Float(*, *)
%Wx : Float(*, *) %Wx : Float(*, *)
%20 : Float(*, *) %20 : Float(*, *)
@ -26,85 +26,92 @@ graph(%0 : Float(*, *)
%cellgate : Float(*, *) %cellgate : Float(*, *)
%outgate : Float(*, *) %outgate : Float(*, *)
%27 : Float(*, *)) { %27 : Float(*, *)) {
%28 : Float(*, *) = prim::FusionGroup_0[device=0](%ingate, %forgetgate, %cellgate, %outgate, %cx, %1, %27, %0) %28 : Float(*, *), %29 : Float(*, *) = prim::FusionGroup_0[device=0](%ingate, %forgetgate, %cellgate, %outgate, %17, %1, %27, %0)
%29 : Float(*, *), %30 : Float(*, *), %31 : Float(*, *), %32 : Float(*, *), %33 : Float(*, *), %34 : Float(*, *) = prim::FusionGroup_1[device=0](%alpha, %beta_i, %Wx, %28, %Uz, %22, %beta_h) %30 : Float(*, *), %31 : Float(*, *), %32 : Float(*, *), %33 : Float(*, *), %34 : Float(*, *), %35 : Float(*, *) = prim::FusionGroup_1[device=0](%14, %15, %Wx, %28, %Uz, %22, %16)
%35 : Float(*, *) = aten::t(%hx) %36 : Float(*, *) = aten::t(%20)
%36 : Float(*, *) = aten::mm(%35, %31) %37 : Float(*, *) = aten::mm(%32, %36)
%37 : Float(*, *) = aten::t(%36) %38 : Float(*, *) = aten::t(%13)
%38 : Float(*, *) = aten::t(%x) %39 : Float(*, *) = aten::mm(%38, %32)
%39 : Float(*, *) = aten::mm(%38, %29)
%40 : Float(*, *) = aten::t(%39) %40 : Float(*, *) = aten::t(%39)
return (%40, %37, %30, %32, %33, %34); %41 : Float(*, *) = aten::t(%18)
%42 : Float(*, *) = aten::mm(%30, %41)
%43 : Float(*, *) = aten::t(%12)
%44 : Float(*, *) = aten::mm(%43, %30)
%45 : Float(*, *) = aten::t(%44)
return (%45, %42, %40, %37, %31, %33, %34, %35, %29);
} }
with prim::FusionGroup_0 = graph(%9 : Float(*, *) with prim::FusionGroup_0 = graph(%0 : Float(*, *)
%19 : Float(*, *) %1 : Float(*, *)
%33 : Float(*, *) %2 : Float(*, *)
%39 : Float(*, *) %3 : Float(*, *)
%46 : Float(*, *) %4 : Float(*, *)
%53 : Float(*, *) %5 : Float(*, *)
%65 : Float(*, *) %6 : Float(*, *)
%67 : Float(*, *)) { %7 : Float(*, *)) {
%69 : Float(*, *) = aten::mul(%67, %65) %8 : Float(*, *) = aten::mul(%5, %3)
%68 : Float(*, *) = aten::mul(%67, %39) %9 : Float(*, *) = aten::mul(%6, %6)
%66 : Float(*, *) = aten::mul(%65, %65) %10 : Float(*, *) = aten::neg(%9)
%64 : Float(*, *) = aten::neg(%66) %11 : int = prim::Constant[value=1]()
%61 : int = prim::Constant[value=1]()
%62 : Float(*, *) = aten::add(%64, %61, %61)
%59 : Float(*, *) = aten::mul(%68, %62)
%55 : int = prim::Constant[value=1]()
%56 : Float(*, *) = aten::add(%53, %59, %55)
%51 : int = prim::Constant[value=1]()
%52 : Float(*, *) = aten::mul(%56, %51)
%50 : Float(*, *) = aten::mul(%52, %33)
%49 : Float(*, *) = aten::mul(%52, %9)
%47 : Float(*, *) = aten::mul(%56, %46)
%44 : Float(*, *) = aten::neg(%39)
%42 : int = prim::Constant[value=1]()
%43 : Float(*, *) = aten::add(%44, %42, %42)
%40 : Float(*, *) = aten::mul(%69, %39)
%37 : Float(*, *) = aten::mul(%40, %43)
%34 : Float(*, *) = aten::mul(%33, %33)
%32 : Float(*, *) = aten::neg(%34)
%29 : int = prim::Constant[value=1]()
%30 : Float(*, *) = aten::add(%32, %29, %29)
%27 : Float(*, *) = aten::mul(%49, %30)
%24 : Float(*, *) = aten::neg(%19)
%22 : int = prim::Constant[value=1]()
%23 : Float(*, *) = aten::add(%24, %22, %22)
%20 : Float(*, *) = aten::mul(%47, %19)
%17 : Float(*, *) = aten::mul(%20, %23)
%14 : Float(*, *) = aten::neg(%9)
%12 : int = prim::Constant[value=1]() %12 : int = prim::Constant[value=1]()
%13 : Float(*, *) = aten::add(%14, %12, %12) %13 : Float(*, *) = aten::add(%10, %12, %12)
%10 : Float(*, *) = aten::mul(%50, %9) %14 : Float(*, *) = aten::mul(%8, %13)
%7 : Float(*, *) = aten::mul(%10, %13) %15 : int = prim::Constant[value=1]()
%4 : Float(*, *) = prim::FusedConcat[dim=1](%7, %17, %27, %37) %16 : int = prim::Constant[value=1]()
return (%4); %17 : Float(*, *) = aten::add(%7, %14, %16)
} %18 : Float(*, *) = aten::mul(%17, %1)
with prim::FusionGroup_1 = graph(%5 : Float(*) %19 : Float(*, *) = aten::mul(%5, %6)
%8 : Float(*) %20 : int = prim::Constant[value=1]()
%10 : Float(*, *) %21 : Float(*, *) = aten::mul(%17, %20)
%12 : Float(*, *) %22 : Float(*, *) = aten::mul(%21, %2)
%13 : Float(*, *) %23 : Float(*, *) = aten::mul(%21, %0)
%20 : Float(*, *) %24 : Float(*, *) = aten::mul(%17, %4)
%22 : Float(*)) { %25 : Float(*, *) = aten::neg(%3)
%30 : int = prim::Constant[value=1]()
%29 : int = prim::Constant[value=1]()
%28 : int = prim::Constant[value=1]()
%26 : int = prim::Constant[value=1]() %26 : int = prim::Constant[value=1]()
%27 : Float(*, *) = aten::mul(%12, %26) %27 : Float(*, *) = aten::add(%25, %26, %26)
%25 : Float(*, *) = aten::mul(%27, %13) %28 : Float(*, *) = aten::mul(%19, %3)
%24 : Float(*, *) = aten::mul(%27, %10) %29 : Float(*, *) = aten::mul(%28, %27)
%23 : Float(*, *) = aten::mul(%27, %22) %30 : Float(*, *) = aten::mul(%2, %2)
%21 : Float(*, *) = aten::mul(%12, %20) %31 : Float(*, *) = aten::neg(%30)
%19 : int = prim::Constant[value=1]() %32 : int = prim::Constant[value=1]()
%17 : int = prim::Constant[value=1]() %33 : Float(*, *) = aten::add(%31, %32, %32)
%18 : Float(*, *) = aten::add(%23, %21, %17) %34 : Float(*, *) = aten::mul(%23, %33)
%14 : Float(*, *) = aten::mul(%12, %13) %35 : Float(*, *) = aten::neg(%1)
%11 : Float(*, *) = aten::mul(%14, %10) %36 : int = prim::Constant[value=1]()
%9 : Float(*, *) = aten::mul(%27, %8) %37 : Float(*, *) = aten::add(%35, %36, %36)
%6 : Float(*, *) = aten::mul(%14, %5) %38 : Float(*, *) = aten::mul(%24, %1)
%2 : int = prim::Constant[value=1]() %39 : Float(*, *) = aten::mul(%38, %37)
%3 : Float(*, *) = aten::add(%9, %6, %2) %40 : Float(*, *) = aten::neg(%0)
return (%3, %11, %18, %24, %25, %27); %41 : int = prim::Constant[value=1]()
%42 : Float(*, *) = aten::add(%40, %41, %41)
%43 : Float(*, *) = aten::mul(%22, %0)
%44 : Float(*, *) = aten::mul(%43, %42)
%45 : Float(*, *) = prim::FusedConcat[dim=1](%44, %39, %34, %29)
return (%45, %18);
}
with prim::FusionGroup_1 = graph(%0 : Float(*)
%1 : Float(*)
%2 : Float(*, *)
%3 : Float(*, *)
%4 : Float(*, *)
%5 : Float(*, *)
%6 : Float(*)) {
%7 : int = prim::Constant[value=1]()
%8 : int = prim::Constant[value=1]()
%9 : int = prim::Constant[value=1]()
%10 : int = prim::Constant[value=1]()
%11 : Float(*, *) = aten::mul(%3, %10)
%12 : Float(*, *) = aten::mul(%11, %4)
%13 : Float(*, *) = aten::mul(%11, %2)
%14 : Float(*, *) = aten::mul(%11, %6)
%15 : Float(*, *) = aten::mul(%3, %5)
%16 : int = prim::Constant[value=1]()
%17 : int = prim::Constant[value=1]()
%18 : Float(*, *) = aten::add(%14, %15, %17)
%19 : Float(*, *) = aten::mul(%3, %4)
%20 : Float(*, *) = aten::mul(%19, %2)
%21 : Float(*, *) = aten::mul(%11, %1)
%22 : Float(*, *) = aten::mul(%19, %0)
%23 : int = prim::Constant[value=1]()
%24 : Float(*, *) = aten::add(%21, %22, %23)
return (%24, %20, %18, %13, %12, %11);
} }

View File

@ -1,60 +1,73 @@
graph(%x.1 : Float(*, *) graph(%x : Float(*, *)
%hx.1 : Float(*, *) %hx : Float(*, *)
%cx.1 : Float(*, *) %cx : Float(*, *)
%w_ih : Float(*, *) %w_ih : Float(*, *)
%w_hh : Float(*, *) %w_hh : Float(*, *)
%alpha.1 : Float(*) %alpha : Float(*)
%beta_i.1 : Float(*) %beta_i : Float(*)
%beta_h.1 : Float(*) %beta_h : Float(*)
%bias : Float(*)) { %bias : Float(*)) {
%9 : Float(*, *) = aten::t(%w_ih) %9 : Float(*, *), %10 : Float(*, *) = prim::DifferentiableGraph_0(%w_ih, %x, %w_hh, %hx, %alpha, %beta_i, %beta_h, %bias, %cx)
%Wx.1 : Float(*, *) = aten::mm(%x.1, %9) return (%10, %9);
%11 : Float(*, *) = aten::t(%w_hh)
%Uz.1 : Float(*, *) = aten::mm(%hx.1, %11)
%13 : Float(*, *), %14 : Float(*, *) = prim::FusionGroup_0[device=0](%beta_h.1, %Uz.1, %beta_i.1, %Wx.1, %alpha.1)
%15 : Dynamic[] = prim::ListConstruct(%13, %bias)
%16 : Dynamic[] = aten::broadcast_tensors(%15)
%17 : Dynamic, %18 : Dynamic = prim::ListUnpack(%16)
%hy : Float(*, *), %20 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_1[device=0](%cx.1, %17, %18)
return (%hy, %cy, %9, %Wx.1, %11, %Uz.1, %14, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %20);
} }
with prim::FusionGroup_0 = graph(%4 : Float(*) with prim::DifferentiableGraph_0 = graph(%0 : Float(*, *)
%5 : Float(*, *) %1 : Float(*, *)
%11 : Float(*) %2 : Float(*, *)
%12 : Float(*, *) %3 : Float(*, *)
%16 : Float(*)) { %4 : Float(*)
%17 : Float(*, *) = aten::mul(%16, %12) %5 : Float(*)
%15 : Float(*, *) = aten::mul(%17, %5) %6 : Float(*)
%13 : Float(*, *) = aten::mul(%11, %12) %7 : Float(*)
%9 : int = prim::Constant[value=1]() %8 : Float(*, *)) {
%10 : Float(*, *) = aten::add(%15, %13, %9) %9 : Float(*, *) = aten::t(%0)
%6 : Float(*, *) = aten::mul(%4, %5) %Wx.1 : Float(*, *) = aten::mm(%1, %9)
%2 : int = prim::Constant[value=1]() %11 : Float(*, *) = aten::t(%2)
%3 : Float(*, *) = aten::add(%10, %6, %2) %Uz.1 : Float(*, *) = aten::mm(%3, %11)
return (%3, %17); %13 : int = prim::Constant[value=1]()
%14 : Float(*, *), %15 : Float(*, *) = prim::FusionGroup_0[device=0](%6, %Uz.1, %5, %Wx.1, %4)
%16 : Dynamic[] = prim::ListConstruct(%14, %7)
%17 : Dynamic[] = aten::broadcast_tensors(%16)
%18 : Dynamic, %19 : Dynamic = prim::ListUnpack(%17)
%hy : Float(*, *), %21 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_1[device=0](%8, %18, %19)
return (%cy, %hy, %9, %Wx.1, %11, %Uz.1, %15, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %21);
} }
with prim::FusionGroup_1 = graph(%13 : Float(*, *) with prim::FusionGroup_0 = graph(%0 : Float(*)
%39 : Dynamic %1 : Float(*, *)
%44 : Dynamic) { %2 : Float(*)
%45 : Dynamic, %46 : Dynamic, %47 : Dynamic, %48 : Dynamic = prim::FusedChunk[chunks=4, dim=1](%44) %3 : Float(*, *)
%40 : Dynamic, %41 : Dynamic, %42 : Dynamic, %43 : Dynamic = prim::FusedChunk[chunks=4, dim=1](%39) %4 : Float(*)) {
%37 : int = prim::Constant[value=1]() %5 : Float(*, *) = aten::mul(%4, %3)
%38 : Float(*, *) = aten::add(%40, %45, %37) %6 : Float(*, *) = aten::mul(%5, %1)
%33 : int = prim::Constant[value=1]() %7 : Float(*, *) = aten::mul(%2, %3)
%34 : Float(*, *) = aten::add(%41, %46, %33) %8 : int = prim::Constant[value=1]()
%29 : int = prim::Constant[value=1]() %9 : Float(*, *) = aten::add(%6, %7, %8)
%30 : Float(*, *) = aten::add(%42, %47, %29) %10 : Float(*, *) = aten::mul(%0, %1)
%11 : int = prim::Constant[value=1]()
%12 : Float(*, *) = aten::add(%9, %10, %11)
return (%12, %5);
}
with prim::FusionGroup_1 = graph(%0 : Float(*, *)
%1 : Dynamic
%2 : Dynamic) {
%3 : Float(*, *), %4 : Float(*, *), %5 : Float(*, *), %6 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%2)
%7 : Float(*, *), %8 : Float(*, *), %9 : Float(*, *), %10 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%1)
%11 : int = prim::Constant[value=1]()
%12 : Float(*, *) = aten::add(%7, %3, %11)
%13 : int = prim::Constant[value=1]()
%14 : Float(*, *) = aten::add(%8, %4, %13)
%15 : int = prim::Constant[value=1]()
%16 : Float(*, *) = aten::add(%9, %5, %15)
%17 : int = prim::Constant[value=1]()
%18 : Float(*, *) = aten::add(%10, %6, %17)
%ingate.1 : Float(*, *) = aten::sigmoid(%12)
%forgetgate.1 : Float(*, *) = aten::sigmoid(%14)
%cellgate.1 : Float(*, *) = aten::tanh(%16)
%outgate.1 : Float(*, *) = aten::sigmoid(%18)
%23 : Float(*, *) = aten::mul(%forgetgate.1, %0)
%24 : Float(*, *) = aten::mul(%ingate.1, %cellgate.1)
%25 : int = prim::Constant[value=1]() %25 : int = prim::Constant[value=1]()
%26 : Float(*, *) = aten::add(%43, %48, %25) %cy : Float(*, *) = aten::add(%23, %24, %25)
%ingate.1 : Float(*, *) = aten::sigmoid(%38) %27 : Float(*, *) = aten::tanh(%cy)
%forgetgate.1 : Float(*, *) = aten::sigmoid(%34) %hy : Float(*, *) = aten::mul(%outgate.1, %27)
%cellgate.1 : Float(*, *) = aten::tanh(%30) return (%hy, %27, %cy, %outgate.1, %cellgate.1, %forgetgate.1, %ingate.1);
%outgate.1 : Float(*, *) = aten::sigmoid(%26)
%14 : Float(*, *) = aten::mul(%forgetgate.1, %13)
%11 : Float(*, *) = aten::mul(%ingate.1, %cellgate.1)
%7 : int = prim::Constant[value=1]()
%cy : Float(*, *) = aten::add(%14, %11, %7)
%4 : Float(*, *) = aten::tanh(%cy)
%hy : Float(*, *) = aten::mul(%outgate.1, %4)
return (%hy, %4, %cy, %outgate.1, %cellgate.1, %forgetgate.1, %ingate.1);
} }

View File

@ -4,8 +4,8 @@ graph(%x : Float(*, *)) {
} }
with prim::FusionGroup_0 = graph(%0 : Float(*, *)) { with prim::FusionGroup_0 = graph(%0 : Float(*, *)) {
%z : float = prim::Constant[value=3]() %z : float = prim::Constant[value=3]()
%4 : int = prim::Constant[value=1]() %2 : int = prim::Constant[value=1]()
%y : Float(*, *) = aten::add(%0, %z, %4) %y : Float(*, *) = aten::add(%0, %z, %2)
%2 : Float(*, *) = aten::mul(%0, %y) %4 : Float(*, *) = aten::mul(%0, %y)
return (%2); return (%4);
} }

View File

@ -152,12 +152,19 @@ def get_fn(file_name, script_path):
def get_execution_plan(graph_executor_state): def get_execution_plan(graph_executor_state):
execution_plans = graph_executor_state.execution_plans.values() execution_plans = list(graph_executor_state.execution_plans.values())
num_plans = len(execution_plans) num_plans = len(execution_plans)
if num_plans != 1: if num_plans != 1:
raise RuntimeError('This test assumes this GraphExecutor should ' raise RuntimeError('This test assumes this GraphExecutor should '
'only have one execution plan, got: {}'.format(num_plans)) 'only have one execution plan, got: {}'.format(num_plans))
return list(execution_plans)[0] return execution_plans[0]
def get_grad_executor(plan_state):
if len(list(plan_state.graph.nodes())) != 1:
raise RuntimeError("Can't get a grad_executor for a non-differentiable graph")
grad_executors = list(plan_state.code.grad_executors())
return grad_executors[0]
def backward_graph(script_module): def backward_graph(script_module):
@ -165,10 +172,8 @@ def backward_graph(script_module):
raise RuntimeError('Expected ScriptModule') raise RuntimeError('Expected ScriptModule')
ge_state = script_module.get_debug_state() ge_state = script_module.get_debug_state()
fwd_plan = get_execution_plan(ge_state) fwd_plan = get_execution_plan(ge_state)
if fwd_plan.grad_executor is None: grad_executor = get_grad_executor(fwd_plan)
raise RuntimeError('Error: tried to get grad_executor of function ' bwd_plan = get_execution_plan(grad_executor.get_debug_state())
'that hasn\'t run backward yet.')
bwd_plan = get_execution_plan(fwd_plan.grad_executor)
# Running JIT passes requires that we own the graph (with a shared_ptr). # Running JIT passes requires that we own the graph (with a shared_ptr).
# The debug state struct does not own its graph so we make a copy of it. # The debug state struct does not own its graph so we make a copy of it.
return bwd_plan.graph.copy() return bwd_plan.graph.copy()
@ -2807,7 +2812,6 @@ a")
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows") @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
@skipIfRocm @skipIfRocm
@unittest.skip("Temporarily broken")
def test_lstm_fusion_cuda(self): def test_lstm_fusion_cuda(self):
inputs = get_lstm_inputs('cuda', training=True) inputs = get_lstm_inputs('cuda', training=True)
module = self.checkScript(LSTMCellS, inputs) module = self.checkScript(LSTMCellS, inputs)
@ -2820,7 +2824,6 @@ a")
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows") @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
@skipIfRocm @skipIfRocm
@unittest.skip("Temporarily broken")
def test_milstm_fusion_cuda(self): def test_milstm_fusion_cuda(self):
inputs = get_milstm_inputs('cuda', training=True) inputs = get_milstm_inputs('cuda', training=True)
module = self.checkScript(MiLSTMCell, inputs) module = self.checkScript(MiLSTMCell, inputs)

View File

@ -35,6 +35,10 @@ struct InputMetadata {
return device_; return device_;
} }
at::Tensor zeros_like() const {
return at::zeros(shape_, at::TensorOptions(*type_, static_cast<int32_t>(device_)));
}
private: private:
const at::Type* type_ = nullptr; const at::Type* type_ = nullptr;
at::DimVector shape_; at::DimVector shape_;

View File

@ -24,6 +24,7 @@ struct ArgumentInfo {
int device() const { int device() const {
return device_; return device_;
} }
// XXX: It is guaranteed that this will return false when called on non-tensor arguments
bool requires_grad() const { bool requires_grad() const {
return requires_grad_; return requires_grad_;
} }
@ -42,7 +43,8 @@ struct ArgumentInfo {
private: private:
unsigned is_tensor_ : 1; unsigned is_tensor_ : 1;
unsigned defined_ : 1; unsigned defined_ : 1;
unsigned requires_grad_ : 6; unsigned requires_grad_ : 1;
unsigned : 5;
unsigned dim_ : 8; unsigned dim_ : 8;
int device_ : 8; // NOTE: this needs to be signed because we use -1 to represent CPU int device_ : 8; // NOTE: this needs to be signed because we use -1 to represent CPU
unsigned type_ : 8; unsigned type_ : 8;
@ -59,6 +61,9 @@ struct ArgumentSpec {
int32_t num_inputs = inputs.size(); int32_t num_inputs = inputs.size();
for (int32_t i = 0; i < num_inputs; ++i) { for (int32_t i = 0; i < num_inputs; ++i) {
auto & arg = args[i]; auto & arg = args[i];
// Initialize all fields to 0. This is convenient, because e.g.
// requires_grad() can be checked even on tensors.
std::memset(&arg, 0, sizeof(ArgumentInfo));
arg.is_tensor_ = static_cast<unsigned>(inputs[i].isTensor()); arg.is_tensor_ = static_cast<unsigned>(inputs[i].isTensor());
if (arg.is_tensor_) { if (arg.is_tensor_) {
at::Tensor t = inputs[i].toTensor(); at::Tensor t = inputs[i].toTensor();

View File

@ -176,6 +176,12 @@ struct Attributes {
#undef CREATE_ACCESSOR #undef CREATE_ACCESSOR
// Our Graphs are not very const-correct, so we need to allow returning
// non-const references too
GraphAttr::ValueType& g(Symbol name) {
return get<GraphAttr>(name);
}
// does not use CREATE_ACCESSOR because we need additional asserts // does not use CREATE_ACCESSOR because we need additional asserts
Derived* t_(Symbol name, TensorAttr::ConstructorType v) { Derived* t_(Symbol name, TensorAttr::ConstructorType v) {
JIT_ASSERT(!v.defined() || !v.is_variable()); JIT_ASSERT(!v.defined() || !v.is_variable());

View File

@ -35,6 +35,7 @@ bool isDifferentiable(Node * n) {
"aten::neg(Tensor self) -> Tensor", "aten::neg(Tensor self) -> Tensor",
"aten::type_as(Tensor self, Tensor other) -> Tensor", "aten::type_as(Tensor self, Tensor other) -> Tensor",
"aten::unsqueeze(Tensor self, int dim) -> Tensor", "aten::unsqueeze(Tensor self, int dim) -> Tensor",
"aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta, Scalar alpha) -> Tensor",
"aten::mm(Tensor self, Tensor mat2) -> Tensor", "aten::mm(Tensor self, Tensor mat2) -> Tensor",
"aten::lt(Tensor self, Tensor other) -> Tensor", "aten::lt(Tensor self, Tensor other) -> Tensor",
"aten::le(Tensor self, Tensor other) -> Tensor", "aten::le(Tensor self, Tensor other) -> Tensor",
@ -97,6 +98,9 @@ static std::vector<Value*> gradientForNode(Node* node, ArrayRef<Value*> grad_val
if (node->matches("aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor")) { if (node->matches("aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor")) {
return {grads.at(0), grads.at(0) * node->namedInput(attr::alpha), nullptr}; return {grads.at(0), grads.at(0) * node->namedInput(attr::alpha), nullptr};
} else if (node->kind() == prim::AutogradAdd) {
return {grads.at(0), grads.at(0)};
} else if (node->matches("aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor")) { } else if (node->matches("aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor")) {
return {grads.at(0), -grads.at(0) * node->namedInput(attr::alpha), nullptr}; return {grads.at(0), -grads.at(0) * node->namedInput(attr::alpha), nullptr};
@ -136,29 +140,14 @@ static std::vector<Value*> gradientForNode(Node* node, ArrayRef<Value*> grad_val
} else if (node->matches("aten::unsqueeze(Tensor self, int dim) -> Tensor")) { } else if (node->matches("aten::unsqueeze(Tensor self, int dim) -> Tensor")) {
return {grads.at(0).squeeze(node->namedInput(attr::dim)), nullptr}; return {grads.at(0).squeeze(node->namedInput(attr::dim)), nullptr};
} else if (node->matches("aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta, Scalar alpha) -> Tensor")) {
return {grads.at(0) * node->namedInput(attr::beta),
grads.at(0).mm(inputs.at(2).t()) * node->namedInput(attr::alpha),
inputs.at(1).t().mm(grads.at(0)) * node->namedInput(attr::alpha),
nullptr, nullptr};
} else if (node->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor")) { } else if (node->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
SymbolicVariable dmat1, dmat2; return {grads.at(0).mm(inputs.at(1).t()), inputs.at(0).t().mm(grads.at(0))};
if (auto type = inputs.at(0).value()->type()->cast<CompleteTensorType>()) {
auto sizes = type->sizes(), strides = type->strides();
if (strides.at(0) == 1 && strides.at(1) == sizes.at(0)) {
dmat1 = inputs.at(1).mm(grads.at(0).t()).t();
} else {
dmat1 = grads.at(0).mm(inputs.at(1).t());
}
} else {
dmat1 = grads.at(0).mm(inputs.at(1).t());
}
if (auto type = inputs.at(1).value()->type()->cast<CompleteTensorType>()) {
auto sizes = type->sizes(), strides = type->strides();
if (strides.at(0) == 1 && strides.at(1) == sizes.at(0)) {
dmat2 = grads.at(0).t().mm(inputs.at(0)).t();
} else {
dmat2 = inputs.at(0).t().mm(grads.at(0));
}
} else {
dmat2 = inputs.at(0).t().mm(grads.at(0));
}
return {dmat1, dmat2};
} else if (node->matches("aten::expand(Tensor self, int[] size, *, int implicit) -> Tensor")) { } else if (node->matches("aten::expand(Tensor self, int[] size, *, int implicit) -> Tensor")) {
const auto& input_sizes = inputs.at(0).sizes(); const auto& input_sizes = inputs.at(0).sizes();
@ -364,7 +353,9 @@ static ReverseDetails addReverseInline(Gradient& grad_desc,
JIT_ASSERT(grad_inputs.size() == node->inputs().size()); JIT_ASSERT(grad_inputs.size() == node->inputs().size());
for (size_t i = 0, num_inputs = grad_inputs.size(); i < num_inputs; ++i) { for (size_t i = 0, num_inputs = grad_inputs.size(); i < num_inputs; ++i) {
if (!requires_grad(inputs[i])) continue; if (!requires_grad(inputs[i])) continue;
JIT_ASSERT(grad_inputs[i]); // NB: Not returning a gradient w.r.t. a value that requires grad is normal if the
// input is non-differentiable. This happens e.g. in the aten::type_as case.
if (!grad_inputs[i]) continue;
set_grad(inputs[i], grad_inputs[i]); set_grad(inputs[i], grad_inputs[i]);
} }
} }
@ -374,6 +365,11 @@ static ReverseDetails addReverseInline(Gradient& grad_desc,
Value * input = inputs[i]; Value * input = inputs[i];
if (!requires_grad(input)) if (!requires_grad(input))
continue; continue;
// NB: Not having a gradient defined w.r.t. an input to the graph which requires grad
// can happen and is not an error. It might have been used only in non-differentiable
// contexts (e.g. as second input to aten::type_as). In that case we simply ignore it
// as an output, because it won't ever produce any meaningful values.
if (grad_map.count(input) == 0) continue;
reverse_block->registerOutput(get_grad(input)); reverse_block->registerOutput(get_grad(input));
grad_desc.df_output_vjps.push_back(i); grad_desc.df_output_vjps.push_back(i);
} }
@ -557,7 +553,7 @@ Gradient differentiate(std::shared_ptr<Graph>& graph, const std::vector<bool>& r
// addReverseInline has to call gradientForNode if *any* of the outputs // addReverseInline has to call gradientForNode if *any* of the outputs
// require grad, but it will emit vjps for *all* outputs. Use DCE to remove // require grad, but it will emit vjps for *all* outputs. Use DCE to remove
// unnecessary nodes. // unnecessary nodes.
EliminateDeadCode(grad_desc.f); EliminateDeadCode(rev_info.reverse_block);
// Fills in f, df, f_real_outputs, df_input_captures, // Fills in f, df, f_real_outputs, df_input_captures,
// modifies df_input_vjps (new vjps are added for temporaries) // modifies df_input_vjps (new vjps are added for temporaries)
lambdaLiftReverse(grad_desc, rev_info); lambdaLiftReverse(grad_desc, rev_info);

View File

@ -24,6 +24,7 @@
#include "torch/csrc/jit/passes/constant_propagation.h" #include "torch/csrc/jit/passes/constant_propagation.h"
#include "torch/csrc/jit/symbolic_variable.h" #include "torch/csrc/jit/symbolic_variable.h"
#include "torch/csrc/jit/ivalue.h" #include "torch/csrc/jit/ivalue.h"
#include "torch/csrc/jit/custom_operator.h"
#include "torch/csrc/autograd/edge.h" #include "torch/csrc/autograd/edge.h"
#include "torch/csrc/autograd/function.h" #include "torch/csrc/autograd/function.h"
@ -45,14 +46,34 @@ using tensor_list = std::vector<at::Tensor>;
using Variable = autograd::Variable; using Variable = autograd::Variable;
using autograd::variable_list; using autograd::variable_list;
// this type is in ExecutionPlan to run its Gradient if it is struct ExecutionPlan {
// specified. It has a list of inputs captured by ExecutionPlan that ExecutionPlan() = default;
// it concats with inputs to form the full set of inputs to graph. ExecutionPlan(std::shared_ptr<Graph> graph)
// see struct Gradient for a description of how the derivative graph : code(graph)
// is constructed and what variables are captured. , graph(std::move(graph)) {}
struct ExecutionPlanAutogradFunction : public autograd::Function {
ExecutionPlanAutogradFunction(GraphExecutor graph, size_t capture_size) void run(Stack& stack) const {
: graph(std::move(graph)) { return InterpreterState(code).runOneStage(stack);
}
operator bool() const {
return static_cast<bool>(graph);
}
ExecutionPlanState getDebugState() {
ExecutionPlanState state;
state.code = &code;
state.graph = graph.get();
return state;
}
Code code;
std::shared_ptr<Graph> graph;
};
struct DifferentiableGraphBackward : public autograd::Function {
DifferentiableGraphBackward(GraphExecutor executor, size_t capture_size)
: executor(std::move(executor)) {
is_var_capture.reserve(capture_size); is_var_capture.reserve(capture_size);
var_captures.reserve(capture_size); var_captures.reserve(capture_size);
ivalue_captures.reserve(capture_size); ivalue_captures.reserve(capture_size);
@ -74,10 +95,28 @@ struct ExecutionPlanAutogradFunction : public autograd::Function {
++ivalue_capture_it; ++ivalue_capture_it;
} }
} }
graph.run(stack);
return fmap(stack, [](IValue & val) { executor.run(stack);
return autograd::Variable(std::move(val).toTensor()); JIT_ASSERT(stack.size() == num_outputs());
});
variable_list outputs;
outputs.reserve(num_outputs());
for (size_t i = 0; i < num_outputs(); ++i) {
if (should_compute_output(i)) {
auto output = std::move(stack[i]).toTensor();
const auto & edge = next_edge(i);
if (output.defined()) {
outputs.push_back(std::move(output));
} else if (edge.is_valid()) {
outputs.push_back(edge.function->input_metadata(edge.input_nr).zeros_like());
} else {
outputs.emplace_back();
}
} else {
outputs.emplace_back();
}
}
return outputs;
} }
void capture(const IValue & val, bool is_output) { void capture(const IValue & val, bool is_output) {
@ -91,7 +130,7 @@ struct ExecutionPlanAutogradFunction : public autograd::Function {
} }
private: private:
friend struct ExecutionPlan; friend struct ExecutionPlan;
GraphExecutor graph; GraphExecutor executor;
// INVARIANT: is_var_capture.size() == var_captures.size() + ivalue_captures.size() // INVARIANT: is_var_capture.size() == var_captures.size() + ivalue_captures.size()
std::vector<bool> is_var_capture; std::vector<bool> is_var_capture;
@ -104,74 +143,17 @@ private:
// This will unwrap Variables, run the plan, and re-wrap them. // This will unwrap Variables, run the plan, and re-wrap them.
// It can optionally also have a gradient which is hooked up // It can optionally also have a gradient which is hooked up
// to the output Variables if present. // to the output Variables if present.
struct ExecutionPlan { struct DifferentiableGraphOp {
ExecutionPlan(std::shared_ptr<Graph>& graph) DifferentiableGraphOp(Gradient grad)
: f(graph), : f(grad.f),
graph(graph),
num_inputs(graph->inputs().size()),
num_outputs(graph->outputs().size()) {}
ExecutionPlan(std::shared_ptr<Graph>& graph, Gradient grad)
: f(graph),
graph(graph),
grad(std::move(grad)), grad(std::move(grad)),
grad_executor(this->grad.df), grad_executor(this->grad.df),
num_inputs(graph->inputs().size()), num_inputs(this->grad.f->inputs().size()),
num_outputs(graph->outputs().size()) {} num_outputs(this->grad.f->outputs().size()) {}
void run(Stack & stack) const {
if (grad) {
return runWithGrad(stack);
}
InterpreterState(f).runOneStage(stack);
}
std::shared_ptr<Graph> get_graph() const {
return graph;
}
ExecutionPlanState getDebugState() {
ExecutionPlanState state;
state.f = &f;
state.graph = graph.get();
if (grad) {
state.grad = &grad;
state.grad_executor = std::unique_ptr<GraphExecutorState>(
new GraphExecutorState(grad_executor.getDebugState()));
} else {
state.grad = nullptr;
state.grad_executor.reset();
}
return state;
}
private:
void detachVariables(Stack & stack) const {
// It would be nice to use an ArrayRef here, but unfortunately those can only
// return const references, so we need to do a bunch of indexing ourselves.
const int64_t stack_size = stack.size();
const int64_t stack_offset = stack_size - num_inputs;
for (int64_t i = stack_offset; i < stack_size; ++i) {
auto & v = stack[i];
if (!v.isTensor()) continue;
auto t = std::move(v).toTensor();
v = IValue{t.defined() ? autograd::as_variable_ref(t).detach() : std::move(t)};
}
}
// Capture (save) inputs that would be required to subsequently run backwards
void captureInputs(ExecutionPlanAutogradFunction & grad_fn, at::ArrayRef<IValue> inputs) const {
for (size_t offset : grad.df_input_captured_inputs) {
grad_fn.capture(inputs[offset], /*is_output*/false);
}
}
void captureOutputs(ExecutionPlanAutogradFunction & grad_fn, at::ArrayRef<IValue> outputs) const {
for (size_t offset : grad.df_input_captured_outputs) {
grad_fn.capture(outputs[offset], /*is_output*/true);
}
}
// XXX: keep in mind that stack can be larger than the inputs we need! // XXX: keep in mind that stack can be larger than the inputs we need!
void runWithGrad(Stack & stack) const { int operator()(Stack & stack) const {
auto grad_fn = std::make_shared<ExecutionPlanAutogradFunction>(grad_executor, auto grad_fn = std::make_shared<DifferentiableGraphBackward>(grad_executor,
grad.df_input_captured_inputs.size() + grad.df_input_captured_outputs.size()); grad.df_input_captured_inputs.size() + grad.df_input_captured_outputs.size());
{ {
@ -179,7 +161,7 @@ private:
// hook up the outputs of df to the gradient functions of the inputs that require gradients // hook up the outputs of df to the gradient functions of the inputs that require gradients
for(auto idx : grad.df_output_vjps) { for(auto idx : grad.df_output_vjps) {
auto v = Variable(inputs[idx].toTensor()); auto v = Variable(inputs[idx].toTensor());
grad_fn->add_next_edge(v.gradient_edge()); grad_fn->add_next_edge(v.defined() ? v.gradient_edge() : autograd::Edge{});
} }
captureInputs(*grad_fn, inputs); captureInputs(*grad_fn, inputs);
} }
@ -201,8 +183,15 @@ private:
// reallocate variables that were already created in wrapTensors. We // reallocate variables that were already created in wrapTensors. We
// should add an API for this. // should add an API for this.
Variable output = outputs[idx].toTensor(); Variable output = outputs[idx].toTensor();
autograd::create_gradient_edge(output, grad_fn); // NB: since our requires_grad setting is only a heuristic we might end up
output.set_requires_grad(true); // wanting to differentiate through integral tensors, which is generally a
// hard error in autograd.
if (at::isFloatingType(output.type().scalarType())) {
autograd::create_gradient_edge(output, grad_fn);
output.set_requires_grad(true);
} else {
grad_fn->add_input_metadata(autograd::Function::undefined_input{});
}
} }
captureOutputs(*grad_fn, outputs); captureOutputs(*grad_fn, outputs);
// drop the temporary outputs so that we return the same number of // drop the temporary outputs so that we return the same number of
@ -210,22 +199,89 @@ private:
const size_t num_temporary_outputs = num_outputs - grad.f_real_outputs; const size_t num_temporary_outputs = num_outputs - grad.f_real_outputs;
stack.erase(stack.end() - num_temporary_outputs, stack.end()); stack.erase(stack.end() - num_temporary_outputs, stack.end());
} }
return 0;
}
private:
friend GraphExecutor* detail::getGradExecutor(Operation& op);
void detachVariables(Stack & stack) const {
// It would be nice to use an ArrayRef here, but unfortunately those can only
// return const references, so we need to do a bunch of indexing ourselves.
const int64_t stack_size = stack.size();
const int64_t stack_offset = stack_size - num_inputs;
for (int64_t i = stack_offset; i < stack_size; ++i) {
auto & v = stack[i];
if (!v.isTensor()) continue;
auto t = std::move(v).toTensor();
v = IValue{t.defined() ? autograd::as_variable_ref(t).detach() : std::move(t)};
}
}
// Capture (save) inputs that would be required to subsequently run backwards
void captureInputs(DifferentiableGraphBackward & grad_fn, at::ArrayRef<IValue> inputs) const {
for (size_t offset : grad.df_input_captured_inputs) {
grad_fn.capture(inputs[offset], /*is_output*/false);
}
}
void captureOutputs(DifferentiableGraphBackward & grad_fn, at::ArrayRef<IValue> outputs) const {
for (size_t offset : grad.df_input_captured_outputs) {
grad_fn.capture(outputs[offset], /*is_output*/true);
}
} }
Code f; Code f;
// optimized graph for debugging and testing Gradient grad;
std::shared_ptr<Graph> graph;
// description of gradient as a graph
Gradient grad; // if(grad) is false when this is unused
// executor for df, including code caches
GraphExecutor grad_executor; GraphExecutor grad_executor;
const size_t num_inputs; const size_t num_inputs;
const size_t num_outputs; const size_t num_outputs;
}; };
void packGradient(Gradient gradient, Node *dnode) {
JIT_ASSERT(dnode->kind() == prim::DifferentiableGraph);
dnode->g_(attr::Subgraph, gradient.f)
->g_(attr::ReverseSubgraph, gradient.df)
->i_(attr::f_real_outputs, gradient.f_real_outputs)
->is_(attr::df_input_vjps, fmap<int64_t>(gradient.df_input_vjps))
->is_(attr::df_input_captured_inputs, fmap<int64_t>(gradient.df_input_captured_inputs))
->is_(attr::df_input_captured_outputs, fmap<int64_t>(gradient.df_input_captured_outputs))
->is_(attr::df_output_vjps, fmap<int64_t>(gradient.df_output_vjps));
}
Gradient getGradient(Node *n) {
JIT_ASSERT(n->kind() == prim::DifferentiableGraph);
Gradient grad;
grad.f = n->g(attr::Subgraph);
grad.df = n->g(attr::ReverseSubgraph);
grad.f_real_outputs = n->i(attr::f_real_outputs);
grad.df_input_vjps = fmap<size_t>(n->is(attr::df_input_vjps));
grad.df_input_captured_inputs = fmap<size_t>(n->is(attr::df_input_captured_inputs));
grad.df_input_captured_outputs = fmap<size_t>(n->is(attr::df_input_captured_outputs));
grad.df_output_vjps = fmap<size_t>(n->is(attr::df_output_vjps));
return grad;
}
} // anonymous namespace } // anonymous namespace
RegisterOperators reg_graph_executor_ops({
Operator(
prim::DifferentiableGraph,
[](Node *n) -> Operation {
return DifferentiableGraphOp(getGradient(n));
})
});
namespace detail {
GraphExecutor* getGradExecutor(Operation& op) {
if (auto diff_op = op.target<DifferentiableGraphOp>()) {
return &diff_op->grad_executor;
}
return nullptr;
}
} // namespace detail
// a Graph can be created via tracing, or via a language-based frontend // a Graph can be created via tracing, or via a language-based frontend
// GraphExecutor runs it. It can run the same graph on many different sizes // GraphExecutor runs it. It can run the same graph on many different sizes
// and different requires_grad states, and handles specializations for each situation. // and different requires_grad states, and handles specializations for each situation.
@ -233,68 +289,49 @@ private:
// tracing concerns separated. // tracing concerns separated.
struct GraphExecutorImpl { struct GraphExecutorImpl {
GraphExecutorImpl(std::shared_ptr<Graph> graph, bool optimize, bool symbolically_differentiable) static std::shared_ptr<Graph> prepareGraph(std::shared_ptr<Graph> graph) {
: graph(std::move(graph)) auto copy = graph->copy();
, optimize(optimize) EraseShapeInformation(*copy);
, num_inputs(this->graph->inputs().size()) return copy;
, num_outputs(this->graph->outputs().size()) }
, symbolically_differentiable(symbolically_differentiable)
, may_introduce_gradient(calcMayIntroduceGradient(this->graph->block())) {}
GraphExecutorImpl(std::shared_ptr<Graph> graph, bool optimize) GraphExecutorImpl(std::shared_ptr<Graph> graph, bool optimize)
: GraphExecutorImpl(graph, optimize, isDifferentiable(*graph)) {} : graph(prepareGraph(graph))
, optimize(optimize)
, num_inputs(this->graph->inputs().size())
, num_outputs(this->graph->outputs().size()) {}
// entry point where execution begins // entry point where execution begins
void run(Stack & stack) { void run(Stack & stack) {
if(stack.size() < num_inputs) { AT_CHECK(stack.size() >= num_inputs, "expected ", num_inputs, " inputs, but got only ", stack.size());
std::stringstream ss;
ss << "expected " << num_inputs << " inputs but got " << stack.size() << " inputs";
throw std::runtime_error(ss.str());
}
auto inputs = last(stack, num_inputs);
// the tracer has called a graph executor
// there is no need to optimize, but we do need to splice the graph of
// this excutor into the trace. Otherwise we might unroll control-flow
// operations.
if(tracer::isTracing()) { if(tracer::isTracing()) {
return runTraced(stack); return runTraced(stack);
} }
// this is the fallback pathway, when we cannot differentiate auto & execution_plan = optimize ? getOrCompile(stack) : getOrCompileFallback();
if(!optimize || (!symbolically_differentiable && needsGradient(inputs))) { return execution_plan.run(stack);
return runFallback(stack);
}
// either we can symbolically differentiate, or we do not need a gradient.
// go down the route where we treat the inputs as tensors
// and fully optimize
auto & implementation = getOrCompile(inputs);
return implementation.run(stack);
} }
std::shared_ptr<Graph> graphFor(const Stack& stack) const { std::shared_ptr<Graph> graphFor(const Stack& stack) const {
auto inputs = last(stack, num_inputs); auto inputs = last(stack, num_inputs);
ArgumentSpec spec(autograd::GradMode::is_enabled(), inputs); ArgumentSpec spec(autograd::GradMode::is_enabled(), inputs);
if (!optimize || (!symbolically_differentiable && needsGradient(inputs))) { if (!optimize) {
JIT_ASSERTM(autograd_fallback_graph, "No graph found for given inputs"); AT_CHECK(fallback, "No graph found for given inputs");
return autograd_fallback_graph; return fallback.graph;
} }
auto it = plan_cache.find(spec); auto it = plan_cache.find(spec);
JIT_ASSERTM(it != plan_cache.end(), "No graph found for given inputs"); AT_CHECK(it != plan_cache.end(), "No graph found for given inputs");
return it->second.get_graph(); return it->second.graph;
} }
GraphExecutorState getDebugState() { GraphExecutorState getDebugState() {
GraphExecutorState state; GraphExecutorState state;
state.graph = graph.get(); state.graph = graph.get();
if (autograd_fallback) { if (fallback) {
state.autograd_fallback = &autograd_fallback; state.fallback = fallback.getDebugState();
state.autograd_fallback_graph = autograd_fallback_graph.get();
} else {
state.autograd_fallback = nullptr;
state.autograd_fallback_graph = nullptr;
} }
for (auto & entry : plan_cache) { for (auto & entry : plan_cache) {
state.execution_plans.emplace(entry.first, entry.second.getDebugState()); state.execution_plans.emplace(entry.first, entry.second.getDebugState());
@ -305,6 +342,121 @@ struct GraphExecutorImpl {
private: private:
friend struct GraphExecutor; friend struct GraphExecutor;
const ExecutionPlan & getOrCompileFallback() {
std::lock_guard<std::mutex> lock(compile_mutex);
if(!fallback) {
auto graph_ = graph->copy();
runRequiredPasses(graph_);
fallback = ExecutionPlan(graph_);
}
return fallback;
}
const ExecutionPlan & getOrCompile(const Stack& stack) {
// outside lock guard, to minimize the time holding the lock on the fast path
// ArgumentSpec even computes its hashCode here.
ArgumentSpec spec(autograd::GradMode::is_enabled(), last(stack, num_inputs));
{
std::lock_guard<std::mutex> lock(compile_mutex);
auto it = plan_cache.find(spec);
if (it != plan_cache.end())
return it->second;
auto plan = compileSpec(spec);
auto r = plan_cache.emplace(std::move(spec), std::move(plan));
return r.first->second;
}
}
ExecutionPlan compileSpec(const ArgumentSpec & spec) {
auto opt_graph = graph->copy();
// Phase 1. Specialize to input definedness (this is very important for
// gradient graphs), and run required passes to bring the graph
// to an executable form.
specializeGrad(opt_graph, spec);
runRequiredPasses(opt_graph);
// Phase 2. Propagate detailed information about the spec through the
// graph (enabled more specializations in later passes).
PropagateInputShapes(*opt_graph, spec);
// Phase 3. Run differentiable optimizations (i.e. simple graph rewrites that
// we can still execute using autograd).
runOptimization(opt_graph, spec);
// Phase 4. If this graph will be differentiated, we need to slice out the
// symbolically differentiable subgraphs for further optimizations.
// Phase 5. Apply non-differentiable optimizations to the graphs we've found
// (or the whole grpah if we know we won't need its derivative).
if (needsGradient(opt_graph, spec)) {
auto diff_nodes = CreateAutodiffSubgraphs(*opt_graph);
for (Node * dnode : diff_nodes) {
// XXX: we don't have requires_grad information on the intermediate values,
// so we conservatively assume it's always true (on tensor inputs).
auto diff_graph = std::move(dnode->g(attr::Subgraph));
auto requires_grads = fmap(diff_graph->inputs(), [](Value* v) {
// NB: only floating-point inputs can have requires_grad=True. If we
// don't have type information, we have to assume that it's true.
if (auto tensor_type = v->type()->cast<TensorType>()) {
return at::isFloatingType(tensor_type->scalarType());
}
return v->type()->isSubtypeOf(DynamicType::get());
});
Gradient gradient = differentiate(diff_graph, requires_grads);
runNondiffOptimization(gradient.f);
packGradient(gradient, dnode);
}
} else {
runNondiffOptimization(opt_graph);
}
return ExecutionPlan(opt_graph);
}
void specializeGrad(std::shared_ptr<Graph>& graph, const ArgumentSpec& spec) {
std::vector<bool> defined;
for (size_t i = 0; i < spec.size(); ++i)
defined.push_back(spec.at(i).defined());
specializeUndef(*graph, defined);
}
void runOptimization(std::shared_ptr<Graph>& graph, const ArgumentSpec& spec) {
EliminateDeadCode(graph);
EliminateCommonSubexpression(graph);
UnrollLoops(graph);
ConstantPropagation(graph);
PeepholeOptimize(graph);
CheckInplace(graph);
BatchMM(graph);
}
void runNondiffOptimization(std::shared_ptr<Graph>& graph) {
FuseGraph(graph);
}
static bool needsGradient(const std::shared_ptr<const Graph>& graph, const ArgumentSpec& spec) {
if (!autograd::GradMode::is_enabled())
return false;
if (mayIntroduceGradient(graph->block()))
return true;
for(size_t i = 0; i < spec.size(); ++i) {
if(spec.at(i).requires_grad())
return true;
}
return false;
}
static bool mayIntroduceGradient(const Block* b) {
for (const Node* n : b->nodes()) {
if (n->kind() == prim::PythonOp)
return true;
for (const Block* bb : n->blocks()) {
if (mayIntroduceGradient(bb))
return true;
}
}
return false;
}
void runTraced(Stack & stack) { void runTraced(Stack & stack) {
auto state = tracer::getTracingState(); auto state = tracer::getTracingState();
auto inputs = last(stack, num_inputs); auto inputs = last(stack, num_inputs);
@ -313,25 +465,18 @@ private:
}); });
ArgumentSpec spec(autograd::GradMode::is_enabled(), inputs); ArgumentSpec spec(autograd::GradMode::is_enabled(), inputs);
runFallback(stack); // NB: we could just run the fallback in here and call it a day, but that would loose all
// the control flow information we have in the graph. Thus, we run the fallback to
// get the correct output values, but we will override the tracing states later.
getOrCompileFallback().run(stack);
auto all_dynamic = [](const at::ArrayRef<Value*> xs) {
for(Value* x : xs) {
if(x->type()->kind() != TypeKind::DynamicType)
return false;
}
return true;
};
// Traces always have types propagated through them, so we make sure to // Traces always have types propagated through them, so we make sure to
// also propagate types through the graph we are inserting here. // also propagate types through the graph we are inserting here.
// However, this->graph itself may already have been generated with // However, this->graph itself may already have been generated with
// tracing and so we only do the type propgation if no concrete types have // tracing and so we only do the type propgation if no concrete types have
// been set. // been set.
auto local_graph = this->graph; auto local_graph = this->graph->copy();
if(all_dynamic(local_graph->inputs()) && all_dynamic(local_graph->outputs())) { PropagateInputShapes(*local_graph, spec);
local_graph = this->graph->copy();
PropagateInputShapes(*local_graph, spec);
}
auto output_values = script::inlineCallTo(*state->graph, *local_graph, input_values); auto output_values = script::inlineCallTo(*state->graph, *local_graph, input_values);
auto outputs = last(stack, num_outputs); auto outputs = last(stack, num_outputs);
@ -343,147 +488,32 @@ private:
} }
} }
void runFallback(Stack & stack) { // The unoptimized starting graph. This field is effectively const, but we can't make it so
auto & fb = getOrCreateAutogradFallback(); // because Graph::copy() is not const (and making it const is not that easy at this point).
InterpreterState(fb).runOneStage(stack);
}
static bool calcMayIntroduceGradient(Block* b) {
for(Node* n : b->nodes()) {
if(n->kind() == prim::PythonOp)
return true;
for(Block* bb : n->blocks()) {
if(calcMayIntroduceGradient(bb))
return true;
}
}
return false;
}
bool needsGradient(at::ArrayRef<IValue> inputs) const {
if (!autograd::GradMode::is_enabled()) {
return false;
}
if (may_introduce_gradient)
return true;
for (const IValue & value : inputs) {
if (!value.isTensor()) continue;
auto t = value.toTensor();
if (t.defined() && autograd::as_variable_ref(t).requires_grad())
return true;
}
return false;
}
const Code & getOrCreateAutogradFallback() {
std::lock_guard<std::mutex> lock(compile_mutex);
if(autograd_fallback) {
return autograd_fallback;
}
auto graph_ = graph->copy();
runRequiredPasses(graph_);
if(optimize) {
runOptimization(graph_, /*graphMustSupportVariables=*/true);
if(!isDifferentiable(*graph_)) {
EraseShapeInformation(*graph_);
CreateAutodiffSubgraphs(*graph_);
}
}
autograd_fallback_graph = graph_;
autograd_fallback = Code(graph_);
return autograd_fallback;
}
const ExecutionPlan & getOrCompile(at::ArrayRef<IValue> inputs) {
// outside lock guard, to minimize the time holding the lock on the fast path
// ArgumentSpec even computes its hashCode here.
ArgumentSpec spec(autograd::GradMode::is_enabled(), inputs);
{
std::lock_guard<std::mutex> lock(compile_mutex);
auto it = plan_cache.find(spec);
if(it != plan_cache.end())
return it->second;
auto plan = compileSpec(spec);
auto r = plan_cache.emplace(std::move(spec), std::move(plan));
return r.first->second;
}
}
bool argumentSpecRequiresGradient(const ArgumentSpec & spec) {
for(size_t i = 0; i < spec.size(); ++i) {
if(spec.at(i).requires_grad())
return true;
}
return false;
}
ExecutionPlan compileSpec(const ArgumentSpec & spec) {
auto graph_ = graph->copy();
specializeToSpec(graph_, spec);
if(!argumentSpecRequiresGradient(spec)) {
runOptimization(graph_, /*graphMustSupportVariables=*/false);
return ExecutionPlan(graph_);
}
JIT_ASSERT(symbolically_differentiable);
std::vector<bool> requires_grads;
requires_grads.reserve(spec.size());
for(size_t i = 0; i < spec.size(); i++)
requires_grads.push_back(spec.at(i).requires_grad());
Gradient gradient = differentiate(graph_, requires_grads);
graph_ = gradient.f;
runOptimization(graph_, /*graphMustSupportVariables=*/false);
return ExecutionPlan(graph_, std::move(gradient));
}
// the unoptimized starting graph
// this is never mutated
std::shared_ptr<Graph> graph; std::shared_ptr<Graph> graph;
// true - do everything we can to make this graph run fast // If false, we'll run the graph as we get it, without any optimizations. Useful
// false - do not modifiy the graph at all and just use the interpreter // for debugging.
// to run the graph. Useful for debugging correctness issues in the implementation
const bool optimize; const bool optimize;
const size_t num_inputs; const size_t num_inputs;
const size_t num_outputs; const size_t num_outputs;
// GraphExecutor optimizes more aggresively when we _know_ the graph will be // Populated only when optimize is false (and in that case plan_cache will be unused).
// symbolically differentiable. // The compiled version of graph.
bool symbolically_differentiable; ExecutionPlan fallback;
// some ops, including python operations, can intorduce requires_grad=True // Mapping from argument configurations to optimized versions of the graph that are
// variables even though no inputs to this graph are availiable, if // specialized to the spec.
// the graph includes those operators then needGradient must be true
// regardles of input state.
bool may_introduce_gradient;
// when this graph has some parts that are not symbolically_differentable,
// but some input does require a derivative, we create and use autograd_fallback,
// which wraps up the fully differentiable subgraphs, and then runs the outer
// graph through autograd.
// Since we can't optimize black box functions anyway, there is only one fallback path,
// and it must work on all sizes (so no optimizations that inspect sizes can run on it)
std::shared_ptr<Graph> autograd_fallback_graph;
Code autograd_fallback;
// optimizable code paths, used when we can differentiate or when no derivative is needed
// Spec describes input conditions, Plan describes how to execute them.
std::unordered_map<ArgumentSpec, ExecutionPlan> plan_cache; std::unordered_map<ArgumentSpec, ExecutionPlan> plan_cache;
// GraphExecutor can be accessed from multiple thread so // GraphExecutors can be accessed from multiple threads, so this thread needs to be
// anytime we are checking or updating the autograd_fallback or // held every time we access the fallback or plan_cache.
// plan_cache, we must hold the compile mutex.
// along the fast path (no compilation) code should
// hold this for as little time as possible.
std::mutex compile_mutex; std::mutex compile_mutex;
}; };
GraphExecutor::GraphExecutor(std::shared_ptr<Graph> graph, bool optimize) GraphExecutor::GraphExecutor(std::shared_ptr<Graph> graph, bool optimize)
: pImpl(new GraphExecutorImpl(std::move(graph), optimize)) {} : pImpl(new GraphExecutorImpl(std::move(graph), optimize)) {}
GraphExecutor::GraphExecutor(std::shared_ptr<Graph> graph, bool optimize, bool symbolically_differentiable)
: pImpl(new GraphExecutorImpl(std::move(graph), optimize, symbolically_differentiable)) {}
void GraphExecutor::run(Stack & inputs) { void GraphExecutor::run(Stack & inputs) {
return pImpl->run(inputs); return pImpl->run(inputs);
} }
@ -509,49 +539,7 @@ void runRequiredPasses(const std::shared_ptr<Graph>& g) {
// add valid expand nodes when the shapes are stable // add valid expand nodes when the shapes are stable
RemoveExpands(g); RemoveExpands(g);
CanonicalizeOps(g); CanonicalizeOps(g);
} EliminateDeadCode(g);
void specializeToSpec(const std::shared_ptr<Graph>& graph, const ArgumentSpec& spec) {
// clean up GradOf and AutogradAdd nodes
// this must be first because later passes do not know what GradOfs are
std::vector<bool> defined;
for(size_t i = 0; i < spec.size(); ++i) {
defined.push_back(spec.at(i).defined());
}
specializeUndef(*graph, defined);
// required passes shared with autograd fallback
runRequiredPasses(graph);
// clean up dead constants from specialization
EliminateDeadCode(graph);
// calculate all input shapes
PropagateInputShapes(*graph, spec);
}
void runOptimization(std::shared_ptr<Graph> & graph, bool graphMustSupportVariables) {
// these optimizations must run in the presence of variables
// and when shape information is not statically known.
EliminateDeadCode(graph);
CheckInplace(graph);
EliminateCommonSubexpression(graph);
if (!graphMustSupportVariables) {
// These optimizations can introduce operators like FusionGroup that
// do not work on variables
// They also may assume that concrete sizes/strides are availiable
UnrollLoops(graph);
ConstantPropagation(graph);
//TODO: create peephole optimizations that are safe to run
// when we are using variables, and when we do not know sizes.
PeepholeOptimize(graph);
// TODO: remove mandatory size checking in BatchMM, otherwise
// it works fine on variables.
BatchMM(graph);
FuseGraph(graph);
}
} }
}} }}

View File

@ -15,29 +15,20 @@ struct GraphExecutorState;
// They is only valid only right after you call getDebugState() and should never // They is only valid only right after you call getDebugState() and should never
// be used again once another GraphExecutor function is called. // be used again once another GraphExecutor function is called.
struct ExecutionPlanState { struct ExecutionPlanState {
Code* f; Code* code = nullptr;
Graph* graph; const Graph* graph = nullptr;
// Those two fields are optional
Gradient* grad;
std::shared_ptr<GraphExecutorState> grad_executor; // shared_ptr to break the cycle...
}; };
struct GraphExecutorState { struct GraphExecutorState {
Graph* graph; const Graph* graph;
ExecutionPlanState fallback; // XXX: members of this field are optional
std::unordered_map<ArgumentSpec, ExecutionPlanState> execution_plans; std::unordered_map<ArgumentSpec, ExecutionPlanState> execution_plans;
// Those two fields are optional
Code* autograd_fallback;
Graph* autograd_fallback_graph;
}; };
struct GraphExecutorImpl; struct GraphExecutorImpl;
struct TORCH_API GraphExecutor { struct TORCH_API GraphExecutor {
GraphExecutor() = default; GraphExecutor() = default;
GraphExecutor(std::shared_ptr<Graph> graph, bool optimize = true); GraphExecutor(std::shared_ptr<Graph> graph, bool optimize = true);
// note: if not specified, symbolically_differentiable is computed from the graph.
GraphExecutor(std::shared_ptr<Graph> graph, bool optimize, bool symbolically_differentiable);
void run(Stack & inputs); void run(Stack & inputs);
explicit operator bool() const { explicit operator bool() const {
return pImpl != nullptr; return pImpl != nullptr;
@ -53,15 +44,11 @@ private:
// regardless of whether sizes have been specialized or not. // regardless of whether sizes have been specialized or not.
TORCH_API void runRequiredPasses(const std::shared_ptr<Graph>& g); TORCH_API void runRequiredPasses(const std::shared_ptr<Graph>& g);
// specialize 'graph' to the types, sizes, and other properties described in spec namespace detail {
// this prepares the graph for execution, including running runRequiredPasses,
// but the execution only remains valid for tensors whose properties match spec GraphExecutor* getGradExecutor(Operation& op);
// otherwise running the graph will have undefined results.
TORCH_API void specializeToSpec(const std::shared_ptr<Graph>& graph, const ArgumentSpec& spec); } // namespace detail
// apply standard optimizations. if graphMustSupportVariables=false then
// then the passes are allowed to modify the graph in ways that make it no longer
// work with tensors that have requires_grad=True
TORCH_API void runOptimization(std::shared_ptr<Graph> & graph, bool graphMustSupportVariables);
}} }}

View File

@ -129,8 +129,8 @@ void initJITBindings(PyObject *module) {
}); });
py::class_<ArgumentSpec>(m, "ArgumentSpec"); py::class_<ArgumentSpec>(m, "ArgumentSpec");
py::class_<Code>(m, "Code") py::class_<Code>(m, "Code")
.def("executors", [](Code& c) { .def("grad_executors", [](Code& c) {
return py::make_iterator(c.executors().begin(), c.executors().end()); return py::make_iterator(c.grad_executors().begin(), c.grad_executors().end());
}); });
py::class_<ExecutionPlanState>(m, "ExecutionPlanState") py::class_<ExecutionPlanState>(m, "ExecutionPlanState")
@ -138,10 +138,7 @@ void initJITBindings(PyObject *module) {
return s.graph; return s.graph;
}) })
.def_property_readonly("code", [](ExecutionPlanState& s) { .def_property_readonly("code", [](ExecutionPlanState& s) {
return s.f; return s.code;
})
.def_property_readonly("grad_executor", [](ExecutionPlanState& s) {
return s.grad_executor.get();
}); });
py::class_<Gradient>(m, "Gradient") py::class_<Gradient>(m, "Gradient")
@ -174,11 +171,8 @@ void initJITBindings(PyObject *module) {
.def_property_readonly("execution_plans", [](GraphExecutorState& s) { .def_property_readonly("execution_plans", [](GraphExecutorState& s) {
return s.execution_plans; return s.execution_plans;
}) })
.def_property_readonly("autograd_fallback", [](GraphExecutorState& s) { .def_property_readonly("fallback", [](GraphExecutorState& s) {
return s.autograd_fallback; return s.fallback;
})
.def_property_readonly("autograd_fallback_graph", [](GraphExecutorState& s) {
return s.autograd_fallback_graph;
}); });
py::class_<GraphExecutor>(m, "GraphExecutor", py::dynamic_attr()) py::class_<GraphExecutor>(m, "GraphExecutor", py::dynamic_attr())
@ -204,7 +198,7 @@ void initJITBindings(PyObject *module) {
.def_property_readonly("graph", [](GraphExecutor& ge) { .def_property_readonly("graph", [](GraphExecutor& ge) {
return ge.graph(); return ge.graph();
}) })
.def("get_debug_state", [](GraphExecutor& ge) { .def("get_debug_state", [](GraphExecutor& ge) {
return ge.getDebugState(); return ge.getDebugState();
}) })
.def("__call__", [](GraphExecutor& ge, py::args args) -> py::object { .def("__call__", [](GraphExecutor& ge, py::args args) -> py::object {

View File

@ -24,7 +24,7 @@ namespace torch { namespace jit {
_(prim, Eval) \ _(prim, Eval) \
_(prim, Expand) /* onnx */ \ _(prim, Expand) /* onnx */ \
_(prim, FusionGroup) \ _(prim, FusionGroup) \
_(prim, GraphExecutor) \ _(prim, DifferentiableGraph) \
_(prim, If) \ _(prim, If) \
_(prim, Jump) /* debug */ \ _(prim, Jump) /* debug */ \
_(prim, JumpNZ) /* debug */ \ _(prim, JumpNZ) /* debug */ \
@ -87,6 +87,12 @@ namespace torch { namespace jit {
_(onnx, Not) \ _(onnx, Not) \
FORALL_ATTR_BASE_SYMBOLS(_) \ FORALL_ATTR_BASE_SYMBOLS(_) \
_(attr, Subgraph) \ _(attr, Subgraph) \
_(attr, ReverseSubgraph) \
_(attr, f_real_outputs) \
_(attr, df_input_vjps) \
_(attr, df_input_captured_inputs) \
_(attr, df_input_captured_outputs) \
_(attr, df_output_vjps) \
_(attr, axes) \ _(attr, axes) \
_(attr, axis) \ _(attr, axis) \
_(attr, broadcast) \ _(attr, broadcast) \

View File

@ -515,7 +515,7 @@ struct CodeImpl {
size_t insertInstruction(Node * n) { size_t insertInstruction(Node * n) {
auto inst = insertInstruction(n->kind(), n->getSourceLocation(), n->inputs(), moveFlags(n) , n->outputs()); auto inst = insertInstruction(n->kind(), n->getSourceLocation(), n->inputs(), moveFlags(n) , n->outputs());
instructions[inst].callback = getInterpreterOperation(n); instructions[inst].callback = getOperation(n);
return inst; return inst;
} }
size_t insertInstruction(Symbol sym, size_t insertInstruction(Symbol sym,
@ -603,27 +603,16 @@ struct CodeImpl {
return r; return r;
} }
// Returns a function implementing functionality of a given node, const std::vector<GraphExecutor*>& grad_executors() {
// or nullptr if it's a no-op for autograd. if (!grad_executors_) {
Operation getInterpreterOperation(jit::Node* node) { grad_executors_.emplace();
if(node->kind() != prim::GraphExecutor) { for (Instruction & instr : instructions) {
return getOperation(node); if (auto executor = detail::getGradExecutor(instr.callback)) {
grad_executors_->push_back(executor);
}
}
} }
// recursive graph executors cannot be Operators because they return *grad_executors_;
// have to register themselves with the interpreter so that
// we can provide useful debugging information
auto executor = std::make_shared<GraphExecutor>(node->g(attr::Subgraph));
graph_executors.emplace_back(executor.get());
return [=](Stack& stack) mutable {
autograd::profiler::RecordFunction record("GraphExecutor");
executor->run(stack);
return 0;
};
}
const std::vector<GraphExecutor*>& executors() {
return graph_executors;
} }
void dumpInstruction(std::ostream & out, size_t pc) const { void dumpInstruction(std::ostream & out, size_t pc) const {
@ -664,7 +653,7 @@ struct CodeImpl {
// It is also very useful for debugging interpreter problems to // It is also very useful for debugging interpreter problems to
// keep this around. // keep this around.
std::shared_ptr<Graph> graph; std::shared_ptr<Graph> graph;
std::vector<GraphExecutor*> graph_executors; // for debugging at::optional<std::vector<GraphExecutor*>> grad_executors_;
PreprocessGraph preprocess; PreprocessGraph preprocess;
std::unordered_map<size_t, int> unique_to_reg; // map from unique of nodes to register in register table std::unordered_map<size_t, int> unique_to_reg; // map from unique of nodes to register in register table
@ -771,8 +760,8 @@ Code::Code(std::shared_ptr<Graph>& graph)
: pImpl(new CodeImpl(graph)) {} : pImpl(new CodeImpl(graph)) {}
Code::~Code() = default; Code::~Code() = default;
const std::vector<GraphExecutor*>& Code::executors() { const std::vector<GraphExecutor*>& Code::grad_executors() {
return pImpl->executors(); return pImpl->grad_executors();
} }
InterpreterState::InterpreterState(const Code & code) InterpreterState::InterpreterState(const Code & code)

View File

@ -29,8 +29,7 @@ struct TORCH_API Code {
Code(std::shared_ptr<Graph>& graph); Code(std::shared_ptr<Graph>& graph);
~Code(); ~Code();
// Returns pointers to GraphExecutors created to run GraphExecutor nodes in the given graph. const std::vector<GraphExecutor*>& grad_executors();
const std::vector<GraphExecutor*>& executors();
explicit operator bool() const { explicit operator bool() const {
return pImpl != nullptr; return pImpl != nullptr;

View File

@ -186,7 +186,7 @@ std::ostream& printNode(std::ostream & out, size_t level, const Node * n, std::v
IR_ELSE() IR_ELSE()
if(n->hasAttribute(attr::Subgraph) && groups) { if(n->hasAttribute(attr::Subgraph) && groups) {
out << n->kind().toQualString() << "_" << groups->size(); out << n->kind().toQualString() << "_" << groups->size();
if (n->numAttributes() > 1) { if (n->numAttributes() > 1 && n->kind() != prim::DifferentiableGraph) {
printAttributes(out, n, /*ignore_subgraph=*/true); printAttributes(out, n, /*ignore_subgraph=*/true);
} }
groups->push_back(n); groups->push_back(n);

View File

@ -24,6 +24,9 @@ std::shared_ptr<Graph> Canonicalize(const std::shared_ptr<Graph>& graph) {
r_outputs.at(i)->setStage(outputs.at(i)->stage()); r_outputs.at(i)->setStage(outputs.at(i)->stage());
rn_env[outputs.at(i)] = r_outputs.at(i); rn_env[outputs.at(i)] = r_outputs.at(i);
} }
if (node->hasAttribute(attr::Subgraph)) {
r_node->g_(attr::Subgraph, Canonicalize(node->g(attr::Subgraph)));
}
} }
for (auto* output : graph->outputs()) { for (auto* output : graph->outputs()) {
r->registerOutput(rn_fn(output)); r->registerOutput(rn_fn(output));

View File

@ -17,7 +17,7 @@ namespace {
// right before nodes[0] (i.e. it will not create cycles and all uses of // right before nodes[0] (i.e. it will not create cycles and all uses of
// new node will be after this position). // new node will be after this position).
// prereq: nodes are in topological order // prereq: nodes are in topological order
void mergeNodes(Block * block, Symbol group_node_kind, ArrayRef<Node*> nodes) { Node* mergeNodes(Block * block, Symbol group_node_kind, ArrayRef<Node*> nodes) {
JIT_ASSERT(nodes.size() > 0); JIT_ASSERT(nodes.size() > 0);
std::unordered_map<Value*, Value*> value_map; std::unordered_map<Value*, Value*> value_map;
Graph * graph = block->owningGraph(); Graph * graph = block->owningGraph();
@ -66,11 +66,12 @@ void mergeNodes(Block * block, Symbol group_node_kind, ArrayRef<Node*> nodes) {
nodes[i - 1]->destroy(); nodes[i - 1]->destroy();
} }
JIT_ASSERT(isDifferentiable(*new_graph)); JIT_ASSERT(isDifferentiable(*new_graph));
return group_node;
} }
} }
void CreateAutodiffSubgraphs(Block * block, size_t threshold) { void CreateAutodiffSubgraphs(Block * block, size_t threshold, std::vector<Node*>& diff_graphs) {
// This implementation is not optimal, but it is simple. // This implementation is not optimal, but it is simple.
// It just scans through the list in order looking for runs of // It just scans through the list in order looking for runs of
// differentiable ops, and then grouping them together when // differentiable ops, and then grouping them together when
@ -93,21 +94,23 @@ void CreateAutodiffSubgraphs(Block * block, size_t threshold) {
groupable.push_back(node); groupable.push_back(node);
} else { } else {
if(groupable.size() >= threshold) { if(groupable.size() >= threshold) {
mergeNodes(block, prim::GraphExecutor, groupable); diff_graphs.push_back(mergeNodes(block, prim::DifferentiableGraph, groupable));
} }
groupable.clear(); groupable.clear();
for (Block * sub_block : node->blocks()) { for (Block * sub_block : node->blocks()) {
CreateAutodiffSubgraphs(sub_block, threshold); CreateAutodiffSubgraphs(sub_block, threshold, diff_graphs);
} }
} }
} }
if(groupable.size() >= threshold) { if(groupable.size() >= threshold) {
mergeNodes(block, prim::GraphExecutor, groupable); diff_graphs.push_back(mergeNodes(block, prim::DifferentiableGraph, groupable));
} }
} }
void CreateAutodiffSubgraphs(Graph & graph, size_t threshold) { std::vector<Node*> CreateAutodiffSubgraphs(Graph & graph, size_t threshold) {
CreateAutodiffSubgraphs(graph.block(), threshold); std::vector<Node*> diff_nodes;
CreateAutodiffSubgraphs(graph.block(), threshold, diff_nodes);
return diff_nodes;
} }

View File

@ -8,6 +8,7 @@ struct Graph;
// insert GraphExecutor nodes that group together // insert GraphExecutor nodes that group together
// subgraphs that are differentiable by the jit's autodiff passes // subgraphs that are differentiable by the jit's autodiff passes
// threshold - minimum number of nodes that will appear in a block // threshold - minimum number of nodes that will appear in a block
TORCH_API void CreateAutodiffSubgraphs(Graph & graph, size_t threshold = 2); // returns all differentiable blocks that have been found
TORCH_API std::vector<Node*> CreateAutodiffSubgraphs(Graph & graph, size_t threshold = 2);
}} }}