Bump black version to 23.1.0 (#96578)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96578
Approved by: https://github.com/ezyang
This commit is contained in:
BowenBao 2023-03-14 19:46:45 -07:00 committed by PyTorch MergeBot
parent a229e78544
commit 60a68477a6
114 changed files with 111 additions and 167 deletions

View File

@ -878,7 +878,7 @@ init_command = [
'tools/linter/adapters/pip_init.py',
'--dry-run={{DRYRUN}}',
'--no-black-binary',
'black==22.3.0',
'black==23.1.0',
'ufmt==1.3.3',
'usort==1.0.2',
]

View File

@ -4,6 +4,7 @@ import os
from typing import Set
# Note - hf and timm have their own version of this, torchbench does not
# TOOD(voz): Someday, consolidate all the files into one runner instead of a shim like this...
def model_names(filename: str) -> Set[str]:

View File

@ -11,12 +11,10 @@ def get_field(csv, model_name: str, field: str, typ=float):
def check_graph_breaks(actual_csv, expected_csv, expected_filename):
failed = []
improved = []
for model in actual_csv["name"]:
graph_breaks = get_field(actual_csv, model, "graph_breaks", typ=int)
expected_graph_breaks = get_field(expected_csv, model, "graph_breaks", typ=int)

View File

@ -31,7 +31,6 @@ ARTIFACTS_QUERY_URL = "https://api.usw2a1.rockset.com/v1/public/shared_lambdas/c
def query_job_sha(repo, sha):
params = {
"parameters": [
{"name": "sha", "type": "string", "value": sha},
@ -108,7 +107,6 @@ def write_filtered_csvs(root_path, dataframes):
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
)

View File

@ -373,7 +373,6 @@ class HuggingfaceRunner(BenchmarkRunner):
model_name,
batch_size=None,
):
is_training = self.args.training
use_eval_mode = self.args.use_eval_mode
dtype = torch.float32
@ -513,7 +512,6 @@ def refresh_model_names_and_batch_sizes():
lm_seen = set()
family_seen = set()
for cls_name in hf_fx._SUPPORTED_MODELS:
if "For" not in cls_name:
continue

View File

@ -73,7 +73,6 @@ def bench_op(
warmup=25,
rep=75,
):
skip = False
# allocate inputs, nchw
x = torch.randn((BATCH, IN_C, IN_H, IN_W), dtype=dtype, device="cuda")

View File

@ -70,7 +70,6 @@ def bench_op(
warmup=25,
rep=75,
):
# allocate inputs, nchw
x = torch.randn((BATCH, IN_C, IN_H, IN_W), dtype=dtype, device="cuda")
w = torch.randn(

View File

@ -66,7 +66,6 @@ def bench_op(
warmup=25,
rep=75,
):
# allocate inputs, nchw
x = torch.randn((BATCH, IN_C, IN_H, IN_W), dtype=dtype, device="cuda")
w = torch.randn(

View File

@ -236,7 +236,6 @@ def bench(layer_params, layer_id, p, fusion_types=[""]):
row = [layer_id]
for fusion_type in fusion_types:
if fusion_type == "":
conv_torchinductor = getattr(Func, "conv_torchinductor")
conv = getattr(Func, "conv")

View File

@ -56,7 +56,6 @@ def bench(shape, layer_id, p, fusion_types=[""]):
row = [layer_id]
for fusion_type in fusion_types:
if fusion_type == "":
fn_mm = getattr(Func, "mm")
else:

View File

@ -46,7 +46,6 @@ def profile_op(
warmup=25,
rep=50,
):
# allocate inputs, nchw
x = torch.randn((BATCH, IN_C, IN_H, IN_W), dtype=dtype, device="cuda")
w = torch.randn(

View File

@ -60,6 +60,7 @@ out = csv.DictWriter(
out.writeheader()
out.writerow({"explain": gist_url})
# Sometimes backtraces will be in third party code, which results
# in very long file names. Delete the absolute path in this case.
def normalize_file(f):

View File

@ -182,7 +182,6 @@ class TimmRunnner(BenchmarkRunner):
model_name,
batch_size=None,
):
is_training = self.args.training
use_eval_mode = self.args.use_eval_mode

View File

@ -242,7 +242,6 @@ class TorchBenchmarkRunner(BenchmarkRunner):
batch_size=None,
part=None,
):
is_training = self.args.training
use_eval_mode = self.args.use_eval_mode
dynamic_shapes = self.args.dynamic_shapes

View File

@ -120,7 +120,7 @@ class TestInitialization(FSDPTest):
composable_handles = traversal_utils._get_fsdp_handles(composable_module)
fsdp_wrapped_handles = traversal_utils._get_fsdp_handles(fsdp_wrapped_model)
self.assertEqual(len(composable_handles), len(fsdp_wrapped_handles))
for (composable_handle, fsdp_wrapped_handle) in zip(
for composable_handle, fsdp_wrapped_handle in zip(
composable_handles, fsdp_wrapped_handles
):
self.assertEqual(
@ -179,7 +179,7 @@ class TestInitialization(FSDPTest):
policy=policy,
sync_module_states=True,
)
for (composable_param, fsdp_wrapped_param) in zip(
for composable_param, fsdp_wrapped_param in zip(
composable_module.parameters(),
fsdp_wrapped_model.parameters(),
):

View File

@ -116,7 +116,7 @@ class TestFSDPCheckpoint(FSDPTest):
assert outputs
assert models
for (l, o) in zip(losses[1:], outputs[1:]):
for l, o in zip(losses[1:], outputs[1:]):
self.assertEqual(losses[0], l)
self.assertEqual(outputs[0], o)
@ -324,7 +324,6 @@ class TestModel(nn.Module):
class TestFSDPCheckpointSubmodule(FSDPTest):
# TODO: grad value checks occasionally fails when use_reentrant = True
@skip_if_lt_x_gpu(2)
@parametrize("use_reentrant", [False])

View File

@ -70,7 +70,6 @@ class Net(nn.Module):
class DummyState:
__slots__ = ["process_group", "noise"]
def __init__(self, process_group: dist.ProcessGroup, noise: int):
@ -157,7 +156,6 @@ class TestCommunicationHooks(FSDPTest):
self.assertEqual(entry._communication_hook, default_hook)
for _ in range(4):
# Clear gradients
net_default_hook.zero_grad()
loss = net_default_hook(inpt).sum()
@ -183,7 +181,6 @@ class TestCommunicationHooks(FSDPTest):
]
def _init_model(self, core, sharding_strategy, mixed_precision=None):
device = torch.device("cuda")
return FSDP(
core,
@ -424,7 +421,6 @@ class TestCommunicationHooks(FSDPTest):
def test_fp16_hook(
self, has_wrapping: bool, sharding_strategy: Optional[ShardingStrategy]
):
state = default_hooks.LowPrecisionState(process_group=_get_default_group())
hook = default_hooks.fp16_compress_hook
@ -452,7 +448,6 @@ class TestCommunicationHooks(FSDPTest):
def test_bf16_hook(
self, has_wrapping: bool, sharding_strategy: Optional[ShardingStrategy]
):
state = default_hooks.LowPrecisionState(process_group=_get_default_group())
hook = default_hooks.bf16_compress_hook

View File

@ -160,7 +160,7 @@ class TestGradAcc(FSDPTest):
num_iters_to_acc = sum(config.num_iters for config in configs)
for _ in range(num_iters_to_acc - 1):
batches.append(tuple(permute_tensor(t) for t in batch))
for (batch1, batch2) in itertools.combinations(batches, r=2):
for batch1, batch2 in itertools.combinations(batches, r=2):
for t1, t2 in zip(batch1, batch2):
assert not torch.all(
t1 == t2

View File

@ -1338,7 +1338,6 @@ class TestFSDPOptimState(FSDPTest):
use_multiple_param_groups: bool,
use_optim_input: bool,
):
NUM_ITERS = 3
# Run a wrapped model for a few iterations
model1, optim1, optim_input1 = self._init_nested_model(

View File

@ -937,14 +937,14 @@ class TestFSDPStateDict(FSDPTest):
# Check that it can be loaded into FSDP.
new_fsdp, _ = _create_module()
_zero_model(new_fsdp)
for (p1, p2) in zip(fsdp.parameters(), new_fsdp.parameters()):
for p1, p2 in zip(fsdp.parameters(), new_fsdp.parameters()):
self.assertNotEqual(p1, p2)
with FSDP.state_dict_type(new_fsdp, STATE_DICT_MAPPING[state_dict_type]):
if state_dict_type != "local_state_dict":
# FlatParameter has not supported deepcopy yet.
state_dict = deepcopy(state_dict)
new_fsdp.load_state_dict(state_dict, strict=True)
for (p1, p2) in zip(fsdp.parameters(), new_fsdp.parameters()):
for p1, p2 in zip(fsdp.parameters(), new_fsdp.parameters()):
self.assertEqual(p1, p2)
# Test that the checkpoint can be loaded into a local model.
@ -954,7 +954,7 @@ class TestFSDPStateDict(FSDPTest):
param.zero_()
with fsdp.summon_full_params(fsdp):
for (p1, p2) in zip(fsdp.parameters(), local.parameters()):
for p1, p2 in zip(fsdp.parameters(), local.parameters()):
self.assertNotEqual(p1, p2)
if state_dict_type == "local_state_dict":
@ -963,7 +963,7 @@ class TestFSDPStateDict(FSDPTest):
with fsdp.summon_full_params(fsdp):
if self.rank == 0:
local.load_state_dict(state_dict, strict=True)
for (p1, p2) in zip(fsdp.parameters(), local.parameters()):
for p1, p2 in zip(fsdp.parameters(), local.parameters()):
self.assertEqual(p1, p2)
@skip_if_lt_x_gpu(2)

View File

@ -31,7 +31,6 @@ class TestShardUtils(TestCase):
out_offsets,
in_split_sizes,
):
for my_rank in range(world_size):
_in_split_sizes = in_split_sizes[my_rank]
_out_split_sizes = [

View File

@ -847,7 +847,6 @@ class NNModuleTests(torch._dynamo.test_case.TestCase):
torch._dynamo.config.traceable_tensor_subclasses.add(TensorProxy)
try:
x = torch.randn(1).as_subclass(TensorProxy)
cnt = torch._dynamo.testing.CompileCounter()
out1 = foo(x)
@ -862,7 +861,6 @@ class NNModuleTests(torch._dynamo.test_case.TestCase):
def test_torch_function_with_closure(self):
def run():
counter = 0
def foo(x):
@ -1097,7 +1095,7 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
opt_mod = torch._dynamo.optimize("eager")(mod)
# Check parameteres and buffers
for (p1, p2) in zip(mod.parameters(), opt_mod.parameters()):
for p1, p2 in zip(mod.parameters(), opt_mod.parameters()):
self.assertTrue(id(p1) == id(p2))
def test_recursion(self):

View File

@ -1572,7 +1572,6 @@ class ReproTests(torch._dynamo.test_case.TestCase):
self.assertEqual(y, 10)
def test_sort_out(self):
dtype = torch.float32
device = "cpu"
@ -1607,7 +1606,6 @@ class ReproTests(torch._dynamo.test_case.TestCase):
self.assertTrue(same(ref, res))
def test_sigmoid_out(self):
dtype = torch.float32
device = "cpu"

View File

@ -178,7 +178,6 @@ class TestInductorConfig(TestCase):
a(torch.randn(10))
def test_api_options(self):
reduce_overhead_opts = torch._inductor.list_mode_options("reduce-overhead")
self.assertEqual(reduce_overhead_opts["triton.cudagraphs"], True)

View File

@ -79,7 +79,6 @@ if HAS_CUDA and not TEST_WITH_ASAN:
class TestInductorDynamic(TestCase):
compile_fn = partial(torch.compile, dynamic=True)
def setUp(self):

View File

@ -597,7 +597,6 @@ class TestInductorOpInfo(TestCase):
)
except Exception as e:
if test_expect is ExpectedTestResult.XFAILURE:
raise e

View File

@ -48,6 +48,7 @@ skipIfNoBFloat16Cuda = _skipper(
lambda: not torch.cuda.is_bf16_supported(), "BFloat16 CUDA is not available"
)
# skips tests for all versions below min_opset_version.
# if exporting the op is only supported after a specific version,
# add this wrapper to prevent running the test for opset_versions

View File

@ -494,7 +494,6 @@ class TestONNXOpset(pytorch_test_common.ExportTestCase):
("zeros", "border", "reflection"),
(True, False),
):
args = (
torch.randn(n, c, h_in, w_in), # x
torch.randn(n, h_out, w_out, 2), # grid,

View File

@ -13,14 +13,12 @@ from torch.testing._internal import common_utils
class TestONNXScriptExport(common_utils.TestCase):
# opset version is
# 1. local function is supported after opset 15
# 2. onnx-script requires users to determine opset in local function
opset_version = 15
def test_onnxscript_registration_with_multiple_models(self):
from onnxscript.onnx_opset import opset15 as op
# 1. Register Selu onnxscript function as custom Op

View File

@ -12,14 +12,12 @@ from torch.testing._internal import common_utils
class TestONNXScriptRuntime(onnx_test_common._TestONNXRuntime):
# opset version is
# 1. local function is supported after opset 15
# 2. onnx-script requires users to determine opset in local function
opset_version = 15
def test_selu_from_onnxscript_example(self):
x = torch.randn(1, 2, 3, 4, requires_grad=True)
model = torch.nn.SELU()
@ -52,7 +50,6 @@ class TestONNXScriptRuntime(onnx_test_common._TestONNXRuntime):
self.run_test(model, x)
def test_layer_norm(self):
x = torch.randn(2, 3)
y = torch.randn(2, 3)
z = torch.randn(2, 3)

View File

@ -30,9 +30,7 @@ def export_to_onnx(
model: Union[torch.nn.Module, torch.jit.ScriptFunction],
input: Union[torch.Tensor, Tuple[torch.Tensor]],
custom_ops: Optional[
Iterable[
Union[contextlib.AbstractContextManager, contextlib.ContextDecorator],
]
Iterable[Union[contextlib.AbstractContextManager, contextlib.ContextDecorator]]
] = None,
mocks: Optional[Iterable] = None,
operator_export_type: torch.onnx.OperatorExportTypes = torch.onnx.OperatorExportTypes.ONNX,
@ -765,7 +763,6 @@ class TestONNXExport(pytorch_test_common.ExportTestCase):
)
def test_dropout_script(self):
eg = torch.zeros(1, 2, 3, requires_grad=True)
@jit_utils._trace(eg)

View File

@ -8600,7 +8600,6 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
@skipIfUnsupportedMinOpsetVersion(9)
def test_kldiv_loss(self):
x = torch.rand(5).log()
y = torch.rand(5)
self._kldiv_loss(x, y)
@ -12832,7 +12831,6 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
name_fn=lambda align_corners: str(align_corners),
)
def test_grid_sample(self, mode, padding_mode, align_corners):
n, c, h_in, w_in, h_out, w_out = 1, 1, 3, 2, 2, 4
class GridSampleModule(torch.nn.Module):

View File

@ -328,7 +328,6 @@ class TestONNXCustomOpShapeInference(pytorch_test_common.ExportTestCase):
self.opset_version = _constants.ONNX_MAX_OPSET
def test_setType_maintains_output_shape_for_single_custom_op(self):
self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "::linalg_inv", 9)
class CustomInverse(torch.nn.Module):
@ -363,7 +362,6 @@ class TestONNXCustomOpShapeInference(pytorch_test_common.ExportTestCase):
self.assertEqual(dim.dim_value, rank)
def test_no_setType_for_single_custom_op(self):
self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "::linalg_inv", 9)
class CustomInverse(torch.nn.Module):
@ -398,7 +396,6 @@ class TestONNXCustomOpShapeInference(pytorch_test_common.ExportTestCase):
def test_setType_maintains_output_shape_for_single_custom_op_with_dynamic_axes(
self,
):
self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "::linalg_inv", 9)
class CustomInverse(torch.nn.Module):
@ -438,7 +435,6 @@ class TestONNXCustomOpShapeInference(pytorch_test_common.ExportTestCase):
self.assertEqual(dims[i].dim_value, x.size()[i])
def test_setType_maintains_output_shape_for_single_custom_op_with_onnx_ops(self):
self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "::linalg_inv", 9)
class CustomInverse(torch.nn.Module):

