rework device type filter rule (#55753)

Summary:
Currently common_device_type generates device-specific test based on vague rules. see https://github.com/pytorch/pytorch/issues/55707.
This should fix https://github.com/pytorch/pytorch/issues/55707

# Changes included
This PR changes the rule:
1. First user provided args (`except_for` and `only_for`) are processed to filter out desired device type from a ALL_AVAILABLE_LIST
2. Then environment variables are processed the exact same way.

tests are generated based on the final filtered list.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/55753

Test Plan: CI

Reviewed By: seemethere, ngimel

Differential Revision: D27709192

Pulled By: walterddr

fbshipit-source-id: 1d5378ef013b22a7fb9fdae24b486730b2e67401
This commit is contained in:
Rong Rong (AI Infra) 2021-04-12 16:04:01 -07:00 committed by Facebook GitHub Bot
parent dab1cdf7cb
commit d0cd16899f

View File

@ -427,6 +427,21 @@ def get_device_type_test_bases():
device_type_test_bases = get_device_type_test_bases()
def filter_desired_device_types(device_type_test_bases, except_for=None, only_for=None):
# device type cannot appear in both except_for and only_for
intersect = set(except_for if except_for else []) & set(only_for if only_for else [])
assert not intersect, f"device ({intersect}) appeared in both except_for and only_for"
if except_for is not None:
device_type_test_bases = filter(
lambda x: x.device_type not in except_for, device_type_test_bases)
if only_for is not None:
device_type_test_bases = filter(
lambda x: x.device_type in only_for, device_type_test_bases)
return list(device_type_test_bases)
# Note [How to extend DeviceTypeTestBase to add new test device]
# The following logic optionally allows downstream projects like pytorch/xla to
# add more test devices.
@ -474,31 +489,28 @@ def instantiate_device_type_tests(generic_test_class, scope, except_for=None, on
# Acquires members names
# See Note [Overriding methods in generic tests]
generic_members = set(generic_test_class.__dict__.keys()) - set(empty_class.__dict__.keys())
generic_tests = [x for x in generic_members if x.startswith('test')]
generic_tests = [x for x in generic_members if x.startswith('test')]\
def split_if_not_empty(x):
# Filter out the device types based on user inputs
desired_device_type_test_bases = filter_desired_device_types(device_type_test_bases,
except_for, only_for)
def split_if_not_empty(x: str):
return x.split(",") if len(x) != 0 else []
# Derive defaults from environment variables if available, default is still none
# Filter out the device types based on environment variables if available
# Usage:
# export PYTORCH_TESTING_DEVICE_ONLY_FOR=cuda,cpu
# export PYTORCH_TESTING_DEVICE_EXCEPT_FOR=xla
if only_for is None:
only_for = split_if_not_empty(os.getenv("PYTORCH_TESTING_DEVICE_ONLY_FOR", ''))
env_only_for = split_if_not_empty(os.getenv("PYTORCH_TESTING_DEVICE_ONLY_FOR", ''))
env_except_for = split_if_not_empty(os.getenv("PYTORCH_TESTING_DEVICE_EXCEPT_FOR", ''))
desired_device_type_test_bases = filter_desired_device_types(desired_device_type_test_bases,
env_except_for, env_only_for)
if except_for is None:
except_for = split_if_not_empty(os.getenv("PYTORCH_TESTING_DEVICE_EXCEPT_FOR", ''))
# Creates device-specific test cases
for base in device_type_test_bases:
# Skips bases listed in except_for
if except_for and only_for:
assert base.device_type not in except_for or base.device_type not in only_for,\
"same device cannot appear in except_for and only_for"
if except_for and base.device_type in except_for:
continue
if only_for and base.device_type not in only_for:
continue
for base in desired_device_type_test_bases:
# Special-case for ROCm testing -- only test for 'cuda' i.e. ROCm device by default
# The except_for and only_for cases were already checked above. At this point we only need to check 'cuda'.
if TEST_WITH_ROCM and base.device_type != 'cuda':