mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[op-bench] check device attribute in user inputs
Summary: The device attribute in the op benchmark can only include 'cpu' or 'cuda'. So adding a check in this diff. Test Plan: buck run caffe2/benchmarks/operator_benchmark:benchmark_all_test -- --warmup_iterations 1 --iterations 1 Reviewed By: ngimel Differential Revision: D22538252 fbshipit-source-id: 3e5af72221fc056b8d867321ad22e35a2557b8c3
This commit is contained in:
parent
a0f110190c
commit
4ddf27ba48
|
|
@ -17,6 +17,7 @@ This module contains utilities for writing microbenchmark tests.
|
|||
|
||||
# Here are the reserved keywords in the benchmark suite
|
||||
_reserved_keywords = {"probs", "total_samples", "tags"}
|
||||
_supported_devices = {"cpu", "cuda"}
|
||||
|
||||
def shape_to_string(shape):
|
||||
return ', '.join([str(x) for x in shape])
|
||||
|
|
@ -108,6 +109,7 @@ def cross_product_configs(**configs):
|
|||
({'M': 2}, {'N' : 4}),
|
||||
({'M': 2}, {'N' : 5}))
|
||||
"""
|
||||
_validate(configs)
|
||||
configs_attrs_list = []
|
||||
for key, values in configs.items():
|
||||
tmp_results = [{key : value} for value in values]
|
||||
|
|
@ -120,6 +122,13 @@ def cross_product_configs(**configs):
|
|||
return generated_configs
|
||||
|
||||
|
||||
def _validate(configs):
|
||||
""" Validate inputs from users."""
|
||||
if 'device' in configs:
|
||||
for v in configs['device']:
|
||||
assert(v in _supported_devices), "Device needs to be a string."
|
||||
|
||||
|
||||
def config_list(**configs):
|
||||
""" Generate configs based on the list of input shapes.
|
||||
This function will take input shapes specified in a list from user. Besides
|
||||
|
|
@ -153,6 +162,8 @@ def config_list(**configs):
|
|||
if any(attr not in configs for attr in reserved_names):
|
||||
raise ValueError("Missing attrs in configs")
|
||||
|
||||
_validate(configs)
|
||||
|
||||
cross_configs = None
|
||||
if 'cross_product_configs' in configs:
|
||||
cross_configs = cross_product_configs(**configs['cross_product_configs'])
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user