mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Change specialization rules in GraphExecutors (#10977)
Summary: **Review last commit only.** Stacked on top of #10949. This commit fixes a number of issues connected to caching differentiability status of graphs inside graph executors, and changes the rules for optimization of differentiable subgraphs. Previously every one of those was instantiated as a separate graph executor, but now they are simply heavier-optimized graph regions, and graph executors are only instantiated for their backward. zdevito Pull Request resolved: https://github.com/pytorch/pytorch/pull/10977 Differential Revision: D9600626 Pulled By: apaszke fbshipit-source-id: dad09a0f586e396afbd5406319c1cd54fbb8a3d3
This commit is contained in:
parent
a320e5cbd3
commit
00df09b65d
|
|
@ -4,11 +4,11 @@ graph(%0 : Float(*, *)
|
|||
%3 : Float(*, *) = prim::FusionGroup_0[device=0](%2, %0, %1)
|
||||
return (%3);
|
||||
}
|
||||
with prim::FusionGroup_0 = graph(%1 : Float(*)
|
||||
%4 : Float(*, *)
|
||||
%5 : Float(*)) {
|
||||
%6 : Float(*, *) = aten::mul(%4, %5)
|
||||
%2 : int = prim::Constant[value=1]()
|
||||
%3 : Float(*, *) = aten::add(%6, %1, %2)
|
||||
return (%3);
|
||||
with prim::FusionGroup_0 = graph(%0 : Float(*)
|
||||
%1 : Float(*, *)
|
||||
%2 : Float(*)) {
|
||||
%3 : Float(*, *) = aten::mul(%1, %2)
|
||||
%4 : int = prim::Constant[value=1]()
|
||||
%5 : Float(*, *) = aten::add(%3, %0, %4)
|
||||
return (%5);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,11 +3,11 @@ graph(%0 : Float(*, *)
|
|||
%2 : Float(*, *) = prim::FusionGroup_0[device=0](%0, %1)
|
||||
return (%2);
|
||||
}
|
||||
with prim::FusionGroup_0 = graph(%3 : Float(*, *)
|
||||
%4 : Float(*, *)) {
|
||||
%6 : int = prim::Constant[value=1]()
|
||||
%7 : Float(*, *) = aten::add(%3, %4, %6)
|
||||
%5 : Float(*, *) = aten::mul(%3, %4)
|
||||
%2 : Float(*, *) = prim::FusedConcat[dim=0](%7, %5)
|
||||
return (%2);
|
||||
with prim::FusionGroup_0 = graph(%0 : Float(*, *)
|
||||
%1 : Float(*, *)) {
|
||||
%2 : int = prim::Constant[value=1]()
|
||||
%3 : Float(*, *) = aten::add(%0, %1, %2)
|
||||
%4 : Float(*, *) = aten::mul(%0, %1)
|
||||
%5 : Float(*, *) = prim::FusedConcat[dim=0](%3, %4)
|
||||
return (%5);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,12 +6,12 @@ graph(%0 : Float(*, *)
|
|||
%5 : Float(*, *) = aten::add(%4, %2, %3)
|
||||
return (%5);
|
||||
}
|
||||
with prim::FusionGroup_0 = graph(%3 : Float(*, *)
|
||||
%4 : Float(*, *)) {
|
||||
%7 : int = prim::Constant[value=1]()
|
||||
%8 : Float(*, *) = aten::add(%3, %4, %7)
|
||||
%5 : int = prim::Constant[value=1]()
|
||||
%6 : Float(*, *) = aten::sub(%3, %4, %5)
|
||||
%2 : Float(*, *) = prim::FusedConcat[dim=0](%8, %6)
|
||||
return (%2);
|
||||
with prim::FusionGroup_0 = graph(%0 : Float(*, *)
|
||||
%1 : Float(*, *)) {
|
||||
%2 : int = prim::Constant[value=1]()
|
||||
%3 : Float(*, *) = aten::add(%0, %1, %2)
|
||||
%4 : int = prim::Constant[value=1]()
|
||||
%5 : Float(*, *) = aten::sub(%0, %1, %4)
|
||||
%6 : Float(*, *) = prim::FusedConcat[dim=0](%3, %5)
|
||||
return (%6);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -64,10 +64,10 @@ graph(%0 : Dynamic
|
|||
%2 : Dynamic
|
||||
%3 : Dynamic
|
||||
%4 : Dynamic) {
|
||||
%23 : Dynamic, %24 : Dynamic = prim::GraphExecutor_0(%0, %3, %1, %4, %2)
|
||||
%23 : Dynamic, %24 : Dynamic = prim::DifferentiableGraph_0(%0, %3, %1, %4, %2)
|
||||
return (%24, %23);
|
||||
}
|
||||
with prim::GraphExecutor_0 = graph(%1 : Dynamic
|
||||
with prim::DifferentiableGraph_0 = graph(%1 : Dynamic
|
||||
%2 : Dynamic
|
||||
%4 : Dynamic
|
||||
%5 : Dynamic
|
||||
|
|
|
|||
|
|
@ -3,14 +3,14 @@ graph(%0 : Float(*)
|
|||
%2 : Float(*) = prim::FusionGroup_0[device=1](%0, %1)
|
||||
return (%2);
|
||||
}
|
||||
with prim::FusionGroup_0 = graph(%5 : Float(*)
|
||||
%10 : Float(*)) {
|
||||
%11 : int = prim::Constant[value=1]()
|
||||
%12 : Float(*) = aten::add(%5, %10, %11)
|
||||
%9 : Float(*) = aten::mul(%5, %12)
|
||||
%6 : int = prim::Constant[value=1]()
|
||||
%7 : Float(*) = aten::add(%9, %5, %6)
|
||||
%3 : Float(*) = aten::tanh(%7)
|
||||
%1 : Float(*) = aten::sigmoid(%3)
|
||||
return (%1);
|
||||
with prim::FusionGroup_0 = graph(%0 : Float(*)
|
||||
%1 : Float(*)) {
|
||||
%2 : int = prim::Constant[value=1]()
|
||||
%3 : Float(*) = aten::add(%0, %1, %2)
|
||||
%4 : Float(*) = aten::mul(%0, %3)
|
||||
%5 : int = prim::Constant[value=1]()
|
||||
%6 : Float(*) = aten::add(%4, %0, %5)
|
||||
%7 : Float(*) = aten::tanh(%6)
|
||||
%8 : Float(*) = aten::sigmoid(%7)
|
||||
return (%8);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,14 +6,14 @@ graph(%0 : Float(*, *)
|
|||
%6 : Float(*, *) = prim::FusionGroup_0[device=0](%4, %5)
|
||||
return (%6);
|
||||
}
|
||||
with prim::FusionGroup_0 = graph(%11 : Dynamic
|
||||
%14 : Dynamic) {
|
||||
%15 : Float(*, *), %16 : Float(*, *) = prim::ConstantChunk[chunks=2, dim=1](%14)
|
||||
%12 : Float(*, *), %13 : Float(*, *) = prim::ConstantChunk[chunks=2, dim=1](%11)
|
||||
%9 : int = prim::Constant[value=1]()
|
||||
%10 : Float(*, *) = aten::add(%13, %16, %9)
|
||||
%5 : int = prim::Constant[value=1]()
|
||||
%6 : Float(*, *) = aten::add(%12, %15, %5)
|
||||
%2 : Float(*, *) = aten::mul(%6, %10)
|
||||
return (%2);
|
||||
with prim::FusionGroup_0 = graph(%0 : Dynamic
|
||||
%1 : Dynamic) {
|
||||
%2 : Float(*, *), %3 : Float(*, *) = prim::ConstantChunk[chunks=2, dim=1](%1)
|
||||
%4 : Float(*, *), %5 : Float(*, *) = prim::ConstantChunk[chunks=2, dim=1](%0)
|
||||
%6 : int = prim::Constant[value=1]()
|
||||
%7 : Float(*, *) = aten::add(%5, %3, %6)
|
||||
%8 : int = prim::Constant[value=1]()
|
||||
%9 : Float(*, *) = aten::add(%4, %2, %8)
|
||||
%10 : Float(*, *) = aten::mul(%9, %7)
|
||||
return (%10);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,29 +16,29 @@ graph(%0 : Float(*, *)
|
|||
%16 : Float(*, *) = prim::FusionGroup_0[device=0](%2, %14, %15)
|
||||
return (%16);
|
||||
}
|
||||
with prim::FusionGroup_0 = graph(%15 : Float(*, *)
|
||||
%41 : Dynamic
|
||||
%46 : Dynamic) {
|
||||
%47 : Float(*, *), %48 : Float(*, *), %49 : Float(*, *), %50 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%46)
|
||||
%42 : Float(*, *), %43 : Float(*, *), %44 : Float(*, *), %45 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%41)
|
||||
%39 : int = prim::Constant[value=1]()
|
||||
%40 : Float(*, *) = aten::add(%42, %47, %39)
|
||||
%35 : int = prim::Constant[value=1]()
|
||||
%36 : Float(*, *) = aten::add(%43, %48, %35)
|
||||
%31 : int = prim::Constant[value=1]()
|
||||
%32 : Float(*, *) = aten::add(%44, %49, %31)
|
||||
%27 : int = prim::Constant[value=1]()
|
||||
%28 : Float(*, *) = aten::add(%45, %50, %27)
|
||||
%24 : Float(*, *) = aten::sigmoid(%40)
|
||||
%22 : Float(*, *) = aten::sigmoid(%36)
|
||||
%20 : Float(*, *) = aten::tanh(%32)
|
||||
%18 : Float(*, *) = aten::sigmoid(%28)
|
||||
%16 : Float(*, *) = aten::mul(%22, %15)
|
||||
%13 : Float(*, *) = aten::mul(%24, %20)
|
||||
%9 : int = prim::Constant[value=1]()
|
||||
%10 : Float(*, *) = aten::add(%16, %13, %9)
|
||||
%6 : Float(*, *) = aten::tanh(%10)
|
||||
%5 : Float(*, *) = aten::mul(%18, %6)
|
||||
%2 : Float(*, *) = prim::FusedConcat[dim=0](%5, %10)
|
||||
return (%2);
|
||||
with prim::FusionGroup_0 = graph(%0 : Float(*, *)
|
||||
%1 : Dynamic
|
||||
%2 : Dynamic) {
|
||||
%3 : Float(*, *), %4 : Float(*, *), %5 : Float(*, *), %6 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%2)
|
||||
%7 : Float(*, *), %8 : Float(*, *), %9 : Float(*, *), %10 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%1)
|
||||
%11 : int = prim::Constant[value=1]()
|
||||
%12 : Float(*, *) = aten::add(%7, %3, %11)
|
||||
%13 : int = prim::Constant[value=1]()
|
||||
%14 : Float(*, *) = aten::add(%8, %4, %13)
|
||||
%15 : int = prim::Constant[value=1]()
|
||||
%16 : Float(*, *) = aten::add(%9, %5, %15)
|
||||
%17 : int = prim::Constant[value=1]()
|
||||
%18 : Float(*, *) = aten::add(%10, %6, %17)
|
||||
%19 : Float(*, *) = aten::sigmoid(%12)
|
||||
%20 : Float(*, *) = aten::sigmoid(%14)
|
||||
%21 : Float(*, *) = aten::tanh(%16)
|
||||
%22 : Float(*, *) = aten::sigmoid(%18)
|
||||
%23 : Float(*, *) = aten::mul(%20, %0)
|
||||
%24 : Float(*, *) = aten::mul(%19, %21)
|
||||
%25 : int = prim::Constant[value=1]()
|
||||
%26 : Float(*, *) = aten::add(%23, %24, %25)
|
||||
%27 : Float(*, *) = aten::tanh(%26)
|
||||
%28 : Float(*, *) = aten::mul(%22, %27)
|
||||
%29 : Float(*, *) = prim::FusedConcat[dim=0](%28, %26)
|
||||
return (%29);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,28 +16,28 @@ graph(%0 : Float(*, *)
|
|||
%16 : Float(*, *), %17 : Float(*, *) = prim::FusionGroup_0[device=0](%2, %14, %15)
|
||||
return (%16, %17);
|
||||
}
|
||||
with prim::FusionGroup_0 = graph(%13 : Float(*, *)
|
||||
%39 : Dynamic
|
||||
%44 : Dynamic) {
|
||||
%45 : Float(*, *), %46 : Float(*, *), %47 : Float(*, *), %48 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%44)
|
||||
%40 : Float(*, *), %41 : Float(*, *), %42 : Float(*, *), %43 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%39)
|
||||
%37 : int = prim::Constant[value=1]()
|
||||
%38 : Float(*, *) = aten::add(%40, %45, %37)
|
||||
%33 : int = prim::Constant[value=1]()
|
||||
%34 : Float(*, *) = aten::add(%41, %46, %33)
|
||||
%29 : int = prim::Constant[value=1]()
|
||||
%30 : Float(*, *) = aten::add(%42, %47, %29)
|
||||
with prim::FusionGroup_0 = graph(%0 : Float(*, *)
|
||||
%1 : Dynamic
|
||||
%2 : Dynamic) {
|
||||
%3 : Float(*, *), %4 : Float(*, *), %5 : Float(*, *), %6 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%2)
|
||||
%7 : Float(*, *), %8 : Float(*, *), %9 : Float(*, *), %10 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%1)
|
||||
%11 : int = prim::Constant[value=1]()
|
||||
%12 : Float(*, *) = aten::add(%7, %3, %11)
|
||||
%13 : int = prim::Constant[value=1]()
|
||||
%14 : Float(*, *) = aten::add(%8, %4, %13)
|
||||
%15 : int = prim::Constant[value=1]()
|
||||
%16 : Float(*, *) = aten::add(%9, %5, %15)
|
||||
%17 : int = prim::Constant[value=1]()
|
||||
%18 : Float(*, *) = aten::add(%10, %6, %17)
|
||||
%19 : Float(*, *) = aten::sigmoid(%12)
|
||||
%20 : Float(*, *) = aten::sigmoid(%14)
|
||||
%21 : Float(*, *) = aten::tanh(%16)
|
||||
%22 : Float(*, *) = aten::sigmoid(%18)
|
||||
%23 : Float(*, *) = aten::mul(%20, %0)
|
||||
%24 : Float(*, *) = aten::mul(%19, %21)
|
||||
%25 : int = prim::Constant[value=1]()
|
||||
%26 : Float(*, *) = aten::add(%43, %48, %25)
|
||||
%22 : Float(*, *) = aten::sigmoid(%38)
|
||||
%20 : Float(*, *) = aten::sigmoid(%34)
|
||||
%18 : Float(*, *) = aten::tanh(%30)
|
||||
%16 : Float(*, *) = aten::sigmoid(%26)
|
||||
%14 : Float(*, *) = aten::mul(%20, %13)
|
||||
%11 : Float(*, *) = aten::mul(%22, %18)
|
||||
%7 : int = prim::Constant[value=1]()
|
||||
%8 : Float(*, *) = aten::add(%14, %11, %7)
|
||||
%4 : Float(*, *) = aten::tanh(%8)
|
||||
%2 : Float(*, *) = aten::mul(%16, %4)
|
||||
return (%2, %8);
|
||||
%26 : Float(*, *) = aten::add(%23, %24, %25)
|
||||
%27 : Float(*, *) = aten::tanh(%26)
|
||||
%28 : Float(*, *) = aten::mul(%22, %27)
|
||||
return (%28, %26);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
graph(%0 : Double(3, 4)
|
||||
%1 : Double(4, 5)) {
|
||||
%2 : Double(3, 5) = aten::mm(%0, %1), scope: TracedModule
|
||||
%3 : Double(3, 4) = aten::neg(%2), scope: TracedModule/traced_fn
|
||||
%3 : Double(*, *) = aten::neg(%2), scope: TracedModule/traced_fn
|
||||
return (%3);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
graph(%0 : Double(3, 4)) {
|
||||
%1 : Double(3, 4) = aten::neg(%0), scope: traced_fn1
|
||||
%1 : Double(*, *) = aten::neg(%0), scope: traced_fn1
|
||||
%2 : Long() = prim::Constant[value={1}]()
|
||||
%3 : int = prim::Constant[value=1]()
|
||||
%4 : Double(3, 4) = aten::add(%1, %2, %3)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
graph(%0 : Double(3, 4)) {
|
||||
%1 : Double(4, 3) = prim::Constant[value=<Tensor>](), scope: TracedModule[TracedModule]
|
||||
%2 : Double(3, 3) = aten::mm(%0, %1), scope: TracedModule[TracedModule]
|
||||
%2 : Double(*, *) = aten::mm(%0, %1), scope: TracedModule[TracedModule]
|
||||
%3 : Double() = prim::Constant[value={1}]()
|
||||
%4 : int = prim::Constant[value=1]()
|
||||
%5 : Double(3, 3) = aten::add(%2, %3, %4)
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ graph(%0 : Double(3, 4)
|
|||
%1 : Double(4, 5)
|
||||
%2 : Double(5, 7)) {
|
||||
%3 : Double(3, 5) = aten::mm(%0, %1), scope: TracedModule
|
||||
%4 : Double(3, 7) = aten::mm(%3, %2), scope: TracedModule/TracedModule[TracedModule1][mod]
|
||||
%4 : Double(*, *) = aten::mm(%3, %2), scope: TracedModule/TracedModule[TracedModule1][mod]
|
||||
%5 : Double() = prim::Constant[value={1}](), scope: TracedModule
|
||||
%6 : int = prim::Constant[value=1](), scope: TracedModule
|
||||
%7 : Double(3, 7) = aten::add(%4, %5, %6), scope: TracedModule
|
||||
|
|
|
|||
|
|
@ -2,10 +2,10 @@ graph(%x : Float(*, *)) {
|
|||
%1 : Float(*, *) = prim::FusionGroup_0[device=0](%x)
|
||||
return (%1);
|
||||
}
|
||||
with prim::FusionGroup_0 = graph(%7 : Float(*, *)) {
|
||||
%8 : Float(*, *), %9 : Float(*, *), %10 : Float(*, *) = prim::ConstantChunk[chunks=3, dim=1](%7)
|
||||
%6 : Float(*, *) = aten::mul(%8, %9)
|
||||
%2 : int = prim::Constant[value=1]()
|
||||
%3 : Float(*, *) = aten::add(%6, %10, %2)
|
||||
return (%3);
|
||||
with prim::FusionGroup_0 = graph(%0 : Float(*, *)) {
|
||||
%1 : Float(*, *), %2 : Float(*, *), %3 : Float(*, *) = prim::ConstantChunk[chunks=3, dim=1](%0)
|
||||
%4 : Float(*, *) = aten::mul(%1, %2)
|
||||
%5 : int = prim::Constant[value=1]()
|
||||
%6 : Float(*, *) = aten::add(%4, %3, %5)
|
||||
return (%6);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,26 +5,26 @@ graph(%s : Float(*, *, *)
|
|||
%4 : Float(*, *, *) = prim::FusionGroup_0[device=0](%s, %y, %x, %z)
|
||||
return (%4);
|
||||
}
|
||||
with prim::FusionGroup_0 = graph(%24 : Float(*, *, *)
|
||||
%28 : Float(*, *, *)
|
||||
%31 : Float(*, *, *)
|
||||
%35 : Float(*, *, *)) {
|
||||
%36 : Float(*, *, *), %37 : Float(*, *, *) = prim::ConstantChunk[chunks=2, dim=2](%35)
|
||||
%32 : Float(*, *, *), %33 : Float(*, *, *), %34 : Float(*, *, *) = prim::ConstantChunk[chunks=3, dim=1](%31)
|
||||
%29 : Float(*, *, *), %30 : Float(*, *, *) = prim::ConstantChunk[chunks=2, dim=0](%28)
|
||||
%26 : int = prim::Constant[value=1]()
|
||||
%27 : Float(*, *, *) = aten::add(%24, %32, %26)
|
||||
%22 : int = prim::Constant[value=1]()
|
||||
%23 : Float(*, *, *) = aten::add(%27, %33, %22)
|
||||
%18 : int = prim::Constant[value=1]()
|
||||
%19 : Float(*, *, *) = aten::add(%23, %34, %18)
|
||||
%14 : int = prim::Constant[value=1]()
|
||||
%15 : Float(*, *, *) = aten::add(%19, %29, %14)
|
||||
%10 : int = prim::Constant[value=1]()
|
||||
%11 : Float(*, *, *) = aten::add(%15, %30, %10)
|
||||
%6 : int = prim::Constant[value=1]()
|
||||
%7 : Float(*, *, *) = aten::add(%11, %36, %6)
|
||||
%2 : int = prim::Constant[value=1]()
|
||||
%3 : Float(*, *, *) = aten::add(%7, %37, %2)
|
||||
return (%3);
|
||||
with prim::FusionGroup_0 = graph(%0 : Float(*, *, *)
|
||||
%1 : Float(*, *, *)
|
||||
%2 : Float(*, *, *)
|
||||
%3 : Float(*, *, *)) {
|
||||
%4 : Float(*, *, *), %5 : Float(*, *, *) = prim::ConstantChunk[chunks=2, dim=2](%3)
|
||||
%6 : Float(*, *, *), %7 : Float(*, *, *), %8 : Float(*, *, *) = prim::ConstantChunk[chunks=3, dim=1](%2)
|
||||
%9 : Float(*, *, *), %10 : Float(*, *, *) = prim::ConstantChunk[chunks=2, dim=0](%1)
|
||||
%11 : int = prim::Constant[value=1]()
|
||||
%12 : Float(*, *, *) = aten::add(%0, %6, %11)
|
||||
%13 : int = prim::Constant[value=1]()
|
||||
%14 : Float(*, *, *) = aten::add(%12, %7, %13)
|
||||
%15 : int = prim::Constant[value=1]()
|
||||
%16 : Float(*, *, *) = aten::add(%14, %8, %15)
|
||||
%17 : int = prim::Constant[value=1]()
|
||||
%18 : Float(*, *, *) = aten::add(%16, %9, %17)
|
||||
%19 : int = prim::Constant[value=1]()
|
||||
%20 : Float(*, *, *) = aten::add(%18, %10, %19)
|
||||
%21 : int = prim::Constant[value=1]()
|
||||
%22 : Float(*, *, *) = aten::add(%20, %4, %21)
|
||||
%23 : int = prim::Constant[value=1]()
|
||||
%24 : Float(*, *, *) = aten::add(%22, %5, %23)
|
||||
return (%24);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,9 +7,9 @@ graph(%0 : Float(*, *)
|
|||
%6 : Dynamic
|
||||
%7 : Dynamic
|
||||
%8 : Dynamic
|
||||
%x : Float(*, *)
|
||||
%hx : Float(*, *)
|
||||
%cx : Float(*, *)
|
||||
%9 : Float(*, *)
|
||||
%10 : Float(*, *)
|
||||
%11 : Float(*, *)
|
||||
%12 : Float(*, *)
|
||||
%13 : Float(*, *)
|
||||
%ingate : Float(*, *)
|
||||
|
|
@ -18,58 +18,67 @@ graph(%0 : Float(*, *)
|
|||
%outgate : Float(*, *)
|
||||
%18 : Float(*, *)) {
|
||||
%19 : int = prim::Constant[value=1]()
|
||||
%20 : Float(*, *) = prim::FusionGroup_0[device=0](%ingate, %forgetgate, %cellgate, %outgate, %cx, %1, %18, %0)
|
||||
%21 : Float(*, *) = aten::mul(%20, %19)
|
||||
%22 : Float(*, *) = aten::t(%hx)
|
||||
%23 : Float(*, *) = aten::mm(%22, %21)
|
||||
%24 : Float(*, *) = aten::t(%23)
|
||||
%25 : Float(*, *) = aten::t(%x)
|
||||
%26 : Float(*, *) = aten::mm(%25, %20)
|
||||
%27 : Float(*, *) = aten::t(%26)
|
||||
return (%27, %24, %21, %21);
|
||||
%20 : Float(*, *), %21 : Float(*, *) = prim::FusionGroup_0[device=0](%ingate, %forgetgate, %cellgate, %outgate, %11, %1, %18, %0)
|
||||
%22 : Float(*, *) = aten::mul(%20, %19)
|
||||
%23 : Float(*, *) = aten::t(%12)
|
||||
%24 : Float(*, *) = aten::mm(%20, %23)
|
||||
%25 : Float(*, *) = aten::mul(%24, %19)
|
||||
%26 : Float(*, *) = aten::t(%10)
|
||||
%27 : Float(*, *) = aten::mm(%26, %20)
|
||||
%28 : Float(*, *) = aten::mul(%27, %19)
|
||||
%29 : Float(*, *) = aten::t(%13)
|
||||
%30 : Float(*, *) = aten::mm(%22, %29)
|
||||
%31 : Float(*, *) = aten::t(%9)
|
||||
%32 : Float(*, *) = aten::mm(%31, %22)
|
||||
%33 : Float(*, *) = aten::t(%32)
|
||||
%34 : Float(*, *) = aten::t(%28)
|
||||
return (%34, %33, %30, %25, %22, %22, %21);
|
||||
}
|
||||
with prim::FusionGroup_0 = graph(%9 : Float(*, *)
|
||||
%19 : Float(*, *)
|
||||
%33 : Float(*, *)
|
||||
%39 : Float(*, *)
|
||||
%46 : Float(*, *)
|
||||
%53 : Float(*, *)
|
||||
%65 : Float(*, *)
|
||||
%67 : Float(*, *)) {
|
||||
%69 : Float(*, *) = aten::mul(%67, %65)
|
||||
%68 : Float(*, *) = aten::mul(%67, %39)
|
||||
%66 : Float(*, *) = aten::mul(%65, %65)
|
||||
%64 : Float(*, *) = aten::neg(%66)
|
||||
%61 : int = prim::Constant[value=1]()
|
||||
%62 : Float(*, *) = aten::add(%64, %61, %61)
|
||||
%59 : Float(*, *) = aten::mul(%68, %62)
|
||||
%55 : int = prim::Constant[value=1]()
|
||||
%56 : Float(*, *) = aten::add(%53, %59, %55)
|
||||
%51 : int = prim::Constant[value=1]()
|
||||
%52 : Float(*, *) = aten::mul(%56, %51)
|
||||
%50 : Float(*, *) = aten::mul(%52, %33)
|
||||
%49 : Float(*, *) = aten::mul(%52, %9)
|
||||
%47 : Float(*, *) = aten::mul(%56, %46)
|
||||
%44 : Float(*, *) = aten::neg(%39)
|
||||
%42 : int = prim::Constant[value=1]()
|
||||
%43 : Float(*, *) = aten::add(%44, %42, %42)
|
||||
%40 : Float(*, *) = aten::mul(%69, %39)
|
||||
%37 : Float(*, *) = aten::mul(%40, %43)
|
||||
%34 : Float(*, *) = aten::mul(%33, %33)
|
||||
%32 : Float(*, *) = aten::neg(%34)
|
||||
%29 : int = prim::Constant[value=1]()
|
||||
%30 : Float(*, *) = aten::add(%32, %29, %29)
|
||||
%27 : Float(*, *) = aten::mul(%49, %30)
|
||||
%24 : Float(*, *) = aten::neg(%19)
|
||||
%22 : int = prim::Constant[value=1]()
|
||||
%23 : Float(*, *) = aten::add(%24, %22, %22)
|
||||
%20 : Float(*, *) = aten::mul(%47, %19)
|
||||
%17 : Float(*, *) = aten::mul(%20, %23)
|
||||
%14 : Float(*, *) = aten::neg(%9)
|
||||
with prim::FusionGroup_0 = graph(%0 : Float(*, *)
|
||||
%1 : Float(*, *)
|
||||
%2 : Float(*, *)
|
||||
%3 : Float(*, *)
|
||||
%4 : Float(*, *)
|
||||
%5 : Float(*, *)
|
||||
%6 : Float(*, *)
|
||||
%7 : Float(*, *)) {
|
||||
%8 : Float(*, *) = aten::mul(%5, %3)
|
||||
%9 : Float(*, *) = aten::mul(%6, %6)
|
||||
%10 : Float(*, *) = aten::neg(%9)
|
||||
%11 : int = prim::Constant[value=1]()
|
||||
%12 : int = prim::Constant[value=1]()
|
||||
%13 : Float(*, *) = aten::add(%14, %12, %12)
|
||||
%10 : Float(*, *) = aten::mul(%50, %9)
|
||||
%7 : Float(*, *) = aten::mul(%10, %13)
|
||||
%4 : Float(*, *) = prim::FusedConcat[dim=1](%7, %17, %27, %37)
|
||||
return (%4);
|
||||
%13 : Float(*, *) = aten::add(%10, %12, %12)
|
||||
%14 : Float(*, *) = aten::mul(%8, %13)
|
||||
%15 : int = prim::Constant[value=1]()
|
||||
%16 : int = prim::Constant[value=1]()
|
||||
%17 : Float(*, *) = aten::add(%7, %14, %16)
|
||||
%18 : Float(*, *) = aten::mul(%17, %1)
|
||||
%19 : Float(*, *) = aten::mul(%5, %6)
|
||||
%20 : int = prim::Constant[value=1]()
|
||||
%21 : Float(*, *) = aten::mul(%17, %20)
|
||||
%22 : Float(*, *) = aten::mul(%21, %2)
|
||||
%23 : Float(*, *) = aten::mul(%21, %0)
|
||||
%24 : Float(*, *) = aten::mul(%17, %4)
|
||||
%25 : Float(*, *) = aten::neg(%3)
|
||||
%26 : int = prim::Constant[value=1]()
|
||||
%27 : Float(*, *) = aten::add(%25, %26, %26)
|
||||
%28 : Float(*, *) = aten::mul(%19, %3)
|
||||
%29 : Float(*, *) = aten::mul(%28, %27)
|
||||
%30 : Float(*, *) = aten::mul(%2, %2)
|
||||
%31 : Float(*, *) = aten::neg(%30)
|
||||
%32 : int = prim::Constant[value=1]()
|
||||
%33 : Float(*, *) = aten::add(%31, %32, %32)
|
||||
%34 : Float(*, *) = aten::mul(%23, %33)
|
||||
%35 : Float(*, *) = aten::neg(%1)
|
||||
%36 : int = prim::Constant[value=1]()
|
||||
%37 : Float(*, *) = aten::add(%35, %36, %36)
|
||||
%38 : Float(*, *) = aten::mul(%24, %1)
|
||||
%39 : Float(*, *) = aten::mul(%38, %37)
|
||||
%40 : Float(*, *) = aten::neg(%0)
|
||||
%41 : int = prim::Constant[value=1]()
|
||||
%42 : Float(*, *) = aten::add(%40, %41, %41)
|
||||
%43 : Float(*, *) = aten::mul(%22, %0)
|
||||
%44 : Float(*, *) = aten::mul(%43, %42)
|
||||
%45 : Float(*, *) = prim::FusedConcat[dim=1](%44, %39, %34, %29)
|
||||
return (%45, %18);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,44 +1,54 @@
|
|||
graph(%x.1 : Float(*, *)
|
||||
%hx.1 : Float(*, *)
|
||||
%cx.1 : Float(*, *)
|
||||
graph(%x : Float(*, *)
|
||||
%hx : Float(*, *)
|
||||
%cx : Float(*, *)
|
||||
%w_ih : Float(*, *)
|
||||
%w_hh : Float(*, *)
|
||||
%b_ih : Float(*)
|
||||
%b_hh : Float(*)) {
|
||||
%7 : Float(*, *) = aten::t(%w_ih)
|
||||
%8 : Float(*, *) = aten::t(%w_hh)
|
||||
%9 : Float(*, *) = aten::mm(%hx.1, %8)
|
||||
%7 : Float(*, *), %8 : Float(*, *) = prim::DifferentiableGraph_0(%w_ih, %w_hh, %hx, %x, %b_ih, %b_hh, %cx)
|
||||
return (%8, %7);
|
||||
}
|
||||
with prim::DifferentiableGraph_0 = graph(%0 : Float(*, *)
|
||||
%1 : Float(*, *)
|
||||
%2 : Float(*, *)
|
||||
%3 : Float(*, *)
|
||||
%4 : Float(*)
|
||||
%5 : Float(*)
|
||||
%6 : Float(*, *)) {
|
||||
%7 : Float(*, *) = aten::t(%0)
|
||||
%8 : Float(*, *) = aten::t(%1)
|
||||
%9 : Float(*, *) = aten::mm(%2, %8)
|
||||
%10 : int = prim::Constant[value=1]()
|
||||
%11 : Float(*, *) = aten::addmm(%9, %x.1, %7, %10, %10)
|
||||
%12 : Float(*, *) = aten::add(%11, %b_ih, %10)
|
||||
%13 : Dynamic[] = prim::ListConstruct(%12, %b_hh)
|
||||
%11 : Float(*, *) = aten::addmm(%9, %3, %7, %10, %10)
|
||||
%12 : Float(*, *) = aten::add(%11, %4, %10)
|
||||
%13 : Dynamic[] = prim::ListConstruct(%12, %5)
|
||||
%14 : Dynamic[] = aten::broadcast_tensors(%13)
|
||||
%15 : Dynamic, %16 : Dynamic = prim::ListUnpack(%14)
|
||||
%hy : Float(*, *), %18 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_0[device=0](%cx.1, %15, %16)
|
||||
return (%hy, %cy, %7, %8, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %18);
|
||||
%hy : Float(*, *), %18 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_0[device=0](%6, %15, %16)
|
||||
return (%cy, %hy, %7, %8, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %18);
|
||||
}
|
||||
with prim::FusionGroup_0 = graph(%13 : Float(*, *)
|
||||
%39 : Dynamic
|
||||
%44 : Dynamic) {
|
||||
%45 : Dynamic, %46 : Dynamic, %47 : Dynamic, %48 : Dynamic = prim::FusedChunk[chunks=4, dim=1](%44)
|
||||
%40 : Dynamic, %41 : Dynamic, %42 : Dynamic, %43 : Dynamic = prim::FusedChunk[chunks=4, dim=1](%39)
|
||||
%37 : int = prim::Constant[value=1]()
|
||||
%38 : Float(*, *) = aten::add(%40, %45, %37)
|
||||
%33 : int = prim::Constant[value=1]()
|
||||
%34 : Float(*, *) = aten::add(%41, %46, %33)
|
||||
%29 : int = prim::Constant[value=1]()
|
||||
%30 : Float(*, *) = aten::add(%42, %47, %29)
|
||||
with prim::FusionGroup_0 = graph(%0 : Float(*, *)
|
||||
%1 : Dynamic
|
||||
%2 : Dynamic) {
|
||||
%3 : Float(*, *), %4 : Float(*, *), %5 : Float(*, *), %6 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%2)
|
||||
%7 : Float(*, *), %8 : Float(*, *), %9 : Float(*, *), %10 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%1)
|
||||
%11 : int = prim::Constant[value=1]()
|
||||
%12 : Float(*, *) = aten::add(%7, %3, %11)
|
||||
%13 : int = prim::Constant[value=1]()
|
||||
%14 : Float(*, *) = aten::add(%8, %4, %13)
|
||||
%15 : int = prim::Constant[value=1]()
|
||||
%16 : Float(*, *) = aten::add(%9, %5, %15)
|
||||
%17 : int = prim::Constant[value=1]()
|
||||
%18 : Float(*, *) = aten::add(%10, %6, %17)
|
||||
%ingate.1 : Float(*, *) = aten::sigmoid(%12)
|
||||
%forgetgate.1 : Float(*, *) = aten::sigmoid(%14)
|
||||
%cellgate.1 : Float(*, *) = aten::tanh(%16)
|
||||
%outgate.1 : Float(*, *) = aten::sigmoid(%18)
|
||||
%23 : Float(*, *) = aten::mul(%forgetgate.1, %0)
|
||||
%24 : Float(*, *) = aten::mul(%ingate.1, %cellgate.1)
|
||||
%25 : int = prim::Constant[value=1]()
|
||||
%26 : Float(*, *) = aten::add(%43, %48, %25)
|
||||
%ingate.1 : Float(*, *) = aten::sigmoid(%38)
|
||||
%forgetgate.1 : Float(*, *) = aten::sigmoid(%34)
|
||||
%cellgate.1 : Float(*, *) = aten::tanh(%30)
|
||||
%outgate.1 : Float(*, *) = aten::sigmoid(%26)
|
||||
%14 : Float(*, *) = aten::mul(%forgetgate.1, %13)
|
||||
%11 : Float(*, *) = aten::mul(%ingate.1, %cellgate.1)
|
||||
%7 : int = prim::Constant[value=1]()
|
||||
%cy : Float(*, *) = aten::add(%14, %11, %7)
|
||||
%4 : Float(*, *) = aten::tanh(%cy)
|
||||
%hy : Float(*, *) = aten::mul(%outgate.1, %4)
|
||||
return (%hy, %4, %cy, %outgate.1, %cellgate.1, %forgetgate.1, %ingate.1);
|
||||
%cy : Float(*, *) = aten::add(%23, %24, %25)
|
||||
%27 : Float(*, *) = aten::tanh(%cy)
|
||||
%hy : Float(*, *) = aten::mul(%outgate.1, %27)
|
||||
return (%hy, %27, %cy, %outgate.1, %cellgate.1, %forgetgate.1, %ingate.1);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,12 +10,12 @@ graph(%0 : Float(*, *)
|
|||
%9 : Dynamic
|
||||
%10 : Dynamic
|
||||
%11 : Dynamic
|
||||
%x : Float(*, *)
|
||||
%hx : Float(*, *)
|
||||
%cx : Float(*, *)
|
||||
%alpha : Float(*)
|
||||
%beta_i : Float(*)
|
||||
%beta_h : Float(*)
|
||||
%12 : Float(*, *)
|
||||
%13 : Float(*, *)
|
||||
%14 : Float(*)
|
||||
%15 : Float(*)
|
||||
%16 : Float(*)
|
||||
%17 : Float(*, *)
|
||||
%18 : Float(*, *)
|
||||
%Wx : Float(*, *)
|
||||
%20 : Float(*, *)
|
||||
|
|
@ -26,85 +26,92 @@ graph(%0 : Float(*, *)
|
|||
%cellgate : Float(*, *)
|
||||
%outgate : Float(*, *)
|
||||
%27 : Float(*, *)) {
|
||||
%28 : Float(*, *) = prim::FusionGroup_0[device=0](%ingate, %forgetgate, %cellgate, %outgate, %cx, %1, %27, %0)
|
||||
%29 : Float(*, *), %30 : Float(*, *), %31 : Float(*, *), %32 : Float(*, *), %33 : Float(*, *), %34 : Float(*, *) = prim::FusionGroup_1[device=0](%alpha, %beta_i, %Wx, %28, %Uz, %22, %beta_h)
|
||||
%35 : Float(*, *) = aten::t(%hx)
|
||||
%36 : Float(*, *) = aten::mm(%35, %31)
|
||||
%37 : Float(*, *) = aten::t(%36)
|
||||
%38 : Float(*, *) = aten::t(%x)
|
||||
%39 : Float(*, *) = aten::mm(%38, %29)
|
||||
%28 : Float(*, *), %29 : Float(*, *) = prim::FusionGroup_0[device=0](%ingate, %forgetgate, %cellgate, %outgate, %17, %1, %27, %0)
|
||||
%30 : Float(*, *), %31 : Float(*, *), %32 : Float(*, *), %33 : Float(*, *), %34 : Float(*, *), %35 : Float(*, *) = prim::FusionGroup_1[device=0](%14, %15, %Wx, %28, %Uz, %22, %16)
|
||||
%36 : Float(*, *) = aten::t(%20)
|
||||
%37 : Float(*, *) = aten::mm(%32, %36)
|
||||
%38 : Float(*, *) = aten::t(%13)
|
||||
%39 : Float(*, *) = aten::mm(%38, %32)
|
||||
%40 : Float(*, *) = aten::t(%39)
|
||||
return (%40, %37, %30, %32, %33, %34);
|
||||
%41 : Float(*, *) = aten::t(%18)
|
||||
%42 : Float(*, *) = aten::mm(%30, %41)
|
||||
%43 : Float(*, *) = aten::t(%12)
|
||||
%44 : Float(*, *) = aten::mm(%43, %30)
|
||||
%45 : Float(*, *) = aten::t(%44)
|
||||
return (%45, %42, %40, %37, %31, %33, %34, %35, %29);
|
||||
}
|
||||
with prim::FusionGroup_0 = graph(%9 : Float(*, *)
|
||||
%19 : Float(*, *)
|
||||
%33 : Float(*, *)
|
||||
%39 : Float(*, *)
|
||||
%46 : Float(*, *)
|
||||
%53 : Float(*, *)
|
||||
%65 : Float(*, *)
|
||||
%67 : Float(*, *)) {
|
||||
%69 : Float(*, *) = aten::mul(%67, %65)
|
||||
%68 : Float(*, *) = aten::mul(%67, %39)
|
||||
%66 : Float(*, *) = aten::mul(%65, %65)
|
||||
%64 : Float(*, *) = aten::neg(%66)
|
||||
%61 : int = prim::Constant[value=1]()
|
||||
%62 : Float(*, *) = aten::add(%64, %61, %61)
|
||||
%59 : Float(*, *) = aten::mul(%68, %62)
|
||||
%55 : int = prim::Constant[value=1]()
|
||||
%56 : Float(*, *) = aten::add(%53, %59, %55)
|
||||
%51 : int = prim::Constant[value=1]()
|
||||
%52 : Float(*, *) = aten::mul(%56, %51)
|
||||
%50 : Float(*, *) = aten::mul(%52, %33)
|
||||
%49 : Float(*, *) = aten::mul(%52, %9)
|
||||
%47 : Float(*, *) = aten::mul(%56, %46)
|
||||
%44 : Float(*, *) = aten::neg(%39)
|
||||
%42 : int = prim::Constant[value=1]()
|
||||
%43 : Float(*, *) = aten::add(%44, %42, %42)
|
||||
%40 : Float(*, *) = aten::mul(%69, %39)
|
||||
%37 : Float(*, *) = aten::mul(%40, %43)
|
||||
%34 : Float(*, *) = aten::mul(%33, %33)
|
||||
%32 : Float(*, *) = aten::neg(%34)
|
||||
%29 : int = prim::Constant[value=1]()
|
||||
%30 : Float(*, *) = aten::add(%32, %29, %29)
|
||||
%27 : Float(*, *) = aten::mul(%49, %30)
|
||||
%24 : Float(*, *) = aten::neg(%19)
|
||||
%22 : int = prim::Constant[value=1]()
|
||||
%23 : Float(*, *) = aten::add(%24, %22, %22)
|
||||
%20 : Float(*, *) = aten::mul(%47, %19)
|
||||
%17 : Float(*, *) = aten::mul(%20, %23)
|
||||
%14 : Float(*, *) = aten::neg(%9)
|
||||
with prim::FusionGroup_0 = graph(%0 : Float(*, *)
|
||||
%1 : Float(*, *)
|
||||
%2 : Float(*, *)
|
||||
%3 : Float(*, *)
|
||||
%4 : Float(*, *)
|
||||
%5 : Float(*, *)
|
||||
%6 : Float(*, *)
|
||||
%7 : Float(*, *)) {
|
||||
%8 : Float(*, *) = aten::mul(%5, %3)
|
||||
%9 : Float(*, *) = aten::mul(%6, %6)
|
||||
%10 : Float(*, *) = aten::neg(%9)
|
||||
%11 : int = prim::Constant[value=1]()
|
||||
%12 : int = prim::Constant[value=1]()
|
||||
%13 : Float(*, *) = aten::add(%14, %12, %12)
|
||||
%10 : Float(*, *) = aten::mul(%50, %9)
|
||||
%7 : Float(*, *) = aten::mul(%10, %13)
|
||||
%4 : Float(*, *) = prim::FusedConcat[dim=1](%7, %17, %27, %37)
|
||||
return (%4);
|
||||
}
|
||||
with prim::FusionGroup_1 = graph(%5 : Float(*)
|
||||
%8 : Float(*)
|
||||
%10 : Float(*, *)
|
||||
%12 : Float(*, *)
|
||||
%13 : Float(*, *)
|
||||
%20 : Float(*, *)
|
||||
%22 : Float(*)) {
|
||||
%30 : int = prim::Constant[value=1]()
|
||||
%29 : int = prim::Constant[value=1]()
|
||||
%28 : int = prim::Constant[value=1]()
|
||||
%13 : Float(*, *) = aten::add(%10, %12, %12)
|
||||
%14 : Float(*, *) = aten::mul(%8, %13)
|
||||
%15 : int = prim::Constant[value=1]()
|
||||
%16 : int = prim::Constant[value=1]()
|
||||
%17 : Float(*, *) = aten::add(%7, %14, %16)
|
||||
%18 : Float(*, *) = aten::mul(%17, %1)
|
||||
%19 : Float(*, *) = aten::mul(%5, %6)
|
||||
%20 : int = prim::Constant[value=1]()
|
||||
%21 : Float(*, *) = aten::mul(%17, %20)
|
||||
%22 : Float(*, *) = aten::mul(%21, %2)
|
||||
%23 : Float(*, *) = aten::mul(%21, %0)
|
||||
%24 : Float(*, *) = aten::mul(%17, %4)
|
||||
%25 : Float(*, *) = aten::neg(%3)
|
||||
%26 : int = prim::Constant[value=1]()
|
||||
%27 : Float(*, *) = aten::mul(%12, %26)
|
||||
%25 : Float(*, *) = aten::mul(%27, %13)
|
||||
%24 : Float(*, *) = aten::mul(%27, %10)
|
||||
%23 : Float(*, *) = aten::mul(%27, %22)
|
||||
%21 : Float(*, *) = aten::mul(%12, %20)
|
||||
%19 : int = prim::Constant[value=1]()
|
||||
%17 : int = prim::Constant[value=1]()
|
||||
%18 : Float(*, *) = aten::add(%23, %21, %17)
|
||||
%14 : Float(*, *) = aten::mul(%12, %13)
|
||||
%11 : Float(*, *) = aten::mul(%14, %10)
|
||||
%9 : Float(*, *) = aten::mul(%27, %8)
|
||||
%6 : Float(*, *) = aten::mul(%14, %5)
|
||||
%2 : int = prim::Constant[value=1]()
|
||||
%3 : Float(*, *) = aten::add(%9, %6, %2)
|
||||
return (%3, %11, %18, %24, %25, %27);
|
||||
%27 : Float(*, *) = aten::add(%25, %26, %26)
|
||||
%28 : Float(*, *) = aten::mul(%19, %3)
|
||||
%29 : Float(*, *) = aten::mul(%28, %27)
|
||||
%30 : Float(*, *) = aten::mul(%2, %2)
|
||||
%31 : Float(*, *) = aten::neg(%30)
|
||||
%32 : int = prim::Constant[value=1]()
|
||||
%33 : Float(*, *) = aten::add(%31, %32, %32)
|
||||
%34 : Float(*, *) = aten::mul(%23, %33)
|
||||
%35 : Float(*, *) = aten::neg(%1)
|
||||
%36 : int = prim::Constant[value=1]()
|
||||
%37 : Float(*, *) = aten::add(%35, %36, %36)
|
||||
%38 : Float(*, *) = aten::mul(%24, %1)
|
||||
%39 : Float(*, *) = aten::mul(%38, %37)
|
||||
%40 : Float(*, *) = aten::neg(%0)
|
||||
%41 : int = prim::Constant[value=1]()
|
||||
%42 : Float(*, *) = aten::add(%40, %41, %41)
|
||||
%43 : Float(*, *) = aten::mul(%22, %0)
|
||||
%44 : Float(*, *) = aten::mul(%43, %42)
|
||||
%45 : Float(*, *) = prim::FusedConcat[dim=1](%44, %39, %34, %29)
|
||||
return (%45, %18);
|
||||
}
|
||||
with prim::FusionGroup_1 = graph(%0 : Float(*)
|
||||
%1 : Float(*)
|
||||
%2 : Float(*, *)
|
||||
%3 : Float(*, *)
|
||||
%4 : Float(*, *)
|
||||
%5 : Float(*, *)
|
||||
%6 : Float(*)) {
|
||||
%7 : int = prim::Constant[value=1]()
|
||||
%8 : int = prim::Constant[value=1]()
|
||||
%9 : int = prim::Constant[value=1]()
|
||||
%10 : int = prim::Constant[value=1]()
|
||||
%11 : Float(*, *) = aten::mul(%3, %10)
|
||||
%12 : Float(*, *) = aten::mul(%11, %4)
|
||||
%13 : Float(*, *) = aten::mul(%11, %2)
|
||||
%14 : Float(*, *) = aten::mul(%11, %6)
|
||||
%15 : Float(*, *) = aten::mul(%3, %5)
|
||||
%16 : int = prim::Constant[value=1]()
|
||||
%17 : int = prim::Constant[value=1]()
|
||||
%18 : Float(*, *) = aten::add(%14, %15, %17)
|
||||
%19 : Float(*, *) = aten::mul(%3, %4)
|
||||
%20 : Float(*, *) = aten::mul(%19, %2)
|
||||
%21 : Float(*, *) = aten::mul(%11, %1)
|
||||
%22 : Float(*, *) = aten::mul(%19, %0)
|
||||
%23 : int = prim::Constant[value=1]()
|
||||
%24 : Float(*, *) = aten::add(%21, %22, %23)
|
||||
return (%24, %20, %18, %13, %12, %11);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,60 +1,73 @@
|
|||
graph(%x.1 : Float(*, *)
|
||||
%hx.1 : Float(*, *)
|
||||
%cx.1 : Float(*, *)
|
||||
graph(%x : Float(*, *)
|
||||
%hx : Float(*, *)
|
||||
%cx : Float(*, *)
|
||||
%w_ih : Float(*, *)
|
||||
%w_hh : Float(*, *)
|
||||
%alpha.1 : Float(*)
|
||||
%beta_i.1 : Float(*)
|
||||
%beta_h.1 : Float(*)
|
||||
%alpha : Float(*)
|
||||
%beta_i : Float(*)
|
||||
%beta_h : Float(*)
|
||||
%bias : Float(*)) {
|
||||
%9 : Float(*, *) = aten::t(%w_ih)
|
||||
%Wx.1 : Float(*, *) = aten::mm(%x.1, %9)
|
||||
%11 : Float(*, *) = aten::t(%w_hh)
|
||||
%Uz.1 : Float(*, *) = aten::mm(%hx.1, %11)
|
||||
%13 : Float(*, *), %14 : Float(*, *) = prim::FusionGroup_0[device=0](%beta_h.1, %Uz.1, %beta_i.1, %Wx.1, %alpha.1)
|
||||
%15 : Dynamic[] = prim::ListConstruct(%13, %bias)
|
||||
%16 : Dynamic[] = aten::broadcast_tensors(%15)
|
||||
%17 : Dynamic, %18 : Dynamic = prim::ListUnpack(%16)
|
||||
%hy : Float(*, *), %20 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_1[device=0](%cx.1, %17, %18)
|
||||
return (%hy, %cy, %9, %Wx.1, %11, %Uz.1, %14, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %20);
|
||||
%9 : Float(*, *), %10 : Float(*, *) = prim::DifferentiableGraph_0(%w_ih, %x, %w_hh, %hx, %alpha, %beta_i, %beta_h, %bias, %cx)
|
||||
return (%10, %9);
|
||||
}
|
||||
with prim::FusionGroup_0 = graph(%4 : Float(*)
|
||||
%5 : Float(*, *)
|
||||
%11 : Float(*)
|
||||
%12 : Float(*, *)
|
||||
%16 : Float(*)) {
|
||||
%17 : Float(*, *) = aten::mul(%16, %12)
|
||||
%15 : Float(*, *) = aten::mul(%17, %5)
|
||||
%13 : Float(*, *) = aten::mul(%11, %12)
|
||||
%9 : int = prim::Constant[value=1]()
|
||||
%10 : Float(*, *) = aten::add(%15, %13, %9)
|
||||
%6 : Float(*, *) = aten::mul(%4, %5)
|
||||
%2 : int = prim::Constant[value=1]()
|
||||
%3 : Float(*, *) = aten::add(%10, %6, %2)
|
||||
return (%3, %17);
|
||||
with prim::DifferentiableGraph_0 = graph(%0 : Float(*, *)
|
||||
%1 : Float(*, *)
|
||||
%2 : Float(*, *)
|
||||
%3 : Float(*, *)
|
||||
%4 : Float(*)
|
||||
%5 : Float(*)
|
||||
%6 : Float(*)
|
||||
%7 : Float(*)
|
||||
%8 : Float(*, *)) {
|
||||
%9 : Float(*, *) = aten::t(%0)
|
||||
%Wx.1 : Float(*, *) = aten::mm(%1, %9)
|
||||
%11 : Float(*, *) = aten::t(%2)
|
||||
%Uz.1 : Float(*, *) = aten::mm(%3, %11)
|
||||
%13 : int = prim::Constant[value=1]()
|
||||
%14 : Float(*, *), %15 : Float(*, *) = prim::FusionGroup_0[device=0](%6, %Uz.1, %5, %Wx.1, %4)
|
||||
%16 : Dynamic[] = prim::ListConstruct(%14, %7)
|
||||
%17 : Dynamic[] = aten::broadcast_tensors(%16)
|
||||
%18 : Dynamic, %19 : Dynamic = prim::ListUnpack(%17)
|
||||
%hy : Float(*, *), %21 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_1[device=0](%8, %18, %19)
|
||||
return (%cy, %hy, %9, %Wx.1, %11, %Uz.1, %15, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %21);
|
||||
}
|
||||
with prim::FusionGroup_1 = graph(%13 : Float(*, *)
|
||||
%39 : Dynamic
|
||||
%44 : Dynamic) {
|
||||
%45 : Dynamic, %46 : Dynamic, %47 : Dynamic, %48 : Dynamic = prim::FusedChunk[chunks=4, dim=1](%44)
|
||||
%40 : Dynamic, %41 : Dynamic, %42 : Dynamic, %43 : Dynamic = prim::FusedChunk[chunks=4, dim=1](%39)
|
||||
%37 : int = prim::Constant[value=1]()
|
||||
%38 : Float(*, *) = aten::add(%40, %45, %37)
|
||||
%33 : int = prim::Constant[value=1]()
|
||||
%34 : Float(*, *) = aten::add(%41, %46, %33)
|
||||
%29 : int = prim::Constant[value=1]()
|
||||
%30 : Float(*, *) = aten::add(%42, %47, %29)
|
||||
with prim::FusionGroup_0 = graph(%0 : Float(*)
|
||||
%1 : Float(*, *)
|
||||
%2 : Float(*)
|
||||
%3 : Float(*, *)
|
||||
%4 : Float(*)) {
|
||||
%5 : Float(*, *) = aten::mul(%4, %3)
|
||||
%6 : Float(*, *) = aten::mul(%5, %1)
|
||||
%7 : Float(*, *) = aten::mul(%2, %3)
|
||||
%8 : int = prim::Constant[value=1]()
|
||||
%9 : Float(*, *) = aten::add(%6, %7, %8)
|
||||
%10 : Float(*, *) = aten::mul(%0, %1)
|
||||
%11 : int = prim::Constant[value=1]()
|
||||
%12 : Float(*, *) = aten::add(%9, %10, %11)
|
||||
return (%12, %5);
|
||||
}
|
||||
with prim::FusionGroup_1 = graph(%0 : Float(*, *)
|
||||
%1 : Dynamic
|
||||
%2 : Dynamic) {
|
||||
%3 : Float(*, *), %4 : Float(*, *), %5 : Float(*, *), %6 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%2)
|
||||
%7 : Float(*, *), %8 : Float(*, *), %9 : Float(*, *), %10 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%1)
|
||||
%11 : int = prim::Constant[value=1]()
|
||||
%12 : Float(*, *) = aten::add(%7, %3, %11)
|
||||
%13 : int = prim::Constant[value=1]()
|
||||
%14 : Float(*, *) = aten::add(%8, %4, %13)
|
||||
%15 : int = prim::Constant[value=1]()
|
||||
%16 : Float(*, *) = aten::add(%9, %5, %15)
|
||||
%17 : int = prim::Constant[value=1]()
|
||||
%18 : Float(*, *) = aten::add(%10, %6, %17)
|
||||
%ingate.1 : Float(*, *) = aten::sigmoid(%12)
|
||||
%forgetgate.1 : Float(*, *) = aten::sigmoid(%14)
|
||||
%cellgate.1 : Float(*, *) = aten::tanh(%16)
|
||||
%outgate.1 : Float(*, *) = aten::sigmoid(%18)
|
||||
%23 : Float(*, *) = aten::mul(%forgetgate.1, %0)
|
||||
%24 : Float(*, *) = aten::mul(%ingate.1, %cellgate.1)
|
||||
%25 : int = prim::Constant[value=1]()
|
||||
%26 : Float(*, *) = aten::add(%43, %48, %25)
|
||||
%ingate.1 : Float(*, *) = aten::sigmoid(%38)
|
||||
%forgetgate.1 : Float(*, *) = aten::sigmoid(%34)
|
||||
%cellgate.1 : Float(*, *) = aten::tanh(%30)
|
||||
%outgate.1 : Float(*, *) = aten::sigmoid(%26)
|
||||
%14 : Float(*, *) = aten::mul(%forgetgate.1, %13)
|
||||
%11 : Float(*, *) = aten::mul(%ingate.1, %cellgate.1)
|
||||
%7 : int = prim::Constant[value=1]()
|
||||
%cy : Float(*, *) = aten::add(%14, %11, %7)
|
||||
%4 : Float(*, *) = aten::tanh(%cy)
|
||||
%hy : Float(*, *) = aten::mul(%outgate.1, %4)
|
||||
return (%hy, %4, %cy, %outgate.1, %cellgate.1, %forgetgate.1, %ingate.1);
|
||||
%cy : Float(*, *) = aten::add(%23, %24, %25)
|
||||
%27 : Float(*, *) = aten::tanh(%cy)
|
||||
%hy : Float(*, *) = aten::mul(%outgate.1, %27)
|
||||
return (%hy, %27, %cy, %outgate.1, %cellgate.1, %forgetgate.1, %ingate.1);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,8 +4,8 @@ graph(%x : Float(*, *)) {
|
|||
}
|
||||
with prim::FusionGroup_0 = graph(%0 : Float(*, *)) {
|
||||
%z : float = prim::Constant[value=3]()
|
||||
%4 : int = prim::Constant[value=1]()
|
||||
%y : Float(*, *) = aten::add(%0, %z, %4)
|
||||
%2 : Float(*, *) = aten::mul(%0, %y)
|
||||
return (%2);
|
||||
%2 : int = prim::Constant[value=1]()
|
||||
%y : Float(*, *) = aten::add(%0, %z, %2)
|
||||
%4 : Float(*, *) = aten::mul(%0, %y)
|
||||
return (%4);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -152,12 +152,19 @@ def get_fn(file_name, script_path):
|
|||
|
||||
|
||||
def get_execution_plan(graph_executor_state):
|
||||
execution_plans = graph_executor_state.execution_plans.values()
|
||||
execution_plans = list(graph_executor_state.execution_plans.values())
|
||||
num_plans = len(execution_plans)
|
||||
if num_plans != 1:
|
||||
raise RuntimeError('This test assumes this GraphExecutor should '
|
||||
'only have one execution plan, got: {}'.format(num_plans))
|
||||
return list(execution_plans)[0]
|
||||
return execution_plans[0]
|
||||
|
||||
|
||||
def get_grad_executor(plan_state):
|
||||
if len(list(plan_state.graph.nodes())) != 1:
|
||||
raise RuntimeError("Can't get a grad_executor for a non-differentiable graph")
|
||||
grad_executors = list(plan_state.code.grad_executors())
|
||||
return grad_executors[0]
|
||||
|
||||
|
||||
def backward_graph(script_module):
|
||||
|
|
@ -165,10 +172,8 @@ def backward_graph(script_module):
|
|||
raise RuntimeError('Expected ScriptModule')
|
||||
ge_state = script_module.get_debug_state()
|
||||
fwd_plan = get_execution_plan(ge_state)
|
||||
if fwd_plan.grad_executor is None:
|
||||
raise RuntimeError('Error: tried to get grad_executor of function '
|
||||
'that hasn\'t run backward yet.')
|
||||
bwd_plan = get_execution_plan(fwd_plan.grad_executor)
|
||||
grad_executor = get_grad_executor(fwd_plan)
|
||||
bwd_plan = get_execution_plan(grad_executor.get_debug_state())
|
||||
# Running JIT passes requires that we own the graph (with a shared_ptr).
|
||||
# The debug state struct does not own its graph so we make a copy of it.
|
||||
return bwd_plan.graph.copy()
|
||||
|
|
@ -2807,7 +2812,6 @@ a")
|
|||
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
@skipIfRocm
|
||||
@unittest.skip("Temporarily broken")
|
||||
def test_lstm_fusion_cuda(self):
|
||||
inputs = get_lstm_inputs('cuda', training=True)
|
||||
module = self.checkScript(LSTMCellS, inputs)
|
||||
|
|
@ -2820,7 +2824,6 @@ a")
|
|||
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
@skipIfRocm
|
||||
@unittest.skip("Temporarily broken")
|
||||
def test_milstm_fusion_cuda(self):
|
||||
inputs = get_milstm_inputs('cuda', training=True)
|
||||
module = self.checkScript(MiLSTMCell, inputs)
|
||||
|
|
|
|||
|
|
@ -35,6 +35,10 @@ struct InputMetadata {
|
|||
return device_;
|
||||
}
|
||||
|
||||
at::Tensor zeros_like() const {
|
||||
return at::zeros(shape_, at::TensorOptions(*type_, static_cast<int32_t>(device_)));
|
||||
}
|
||||
|
||||
private:
|
||||
const at::Type* type_ = nullptr;
|
||||
at::DimVector shape_;
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ struct ArgumentInfo {
|
|||
int device() const {
|
||||
return device_;
|
||||
}
|
||||
// XXX: It is guaranteed that this will return false when called on non-tensor arguments
|
||||
bool requires_grad() const {
|
||||
return requires_grad_;
|
||||
}
|
||||
|
|
@ -42,7 +43,8 @@ struct ArgumentInfo {
|
|||
private:
|
||||
unsigned is_tensor_ : 1;
|
||||
unsigned defined_ : 1;
|
||||
unsigned requires_grad_ : 6;
|
||||
unsigned requires_grad_ : 1;
|
||||
unsigned : 5;
|
||||
unsigned dim_ : 8;
|
||||
int device_ : 8; // NOTE: this needs to be signed because we use -1 to represent CPU
|
||||
unsigned type_ : 8;
|
||||
|
|
@ -59,6 +61,9 @@ struct ArgumentSpec {
|
|||
int32_t num_inputs = inputs.size();
|
||||
for (int32_t i = 0; i < num_inputs; ++i) {
|
||||
auto & arg = args[i];
|
||||
// Initialize all fields to 0. This is convenient, because e.g.
|
||||
// requires_grad() can be checked even on tensors.
|
||||
std::memset(&arg, 0, sizeof(ArgumentInfo));
|
||||
arg.is_tensor_ = static_cast<unsigned>(inputs[i].isTensor());
|
||||
if (arg.is_tensor_) {
|
||||
at::Tensor t = inputs[i].toTensor();
|
||||
|
|
|
|||
|
|
@ -176,6 +176,12 @@ struct Attributes {
|
|||
|
||||
#undef CREATE_ACCESSOR
|
||||
|
||||
// Our Graphs are not very const-correct, so we need to allow returning
|
||||
// non-const references too
|
||||
GraphAttr::ValueType& g(Symbol name) {
|
||||
return get<GraphAttr>(name);
|
||||
}
|
||||
|
||||
// does not use CREATE_ACCESSOR because we need additional asserts
|
||||
Derived* t_(Symbol name, TensorAttr::ConstructorType v) {
|
||||
JIT_ASSERT(!v.defined() || !v.is_variable());
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ bool isDifferentiable(Node * n) {
|
|||
"aten::neg(Tensor self) -> Tensor",
|
||||
"aten::type_as(Tensor self, Tensor other) -> Tensor",
|
||||
"aten::unsqueeze(Tensor self, int dim) -> Tensor",
|
||||
"aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta, Scalar alpha) -> Tensor",
|
||||
"aten::mm(Tensor self, Tensor mat2) -> Tensor",
|
||||
"aten::lt(Tensor self, Tensor other) -> Tensor",
|
||||
"aten::le(Tensor self, Tensor other) -> Tensor",
|
||||
|
|
@ -97,6 +98,9 @@ static std::vector<Value*> gradientForNode(Node* node, ArrayRef<Value*> grad_val
|
|||
if (node->matches("aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor")) {
|
||||
return {grads.at(0), grads.at(0) * node->namedInput(attr::alpha), nullptr};
|
||||
|
||||
} else if (node->kind() == prim::AutogradAdd) {
|
||||
return {grads.at(0), grads.at(0)};
|
||||
|
||||
} else if (node->matches("aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor")) {
|
||||
return {grads.at(0), -grads.at(0) * node->namedInput(attr::alpha), nullptr};
|
||||
|
||||
|
|
@ -136,29 +140,14 @@ static std::vector<Value*> gradientForNode(Node* node, ArrayRef<Value*> grad_val
|
|||
} else if (node->matches("aten::unsqueeze(Tensor self, int dim) -> Tensor")) {
|
||||
return {grads.at(0).squeeze(node->namedInput(attr::dim)), nullptr};
|
||||
|
||||
} else if (node->matches("aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta, Scalar alpha) -> Tensor")) {
|
||||
return {grads.at(0) * node->namedInput(attr::beta),
|
||||
grads.at(0).mm(inputs.at(2).t()) * node->namedInput(attr::alpha),
|
||||
inputs.at(1).t().mm(grads.at(0)) * node->namedInput(attr::alpha),
|
||||
nullptr, nullptr};
|
||||
|
||||
} else if (node->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
|
||||
SymbolicVariable dmat1, dmat2;
|
||||
if (auto type = inputs.at(0).value()->type()->cast<CompleteTensorType>()) {
|
||||
auto sizes = type->sizes(), strides = type->strides();
|
||||
if (strides.at(0) == 1 && strides.at(1) == sizes.at(0)) {
|
||||
dmat1 = inputs.at(1).mm(grads.at(0).t()).t();
|
||||
} else {
|
||||
dmat1 = grads.at(0).mm(inputs.at(1).t());
|
||||
}
|
||||
} else {
|
||||
dmat1 = grads.at(0).mm(inputs.at(1).t());
|
||||
}
|
||||
if (auto type = inputs.at(1).value()->type()->cast<CompleteTensorType>()) {
|
||||
auto sizes = type->sizes(), strides = type->strides();
|
||||
if (strides.at(0) == 1 && strides.at(1) == sizes.at(0)) {
|
||||
dmat2 = grads.at(0).t().mm(inputs.at(0)).t();
|
||||
} else {
|
||||
dmat2 = inputs.at(0).t().mm(grads.at(0));
|
||||
}
|
||||
} else {
|
||||
dmat2 = inputs.at(0).t().mm(grads.at(0));
|
||||
}
|
||||
return {dmat1, dmat2};
|
||||
return {grads.at(0).mm(inputs.at(1).t()), inputs.at(0).t().mm(grads.at(0))};
|
||||
|
||||
} else if (node->matches("aten::expand(Tensor self, int[] size, *, int implicit) -> Tensor")) {
|
||||
const auto& input_sizes = inputs.at(0).sizes();
|
||||
|
|
@ -364,7 +353,9 @@ static ReverseDetails addReverseInline(Gradient& grad_desc,
|
|||
JIT_ASSERT(grad_inputs.size() == node->inputs().size());
|
||||
for (size_t i = 0, num_inputs = grad_inputs.size(); i < num_inputs; ++i) {
|
||||
if (!requires_grad(inputs[i])) continue;
|
||||
JIT_ASSERT(grad_inputs[i]);
|
||||
// NB: Not returning a gradient w.r.t. a value that requires grad is normal if the
|
||||
// input is non-differentiable. This happens e.g. in the aten::type_as case.
|
||||
if (!grad_inputs[i]) continue;
|
||||
set_grad(inputs[i], grad_inputs[i]);
|
||||
}
|
||||
}
|
||||
|
|
@ -374,6 +365,11 @@ static ReverseDetails addReverseInline(Gradient& grad_desc,
|
|||
Value * input = inputs[i];
|
||||
if (!requires_grad(input))
|
||||
continue;
|
||||
// NB: Not having a gradient defined w.r.t. an input to the graph which requires grad
|
||||
// can happen and is not an error. It might have been used only in non-differentiable
|
||||
// contexts (e.g. as second input to aten::type_as). In that case we simply ignore it
|
||||
// as an output, because it won't ever produce any meaningful values.
|
||||
if (grad_map.count(input) == 0) continue;
|
||||
reverse_block->registerOutput(get_grad(input));
|
||||
grad_desc.df_output_vjps.push_back(i);
|
||||
}
|
||||
|
|
@ -557,7 +553,7 @@ Gradient differentiate(std::shared_ptr<Graph>& graph, const std::vector<bool>& r
|
|||
// addReverseInline has to call gradientForNode if *any* of the outputs
|
||||
// require grad, but it will emit vjps for *all* outputs. Use DCE to remove
|
||||
// unnecessary nodes.
|
||||
EliminateDeadCode(grad_desc.f);
|
||||
EliminateDeadCode(rev_info.reverse_block);
|
||||
// Fills in f, df, f_real_outputs, df_input_captures,
|
||||
// modifies df_input_vjps (new vjps are added for temporaries)
|
||||
lambdaLiftReverse(grad_desc, rev_info);
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@
|
|||
#include "torch/csrc/jit/passes/constant_propagation.h"
|
||||
#include "torch/csrc/jit/symbolic_variable.h"
|
||||
#include "torch/csrc/jit/ivalue.h"
|
||||
#include "torch/csrc/jit/custom_operator.h"
|
||||
|
||||
#include "torch/csrc/autograd/edge.h"
|
||||
#include "torch/csrc/autograd/function.h"
|
||||
|
|
@ -45,14 +46,34 @@ using tensor_list = std::vector<at::Tensor>;
|
|||
using Variable = autograd::Variable;
|
||||
using autograd::variable_list;
|
||||
|
||||
// this type is in ExecutionPlan to run its Gradient if it is
|
||||
// specified. It has a list of inputs captured by ExecutionPlan that
|
||||
// it concats with inputs to form the full set of inputs to graph.
|
||||
// see struct Gradient for a description of how the derivative graph
|
||||
// is constructed and what variables are captured.
|
||||
struct ExecutionPlanAutogradFunction : public autograd::Function {
|
||||
ExecutionPlanAutogradFunction(GraphExecutor graph, size_t capture_size)
|
||||
: graph(std::move(graph)) {
|
||||
struct ExecutionPlan {
|
||||
ExecutionPlan() = default;
|
||||
ExecutionPlan(std::shared_ptr<Graph> graph)
|
||||
: code(graph)
|
||||
, graph(std::move(graph)) {}
|
||||
|
||||
void run(Stack& stack) const {
|
||||
return InterpreterState(code).runOneStage(stack);
|
||||
}
|
||||
|
||||
operator bool() const {
|
||||
return static_cast<bool>(graph);
|
||||
}
|
||||
|
||||
ExecutionPlanState getDebugState() {
|
||||
ExecutionPlanState state;
|
||||
state.code = &code;
|
||||
state.graph = graph.get();
|
||||
return state;
|
||||
}
|
||||
|
||||
Code code;
|
||||
std::shared_ptr<Graph> graph;
|
||||
};
|
||||
|
||||
struct DifferentiableGraphBackward : public autograd::Function {
|
||||
DifferentiableGraphBackward(GraphExecutor executor, size_t capture_size)
|
||||
: executor(std::move(executor)) {
|
||||
is_var_capture.reserve(capture_size);
|
||||
var_captures.reserve(capture_size);
|
||||
ivalue_captures.reserve(capture_size);
|
||||
|
|
@ -74,10 +95,28 @@ struct ExecutionPlanAutogradFunction : public autograd::Function {
|
|||
++ivalue_capture_it;
|
||||
}
|
||||
}
|
||||
graph.run(stack);
|
||||
return fmap(stack, [](IValue & val) {
|
||||
return autograd::Variable(std::move(val).toTensor());
|
||||
});
|
||||
|
||||
executor.run(stack);
|
||||
JIT_ASSERT(stack.size() == num_outputs());
|
||||
|
||||
variable_list outputs;
|
||||
outputs.reserve(num_outputs());
|
||||
for (size_t i = 0; i < num_outputs(); ++i) {
|
||||
if (should_compute_output(i)) {
|
||||
auto output = std::move(stack[i]).toTensor();
|
||||
const auto & edge = next_edge(i);
|
||||
if (output.defined()) {
|
||||
outputs.push_back(std::move(output));
|
||||
} else if (edge.is_valid()) {
|
||||
outputs.push_back(edge.function->input_metadata(edge.input_nr).zeros_like());
|
||||
} else {
|
||||
outputs.emplace_back();
|
||||
}
|
||||
} else {
|
||||
outputs.emplace_back();
|
||||
}
|
||||
}
|
||||
return outputs;
|
||||
}
|
||||
|
||||
void capture(const IValue & val, bool is_output) {
|
||||
|
|
@ -91,7 +130,7 @@ struct ExecutionPlanAutogradFunction : public autograd::Function {
|
|||
}
|
||||
private:
|
||||
friend struct ExecutionPlan;
|
||||
GraphExecutor graph;
|
||||
GraphExecutor executor;
|
||||
|
||||
// INVARIANT: is_var_capture.size() == var_captures.size() + ivalue_captures.size()
|
||||
std::vector<bool> is_var_capture;
|
||||
|
|
@ -104,74 +143,17 @@ private:
|
|||
// This will unwrap Variables, run the plan, and re-wrap them.
|
||||
// It can optionally also have a gradient which is hooked up
|
||||
// to the output Variables if present.
|
||||
struct ExecutionPlan {
|
||||
ExecutionPlan(std::shared_ptr<Graph>& graph)
|
||||
: f(graph),
|
||||
graph(graph),
|
||||
num_inputs(graph->inputs().size()),
|
||||
num_outputs(graph->outputs().size()) {}
|
||||
ExecutionPlan(std::shared_ptr<Graph>& graph, Gradient grad)
|
||||
: f(graph),
|
||||
graph(graph),
|
||||
struct DifferentiableGraphOp {
|
||||
DifferentiableGraphOp(Gradient grad)
|
||||
: f(grad.f),
|
||||
grad(std::move(grad)),
|
||||
grad_executor(this->grad.df),
|
||||
num_inputs(graph->inputs().size()),
|
||||
num_outputs(graph->outputs().size()) {}
|
||||
|
||||
void run(Stack & stack) const {
|
||||
if (grad) {
|
||||
return runWithGrad(stack);
|
||||
}
|
||||
InterpreterState(f).runOneStage(stack);
|
||||
}
|
||||
|
||||
std::shared_ptr<Graph> get_graph() const {
|
||||
return graph;
|
||||
}
|
||||
|
||||
ExecutionPlanState getDebugState() {
|
||||
ExecutionPlanState state;
|
||||
state.f = &f;
|
||||
state.graph = graph.get();
|
||||
if (grad) {
|
||||
state.grad = &grad;
|
||||
state.grad_executor = std::unique_ptr<GraphExecutorState>(
|
||||
new GraphExecutorState(grad_executor.getDebugState()));
|
||||
} else {
|
||||
state.grad = nullptr;
|
||||
state.grad_executor.reset();
|
||||
}
|
||||
return state;
|
||||
}
|
||||
|
||||
private:
|
||||
void detachVariables(Stack & stack) const {
|
||||
// It would be nice to use an ArrayRef here, but unfortunately those can only
|
||||
// return const references, so we need to do a bunch of indexing ourselves.
|
||||
const int64_t stack_size = stack.size();
|
||||
const int64_t stack_offset = stack_size - num_inputs;
|
||||
for (int64_t i = stack_offset; i < stack_size; ++i) {
|
||||
auto & v = stack[i];
|
||||
if (!v.isTensor()) continue;
|
||||
auto t = std::move(v).toTensor();
|
||||
v = IValue{t.defined() ? autograd::as_variable_ref(t).detach() : std::move(t)};
|
||||
}
|
||||
}
|
||||
// Capture (save) inputs that would be required to subsequently run backwards
|
||||
void captureInputs(ExecutionPlanAutogradFunction & grad_fn, at::ArrayRef<IValue> inputs) const {
|
||||
for (size_t offset : grad.df_input_captured_inputs) {
|
||||
grad_fn.capture(inputs[offset], /*is_output*/false);
|
||||
}
|
||||
}
|
||||
void captureOutputs(ExecutionPlanAutogradFunction & grad_fn, at::ArrayRef<IValue> outputs) const {
|
||||
for (size_t offset : grad.df_input_captured_outputs) {
|
||||
grad_fn.capture(outputs[offset], /*is_output*/true);
|
||||
}
|
||||
}
|
||||
num_inputs(this->grad.f->inputs().size()),
|
||||
num_outputs(this->grad.f->outputs().size()) {}
|
||||
|
||||
// XXX: keep in mind that stack can be larger than the inputs we need!
|
||||
void runWithGrad(Stack & stack) const {
|
||||
auto grad_fn = std::make_shared<ExecutionPlanAutogradFunction>(grad_executor,
|
||||
int operator()(Stack & stack) const {
|
||||
auto grad_fn = std::make_shared<DifferentiableGraphBackward>(grad_executor,
|
||||
grad.df_input_captured_inputs.size() + grad.df_input_captured_outputs.size());
|
||||
|
||||
{
|
||||
|
|
@ -179,7 +161,7 @@ private:
|
|||
// hook up the outputs of df to the gradient functions of the inputs that require gradients
|
||||
for(auto idx : grad.df_output_vjps) {
|
||||
auto v = Variable(inputs[idx].toTensor());
|
||||
grad_fn->add_next_edge(v.gradient_edge());
|
||||
grad_fn->add_next_edge(v.defined() ? v.gradient_edge() : autograd::Edge{});
|
||||
}
|
||||
captureInputs(*grad_fn, inputs);
|
||||
}
|
||||
|
|
@ -201,8 +183,15 @@ private:
|
|||
// reallocate variables that were already created in wrapTensors. We
|
||||
// should add an API for this.
|
||||
Variable output = outputs[idx].toTensor();
|
||||
// NB: since our requires_grad setting is only a heuristic we might end up
|
||||
// wanting to differentiate through integral tensors, which is generally a
|
||||
// hard error in autograd.
|
||||
if (at::isFloatingType(output.type().scalarType())) {
|
||||
autograd::create_gradient_edge(output, grad_fn);
|
||||
output.set_requires_grad(true);
|
||||
} else {
|
||||
grad_fn->add_input_metadata(autograd::Function::undefined_input{});
|
||||
}
|
||||
}
|
||||
captureOutputs(*grad_fn, outputs);
|
||||
// drop the temporary outputs so that we return the same number of
|
||||
|
|
@ -210,22 +199,89 @@ private:
|
|||
const size_t num_temporary_outputs = num_outputs - grad.f_real_outputs;
|
||||
stack.erase(stack.end() - num_temporary_outputs, stack.end());
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
private:
|
||||
friend GraphExecutor* detail::getGradExecutor(Operation& op);
|
||||
|
||||
void detachVariables(Stack & stack) const {
|
||||
// It would be nice to use an ArrayRef here, but unfortunately those can only
|
||||
// return const references, so we need to do a bunch of indexing ourselves.
|
||||
const int64_t stack_size = stack.size();
|
||||
const int64_t stack_offset = stack_size - num_inputs;
|
||||
for (int64_t i = stack_offset; i < stack_size; ++i) {
|
||||
auto & v = stack[i];
|
||||
if (!v.isTensor()) continue;
|
||||
auto t = std::move(v).toTensor();
|
||||
v = IValue{t.defined() ? autograd::as_variable_ref(t).detach() : std::move(t)};
|
||||
}
|
||||
}
|
||||
// Capture (save) inputs that would be required to subsequently run backwards
|
||||
void captureInputs(DifferentiableGraphBackward & grad_fn, at::ArrayRef<IValue> inputs) const {
|
||||
for (size_t offset : grad.df_input_captured_inputs) {
|
||||
grad_fn.capture(inputs[offset], /*is_output*/false);
|
||||
}
|
||||
}
|
||||
void captureOutputs(DifferentiableGraphBackward & grad_fn, at::ArrayRef<IValue> outputs) const {
|
||||
for (size_t offset : grad.df_input_captured_outputs) {
|
||||
grad_fn.capture(outputs[offset], /*is_output*/true);
|
||||
}
|
||||
}
|
||||
|
||||
Code f;
|
||||
// optimized graph for debugging and testing
|
||||
std::shared_ptr<Graph> graph;
|
||||
// description of gradient as a graph
|
||||
Gradient grad; // if(grad) is false when this is unused
|
||||
// executor for df, including code caches
|
||||
Gradient grad;
|
||||
GraphExecutor grad_executor;
|
||||
|
||||
const size_t num_inputs;
|
||||
const size_t num_outputs;
|
||||
};
|
||||
|
||||
void packGradient(Gradient gradient, Node *dnode) {
|
||||
JIT_ASSERT(dnode->kind() == prim::DifferentiableGraph);
|
||||
dnode->g_(attr::Subgraph, gradient.f)
|
||||
->g_(attr::ReverseSubgraph, gradient.df)
|
||||
->i_(attr::f_real_outputs, gradient.f_real_outputs)
|
||||
->is_(attr::df_input_vjps, fmap<int64_t>(gradient.df_input_vjps))
|
||||
->is_(attr::df_input_captured_inputs, fmap<int64_t>(gradient.df_input_captured_inputs))
|
||||
->is_(attr::df_input_captured_outputs, fmap<int64_t>(gradient.df_input_captured_outputs))
|
||||
->is_(attr::df_output_vjps, fmap<int64_t>(gradient.df_output_vjps));
|
||||
}
|
||||
|
||||
Gradient getGradient(Node *n) {
|
||||
JIT_ASSERT(n->kind() == prim::DifferentiableGraph);
|
||||
Gradient grad;
|
||||
grad.f = n->g(attr::Subgraph);
|
||||
grad.df = n->g(attr::ReverseSubgraph);
|
||||
grad.f_real_outputs = n->i(attr::f_real_outputs);
|
||||
grad.df_input_vjps = fmap<size_t>(n->is(attr::df_input_vjps));
|
||||
grad.df_input_captured_inputs = fmap<size_t>(n->is(attr::df_input_captured_inputs));
|
||||
grad.df_input_captured_outputs = fmap<size_t>(n->is(attr::df_input_captured_outputs));
|
||||
grad.df_output_vjps = fmap<size_t>(n->is(attr::df_output_vjps));
|
||||
return grad;
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
RegisterOperators reg_graph_executor_ops({
|
||||
Operator(
|
||||
prim::DifferentiableGraph,
|
||||
[](Node *n) -> Operation {
|
||||
return DifferentiableGraphOp(getGradient(n));
|
||||
})
|
||||
});
|
||||
|
||||
namespace detail {
|
||||
|
||||
GraphExecutor* getGradExecutor(Operation& op) {
|
||||
if (auto diff_op = op.target<DifferentiableGraphOp>()) {
|
||||
return &diff_op->grad_executor;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// a Graph can be created via tracing, or via a language-based frontend
|
||||
// GraphExecutor runs it. It can run the same graph on many different sizes
|
||||
// and different requires_grad states, and handles specializations for each situation.
|
||||
|
|
@ -233,68 +289,49 @@ private:
|
|||
// tracing concerns separated.
|
||||
struct GraphExecutorImpl {
|
||||
|
||||
GraphExecutorImpl(std::shared_ptr<Graph> graph, bool optimize, bool symbolically_differentiable)
|
||||
: graph(std::move(graph))
|
||||
static std::shared_ptr<Graph> prepareGraph(std::shared_ptr<Graph> graph) {
|
||||
auto copy = graph->copy();
|
||||
EraseShapeInformation(*copy);
|
||||
return copy;
|
||||
}
|
||||
|
||||
GraphExecutorImpl(std::shared_ptr<Graph> graph, bool optimize)
|
||||
: graph(prepareGraph(graph))
|
||||
, optimize(optimize)
|
||||
, num_inputs(this->graph->inputs().size())
|
||||
, num_outputs(this->graph->outputs().size())
|
||||
, symbolically_differentiable(symbolically_differentiable)
|
||||
, may_introduce_gradient(calcMayIntroduceGradient(this->graph->block())) {}
|
||||
GraphExecutorImpl(std::shared_ptr<Graph> graph, bool optimize)
|
||||
: GraphExecutorImpl(graph, optimize, isDifferentiable(*graph)) {}
|
||||
, num_outputs(this->graph->outputs().size()) {}
|
||||
|
||||
// entry point where execution begins
|
||||
void run(Stack & stack) {
|
||||
if(stack.size() < num_inputs) {
|
||||
std::stringstream ss;
|
||||
ss << "expected " << num_inputs << " inputs but got " << stack.size() << " inputs";
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
auto inputs = last(stack, num_inputs);
|
||||
AT_CHECK(stack.size() >= num_inputs, "expected ", num_inputs, " inputs, but got only ", stack.size());
|
||||
|
||||
// the tracer has called a graph executor
|
||||
// there is no need to optimize, but we do need to splice the graph of
|
||||
// this excutor into the trace. Otherwise we might unroll control-flow
|
||||
// operations.
|
||||
if(tracer::isTracing()) {
|
||||
return runTraced(stack);
|
||||
}
|
||||
|
||||
// this is the fallback pathway, when we cannot differentiate
|
||||
if(!optimize || (!symbolically_differentiable && needsGradient(inputs))) {
|
||||
return runFallback(stack);
|
||||
}
|
||||
|
||||
// either we can symbolically differentiate, or we do not need a gradient.
|
||||
// go down the route where we treat the inputs as tensors
|
||||
// and fully optimize
|
||||
auto & implementation = getOrCompile(inputs);
|
||||
return implementation.run(stack);
|
||||
auto & execution_plan = optimize ? getOrCompile(stack) : getOrCompileFallback();
|
||||
return execution_plan.run(stack);
|
||||
}
|
||||
|
||||
std::shared_ptr<Graph> graphFor(const Stack& stack) const {
|
||||
auto inputs = last(stack, num_inputs);
|
||||
ArgumentSpec spec(autograd::GradMode::is_enabled(), inputs);
|
||||
|
||||
if (!optimize || (!symbolically_differentiable && needsGradient(inputs))) {
|
||||
JIT_ASSERTM(autograd_fallback_graph, "No graph found for given inputs");
|
||||
return autograd_fallback_graph;
|
||||
if (!optimize) {
|
||||
AT_CHECK(fallback, "No graph found for given inputs");
|
||||
return fallback.graph;
|
||||
}
|
||||
|
||||
auto it = plan_cache.find(spec);
|
||||
JIT_ASSERTM(it != plan_cache.end(), "No graph found for given inputs");
|
||||
return it->second.get_graph();
|
||||
AT_CHECK(it != plan_cache.end(), "No graph found for given inputs");
|
||||
return it->second.graph;
|
||||
}
|
||||
|
||||
GraphExecutorState getDebugState() {
|
||||
GraphExecutorState state;
|
||||
state.graph = graph.get();
|
||||
if (autograd_fallback) {
|
||||
state.autograd_fallback = &autograd_fallback;
|
||||
state.autograd_fallback_graph = autograd_fallback_graph.get();
|
||||
} else {
|
||||
state.autograd_fallback = nullptr;
|
||||
state.autograd_fallback_graph = nullptr;
|
||||
if (fallback) {
|
||||
state.fallback = fallback.getDebugState();
|
||||
}
|
||||
for (auto & entry : plan_cache) {
|
||||
state.execution_plans.emplace(entry.first, entry.second.getDebugState());
|
||||
|
|
@ -305,6 +342,121 @@ struct GraphExecutorImpl {
|
|||
private:
|
||||
friend struct GraphExecutor;
|
||||
|
||||
const ExecutionPlan & getOrCompileFallback() {
|
||||
std::lock_guard<std::mutex> lock(compile_mutex);
|
||||
if(!fallback) {
|
||||
auto graph_ = graph->copy();
|
||||
runRequiredPasses(graph_);
|
||||
fallback = ExecutionPlan(graph_);
|
||||
}
|
||||
return fallback;
|
||||
}
|
||||
|
||||
const ExecutionPlan & getOrCompile(const Stack& stack) {
|
||||
// outside lock guard, to minimize the time holding the lock on the fast path
|
||||
// ArgumentSpec even computes its hashCode here.
|
||||
ArgumentSpec spec(autograd::GradMode::is_enabled(), last(stack, num_inputs));
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(compile_mutex);
|
||||
auto it = plan_cache.find(spec);
|
||||
if (it != plan_cache.end())
|
||||
return it->second;
|
||||
auto plan = compileSpec(spec);
|
||||
auto r = plan_cache.emplace(std::move(spec), std::move(plan));
|
||||
return r.first->second;
|
||||
}
|
||||
}
|
||||
|
||||
ExecutionPlan compileSpec(const ArgumentSpec & spec) {
|
||||
auto opt_graph = graph->copy();
|
||||
|
||||
// Phase 1. Specialize to input definedness (this is very important for
|
||||
// gradient graphs), and run required passes to bring the graph
|
||||
// to an executable form.
|
||||
specializeGrad(opt_graph, spec);
|
||||
runRequiredPasses(opt_graph);
|
||||
|
||||
// Phase 2. Propagate detailed information about the spec through the
|
||||
// graph (enabled more specializations in later passes).
|
||||
PropagateInputShapes(*opt_graph, spec);
|
||||
|
||||
// Phase 3. Run differentiable optimizations (i.e. simple graph rewrites that
|
||||
// we can still execute using autograd).
|
||||
runOptimization(opt_graph, spec);
|
||||
|
||||
// Phase 4. If this graph will be differentiated, we need to slice out the
|
||||
// symbolically differentiable subgraphs for further optimizations.
|
||||
// Phase 5. Apply non-differentiable optimizations to the graphs we've found
|
||||
// (or the whole grpah if we know we won't need its derivative).
|
||||
if (needsGradient(opt_graph, spec)) {
|
||||
auto diff_nodes = CreateAutodiffSubgraphs(*opt_graph);
|
||||
for (Node * dnode : diff_nodes) {
|
||||
// XXX: we don't have requires_grad information on the intermediate values,
|
||||
// so we conservatively assume it's always true (on tensor inputs).
|
||||
auto diff_graph = std::move(dnode->g(attr::Subgraph));
|
||||
auto requires_grads = fmap(diff_graph->inputs(), [](Value* v) {
|
||||
// NB: only floating-point inputs can have requires_grad=True. If we
|
||||
// don't have type information, we have to assume that it's true.
|
||||
if (auto tensor_type = v->type()->cast<TensorType>()) {
|
||||
return at::isFloatingType(tensor_type->scalarType());
|
||||
}
|
||||
return v->type()->isSubtypeOf(DynamicType::get());
|
||||
});
|
||||
Gradient gradient = differentiate(diff_graph, requires_grads);
|
||||
runNondiffOptimization(gradient.f);
|
||||
packGradient(gradient, dnode);
|
||||
}
|
||||
} else {
|
||||
runNondiffOptimization(opt_graph);
|
||||
}
|
||||
return ExecutionPlan(opt_graph);
|
||||
}
|
||||
|
||||
void specializeGrad(std::shared_ptr<Graph>& graph, const ArgumentSpec& spec) {
|
||||
std::vector<bool> defined;
|
||||
for (size_t i = 0; i < spec.size(); ++i)
|
||||
defined.push_back(spec.at(i).defined());
|
||||
specializeUndef(*graph, defined);
|
||||
}
|
||||
|
||||
void runOptimization(std::shared_ptr<Graph>& graph, const ArgumentSpec& spec) {
|
||||
EliminateDeadCode(graph);
|
||||
EliminateCommonSubexpression(graph);
|
||||
UnrollLoops(graph);
|
||||
ConstantPropagation(graph);
|
||||
PeepholeOptimize(graph);
|
||||
CheckInplace(graph);
|
||||
BatchMM(graph);
|
||||
}
|
||||
|
||||
void runNondiffOptimization(std::shared_ptr<Graph>& graph) {
|
||||
FuseGraph(graph);
|
||||
}
|
||||
|
||||
static bool needsGradient(const std::shared_ptr<const Graph>& graph, const ArgumentSpec& spec) {
|
||||
if (!autograd::GradMode::is_enabled())
|
||||
return false;
|
||||
if (mayIntroduceGradient(graph->block()))
|
||||
return true;
|
||||
for(size_t i = 0; i < spec.size(); ++i) {
|
||||
if(spec.at(i).requires_grad())
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool mayIntroduceGradient(const Block* b) {
|
||||
for (const Node* n : b->nodes()) {
|
||||
if (n->kind() == prim::PythonOp)
|
||||
return true;
|
||||
for (const Block* bb : n->blocks()) {
|
||||
if (mayIntroduceGradient(bb))
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void runTraced(Stack & stack) {
|
||||
auto state = tracer::getTracingState();
|
||||
auto inputs = last(stack, num_inputs);
|
||||
|
|
@ -313,25 +465,18 @@ private:
|
|||
});
|
||||
|
||||
ArgumentSpec spec(autograd::GradMode::is_enabled(), inputs);
|
||||
runFallback(stack);
|
||||
// NB: we could just run the fallback in here and call it a day, but that would loose all
|
||||
// the control flow information we have in the graph. Thus, we run the fallback to
|
||||
// get the correct output values, but we will override the tracing states later.
|
||||
getOrCompileFallback().run(stack);
|
||||
|
||||
auto all_dynamic = [](const at::ArrayRef<Value*> xs) {
|
||||
for(Value* x : xs) {
|
||||
if(x->type()->kind() != TypeKind::DynamicType)
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
};
|
||||
// Traces always have types propagated through them, so we make sure to
|
||||
// also propagate types through the graph we are inserting here.
|
||||
// However, this->graph itself may already have been generated with
|
||||
// tracing and so we only do the type propgation if no concrete types have
|
||||
// been set.
|
||||
auto local_graph = this->graph;
|
||||
if(all_dynamic(local_graph->inputs()) && all_dynamic(local_graph->outputs())) {
|
||||
local_graph = this->graph->copy();
|
||||
auto local_graph = this->graph->copy();
|
||||
PropagateInputShapes(*local_graph, spec);
|
||||
}
|
||||
auto output_values = script::inlineCallTo(*state->graph, *local_graph, input_values);
|
||||
|
||||
auto outputs = last(stack, num_outputs);
|
||||
|
|
@ -343,147 +488,32 @@ private:
|
|||
}
|
||||
}
|
||||
|
||||
void runFallback(Stack & stack) {
|
||||
auto & fb = getOrCreateAutogradFallback();
|
||||
InterpreterState(fb).runOneStage(stack);
|
||||
}
|
||||
|
||||
static bool calcMayIntroduceGradient(Block* b) {
|
||||
for(Node* n : b->nodes()) {
|
||||
if(n->kind() == prim::PythonOp)
|
||||
return true;
|
||||
for(Block* bb : n->blocks()) {
|
||||
if(calcMayIntroduceGradient(bb))
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
bool needsGradient(at::ArrayRef<IValue> inputs) const {
|
||||
if (!autograd::GradMode::is_enabled()) {
|
||||
return false;
|
||||
}
|
||||
if (may_introduce_gradient)
|
||||
return true;
|
||||
for (const IValue & value : inputs) {
|
||||
if (!value.isTensor()) continue;
|
||||
auto t = value.toTensor();
|
||||
if (t.defined() && autograd::as_variable_ref(t).requires_grad())
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
const Code & getOrCreateAutogradFallback() {
|
||||
std::lock_guard<std::mutex> lock(compile_mutex);
|
||||
if(autograd_fallback) {
|
||||
return autograd_fallback;
|
||||
}
|
||||
auto graph_ = graph->copy();
|
||||
runRequiredPasses(graph_);
|
||||
if(optimize) {
|
||||
runOptimization(graph_, /*graphMustSupportVariables=*/true);
|
||||
if(!isDifferentiable(*graph_)) {
|
||||
EraseShapeInformation(*graph_);
|
||||
CreateAutodiffSubgraphs(*graph_);
|
||||
}
|
||||
}
|
||||
autograd_fallback_graph = graph_;
|
||||
autograd_fallback = Code(graph_);
|
||||
return autograd_fallback;
|
||||
}
|
||||
const ExecutionPlan & getOrCompile(at::ArrayRef<IValue> inputs) {
|
||||
// outside lock guard, to minimize the time holding the lock on the fast path
|
||||
// ArgumentSpec even computes its hashCode here.
|
||||
ArgumentSpec spec(autograd::GradMode::is_enabled(), inputs);
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(compile_mutex);
|
||||
auto it = plan_cache.find(spec);
|
||||
if(it != plan_cache.end())
|
||||
return it->second;
|
||||
auto plan = compileSpec(spec);
|
||||
auto r = plan_cache.emplace(std::move(spec), std::move(plan));
|
||||
return r.first->second;
|
||||
}
|
||||
}
|
||||
|
||||
bool argumentSpecRequiresGradient(const ArgumentSpec & spec) {
|
||||
for(size_t i = 0; i < spec.size(); ++i) {
|
||||
if(spec.at(i).requires_grad())
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
ExecutionPlan compileSpec(const ArgumentSpec & spec) {
|
||||
auto graph_ = graph->copy();
|
||||
|
||||
specializeToSpec(graph_, spec);
|
||||
|
||||
if(!argumentSpecRequiresGradient(spec)) {
|
||||
runOptimization(graph_, /*graphMustSupportVariables=*/false);
|
||||
return ExecutionPlan(graph_);
|
||||
}
|
||||
JIT_ASSERT(symbolically_differentiable);
|
||||
|
||||
std::vector<bool> requires_grads;
|
||||
requires_grads.reserve(spec.size());
|
||||
for(size_t i = 0; i < spec.size(); i++)
|
||||
requires_grads.push_back(spec.at(i).requires_grad());
|
||||
|
||||
Gradient gradient = differentiate(graph_, requires_grads);
|
||||
graph_ = gradient.f;
|
||||
runOptimization(graph_, /*graphMustSupportVariables=*/false);
|
||||
return ExecutionPlan(graph_, std::move(gradient));
|
||||
}
|
||||
// the unoptimized starting graph
|
||||
// this is never mutated
|
||||
// The unoptimized starting graph. This field is effectively const, but we can't make it so
|
||||
// because Graph::copy() is not const (and making it const is not that easy at this point).
|
||||
std::shared_ptr<Graph> graph;
|
||||
|
||||
// true - do everything we can to make this graph run fast
|
||||
// false - do not modifiy the graph at all and just use the interpreter
|
||||
// to run the graph. Useful for debugging correctness issues in the implementation
|
||||
// If false, we'll run the graph as we get it, without any optimizations. Useful
|
||||
// for debugging.
|
||||
const bool optimize;
|
||||
const size_t num_inputs;
|
||||
const size_t num_outputs;
|
||||
|
||||
// GraphExecutor optimizes more aggresively when we _know_ the graph will be
|
||||
// symbolically differentiable.
|
||||
bool symbolically_differentiable;
|
||||
// Populated only when optimize is false (and in that case plan_cache will be unused).
|
||||
// The compiled version of graph.
|
||||
ExecutionPlan fallback;
|
||||
|
||||
// some ops, including python operations, can intorduce requires_grad=True
|
||||
// variables even though no inputs to this graph are availiable, if
|
||||
// the graph includes those operators then needGradient must be true
|
||||
// regardles of input state.
|
||||
bool may_introduce_gradient;
|
||||
|
||||
// when this graph has some parts that are not symbolically_differentable,
|
||||
// but some input does require a derivative, we create and use autograd_fallback,
|
||||
// which wraps up the fully differentiable subgraphs, and then runs the outer
|
||||
// graph through autograd.
|
||||
// Since we can't optimize black box functions anyway, there is only one fallback path,
|
||||
// and it must work on all sizes (so no optimizations that inspect sizes can run on it)
|
||||
std::shared_ptr<Graph> autograd_fallback_graph;
|
||||
Code autograd_fallback;
|
||||
|
||||
// optimizable code paths, used when we can differentiate or when no derivative is needed
|
||||
// Spec describes input conditions, Plan describes how to execute them.
|
||||
// Mapping from argument configurations to optimized versions of the graph that are
|
||||
// specialized to the spec.
|
||||
std::unordered_map<ArgumentSpec, ExecutionPlan> plan_cache;
|
||||
|
||||
// GraphExecutor can be accessed from multiple thread so
|
||||
// anytime we are checking or updating the autograd_fallback or
|
||||
// plan_cache, we must hold the compile mutex.
|
||||
// along the fast path (no compilation) code should
|
||||
// hold this for as little time as possible.
|
||||
// GraphExecutors can be accessed from multiple threads, so this thread needs to be
|
||||
// held every time we access the fallback or plan_cache.
|
||||
std::mutex compile_mutex;
|
||||
};
|
||||
|
||||
GraphExecutor::GraphExecutor(std::shared_ptr<Graph> graph, bool optimize)
|
||||
: pImpl(new GraphExecutorImpl(std::move(graph), optimize)) {}
|
||||
|
||||
GraphExecutor::GraphExecutor(std::shared_ptr<Graph> graph, bool optimize, bool symbolically_differentiable)
|
||||
: pImpl(new GraphExecutorImpl(std::move(graph), optimize, symbolically_differentiable)) {}
|
||||
|
||||
void GraphExecutor::run(Stack & inputs) {
|
||||
return pImpl->run(inputs);
|
||||
}
|
||||
|
|
@ -509,49 +539,7 @@ void runRequiredPasses(const std::shared_ptr<Graph>& g) {
|
|||
// add valid expand nodes when the shapes are stable
|
||||
RemoveExpands(g);
|
||||
CanonicalizeOps(g);
|
||||
}
|
||||
|
||||
void specializeToSpec(const std::shared_ptr<Graph>& graph, const ArgumentSpec& spec) {
|
||||
// clean up GradOf and AutogradAdd nodes
|
||||
// this must be first because later passes do not know what GradOfs are
|
||||
std::vector<bool> defined;
|
||||
for(size_t i = 0; i < spec.size(); ++i) {
|
||||
defined.push_back(spec.at(i).defined());
|
||||
}
|
||||
specializeUndef(*graph, defined);
|
||||
|
||||
// required passes shared with autograd fallback
|
||||
runRequiredPasses(graph);
|
||||
|
||||
// clean up dead constants from specialization
|
||||
EliminateDeadCode(graph);
|
||||
// calculate all input shapes
|
||||
PropagateInputShapes(*graph, spec);
|
||||
}
|
||||
|
||||
void runOptimization(std::shared_ptr<Graph> & graph, bool graphMustSupportVariables) {
|
||||
|
||||
// these optimizations must run in the presence of variables
|
||||
// and when shape information is not statically known.
|
||||
EliminateDeadCode(graph);
|
||||
CheckInplace(graph);
|
||||
EliminateCommonSubexpression(graph);
|
||||
|
||||
if (!graphMustSupportVariables) {
|
||||
// These optimizations can introduce operators like FusionGroup that
|
||||
// do not work on variables
|
||||
|
||||
// They also may assume that concrete sizes/strides are availiable
|
||||
UnrollLoops(graph);
|
||||
ConstantPropagation(graph);
|
||||
//TODO: create peephole optimizations that are safe to run
|
||||
// when we are using variables, and when we do not know sizes.
|
||||
PeepholeOptimize(graph);
|
||||
// TODO: remove mandatory size checking in BatchMM, otherwise
|
||||
// it works fine on variables.
|
||||
BatchMM(graph);
|
||||
FuseGraph(graph);
|
||||
}
|
||||
EliminateDeadCode(g);
|
||||
}
|
||||
|
||||
}}
|
||||
|
|
|
|||
|
|
@ -15,29 +15,20 @@ struct GraphExecutorState;
|
|||
// They is only valid only right after you call getDebugState() and should never
|
||||
// be used again once another GraphExecutor function is called.
|
||||
struct ExecutionPlanState {
|
||||
Code* f;
|
||||
Graph* graph;
|
||||
|
||||
// Those two fields are optional
|
||||
Gradient* grad;
|
||||
std::shared_ptr<GraphExecutorState> grad_executor; // shared_ptr to break the cycle...
|
||||
Code* code = nullptr;
|
||||
const Graph* graph = nullptr;
|
||||
};
|
||||
|
||||
struct GraphExecutorState {
|
||||
Graph* graph;
|
||||
const Graph* graph;
|
||||
ExecutionPlanState fallback; // XXX: members of this field are optional
|
||||
std::unordered_map<ArgumentSpec, ExecutionPlanState> execution_plans;
|
||||
|
||||
// Those two fields are optional
|
||||
Code* autograd_fallback;
|
||||
Graph* autograd_fallback_graph;
|
||||
};
|
||||
|
||||
struct GraphExecutorImpl;
|
||||
struct TORCH_API GraphExecutor {
|
||||
GraphExecutor() = default;
|
||||
GraphExecutor(std::shared_ptr<Graph> graph, bool optimize = true);
|
||||
// note: if not specified, symbolically_differentiable is computed from the graph.
|
||||
GraphExecutor(std::shared_ptr<Graph> graph, bool optimize, bool symbolically_differentiable);
|
||||
void run(Stack & inputs);
|
||||
explicit operator bool() const {
|
||||
return pImpl != nullptr;
|
||||
|
|
@ -53,15 +44,11 @@ private:
|
|||
// regardless of whether sizes have been specialized or not.
|
||||
TORCH_API void runRequiredPasses(const std::shared_ptr<Graph>& g);
|
||||
|
||||
// specialize 'graph' to the types, sizes, and other properties described in spec
|
||||
// this prepares the graph for execution, including running runRequiredPasses,
|
||||
// but the execution only remains valid for tensors whose properties match spec
|
||||
// otherwise running the graph will have undefined results.
|
||||
TORCH_API void specializeToSpec(const std::shared_ptr<Graph>& graph, const ArgumentSpec& spec);
|
||||
namespace detail {
|
||||
|
||||
GraphExecutor* getGradExecutor(Operation& op);
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// apply standard optimizations. if graphMustSupportVariables=false then
|
||||
// then the passes are allowed to modify the graph in ways that make it no longer
|
||||
// work with tensors that have requires_grad=True
|
||||
TORCH_API void runOptimization(std::shared_ptr<Graph> & graph, bool graphMustSupportVariables);
|
||||
|
||||
}}
|
||||
|
|
|
|||
|
|
@ -129,8 +129,8 @@ void initJITBindings(PyObject *module) {
|
|||
});
|
||||
py::class_<ArgumentSpec>(m, "ArgumentSpec");
|
||||
py::class_<Code>(m, "Code")
|
||||
.def("executors", [](Code& c) {
|
||||
return py::make_iterator(c.executors().begin(), c.executors().end());
|
||||
.def("grad_executors", [](Code& c) {
|
||||
return py::make_iterator(c.grad_executors().begin(), c.grad_executors().end());
|
||||
});
|
||||
|
||||
py::class_<ExecutionPlanState>(m, "ExecutionPlanState")
|
||||
|
|
@ -138,10 +138,7 @@ void initJITBindings(PyObject *module) {
|
|||
return s.graph;
|
||||
})
|
||||
.def_property_readonly("code", [](ExecutionPlanState& s) {
|
||||
return s.f;
|
||||
})
|
||||
.def_property_readonly("grad_executor", [](ExecutionPlanState& s) {
|
||||
return s.grad_executor.get();
|
||||
return s.code;
|
||||
});
|
||||
|
||||
py::class_<Gradient>(m, "Gradient")
|
||||
|
|
@ -174,11 +171,8 @@ void initJITBindings(PyObject *module) {
|
|||
.def_property_readonly("execution_plans", [](GraphExecutorState& s) {
|
||||
return s.execution_plans;
|
||||
})
|
||||
.def_property_readonly("autograd_fallback", [](GraphExecutorState& s) {
|
||||
return s.autograd_fallback;
|
||||
})
|
||||
.def_property_readonly("autograd_fallback_graph", [](GraphExecutorState& s) {
|
||||
return s.autograd_fallback_graph;
|
||||
.def_property_readonly("fallback", [](GraphExecutorState& s) {
|
||||
return s.fallback;
|
||||
});
|
||||
|
||||
py::class_<GraphExecutor>(m, "GraphExecutor", py::dynamic_attr())
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ namespace torch { namespace jit {
|
|||
_(prim, Eval) \
|
||||
_(prim, Expand) /* onnx */ \
|
||||
_(prim, FusionGroup) \
|
||||
_(prim, GraphExecutor) \
|
||||
_(prim, DifferentiableGraph) \
|
||||
_(prim, If) \
|
||||
_(prim, Jump) /* debug */ \
|
||||
_(prim, JumpNZ) /* debug */ \
|
||||
|
|
@ -87,6 +87,12 @@ namespace torch { namespace jit {
|
|||
_(onnx, Not) \
|
||||
FORALL_ATTR_BASE_SYMBOLS(_) \
|
||||
_(attr, Subgraph) \
|
||||
_(attr, ReverseSubgraph) \
|
||||
_(attr, f_real_outputs) \
|
||||
_(attr, df_input_vjps) \
|
||||
_(attr, df_input_captured_inputs) \
|
||||
_(attr, df_input_captured_outputs) \
|
||||
_(attr, df_output_vjps) \
|
||||
_(attr, axes) \
|
||||
_(attr, axis) \
|
||||
_(attr, broadcast) \
|
||||
|
|
|
|||
|
|
@ -515,7 +515,7 @@ struct CodeImpl {
|
|||
|
||||
size_t insertInstruction(Node * n) {
|
||||
auto inst = insertInstruction(n->kind(), n->getSourceLocation(), n->inputs(), moveFlags(n) , n->outputs());
|
||||
instructions[inst].callback = getInterpreterOperation(n);
|
||||
instructions[inst].callback = getOperation(n);
|
||||
return inst;
|
||||
}
|
||||
size_t insertInstruction(Symbol sym,
|
||||
|
|
@ -603,27 +603,16 @@ struct CodeImpl {
|
|||
return r;
|
||||
}
|
||||
|
||||
// Returns a function implementing functionality of a given node,
|
||||
// or nullptr if it's a no-op for autograd.
|
||||
Operation getInterpreterOperation(jit::Node* node) {
|
||||
if(node->kind() != prim::GraphExecutor) {
|
||||
return getOperation(node);
|
||||
const std::vector<GraphExecutor*>& grad_executors() {
|
||||
if (!grad_executors_) {
|
||||
grad_executors_.emplace();
|
||||
for (Instruction & instr : instructions) {
|
||||
if (auto executor = detail::getGradExecutor(instr.callback)) {
|
||||
grad_executors_->push_back(executor);
|
||||
}
|
||||
// recursive graph executors cannot be Operators because they
|
||||
// have to register themselves with the interpreter so that
|
||||
// we can provide useful debugging information
|
||||
|
||||
auto executor = std::make_shared<GraphExecutor>(node->g(attr::Subgraph));
|
||||
graph_executors.emplace_back(executor.get());
|
||||
return [=](Stack& stack) mutable {
|
||||
autograd::profiler::RecordFunction record("GraphExecutor");
|
||||
executor->run(stack);
|
||||
return 0;
|
||||
};
|
||||
}
|
||||
|
||||
const std::vector<GraphExecutor*>& executors() {
|
||||
return graph_executors;
|
||||
}
|
||||
return *grad_executors_;
|
||||
}
|
||||
|
||||
void dumpInstruction(std::ostream & out, size_t pc) const {
|
||||
|
|
@ -664,7 +653,7 @@ struct CodeImpl {
|
|||
// It is also very useful for debugging interpreter problems to
|
||||
// keep this around.
|
||||
std::shared_ptr<Graph> graph;
|
||||
std::vector<GraphExecutor*> graph_executors; // for debugging
|
||||
at::optional<std::vector<GraphExecutor*>> grad_executors_;
|
||||
PreprocessGraph preprocess;
|
||||
|
||||
std::unordered_map<size_t, int> unique_to_reg; // map from unique of nodes to register in register table
|
||||
|
|
@ -771,8 +760,8 @@ Code::Code(std::shared_ptr<Graph>& graph)
|
|||
: pImpl(new CodeImpl(graph)) {}
|
||||
Code::~Code() = default;
|
||||
|
||||
const std::vector<GraphExecutor*>& Code::executors() {
|
||||
return pImpl->executors();
|
||||
const std::vector<GraphExecutor*>& Code::grad_executors() {
|
||||
return pImpl->grad_executors();
|
||||
}
|
||||
|
||||
InterpreterState::InterpreterState(const Code & code)
|
||||
|
|
|
|||
|
|
@ -29,8 +29,7 @@ struct TORCH_API Code {
|
|||
Code(std::shared_ptr<Graph>& graph);
|
||||
~Code();
|
||||
|
||||
// Returns pointers to GraphExecutors created to run GraphExecutor nodes in the given graph.
|
||||
const std::vector<GraphExecutor*>& executors();
|
||||
const std::vector<GraphExecutor*>& grad_executors();
|
||||
|
||||
explicit operator bool() const {
|
||||
return pImpl != nullptr;
|
||||
|
|
|
|||
|
|
@ -186,7 +186,7 @@ std::ostream& printNode(std::ostream & out, size_t level, const Node * n, std::v
|
|||
IR_ELSE()
|
||||
if(n->hasAttribute(attr::Subgraph) && groups) {
|
||||
out << n->kind().toQualString() << "_" << groups->size();
|
||||
if (n->numAttributes() > 1) {
|
||||
if (n->numAttributes() > 1 && n->kind() != prim::DifferentiableGraph) {
|
||||
printAttributes(out, n, /*ignore_subgraph=*/true);
|
||||
}
|
||||
groups->push_back(n);
|
||||
|
|
|
|||
|
|
@ -24,6 +24,9 @@ std::shared_ptr<Graph> Canonicalize(const std::shared_ptr<Graph>& graph) {
|
|||
r_outputs.at(i)->setStage(outputs.at(i)->stage());
|
||||
rn_env[outputs.at(i)] = r_outputs.at(i);
|
||||
}
|
||||
if (node->hasAttribute(attr::Subgraph)) {
|
||||
r_node->g_(attr::Subgraph, Canonicalize(node->g(attr::Subgraph)));
|
||||
}
|
||||
}
|
||||
for (auto* output : graph->outputs()) {
|
||||
r->registerOutput(rn_fn(output));
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ namespace {
|
|||
// right before nodes[0] (i.e. it will not create cycles and all uses of
|
||||
// new node will be after this position).
|
||||
// prereq: nodes are in topological order
|
||||
void mergeNodes(Block * block, Symbol group_node_kind, ArrayRef<Node*> nodes) {
|
||||
Node* mergeNodes(Block * block, Symbol group_node_kind, ArrayRef<Node*> nodes) {
|
||||
JIT_ASSERT(nodes.size() > 0);
|
||||
std::unordered_map<Value*, Value*> value_map;
|
||||
Graph * graph = block->owningGraph();
|
||||
|
|
@ -66,11 +66,12 @@ void mergeNodes(Block * block, Symbol group_node_kind, ArrayRef<Node*> nodes) {
|
|||
nodes[i - 1]->destroy();
|
||||
}
|
||||
JIT_ASSERT(isDifferentiable(*new_graph));
|
||||
return group_node;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void CreateAutodiffSubgraphs(Block * block, size_t threshold) {
|
||||
void CreateAutodiffSubgraphs(Block * block, size_t threshold, std::vector<Node*>& diff_graphs) {
|
||||
// This implementation is not optimal, but it is simple.
|
||||
// It just scans through the list in order looking for runs of
|
||||
// differentiable ops, and then grouping them together when
|
||||
|
|
@ -93,21 +94,23 @@ void CreateAutodiffSubgraphs(Block * block, size_t threshold) {
|
|||
groupable.push_back(node);
|
||||
} else {
|
||||
if(groupable.size() >= threshold) {
|
||||
mergeNodes(block, prim::GraphExecutor, groupable);
|
||||
diff_graphs.push_back(mergeNodes(block, prim::DifferentiableGraph, groupable));
|
||||
}
|
||||
groupable.clear();
|
||||
for (Block * sub_block : node->blocks()) {
|
||||
CreateAutodiffSubgraphs(sub_block, threshold);
|
||||
CreateAutodiffSubgraphs(sub_block, threshold, diff_graphs);
|
||||
}
|
||||
}
|
||||
}
|
||||
if(groupable.size() >= threshold) {
|
||||
mergeNodes(block, prim::GraphExecutor, groupable);
|
||||
diff_graphs.push_back(mergeNodes(block, prim::DifferentiableGraph, groupable));
|
||||
}
|
||||
}
|
||||
|
||||
void CreateAutodiffSubgraphs(Graph & graph, size_t threshold) {
|
||||
CreateAutodiffSubgraphs(graph.block(), threshold);
|
||||
std::vector<Node*> CreateAutodiffSubgraphs(Graph & graph, size_t threshold) {
|
||||
std::vector<Node*> diff_nodes;
|
||||
CreateAutodiffSubgraphs(graph.block(), threshold, diff_nodes);
|
||||
return diff_nodes;
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ struct Graph;
|
|||
// insert GraphExecutor nodes that group together
|
||||
// subgraphs that are differentiable by the jit's autodiff passes
|
||||
// threshold - minimum number of nodes that will appear in a block
|
||||
TORCH_API void CreateAutodiffSubgraphs(Graph & graph, size_t threshold = 2);
|
||||
// returns all differentiable blocks that have been found
|
||||
TORCH_API std::vector<Node*> CreateAutodiffSubgraphs(Graph & graph, size_t threshold = 2);
|
||||
|
||||
}}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user