Bump black version to 23.1.0 (#96578)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96578
Approved by: https://github.com/ezyang
This commit is contained in:
BowenBao 2023-03-14 19:46:45 -07:00 committed by PyTorch MergeBot
parent a229e78544
commit 60a68477a6
114 changed files with 111 additions and 167 deletions

View File

@ -878,7 +878,7 @@ init_command = [
'tools/linter/adapters/pip_init.py', 'tools/linter/adapters/pip_init.py',
'--dry-run={{DRYRUN}}', '--dry-run={{DRYRUN}}',
'--no-black-binary', '--no-black-binary',
'black==22.3.0', 'black==23.1.0',
'ufmt==1.3.3', 'ufmt==1.3.3',
'usort==1.0.2', 'usort==1.0.2',
] ]

View File

@ -4,6 +4,7 @@ import os
from typing import Set from typing import Set
# Note - hf and timm have their own version of this, torchbench does not # Note - hf and timm have their own version of this, torchbench does not
# TOOD(voz): Someday, consolidate all the files into one runner instead of a shim like this... # TOOD(voz): Someday, consolidate all the files into one runner instead of a shim like this...
def model_names(filename: str) -> Set[str]: def model_names(filename: str) -> Set[str]:

View File

@ -11,12 +11,10 @@ def get_field(csv, model_name: str, field: str, typ=float):
def check_graph_breaks(actual_csv, expected_csv, expected_filename): def check_graph_breaks(actual_csv, expected_csv, expected_filename):
failed = [] failed = []
improved = [] improved = []
for model in actual_csv["name"]: for model in actual_csv["name"]:
graph_breaks = get_field(actual_csv, model, "graph_breaks", typ=int) graph_breaks = get_field(actual_csv, model, "graph_breaks", typ=int)
expected_graph_breaks = get_field(expected_csv, model, "graph_breaks", typ=int) expected_graph_breaks = get_field(expected_csv, model, "graph_breaks", typ=int)

View File

@ -31,7 +31,6 @@ ARTIFACTS_QUERY_URL = "https://api.usw2a1.rockset.com/v1/public/shared_lambdas/c
def query_job_sha(repo, sha): def query_job_sha(repo, sha):
params = { params = {
"parameters": [ "parameters": [
{"name": "sha", "type": "string", "value": sha}, {"name": "sha", "type": "string", "value": sha},
@ -108,7 +107,6 @@ def write_filtered_csvs(root_path, dataframes):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
) )

View File

@ -373,7 +373,6 @@ class HuggingfaceRunner(BenchmarkRunner):
model_name, model_name,
batch_size=None, batch_size=None,
): ):
is_training = self.args.training is_training = self.args.training
use_eval_mode = self.args.use_eval_mode use_eval_mode = self.args.use_eval_mode
dtype = torch.float32 dtype = torch.float32
@ -513,7 +512,6 @@ def refresh_model_names_and_batch_sizes():
lm_seen = set() lm_seen = set()
family_seen = set() family_seen = set()
for cls_name in hf_fx._SUPPORTED_MODELS: for cls_name in hf_fx._SUPPORTED_MODELS:
if "For" not in cls_name: if "For" not in cls_name:
continue continue

View File

@ -73,7 +73,6 @@ def bench_op(
warmup=25, warmup=25,
rep=75, rep=75,
): ):
skip = False skip = False
# allocate inputs, nchw # allocate inputs, nchw
x = torch.randn((BATCH, IN_C, IN_H, IN_W), dtype=dtype, device="cuda") x = torch.randn((BATCH, IN_C, IN_H, IN_W), dtype=dtype, device="cuda")

View File

