Added remove_noop_ops to joint_graph_passes (#124451)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124451
Approved by: https://github.com/ezyang, https://github.com/fmassa
This commit is contained in:
chilli 2024-06-07 16:08:14 -07:00 committed by PyTorch MergeBot
parent c446851829
commit ffc202a1b9
23 changed files with 187 additions and 399 deletions

View File

@ -1048,10 +1048,10 @@ class GraphModule(torch.nn.Module):
joint_graph,
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f64[2, 2, 8, 4]", primals_2: "f64[2, 2, 8, 4]", primals_3: "f64[2, 2, 8, 4]", alias_3: "f64[2, 2, 8, 4]", alias_5: "f32[2, 2, 8]", tangents_1: "f64[2, 2, 8, 4]"):
def forward(self, primals_1: "f64[2, 2, 8, 4]", primals_2: "f64[2, 2, 8, 4]", primals_3: "f64[2, 2, 8, 4]", getitem: "f64[2, 2, 8, 4]", getitem_1: "f32[2, 2, 8]", tangents_1: "f64[2, 2, 8, 4]"):
fw_graph = self.fw_graph
joint_graph = self.joint_graph
flex_attention_backward = torch.ops.higher_order.flex_attention_backward(primals_1, primals_2, primals_3, alias_3, alias_5, tangents_1, fw_graph, joint_graph); primals_1 = primals_2 = primals_3 = alias_3 = alias_5 = tangents_1 = fw_graph = joint_graph = None
flex_attention_backward = torch.ops.higher_order.flex_attention_backward(primals_1, primals_2, primals_3, getitem, getitem_1, tangents_1, fw_graph, joint_graph); primals_1 = primals_2 = primals_3 = getitem = getitem_1 = tangents_1 = fw_graph = joint_graph = None
getitem_2: "f64[2, 2, 8, 4]" = flex_attention_backward[0]
getitem_3: "f64[2, 2, 8, 4]" = flex_attention_backward[1]
getitem_4: "f64[2, 2, 8, 4]" = flex_attention_backward[2]; flex_attention_backward = None

View File

@ -883,6 +883,9 @@ def _get_sfdp_patterns():
"pass_dicts": patterns,
"extra_check": extra_check,
"scalar_workaround": workaround,
# with dropout turned into clone, we end up with a number of
# semantically identical graphs
"skip_duplicates": True,
}

View File

@ -313,6 +313,10 @@ def joint_graph_passes(graph: torch.fx.GraphModule):
config.joint_custom_pre_pass(graph.graph)
count += 1
from .post_grad import remove_noop_ops
remove_noop_ops(graph.graph)
if config.joint_graph_constant_folding:
constant_fold_uniform_value(graph)

View File

@ -42,23 +42,19 @@ amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
alias_default = CallFunction(aten.alias.default, div_Tensor_1)
alias_default_1 = CallFunction(aten.alias.default, alias_default)
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
neg_default = CallFunction(aten.neg.default, alias_default_3)
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, alias_default_3, _users=2)
mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, div_Tensor_1, _users=2)
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor)
div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale'))
@ -123,11 +119,7 @@ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignore
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
alias_default = CallFunction(aten.alias.default, convert_element_type_default_1)
alias_default_1 = CallFunction(aten.alias.default, alias_default)
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
alias_default_3 = CallFunction(aten.alias.default, alias_default_2)
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2)
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
neg_default = CallFunction(aten.neg.default, convert_element_type_default_2)
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())

View File

@ -46,7 +46,7 @@ amax_default = CallFunction(aten.amax.default, view_default_2, Ignored(), True)
sub_Tensor = CallFunction(aten.sub.Tensor, view_default_2, amax_default)
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
@ -56,18 +56,14 @@ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_form
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
alias_default = CallFunction(aten.alias.default, div_Tensor_1)
alias_default_1 = CallFunction(aten.alias.default, alias_default)
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
neg_default = CallFunction(aten.neg.default, alias_default_3)
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, bmm_default_2, Ignored())
view_default_7 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored())
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, alias_default_3, _users=2)
mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, div_Tensor_1, _users=2)
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor)
view_default_8 = CallFunction(aten.view.default, fma_default, Ignored(), _users=2)
@ -137,7 +133,7 @@ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ign
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
@ -147,17 +143,13 @@ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_form
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
alias_default = CallFunction(aten.alias.default, div_Tensor_1)
alias_default_1 = CallFunction(aten.alias.default, alias_default)
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
neg_default = CallFunction(aten.neg.default, alias_default_3)
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, alias_default_3, _users=2)
mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, div_Tensor_1, _users=2)
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor)
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, fma_default, Ignored())

