mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[BE]: Update ruff to 0.11.8 (#153249)
Fixes a ton of false negatives throughout the codebase. RUFF also properly validates NOQA comments now and most of the changes are fixing typos there or removing filewide flake8 suppressions that were also silencing ruff issues. Pull Request resolved: https://github.com/pytorch/pytorch/pull/153249 Approved by: https://github.com/cyyever, https://github.com/albanD, https://github.com/seemethere
This commit is contained in:
parent
5c3fddb9cc
commit
3555ebb63d
2
.flake8
2
.flake8
|
|
@ -19,6 +19,8 @@ ignore =
|
||||||
G100,G101,G200
|
G100,G101,G200
|
||||||
# these ignores are from flake8-simplify. please fix or ignore with commented reason
|
# these ignores are from flake8-simplify. please fix or ignore with commented reason
|
||||||
SIM105,SIM108,SIM110,SIM111,SIM113,SIM114,SIM115,SIM116,SIM117,SIM118,SIM119,SIM12,
|
SIM105,SIM108,SIM110,SIM111,SIM113,SIM114,SIM115,SIM116,SIM117,SIM118,SIM119,SIM12,
|
||||||
|
# SIM104 is already covered by pyupgrade ruff
|
||||||
|
SIM104,
|
||||||
# flake8-simplify code styles
|
# flake8-simplify code styles
|
||||||
SIM102,SIM103,SIM106,SIM112,
|
SIM102,SIM103,SIM106,SIM112,
|
||||||
# TorchFix codes that don't make sense for PyTorch itself:
|
# TorchFix codes that don't make sense for PyTorch itself:
|
||||||
|
|
|
||||||
1
.github/scripts/filter_test_configs.py
vendored
1
.github/scripts/filter_test_configs.py
vendored
|
|
@ -1,4 +1,5 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
# ruff: noqa: LOG015
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
|
|
||||||
|
|
@ -1456,7 +1456,7 @@ init_command = [
|
||||||
'black==23.12.1',
|
'black==23.12.1',
|
||||||
'usort==1.0.8.post1',
|
'usort==1.0.8.post1',
|
||||||
'isort==5.13.2',
|
'isort==5.13.2',
|
||||||
'ruff==0.9.8', # sync with RUFF
|
'ruff==0.11.8', # sync with RUFF
|
||||||
]
|
]
|
||||||
is_formatter = true
|
is_formatter = true
|
||||||
|
|
||||||
|
|
@ -1542,7 +1542,7 @@ init_command = [
|
||||||
'python3',
|
'python3',
|
||||||
'tools/linter/adapters/pip_init.py',
|
'tools/linter/adapters/pip_init.py',
|
||||||
'--dry-run={{DRYRUN}}',
|
'--dry-run={{DRYRUN}}',
|
||||||
'ruff==0.9.8', # sync with PYFMT
|
'ruff==0.11.8', # sync with PYFMT
|
||||||
]
|
]
|
||||||
is_formatter = true
|
is_formatter = true
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1379,7 +1379,7 @@ def _produce_dynamic_shapes_for_export(path, x):
|
||||||
|
|
||||||
if not isinstance(x, torch.Tensor):
|
if not isinstance(x, torch.Tensor):
|
||||||
return None
|
return None
|
||||||
return {i: Dim.AUTO for i in getattr(x, "_dynamo_dynamic_indices", {})}
|
return dict.fromkeys(getattr(x, "_dynamo_dynamic_indices", {}), Dim.AUTO)
|
||||||
|
|
||||||
|
|
||||||
class AOTInductorModelCache:
|
class AOTInductorModelCache:
|
||||||
|
|
@ -1671,7 +1671,7 @@ def maybe_snapshot_memory(should_snapshot_memory, suffix):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("Failed to save memory snapshot, %s", e)
|
log.error("Failed to save memory snapshot, %s", e)
|
||||||
|
|
||||||
torch.cuda.memory._record_memory_history(enabled=None)
|
torch.cuda.memory._record_memory_history(enabled=None)
|
||||||
|
|
||||||
|
|
@ -2687,7 +2687,7 @@ class BenchmarkRunner:
|
||||||
experiment,
|
experiment,
|
||||||
tag,
|
tag,
|
||||||
):
|
):
|
||||||
logging.info("Minifying %s...", name)
|
log.info("Minifying %s...", name)
|
||||||
os.environ["TORCH_COMPILE_DEBUG"] = "1"
|
os.environ["TORCH_COMPILE_DEBUG"] = "1"
|
||||||
os.environ["TORCHDYNAMO_REPRO_AFTER"] = "dynamo"
|
os.environ["TORCHDYNAMO_REPRO_AFTER"] = "dynamo"
|
||||||
os.environ["TORCHDYNAMO_REPRO_LEVEL"] = "4"
|
os.environ["TORCHDYNAMO_REPRO_LEVEL"] = "4"
|
||||||
|
|
@ -2702,9 +2702,9 @@ class BenchmarkRunner:
|
||||||
try:
|
try:
|
||||||
shutil.move("repro.py", f"{repro_dir}/{name}_repro.py")
|
shutil.move("repro.py", f"{repro_dir}/{name}_repro.py")
|
||||||
except OSError:
|
except OSError:
|
||||||
logging.error("Could not find repro script for model %s", name)
|
log.error("Could not find repro script for model %s", name)
|
||||||
else:
|
else:
|
||||||
logging.info(
|
log.info(
|
||||||
"Repro script for model %s with minified graph saved to %s",
|
"Repro script for model %s with minified graph saved to %s",
|
||||||
name,
|
name,
|
||||||
repro_dir,
|
repro_dir,
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
# flake8: noqa: F821
|
||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
import logging
|
import logging
|
||||||
|
|
@ -48,7 +49,6 @@ def pip_install(package):
|
||||||
|
|
||||||
# Disable the flake warnings for the imports. Flake8 does not provide a way to
|
# Disable the flake warnings for the imports. Flake8 does not provide a way to
|
||||||
# disable just warning for the entire file. Disabling flake8 entirely.
|
# disable just warning for the entire file. Disabling flake8 entirely.
|
||||||
# flake8: noqa
|
|
||||||
imports = [
|
imports = [
|
||||||
"AlbertForPreTraining",
|
"AlbertForPreTraining",
|
||||||
"AutoConfig",
|
"AutoConfig",
|
||||||
|
|
@ -111,7 +111,7 @@ BATCH_SIZE_KNOWN_MODELS = {}
|
||||||
# Get the list of models and their batch sizes
|
# Get the list of models and their batch sizes
|
||||||
MODELS_FILENAME = os.path.join(os.path.dirname(__file__), "huggingface_models_list.txt")
|
MODELS_FILENAME = os.path.join(os.path.dirname(__file__), "huggingface_models_list.txt")
|
||||||
assert os.path.exists(MODELS_FILENAME)
|
assert os.path.exists(MODELS_FILENAME)
|
||||||
with open(MODELS_FILENAME, "r") as fh:
|
with open(MODELS_FILENAME) as fh:
|
||||||
lines = fh.readlines()
|
lines = fh.readlines()
|
||||||
lines = [line.rstrip() for line in lines]
|
lines = [line.rstrip() for line in lines]
|
||||||
for line in lines:
|
for line in lines:
|
||||||
|
|
@ -166,7 +166,7 @@ def get_sequence_length(model_cls, model_name):
|
||||||
seq_length = 10000 # NB: a more realistic size is 155136
|
seq_length = 10000 # NB: a more realistic size is 155136
|
||||||
else:
|
else:
|
||||||
log.info(
|
log.info(
|
||||||
f"Sequence Length not defined for {model_name}. Choosing 128 arbitrarily"
|
f"Sequence Length not defined for {model_name}. Choosing 128 arbitrarily" # noqa: G004
|
||||||
)
|
)
|
||||||
seq_length = 128
|
seq_length = 128
|
||||||
return seq_length
|
return seq_length
|
||||||
|
|
@ -204,22 +204,16 @@ def generate_inputs_for_model(
|
||||||
|
|
||||||
input_dict = {"input_ids": input}
|
input_dict = {"input_ids": input}
|
||||||
|
|
||||||
if (
|
if model_name.startswith(("T5", "M2M100", "MT5")) or model_cls in [
|
||||||
model_name.startswith("T5")
|
BlenderbotModel,
|
||||||
or model_name.startswith("M2M100")
|
BlenderbotSmallModel,
|
||||||
or model_name.startswith("MT5")
|
BlenderbotForConditionalGeneration,
|
||||||
or model_cls
|
BlenderbotSmallForConditionalGeneration,
|
||||||
in [
|
PegasusModel,
|
||||||
BlenderbotModel,
|
PegasusForConditionalGeneration,
|
||||||
BlenderbotSmallModel,
|
MarianModel,
|
||||||
BlenderbotForConditionalGeneration,
|
MarianMTModel,
|
||||||
BlenderbotSmallForConditionalGeneration,
|
]:
|
||||||
PegasusModel,
|
|
||||||
PegasusForConditionalGeneration,
|
|
||||||
MarianModel,
|
|
||||||
MarianMTModel,
|
|
||||||
]
|
|
||||||
):
|
|
||||||
input_dict["decoder_input_ids"] = input
|
input_dict["decoder_input_ids"] = input
|
||||||
|
|
||||||
if model_name.startswith("Lxmert"):
|
if model_name.startswith("Lxmert"):
|
||||||
|
|
@ -251,11 +245,8 @@ def generate_inputs_for_model(
|
||||||
device, 0, seq_length, (bs,)
|
device, 0, seq_length, (bs,)
|
||||||
)
|
)
|
||||||
input_dict["end_positions"] = rand_int_tensor(device, 0, seq_length, (bs,))
|
input_dict["end_positions"] = rand_int_tensor(device, 0, seq_length, (bs,))
|
||||||
elif (
|
elif model_name.endswith(
|
||||||
model_name.endswith("MaskedLM")
|
("MaskedLM", "HeadModel", "CausalLM", "DoubleHeadsModel")
|
||||||
or model_name.endswith("HeadModel")
|
|
||||||
or model_name.endswith("CausalLM")
|
|
||||||
or model_name.endswith("DoubleHeadsModel")
|
|
||||||
):
|
):
|
||||||
input_dict["labels"] = rand_int_tensor(
|
input_dict["labels"] = rand_int_tensor(
|
||||||
device, 0, vocab_size, (bs, seq_length)
|
device, 0, vocab_size, (bs, seq_length)
|
||||||
|
|
@ -429,7 +420,7 @@ class HuggingfaceRunner(BenchmarkRunner):
|
||||||
elif batch_size is None:
|
elif batch_size is None:
|
||||||
batch_size_default = 16
|
batch_size_default = 16
|
||||||
log.info(
|
log.info(
|
||||||
f"Batch size not specified for {model_name}. Setting batch_size=16"
|
f"Batch size not specified for {model_name}. Setting batch_size=16" # noqa: G004
|
||||||
)
|
)
|
||||||
|
|
||||||
if batch_size is None:
|
if batch_size is None:
|
||||||
|
|
@ -438,7 +429,7 @@ class HuggingfaceRunner(BenchmarkRunner):
|
||||||
if model_name in batch_size_divisors:
|
if model_name in batch_size_divisors:
|
||||||
batch_size = max(int(batch_size / batch_size_divisors[model_name]), 1)
|
batch_size = max(int(batch_size / batch_size_divisors[model_name]), 1)
|
||||||
log.info(
|
log.info(
|
||||||
f"Running smaller batch size={batch_size} for {model_name}, orig batch_size={batch_size_default}"
|
f"Running smaller batch size={batch_size} for {model_name}, orig batch_size={batch_size_default}" # noqa: G004
|
||||||
)
|
)
|
||||||
|
|
||||||
example_inputs = generate_inputs_for_model(
|
example_inputs = generate_inputs_for_model(
|
||||||
|
|
@ -474,8 +465,8 @@ class HuggingfaceRunner(BenchmarkRunner):
|
||||||
if index < start or index >= end:
|
if index < start or index >= end:
|
||||||
continue
|
continue
|
||||||
if (
|
if (
|
||||||
not re.search("|".join(args.filter), model_name, re.I)
|
not re.search("|".join(args.filter), model_name, re.IGNORECASE)
|
||||||
or re.search("|".join(args.exclude), model_name, re.I)
|
or re.search("|".join(args.exclude), model_name, re.IGNORECASE)
|
||||||
or model_name in args.exclude_exact
|
or model_name in args.exclude_exact
|
||||||
or model_name in self.skip_models
|
or model_name in self.skip_models
|
||||||
):
|
):
|
||||||
|
|
@ -621,7 +612,7 @@ def refresh_model_names_and_batch_sizes():
|
||||||
+ [f"--output={MODELS_FILENAME}"]
|
+ [f"--output={MODELS_FILENAME}"]
|
||||||
)
|
)
|
||||||
except subprocess.SubprocessError:
|
except subprocess.SubprocessError:
|
||||||
log.warning(f"Failed to find suitable batch size for {model_name}")
|
log.warning(f"Failed to find suitable batch size for {model_name}") # noqa: G004
|
||||||
|
|
||||||
|
|
||||||
def huggingface_main():
|
def huggingface_main():
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,5 @@
|
||||||
# flake8: noqa
|
# flake8: noqa: B902
|
||||||
|
|
||||||
import triton
|
|
||||||
from prettytable import PrettyTable
|
from prettytable import PrettyTable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
@ -18,7 +17,7 @@ torch.manual_seed(0)
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
|
||||||
|
|
||||||
class Func(object):
|
class Func:
|
||||||
# mm
|
# mm
|
||||||
@torch._dynamo.optimize("inductor")
|
@torch._dynamo.optimize("inductor")
|
||||||
def mm(a, b, bias):
|
def mm(a, b, bias):
|
||||||
|
|
@ -45,7 +44,9 @@ class Func(object):
|
||||||
return torch.relu(y)
|
return torch.relu(y)
|
||||||
|
|
||||||
|
|
||||||
def bench(shape, layer_id, p, fusion_types=[""]):
|
def bench(shape, layer_id, p, fusion_types=None):
|
||||||
|
if fusion_types is None:
|
||||||
|
fusion_types = [""]
|
||||||
dtype = torch.float16
|
dtype = torch.float16
|
||||||
M, K = shape[0]
|
M, K = shape[0]
|
||||||
_, N = shape[1]
|
_, N = shape[1]
|
||||||
|
|
@ -60,7 +61,7 @@ def bench(shape, layer_id, p, fusion_types=[""]):
|
||||||
row = [layer_id]
|
row = [layer_id]
|
||||||
for fusion_type in fusion_types:
|
for fusion_type in fusion_types:
|
||||||
if fusion_type == "":
|
if fusion_type == "":
|
||||||
fn_mm = getattr(Func, "mm")
|
fn_mm = Func.mm
|
||||||
else:
|
else:
|
||||||
fn_mm = getattr(Func, f"mm_{fusion_type}")
|
fn_mm = getattr(Func, f"mm_{fusion_type}")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1450,7 +1450,7 @@ class DashboardUpdater:
|
||||||
try:
|
try:
|
||||||
RegressionTracker(self.args).diff()
|
RegressionTracker(self.args).diff()
|
||||||
except Exception:
|
except Exception:
|
||||||
logging.exception("")
|
log.exception("")
|
||||||
with open(f"{self.args.output_dir}/gh_regression.txt", "w") as gh_fh:
|
with open(f"{self.args.output_dir}/gh_regression.txt", "w") as gh_fh:
|
||||||
gh_fh.write("")
|
gh_fh.write("")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,9 @@ import pandas as pd
|
||||||
from tabulate import tabulate
|
from tabulate import tabulate
|
||||||
|
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def gmean(s):
|
def gmean(s):
|
||||||
return s.product() ** (1 / len(s))
|
return s.product() ** (1 / len(s))
|
||||||
|
|
||||||
|
|
@ -67,7 +70,7 @@ def main(directory, amp, float32, perf_compare):
|
||||||
try:
|
try:
|
||||||
dfs[os.path.basename(f)].append(pd.read_csv(f))
|
dfs[os.path.basename(f)].append(pd.read_csv(f))
|
||||||
except Exception:
|
except Exception:
|
||||||
logging.warning("failed parsing %s", f)
|
log.warning("failed parsing %s", f)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
# dtype -> statistic -> benchmark -> compiler -> value
|
# dtype -> statistic -> benchmark -> compiler -> value
|
||||||
|
|
|
||||||
|
|
@ -43,7 +43,7 @@ def torchao_optimize_ctx(quantization: str):
|
||||||
from torchao.quantization.autoquant import AUTOQUANT_CACHE
|
from torchao.quantization.autoquant import AUTOQUANT_CACHE
|
||||||
|
|
||||||
if len(AUTOQUANT_CACHE) == 0:
|
if len(AUTOQUANT_CACHE) == 0:
|
||||||
raise Exception( # noqa: TRY002`
|
raise Exception( # noqa: TRY002
|
||||||
"NotAutoquantizable"
|
"NotAutoquantizable"
|
||||||
f"Found no autoquantizable layers in model {type(module)}, stopping autoquantized run"
|
f"Found no autoquantizable layers in model {type(module)}, stopping autoquantized run"
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,8 @@ import pandas as pd
|
||||||
from torch._functorch.benchmark_utils import compute_utilization
|
from torch._functorch.benchmark_utils import compute_utilization
|
||||||
|
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
# process the chrome traces output by the pytorch profiler
|
# process the chrome traces output by the pytorch profiler
|
||||||
# require the json input file's name to be in format {model_name}_chrome_trace_*.json
|
# require the json input file's name to be in format {model_name}_chrome_trace_*.json
|
||||||
# the runtimes file should have format (model_name, runtime)
|
# the runtimes file should have format (model_name, runtime)
|
||||||
|
|
@ -65,7 +67,7 @@ def main():
|
||||||
)
|
)
|
||||||
print(f"{modelname}, {utilization}, {mm_conv_utilization}")
|
print(f"{modelname}, {utilization}, {mm_conv_utilization}")
|
||||||
except BaseException:
|
except BaseException:
|
||||||
logging.exception("%s, ERROR", filename)
|
log.exception("%s, ERROR", filename)
|
||||||
print(f"{filename}, ERROR")
|
print(f"{filename}, ERROR")
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -73,6 +73,23 @@ quote-style = "double"
|
||||||
|
|
||||||
[tool.ruff.lint]
|
[tool.ruff.lint]
|
||||||
# NOTE: Synchoronize the ignores with .flake8
|
# NOTE: Synchoronize the ignores with .flake8
|
||||||
|
external = [
|
||||||
|
"B001",
|
||||||
|
"B902",
|
||||||
|
"B950",
|
||||||
|
"E121",
|
||||||
|
"E122",
|
||||||
|
"E128",
|
||||||
|
"E131",
|
||||||
|
"E704",
|
||||||
|
"E723",
|
||||||
|
"F723",
|
||||||
|
"F812",
|
||||||
|
"P201",
|
||||||
|
"P204",
|
||||||
|
"T484",
|
||||||
|
"TOR901",
|
||||||
|
]
|
||||||
ignore = [
|
ignore = [
|
||||||
# these ignores are from flake8-bugbear; please fix!
|
# these ignores are from flake8-bugbear; please fix!
|
||||||
"B007", "B008", "B017",
|
"B007", "B008", "B017",
|
||||||
|
|
@ -108,6 +125,8 @@ ignore = [
|
||||||
"SIM117",
|
"SIM117",
|
||||||
"SIM118",
|
"SIM118",
|
||||||
"UP007", # keep-runtime-typing
|
"UP007", # keep-runtime-typing
|
||||||
|
"TC006",
|
||||||
|
"TC007",
|
||||||
]
|
]
|
||||||
select = [
|
select = [
|
||||||
"B",
|
"B",
|
||||||
|
|
@ -173,7 +192,7 @@ select = [
|
||||||
"RUF030", # No print statement in assert
|
"RUF030", # No print statement in assert
|
||||||
"S324", # for hashlib FIPS compliance
|
"S324", # for hashlib FIPS compliance
|
||||||
"SLOT",
|
"SLOT",
|
||||||
"TCH",
|
"TC",
|
||||||
"TRY002", # ban vanilla raise (todo fix NOQAs)
|
"TRY002", # ban vanilla raise (todo fix NOQAs)
|
||||||
"TRY203",
|
"TRY203",
|
||||||
"TRY401", # verbose-log-message
|
"TRY401", # verbose-log-message
|
||||||
|
|
@ -187,6 +206,12 @@ select = [
|
||||||
"functorch/notebooks/**" = [
|
"functorch/notebooks/**" = [
|
||||||
"F401",
|
"F401",
|
||||||
]
|
]
|
||||||
|
"test/export/**" = [
|
||||||
|
"PGH004"
|
||||||
|
]
|
||||||
|
"test/typing/**" = [
|
||||||
|
"PGH004"
|
||||||
|
]
|
||||||
"test/typing/reveal/**" = [
|
"test/typing/reveal/**" = [
|
||||||
"F821",
|
"F821",
|
||||||
]
|
]
|
||||||
|
|
@ -200,6 +225,9 @@ select = [
|
||||||
"test/dynamo/test_debug_utils.py" = [
|
"test/dynamo/test_debug_utils.py" = [
|
||||||
"UP037",
|
"UP037",
|
||||||
]
|
]
|
||||||
|
"test/dynamo/test_misc.py" = [
|
||||||
|
"PGH004",
|
||||||
|
]
|
||||||
"test/jit/**" = [
|
"test/jit/**" = [
|
||||||
"PLR0133", # tests require this for JIT
|
"PLR0133", # tests require this for JIT
|
||||||
"PYI",
|
"PYI",
|
||||||
|
|
@ -212,12 +240,20 @@ select = [
|
||||||
"RUF015",
|
"RUF015",
|
||||||
"UP", # We don't want to modify the jit test as they test specify syntax
|
"UP", # We don't want to modify the jit test as they test specify syntax
|
||||||
]
|
]
|
||||||
|
"test/inductor/s429861_repro.py" = [
|
||||||
|
"PGH004",
|
||||||
|
]
|
||||||
"test/inductor/test_torchinductor.py" = [
|
"test/inductor/test_torchinductor.py" = [
|
||||||
"UP037",
|
"UP037",
|
||||||
]
|
]
|
||||||
# autogenerated #TODO figure out why file level noqa is ignored
|
# autogenerated #TODO figure out why file level noqa is ignored
|
||||||
|
"torch/_appdirs.py" = ["PGH004"]
|
||||||
|
"torch/jit/_shape_functions.py" = ["PGH004"]
|
||||||
"torch/_inductor/fx_passes/serialized_patterns/**" = ["F401", "F501"]
|
"torch/_inductor/fx_passes/serialized_patterns/**" = ["F401", "F501"]
|
||||||
"torch/_inductor/autoheuristic/artifacts/**" = ["F401", "F501"]
|
"torch/_inductor/autoheuristic/artifacts/**" = ["F401", "F501"]
|
||||||
|
"torch/_inductor/codegen/**" = [
|
||||||
|
"PGH004"
|
||||||
|
]
|
||||||
"torchgen/api/types/__init__.py" = [
|
"torchgen/api/types/__init__.py" = [
|
||||||
"F401",
|
"F401",
|
||||||
"F403",
|
"F403",
|
||||||
|
|
@ -232,3 +268,6 @@ select = [
|
||||||
"torch/_vendor/**" = [
|
"torch/_vendor/**" = [
|
||||||
"UP", # No need to mess with _vendor
|
"UP", # No need to mess with _vendor
|
||||||
]
|
]
|
||||||
|
"tools/linter/**" = [
|
||||||
|
"LOG015" # please fix
|
||||||
|
]
|
||||||
|
|
|
||||||
|
|
@ -24,9 +24,15 @@ from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, T
|
||||||
|
|
||||||
# TODO: Once more test files are created, move the contents to a ao folder.
|
# TODO: Once more test files are created, move the contents to a ao folder.
|
||||||
|
|
||||||
logging.basicConfig(
|
logger = logging.getLogger(__name__)
|
||||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
|
logger.setLevel(logging.INFO)
|
||||||
)
|
|
||||||
|
handler = logging.StreamHandler()
|
||||||
|
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||||
|
handler.setFormatter(formatter)
|
||||||
|
|
||||||
|
logger.addHandler(handler)
|
||||||
|
logger.propagate = False # Prevent duplicate logs if root logger also has handlers
|
||||||
|
|
||||||
|
|
||||||
class TestQuantizedSparseKernels(TestCase):
|
class TestQuantizedSparseKernels(TestCase):
|
||||||
|
|
@ -78,10 +84,10 @@ class TestQuantizedSparseKernels(TestCase):
|
||||||
|
|
||||||
for use_channelwise, dynamic_mode in product([True, False], [True, False]):
|
for use_channelwise, dynamic_mode in product([True, False], [True, False]):
|
||||||
if qengine_is_fbgemm() and dynamic_mode:
|
if qengine_is_fbgemm() and dynamic_mode:
|
||||||
logging.info("dynamic sparse qlinear is only available in qnnpack")
|
logger.info("dynamic sparse qlinear is only available in qnnpack")
|
||||||
continue
|
continue
|
||||||
if qengine_is_qnnpack() and not dynamic_mode:
|
if qengine_is_qnnpack() and not dynamic_mode:
|
||||||
logging.info("static sparse qlinear is only available in fbgemm")
|
logger.info("static sparse qlinear is only available in fbgemm")
|
||||||
continue
|
continue
|
||||||
if use_channelwise:
|
if use_channelwise:
|
||||||
W_q = torch.quantize_per_channel(
|
W_q = torch.quantize_per_channel(
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
# Owner(s): ["module: dynamo"]
|
# Owner(s): ["module: dynamo"]
|
||||||
# flake8: noqa
|
# flake8: noqa: B950
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
import itertools
|
import itertools
|
||||||
|
|
@ -13,7 +13,6 @@ from torch import _inductor as inductor
|
||||||
from torch._dynamo import compiled_autograd
|
from torch._dynamo import compiled_autograd
|
||||||
from torch._dynamo._trace_wrapped_higher_order_op import trace_wrapped
|
from torch._dynamo._trace_wrapped_higher_order_op import trace_wrapped
|
||||||
from torch._dynamo.testing import normalize_gm
|
from torch._dynamo.testing import normalize_gm
|
||||||
from torch._dynamo.utils import counters
|
|
||||||
from torch.fx.experimental.proxy_tensor import make_fx
|
from torch.fx.experimental.proxy_tensor import make_fx
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,17 +1,13 @@
|
||||||
# Owner(s): ["module: dynamo"]
|
# Owner(s): ["module: dynamo"]
|
||||||
|
|
||||||
# ruff: noqa: TRY002
|
# ruff: noqa: TRY002
|
||||||
# flake8: noqa
|
|
||||||
|
|
||||||
import dataclasses
|
|
||||||
import gc
|
|
||||||
import itertools
|
import itertools
|
||||||
import types
|
import types
|
||||||
import unittest
|
import unittest
|
||||||
import weakref
|
import weakref
|
||||||
from collections import defaultdict, namedtuple, OrderedDict
|
from collections import defaultdict, namedtuple, OrderedDict
|
||||||
from dataclasses import dataclass, fields, is_dataclass
|
from typing import Any
|
||||||
from typing import Any, Optional, Tuple
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch._dynamo.config
|
import torch._dynamo.config
|
||||||
|
|
@ -22,8 +18,6 @@ import torch.nn
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch._dynamo.testing import same
|
from torch._dynamo.testing import same
|
||||||
from torch._dynamo.utils import dict_items
|
from torch._dynamo.utils import dict_items
|
||||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
|
||||||
from torch.testing._internal.common_utils import TestCase
|
|
||||||
|
|
||||||
|
|
||||||
class SimpleDict(dict):
|
class SimpleDict(dict):
|
||||||
|
|
@ -435,7 +429,7 @@ class DictTests(torch._dynamo.test_case.TestCase):
|
||||||
config = dotdict({"a": 1, "b": 2})
|
config = dotdict({"a": 1, "b": 2})
|
||||||
|
|
||||||
def fn(x):
|
def fn(x):
|
||||||
x2 = x * 2
|
x2 = x * 2 # noqa: F841
|
||||||
x3 = x * config.get("a", 3)
|
x3 = x * config.get("a", 3)
|
||||||
return x3
|
return x3
|
||||||
|
|
||||||
|
|
@ -643,8 +637,8 @@ class DictTests(torch._dynamo.test_case.TestCase):
|
||||||
):
|
):
|
||||||
|
|
||||||
class CustomDict(super_class):
|
class CustomDict(super_class):
|
||||||
def __new__(self, *args, **kwargs):
|
def __new__(cls, *args, **kwargs):
|
||||||
return super().__new__(self, *args, **kwargs)
|
return super().__new__(cls, *args, **kwargs)
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
@ -806,7 +800,7 @@ class DictTests(torch._dynamo.test_case.TestCase):
|
||||||
d = {"a": 2, "b": 3, "c": 5 * x}
|
d = {"a": 2, "b": 3, "c": 5 * x}
|
||||||
mp = types.MappingProxyType(d)
|
mp = types.MappingProxyType(d)
|
||||||
y = torch.sin(x * mp["a"])
|
y = torch.sin(x * mp["a"])
|
||||||
for k, v in mp.items():
|
for k, v in mp.items(): # noqa: PERF102
|
||||||
y += torch.cos(x * v)
|
y += torch.cos(x * v)
|
||||||
return mp
|
return mp
|
||||||
|
|
||||||
|
|
@ -823,7 +817,7 @@ class DictTests(torch._dynamo.test_case.TestCase):
|
||||||
def fn(x):
|
def fn(x):
|
||||||
mp = types.MappingProxyType(d)
|
mp = types.MappingProxyType(d)
|
||||||
y = torch.sin(x * mp["a"])
|
y = torch.sin(x * mp["a"])
|
||||||
for k, v in mp.items():
|
for k, v in mp.items(): # noqa: PERF102
|
||||||
y += torch.cos(x * v)
|
y += torch.cos(x * v)
|
||||||
d["d"] = 4
|
d["d"] = 4
|
||||||
return mp
|
return mp
|
||||||
|
|
@ -844,7 +838,7 @@ class DictTests(torch._dynamo.test_case.TestCase):
|
||||||
|
|
||||||
def fn(x, mp):
|
def fn(x, mp):
|
||||||
y = torch.sin(x * mp["a"])
|
y = torch.sin(x * mp["a"])
|
||||||
for k, v in mp.items():
|
for k, v in mp.items(): # noqa: PERF102
|
||||||
y += torch.cos(x * v)
|
y += torch.cos(x * v)
|
||||||
return y
|
return y
|
||||||
|
|
||||||
|
|
@ -939,7 +933,7 @@ class DictTests(torch._dynamo.test_case.TestCase):
|
||||||
|
|
||||||
def test_items_type(self):
|
def test_items_type(self):
|
||||||
def fn():
|
def fn():
|
||||||
d = dict({"a": 1, "b": "2", "c": torch.tensor(3)})
|
d = dict({"a": 1, "b": "2", "c": torch.tensor(3)}) # noqa: C418
|
||||||
return d.items()
|
return d.items()
|
||||||
|
|
||||||
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
# Owner(s): ["module: dynamo"]
|
# Owner(s): ["module: dynamo"]
|
||||||
# flake8: noqa
|
# flake8: noqa: B950
|
||||||
import torch
|
import torch
|
||||||
import torch._dynamo
|
import torch._dynamo
|
||||||
import torch._dynamo.test_case
|
import torch._dynamo.test_case
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,16 @@ from torch._C import parse_schema, Tag
|
||||||
|
|
||||||
|
|
||||||
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
|
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
|
||||||
logging.basicConfig(level=logging.INFO, format=FORMAT)
|
|
||||||
|
log = logging.getLogger("log")
|
||||||
|
log.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
handler = logging.StreamHandler()
|
||||||
|
formatter = logging.Formatter(FORMAT)
|
||||||
|
handler.setFormatter(formatter)
|
||||||
|
|
||||||
|
log.addHandler(handler)
|
||||||
|
log.propagate = False # Avoid double logging if root logger has handlers
|
||||||
|
|
||||||
# How to run this test locally:
|
# How to run this test locally:
|
||||||
# 1 Have two virtual environments (eg conda env), one without PyTorch installed (venv_nightly)
|
# 1 Have two virtual environments (eg conda env), one without PyTorch installed (venv_nightly)
|
||||||
|
|
@ -259,10 +268,10 @@ def check_bc(existing_schemas):
|
||||||
is_allow_list, trust_not_core_aten = allow_listed(existing_schema)
|
is_allow_list, trust_not_core_aten = allow_listed(existing_schema)
|
||||||
if is_allow_list:
|
if is_allow_list:
|
||||||
if trust_not_core_aten or not is_core_aten_op(existing_schema):
|
if trust_not_core_aten or not is_core_aten_op(existing_schema):
|
||||||
logging.info("schema: %s found on allowlist, skipping", existing_schema)
|
log.info("schema: %s found on allowlist, skipping", existing_schema)
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
logging.info(
|
log.info(
|
||||||
"schema: %s found on allowlist, but is a core ATen op, checking BC. "
|
"schema: %s found on allowlist, but is a core ATen op, checking BC. "
|
||||||
"NOTE: If you have removed an operator we will conservatively assume that "
|
"NOTE: If you have removed an operator we will conservatively assume that "
|
||||||
"it is a core ATen op. If the operator you removed is not a core ATen op, "
|
"it is a core ATen op. If the operator you removed is not a core ATen op, "
|
||||||
|
|
@ -272,13 +281,13 @@ def check_bc(existing_schemas):
|
||||||
)
|
)
|
||||||
if has_valid_upgraders(existing_schema, version_map):
|
if has_valid_upgraders(existing_schema, version_map):
|
||||||
if not is_core_aten_op(existing_schema):
|
if not is_core_aten_op(existing_schema):
|
||||||
logging.info("schema: %s has valid upgrader, skipping", existing_schema)
|
log.info("schema: %s has valid upgrader, skipping", existing_schema)
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
logging.info(
|
log.info(
|
||||||
"schema: %s has a valid upgrader, but is a core ATen op, checking BC"
|
"schema: %s has a valid upgrader, but is a core ATen op, checking BC"
|
||||||
)
|
)
|
||||||
logging.debug("processing existing schema: %s", existing_schema)
|
log.debug("processing existing schema: %s", existing_schema)
|
||||||
matching_new_schemas = new_schema_dict.get(existing_schema.name, [])
|
matching_new_schemas = new_schema_dict.get(existing_schema.name, [])
|
||||||
found = False
|
found = False
|
||||||
for matching_new_schema in matching_new_schemas:
|
for matching_new_schema in matching_new_schemas:
|
||||||
|
|
@ -286,7 +295,7 @@ def check_bc(existing_schemas):
|
||||||
found = True
|
found = True
|
||||||
break
|
break
|
||||||
if not found:
|
if not found:
|
||||||
logging.warning(
|
log.warning(
|
||||||
"Can NOT find backward compatible schemas after changes "
|
"Can NOT find backward compatible schemas after changes "
|
||||||
"for schema %s from the following candidates:\n[\n%s\n]",
|
"for schema %s from the following candidates:\n[\n%s\n]",
|
||||||
str(existing_schema),
|
str(existing_schema),
|
||||||
|
|
@ -296,9 +305,9 @@ def check_bc(existing_schemas):
|
||||||
broken_ops.append(str(existing_schema))
|
broken_ops.append(str(existing_schema))
|
||||||
is_bc = False
|
is_bc = False
|
||||||
if is_bc:
|
if is_bc:
|
||||||
logging.info("Found backward compatible schemas for all existing schemas")
|
log.info("Found backward compatible schemas for all existing schemas")
|
||||||
else:
|
else:
|
||||||
logging.warning(
|
log.warning(
|
||||||
"The PR is introducing backward incompatible changes to the "
|
"The PR is introducing backward incompatible changes to the "
|
||||||
"operator library. Please contact PyTorch team to confirm "
|
"operator library. Please contact PyTorch team to confirm "
|
||||||
"whether this change is wanted or not. \n\nBroken ops: "
|
"whether this change is wanted or not. \n\nBroken ops: "
|
||||||
|
|
@ -315,9 +324,9 @@ def check_fc(existing_schemas):
|
||||||
for existing_schema in existing_schemas:
|
for existing_schema in existing_schemas:
|
||||||
is_allow_list, _ = allow_listed(existing_schema)
|
is_allow_list, _ = allow_listed(existing_schema)
|
||||||
if is_allow_list:
|
if is_allow_list:
|
||||||
logging.info("schema: %s found on allowlist, skipping", existing_schema)
|
log.info("schema: %s found on allowlist, skipping", existing_schema)
|
||||||
continue
|
continue
|
||||||
logging.info("processing existing schema: %s", existing_schema)
|
log.info("processing existing schema: %s", existing_schema)
|
||||||
matching_new_schemas = new_schema_dict.get(existing_schema.name, [])
|
matching_new_schemas = new_schema_dict.get(existing_schema.name, [])
|
||||||
found = False
|
found = False
|
||||||
possible_failure_reasons = []
|
possible_failure_reasons = []
|
||||||
|
|
@ -331,13 +340,13 @@ def check_fc(existing_schemas):
|
||||||
if reason != "":
|
if reason != "":
|
||||||
possible_failure_reasons.append(reason)
|
possible_failure_reasons.append(reason)
|
||||||
if not found:
|
if not found:
|
||||||
logging.warning(
|
log.warning(
|
||||||
"Can NOT find forward compatible schemas after changes "
|
"Can NOT find forward compatible schemas after changes "
|
||||||
"for schema %s from the following candidates:\n[\n\t%s\n]",
|
"for schema %s from the following candidates:\n[\n\t%s\n]",
|
||||||
str(existing_schema),
|
str(existing_schema),
|
||||||
"\n\t".join(str(s) for s in matching_new_schemas),
|
"\n\t".join(str(s) for s in matching_new_schemas),
|
||||||
)
|
)
|
||||||
logging.warning(
|
log.warning(
|
||||||
"Refer to following reasons for failure "
|
"Refer to following reasons for failure "
|
||||||
"to find FC schema:\n[\n%s\n]",
|
"to find FC schema:\n[\n%s\n]",
|
||||||
"\n\t".join(str(r) for r in possible_failure_reasons),
|
"\n\t".join(str(r) for r in possible_failure_reasons),
|
||||||
|
|
@ -345,9 +354,9 @@ def check_fc(existing_schemas):
|
||||||
broken_ops.append(str(existing_schema))
|
broken_ops.append(str(existing_schema))
|
||||||
is_fc = False
|
is_fc = False
|
||||||
if is_fc:
|
if is_fc:
|
||||||
logging.info("Found forward compatible schemas for all existing schemas")
|
log.info("Found forward compatible schemas for all existing schemas")
|
||||||
else:
|
else:
|
||||||
logging.warning(
|
log.warning(
|
||||||
"The PR is introducing a potentially forward incompatible changes to the "
|
"The PR is introducing a potentially forward incompatible changes to the "
|
||||||
"operator library. Please contact PyTorch team to confirm "
|
"operator library. Please contact PyTorch team to confirm "
|
||||||
"whether this change is wanted or not. \n\nBroken ops: "
|
"whether this change is wanted or not. \n\nBroken ops: "
|
||||||
|
|
@ -374,7 +383,7 @@ if __name__ == "__main__":
|
||||||
break
|
break
|
||||||
|
|
||||||
if dont_parse(line.strip()):
|
if dont_parse(line.strip()):
|
||||||
logging.info("Not parsing schema line: %s", line.strip())
|
log.info("Not parsing schema line: %s", line.strip())
|
||||||
continue
|
continue
|
||||||
s = parse_schema(line.strip())
|
s = parse_schema(line.strip())
|
||||||
slist.append(s)
|
slist.append(s)
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,5 @@
|
||||||
# Owner(s): ["module: inductor"]
|
# Owner(s): ["module: inductor"]
|
||||||
# ruff: noqa: F841
|
# ruff: noqa: F841
|
||||||
# flake8: noqa
|
|
||||||
import collections
|
import collections
|
||||||
import collections.abc
|
import collections.abc
|
||||||
import copy
|
import copy
|
||||||
|
|
@ -296,7 +295,7 @@ class TestJointOps(TestCase):
|
||||||
self.s.z = ["z"]
|
self.s.z = ["z"]
|
||||||
p = pickle.dumps(self.s, i)
|
p = pickle.dumps(self.s, i)
|
||||||
dup = pickle.loads(p)
|
dup = pickle.loads(p)
|
||||||
self.assertEqual(self.s, dup, "%s != %s" % (self.s, dup))
|
self.assertEqual(self.s, dup, "%s != %s" % (self.s, dup)) # noqa: UP031
|
||||||
if type(self.s) not in (OrderedSet, frozenset):
|
if type(self.s) not in (OrderedSet, frozenset):
|
||||||
self.assertEqual(self.s.x, dup.x)
|
self.assertEqual(self.s.x, dup.x)
|
||||||
self.assertEqual(self.s.z, dup.z)
|
self.assertEqual(self.s.z, dup.z)
|
||||||
|
|
@ -390,7 +389,7 @@ class TestJointOps(TestCase):
|
||||||
self.assertEqual(repr(s), "{OrderedSet(...)}")
|
self.assertEqual(repr(s), "{OrderedSet(...)}")
|
||||||
else:
|
else:
|
||||||
name = repr(s).partition("(")[0] # strip class name
|
name = repr(s).partition("(")[0] # strip class name
|
||||||
self.assertEqual(repr(s), "%s({%s(...)})" % (name, name))
|
self.assertEqual(repr(s), "%s({%s(...)})" % (name, name)) # noqa: UP031
|
||||||
|
|
||||||
@unittest.skip("Different hashing")
|
@unittest.skip("Different hashing")
|
||||||
def test_do_not_rehash_dict_keys(self):
|
def test_do_not_rehash_dict_keys(self):
|
||||||
|
|
@ -454,7 +453,7 @@ class TestSet(TestJointOps, TestCase):
|
||||||
|
|
||||||
def test_set_literal_insertion_order(self):
|
def test_set_literal_insertion_order(self):
|
||||||
# SF Issue #26020 -- Expect left to right insertion
|
# SF Issue #26020 -- Expect left to right insertion
|
||||||
s = {1, 1.0, True}
|
s = {1, 1.0, True} # noqa: B033
|
||||||
self.assertEqual(len(s), 1)
|
self.assertEqual(len(s), 1)
|
||||||
stored_value = s.pop()
|
stored_value = s.pop()
|
||||||
self.assertEqual(type(stored_value), int)
|
self.assertEqual(type(stored_value), int)
|
||||||
|
|
@ -715,19 +714,19 @@ class TestSet(TestJointOps, TestCase):
|
||||||
myset = {1, 2, 3}
|
myset = {1, 2, 3}
|
||||||
|
|
||||||
myobj = TestRichSetCompare()
|
myobj = TestRichSetCompare()
|
||||||
myset < myobj
|
myset < myobj # noqa: B015
|
||||||
self.assertTrue(myobj.gt_called)
|
self.assertTrue(myobj.gt_called)
|
||||||
|
|
||||||
myobj = TestRichSetCompare()
|
myobj = TestRichSetCompare()
|
||||||
myset > myobj
|
myset > myobj # noqa: B015
|
||||||
self.assertTrue(myobj.lt_called)
|
self.assertTrue(myobj.lt_called)
|
||||||
|
|
||||||
myobj = TestRichSetCompare()
|
myobj = TestRichSetCompare()
|
||||||
myset <= myobj
|
myset <= myobj # noqa: B015
|
||||||
self.assertTrue(myobj.ge_called)
|
self.assertTrue(myobj.ge_called)
|
||||||
|
|
||||||
myobj = TestRichSetCompare()
|
myobj = TestRichSetCompare()
|
||||||
myset >= myobj
|
myset >= myobj # noqa: B015
|
||||||
self.assertTrue(myobj.le_called)
|
self.assertTrue(myobj.le_called)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -834,7 +833,9 @@ class TestBasicOps(TestCase):
|
||||||
p = pickle.dumps(self.OrderedSet, proto)
|
p = pickle.dumps(self.OrderedSet, proto)
|
||||||
copy = pickle.loads(p)
|
copy = pickle.loads(p)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
self.OrderedSet, copy, "%s != %s" % (self.OrderedSet, copy)
|
self.OrderedSet,
|
||||||
|
copy,
|
||||||
|
"%s != %s" % (self.OrderedSet, copy), # noqa: UP031
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_issue_37219(self):
|
def test_issue_37219(self):
|
||||||
|
|
@ -1195,7 +1196,7 @@ class TestMutate(TestCase):
|
||||||
expected_len = 0
|
expected_len = 0
|
||||||
for v in self.values:
|
for v in self.values:
|
||||||
tmp.add(v)
|
tmp.add(v)
|
||||||
expected_len += 1
|
expected_len += 1 # noqa: SIM113
|
||||||
self.assertEqual(len(tmp), expected_len)
|
self.assertEqual(len(tmp), expected_len)
|
||||||
self.assertEqual(tmp, self.OrderedSet)
|
self.assertEqual(tmp, self.OrderedSet)
|
||||||
|
|
||||||
|
|
@ -1518,7 +1519,7 @@ class TestOnlySetsString(TestOnlySetsInBinaryOps, TestCase):
|
||||||
class TestOnlySetsGenerator(TestOnlySetsInBinaryOps, TestCase):
|
class TestOnlySetsGenerator(TestOnlySetsInBinaryOps, TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
def gen():
|
def gen():
|
||||||
for i in range(0, 10, 2):
|
for i in range(0, 10, 2): # noqa: UP028
|
||||||
yield i
|
yield i
|
||||||
|
|
||||||
self.OrderedSet = OrderedSet((1, 2, 3))
|
self.OrderedSet = OrderedSet((1, 2, 3))
|
||||||
|
|
@ -1541,7 +1542,7 @@ class TestCopying:
|
||||||
|
|
||||||
def test_deep_copy(self):
|
def test_deep_copy(self):
|
||||||
dup = copy.deepcopy(self.OrderedSet)
|
dup = copy.deepcopy(self.OrderedSet)
|
||||||
##print type(dup), repr(dup)
|
# print type(dup), repr(dup)
|
||||||
dup_list = sorted(dup, key=repr)
|
dup_list = sorted(dup, key=repr)
|
||||||
set_list = sorted(self.OrderedSet, key=repr)
|
set_list = sorted(self.OrderedSet, key=repr)
|
||||||
self.assertEqual(len(dup_list), len(set_list))
|
self.assertEqual(len(dup_list), len(set_list))
|
||||||
|
|
@ -1641,7 +1642,7 @@ class TestIdentities(TestCase):
|
||||||
|
|
||||||
def R(seqn):
|
def R(seqn):
|
||||||
"Regular generator"
|
"Regular generator"
|
||||||
for i in seqn:
|
for i in seqn: # noqa: UP028
|
||||||
yield i
|
yield i
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1655,7 +1656,7 @@ class G:
|
||||||
return self.seqn[i]
|
return self.seqn[i]
|
||||||
|
|
||||||
|
|
||||||
class I:
|
class I: # noqa: E742
|
||||||
"Sequence using iterator protocol"
|
"Sequence using iterator protocol"
|
||||||
|
|
||||||
def __init__(self, seqn):
|
def __init__(self, seqn):
|
||||||
|
|
@ -1681,7 +1682,7 @@ class Ig:
|
||||||
self.i = 0
|
self.i = 0
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
for val in self.seqn:
|
for val in self.seqn: # noqa: UP028
|
||||||
yield val
|
yield val
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1743,7 +1744,7 @@ from itertools import chain
|
||||||
|
|
||||||
def L(seqn):
|
def L(seqn):
|
||||||
"Test multiple tiers of iterators"
|
"Test multiple tiers of iterators"
|
||||||
return chain(map(lambda x: x, R(Ig(G(seqn)))))
|
return chain(map(lambda x: x, R(Ig(G(seqn))))) # noqa: C417
|
||||||
|
|
||||||
|
|
||||||
class TestVariousIteratorArgs(TestCase):
|
class TestVariousIteratorArgs(TestCase):
|
||||||
|
|
@ -1909,7 +1910,7 @@ def powerset(U):
|
||||||
def cube(n):
|
def cube(n):
|
||||||
"""Graph of n-dimensional hypercube."""
|
"""Graph of n-dimensional hypercube."""
|
||||||
singletons = [frozenset([x]) for x in range(n)]
|
singletons = [frozenset([x]) for x in range(n)]
|
||||||
return dict(
|
return dict( # noqa: C404
|
||||||
[(x, frozenset([x ^ s for s in singletons])) for x in powerset(range(n))]
|
[(x, frozenset([x ^ s for s in singletons])) for x in powerset(range(n))]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -1946,7 +1947,7 @@ def faces(G):
|
||||||
f.add(frozenset([v1, v2, v3, v4]))
|
f.add(frozenset([v1, v2, v3, v4]))
|
||||||
else:
|
else:
|
||||||
for v5 in G[v4]:
|
for v5 in G[v4]:
|
||||||
if v5 == v3 or v5 == v2:
|
if v5 == v3 or v5 == v2: # noqa: SIM109
|
||||||
continue
|
continue
|
||||||
if v1 in G[v5]:
|
if v1 in G[v5]:
|
||||||
f.add(frozenset([v1, v2, v3, v4, v5]))
|
f.add(frozenset([v1, v2, v3, v4, v5]))
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,5 @@
|
||||||
# Owner(s): ["oncall: jit"]
|
# Owner(s): ["oncall: jit"]
|
||||||
# flake8: noqa
|
|
||||||
|
|
||||||
import sys
|
|
||||||
import unittest
|
|
||||||
from dataclasses import dataclass, field, InitVar
|
from dataclasses import dataclass, field, InitVar
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,5 @@
|
||||||
# Owner(s): ["oncall: jit"]
|
# Owner(s): ["oncall: jit"]
|
||||||
# flake8: noqa
|
|
||||||
|
|
||||||
import sys
|
|
||||||
import unittest
|
|
||||||
from enum import Enum
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from jit.myfunction_a import my_function_a
|
from jit.myfunction_a import my_function_a
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
# Owner(s): ["module: onnx"]
|
# Owner(s): ["module: onnx"]
|
||||||
|
# flake8: noqa: B950
|
||||||
"""Test op correctness by comparing with PyTorch results.
|
"""Test op correctness by comparing with PyTorch results.
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
@ -32,14 +33,13 @@ wrangler function. See `_mean_input_wrangler` for an example.
|
||||||
op, use `ops_test_common.duplicate_opinfo` to create new OpInfo with new names and map each
|
op, use `ops_test_common.duplicate_opinfo` to create new OpInfo with new names and map each
|
||||||
to one overload.
|
to one overload.
|
||||||
"""
|
"""
|
||||||
# flake8: noqa
|
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import functools
|
import functools
|
||||||
from typing import Any, Callable, Collection, Optional
|
from typing import Any, Callable, Optional, TYPE_CHECKING
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
@ -51,6 +51,10 @@ from torch.testing._internal import common_methods_invocations
|
||||||
from torch.testing._internal.opinfo import definitions as opinfo_definitions
|
from torch.testing._internal.opinfo import definitions as opinfo_definitions
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Collection
|
||||||
|
|
||||||
|
|
||||||
# Create a copy of the op_db to modify
|
# Create a copy of the op_db to modify
|
||||||
OPS_DB = copy.deepcopy(common_methods_invocations.op_db)
|
OPS_DB = copy.deepcopy(common_methods_invocations.op_db)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -32,7 +32,7 @@ class TestGraphUtils(TestCase):
|
||||||
example_inputs = (torch.randn(1, 3, 5, 5),)
|
example_inputs = (torch.randn(1, 3, 5, 5),)
|
||||||
|
|
||||||
# program capture
|
# program capture
|
||||||
m, guards = torchdynamo.export( # noqa: F841©
|
m, guards = torchdynamo.export( # noqa: F841
|
||||||
m,
|
m,
|
||||||
*copy.deepcopy(example_inputs),
|
*copy.deepcopy(example_inputs),
|
||||||
aten_graph=True,
|
aten_graph=True,
|
||||||
|
|
|
||||||
|
|
@ -6,9 +6,17 @@ import torch
|
||||||
import torch.distributed as c10d
|
import torch.distributed as c10d
|
||||||
|
|
||||||
|
|
||||||
logging.basicConfig(
|
FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
|
|
||||||
)
|
log = logging.getLogger("log")
|
||||||
|
log.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
handler = logging.StreamHandler()
|
||||||
|
formatter = logging.Formatter(FORMAT)
|
||||||
|
handler.setFormatter(formatter)
|
||||||
|
|
||||||
|
log.addHandler(handler)
|
||||||
|
log.propagate = False # Prevent log duplication
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
|
|
@ -29,14 +37,14 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
store = c10d.TCPStore(args.addr, port, world_size, rank == 0)
|
store = c10d.TCPStore(args.addr, port, world_size, rank == 0)
|
||||||
process_group = c10d.ProcessGroupNCCL(store, rank, world_size)
|
process_group = c10d.ProcessGroupNCCL(store, rank, world_size)
|
||||||
logging.info("Running first allreduce")
|
log.info("Running first allreduce")
|
||||||
process_group.allreduce(torch.rand(10).cuda(rank)).wait()
|
process_group.allreduce(torch.rand(10).cuda(rank)).wait()
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
logging.info("Running second allreduce only on rank 0")
|
log.info("Running second allreduce only on rank 0")
|
||||||
work = process_group.allreduce(torch.rand(10).cuda(rank))
|
work = process_group.allreduce(torch.rand(10).cuda(rank))
|
||||||
logging.info("Waiting for allreduce to complete...")
|
log.info("Waiting for allreduce to complete...")
|
||||||
work.wait()
|
work.wait()
|
||||||
logging.info("Second allreduce successful: %s", work.is_success())
|
log.info("Second allreduce successful: %s", work.is_success())
|
||||||
else:
|
else:
|
||||||
logging.info("Aborting all other ranks.")
|
log.info("Aborting all other ranks.")
|
||||||
os.abort()
|
os.abort()
|
||||||
|
|
|
||||||
|
|
@ -181,7 +181,7 @@ class Pair(NamedTuple):
|
||||||
|
|
||||||
|
|
||||||
# for testing pytrees
|
# for testing pytrees
|
||||||
class Foo: # noqa: B209
|
class Foo:
|
||||||
def __init__(self, a, b):
|
def __init__(self, a, b):
|
||||||
self.a = a
|
self.a = a
|
||||||
self.b = b
|
self.b = b
|
||||||
|
|
|
||||||
|
|
@ -38,12 +38,13 @@ from quantization.core.test_workflow_module import TestDistributed # noqa: F401
|
||||||
from quantization.core.test_workflow_module import TestFusedObsFakeQuantModule # noqa: F401
|
from quantization.core.test_workflow_module import TestFusedObsFakeQuantModule # noqa: F401
|
||||||
from quantization.core.test_backend_config import TestBackendConfig # noqa: F401
|
from quantization.core.test_backend_config import TestBackendConfig # noqa: F401
|
||||||
from quantization.core.test_utils import TestUtils # noqa: F401
|
from quantization.core.test_utils import TestUtils # noqa: F401
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
try:
|
try:
|
||||||
# This test has extra data dependencies, so in some environments, e.g. Meta internal
|
# This test has extra data dependencies, so in some environments, e.g. Meta internal
|
||||||
# Buck, it has its own test runner.
|
# Buck, it has its own test runner.
|
||||||
from quantization.core.test_docs import TestQuantizationDocs # noqa: F401
|
from quantization.core.test_docs import TestQuantizationDocs # noqa: F401
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logging.warning(e)
|
log.warning(e)
|
||||||
|
|
||||||
# Eager Mode Workflow. Tests for the functionality of APIs and different features implemented
|
# Eager Mode Workflow. Tests for the functionality of APIs and different features implemented
|
||||||
# using eager mode.
|
# using eager mode.
|
||||||
|
|
@ -77,7 +78,7 @@ try:
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
# In FBCode we separate FX out into a separate target for the sake of dev
|
# In FBCode we separate FX out into a separate target for the sake of dev
|
||||||
# velocity. These are covered by a separate test target `quantization_fx`
|
# velocity. These are covered by a separate test target `quantization_fx`
|
||||||
logging.warning(e)
|
log.warning(e)
|
||||||
|
|
||||||
# PyTorch 2 Export Quantization
|
# PyTorch 2 Export Quantization
|
||||||
try:
|
try:
|
||||||
|
|
@ -99,7 +100,7 @@ try:
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
# In FBCode we separate PT2 out into a separate target for the sake of dev
|
# In FBCode we separate PT2 out into a separate target for the sake of dev
|
||||||
# velocity. These are covered by a separate test target `quantization_pt2e`
|
# velocity. These are covered by a separate test target `quantization_pt2e`
|
||||||
logging.warning(e)
|
log.warning(e)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from quantization.fx.test_numeric_suite_fx import TestFXGraphMatcher # noqa: F401
|
from quantization.fx.test_numeric_suite_fx import TestFXGraphMatcher # noqa: F401
|
||||||
|
|
@ -108,7 +109,7 @@ try:
|
||||||
from quantization.fx.test_numeric_suite_fx import TestFXNumericSuiteNShadows # noqa: F401
|
from quantization.fx.test_numeric_suite_fx import TestFXNumericSuiteNShadows # noqa: F401
|
||||||
from quantization.fx.test_numeric_suite_fx import TestFXNumericSuiteCoreAPIsModels # noqa: F401
|
from quantization.fx.test_numeric_suite_fx import TestFXNumericSuiteCoreAPIsModels # noqa: F401
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logging.warning(e)
|
log.warning(e)
|
||||||
|
|
||||||
# Test the model report module
|
# Test the model report module
|
||||||
try:
|
try:
|
||||||
|
|
@ -120,19 +121,19 @@ try:
|
||||||
from quantization.fx.test_model_report_fx import TestFxDetectOutliers # noqa: F401
|
from quantization.fx.test_model_report_fx import TestFxDetectOutliers # noqa: F401
|
||||||
from quantization.fx.test_model_report_fx import TestFxModelReportVisualizer # noqa: F401
|
from quantization.fx.test_model_report_fx import TestFxModelReportVisualizer # noqa: F401
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logging.warning(e)
|
log.warning(e)
|
||||||
|
|
||||||
# Equalization for FX mode
|
# Equalization for FX mode
|
||||||
try:
|
try:
|
||||||
from quantization.fx.test_equalize_fx import TestEqualizeFx # noqa: F401
|
from quantization.fx.test_equalize_fx import TestEqualizeFx # noqa: F401
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logging.warning(e)
|
log.warning(e)
|
||||||
|
|
||||||
# Backward Compatibility. Tests serialization and BC for quantized modules.
|
# Backward Compatibility. Tests serialization and BC for quantized modules.
|
||||||
try:
|
try:
|
||||||
from quantization.bc.test_backward_compatibility import TestSerialization # noqa: F401
|
from quantization.bc.test_backward_compatibility import TestSerialization # noqa: F401
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logging.warning(e)
|
log.warning(e)
|
||||||
|
|
||||||
# JIT Graph Mode Quantization
|
# JIT Graph Mode Quantization
|
||||||
from quantization.jit.test_quantize_jit import TestQuantizeJit # noqa: F401
|
from quantization.jit.test_quantize_jit import TestQuantizeJit # noqa: F401
|
||||||
|
|
@ -151,29 +152,29 @@ from quantization.ao_migration.test_ao_migration import TestAOMigrationNNIntrins
|
||||||
try:
|
try:
|
||||||
from quantization.ao_migration.test_quantization_fx import TestAOMigrationQuantizationFx # noqa: F401
|
from quantization.ao_migration.test_quantization_fx import TestAOMigrationQuantizationFx # noqa: F401
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logging.warning(e)
|
log.warning(e)
|
||||||
|
|
||||||
# Experimental functionality
|
# Experimental functionality
|
||||||
try:
|
try:
|
||||||
from quantization.core.experimental.test_bits import TestBitsCPU # noqa: F401
|
from quantization.core.experimental.test_bits import TestBitsCPU # noqa: F401
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logging.warning(e)
|
log.warning(e)
|
||||||
try:
|
try:
|
||||||
from quantization.core.experimental.test_bits import TestBitsCUDA # noqa: F401
|
from quantization.core.experimental.test_bits import TestBitsCUDA # noqa: F401
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logging.warning(e)
|
log.warning(e)
|
||||||
try:
|
try:
|
||||||
from quantization.core.experimental.test_floatx import TestFloat8DtypeCPU # noqa: F401
|
from quantization.core.experimental.test_floatx import TestFloat8DtypeCPU # noqa: F401
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logging.warning(e)
|
log.warning(e)
|
||||||
try:
|
try:
|
||||||
from quantization.core.experimental.test_floatx import TestFloat8DtypeCUDA # noqa: F401
|
from quantization.core.experimental.test_floatx import TestFloat8DtypeCUDA # noqa: F401
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logging.warning(e)
|
log.warning(e)
|
||||||
try:
|
try:
|
||||||
from quantization.core.experimental.test_floatx import TestFloat8DtypeCPUOnlyCPU # noqa: F401
|
from quantization.core.experimental.test_floatx import TestFloat8DtypeCPUOnlyCPU # noqa: F401
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logging.warning(e)
|
log.warning(e)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
run_tests()
|
run_tests()
|
||||||
|
|
|
||||||
|
|
@ -595,7 +595,7 @@ def logging_manager(*, debug: bool = False) -> Generator[logging.Logger, None, N
|
||||||
print(f"log file: {log_file}")
|
print(f"log file: {log_file}")
|
||||||
yield root_logger
|
yield root_logger
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception("Fatal exception")
|
logging.exception("Fatal exception") # noqa: LOG015
|
||||||
logging_record_exception(e)
|
logging_record_exception(e)
|
||||||
print(f"log file: {log_file}")
|
print(f"log file: {log_file}")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
@ -603,7 +603,7 @@ def logging_manager(*, debug: bool = False) -> Generator[logging.Logger, None, N
|
||||||
# You could logging.debug here to suppress the backtrace
|
# You could logging.debug here to suppress the backtrace
|
||||||
# entirely, but there is no reason to hide it from technically
|
# entirely, but there is no reason to hide it from technically
|
||||||
# savvy users.
|
# savvy users.
|
||||||
logging.info("", exc_info=True)
|
logging.info("", exc_info=True) # noqa: LOG015
|
||||||
logging_record_exception(e)
|
logging_record_exception(e)
|
||||||
print(f"log file: {log_file}")
|
print(f"log file: {log_file}")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
|
||||||
|
|
@ -112,7 +112,7 @@ except ModuleNotFoundError:
|
||||||
try:
|
try:
|
||||||
import torch._logging
|
import torch._logging
|
||||||
import torch._numpy as tnp
|
import torch._numpy as tnp
|
||||||
from torch._guards import detect_fake_mode # noqa: F401n
|
from torch._guards import detect_fake_mode # noqa: F401
|
||||||
from torch._logging import LazyString
|
from torch._logging import LazyString
|
||||||
|
|
||||||
from . import config
|
from . import config
|
||||||
|
|
|
||||||
|
|
@ -939,7 +939,7 @@ class HalideKernel(SIMDKernel):
|
||||||
|
|
||||||
# group the expression by variables used
|
# group the expression by variables used
|
||||||
offset = sympy.S.Zero
|
offset = sympy.S.Zero
|
||||||
split_expr = {s: sympy.S.Zero for s in symbols}
|
split_expr = dict.fromkeys(symbols, sympy.S.Zero)
|
||||||
split_failed: list[tuple[list[sympy.Symbol], sympy.Expr]] = []
|
split_failed: list[tuple[list[sympy.Symbol], sympy.Expr]] = []
|
||||||
index = sympy.expand(self.rename_indexing(index))
|
index = sympy.expand(self.rename_indexing(index))
|
||||||
for part in index.args if isinstance(index, sympy.Add) else [index]:
|
for part in index.args if isinstance(index, sympy.Add) else [index]:
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
import os # noqa: C101
|
import os
|
||||||
import sys
|
import sys
|
||||||
from typing import Any, Callable, Literal, Optional, TYPE_CHECKING, Union
|
from typing import Any, Callable, Literal, Optional, TYPE_CHECKING, Union
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -482,7 +482,7 @@ class StorageWeakRefWrapper:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_weakref_and_data_ptr(
|
def from_weakref_and_data_ptr(
|
||||||
cls: type[S],
|
cls: type[StorageWeakRefWrapper],
|
||||||
cdata: Any,
|
cdata: Any,
|
||||||
data_ptr: int,
|
data_ptr: int,
|
||||||
extra_ref_check: Optional[Callable[[], bool]] = None,
|
extra_ref_check: Optional[Callable[[], bool]] = None,
|
||||||
|
|
|
||||||
|
|
@ -149,7 +149,7 @@ def grouped_gemm_lowering(
|
||||||
has_bias=[bias is not None for bias in b],
|
has_bias=[bias is not None for bias in b],
|
||||||
trans_w=True,
|
trans_w=True,
|
||||||
epilogue_creator=None,
|
epilogue_creator=None,
|
||||||
act_mapping={num: x for num in range(num_gemm)},
|
act_mapping=dict.fromkeys(range(num_gemm), x),
|
||||||
)
|
)
|
||||||
|
|
||||||
input_nodes = [x, *w]
|
input_nodes = [x, *w]
|
||||||
|
|
|
||||||
|
|
@ -3029,7 +3029,7 @@ class Scheduler:
|
||||||
if fusion_log.isEnabledFor(logging.DEBUG):
|
if fusion_log.isEnabledFor(logging.DEBUG):
|
||||||
fusion_log.debug("fuse_nodes_once, candidates:")
|
fusion_log.debug("fuse_nodes_once, candidates:")
|
||||||
for node in fused_nodes:
|
for node in fused_nodes:
|
||||||
fusion_log.debug(" " + node.debug_str_short()) # noqa: G003
|
fusion_log.debug(" %s", node.debug_str_short())
|
||||||
|
|
||||||
# These are potential fusions which we are async compiling,
|
# These are potential fusions which we are async compiling,
|
||||||
# and which we will benchmark profitability of.
|
# and which we will benchmark profitability of.
|
||||||
|
|
|
||||||
|
|
@ -69,7 +69,7 @@ try:
|
||||||
not_implemented_log = torch._logging.getArtifactLogger(__name__, "not_implemented")
|
not_implemented_log = torch._logging.getArtifactLogger(__name__, "not_implemented")
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
if "'not_implemented' not registered" in str(e):
|
if "'not_implemented' not registered" in str(e):
|
||||||
import logging as not_implemented_log
|
not_implemented_log = logging.getLogger(__name__ + ".not_implemented")
|
||||||
else:
|
else:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -228,7 +228,7 @@ def convert_pt2e(
|
||||||
# for detailed explanation of output quantized model
|
# for detailed explanation of output quantized model
|
||||||
quantized_model = convert_pt2e(prepared_model)
|
quantized_model = convert_pt2e(prepared_model)
|
||||||
|
|
||||||
""" # flake8: noqa
|
"""
|
||||||
torch._C._log_api_usage_once("quantization_api.quantize_pt2e.convert_pt2e")
|
torch._C._log_api_usage_once("quantization_api.quantize_pt2e.convert_pt2e")
|
||||||
if not isinstance(use_reference_representation, bool):
|
if not isinstance(use_reference_representation, bool):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|
|
||||||
|
|
@ -358,13 +358,11 @@ class SACEstimator(TorchDispatchMode):
|
||||||
output_ids = tuple(hash(st) for st in out_storages)
|
output_ids = tuple(hash(st) for st in out_storages)
|
||||||
# 4. If the function is not inplace, return
|
# 4. If the function is not inplace, return
|
||||||
if not is_inplace(func):
|
if not is_inplace(func):
|
||||||
return curr_idx, output_ids, {mod_fqn: () for mod_fqn in active_mod_fqns}
|
return curr_idx, output_ids, dict.fromkeys(active_mod_fqns, ())
|
||||||
|
|
||||||
op_idx = curr_idx
|
op_idx = curr_idx
|
||||||
# 5. Initialize the parent op ids of the inplace op for each of the active modules
|
# 5. Initialize the parent op ids of the inplace op for each of the active modules
|
||||||
mod_op_parent_idxs: dict[str, int] = {
|
mod_op_parent_idxs: dict[str, int] = dict.fromkeys(active_mod_fqns, -1)
|
||||||
mod_fqn: -1 for mod_fqn in active_mod_fqns
|
|
||||||
}
|
|
||||||
for i, d in enumerate(self._sac_metadata):
|
for i, d in enumerate(self._sac_metadata):
|
||||||
# 6. Find the first occurence of a tensor corresponding to each module that
|
# 6. Find the first occurence of a tensor corresponding to each module that
|
||||||
# shares the same storage as the current tensor
|
# shares the same storage as the current tensor
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,4 @@
|
||||||
# mypy: allow-untyped-defs
|
# mypy: allow-untyped-defs
|
||||||
# flake8: noqa C101
|
|
||||||
import itertools
|
import itertools
|
||||||
from collections.abc import Iterable, Iterator
|
from collections.abc import Iterable, Iterator
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ import sys
|
||||||
import types
|
import types
|
||||||
from collections.abc import Iterator, Mapping
|
from collections.abc import Iterator, Mapping
|
||||||
from typing import Any, Callable, Optional, TypeVar, Union
|
from typing import Any, Callable, Optional, TypeVar, Union
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed.rpc as rpc
|
import torch.distributed.rpc as rpc
|
||||||
|
|
@ -319,34 +320,34 @@ class _RemoteModule(nn.Module):
|
||||||
def add_module(self, name: str, module: Optional[Module]) -> None:
|
def add_module(self, name: str, module: Optional[Module]) -> None:
|
||||||
_raise_not_supported(self.add_module.__name__)
|
_raise_not_supported(self.add_module.__name__)
|
||||||
|
|
||||||
def apply(self: T, fn: Callable[[Module], None]) -> T: # type: ignore[return]
|
def apply(self, fn: Callable[[Module], None]) -> Self: # type: ignore[return]
|
||||||
_raise_not_supported(self.apply.__name__)
|
_raise_not_supported(self.apply.__name__)
|
||||||
|
|
||||||
def cuda(self: T, device: Optional[Union[int, device]] = None) -> T: # type: ignore[return]
|
def cuda(self, device: Optional[Union[int, device]] = None) -> Self: # type: ignore[return]
|
||||||
_raise_not_supported(self.cuda.__name__)
|
_raise_not_supported(self.cuda.__name__)
|
||||||
|
|
||||||
def ipu(self: T, device: Optional[Union[int, device]] = None) -> T: # type: ignore[return]
|
def ipu(self, device: Optional[Union[int, device]] = None) -> Self: # type: ignore[return]
|
||||||
_raise_not_supported(self.ipu.__name__)
|
_raise_not_supported(self.ipu.__name__)
|
||||||
|
|
||||||
def xpu(self: T, device: Optional[Union[int, device]] = None) -> T: # type: ignore[return]
|
def xpu(self, device: Optional[Union[int, device]] = None) -> Self: # type: ignore[return]
|
||||||
_raise_not_supported(self.xpu.__name__)
|
_raise_not_supported(self.xpu.__name__)
|
||||||
|
|
||||||
def cpu(self: T) -> T: # type: ignore[return]
|
def cpu(self) -> Self: # type: ignore[return]
|
||||||
_raise_not_supported(self.cpu.__name__)
|
_raise_not_supported(self.cpu.__name__)
|
||||||
|
|
||||||
def type(self: T, dst_type: Union[dtype, str]) -> T: # type: ignore[return]
|
def type(self, dst_type: Union[dtype, str]) -> Self: # type: ignore[return]
|
||||||
_raise_not_supported(self.type.__name__)
|
_raise_not_supported(self.type.__name__)
|
||||||
|
|
||||||
def float(self: T) -> T: # type: ignore[return]
|
def float(self) -> Self: # type: ignore[return]
|
||||||
_raise_not_supported(self.float.__name__)
|
_raise_not_supported(self.float.__name__)
|
||||||
|
|
||||||
def double(self: T) -> T: # type: ignore[return]
|
def double(self) -> Self: # type: ignore[return]
|
||||||
_raise_not_supported(self.double.__name__)
|
_raise_not_supported(self.double.__name__)
|
||||||
|
|
||||||
def half(self: T) -> T: # type: ignore[return]
|
def half(self) -> Self: # type: ignore[return]
|
||||||
_raise_not_supported(self.half.__name__)
|
_raise_not_supported(self.half.__name__)
|
||||||
|
|
||||||
def bfloat16(self: T) -> T: # type: ignore[return]
|
def bfloat16(self) -> Self: # type: ignore[return]
|
||||||
_raise_not_supported(self.bfloat16.__name__)
|
_raise_not_supported(self.bfloat16.__name__)
|
||||||
|
|
||||||
def to(self, *args, **kwargs) -> T: # type: ignore[misc, return, type-var]
|
def to(self, *args, **kwargs) -> T: # type: ignore[misc, return, type-var]
|
||||||
|
|
@ -428,19 +429,19 @@ class _RemoteModule(nn.Module):
|
||||||
):
|
):
|
||||||
_raise_not_supported(self.named_modules.__name__)
|
_raise_not_supported(self.named_modules.__name__)
|
||||||
|
|
||||||
def train(self: T, mode: bool = True) -> T:
|
def train(self, mode: bool = True) -> Self:
|
||||||
return self.module_rref.rpc_sync().train() # type: ignore[operator, union-attr]
|
return self.module_rref.rpc_sync().train() # type: ignore[operator, union-attr]
|
||||||
|
|
||||||
def eval(self: T) -> T:
|
def eval(self) -> Self:
|
||||||
return self.module_rref.rpc_sync().eval() # type: ignore[operator, union-attr]
|
return self.module_rref.rpc_sync().eval() # type: ignore[operator, union-attr]
|
||||||
|
|
||||||
def requires_grad_(self: T, requires_grad: bool = True) -> T: # type: ignore[return]
|
def requires_grad_(self, requires_grad: bool = True) -> Self: # type: ignore[return]
|
||||||
_raise_not_supported(self.requires_grad_.__name__)
|
_raise_not_supported(self.requires_grad_.__name__)
|
||||||
|
|
||||||
def zero_grad(self, set_to_none: bool = True) -> None:
|
def zero_grad(self, set_to_none: bool = True) -> None:
|
||||||
_raise_not_supported(self.zero_grad.__name__)
|
_raise_not_supported(self.zero_grad.__name__)
|
||||||
|
|
||||||
def share_memory(self: T) -> T: # type: ignore[return]
|
def share_memory(self) -> Self: # type: ignore[return]
|
||||||
_raise_not_supported(self.share_memory.__name__)
|
_raise_not_supported(self.share_memory.__name__)
|
||||||
|
|
||||||
def extra_repr(self) -> str: # type: ignore[return]
|
def extra_repr(self) -> str: # type: ignore[return]
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,4 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||||
# flake8: noqa
|
|
||||||
|
|
||||||
from .binary import _apply_native_binary, _is_native_binary
|
from .binary import _apply_native_binary, _is_native_binary
|
||||||
from .core import is_masked_tensor, MaskedTensor
|
from .core import is_masked_tensor, MaskedTensor
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
# mypy: allow-untyped-defs
|
# mypy: allow-untyped-defs
|
||||||
# flake8: noqa C101
|
# flake8: noqa: B950
|
||||||
"""This module implements the user facing API for flex_attention in PyTorch."""
|
"""This module implements the user facing API for flex_attention in PyTorch."""
|
||||||
import functools
|
import functools
|
||||||
import inspect
|
import inspect
|
||||||
|
|
|
||||||
|
|
@ -1004,7 +1004,7 @@ class Module:
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def apply(self: T, fn: Callable[["Module"], None]) -> T:
|
def apply(self, fn: Callable[["Module"], None]) -> Self:
|
||||||
r"""Apply ``fn`` recursively to every submodule (as returned by ``.children()``) as well as self.
|
r"""Apply ``fn`` recursively to every submodule (as returned by ``.children()``) as well as self.
|
||||||
|
|
||||||
Typical use includes initializing the parameters of a model
|
Typical use includes initializing the parameters of a model
|
||||||
|
|
@ -1045,7 +1045,7 @@ class Module:
|
||||||
fn(self)
|
fn(self)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def cuda(self: T, device: Optional[Union[int, device]] = None) -> T:
|
def cuda(self, device: Optional[Union[int, device]] = None) -> Self:
|
||||||
r"""Move all model parameters and buffers to the GPU.
|
r"""Move all model parameters and buffers to the GPU.
|
||||||
|
|
||||||
This also makes associated parameters and buffers different objects. So
|
This also makes associated parameters and buffers different objects. So
|
||||||
|
|
@ -1064,7 +1064,7 @@ class Module:
|
||||||
"""
|
"""
|
||||||
return self._apply(lambda t: t.cuda(device))
|
return self._apply(lambda t: t.cuda(device))
|
||||||
|
|
||||||
def ipu(self: T, device: Optional[Union[int, device]] = None) -> T:
|
def ipu(self, device: Optional[Union[int, device]] = None) -> Self:
|
||||||
r"""Move all model parameters and buffers to the IPU.
|
r"""Move all model parameters and buffers to the IPU.
|
||||||
|
|
||||||
This also makes associated parameters and buffers different objects. So
|
This also makes associated parameters and buffers different objects. So
|
||||||
|
|
@ -1083,7 +1083,7 @@ class Module:
|
||||||
"""
|
"""
|
||||||
return self._apply(lambda t: t.ipu(device))
|
return self._apply(lambda t: t.ipu(device))
|
||||||
|
|
||||||
def xpu(self: T, device: Optional[Union[int, device]] = None) -> T:
|
def xpu(self, device: Optional[Union[int, device]] = None) -> Self:
|
||||||
r"""Move all model parameters and buffers to the XPU.
|
r"""Move all model parameters and buffers to the XPU.
|
||||||
|
|
||||||
This also makes associated parameters and buffers different objects. So
|
This also makes associated parameters and buffers different objects. So
|
||||||
|
|
@ -1102,7 +1102,7 @@ class Module:
|
||||||
"""
|
"""
|
||||||
return self._apply(lambda t: t.xpu(device))
|
return self._apply(lambda t: t.xpu(device))
|
||||||
|
|
||||||
def mtia(self: T, device: Optional[Union[int, device]] = None) -> T:
|
def mtia(self, device: Optional[Union[int, device]] = None) -> Self:
|
||||||
r"""Move all model parameters and buffers to the MTIA.
|
r"""Move all model parameters and buffers to the MTIA.
|
||||||
|
|
||||||
This also makes associated parameters and buffers different objects. So
|
This also makes associated parameters and buffers different objects. So
|
||||||
|
|
@ -1121,7 +1121,7 @@ class Module:
|
||||||
"""
|
"""
|
||||||
return self._apply(lambda t: t.mtia(device))
|
return self._apply(lambda t: t.mtia(device))
|
||||||
|
|
||||||
def cpu(self: T) -> T:
|
def cpu(self) -> Self:
|
||||||
r"""Move all model parameters and buffers to the CPU.
|
r"""Move all model parameters and buffers to the CPU.
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
|
|
@ -1132,7 +1132,7 @@ class Module:
|
||||||
"""
|
"""
|
||||||
return self._apply(lambda t: t.cpu())
|
return self._apply(lambda t: t.cpu())
|
||||||
|
|
||||||
def type(self: T, dst_type: Union[dtype, str]) -> T:
|
def type(self, dst_type: Union[dtype, str]) -> Self:
|
||||||
r"""Casts all parameters and buffers to :attr:`dst_type`.
|
r"""Casts all parameters and buffers to :attr:`dst_type`.
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
|
|
@ -1146,7 +1146,7 @@ class Module:
|
||||||
"""
|
"""
|
||||||
return self._apply(lambda t: t.type(dst_type))
|
return self._apply(lambda t: t.type(dst_type))
|
||||||
|
|
||||||
def float(self: T) -> T:
|
def float(self) -> Self:
|
||||||
r"""Casts all floating point parameters and buffers to ``float`` datatype.
|
r"""Casts all floating point parameters and buffers to ``float`` datatype.
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
|
|
@ -1157,7 +1157,7 @@ class Module:
|
||||||
"""
|
"""
|
||||||
return self._apply(lambda t: t.float() if t.is_floating_point() else t)
|
return self._apply(lambda t: t.float() if t.is_floating_point() else t)
|
||||||
|
|
||||||
def double(self: T) -> T:
|
def double(self) -> Self:
|
||||||
r"""Casts all floating point parameters and buffers to ``double`` datatype.
|
r"""Casts all floating point parameters and buffers to ``double`` datatype.
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
|
|
@ -1168,7 +1168,7 @@ class Module:
|
||||||
"""
|
"""
|
||||||
return self._apply(lambda t: t.double() if t.is_floating_point() else t)
|
return self._apply(lambda t: t.double() if t.is_floating_point() else t)
|
||||||
|
|
||||||
def half(self: T) -> T:
|
def half(self) -> Self:
|
||||||
r"""Casts all floating point parameters and buffers to ``half`` datatype.
|
r"""Casts all floating point parameters and buffers to ``half`` datatype.
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
|
|
@ -1179,7 +1179,7 @@ class Module:
|
||||||
"""
|
"""
|
||||||
return self._apply(lambda t: t.half() if t.is_floating_point() else t)
|
return self._apply(lambda t: t.half() if t.is_floating_point() else t)
|
||||||
|
|
||||||
def bfloat16(self: T) -> T:
|
def bfloat16(self) -> Self:
|
||||||
r"""Casts all floating point parameters and buffers to ``bfloat16`` datatype.
|
r"""Casts all floating point parameters and buffers to ``bfloat16`` datatype.
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
|
|
@ -1191,8 +1191,8 @@ class Module:
|
||||||
return self._apply(lambda t: t.bfloat16() if t.is_floating_point() else t)
|
return self._apply(lambda t: t.bfloat16() if t.is_floating_point() else t)
|
||||||
|
|
||||||
def to_empty(
|
def to_empty(
|
||||||
self: T, *, device: Optional[DeviceLikeType], recurse: bool = True
|
self, *, device: Optional[DeviceLikeType], recurse: bool = True
|
||||||
) -> T:
|
) -> Self:
|
||||||
r"""Move the parameters and buffers to the specified device without copying storage.
|
r"""Move the parameters and buffers to the specified device without copying storage.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -2837,7 +2837,7 @@ class Module:
|
||||||
memo, submodule_prefix, remove_duplicate
|
memo, submodule_prefix, remove_duplicate
|
||||||
)
|
)
|
||||||
|
|
||||||
def train(self: T, mode: bool = True) -> T:
|
def train(self, mode: bool = True) -> Self:
|
||||||
r"""Set the module in training mode.
|
r"""Set the module in training mode.
|
||||||
|
|
||||||
This has an effect only on certain modules. See the documentation of
|
This has an effect only on certain modules. See the documentation of
|
||||||
|
|
@ -2859,7 +2859,7 @@ class Module:
|
||||||
module.train(mode)
|
module.train(mode)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def eval(self: T) -> T:
|
def eval(self) -> Self:
|
||||||
r"""Set the module in evaluation mode.
|
r"""Set the module in evaluation mode.
|
||||||
|
|
||||||
This has an effect only on certain modules. See the documentation of
|
This has an effect only on certain modules. See the documentation of
|
||||||
|
|
@ -2877,7 +2877,7 @@ class Module:
|
||||||
"""
|
"""
|
||||||
return self.train(False)
|
return self.train(False)
|
||||||
|
|
||||||
def requires_grad_(self: T, requires_grad: bool = True) -> T:
|
def requires_grad_(self, requires_grad: bool = True) -> Self:
|
||||||
r"""Change if autograd should record operations on parameters in this module.
|
r"""Change if autograd should record operations on parameters in this module.
|
||||||
|
|
||||||
This method sets the parameters' :attr:`requires_grad` attributes
|
This method sets the parameters' :attr:`requires_grad` attributes
|
||||||
|
|
@ -2928,7 +2928,7 @@ class Module:
|
||||||
p.grad.requires_grad_(False)
|
p.grad.requires_grad_(False)
|
||||||
p.grad.zero_()
|
p.grad.zero_()
|
||||||
|
|
||||||
def share_memory(self: T) -> T:
|
def share_memory(self) -> Self:
|
||||||
r"""See :meth:`torch.Tensor.share_memory_`."""
|
r"""See :meth:`torch.Tensor.share_memory_`."""
|
||||||
return self._apply(lambda t: t.share_memory_())
|
return self._apply(lambda t: t.share_memory_())
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -70,7 +70,7 @@ def from_dynamic_axes_to_dynamic_shapes(
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The axis in dynamic_axes must be in the form of: dict[int, str] or list[int]."
|
"The axis in dynamic_axes must be in the form of: dict[int, str] or list[int]."
|
||||||
)
|
)
|
||||||
dynamic_shapes[input_name] = {k: torch.export.Dim.DYNAMIC for k in axes}
|
dynamic_shapes[input_name] = dict.fromkeys(axes, torch.export.Dim.DYNAMIC)
|
||||||
elif axes is None:
|
elif axes is None:
|
||||||
dynamic_shapes[input_name] = None
|
dynamic_shapes[input_name] = None
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
"""torch.ops.aten operators under the `core` module."""
|
"""torch.ops.aten operators under the `core` module."""
|
||||||
# mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value,type-var,operator,no-untyped-def,index"
|
# mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value,type-var,operator,no-untyped-def,index"
|
||||||
# ruff: noqa: TCH001,TCH002
|
# ruff: noqa: TCH001,TCH002
|
||||||
# flake8: noqa
|
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
"""torch.ops.aten operators under the `core` module."""
|
"""torch.ops.aten operators under the `core` module."""
|
||||||
# mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value,type-var,operator,no-untyped-def,index"
|
# mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value,type-var,operator,no-untyped-def,index"
|
||||||
# ruff: noqa: TCH001,TCH002
|
# ruff: noqa: TCH001,TCH002
|
||||||
# flake8: noqa
|
# flake8: noqa: B950
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,6 @@
|
||||||
|
|
||||||
# mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value,type-var,operator,no-untyped-def,index"
|
# mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value,type-var,operator,no-untyped-def,index"
|
||||||
# ruff: noqa: TCH001,TCH002
|
# ruff: noqa: TCH001,TCH002
|
||||||
# flake8: noqa
|
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
|
@ -12,7 +11,6 @@ import torch
|
||||||
from torch.onnx._internal.exporter._torchlib._tensor_typing import (
|
from torch.onnx._internal.exporter._torchlib._tensor_typing import (
|
||||||
BOOL,
|
BOOL,
|
||||||
FLOAT,
|
FLOAT,
|
||||||
INT64,
|
|
||||||
IntType,
|
IntType,
|
||||||
TensorType,
|
TensorType,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user