Enable in oss (#124031)

Biggest movement is 4% HF inference, 9% TIMM inference. Note, this is max-autotune mode so we are more tolerant of compilation increases. We could improve compilation time by limiting:
```
# Take how many of the top triton kernels to benchmark epilogue
max_epilogue_benchmarked_choices = 3
```

There is a hf_Whisper failure which you can repro on main without this stack with `TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_BACKENDS=TRITON TORCHINDUCTOR_MAX_AUTOTUNE=1 python benchmarks/dynamo/torchbench.py --backend inductor --amp --accuracy --training --only hf_Whisper`. When you turn off epilogue fusion, it fixes the accuracy. I bisected the failure to an epilogue, however when you compare the results of that epilogue with the corresponding separate kernels the results of the output are equivalent.

Inference:

<img width="1686" alt="image" src="https://github.com/pytorch/pytorch/assets/11477974/0b240080-cd33-4c08-89d3-583103b1fb0c">

Training:

<img width="1329" alt="Screenshot 2024-04-16 at 6 16 30 PM" src="https://github.com/pytorch/pytorch/assets/11477974/db0afcc9-7288-4c27-84ce-4fc1a5690788">

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124031
Approved by: https://github.com/Chillee, https://github.com/shunting314
ghstack dependencies: #124030, #122642, #123229, #122825
This commit is contained in:
eellison 2024-04-18 21:38:19 -07:00 committed by PyTorch MergeBot
parent e6a788ac26
commit 000d55870a
7 changed files with 66 additions and 3 deletions

View File

@ -2578,6 +2578,14 @@ class BenchmarkRunner:
# E.g., the output order might not match, None might be part of output, etc.
try:
if self.args.training and self.args.amp:
if process_fn := self.get_output_amp_train_process_func.get(
name, None
):
correct_result = process_fn(correct_result)
new_result = process_fn(new_result)
fp64_outputs = process_fn(fp64_outputs)
if not same(
correct_result,
new_result,

View File

@ -55,6 +55,12 @@ imports = [
]
def process_hf_reformer_output(out):
assert isinstance(out, list)
# second output is unstable
return [elem for i, elem in enumerate(out) if i != 1]
try:
mod = importlib.import_module("transformers")
for cls in imports:
@ -532,6 +538,10 @@ class HuggingfaceRunner(BenchmarkRunner):
return SKIP_ACCURACY_CHECK_MODELS
return set()
@property
def get_output_amp_train_process_func(self):
return {}
def pick_grad(self, name, is_training):
if is_training:
return torch.enable_grad()

View File

@ -194,6 +194,10 @@ class TimmRunner(BenchmarkRunner):
def force_fp16_for_bf16_models(self):
return set()
@property
def get_output_amp_train_process_func(self):
return {}
@property
def skip_accuracy_check_as_eager_non_deterministic(self):
if self.args.accuracy and self.args.training:

View File

@ -88,6 +88,30 @@ def load_yaml_file():
return maybe_list_to_set(data)
def process_hf_reformer_output(out):
assert isinstance(out, list)
# second output is unstable
return [elem for i, elem in enumerate(out) if i != 1]
def process_hf_whisper_output(out):
out_ret = []
for i, elem in enumerate(out):
if i == 0:
assert isinstance(elem, dict)
out_ret.append({k: v for k, v in elem.items() if k != "logits"})
elif i != 1:
out_ret.append(elem)
return out_ret
process_train_model_output = {
"hf_Reformer": process_hf_reformer_output,
"hf_Whisper": process_hf_whisper_output,
}
class TorchBenchmarkRunner(BenchmarkRunner):
def __init__(self):
super().__init__()
@ -142,6 +166,10 @@ class TorchBenchmarkRunner(BenchmarkRunner):
def non_deterministic_models(self):
return self._config["non_deterministic"]
@property
def get_output_amp_train_process_func(self):
return process_train_model_output
@property
def skip_not_suitable_for_training_models(self):
return self._skip["test"]["training"]

View File

@ -3932,7 +3932,17 @@ class TritonScheduling(BaseScheduling):
wrapped_jit_function = mod.triton_
# call once to trigger the compilation
call(wrapped_jit_function.clone_args(*args)[0])
try:
call(wrapped_jit_function.clone_args(*args)[0])
except Exception as e:
log.debug(
"Exception (%s) in compiling fused nodes %s",
e,
{n.get_name() for n in nodes},
)
ms = float("inf")
store_cache()
return ms, mod.__file__
launchers = wrapped_jit_function.launchers
assert len(launchers) == 1

View File

@ -303,7 +303,10 @@ benchmark_fusion = os.environ.get("TORCHINDUCTOR_BENCHMARK_FUSION") == "1"
enabled_metric_tables = os.environ.get("TORCHINDUCTOR_ENABLED_METRIC_TABLES", "")
benchmark_multi_templates = (
os.environ.get("TORCHINDUCTOR_BENCHMARK_MULTI_TEMPLATES", "0") == "1"
os.environ.get(
"TORCHINDUCTOR_BENCHMARK_MULTI_TEMPLATES", "0" if is_fbcode() else "1"
)
== "1"
)
# Take how many of the top triton kernels to benchmark epilogue

View File

@ -1035,7 +1035,7 @@ class AlgorithmSelectorCache(PersistentCache):
precompile_fn = precompile(choices)
if return_multi_template:
if return_multi_template and (config.max_autotune or config.max_autotune_gemm):
def get_timings():
timings = do_autotuning(precompile_fn)