View File

@ -46,7 +46,7 @@ amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
@ -55,16 +55,12 @@ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_form
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
alias_default = CallFunction(aten.alias.default, div_Tensor_1)
alias_default_1 = CallFunction(aten.alias.default, alias_default)
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
neg_default = CallFunction(aten.neg.default, alias_default_3)
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, alias_default_3, _users=2)
mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, div_Tensor_1, _users=2)
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor)
div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale'))
@ -144,11 +140,7 @@ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_form
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
alias_default = CallFunction(aten.alias.default, convert_element_type_default_1)
alias_default_1 = CallFunction(aten.alias.default, alias_default)
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
alias_default_3 = CallFunction(aten.alias.default, alias_default_2)
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2)
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
neg_default = CallFunction(aten.neg.default, convert_element_type_default_2)
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())

View File

@ -48,7 +48,7 @@ amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
@ -59,11 +59,7 @@ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_form
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
alias_default = CallFunction(aten.alias.default, div_Tensor_1)
alias_default_1 = CallFunction(aten.alias.default, alias_default)
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
neg_default = CallFunction(aten.neg.default, alias_default_3)
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
@ -71,8 +67,7 @@ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored())
mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default_3, alias_default_3, _users=2)
mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2)
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale_factor'))
@ -116,13 +111,12 @@ sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
clone_default_2 = CallFunction(aten.clone.default, div_Tensor_1)
expand_default_2 = CallFunction(aten.expand.default, clone_default_2, Ignored())
expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored())
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
_sfdp_pattern_12_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
@ -158,11 +152,7 @@ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_form
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
alias_default = CallFunction(aten.alias.default, convert_element_type_default_1)
alias_default_1 = CallFunction(aten.alias.default, alias_default)
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
alias_default_3 = CallFunction(aten.alias.default, alias_default_2)
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2)
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
neg_default = CallFunction(aten.neg.default, convert_element_type_default_2)
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
@ -171,8 +161,7 @@ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored())
mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, clone_default_3, Ignored())
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored())
mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2)
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
@ -220,12 +209,11 @@ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
clone_default_2 = CallFunction(aten.clone.default, convert_element_type_default_1)
expand_default_2 = CallFunction(aten.expand.default, clone_default_2, Ignored())
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored())
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
_sfdp_pattern_12_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)

View File

@ -38,22 +38,17 @@ amax_default = CallFunction(aten.amax.default, bmm_default, Ignored(), True)
sub_Tensor = CallFunction(aten.sub.Tensor, bmm_default, amax_default)
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor)
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored(), _users=2)
bmm_default_1 = CallFunction(aten.bmm.default, mul_Tensor_1, KeywordArg('value'))
alias_default = CallFunction(aten.alias.default, div_Tensor)
alias_default_1 = CallFunction(aten.alias.default, alias_default)
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
neg_default = CallFunction(aten.neg.default, alias_default_3)
neg_default = CallFunction(aten.neg.default, div_Tensor)
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
bmm_default_2 = CallFunction(aten.bmm.default, KeywordArg('tangents_1'), permute_default_1)
convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored())
mul_Tensor_3 = CallFunction(aten.mul.Tensor, bmm_default_2, mul_Tensor_2)
clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default, alias_default_3, _users=2)
mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor, _users=2)
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4, _users=2)
permute_default_2 = CallFunction(aten.permute.default, permute_default, Ignored())
@ -78,8 +73,7 @@ sub_Tensor = CallFunction(aten.sub.Tensor, bmm_default, amax_default)
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
clone_default = CallFunction(aten.clone.default, div_Tensor)
_sfdp_pattern_13_inference = CallFunction(aten.bmm.default, clone_default, KeywordArg('value'), _users=0)
_sfdp_pattern_13_inference = CallFunction(aten.bmm.default, div_Tensor, KeywordArg('value'), _users=0)
rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
@ -96,19 +90,14 @@ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1)
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored(), _users=2)
bmm_default_1 = CallFunction(aten.bmm.default, mul_Tensor_1, KeywordArg('value'))
alias_default = CallFunction(aten.alias.default, convert_element_type_default_1)
alias_default_1 = CallFunction(aten.alias.default, alias_default)
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
alias_default_3 = CallFunction(aten.alias.default, alias_default_2)
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2)
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
neg_default = CallFunction(aten.neg.default, convert_element_type_default_2)
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
bmm_default_2 = CallFunction(aten.bmm.default, KeywordArg('tangents_1'), permute_default_1)
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored())
mul_Tensor_3 = CallFunction(aten.mul.Tensor, bmm_default_2, mul_Tensor_2)
clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, clone_default, Ignored())
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored())
mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2)
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
@ -137,5 +126,4 @@ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored())
clone_default = CallFunction(aten.clone.default, convert_element_type_default_1)
_sfdp_pattern_13_half_inference = CallFunction(aten.bmm.default, clone_default, KeywordArg('value'), _users=0)
_sfdp_pattern_13_half_inference = CallFunction(aten.bmm.default, convert_element_type_default_1, KeywordArg('value'), _users=0)