View File

@ -133,6 +133,7 @@ ignores = [
ignores = [os.path.join(proj_dir, ignore) for ignore in ignores]
# Check if the compiler is hip-clang.
def is_hip_clang() -> bool:
try:

View File

@ -5,6 +5,7 @@ from torchgen.api.autograd import NativeFunctionWithDifferentiabilityInfo as NFW
from torchgen.context import native_function_manager
from torchgen.utils import T
# Like tools.api.context.with_native_function, but for
# NativeFunctionWithDifferentiabilityInfo.
def with_native_function_with_differentiability_info(

View File

@ -420,7 +420,6 @@ UNTRACEABLE_FUNCTIONS = VIEW_FUNCTIONS
def get_infos_with_derivatives_list(
differentiability_infos: Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]]
) -> List[DifferentiabilityInfo]:
diff_info_list = [
info
for diffinfo_dict in differentiability_infos.values()
@ -469,7 +468,6 @@ def gen_autograd_functions_python(
differentiability_infos: Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]],
template_path: str,
) -> None:
fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
num_shards = 5
fm.write(

View File

@ -221,6 +221,7 @@ ${assign_return_values} ([&]() {
TMP_VAR = "_tmp"
# FIXME: Ideally these functions should be methods on Type class, but we have a
# comment in codegen/model.py there saying these concepts are not well defined.
# Thus we put a version that commonly used by autograd codegen here.
@ -321,7 +322,8 @@ def emit_view_call(
def emit_view_lambda(f: NativeFunction, unpacked_bindings: List[Binding]) -> str:
"""Generate an additional lambda function to recover views in backward when as_strided is not supported.
See Note [View + Inplace update for base tensor] and [View + Inplace update for view tensor] for more details."""
See Note [View + Inplace update for base tensor] and [View + Inplace update for view tensor] for more details.
"""
input_base = "input_base"
replay_view_func = ""
updated_unpacked_args: List[str] = []

View File

@ -17,6 +17,7 @@ from torchgen.utils import FileManager, mapMaybe
OPTIONAL_TYPE_PATTERN = re.compile(r"c10::optional<(.+)>")
TYPE_PATTERN = re.compile(r"(?:const\s+)?([A-Z]\w+)")
# Add 'at::' to types defined in ATen namespace, e.g. Tensor, TensorList, IntArrayRef and etc.
# TODO: maybe update the cpp argument API to take optional namespace argument?
def fully_qualified_type(argument_type: str) -> str:

View File

@ -761,7 +761,6 @@ def gen_variable_type(
template_path: str,
used_keys: Set[str],
) -> None:
"""VariableType.h and VariableType.cpp body
This is the at::Type subclass for differentiable tensors. The

View File

@ -52,6 +52,7 @@ _GLOBAL_LOAD_DERIVATIVE_CACHE = {}
_VALID_AUTOGRAD_KEYS = set(AUTOGRAD_KEYS)
# This function directly adds per-dispatchkey derivative entries for {view}_copy variants of each view op.
# Since every {view} and {view}_copy op shares the same derivative formula,
# we generate them here instead of duplicating them in the yaml.
@ -96,7 +97,6 @@ def load_derivatives(
global _GLOBAL_LOAD_DERIVATIVE_CACHE
key = (derivatives_yaml_path, native_yaml_path)
if key not in _GLOBAL_LOAD_DERIVATIVE_CACHE:
with open(derivatives_yaml_path, "r") as f:
definitions = yaml.load(f, Loader=YamlLoader)

View File

@ -183,7 +183,6 @@ def create_debug_info_from_selected_models(
selected_models: List[dict],
new_style_rule: bool,
):
model_dict = {
"asset_info": {}, # maps asset name -> dict of asset metadata like hashes
"is_new_style_rule": new_style_rule,
@ -465,13 +464,13 @@ def fill_output(output: Dict[str, object], options: object):
# to True, since it indicates that this operator list came from something
# other than a traced operator list.
include_all_non_op_selectives = False
for (op_name, op_info) in operators.items():
for op_name, op_info in operators.items():
include_all_non_op_selectives = (
include_all_non_op_selectives or op_info.include_all_overloads
)
operators_as_dict = {}
for (k, v) in operators.items():
for k, v in operators.items():
operators_as_dict[k] = v.to_dict()
output["operators"] = operators_as_dict

View File

@ -18,14 +18,14 @@ from torchgen.selective_build.selector import (
def extract_all_operators(selective_builder: SelectiveBuilder) -> Set[str]:
ops = []
for (op_name, op) in selective_builder.operators.items():
for op_name, op in selective_builder.operators.items():
ops.append(op_name)
return set(ops)
def extract_training_operators(selective_builder: SelectiveBuilder) -> Set[str]:
ops = []
for (op_name, op) in selective_builder.operators.items():
for op_name, op in selective_builder.operators.items():
if op.is_used_for_training:
ops.append(op_name)
return set(ops)
@ -33,7 +33,7 @@ def extract_training_operators(selective_builder: SelectiveBuilder) -> Set[str]:
def throw_if_any_op_includes_overloads(selective_builder: SelectiveBuilder) -> None:
ops = []
for (op_name, op) in selective_builder.operators.items():
for op_name, op in selective_builder.operators.items():
if op.include_all_overloads:
ops.append(op_name)
if ops:
@ -47,7 +47,6 @@ def throw_if_any_op_includes_overloads(selective_builder: SelectiveBuilder) -> N
def gen_supported_mobile_models(model_dicts: List[Any], output_dir: str) -> None:
supported_mobile_models_source = """/*
* Generated by gen_oplist.py
*/

View File

@ -38,7 +38,6 @@ def is_this_type_of_tests(target_name: str, test_set_by_type: Set[str]) -> bool:
def print_test_by_type(
tests: TestList, test_set_by_type: Set[str], type_name: str, summary_file: IO[str]
) -> None:
print("Tests " + type_name + " to collect coverage:", file=summary_file)
for test in tests:
if is_this_type_of_tests(test.name, test_set_by_type):

View File

@ -22,6 +22,7 @@ result = subprocess.run(
PYTORCH_ROOT = result.stdout.decode("utf-8").strip()
IS_WINDOWS: bool = os.name == "nt"
# Returns '/usr/local/include/python<version number>'
def get_python_include_dir() -> str:
return gp()["include"]
@ -147,7 +148,7 @@ def check_file(
proc = run_command(
[binary, f"-p={build_dir}", *include_args, filename],
)
except (OSError) as err:
except OSError as err:
return [
LintMessage(
path=filename,

View File

@ -47,7 +47,7 @@ selected_mobile_ops_preamble = """#pragma once
def extract_root_operators(selective_builder: SelectiveBuilder) -> Set[str]:
ops = []
for (op_name, op) in selective_builder.operators.items():
for op_name, op in selective_builder.operators.items():
if op.is_root_operator:
ops.append(op_name)
return set(ops)

View File

@ -142,7 +142,6 @@ def _format_rule_for_cpp(rule: _RuleType) -> str:
def gen_diagnostics_python(
rules: Sequence[_RuleType], out_py_dir: str, template_dir: str
) -> None:
rule_class_lines = [_format_rule_for_python_class(rule) for rule in rules]
rule_field_lines = [_format_rule_for_python_field(rule) for rule in rules]
@ -165,7 +164,6 @@ def gen_diagnostics_python(
def gen_diagnostics_cpp(
rules: Sequence[_RuleType], out_cpp_dir: str, template_dir: str
) -> None:
rule_lines = [_format_rule_for_cpp(rule) for rule in rules]
rule_names = [f'"{_kebab_case_to_snake_case(rule["name"])}",' for rule in rules]
@ -206,7 +204,6 @@ def gen_diagnostics(
out_cpp_dir: str,
out_docs_dir: str,
) -> None:
with open(rules_path, "r") as f:
rules = yaml.load(f, Loader=torchgen_utils.YamlLoader)

View File

@ -41,7 +41,7 @@ def apply_replacements(replacements: Dict[str, str], text: str) -> str:
Returns:
Text with replacements applied, if any.
"""
for (before, after) in replacements.items():
for before, after in replacements.items():
text = text.replace(before, after)
return text

View File

@ -54,7 +54,6 @@ def generate_code(
operator_selector = SelectiveBuilder.get_nop_selector()
if subset == "libtorch" or not subset:
gen_autograd(
native_functions_path or NATIVE_FUNCTIONS_PATH,
tags_path or TAGS_PATH,

View File

@ -134,7 +134,6 @@ def rocm_get_per_process_gpu_info() -> List[Dict[str, Any]]:
if __name__ == "__main__":
handle = None
try:
pynvml.nvmlInit()

View File

@ -132,7 +132,10 @@ def upload_to_s3(
json.dump(doc, body)
body.write("\n")
S3_RESOURCE.Object(f"{bucket_name}", f"{key}",).put(
S3_RESOURCE.Object(
f"{bucket_name}",
f"{key}",
).put(
Body=gzip.compress(body.getvalue().encode()),
ContentEncoding="gzip",
ContentType="application/json",

View File

@ -13,6 +13,7 @@ from torchgen.gen_backend_stubs import run
path = os.path.dirname(os.path.realpath(__file__))
gen_backend_stubs_path = os.path.join(path, "../torchgen/gen_backend_stubs.py")
# gen_backend_stubs.py is an integration point that is called directly by external backends.
# The tests here are to confirm that badly formed inputs result in reasonable error messages.
class TestGenBackendStubs(expecttest.TestCase):

View File

@ -9,7 +9,6 @@ import os
def main() -> None:
target = os.path.join("torch", "masked", "_docs.py")
try:

View File

@ -169,6 +169,7 @@ def get_decompositions(
import torch._decomp.decompositions
import torch._refs
# This list was copied from torch/_inductor/decomposition.py
# excluding decompositions that results in prim ops
# Resulting opset of decomposition is core aten ops

View File

@ -88,6 +88,7 @@ pw_cast_for_int_to_real = partial(
type_casts, type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
)
# This expands x until x.dim() == dim. Might be useful as an operator
def _unsqueeze_to_dim(x: Tensor, dim: int):
for _ in range(dim - x.dim()):
@ -619,7 +620,6 @@ def slice_forward(
end: Optional[int] = None,
step: int = 1,
):
ndim = self.dim()
if ndim == 0:
raise RuntimeError("slice() cannot be applied to a 0-dim tensor.")

View File

@ -86,6 +86,7 @@ def _register_jit_decomposition_for_jvp(decomp, use_python=False):
# The only decompositions here are temporary or hacks for the purposes of jvp
# TODO: do these also belong here?
@maybe_register_decomposition(aten.trace.default)
def trace(self: Tensor) -> Tensor:

View File

@ -288,7 +288,6 @@ def _compile(
hooks: Hooks,
frame: Optional[types.FrameType] = None,
) -> Optional[GuardedCode]:
output: Optional[OutputGraph] = None
# This is shared across restarts
mutated_closure_cell_contents: Set[str] = set()

View File

@ -156,7 +156,6 @@ def filter_stack(stack):
def format_error_msg(exc, code, record_filename=None, frame=None):
msg = os.linesep * 2
if config.verbose:

View File

@ -11,6 +11,7 @@ logging.addLevelName(logging.CODE, "CODE")
# Disable progress bar by default, not in dynamo config because otherwise get a circular import
disable_progress = True
# Return all loggers that torchdynamo/torchinductor is responsible for
def get_loggers():
return [

View File

@ -355,7 +355,7 @@ class BuiltinVariable(VariableTracker):
return None
# Return first handler that matches the type checks
for ((type1, type2), handler) in handlers[op]:
for (type1, type2), handler in handlers[op]:
if isinstance(a, type1) and isinstance(b, type2):
return handler
@ -641,7 +641,6 @@ class BuiltinVariable(VariableTracker):
)
for i in [a, b]
):
if any([isinstance(val, FakeItemVariable) for val in [a, b]]):
return variables.FakeItemVariable.from_tensor_variable(result)
@ -678,7 +677,6 @@ class BuiltinVariable(VariableTracker):
)
return SymNodeVariable.create(tx, proxy, None)
else:
unimplemented(f"unsupported min / max over args {str(a)}, {str(b)}")
call_min = _call_min_max

View File

@ -73,6 +73,7 @@ constant_fold_functions = [
if torch.distributed.is_available():
constant_fold_functions.append(torch.distributed.is_initialized)
# TODO(voz): perhaps a decorator? This is rather readable for now tho, and not a public API.
def remap_as_fn___radd__(*args):
return torch._C._TensorBase.__radd__(*args)

View File

@ -412,7 +412,8 @@ class KernelArgs:
class CSEVariable:
"""A CSEVariable is just a name for an expression but it is useful to be able to annotate them on a backend dependent basis.
The backends can inherit from this class and overload the "create_cse_var" Kernel to do that.
The "update_on_args" method gives you a hook for annotations, see example of TritonCSEVariable in triton.py."""
The "update_on_args" method gives you a hook for annotations, see example of TritonCSEVariable in triton.py.
"""
def __init__(self, name):
self.name = name

View File

@ -1535,7 +1535,6 @@ class TritonScheduling:
@contextlib.contextmanager
def end_current_reduction_loop():
if current_loop_writes:
# flush out any other runnable nodes to reduce number of loops
for other_node in nodes[index + 1 :]:

View File

@ -183,7 +183,6 @@ class cpp:
# config specific to codegen/triton.py
class triton:
# Use cudagraphs on output code
cudagraphs = False

View File

@ -1,5 +1,6 @@
import torch
# Check the pattern: (nn.module, F.function/torch.Tensor.method) matched.
# Works for length 2 patterns with 1 module and 1 function/method.
def matches_module_function_pattern(pattern, node, modules):

View File

@ -802,7 +802,6 @@ class Reduction(Loops):
reduction_numel = V.graph.sizevars.simplify(sympy_product(reduction_ranges))
if reduction_numel == 0:
# N.B. This is a hack to generate the literal of the given type
# Ideally, we should be fixing `def constant` in triton.py
# but it breaks due to hardcoded dtypes in other places
@ -1252,7 +1251,6 @@ class PermuteView(BaseView):
class SqueezeView(BaseView):
@classmethod
def create(cls, x, *, dim=None):
if is_storage_and_layout(x):
storage, old_layout = as_storage_and_layout(x)
new_size = []
@ -3828,7 +3826,12 @@ class ConvolutionTransposeUnary(ExternKernelAlloc):
):
kernel = "torch.ops.mkldnn._convolution_transpose_pointwise"
transposed = True
(inputs, constant_args, kernel_layout, _,) = _prepare_convolution_fusion_create(
(
inputs,
constant_args,
kernel_layout,
_,
) = _prepare_convolution_fusion_create(
cls,
x,
weight,

View File

@ -2144,7 +2144,6 @@ def scatter(x, dim: int, index, src, **kwargs):
def scatter_fallback(
fn, self, dim: int, index, src, *, reduce: str = None, include_self: bool = True
):
if reduce not in {None, "sum"} or (
reduce == "sum" and self.get_dtype() in {torch.bool, torch.int64}
):
@ -2158,7 +2157,6 @@ def scatter_fallback(
@register_lowering(aten.scatter_, type_promotion_kind=None)
def scatter_(self, dim: int, index, src, *, reduce: str = None):
if reduce == "add":
reduce = "sum"
elif reduce == "multiply":
@ -2674,7 +2672,6 @@ def constant_boundary_condition_2d(x, fill_value, padding):
def pooling_size(x, i, kernel_size, stride, padding, ceil_mode):
x_out = ir.FloorDiv(
x + 2 * padding[i] - (kernel_size[i] - 1) + (stride[i] - 1), stride[i]
)
@ -3212,7 +3209,6 @@ def avg_pool2d_backward(
count_include_pad,
divisor_override=None,
):
assert not divisor_override
if not stride:
stride = kernel_size

View File

@ -441,7 +441,6 @@ def shape_of_mm(a, b):
CallFunction(aten.cat, ListOf(CallFunction(aten.mm, Arg(), Arg())), Arg()),
)
def cat_mm(match, inputs, dim):
return cat_tuned_op(match, inputs, dim, op=L[aten.mm], shape_of=shape_of_mm)

View File

@ -129,7 +129,6 @@ if has_triton():
# allocate accumulator
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
for crs in range(0, CRS, BLOCK_K):
# ------ matrix multiplication ------
acc += tl.dot(matrix_x, matrix_w)
# ------ update ptrs ------
@ -306,7 +305,6 @@ if has_triton():
# allocate accumulator
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
for crs in range(0, CRS, BLOCK_K):
# ------ matrix multiplication ------
acc += tl.dot(matrix_x, matrix_w)
# ------ update ptrs ------

View File

@ -3,7 +3,6 @@ import torch
from ..utils import has_triton
if has_triton():
import triton
class _conv1x1:

View File

@ -273,7 +273,6 @@ class LOBPCGAutogradFunction(torch.autograd.Function):
ortho_fparams: Optional[Dict[str, float]] = None,
ortho_bparams: Optional[Dict[str, bool]] = None,
) -> Tuple[Tensor, Tensor]:
# makes sure that input is contiguous for efficiency.
# Note: autograd does not support dense gradients for sparse input yet.
A = A.contiguous() if (not A.is_sparse) else A
@ -360,7 +359,6 @@ def lobpcg(
ortho_fparams: Optional[Dict[str, float]] = None,
ortho_bparams: Optional[Dict[str, bool]] = None,
) -> Tuple[Tensor, Tensor]:
"""Find the k largest (or smallest) eigenvalues and the corresponding
eigenvectors of a symmetric positive definite generalized
eigenvalue problem using matrix-free LOBPCG methods.
@ -598,7 +596,6 @@ def _lobpcg(
ortho_fparams: Optional[Dict[str, float]] = None,
ortho_bparams: Optional[Dict[str, bool]] = None,
) -> Tuple[Tensor, Tensor]:
# A must be square:
assert A.shape[-2] == A.shape[-1], A.shape
if B is not None:
@ -707,7 +704,6 @@ class LOBPCG:
method: str,
tracker: None,
) -> None:
# constant parameters
self.A = A
self.B = B
@ -833,7 +829,6 @@ class LOBPCG:
self.call_tracker()
while not self.stop_iteration():
self.update()
if not torch.jit.is_scripting() and self.tracker is not None:

View File

@ -2486,7 +2486,6 @@ def _cudnn_rnn(
batch_sizes,
dropout_state,
):
is_input_packed = len(batch_sizes) != 0
if is_input_packed:
seq_length = len(batch_sizes)
@ -2773,7 +2772,6 @@ import torch._refs.special
def activate_meta():
activate_meta_table = {}
# For a given op, we pick the most specific decomp function from

View File

@ -135,6 +135,7 @@ is_included_in_alias = torch._C._dispatch_is_included_in_alias
DispatchKey = torch._C.DispatchKey
# Equivalent to computeDispatchTableEntryWithDebug
def resolve_key(op: OperatorBase, k: DispatchKey): # type: ignore[valid-type]
# 1. (Direct) operator registration

View File

@ -940,6 +940,7 @@ bitwise_xor = _make_elementwise_binary_prim(
# doc="",
# )
# div prim performs truncation division on integer inputs
# and true division for floating and complex inputs
def _div_aten(a, b):
@ -1151,6 +1152,7 @@ zeta = _make_elementwise_binary_prim(
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
)
#
# View operations
def _as_strided_meta(
@ -1701,6 +1703,7 @@ split_dim = _make_prim(
doc=_split_dim_doc,
)
# Note: allows dimensions to be specified redundantly
def _squeeze_meta(a: TensorLikeType, dimensions: Sequence) -> TensorLikeType:
assert isinstance(a, TensorLike)
@ -1980,7 +1983,6 @@ rev = _make_prim(
def _where_meta(
pred: TensorLikeType, a: TensorLikeType, b: TensorLikeType
) -> TensorLikeType:
return _elementwise_meta(
a,
b,
@ -2004,6 +2006,7 @@ where = _make_prim(
doc=_where_doc,
)
#
# Type conversions
#
@ -2022,7 +2025,6 @@ def _convert_element_type_meta(a: TensorLikeType, dtype: torch.dtype) -> TensorL
def _convert_element_type_aten(a: Tensor, dtype: torch.dtype) -> Tensor:
# Propagates requires grad when possible
if not utils.is_grad_dtype(dtype):
requires_grad = False
@ -2078,6 +2080,7 @@ device_put = _make_prim(
doc=_device_put_doc,
)
# NOTE: need to model meta scalars
# See https://github.com/pytorch/pytorch/issues/78070
def _item_meta(a: TensorLikeType) -> FakeTensor:
@ -2100,6 +2103,7 @@ item = _make_prim(
doc=_item_doc,
)
# NOTE: need to model meta scalars
# See https://github.com/pytorch/pytorch/issues/78070
def _maximum_value_meta(dtype: torch.dtype) -> FakeTensor:
@ -2732,6 +2736,7 @@ svd = _make_prim(
# Randomness Prims
#
# TODO: add generator support
# NOTE: there is currently no way of acquiring the "default" torch generator
def _normal_meta(

View File

@ -60,6 +60,7 @@ DEFAULT_NVFUSER_PYTHON_CONFIG = MappingProxyType(
}
)
# nvFuserTensorTemplate and nvFuserScalarTemplate are helper objects
# for cached construction of the nvFuser's Fusion
# TODO: change what is stored in the cache for nvFuser's Tensor objects
@ -258,7 +259,6 @@ def nvfuser_execute(gm: GraphModule, *args, executor_parameters=None):
)
for arg in flat_args
):
# Construction of the fusion is expensive and cached based on the GraphModule
# and symbolic nvFuser args.
nv_template_args = to_nvfuser_template_args(flat_args)

View File

@ -223,7 +223,6 @@ _nvfuser_impls["{fname}"] = _{fname}_nvfuser
def _native_batch_norm_nvfuser(
fd, input, weight, bias, running_mean, running_var, training, momentum, eps
):
"""
if weight is None:
weight = fd.define_null_tensor()
@ -565,7 +564,6 @@ def register_native_batch_norm():
momentum: float,
eps: float,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if torch._prims_common.is_complex_dtype(input.dtype):
raise NotImplementedError("Complex tensors are not supported")

View File

@ -379,6 +379,7 @@ from torch._decomp import register_decomposition
infer_aten_op = object()
# TODO: add type promotion support
def _make_elementwise_unary_reference(
type_promotion_kind,
@ -556,7 +557,6 @@ def exp2(a):
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH,
)
def fill(a: TensorLikeType, value: NumberType) -> TensorLikeType:
assert isinstance(a, TensorLike)
assert isinstance(value, Number)
@ -1118,7 +1118,6 @@ def float_power(
a: Union[TensorLikeType, NumberType],
b: Union[TensorLikeType, NumberType],
) -> Tensor:
if isinstance(a, Number) and isinstance(b, Number):
raise ValueError(
"Receive two Number inputs to an elementwise binary operation!"
@ -1168,6 +1167,7 @@ def float_power(
# For reference, see CPython's implementation:
# https://github.com/python/cpython/blob/ace008c531dd685a30c1dd68f9b5ba35f20171cf/Objects/floatobject.c#L636
# TODO: add docstring
@_make_elementwise_binary_reference(
type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
@ -1801,6 +1801,7 @@ def clamp_max(
# Conditional references
#
# https://pytorch.org/docs/stable/generated/torch.where.html
# TODO: implement alternate where
@register_decomposition(aten.where)
@ -4092,7 +4093,6 @@ def new_empty(
device: Optional[torch.device] = None,
pin_memory: bool = False,
) -> TensorLikeType:
dtype = a.dtype if dtype is None else dtype
layout = a.layout if layout is None else layout
device = a.device if device is None else device
@ -4275,7 +4275,6 @@ def empty_like(
requires_grad: bool = False,
memory_format: torch.memory_format = torch.preserve_format,
) -> TensorLikeType:
dtype = a.dtype if dtype is None else dtype
layout = a.layout if layout is None else layout
device = a.device if device is None else device

View File

@ -82,7 +82,6 @@ def _dropout_helper(
def alpha_dropout(
self: TensorLikeType, p: float = 0.5, training: bool = False, inplace: bool = False
) -> TensorLikeType:
if inplace:
raise NotImplementedError
@ -178,7 +177,6 @@ def celu(
def dropout(
a: TensorLikeType, p: float = 0.5, training: bool = True, inplace: bool = False
) -> TensorLikeType:
if inplace:
raise NotImplementedError

View File

@ -295,7 +295,6 @@ class MetaConverter:
torch._C.DispatchKey.ADInplaceOrView, False
)
try:
if base.dtype == t.dtype:
pass
elif is_c_of_r(base.dtype, t.dtype):

View File

@ -180,6 +180,7 @@ def _rebuild_tensor_v2(
_sparse_tensors_to_validate: List["torch.Tensor"] = []
# In _legacy_load() in serialization.py we unpickle storages after the sparse
# tensors have been already unpickled. Those storages contain data necessary for
# validating sparse tensors: indices and values. That's why sparse tensors are

View File

@ -103,7 +103,6 @@ class DistributedDataParallel(Module):
gradient_as_bucket_view=False,
static_graph=False,
):
super().__init__()
self.logger: Optional[dist.Logger] = None
if not any((p.requires_grad for p in module.parameters())):

View File

@ -1849,10 +1849,10 @@ class FlatParamHandle:
flat_param.grad = None
def _deregister_orig_params(self):
for (param_name, module, _) in self.flat_param._param_infos:
for param_name, module, _ in self.flat_param._param_infos:
if hasattr(module, param_name):
delattr(module, param_name)
for (param_name, module, _, _, _, _) in self.flat_param._shared_param_infos:
for param_name, module, _, _, _, _ in self.flat_param._shared_param_infos:
if hasattr(module, param_name):
delattr(module, param_name)

View File

@ -123,9 +123,9 @@ _context = engine.background_context
@contextlib.contextmanager
def create_export_diagnostic_context() -> Generator[
infra.DiagnosticContext, None, None
]:
def create_export_diagnostic_context() -> (
Generator[infra.DiagnosticContext, None, None]
):
"""Create a diagnostic context for export.
This is a workaround for code robustness since diagnostic context is accessed by

View File

@ -30,7 +30,6 @@ def _export(
args,
**kwargs,
) -> Union["onnx.ModelProto", bytes]:
export_options = options.ExportOptions()
export_options.update(**kwargs)
# Apply decomposition table to the input graph.

View File

@ -149,9 +149,9 @@ _ATENLIB_FUNCTIONS = {
}
def _create_op_overload_to_exporter_key_table() -> Dict[
Union[torch._ops.OpOverload, Callable], str
]:
def _create_op_overload_to_exporter_key_table() -> (
Dict[Union[torch._ops.OpOverload, Callable], str]
):
# TODO(justinchuby): Improve how the table is constructed.
table: Dict[Union[torch._ops.OpOverload, Callable], str] = {}
@ -189,9 +189,9 @@ _OP_OVERLOAD_TO_EXPORTER_KEY_TABLE = _create_op_overload_to_exporter_key_table()
@_beartype.beartype
def _create_onnx_friendly_decomposition_table() -> Dict[
torch._ops.OpOverload, Callable
]:
def _create_onnx_friendly_decomposition_table() -> (
Dict[torch._ops.OpOverload, Callable]
):
decomposition_table: Dict[torch._ops.OpOverload, Callable] = {}
for op_overload, decomp_fn in torch._decomp.decomposition_table.items():
# Skip decomposition into "prim::*" ops, because they are not generally supported by ONNX.

View File

@ -337,7 +337,6 @@ def _export_fx_node_to_onnxscript(
_validate_op_between_ort_torch(node, symbolic_fn, torch_args, torch_kwargs)
fx_name_to_onnxscipt_value[node.name] = output
elif node.op == "output":
if isinstance(node.args[0], torch.fx.Node):
onnx_tensor_or_tensor_tuple = fx_name_to_onnxscipt_value[node.args[0].name]
onnxscript_graph.register_outputs(onnx_tensor_or_tensor_tuple)
@ -389,7 +388,6 @@ def _export_fx_node_to_onnxscript(
def export_fx_to_onnxscript(
fx_module_with_metadata: torch.fx.GraphModule, options: options.ExportOptions
):
# Initialize the ONNX graph
onnxscript_graph = graph_building.TorchScriptGraph()
tracer = graph_building.TorchScriptTracingEvaluator(onnxscript_graph)

View File

@ -253,7 +253,6 @@ def export_without_parameters_and_buffers(
Tuple[Any, ...],
Tuple[Any, ...],
]:
graph_module, bound_args = _trace_into_fx_graph_via_fx_symbolic_trace(
module, *args, **kwargs
)

View File

@ -292,7 +292,6 @@ def _find_onnxscript_op(
def _convert_tensor_to_numpy(input: Any) -> Any:
try:
import numpy as np
except ImportError:

View File

@ -68,7 +68,6 @@ def batch_norm(
eps,
cudnn_enabled,
):
if (
torch.is_autocast_enabled()
and not symbolic_helper.args_have_same_dtype(

View File

@ -1324,7 +1324,6 @@ def _op_with_optional_float_cast(g: jit_utils.GraphContext, op_name, *args, **kw
if require_cast:
for input in inputs:
if input.isCompleteTensor():
input_scalar_type = _type_utils.JitScalarType.from_value(input)
if input_scalar_type != dtype_0:
@ -4484,7 +4483,6 @@ def _generic_rnn(
batch_first=None,
batch_sizes=None,
):
warnings.warn(
"Exporting a model to ONNX with a batch_size other than 1, "
+ "with a variable length with "

View File

@ -5,6 +5,7 @@ from typing import cast
import torch
from torch.types import Storage
# because get_storage_from_record returns a tensor!?
class _HasStorage:
def __init__(self, storage):

View File

@ -16,6 +16,7 @@ _zip_searchorder = (
(".py", False),
)
# Replace any occurrences of '\r\n?' in the input string with '\n'.
# This converts DOS and Mac line endings to Unix line endings.
def _normalize_line_endings(source):

View File

@ -916,7 +916,6 @@ class PackageExporter:
def _persistent_id(self, obj):
if torch.is_storage(obj) or isinstance(obj, torch.storage.TypedStorage):
storage: Storage
if isinstance(obj, torch.storage.TypedStorage):
# TODO: Once we decide to break serialization FC, we can

View File

@ -615,6 +615,7 @@ class AliasInfo:
# the great majority of PyTorch's (public) operators.
#
# Classes and methods for the operator database
@dataclass
class OpInfo:
@ -1549,6 +1550,7 @@ def make_error_inputs_elementwise_binary(error_inputs_func):
# The following functions and classes are for testing elementwise binary operators.
# Returns a generator of pairs of contiguous tensors on the requested device
# and with the requested dtype.
#
@ -1997,7 +1999,6 @@ class BinaryUfuncInfo(OpInfo):
supports_two_python_scalars=False, # Whether the operator allows scalar x scalar inputs
**kwargs,
):
self._original_binary_ufunc_args = locals().copy()
# Elementwise binary operations perform the equivalent of test_numpy_refs
@ -2144,7 +2145,6 @@ def _filter_unary_elementwise_tensor(a, *, op):
def generate_elementwise_unary_tensors(op, *, device, dtype, requires_grad, **kwargs):
# Special-cases bool
if dtype is torch.bool:
tensors = (
@ -2491,7 +2491,6 @@ class SpectralFuncInfo(OpInfo):
decorators=None,
**kwargs,
):
self._original_spectral_func_args = dict(locals()).copy()
self._original_spectral_func_args.update(kwargs)

View File

@ -29,6 +29,7 @@ from torch.testing._internal.opinfo.core import (
)
from torch.testing._internal.opinfo.utils import prod_numpy, reference_reduction_numpy
# Used for log_softmax, softmax, softmin
def sample_inputs_softmax_variant(
op_info,

View File

@ -53,7 +53,6 @@ class SpectralFuncPythonRefInfo(SpectralFuncInfo):
supports_nvfuser=True,
**kwargs,
): # additional kwargs override kwargs inherited from the torch opinfo
self.torch_opinfo_name = torch_opinfo_name
self.torch_opinfo = _find_referenced_opinfo(
torch_opinfo_name, torch_opinfo_variant, op_db=op_db

View File

@ -36,6 +36,7 @@ from torch.testing._internal.opinfo.utils import (
if TEST_SCIPY:
import scipy.special
# TODO: Consolidate `i0e` with sample_inputs_unary when `make_tensor`,
# supports `exclude` argument.
# For more context: https://github.com/pytorch/pytorch/pull/56352#discussion_r633277617

View File

@ -103,7 +103,6 @@ class PythonRefInfo(OpInfo):
supports_nvfuser=True,
**kwargs,
): # additional kwargs override kwargs inherited from the torch opinfo
self.torch_opinfo_name = torch_opinfo_name
self.torch_opinfo_variant_name = torch_opinfo_variant_name
self.torch_opinfo = _find_referenced_opinfo(
@ -134,7 +133,6 @@ class ReductionPythonRefInfo(ReductionOpInfo):
supports_nvfuser=True,
**kwargs,
): # additional kwargs override kwargs inherited from the torch opinfo
self.torch_opinfo_name = torch_opinfo_name
self.torch_opinfo_variant_name = torch_opinfo_variant_name
self.torch_opinfo = _find_referenced_opinfo(
@ -169,7 +167,6 @@ class ElementwiseUnaryPythonRefInfo(UnaryUfuncInfo):
supports_nvfuser=True,
**kwargs,
): # additional kwargs override kwargs inherited from the torch opinfo
self.torch_opinfo_name = torch_opinfo_name
self.torch_opinfo_variant_name = torch_opinfo_variant_name
self.torch_opinfo = _find_referenced_opinfo(
@ -201,7 +198,6 @@ class ElementwiseBinaryPythonRefInfo(BinaryUfuncInfo):
supports_nvfuser=True,
**kwargs,
): # additional kwargs override kwargs inherited from the torch opinfo
self.torch_opinfo_name = torch_opinfo_name
self.torch_opinfo_variant_name = torch_opinfo_variant_name
self.torch_opinfo = _find_referenced_opinfo(

View File

@ -1,5 +1,6 @@
import sympy
# The normal Python interpretation of the operators
# NB: For magic methods this needs to use normal magic methods
# so that test_magic_methods works

View File

@ -313,6 +313,7 @@ JIT_TO_CPP_DEFAULT = {
"long": "at::kLong",
}
# Convert a JIT default into C++ expression representing the default
def default_expr(d: str, t: Type, *, symint: bool) -> str:
if d == "None" and str(t) == "Tensor?":

View File

@ -69,6 +69,7 @@ reapply_views_binding = Binding(
default=None,
)
# The lambda capture itself doesn't have a name.
# The name returned here corresponds to the name of the inner function called by the lambda.
def name(

Some files were not shown because too many files have changed in this diff Show More