mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[op-bench] check device attribute in user inputs
Summary: The device attribute in the op benchmark can only include 'cpu' or 'cuda'. So adding a check in this diff. Test Plan: buck run caffe2/benchmarks/operator_benchmark:benchmark_all_test -- --warmup_iterations 1 --iterations 1 Reviewed By: ngimel Differential Revision: D22538252 fbshipit-source-id: 3e5af72221fc056b8d867321ad22e35a2557b8c3
This commit is contained in:
parent
a0f110190c
commit
4ddf27ba48
|
|
@ -17,6 +17,7 @@ This module contains utilities for writing microbenchmark tests.
|
|||
|
||||
# Here are the reserved keywords in the benchmark suite
|
||||
_reserved_keywords = {"probs", "total_samples", "tags"}
|
||||
_supported_devices = {"cpu", "cuda"}
|
||||
|
||||
def shape_to_string(shape):
|
||||
return ', '.join([str(x) for x in shape])
|
||||
|
|
@ -108,6 +109,7 @@ def cross_product_configs(**configs):
|
|||
({'M': 2}, {'N' : 4}),
|
||||
({'M': 2}, {'N' : 5}))
|
||||
"""
|
||||
_validate(configs)
|
||||
configs_attrs_list = []
|
||||
for key, values in configs.items():
|
||||
tmp_results = [{key : value} for value in values]
|
||||
|
|
@ -120,6 +122,13 @@ def cross_product_configs(**configs):
|
|||
return generated_configs
|
||||
|
||||
|
||||
def _validate(configs):
|
||||
""" Validate inputs from users."""
|
||||
if 'device' in configs:
|
||||
for v in configs['device']:
|
||||
assert(v in _supported_devices), "Device needs to be a string."
|
||||
|
||||
|
||||
def config_list(**configs):
|
||||
""" Generate configs based on the list of input shapes.
|
||||
This function will take input shapes specified in a list from user. Besides
|
||||
|
|
@ -153,6 +162,8 @@ def config_list(**configs):
|
|||
if any(attr not in configs for attr in reserved_names):
|
||||
raise ValueError("Missing attrs in configs")
|
||||
|
||||
_validate(configs)
|
||||
|
||||
cross_configs = None
|
||||
if 'cross_product_configs' in configs:
|
||||
cross_configs = cross_product_configs(**configs['cross_product_configs'])
|
||||
|
|
@ -318,14 +329,14 @@ def get_operator_range(chars_range):
|
|||
if chars_range == 'None' or chars_range is None:
|
||||
return None
|
||||
|
||||
if all(item not in chars_range for item in [',', '-']):
|
||||
raise ValueError("The correct format for operator_range is "
|
||||
if all(item not in chars_range for item in [',', '-']):
|
||||
raise ValueError("The correct format for operator_range is "
|
||||
"<start>-<end>, or <point>, <start>-<end>")
|
||||
|
||||
ops_start_chars_set = set()
|
||||
ranges = chars_range.split(',')
|
||||
for item in ranges:
|
||||
if len(item) == 1:
|
||||
for item in ranges:
|
||||
if len(item) == 1:
|
||||
ops_start_chars_set.add(item.lower())
|
||||
continue
|
||||
start, end = item.split("-")
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user