View File

@ -47,7 +47,7 @@ amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
@ -56,16 +56,12 @@ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_form
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
alias_default = CallFunction(aten.alias.default, div_Tensor_1)
alias_default_1 = CallFunction(aten.alias.default, alias_default)
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
neg_default = CallFunction(aten.neg.default, alias_default_3)
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, alias_default_3, _users=2)
mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, div_Tensor_1, _users=2)
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor)
div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale'))
@ -148,11 +144,7 @@ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_form
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
alias_default = CallFunction(aten.alias.default, convert_element_type_default_1)
alias_default_1 = CallFunction(aten.alias.default, alias_default)
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
alias_default_3 = CallFunction(aten.alias.default, alias_default_2)
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2)
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
neg_default = CallFunction(aten.neg.default, convert_element_type_default_2)
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())

View File

@ -50,7 +50,7 @@ amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True)
sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default)
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
expand_default_3 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
view_default_3 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
@ -60,16 +60,12 @@ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _us
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored())
alias_default = CallFunction(aten.alias.default, div_Tensor_1)
alias_default_1 = CallFunction(aten.alias.default, alias_default)
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
neg_default = CallFunction(aten.neg.default, alias_default_3)
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, alias_default_3, _users=2)
mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, div_Tensor_1, _users=2)
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor)
where_self_1 = CallFunction(aten.where.self, expand_default, scalar_tensor_default, fma_default)
@ -161,11 +157,7 @@ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _us
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored())
alias_default = CallFunction(aten.alias.default, convert_element_type_default_1)
alias_default_1 = CallFunction(aten.alias.default, alias_default)
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
alias_default_3 = CallFunction(aten.alias.default, alias_default_2)
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2)
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
neg_default = CallFunction(aten.neg.default, convert_element_type_default_2)
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())

View File