@ -70,7 +70,6 @@ def bench_op(
warmup=25, warmup=25,
rep=75, rep=75,
): ):
# allocate inputs, nchw # allocate inputs, nchw
x = torch.randn((BATCH, IN_C, IN_H, IN_W), dtype=dtype, device="cuda") x = torch.randn((BATCH, IN_C, IN_H, IN_W), dtype=dtype, device="cuda")
w = torch.randn( w = torch.randn(

View File

@ -66,7 +66,6 @@ def bench_op(
warmup=25, warmup=25,
rep=75, rep=75,
): ):
# allocate inputs, nchw # allocate inputs, nchw
x = torch.randn((BATCH, IN_C, IN_H, IN_W), dtype=dtype, device="cuda") x = torch.randn((BATCH, IN_C, IN_H, IN_W), dtype=dtype, device="cuda")
w = torch.randn( w = torch.randn(

View File

@ -236,7 +236,6 @@ def bench(layer_params, layer_id, p, fusion_types=[""]):
row = [layer_id] row = [layer_id]
for fusion_type in fusion_types: for fusion_type in fusion_types:
if fusion_type == "": if fusion_type == "":
conv_torchinductor = getattr(Func, "conv_torchinductor") conv_torchinductor = getattr(Func, "conv_torchinductor")
conv = getattr(Func, "conv") conv = getattr(Func, "conv")

View File

@ -56,7 +56,6 @@ def bench(shape, layer_id, p, fusion_types=[""]):
row = [layer_id] row = [layer_id]
for fusion_type in fusion_types: for fusion_type in fusion_types:
if fusion_type == "": if fusion_type == "":
fn_mm = getattr(Func, "mm") fn_mm = getattr(Func, "mm")
else: else:

View File

@ -46,7 +46,6 @@ def profile_op(
warmup=25, warmup=25,
rep=50, rep=50,
): ):
# allocate inputs, nchw # allocate inputs, nchw
x = torch.randn((BATCH, IN_C, IN_H, IN_W), dtype=dtype, device="cuda") x = torch.randn((BATCH, IN_C, IN_H, IN_W), dtype=dtype, device="cuda")
w = torch.randn( w = torch.randn(

View File

@ -60,6 +60,7 @@ out = csv.DictWriter(
out.writeheader() out.writeheader()
out.writerow({"explain": gist_url}) out.writerow({"explain": gist_url})
# Sometimes backtraces will be in third party code, which results # Sometimes backtraces will be in third party code, which results
# in very long file names. Delete the absolute path in this case. # in very long file names. Delete the absolute path in this case.
def normalize_file(f): def normalize_file(f):

View File

@ -182,7 +182,6 @@ class TimmRunnner(BenchmarkRunner):
model_name, model_name,
batch_size=None, batch_size=None,
): ):
is_training = self.args.training is_training = self.args.training
use_eval_mode = self.args.use_eval_mode use_eval_mode = self.args.use_eval_mode

View File

@ -242,7 +242,6 @@ class TorchBenchmarkRunner(BenchmarkRunner):
batch_size=None, batch_size=None,
part=None, part=None,
): ):
is_training = self.args.training is_training = self.args.training
use_eval_mode = self.args.use_eval_mode use_eval_mode = self.args.use_eval_mode
dynamic_shapes = self.args.dynamic_shapes dynamic_shapes = self.args.dynamic_shapes

View File

@ -120,7 +120,7 @@ class TestInitialization(FSDPTest):
composable_handles = traversal_utils._get_fsdp_handles(composable_module) composable_handles = traversal_utils._get_fsdp_handles(composable_module)
fsdp_wrapped_handles = traversal_utils._get_fsdp_handles(fsdp_wrapped_model) fsdp_wrapped_handles = traversal_utils._get_fsdp_handles(fsdp_wrapped_model)
self.assertEqual(len(composable_handles), len(fsdp_wrapped_handles)) self.assertEqual(len(composable_handles), len(fsdp_wrapped_handles))
for (composable_handle, fsdp_wrapped_handle) in zip( for composable_handle, fsdp_wrapped_handle in zip(
composable_handles, fsdp_wrapped_handles composable_handles, fsdp_wrapped_handles
): ):
self.assertEqual( self.assertEqual(
@ -179,7 +179,7 @@ class TestInitialization(FSDPTest):
policy=policy, policy=policy,
sync_module_states=True, sync_module_states=True,
) )
for (composable_param, fsdp_wrapped_param) in zip( for composable_param, fsdp_wrapped_param in zip(
composable_module.parameters(), composable_module.parameters(),
fsdp_wrapped_model.parameters(), fsdp_wrapped_model.parameters(),
): ):

View File

@ -116,7 +116,7 @@ class TestFSDPCheckpoint(FSDPTest):
assert outputs assert outputs
assert models assert models
for (l, o) in zip(losses[1:], outputs[1:]): for l, o in zip(losses[1:], outputs[1:]):
self.assertEqual(losses[0], l) self.assertEqual(losses[0], l)
self.assertEqual(outputs[0], o) self.assertEqual(outputs[0], o)
@ -324,7 +324,6 @@ class TestModel(nn.Module):
class TestFSDPCheckpointSubmodule(FSDPTest): class TestFSDPCheckpointSubmodule(FSDPTest):
# TODO: grad value checks occasionally fails when use_reentrant = True # TODO: grad value checks occasionally fails when use_reentrant = True
@skip_if_lt_x_gpu(2) @skip_if_lt_x_gpu(2)
@parametrize("use_reentrant", [False]) @parametrize("use_reentrant", [False])

View File

@ -70,7 +70,6 @@ class Net(nn.Module):
class DummyState: class DummyState:
__slots__ = ["process_group", "noise"] __slots__ = ["process_group", "noise"]
def __init__(self, process_group: dist.ProcessGroup, noise: int): def __init__(self, process_group: dist.ProcessGroup, noise: int):
@ -157,7 +156,6 @@ class TestCommunicationHooks(FSDPTest):
self.assertEqual(entry._communication_hook, default_hook) self.assertEqual(entry._communication_hook, default_hook)
for _ in range(4): for _ in range(4):
# Clear gradients # Clear gradients
net_default_hook.zero_grad() net_default_hook.zero_grad()
loss = net_default_hook(inpt).sum() loss = net_default_hook(inpt).sum()
@ -183,7 +181,6 @@ class TestCommunicationHooks(FSDPTest):
] ]
def _init_model(self, core, sharding_strategy, mixed_precision=None): def _init_model(self, core, sharding_strategy, mixed_precision=None):
device = torch.device("cuda") device = torch.device("cuda")
return FSDP( return FSDP(
core, core,
@ -424,7 +421,6 @@ class TestCommunicationHooks(FSDPTest):
def test_fp16_hook( def test_fp16_hook(
self, has_wrapping: bool, sharding_strategy: Optional[ShardingStrategy] self, has_wrapping: bool, sharding_strategy: Optional[ShardingStrategy]
): ):
state = default_hooks.LowPrecisionState(process_group=_get_default_group()) state = default_hooks.LowPrecisionState(process_group=_get_default_group())
hook = default_hooks.fp16_compress_hook hook = default_hooks.fp16_compress_hook
@ -452,7 +448,6 @@ class TestCommunicationHooks(FSDPTest):
def test_bf16_hook( def test_bf16_hook(
self, has_wrapping: bool, sharding_strategy: Optional[ShardingStrategy] self, has_wrapping: bool, sharding_strategy: Optional[ShardingStrategy]
): ):
state = default_hooks.LowPrecisionState(process_group=_get_default_group()) state = default_hooks.LowPrecisionState(process_group=_get_default_group())
hook = default_hooks.bf16_compress_hook hook = default_hooks.bf16_compress_hook

View File

@ -160,7 +160,7 @@ class TestGradAcc(FSDPTest):
num_iters_to_acc = sum(config.num_iters for config in configs) num_iters_to_acc = sum(config.num_iters for config in configs)
for _ in range(num_iters_to_acc - 1): for _ in range(num_iters_to_acc - 1):
batches.append(tuple(permute_tensor(t) for t in batch)) batches.append(tuple(permute_tensor(t) for t in batch))
for (batch1, batch2) in itertools.combinations(batches, r=2): for batch1, batch2 in itertools.combinations(batches, r=2):
for t1, t2 in zip(batch1, batch2): for t1, t2 in zip(batch1, batch2):
assert not torch.all( assert not torch.all(
t1 == t2 t1 == t2

View File

@ -1338,7 +1338,6 @@ class TestFSDPOptimState(FSDPTest):
use_multiple_param_groups: bool, use_multiple_param_groups: bool,
use_optim_input: bool, use_optim_input: bool,
): ):
NUM_ITERS = 3 NUM_ITERS = 3
# Run a wrapped model for a few iterations # Run a wrapped model for a few iterations
model1, optim1, optim_input1 = self._init_nested_model( model1, optim1, optim_input1 = self._init_nested_model(

View File

@ -937,14 +937,14 @@ class TestFSDPStateDict(FSDPTest):
# Check that it can be loaded into FSDP. # Check that it can be loaded into FSDP.
new_fsdp, _ = _create_module() new_fsdp, _ = _create_module()
_zero_model(new_fsdp) _zero_model(new_fsdp)
for (p1, p2) in zip(fsdp.parameters(), new_fsdp.parameters()): for p1, p2 in zip(fsdp.parameters(), new_fsdp.parameters()):
self.assertNotEqual(p1, p2) self.assertNotEqual(p1, p2)
with FSDP.state_dict_type(new_fsdp, STATE_DICT_MAPPING[state_dict_type]): with FSDP.state_dict_type(new_fsdp, STATE_DICT_MAPPING[state_dict_type]):
if state_dict_type != "local_state_dict": if state_dict_type != "local_state_dict":
# FlatParameter has not supported deepcopy yet. # FlatParameter has not supported deepcopy yet.
state_dict = deepcopy(state_dict) state_dict = deepcopy(state_dict)
new_fsdp.load_state_dict(state_dict, strict=True) new_fsdp.load_state_dict(state_dict, strict=True)
for (p1, p2) in zip(fsdp.parameters(), new_fsdp.parameters()): for p1, p2 in zip(fsdp.parameters(), new_fsdp.parameters()):
self.assertEqual(p1, p2) self.assertEqual(p1, p2)
# Test that the checkpoint can be loaded into a local model. # Test that the checkpoint can be loaded into a local model.
@ -954,7 +954,7 @@ class TestFSDPStateDict(FSDPTest):
param.zero_() param.zero_()
with fsdp.summon_full_params(fsdp): with fsdp.summon_full_params(fsdp):
for (p1, p2) in zip(fsdp.parameters(), local.parameters()): for p1, p2 in zip(fsdp.parameters(), local.parameters()):
self.assertNotEqual(p1, p2) self.assertNotEqual(p1, p2)
if state_dict_type == "local_state_dict": if state_dict_type == "local_state_dict":
@ -963,7 +963,7 @@ class TestFSDPStateDict(FSDPTest):
with fsdp.summon_full_params(fsdp): with fsdp.summon_full_params(fsdp):
if self.rank == 0: if self.rank == 0:
local.load_state_dict(state_dict, strict=True) local.load_state_dict(state_dict, strict=True)
for (p1, p2) in zip(fsdp.parameters(), local.parameters()): for p1, p2 in zip(fsdp.parameters(), local.parameters()):
self.assertEqual(p1, p2) self.assertEqual(p1, p2)
@skip_if_lt_x_gpu(2) @skip_if_lt_x_gpu(2)

View File

@ -31,7 +31,6 @@ class TestShardUtils(TestCase):
out_offsets, out_offsets,
in_split_sizes, in_split_sizes,
): ):
for my_rank in range(world_size): for my_rank in range(world_size):
_in_split_sizes = in_split_sizes[my_rank] _in_split_sizes = in_split_sizes[my_rank]
_out_split_sizes = [ _out_split_sizes = [

View File

@ -847,7 +847,6 @@ class NNModuleTests(torch._dynamo.test_case.TestCase):
torch._dynamo.config.traceable_tensor_subclasses.add(TensorProxy) torch._dynamo.config.traceable_tensor_subclasses.add(TensorProxy)
try: try:
x = torch.randn(1).as_subclass(TensorProxy) x = torch.randn(1).as_subclass(TensorProxy)
cnt = torch._dynamo.testing.CompileCounter() cnt = torch._dynamo.testing.CompileCounter()
out1 = foo(x) out1 = foo(x)
@ -862,7 +861,6 @@ class NNModuleTests(torch._dynamo.test_case.TestCase):
def test_torch_function_with_closure(self): def test_torch_function_with_closure(self):
def run(): def run():
counter = 0 counter = 0
def foo(x): def foo(x):
@ -1097,7 +1095,7 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
opt_mod = torch._dynamo.optimize("eager")(mod) opt_mod = torch._dynamo.optimize("eager")(mod)
# Check parameteres and buffers # Check parameteres and buffers
for (p1, p2) in zip(mod.parameters(), opt_mod.parameters()): for p1, p2 in zip(mod.parameters(), opt_mod.parameters()):
self.assertTrue(id(p1) == id(p2)) self.assertTrue(id(p1) == id(p2))
def test_recursion(self): def test_recursion(self):

View File

@ -1572,7 +1572,6 @@ class ReproTests(torch._dynamo.test_case.TestCase):
self.assertEqual(y, 10) self.assertEqual(y, 10)
def test_sort_out(self): def test_sort_out(self):
dtype = torch.float32 dtype = torch.float32
device = "cpu" device = "cpu"
@ -1607,7 +1606,6 @@ class ReproTests(torch._dynamo.test_case.TestCase):
self.assertTrue(same(ref, res)) self.assertTrue(same(ref, res))
def test_sigmoid_out(self): def test_sigmoid_out(self):
dtype = torch.float32 dtype = torch.float32
device = "cpu" device = "cpu"

View File

@ -178,7 +178,6 @@ class TestInductorConfig(TestCase):
a(torch.randn(10)) a(torch.randn(10))
def test_api_options(self): def test_api_options(self):
reduce_overhead_opts = torch._inductor.list_mode_options("reduce-overhead") reduce_overhead_opts = torch._inductor.list_mode_options("reduce-overhead")
self.assertEqual(reduce_overhead_opts["triton.cudagraphs"], True) self.assertEqual(reduce_overhead_opts["triton.cudagraphs"], True)

View File

@ -79,7 +79,6 @@ if HAS_CUDA and not TEST_WITH_ASAN:
class TestInductorDynamic(TestCase): class TestInductorDynamic(TestCase):
compile_fn = partial(torch.compile, dynamic=True) compile_fn = partial(torch.compile, dynamic=True)
def setUp(self): def setUp(self):

View File

@ -597,7 +597,6 @@ class TestInductorOpInfo(TestCase):
) )
except Exception as e: except Exception as e:
if test_expect is ExpectedTestResult.XFAILURE: if test_expect is ExpectedTestResult.XFAILURE:
raise e raise e

View File

@ -48,6 +48,7 @@ skipIfNoBFloat16Cuda = _skipper(
lambda: not torch.cuda.is_bf16_supported(), "BFloat16 CUDA is not available" lambda: not torch.cuda.is_bf16_supported(), "BFloat16 CUDA is not available"
) )
# skips tests for all versions below min_opset_version. # skips tests for all versions below min_opset_version.
# if exporting the op is only supported after a specific version, # if exporting the op is only supported after a specific version,
# add this wrapper to prevent running the test for opset_versions # add this wrapper to prevent running the test for opset_versions

View File

@ -494,7 +494,6 @@ class TestONNXOpset(pytorch_test_common.ExportTestCase):
("zeros", "border", "reflection"), ("zeros", "border", "reflection"),
(True, False), (True, False),
): ):
args = ( args = (
torch.randn(n, c, h_in, w_in), # x torch.randn(n, c, h_in, w_in), # x
torch.randn(n, h_out, w_out, 2), # grid, torch.randn(n, h_out, w_out, 2), # grid,

View File

@ -13,14 +13,12 @@ from torch.testing._internal import common_utils
class TestONNXScriptExport(common_utils.TestCase): class TestONNXScriptExport(common_utils.TestCase):
# opset version is # opset version is
# 1. local function is supported after opset 15 # 1. local function is supported after opset 15
# 2. onnx-script requires users to determine opset in local function # 2. onnx-script requires users to determine opset in local function
opset_version = 15 opset_version = 15
def test_onnxscript_registration_with_multiple_models(self): def test_onnxscript_registration_with_multiple_models(self):
from onnxscript.onnx_opset import opset15 as op from onnxscript.onnx_opset import opset15 as op
# 1. Register Selu onnxscript function as custom Op # 1. Register Selu onnxscript function as custom Op

View File

@ -12,14 +12,12 @@ from torch.testing._internal import common_utils
class TestONNXScriptRuntime(onnx_test_common._TestONNXRuntime): class TestONNXScriptRuntime(onnx_test_common._TestONNXRuntime):
# opset version is # opset version is
# 1. local function is supported after opset 15 # 1. local function is supported after opset 15
# 2. onnx-script requires users to determine opset in local function # 2. onnx-script requires users to determine opset in local function
opset_version = 15 opset_version = 15
def test_selu_from_onnxscript_example(self): def test_selu_from_onnxscript_example(self):
x = torch.randn(1, 2, 3, 4, requires_grad=True) x = torch.randn(1, 2, 3, 4, requires_grad=True)
model = torch.nn.SELU() model = torch.nn.SELU()
@ -52,7 +50,6 @@ class TestONNXScriptRuntime(onnx_test_common._TestONNXRuntime):
self.run_test(model, x) self.run_test(model, x)
def test_layer_norm(self): def test_layer_norm(self):
x = torch.randn(2, 3) x = torch.randn(2, 3)
y = torch.randn(2, 3) y = torch.randn(2, 3)
z = torch.randn(2, 3) z = torch.randn(2, 3)

View File

@ -30,9 +30,7 @@ def export_to_onnx(
model: Union[torch.nn.Module, torch.jit.ScriptFunction], model: Union[torch.nn.Module, torch.jit.ScriptFunction],
input: Union[torch.Tensor, Tuple[torch.Tensor]], input: Union[torch.Tensor, Tuple[torch.Tensor]],
custom_ops: Optional[ custom_ops: Optional[
Iterable[ Iterable[Union[contextlib.AbstractContextManager, contextlib.ContextDecorator]]
Union[contextlib.AbstractContextManager, contextlib.ContextDecorator],
]
] = None, ] = None,
mocks: Optional[Iterable] = None, mocks: Optional[Iterable] = None,
operator_export_type: torch.onnx.OperatorExportTypes = torch.onnx.OperatorExportTypes.ONNX, operator_export_type: torch.onnx.OperatorExportTypes = torch.onnx.OperatorExportTypes.ONNX,
@ -765,7 +763,6 @@ class TestONNXExport(pytorch_test_common.ExportTestCase):
) )
def test_dropout_script(self): def test_dropout_script(self):
eg = torch.zeros(1, 2, 3, requires_grad=True) eg = torch.zeros(1, 2, 3, requires_grad=True)
@jit_utils._trace(eg) @jit_utils._trace(eg)

View File

@ -8600,7 +8600,6 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
@skipIfUnsupportedMinOpsetVersion(9) @skipIfUnsupportedMinOpsetVersion(9)
def test_kldiv_loss(self): def test_kldiv_loss(self):
x = torch.rand(5).log() x = torch.rand(5).log()
y = torch.rand(5) y = torch.rand(5)
self._kldiv_loss(x, y) self._kldiv_loss(x, y)
@ -12832,7 +12831,6 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
name_fn=lambda align_corners: str(align_corners), name_fn=lambda align_corners: str(align_corners),
) )
def test_grid_sample(self, mode, padding_mode, align_corners): def test_grid_sample(self, mode, padding_mode, align_corners):
n, c, h_in, w_in, h_out, w_out = 1, 1, 3, 2, 2, 4 n, c, h_in, w_in, h_out, w_out = 1, 1, 3, 2, 2, 4
class GridSampleModule(torch.nn.Module): class GridSampleModule(torch.nn.Module):

