Re-enable SGD (#117434)

Re-enables the SGD optimizer now that compile times are more reasonable. [Benchmark run](https://github.com/pytorch/pytorch/actions/runs/7511073761)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117434
Approved by: https://github.com/anijain2305, https://github.com/janeyx99
This commit is contained in:
Michael Lazos 2024-01-19 04:28:42 +00:00 committed by PyTorch MergeBot
parent 924ed91612
commit f302a0d380
21 changed files with 402 additions and 373 deletions

View File

@ -1 +1 @@
d6dc1a01a4d8c9d3a4369fbc68a14192c60006f6
2990cb38c17e06d0dbe25437674ca40130d76a8f

View File

@ -2,7 +2,7 @@ name,accuracy,graph_breaks
AlbertForMaskedLM,pass,5
AlbertForMaskedLM,pass,4
@ -14,11 +14,11 @@ AllenaiLongformerBase,pass,8
BartForCausalLM,pass,13
BartForCausalLM,pass,12
BartForConditionalGeneration,pass,25
BartForConditionalGeneration,pass,24
@ -34,11 +34,11 @@ BlenderbotForCausalLM,eager_fail_to_run,0
BlenderbotSmallForCausalLM,pass,13
BlenderbotSmallForCausalLM,pass,12
BlenderbotSmallForConditionalGeneration,pass,25
BlenderbotSmallForConditionalGeneration,pass,24
@ -74,7 +74,7 @@ DistillGPT2,pass,4
ElectraForCausalLM,pass,5
ElectraForCausalLM,pass,4
@ -98,15 +98,15 @@ LayoutLMForSequenceClassification,pass,6
M2M100ForConditionalGeneration,pass,5
M2M100ForConditionalGeneration,pass,4
MBartForCausalLM,pass,13
MBartForCausalLM,pass,12
MBartForConditionalGeneration,pass,25
MBartForConditionalGeneration,pass,24
@ -130,19 +130,19 @@ MobileBertForQuestionAnswering,pass,3
OPTForCausalLM,pass,13
OPTForCausalLM,pass,12
PLBartForCausalLM,pass,13
PLBartForCausalLM,pass,12
PLBartForConditionalGeneration,pass,30
PLBartForConditionalGeneration,pass,29
PegasusForCausalLM,pass,13
PegasusForCausalLM,pass,12
@ -158,7 +158,7 @@ RobertaForQuestionAnswering,pass,4
Speech2Text2ForCausalLM,pass,13
Speech2Text2ForCausalLM,pass,12
@ -170,11 +170,11 @@ T5Small,pass,4
TrOCRForCausalLM,pass,13
TrOCRForCausalLM,pass,12
XGLMForCausalLM,pass,13
XGLMForCausalLM,pass,12

1 name accuracy graph_breaks
2 AlbertForMaskedLM pass 5 4
3 AlbertForQuestionAnswering pass 4
4 AllenaiLongformerBase pass 8
5 BartForCausalLM pass 13 12
6 BartForConditionalGeneration pass 25 24
7 BertForMaskedLM pass 4
8 BertForQuestionAnswering pass 4
14 DebertaForQuestionAnswering pass 4
15 DebertaV2ForMaskedLM pass_due_to_skip 0
16 DebertaV2ForQuestionAnswering eager_1st_run_OOM 0
17 DistilBertForMaskedLM pass 4
18 DistilBertForQuestionAnswering pass 4
19 DistillGPT2 pass 4
20 ElectraForCausalLM pass 5 4
21 ElectraForQuestionAnswering pass 4
22 GPT2ForSequenceClassification pass 6
23 GoogleFnet pass 4
24 LayoutLMForMaskedLM pass 4
34 OPTForCausalLM pass 13 12
35 PLBartForCausalLM pass 13 12
36 PLBartForConditionalGeneration pass 30 29
37 PegasusForCausalLM pass 13 12
38 PegasusForConditionalGeneration pass 23
39 RobertaForCausalLM pass 4
40 RobertaForQuestionAnswering pass 4
41 Speech2Text2ForCausalLM pass 13 12
42 T5ForConditionalGeneration pass 4
43 T5Small pass 4
44 TrOCRForCausalLM pass 13 12
74
75
76
77
78
79
80
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
158
159
160
161
162
163
164
170
171
172
173
174
175
176
177
178
179
180

View File

@ -2,7 +2,7 @@ name,accuracy,graph_breaks
adv_inception_v3,pass,7
adv_inception_v3,pass,6
@ -10,7 +10,7 @@ beit_base_patch16_224,pass,6
botnet26t_256,pass,7
botnet26t_256,pass,6
@ -18,11 +18,11 @@ cait_m36_384,eager_fail_to_run,0
coat_lite_mini,pass,7
coat_lite_mini,pass,6
convit_base,pass,7
convit_base,pass,6
@ -50,11 +50,11 @@ dla102,pass,6
dm_nfnet_f0,pass,7
dm_nfnet_f0,pass,6
dpn107,pass,7
dpn107,pass,6
@ -74,15 +74,15 @@ fbnetc_100,pass,6
fbnetv3_b,pass,7
fbnetv3_b,pass,6
gernet_l,pass,7
gernet_l,pass,6
ghostnet_100,pass,7
ghostnet_100,pass,6
@ -90,7 +90,7 @@ gluon_inception_v3,pass,6
gmixer_24_224,pass,7
gmixer_24_224,pass,6
@ -102,7 +102,7 @@ hrnet_w18,pass,5
inception_v3,pass,7
inception_v3,pass,6
@ -110,7 +110,7 @@ jx_nest_base,pass,6
lcnet_050,fail_accuracy,7
lcnet_050,fail_accuracy,6
@ -122,7 +122,7 @@ mixer_b16_224,pass,6
mixnet_l,pass,7
mixnet_l,pass,6
@ -138,7 +138,7 @@ mobilenetv3_large_100,pass,6
mobilevit_s,pass,7
mobilevit_s,pass,6
@ -146,7 +146,7 @@ nfnet_l0,pass,6
pit_b_224,pass,7
pit_b_224,pass,6
@ -154,11 +154,11 @@ pnasnet5large,pass,5
poolformer_m36,pass,7
poolformer_m36,pass,6
regnety_002,pass,7
regnety_002,pass,6
@ -166,23 +166,23 @@ repvgg_a2,pass,6
res2net101_26w_4s,pass,7
res2net101_26w_4s,pass,6
res2net50_14w_8s,pass,7
res2net50_14w_8s,pass,6
res2next50,pass,7
res2next50,pass,6
resmlp_12_224,pass,7
resmlp_12_224,pass,6
resnest101e,pass,7
resnest101e,pass,6
@ -190,11 +190,11 @@ rexnet_100,pass,6
sebotnet33ts_256,pass,7
sebotnet33ts_256,pass,6
selecsls42b,pass,7
selecsls42b,pass,6
@ -206,19 +206,19 @@ swin_base_patch4_window7_224,pass,6
swsl_resnext101_32x16d,pass,7
swsl_resnext101_32x16d,pass,6
tf_efficientnet_b0,pass,7
tf_efficientnet_b0,pass,6
tf_mixnet_l,pass,7
tf_mixnet_l,pass,6
tinynet_a,pass,7
tinynet_a,pass,6

1 name accuracy graph_breaks
2 adv_inception_v3 pass 7 6
3 beit_base_patch16_224 pass 6
4 botnet26t_256 pass 7 6
5 cait_m36_384 eager_fail_to_run 0
6 coat_lite_mini pass 7 6
7 convit_base pass 7 6
8 convmixer_768_32 pass 5
10 crossvit_9_240 pass 6
11 cspdarknet53 pass 6
12 deit_base_distilled_patch16_224 pass 6
13 dla102 pass 6
14 dm_nfnet_f0 pass 7 6
15 dpn107 pass 7 6
16 eca_botnext26ts_256 pass 6
18 ese_vovnet19b_dw pass 6
19 fbnetc_100 pass 6
20 fbnetv3_b pass 7 6
21 gernet_l pass 7 6
22 ghostnet_100 pass 7 6
23 gluon_inception_v3 pass 6
24 gmixer_24_224 pass 7 6
25 gmlp_s16_224 pass 6
26 hrnet_w18 pass 5
27 inception_v3 pass 7 6
28 jx_nest_base pass 6
50 selecsls42b pass 7 6
51 spnasnet_100 pass 6
52 swin_base_patch4_window7_224 pass 6
53 swsl_resnext101_32x16d pass 7 6
54 tf_efficientnet_b0 pass 7 6
55 tf_mixnet_l pass 7 6
56 tinynet_a pass 7 6
57 tnt_s_patch16_224 pass 6
58 twins_pcpvt_base pass 6
59 visformer_small pass 6
60 vit_base_patch16_224 pass 6
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
90
91
92
93
94
95
96
102
103
104
105
106
107
108
110
111
112
113
114
115
116
122
123
124
125
126
127
128
138
139
140
141
142
143
144
146
147
148
149
150
151
152
154
155
156
157
158
159
160
161
162
163
164
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
190
191
192
193
194
195
196
197
198
199
200
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224

View File

@ -2,11 +2,11 @@ name,accuracy,graph_breaks
torchrec_dlrm,pass,7
torchrec_dlrm,pass,6
BERT_pytorch,pass,7
BERT_pytorch,pass,6
@ -18,7 +18,7 @@ DALLE2_pytorch,eager_fail_to_run,0
LearningToPaint,pass,7
LearningToPaint,pass,6
@ -26,7 +26,7 @@ Super_SloMo,pass,6
alexnet,pass,7
alexnet,pass,6
@ -50,15 +50,15 @@ cm3leon_generate,eager_fail_to_run,0
dcgan,pass,7
dcgan,pass,6
demucs,pass,10
demucs,pass,9
densenet121,pass,7
densenet121,pass,6
@ -70,7 +70,7 @@ detectron2_maskrcnn_r_50_c4,eager_fail_to_run,0
dlrm,pass,7
dlrm,pass,6
@ -86,7 +86,7 @@ drq,pass,5
fastNLP_Bert,pass,11
fastNLP_Bert,pass,10
@ -126,7 +126,7 @@ hf_Reformer,pass,25
hf_T5_base,pass,6
hf_T5_base,eager_2nd_run_OOM,0
@ -162,7 +162,7 @@ mnasnet1_0,pass,6
mobilenet_v2,pass,7
mobilenet_v2,pass,6
@ -174,7 +174,7 @@ mobilenet_v3_large,pass,6
moco,pass,18
moco,pass,17
@ -190,19 +190,19 @@ opacus_cifar10,eager_fail_to_run,0
phlippe_densenet,pass,7
phlippe_densenet,pass,6
phlippe_resnet,pass,7
phlippe_resnet,pass,6
pytorch_CycleGAN_and_pix2pix,pass,7
pytorch_CycleGAN_and_pix2pix,pass,6
pytorch_stargan,pass,7
pytorch_stargan,pass,6
@ -214,11 +214,11 @@ resnet152,pass,6
resnet18,pass,7
resnet18,pass,6
resnet50,pass,7
resnet50,pass,6
@ -234,7 +234,7 @@ sam,eager_fail_to_run,0
shufflenet_v2_x1_0,pass,7
shufflenet_v2_x1_0,pass,6
@ -242,15 +242,15 @@ soft_actor_critic,pass,5
speech_transformer,pass,17
speech_transformer,pass,16
squeezenet1_1,pass,7
squeezenet1_1,pass,6
stable_diffusion_text_encoder,pass,6
stable_diffusion_text_encoder,pass,5
@ -262,7 +262,7 @@ timm_efficientnet,pass,6
timm_regnet,pass,7
timm_regnet,pass,6
@ -270,7 +270,7 @@ timm_resnest,pass,6
timm_vision_transformer,pass,7
timm_vision_transformer,pass,6
@ -278,7 +278,7 @@ timm_vision_transformer_large,pass_due_to_skip,0
timm_vovnet,pass,7
timm_vovnet,pass,6
@ -290,11 +290,11 @@ tts_angular,pass,8
vgg16,pass,7
vgg16,pass,6
vision_maskrcnn,pass,34
vision_maskrcnn,pass,33

1 name accuracy graph_breaks
2 torchrec_dlrm pass 7 6
3 BERT_pytorch pass 7 6
4 Background_Matting pass_due_to_skip 0
5 DALLE2_pytorch eager_fail_to_run 0
6 LearningToPaint pass 7 6
7 Super_SloMo pass 6
8 alexnet pass 7 6
9 basic_gnn_edgecnn pass 21
10 basic_gnn_gcn pass 12
11 basic_gnn_gin pass 6
12 basic_gnn_sage pass 6
18 detectron2_maskrcnn_r_50_c4 eager_fail_to_run 0
19 dlrm pass 7 6
20 doctr_det_predictor eager_fail_to_run 0
21 doctr_reco_predictor eager_fail_to_run 0
22 drq pass 5
23 fastNLP_Bert pass 11 10
24 functorch_dp_cifar10 pass 6
26 hf_Albert pass 5
27 hf_Bart pass 5
28 hf_BigBird pass 5
29 hf_DistilBert pass 5
30 hf_GPT2 pass 5
31 hf_GPT2_large pass_due_to_skip 0
32 hf_Reformer pass 25
50 phlippe_resnet pass 7 6
51 pytorch_CycleGAN_and_pix2pix pass 7 6
52 pytorch_stargan pass 7 6
53 pytorch_unet pass 6
54 resnet152 pass 6
55 resnet18 pass 7 6
56 resnet50 pass 7 6
57 resnet50_quantized_qat eager_fail_to_run 0
58 resnext50_32x4d pass 6
59 sam eager_fail_to_run 0
60 shufflenet_v2_x1_0 pass 7 6
61 soft_actor_critic pass 5
62 speech_transformer pass 17 16
63 squeezenet1_1 pass 7 6
64 stable_diffusion_text_encoder pass 6 5
70 timm_vision_transformer_large pass_due_to_skip 0
71 timm_vovnet pass 7 6
72 torch_multimodal_clip pass 6
73 tts_angular pass 8
74 vgg16 pass 7 6
75 vision_maskrcnn pass 34 33
76 yolov3 pass 8
86
87
88
89
90
91
92
126
127
128
129
130
131
132
162
163
164
165
166
167
168
174
175
176
177
178
179
180
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
214
215
216
217
218
219
220
221
222
223
224
234
235
236
237
238
239
240
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
262
263
264
265
266
267
268
270
271
272
273
274
275
276
278
279
280
281
282
283
284
290
291
292
293
294
295
296
297
298
299
300

View File

@ -2,7 +2,7 @@ name,accuracy,graph_breaks
AlbertForMaskedLM,pass,5
AlbertForMaskedLM,pass,4
@ -14,11 +14,11 @@ AllenaiLongformerBase,pass,8
BartForCausalLM,pass,13
BartForCausalLM,pass,12
BartForConditionalGeneration,pass,25
BartForConditionalGeneration,pass,24
@ -34,11 +34,11 @@ BlenderbotForCausalLM,eager_fail_to_run,0
BlenderbotSmallForCausalLM,pass,13
BlenderbotSmallForCausalLM,pass,12
BlenderbotSmallForConditionalGeneration,pass,25
BlenderbotSmallForConditionalGeneration,pass,24
@ -74,7 +74,7 @@ DistillGPT2,pass,4
ElectraForCausalLM,pass,5
ElectraForCausalLM,pass,4
@ -98,15 +98,15 @@ LayoutLMForSequenceClassification,pass,6
M2M100ForConditionalGeneration,pass,5
M2M100ForConditionalGeneration,pass,4
MBartForCausalLM,pass,13
MBartForCausalLM,pass,12
MBartForConditionalGeneration,pass,25
MBartForConditionalGeneration,pass,24
@ -130,19 +130,19 @@ MobileBertForQuestionAnswering,pass,3
OPTForCausalLM,pass,13
OPTForCausalLM,pass,12
PLBartForCausalLM,pass,13
PLBartForCausalLM,pass,12
PLBartForConditionalGeneration,pass,30
PLBartForConditionalGeneration,pass,29
PegasusForCausalLM,pass,13
PegasusForCausalLM,pass,12
@ -158,7 +158,7 @@ RobertaForQuestionAnswering,pass,4
Speech2Text2ForCausalLM,pass,13
Speech2Text2ForCausalLM,pass,12
@ -170,11 +170,11 @@ T5Small,pass,4
TrOCRForCausalLM,pass,13
TrOCRForCausalLM,pass,12
XGLMForCausalLM,pass,13
XGLMForCausalLM,pass,12

1 name accuracy graph_breaks
2 AlbertForMaskedLM pass 5 4
3 AlbertForQuestionAnswering pass 4
4 AllenaiLongformerBase pass 8
5 BartForCausalLM pass 13 12
6 BartForConditionalGeneration pass 25 24
7 BertForMaskedLM pass 4
8 BertForQuestionAnswering pass 4
14 DebertaForQuestionAnswering pass 4
15 DebertaV2ForMaskedLM pass_due_to_skip 0
16 DebertaV2ForQuestionAnswering eager_1st_run_OOM 0
17 DistilBertForMaskedLM pass 4
18 DistilBertForQuestionAnswering pass 4
19 DistillGPT2 pass 4
20 ElectraForCausalLM pass 5 4
21 ElectraForQuestionAnswering pass 4
22 GPT2ForSequenceClassification pass 6
23 GoogleFnet pass 4
24 LayoutLMForMaskedLM pass 4
34 OPTForCausalLM pass 13 12
35 PLBartForCausalLM pass 13 12
36 PLBartForConditionalGeneration pass 30 29
37 PegasusForCausalLM pass 13 12
38 PegasusForConditionalGeneration pass 23
39 RobertaForCausalLM pass 4
40 RobertaForQuestionAnswering pass 4
41 Speech2Text2ForCausalLM pass 13 12
42 T5ForConditionalGeneration pass 4
43 T5Small pass 4
44 TrOCRForCausalLM pass 13 12
74
75
76
77
78
79
80
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
158
159
160
161
162
163
164
170
171
172
173
174
175
176
177
178
179
180

View File

@ -2,7 +2,7 @@ name,accuracy,graph_breaks
adv_inception_v3,pass,7
adv_inception_v3,pass,6
@ -10,7 +10,7 @@ beit_base_patch16_224,pass,6
botnet26t_256,pass,7
botnet26t_256,pass,6
@ -18,11 +18,11 @@ cait_m36_384,eager_fail_to_run,0
coat_lite_mini,pass,7
coat_lite_mini,pass,6
convit_base,pass,7
convit_base,pass,6
@ -50,11 +50,11 @@ dla102,pass,6
dm_nfnet_f0,pass,7
dm_nfnet_f0,pass,6
dpn107,pass,7
dpn107,pass,6
@ -74,15 +74,15 @@ fbnetc_100,pass,6
fbnetv3_b,pass,7
fbnetv3_b,pass,6
gernet_l,pass,7
gernet_l,pass,6
ghostnet_100,pass,7
ghostnet_100,pass,6
@ -90,7 +90,7 @@ gluon_inception_v3,pass,6
gmixer_24_224,pass,7
gmixer_24_224,pass,6
@ -102,7 +102,7 @@ hrnet_w18,pass,5
inception_v3,pass,7
inception_v3,pass,6
@ -110,7 +110,7 @@ jx_nest_base,pass,6
lcnet_050,fail_accuracy,7
lcnet_050,fail_accuracy,6
@ -122,7 +122,7 @@ mixer_b16_224,pass,6
mixnet_l,pass,7
mixnet_l,pass,6
@ -138,7 +138,7 @@ mobilenetv3_large_100,pass,6
mobilevit_s,pass,7
mobilevit_s,pass,6
@ -146,7 +146,7 @@ nfnet_l0,pass,6
pit_b_224,pass,7
pit_b_224,pass,6
@ -154,11 +154,11 @@ pnasnet5large,pass,5
poolformer_m36,pass,7
poolformer_m36,pass,6
regnety_002,pass,7
regnety_002,pass,6
@ -166,23 +166,23 @@ repvgg_a2,pass,6
res2net101_26w_4s,pass,7
res2net101_26w_4s,pass,6
res2net50_14w_8s,pass,7
res2net50_14w_8s,pass,6
res2next50,pass,7
res2next50,pass,6
resmlp_12_224,pass,7
resmlp_12_224,pass,6
resnest101e,pass,7
resnest101e,pass,6
@ -190,11 +190,11 @@ rexnet_100,pass,6
sebotnet33ts_256,pass,7
sebotnet33ts_256,pass,6
selecsls42b,pass,7
selecsls42b,pass,6
@ -206,19 +206,19 @@ swin_base_patch4_window7_224,pass,6
swsl_resnext101_32x16d,pass,7
swsl_resnext101_32x16d,pass,6
tf_efficientnet_b0,pass,7
tf_efficientnet_b0,pass,6
tf_mixnet_l,pass,7
tf_mixnet_l,pass,6
tinynet_a,pass,7
tinynet_a,pass,6

1 name accuracy graph_breaks
2 adv_inception_v3 pass 7 6
3 beit_base_patch16_224 pass 6
4 botnet26t_256 pass 7 6
5 cait_m36_384 eager_fail_to_run 0
6 coat_lite_mini pass 7 6
7 convit_base pass 7 6
8 convmixer_768_32 pass 5
10 crossvit_9_240 pass 6
11 cspdarknet53 pass 6
12 deit_base_distilled_patch16_224 pass 6
13 dla102 pass 6
14 dm_nfnet_f0 pass 7 6
15 dpn107 pass 7 6
16 eca_botnext26ts_256 pass 6
18 ese_vovnet19b_dw pass 6
19 fbnetc_100 pass 6
20 fbnetv3_b pass 7 6
21 gernet_l pass 7 6
22 ghostnet_100 pass 7 6
23 gluon_inception_v3 pass 6
24 gmixer_24_224 pass 7 6
25 gmlp_s16_224 pass 6
26 hrnet_w18 pass 5
27 inception_v3 pass 7 6
28 jx_nest_base pass 6
50 selecsls42b pass 7 6
51 spnasnet_100 pass 6
52 swin_base_patch4_window7_224 pass 6
53 swsl_resnext101_32x16d pass 7 6
54 tf_efficientnet_b0 pass 7 6
55 tf_mixnet_l pass 7 6
56 tinynet_a pass 7 6
57 tnt_s_patch16_224 pass 6
58 twins_pcpvt_base pass 6
59 visformer_small pass 6
60 vit_base_patch16_224 pass 6
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
90
91
92
93
94
95
96
102
103
104
105
106
107
108
110
111
112
113
114
115
116
122
123
124
125
126
127
128
138
139
140
141
142
143
144
146
147
148
149
150
151
152
154
155
156
157
158
159
160
161
162
163
164
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
190
191
192
193
194
195
196
197
198
199
200
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224

View File

@ -2,7 +2,7 @@ name,accuracy,graph_breaks
BERT_pytorch,pass,7
BERT_pytorch,pass,6
@ -14,7 +14,7 @@ DALLE2_pytorch,eager_fail_to_run,0
LearningToPaint,pass,7
LearningToPaint,pass,6
@ -22,7 +22,7 @@ Super_SloMo,pass,6
alexnet,pass,7
alexnet,pass,6
@ -46,15 +46,15 @@ cm3leon_generate,eager_fail_to_run,0
dcgan,pass,7
dcgan,pass,6
demucs,pass,10
demucs,pass,9
densenet121,pass,7
densenet121,pass,6
@ -66,7 +66,7 @@ detectron2_maskrcnn_r_50_c4,eager_fail_to_run,0
dlrm,pass,7
dlrm,pass,6
@ -82,7 +82,7 @@ drq,pass,5
fastNLP_Bert,pass,11
fastNLP_Bert,pass,10
@ -122,7 +122,7 @@ hf_Reformer,pass,25
hf_T5_base,pass,6
hf_T5_base,eager_2nd_run_OOM,0
@ -158,7 +158,7 @@ mnasnet1_0,pass,6
mobilenet_v2,pass,7
mobilenet_v2,pass,6
@ -170,7 +170,7 @@ mobilenet_v3_large,pass,6
moco,pass,18
moco,pass,17
@ -186,19 +186,19 @@ opacus_cifar10,eager_fail_to_run,0
phlippe_densenet,pass,7
phlippe_densenet,pass,6
phlippe_resnet,pass,7
phlippe_resnet,pass,6
pytorch_CycleGAN_and_pix2pix,pass,7
pytorch_CycleGAN_and_pix2pix,pass,6
pytorch_stargan,pass,7
pytorch_stargan,pass,6
@ -210,11 +210,11 @@ resnet152,pass,6
resnet18,pass,7
resnet18,pass,6
resnet50,pass,7
resnet50,pass,6
@ -230,7 +230,7 @@ sam,eager_fail_to_run,0
shufflenet_v2_x1_0,pass,7
shufflenet_v2_x1_0,pass,6
@ -238,11 +238,11 @@ soft_actor_critic,pass,5
squeezenet1_1,pass,7
squeezenet1_1,pass,6
stable_diffusion_text_encoder,pass,6
stable_diffusion_text_encoder,pass,5
@ -254,7 +254,7 @@ timm_efficientnet,pass,6
timm_regnet,pass,7
timm_regnet,pass,6
@ -262,7 +262,7 @@ timm_resnest,pass,6
timm_vision_transformer,pass,7
timm_vision_transformer,pass,6
@ -270,7 +270,7 @@ timm_vision_transformer_large,pass_due_to_skip,0
timm_vovnet,pass,7
timm_vovnet,pass,6
@ -282,11 +282,11 @@ tts_angular,pass,8
vgg16,pass,7
vgg16,pass,6
vision_maskrcnn,pass,34
vision_maskrcnn,pass,33

1 name accuracy graph_breaks
2 BERT_pytorch pass 7 6
3 Background_Matting pass_due_to_skip 0
4 DALLE2_pytorch eager_fail_to_run 0
5 LearningToPaint pass 7 6
6 Super_SloMo pass 6
7 alexnet pass 7 6
8 basic_gnn_edgecnn pass 21
14 demucs pass 10 9
15 densenet121 pass 7 6
16 detectron2_fcos_r_50_fpn model_fail_to_load 0
17 detectron2_maskrcnn_r_50_c4 eager_fail_to_run 0
18 dlrm pass 7 6
19 doctr_det_predictor eager_fail_to_run 0
20 doctr_reco_predictor eager_fail_to_run 0
22 fastNLP_Bert pass 11 10
23 functorch_dp_cifar10 pass 6
24 functorch_maml_omniglot pass 6
25 hf_Albert pass 5
26 hf_Bart pass 5
27 hf_BigBird fail_to_run 3
28 hf_DistilBert pass 5
46 nvidia_deeprecommender pass 6
47 opacus_cifar10 eager_fail_to_run 0
48 phlippe_densenet pass 7 6
49 phlippe_resnet pass 7 6
50 pytorch_CycleGAN_and_pix2pix pass 7 6
51 pytorch_stargan pass 7 6
52 pytorch_unet pass 6
53 resnet152 pass 6
54 resnet18 pass 7 6
55 resnet50 pass 7 6
56 resnet50_quantized_qat eager_fail_to_run 0
57 resnext50_32x4d pass 6
58 sam eager_fail_to_run 0
59 shufflenet_v2_x1_0 pass 7 6
60 soft_actor_critic pass 5
66 timm_resnest pass 6
67 timm_vision_transformer pass 7 6
68 timm_vision_transformer_large pass_due_to_skip 0
69 timm_vovnet pass 7 6
70 torch_multimodal_clip pass 6
71 tts_angular pass 8
72 vgg16 pass 7 6
82
83
84
85
86
87
88
122
123
124
125
126
127
128
158
159
160
161
162
163
164
170
171
172
173
174
175
176
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
210
211
212
213
214
215
216
217
218
219
220
230
231
232
233
234
235
236
238
239
240
241
242
243
244
245
246
247
248
254
255
256
257
258
259
260
262
263
264
265
266
267
268
270
271
272
273
274
275
276
282
283
284
285
286
287
288
289
290
291
292

View File

@ -2,7 +2,7 @@ name,accuracy,graph_breaks
AlbertForMaskedLM,pass,5
AlbertForMaskedLM,pass,4
@ -14,11 +14,11 @@ AllenaiLongformerBase,pass,8
BartForCausalLM,pass,13
BartForCausalLM,pass,12
BartForConditionalGeneration,pass,25
BartForConditionalGeneration,pass,24
@ -34,11 +34,11 @@ BlenderbotForCausalLM,eager_fail_to_run,0
BlenderbotSmallForCausalLM,pass,13
BlenderbotSmallForCausalLM,pass,12
BlenderbotSmallForConditionalGeneration,pass,25
BlenderbotSmallForConditionalGeneration,pass,24
@ -74,7 +74,7 @@ DistillGPT2,pass,4
ElectraForCausalLM,pass,5
ElectraForCausalLM,pass,4
@ -98,15 +98,15 @@ LayoutLMForSequenceClassification,pass,6
M2M100ForConditionalGeneration,pass,5
M2M100ForConditionalGeneration,pass,4
MBartForCausalLM,pass,13
MBartForCausalLM,pass,12
MBartForConditionalGeneration,pass,25
MBartForConditionalGeneration,pass,24
@ -130,19 +130,19 @@ MobileBertForQuestionAnswering,pass,3
OPTForCausalLM,pass,13
OPTForCausalLM,pass,12
PLBartForCausalLM,pass,13
PLBartForCausalLM,pass,12
PLBartForConditionalGeneration,pass,30
PLBartForConditionalGeneration,pass,29
PegasusForCausalLM,pass,13
PegasusForCausalLM,pass,12
@ -158,7 +158,7 @@ RobertaForQuestionAnswering,pass,4
Speech2Text2ForCausalLM,pass,13
Speech2Text2ForCausalLM,pass,12
@ -170,11 +170,11 @@ T5Small,pass,4
TrOCRForCausalLM,pass,13
TrOCRForCausalLM,pass,12
XGLMForCausalLM,pass,13
XGLMForCausalLM,pass,12

1 name accuracy graph_breaks
2 AlbertForMaskedLM pass 5 4
3 AlbertForQuestionAnswering pass 4
4 AllenaiLongformerBase pass 8
5 BartForCausalLM pass 13 12
6 BartForConditionalGeneration pass 25 24
7 BertForMaskedLM pass 4
8 BertForQuestionAnswering pass 4
14 DebertaForQuestionAnswering pass 4
15 DebertaV2ForMaskedLM pass_due_to_skip 0
16 DebertaV2ForQuestionAnswering eager_1st_run_OOM 0
17 DistilBertForMaskedLM pass 4
18 DistilBertForQuestionAnswering pass 4
19 DistillGPT2 pass 4
20 ElectraForCausalLM pass 5 4
21 ElectraForQuestionAnswering pass 4
22 GPT2ForSequenceClassification pass 6
23 GoogleFnet pass 4
24 LayoutLMForMaskedLM pass 4
34 OPTForCausalLM pass 13 12
35 PLBartForCausalLM pass 13 12
36 PLBartForConditionalGeneration pass 30 29
37 PegasusForCausalLM pass 13 12
38 PegasusForConditionalGeneration pass 23
39 RobertaForCausalLM pass 4
40 RobertaForQuestionAnswering pass 4
41 Speech2Text2ForCausalLM pass 13 12
42 T5ForConditionalGeneration pass 4
43 T5Small pass 4
44 TrOCRForCausalLM pass 13 12
74
75
76
77
78
79
80
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
158
159
160
161
162
163
164
170
171
172
173
174
175
176
177
178
179
180

View File

@ -2,7 +2,7 @@ name,accuracy,graph_breaks
adv_inception_v3,pass,7
adv_inception_v3,pass,6
@ -10,7 +10,7 @@ beit_base_patch16_224,pass,6
botnet26t_256,pass,7
botnet26t_256,pass,6
@ -18,11 +18,11 @@ cait_m36_384,eager_fail_to_run,0
coat_lite_mini,pass,7
coat_lite_mini,pass,6
convit_base,pass,7
convit_base,pass,6
@ -50,11 +50,11 @@ dla102,pass,6
dm_nfnet_f0,pass,7
dm_nfnet_f0,pass,6
dpn107,pass,7
dpn107,pass,6
@ -74,15 +74,15 @@ fbnetc_100,pass,6
fbnetv3_b,pass,7
fbnetv3_b,pass,6
gernet_l,pass,7
gernet_l,pass,6
ghostnet_100,pass,7
ghostnet_100,pass,6
@ -90,7 +90,7 @@ gluon_inception_v3,pass,6
gmixer_24_224,pass,7
gmixer_24_224,pass,6
@ -102,7 +102,7 @@ hrnet_w18,pass,5
inception_v3,pass,7
inception_v3,pass,6
@ -110,7 +110,7 @@ jx_nest_base,pass,6
lcnet_050,pass,7
lcnet_050,pass,6
@ -122,7 +122,7 @@ mixer_b16_224,pass,6
mixnet_l,pass,7
mixnet_l,pass,6
@ -138,7 +138,7 @@ mobilenetv3_large_100,pass,6
mobilevit_s,pass,7
mobilevit_s,pass,6
@ -146,7 +146,7 @@ nfnet_l0,pass,6
pit_b_224,pass,7
pit_b_224,pass,6
@ -154,11 +154,11 @@ pnasnet5large,pass,5
poolformer_m36,pass,7
poolformer_m36,pass,6
regnety_002,pass,7
regnety_002,pass,6
@ -166,23 +166,23 @@ repvgg_a2,pass,6
res2net101_26w_4s,pass,7
res2net101_26w_4s,pass,6
res2net50_14w_8s,pass,7
res2net50_14w_8s,pass,6
res2next50,pass,7
res2next50,pass,6
resmlp_12_224,pass,7
resmlp_12_224,pass,6
resnest101e,pass,7
resnest101e,pass,6
@ -190,11 +190,11 @@ rexnet_100,pass,6
sebotnet33ts_256,pass,7
sebotnet33ts_256,pass,6
selecsls42b,pass,7
selecsls42b,pass,6
@ -206,19 +206,19 @@ swin_base_patch4_window7_224,pass,6
swsl_resnext101_32x16d,pass,7
swsl_resnext101_32x16d,pass,6
tf_efficientnet_b0,pass,7
tf_efficientnet_b0,pass,6
tf_mixnet_l,pass,7
tf_mixnet_l,pass,6
tinynet_a,pass,7
tinynet_a,pass,6

1 name accuracy graph_breaks
2 adv_inception_v3 pass 7 6
3 beit_base_patch16_224 pass 6
4 botnet26t_256 pass 7 6
5 cait_m36_384 eager_fail_to_run 0
6 coat_lite_mini pass 7 6
7 convit_base pass 7 6
8 convmixer_768_32 pass 5
10 crossvit_9_240 pass 6
11 cspdarknet53 pass 6
12 deit_base_distilled_patch16_224 pass 6
13 dla102 pass 6
14 dm_nfnet_f0 pass 7 6
15 dpn107 pass 7 6
16 eca_botnext26ts_256 pass 6
18 ese_vovnet19b_dw pass 6
19 fbnetc_100 pass 6
20 fbnetv3_b pass 7 6
21 gernet_l pass 7 6
22 ghostnet_100 pass 7 6
23 gluon_inception_v3 pass 6
24 gmixer_24_224 pass 7 6
25 gmlp_s16_224 pass 6
26 hrnet_w18 pass 5
27 inception_v3 pass 7 6
28 jx_nest_base pass 6
50 selecsls42b pass 7 6
51 spnasnet_100 pass 6
52 swin_base_patch4_window7_224 pass 6
53 swsl_resnext101_32x16d pass 7 6
54 tf_efficientnet_b0 pass 7 6
55 tf_mixnet_l pass 7 6
56 tinynet_a pass 7 6
57 tnt_s_patch16_224 pass 6
58 twins_pcpvt_base pass 6
59 visformer_small pass 6
60 vit_base_patch16_224 pass 6
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
90
91
92
93
94
95
96
102
103
104
105
106
107
108
110
111
112
113
114
115
116
122
123
124
125
126
127
128
138
139
140
141
142
143
144
146
147
148
149
150
151
152
154
155
156
157
158
159
160
161
162
163
164
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
190
191
192
193
194
195
196
197
198
199
200
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224

View File

@ -2,7 +2,7 @@ name,accuracy,graph_breaks
BERT_pytorch,pass,7
BERT_pytorch,pass,6
@ -14,7 +14,7 @@ DALLE2_pytorch,eager_fail_to_run,0
LearningToPaint,pass,7
LearningToPaint,pass,6
@ -22,7 +22,7 @@ Super_SloMo,pass,6
alexnet,pass,7
alexnet,pass,6
@ -46,7 +46,7 @@ cm3leon_generate,eager_fail_to_run,0
dcgan,pass,7
dcgan,pass,6
@ -54,7 +54,7 @@ demucs,fail_to_run,4
densenet121,pass,7
densenet121,pass,6
@ -66,7 +66,7 @@ detectron2_maskrcnn_r_50_c4,eager_fail_to_run,0
dlrm,pass,7
dlrm,pass,6
@ -82,7 +82,7 @@ drq,pass,5
fastNLP_Bert,pass,11
fastNLP_Bert,pass,10
@ -122,7 +122,7 @@ hf_Reformer,pass,25
hf_T5_base,OOM,3
hf_T5_base,eager_2nd_run_OOM,0
@ -158,7 +158,7 @@ mnasnet1_0,pass,6
mobilenet_v2,pass,7
mobilenet_v2,pass,6
@ -170,7 +170,7 @@ mobilenet_v3_large,pass,6
moco,pass,18
moco,pass,17
@ -186,19 +186,19 @@ opacus_cifar10,eager_fail_to_run,0
phlippe_densenet,pass,7
phlippe_densenet,pass,6
phlippe_resnet,pass,7
phlippe_resnet,pass,6
pytorch_CycleGAN_and_pix2pix,pass,7
pytorch_CycleGAN_and_pix2pix,pass,6
pytorch_stargan,pass,7
pytorch_stargan,pass,6
@ -210,11 +210,11 @@ resnet152,pass,6
resnet18,pass,7
resnet18,pass,6
resnet50,pass,7
resnet50,pass,6
@ -230,7 +230,7 @@ sam,eager_fail_to_run,0
shufflenet_v2_x1_0,pass,7
shufflenet_v2_x1_0,pass,6
@ -238,11 +238,11 @@ soft_actor_critic,pass,5
squeezenet1_1,pass,7
squeezenet1_1,pass,6
stable_diffusion_text_encoder,pass,6
stable_diffusion_text_encoder,pass,5
@ -254,7 +254,7 @@ timm_efficientnet,pass,6
timm_regnet,pass,7
timm_regnet,pass,6
@ -262,7 +262,7 @@ timm_resnest,pass,6
timm_vision_transformer,pass,7
timm_vision_transformer,pass,6
@ -270,7 +270,7 @@ timm_vision_transformer_large,pass_due_to_skip,0
timm_vovnet,pass,7
timm_vovnet,pass,6
@ -282,11 +282,11 @@ tts_angular,pass,8
vgg16,pass,7
vgg16,pass,6
vision_maskrcnn,pass,34
vision_maskrcnn,pass,33

1 name accuracy graph_breaks
2 BERT_pytorch pass 7 6
3 Background_Matting pass_due_to_skip 0
4 DALLE2_pytorch eager_fail_to_run 0
5 LearningToPaint pass 7 6
6 Super_SloMo pass 6
7 alexnet pass 7 6
8 basic_gnn_edgecnn pass 21
14 demucs fail_to_run 4
15 densenet121 pass 7 6
16 detectron2_fcos_r_50_fpn model_fail_to_load 0
17 detectron2_maskrcnn_r_50_c4 eager_fail_to_run 0
18 dlrm pass 7 6
19 doctr_det_predictor eager_fail_to_run 0
20 doctr_reco_predictor eager_fail_to_run 0
22 fastNLP_Bert pass 11 10
23 functorch_dp_cifar10 pass 6
24 functorch_maml_omniglot pass 6
25 hf_Albert pass 5
26 hf_Bart pass 5
27 hf_BigBird fail_to_run 3
28 hf_DistilBert pass 5
46 nvidia_deeprecommender pass 6
47 opacus_cifar10 eager_fail_to_run 0
48 phlippe_densenet pass 7 6
49 phlippe_resnet pass 7 6
50 pytorch_CycleGAN_and_pix2pix pass 7 6
51 pytorch_stargan pass 7 6
52 pytorch_unet pass 6
54 resnet18 pass 7 6
55 resnet50 pass 7 6
56 resnet50_quantized_qat eager_fail_to_run 0
57 resnext50_32x4d pass 6
58 sam eager_fail_to_run 0
59 shufflenet_v2_x1_0 pass 7 6
60 soft_actor_critic pass 5
66 timm_resnest pass 6
67 timm_vision_transformer pass 7 6
68 timm_vision_transformer_large pass_due_to_skip 0
69 timm_vovnet pass 7 6
70 torch_multimodal_clip pass 6
71 tts_angular pass 8
72 vgg16 pass 7 6
82
83
84
85
86
87
88
122
123
124
125
126
127
128
158
159
160
161
162
163
164
170
171
172
173
174
175
176
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
210
211
212
213
214
215
216
217
218
219
220
230
231
232
233
234
235
236
238
239
240
241
242
243
244
245
246
247
248
254
255
256
257
258
259
260
262
263
264
265
266
267
268
270
271
272
273
274
275
276
282
283
284
285
286
287
288
289
290
291
292

View File

@ -2,7 +2,7 @@ name,accuracy,graph_breaks
AlbertForMaskedLM,pass,5
AlbertForMaskedLM,pass,4
@ -14,11 +14,11 @@ AllenaiLongformerBase,pass,8
BartForCausalLM,pass,13
BartForCausalLM,pass,12
BartForConditionalGeneration,pass,25
BartForConditionalGeneration,pass,24
@ -34,11 +34,11 @@ BlenderbotForCausalLM,eager_fail_to_run,0
BlenderbotSmallForCausalLM,pass,13
BlenderbotSmallForCausalLM,pass,12
BlenderbotSmallForConditionalGeneration,pass,25
BlenderbotSmallForConditionalGeneration,pass,24
@ -74,7 +74,7 @@ DistillGPT2,pass,4
ElectraForCausalLM,pass,5
ElectraForCausalLM,pass,4
@ -98,15 +98,15 @@ LayoutLMForSequenceClassification,pass,6
M2M100ForConditionalGeneration,pass,5
M2M100ForConditionalGeneration,pass,4
MBartForCausalLM,pass,13
MBartForCausalLM,pass,12
MBartForConditionalGeneration,pass,25
MBartForConditionalGeneration,pass,24
@ -130,19 +130,19 @@ MobileBertForQuestionAnswering,pass,3
OPTForCausalLM,pass,13
OPTForCausalLM,pass,12
PLBartForCausalLM,pass,13
PLBartForCausalLM,pass,12
PLBartForConditionalGeneration,pass,30
PLBartForConditionalGeneration,pass,29
PegasusForCausalLM,pass,13
PegasusForCausalLM,pass,12
@ -158,7 +158,7 @@ RobertaForQuestionAnswering,pass,4
Speech2Text2ForCausalLM,pass,13
Speech2Text2ForCausalLM,pass,12
@ -170,11 +170,11 @@ T5Small,pass,4
TrOCRForCausalLM,pass,13
TrOCRForCausalLM,pass,12
XGLMForCausalLM,pass,13
XGLMForCausalLM,pass,12

1 name accuracy graph_breaks
2 AlbertForMaskedLM pass 5 4
3 AlbertForQuestionAnswering pass 4
4 AllenaiLongformerBase pass 8
5 BartForCausalLM pass 13 12
6 BartForConditionalGeneration pass 25 24
7 BertForMaskedLM pass 4
8 BertForQuestionAnswering pass 4
14 DebertaForQuestionAnswering pass 4
15 DebertaV2ForMaskedLM pass_due_to_skip 0
16 DebertaV2ForQuestionAnswering eager_1st_run_OOM 0
17 DistilBertForMaskedLM pass 4
18 DistilBertForQuestionAnswering pass 4
19 DistillGPT2 pass 4
20 ElectraForCausalLM pass 5 4
21 ElectraForQuestionAnswering pass 4
22 GPT2ForSequenceClassification pass 6
23 GoogleFnet pass 4
24 LayoutLMForMaskedLM pass 4
34 OPTForCausalLM pass 13 12
35 PLBartForCausalLM pass 13 12
36 PLBartForConditionalGeneration pass 30 29
37 PegasusForCausalLM pass 13 12
38 PegasusForConditionalGeneration pass 23
39 RobertaForCausalLM pass 4
40 RobertaForQuestionAnswering pass 4
41 Speech2Text2ForCausalLM pass 13 12
42 T5ForConditionalGeneration pass 4
43 T5Small pass 4
44 TrOCRForCausalLM pass 13 12
74
75
76
77
78
79
80
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
158
159
160
161
162
163
164
170
171
172
173
174
175
176
177
178
179
180

View File

@ -2,7 +2,7 @@ name,accuracy,graph_breaks
adv_inception_v3,pass,7
adv_inception_v3,pass,6
@ -10,7 +10,7 @@ beit_base_patch16_224,pass,6
botnet26t_256,pass,7
botnet26t_256,pass,6
@ -18,11 +18,11 @@ cait_m36_384,eager_fail_to_run,0
coat_lite_mini,pass,7
coat_lite_mini,pass,6
convit_base,pass,7
convit_base,pass,6
@ -50,11 +50,11 @@ dla102,pass,6
dm_nfnet_f0,pass,7
dm_nfnet_f0,pass,6
dpn107,pass,7
dpn107,pass,6
@ -74,15 +74,15 @@ fbnetc_100,pass,6
fbnetv3_b,pass,7
fbnetv3_b,pass,6
gernet_l,pass,7
gernet_l,pass,6
ghostnet_100,pass,7
ghostnet_100,pass,6
@ -90,7 +90,7 @@ gluon_inception_v3,pass,6
gmixer_24_224,pass,7
gmixer_24_224,pass,6
@ -102,7 +102,7 @@ hrnet_w18,pass,5
inception_v3,pass,7
inception_v3,pass,6
@ -110,7 +110,7 @@ jx_nest_base,pass,6
lcnet_050,pass,7
lcnet_050,pass,6
@ -122,7 +122,7 @@ mixer_b16_224,pass,6
mixnet_l,pass,7
mixnet_l,pass,6
@ -138,7 +138,7 @@ mobilenetv3_large_100,pass,6
mobilevit_s,pass,7
mobilevit_s,pass,6
@ -146,7 +146,7 @@ nfnet_l0,pass,6
pit_b_224,pass,7
pit_b_224,pass,6
@ -154,11 +154,11 @@ pnasnet5large,pass,5
poolformer_m36,pass,7
poolformer_m36,pass,6
regnety_002,pass,7
regnety_002,pass,6
@ -166,23 +166,23 @@ repvgg_a2,pass,6
res2net101_26w_4s,pass,7
res2net101_26w_4s,pass,6
res2net50_14w_8s,pass,7
res2net50_14w_8s,pass,6
res2next50,pass,7
res2next50,pass,6
resmlp_12_224,pass,7
resmlp_12_224,pass,6
resnest101e,pass,7
resnest101e,pass,6
@ -190,11 +190,11 @@ rexnet_100,pass,6
sebotnet33ts_256,pass,7
sebotnet33ts_256,pass,6
selecsls42b,pass,7
selecsls42b,pass,6
@ -206,19 +206,19 @@ swin_base_patch4_window7_224,pass,6
swsl_resnext101_32x16d,pass,7
swsl_resnext101_32x16d,pass,6
tf_efficientnet_b0,pass,7
tf_efficientnet_b0,pass,6
tf_mixnet_l,pass,7
tf_mixnet_l,pass,6
tinynet_a,pass,7
tinynet_a,pass,6

1 name accuracy graph_breaks
2 adv_inception_v3 pass 7 6
3 beit_base_patch16_224 pass 6
4 botnet26t_256 pass 7 6
5 cait_m36_384 eager_fail_to_run 0
6 coat_lite_mini pass 7 6
7 convit_base pass 7 6
8 convmixer_768_32 pass 5
10 crossvit_9_240 pass 6
11 cspdarknet53 pass 6
12 deit_base_distilled_patch16_224 pass 6
13 dla102 pass 6
14 dm_nfnet_f0 pass 7 6
15 dpn107 pass 7 6
16 eca_botnext26ts_256 pass 6
18 ese_vovnet19b_dw pass 6
19 fbnetc_100 pass 6
20 fbnetv3_b pass 7 6
21 gernet_l pass 7 6
22 ghostnet_100 pass 7 6
23 gluon_inception_v3 pass 6
24 gmixer_24_224 pass 7 6
25 gmlp_s16_224 pass 6
26 hrnet_w18 pass 5
27 inception_v3 pass 7 6
28 jx_nest_base pass 6
50 selecsls42b pass 7 6
51 spnasnet_100 pass 6
52 swin_base_patch4_window7_224 pass 6
53 swsl_resnext101_32x16d pass 7 6
54 tf_efficientnet_b0 pass 7 6
55 tf_mixnet_l pass 7 6
56 tinynet_a pass 7 6
57 tnt_s_patch16_224 pass 6
58 twins_pcpvt_base pass 6
59 visformer_small pass 6
60 vit_base_patch16_224 pass 6
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
90
91
92
93
94
95
96
102
103
104
105
106
107
108
110
111
112
113
114
115
116
122
123
124
125
126
127
128
138
139
140
141
142
143
144
146
147
148
149
150
151
152
154
155
156
157
158
159
160
161
162
163
164
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
190
191
192
193
194
195
196
197
198
199
200
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224

View File

@ -2,11 +2,11 @@ name,accuracy,graph_breaks
torchrec_dlrm,pass,7
torchrec_dlrm,pass,6
BERT_pytorch,pass,7
BERT_pytorch,pass,6
@ -18,7 +18,7 @@ DALLE2_pytorch,eager_fail_to_run,0
LearningToPaint,pass,7
LearningToPaint,pass,6
@ -26,7 +26,7 @@ Super_SloMo,pass,6
alexnet,pass,7
alexnet,pass,6
@ -50,15 +50,15 @@ cm3leon_generate,eager_fail_to_run,0
dcgan,pass,7
dcgan,pass,6
demucs,pass,10
demucs,pass,9
densenet121,pass,7
densenet121,pass,6
@ -70,7 +70,7 @@ detectron2_maskrcnn_r_50_c4,eager_fail_to_run,0
dlrm,pass,7
dlrm,pass,6
@ -86,7 +86,7 @@ drq,pass,5
fastNLP_Bert,pass,11
fastNLP_Bert,pass,10
@ -126,7 +126,7 @@ hf_Reformer,pass,25
hf_T5_base,pass,6
hf_T5_base,eager_2nd_run_OOM,0
@ -162,7 +162,7 @@ mnasnet1_0,pass,6
mobilenet_v2,pass,7
mobilenet_v2,pass,6
@ -174,7 +174,7 @@ mobilenet_v3_large,pass,6
moco,pass,18
moco,pass,17
@ -190,19 +190,19 @@ opacus_cifar10,eager_fail_to_run,0
phlippe_densenet,pass,7
phlippe_densenet,pass,6
phlippe_resnet,pass,7
phlippe_resnet,pass,6
pytorch_CycleGAN_and_pix2pix,pass,7
pytorch_CycleGAN_and_pix2pix,pass,6
pytorch_stargan,pass,7
pytorch_stargan,pass,6
@ -214,11 +214,11 @@ resnet152,pass,6
resnet18,pass,7
resnet18,pass,6
resnet50,pass,7
resnet50,pass,6
@ -234,7 +234,7 @@ sam,eager_fail_to_run,0
shufflenet_v2_x1_0,pass,7
shufflenet_v2_x1_0,pass,6
@ -242,15 +242,15 @@ soft_actor_critic,pass,5
speech_transformer,pass,17
speech_transformer,pass,16
squeezenet1_1,pass,7
squeezenet1_1,pass,6
stable_diffusion_text_encoder,pass,6
stable_diffusion_text_encoder,pass,5
@ -262,7 +262,7 @@ timm_efficientnet,pass,6
timm_regnet,pass,7
timm_regnet,pass,6
@ -270,7 +270,7 @@ timm_resnest,pass,6
timm_vision_transformer,pass,7
timm_vision_transformer,pass,6
@ -278,7 +278,7 @@ timm_vision_transformer_large,pass_due_to_skip,0
timm_vovnet,pass,7
timm_vovnet,pass,6
@ -290,11 +290,11 @@ tts_angular,pass,8
vgg16,pass,7
vgg16,pass,6
vision_maskrcnn,pass,34
vision_maskrcnn,pass,33

1 name accuracy graph_breaks
2 torchrec_dlrm pass 7 6
3 BERT_pytorch pass 7 6
4 Background_Matting pass_due_to_skip 0
5 DALLE2_pytorch eager_fail_to_run 0
6 LearningToPaint pass 7 6
7 Super_SloMo pass 6
8 alexnet pass 7 6
9 basic_gnn_edgecnn pass 21
10 basic_gnn_gcn pass 12
11 basic_gnn_gin pass 6
12 basic_gnn_sage pass 6
18 detectron2_maskrcnn_r_50_c4 eager_fail_to_run 0
19 dlrm pass 7 6
20 doctr_det_predictor eager_fail_to_run 0
21 doctr_reco_predictor eager_fail_to_run 0
22 drq pass 5
23 fastNLP_Bert pass 11 10
24 functorch_dp_cifar10 pass 6
26 hf_Albert pass 5
27 hf_Bart pass 5
28 hf_BigBird pass 5
29 hf_DistilBert pass 5
30 hf_GPT2 pass 5
31 hf_GPT2_large pass_due_to_skip 0
32 hf_Reformer pass 25
50 phlippe_resnet pass 7 6
51 pytorch_CycleGAN_and_pix2pix pass 7 6
52 pytorch_stargan pass 7 6
53 pytorch_unet pass 6
54 resnet152 pass 6
55 resnet18 pass 7 6
56 resnet50 pass 7 6
57 resnet50_quantized_qat eager_fail_to_run 0
58 resnext50_32x4d pass 6
59 sam eager_fail_to_run 0
60 shufflenet_v2_x1_0 pass 7 6
61 soft_actor_critic pass 5
62 speech_transformer pass 17 16
63 squeezenet1_1 pass 7 6
64 stable_diffusion_text_encoder pass 6 5
70 timm_vision_transformer_large pass_due_to_skip 0
71 timm_vovnet pass 7 6
72 torch_multimodal_clip pass 6
73 tts_angular pass 8
74 vgg16 pass 7 6
75 vision_maskrcnn pass 34 33
76 yolov3 pass 8
86
87
88
89
90
91
92
126
127
128
129
130
131
132
162
163
164
165
166
167
168
174
175
176
177
178
179
180
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
214
215
216
217
218
219
220
221
222
223
224
234
235
236
237
238
239
240
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
262
263
264
265
266
267
268
270
271
272
273
274
275
276
278
279
280
281
282
283
284
290
291
292
293
294
295
296
297
298
299
300

View File

@ -2,7 +2,7 @@ name,accuracy,graph_breaks
AlbertForMaskedLM,pass,5
AlbertForMaskedLM,pass,4
@ -14,11 +14,11 @@ AllenaiLongformerBase,pass,8
BartForCausalLM,pass,13
BartForCausalLM,pass,12
BartForConditionalGeneration,pass,25
BartForConditionalGeneration,pass,24
@ -34,11 +34,11 @@ BlenderbotForCausalLM,eager_fail_to_run,0
BlenderbotSmallForCausalLM,pass,13
BlenderbotSmallForCausalLM,pass,12
BlenderbotSmallForConditionalGeneration,pass,25
BlenderbotSmallForConditionalGeneration,pass,24
@ -74,7 +74,7 @@ DistillGPT2,pass,4
ElectraForCausalLM,pass,5
ElectraForCausalLM,pass,4
@ -98,15 +98,15 @@ LayoutLMForSequenceClassification,pass,6
M2M100ForConditionalGeneration,pass,5
M2M100ForConditionalGeneration,pass,4
MBartForCausalLM,pass,13
MBartForCausalLM,pass,12
MBartForConditionalGeneration,pass,25
MBartForConditionalGeneration,pass,24
@ -130,19 +130,19 @@ MobileBertForQuestionAnswering,pass,3
OPTForCausalLM,pass,13
OPTForCausalLM,pass,12
PLBartForCausalLM,pass,13
PLBartForCausalLM,pass,12
PLBartForConditionalGeneration,pass,30
PLBartForConditionalGeneration,pass,29
PegasusForCausalLM,pass,13
PegasusForCausalLM,pass,12
@ -158,7 +158,7 @@ RobertaForQuestionAnswering,pass,4
Speech2Text2ForCausalLM,pass,13
Speech2Text2ForCausalLM,pass,12
@ -170,11 +170,11 @@ T5Small,pass,4
TrOCRForCausalLM,pass,13
TrOCRForCausalLM,pass,12
XGLMForCausalLM,pass,13
XGLMForCausalLM,pass,12

1 name accuracy graph_breaks
2 AlbertForMaskedLM pass 5 4
3 AlbertForQuestionAnswering pass 4
4 AllenaiLongformerBase pass 8
5 BartForCausalLM pass 13 12
6 BartForConditionalGeneration pass 25 24
7 BertForMaskedLM pass 4
8 BertForQuestionAnswering pass 4
14 DebertaForQuestionAnswering pass 4
15 DebertaV2ForMaskedLM pass_due_to_skip 0
16 DebertaV2ForQuestionAnswering eager_1st_run_OOM 0
17 DistilBertForMaskedLM pass 4
18 DistilBertForQuestionAnswering pass 4
19 DistillGPT2 pass 4
20 ElectraForCausalLM pass 5 4
21 ElectraForQuestionAnswering pass 4
22 GPT2ForSequenceClassification pass 6
23 GoogleFnet pass 4
24 LayoutLMForMaskedLM pass 4
34 OPTForCausalLM pass 13 12
35 PLBartForCausalLM pass 13 12
36 PLBartForConditionalGeneration pass 30 29
37 PegasusForCausalLM pass 13 12
38 PegasusForConditionalGeneration pass 23
39 RobertaForCausalLM pass 4
40 RobertaForQuestionAnswering pass 4
41 Speech2Text2ForCausalLM pass 13 12
42 T5ForConditionalGeneration pass 4
43 T5Small pass 4
44 TrOCRForCausalLM pass 13 12
74
75
76
77
78
79
80
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
158
159
160
161
162
163
164
170
171
172
173
174
175
176
177
178
179
180

View File

@ -2,7 +2,7 @@ name,accuracy,graph_breaks
adv_inception_v3,pass,7
adv_inception_v3,pass,6
@ -10,7 +10,7 @@ beit_base_patch16_224,pass,6
botnet26t_256,pass,7
botnet26t_256,pass,6
@ -18,11 +18,11 @@ cait_m36_384,eager_fail_to_run,0
coat_lite_mini,pass,7
coat_lite_mini,pass,6
convit_base,pass,7
convit_base,pass,6
@ -50,11 +50,11 @@ dla102,pass,6
dm_nfnet_f0,pass,7
dm_nfnet_f0,pass,6
dpn107,pass,7
dpn107,pass,6
@ -74,15 +74,15 @@ fbnetc_100,pass,6
fbnetv3_b,pass,7
fbnetv3_b,pass,6
gernet_l,pass,7
gernet_l,pass,6
ghostnet_100,pass,7
ghostnet_100,pass,6
@ -90,7 +90,7 @@ gluon_inception_v3,pass,6
gmixer_24_224,pass,7
gmixer_24_224,pass,6
@ -102,7 +102,7 @@ hrnet_w18,pass,5
inception_v3,pass,7
inception_v3,pass,6
@ -110,7 +110,7 @@ jx_nest_base,pass,6
lcnet_050,pass,7
lcnet_050,pass,6
@ -122,7 +122,7 @@ mixer_b16_224,pass,6
mixnet_l,pass,7
mixnet_l,pass,6
@ -138,7 +138,7 @@ mobilenetv3_large_100,pass,6
mobilevit_s,pass,7
mobilevit_s,pass,6
@ -146,7 +146,7 @@ nfnet_l0,pass,6
pit_b_224,pass,7
pit_b_224,pass,6
@ -154,11 +154,11 @@ pnasnet5large,pass,5
poolformer_m36,pass,7
poolformer_m36,pass,6
regnety_002,pass,7
regnety_002,pass,6
@ -166,23 +166,23 @@ repvgg_a2,pass,6
res2net101_26w_4s,pass,7
res2net101_26w_4s,pass,6
res2net50_14w_8s,pass,7
res2net50_14w_8s,pass,6
res2next50,pass,7
res2next50,pass,6
resmlp_12_224,pass,7
resmlp_12_224,pass,6
resnest101e,pass,7
resnest101e,pass,6
@ -190,11 +190,11 @@ rexnet_100,pass,6
sebotnet33ts_256,pass,7
sebotnet33ts_256,pass,6
selecsls42b,pass,7
selecsls42b,pass,6
@ -206,19 +206,19 @@ swin_base_patch4_window7_224,pass,6
swsl_resnext101_32x16d,pass,7
swsl_resnext101_32x16d,pass,6
tf_efficientnet_b0,pass,7
tf_efficientnet_b0,pass,6
tf_mixnet_l,pass,7
tf_mixnet_l,pass,6
tinynet_a,pass,7
tinynet_a,pass,6

1 name accuracy graph_breaks
2 adv_inception_v3 pass 7 6
3 beit_base_patch16_224 pass 6
4 botnet26t_256 pass 7 6
5 cait_m36_384 eager_fail_to_run 0
6 coat_lite_mini pass 7 6
7 convit_base pass 7 6
8 convmixer_768_32 pass 5
10 crossvit_9_240 pass 6
11 cspdarknet53 pass 6
12 deit_base_distilled_patch16_224 pass 6
13 dla102 pass 6
14 dm_nfnet_f0 pass 7 6
15 dpn107 pass 7 6
16 eca_botnext26ts_256 pass 6
18 ese_vovnet19b_dw pass 6
19 fbnetc_100 pass 6
20 fbnetv3_b pass 7 6
21 gernet_l pass 7 6
22 ghostnet_100 pass 7 6
23 gluon_inception_v3 pass 6
24 gmixer_24_224 pass 7 6
25 gmlp_s16_224 pass 6
26 hrnet_w18 pass 5
27 inception_v3 pass 7 6
28 jx_nest_base pass 6
50 selecsls42b pass 7 6
51 spnasnet_100 pass 6
52 swin_base_patch4_window7_224 pass 6
53 swsl_resnext101_32x16d pass 7 6
54 tf_efficientnet_b0 pass 7 6
55 tf_mixnet_l pass 7 6
56 tinynet_a pass 7 6
57 tnt_s_patch16_224 pass 6
58 twins_pcpvt_base pass 6
59 visformer_small pass 6
60 vit_base_patch16_224 pass 6
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
90
91
92
93
94
95
96
102
103
104
105
106
107
108
110
111
112
113
114
115
116
122
123
124
125
126
127
128
138
139
140
141
142
143
144
146
147
148
149
150
151
152
154
155
156
157
158
159
160
161
162
163
164
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
190
191
192
193
194
195
196
197
198
199
200
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224

View File

@ -2,11 +2,11 @@ name,accuracy,graph_breaks
torchrec_dlrm,pass,7
torchrec_dlrm,pass,6
BERT_pytorch,pass,7
BERT_pytorch,pass,6
@ -18,7 +18,7 @@ DALLE2_pytorch,eager_fail_to_run,0
LearningToPaint,pass,7
LearningToPaint,pass,6
@ -26,7 +26,7 @@ Super_SloMo,pass,6
alexnet,pass,7
alexnet,pass,6
@ -50,15 +50,15 @@ cm3leon_generate,eager_fail_to_run,0
dcgan,pass,7
dcgan,pass,6
demucs,pass,10
demucs,pass,9
densenet121,pass,7
densenet121,pass,6
@ -70,7 +70,7 @@ detectron2_maskrcnn_r_50_c4,eager_fail_to_run,0
dlrm,pass,7
dlrm,pass,6
@ -86,7 +86,7 @@ drq,pass,5
fastNLP_Bert,pass,11
fastNLP_Bert,pass,10
@ -126,7 +126,7 @@ hf_Reformer,pass,25
hf_T5_base,OOM,3
hf_T5_base,eager_2nd_run_OOM,0
@ -162,7 +162,7 @@ mnasnet1_0,pass,6
mobilenet_v2,pass,7
mobilenet_v2,pass,6
@ -174,7 +174,7 @@ mobilenet_v3_large,pass,6
moco,pass,18
moco,pass,17
@ -190,19 +190,19 @@ opacus_cifar10,eager_fail_to_run,0
phlippe_densenet,pass,7
phlippe_densenet,pass,6
phlippe_resnet,pass,7
phlippe_resnet,pass,6
pytorch_CycleGAN_and_pix2pix,pass,7
pytorch_CycleGAN_and_pix2pix,pass,6
pytorch_stargan,pass,7
pytorch_stargan,pass,6
@ -214,11 +214,11 @@ resnet152,pass,6
resnet18,pass,7
resnet18,pass,6
resnet50,pass,7
resnet50,pass,6
@ -234,7 +234,7 @@ sam,eager_fail_to_run,0
shufflenet_v2_x1_0,pass,7
shufflenet_v2_x1_0,pass,6
@ -242,15 +242,15 @@ soft_actor_critic,pass,5
speech_transformer,pass,17
speech_transformer,pass,16
squeezenet1_1,pass,7
squeezenet1_1,pass,6
stable_diffusion_text_encoder,pass,6
stable_diffusion_text_encoder,pass,5
@ -262,7 +262,7 @@ timm_efficientnet,pass,6
timm_regnet,pass,7
timm_regnet,pass,6
@ -270,7 +270,7 @@ timm_resnest,pass,6
timm_vision_transformer,pass,7
timm_vision_transformer,pass,6
@ -278,7 +278,7 @@ timm_vision_transformer_large,pass_due_to_skip,0
timm_vovnet,pass,7
timm_vovnet,pass,6
@ -290,11 +290,11 @@ tts_angular,pass,8
vgg16,pass,7
vgg16,pass,6
vision_maskrcnn,pass,34
vision_maskrcnn,pass,33

1 name accuracy graph_breaks
2 torchrec_dlrm pass 7 6
3 BERT_pytorch pass 7 6
4 Background_Matting pass_due_to_skip 0
5 DALLE2_pytorch eager_fail_to_run 0
6 LearningToPaint pass 7 6
7 Super_SloMo pass 6
8 alexnet pass 7 6
9 basic_gnn_edgecnn pass 21
10 basic_gnn_gcn pass 12
11 basic_gnn_gin pass 6
12 basic_gnn_sage pass 6
18 detectron2_maskrcnn_r_50_c4 eager_fail_to_run 0
19 dlrm pass 7 6
20 doctr_det_predictor eager_fail_to_run 0
21 doctr_reco_predictor eager_fail_to_run 0
22 drq pass 5
23 fastNLP_Bert pass 11 10
24 functorch_dp_cifar10 pass 6
26 hf_Albert pass 5
27 hf_Bart pass 5
28 hf_BigBird pass 5
29 hf_DistilBert pass 5
30 hf_GPT2 pass 5
31 hf_GPT2_large pass_due_to_skip 0
32 hf_Reformer pass 25
50 phlippe_resnet pass 7 6
51 pytorch_CycleGAN_and_pix2pix pass 7 6
52 pytorch_stargan pass 7 6
53 pytorch_unet pass 6
54 resnet152 pass 6
55 resnet18 pass 7 6
56 resnet50 pass 7 6
57 resnet50_quantized_qat eager_fail_to_run 0
58 resnext50_32x4d pass 6
59 sam eager_fail_to_run 0
60 shufflenet_v2_x1_0 pass 7 6
61 soft_actor_critic pass 5
62 speech_transformer pass 17 16
63 squeezenet1_1 pass 7 6
64 stable_diffusion_text_encoder pass 6 5
70 timm_vision_transformer_large pass_due_to_skip 0
71 timm_vovnet pass 7 6
72 torch_multimodal_clip pass 6
73 tts_angular pass 8
74 vgg16 pass 7 6
75 vision_maskrcnn pass 34 33
76 yolov3 pass 8
86
87
88
89
90
91
92
126
127
128
129
130
131
132
162
163
164
165
166
167
168
174
175
176
177
178
179
180
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
214
215
216
217
218
219
220
221
222
223
224
234
235
236
237
238
239
240
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
262
263
264
265
266
267
268
270
271
272
273
274
275
276
278
279
280
281
282
283
284
290
291
292
293
294
295
296
297
298
299
300

View File

@ -1982,6 +1982,12 @@ class BenchmarkRunner:
if device == "cuda" and self.args.training and name not in CI_SKIP_OPTIMIZER:
if (name in CI_USE_SGD and self.args.ci) or name in BENCHMARK_USE_SGD:
self.optimizer = torch.optim.SGD(params, lr=0.01, foreach=True)
# Disable multi_tensor_sgd for benchmarking, there isn't a large performance benefit (~1%) to compiling
# this optimizer because it is a single foreach add, and increases compile time.
# After autotuning and fake tensor caching lands, we can enable, becuase the compile time impact will be lower.
# Fake Tensor caching: https://github.com/pytorch/pytorch/pull/113873
# Autotuning: https://github.com/pytorch/pytorch/issues/117447
self.optimizer.step = torch._dynamo.disable(self.optimizer.step)
else:
self.optimizer = torch.optim.Adam(
params, lr=0.01, capturable=True, foreach=True

View File

@ -11,8 +11,8 @@ import torch
import torch._inductor
# The rest of the optimizers not yet imported: Adamax, LBFGS, RAdam, SGD, SparseAdam
from torch.optim import Adadelta, Adagrad, Adam, AdamW, ASGD, NAdam, RMSprop, Rprop
# The rest of the optimizers not yet imported: Adamax, LBFGS, RAdam, SparseAdam
from torch.optim import Adadelta, Adagrad, Adam, AdamW, ASGD, NAdam, RMSprop, Rprop, SGD
from torch.testing._internal.common_optimizers import optim_db
@ -41,6 +41,10 @@ KERNEL_COUNT_OVERRIDES = {
"test_adadelta_foreach_weight_decay_maximize_cpu": 12,
"test_adadelta_foreach_rho_weight_decay_cpu": 12,
"test_adadelta_foreach_weight_decay_cpu": 12,
"test_sgd_foreach_momentum_weight_decay_cpu": 16,
"test_sgd_foreach_momentum_nesterov_weight_decay_cpu": 16,
"test_sgd_foreach_momentum_dampening_cuda": 5,
"test_sgd_foreach_momentum_cuda": 5,
}
# also tracks currently supported optimizers
@ -53,6 +57,7 @@ KERNEL_COUNTS = {
Adadelta: KernelCounts(multitensor=1, singletensor=4),
Adagrad: KernelCounts(multitensor=5, singletensor=8),
ASGD: KernelCounts(multitensor=2, singletensor=12),
SGD: KernelCounts(multitensor=2, singletensor=8),
}
@ -233,8 +238,7 @@ def make_recompile_test(optim_cls, closure=None, kernel_count=2, **kwargs):
# check no recompile here
with torch.set_grad_enabled(False):
compiled_step()
for _ in range(4):
compiled_step()
# perturb state to force recompile
@ -247,12 +251,20 @@ def make_recompile_test(optim_cls, closure=None, kernel_count=2, **kwargs):
compiled_step()
if self.check_kernel_count:
if optim_cls is SGD:
# SGD triggers an additional recompile
# because of momentum buffer list mutation in step()
multiplier = 3
else:
# currently, we compile the step and the rest of the computation
# separately because the step is a single element tensor
# hence, the usual kernel count is 2
# multiply by 2 to account for the recompile
multiplier = 2
self.assertEqual(
torch._inductor.metrics.generated_kernel_count, 2 * kernel_count
torch._inductor.metrics.generated_kernel_count,
multiplier * kernel_count,
)
return test_fn
@ -271,8 +283,6 @@ class CompiledOptimizerTests(TestCase):
super().tearDown()
torch._inductor.metrics.reset()
# test_sgd = make_test(SGD, kernel_count=1, lr=0.01)
test_adam_recompile = make_recompile_test(Adam, lr=0.01)
test_adamw_recompile = make_recompile_test(AdamW, lr=0.01)
# Need an impl which does not use python scalars
@ -289,7 +299,12 @@ class CompiledOptimizerTests(TestCase):
test_asgd_recompile_foreach = make_recompile_test(
ASGD, kernel_count=2, lr=0.01, foreach=True
)
# test_sgd_recompile = make_recompile_test(SGD, kernel_count=1, lr=0.01)
test_sgd_recompile_single = make_recompile_test(
SGD, kernel_count=4, lr=0.01, foreach=False
)
test_sgd_recompile_foreach = make_recompile_test(
SGD, kernel_count=1, lr=0.01, foreach=True
)
@requires_cuda()
def test_static_address_finalizer(self):

View File

@ -1478,7 +1478,6 @@ class TorchPatcher:
disabled_multi_tensor_opt_modules = {
adamax,
radam, # data-dependent control flow
sgd, # for now, until we can speed up compilation (this affects the benchmarks)
}
for opt_mod in optimizer_modules:

View File

@ -187,11 +187,23 @@ class OptimizerVariable(UserDefinedObjectVariable):
def update_list_args(self, tx, args, kwargs, py_args, py_kwargs):
"""Update the args and kwargs to the traced optimizer call"""
for arg, py_arg in zip(args, py_args):
if isinstance(arg, ListVariable) and all(
isinstance(t, torch.Tensor) for t in py_arg
):
if isinstance(arg, ListVariable):
assert isinstance(
py_arg, list
), "py_arg should be a list in optimizer variable"
for i, val in enumerate(py_arg):
tx.output.side_effects.mutation(arg)
arg.items.extend([self.wrap_tensor(tx, t) for t in py_arg])
if isinstance(val, torch.Tensor):
arg.items.append(self.wrap_tensor(tx, val))
else:
from .builder import SourcelessBuilder, VariableBuilder
if arg.source:
arg.items.append(
VariableBuilder(tx, GetItemSource(arg.source, i))(val)
)
else:
arg.items.append(SourcelessBuilder()(tx, val))
def create_finalizer(self, tx):
names_to_delete = self.static_tensor_names

View File

@ -34,7 +34,6 @@ dynamo_expected_failures = {
"TestCppExtensionOpenRgistration.test_open_device_registration",
"TestAutogradFallback.test_inplace_autograd_function_registered_to_cpu_mode_warn",
"TestAutogradFallback.test_inplace_autograd_function_registered_to_cpu_mode_nothing",
"TestFunctionalOptimParity.test_functional_optim_parity_sgd",
"TestIndexingCPU.test_invalid_index_cpu",
"NumpyTestsCPU.test_boolean_shape_mismatch_cpu",
"TestIndexingCPU.test_empty_ndim_index_bool_cpu",
@ -912,7 +911,6 @@ dynamo_expected_failures = {
"TestSDPACPU.test_scaled_dot_product_fused_attention_vs_math_cpu_fused_kernel0_bfloat16_batch_size_2_seq_len_267_n_head_1_head_dim_16_causal_True_train_False_cpu_bfloat16",
"TestSDPACPU.test_scaled_dot_product_fused_attention_vs_math_cpu_fused_kernel0_bfloat16_batch_size_12_seq_len_1030_n_head_1_head_dim_8_causal_False_train_False_cpu_bfloat16",
"TestSDPACPU.test_scaled_dot_product_fused_attention_vs_math_cpu_fused_kernel0_float64_batch_size_2_seq_len_267_n_head_3_head_dim_16_causal_False_train_False_cpu_float64",
"TestTransformersCPU.test_train_with_is_causal_cpu",
"TestSDPACPU.test_scaled_dot_product_fused_attention_vs_math_cpu_fused_kernel0_float32_batch_size_12_seq_len_1030_n_head_1_head_dim_8_causal_True_train_True_cpu_float32",
"TestSDPACPU.test_scaled_dot_product_fused_attention_vs_math_cpu_fused_kernel0_float64_batch_size_12_seq_len_267_n_head_3_head_dim_16_causal_False_train_True_cpu_float64",
"TestSDPACPU.test_scaled_dot_product_fused_attention_vs_math_cpu_fused_kernel0_float64_batch_size_12_seq_len_267_n_head_3_head_dim_8_causal_False_train_True_cpu_float64",
@ -2875,7 +2873,6 @@ dynamo_expected_failures = {
"TestTorchTidyProfiler.test_impl_reuse", # profiler/test_profiler
"TestExperimentalUtils.test_profiler_pattern_matcher_json_report", # profiler/test_profiler
"TestTorchTidyProfiler.test_tensorimpl_invalidation_full", # profiler/test_profiler
"TestProfiler.test_kineto_profiler_multiple_steppers", # profiler/test_profiler
"TestProfiler.test_profiler_tracing", # profiler/test_profiler
"TestProfiler.test_is_profiler_enabled", # profiler/test_profiler
"TestExperimentalUtils.test_utils_compute_idle_time", # profiler/test_profiler