@ -49,7 +49,7 @@ amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
@ -60,11 +60,7 @@ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_form
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
alias_default = CallFunction(aten.alias.default, div_Tensor_1)
alias_default_1 = CallFunction(aten.alias.default, alias_default)
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
neg_default = CallFunction(aten.neg.default, alias_default_3)
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
@ -72,8 +68,7 @@ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored())
mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default_3, alias_default_3, _users=2)
mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2)
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale'))
@ -119,13 +114,12 @@ sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
clone_default_2 = CallFunction(aten.clone.default, div_Tensor_1)
expand_default_2 = CallFunction(aten.expand.default, clone_default_2, Ignored())
expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored())
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
_sfdp_pattern_16_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
@ -147,7 +141,7 @@ amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
@ -157,11 +151,7 @@ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored(
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
alias_default = CallFunction(aten.alias.default, div_Tensor_1)
alias_default_1 = CallFunction(aten.alias.default, alias_default)
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
neg_default = CallFunction(aten.neg.default, alias_default_3)
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
@ -169,8 +159,7 @@ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored())
mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default, alias_default_3, _users=2)
mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2)
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale'))
@ -214,8 +203,7 @@ sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
clone_default = CallFunction(aten.clone.default, div_Tensor_1)
expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored())
expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
@ -256,11 +244,7 @@ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_form
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
alias_default = CallFunction(aten.alias.default, convert_element_type_default_1)
alias_default_1 = CallFunction(aten.alias.default, alias_default)
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
alias_default_3 = CallFunction(aten.alias.default, alias_default_2)
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2)
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
neg_default = CallFunction(aten.neg.default, convert_element_type_default_2)
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
@ -269,8 +253,7 @@ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored())
mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, clone_default_3, Ignored())
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored())
mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2)
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
@ -320,13 +303,12 @@ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
clone_default_2 = CallFunction(aten.clone.default, convert_element_type_default_1)
expand_default_2 = CallFunction(aten.expand.default, clone_default_2, Ignored())
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored())
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
_sfdp_pattern_16_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
@ -360,11 +342,7 @@ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored(
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
alias_default = CallFunction(aten.alias.default, convert_element_type_default_1)
alias_default_1 = CallFunction(aten.alias.default, alias_default)
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
alias_default_3 = CallFunction(aten.alias.default, alias_default_2)
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2)
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
neg_default = CallFunction(aten.neg.default, convert_element_type_default_2)
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
@ -373,8 +351,7 @@ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored())
mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, clone_default, Ignored())
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored())
mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2)
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
@ -422,8 +399,7 @@ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
clone_default = CallFunction(aten.clone.default, convert_element_type_default_1)
expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored())
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
@ -451,7 +427,7 @@ amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored())
@ -463,11 +439,7 @@ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_form
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
alias_default = CallFunction(aten.alias.default, div_Tensor_1)
alias_default_1 = CallFunction(aten.alias.default, alias_default)
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
neg_default = CallFunction(aten.neg.default, alias_default_3)
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
@ -476,8 +448,7 @@ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, Ignored())
mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_1, mul_Tensor_2)
clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default_3, alias_default_3, _users=2)
mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2)
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, fma_default, Ignored())
@ -524,14 +495,13 @@ sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
clone_default_2 = CallFunction(aten.clone.default, div_Tensor_1)
convert_element_type_default = CallFunction(prims.convert_element_type.default, clone_default_2, Ignored())
convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored())
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
_sfdp_pattern_16_half_mask_fp32_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
@ -553,7 +523,7 @@ amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored())
@ -564,11 +534,7 @@ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored(
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
alias_default = CallFunction(aten.alias.default, div_Tensor_1)
alias_default_1 = CallFunction(aten.alias.default, alias_default)
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
neg_default = CallFunction(aten.neg.default, alias_default_3)
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
@ -577,8 +543,7 @@ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, Ignored())
mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_1, mul_Tensor_2)
clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default, alias_default_3, _users=2)
mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2)
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, fma_default, Ignored())
@ -623,8 +588,7 @@ sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
clone_default = CallFunction(aten.clone.default, div_Tensor_1)
convert_element_type_default = CallFunction(prims.convert_element_type.default, clone_default, Ignored())
convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())

View File

@ -52,7 +52,7 @@ amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True)
sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default)
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
expand_default_3 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
@ -64,11 +64,7 @@ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _us
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored())
alias_default = CallFunction(aten.alias.default, div_Tensor_1)
alias_default_1 = CallFunction(aten.alias.default, alias_default)
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
neg_default = CallFunction(aten.neg.default, alias_default_3)
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
@ -76,8 +72,7 @@ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored())
mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default_3, alias_default_3, _users=2)
mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2)
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
where_self_1 = CallFunction(aten.where.self, expand_default, scalar_tensor_default, fma_default)
@ -128,13 +123,12 @@ sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default)
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
clone_default_2 = CallFunction(aten.clone.default, div_Tensor_1)
expand_default_3 = CallFunction(aten.expand.default, clone_default_2, Ignored())
expand_default_3 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored())
clone_default_3 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format)
view_default_5 = CallFunction(aten.view.default, clone_default_3, Ignored())
clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format)
view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored())
bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5)
_sfdp_pattern_17_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
@ -175,11 +169,7 @@ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _us
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored())
alias_default = CallFunction(aten.alias.default, convert_element_type_default_1)
alias_default_1 = CallFunction(aten.alias.default, alias_default)
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
alias_default_3 = CallFunction(aten.alias.default, alias_default_2)
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2)
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
neg_default = CallFunction(aten.neg.default, convert_element_type_default_2)
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
@ -188,8 +178,7 @@ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored())
mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, clone_default_3, Ignored())
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored())
mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2)
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
@ -244,12 +233,11 @@ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
clone_default_2 = CallFunction(aten.clone.default, convert_element_type_default_1)
expand_default_3 = CallFunction(aten.expand.default, clone_default_2, Ignored())
expand_default_3 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored())
clone_default_3 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format)
view_default_5 = CallFunction(aten.view.default, clone_default_3, Ignored())
clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format)
view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored())
bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5)
_sfdp_pattern_17_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)

View File

