mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
test fixing benchmarks (#162503)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162503 Approved by: https://github.com/huydhn ghstack dependencies: #160741
This commit is contained in:
parent
760c478a14
commit
484c4093a8
|
|
@ -393,10 +393,10 @@ elif [[ $TEST_CONFIG == *"perf_hf"* ]]; then
|
||||||
test_hf_perf
|
test_hf_perf
|
||||||
elif [[ $TEST_CONFIG == *"perf_timm"* ]]; then
|
elif [[ $TEST_CONFIG == *"perf_timm"* ]]; then
|
||||||
test_timm_perf
|
test_timm_perf
|
||||||
elif [[ $TEST_CONFIG == *"perf_smoketest"* ]]; then
|
|
||||||
test_torchbench_smoketest "${SHARD_NUMBER}"
|
|
||||||
elif [[ $TEST_CONFIG == *"aot_inductor_perf_smoketest"* ]]; then
|
elif [[ $TEST_CONFIG == *"aot_inductor_perf_smoketest"* ]]; then
|
||||||
test_aoti_torchbench_smoketest "${SHARD_NUMBER}"
|
test_aoti_torchbench_smoketest "${SHARD_NUMBER}"
|
||||||
|
elif [[ $TEST_CONFIG == *"perf_smoketest"* ]]; then
|
||||||
|
test_torchbench_smoketest "${SHARD_NUMBER}"
|
||||||
elif [[ $TEST_CONFIG == *"mps"* ]]; then
|
elif [[ $TEST_CONFIG == *"mps"* ]]; then
|
||||||
test_python_mps
|
test_python_mps
|
||||||
elif [[ $NUM_TEST_SHARDS -gt 1 ]]; then
|
elif [[ $NUM_TEST_SHARDS -gt 1 ]]; then
|
||||||
|
|
|
||||||
|
|
@ -1424,7 +1424,7 @@ class AOTInductorModelCache:
|
||||||
torch.hpu.max_memory_allocated() - pre_clone_memory_used
|
torch.hpu.max_memory_allocated() - pre_clone_memory_used
|
||||||
) / 1e9
|
) / 1e9
|
||||||
|
|
||||||
inductor_configs = {}
|
inductor_configs = {"aot_inductor.package_constants_in_so": False}
|
||||||
if mode == "max-autotune":
|
if mode == "max-autotune":
|
||||||
inductor_configs["max_autotune"] = True
|
inductor_configs["max_autotune"] = True
|
||||||
ep = torch.export.export(
|
ep = torch.export.export(
|
||||||
|
|
@ -1439,8 +1439,14 @@ class AOTInductorModelCache:
|
||||||
ep, inductor_configs=inductor_configs
|
ep, inductor_configs=inductor_configs
|
||||||
) # type: ignore[arg-type]
|
) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
compiled = torch._inductor.aoti_load_package(package_path)
|
||||||
|
compiled.load_constants(
|
||||||
|
{**ep.state_dict, **ep.constants},
|
||||||
|
check_full_update=False,
|
||||||
|
user_managed=True,
|
||||||
|
)
|
||||||
cls.cache[key] = (
|
cls.cache[key] = (
|
||||||
torch._inductor.aoti_load_package(package_path),
|
compiled,
|
||||||
clone_memory_used,
|
clone_memory_used,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -721,8 +721,15 @@ void AOTIModelPackageLoader::load_constants(
|
||||||
for (const auto& it : constants_map) {
|
for (const auto& it : constants_map) {
|
||||||
if (fqn_to_constant_name.find(it.first) != fqn_to_constant_name.end()) {
|
if (fqn_to_constant_name.find(it.first) != fqn_to_constant_name.end()) {
|
||||||
updated_constants_map.emplace(fqn_to_constant_name[it.first], it.second);
|
updated_constants_map.emplace(fqn_to_constant_name[it.first], it.second);
|
||||||
} else {
|
} else if (check_full_update) {
|
||||||
throw std::runtime_error("Constant not found: " + it.first);
|
std::string constant_fqns = "";
|
||||||
|
for (const auto& it2 : fqn_to_constant_name) {
|
||||||
|
constant_fqns += it2.first + ", ";
|
||||||
|
}
|
||||||
|
throw std::runtime_error(
|
||||||
|
"The constant with FQN " + it.first +
|
||||||
|
" was not found in the model. The available constants are: " +
|
||||||
|
constant_fqns);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user