mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Added remove_noop_ops to joint_graph_passes (#124451)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124451 Approved by: https://github.com/ezyang, https://github.com/fmassa
This commit is contained in:
parent
c446851829
commit
ffc202a1b9
|
|
@ -1048,10 +1048,10 @@ class GraphModule(torch.nn.Module):
|
|||
joint_graph,
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_1: "f64[2, 2, 8, 4]", primals_2: "f64[2, 2, 8, 4]", primals_3: "f64[2, 2, 8, 4]", alias_3: "f64[2, 2, 8, 4]", alias_5: "f32[2, 2, 8]", tangents_1: "f64[2, 2, 8, 4]"):
|
||||
def forward(self, primals_1: "f64[2, 2, 8, 4]", primals_2: "f64[2, 2, 8, 4]", primals_3: "f64[2, 2, 8, 4]", getitem: "f64[2, 2, 8, 4]", getitem_1: "f32[2, 2, 8]", tangents_1: "f64[2, 2, 8, 4]"):
|
||||
fw_graph = self.fw_graph
|
||||
joint_graph = self.joint_graph
|
||||
flex_attention_backward = torch.ops.higher_order.flex_attention_backward(primals_1, primals_2, primals_3, alias_3, alias_5, tangents_1, fw_graph, joint_graph); primals_1 = primals_2 = primals_3 = alias_3 = alias_5 = tangents_1 = fw_graph = joint_graph = None
|
||||
flex_attention_backward = torch.ops.higher_order.flex_attention_backward(primals_1, primals_2, primals_3, getitem, getitem_1, tangents_1, fw_graph, joint_graph); primals_1 = primals_2 = primals_3 = getitem = getitem_1 = tangents_1 = fw_graph = joint_graph = None
|
||||
getitem_2: "f64[2, 2, 8, 4]" = flex_attention_backward[0]
|
||||
getitem_3: "f64[2, 2, 8, 4]" = flex_attention_backward[1]
|
||||
getitem_4: "f64[2, 2, 8, 4]" = flex_attention_backward[2]; flex_attention_backward = None
|
||||
|
|
|
|||
|
|
@ -883,6 +883,9 @@ def _get_sfdp_patterns():
|
|||
"pass_dicts": patterns,
|
||||
"extra_check": extra_check,
|
||||
"scalar_workaround": workaround,
|
||||
# with dropout turned into clone, we end up with a number of
|
||||
# semantically identical graphs
|
||||
"skip_duplicates": True,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -313,6 +313,10 @@ def joint_graph_passes(graph: torch.fx.GraphModule):
|
|||
config.joint_custom_pre_pass(graph.graph)
|
||||
count += 1
|
||||
|
||||
from .post_grad import remove_noop_ops
|
||||
|
||||
remove_noop_ops(graph.graph)
|
||||
|
||||
if config.joint_graph_constant_folding:
|
||||
constant_fold_uniform_value(graph)
|
||||
|
||||
|
|
|
|||
|
|
@ -42,23 +42,19 @@ amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
|
|||
sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
|
||||
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
||||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
|
||||
expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
|
||||
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
||||
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
|
||||
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
||||
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
||||
alias_default = CallFunction(aten.alias.default, div_Tensor_1)
|
||||
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
||||
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
||||
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
|
||||
neg_default = CallFunction(aten.neg.default, alias_default_3)
|
||||
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
|
||||
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
||||
permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
||||
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
|
||||
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
||||
mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, alias_default_3, _users=2)
|
||||
mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, div_Tensor_1, _users=2)
|
||||
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
|
||||
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor)
|
||||
div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale'))
|
||||
|
|
@ -123,11 +119,7 @@ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignore
|
|||
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
||||
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
||||
alias_default = CallFunction(aten.alias.default, convert_element_type_default_1)
|
||||
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
||||
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
||||
alias_default_3 = CallFunction(aten.alias.default, alias_default_2)
|
||||
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2)
|
||||
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
|
||||
neg_default = CallFunction(aten.neg.default, convert_element_type_default_2)
|
||||
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
||||
permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ amax_default = CallFunction(aten.amax.default, view_default_2, Ignored(), True)
|
|||
sub_Tensor = CallFunction(aten.sub.Tensor, view_default_2, amax_default)
|
||||
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
||||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
|
||||
convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
|
||||
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
|
||||
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
||||
|
|
@ -56,18 +56,14 @@ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_form
|
|||
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
||||
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
||||
alias_default = CallFunction(aten.alias.default, div_Tensor_1)
|
||||
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
||||
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
||||
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
|
||||
neg_default = CallFunction(aten.neg.default, alias_default_3)
|
||||
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
|
||||
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
||||
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
||||
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
|
||||
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, bmm_default_2, Ignored())
|
||||
view_default_7 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored())
|
||||
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
|
||||
mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, alias_default_3, _users=2)
|
||||
mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, div_Tensor_1, _users=2)
|
||||
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
|
||||
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor)
|
||||
view_default_8 = CallFunction(aten.view.default, fma_default, Ignored(), _users=2)
|
||||
|
|
@ -137,7 +133,7 @@ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ign
|
|||
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
||||
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
||||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
|
||||
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
|
||||
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
|
||||
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
||||
|
|
@ -147,17 +143,13 @@ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_form
|
|||
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
||||
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
||||
alias_default = CallFunction(aten.alias.default, div_Tensor_1)
|
||||
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
||||
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
||||
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
|
||||
neg_default = CallFunction(aten.neg.default, alias_default_3)
|
||||
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
|
||||
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
||||
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
||||
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
|
||||
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
||||
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
|
||||
mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, alias_default_3, _users=2)
|
||||
mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, div_Tensor_1, _users=2)
|
||||
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
|
||||
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor)
|
||||
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, fma_default, Ignored())
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
|
|||
sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
|
||||
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
||||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
|
||||
expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
|
||||
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
||||
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
||||
|
|
@ -55,16 +55,12 @@ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_form
|
|||
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
||||
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
||||
alias_default = CallFunction(aten.alias.default, div_Tensor_1)
|
||||
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
||||
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
||||
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
|
||||
neg_default = CallFunction(aten.neg.default, alias_default_3)
|
||||
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
|
||||
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
||||
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
||||
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
|
||||
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
||||
mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, alias_default_3, _users=2)
|
||||
mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, div_Tensor_1, _users=2)
|
||||
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
|
||||
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor)
|
||||
div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale'))
|
||||
|
|
@ -144,11 +140,7 @@ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_form
|
|||
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
||||
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
||||
alias_default = CallFunction(aten.alias.default, convert_element_type_default_1)
|
||||
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
||||
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
||||
alias_default_3 = CallFunction(aten.alias.default, alias_default_2)
|
||||
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2)
|
||||
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
|
||||
neg_default = CallFunction(aten.neg.default, convert_element_type_default_2)
|
||||
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
||||
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
|
|||
sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
|
||||
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
||||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
|
||||
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
|
||||
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
|
||||
expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
|
||||
|
|
@ -59,11 +59,7 @@ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_form
|
|||
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
||||
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
||||
alias_default = CallFunction(aten.alias.default, div_Tensor_1)
|
||||
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
||||
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
||||
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
|
||||
neg_default = CallFunction(aten.neg.default, alias_default_3)
|
||||
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
|
||||
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
||||
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
||||
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
|
||||
|
|
@ -71,8 +67,7 @@ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
|||
convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
||||
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored())
|
||||
mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
|
||||
clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
|
||||
mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default_3, alias_default_3, _users=2)
|
||||
mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2)
|
||||
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
||||
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
|
||||
div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale_factor'))
|
||||
|
|
@ -116,13 +111,12 @@ sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
|
|||
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
||||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
||||
clone_default_2 = CallFunction(aten.clone.default, div_Tensor_1)
|
||||
expand_default_2 = CallFunction(aten.expand.default, clone_default_2, Ignored())
|
||||
expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
|
||||
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
||||
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
||||
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
||||
clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
||||
view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored())
|
||||
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
||||
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
||||
_sfdp_pattern_12_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
|
||||
|
||||
|
|
@ -158,11 +152,7 @@ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_form
|
|||
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
||||
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
||||
alias_default = CallFunction(aten.alias.default, convert_element_type_default_1)
|
||||
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
||||
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
||||
alias_default_3 = CallFunction(aten.alias.default, alias_default_2)
|
||||
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2)
|
||||
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
|
||||
neg_default = CallFunction(aten.neg.default, convert_element_type_default_2)
|
||||
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
||||
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
||||
|
|
@ -171,8 +161,7 @@ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
|||
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
||||
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored())
|
||||
mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
|
||||
clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
|
||||
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, clone_default_3, Ignored())
|
||||
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored())
|
||||
mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2)
|
||||
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
||||
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
|
||||
|
|
@ -220,12 +209,11 @@ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
|||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
||||
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
|
||||
clone_default_2 = CallFunction(aten.clone.default, convert_element_type_default_1)
|
||||
expand_default_2 = CallFunction(aten.expand.default, clone_default_2, Ignored())
|
||||
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
|
||||
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
||||
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
||||
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
||||
clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
||||
view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored())
|
||||
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
||||
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
||||
_sfdp_pattern_12_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
|
||||
|
|
|
|||
|
|
@ -38,22 +38,17 @@ amax_default = CallFunction(aten.amax.default, bmm_default, Ignored(), True)
|
|||
sub_Tensor = CallFunction(aten.sub.Tensor, bmm_default, amax_default)
|
||||
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
||||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
|
||||
div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
|
||||
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor)
|
||||
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored(), _users=2)
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, mul_Tensor_1, KeywordArg('value'))
|
||||
alias_default = CallFunction(aten.alias.default, div_Tensor)
|
||||
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
||||
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
||||
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
|
||||
neg_default = CallFunction(aten.neg.default, alias_default_3)
|
||||
neg_default = CallFunction(aten.neg.default, div_Tensor)
|
||||
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
||||
bmm_default_2 = CallFunction(aten.bmm.default, KeywordArg('tangents_1'), permute_default_1)
|
||||
convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
||||
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored())
|
||||
mul_Tensor_3 = CallFunction(aten.mul.Tensor, bmm_default_2, mul_Tensor_2)
|
||||
clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
|
||||
mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default, alias_default_3, _users=2)
|
||||
mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor, _users=2)
|
||||
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
||||
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4, _users=2)
|
||||
permute_default_2 = CallFunction(aten.permute.default, permute_default, Ignored())
|
||||
|
|
@ -78,8 +73,7 @@ sub_Tensor = CallFunction(aten.sub.Tensor, bmm_default, amax_default)
|
|||
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
||||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
||||
clone_default = CallFunction(aten.clone.default, div_Tensor)
|
||||
_sfdp_pattern_13_inference = CallFunction(aten.bmm.default, clone_default, KeywordArg('value'), _users=0)
|
||||
_sfdp_pattern_13_inference = CallFunction(aten.bmm.default, div_Tensor, KeywordArg('value'), _users=0)
|
||||
|
||||
|
||||
rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
||||
|
|
@ -96,19 +90,14 @@ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default
|
|||
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1)
|
||||
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored(), _users=2)
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, mul_Tensor_1, KeywordArg('value'))
|
||||
alias_default = CallFunction(aten.alias.default, convert_element_type_default_1)
|
||||
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
||||
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
||||
alias_default_3 = CallFunction(aten.alias.default, alias_default_2)
|
||||
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2)
|
||||
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
|
||||
neg_default = CallFunction(aten.neg.default, convert_element_type_default_2)
|
||||
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
||||
bmm_default_2 = CallFunction(aten.bmm.default, KeywordArg('tangents_1'), permute_default_1)
|
||||
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
||||
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored())
|
||||
mul_Tensor_3 = CallFunction(aten.mul.Tensor, bmm_default_2, mul_Tensor_2)
|
||||
clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
|
||||
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, clone_default, Ignored())
|
||||
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored())
|
||||
mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2)
|
||||
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
||||
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
|
||||
|
|
@ -137,5 +126,4 @@ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
|||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
||||
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored())
|
||||
clone_default = CallFunction(aten.clone.default, convert_element_type_default_1)
|
||||
_sfdp_pattern_13_half_inference = CallFunction(aten.bmm.default, clone_default, KeywordArg('value'), _users=0)
|
||||
_sfdp_pattern_13_half_inference = CallFunction(aten.bmm.default, convert_element_type_default_1, KeywordArg('value'), _users=0)
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
|
|||
sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
|
||||
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
||||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
|
||||
expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
|
||||
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
||||
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
||||
|
|
@ -56,16 +56,12 @@ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_form
|
|||
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
||||
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
||||
alias_default = CallFunction(aten.alias.default, div_Tensor_1)
|
||||
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
||||
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
||||
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
|
||||
neg_default = CallFunction(aten.neg.default, alias_default_3)
|
||||
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
|
||||
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
||||
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
||||
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
|
||||
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
||||
mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, alias_default_3, _users=2)
|
||||
mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, div_Tensor_1, _users=2)
|
||||
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
|
||||
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor)
|
||||
div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale'))
|
||||
|
|
@ -148,11 +144,7 @@ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_form
|
|||
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
||||
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
||||
alias_default = CallFunction(aten.alias.default, convert_element_type_default_1)
|
||||
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
||||
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
||||
alias_default_3 = CallFunction(aten.alias.default, alias_default_2)
|
||||
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2)
|
||||
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
|
||||
neg_default = CallFunction(aten.neg.default, convert_element_type_default_2)
|
||||
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
||||
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
||||
|
|
|
|||
|
|
@ -50,7 +50,7 @@ amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True)
|
|||
sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default)
|
||||
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
||||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
|
||||
expand_default_3 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
|
||||
view_default_3 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
|
||||
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
||||
|
|
@ -60,16 +60,12 @@ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _us
|
|||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
||||
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
||||
scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored())
|
||||
alias_default = CallFunction(aten.alias.default, div_Tensor_1)
|
||||
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
||||
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
||||
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
|
||||
neg_default = CallFunction(aten.neg.default, alias_default_3)
|
||||
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
|
||||
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
||||
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
||||
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
|
||||
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
||||
mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, alias_default_3, _users=2)
|
||||
mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, div_Tensor_1, _users=2)
|
||||
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
|
||||
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor)
|
||||
where_self_1 = CallFunction(aten.where.self, expand_default, scalar_tensor_default, fma_default)
|
||||
|
|
@ -161,11 +157,7 @@ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _us
|
|||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
||||
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
||||
scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored())
|
||||
alias_default = CallFunction(aten.alias.default, convert_element_type_default_1)
|
||||
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
||||
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
||||
alias_default_3 = CallFunction(aten.alias.default, alias_default_2)
|
||||
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2)
|
||||
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
|
||||
neg_default = CallFunction(aten.neg.default, convert_element_type_default_2)
|
||||
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
||||
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
|
|||
sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
|
||||
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
||||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
|
||||
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
|
||||
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
|
||||
expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
|
||||
|
|
@ -60,11 +60,7 @@ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_form
|
|||
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
||||
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
||||
alias_default = CallFunction(aten.alias.default, div_Tensor_1)
|
||||
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
||||
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
||||
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
|
||||
neg_default = CallFunction(aten.neg.default, alias_default_3)
|
||||
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
|
||||
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
||||
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
||||
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
|
||||
|
|
@ -72,8 +68,7 @@ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
|||
convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
||||
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored())
|
||||
mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
|
||||
clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
|
||||
mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default_3, alias_default_3, _users=2)
|
||||
mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2)
|
||||
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
||||
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
|
||||
div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale'))
|
||||
|
|
@ -119,13 +114,12 @@ sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
|
|||
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
||||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
||||
clone_default_2 = CallFunction(aten.clone.default, div_Tensor_1)
|
||||
expand_default_2 = CallFunction(aten.expand.default, clone_default_2, Ignored())
|
||||
expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
|
||||
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
||||
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
||||
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
||||
clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
||||
view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored())
|
||||
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
||||
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
||||
_sfdp_pattern_16_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
|
||||
|
||||
|
|
@ -147,7 +141,7 @@ amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
|
|||
sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
|
||||
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
||||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
|
||||
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
|
||||
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
|
||||
expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
|
||||
|
|
@ -157,11 +151,7 @@ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored(
|
|||
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
||||
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
||||
alias_default = CallFunction(aten.alias.default, div_Tensor_1)
|
||||
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
||||
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
||||
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
|
||||
neg_default = CallFunction(aten.neg.default, alias_default_3)
|
||||
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
|
||||
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
||||
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
||||
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
|
||||
|
|
@ -169,8 +159,7 @@ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
|||
convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
||||
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored())
|
||||
mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
|
||||
clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
|
||||
mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default, alias_default_3, _users=2)
|
||||
mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2)
|
||||
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
||||
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
|
||||
div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale'))
|
||||
|
|
@ -214,8 +203,7 @@ sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
|
|||
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
||||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
||||
clone_default = CallFunction(aten.clone.default, div_Tensor_1)
|
||||
expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored())
|
||||
expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
|
||||
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
||||
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
||||
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
||||
|
|
@ -256,11 +244,7 @@ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_form
|
|||
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
||||
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
||||
alias_default = CallFunction(aten.alias.default, convert_element_type_default_1)
|
||||
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
||||
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
||||
alias_default_3 = CallFunction(aten.alias.default, alias_default_2)
|
||||
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2)
|
||||
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
|
||||
neg_default = CallFunction(aten.neg.default, convert_element_type_default_2)
|
||||
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
||||
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
||||
|
|
@ -269,8 +253,7 @@ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
|||
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
||||
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored())
|
||||
mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
|
||||
clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
|
||||
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, clone_default_3, Ignored())
|
||||
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored())
|
||||
mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2)
|
||||
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
||||
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
|
||||
|
|
@ -320,13 +303,12 @@ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
|||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
||||
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
|
||||
clone_default_2 = CallFunction(aten.clone.default, convert_element_type_default_1)
|
||||
expand_default_2 = CallFunction(aten.expand.default, clone_default_2, Ignored())
|
||||
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
|
||||
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
||||
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
||||
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
||||
clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
||||
view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored())
|
||||
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
||||
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
||||
_sfdp_pattern_16_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
|
||||
|
||||
|
|
@ -360,11 +342,7 @@ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored(
|
|||
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
||||
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
||||
alias_default = CallFunction(aten.alias.default, convert_element_type_default_1)
|
||||
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
||||
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
||||
alias_default_3 = CallFunction(aten.alias.default, alias_default_2)
|
||||
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2)
|
||||
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
|
||||
neg_default = CallFunction(aten.neg.default, convert_element_type_default_2)
|
||||
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
||||
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
||||
|
|
@ -373,8 +351,7 @@ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
|||
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
||||
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored())
|
||||
mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
|
||||
clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
|
||||
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, clone_default, Ignored())
|
||||
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored())
|
||||
mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2)
|
||||
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
||||
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
|
||||
|
|
@ -422,8 +399,7 @@ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
|||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
||||
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
|
||||
clone_default = CallFunction(aten.clone.default, convert_element_type_default_1)
|
||||
expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored())
|
||||
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
|
||||
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
||||
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
||||
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
||||
|
|
@ -451,7 +427,7 @@ amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
|
|||
sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
|
||||
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
||||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
|
||||
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
|
||||
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
|
||||
convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored())
|
||||
|
|
@ -463,11 +439,7 @@ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_form
|
|||
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
||||
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
||||
alias_default = CallFunction(aten.alias.default, div_Tensor_1)
|
||||
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
||||
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
||||
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
|
||||
neg_default = CallFunction(aten.neg.default, alias_default_3)
|
||||
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
|
||||
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
||||
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
||||
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
|
||||
|
|
@ -476,8 +448,7 @@ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default
|
|||
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
||||
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, Ignored())
|
||||
mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_1, mul_Tensor_2)
|
||||
clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
|
||||
mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default_3, alias_default_3, _users=2)
|
||||
mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2)
|
||||
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
||||
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
|
||||
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, fma_default, Ignored())
|
||||
|
|
@ -524,14 +495,13 @@ sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
|
|||
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
||||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
||||
clone_default_2 = CallFunction(aten.clone.default, div_Tensor_1)
|
||||
convert_element_type_default = CallFunction(prims.convert_element_type.default, clone_default_2, Ignored())
|
||||
convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
|
||||
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
|
||||
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
||||
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
||||
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
||||
clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
||||
view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored())
|
||||
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
||||
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
||||
_sfdp_pattern_16_half_mask_fp32_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
|
||||
|
||||
|
|
@ -553,7 +523,7 @@ amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
|
|||
sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
|
||||
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
||||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
|
||||
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
|
||||
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
|
||||
convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored())
|
||||
|
|
@ -564,11 +534,7 @@ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored(
|
|||
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
||||
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
||||
alias_default = CallFunction(aten.alias.default, div_Tensor_1)
|
||||
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
||||
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
||||
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
|
||||
neg_default = CallFunction(aten.neg.default, alias_default_3)
|
||||
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
|
||||
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
||||
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
||||
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
|
||||
|
|
@ -577,8 +543,7 @@ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default
|
|||
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
||||
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, Ignored())
|
||||
mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_1, mul_Tensor_2)
|
||||
clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
|
||||
mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default, alias_default_3, _users=2)
|
||||
mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2)
|
||||
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
||||
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
|
||||
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, fma_default, Ignored())
|
||||
|
|
@ -623,8 +588,7 @@ sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
|
|||
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
||||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
||||
clone_default = CallFunction(aten.clone.default, div_Tensor_1)
|
||||
convert_element_type_default = CallFunction(prims.convert_element_type.default, clone_default, Ignored())
|
||||
convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
|
||||
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
|
||||
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
||||
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
||||
|
|
|
|||
|
|
@ -52,7 +52,7 @@ amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True)
|
|||
sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default)
|
||||
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
||||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
|
||||
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
|
||||
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
|
||||
expand_default_3 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
|
||||
|
|
@ -64,11 +64,7 @@ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _us
|
|||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
||||
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
||||
scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored())
|
||||
alias_default = CallFunction(aten.alias.default, div_Tensor_1)
|
||||
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
||||
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
||||
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
|
||||
neg_default = CallFunction(aten.neg.default, alias_default_3)
|
||||
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
|
||||
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
||||
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
||||
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
|
||||
|
|
@ -76,8 +72,7 @@ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
|||
convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
||||
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored())
|
||||
mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
|
||||
clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
|
||||
mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default_3, alias_default_3, _users=2)
|
||||
mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2)
|
||||
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
||||
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
|
||||
where_self_1 = CallFunction(aten.where.self, expand_default, scalar_tensor_default, fma_default)
|
||||
|
|
@ -128,13 +123,12 @@ sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default)
|
|||
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
||||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
||||
clone_default_2 = CallFunction(aten.clone.default, div_Tensor_1)
|
||||
expand_default_3 = CallFunction(aten.expand.default, clone_default_2, Ignored())
|
||||
expand_default_3 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
|
||||
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
|
||||
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
||||
expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
||||
clone_default_3 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format)
|
||||
view_default_5 = CallFunction(aten.view.default, clone_default_3, Ignored())
|
||||
clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format)
|
||||
view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored())
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5)
|
||||
_sfdp_pattern_17_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
|
||||
|
||||
|
|
@ -175,11 +169,7 @@ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _us
|
|||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
||||
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
||||
scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored())
|
||||
alias_default = CallFunction(aten.alias.default, convert_element_type_default_1)
|
||||
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
||||
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
||||
alias_default_3 = CallFunction(aten.alias.default, alias_default_2)
|
||||
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2)
|
||||
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
|
||||
neg_default = CallFunction(aten.neg.default, convert_element_type_default_2)
|
||||
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
||||
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
||||
|
|
@ -188,8 +178,7 @@ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
|||
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
||||
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored())
|
||||
mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
|
||||
clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
|
||||
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, clone_default_3, Ignored())
|
||||
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored())
|
||||
mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2)
|
||||
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
||||
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
|
||||
|
|
@ -244,12 +233,11 @@ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
|||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
||||
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
|
||||
clone_default_2 = CallFunction(aten.clone.default, convert_element_type_default_1)
|
||||
expand_default_3 = CallFunction(aten.expand.default, clone_default_2, Ignored())
|
||||
expand_default_3 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
|
||||
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
|
||||
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
||||
expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
||||
clone_default_3 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format)
|
||||
view_default_5 = CallFunction(aten.view.default, clone_default_3, Ignored())
|
||||
clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format)
|
||||
view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored())
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5)
|
||||
_sfdp_pattern_17_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
|
||||
|
|
|
|||
|
|
@ -51,7 +51,7 @@ amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True)
|
|||
sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default)
|
||||
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
||||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
|
||||
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
|
||||
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
|
||||
expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
|
||||
|
|
@ -62,11 +62,7 @@ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_form
|
|||
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
||||
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
||||
alias_default = CallFunction(aten.alias.default, div_Tensor_1)
|
||||
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
||||
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
||||
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
|
||||
neg_default = CallFunction(aten.neg.default, alias_default_3)
|
||||
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
|
||||
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
||||
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
||||
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
|
||||
|
|
@ -74,8 +70,7 @@ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
|||
convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
||||
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored())
|
||||
mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
|
||||
clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
|
||||
mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default_3, alias_default_3, _users=2)
|
||||
mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2)
|
||||
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
||||
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
|
||||
scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored())
|
||||
|
|
@ -126,13 +121,12 @@ sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default)
|
|||
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
||||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
||||
clone_default_2 = CallFunction(aten.clone.default, div_Tensor_1)
|
||||
expand_default_2 = CallFunction(aten.expand.default, clone_default_2, Ignored())
|
||||
expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
|
||||
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
||||
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2)
|
||||
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
||||
clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
||||
view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored())
|
||||
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
||||
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
||||
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
||||
_sfdp_pattern_18_inference = MultiOutputPattern([view_default_5,
|
||||
|
|
@ -160,7 +154,7 @@ amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True)
|
|||
sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default)
|
||||
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
||||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
|
||||
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
|
||||
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
|
||||
expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
|
||||
|
|
@ -170,11 +164,7 @@ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored(
|
|||
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
||||
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
||||
alias_default = CallFunction(aten.alias.default, div_Tensor_1)
|
||||
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
||||
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
||||
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
|
||||
neg_default = CallFunction(aten.neg.default, alias_default_3)
|
||||
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
|
||||
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
||||
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
||||
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
|
||||
|
|
@ -182,8 +172,7 @@ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
|||
convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
||||
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored())
|
||||
mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
|
||||
clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
|
||||
mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default, alias_default_3, _users=2)
|
||||
mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2)
|
||||
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
||||
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
|
||||
scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored())
|
||||
|
|
@ -232,8 +221,7 @@ sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default)
|
|||
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
||||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
||||
clone_default = CallFunction(aten.clone.default, div_Tensor_1)
|
||||
expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored())
|
||||
expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
|
||||
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
||||
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2)
|
||||
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
||||
|
|
@ -280,11 +268,7 @@ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_form
|
|||
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
||||
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
||||
alias_default = CallFunction(aten.alias.default, convert_element_type_default_1)
|
||||
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
||||
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
||||
alias_default_3 = CallFunction(aten.alias.default, alias_default_2)
|
||||
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2)
|
||||
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
|
||||
neg_default = CallFunction(aten.neg.default, convert_element_type_default_2)
|
||||
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
||||
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
||||
|
|
@ -293,8 +277,7 @@ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
|||
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
||||
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored())
|
||||
mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
|
||||
clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
|
||||
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, clone_default_3, Ignored())
|
||||
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored())
|
||||
mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2)
|
||||
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
||||
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
|
||||
|
|
@ -349,13 +332,12 @@ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
|||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
||||
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
|
||||
clone_default_2 = CallFunction(aten.clone.default, convert_element_type_default_1)
|
||||
expand_default_2 = CallFunction(aten.expand.default, clone_default_2, Ignored())
|
||||
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
|
||||
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
||||
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2)
|
||||
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
||||
clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
||||
view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored())
|
||||
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
||||
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
||||
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
||||
_sfdp_pattern_18_half_inference = MultiOutputPattern([view_default_5,
|
||||
|
|
@ -395,11 +377,7 @@ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored(
|
|||
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
||||
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
||||
alias_default = CallFunction(aten.alias.default, convert_element_type_default_1)
|
||||
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
||||
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
||||
alias_default_3 = CallFunction(aten.alias.default, alias_default_2)
|
||||
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2)
|
||||
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
|
||||
neg_default = CallFunction(aten.neg.default, convert_element_type_default_2)
|
||||
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
||||
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
||||
|
|
@ -408,8 +386,7 @@ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
|||
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
||||
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored())
|
||||
mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
|
||||
clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
|
||||
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, clone_default, Ignored())
|
||||
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored())
|
||||
mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2)
|
||||
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
||||
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
|
||||
|
|
@ -462,8 +439,7 @@ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
|||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
||||
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
|
||||
clone_default = CallFunction(aten.clone.default, convert_element_type_default_1)
|
||||
expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored())
|
||||
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
|
||||
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
||||
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2)
|
||||
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
|
|||
sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
|
||||
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
||||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
|
||||
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
|
||||
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
|
||||
expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
|
||||
|
|
@ -57,11 +57,7 @@ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignore
|
|||
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
||||
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
||||
alias_default = CallFunction(aten.alias.default, div_Tensor_1)
|
||||
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
||||
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
||||
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
|
||||
neg_default = CallFunction(aten.neg.default, alias_default_3)
|
||||
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
|
||||
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
||||
permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
||||
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
|
||||
|
|
@ -69,8 +65,7 @@ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
|||
convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
||||
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored())
|
||||
mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
|
||||
clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
|
||||
mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default, alias_default_3, _users=2)
|
||||
mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2)
|
||||
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
||||
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
|
||||
scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored())
|
||||
|
|
@ -114,8 +109,7 @@ sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
|
|||
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
||||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
||||
clone_default = CallFunction(aten.clone.default, div_Tensor_1)
|
||||
expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored())
|
||||
expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
|
||||
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
||||
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
|
||||
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
|
||||
|
|
@ -141,7 +135,7 @@ amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
|
|||
sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
|
||||
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
||||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
|
||||
convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
|
||||
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default)
|
||||
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
|
||||
|
|
@ -151,11 +145,7 @@ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignore
|
|||
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
||||
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
||||
alias_default = CallFunction(aten.alias.default, div_Tensor_1)
|
||||
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
||||
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
||||
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
|
||||
neg_default = CallFunction(aten.neg.default, alias_default_3)
|
||||
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
|
||||
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
||||
permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
||||
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
|
||||
|
|
@ -163,9 +153,8 @@ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
|||
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
||||
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_1, Ignored())
|
||||
mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
|
||||
clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
|
||||
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, clone_default, Ignored())
|
||||
mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, alias_default_3, _users=2)
|
||||
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored())
|
||||
mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, div_Tensor_1, _users=2)
|
||||
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
||||
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
|
||||
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, fma_default, Ignored())
|
||||
|
|
@ -211,8 +200,7 @@ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
|||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
||||
convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
|
||||
clone_default = CallFunction(aten.clone.default, convert_element_type_default)
|
||||
expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored())
|
||||
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
|
||||
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
||||
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
|
||||
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
|
||||
|
|
|
|||
|
|
@ -42,23 +42,19 @@ amax_default = CallFunction(aten.amax.default, mul_Tensor, Ignored(), True)
|
|||
sub_Tensor = CallFunction(aten.sub.Tensor, mul_Tensor, amax_default)
|
||||
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
||||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
|
||||
div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
|
||||
expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored())
|
||||
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
||||
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
|
||||
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
||||
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
||||
alias_default = CallFunction(aten.alias.default, div_Tensor)
|
||||
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
||||
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
||||
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
|
||||
neg_default = CallFunction(aten.neg.default, alias_default_3)
|
||||
neg_default = CallFunction(aten.neg.default, div_Tensor)
|
||||
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
||||
permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
||||
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
|
||||
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
||||
mul_Tensor_1 = CallFunction(aten.mul.Tensor, view_default_7, alias_default_3, _users=2)
|
||||
mul_Tensor_1 = CallFunction(aten.mul.Tensor, view_default_7, div_Tensor, _users=2)
|
||||
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_1, Ignored(), True)
|
||||
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_1)
|
||||
mul_Tensor_2 = CallFunction(aten.mul.Tensor, fma_default, KeywordArg('scale_factor'))
|
||||
|
|
@ -123,11 +119,7 @@ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignore
|
|||
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
||||
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
||||
alias_default = CallFunction(aten.alias.default, convert_element_type_default_1)
|
||||
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
||||
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
||||
alias_default_3 = CallFunction(aten.alias.default, alias_default_2)
|
||||
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2)
|
||||
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
|
||||
neg_default = CallFunction(aten.neg.default, convert_element_type_default_2)
|
||||
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
||||
permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
|
|||
sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
|
||||
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
||||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
|
||||
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
|
||||
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
|
||||
expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
|
||||
|
|
@ -53,11 +53,7 @@ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignore
|
|||
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
||||
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
||||
alias_default = CallFunction(aten.alias.default, div_Tensor_1)
|
||||
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
||||
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
||||
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
|
||||
neg_default = CallFunction(aten.neg.default, alias_default_3)
|
||||
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
|
||||
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
||||
permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
||||
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
|
||||
|
|
@ -65,8 +61,7 @@ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
|||
convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
||||
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored())
|
||||
mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
|
||||
clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
|
||||
mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default, alias_default_3, _users=2)
|
||||
mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2)
|
||||
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
||||
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
|
||||
div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale_factor'))
|
||||
|
|
@ -103,8 +98,7 @@ sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
|
|||
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
||||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
||||
clone_default = CallFunction(aten.clone.default, div_Tensor_1)
|
||||
expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored())
|
||||
expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
|
||||
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
||||
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
|
||||
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
|
||||
|
|
@ -137,11 +131,7 @@ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignore
|
|||
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
||||
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
||||
alias_default = CallFunction(aten.alias.default, convert_element_type_default_1)
|
||||
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
||||
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
||||
alias_default_3 = CallFunction(aten.alias.default, alias_default_2)
|
||||
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2)
|
||||
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
|
||||
neg_default = CallFunction(aten.neg.default, convert_element_type_default_2)
|
||||
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
||||
permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
||||
|
|
@ -150,8 +140,7 @@ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
|||
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
||||
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored())
|
||||
mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
|
||||
clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
|
||||
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, clone_default, Ignored())
|
||||
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored())
|
||||
mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2)
|
||||
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
||||
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
|
||||
|
|
@ -192,8 +181,7 @@ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
|||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
||||
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
|
||||
clone_default = CallFunction(aten.clone.default, convert_element_type_default_1)
|
||||
expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored())
|
||||
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
|
||||
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
||||
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
|
||||
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ amax_default = CallFunction(aten.amax.default, mul_Tensor, Ignored(), True)
|
|||
sub_Tensor = CallFunction(aten.sub.Tensor, mul_Tensor, amax_default)
|
||||
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
||||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
|
||||
div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
|
||||
mul_Tensor_1 = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor)
|
||||
mul_Tensor_2 = CallFunction(aten.mul.Tensor, mul_Tensor_1, Ignored())
|
||||
expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_2, Ignored())
|
||||
|
|
@ -53,11 +53,7 @@ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignore
|
|||
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
||||
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
||||
alias_default = CallFunction(aten.alias.default, div_Tensor)
|
||||
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
||||
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
||||
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
|
||||
neg_default = CallFunction(aten.neg.default, alias_default_3)
|
||||
neg_default = CallFunction(aten.neg.default, div_Tensor)
|
||||
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
||||
permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
||||
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
|
||||
|
|
@ -65,8 +61,7 @@ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
|||
convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
||||
mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored())
|
||||
mul_Tensor_4 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_3)
|
||||
clone_default = CallFunction(aten.clone.default, mul_Tensor_4, memory_format=torch.contiguous_format)
|
||||
mul_Tensor_5 = CallFunction(aten.mul.Tensor, clone_default, alias_default_3, _users=2)
|
||||
mul_Tensor_5 = CallFunction(aten.mul.Tensor, mul_Tensor_4, div_Tensor, _users=2)
|
||||
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_5, Ignored(), True)
|
||||
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_5)
|
||||
mul_Tensor_6 = CallFunction(aten.mul.Tensor, fma_default, KeywordArg('scale_factor'))
|
||||
|
|
@ -103,8 +98,7 @@ sub_Tensor = CallFunction(aten.sub.Tensor, mul_Tensor, amax_default)
|
|||
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
||||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
||||
clone_default = CallFunction(aten.clone.default, div_Tensor)
|
||||
expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored())
|
||||
expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored())
|
||||
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
||||
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
|
||||
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
|
||||
|
|
@ -137,11 +131,7 @@ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignore
|
|||
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
||||
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
||||
alias_default = CallFunction(aten.alias.default, convert_element_type_default_1)
|
||||
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
||||
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
||||
alias_default_3 = CallFunction(aten.alias.default, alias_default_2)
|
||||
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2)
|
||||
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
|
||||
neg_default = CallFunction(aten.neg.default, convert_element_type_default_2)
|
||||
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
||||
permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
||||
|
|
@ -150,8 +140,7 @@ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
|||
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
||||
mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored())
|
||||
mul_Tensor_4 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_3)
|
||||
clone_default = CallFunction(aten.clone.default, mul_Tensor_4, memory_format=torch.contiguous_format)
|
||||
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, clone_default, Ignored())
|
||||
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_4, Ignored())
|
||||
mul_Tensor_5 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2)
|
||||
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_5, Ignored(), True)
|
||||
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_5)
|
||||
|
|
@ -192,8 +181,7 @@ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
|||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
||||
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored())
|
||||
clone_default = CallFunction(aten.clone.default, convert_element_type_default_1)
|
||||
expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored())
|
||||
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
|
||||
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
||||
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
|
||||
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
|
||||
|
|
|
|||
|
|
@ -43,23 +43,19 @@ amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
|
|||
sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
|
||||
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
||||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
|
||||
expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
|
||||
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
||||
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
|
||||
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
||||
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
||||
alias_default = CallFunction(aten.alias.default, div_Tensor_1)
|
||||
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
||||
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
||||
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
|
||||
neg_default = CallFunction(aten.neg.default, alias_default_3)
|
||||
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
|
||||
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
||||
permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
||||
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
|
||||
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
||||
mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, alias_default_3, _users=2)
|
||||
mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, div_Tensor_1, _users=2)
|
||||
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
|
||||
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor)
|
||||
div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, Ignored())
|
||||
|
|
@ -126,11 +122,7 @@ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignore
|
|||
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
||||
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
||||
alias_default = CallFunction(aten.alias.default, convert_element_type_default_1)
|
||||
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
||||
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
||||
alias_default_3 = CallFunction(aten.alias.default, alias_default_2)
|
||||
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2)
|
||||
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
|
||||
neg_default = CallFunction(aten.neg.default, convert_element_type_default_2)
|
||||
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
||||
permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
|
|||
sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
|
||||
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
||||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
|
||||
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
|
||||
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
|
||||
expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
|
||||
|
|
@ -54,11 +54,7 @@ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignore
|
|||
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
||||
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
||||
alias_default = CallFunction(aten.alias.default, div_Tensor_1)
|
||||
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
||||
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
||||
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
|
||||
neg_default = CallFunction(aten.neg.default, alias_default_3)
|
||||
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
|
||||
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
||||
permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
||||
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
|
||||
|
|
@ -66,8 +62,7 @@ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
|||
convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
||||
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored())
|
||||
mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
|
||||
clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
|
||||
mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default, alias_default_3, _users=2)
|
||||
mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2)
|
||||
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
||||
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
|
||||
div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, Ignored())
|
||||
|
|
@ -105,8 +100,7 @@ sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
|
|||
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
||||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
||||
clone_default = CallFunction(aten.clone.default, div_Tensor_1)
|
||||
expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored())
|
||||
expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
|
||||
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
||||
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
|
||||
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
|
||||
|
|
@ -140,11 +134,7 @@ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignore
|
|||
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
||||
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
||||
alias_default = CallFunction(aten.alias.default, convert_element_type_default_1)
|
||||
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
||||
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
||||
alias_default_3 = CallFunction(aten.alias.default, alias_default_2)
|
||||
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2)
|
||||
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
|
||||
neg_default = CallFunction(aten.neg.default, convert_element_type_default_2)
|
||||
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
||||
permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
||||
|
|
@ -153,8 +143,7 @@ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
|||
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
||||
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored())
|
||||
mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
|
||||
clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
|
||||
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, clone_default, Ignored())
|
||||
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored())
|
||||
mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2)
|
||||
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
||||
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
|
||||
|
|
@ -196,8 +185,7 @@ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
|||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
||||
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
|
||||
clone_default = CallFunction(aten.clone.default, convert_element_type_default_1)
|
||||
expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored())
|
||||
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
|
||||
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
||||
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
|
||||
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
|
|||
sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
|
||||
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
||||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
|
||||
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
|
||||
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
|
||||
convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored())
|
||||
|
|
@ -60,11 +60,7 @@ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_form
|
|||
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
||||
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
||||
alias_default = CallFunction(aten.alias.default, div_Tensor_1)
|
||||
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
||||
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
||||
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
|
||||
neg_default = CallFunction(aten.neg.default, alias_default_3)
|
||||
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
|
||||
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
||||
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
||||
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
|
||||
|
|
@ -74,8 +70,7 @@ convert_element_type_default_2 = CallFunction(prims.convert_element_type.default
|
|||
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
||||
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored())
|
||||
mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, mul_Tensor_2)
|
||||
clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
|
||||
mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default_3, alias_default_3, _users=2)
|
||||
mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2)
|
||||
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
||||
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
|
||||
div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, Ignored())
|
||||
|
|
@ -118,14 +113,13 @@ sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
|
|||
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
||||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
||||
clone_default_2 = CallFunction(aten.clone.default, div_Tensor_1)
|
||||
convert_element_type_default = CallFunction(prims.convert_element_type.default, clone_default_2, Ignored())
|
||||
convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
|
||||
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
|
||||
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
||||
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
||||
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
||||
clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
||||
view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored())
|
||||
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
||||
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
||||
_sfdp_pattern_7_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
|
||||
|
||||
|
|
@ -149,7 +143,7 @@ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ign
|
|||
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
||||
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
||||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
|
||||
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
|
||||
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
|
||||
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored())
|
||||
|
|
@ -161,11 +155,7 @@ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_form
|
|||
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
||||
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
||||
alias_default = CallFunction(aten.alias.default, div_Tensor_1)
|
||||
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
||||
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
||||
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
|
||||
neg_default = CallFunction(aten.neg.default, alias_default_3)
|
||||
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
|
||||
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
||||
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
||||
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
|
||||
|
|
@ -174,8 +164,7 @@ convert_element_type_default_2 = CallFunction(prims.convert_element_type.default
|
|||
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
||||
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored())
|
||||
mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, mul_Tensor_2)
|
||||
clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
|
||||
mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default_3, alias_default_3, _users=2)
|
||||
mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2)
|
||||
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
||||
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
|
||||
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored())
|
||||
|
|
@ -220,13 +209,12 @@ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_de
|
|||
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
||||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
||||
clone_default_2 = CallFunction(aten.clone.default, div_Tensor_1)
|
||||
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, clone_default_2, Ignored())
|
||||
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
|
||||
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
|
||||
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
||||
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
||||
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
||||
clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
||||
view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored())
|
||||
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
||||
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
||||
_sfdp_pattern_7_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
|
|||
sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
|
||||
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
||||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
|
||||
convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
|
||||
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
|
||||
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
||||
|
|
@ -56,18 +56,14 @@ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_form
|
|||
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
||||
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
||||
alias_default = CallFunction(aten.alias.default, div_Tensor_1)
|
||||
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
||||
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
||||
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
|
||||
neg_default = CallFunction(aten.neg.default, alias_default_3)
|
||||
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
|
||||
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
||||
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
||||
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
|
||||
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, bmm_default_2, Ignored())
|
||||
view_default_7 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored())
|
||||
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
|
||||
mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, alias_default_3, _users=2)
|
||||
mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, div_Tensor_1, _users=2)
|
||||
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
|
||||
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor)
|
||||
div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, Ignored())
|
||||
|
|
@ -137,7 +133,7 @@ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ign
|
|||
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
||||
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
||||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
|
||||
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
|
||||
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
|
||||
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
||||
|
|
@ -147,17 +143,13 @@ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_form
|
|||
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
||||
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
||||
alias_default = CallFunction(aten.alias.default, div_Tensor_1)
|
||||
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
||||
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
||||
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
|
||||
neg_default = CallFunction(aten.neg.default, alias_default_3)
|
||||
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
|
||||
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
||||
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
||||
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
|
||||
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
||||
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
|
||||
mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, alias_default_3, _users=2)
|
||||
mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, div_Tensor_1, _users=2)
|
||||
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
|
||||
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor)
|
||||
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, fma_default, Ignored())
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ amax_default = CallFunction(aten.amax.default, view_default_2, Ignored(), True)
|
|||
sub_Tensor = CallFunction(aten.sub.Tensor, view_default_2, amax_default)
|
||||
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
||||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
|
||||
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
|
||||
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
|
||||
convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored())
|
||||
|
|
@ -60,11 +60,7 @@ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_form
|
|||
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
||||
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
||||
alias_default = CallFunction(aten.alias.default, div_Tensor_1)
|
||||
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
||||
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
||||
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
|
||||
neg_default = CallFunction(aten.neg.default, alias_default_3)
|
||||
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
|
||||
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
||||
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
||||
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
|
||||
|
|
@ -74,8 +70,7 @@ convert_element_type_default_2 = CallFunction(prims.convert_element_type.default
|
|||
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
||||
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored())
|
||||
mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, mul_Tensor_2)
|
||||
clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
|
||||
mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default_3, alias_default_3, _users=2)
|
||||
mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2)
|
||||
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
||||
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
|
||||
view_default_8 = CallFunction(aten.view.default, fma_default, Ignored(), _users=2)
|
||||
|
|
@ -118,14 +113,13 @@ sub_Tensor = CallFunction(aten.sub.Tensor, view_default_2, amax_default)
|
|||
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
||||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
||||
clone_default_2 = CallFunction(aten.clone.default, div_Tensor_1)
|
||||
convert_element_type_default = CallFunction(prims.convert_element_type.default, clone_default_2, Ignored())
|
||||
convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
|
||||
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
|
||||
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
||||
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
||||
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
||||
clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
||||
view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored())
|
||||
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
||||
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
||||
_sfdp_pattern_9_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
|
||||
|
||||
|
|
@ -149,7 +143,7 @@ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ign
|
|||
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
||||
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
||||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
|
||||
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
|
||||
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
|
||||
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored())
|
||||
|
|
@ -161,11 +155,7 @@ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_form
|
|||
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
||||
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
||||
alias_default = CallFunction(aten.alias.default, div_Tensor_1)
|
||||
alias_default_1 = CallFunction(aten.alias.default, alias_default)
|
||||
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
|
||||
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
|
||||
neg_default = CallFunction(aten.neg.default, alias_default_3)
|
||||
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
|
||||
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
||||
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
||||
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
|
||||
|
|
@ -174,8 +164,7 @@ convert_element_type_default_2 = CallFunction(prims.convert_element_type.default
|
|||
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
||||
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored())
|
||||
mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, mul_Tensor_2)
|
||||
clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
|
||||
mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default_3, alias_default_3, _users=2)
|
||||
mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2)
|
||||
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
||||
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
|
||||
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored())
|
||||
|
|
@ -220,13 +209,12 @@ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_de
|
|||
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
||||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
||||
clone_default_2 = CallFunction(aten.clone.default, div_Tensor_1)
|
||||
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, clone_default_2, Ignored())
|
||||
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
|
||||
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
|
||||
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
||||
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
||||
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
||||
clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
||||
view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored())
|
||||
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
||||
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
||||
_sfdp_pattern_9_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
|
||||
|
|
|
|||
|
|
@ -906,6 +906,7 @@ class PatternPrettyPrinter:
|
|||
self.memoized_objs_pp: Dict[PatternExpr, str] = {}
|
||||
|
||||
@staticmethod
|
||||
@functools.lru_cache(None)
|
||||
def run(obj: PatternExpr, output_name: str = "output") -> str:
|
||||
"""
|
||||
Serializes obj to python code with obj written out to `output_name`
|
||||
|
|
@ -1463,6 +1464,7 @@ def gen_register_replacement(
|
|||
extra_check: Callable[[Match], bool] = _return_true,
|
||||
scalar_workaround: Union[Dict[str, Union[float, int]], None] = None,
|
||||
exclusive_arg_names: Sequence[str] = (),
|
||||
skip_duplicates: bool = False,
|
||||
) -> None:
|
||||
# Make sure the example_inputs is materialized.
|
||||
example_inputs = tuple(example_inputs)
|
||||
|
|
@ -1491,6 +1493,8 @@ def gen_register_replacement(
|
|||
# Since this is just an optimization we can clear it out.
|
||||
arg.constant = None
|
||||
|
||||
if PatternPrettyPrinter.run(pat) in _seen_patterns and skip_duplicates:
|
||||
return
|
||||
_known_precompiled_patterns.append(
|
||||
(search_fn, example_inputs, trace_fn, scalar_workaround, pat)
|
||||
)
|
||||
|
|
@ -1790,6 +1794,11 @@ def fwd_only(
|
|||
# TODO - look into using aot autograd, asserting no mutating ops here
|
||||
with enable_python_dispatcher():
|
||||
gm = make_fx(fn, select_decomp_table(), tracing_mode="real")(*args)
|
||||
|
||||
from .fx_passes.post_grad import remove_noop_ops
|
||||
|
||||
remove_noop_ops(gm.graph)
|
||||
|
||||
if run_dce:
|
||||
gm.graph.eliminate_dead_code()
|
||||
gm.recompile()
|
||||
|
|
@ -1820,6 +1829,10 @@ def joint_fwd_bwd(fn: Callable[..., Any], args: Sequence[Any]) -> torch.fx.Graph
|
|||
)(*args)
|
||||
assert gm
|
||||
|
||||
from .fx_passes.post_grad import remove_noop_ops
|
||||
|
||||
remove_noop_ops(gm.graph)
|
||||
|
||||
from .fx_passes.joint_graph import pointless_view
|
||||
|
||||
matcher_pass = PatternMatcherPass()
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user