@ -51,7 +51,7 @@ amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True)
sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default)
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
@ -62,11 +62,7 @@ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_form
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
alias_default = CallFunction(aten.alias.default, div_Tensor_1)
alias_default_1 = CallFunction(aten.alias.default, alias_default)
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
neg_default = CallFunction(aten.neg.default, alias_default_3)
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
@ -74,8 +70,7 @@ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored())
mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default_3, alias_default_3, _users=2)
mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2)
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored())
@ -126,13 +121,12 @@ sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default)
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
clone_default_2 = CallFunction(aten.clone.default, div_Tensor_1)
expand_default_2 = CallFunction(aten.expand.default, clone_default_2, Ignored())
expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2)
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored())
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
_sfdp_pattern_18_inference = MultiOutputPattern([view_default_5,
@ -160,7 +154,7 @@ amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True)
sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default)
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
@ -170,11 +164,7 @@ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored(
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
alias_default = CallFunction(aten.alias.default, div_Tensor_1)
alias_default_1 = CallFunction(aten.alias.default, alias_default)
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
neg_default = CallFunction(aten.neg.default, alias_default_3)
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
@ -182,8 +172,7 @@ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored())
mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default, alias_default_3, _users=2)
mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2)
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored())
@ -232,8 +221,7 @@ sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default)
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
clone_default = CallFunction(aten.clone.default, div_Tensor_1)
expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored())
expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2)
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
@ -280,11 +268,7 @@ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_form
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
alias_default = CallFunction(aten.alias.default, convert_element_type_default_1)
alias_default_1 = CallFunction(aten.alias.default, alias_default)
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
alias_default_3 = CallFunction(aten.alias.default, alias_default_2)
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2)
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
neg_default = CallFunction(aten.neg.default, convert_element_type_default_2)
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
@ -293,8 +277,7 @@ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored())
mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, clone_default_3, Ignored())
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored())
mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2)
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
@ -349,13 +332,12 @@ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
clone_default_2 = CallFunction(aten.clone.default, convert_element_type_default_1)
expand_default_2 = CallFunction(aten.expand.default, clone_default_2, Ignored())
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2)
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored())
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
_sfdp_pattern_18_half_inference = MultiOutputPattern([view_default_5,
@ -395,11 +377,7 @@ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored(
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
alias_default = CallFunction(aten.alias.default, convert_element_type_default_1)
alias_default_1 = CallFunction(aten.alias.default, alias_default)
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
alias_default_3 = CallFunction(aten.alias.default, alias_default_2)
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2)
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
neg_default = CallFunction(aten.neg.default, convert_element_type_default_2)
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
@ -408,8 +386,7 @@ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored())
mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, clone_default, Ignored())
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored())
mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2)
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
@ -462,8 +439,7 @@ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
clone_default = CallFunction(aten.clone.default, convert_element_type_default_1)
expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored())
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2)
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())

View File

@ -48,7 +48,7 @@ amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
@ -57,11 +57,7 @@ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignore
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
alias_default = CallFunction(aten.alias.default, div_Tensor_1)
alias_default_1 = CallFunction(aten.alias.default, alias_default)
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
neg_default = CallFunction(aten.neg.default, alias_default_3)
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
@ -69,8 +65,7 @@ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored())
mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default, alias_default_3, _users=2)
mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2)
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored())
@ -114,8 +109,7 @@ sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
clone_default = CallFunction(aten.clone.default, div_Tensor_1)
expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored())
expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
@ -141,7 +135,7 @@ amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default)
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
@ -151,11 +145,7 @@ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignore
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
alias_default = CallFunction(aten.alias.default, div_Tensor_1)
alias_default_1 = CallFunction(aten.alias.default, alias_default)
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
neg_default = CallFunction(aten.neg.default, alias_default_3)
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
@ -163,9 +153,8 @@ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_1, Ignored())
mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, clone_default, Ignored())
mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, alias_default_3, _users=2)
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored())
mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, div_Tensor_1, _users=2)
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, fma_default, Ignored())
@ -211,8 +200,7 @@ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
clone_default = CallFunction(aten.clone.default, convert_element_type_default)
expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored())
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())

View File

