mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Static Runtime] Do not replace with copy variants if TE fuser is enabled (#72946)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72946
The passes to replace with copy variants are run after TensorExpr fusion. Due to this the resulting graph does not conform to the assumptions made in the fuser.
So, even if these flags `use_copy_variants`, `use_maybe_copy_variants` are turned on, the corresponding passes will not be executed if TensorExpr fusion is enabled.
ghstack-source-id: 149429753
Test Plan: Tested locally.
Reviewed By: mikeiovine
Differential Revision: D34283842
fbshipit-source-id: 74edea517a00c85dff0319f9c8b3ac8befe09018
(cherry picked from commit 3798af7f1b)
This commit is contained in:
parent
02afdd54b9
commit
2724e4c039
|
|
@ -157,10 +157,10 @@ void OptimizeGraph(
|
|||
// TODO: we can avoid this guard by moving operations
|
||||
// to exposed folders.
|
||||
#ifdef FBCODE_CAFFE2
|
||||
if (opts.use_copy_variants) {
|
||||
if (opts.use_copy_variants && !opts.enable_tensorexpr_fusion) {
|
||||
ReplaceWithCopy(graph);
|
||||
}
|
||||
if (opts.use_maybe_copy_variants) {
|
||||
if (opts.use_maybe_copy_variants && !opts.enable_tensorexpr_fusion) {
|
||||
ReplaceWithMaybeCopy(graph);
|
||||
}
|
||||
FuseListUnpack(graph);
|
||||
|
|
|
|||
|
|
@ -166,11 +166,18 @@ struct TORCH_API StaticModuleOptions {
|
|||
bool manage_output_tensors{false};
|
||||
// Gates the ReplaceWithCopy pass, which replaces ops that
|
||||
// sometimes alias their outputs with out variants that
|
||||
// always copy (so the output may participate in memory planning)
|
||||
// always copy (so the output may participate in memory planning).
|
||||
// Since replacing with copies is done after TensorExpr fusion, the
|
||||
// resulting graph does not conform to the assumptions made in the fuser.
|
||||
// So, even if this flag is turned on, the ReplaceWithCopy pass will not
|
||||
// be executed if TensorExpr fusion is enabled.
|
||||
bool use_copy_variants{true};
|
||||
// Gates the ReplaceWithMaybeCopy pass, which replaces ops that
|
||||
// sometimes alias their outputs with subgraphs that include an out
|
||||
// variant.
|
||||
// For the same reason as `use_copy_variants`, the ReplaceWithMaybeCopy pass
|
||||
// will not be executed if TensorExpr fusion is enabled, even if this flag
|
||||
// is turned on.
|
||||
bool use_maybe_copy_variants{true};
|
||||
// enable TensorExpr fusion of ops at model loading time
|
||||
bool enable_tensorexpr_fusion{false};
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user