mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Turn on linting for functorch (#81987)
Test Plan: - wait for CI Pull Request resolved: https://github.com/pytorch/pytorch/pull/81987 Approved by: https://github.com/samdow
This commit is contained in:
parent
5cb802a63b
commit
f42ed3f98f
5
.flake8
5
.flake8
|
|
@ -22,8 +22,9 @@ exclude =
|
|||
./docs/caffe2,
|
||||
./docs/cpp/src,
|
||||
./docs/src,
|
||||
# See NOTE: [Impending functorch move]
|
||||
./functorch,
|
||||
./functorch/docs,
|
||||
./functorch/examples,
|
||||
./functorch/notebooks,
|
||||
./scripts,
|
||||
./test/generated_type_hints_smoketest.py,
|
||||
./third_party,
|
||||
|
|
|
|||
|
|
@ -9,11 +9,9 @@ exclude_patterns = [
|
|||
'docs/caffe2/**',
|
||||
'docs/cpp/src/**',
|
||||
'docs/src/**',
|
||||
# NOTE: [Impending functorch move]
|
||||
# In preparation for the functorch -> pytorch merge,
|
||||
# we are adding the following excludes so that functorch passes
|
||||
# lint when it gets merged in. Please don't delete.
|
||||
'functorch/**',
|
||||
'functorch/docs/**',
|
||||
'functorch/examples/**',
|
||||
'functorch/notebooks/**',
|
||||
'scripts/**',
|
||||
'test/generated_type_hints_smoketest.py',
|
||||
'third_party/**',
|
||||
|
|
@ -227,8 +225,6 @@ code = 'TYPEIGNORE'
|
|||
include_patterns = ['**/*.py', '**/*.pyi']
|
||||
exclude_patterns = [
|
||||
'test/test_jit.py',
|
||||
# See NOTE: [Impending functorch move]
|
||||
'functorch/**',
|
||||
]
|
||||
command = [
|
||||
'python3',
|
||||
|
|
@ -301,8 +297,6 @@ exclude_patterns=[
|
|||
'tools/clang_format_hash/**',
|
||||
'test/cpp/jit/upgrader_models/*.ptl',
|
||||
'test/cpp/jit/upgrader_models/*.ptl.ff',
|
||||
# See NOTE: [Impending functorch move]
|
||||
'functorch/**',
|
||||
]
|
||||
command = [
|
||||
'python3',
|
||||
|
|
@ -322,8 +316,6 @@ exclude_patterns = [
|
|||
'aten/src/ATen/native/vulkan/api/vk_mem_alloc.h',
|
||||
'test/cpp/jit/upgrader_models/*.ptl',
|
||||
'test/cpp/jit/upgrader_models/*.ptl.ff',
|
||||
# See NOTE: [Impending functorch move]
|
||||
'functorch/**',
|
||||
]
|
||||
command = [
|
||||
'python3',
|
||||
|
|
@ -353,8 +345,6 @@ exclude_patterns = [
|
|||
'test/cpp/jit/upgrader_models/*.ptl',
|
||||
'test/cpp/jit/upgrader_models/*.ptl.ff',
|
||||
'.lintrunner.toml',
|
||||
# See NOTE: [Impending functorch move]
|
||||
'functorch/**',
|
||||
]
|
||||
command = [
|
||||
'python3',
|
||||
|
|
@ -436,8 +426,6 @@ exclude_patterns = [
|
|||
'**/git-pre-commit',
|
||||
'**/git-clang-format',
|
||||
'**/gradlew',
|
||||
# See NOTE: [Impending functorch move]
|
||||
'functorch/**',
|
||||
]
|
||||
command = [
|
||||
'python3',
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ set -ex
|
|||
echo CU_VERSION is "${CU_VERSION}"
|
||||
echo CUDA_VERSION is "${CUDA_VERSION}"
|
||||
|
||||
# Currenly, CU_VERSION and CUDA_VERSION are not consistent.
|
||||
# Currenly, CU_VERSION and CUDA_VERSION are not consistent.
|
||||
# to understand this code, see https://github.com/pytorch/vision/issues/4443
|
||||
version="cpu"
|
||||
if [[ ! -z "${CUDA_VERSION}" ]] ; then
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ def get_model_name(filename):
|
|||
return modelname
|
||||
|
||||
def get_total_length(run_times_df, modelname):
|
||||
return float(run_times_df[run_times_df["name"]==modelname]["runtime"])
|
||||
return float(run_times_df[run_times_df["name"] == modelname]["runtime"])
|
||||
|
||||
|
||||
def main():
|
||||
|
|
@ -51,16 +51,16 @@ def main():
|
|||
else:
|
||||
print("Please provide a filename or a folder name")
|
||||
|
||||
print(f"modelname, GPU Utilization, MM and Conv time")
|
||||
print("modelname, GPU Utilization, MM and Conv time")
|
||||
|
||||
run_times_df = pd.read_csv(args.runtime)
|
||||
run_times_df = pd.read_csv(args.runtime)
|
||||
for filename in filenames:
|
||||
try:
|
||||
modelname = get_model_name(filename)
|
||||
total_length = get_total_length(run_times_df, modelname) * 1e6
|
||||
utilization, mm_conv_utilization = compute_utilization(filenames, total_length)
|
||||
print(f"{modelname}, {utilization}, {mm_conv_utilization}")
|
||||
except:
|
||||
except BaseException:
|
||||
logging.exception(f"{filename}, ERROR")
|
||||
print(f"{filename}, ERROR")
|
||||
|
||||
|
|
|
|||
|
|
@ -2,11 +2,8 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torchvision.models as models
|
||||
from opacus.utils.module_modification import convert_batchnorm_modules
|
||||
from torchvision.datasets import CIFAR10
|
||||
import time
|
||||
|
||||
from functools import partial
|
||||
import functorch
|
||||
from functorch import vmap, grad
|
||||
from functorch import make_functional
|
||||
from opacus import PrivacyEngine
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ nops = len(ops)
|
|||
pivot_op_shape = df.pivot_table(values="time", index=["operator", "shape"], columns=["fuser"])
|
||||
pivot_speedups = (pivot_op_shape.T / pivot_op_shape["eager"]).T
|
||||
|
||||
plt.rcParams["figure.figsize"] = (20,100)
|
||||
plt.rcParams["figure.figsize"] = (20, 100)
|
||||
fig, axs = plt.subplots(nops)
|
||||
plt.subplots_adjust(hspace=0.5)
|
||||
for idx, op in enumerate(ops):
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import torch
|
||||
import time
|
||||
from functorch.compile import memory_efficient_fusion, clear_compile_cache
|
||||
import benchmark_helper
|
||||
|
||||
|
|
|
|||
|
|
@ -63,7 +63,9 @@ def profile_cuda_kernels(fn, args, string_id="Model time"):
|
|||
print("################################################\n\n\n\n")
|
||||
|
||||
|
||||
def time_with_torch_timer(fn, args, string_id, kwargs={}):
|
||||
def time_with_torch_timer(fn, args, string_id, kwargs=None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
print("################################################")
|
||||
print(f"#### Torch Timer for {string_id} starts #########")
|
||||
print("################################################")
|
||||
|
|
|
|||
|
|
@ -1,9 +1,8 @@
|
|||
import torch
|
||||
import time
|
||||
from functorch.compile import memory_efficient_pointwise_fusion, clear_compile_cache
|
||||
import benchmark_helper
|
||||
|
||||
### ALL comments regarding the patetrns
|
||||
# ALL comments regarding the patetrns
|
||||
|
||||
|
||||
def bias_gelu_dropout(input, bias):
|
||||
|
|
|
|||
|
|
@ -3,4 +3,4 @@
|
|||
font-family: Arial Black;
|
||||
dominant-baseline: central;
|
||||
text-anchor: middle;
|
||||
}</style></svg>
|
||||
}</style></svg>
|
||||
|
|
|
|||
|
Before Width: | Height: | Size: 6.2 KiB After Width: | Height: | Size: 6.2 KiB |
|
|
@ -40,4 +40,4 @@ Compilers (experimental)
|
|||
:nosignatures:
|
||||
|
||||
nop
|
||||
ts_compile
|
||||
ts_compile
|
||||
|
|
|
|||
|
|
@ -465,4 +465,4 @@ def parse_args():
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
|
|
|||
1
functorch/examples/maml_omniglot/.gitignore
vendored
1
functorch/examples/maml_omniglot/.gitignore
vendored
|
|
@ -1,3 +1,2 @@
|
|||
omniglot/
|
||||
maml-accs.png
|
||||
|
||||
|
|
|
|||
|
|
@ -121,7 +121,7 @@ def _create_batched_inputs(
|
|||
flat_in_dims: List[Any], flat_args: List[Any], vmap_level: int, args_spec) -> Tuple:
|
||||
# See NOTE [Ignored _remove_batch_dim, _add_batch_dim]
|
||||
batched_inputs = [arg if in_dim is None else
|
||||
_add_batch_dim(arg, in_dim, vmap_level) # type: ignore
|
||||
_add_batch_dim(arg, in_dim, vmap_level)
|
||||
for in_dim, arg in zip(flat_in_dims, flat_args)]
|
||||
return tree_unflatten(batched_inputs, args_spec)
|
||||
|
||||
|
|
|
|||
|
|
@ -259,4 +259,3 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
|
|||
}
|
||||
|
||||
}}
|
||||
|
||||
|
|
|
|||
|
|
@ -97,4 +97,3 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
|
|||
// Not sure how to add the ones with irregular args to the mix cleanly (i.e. randint takes an extra int parameter)
|
||||
}
|
||||
}}
|
||||
|
||||
|
|
|
|||
|
|
@ -470,4 +470,3 @@ inline VmapDimVector range(int64_t start, int64_t stop) {
|
|||
}
|
||||
|
||||
}}
|
||||
|
||||
|
|
|
|||
|
|
@ -216,4 +216,3 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
|
|||
VARIADIC_BDIMS_BOXED(_lu_with_info);
|
||||
}
|
||||
}}
|
||||
|
||||
|
|
|
|||
|
|
@ -463,7 +463,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchVmapMode, m) {
|
|||
decltype(&ATEN_FN2(randint_like, low_dtype)), &ATEN_FN2(randint_like, low_dtype), int64_t, int64_t, TENSOR_LIKE_COMMON_ARG_TYPES>);
|
||||
m.impl("rand_like", tensor_like_random_batch_rule<decltype(&ATEN_FN(rand_like)), &ATEN_FN(rand_like), TENSOR_LIKE_COMMON_ARG_TYPES>);
|
||||
m.impl("randn_like", tensor_like_random_batch_rule<decltype(&ATEN_FN(randn_like)), &ATEN_FN(randn_like), TENSOR_LIKE_COMMON_ARG_TYPES>);
|
||||
|
||||
|
||||
#undef RANDOM_BATCH_RULE
|
||||
#undef RANDOM_BATCH_RULE2
|
||||
#undef RANDOM_INPLACE_BATCH_RULE
|
||||
|
|
|
|||
|
|
@ -37,4 +37,3 @@ inline bool ivalueParticipatesInCurrentLevel(const IValue& ivalue) {
|
|||
}
|
||||
|
||||
}}
|
||||
|
||||
|
|
|
|||
|
|
@ -19,7 +19,6 @@ Let's demonstrate how to do this using an ensemble of simple CNNs.
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from functools import partial
|
||||
torch.manual_seed(0)
|
||||
|
||||
# Here's a simple CNN
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ efficiently using a standard autodiff system like PyTorch Autograd; functorch
|
|||
provides ways of computing various higher-order autodiff quantities efficiently.
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from functools import partial
|
||||
torch.manual_seed(0)
|
||||
|
|
|
|||
|
|
@ -12,7 +12,6 @@ and optimization research.
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from functools import partial
|
||||
torch.manual_seed(0)
|
||||
|
||||
# Here's a simple CNN
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
### Holds the colab ready versions of the notebook tutorials.
|
||||
### Holds the colab ready versions of the notebook tutorials.
|
||||
|
||||
These are similar to the jupyter notebooks, but have additional colab specific changes including the building of functorch in colab to prep for running.
|
||||
|
||||
The colabs and notebooks are not auto-synced atm, thus currently updates to one need to be synched to the other.
|
||||
The colabs and notebooks are not auto-synced atm, thus currently updates to one need to be synched to the other.
|
||||
|
|
|
|||
|
|
@ -96,6 +96,3 @@ There's a couple different resources for finding batching rules to write.
|
|||
1. [BatchingRegistrations.cpp](functorch/csrc/BatchingRegistrations.cpp): This is probably the easiest place to start. These were batching rules that were written with an old API, and thus have a lot of cruft in them that are no longer necessary. Porting these batching rules to using one of the above options is an easy way to get started and help us reduce tech debt :) Once you've gotten your footing with writing batching rules, you can start helping with writing new batching rules.
|
||||
2. Popular operators. See [1](https://github.com/facebookresearch/functorch/issues/112), [2](https://github.com/facebookresearch/functorch/issues/101), [3](https://github.com/facebookresearch/functorch/issues/102), and [4](https://github.com/facebookresearch/functorch/issues/102). These contain lists of (user-facing) PyTorch operators sorted by usages, along with whether they have a batching rule implemented or not.
|
||||
3. [Master List](https://docs.google.com/spreadsheets/d/1Sp4HUjxwMifS5oDQg0yvjqk7hKOpCfKO4jWH4MTGP-k/edit#gid=0). This is the master list of vmap operator support :). It's generated by [this script](op_analysis/gen_data.py). Theoretically, we want to support most of the operators in that list (that aren't composite or out variants).
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user