@ -42,23 +42,19 @@ amax_default = CallFunction(aten.amax.default, mul_Tensor, Ignored(), True)
sub_Tensor = CallFunction(aten.sub.Tensor, mul_Tensor, amax_default)
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored())
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
alias_default = CallFunction(aten.alias.default, div_Tensor)
alias_default_1 = CallFunction(aten.alias.default, alias_default)
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
neg_default = CallFunction(aten.neg.default, alias_default_3)
neg_default = CallFunction(aten.neg.default, div_Tensor)
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
mul_Tensor_1 = CallFunction(aten.mul.Tensor, view_default_7, alias_default_3, _users=2)
mul_Tensor_1 = CallFunction(aten.mul.Tensor, view_default_7, div_Tensor, _users=2)
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_1, Ignored(), True)
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_1)
mul_Tensor_2 = CallFunction(aten.mul.Tensor, fma_default, KeywordArg('scale_factor'))
@ -123,11 +119,7 @@ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignore
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
alias_default = CallFunction(aten.alias.default, convert_element_type_default_1)
alias_default_1 = CallFunction(aten.alias.default, alias_default)
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
alias_default_3 = CallFunction(aten.alias.default, alias_default_2)
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2)
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
neg_default = CallFunction(aten.neg.default, convert_element_type_default_2)
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())

View File

@ -44,7 +44,7 @@ amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
@ -53,11 +53,7 @@ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignore
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
alias_default = CallFunction(aten.alias.default, div_Tensor_1)
alias_default_1 = CallFunction(aten.alias.default, alias_default)
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
neg_default = CallFunction(aten.neg.default, alias_default_3)
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
@ -65,8 +61,7 @@ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored())
mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default, alias_default_3, _users=2)
mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2)
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale_factor'))
@ -103,8 +98,7 @@ sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
clone_default = CallFunction(aten.clone.default, div_Tensor_1)
expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored())
expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
@ -137,11 +131,7 @@ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignore
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
alias_default = CallFunction(aten.alias.default, convert_element_type_default_1)
alias_default_1 = CallFunction(aten.alias.default, alias_default)
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
alias_default_3 = CallFunction(aten.alias.default, alias_default_2)
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2)
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
neg_default = CallFunction(aten.neg.default, convert_element_type_default_2)
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
@ -150,8 +140,7 @@ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored())
mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, clone_default, Ignored())
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored())
mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2)
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
@ -192,8 +181,7 @@ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
clone_default = CallFunction(aten.clone.default, convert_element_type_default_1)
expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored())
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())

View File

@ -44,7 +44,7 @@ amax_default = CallFunction(aten.amax.default, mul_Tensor, Ignored(), True)
sub_Tensor = CallFunction(aten.sub.Tensor, mul_Tensor, amax_default)
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
mul_Tensor_1 = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor)
mul_Tensor_2 = CallFunction(aten.mul.Tensor, mul_Tensor_1, Ignored())
expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_2, Ignored())
@ -53,11 +53,7 @@ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignore
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
alias_default = CallFunction(aten.alias.default, div_Tensor)
alias_default_1 = CallFunction(aten.alias.default, alias_default)
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
neg_default = CallFunction(aten.neg.default, alias_default_3)
neg_default = CallFunction(aten.neg.default, div_Tensor)
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
@ -65,8 +61,7 @@ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored())
mul_Tensor_4 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_3)
clone_default = CallFunction(aten.clone.default, mul_Tensor_4, memory_format=torch.contiguous_format)
mul_Tensor_5 = CallFunction(aten.mul.Tensor, clone_default, alias_default_3, _users=2)
mul_Tensor_5 = CallFunction(aten.mul.Tensor, mul_Tensor_4, div_Tensor, _users=2)
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_5, Ignored(), True)
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_5)
mul_Tensor_6 = CallFunction(aten.mul.Tensor, fma_default, KeywordArg('scale_factor'))
@ -103,8 +98,7 @@ sub_Tensor = CallFunction(aten.sub.Tensor, mul_Tensor, amax_default)
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
clone_default = CallFunction(aten.clone.default, div_Tensor)
expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored())
expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored())
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
@ -137,11 +131,7 @@ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignore
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
alias_default = CallFunction(aten.alias.default, convert_element_type_default_1)
alias_default_1 = CallFunction(aten.alias.default, alias_default)
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
alias_default_3 = CallFunction(aten.alias.default, alias_default_2)
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2)
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
neg_default = CallFunction(aten.neg.default, convert_element_type_default_2)
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
@ -150,8 +140,7 @@ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored())
mul_Tensor_4 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_3)
clone_default = CallFunction(aten.clone.default, mul_Tensor_4, memory_format=torch.contiguous_format)
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, clone_default, Ignored())
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_4, Ignored())
mul_Tensor_5 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2)
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_5, Ignored(), True)
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_5)
@ -192,8 +181,7 @@ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored())
clone_default = CallFunction(aten.clone.default, convert_element_type_default_1)
expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored())
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())