View File

@ -328,7 +328,6 @@ class TestONNXCustomOpShapeInference(pytorch_test_common.ExportTestCase):
self.opset_version = _constants.ONNX_MAX_OPSET self.opset_version = _constants.ONNX_MAX_OPSET
def test_setType_maintains_output_shape_for_single_custom_op(self): def test_setType_maintains_output_shape_for_single_custom_op(self):
self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "::linalg_inv", 9) self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "::linalg_inv", 9)
class CustomInverse(torch.nn.Module): class CustomInverse(torch.nn.Module):
@ -363,7 +362,6 @@ class TestONNXCustomOpShapeInference(pytorch_test_common.ExportTestCase):
self.assertEqual(dim.dim_value, rank) self.assertEqual(dim.dim_value, rank)
def test_no_setType_for_single_custom_op(self): def test_no_setType_for_single_custom_op(self):
self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "::linalg_inv", 9) self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "::linalg_inv", 9)
class CustomInverse(torch.nn.Module): class CustomInverse(torch.nn.Module):
@ -398,7 +396,6 @@ class TestONNXCustomOpShapeInference(pytorch_test_common.ExportTestCase):
def test_setType_maintains_output_shape_for_single_custom_op_with_dynamic_axes( def test_setType_maintains_output_shape_for_single_custom_op_with_dynamic_axes(
self, self,
): ):
self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "::linalg_inv", 9) self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "::linalg_inv", 9)
class CustomInverse(torch.nn.Module): class CustomInverse(torch.nn.Module):
@ -438,7 +435,6 @@ class TestONNXCustomOpShapeInference(pytorch_test_common.ExportTestCase):
self.assertEqual(dims[i].dim_value, x.size()[i]) self.assertEqual(dims[i].dim_value, x.size()[i])
def test_setType_maintains_output_shape_for_single_custom_op_with_onnx_ops(self): def test_setType_maintains_output_shape_for_single_custom_op_with_onnx_ops(self):
self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "::linalg_inv", 9) self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "::linalg_inv", 9)
class CustomInverse(torch.nn.Module): class CustomInverse(torch.nn.Module):

View File