View File

@ -43,23 +43,19 @@ amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
alias_default = CallFunction(aten.alias.default, div_Tensor_1)
alias_default_1 = CallFunction(aten.alias.default, alias_default)
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
neg_default = CallFunction(aten.neg.default, alias_default_3)
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, alias_default_3, _users=2)
mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, div_Tensor_1, _users=2)
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor)
div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, Ignored())
@ -126,11 +122,7 @@ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignore
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
alias_default = CallFunction(aten.alias.default, convert_element_type_default_1)
alias_default_1 = CallFunction(aten.alias.default, alias_default)
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
alias_default_3 = CallFunction(aten.alias.default, alias_default_2)
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2)
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
neg_default = CallFunction(aten.neg.default, convert_element_type_default_2)
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())

View File

@ -45,7 +45,7 @@ amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
@ -54,11 +54,7 @@ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignore
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
alias_default = CallFunction(aten.alias.default, div_Tensor_1)
alias_default_1 = CallFunction(aten.alias.default, alias_default)
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
neg_default = CallFunction(aten.neg.default, alias_default_3)
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
@ -66,8 +62,7 @@ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored())
mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default, alias_default_3, _users=2)
mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2)
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, Ignored())
@ -105,8 +100,7 @@ sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
clone_default = CallFunction(aten.clone.default, div_Tensor_1)
expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored())
expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
@ -140,11 +134,7 @@ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignore
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
alias_default = CallFunction(aten.alias.default, convert_element_type_default_1)
alias_default_1 = CallFunction(aten.alias.default, alias_default)
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
alias_default_3 = CallFunction(aten.alias.default, alias_default_2)
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2)
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
neg_default = CallFunction(aten.neg.default, convert_element_type_default_2)
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
@ -153,8 +143,7 @@ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored())
mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, clone_default, Ignored())
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored())
mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2)
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
@ -196,8 +185,7 @@ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
clone_default = CallFunction(aten.clone.default, convert_element_type_default_1)
expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored())
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())

View File

@ -48,7 +48,7 @@ amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored())
@ -60,11 +60,7 @@ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_form
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
alias_default = CallFunction(aten.alias.default, div_Tensor_1)
alias_default_1 = CallFunction(aten.alias.default, alias_default)
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
neg_default = CallFunction(aten.neg.default, alias_default_3)
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
@ -74,8 +70,7 @@ convert_element_type_default_2 = CallFunction(prims.convert_element_type.default
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored())
mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, mul_Tensor_2)
clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default_3, alias_default_3, _users=2)
mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2)
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, Ignored())
@ -118,14 +113,13 @@ sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
clone_default_2 = CallFunction(aten.clone.default, div_Tensor_1)
convert_element_type_default = CallFunction(prims.convert_element_type.default, clone_default_2, Ignored())
convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored())
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
_sfdp_pattern_7_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
@ -149,7 +143,7 @@ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ign
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored())
@ -161,11 +155,7 @@ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_form
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
alias_default = CallFunction(aten.alias.default, div_Tensor_1)
alias_default_1 = CallFunction(aten.alias.default, alias_default)
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
neg_default = CallFunction(aten.neg.default, alias_default_3)
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
@ -174,8 +164,7 @@ convert_element_type_default_2 = CallFunction(prims.convert_element_type.default
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored())
mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, mul_Tensor_2)
clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default_3, alias_default_3, _users=2)
mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2)
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored())
@ -220,13 +209,12 @@ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_de
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
clone_default_2 = CallFunction(aten.clone.default, div_Tensor_1)
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, clone_default_2, Ignored())
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored())
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
_sfdp_pattern_7_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)

View File

@ -46,7 +46,7 @@ amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
@ -56,18 +56,14 @@ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_form
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
alias_default = CallFunction(aten.alias.default, div_Tensor_1)
alias_default_1 = CallFunction(aten.alias.default, alias_default)
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
neg_default = CallFunction(aten.neg.default, alias_default_3)
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, bmm_default_2, Ignored())
view_default_7 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored())
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, alias_default_3, _users=2)
mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, div_Tensor_1, _users=2)
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor)
div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, Ignored())
@ -137,7 +133,7 @@ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ign
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
@ -147,17 +143,13 @@ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_form
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
alias_default = CallFunction(aten.alias.default, div_Tensor_1)
alias_default_1 = CallFunction(aten.alias.default, alias_default)
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
neg_default = CallFunction(aten.neg.default, alias_default_3)
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, alias_default_3, _users=2)
mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, div_Tensor_1, _users=2)
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor)
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, fma_default, Ignored())

View File

@ -48,7 +48,7 @@ amax_default = CallFunction(aten.amax.default, view_default_2, Ignored(), True)
sub_Tensor = CallFunction(aten.sub.Tensor, view_default_2, amax_default)
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored())
@ -60,11 +60,7 @@ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_form
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
alias_default = CallFunction(aten.alias.default, div_Tensor_1)
alias_default_1 = CallFunction(aten.alias.default, alias_default)
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
neg_default = CallFunction(aten.neg.default, alias_default_3)
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
@ -74,8 +70,7 @@ convert_element_type_default_2 = CallFunction(prims.convert_element_type.default
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored())
mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, mul_Tensor_2)
clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default_3, alias_default_3, _users=2)
mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2)
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
view_default_8 = CallFunction(aten.view.default, fma_default, Ignored(), _users=2)
@ -118,14 +113,13 @@ sub_Tensor = CallFunction(aten.sub.Tensor, view_default_2, amax_default)
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
clone_default_2 = CallFunction(aten.clone.default, div_Tensor_1)
convert_element_type_default = CallFunction(prims.convert_element_type.default, clone_default_2, Ignored())
convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored())
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
_sfdp_pattern_9_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
@ -149,7 +143,7 @@ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ign
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored())
@ -161,11 +155,7 @@ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_form
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
alias_default = CallFunction(aten.alias.default, div_Tensor_1)
alias_default_1 = CallFunction(aten.alias.default, alias_default)
alias_default_2 = CallFunction(aten.alias.default, alias_default_1)
alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2)
neg_default = CallFunction(aten.neg.default, alias_default_3)
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
@ -174,8 +164,7 @@ convert_element_type_default_2 = CallFunction(prims.convert_element_type.default
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored())
mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, mul_Tensor_2)
clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format)
mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default_3, alias_default_3, _users=2)
mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2)
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored())
@ -220,13 +209,12 @@ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_de
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
clone_default_2 = CallFunction(aten.clone.default, div_Tensor_1)
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, clone_default_2, Ignored())
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored())
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
_sfdp_pattern_9_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)

View File

@ -906,6 +906,7 @@ class PatternPrettyPrinter:
self.memoized_objs_pp: Dict[PatternExpr, str] = {}
@staticmethod
@functools.lru_cache(None)
def run(obj: PatternExpr, output_name: str = "output") -> str:
"""
Serializes obj to python code with obj written out to `output_name`
@ -1463,6 +1464,7 @@ def gen_register_replacement(
extra_check: Callable[[Match], bool] = _return_true,
scalar_workaround: Union[Dict[str, Union[float, int]], None] = None,
exclusive_arg_names: Sequence[str] = (),
skip_duplicates: bool = False,
) -> None:
# Make sure the example_inputs is materialized.
example_inputs = tuple(example_inputs)
@ -1491,6 +1493,8 @@ def gen_register_replacement(
# Since this is just an optimization we can clear it out.
arg.constant = None
if PatternPrettyPrinter.run(pat) in _seen_patterns and skip_duplicates:
return
_known_precompiled_patterns.append(
(search_fn, example_inputs, trace_fn, scalar_workaround, pat)
)
@ -1790,6 +1794,11 @@ def fwd_only(
# TODO - look into using aot autograd, asserting no mutating ops here
with enable_python_dispatcher():
gm = make_fx(fn, select_decomp_table(), tracing_mode="real")(*args)
from .fx_passes.post_grad import remove_noop_ops
remove_noop_ops(gm.graph)
if run_dce:
gm.graph.eliminate_dead_code()
gm.recompile()
@ -1820,6 +1829,10 @@ def joint_fwd_bwd(fn: Callable[..., Any], args: Sequence[Any]) -> torch.fx.Graph
)(*args)
assert gm
from .fx_passes.post_grad import remove_noop_ops
remove_noop_ops(gm.graph)
from .fx_passes.joint_graph import pointless_view
matcher_pass = PatternMatcherPass()