@ -133,6 +133,7 @@ ignores = [
ignores = [os.path.join(proj_dir, ignore) for ignore in ignores] ignores = [os.path.join(proj_dir, ignore) for ignore in ignores]
# Check if the compiler is hip-clang. # Check if the compiler is hip-clang.
def is_hip_clang() -> bool: def is_hip_clang() -> bool:
try: try:

View File

@ -5,6 +5,7 @@ from torchgen.api.autograd import NativeFunctionWithDifferentiabilityInfo as NFW
from torchgen.context import native_function_manager from torchgen.context import native_function_manager
from torchgen.utils import T from torchgen.utils import T
# Like tools.api.context.with_native_function, but for # Like tools.api.context.with_native_function, but for
# NativeFunctionWithDifferentiabilityInfo. # NativeFunctionWithDifferentiabilityInfo.
def with_native_function_with_differentiability_info( def with_native_function_with_differentiability_info(

View File

@ -420,7 +420,6 @@ UNTRACEABLE_FUNCTIONS = VIEW_FUNCTIONS
def get_infos_with_derivatives_list( def get_infos_with_derivatives_list(
differentiability_infos: Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]] differentiability_infos: Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]]
) -> List[DifferentiabilityInfo]: ) -> List[DifferentiabilityInfo]:
diff_info_list = [ diff_info_list = [
info info
for diffinfo_dict in differentiability_infos.values() for diffinfo_dict in differentiability_infos.values()
@ -469,7 +468,6 @@ def gen_autograd_functions_python(
differentiability_infos: Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]], differentiability_infos: Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]],
template_path: str, template_path: str,
) -> None: ) -> None:
fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
num_shards = 5 num_shards = 5
fm.write( fm.write(

View File

@ -221,6 +221,7 @@ ${assign_return_values} ([&]() {
TMP_VAR = "_tmp" TMP_VAR = "_tmp"
# FIXME: Ideally these functions should be methods on Type class, but we have a # FIXME: Ideally these functions should be methods on Type class, but we have a
# comment in codegen/model.py there saying these concepts are not well defined. # comment in codegen/model.py there saying these concepts are not well defined.
# Thus we put a version that commonly used by autograd codegen here. # Thus we put a version that commonly used by autograd codegen here.
@ -321,7 +322,8 @@ def emit_view_call(
def emit_view_lambda(f: NativeFunction, unpacked_bindings: List[Binding]) -> str: def emit_view_lambda(f: NativeFunction, unpacked_bindings: List[Binding]) -> str:
"""Generate an additional lambda function to recover views in backward when as_strided is not supported. """Generate an additional lambda function to recover views in backward when as_strided is not supported.
See Note [View + Inplace update for base tensor] and [View + Inplace update for view tensor] for more details.""" See Note [View + Inplace update for base tensor] and [View + Inplace update for view tensor] for more details.
"""
input_base = "input_base" input_base = "input_base"
replay_view_func = "" replay_view_func = ""
updated_unpacked_args: List[str] = [] updated_unpacked_args: List[str] = []

View File

@ -17,6 +17,7 @@ from torchgen.utils import FileManager, mapMaybe
OPTIONAL_TYPE_PATTERN = re.compile(r"c10::optional<(.+)>") OPTIONAL_TYPE_PATTERN = re.compile(r"c10::optional<(.+)>")
TYPE_PATTERN = re.compile(r"(?:const\s+)?([A-Z]\w+)") TYPE_PATTERN = re.compile(r"(?:const\s+)?([A-Z]\w+)")
# Add 'at::' to types defined in ATen namespace, e.g. Tensor, TensorList, IntArrayRef and etc. # Add 'at::' to types defined in ATen namespace, e.g. Tensor, TensorList, IntArrayRef and etc.
# TODO: maybe update the cpp argument API to take optional namespace argument? # TODO: maybe update the cpp argument API to take optional namespace argument?
def fully_qualified_type(argument_type: str) -> str: def fully_qualified_type(argument_type: str) -> str:

View File

@ -761,7 +761,6 @@ def gen_variable_type(
template_path: str, template_path: str,
used_keys: Set[str], used_keys: Set[str],
) -> None: ) -> None:
"""VariableType.h and VariableType.cpp body """VariableType.h and VariableType.cpp body
This is the at::Type subclass for differentiable tensors. The This is the at::Type subclass for differentiable tensors. The

View File

@ -52,6 +52,7 @@ _GLOBAL_LOAD_DERIVATIVE_CACHE = {}
_VALID_AUTOGRAD_KEYS = set(AUTOGRAD_KEYS) _VALID_AUTOGRAD_KEYS = set(AUTOGRAD_KEYS)
# This function directly adds per-dispatchkey derivative entries for {view}_copy variants of each view op. # This function directly adds per-dispatchkey derivative entries for {view}_copy variants of each view op.
# Since every {view} and {view}_copy op shares the same derivative formula, # Since every {view} and {view}_copy op shares the same derivative formula,
# we generate them here instead of duplicating them in the yaml. # we generate them here instead of duplicating them in the yaml.
@ -96,7 +97,6 @@ def load_derivatives(
global _GLOBAL_LOAD_DERIVATIVE_CACHE global _GLOBAL_LOAD_DERIVATIVE_CACHE
key = (derivatives_yaml_path, native_yaml_path) key = (derivatives_yaml_path, native_yaml_path)
if key not in _GLOBAL_LOAD_DERIVATIVE_CACHE: if key not in _GLOBAL_LOAD_DERIVATIVE_CACHE:
with open(derivatives_yaml_path, "r") as f: with open(derivatives_yaml_path, "r") as f:
definitions = yaml.load(f, Loader=YamlLoader) definitions = yaml.load(f, Loader=YamlLoader)

View File

@ -183,7 +183,6 @@ def create_debug_info_from_selected_models(
selected_models: List[dict], selected_models: List[dict],
new_style_rule: bool, new_style_rule: bool,
): ):
model_dict = { model_dict = {
"asset_info": {}, # maps asset name -> dict of asset metadata like hashes "asset_info": {}, # maps asset name -> dict of asset metadata like hashes
"is_new_style_rule": new_style_rule, "is_new_style_rule": new_style_rule,
@ -465,13 +464,13 @@ def fill_output(output: Dict[str, object], options: object):
# to True, since it indicates that this operator list came from something # to True, since it indicates that this operator list came from something
# other than a traced operator list. # other than a traced operator list.
include_all_non_op_selectives = False include_all_non_op_selectives = False
for (op_name, op_info) in operators.items(): for op_name, op_info in operators.items():
include_all_non_op_selectives = ( include_all_non_op_selectives = (
include_all_non_op_selectives or op_info.include_all_overloads include_all_non_op_selectives or op_info.include_all_overloads
) )
operators_as_dict = {} operators_as_dict = {}
for (k, v) in operators.items(): for k, v in operators.items():
operators_as_dict[k] = v.to_dict() operators_as_dict[k] = v.to_dict()
output["operators"] = operators_as_dict output["operators"] = operators_as_dict

View File

@ -18,14 +18,14 @@ from torchgen.selective_build.selector import (
def extract_all_operators(selective_builder: SelectiveBuilder) -> Set[str]: def extract_all_operators(selective_builder: SelectiveBuilder) -> Set[str]:
ops = [] ops = []
for (op_name, op) in selective_builder.operators.items(): for op_name, op in selective_builder.operators.items():
ops.append(op_name) ops.append(op_name)
return set(ops) return set(ops)
def extract_training_operators(selective_builder: SelectiveBuilder) -> Set[str]: def extract_training_operators(selective_builder: SelectiveBuilder) -> Set[str]:
ops = [] ops = []
for (op_name, op) in selective_builder.operators.items(): for op_name, op in selective_builder.operators.items():
if op.is_used_for_training: if op.is_used_for_training:
ops.append(op_name) ops.append(op_name)
return set(ops) return set(ops)
@ -33,7 +33,7 @@ def extract_training_operators(selective_builder: SelectiveBuilder) -> Set[str]:
def throw_if_any_op_includes_overloads(selective_builder: SelectiveBuilder) -> None: def throw_if_any_op_includes_overloads(selective_builder: SelectiveBuilder) -> None:
ops = [] ops = []
for (op_name, op) in selective_builder.operators.items(): for op_name, op in selective_builder.operators.items():
if op.include_all_overloads: if op.include_all_overloads:
ops.append(op_name) ops.append(op_name)
if ops: if ops:
@ -47,7 +47,6 @@ def throw_if_any_op_includes_overloads(selective_builder: SelectiveBuilder) -> N
def gen_supported_mobile_models(model_dicts: List[Any], output_dir: str) -> None: def gen_supported_mobile_models(model_dicts: List[Any], output_dir: str) -> None:
supported_mobile_models_source = """/* supported_mobile_models_source = """/*
* Generated by gen_oplist.py * Generated by gen_oplist.py
*/ */

View File

@ -38,7 +38,6 @@ def is_this_type_of_tests(target_name: str, test_set_by_type: Set[str]) -> bool:
def print_test_by_type( def print_test_by_type(
tests: TestList, test_set_by_type: Set[str], type_name: str, summary_file: IO[str] tests: TestList, test_set_by_type: Set[str], type_name: str, summary_file: IO[str]
) -> None: ) -> None:
print("Tests " + type_name + " to collect coverage:", file=summary_file) print("Tests " + type_name + " to collect coverage:", file=summary_file)
for test in tests: for test in tests:
if is_this_type_of_tests(test.name, test_set_by_type): if is_this_type_of_tests(test.name, test_set_by_type):

View File

@ -22,6 +22,7 @@ result = subprocess.run(
PYTORCH_ROOT = result.stdout.decode("utf-8").strip() PYTORCH_ROOT = result.stdout.decode("utf-8").strip()
IS_WINDOWS: bool = os.name == "nt" IS_WINDOWS: bool = os.name == "nt"
# Returns '/usr/local/include/python<version number>' # Returns '/usr/local/include/python<version number>'
def get_python_include_dir() -> str: def get_python_include_dir() -> str:
return gp()["include"] return gp()["include"]
@ -147,7 +148,7 @@ def check_file(
proc = run_command( proc = run_command(
[binary, f"-p={build_dir}", *include_args, filename], [binary, f"-p={build_dir}", *include_args, filename],
) )
except (OSError) as err: except OSError as err:
return [ return [
LintMessage( LintMessage(
path=filename, path=filename,

View File

@ -47,7 +47,7 @@ selected_mobile_ops_preamble = """#pragma once
def extract_root_operators(selective_builder: SelectiveBuilder) -> Set[str]: def extract_root_operators(selective_builder: SelectiveBuilder) -> Set[str]:
ops = [] ops = []
for (op_name, op) in selective_builder.operators.items(): for op_name, op in selective_builder.operators.items():
if op.is_root_operator: if op.is_root_operator:
ops.append(op_name) ops.append(op_name)
return set(ops) return set(ops)

View File

@ -142,7 +142,6 @@ def _format_rule_for_cpp(rule: _RuleType) -> str:
def gen_diagnostics_python( def gen_diagnostics_python(
rules: Sequence[_RuleType], out_py_dir: str, template_dir: str rules: Sequence[_RuleType], out_py_dir: str, template_dir: str
) -> None: ) -> None:
rule_class_lines = [_format_rule_for_python_class(rule) for rule in rules] rule_class_lines = [_format_rule_for_python_class(rule) for rule in rules]
rule_field_lines = [_format_rule_for_python_field(rule) for rule in rules] rule_field_lines = [_format_rule_for_python_field(rule) for rule in rules]
@ -165,7 +164,6 @@ def gen_diagnostics_python(
def gen_diagnostics_cpp( def gen_diagnostics_cpp(
rules: Sequence[_RuleType], out_cpp_dir: str, template_dir: str rules: Sequence[_RuleType], out_cpp_dir: str, template_dir: str
) -> None: ) -> None:
rule_lines = [_format_rule_for_cpp(rule) for rule in rules] rule_lines = [_format_rule_for_cpp(rule) for rule in rules]
rule_names = [f'"{_kebab_case_to_snake_case(rule["name"])}",' for rule in rules] rule_names = [f'"{_kebab_case_to_snake_case(rule["name"])}",' for rule in rules]
@ -206,7 +204,6 @@ def gen_diagnostics(
out_cpp_dir: str, out_cpp_dir: str,
out_docs_dir: str, out_docs_dir: str,
) -> None: ) -> None:
with open(rules_path, "r") as f: with open(rules_path, "r") as f:
rules = yaml.load(f, Loader=torchgen_utils.YamlLoader) rules = yaml.load(f, Loader=torchgen_utils.YamlLoader)

View File

@ -41,7 +41,7 @@ def apply_replacements(replacements: Dict[str, str], text: str) -> str:
Returns: Returns:
Text with replacements applied, if any. Text with replacements applied, if any.
""" """
for (before, after) in replacements.items(): for before, after in replacements.items():
text = text.replace(before, after) text = text.replace(before, after)
return text return text

View File

@ -54,7 +54,6 @@ def generate_code(
operator_selector = SelectiveBuilder.get_nop_selector() operator_selector = SelectiveBuilder.get_nop_selector()
if subset == "libtorch" or not subset: if subset == "libtorch" or not subset:
gen_autograd( gen_autograd(
native_functions_path or NATIVE_FUNCTIONS_PATH, native_functions_path or NATIVE_FUNCTIONS_PATH,
tags_path or TAGS_PATH, tags_path or TAGS_PATH,

View File

@ -134,7 +134,6 @@ def rocm_get_per_process_gpu_info() -> List[Dict[str, Any]]:
if __name__ == "__main__": if __name__ == "__main__":
handle = None handle = None
try: try:
pynvml.nvmlInit() pynvml.nvmlInit()

View File

@ -132,7 +132,10 @@ def upload_to_s3(
json.dump(doc, body) json.dump(doc, body)
body.write("\n") body.write("\n")
S3_RESOURCE.Object(f"{bucket_name}", f"{key}",).put( S3_RESOURCE.Object(
f"{bucket_name}",
f"{key}",
).put(
Body=gzip.compress(body.getvalue().encode()), Body=gzip.compress(body.getvalue().encode()),
ContentEncoding="gzip", ContentEncoding="gzip",
ContentType="application/json", ContentType="application/json",

View File

@ -13,6 +13,7 @@ from torchgen.gen_backend_stubs import run
path = os.path.dirname(os.path.realpath(__file__)) path = os.path.dirname(os.path.realpath(__file__))
gen_backend_stubs_path = os.path.join(path, "../torchgen/gen_backend_stubs.py") gen_backend_stubs_path = os.path.join(path, "../torchgen/gen_backend_stubs.py")
# gen_backend_stubs.py is an integration point that is called directly by external backends. # gen_backend_stubs.py is an integration point that is called directly by external backends.
# The tests here are to confirm that badly formed inputs result in reasonable error messages. # The tests here are to confirm that badly formed inputs result in reasonable error messages.
class TestGenBackendStubs(expecttest.TestCase): class TestGenBackendStubs(expecttest.TestCase):

View File

@ -9,7 +9,6 @@ import os
def main() -> None: def main() -> None:
target = os.path.join("torch", "masked", "_docs.py") target = os.path.join("torch", "masked", "_docs.py")
try: try:

View File

@ -169,6 +169,7 @@ def get_decompositions(
import torch._decomp.decompositions import torch._decomp.decompositions
import torch._refs import torch._refs
# This list was copied from torch/_inductor/decomposition.py # This list was copied from torch/_inductor/decomposition.py
# excluding decompositions that results in prim ops # excluding decompositions that results in prim ops
# Resulting opset of decomposition is core aten ops # Resulting opset of decomposition is core aten ops

View File

@ -88,6 +88,7 @@ pw_cast_for_int_to_real = partial(
type_casts, type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT type_casts, type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
) )
# This expands x until x.dim() == dim. Might be useful as an operator # This expands x until x.dim() == dim. Might be useful as an operator
def _unsqueeze_to_dim(x: Tensor, dim: int): def _unsqueeze_to_dim(x: Tensor, dim: int):
for _ in range(dim - x.dim()): for _ in range(dim - x.dim()):
@ -619,7 +620,6 @@ def slice_forward(
end: Optional[int] = None, end: Optional[int] = None,
step: int = 1, step: int = 1,
): ):
ndim = self.dim() ndim = self.dim()
if ndim == 0: if ndim == 0:
raise RuntimeError("slice() cannot be applied to a 0-dim tensor.") raise RuntimeError("slice() cannot be applied to a 0-dim tensor.")

View File

@ -86,6 +86,7 @@ def _register_jit_decomposition_for_jvp(decomp, use_python=False):
# The only decompositions here are temporary or hacks for the purposes of jvp # The only decompositions here are temporary or hacks for the purposes of jvp
# TODO: do these also belong here? # TODO: do these also belong here?
@maybe_register_decomposition(aten.trace.default) @maybe_register_decomposition(aten.trace.default)
def trace(self: Tensor) -> Tensor: def trace(self: Tensor) -> Tensor:

View File

@ -288,7 +288,6 @@ def _compile(
hooks: Hooks, hooks: Hooks,
frame: Optional[types.FrameType] = None, frame: Optional[types.FrameType] = None,
) -> Optional[GuardedCode]: ) -> Optional[GuardedCode]:
output: Optional[OutputGraph] = None output: Optional[OutputGraph] = None
# This is shared across restarts # This is shared across restarts
mutated_closure_cell_contents: Set[str] = set() mutated_closure_cell_contents: Set[str] = set()

View File

@ -156,7 +156,6 @@ def filter_stack(stack):
def format_error_msg(exc, code, record_filename=None, frame=None): def format_error_msg(exc, code, record_filename=None, frame=None):
msg = os.linesep * 2 msg = os.linesep * 2
if config.verbose: if config.verbose:

View File

@ -11,6 +11,7 @@ logging.addLevelName(logging.CODE, "CODE")
# Disable progress bar by default, not in dynamo config because otherwise get a circular import # Disable progress bar by default, not in dynamo config because otherwise get a circular import
disable_progress = True disable_progress = True
# Return all loggers that torchdynamo/torchinductor is responsible for # Return all loggers that torchdynamo/torchinductor is responsible for
def get_loggers(): def get_loggers():
return [ return [

View File

@ -355,7 +355,7 @@ class BuiltinVariable(VariableTracker):
return None return None
# Return first handler that matches the type checks # Return first handler that matches the type checks
for ((type1, type2), handler) in handlers[op]: for (type1, type2), handler in handlers[op]:
if isinstance(a, type1) and isinstance(b, type2): if isinstance(a, type1) and isinstance(b, type2):
return handler return handler
@ -641,7 +641,6 @@ class BuiltinVariable(VariableTracker):
) )
for i in [a, b] for i in [a, b]
): ):
if any([isinstance(val, FakeItemVariable) for val in [a, b]]): if any([isinstance(val, FakeItemVariable) for val in [a, b]]):
return variables.FakeItemVariable.from_tensor_variable(result) return variables.FakeItemVariable.from_tensor_variable(result)
@ -678,7 +677,6 @@ class BuiltinVariable(VariableTracker):
) )
return SymNodeVariable.create(tx, proxy, None) return SymNodeVariable.create(tx, proxy, None)
else: else:
unimplemented(f"unsupported min / max over args {str(a)}, {str(b)}") unimplemented(f"unsupported min / max over args {str(a)}, {str(b)}")
call_min = _call_min_max call_min = _call_min_max

View File

@ -73,6 +73,7 @@ constant_fold_functions = [
if torch.distributed.is_available(): if torch.distributed.is_available():
constant_fold_functions.append(torch.distributed.is_initialized) constant_fold_functions.append(torch.distributed.is_initialized)
# TODO(voz): perhaps a decorator? This is rather readable for now tho, and not a public API. # TODO(voz): perhaps a decorator? This is rather readable for now tho, and not a public API.
def remap_as_fn___radd__(*args): def remap_as_fn___radd__(*args):
return torch._C._TensorBase.__radd__(*args) return torch._C._TensorBase.__radd__(*args)

View File

@ -412,7 +412,8 @@ class KernelArgs:
class CSEVariable: class CSEVariable:
"""A CSEVariable is just a name for an expression but it is useful to be able to annotate them on a backend dependent basis. """A CSEVariable is just a name for an expression but it is useful to be able to annotate them on a backend dependent basis.
The backends can inherit from this class and overload the "create_cse_var" Kernel to do that. The backends can inherit from this class and overload the "create_cse_var" Kernel to do that.
The "update_on_args" method gives you a hook for annotations, see example of TritonCSEVariable in triton.py.""" The "update_on_args" method gives you a hook for annotations, see example of TritonCSEVariable in triton.py.
"""
def __init__(self, name): def __init__(self, name):
self.name = name self.name = name

View File

@ -1535,7 +1535,6 @@ class TritonScheduling:
@contextlib.contextmanager @contextlib.contextmanager
def end_current_reduction_loop(): def end_current_reduction_loop():
if current_loop_writes: if current_loop_writes:
# flush out any other runnable nodes to reduce number of loops # flush out any other runnable nodes to reduce number of loops
for other_node in nodes[index + 1 :]: for other_node in nodes[index + 1 :]:

View File

@ -183,7 +183,6 @@ class cpp:
# config specific to codegen/triton.py # config specific to codegen/triton.py
class triton: class triton:
# Use cudagraphs on output code # Use cudagraphs on output code
cudagraphs = False cudagraphs = False

View File

@ -1,5 +1,6 @@
import torch import torch
# Check the pattern: (nn.module, F.function/torch.Tensor.method) matched. # Check the pattern: (nn.module, F.function/torch.Tensor.method) matched.
# Works for length 2 patterns with 1 module and 1 function/method. # Works for length 2 patterns with 1 module and 1 function/method.
def matches_module_function_pattern(pattern, node, modules): def matches_module_function_pattern(pattern, node, modules):

View File

@ -802,7 +802,6 @@ class Reduction(Loops):
reduction_numel = V.graph.sizevars.simplify(sympy_product(reduction_ranges)) reduction_numel = V.graph.sizevars.simplify(sympy_product(reduction_ranges))
if reduction_numel == 0: if reduction_numel == 0:
# N.B. This is a hack to generate the literal of the given type # N.B. This is a hack to generate the literal of the given type
# Ideally, we should be fixing `def constant` in triton.py # Ideally, we should be fixing `def constant` in triton.py
# but it breaks due to hardcoded dtypes in other places # but it breaks due to hardcoded dtypes in other places
@ -1252,7 +1251,6 @@ class PermuteView(BaseView):
class SqueezeView(BaseView): class SqueezeView(BaseView):
@classmethod @classmethod
def create(cls, x, *, dim=None): def create(cls, x, *, dim=None):
if is_storage_and_layout(x): if is_storage_and_layout(x):
storage, old_layout = as_storage_and_layout(x) storage, old_layout = as_storage_and_layout(x)
new_size = [] new_size = []
@ -3828,7 +3826,12 @@ class ConvolutionTransposeUnary(ExternKernelAlloc):
): ):
kernel = "torch.ops.mkldnn._convolution_transpose_pointwise" kernel = "torch.ops.mkldnn._convolution_transpose_pointwise"
transposed = True transposed = True
(inputs, constant_args, kernel_layout, _,) = _prepare_convolution_fusion_create( (
inputs,
constant_args,
kernel_layout,
_,
) = _prepare_convolution_fusion_create(
cls, cls,
x, x,
weight, weight,

View File

@ -2144,7 +2144,6 @@ def scatter(x, dim: int, index, src, **kwargs):
def scatter_fallback( def scatter_fallback(
fn, self, dim: int, index, src, *, reduce: str = None, include_self: bool = True fn, self, dim: int, index, src, *, reduce: str = None, include_self: bool = True
): ):
if reduce not in {None, "sum"} or ( if reduce not in {None, "sum"} or (
reduce == "sum" and self.get_dtype() in {torch.bool, torch.int64} reduce == "sum" and self.get_dtype() in {torch.bool, torch.int64}
): ):
@ -2158,7 +2157,6 @@ def scatter_fallback(
@register_lowering(aten.scatter_, type_promotion_kind=None) @register_lowering(aten.scatter_, type_promotion_kind=None)
def scatter_(self, dim: int, index, src, *, reduce: str = None): def scatter_(self, dim: int, index, src, *, reduce: str = None):
if reduce == "add": if reduce == "add":
reduce = "sum" reduce = "sum"
elif reduce == "multiply": elif reduce == "multiply":
@ -2674,7 +2672,6 @@ def constant_boundary_condition_2d(x, fill_value, padding):
def pooling_size(x, i, kernel_size, stride, padding, ceil_mode): def pooling_size(x, i, kernel_size, stride, padding, ceil_mode):
x_out = ir.FloorDiv( x_out = ir.FloorDiv(
x + 2 * padding[i] - (kernel_size[i] - 1) + (stride[i] - 1), stride[i] x + 2 * padding[i] - (kernel_size[i] - 1) + (stride[i] - 1), stride[i]
) )
@ -3212,7 +3209,6 @@ def avg_pool2d_backward(
count_include_pad, count_include_pad,
divisor_override=None, divisor_override=None,
): ):
assert not divisor_override assert not divisor_override
if not stride: if not stride:
stride = kernel_size stride = kernel_size

View File

@ -441,7 +441,6 @@ def shape_of_mm(a, b):
CallFunction(aten.cat, ListOf(CallFunction(aten.mm, Arg(), Arg())), Arg()), CallFunction(aten.cat, ListOf(CallFunction(aten.mm, Arg(), Arg())), Arg()),
) )
def cat_mm(match, inputs, dim): def cat_mm(match, inputs, dim):
return cat_tuned_op(match, inputs, dim, op=L[aten.mm], shape_of=shape_of_mm) return cat_tuned_op(match, inputs, dim, op=L[aten.mm], shape_of=shape_of_mm)

View File

@ -129,7 +129,6 @@ if has_triton():
# allocate accumulator # allocate accumulator
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
for crs in range(0, CRS, BLOCK_K): for crs in range(0, CRS, BLOCK_K):
# ------ matrix multiplication ------ # ------ matrix multiplication ------
acc += tl.dot(matrix_x, matrix_w) acc += tl.dot(matrix_x, matrix_w)
# ------ update ptrs ------ # ------ update ptrs ------
@ -306,7 +305,6 @@ if has_triton():
# allocate accumulator # allocate accumulator
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
for crs in range(0, CRS, BLOCK_K): for crs in range(0, CRS, BLOCK_K):
# ------ matrix multiplication ------ # ------ matrix multiplication ------
acc += tl.dot(matrix_x, matrix_w) acc += tl.dot(matrix_x, matrix_w)
# ------ update ptrs ------ # ------ update ptrs ------

View File

@ -3,7 +3,6 @@ import torch
from ..utils import has_triton from ..utils import has_triton
if has_triton(): if has_triton():
import triton import triton
class _conv1x1: class _conv1x1:

View File

@ -273,7 +273,6 @@ class LOBPCGAutogradFunction(torch.autograd.Function):
ortho_fparams: Optional[Dict[str, float]] = None, ortho_fparams: Optional[Dict[str, float]] = None,
ortho_bparams: Optional[Dict[str, bool]] = None, ortho_bparams: Optional[Dict[str, bool]] = None,
) -> Tuple[Tensor, Tensor]: ) -> Tuple[Tensor, Tensor]:
# makes sure that input is contiguous for efficiency. # makes sure that input is contiguous for efficiency.
# Note: autograd does not support dense gradients for sparse input yet. # Note: autograd does not support dense gradients for sparse input yet.
A = A.contiguous() if (not A.is_sparse) else A A = A.contiguous() if (not A.is_sparse) else A
@ -360,7 +359,6 @@ def lobpcg(
ortho_fparams: Optional[Dict[str, float]] = None, ortho_fparams: Optional[Dict[str, float]] = None,
ortho_bparams: Optional[Dict[str, bool]] = None, ortho_bparams: Optional[Dict[str, bool]] = None,
) -> Tuple[Tensor, Tensor]: ) -> Tuple[Tensor, Tensor]:
"""Find the k largest (or smallest) eigenvalues and the corresponding """Find the k largest (or smallest) eigenvalues and the corresponding
eigenvectors of a symmetric positive definite generalized eigenvectors of a symmetric positive definite generalized
eigenvalue problem using matrix-free LOBPCG methods. eigenvalue problem using matrix-free LOBPCG methods.
@ -598,7 +596,6 @@ def _lobpcg(
ortho_fparams: Optional[Dict[str, float]] = None, ortho_fparams: Optional[Dict[str, float]] = None,
ortho_bparams: Optional[Dict[str, bool]] = None, ortho_bparams: Optional[Dict[str, bool]] = None,
) -> Tuple[Tensor, Tensor]: ) -> Tuple[Tensor, Tensor]:
# A must be square: # A must be square:
assert A.shape[-2] == A.shape[-1], A.shape assert A.shape[-2] == A.shape[-1], A.shape
if B is not None: if B is not None:
@ -707,7 +704,6 @@ class LOBPCG:
method: str, method: str,
tracker: None, tracker: None,
) -> None: ) -> None:
# constant parameters # constant parameters
self.A = A self.A = A
self.B = B self.B = B
@ -833,7 +829,6 @@ class LOBPCG:
self.call_tracker() self.call_tracker()
while not self.stop_iteration(): while not self.stop_iteration():
self.update() self.update()
if not torch.jit.is_scripting() and self.tracker is not None: if not torch.jit.is_scripting() and self.tracker is not None:

View File

@ -2486,7 +2486,6 @@ def _cudnn_rnn(
batch_sizes, batch_sizes,
dropout_state, dropout_state,
): ):
is_input_packed = len(batch_sizes) != 0 is_input_packed = len(batch_sizes) != 0
if is_input_packed: if is_input_packed:
seq_length = len(batch_sizes) seq_length = len(batch_sizes)
@ -2773,7 +2772,6 @@ import torch._refs.special
def activate_meta(): def activate_meta():
activate_meta_table = {} activate_meta_table = {}
# For a given op, we pick the most specific decomp function from # For a given op, we pick the most specific decomp function from

View File

@ -135,6 +135,7 @@ is_included_in_alias = torch._C._dispatch_is_included_in_alias
DispatchKey = torch._C.DispatchKey DispatchKey = torch._C.DispatchKey
# Equivalent to computeDispatchTableEntryWithDebug # Equivalent to computeDispatchTableEntryWithDebug
def resolve_key(op: OperatorBase, k: DispatchKey): # type: ignore[valid-type] def resolve_key(op: OperatorBase, k: DispatchKey): # type: ignore[valid-type]
# 1. (Direct) operator registration # 1. (Direct) operator registration

View File

@ -940,6 +940,7 @@ bitwise_xor = _make_elementwise_binary_prim(
# doc="", # doc="",
# ) # )
# div prim performs truncation division on integer inputs # div prim performs truncation division on integer inputs
# and true division for floating and complex inputs # and true division for floating and complex inputs
def _div_aten(a, b): def _div_aten(a, b):
@ -1151,6 +1152,7 @@ zeta = _make_elementwise_binary_prim(
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
) )
# #
# View operations # View operations
def _as_strided_meta( def _as_strided_meta(
@ -1701,6 +1703,7 @@ split_dim = _make_prim(
doc=_split_dim_doc, doc=_split_dim_doc,
) )
# Note: allows dimensions to be specified redundantly # Note: allows dimensions to be specified redundantly
def _squeeze_meta(a: TensorLikeType, dimensions: Sequence) -> TensorLikeType: def _squeeze_meta(a: TensorLikeType, dimensions: Sequence) -> TensorLikeType:
assert isinstance(a, TensorLike) assert isinstance(a, TensorLike)
@ -1980,7 +1983,6 @@ rev = _make_prim(
def _where_meta( def _where_meta(
pred: TensorLikeType, a: TensorLikeType, b: TensorLikeType pred: TensorLikeType, a: TensorLikeType, b: TensorLikeType
) -> TensorLikeType: ) -> TensorLikeType:
return _elementwise_meta( return _elementwise_meta(
a, a,
b, b,
@ -2004,6 +2006,7 @@ where = _make_prim(
doc=_where_doc, doc=_where_doc,
) )
# #
# Type conversions # Type conversions
# #
@ -2022,7 +2025,6 @@ def _convert_element_type_meta(a: TensorLikeType, dtype: torch.dtype) -> TensorL
def _convert_element_type_aten(a: Tensor, dtype: torch.dtype) -> Tensor: def _convert_element_type_aten(a: Tensor, dtype: torch.dtype) -> Tensor:
# Propagates requires grad when possible # Propagates requires grad when possible
if not utils.is_grad_dtype(dtype): if not utils.is_grad_dtype(dtype):
requires_grad = False requires_grad = False
@ -2078,6 +2080,7 @@ device_put = _make_prim(
doc=_device_put_doc, doc=_device_put_doc,
) )
# NOTE: need to model meta scalars # NOTE: need to model meta scalars
# See https://github.com/pytorch/pytorch/issues/78070 # See https://github.com/pytorch/pytorch/issues/78070
def _item_meta(a: TensorLikeType) -> FakeTensor: def _item_meta(a: TensorLikeType) -> FakeTensor:
@ -2100,6 +2103,7 @@ item = _make_prim(
doc=_item_doc, doc=_item_doc,
) )
# NOTE: need to model meta scalars # NOTE: need to model meta scalars
# See https://github.com/pytorch/pytorch/issues/78070 # See https://github.com/pytorch/pytorch/issues/78070
def _maximum_value_meta(dtype: torch.dtype) -> FakeTensor: def _maximum_value_meta(dtype: torch.dtype) -> FakeTensor:
@ -2732,6 +2736,7 @@ svd = _make_prim(
# Randomness Prims # Randomness Prims
# #
# TODO: add generator support # TODO: add generator support
# NOTE: there is currently no way of acquiring the "default" torch generator # NOTE: there is currently no way of acquiring the "default" torch generator
def _normal_meta( def _normal_meta(

View File

@ -60,6 +60,7 @@ DEFAULT_NVFUSER_PYTHON_CONFIG = MappingProxyType(
} }
) )
# nvFuserTensorTemplate and nvFuserScalarTemplate are helper objects # nvFuserTensorTemplate and nvFuserScalarTemplate are helper objects
# for cached construction of the nvFuser's Fusion # for cached construction of the nvFuser's Fusion
# TODO: change what is stored in the cache for nvFuser's Tensor objects # TODO: change what is stored in the cache for nvFuser's Tensor objects
@ -258,7 +259,6 @@ def nvfuser_execute(gm: GraphModule, *args, executor_parameters=None):
) )
for arg in flat_args for arg in flat_args
): ):
# Construction of the fusion is expensive and cached based on the GraphModule # Construction of the fusion is expensive and cached based on the GraphModule
# and symbolic nvFuser args. # and symbolic nvFuser args.
nv_template_args = to_nvfuser_template_args(flat_args) nv_template_args = to_nvfuser_template_args(flat_args)

View File

@ -223,7 +223,6 @@ _nvfuser_impls["{fname}"] = _{fname}_nvfuser
def _native_batch_norm_nvfuser( def _native_batch_norm_nvfuser(
fd, input, weight, bias, running_mean, running_var, training, momentum, eps fd, input, weight, bias, running_mean, running_var, training, momentum, eps
): ):
""" """
if weight is None: if weight is None:
weight = fd.define_null_tensor() weight = fd.define_null_tensor()
@ -565,7 +564,6 @@ def register_native_batch_norm():
momentum: float, momentum: float,
eps: float, eps: float,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if torch._prims_common.is_complex_dtype(input.dtype): if torch._prims_common.is_complex_dtype(input.dtype):
raise NotImplementedError("Complex tensors are not supported") raise NotImplementedError("Complex tensors are not supported")

View File

@ -379,6 +379,7 @@ from torch._decomp import register_decomposition
infer_aten_op = object() infer_aten_op = object()
# TODO: add type promotion support # TODO: add type promotion support
def _make_elementwise_unary_reference( def _make_elementwise_unary_reference(
type_promotion_kind, type_promotion_kind,
@ -556,7 +557,6 @@ def exp2(a):
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH,
) )
def fill(a: TensorLikeType, value: NumberType) -> TensorLikeType: def fill(a: TensorLikeType, value: NumberType) -> TensorLikeType:
assert isinstance(a, TensorLike) assert isinstance(a, TensorLike)
assert isinstance(value, Number) assert isinstance(value, Number)
@ -1118,7 +1118,6 @@ def float_power(
a: Union[TensorLikeType, NumberType], a: Union[TensorLikeType, NumberType],
b: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType],
) -> Tensor: ) -> Tensor:
if isinstance(a, Number) and isinstance(b, Number): if isinstance(a, Number) and isinstance(b, Number):
raise ValueError( raise ValueError(
"Receive two Number inputs to an elementwise binary operation!" "Receive two Number inputs to an elementwise binary operation!"
@ -1168,6 +1167,7 @@ def float_power(
# For reference, see CPython's implementation: # For reference, see CPython's implementation:
# https://github.com/python/cpython/blob/ace008c531dd685a30c1dd68f9b5ba35f20171cf/Objects/floatobject.c#L636 # https://github.com/python/cpython/blob/ace008c531dd685a30c1dd68f9b5ba35f20171cf/Objects/floatobject.c#L636
# TODO: add docstring # TODO: add docstring
@_make_elementwise_binary_reference( @_make_elementwise_binary_reference(
type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
@ -1801,6 +1801,7 @@ def clamp_max(
# Conditional references # Conditional references
# #
# https://pytorch.org/docs/stable/generated/torch.where.html # https://pytorch.org/docs/stable/generated/torch.where.html
# TODO: implement alternate where # TODO: implement alternate where
@register_decomposition(aten.where) @register_decomposition(aten.where)
@ -4092,7 +4093,6 @@ def new_empty(
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
pin_memory: bool = False, pin_memory: bool = False,
) -> TensorLikeType: ) -> TensorLikeType:
dtype = a.dtype if dtype is None else dtype dtype = a.dtype if dtype is None else dtype
layout = a.layout if layout is None else layout layout = a.layout if layout is None else layout
device = a.device if device is None else device device = a.device if device is None else device
@ -4275,7 +4275,6 @@ def empty_like(
requires_grad: bool = False, requires_grad: bool = False,
memory_format: torch.memory_format = torch.preserve_format, memory_format: torch.memory_format = torch.preserve_format,
) -> TensorLikeType: ) -> TensorLikeType:
dtype = a.dtype if dtype is None else dtype dtype = a.dtype if dtype is None else dtype
layout = a.layout if layout is None else layout layout = a.layout if layout is None else layout
device = a.device if device is None else device device = a.device if device is None else device

View File

@ -82,7 +82,6 @@ def _dropout_helper(
def alpha_dropout( def alpha_dropout(
self: TensorLikeType, p: float = 0.5, training: bool = False, inplace: bool = False self: TensorLikeType, p: float = 0.5, training: bool = False, inplace: bool = False
) -> TensorLikeType: ) -> TensorLikeType:
if inplace: if inplace:
raise NotImplementedError raise NotImplementedError
@ -178,7 +177,6 @@ def celu(
def dropout( def dropout(
a: TensorLikeType, p: float = 0.5, training: bool = True, inplace: bool = False a: TensorLikeType, p: float = 0.5, training: bool = True, inplace: bool = False
) -> TensorLikeType: ) -> TensorLikeType:
if inplace: if inplace:
raise NotImplementedError raise NotImplementedError

View File

@ -295,7 +295,6 @@ class MetaConverter:
torch._C.DispatchKey.ADInplaceOrView, False torch._C.DispatchKey.ADInplaceOrView, False
) )
try: try:
if base.dtype == t.dtype: if base.dtype == t.dtype:
pass pass
elif is_c_of_r(base.dtype, t.dtype): elif is_c_of_r(base.dtype, t.dtype):

View File

@ -180,6 +180,7 @@ def _rebuild_tensor_v2(
_sparse_tensors_to_validate: List["torch.Tensor"] = [] _sparse_tensors_to_validate: List["torch.Tensor"] = []
# In _legacy_load() in serialization.py we unpickle storages after the sparse # In _legacy_load() in serialization.py we unpickle storages after the sparse
# tensors have been already unpickled. Those storages contain data necessary for # tensors have been already unpickled. Those storages contain data necessary for
# validating sparse tensors: indices and values. That's why sparse tensors are # validating sparse tensors: indices and values. That's why sparse tensors are

View File

@ -103,7 +103,6 @@ class DistributedDataParallel(Module):
gradient_as_bucket_view=False, gradient_as_bucket_view=False,
static_graph=False, static_graph=False,
): ):
super().__init__() super().__init__()
self.logger: Optional[dist.Logger] = None self.logger: Optional[dist.Logger] = None
if not any((p.requires_grad for p in module.parameters())): if not any((p.requires_grad for p in module.parameters())):

View File

@ -1849,10 +1849,10 @@ class FlatParamHandle:
flat_param.grad = None flat_param.grad = None
def _deregister_orig_params(self): def _deregister_orig_params(self):
for (param_name, module, _) in self.flat_param._param_infos: for param_name, module, _ in self.flat_param._param_infos:
if hasattr(module, param_name): if hasattr(module, param_name):
delattr(module, param_name) delattr(module, param_name)
for (param_name, module, _, _, _, _) in self.flat_param._shared_param_infos: for param_name, module, _, _, _, _ in self.flat_param._shared_param_infos:
if hasattr(module, param_name): if hasattr(module, param_name):
delattr(module, param_name) delattr(module, param_name)

View File

@ -123,9 +123,9 @@ _context = engine.background_context
@contextlib.contextmanager @contextlib.contextmanager
def create_export_diagnostic_context() -> Generator[ def create_export_diagnostic_context() -> (
infra.DiagnosticContext, None, None Generator[infra.DiagnosticContext, None, None]
]: ):
"""Create a diagnostic context for export. """Create a diagnostic context for export.
This is a workaround for code robustness since diagnostic context is accessed by This is a workaround for code robustness since diagnostic context is accessed by

View File

@ -30,7 +30,6 @@ def _export(
args, args,
**kwargs, **kwargs,
) -> Union["onnx.ModelProto", bytes]: ) -> Union["onnx.ModelProto", bytes]:
export_options = options.ExportOptions() export_options = options.ExportOptions()
export_options.update(**kwargs) export_options.update(**kwargs)
# Apply decomposition table to the input graph. # Apply decomposition table to the input graph.

View File

@ -149,9 +149,9 @@ _ATENLIB_FUNCTIONS = {
} }
def _create_op_overload_to_exporter_key_table() -> Dict[ def _create_op_overload_to_exporter_key_table() -> (
Union[torch._ops.OpOverload, Callable], str Dict[Union[torch._ops.OpOverload, Callable], str]
]: ):
# TODO(justinchuby): Improve how the table is constructed. # TODO(justinchuby): Improve how the table is constructed.
table: Dict[Union[torch._ops.OpOverload, Callable], str] = {} table: Dict[Union[torch._ops.OpOverload, Callable], str] = {}
@ -189,9 +189,9 @@ _OP_OVERLOAD_TO_EXPORTER_KEY_TABLE = _create_op_overload_to_exporter_key_table()
@_beartype.beartype @_beartype.beartype
def _create_onnx_friendly_decomposition_table() -> Dict[ def _create_onnx_friendly_decomposition_table() -> (
torch._ops.OpOverload, Callable Dict[torch._ops.OpOverload, Callable]
]: ):
decomposition_table: Dict[torch._ops.OpOverload, Callable] = {} decomposition_table: Dict[torch._ops.OpOverload, Callable] = {}
for op_overload, decomp_fn in torch._decomp.decomposition_table.items(): for op_overload, decomp_fn in torch._decomp.decomposition_table.items():
# Skip decomposition into "prim::*" ops, because they are not generally supported by ONNX. # Skip decomposition into "prim::*" ops, because they are not generally supported by ONNX.

View File

@ -337,7 +337,6 @@ def _export_fx_node_to_onnxscript(
_validate_op_between_ort_torch(node, symbolic_fn, torch_args, torch_kwargs) _validate_op_between_ort_torch(node, symbolic_fn, torch_args, torch_kwargs)
fx_name_to_onnxscipt_value[node.name] = output fx_name_to_onnxscipt_value[node.name] = output
elif node.op == "output": elif node.op == "output":
if isinstance(node.args[0], torch.fx.Node): if isinstance(node.args[0], torch.fx.Node):
onnx_tensor_or_tensor_tuple = fx_name_to_onnxscipt_value[node.args[0].name] onnx_tensor_or_tensor_tuple = fx_name_to_onnxscipt_value[node.args[0].name]
onnxscript_graph.register_outputs(onnx_tensor_or_tensor_tuple) onnxscript_graph.register_outputs(onnx_tensor_or_tensor_tuple)
@ -389,7 +388,6 @@ def _export_fx_node_to_onnxscript(
def export_fx_to_onnxscript( def export_fx_to_onnxscript(
fx_module_with_metadata: torch.fx.GraphModule, options: options.ExportOptions fx_module_with_metadata: torch.fx.GraphModule, options: options.ExportOptions
): ):
# Initialize the ONNX graph # Initialize the ONNX graph
onnxscript_graph = graph_building.TorchScriptGraph() onnxscript_graph = graph_building.TorchScriptGraph()
tracer = graph_building.TorchScriptTracingEvaluator(onnxscript_graph) tracer = graph_building.TorchScriptTracingEvaluator(onnxscript_graph)

View File

@ -253,7 +253,6 @@ def export_without_parameters_and_buffers(
Tuple[Any, ...], Tuple[Any, ...],
Tuple[Any, ...], Tuple[Any, ...],
]: ]:
graph_module, bound_args = _trace_into_fx_graph_via_fx_symbolic_trace( graph_module, bound_args = _trace_into_fx_graph_via_fx_symbolic_trace(
module, *args, **kwargs module, *args, **kwargs
) )

View File

@ -292,7 +292,6 @@ def _find_onnxscript_op(
def _convert_tensor_to_numpy(input: Any) -> Any: def _convert_tensor_to_numpy(input: Any) -> Any:
try: try:
import numpy as np import numpy as np
except ImportError: except ImportError:

View File

@ -68,7 +68,6 @@ def batch_norm(
eps, eps,
cudnn_enabled, cudnn_enabled,
): ):
if ( if (
torch.is_autocast_enabled() torch.is_autocast_enabled()
and not symbolic_helper.args_have_same_dtype( and not symbolic_helper.args_have_same_dtype(

View File

@ -1324,7 +1324,6 @@ def _op_with_optional_float_cast(g: jit_utils.GraphContext, op_name, *args, **kw
if require_cast: if require_cast:
for input in inputs: for input in inputs:
if input.isCompleteTensor(): if input.isCompleteTensor():
input_scalar_type = _type_utils.JitScalarType.from_value(input) input_scalar_type = _type_utils.JitScalarType.from_value(input)
if input_scalar_type != dtype_0: if input_scalar_type != dtype_0:
@ -4484,7 +4483,6 @@ def _generic_rnn(
batch_first=None, batch_first=None,
batch_sizes=None, batch_sizes=None,
): ):
warnings.warn( warnings.warn(
"Exporting a model to ONNX with a batch_size other than 1, " "Exporting a model to ONNX with a batch_size other than 1, "
+ "with a variable length with " + "with a variable length with "

View File

@ -5,6 +5,7 @@ from typing import cast
import torch import torch
from torch.types import Storage from torch.types import Storage
# because get_storage_from_record returns a tensor!? # because get_storage_from_record returns a tensor!?
class _HasStorage: class _HasStorage:
def __init__(self, storage): def __init__(self, storage):

View File

@ -16,6 +16,7 @@ _zip_searchorder = (
(".py", False), (".py", False),
) )
# Replace any occurrences of '\r\n?' in the input string with '\n'. # Replace any occurrences of '\r\n?' in the input string with '\n'.
# This converts DOS and Mac line endings to Unix line endings. # This converts DOS and Mac line endings to Unix line endings.
def _normalize_line_endings(source): def _normalize_line_endings(source):

View File

@ -916,7 +916,6 @@ class PackageExporter:
def _persistent_id(self, obj): def _persistent_id(self, obj):
if torch.is_storage(obj) or isinstance(obj, torch.storage.TypedStorage): if torch.is_storage(obj) or isinstance(obj, torch.storage.TypedStorage):
storage: Storage storage: Storage
if isinstance(obj, torch.storage.TypedStorage): if isinstance(obj, torch.storage.TypedStorage):
# TODO: Once we decide to break serialization FC, we can # TODO: Once we decide to break serialization FC, we can

View File

@ -615,6 +615,7 @@ class AliasInfo:
# the great majority of PyTorch's (public) operators. # the great majority of PyTorch's (public) operators.
# #
# Classes and methods for the operator database # Classes and methods for the operator database
@dataclass @dataclass
class OpInfo: class OpInfo:
@ -1549,6 +1550,7 @@ def make_error_inputs_elementwise_binary(error_inputs_func):
# The following functions and classes are for testing elementwise binary operators. # The following functions and classes are for testing elementwise binary operators.
# Returns a generator of pairs of contiguous tensors on the requested device # Returns a generator of pairs of contiguous tensors on the requested device
# and with the requested dtype. # and with the requested dtype.
# #
@ -1997,7 +1999,6 @@ class BinaryUfuncInfo(OpInfo):
supports_two_python_scalars=False, # Whether the operator allows scalar x scalar inputs supports_two_python_scalars=False, # Whether the operator allows scalar x scalar inputs
**kwargs, **kwargs,
): ):
self._original_binary_ufunc_args = locals().copy() self._original_binary_ufunc_args = locals().copy()
# Elementwise binary operations perform the equivalent of test_numpy_refs # Elementwise binary operations perform the equivalent of test_numpy_refs
@ -2144,7 +2145,6 @@ def _filter_unary_elementwise_tensor(a, *, op):
def generate_elementwise_unary_tensors(op, *, device, dtype, requires_grad, **kwargs): def generate_elementwise_unary_tensors(op, *, device, dtype, requires_grad, **kwargs):
# Special-cases bool # Special-cases bool
if dtype is torch.bool: if dtype is torch.bool:
tensors = ( tensors = (
@ -2491,7 +2491,6 @@ class SpectralFuncInfo(OpInfo):
decorators=None, decorators=None,
**kwargs, **kwargs,
): ):
self._original_spectral_func_args = dict(locals()).copy() self._original_spectral_func_args = dict(locals()).copy()
self._original_spectral_func_args.update(kwargs) self._original_spectral_func_args.update(kwargs)

View File

@ -29,6 +29,7 @@ from torch.testing._internal.opinfo.core import (
) )
from torch.testing._internal.opinfo.utils import prod_numpy, reference_reduction_numpy from torch.testing._internal.opinfo.utils import prod_numpy, reference_reduction_numpy
# Used for log_softmax, softmax, softmin # Used for log_softmax, softmax, softmin
def sample_inputs_softmax_variant( def sample_inputs_softmax_variant(
op_info, op_info,

View File

@ -53,7 +53,6 @@ class SpectralFuncPythonRefInfo(SpectralFuncInfo):
supports_nvfuser=True, supports_nvfuser=True,
**kwargs, **kwargs,
): # additional kwargs override kwargs inherited from the torch opinfo ): # additional kwargs override kwargs inherited from the torch opinfo
self.torch_opinfo_name = torch_opinfo_name self.torch_opinfo_name = torch_opinfo_name
self.torch_opinfo = _find_referenced_opinfo( self.torch_opinfo = _find_referenced_opinfo(
torch_opinfo_name, torch_opinfo_variant, op_db=op_db torch_opinfo_name, torch_opinfo_variant, op_db=op_db

View File

@ -36,6 +36,7 @@ from torch.testing._internal.opinfo.utils import (
if TEST_SCIPY: if TEST_SCIPY:
import scipy.special import scipy.special
# TODO: Consolidate `i0e` with sample_inputs_unary when `make_tensor`, # TODO: Consolidate `i0e` with sample_inputs_unary when `make_tensor`,
# supports `exclude` argument. # supports `exclude` argument.
# For more context: https://github.com/pytorch/pytorch/pull/56352#discussion_r633277617 # For more context: https://github.com/pytorch/pytorch/pull/56352#discussion_r633277617

View File

@ -103,7 +103,6 @@ class PythonRefInfo(OpInfo):
supports_nvfuser=True, supports_nvfuser=True,
**kwargs, **kwargs,
): # additional kwargs override kwargs inherited from the torch opinfo ): # additional kwargs override kwargs inherited from the torch opinfo
self.torch_opinfo_name = torch_opinfo_name self.torch_opinfo_name = torch_opinfo_name
self.torch_opinfo_variant_name = torch_opinfo_variant_name self.torch_opinfo_variant_name = torch_opinfo_variant_name
self.torch_opinfo = _find_referenced_opinfo( self.torch_opinfo = _find_referenced_opinfo(
@ -134,7 +133,6 @@ class ReductionPythonRefInfo(ReductionOpInfo):
supports_nvfuser=True, supports_nvfuser=True,
**kwargs, **kwargs,
): # additional kwargs override kwargs inherited from the torch opinfo ): # additional kwargs override kwargs inherited from the torch opinfo
self.torch_opinfo_name = torch_opinfo_name self.torch_opinfo_name = torch_opinfo_name
self.torch_opinfo_variant_name = torch_opinfo_variant_name self.torch_opinfo_variant_name = torch_opinfo_variant_name
self.torch_opinfo = _find_referenced_opinfo( self.torch_opinfo = _find_referenced_opinfo(
@ -169,7 +167,6 @@ class ElementwiseUnaryPythonRefInfo(UnaryUfuncInfo):
supports_nvfuser=True, supports_nvfuser=True,
**kwargs, **kwargs,
): # additional kwargs override kwargs inherited from the torch opinfo ): # additional kwargs override kwargs inherited from the torch opinfo
self.torch_opinfo_name = torch_opinfo_name self.torch_opinfo_name = torch_opinfo_name
self.torch_opinfo_variant_name = torch_opinfo_variant_name self.torch_opinfo_variant_name = torch_opinfo_variant_name
self.torch_opinfo = _find_referenced_opinfo( self.torch_opinfo = _find_referenced_opinfo(
@ -201,7 +198,6 @@ class ElementwiseBinaryPythonRefInfo(BinaryUfuncInfo):
supports_nvfuser=True, supports_nvfuser=True,
**kwargs, **kwargs,
): # additional kwargs override kwargs inherited from the torch opinfo ): # additional kwargs override kwargs inherited from the torch opinfo
self.torch_opinfo_name = torch_opinfo_name self.torch_opinfo_name = torch_opinfo_name
self.torch_opinfo_variant_name = torch_opinfo_variant_name self.torch_opinfo_variant_name = torch_opinfo_variant_name
self.torch_opinfo = _find_referenced_opinfo( self.torch_opinfo = _find_referenced_opinfo(

View File

@ -1,5 +1,6 @@
import sympy import sympy
# The normal Python interpretation of the operators # The normal Python interpretation of the operators
# NB: For magic methods this needs to use normal magic methods # NB: For magic methods this needs to use normal magic methods
# so that test_magic_methods works # so that test_magic_methods works

View File

@ -313,6 +313,7 @@ JIT_TO_CPP_DEFAULT = {
"long": "at::kLong", "long": "at::kLong",
} }
# Convert a JIT default into C++ expression representing the default # Convert a JIT default into C++ expression representing the default
def default_expr(d: str, t: Type, *, symint: bool) -> str: def default_expr(d: str, t: Type, *, symint: bool) -> str:
if d == "None" and str(t) == "Tensor?": if d == "None" and str(t) == "Tensor?":

View File

@ -69,6 +69,7 @@ reapply_views_binding = Binding(
default=None, default=None,
) )
# The lambda capture itself doesn't have a name. # The lambda capture itself doesn't have a name.
# The name returned here corresponds to the name of the inner function called by the lambda. # The name returned here corresponds to the name of the inner function called by the lambda.
def name( def name(

Some files were not shown because too many files have changed in this diff Show More