Apply UFMT to all non test/torch files (#106205)

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106205
Approved by: https://github.com/albanD
This commit is contained in:
Edward Z. Yang 2023-07-28 16:04:39 -04:00 committed by PyTorch MergeBot
parent 1163800d0f
commit e6ec0efaf8
51 changed files with 1671 additions and 941 deletions

View File

@ -908,62 +908,6 @@ exclude_patterns = [
'third_party/**/*.pyi', 'third_party/**/*.pyi',
# These files are all grandfathered in, feel free to remove from this list # These files are all grandfathered in, feel free to remove from this list
# as necessary # as necessary
'aten/src/ATen/function_wrapper.py',
'aten/src/ATen/native/quantized/cpu/qnnpack/configure.py',
'aten/src/ATen/native/quantized/cpu/qnnpack/deps/clog/configure.py',
'aten/src/ATen/native/quantized/cpu/qnnpack/generate-wrapper.py',
'aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/generate_kernels.py',
'aten/src/ATen/nnapi/codegen.py',
'functorch/__init__.py',
'functorch/_src/__init__.py',
'functorch/_src/aot_autograd/__init__.py',
'functorch/_src/eager_transforms/__init__.py',
'functorch/_src/make_functional/__init__.py',
'functorch/_src/vmap/__init__.py',
'functorch/benchmarks/chrome_trace_parser.py',
'functorch/benchmarks/cse.py',
'functorch/benchmarks/operator_authoring.py',
'functorch/benchmarks/per_sample_grads.py',
'functorch/benchmarks/pointwise_scorecard.py',
'functorch/benchmarks/process_scorecard.py',
'functorch/compile/__init__.py',
'functorch/dim/__init__.py',
'functorch/dim/batch_tensor.py',
'functorch/dim/delayed_mul_tensor.py',
'functorch/dim/dim.py',
'functorch/dim/magic_trace.py',
'functorch/dim/op_properties.py',
'functorch/dim/reference.py',
'functorch/dim/tree_map.py',
'functorch/dim/wrap_type.py',
'functorch/docs/source/conf.py',
'functorch/einops/__init__.py',
'functorch/einops/_parsing.py',
'functorch/einops/rearrange.py',
'functorch/examples/compilation/eager_fusion.py',
'functorch/examples/compilation/fuse_module.py',
'functorch/examples/compilation/linear_train.py',
'functorch/examples/compilation/simple_function.py',
'functorch/examples/dp_cifar10/cifar10_opacus.py',
'functorch/examples/dp_cifar10/cifar10_transforms.py',
'functorch/examples/ensembling/parallel_train.py',
'functorch/examples/lennard_jones/lennard_jones.py',
'functorch/examples/maml_omniglot/maml-omniglot-higher.py',
'functorch/examples/maml_omniglot/maml-omniglot-ptonly.py',
'functorch/examples/maml_omniglot/maml-omniglot-transforms.py',
'functorch/examples/maml_omniglot/support/omniglot_loaders.py',
'functorch/examples/maml_regression/evjang.py',
'functorch/examples/maml_regression/evjang_transforms.py',
'functorch/examples/maml_regression/evjang_transforms_module.py',
'functorch/experimental/__init__.py',
'functorch/experimental/_cond.py',
'functorch/experimental/_map.py',
'functorch/experimental/control_flow.py',
'functorch/experimental/ops.py',
'functorch/notebooks/_src/plot_ensembling.py',
'functorch/notebooks/_src/plot_jacobians_and_hessians.py',
'functorch/notebooks/_src/plot_per_sample_gradients.py',
'functorch/op_analysis/gen_data.py',
'test/_nvfuser/__init__.py', 'test/_nvfuser/__init__.py',
'test/_nvfuser/test_dynamo.py', 'test/_nvfuser/test_dynamo.py',
'test/_nvfuser/test_python_frontend.py', 'test/_nvfuser/test_python_frontend.py',

View File

@ -31,7 +31,6 @@ def main(args):
], ],
extra_include_dirs="src", extra_include_dirs="src",
): ):
requantization_objects = [ requantization_objects = [
build.cc("requantization/precise-scalar.c"), build.cc("requantization/precise-scalar.c"),
build.cc("requantization/fp32-scalar.c"), build.cc("requantization/fp32-scalar.c"),
@ -192,7 +191,6 @@ def main(args):
}, },
extra_include_dirs=["src", "test"], extra_include_dirs=["src", "test"],
): ):
build.unittest("hgemm-test", build.cxx("hgemm.cc")) build.unittest("hgemm-test", build.cxx("hgemm.cc"))
build.unittest("q8avgpool-test", build.cxx("q8avgpool.cc")) build.unittest("q8avgpool-test", build.cxx("q8avgpool.cc"))
build.unittest("q8conv-test", build.cxx("q8conv.cc")) build.unittest("q8conv-test", build.cxx("q8conv.cc"))
@ -252,7 +250,6 @@ def main(args):
isa=benchmark_isa, isa=benchmark_isa,
extra_include_dirs="src", extra_include_dirs="src",
): ):
build.benchmark("add-bench", build.cxx("add.cc")) build.benchmark("add-bench", build.cxx("add.cc"))
build.benchmark("average-pooling-bench", build.cxx("average-pooling.cc")) build.benchmark("average-pooling-bench", build.cxx("average-pooling.cc"))
build.benchmark("channel-shuffle-bench", build.cxx("channel-shuffle.cc")) build.benchmark("channel-shuffle-bench", build.cxx("channel-shuffle.cc"))

View File

@ -7,6 +7,7 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import confu import confu
parser = confu.standard_parser("clog configuration script") parser = confu.standard_parser("clog configuration script")
@ -19,13 +20,16 @@ def main(args):
with build.options(source_dir="src", extra_include_dirs="src"): with build.options(source_dir="src", extra_include_dirs="src"):
build.static_library("clog", build.cc("clog.c")) build.static_library("clog", build.cc("clog.c"))
with build.options(source_dir="test", deps={ with build.options(
(build, build.deps.googletest): all, source_dir="test",
"log": build.target.is_android}): deps={(build, build.deps.googletest): all, "log": build.target.is_android},
):
build.unittest("clog-test", build.cxx("clog.cc")) build.unittest("clog-test", build.cxx("clog.cc"))
return build return build
if __name__ == "__main__": if __name__ == "__main__":
import sys import sys
main(sys.argv[1:]).generate() main(sys.argv[1:]).generate()

View File

@ -8,12 +8,12 @@
# Kernels are ordered (see `sort_index`), and when dispatching, # Kernels are ordered (see `sort_index`), and when dispatching,
# we select the first kernel in the list that supports the inputs # we select the first kernel in the list that supports the inputs
import argparse
import collections import collections
import itertools import itertools
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple, TypeVar from typing import Dict, List, Optional, Tuple, TypeVar
import argparse
DTYPES = { DTYPES = {
"f32": "float", "f32": "float",
@ -303,7 +303,11 @@ T = TypeVar("T", FwdKernel, BwdKernel)
def write_decl_impl( def write_decl_impl(
kernels: List[T], family_name: str, impl_file: str, autogen_dir: Path, disable_def: str = None kernels: List[T],
family_name: str,
impl_file: str,
autogen_dir: Path,
disable_def: str = None,
) -> None: ) -> None:
cpp_file_header = """/* cpp_file_header = """/*
* Copyright (c) Meta Platforms, Inc. and affiliates. * Copyright (c) Meta Platforms, Inc. and affiliates.
@ -382,22 +386,28 @@ def main(output_dir: Optional[str]) -> None:
FwdKernel.get_all(), FwdKernel.get_all(),
"cutlassF", "cutlassF",
impl_file="<ATen/native/transformers/cuda/mem_eff_attention/kernel_forward.h>", impl_file="<ATen/native/transformers/cuda/mem_eff_attention/kernel_forward.h>",
autogen_dir=output_dir autogen_dir=output_dir,
) )
write_decl_impl( write_decl_impl(
BwdKernel.get_all(), BwdKernel.get_all(),
"cutlassB", "cutlassB",
impl_file="<ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h>", impl_file="<ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h>",
autogen_dir=output_dir autogen_dir=output_dir,
) )
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
prog='generate_kernels', prog="generate_kernels",
description='Generate the mem-eff kernels template instantiations') description="Generate the mem-eff kernels template instantiations",
)
# Set an optional output directory # Set an optional output directory
parser.add_argument('-o', '--output_dir', required=False, help="Where to generate the kernels " parser.add_argument(
" will default to <ATen/native/transformers/cuda/mem_eff_attention/kernels/> ") "-o",
"--output_dir",
required=False,
help="Where to generate the kernels "
" will default to <ATen/native/transformers/cuda/mem_eff_attention/kernels/> ",
)
args = parser.parse_args() args = parser.parse_args()
main(args.output_dir) main(args.output_dir)

View File

@ -7,9 +7,9 @@ that opens libneuralnetworks.so with dlopen and finds the functions
we need with dlsym. We also generate a "check" wrapper that checks we need with dlsym. We also generate a "check" wrapper that checks
return values and throws C++ exceptions on errors. return values and throws C++ exceptions on errors.
""" """
import sys
import re
import pathlib import pathlib
import re
import sys
import textwrap import textwrap
@ -36,39 +36,155 @@ PREFIX = """\
NNAPI_FUNCTIONS = [ NNAPI_FUNCTIONS = [
("int", "ANeuralNetworks_getDeviceCount", "uint32_t* numDevices"), # noqa: B950 ("int", "ANeuralNetworks_getDeviceCount", "uint32_t* numDevices"), # noqa: B950
("int", "ANeuralNetworks_getDevice", "uint32_t devIndex, ANeuralNetworksDevice** device"), # noqa: B950 (
("int", "ANeuralNetworksDevice_getName", "const ANeuralNetworksDevice* device, const char** name"), # noqa: B950 "int",
("int", "ANeuralNetworksDevice_getVersion", "const ANeuralNetworksDevice* device, const char** version"), # noqa: B950 "ANeuralNetworks_getDevice",
("int", "ANeuralNetworksDevice_getFeatureLevel", "const ANeuralNetworksDevice* device, int64_t* featureLevel"), # noqa: B950 "uint32_t devIndex, ANeuralNetworksDevice** device",
("int", "ANeuralNetworksModel_getSupportedOperationsForDevices", " const ANeuralNetworksModel* model, const ANeuralNetworksDevice* const* devices, uint32_t numDevices, bool* supportedOps"), # noqa: B950 ), # noqa: B950
("int", "ANeuralNetworksCompilation_createForDevices", "ANeuralNetworksModel* model, const ANeuralNetworksDevice* const* devices, uint32_t numDevices, ANeuralNetworksCompilation** compilation"), # noqa: B950 (
("int", "ANeuralNetworksExecution_compute", "ANeuralNetworksExecution* execution"), # noqa: B950 "int",
("int", "ANeuralNetworksMemory_createFromFd", "size_t size, int protect, int fd, size_t offset, ANeuralNetworksMemory** memory"), # noqa: B950 "ANeuralNetworksDevice_getName",
("void", "ANeuralNetworksMemory_free", "ANeuralNetworksMemory* memory"), # noqa: B950 "const ANeuralNetworksDevice* device, const char** name",
("int", "ANeuralNetworksModel_create", "ANeuralNetworksModel** model"), # noqa: B950 ), # noqa: B950
(
"int",
"ANeuralNetworksDevice_getVersion",
"const ANeuralNetworksDevice* device, const char** version",
), # noqa: B950
(
"int",
"ANeuralNetworksDevice_getFeatureLevel",
"const ANeuralNetworksDevice* device, int64_t* featureLevel",
), # noqa: B950
(
"int",
"ANeuralNetworksModel_getSupportedOperationsForDevices",
" const ANeuralNetworksModel* model, const ANeuralNetworksDevice* const* devices, uint32_t numDevices, bool* supportedOps",
), # noqa: B950
(
"int",
"ANeuralNetworksCompilation_createForDevices",
"ANeuralNetworksModel* model, const ANeuralNetworksDevice* const* devices, uint32_t numDevices, ANeuralNetworksCompilation** compilation", # noqa: B950
),
(
"int",
"ANeuralNetworksExecution_compute",
"ANeuralNetworksExecution* execution",
), # noqa: B950
(
"int",
"ANeuralNetworksMemory_createFromFd",
"size_t size, int protect, int fd, size_t offset, ANeuralNetworksMemory** memory",
), # noqa: B950
(
"void",
"ANeuralNetworksMemory_free",
"ANeuralNetworksMemory* memory",
), # noqa: B950
(
"int",
"ANeuralNetworksModel_create",
"ANeuralNetworksModel** model",
), # noqa: B950
("void", "ANeuralNetworksModel_free", "ANeuralNetworksModel* model"), # noqa: B950 ("void", "ANeuralNetworksModel_free", "ANeuralNetworksModel* model"), # noqa: B950
("int", "ANeuralNetworksModel_finish", "ANeuralNetworksModel* model"), # noqa: B950 ("int", "ANeuralNetworksModel_finish", "ANeuralNetworksModel* model"), # noqa: B950
("int", "ANeuralNetworksModel_addOperand", "ANeuralNetworksModel* model, const ANeuralNetworksOperandType* type"), # noqa: B950 (
("int", "ANeuralNetworksModel_setOperandValue", "ANeuralNetworksModel* model, int32_t index, const void* buffer, size_t length"), # noqa: B950 "int",
("int", "ANeuralNetworksModel_setOperandValueFromMemory", "ANeuralNetworksModel* model, int32_t index, const ANeuralNetworksMemory* memory, size_t offset, size_t length"), # noqa: B950 "ANeuralNetworksModel_addOperand",
("int", "ANeuralNetworksModel_addOperation", "ANeuralNetworksModel* model, ANeuralNetworksOperationType type, uint32_t inputCount, const uint32_t* inputs, uint32_t outputCount, const uint32_t* outputs"), # noqa: B950 "ANeuralNetworksModel* model, const ANeuralNetworksOperandType* type",
("int", "ANeuralNetworksModel_identifyInputsAndOutputs", "ANeuralNetworksModel* model, uint32_t inputCount, const uint32_t* inputs, uint32_t outputCount, const uint32_t* outputs"), # noqa: B950 ), # noqa: B950
("int", "ANeuralNetworksModel_relaxComputationFloat32toFloat16", "ANeuralNetworksModel* model, bool allow"), # noqa: B950 (
("int", "ANeuralNetworksCompilation_create", "ANeuralNetworksModel* model, ANeuralNetworksCompilation** compilation"), # noqa: B950 "int",
("void", "ANeuralNetworksCompilation_free", "ANeuralNetworksCompilation* compilation"), # noqa: B950 "ANeuralNetworksModel_setOperandValue",
("int", "ANeuralNetworksCompilation_setPreference", "ANeuralNetworksCompilation* compilation, int32_t preference"), # noqa: B950 "ANeuralNetworksModel* model, int32_t index, const void* buffer, size_t length",
("int", "ANeuralNetworksCompilation_finish", "ANeuralNetworksCompilation* compilation"), # noqa: B950 ), # noqa: B950
("int", "ANeuralNetworksExecution_create", "ANeuralNetworksCompilation* compilation, ANeuralNetworksExecution** execution"), # noqa: B950 (
("void", "ANeuralNetworksExecution_free", "ANeuralNetworksExecution* execution"), # noqa: B950 "int",
("int", "ANeuralNetworksExecution_setInput", "ANeuralNetworksExecution* execution, int32_t index, const ANeuralNetworksOperandType* type, const void* buffer, size_t length"), # noqa: B950 "ANeuralNetworksModel_setOperandValueFromMemory",
("int", "ANeuralNetworksExecution_setInputFromMemory", "ANeuralNetworksExecution* execution, int32_t index, const ANeuralNetworksOperandType* type, const ANeuralNetworksMemory* memory, size_t offset, size_t length"), # noqa: B950 "ANeuralNetworksModel* model, int32_t index, const ANeuralNetworksMemory* memory, size_t offset, size_t length",
("int", "ANeuralNetworksExecution_setOutput", "ANeuralNetworksExecution* execution, int32_t index, const ANeuralNetworksOperandType* type, void* buffer, size_t length"), # noqa: B950 ), # noqa: B950
("int", "ANeuralNetworksExecution_setOutputFromMemory", "ANeuralNetworksExecution* execution, int32_t index, const ANeuralNetworksOperandType* type, const ANeuralNetworksMemory* memory, size_t offset, size_t length"), # noqa: B950 (
("int", "ANeuralNetworksExecution_startCompute", "ANeuralNetworksExecution* execution, ANeuralNetworksEvent** event"), # noqa: B950 "int",
"ANeuralNetworksModel_addOperation",
"ANeuralNetworksModel* model, ANeuralNetworksOperationType type, uint32_t inputCount, const uint32_t* inputs, uint32_t outputCount, const uint32_t* outputs", # noqa: B950
),
(
"int",
"ANeuralNetworksModel_identifyInputsAndOutputs",
"ANeuralNetworksModel* model, uint32_t inputCount, const uint32_t* inputs, uint32_t outputCount, const uint32_t* outputs",
), # noqa: B950
(
"int",
"ANeuralNetworksModel_relaxComputationFloat32toFloat16",
"ANeuralNetworksModel* model, bool allow",
), # noqa: B950
(
"int",
"ANeuralNetworksCompilation_create",
"ANeuralNetworksModel* model, ANeuralNetworksCompilation** compilation",
), # noqa: B950
(
"void",
"ANeuralNetworksCompilation_free",
"ANeuralNetworksCompilation* compilation",
), # noqa: B950
(
"int",
"ANeuralNetworksCompilation_setPreference",
"ANeuralNetworksCompilation* compilation, int32_t preference",
), # noqa: B950
(
"int",
"ANeuralNetworksCompilation_finish",
"ANeuralNetworksCompilation* compilation",
), # noqa: B950
(
"int",
"ANeuralNetworksExecution_create",
"ANeuralNetworksCompilation* compilation, ANeuralNetworksExecution** execution",
), # noqa: B950
(
"void",
"ANeuralNetworksExecution_free",
"ANeuralNetworksExecution* execution",
), # noqa: B950
(
"int",
"ANeuralNetworksExecution_setInput",
"ANeuralNetworksExecution* execution, int32_t index, const ANeuralNetworksOperandType* type, const void* buffer, size_t length", # noqa: B950
),
(
"int",
"ANeuralNetworksExecution_setInputFromMemory",
"ANeuralNetworksExecution* execution, int32_t index, const ANeuralNetworksOperandType* type, const ANeuralNetworksMemory* memory, size_t offset, size_t length", # noqa: B950
),
(
"int",
"ANeuralNetworksExecution_setOutput",
"ANeuralNetworksExecution* execution, int32_t index, const ANeuralNetworksOperandType* type, void* buffer, size_t length",
), # noqa: B950
(
"int",
"ANeuralNetworksExecution_setOutputFromMemory",
"ANeuralNetworksExecution* execution, int32_t index, const ANeuralNetworksOperandType* type, const ANeuralNetworksMemory* memory, size_t offset, size_t length", # noqa: B950
),
(
"int",
"ANeuralNetworksExecution_startCompute",
"ANeuralNetworksExecution* execution, ANeuralNetworksEvent** event",
), # noqa: B950
("int", "ANeuralNetworksEvent_wait", "ANeuralNetworksEvent* event"), # noqa: B950 ("int", "ANeuralNetworksEvent_wait", "ANeuralNetworksEvent* event"), # noqa: B950
("void", "ANeuralNetworksEvent_free", "ANeuralNetworksEvent* event"), # noqa: B950 ("void", "ANeuralNetworksEvent_free", "ANeuralNetworksEvent* event"), # noqa: B950
("int", "ANeuralNetworksExecution_getOutputOperandRank", "ANeuralNetworksExecution* execution, int32_t index, uint32_t* rank"), # noqa: B950 (
("int", "ANeuralNetworksExecution_getOutputOperandDimensions", "ANeuralNetworksExecution* execution, int32_t index, uint32_t* dimensions"), # noqa: B950 "int",
"ANeuralNetworksExecution_getOutputOperandRank",
"ANeuralNetworksExecution* execution, int32_t index, uint32_t* rank",
), # noqa: B950
(
"int",
"ANeuralNetworksExecution_getOutputOperandDimensions",
"ANeuralNetworksExecution* execution, int32_t index, uint32_t* dimensions",
), # noqa: B950
] ]
@ -82,18 +198,26 @@ def main(argv):
struct_members.append(f" {ret}(*{short_name})({args});") struct_members.append(f" {ret}(*{short_name})({args});")
load_functions.append(f' *(void**)&nnapi_.{short_name} = dlsym(handle, "{name}");') load_functions.append(
load_functions.append(f' check_nnapi_.{short_name} = check_{short_name};') f' *(void**)&nnapi_.{short_name} = dlsym(handle, "{name}");'
)
load_functions.append(f" check_nnapi_.{short_name} = check_{short_name};")
call_args = "".join(re.findall(r"\w+(?:,|$)", args)) call_args = "".join(re.findall(r"\w+(?:,|$)", args))
if ret == "void": if ret == "void":
define_checks.append(textwrap.dedent(f"""\ define_checks.append(
textwrap.dedent(
f"""\
{ret} check_{short_name}({args}) {{ {ret} check_{short_name}({args}) {{
CAFFE_ENFORCE(nnapi_.{short_name}); CAFFE_ENFORCE(nnapi_.{short_name});
nnapi_.{short_name}({call_args}); nnapi_.{short_name}({call_args});
}}""")) }}"""
)
)
if ret == "int": if ret == "int":
define_checks.append(textwrap.dedent(f"""\ define_checks.append(
textwrap.dedent(
f"""\
{ret} check_{short_name}({args}) {{ {ret} check_{short_name}({args}) {{
CAFFE_ENFORCE(nnapi_.{short_name}); CAFFE_ENFORCE(nnapi_.{short_name});
int ret = nnapi_.{short_name}({call_args}); int ret = nnapi_.{short_name}({call_args});
@ -103,13 +227,16 @@ def main(argv):
"{short_name}", "failed with error ", ret "{short_name}", "failed with error ", ret
); );
return ret; return ret;
}}""")) }}"""
)
)
out_dir = pathlib.Path(__file__).parent out_dir = pathlib.Path(__file__).parent
(out_dir / "nnapi_wrapper.h").write_text( (out_dir / "nnapi_wrapper.h").write_text(
PREFIX + PREFIX
textwrap.dedent("""\ + textwrap.dedent(
"""\
#ifndef NNAPI_WRAPPER_H_ #ifndef NNAPI_WRAPPER_H_
#define NNAPI_WRAPPER_H_ #define NNAPI_WRAPPER_H_
#include <stddef.h> #include <stddef.h>
@ -122,13 +249,14 @@ def main(argv):
void nnapi_wrapper_load(struct nnapi_wrapper** nnapi, struct nnapi_wrapper** check_nnapi); void nnapi_wrapper_load(struct nnapi_wrapper** nnapi, struct nnapi_wrapper** check_nnapi);
#endif #endif
#endif #endif
""") """
.replace("__STRUCT_MEMBERS__", "\n".join(struct_members)) ).replace("__STRUCT_MEMBERS__", "\n".join(struct_members))
) )
(out_dir / "nnapi_wrapper.cpp").write_text( (out_dir / "nnapi_wrapper.cpp").write_text(
PREFIX + PREFIX
textwrap.dedent("""\ + textwrap.dedent(
"""\
#ifndef _WIN32 #ifndef _WIN32
#include <dlfcn.h> #include <dlfcn.h>
#endif #endif
@ -157,7 +285,8 @@ def main(argv):
*check_nnapi = &check_nnapi_; *check_nnapi = &check_nnapi_;
#endif #endif
} }
""") """
)
.replace("__DEFINE_CHECK_FUNCTIONS__", "\n".join(define_checks)) .replace("__DEFINE_CHECK_FUNCTIONS__", "\n".join(define_checks))
.replace("__LOAD_FUNCTIONS__", "\n".join(load_functions)) .replace("__LOAD_FUNCTIONS__", "\n".join(load_functions))
) )

View File

@ -5,6 +5,27 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import torch import torch
from torch._functorch.deprecated import (
combine_state_for_ensemble,
functionalize,
grad,
grad_and_value,
hessian,
jacfwd,
jacrev,
jvp,
make_functional,
make_functional_with_buffers,
vjp,
vmap,
)
# utilities. Maybe these should go in their own namespace in the future?
from torch._functorch.make_functional import (
FunctionalModule,
FunctionalModuleWithBuffers,
)
# Top-level APIs. Please think carefully before adding something to the # Top-level APIs. Please think carefully before adding something to the
# top-level namespace: # top-level namespace:
# - private helper functions should go into torch._functorch # - private helper functions should go into torch._functorch
@ -14,15 +35,4 @@ import torch
# Was never documented # Was never documented
from torch._functorch.python_key import make_fx from torch._functorch.python_key import make_fx
from torch._functorch.deprecated import (
vmap, grad, grad_and_value, vjp, jvp, jacrev, jacfwd, hessian, functionalize,
make_functional, make_functional_with_buffers, combine_state_for_ensemble,
)
# utilities. Maybe these should go in their own namespace in the future?
from torch._functorch.make_functional import (
FunctionalModule,
FunctionalModuleWithBuffers,
)
__version__ = torch.__version__ __version__ = torch.__version__

View File

@ -2,6 +2,6 @@
# If you are not a PyTorch developer and you are relying on the following # If you are not a PyTorch developer and you are relying on the following
# imports, please file an issue. # imports, please file an issue.
from torch._functorch.eager_transforms import ( from torch._functorch.eager_transforms import (
_unwrap_functional_tensor,
_assert_wrapped_functional, _assert_wrapped_functional,
_unwrap_functional_tensor,
) )

View File

@ -4,13 +4,13 @@
from torch._functorch.vmap import ( from torch._functorch.vmap import (
_add_batch_dim, _add_batch_dim,
_broadcast_to_and_flatten, _broadcast_to_and_flatten,
_create_batched_inputs,
_get_name, _get_name,
_process_batched_inputs,
_remove_batch_dim, _remove_batch_dim,
_unwrap_batched,
_validate_and_get_batch_size, _validate_and_get_batch_size,
Tensor, Tensor,
tree_flatten, tree_flatten,
tree_unflatten, tree_unflatten,
_process_batched_inputs,
_create_batched_inputs,
_unwrap_batched,
) )

View File

@ -1,8 +1,9 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import argparse import argparse
import logging
import os import os
import logging
import pandas as pd import pandas as pd
from torch._functorch.benchmark_utils import compute_utilization from torch._functorch.benchmark_utils import compute_utilization
@ -20,6 +21,7 @@ def get_model_name(filename):
modelname = tail[: tail.find("_chrome_trace")] modelname = tail[: tail.find("_chrome_trace")]
return modelname return modelname
def get_total_length(run_times_df, modelname): def get_total_length(run_times_df, modelname):
return float(run_times_df[run_times_df["name"] == modelname]["runtime"]) return float(run_times_df[run_times_df["name"] == modelname]["runtime"])
@ -31,14 +33,14 @@ def main():
"--runtime", "-runf", help="file name of the runtime file", required=True "--runtime", "-runf", help="file name of the runtime file", required=True
) )
group.add_argument( group.add_argument(
"--filename", "-f", action="append", help="a filename of the json file to process" "--filename",
) "-f",
group.add_argument( action="append",
"--folder", "-fd", help="a folder of the json files to process" help="a filename of the json file to process",
) )
group.add_argument("--folder", "-fd", help="a folder of the json files to process")
args = parser.parse_args() args = parser.parse_args()
if args.filename: if args.filename:
filenames = args.filename filenames = args.filename
elif args.folder: elif args.folder:
@ -58,11 +60,14 @@ def main():
try: try:
modelname = get_model_name(filename) modelname = get_model_name(filename)
total_length = get_total_length(run_times_df, modelname) * 1e6 total_length = get_total_length(run_times_df, modelname) * 1e6
utilization, mm_conv_utilization = compute_utilization(filenames, total_length) utilization, mm_conv_utilization = compute_utilization(
filenames, total_length
)
print(f"{modelname}, {utilization}, {mm_conv_utilization}") print(f"{modelname}, {utilization}, {mm_conv_utilization}")
except BaseException: except BaseException:
logging.exception("%s, ERROR", filename) logging.exception("%s, ERROR", filename)
print(f"{filename}, ERROR") print(f"{filename}, ERROR")
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -1,9 +1,10 @@
import torch import torch
import torch.fx as fx import torch.fx as fx
from functorch import make_fx from functorch import make_fx
from torch.profiler import profile, ProfilerActivity
from torch._functorch.compile_utils import fx_graph_cse from torch._functorch.compile_utils import fx_graph_cse
from torch.profiler import profile, ProfilerActivity
def profile_it(f, inp): def profile_it(f, inp):
for _ in range(5): for _ in range(5):
@ -20,6 +21,7 @@ def profile_it(f, inp):
cuda_time_total = cuda_time_total + e.cuda_time_total cuda_time_total = cuda_time_total + e.cuda_time_total
return cuda_time_total / itr return cuda_time_total / itr
def profile_function(name, f, inp): def profile_function(name, f, inp):
fx_g = make_fx(f)(inp) fx_g = make_fx(f)(inp)
@ -34,17 +36,23 @@ def profile_function(name, f, inp):
avg_cuda_time_g = profile_it(new_g, inp) avg_cuda_time_g = profile_it(new_g, inp)
num_node_decrease = len(fx_g.graph.nodes) - len(new_g.graph.nodes) num_node_decrease = len(fx_g.graph.nodes) - len(new_g.graph.nodes)
print(f"{name}, {avg_cuda_time_f}, {avg_cuda_time_g}, {num_node_decrease}, {len(fx_g.graph.nodes)}") print(
f"{name}, {avg_cuda_time_f}, {avg_cuda_time_g}, {num_node_decrease}, {len(fx_g.graph.nodes)}"
)
g_gpu = torch.Generator(device='cuda')
g_gpu = torch.Generator(device="cuda")
g_gpu.manual_seed(2147483647) g_gpu.manual_seed(2147483647)
inp = torch.randn(2**20, device='cuda', generator=g_gpu) inp = torch.randn(2**20, device="cuda", generator=g_gpu)
def f1(x): def f1(x):
return x.cos().cos() return x.cos().cos()
profile_function("f1", f1, inp) profile_function("f1", f1, inp)
def fsum(x): def fsum(x):
a = x.sum() a = x.sum()
b = x.sum() b = x.sum()
@ -52,22 +60,29 @@ def fsum(x):
d = x.sum() d = x.sum()
return a + b + c + d return a + b + c + d
profile_function("fsum", fsum, inp) profile_function("fsum", fsum, inp)
def fconcat(x): def fconcat(x):
a = torch.cat((x, x)) a = torch.cat((x, x))
b = torch.cat((x, x)) b = torch.cat((x, x))
return a + b return a + b
profile_function("fconcat", fconcat, inp) profile_function("fconcat", fconcat, inp)
def fsum2(x): def fsum2(x):
a = x.sum() a = x.sum()
for _ in range(30): for _ in range(30):
a = a + x.sum() a = a + x.sum()
return a return a
profile_function("fsum2", fsum2, inp) profile_function("fsum2", fsum2, inp)
def fsummulti(x): def fsummulti(x):
a = 0 a = 0
for _ in range(3): for _ in range(3):
@ -75,8 +90,10 @@ def fsummulti(x):
a = a * x.sum() a = a * x.sum()
return a return a
profile_function("fsummulti", fsummulti, inp) profile_function("fsummulti", fsummulti, inp)
def fsummulti2(x): def fsummulti2(x):
a = 0 a = 0
for _ in range(30): for _ in range(30):
@ -84,20 +101,25 @@ def fsummulti2(x):
a = a * x.sum() a = a * x.sum()
return a return a
profile_function("fsummulti2", fsummulti2, inp) profile_function("fsummulti2", fsummulti2, inp)
def fcos(x): def fcos(x):
a = 0 a = 0
for _ in range(3): for _ in range(3):
a = a + x.cos() a = a + x.cos()
return a return a
profile_function("fcos", fcos, inp) profile_function("fcos", fcos, inp)
def fcos2(x): def fcos2(x):
a = 0 a = 0
for _ in range(30): for _ in range(30):
a = a + x.cos() a = a + x.cos()
return a return a
profile_function("fcos2", fcos2, inp) profile_function("fcos2", fcos2, inp)

View File

@ -1,7 +1,8 @@
import timeit
from functools import partial from functools import partial
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import timeit
import torch import torch
from functorch.compile import pointwise_operator from functorch.compile import pointwise_operator

View File

@ -1,14 +1,14 @@
import time
import torch import torch
import torch.nn as nn import torch.nn as nn
import torchvision.models as models import torchvision.models as models
from opacus.utils.module_modification import convert_batchnorm_modules
import time
from functorch import vmap, grad from functorch import grad, make_functional, vmap
from functorch import make_functional
from opacus import PrivacyEngine from opacus import PrivacyEngine
from opacus.utils.module_modification import convert_batchnorm_modules
device = 'cuda' device = "cuda"
batch_size = 128 batch_size = 128
torch.manual_seed(0) torch.manual_seed(0)
@ -20,6 +20,7 @@ images = torch.randn(batch_size, 3, 32, 32, device=device)
targets = torch.randint(0, 10, (batch_size,), device=device) targets = torch.randint(0, 10, (batch_size,), device=device)
func_model, weights = make_functional(model_functorch) func_model, weights = make_functional(model_functorch)
def compute_loss(weights, image, target): def compute_loss(weights, image, target):
images = image.unsqueeze(0) images = image.unsqueeze(0)
targets = target.unsqueeze(0) targets = target.unsqueeze(0)
@ -27,11 +28,11 @@ def compute_loss(weights, image, target):
loss = criterion(output, targets) loss = criterion(output, targets)
return loss return loss
def functorch_per_sample_grad(): def functorch_per_sample_grad():
compute_grad = grad(compute_loss) compute_grad = grad(compute_loss)
compute_per_sample_grad = vmap(compute_grad, (None, 0, 0)) compute_per_sample_grad = vmap(compute_grad, (None, 0, 0))
start = time.time() start = time.time()
result = compute_per_sample_grad(weights, images, targets) result = compute_per_sample_grad(weights, images, targets)
torch.cuda.synchronize() torch.cuda.synchronize()
@ -39,6 +40,7 @@ def functorch_per_sample_grad():
return result, end - start # end - start in seconds return result, end - start # end - start in seconds
torch.manual_seed(0) torch.manual_seed(0)
model_opacus = convert_batchnorm_modules(models.resnet18(num_classes=10)) model_opacus = convert_batchnorm_modules(models.resnet18(num_classes=10))
model_opacus = model_opacus.to(device) model_opacus = model_opacus.to(device)
@ -54,6 +56,7 @@ privacy_engine = PrivacyEngine(
max_grad_norm=10000.0, max_grad_norm=10000.0,
) )
def opacus_per_sample_grad(): def opacus_per_sample_grad():
start = time.time() start = time.time()
output = model_opacus(images) output = model_opacus(images)
@ -63,7 +66,7 @@ def opacus_per_sample_grad():
end = time.time() end = time.time()
expected = [p.grad_sample for p in model_opacus.parameters()] expected = [p.grad_sample for p in model_opacus.parameters()]
for p in model_opacus.parameters(): for p in model_opacus.parameters():
delattr(p, 'grad_sample') delattr(p, "grad_sample")
p.grad = None p.grad = None
return expected, end - start return expected, end - start

View File

@ -1,14 +1,16 @@
import sys
import time
import torch
import inspect import inspect
import itertools import itertools
import sys
import time
import torch
from functorch import pointwise_operator from functorch import pointwise_operator
torch.set_num_threads(1) torch.set_num_threads(1)
torch._C._debug_set_fusion_group_inlining(False) torch._C._debug_set_fusion_group_inlining(False)
def rand(*shape): def rand(*shape):
return torch.rand(*shape).mul(16).add(1) return torch.rand(*shape).mul(16).add(1)
@ -19,105 +21,139 @@ def rand(*shape):
def scalar(): def scalar():
return (rand(1), rand(1)) return (rand(1), rand(1))
def small(): def small():
return (rand(32), rand(32)) return (rand(32), rand(32))
def small_2d(): def small_2d():
return (rand(1, 32), rand(1, 32)) return (rand(1, 32), rand(1, 32))
def small_broadcast(): def small_broadcast():
return (rand(4, 32), rand(32)) return (rand(4, 32), rand(32))
def medium(): def medium():
return (rand(32, 12, 64, 64), rand(32, 12, 64, 64)) return (rand(32, 12, 64, 64), rand(32, 12, 64, 64))
def medium_sliced(): def medium_sliced():
return (rand(32, 12, 64, 64)[..., ::2], return (rand(32, 12, 64, 64)[..., ::2], rand(32, 12, 64, 64)[..., ::2])
rand(32, 12, 64, 64)[..., ::2])
def medium_transpose(): def medium_transpose():
return (rand(32, 12, 64, 64).transpose(-1, -2), return (
rand(32, 12, 64, 64).transpose(-1, -2)) rand(32, 12, 64, 64).transpose(-1, -2),
rand(32, 12, 64, 64).transpose(-1, -2),
)
def medium2(): def medium2():
return (rand(32, 3, 224, 224), rand(32, 3, 224, 224)) return (rand(32, 3, 224, 224), rand(32, 3, 224, 224))
def medium3d(): def medium3d():
return (rand(16, 32, 64), rand(16, 32, 64)) return (rand(16, 32, 64), rand(16, 32, 64))
def medium_channels_last(): def medium_channels_last():
return (rand(32, 3, 224, 224).to(memory_format=torch.channels_last), return (
rand(32, 3, 224, 224).to(memory_format=torch.channels_last)) rand(32, 3, 224, 224).to(memory_format=torch.channels_last),
rand(32, 3, 224, 224).to(memory_format=torch.channels_last),
)
def medium_broadcast(): def medium_broadcast():
return (rand(32, 12, 64, 64), rand(64)) return (rand(32, 12, 64, 64), rand(64))
def medium_broadcast_channels_last(): def medium_broadcast_channels_last():
return (rand(32, 3, 223, 223).to(memory_format=torch.channels_last), return (rand(32, 3, 223, 223).to(memory_format=torch.channels_last), rand(3, 1, 1))
rand(3, 1, 1))
def large(): def large():
return (rand(8192, 8192), rand(8192, 8192)) return (rand(8192, 8192), rand(8192, 8192))
def large_transpose(): def large_transpose():
return (rand(8192, 8192).transpose(0, 1), return (rand(8192, 8192).transpose(0, 1), rand(8192, 8192).transpose(0, 1))
rand(8192, 8192).transpose(0, 1))
def large_channels_last(): def large_channels_last():
return (rand(32, 32, 256, 256).to(memory_format=torch.channels_last), return (
rand(32, 32, 256, 256).to(memory_format=torch.channels_last)) rand(32, 32, 256, 256).to(memory_format=torch.channels_last),
rand(32, 32, 256, 256).to(memory_format=torch.channels_last),
)
def pathological_broadcast(): def pathological_broadcast():
return (rand(1, 32, 32, 2), rand(1024, 1, 1, 2)) return (rand(1, 32, 32, 2), rand(1024, 1, 1, 2))
# ------------------------------------------------------------------------------ # ------------------------------------------------------------------------------
# Operator test cases # Operator test cases
# ------------------------------------------------------------------------------ # ------------------------------------------------------------------------------
def add(a, b): def add(a, b):
return a + b return a + b
def sub(a, b): def sub(a, b):
return a - b return a - b
def mul(a, b): def mul(a, b):
return a * b return a * b
def div(a, b): def div(a, b):
return a / b return a / b
def relu(a): def relu(a):
return a.relu() return a.relu()
def sigmoid(a): def sigmoid(a):
return a.sigmoid() return a.sigmoid()
def tanh(a): def tanh(a):
return a.tanh() return a.tanh()
def log(a): def log(a):
return a.log() return a.log()
def exp(a): def exp(a):
return a.exp() return a.exp()
def square(a): def square(a):
return a**2 return a**2
def fma(a, b): def fma(a, b):
return a * b + b return a * b + b
def hardswish(a): def hardswish(a):
return a * (a + 3.0).clamp(0.0, 6.0) / 6.0 return a * (a + 3.0).clamp(0.0, 6.0) / 6.0
def native_hardswish(a): def native_hardswish(a):
return torch._C._nn.hardswish(a) return torch._C._nn.hardswish(a)
def softplus(a): def softplus(a):
return (a * 1.0).exp().log1p() / 1.0 return (a * 1.0).exp().log1p() / 1.0
def mish(a): def mish(a):
return a * ((a * 1.0).exp().log1p() / 1.0).tanh() return a * ((a * 1.0).exp().log1p() / 1.0).tanh()
# ------------------------------------------------------------------------------ # ------------------------------------------------------------------------------
# Helpers # Helpers
# ------------------------------------------------------------------------------ # ------------------------------------------------------------------------------
@ -128,6 +164,7 @@ def time_cpu(fn, args, iters):
e = time.perf_counter() e = time.perf_counter()
return e - s return e - s
def time_cuda(fn, args, iters): def time_cuda(fn, args, iters):
start = torch.cuda.Event(enable_timing=True) start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True)
@ -138,19 +175,23 @@ def time_cuda(fn, args, iters):
torch.cuda.synchronize() torch.cuda.synchronize()
return start.elapsed_time(end) / 1e3 return start.elapsed_time(end) / 1e3
def benchmark_with_timer(fn, args, timer): def benchmark_with_timer(fn, args, timer):
timer(fn, args, 3) timer(fn, args, 3)
calibration = timer(fn, args, 1) calibration = timer(fn, args, 1)
iters = int(1.0 / calibration) iters = int(1.0 / calibration)
return timer(fn, args, iters) / iters return timer(fn, args, iters) / iters
def benchmark(fn, args): def benchmark(fn, args):
timer = time_cpu if args[0].device.type == "cpu" else time_cuda timer = time_cpu if args[0].device.type == "cpu" else time_cuda
return benchmark_with_timer(fn, args, timer) return benchmark_with_timer(fn, args, timer)
def micros(s): def micros(s):
return f"{s * 1e6:.1f}" return f"{s * 1e6:.1f}"
shapes = [ shapes = [
scalar, scalar,
small, small,
@ -211,7 +252,17 @@ for shape, operator in itertools.product(shapes, operators):
args = shape()[:nargs] args = shape()[:nargs]
result = benchmark(operator, args) result = benchmark(operator, args)
print(",".join(["eager", args[0].device.type, operator.__name__, shape.__name__, micros(result)])) print(
",".join(
[
"eager",
args[0].device.type,
operator.__name__,
shape.__name__,
micros(result),
]
)
)
try: try:
if shape == medium_transpose: if shape == medium_transpose:
raise RuntimeError("pointwise_operator hangs on medium_transpose") raise RuntimeError("pointwise_operator hangs on medium_transpose")
@ -219,11 +270,41 @@ for shape, operator in itertools.product(shapes, operators):
raise RuntimeError("pointwise_operator fails on medium_transpose") raise RuntimeError("pointwise_operator fails on medium_transpose")
pw_op = pointwise_operator(operator) pw_op = pointwise_operator(operator)
result = benchmark(pw_op, args) result = benchmark(pw_op, args)
print(",".join(["pointwise", args[0].device.type, operator.__name__, shape.__name__, micros(result)])) print(
",".join(
[
"pointwise",
args[0].device.type,
operator.__name__,
shape.__name__,
micros(result),
]
)
)
except Exception: except Exception:
print(",".join(["pointwise", args[0].device.type, operator.__name__, shape.__name__, micros(float("nan"))])) print(
",".join(
[
"pointwise",
args[0].device.type,
operator.__name__,
shape.__name__,
micros(float("nan")),
]
)
)
ts_op = torch.jit.script(operator) ts_op = torch.jit.script(operator)
result = benchmark(ts_op, args) result = benchmark(ts_op, args)
print(",".join(["fuser", args[0].device.type, operator.__name__, shape.__name__, micros(result)])) print(
",".join(
[
"fuser",
args[0].device.type,
operator.__name__,
shape.__name__,
micros(result),
]
)
)
sys.stdout.flush() sys.stdout.flush()

View File

@ -1,11 +1,13 @@
import pandas
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import pandas
df = pandas.read_csv("perf.csv") df = pandas.read_csv("perf.csv")
ops = pandas.unique(df["operator"]) ops = pandas.unique(df["operator"])
nops = len(ops) nops = len(ops)
pivot_op_shape = df.pivot_table(values="time", index=["operator", "shape"], columns=["fuser"]) pivot_op_shape = df.pivot_table(
values="time", index=["operator", "shape"], columns=["fuser"]
)
pivot_speedups = (pivot_op_shape.T / pivot_op_shape["eager"]).T pivot_speedups = (pivot_op_shape.T / pivot_op_shape["eager"]).T
plt.rcParams["figure.figsize"] = (20, 100) plt.rcParams["figure.figsize"] = (20, 100)

View File

@ -1,31 +1,31 @@
from torch._functorch.python_key import pythonkey_decompose from torch._functorch import config
from torch._functorch.fx_minifier import minifier
from torch._functorch.aot_autograd import ( from torch._functorch.aot_autograd import (
aot_function, aot_function,
aot_module, aot_module,
aot_module_simplified,
compiled_function, compiled_function,
compiled_module, compiled_module,
aot_module_simplified,
get_graph_being_compiled,
get_aot_graph_name,
get_aot_compilation_context, get_aot_compilation_context,
get_aot_graph_name,
get_graph_being_compiled,
make_boxed_compiler,
make_boxed_func, make_boxed_func,
make_boxed_compiler
) )
from torch._functorch.compilers import ( from torch._functorch.compilers import (
ts_compile,
draw_graph_compile,
nop,
nnc_jit,
memory_efficient_fusion,
debug_compile, debug_compile,
default_decompositions,
draw_graph_compile,
memory_efficient_fusion,
nnc_jit,
nop,
print_compile, print_compile,
default_decompositions ts_compile,
) )
from torch._functorch.fx_minifier import minifier
from torch._functorch.partitioners import ( from torch._functorch.partitioners import (
min_cut_rematerialization_partition,
default_partition, default_partition,
draw_graph, draw_graph,
draw_joint_graph, draw_joint_graph,
min_cut_rematerialization_partition,
) )
from torch._functorch import config from torch._functorch.python_key import pythonkey_decompose

View File

@ -1,20 +1,26 @@
import torch
from typing import Union, Sequence
import inspect
import dis import dis
from .tree_map import tree_flatten, tree_map import inspect
from .wrap_type import wrap_type from typing import Sequence, Union
import torch
import functorch._C import functorch._C
from functorch._C import dim as _C from functorch._C import dim as _C
from .tree_map import tree_flatten, tree_map
from .wrap_type import wrap_type
_C._patch_tensor_class() _C._patch_tensor_class()
dims, DimList, dimlists = _C.dims, _C.DimList, _C.dimlists dims, DimList, dimlists = _C.dims, _C.DimList, _C.dimlists
class DimensionMismatchError(Exception): class DimensionMismatchError(Exception):
pass pass
class DimensionBindError(Exception): class DimensionBindError(Exception):
pass pass
from . import op_properties from . import op_properties
# use dict to avoid writing C++ bindings for set # use dict to avoid writing C++ bindings for set
@ -24,11 +30,11 @@ use_c = True
if not use_c: if not use_c:
from . import reference from . import reference
class _Tensor: class _Tensor:
# fast path around slow wrapping/unwrapping logic for simply queries used # fast path around slow wrapping/unwrapping logic for simply queries used
# by the implementation... # by the implementation...
@property @property
def dims(self): def dims(self):
return tuple(d for d in self._levels if isinstance(d, Dim)) return tuple(d for d in self._levels if isinstance(d, Dim))
@ -47,11 +53,12 @@ class _Tensor:
def __repr__(self): def __repr__(self):
tensor, levels, ndim = self._tensor, self._levels, self.ndim tensor, levels, ndim = self._tensor, self._levels, self.ndim
return f'{tensor}\nwith dims={tuple(l + ndim if isinstance(l, int) else l for l in levels)} sizes={tuple(tensor.size())}' return f"{tensor}\nwith dims={tuple(l + ndim if isinstance(l, int) else l for l in levels)} sizes={tuple(tensor.size())}"
TensorLike = (_Tensor, torch.Tensor) TensorLike = (_Tensor, torch.Tensor)
class Dim(_C.Dim, _Tensor): class Dim(_C.Dim, _Tensor):
# note that _C.Dim comes before tensor because we want the Dim API for things like size to take precendence. # note that _C.Dim comes before tensor because we want the Dim API for things like size to take precendence.
# Tensor defines format, but we want to print Dims with special formatting # Tensor defines format, but we want to print Dims with special formatting
@ -69,6 +76,7 @@ def cat(tensors, dim, new_dim):
n = dims() n = dims()
return stack(tensors, n, dim).index([n, dim], new_dim) return stack(tensors, n, dim).index([n, dim], new_dim)
if use_c: if use_c:
_wrap = _C._wrap _wrap = _C._wrap
@ -107,41 +115,41 @@ if use_c:
else: else:
_Tensor.order = reference.positional _Tensor.order = reference.positional
_def('mean') _def("mean")
_def('sum') _def("sum")
_def('all') _def("all")
_def('amax') _def("amax")
_def('amin') _def("amin")
_def('aminmax') _def("aminmax")
_def('any') _def("any")
_def('count_nonzero') _def("count_nonzero")
_def('logsumexp') _def("logsumexp")
_def('nanmean') _def("nanmean")
_def('nansum') _def("nansum")
_def('prod') _def("prod")
_def('std', keepdim_offset=2) _def("std", keepdim_offset=2)
_def('var', keepdim_offset=2) _def("var", keepdim_offset=2)
_def('max', single_dim=True) _def("max", single_dim=True)
_def('min', single_dim=True) _def("min", single_dim=True)
_def('argmax', single_dim=True) _def("argmax", single_dim=True)
_def('argmin', single_dim=True) _def("argmin", single_dim=True)
_def('kthvalue', single_dim=True) _def("kthvalue", single_dim=True)
_def('median', single_dim=True) _def("median", single_dim=True)
_def('nanmedian', single_dim=True) _def("nanmedian", single_dim=True)
_def('mode', single_dim=True) _def("mode", single_dim=True)
_def('sort', reduce=False) _def("sort", reduce=False)
_def('argsort', reduce=False) _def("argsort", reduce=False)
_def('unbind', single_dim=True) _def("unbind", single_dim=True)
_def('chunk', dim_offset=1, reduce=False) _def("chunk", dim_offset=1, reduce=False)
_def('cummax', single_dim=True, reduce=False) _def("cummax", single_dim=True, reduce=False)
_def('cummin', single_dim=True, reduce=False) _def("cummin", single_dim=True, reduce=False)
_def('cumprod', single_dim=True, reduce=False) _def("cumprod", single_dim=True, reduce=False)
_def('cumprod_', single_dim=True, reduce=False) _def("cumprod_", single_dim=True, reduce=False)
_def('cumsum', single_dim=True, reduce=False) _def("cumsum", single_dim=True, reduce=False)
_def('cumsum_', single_dim=True, reduce=False) _def("cumsum_", single_dim=True, reduce=False)
_def('logcumsumexp', single_dim=True, reduce=False) _def("logcumsumexp", single_dim=True, reduce=False)
_def('renorm', dim_offset=1, single_dim=True, reduce=False) _def("renorm", dim_offset=1, single_dim=True, reduce=False)
_def('softmax', single_dim=True, reduce=False) _def("softmax", single_dim=True, reduce=False)
softmax = _wrap(torch.nn.functional.softmax, single_dim=True, reduce=False) softmax = _wrap(torch.nn.functional.softmax, single_dim=True, reduce=False)
# stuff to handle in the future, because they require special # stuff to handle in the future, because they require special

View File

@ -3,14 +3,13 @@
# #
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from torch._C._functorch import (
_vmap_add_layers,
_vmap_remove_layers,
)
from contextlib import contextmanager from contextlib import contextmanager
from torch._C._functorch import _vmap_add_layers, _vmap_remove_layers
_enabled = False _enabled = False
@contextmanager @contextmanager
def _enable_layers(dims): def _enable_layers(dims):
global _enabled global _enabled

View File

@ -4,9 +4,11 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import torch import torch
from . import _Tensor, Tensor from . import _Tensor, Tensor
from .reference import _dims, _enable_layers, llist, ltuple from .reference import _dims, _enable_layers, llist, ltuple
class DelayedMulTensor(_Tensor): class DelayedMulTensor(_Tensor):
def __init__(self, lhs, rhs): def __init__(self, lhs, rhs):
self._lhs, self._rhs = lhs, rhs self._lhs, self._rhs = lhs, rhs
@ -37,7 +39,9 @@ class DelayedMulTensor(_Tensor):
@property @property
def _tensor(self): def _tensor(self):
if self._tensor_data is None: if self._tensor_data is None:
self._tensor_data = Tensor.from_batched(self._batchtensor, self._has_device)._tensor self._tensor_data = Tensor.from_batched(
self._batchtensor, self._has_device
)._tensor
return self._tensor_data return self._tensor_data
@property @property
@ -48,20 +52,26 @@ class DelayedMulTensor(_Tensor):
def dims(self): def dims(self):
return ltuple(super().dims) return ltuple(super().dims)
def sum(self, dim): def sum(self, dim):
dims = _dims(dim, 0, False, False) dims = _dims(dim, 0, False, False)
n = ord('a') n = ord("a")
all_levels = self._levels all_levels = self._levels
def to_char(d): def to_char(d):
return chr(n + all_levels.index(d)) return chr(n + all_levels.index(d))
plhs, levelslhs = self._lhs._tensor, self._lhs._levels plhs, levelslhs = self._lhs._tensor, self._lhs._levels
prhs, levelsrhs = self._rhs._tensor, self._rhs._levels prhs, levelsrhs = self._rhs._tensor, self._rhs._levels
new_dims = tuple(d for d in self.dims if d not in dims) new_dims = tuple(d for d in self.dims if d not in dims)
new_levels = [l for l in self._levels if l not in dims] new_levels = [l for l in self._levels if l not in dims]
fmt = ''.join([*(to_char(d) for d in levelslhs), ',', fmt = "".join(
*(to_char(d) for d in levelsrhs), '->', [
*(to_char(d) for d in new_levels)]) *(to_char(d) for d in levelslhs),
",",
*(to_char(d) for d in levelsrhs),
"->",
*(to_char(d) for d in new_levels),
]
)
result_data = torch.einsum(fmt, (plhs, prhs)) result_data = torch.einsum(fmt, (plhs, prhs))
return Tensor.from_positional(result_data, new_levels, True) return Tensor.from_positional(result_data, new_levels, True)

View File

@ -4,11 +4,14 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
_vmap_levels = [] _vmap_levels = []
@dataclass @dataclass
class LevelInfo: class LevelInfo:
level: int level: int
alive: bool = True alive: bool = True
class Dim: class Dim:
def __init__(self, name: str, size: Union[None, int] = None): def __init__(self, name: str, size: Union[None, int] = None):
self.name = name self.name = name
@ -20,7 +23,9 @@ class Dim:
def __del__(self): def __del__(self):
if self._vmap_level is not None: if self._vmap_level is not None:
_vmap_active_levels[self._vmap_stack].alive = False _vmap_active_levels[self._vmap_stack].alive = False
while not _vmap_levels[-1].alive and current_level() == _vmap_levels[-1].level: while (
not _vmap_levels[-1].alive and current_level() == _vmap_levels[-1].level
):
_vmap_decrement_nesting() _vmap_decrement_nesting()
_vmap_levels.pop() _vmap_levels.pop()
@ -33,13 +38,14 @@ class Dim:
def size(self, size: int): def size(self, size: int):
if self._size is None: if self._size is None:
self._size = size self._size = size
self._vmap_level = _vmap_increment_nesting(size, 'same') self._vmap_level = _vmap_increment_nesting(size, "same")
self._vmap_stack = len(_vmap_levels) self._vmap_stack = len(_vmap_levels)
_vmap_levels.append(LevelInfo(self._vmap_level)) _vmap_levels.append(LevelInfo(self._vmap_level))
elif self._size != size: elif self._size != size:
raise DimensionBindError( raise DimensionBindError(
f"Dim '{self}' previously bound to a dimension of size {self._size} cannot bind to a dimension of size {size}") f"Dim '{self}' previously bound to a dimension of size {self._size} cannot bind to a dimension of size {size}"
)
@property @property
def is_bound(self): def is_bound(self):
@ -50,10 +56,13 @@ class Dim:
def extract_name(inst): def extract_name(inst):
assert inst.opname == 'STORE_FAST' or inst.opname == 'STORE_NAME' assert inst.opname == "STORE_FAST" or inst.opname == "STORE_NAME"
return inst.argval return inst.argval
_cache = {} _cache = {}
def dims(lists=0): def dims(lists=0):
frame = inspect.currentframe() frame = inspect.currentframe()
assert frame is not None assert frame is not None
@ -66,17 +75,22 @@ def dims(lists=0):
instructions = list(dis.get_instructions(calling_frame.f_code)) instructions = list(dis.get_instructions(calling_frame.f_code))
unpack = instructions[first] unpack = instructions[first]
if unpack.opname == 'STORE_FAST' or unpack.opname == 'STORE_NAME': if unpack.opname == "STORE_FAST" or unpack.opname == "STORE_NAME":
# just a single dim, not a list # just a single dim, not a list
name = unpack.argval name = unpack.argval
ctor = Dim if lists == 0 else DimList ctor = Dim if lists == 0 else DimList
_cache[key] = lambda: ctor(name=name) _cache[key] = lambda: ctor(name=name)
else: else:
assert unpack.opname == 'UNPACK_SEQUENCE' assert unpack.opname == "UNPACK_SEQUENCE"
ndims = unpack.argval ndims = unpack.argval
names = tuple(extract_name(instructions[first + 1 + i]) for i in range(ndims)) names = tuple(
extract_name(instructions[first + 1 + i]) for i in range(ndims)
)
first_list = len(names) - lists first_list = len(names) - lists
_cache[key] = lambda: tuple(Dim(n) if i < first_list else DimList(name=n) for i, n in enumerate(names)) _cache[key] = lambda: tuple(
Dim(n) if i < first_list else DimList(name=n)
for i, n in enumerate(names)
)
return _cache[key]() return _cache[key]()
@ -87,6 +101,7 @@ def _dim_set(positional, arg):
else: else:
assert isinstance(a, int) assert isinstance(a, int)
return positional[a] return positional[a]
if arg is None: if arg is None:
return positional return positional
elif not isinstance(arg, (Dim, int)): elif not isinstance(arg, (Dim, int)):

View File

@ -3,25 +3,33 @@
# #
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from contextlib import contextmanager
import os import os
import subprocess
import signal import signal
import subprocess
from contextlib import contextmanager
@contextmanager @contextmanager
def magic_trace(output='trace.fxt', magic_trace_cache='/tmp/magic-trace'): def magic_trace(output="trace.fxt", magic_trace_cache="/tmp/magic-trace"):
pid = os.getpid() pid = os.getpid()
if not os.path.exists(magic_trace_cache): if not os.path.exists(magic_trace_cache):
print(f"Downloading magic_trace to: {magic_trace_cache}") print(f"Downloading magic_trace to: {magic_trace_cache}")
subprocess.run(['wget', '-O', magic_trace_cache, '-q', subprocess.run(
'https://github.com/janestreet/magic-trace/releases/download/v1.0.2/magic-trace']) [
subprocess.run(['chmod', '+x', magic_trace_cache]) "wget",
args = [magic_trace_cache, 'attach', '-pid', str(pid), '-o', output] "-O",
p = subprocess.Popen(args, stderr=subprocess.PIPE, encoding='utf-8') magic_trace_cache,
"-q",
"https://github.com/janestreet/magic-trace/releases/download/v1.0.2/magic-trace",
]
)
subprocess.run(["chmod", "+x", magic_trace_cache])
args = [magic_trace_cache, "attach", "-pid", str(pid), "-o", output]
p = subprocess.Popen(args, stderr=subprocess.PIPE, encoding="utf-8")
while True: while True:
x = p.stderr.readline() x = p.stderr.readline()
print(x) print(x)
if 'Attached' in x: if "Attached" in x:
break break
try: try:
yield yield
@ -31,4 +39,4 @@ def magic_trace(output='trace.fxt', magic_trace_cache='/tmp/magic-trace'):
print(p.stderr.read()) print(p.stderr.read())
p.stderr.close() p.stderr.close()
if r != 0: if r != 0:
raise ValueError(f'magic_trace exited abnormally: {r}') raise ValueError(f"magic_trace exited abnormally: {r}")

View File

@ -4,29 +4,58 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import torch import torch
# pointwise operators can go through a faster pathway # pointwise operators can go through a faster pathway
tensor_magic_methods = [ tensor_magic_methods = ["add", ""]
'add',
''
]
pointwise_magic_methods_with_reverse = ( pointwise_magic_methods_with_reverse = (
'add', 'sub', 'mul', 'floordiv', 'div', 'truediv', 'mod', "add",
'pow', 'lshift', 'rshift', 'and', 'or', 'xor' "sub",
"mul",
"floordiv",
"div",
"truediv",
"mod",
"pow",
"lshift",
"rshift",
"and",
"or",
"xor",
) )
pointwise_magic_methods = ( pointwise_magic_methods = (
*(x for m in pointwise_magic_methods_with_reverse for x in (m, 'r' + m)), *(x for m in pointwise_magic_methods_with_reverse for x in (m, "r" + m)),
'eq', 'gt', 'le', 'lt', 'ge', 'gt', 'ne', 'neg', 'pos', "eq",
'abs', 'invert', "gt",
'iadd', 'isub', 'imul', 'ifloordiv', 'idiv', "le",
'itruediv', 'imod', 'ipow', 'ilshift', 'irshift', 'iand', "lt",
'ior', 'ixor', "ge",
'int', 'long', 'float', 'complex', "gt",
"ne",
"neg",
"pos",
"abs",
"invert",
"iadd",
"isub",
"imul",
"ifloordiv",
"idiv",
"itruediv",
"imod",
"ipow",
"ilshift",
"irshift",
"iand",
"ior",
"ixor",
"int",
"long",
"float",
"complex",
) )
pointwise_methods = ( pointwise_methods = (*(f"__{m}__" for m in pointwise_magic_methods),)
*(f'__{m}__' for m in pointwise_magic_methods),
)
pointwise = ( pointwise = (
*(getattr(torch.Tensor, m) for m in pointwise_methods), *(getattr(torch.Tensor, m) for m in pointwise_methods),

View File

@ -6,23 +6,28 @@
# reference python implementations for C ops # reference python implementations for C ops
import torch import torch
from .tree_map import tree_flatten, tree_map
from .batch_tensor import _enable_layers
from . import op_properties
from functorch._C import dim as _C from functorch._C import dim as _C
from . import op_properties
from .batch_tensor import _enable_layers
from .tree_map import tree_flatten, tree_map
DimList = _C.DimList DimList = _C.DimList
from functools import reduce
import operator import operator
from functools import reduce
# use dict to avoid writing C++ bindings for set # use dict to avoid writing C++ bindings for set
pointwise = set(op_properties.pointwise) pointwise = set(op_properties.pointwise)
def prod(x): def prod(x):
return reduce(operator.mul, x, 1) return reduce(operator.mul, x, 1)
def _wrap_dim(d, N, keepdim): def _wrap_dim(d, N, keepdim):
from . import Dim from . import Dim
if isinstance(d, Dim): if isinstance(d, Dim):
assert not keepdim, "cannot preserve first-class dimensions with keepdim=True" assert not keepdim, "cannot preserve first-class dimensions with keepdim=True"
return d return d
@ -31,40 +36,52 @@ def _wrap_dim(d, N, keepdim):
else: else:
return d return d
def _dims(d, N, keepdim, single_dim): def _dims(d, N, keepdim, single_dim):
from . import Dim from . import Dim
if isinstance(d, (Dim, int)): if isinstance(d, (Dim, int)):
return ltuple((_wrap_dim(d, N, keepdim),)) return ltuple((_wrap_dim(d, N, keepdim),))
assert not single_dim, f"expected a single dimension or int but found: {d}" assert not single_dim, f"expected a single dimension or int but found: {d}"
return ltuple(_wrap_dim(x, N, keepdim) for x in d) return ltuple(_wrap_dim(x, N, keepdim) for x in d)
def _bind_dims_to_size(lhs_size, rhs, lhs_debug): def _bind_dims_to_size(lhs_size, rhs, lhs_debug):
from . import DimensionMismatchError from . import DimensionMismatchError
not_bound = tuple((i, r) for i, r in enumerate(rhs) if not r.is_bound) not_bound = tuple((i, r) for i, r in enumerate(rhs) if not r.is_bound)
if len(not_bound) == 1: if len(not_bound) == 1:
idx, d = not_bound[0] idx, d = not_bound[0]
rhs_so_far = prod(r.size for r in rhs if r.is_bound) rhs_so_far = prod(r.size for r in rhs if r.is_bound)
if lhs_size % rhs_so_far != 0: if lhs_size % rhs_so_far != 0:
rhs_s = tuple('?' if not r.is_bound else str(r.size) for r in rhs) rhs_s = tuple("?" if not r.is_bound else str(r.size) for r in rhs)
raise DimensionMismatchError(f"inferred dimension does not evenly fit into larger dimension: {lhs_size} vs {rhs_s}") raise DimensionMismatchError(
f"inferred dimension does not evenly fit into larger dimension: {lhs_size} vs {rhs_s}"
)
new_size = lhs_size // rhs_so_far new_size = lhs_size // rhs_so_far
d.size = new_size d.size = new_size
elif len(not_bound) > 1: elif len(not_bound) > 1:
rhs_s = tuple('?' if not r.is_bound else str(r.size) for r in rhs) rhs_s = tuple("?" if not r.is_bound else str(r.size) for r in rhs)
raise DimensionMismatchError(f"cannot infer the size of two dimensions at once: {rhs} with sizes {rhs_s}") raise DimensionMismatchError(
f"cannot infer the size of two dimensions at once: {rhs} with sizes {rhs_s}"
)
else: else:
rhs_size = prod(r.size for r in rhs) rhs_size = prod(r.size for r in rhs)
if lhs_size != rhs_size: if lhs_size != rhs_size:
raise DimensionMismatchError( raise DimensionMismatchError(
f"Dimension sizes to do not match ({lhs_size} != {rhs_size}) when matching {lhs_debug} to {rhs}") f"Dimension sizes to do not match ({lhs_size} != {rhs_size}) when matching {lhs_debug} to {rhs}"
)
def _tensor_levels(inp): def _tensor_levels(inp):
from . import _Tensor from . import _Tensor
if isinstance(inp, _Tensor): if isinstance(inp, _Tensor):
return inp._tensor, llist(inp._levels), inp._has_device return inp._tensor, llist(inp._levels), inp._has_device
else: else:
return inp, llist(range(-inp.ndim, 0)), True return inp, llist(range(-inp.ndim, 0)), True
def _match_levels(v, from_levels, to_levels): def _match_levels(v, from_levels, to_levels):
view = [] view = []
permute = [] permute = []
@ -90,6 +107,7 @@ def _match_levels(v, from_levels, to_levels):
# should not physically move if possible # should not physically move if possible
def _positional_no_permute(self, dim, expand_dim=False): def _positional_no_permute(self, dim, expand_dim=False):
from . import Tensor from . import Tensor
ptensor, levels = self._tensor, llist(self._levels) ptensor, levels = self._tensor, llist(self._levels)
try: try:
idx = levels.index(dim) idx = levels.index(dim)
@ -107,8 +125,10 @@ def _positional_no_permute(self, dim, expand_dim=False):
levels[idx] = -idx_batched - 1 levels[idx] = -idx_batched - 1
return Tensor.from_positional(ptensor, levels, self._has_device), idx_batched return Tensor.from_positional(ptensor, levels, self._has_device), idx_batched
def seq(a, b): def seq(a, b):
from . import Dim from . import Dim
if isinstance(a, Dim) != isinstance(b, Dim): if isinstance(a, Dim) != isinstance(b, Dim):
return False return False
if isinstance(a, Dim): if isinstance(a, Dim):
@ -116,6 +136,7 @@ def seq(a, b):
else: else:
return a == b return a == b
class isin: class isin:
def __contains__(self, item): def __contains__(self, item):
for x in self: for x in self:
@ -133,18 +154,27 @@ class isin:
class llist(isin, list): class llist(isin, list):
pass pass
class ltuple(isin, tuple): class ltuple(isin, tuple):
pass pass
empty_dict = {} empty_dict = {}
@classmethod @classmethod
def __torch_function__(self, orig, cls, args, kwargs=empty_dict): def __torch_function__(self, orig, cls, args, kwargs=empty_dict):
from . import _Tensor, TensorLike, Tensor from . import _Tensor, Tensor, TensorLike
from .delayed_mul_tensor import DelayedMulTensor from .delayed_mul_tensor import DelayedMulTensor
if orig is torch.Tensor.__mul__: if orig is torch.Tensor.__mul__:
lhs, rhs = args lhs, rhs = args
if isinstance(lhs, _Tensor) and isinstance(rhs, _Tensor) and lhs.ndim == 0 and rhs.ndim == 0: if (
isinstance(lhs, _Tensor)
and isinstance(rhs, _Tensor)
and lhs.ndim == 0
and rhs.ndim == 0
):
return DelayedMulTensor(lhs, rhs) return DelayedMulTensor(lhs, rhs)
all_dims = llist() all_dims = llist()
flat_args, unflatten = tree_flatten((args, kwargs)) flat_args, unflatten = tree_flatten((args, kwargs))
@ -172,7 +202,11 @@ def __torch_function__(self, orig, cls, args, kwargs=empty_dict):
for i, f in enumerate(flat_args): for i, f in enumerate(flat_args):
if isinstance(f, TensorLike): if isinstance(f, TensorLike):
ptensor, levels, _ = _tensor_levels(f) ptensor, levels, _ = _tensor_levels(f)
if isinstance(f, _Tensor) and not f._has_device and device_holding_tensor is not None: if (
isinstance(f, _Tensor)
and not f._has_device
and device_holding_tensor is not None
):
ptensor = ptensor.to(device=device_holding_tensor.device) ptensor = ptensor.to(device=device_holding_tensor.device)
flat_args[i] = ptensor flat_args[i] = ptensor
for l in levels: for l in levels:
@ -187,14 +221,19 @@ def __torch_function__(self, orig, cls, args, kwargs=empty_dict):
def wrap(t): def wrap(t):
if isinstance(t, TensorLike): if isinstance(t, TensorLike):
return Tensor.from_positional(t, result_levels, device_holding_tensor is not None) return Tensor.from_positional(
t, result_levels, device_holding_tensor is not None
)
return t return t
return tree_map(wrap, result) return tree_map(wrap, result)
else: else:
def wrap(t): def wrap(t):
if isinstance(t, TensorLike): if isinstance(t, TensorLike):
return Tensor.from_batched(t, device_holding_tensor is not None) return Tensor.from_batched(t, device_holding_tensor is not None)
return t return t
with _enable_layers(all_dims): with _enable_layers(all_dims):
print(f"batch_tensor for {orig}") print(f"batch_tensor for {orig}")
args, kwargs = unflatten(unwrap(f) for f in flat_args) args, kwargs = unflatten(unwrap(f) for f in flat_args)
@ -202,8 +241,10 @@ def __torch_function__(self, orig, cls, args, kwargs=empty_dict):
# print("END", orig) # print("END", orig)
return tree_map(wrap, result) return tree_map(wrap, result)
def positional(self, *dims): def positional(self, *dims):
from . import Dim, Tensor from . import Dim, Tensor
ptensor, levels = self._tensor, llist(self._levels) ptensor, levels = self._tensor, llist(self._levels)
flat_dims = llist() flat_dims = llist()
view = [] view = []
@ -231,7 +272,9 @@ def positional(self, *dims):
try: try:
idx = levels.index(d) idx = levels.index(d)
except ValueError as e: except ValueError as e:
raise DimensionBindError(f'tensor of dimensions {self.dims} does not contain dim {d}') from e raise DimensionBindError(
f"tensor of dimensions {self.dims} does not contain dim {d}"
) from e
p = permute[idx] p = permute[idx]
del levels[idx] del levels[idx]
del permute[idx] del permute[idx]
@ -248,12 +291,15 @@ def positional(self, *dims):
result = result.reshape(*view, *result.size()[len(flat_dims) :]) result = result.reshape(*view, *result.size()[len(flat_dims) :])
return result return result
def _contains_dim(input): def _contains_dim(input):
from . import Dim from . import Dim
for i in input: for i in input:
if isinstance(i, Dim): if isinstance(i, Dim):
return True return True
def expand(self, *sizes): def expand(self, *sizes):
if not _contains_dim(sizes): if not _contains_dim(sizes):
return self.__torch_function__(torch.Tensor.expand, None, (self, *sizes)) return self.__torch_function__(torch.Tensor.expand, None, (self, *sizes))
@ -265,27 +311,36 @@ def expand(self, *sizes):
_not_present = object() _not_present = object()
def _getarg(name, offset, args, kwargs, default): def _getarg(name, offset, args, kwargs, default):
if len(args) > offset: if len(args) > offset:
return args[offset] return args[offset]
return kwargs.get(name, default) return kwargs.get(name, default)
def _patcharg(name, offset, args, kwargs, value): def _patcharg(name, offset, args, kwargs, value):
if len(args) > offset: if len(args) > offset:
args[offset] = value args[offset] = value
else: else:
kwargs[name] = value kwargs[name] = value
def _wrap(orig, dim_offset=0, keepdim_offset=1, dim_name='dim', single_dim=False, reduce=True):
from . import TensorLike, Dim, Tensor def _wrap(
orig, dim_offset=0, keepdim_offset=1, dim_name="dim", single_dim=False, reduce=True
):
from . import Dim, Tensor, TensorLike
def fn(self, *args, **kwargs): def fn(self, *args, **kwargs):
dim = _getarg(dim_name, dim_offset, args, kwargs, _not_present) dim = _getarg(dim_name, dim_offset, args, kwargs, _not_present)
if dim is _not_present or (single_dim and not isinstance(dim, Dim)): if dim is _not_present or (single_dim and not isinstance(dim, Dim)):
with _enable_layers(self.dims): with _enable_layers(self.dims):
print(f"dim fallback batch_tensor for {orig}") print(f"dim fallback batch_tensor for {orig}")
return Tensor.from_batched(orig(self._batchtensor, *args, **kwargs), self._has_device) return Tensor.from_batched(
keepdim = _getarg('keepdim', keepdim_offset, args, kwargs, False) if reduce else False orig(self._batchtensor, *args, **kwargs), self._has_device
)
keepdim = (
_getarg("keepdim", keepdim_offset, args, kwargs, False) if reduce else False
)
t, levels = self._tensor, llist(self._levels) t, levels = self._tensor, llist(self._levels)
dims = _dims(dim, self._batchtensor.ndim, keepdim, single_dim) dims = _dims(dim, self._batchtensor.ndim, keepdim, single_dim)
dim_indices = tuple(levels.index(d) for d in dims) dim_indices = tuple(levels.index(d) for d in dims)
@ -295,7 +350,9 @@ def _wrap(orig, dim_offset=0, keepdim_offset=1, dim_name='dim', single_dim=False
new_levels = levels new_levels = levels
if len(dim_indices) == 1: if len(dim_indices) == 1:
dim_indices = dim_indices[0] # so that dims that really only take a single argument work... dim_indices = dim_indices[
0
] # so that dims that really only take a single argument work...
args = list(args) args = list(args)
_patcharg(dim_name, dim_offset, args, kwargs, dim_indices) _patcharg(dim_name, dim_offset, args, kwargs, dim_indices)
@ -303,21 +360,27 @@ def _wrap(orig, dim_offset=0, keepdim_offset=1, dim_name='dim', single_dim=False
if isinstance(t, TensorLike): if isinstance(t, TensorLike):
return Tensor.from_positional(t, new_levels, self._has_device) return Tensor.from_positional(t, new_levels, self._has_device)
return t return t
with _enable_layers(new_levels): with _enable_layers(new_levels):
print(f"dim used batch_tensor for {orig}") print(f"dim used batch_tensor for {orig}")
r = orig(t, *args, **kwargs) r = orig(t, *args, **kwargs)
return tree_map(wrap, r) return tree_map(wrap, r)
return fn return fn
def _def(name, *args, **kwargs): def _def(name, *args, **kwargs):
from . import _Tensor from . import _Tensor
orig = getattr(torch.Tensor, name) orig = getattr(torch.Tensor, name)
setattr(_Tensor, name, _wrap(orig, *args, **kwargs)) setattr(_Tensor, name, _wrap(orig, *args, **kwargs))
no_slice = slice(None) no_slice = slice(None)
_orig_getitem = torch.Tensor.__getitem__ _orig_getitem = torch.Tensor.__getitem__
class dim_tracker: class dim_tracker:
def __init__(self): def __init__(self):
self.dims = llist() self.dims = llist()
@ -331,8 +394,10 @@ class dim_tracker:
def __getitem__(self, d): def __getitem__(self, d):
return self.count[self.dims.index(d)] return self.count[self.dims.index(d)]
def t__getitem__(self, input): def t__getitem__(self, input):
from . import Dim, DimensionBindError, _Tensor, TensorLike, DimList, Tensor from . import _Tensor, Dim, DimensionBindError, DimList, Tensor, TensorLike
# * bail to original example if we have a single non-Dim tensor, or a non-tensor # * bail to original example if we have a single non-Dim tensor, or a non-tensor
# * locate ... or an unbound tensor list, and determine its size, bind dim list # * locate ... or an unbound tensor list, and determine its size, bind dim list
# (remember that None does not count to the total dim count) # (remember that None does not count to the total dim count)
@ -345,10 +410,13 @@ def t__getitem__(self, input):
# this handles bool indexing handling, as well as some other simple cases. # this handles bool indexing handling, as well as some other simple cases.
is_simple = (not isinstance(input, Dim) and is_simple = (
not isinstance(input, (tuple, list)) and not isinstance(input, Dim)
and not isinstance(input, (tuple, list))
and
# WAR for functorch bug where zero time tensors in getitem are not handled correctly. # WAR for functorch bug where zero time tensors in getitem are not handled correctly.
not (isinstance(input, TensorLike) and input.ndim == 0)) not (isinstance(input, TensorLike) and input.ndim == 0)
)
if is_simple: if is_simple:
if isinstance(self, _Tensor): if isinstance(self, _Tensor):
@ -368,8 +436,10 @@ def t__getitem__(self, input):
for i, s in enumerate(input): for i, s in enumerate(input):
if s is ... or isinstance(s, DimList) and not s.is_bound: if s is ... or isinstance(s, DimList) and not s.is_bound:
if expanding_object is not None: if expanding_object is not None:
msg = 'at most one ... or unbound dimension list can exist in indexing list but' \ msg = (
f' found 2 at offsets {i} and {expanding_object}' "at most one ... or unbound dimension list can exist in indexing list but"
f" found 2 at offsets {i} and {expanding_object}"
)
raise DimensionBindError(msg) raise DimensionBindError(msg)
expanding_object = i expanding_object = i
@ -381,12 +451,16 @@ def t__getitem__(self, input):
ndim = self.ndim ndim = self.ndim
if dims_indexed > ndim: if dims_indexed > ndim:
raise IndexError(f'at least {dims_indexed} indices were supplied but the tensor only has {ndim} dimensions.') raise IndexError(
f"at least {dims_indexed} indices were supplied but the tensor only has {ndim} dimensions."
)
if expanding_object is not None: if expanding_object is not None:
expanding_ndims = ndim - dims_indexed expanding_ndims = ndim - dims_indexed
obj = input[expanding_object] obj = input[expanding_object]
if obj is ...: if obj is ...:
input[expanding_object:expanding_object + 1] = [no_slice] * expanding_ndims input[expanding_object : expanding_object + 1] = [
no_slice
] * expanding_ndims
else: else:
obj.bind_len(expanding_ndims) obj.bind_len(expanding_ndims)
# flatten the dimslists into the indexing # flatten the dimslists into the indexing
@ -420,7 +494,7 @@ def t__getitem__(self, input):
elif isinstance(idx, (tuple, list)) and idx and isinstance(idx[0], Dim): elif isinstance(idx, (tuple, list)) and idx and isinstance(idx[0], Dim):
for d in idx: for d in idx:
dims_seen.record(idx) dims_seen.record(idx)
_bind_dims_to_size(sz, idx, f'offset {i}') _bind_dims_to_size(sz, idx, f"offset {i}")
view_sizes.extend(d.size for d in idx) view_sizes.extend(d.size for d in idx)
requires_view = True requires_view = True
dim_packs.append(i) dim_packs.append(i)
@ -499,6 +573,7 @@ def t__getitem__(self, input):
return Tensor.from_positional(result, result_levels, has_device) return Tensor.from_positional(result, result_levels, has_device)
# XXX - dim is optional and can be the outer-most dimension... # XXX - dim is optional and can be the outer-most dimension...
def stack(tensors, new_dim, dim=0, out=None): def stack(tensors, new_dim, dim=0, out=None):
if isinstance(dim, int): if isinstance(dim, int):
@ -517,12 +592,20 @@ def stack(tensors, new_dim, dim=0, out=None):
pr = torch.stack(ptensors, index, out=out) pr = torch.stack(ptensors, index, out=out)
return pr.index((index, index + 1), (new_dim, dim)) return pr.index((index, index + 1), (new_dim, dim))
_orig_split = torch.Tensor.split _orig_split = torch.Tensor.split
def split(self, split_size_or_sections, dim=0): def split(self, split_size_or_sections, dim=0):
from . import Dim, _Tensor from . import _Tensor, Dim
if isinstance(split_size_or_sections, int) or any(isinstance(t, int) for t in split_size_or_sections):
if isinstance(split_size_or_sections, int) or any(
isinstance(t, int) for t in split_size_or_sections
):
if isinstance(dim, Dim): if isinstance(dim, Dim):
raise ValueError('when dim is specified as a Dim object, split sizes must also be dimensions.') raise ValueError(
"when dim is specified as a Dim object, split sizes must also be dimensions."
)
return _orig_split(self, split_size_or_sections, dim=dim) return _orig_split(self, split_size_or_sections, dim=dim)
if isinstance(dim, Dim): if isinstance(dim, Dim):
@ -542,8 +625,9 @@ def split(self, split_size_or_sections, dim=0):
unbound.append(i) unbound.append(i)
if unbound: if unbound:
assert total_bound_size <= size, \ assert (
f"result dimensions are larger than original: {total_bound_size} vs {size} ({split_size_or_sections})" total_bound_size <= size
), f"result dimensions are larger than original: {total_bound_size} vs {size} ({split_size_or_sections})"
remaining_size = size - total_bound_size remaining_size = size - total_bound_size
chunk_size = -(-remaining_size // len(unbound)) chunk_size = -(-remaining_size // len(unbound))
for u in unbound: for u in unbound:
@ -552,6 +636,10 @@ def split(self, split_size_or_sections, dim=0):
sizes[u] = sz sizes[u] = sz
remaining_size -= sz remaining_size -= sz
else: else:
assert total_bound_size == size, \ assert (
f"result dimensions do not match original: {total_bound_size} vs {size} ({split_size_or_sections})" total_bound_size == size
return tuple(t.index(dim, d) for d, t in zip(split_size_or_sections, _orig_split(self, sizes, dim=dim))) ), f"result dimensions do not match original: {total_bound_size} vs {size} ({split_size_or_sections})"
return tuple(
t.index(dim, d)
for d, t in zip(split_size_or_sections, _orig_split(self, sizes, dim=dim))
)

View File

@ -5,8 +5,10 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from functorch._C import dim from functorch._C import dim
tree_flatten = dim.tree_flatten tree_flatten = dim.tree_flatten
def tree_map(fn, tree): def tree_map(fn, tree):
vs, unflatten = tree_flatten(tree) vs, unflatten = tree_flatten(tree)
return unflatten(fn(v) for v in vs) return unflatten(fn(v) for v in vs)

View File

@ -4,22 +4,35 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from types import FunctionType, BuiltinMethodType, MethodDescriptorType, WrapperDescriptorType, GetSetDescriptorType from types import (
BuiltinMethodType,
FunctionType,
GetSetDescriptorType,
MethodDescriptorType,
WrapperDescriptorType,
)
from functorch._C import dim as _C from functorch._C import dim as _C
_wrap_method = _C._wrap_method _wrap_method = _C._wrap_method
FUNC_TYPES = (FunctionType, MethodDescriptorType, BuiltinMethodType, WrapperDescriptorType) FUNC_TYPES = (
FunctionType,
MethodDescriptorType,
BuiltinMethodType,
WrapperDescriptorType,
)
PROPERTY_TYPES = (GetSetDescriptorType, property) PROPERTY_TYPES = (GetSetDescriptorType, property)
def _py_wrap_method(orig, __torch_function__): def _py_wrap_method(orig, __torch_function__):
def impl(*args, **kwargs): def impl(*args, **kwargs):
return __torch_function__(orig, None, args, kwargs) return __torch_function__(orig, None, args, kwargs)
return impl return impl
def wrap_type(use_c, to_patch, pattern, __torch_function__): def wrap_type(use_c, to_patch, pattern, __torch_function__):
if use_c: if use_c:
wrap_method = _wrap_method wrap_method = _wrap_method
else: else:
@ -29,18 +42,27 @@ def wrap_type(use_c, to_patch, pattern, __torch_function__):
for t in reversed(pattern.mro()[:-1]): # skip object for t in reversed(pattern.mro()[:-1]): # skip object
all.update(t.__dict__) all.update(t.__dict__)
def wrap_attr(orig): def wrap_attr(orig):
return property(wrap_method(orig.__get__, __torch_function__)) return property(wrap_method(orig.__get__, __torch_function__))
for name, obj in all.items(): for name, obj in all.items():
if name in ('__dict__', '__new__', '__init__', '__repr__', '__weakref__', '__doc__', '__module__', '__dir__'): if name in (
"__dict__",
"__new__",
"__init__",
"__repr__",
"__weakref__",
"__doc__",
"__module__",
"__dir__",
):
continue continue
# skip things that have been overloaded # skip things that have been overloaded
# things that come from object like `__eq__` still need to be patched, however. # things that come from object like `__eq__` still need to be patched, however.
if hasattr(to_patch, name) and getattr(to_patch, name) is not getattr(object, name, None): if hasattr(to_patch, name) and getattr(to_patch, name) is not getattr(
object, name, None
):
continue continue
if isinstance(obj, FUNC_TYPES): if isinstance(obj, FUNC_TYPES):

View File

@ -14,18 +14,21 @@
# documentation root, use os.path.abspath to make it absolute, like shown here. # documentation root, use os.path.abspath to make it absolute, like shown here.
# #
import os import os
import functorch
# import sys # import sys
# source code directory, relative to this file, for sphinx-autobuild # source code directory, relative to this file, for sphinx-autobuild
# sys.path.insert(0, os.path.abspath('../..')) # sys.path.insert(0, os.path.abspath('../..'))
import torch import torch
import functorch
RELEASE = os.environ.get('RELEASE', False) RELEASE = os.environ.get("RELEASE", False)
import sys
import pytorch_sphinx_theme import pytorch_sphinx_theme
import sys
# -- General configuration ------------------------------------------------ # -- General configuration ------------------------------------------------
@ -35,18 +38,18 @@ import sys
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones. # ones.
extensions = [ extensions = [
'sphinx.ext.autodoc', "sphinx.ext.autodoc",
'sphinx.ext.autosummary', "sphinx.ext.autosummary",
'sphinx.ext.doctest', "sphinx.ext.doctest",
'sphinx.ext.intersphinx', "sphinx.ext.intersphinx",
'sphinx.ext.todo', "sphinx.ext.todo",
'sphinx.ext.coverage', "sphinx.ext.coverage",
'sphinx.ext.napoleon', "sphinx.ext.napoleon",
'sphinx.ext.viewcode', "sphinx.ext.viewcode",
# 'sphinxcontrib.katex', # 'sphinxcontrib.katex',
'sphinx.ext.autosectionlabel', "sphinx.ext.autosectionlabel",
'sphinx_copybutton', "sphinx_copybutton",
'myst_nb', "myst_nb",
] ]
# sys.path.insert(0, os.path.abspath('./notebooks')) # sys.path.insert(0, os.path.abspath('./notebooks'))
@ -75,21 +78,21 @@ napoleon_use_ivar = True
autosummary_generate = True autosummary_generate = True
# Add any paths that contain templates here, relative to this directory. # Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates'] templates_path = ["_templates"]
# The suffix(es) of source filenames. # The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string: # You can specify multiple suffix as a list of string:
# #
# source_suffix = ['.rst', '.md'] # source_suffix = ['.rst', '.md']
source_suffix = '.rst' source_suffix = ".rst"
# The master toctree document. # The master toctree document.
master_doc = 'index' master_doc = "index"
# General information about the project. # General information about the project.
project = 'functorch' project = "functorch"
copyright = 'PyTorch Contributors' copyright = "PyTorch Contributors"
author = 'PyTorch Contributors' author = "PyTorch Contributors"
functorch_version = str(functorch.__version__) functorch_version = str(functorch.__version__)
# The version info for the project you're documenting, acts as replacement for # The version info for the project you're documenting, acts as replacement for
@ -98,16 +101,16 @@ functorch_version = str(functorch.__version__)
# #
# The short X.Y version. # The short X.Y version.
# TODO: change to [:2] at v1.0 # TODO: change to [:2] at v1.0
version = 'nightly (' + functorch_version + ')' version = "nightly (" + functorch_version + ")"
# The full version, including alpha/beta/rc tags. # The full version, including alpha/beta/rc tags.
# TODO: verify this works as expected # TODO: verify this works as expected
release = 'nightly' release = "nightly"
# Customized html_title here. # Customized html_title here.
# Default is " ".join(project, release, "documentation") if not set # Default is " ".join(project, release, "documentation") if not set
# TODO: I don't know if this flag works, please check before using it # TODO: I don't know if this flag works, please check before using it
if RELEASE: if RELEASE:
raise RuntimeError('NYI') raise RuntimeError("NYI")
# remove hash (start with 'a') from version number if any # remove hash (start with 'a') from version number if any
# version_end = functorch_version.find('a') # version_end = functorch_version.find('a')
# if version_end == -1: # if version_end == -1:
@ -128,10 +131,10 @@ language = "en"
# List of patterns, relative to source directory, that match files and # List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files. # directories to ignore when looking for source files.
# This patterns also effect to html_static_path and html_extra_path # This patterns also effect to html_static_path and html_extra_path
exclude_patterns = ['notebooks/colab**', 'notebooks/_src/**'] exclude_patterns = ["notebooks/colab**", "notebooks/_src/**"]
# The name of the Pygments (syntax highlighting) style to use. # The name of the Pygments (syntax highlighting) style to use.
pygments_style = 'sphinx' pygments_style = "sphinx"
# If true, `todo` and `todoList` produce output, else they produce nothing. # If true, `todo` and `todoList` produce output, else they produce nothing.
todo_include_todos = True todo_include_todos = True
@ -140,7 +143,7 @@ todo_include_todos = True
autodoc_inherit_docstrings = False autodoc_inherit_docstrings = False
# Disable displaying type annotations, these can be very verbose # Disable displaying type annotations, these can be very verbose
autodoc_typehints = 'none' autodoc_typehints = "none"
# Enable overriding of function signatures in the first line of the docstring. # Enable overriding of function signatures in the first line of the docstring.
autodoc_docstring_signature = True autodoc_docstring_signature = True
@ -159,7 +162,7 @@ autodoc_docstring_signature = True
# #
# #
html_theme = 'pytorch_sphinx_theme' html_theme = "pytorch_sphinx_theme"
html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()] html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()]
# Theme options are theme-specific and customize the look and feel of a theme # Theme options are theme-specific and customize the look and feel of a theme
@ -178,10 +181,10 @@ html_theme_options = {
# Add any paths that contain custom static files (such as style sheets) here, # Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files, # relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css". # so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static'] html_static_path = ["_static"]
html_css_files = [ html_css_files = [
'css/custom.css', "css/custom.css",
] ]
@ -191,19 +194,20 @@ def setup(app):
# and can be moved outside of this function (and the setup(app) function # and can be moved outside of this function (and the setup(app) function
# can be deleted). # can be deleted).
html_css_files = [ html_css_files = [
'https://cdn.jsdelivr.net/npm/katex@0.10.0-beta/dist/katex.min.css' "https://cdn.jsdelivr.net/npm/katex@0.10.0-beta/dist/katex.min.css"
] ]
# In Sphinx 1.8 it was renamed to `add_css_file`, 1.7 and prior it is # In Sphinx 1.8 it was renamed to `add_css_file`, 1.7 and prior it is
# `add_stylesheet` (deprecated in 1.8). # `add_stylesheet` (deprecated in 1.8).
add_css = getattr(app, 'add_css_file', app.add_stylesheet) add_css = getattr(app, "add_css_file", app.add_stylesheet)
for css_file in html_css_files: for css_file in html_css_files:
add_css(css_file) add_css(css_file)
# -- Options for HTMLHelp output ------------------------------------------ # -- Options for HTMLHelp output ------------------------------------------
# Output file base name for HTML help builder. # Output file base name for HTML help builder.
htmlhelp_basename = 'PyTorchdoc' htmlhelp_basename = "PyTorchdoc"
# -- Options for LaTeX output --------------------------------------------- # -- Options for LaTeX output ---------------------------------------------
@ -212,15 +216,12 @@ latex_elements = {
# The paper size ('letterpaper' or 'a4paper'). # The paper size ('letterpaper' or 'a4paper').
# #
# 'papersize': 'letterpaper', # 'papersize': 'letterpaper',
# The font size ('10pt', '11pt' or '12pt'). # The font size ('10pt', '11pt' or '12pt').
# #
# 'pointsize': '10pt', # 'pointsize': '10pt',
# Additional stuff for the LaTeX preamble. # Additional stuff for the LaTeX preamble.
# #
# 'preamble': '', # 'preamble': '',
# Latex figure (float) alignment # Latex figure (float) alignment
# #
# 'figure_align': 'htbp', # 'figure_align': 'htbp',
@ -230,8 +231,13 @@ latex_elements = {
# (source start file, target name, title, # (source start file, target name, title,
# author, documentclass [howto, manual, or own class]). # author, documentclass [howto, manual, or own class]).
latex_documents = [ latex_documents = [
(master_doc, 'pytorch.tex', 'PyTorch Documentation', (
'Torch Contributors', 'manual'), master_doc,
"pytorch.tex",
"PyTorch Documentation",
"Torch Contributors",
"manual",
),
] ]
@ -239,10 +245,7 @@ latex_documents = [
# One entry per manual page. List of tuples # One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section). # (source start file, name, description, authors, manual section).
man_pages = [ man_pages = [(master_doc, "functorch", "functorch Documentation", [author], 1)]
(master_doc, 'functorch', 'functorch Documentation',
[author], 1)
]
# -- Options for Texinfo output ------------------------------------------- # -- Options for Texinfo output -------------------------------------------
@ -251,37 +254,44 @@ man_pages = [
# (source start file, target name, title, author, # (source start file, target name, title, author,
# dir menu entry, description, category) # dir menu entry, description, category)
texinfo_documents = [ texinfo_documents = [
(master_doc, 'functorch', 'functorch Documentation', (
author, 'functorch', 'One line description of project.', master_doc,
'Miscellaneous'), "functorch",
"functorch Documentation",
author,
"functorch",
"One line description of project.",
"Miscellaneous",
),
] ]
# Example configuration for intersphinx: refer to the Python standard library. # Example configuration for intersphinx: refer to the Python standard library.
intersphinx_mapping = { intersphinx_mapping = {
'python': ('https://docs.python.org/3', None), "python": ("https://docs.python.org/3", None),
'numpy': ('https://numpy.org/doc/stable', None), "numpy": ("https://numpy.org/doc/stable", None),
"torch": ("https://pytorch.org/docs/stable/", None), "torch": ("https://pytorch.org/docs/stable/", None),
} }
import sphinx.ext.doctest
# -- A patch that prevents Sphinx from cross-referencing ivar tags ------- # -- A patch that prevents Sphinx from cross-referencing ivar tags -------
# See http://stackoverflow.com/a/41184353/3343043 # See http://stackoverflow.com/a/41184353/3343043
from docutils import nodes from docutils import nodes
from sphinx.util.docfields import TypedField
from sphinx import addnodes from sphinx import addnodes
import sphinx.ext.doctest from sphinx.util.docfields import TypedField
# Without this, doctest adds any example with a `>>>` as a test # Without this, doctest adds any example with a `>>>` as a test
doctest_test_doctest_blocks = '' doctest_test_doctest_blocks = ""
doctest_default_flags = sphinx.ext.doctest.doctest.ELLIPSIS doctest_default_flags = sphinx.ext.doctest.doctest.ELLIPSIS
doctest_global_setup = ''' doctest_global_setup = """
import torch import torch
try: try:
import torchvision import torchvision
except ImportError: except ImportError:
torchvision = None torchvision = None
''' """
def patched_make_field(self, types, domain, items, **kw): def patched_make_field(self, types, domain, items, **kw):
@ -291,43 +301,51 @@ def patched_make_field(self, types, domain, items, **kw):
# (List, unicode, Tuple) -> nodes.field # (List, unicode, Tuple) -> nodes.field
def handle_item(fieldarg, content): def handle_item(fieldarg, content):
par = nodes.paragraph() par = nodes.paragraph()
par += addnodes.literal_strong('', fieldarg) # Patch: this line added par += addnodes.literal_strong("", fieldarg) # Patch: this line added
# par.extend(self.make_xrefs(self.rolename, domain, fieldarg, # par.extend(self.make_xrefs(self.rolename, domain, fieldarg,
# addnodes.literal_strong)) # addnodes.literal_strong))
if fieldarg in types: if fieldarg in types:
par += nodes.Text(' (') par += nodes.Text(" (")
# NOTE: using .pop() here to prevent a single type node to be # NOTE: using .pop() here to prevent a single type node to be
# inserted twice into the doctree, which leads to # inserted twice into the doctree, which leads to
# inconsistencies later when references are resolved # inconsistencies later when references are resolved
fieldtype = types.pop(fieldarg) fieldtype = types.pop(fieldarg)
if len(fieldtype) == 1 and isinstance(fieldtype[0], nodes.Text): if len(fieldtype) == 1 and isinstance(fieldtype[0], nodes.Text):
typename = u''.join(n.astext() for n in fieldtype) typename = "".join(n.astext() for n in fieldtype)
typename = typename.replace('int', 'python:int') typename = typename.replace("int", "python:int")
typename = typename.replace('long', 'python:long') typename = typename.replace("long", "python:long")
typename = typename.replace('float', 'python:float') typename = typename.replace("float", "python:float")
typename = typename.replace('bool', 'python:bool') typename = typename.replace("bool", "python:bool")
typename = typename.replace('type', 'python:type') typename = typename.replace("type", "python:type")
par.extend(self.make_xrefs(self.typerolename, domain, typename, par.extend(
addnodes.literal_emphasis, **kw)) self.make_xrefs(
self.typerolename,
domain,
typename,
addnodes.literal_emphasis,
**kw,
)
)
else: else:
par += fieldtype par += fieldtype
par += nodes.Text(')') par += nodes.Text(")")
par += nodes.Text(' -- ') par += nodes.Text(" -- ")
par += content par += content
return par return par
fieldname = nodes.field_name('', self.label) fieldname = nodes.field_name("", self.label)
if len(items) == 1 and self.can_collapse: if len(items) == 1 and self.can_collapse:
fieldarg, content = items[0] fieldarg, content = items[0]
bodynode = handle_item(fieldarg, content) bodynode = handle_item(fieldarg, content)
else: else:
bodynode = self.list_type() bodynode = self.list_type()
for fieldarg, content in items: for fieldarg, content in items:
bodynode += nodes.list_item('', handle_item(fieldarg, content)) bodynode += nodes.list_item("", handle_item(fieldarg, content))
fieldbody = nodes.field_body('', bodynode) fieldbody = nodes.field_body("", bodynode)
return nodes.field('', fieldname, fieldbody) return nodes.field("", fieldname, fieldbody)
TypedField.make_field = patched_make_field TypedField.make_field = patched_make_field
copybutton_prompt_text = r'>>> |\.\.\. ' copybutton_prompt_text = r">>> |\.\.\. "
copybutton_prompt_is_regexp = True copybutton_prompt_is_regexp = True

View File

@ -1,3 +1,3 @@
from .rearrange import rearrange from .rearrange import rearrange
__all__ = ['rearrange'] __all__ = ["rearrange"]

View File

@ -40,7 +40,9 @@ class AnonymousAxis:
def __init__(self, value: str) -> None: def __init__(self, value: str) -> None:
self.value = int(value) self.value = int(value)
if self.value < 1: if self.value < 1:
raise ValueError(f'Anonymous axis should have positive length, not {self.value}') raise ValueError(
f"Anonymous axis should have positive length, not {self.value}"
)
def __repr__(self) -> str: def __repr__(self) -> str:
return f"{self.value}-axis" return f"{self.value}-axis"
@ -49,7 +51,13 @@ class AnonymousAxis:
class ParsedExpression: class ParsedExpression:
"""Structure containing information about one side of an `einops`-style pattern (e.g. 'b c (h w)').""" """Structure containing information about one side of an `einops`-style pattern (e.g. 'b c (h w)')."""
def __init__(self, expression: str, *, allow_underscore: bool = False, allow_duplicates: bool = False) -> None: def __init__(
self,
expression: str,
*,
allow_underscore: bool = False,
allow_duplicates: bool = False,
) -> None:
"""Parse the expression and store relevant metadata. """Parse the expression and store relevant metadata.
Args: Args:
@ -66,10 +74,13 @@ class ParsedExpression:
self.composition: List[Union[List[Union[str, AnonymousAxis]], str]] = [] self.composition: List[Union[List[Union[str, AnonymousAxis]], str]] = []
if "." in expression: if "." in expression:
if "..." not in expression: if "..." not in expression:
raise ValueError("Expression may contain dots only inside ellipsis (...)") raise ValueError(
"Expression may contain dots only inside ellipsis (...)"
)
if str.count(expression, "...") != 1 or str.count(expression, ".") != 3: if str.count(expression, "...") != 1 or str.count(expression, ".") != 3:
raise ValueError( raise ValueError(
"Expression may contain dots only inside ellipsis (...); only one ellipsis for tensor ") "Expression may contain dots only inside ellipsis (...); only one ellipsis for tensor "
)
expression = expression.replace("...", _ellipsis) expression = expression.replace("...", _ellipsis)
self.has_ellipsis = True self.has_ellipsis = True
@ -78,7 +89,9 @@ class ParsedExpression:
def add_axis_name(x: str) -> None: def add_axis_name(x: str) -> None:
if x in self.identifiers: if x in self.identifiers:
if not (allow_underscore and x == "_") and not allow_duplicates: if not (allow_underscore and x == "_") and not allow_duplicates:
raise ValueError(f"Indexing expression contains duplicate dimension '{x}'") raise ValueError(
f"Indexing expression contains duplicate dimension '{x}'"
)
if x == _ellipsis: if x == _ellipsis:
self.identifiers.add(_ellipsis) self.identifiers.add(_ellipsis)
if bracket_group is None: if bracket_group is None:
@ -96,10 +109,14 @@ class ParsedExpression:
else: else:
pass # no need to think about 1s inside parenthesis pass # no need to think about 1s inside parenthesis
return return
is_axis_name, reason = self.check_axis_name_return_reason(x, allow_underscore=allow_underscore) is_axis_name, reason = self.check_axis_name_return_reason(
x, allow_underscore=allow_underscore
)
if not (is_number or is_axis_name): if not (is_number or is_axis_name):
raise ValueError(f"Invalid axis identifier: {x}\n{reason}") raise ValueError(f"Invalid axis identifier: {x}\n{reason}")
axis_name: Union[str, AnonymousAxis] = AnonymousAxis(x) if is_number else x axis_name: Union[str, AnonymousAxis] = (
AnonymousAxis(x) if is_number else x
)
self.identifiers.add(axis_name) self.identifiers.add(axis_name)
if is_number: if is_number:
self.has_non_unitary_anonymous_axes = True self.has_non_unitary_anonymous_axes = True
@ -116,7 +133,9 @@ class ParsedExpression:
current_identifier = None current_identifier = None
if char == "(": if char == "(":
if bracket_group is not None: if bracket_group is not None:
raise ValueError("Axis composition is one-level (brackets inside brackets not allowed)") raise ValueError(
"Axis composition is one-level (brackets inside brackets not allowed)"
)
bracket_group = [] bracket_group = []
elif char == ")": elif char == ")":
if bracket_group is None: if bracket_group is None:
@ -137,7 +156,9 @@ class ParsedExpression:
add_axis_name(current_identifier) add_axis_name(current_identifier)
@staticmethod @staticmethod
def check_axis_name_return_reason(name: str, allow_underscore: bool = False) -> Tuple[bool, str]: def check_axis_name_return_reason(
name: str, allow_underscore: bool = False
) -> Tuple[bool, str]:
"""Check if the given axis name is valid, and a message explaining why if not. """Check if the given axis name is valid, and a message explaining why if not.
Valid axes names are python identifiers except keywords, and should not start or end with an underscore. Valid axes names are python identifiers except keywords, and should not start or end with an underscore.
@ -157,10 +178,14 @@ class ParsedExpression:
return False, "axis name should should not start or end with underscore" return False, "axis name should should not start or end with underscore"
else: else:
if keyword.iskeyword(name): if keyword.iskeyword(name):
warnings.warn(f"It is discouraged to use axes names that are keywords: {name}", RuntimeWarning) warnings.warn(
f"It is discouraged to use axes names that are keywords: {name}",
RuntimeWarning,
)
if name in ["axis"]: if name in ["axis"]:
warnings.warn( warnings.warn(
"It is discouraged to use 'axis' as an axis name and will raise an error in future", FutureWarning "It is discouraged to use 'axis' as an axis name and will raise an error in future",
FutureWarning,
) )
return True, "" return True, ""
@ -178,8 +203,9 @@ class ParsedExpression:
return is_valid return is_valid
def parse_pattern(
def parse_pattern(pattern: str, axes_lengths: Mapping[str, int]) -> Tuple[ParsedExpression, ParsedExpression]: pattern: str, axes_lengths: Mapping[str, int]
) -> Tuple[ParsedExpression, ParsedExpression]:
"""Parse an `einops`-style pattern into a left-hand side and right-hand side `ParsedExpression` object. """Parse an `einops`-style pattern into a left-hand side and right-hand side `ParsedExpression` object.
Args: Args:
@ -203,9 +229,13 @@ def parse_pattern(pattern: str, axes_lengths: Mapping[str, int]) -> Tuple[Parsed
right = ParsedExpression(right_str) right = ParsedExpression(right_str)
if not left.has_ellipsis and right.has_ellipsis: if not left.has_ellipsis and right.has_ellipsis:
raise ValueError(f"Ellipsis found in right side, but not left side of a pattern {pattern}") raise ValueError(
f"Ellipsis found in right side, but not left side of a pattern {pattern}"
)
if left.has_ellipsis and left.has_ellipsis_parenthesized: if left.has_ellipsis and left.has_ellipsis_parenthesized:
raise ValueError(f"Ellipsis is parenthesis in the left side is not allowed: {pattern}") raise ValueError(
f"Ellipsis is parenthesis in the left side is not allowed: {pattern}"
)
return left, right return left, right
@ -222,18 +252,24 @@ def validate_rearrange_expressions(
""" """
for length in axes_lengths.values(): for length in axes_lengths.values():
if (length_type := type(length)) is not int: if (length_type := type(length)) is not int:
raise TypeError(f"rearrange axis lengths must be integers, got: {length_type}") raise TypeError(
f"rearrange axis lengths must be integers, got: {length_type}"
)
if left.has_non_unitary_anonymous_axes or right.has_non_unitary_anonymous_axes: if left.has_non_unitary_anonymous_axes or right.has_non_unitary_anonymous_axes:
raise ValueError("rearrange only supports unnamed axes of size 1") raise ValueError("rearrange only supports unnamed axes of size 1")
difference = set.symmetric_difference(left.identifiers, right.identifiers) difference = set.symmetric_difference(left.identifiers, right.identifiers)
if len(difference) > 0: if len(difference) > 0:
raise ValueError(f"Identifiers only on one side of rearrange expression (should be on both): {difference}") raise ValueError(
f"Identifiers only on one side of rearrange expression (should be on both): {difference}"
)
unmatched_axes = axes_lengths.keys() - left.identifiers unmatched_axes = axes_lengths.keys() - left.identifiers
if len(unmatched_axes) > 0: if len(unmatched_axes) > 0:
raise ValueError(f"Identifiers not found in rearrange expression: {unmatched_axes}") raise ValueError(
f"Identifiers not found in rearrange expression: {unmatched_axes}"
)
def comma_separate(collection: Collection[Union[str, Collection[str]]]) -> str: def comma_separate(collection: Collection[Union[str, Collection[str]]]) -> str:
@ -259,6 +295,8 @@ def comma_separate(collection: Collection[Union[str, Collection[str]]]) -> str:
'(d0,), (), (d1,), (d2,), (d3, d4)' '(d0,), (), (d1,), (d2,), (d3, d4)'
""" """
return ", ".join( return ", ".join(
item if isinstance(item, str) else f"({comma_separate(item)}{',' if len(item) == 1 else ''})" item
if isinstance(item, str)
else f"({comma_separate(item)}{',' if len(item) == 1 else ''})"
for item in collection for item in collection
) )

View File

@ -4,8 +4,15 @@ import functools
from typing import Callable, Dict, List, Sequence, Tuple, Union from typing import Callable, Dict, List, Sequence, Tuple, Union
import torch import torch
from functorch._C import dim as _C from functorch._C import dim as _C
from ._parsing import AnonymousAxis, _ellipsis, comma_separate, parse_pattern, validate_rearrange_expressions from ._parsing import (
_ellipsis,
AnonymousAxis,
comma_separate,
parse_pattern,
validate_rearrange_expressions,
)
__all__ = ["rearrange"] __all__ = ["rearrange"]
@ -79,10 +86,12 @@ def _create_rearrange_callable(
dims_i += 1 dims_i += 1
elif dimension == _ellipsis: elif dimension == _ellipsis:
identifier = _ellipsis identifier = _ellipsis
identifier_dim_map[identifier] = tuple(first_class_dims[dims_i + j] for j in range(n_ellipsis_dims)) identifier_dim_map[identifier] = tuple(
first_class_dims[dims_i + j] for j in range(n_ellipsis_dims)
)
dims_i += n_ellipsis_dims dims_i += n_ellipsis_dims
else: else:
raise ValueError(f'Unexpected dimension: {dimension}') raise ValueError(f"Unexpected dimension: {dimension}")
def composition_to_dims( def composition_to_dims(
composition: Sequence[Union[List[Union[str, AnonymousAxis]], str]] composition: Sequence[Union[List[Union[str, AnonymousAxis]], str]]
@ -92,11 +101,17 @@ def _create_rearrange_callable(
dim_composition: List[Union[str, Tuple[str, ...]]] = [] dim_composition: List[Union[str, Tuple[str, ...]]] = []
for dimension in composition: for dimension in composition:
if isinstance(dimension, list): if isinstance(dimension, list):
dim_composition.append(tuple(dim for identifier in dimension for dim in identifier_dim_map[identifier])) dim_composition.append(
tuple(
dim
for identifier in dimension
for dim in identifier_dim_map[identifier]
)
)
elif dimension == _ellipsis: elif dimension == _ellipsis:
dim_composition.extend(identifier_dim_map[_ellipsis]) dim_composition.extend(identifier_dim_map[_ellipsis])
else: else:
raise ValueError(f'Unexpected dimension: {dimension}') raise ValueError(f"Unexpected dimension: {dimension}")
return dim_composition return dim_composition
left_dims = composition_to_dims(left.composition) left_dims = composition_to_dims(left.composition)
@ -108,16 +123,22 @@ def _create_rearrange_callable(
custom_rearrange_callable_name = "do_rearrange" custom_rearrange_callable_name = "do_rearrange"
custom_rearrange_callable_code = ( custom_rearrange_callable_code = (
(
f"def {custom_rearrange_callable_name}(tensor):\n" f"def {custom_rearrange_callable_name}(tensor):\n"
f" {comma_separate(first_class_dims)} = dims({n_dims})\n" f" {comma_separate(first_class_dims)} = dims({n_dims})\n"
)
+ ( + (
"".join(f" {dim}.size = {length}\n" for (dim, length) in specified_lengths) "".join(
if specified_lengths else "" f" {dim}.size = {length}\n" for (dim, length) in specified_lengths
)
if specified_lengths
else ""
) )
+ f" tensor = tensor[{comma_separate(left_dims)}].order({comma_separate(right_dims)})\n" + f" tensor = tensor[{comma_separate(left_dims)}].order({comma_separate(right_dims)})\n"
+ ( + (
f" return tensor.sum({comma_separate([anon_dims])}, keepdim=False)\n" f" return tensor.sum({comma_separate([anon_dims])}, keepdim=False)\n"
if anon_dims else " return tensor\n" if anon_dims
else " return tensor\n"
) )
) )
@ -126,7 +147,9 @@ def _create_rearrange_callable(
def rearrange( def rearrange(
tensor: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]], pattern: str, **axes_lengths: int tensor: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]],
pattern: str,
**axes_lengths: int,
) -> torch.Tensor: ) -> torch.Tensor:
r"""A native implementation of `einops.rearrange`, a reader-friendly smart element reordering for multidimensional r"""A native implementation of `einops.rearrange`, a reader-friendly smart element reordering for multidimensional
tensors. This operation includes functionality of transpose (axes permutation), reshape (view), squeeze, unsqueeze, tensors. This operation includes functionality of transpose (axes permutation), reshape (view), squeeze, unsqueeze,
@ -177,6 +200,8 @@ def rearrange(
if not isinstance(tensor, torch.Tensor): if not isinstance(tensor, torch.Tensor):
tensor = torch.stack(tensor) tensor = torch.stack(tensor)
rearrange_callable = _create_rearrange_callable(tensor.ndim, pattern, **axes_lengths) rearrange_callable = _create_rearrange_callable(
tensor.ndim, pattern, **axes_lengths
)
return rearrange_callable(tensor) return rearrange_callable(tensor)

View File

@ -1,7 +1,8 @@
from functorch.compile import aot_function, tvm_compile
import torch
import time import time
import torch
import torch.utils import torch.utils
from functorch.compile import aot_function, tvm_compile
a = torch.randn(2000, 1, 4, requires_grad=True) a = torch.randn(2000, 1, 4, requires_grad=True)
b = torch.randn(1, 2000, 4) b = torch.randn(1, 2000, 4)
@ -11,8 +12,8 @@ def f(a):
return (a * b).sum(dim=0) return (a * b).sum(dim=0)
fw_compiler = tvm_compile(target='llvm', tuning_logfile='fw_keops') fw_compiler = tvm_compile(target="llvm", tuning_logfile="fw_keops")
bw_compiler = tvm_compile(target='llvm', tuning_logfile='bw_keops') bw_compiler = tvm_compile(target="llvm", tuning_logfile="bw_keops")
compiled_f = aot_function(f, fw_compiler, bw_compiler) compiled_f = aot_function(f, fw_compiler, bw_compiler)
# fw_compiler = lambda x, _: x # fw_compiler = lambda x, _: x
@ -32,13 +33,15 @@ def bench(func):
def bench_jax(): def bench_jax():
import jax.numpy as jnp
import jax import jax
import jax.numpy as jnp
jax_a = jnp.array(a.detach().numpy()) jax_a = jnp.array(a.detach().numpy())
jax_b = jnp.array(b.detach().numpy()) jax_b = jnp.array(b.detach().numpy())
def f(a): def f(a):
return jnp.sin((a * jax_b).sum(axis=[0])).sum() return jnp.sin((a * jax_b).sum(axis=[0])).sum()
jit_f = jax.jit(jax.grad(f)) jit_f = jax.jit(jax.grad(f))
jit_f(jax_a) jit_f(jax_a)
begin = time.time() begin = time.time()

View File

@ -1,15 +1,16 @@
import timeit import timeit
from functorch.compile import compiled_module, tvm_compile
import torch.nn as nn
import torch import torch
import torch.nn as nn
from functorch.compile import compiled_module, tvm_compile
def nop(f, _): def nop(f, _):
return f return f
fw_compiler = tvm_compile(target='llvm', tuning_logfile='fw_keops') fw_compiler = tvm_compile(target="llvm", tuning_logfile="fw_keops")
bw_compiler = tvm_compile(target='llvm', tuning_logfile='bw_keops') bw_compiler = tvm_compile(target="llvm", tuning_logfile="bw_keops")
fw_compiler = nop fw_compiler = nop
bw_compiler = nop bw_compiler = nop

View File

@ -4,11 +4,13 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from functorch import make_functional import time
from functorch.compile import nnc_jit
import torch import torch
import torch.nn as nn import torch.nn as nn
import time from functorch import make_functional
from functorch.compile import nnc_jit
torch._C._jit_override_can_fuse_on_cpu(True) torch._C._jit_override_can_fuse_on_cpu(True)
@ -54,7 +56,9 @@ def functional_step(x, weights):
return out, new_weights return out, new_weights
optim = torch.optim.SGD(jit_mod.parameters(), lr=lr, momentum=0, dampening=0, weight_decay=0) optim = torch.optim.SGD(
jit_mod.parameters(), lr=lr, momentum=0, dampening=0, weight_decay=0
)
def jit_step(x, weights): def jit_step(x, weights):

View File

@ -4,10 +4,11 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import time
import torch
from functorch import grad, make_fx from functorch import grad, make_fx
from functorch.compile import nnc_jit from functorch.compile import nnc_jit
import torch
import time
def f(x): def f(x):

View File

@ -17,8 +17,8 @@ import torch.nn as nn
import torch.optim as optim import torch.optim as optim
import torch.utils.data import torch.utils.data
import torchvision.transforms as transforms import torchvision.transforms as transforms
from torchvision import models
from opacus import PrivacyEngine from opacus import PrivacyEngine
from torchvision import models
from torchvision.datasets import CIFAR10 from torchvision.datasets import CIFAR10
from tqdm import tqdm from tqdm import tqdm
@ -52,7 +52,6 @@ def train(args, model, train_loader, optimizer, privacy_engine, epoch, device):
top1_acc = [] top1_acc = []
for i, (images, target) in enumerate(tqdm(train_loader)): for i, (images, target) in enumerate(tqdm(train_loader)):
images = images.to(device) images = images.to(device)
target = target.to(device) target = target.to(device)
@ -279,6 +278,7 @@ def main():
) )
logger.info(metrics) logger.info(metrics)
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description="PyTorch CIFAR10 DP Training") parser = argparse.ArgumentParser(description="PyTorch CIFAR10 DP Training")
parser.add_argument( parser.add_argument(
@ -309,7 +309,7 @@ def parse_args():
default=256, default=256,
type=int, type=int,
metavar="N", metavar="N",
help="mini-batch size for test dataset (default: 256)" help="mini-batch size for test dataset (default: 256)",
) )
parser.add_argument( parser.add_argument(
"--sample-rate", "--sample-rate",

View File

@ -17,12 +17,12 @@ import torch.nn as nn
import torch.optim as optim import torch.optim as optim
import torch.utils.data import torch.utils.data
import torchvision.transforms as transforms import torchvision.transforms as transforms
from torch.func import functional_call, grad_and_value, vmap
from torchvision import models from torchvision import models
from torchvision.datasets import CIFAR10 from torchvision.datasets import CIFAR10
from tqdm import tqdm from tqdm import tqdm
from torch.func import vmap, grad_and_value, functional_call
logging.basicConfig( logging.basicConfig(
format="%(asctime)s:%(levelname)s:%(message)s", format="%(asctime)s:%(levelname)s:%(message)s",
datefmt="%m/%d/%Y %H:%M:%S", datefmt="%m/%d/%Y %H:%M:%S",
@ -44,12 +44,16 @@ def accuracy(preds, labels):
def compute_norms(sample_grads): def compute_norms(sample_grads):
batch_size = sample_grads[0].shape[0] batch_size = sample_grads[0].shape[0]
norms = [sample_grad.view(batch_size, -1).norm(2, dim=-1) for sample_grad in sample_grads] norms = [
sample_grad.view(batch_size, -1).norm(2, dim=-1) for sample_grad in sample_grads
]
norms = torch.stack(norms, dim=0).norm(2, dim=0) norms = torch.stack(norms, dim=0).norm(2, dim=0)
return norms, batch_size return norms, batch_size
def clip_and_accumulate_and_add_noise(model, max_per_sample_grad_norm=1.0, noise_multiplier=1.0): def clip_and_accumulate_and_add_noise(
model, max_per_sample_grad_norm=1.0, noise_multiplier=1.0
):
sample_grads = tuple(param.grad_sample for param in model.parameters()) sample_grads = tuple(param.grad_sample for param in model.parameters())
# step 0: compute the norms # step 0: compute the norms
@ -60,13 +64,16 @@ def clip_and_accumulate_and_add_noise(model, max_per_sample_grad_norm=1.0, noise
clip_factor = clip_factor.clamp(max=1.0) clip_factor = clip_factor.clamp(max=1.0)
# step 2: clip # step 2: clip
grads = tuple(torch.einsum('i,i...', clip_factor, sample_grad) grads = tuple(
for sample_grad in sample_grads) torch.einsum("i,i...", clip_factor, sample_grad) for sample_grad in sample_grads
)
# step 3: add gaussian noise # step 3: add gaussian noise
stddev = max_per_sample_grad_norm * noise_multiplier stddev = max_per_sample_grad_norm * noise_multiplier
noises = tuple(torch.normal(0, stddev, grad_param.shape, device=grad_param.device) noises = tuple(
for grad_param in grads) torch.normal(0, stddev, grad_param.shape, device=grad_param.device)
for grad_param in grads
)
grads = tuple(noise + grad_param for noise, grad_param in zip(noises, grads)) grads = tuple(noise + grad_param for noise, grad_param in zip(noises, grads))
# step 4: assign the new grads, delete the sample grads # step 4: assign the new grads, delete the sample grads
@ -84,7 +91,6 @@ def train(args, model, train_loader, optimizer, epoch, device):
top1_acc = [] top1_acc = []
for i, (images, target) in enumerate(tqdm(train_loader)): for i, (images, target) in enumerate(tqdm(train_loader)):
images = images.to(device) images = images.to(device)
target = target.to(device) target = target.to(device)
@ -120,8 +126,9 @@ def train(args, model, train_loader, optimizer, epoch, device):
# detaching weights since we don't need to track gradients outside of transforms # detaching weights since we don't need to track gradients outside of transforms
# and this is more performant # and this is more performant
detached_weights = {k: v.detach() for k, v in weights.items()} detached_weights = {k: v.detach() for k, v in weights.items()}
sample_grads, (sample_loss, output) = \ sample_grads, (sample_loss, output) = vmap(grads_loss_output, (None, 0, 0))(
vmap(grads_loss_output, (None, 0, 0))(detached_weights, images, target) detached_weights, images, target
)
loss = sample_loss.mean() loss = sample_loss.mean()
for name, grad_sample in sample_grads.items(): for name, grad_sample in sample_grads.items():
@ -129,7 +136,8 @@ def train(args, model, train_loader, optimizer, epoch, device):
# Step 2: Clip the per-sample-grads, sum them to form grads, and add noise # Step 2: Clip the per-sample-grads, sum them to form grads, and add noise
clip_and_accumulate_and_add_noise( clip_and_accumulate_and_add_noise(
model, args.max_per_sample_grad_norm, args.sigma) model, args.max_per_sample_grad_norm, args.sigma
)
preds = np.argmax(output.detach().cpu().numpy(), axis=1) preds = np.argmax(output.detach().cpu().numpy(), axis=1)
labels = target.detach().cpu().numpy() labels = target.detach().cpu().numpy()
@ -270,9 +278,7 @@ def main():
for param_group in optimizer.param_groups: for param_group in optimizer.param_groups:
param_group["lr"] = lr param_group["lr"] = lr
train_duration = train( train_duration = train(args, model, train_loader, optimizer, epoch, device)
args, model, train_loader, optimizer, epoch, device
)
top1_acc = test(args, model, test_loader, device) top1_acc = test(args, model, test_loader, device)
# remember best acc@1 and save checkpoint # remember best acc@1 and save checkpoint
@ -308,6 +314,7 @@ def main():
) )
logger.info(metrics) logger.info(metrics)
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description="PyTorch CIFAR10 DP Training") parser = argparse.ArgumentParser(description="PyTorch CIFAR10 DP Training")
parser.add_argument( parser.add_argument(
@ -338,7 +345,7 @@ def parse_args():
default=256, default=256,
type=int, type=int,
metavar="N", metavar="N",
help="mini-batch size for test dataset (default: 256)" help="mini-batch size for test dataset (default: 256)",
) )
parser.add_argument( parser.add_argument(
"--sample-rate", "--sample-rate",

View File

@ -1,9 +1,10 @@
import argparse import argparse
import math import math
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.func import functional_call, grad_and_value, vmap, stack_module_state from torch.func import functional_call, grad_and_value, stack_module_state, vmap
# Adapted from http://willwhitney.com/parallel-training-jax.html , which is a # Adapted from http://willwhitney.com/parallel-training-jax.html , which is a
# tutorial on Model Ensembling with JAX by Will Whitney. # tutorial on Model Ensembling with JAX by Will Whitney.
@ -33,15 +34,21 @@ DEVICE = args.device
# Step 1: Make some spirals # Step 1: Make some spirals
def make_spirals(n_samples, noise_std=0., rotations=1.): def make_spirals(n_samples, noise_std=0.0, rotations=1.0):
ts = torch.linspace(0, 1, n_samples, device=DEVICE) ts = torch.linspace(0, 1, n_samples, device=DEVICE)
rs = ts**0.5 rs = ts**0.5
thetas = rs * rotations * 2 * math.pi thetas = rs * rotations * 2 * math.pi
signs = torch.randint(0, 2, (n_samples,), device=DEVICE) * 2 - 1 signs = torch.randint(0, 2, (n_samples,), device=DEVICE) * 2 - 1
labels = (signs > 0).to(torch.long).to(DEVICE) labels = (signs > 0).to(torch.long).to(DEVICE)
xs = rs * signs * torch.cos(thetas) + torch.randn(n_samples, device=DEVICE) * noise_std xs = (
ys = rs * signs * torch.sin(thetas) + torch.randn(n_samples, device=DEVICE) * noise_std rs * signs * torch.cos(thetas)
+ torch.randn(n_samples, device=DEVICE) * noise_std
)
ys = (
rs * signs * torch.sin(thetas)
+ torch.randn(n_samples, device=DEVICE) * noise_std
)
points = torch.stack([xs, ys], dim=1) points = torch.stack([xs, ys], dim=1)
return points, labels return points, labels
@ -70,6 +77,7 @@ class MLPClassifier(nn.Module):
loss_fn = nn.NLLLoss() loss_fn = nn.NLLLoss()
model = MLPClassifier().to(DEVICE) model = MLPClassifier().to(DEVICE)
def train_step_fn(weights, batch, targets, lr=0.2): def train_step_fn(weights, batch, targets, lr=0.2):
def compute_loss(weights, batch, targets): def compute_loss(weights, batch, targets):
output = functional_call(model, weights, batch) output = functional_call(model, weights, batch)
@ -109,6 +117,7 @@ def init_fn(num_models):
params, _ = stack_module_state(models) params, _ = stack_module_state(models)
return params return params
# Step 6: Now, can we try multiple models at the same time? # Step 6: Now, can we try multiple models at the same time?
# The answer is: yes! `loss` is a 2-tuple, and we can see that the value keeps # The answer is: yes! `loss` is a 2-tuple, and we can see that the value keeps
# on decreasing # on decreasing

View File

@ -4,11 +4,11 @@
import torch import torch
from torch import nn from torch import nn
from torch.nn.functional import mse_loss
from torch.func import jacrev, vmap from torch.func import jacrev, vmap
from torch.nn.functional import mse_loss
sigma = 0.5 sigma = 0.5
epsilon = 4. epsilon = 4.0
def lennard_jones(r): def lennard_jones(r):
@ -29,7 +29,9 @@ norms = torch.norm(drs, dim=1).reshape(-1, 1)
# Create training energies # Create training energies
training_energies = torch.stack(list(map(lennard_jones, norms))).reshape(-1, 1) training_energies = torch.stack(list(map(lennard_jones, norms))).reshape(-1, 1)
# Create forces with random direction vectors # Create forces with random direction vectors
training_forces = torch.stack([force * dr for force, dr in zip(map(lennard_jones_force, norms), drs)]) training_forces = torch.stack(
[force * dr for force, dr in zip(map(lennard_jones_force, norms), drs)]
)
model = nn.Sequential( model = nn.Sequential(
nn.Linear(1, 16), nn.Linear(1, 16),
@ -40,7 +42,7 @@ model = nn.Sequential(
nn.Tanh(), nn.Tanh(),
nn.Linear(16, 16), nn.Linear(16, 16),
nn.Tanh(), nn.Tanh(),
nn.Linear(16, 1) nn.Linear(16, 1),
) )
@ -54,7 +56,10 @@ def make_prediction(model, drs):
def loss_fn(energies, forces, predicted_energies, predicted_forces): def loss_fn(energies, forces, predicted_energies, predicted_forces):
return mse_loss(energies, predicted_energies) + 0.01 * mse_loss(forces, predicted_forces) / 3 return (
mse_loss(energies, predicted_energies)
+ 0.01 * mse_loss(forces, predicted_forces) / 3
)
optimiser = torch.optim.Adam(model.parameters(), lr=1e-3) optimiser = torch.optim.Adam(model.parameters(), lr=1e-3)

View File

@ -27,38 +27,43 @@ Our MAML++ fork and experiments are available at:
https://github.com/bamos/HowToTrainYourMAMLPytorch https://github.com/bamos/HowToTrainYourMAMLPytorch
""" """
from support.omniglot_loaders import OmniglotNShot
import higher
import torch.optim as optim
import torch.nn.functional as F
from torch import nn
import torch
import matplotlib.pyplot as plt
import argparse import argparse
import time import time
import pandas as pd import higher
import numpy as np
import matplotlib as mpl import matplotlib as mpl
mpl.use('Agg') import matplotlib.pyplot as plt
plt.style.use('bmh') import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import torch.optim as optim
from support.omniglot_loaders import OmniglotNShot
from torch import nn
mpl.use("Agg")
plt.style.use("bmh")
def main(): def main():
argparser = argparse.ArgumentParser() argparser = argparse.ArgumentParser()
argparser.add_argument('--n-way', '--n_way', type=int, help='n way', default=5) argparser.add_argument("--n-way", "--n_way", type=int, help="n way", default=5)
argparser.add_argument( argparser.add_argument(
'--k-spt', '--k_spt', type=int, help='k shot for support set', default=5) "--k-spt", "--k_spt", type=int, help="k shot for support set", default=5
)
argparser.add_argument( argparser.add_argument(
'--k-qry', '--k_qry', type=int, help='k shot for query set', default=15) "--k-qry", "--k_qry", type=int, help="k shot for query set", default=15
)
argparser.add_argument("--device", type=str, help="device", default="cuda")
argparser.add_argument( argparser.add_argument(
'--device', type=str, help='device', default='cuda') "--task-num",
argparser.add_argument( "--task_num",
'--task-num', '--task_num',
type=int, type=int,
help='meta batch size, namely task num', help="meta batch size, namely task num",
default=32) default=32,
argparser.add_argument('--seed', type=int, help='random seed', default=1) )
argparser.add_argument("--seed", type=int, help="random seed", default=1)
args = argparser.parse_args() args = argparser.parse_args()
torch.manual_seed(args.seed) torch.manual_seed(args.seed)
@ -69,7 +74,7 @@ def main():
# Set up the Omniglot loader. # Set up the Omniglot loader.
device = args.device device = args.device
db = OmniglotNShot( db = OmniglotNShot(
'/tmp/omniglot-data', "/tmp/omniglot-data",
batchsz=args.task_num, batchsz=args.task_num,
n_way=args.n_way, n_way=args.n_way,
k_shot=args.k_spt, k_shot=args.k_spt,
@ -97,7 +102,8 @@ def main():
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2), nn.MaxPool2d(2, 2),
Flatten(), Flatten(),
nn.Linear(64, args.n_way)).to(device) nn.Linear(64, args.n_way),
).to(device)
# We will use Adam to (meta-)optimize the initial parameters # We will use Adam to (meta-)optimize the initial parameters
# to be adapted. # to be adapted.
@ -134,9 +140,10 @@ def train(db, net, device, meta_opt, epoch, log):
qry_accs = [] qry_accs = []
meta_opt.zero_grad() meta_opt.zero_grad()
for i in range(task_num): for i in range(task_num):
with higher.innerloop_ctx( with higher.innerloop_ctx(net, inner_opt, copy_initial_weights=False) as (
net, inner_opt, copy_initial_weights=False fnet,
) as (fnet, diffopt): diffopt,
):
# Optimize the likelihood of the support set by taking # Optimize the likelihood of the support set by taking
# gradient steps w.r.t. the model's parameters. # gradient steps w.r.t. the model's parameters.
# This adapts the model's meta-parameters to the task. # This adapts the model's meta-parameters to the task.
@ -153,8 +160,7 @@ def train(db, net, device, meta_opt, epoch, log):
qry_logits = fnet(x_qry[i]) qry_logits = fnet(x_qry[i])
qry_loss = F.cross_entropy(qry_logits, y_qry[i]) qry_loss = F.cross_entropy(qry_logits, y_qry[i])
qry_losses.append(qry_loss.detach()) qry_losses.append(qry_loss.detach())
qry_acc = (qry_logits.argmax( qry_acc = (qry_logits.argmax(dim=1) == y_qry[i]).sum().item() / querysz
dim=1) == y_qry[i]).sum().item() / querysz
qry_accs.append(qry_acc) qry_accs.append(qry_acc)
# print([b.shape for b in fnet[1].buffers()]) # print([b.shape for b in fnet[1].buffers()])
@ -166,21 +172,23 @@ def train(db, net, device, meta_opt, epoch, log):
meta_opt.step() meta_opt.step()
qry_losses = sum(qry_losses) / task_num qry_losses = sum(qry_losses) / task_num
qry_accs = 100. * sum(qry_accs) / task_num qry_accs = 100.0 * sum(qry_accs) / task_num
i = epoch + float(batch_idx) / n_train_iter i = epoch + float(batch_idx) / n_train_iter
iter_time = time.time() - start_time iter_time = time.time() - start_time
if batch_idx % 4 == 0: if batch_idx % 4 == 0:
print( print(
f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}' f"[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}"
) )
log.append({ log.append(
'epoch': i, {
'loss': qry_losses, "epoch": i,
'acc': qry_accs, "loss": qry_losses,
'mode': 'train', "acc": qry_accs,
'time': time.time(), "mode": "train",
}) "time": time.time(),
}
)
def test(db, net, device, epoch, log): def test(db, net, device, epoch, log):
@ -196,7 +204,7 @@ def test(db, net, device, epoch, log):
qry_accs = [] qry_accs = []
for _ in range(n_test_iter): for _ in range(n_test_iter):
x_spt, y_spt, x_qry, y_qry = db.next('test') x_spt, y_spt, x_qry, y_qry = db.next("test")
task_num, setsz, c_, h, w = x_spt.size() task_num, setsz, c_, h, w = x_spt.size()
@ -206,7 +214,10 @@ def test(db, net, device, epoch, log):
inner_opt = torch.optim.SGD(net.parameters(), lr=1e-1) inner_opt = torch.optim.SGD(net.parameters(), lr=1e-1)
for i in range(task_num): for i in range(task_num):
with higher.innerloop_ctx(net, inner_opt, track_higher_grads=False) as (fnet, diffopt): with higher.innerloop_ctx(net, inner_opt, track_higher_grads=False) as (
fnet,
diffopt,
):
# Optimize the likelihood of the support set by taking # Optimize the likelihood of the support set by taking
# gradient steps w.r.t. the model's parameters. # gradient steps w.r.t. the model's parameters.
# This adapts the model's meta-parameters to the task. # This adapts the model's meta-parameters to the task.
@ -217,24 +228,22 @@ def test(db, net, device, epoch, log):
# The query loss and acc induced by these parameters. # The query loss and acc induced by these parameters.
qry_logits = fnet(x_qry[i]).detach() qry_logits = fnet(x_qry[i]).detach()
qry_loss = F.cross_entropy( qry_loss = F.cross_entropy(qry_logits, y_qry[i], reduction="none")
qry_logits, y_qry[i], reduction='none')
qry_losses.append(qry_loss.detach()) qry_losses.append(qry_loss.detach())
qry_accs.append( qry_accs.append((qry_logits.argmax(dim=1) == y_qry[i]).detach())
(qry_logits.argmax(dim=1) == y_qry[i]).detach())
qry_losses = torch.cat(qry_losses).mean().item() qry_losses = torch.cat(qry_losses).mean().item()
qry_accs = 100. * torch.cat(qry_accs).float().mean().item() qry_accs = 100.0 * torch.cat(qry_accs).float().mean().item()
print( print(f"[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}")
f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}' log.append(
{
"epoch": epoch + 1,
"loss": qry_losses,
"acc": qry_accs,
"mode": "test",
"time": time.time(),
}
) )
log.append({
'epoch': epoch + 1,
'loss': qry_losses,
'acc': qry_accs,
'mode': 'test',
'time': time.time(),
})
def plot(log): def plot(log):
@ -243,17 +252,17 @@ def plot(log):
df = pd.DataFrame(log) df = pd.DataFrame(log)
fig, ax = plt.subplots(figsize=(6, 4)) fig, ax = plt.subplots(figsize=(6, 4))
train_df = df[df['mode'] == 'train'] train_df = df[df["mode"] == "train"]
test_df = df[df['mode'] == 'test'] test_df = df[df["mode"] == "test"]
ax.plot(train_df['epoch'], train_df['acc'], label='Train') ax.plot(train_df["epoch"], train_df["acc"], label="Train")
ax.plot(test_df['epoch'], test_df['acc'], label='Test') ax.plot(test_df["epoch"], test_df["acc"], label="Test")
ax.set_xlabel('Epoch') ax.set_xlabel("Epoch")
ax.set_ylabel('Accuracy') ax.set_ylabel("Accuracy")
ax.set_ylim(70, 100) ax.set_ylim(70, 100)
fig.legend(ncol=2, loc='lower right') fig.legend(ncol=2, loc="lower right")
fig.tight_layout() fig.tight_layout()
fname = 'maml-accs.png' fname = "maml-accs.png"
print(f'--- Plotting accuracy to {fname}') print(f"--- Plotting accuracy to {fname}")
fig.savefig(fname) fig.savefig(fname)
plt.close(fig) plt.close(fig)
@ -265,5 +274,5 @@ class Flatten(nn.Module):
return input.view(input.size(0), -1) return input.view(input.size(0), -1)
if __name__ == '__main__': if __name__ == "__main__":
main() main()

View File

@ -27,38 +27,43 @@ Our MAML++ fork and experiments are available at:
https://github.com/bamos/HowToTrainYourMAMLPytorch https://github.com/bamos/HowToTrainYourMAMLPytorch
""" """
from support.omniglot_loaders import OmniglotNShot
from functorch import make_functional_with_buffers
import torch.optim as optim
import torch.nn.functional as F
from torch import nn
import torch
import matplotlib.pyplot as plt
import argparse import argparse
import time import time
import pandas as pd
import numpy as np
import matplotlib as mpl import matplotlib as mpl
mpl.use('Agg') import matplotlib.pyplot as plt
plt.style.use('bmh') import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import torch.optim as optim
from functorch import make_functional_with_buffers
from support.omniglot_loaders import OmniglotNShot
from torch import nn
mpl.use("Agg")
plt.style.use("bmh")
def main(): def main():
argparser = argparse.ArgumentParser() argparser = argparse.ArgumentParser()
argparser.add_argument('--n-way', '--n_way', type=int, help='n way', default=5) argparser.add_argument("--n-way", "--n_way", type=int, help="n way", default=5)
argparser.add_argument( argparser.add_argument(
'--k-spt', '--k_spt', type=int, help='k shot for support set', default=5) "--k-spt", "--k_spt", type=int, help="k shot for support set", default=5
)
argparser.add_argument( argparser.add_argument(
'--k-qry', '--k_qry', type=int, help='k shot for query set', default=15) "--k-qry", "--k_qry", type=int, help="k shot for query set", default=15
)
argparser.add_argument("--device", type=str, help="device", default="cuda")
argparser.add_argument( argparser.add_argument(
'--device', type=str, help='device', default='cuda') "--task-num",
argparser.add_argument( "--task_num",
'--task-num', '--task_num',
type=int, type=int,
help='meta batch size, namely task num', help="meta batch size, namely task num",
default=32) default=32,
argparser.add_argument('--seed', type=int, help='random seed', default=1) )
argparser.add_argument("--seed", type=int, help="random seed", default=1)
args = argparser.parse_args() args = argparser.parse_args()
torch.manual_seed(args.seed) torch.manual_seed(args.seed)
@ -69,7 +74,7 @@ def main():
# Set up the Omniglot loader. # Set up the Omniglot loader.
device = args.device device = args.device
db = OmniglotNShot( db = OmniglotNShot(
'/tmp/omniglot-data', "/tmp/omniglot-data",
batchsz=args.task_num, batchsz=args.task_num,
n_way=args.n_way, n_way=args.n_way,
k_shot=args.k_spt, k_shot=args.k_spt,
@ -97,7 +102,8 @@ def main():
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2), nn.MaxPool2d(2, 2),
Flatten(), Flatten(),
nn.Linear(64, args.n_way)).to(device) nn.Linear(64, args.n_way),
).to(device)
net.train() net.train()
fnet, params, buffers = make_functional_with_buffers(net) fnet, params, buffers = make_functional_with_buffers(net)
@ -153,8 +159,7 @@ def train(db, net, device, meta_opt, epoch, log):
qry_logits = fnet(new_params, buffers, x_qry[i]) qry_logits = fnet(new_params, buffers, x_qry[i])
qry_loss = F.cross_entropy(qry_logits, y_qry[i]) qry_loss = F.cross_entropy(qry_logits, y_qry[i])
qry_losses.append(qry_loss.detach()) qry_losses.append(qry_loss.detach())
qry_acc = (qry_logits.argmax( qry_acc = (qry_logits.argmax(dim=1) == y_qry[i]).sum().item() / querysz
dim=1) == y_qry[i]).sum().item() / querysz
qry_accs.append(qry_acc) qry_accs.append(qry_acc)
# Update the model's meta-parameters to optimize the query # Update the model's meta-parameters to optimize the query
@ -164,21 +169,23 @@ def train(db, net, device, meta_opt, epoch, log):
meta_opt.step() meta_opt.step()
qry_losses = sum(qry_losses) / task_num qry_losses = sum(qry_losses) / task_num
qry_accs = 100. * sum(qry_accs) / task_num qry_accs = 100.0 * sum(qry_accs) / task_num
i = epoch + float(batch_idx) / n_train_iter i = epoch + float(batch_idx) / n_train_iter
iter_time = time.time() - start_time iter_time = time.time() - start_time
if batch_idx % 4 == 0: if batch_idx % 4 == 0:
print( print(
f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}' f"[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}"
) )
log.append({ log.append(
'epoch': i, {
'loss': qry_losses, "epoch": i,
'acc': qry_accs, "loss": qry_losses,
'mode': 'train', "acc": qry_accs,
'time': time.time(), "mode": "train",
}) "time": time.time(),
}
)
def test(db, net, device, epoch, log): def test(db, net, device, epoch, log):
@ -194,7 +201,7 @@ def test(db, net, device, epoch, log):
qry_accs = [] qry_accs = []
for batch_idx in range(n_test_iter): for batch_idx in range(n_test_iter):
x_spt, y_spt, x_qry, y_qry = db.next('test') x_spt, y_spt, x_qry, y_qry = db.next("test")
task_num, setsz, c_, h, w = x_spt.size() task_num, setsz, c_, h, w = x_spt.size()
# TODO: Maybe pull this out into a separate module so it # TODO: Maybe pull this out into a separate module so it
@ -211,24 +218,22 @@ def test(db, net, device, epoch, log):
# The query loss and acc induced by these parameters. # The query loss and acc induced by these parameters.
qry_logits = fnet(new_params, buffers, x_qry[i]).detach() qry_logits = fnet(new_params, buffers, x_qry[i]).detach()
qry_loss = F.cross_entropy( qry_loss = F.cross_entropy(qry_logits, y_qry[i], reduction="none")
qry_logits, y_qry[i], reduction='none')
qry_losses.append(qry_loss.detach()) qry_losses.append(qry_loss.detach())
qry_accs.append( qry_accs.append((qry_logits.argmax(dim=1) == y_qry[i]).detach())
(qry_logits.argmax(dim=1) == y_qry[i]).detach())
qry_losses = torch.cat(qry_losses).mean().item() qry_losses = torch.cat(qry_losses).mean().item()
qry_accs = 100. * torch.cat(qry_accs).float().mean().item() qry_accs = 100.0 * torch.cat(qry_accs).float().mean().item()
print( print(f"[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}")
f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}' log.append(
{
"epoch": epoch + 1,
"loss": qry_losses,
"acc": qry_accs,
"mode": "test",
"time": time.time(),
}
) )
log.append({
'epoch': epoch + 1,
'loss': qry_losses,
'acc': qry_accs,
'mode': 'test',
'time': time.time(),
})
def plot(log): def plot(log):
@ -237,17 +242,17 @@ def plot(log):
df = pd.DataFrame(log) df = pd.DataFrame(log)
fig, ax = plt.subplots(figsize=(6, 4)) fig, ax = plt.subplots(figsize=(6, 4))
train_df = df[df['mode'] == 'train'] train_df = df[df["mode"] == "train"]
test_df = df[df['mode'] == 'test'] test_df = df[df["mode"] == "test"]
ax.plot(train_df['epoch'], train_df['acc'], label='Train') ax.plot(train_df["epoch"], train_df["acc"], label="Train")
ax.plot(test_df['epoch'], test_df['acc'], label='Test') ax.plot(test_df["epoch"], test_df["acc"], label="Test")
ax.set_xlabel('Epoch') ax.set_xlabel("Epoch")
ax.set_ylabel('Accuracy') ax.set_ylabel("Accuracy")
ax.set_ylim(70, 100) ax.set_ylim(70, 100)
fig.legend(ncol=2, loc='lower right') fig.legend(ncol=2, loc="lower right")
fig.tight_layout() fig.tight_layout()
fname = 'maml-accs.png' fname = "maml-accs.png"
print(f'--- Plotting accuracy to {fname}') print(f"--- Plotting accuracy to {fname}")
fig.savefig(fname) fig.savefig(fname)
plt.close(fig) plt.close(fig)
@ -259,5 +264,5 @@ class Flatten(nn.Module):
return input.view(input.size(0), -1) return input.view(input.size(0), -1)
if __name__ == '__main__': if __name__ == "__main__":
main() main()

View File

@ -27,39 +27,44 @@ Our MAML++ fork and experiments are available at:
https://github.com/bamos/HowToTrainYourMAMLPytorch https://github.com/bamos/HowToTrainYourMAMLPytorch
""" """
from support.omniglot_loaders import OmniglotNShot
from torch.func import vmap, grad, functional_call
import torch.optim as optim
import torch.nn.functional as F
from torch import nn
import torch
import matplotlib.pyplot as plt
import argparse import argparse
import time
import functools import functools
import time
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd import pandas as pd
import numpy as np import torch
import matplotlib as mpl import torch.nn.functional as F
mpl.use('Agg') import torch.optim as optim
plt.style.use('bmh') from support.omniglot_loaders import OmniglotNShot
from torch import nn
from torch.func import functional_call, grad, vmap
mpl.use("Agg")
plt.style.use("bmh")
def main(): def main():
argparser = argparse.ArgumentParser() argparser = argparse.ArgumentParser()
argparser.add_argument('--n-way', '--n_way', type=int, help='n way', default=5) argparser.add_argument("--n-way", "--n_way", type=int, help="n way", default=5)
argparser.add_argument( argparser.add_argument(
'--k-spt', '--k_spt', type=int, help='k shot for support set', default=5) "--k-spt", "--k_spt", type=int, help="k shot for support set", default=5
)
argparser.add_argument( argparser.add_argument(
'--k-qry', '--k_qry', type=int, help='k shot for query set', default=15) "--k-qry", "--k_qry", type=int, help="k shot for query set", default=15
)
argparser.add_argument("--device", type=str, help="device", default="cuda")
argparser.add_argument( argparser.add_argument(
'--device', type=str, help='device', default='cuda') "--task-num",
argparser.add_argument( "--task_num",
'--task-num', '--task_num',
type=int, type=int,
help='meta batch size, namely task num', help="meta batch size, namely task num",
default=32) default=32,
argparser.add_argument('--seed', type=int, help='random seed', default=1) )
argparser.add_argument("--seed", type=int, help="random seed", default=1)
args = argparser.parse_args() args = argparser.parse_args()
torch.manual_seed(args.seed) torch.manual_seed(args.seed)
@ -70,7 +75,7 @@ def main():
# Set up the Omniglot loader. # Set up the Omniglot loader.
device = args.device device = args.device
db = OmniglotNShot( db = OmniglotNShot(
'/tmp/omniglot-data', "/tmp/omniglot-data",
batchsz=args.task_num, batchsz=args.task_num,
n_way=args.n_way, n_way=args.n_way,
k_shot=args.k_spt, k_shot=args.k_spt,
@ -95,7 +100,8 @@ def main():
nn.ReLU(inplace=inplace_relu), nn.ReLU(inplace=inplace_relu),
nn.MaxPool2d(2, 2), nn.MaxPool2d(2, 2),
nn.Flatten(), nn.Flatten(),
nn.Linear(64, args.n_way)).to(device) nn.Linear(64, args.n_way),
).to(device)
net.train() net.train()
@ -132,8 +138,7 @@ def loss_for_task(net, n_inner_iter, x_spt, y_spt, x_qry, y_qry):
# These will be used to update the model's meta-parameters. # These will be used to update the model's meta-parameters.
qry_logits = functional_call(net, (new_params, buffers), x_qry) qry_logits = functional_call(net, (new_params, buffers), x_qry)
qry_loss = F.cross_entropy(qry_logits, y_qry) qry_loss = F.cross_entropy(qry_logits, y_qry)
qry_acc = (qry_logits.argmax( qry_acc = (qry_logits.argmax(dim=1) == y_qry).sum() / querysz
dim=1) == y_qry).sum() / querysz
return qry_loss, qry_acc return qry_loss, qry_acc
@ -163,21 +168,23 @@ def train(db, net, device, meta_opt, epoch, log):
meta_opt.step() meta_opt.step()
qry_losses = qry_losses.detach().sum() / task_num qry_losses = qry_losses.detach().sum() / task_num
qry_accs = 100. * qry_accs.sum() / task_num qry_accs = 100.0 * qry_accs.sum() / task_num
i = epoch + float(batch_idx) / n_train_iter i = epoch + float(batch_idx) / n_train_iter
iter_time = time.time() - start_time iter_time = time.time() - start_time
if batch_idx % 4 == 0: if batch_idx % 4 == 0:
print( print(
f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}' f"[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}"
) )
log.append({ log.append(
'epoch': i, {
'loss': qry_losses, "epoch": i,
'acc': qry_accs, "loss": qry_losses,
'mode': 'train', "acc": qry_accs,
'time': time.time(), "mode": "train",
}) "time": time.time(),
}
)
def test(db, net, device, epoch, log): def test(db, net, device, epoch, log):
@ -194,7 +201,7 @@ def test(db, net, device, epoch, log):
qry_accs = [] qry_accs = []
for batch_idx in range(n_test_iter): for batch_idx in range(n_test_iter):
x_spt, y_spt, x_qry, y_qry = db.next('test') x_spt, y_spt, x_qry, y_qry = db.next("test")
task_num, setsz, c_, h, w = x_spt.size() task_num, setsz, c_, h, w = x_spt.size()
# TODO: Maybe pull this out into a separate module so it # TODO: Maybe pull this out into a separate module so it
@ -207,28 +214,28 @@ def test(db, net, device, epoch, log):
spt_logits = functional_call(net, (new_params, buffers), x_spt[i]) spt_logits = functional_call(net, (new_params, buffers), x_spt[i])
spt_loss = F.cross_entropy(spt_logits, y_spt[i]) spt_loss = F.cross_entropy(spt_logits, y_spt[i])
grads = torch.autograd.grad(spt_loss, new_params.values()) grads = torch.autograd.grad(spt_loss, new_params.values())
new_params = {k: new_params[k] - g * 1e-1 for k, g, in zip(new_params, grads)} new_params = {
k: new_params[k] - g * 1e-1 for k, g, in zip(new_params, grads)
}
# The query loss and acc induced by these parameters. # The query loss and acc induced by these parameters.
qry_logits = functional_call(net, (new_params, buffers), x_qry[i]).detach() qry_logits = functional_call(net, (new_params, buffers), x_qry[i]).detach()
qry_loss = F.cross_entropy( qry_loss = F.cross_entropy(qry_logits, y_qry[i], reduction="none")
qry_logits, y_qry[i], reduction='none')
qry_losses.append(qry_loss.detach()) qry_losses.append(qry_loss.detach())
qry_accs.append( qry_accs.append((qry_logits.argmax(dim=1) == y_qry[i]).detach())
(qry_logits.argmax(dim=1) == y_qry[i]).detach())
qry_losses = torch.cat(qry_losses).mean().item() qry_losses = torch.cat(qry_losses).mean().item()
qry_accs = 100. * torch.cat(qry_accs).float().mean().item() qry_accs = 100.0 * torch.cat(qry_accs).float().mean().item()
print( print(f"[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}")
f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}' log.append(
{
"epoch": epoch + 1,
"loss": qry_losses,
"acc": qry_accs,
"mode": "test",
"time": time.time(),
}
) )
log.append({
'epoch': epoch + 1,
'loss': qry_losses,
'acc': qry_accs,
'mode': 'test',
'time': time.time(),
})
def plot(log): def plot(log):
@ -237,20 +244,20 @@ def plot(log):
df = pd.DataFrame(log) df = pd.DataFrame(log)
fig, ax = plt.subplots(figsize=(6, 4)) fig, ax = plt.subplots(figsize=(6, 4))
train_df = df[df['mode'] == 'train'] train_df = df[df["mode"] == "train"]
test_df = df[df['mode'] == 'test'] test_df = df[df["mode"] == "test"]
ax.plot(train_df['epoch'], train_df['acc'], label='Train') ax.plot(train_df["epoch"], train_df["acc"], label="Train")
ax.plot(test_df['epoch'], test_df['acc'], label='Test') ax.plot(test_df["epoch"], test_df["acc"], label="Test")
ax.set_xlabel('Epoch') ax.set_xlabel("Epoch")
ax.set_ylabel('Accuracy') ax.set_ylabel("Accuracy")
ax.set_ylim(70, 100) ax.set_ylim(70, 100)
fig.legend(ncol=2, loc='lower right') fig.legend(ncol=2, loc="lower right")
fig.tight_layout() fig.tight_layout()
fname = 'maml-accs.png' fname = "maml-accs.png"
print(f'--- Plotting accuracy to {fname}') print(f"--- Plotting accuracy to {fname}")
fig.savefig(fname) fig.savefig(fname)
plt.close(fig) plt.close(fig)
if __name__ == '__main__': if __name__ == "__main__":
main() main()

View File

@ -17,38 +17,38 @@
# https://github.com/dragen1860/MAML-Pytorch/blob/master/omniglot.py # https://github.com/dragen1860/MAML-Pytorch/blob/master/omniglot.py
# https://github.com/dragen1860/MAML-Pytorch/blob/master/omniglotNShot.py # https://github.com/dragen1860/MAML-Pytorch/blob/master/omniglotNShot.py
import torchvision.transforms as transforms import errno
from PIL import Image import os
import os.path
import numpy as np import numpy as np
import torch import torch
import torch.utils.data as data import torch.utils.data as data
import os import torchvision.transforms as transforms
import os.path from PIL import Image
import errno
class Omniglot(data.Dataset): class Omniglot(data.Dataset):
urls = [ urls = [
'https://github.com/brendenlake/omniglot/raw/master/python/images_background.zip', "https://github.com/brendenlake/omniglot/raw/master/python/images_background.zip",
'https://github.com/brendenlake/omniglot/raw/master/python/images_evaluation.zip' "https://github.com/brendenlake/omniglot/raw/master/python/images_evaluation.zip",
] ]
raw_folder = 'raw' raw_folder = "raw"
processed_folder = 'processed' processed_folder = "processed"
training_file = 'training.pt' training_file = "training.pt"
test_file = 'test.pt' test_file = "test.pt"
''' """
The items are (filename,category). The index of all the categories can be found in self.idx_classes The items are (filename,category). The index of all the categories can be found in self.idx_classes
Args: Args:
- root: the directory where the dataset will be stored - root: the directory where the dataset will be stored
- transform: how to transform the input - transform: how to transform the input
- target_transform: how to transform the target - target_transform: how to transform the target
- download: need to download the dataset - download: need to download the dataset
''' """
def __init__(self, root, transform=None, target_transform=None, def __init__(self, root, transform=None, target_transform=None, download=False):
download=False):
self.root = root self.root = root
self.transform = transform self.transform = transform
self.target_transform = target_transform self.target_transform = target_transform
@ -57,14 +57,16 @@ class Omniglot(data.Dataset):
if download: if download:
self.download() self.download()
else: else:
raise RuntimeError('Dataset not found.' + ' You can use download=True to download it') raise RuntimeError(
"Dataset not found." + " You can use download=True to download it"
)
self.all_items = find_classes(os.path.join(self.root, self.processed_folder)) self.all_items = find_classes(os.path.join(self.root, self.processed_folder))
self.idx_classes = index_classes(self.all_items) self.idx_classes = index_classes(self.all_items)
def __getitem__(self, index): def __getitem__(self, index):
filename = self.all_items[index][0] filename = self.all_items[index][0]
img = str.join('/', [self.all_items[index][2], filename]) img = str.join("/", [self.all_items[index][2], filename])
target = self.idx_classes[self.all_items[index][1]] target = self.idx_classes[self.all_items[index][1]]
if self.transform is not None: if self.transform is not None:
@ -78,8 +80,11 @@ class Omniglot(data.Dataset):
return len(self.all_items) return len(self.all_items)
def _check_exists(self): def _check_exists(self):
return os.path.exists(os.path.join(self.root, self.processed_folder, "images_evaluation")) and \ return os.path.exists(
os.path.exists(os.path.join(self.root, self.processed_folder, "images_background")) os.path.join(self.root, self.processed_folder, "images_evaluation")
) and os.path.exists(
os.path.join(self.root, self.processed_folder, "images_background")
)
def download(self): def download(self):
import urllib import urllib
@ -99,15 +104,15 @@ class Omniglot(data.Dataset):
raise raise
for url in self.urls: for url in self.urls:
print('== Downloading ' + url) print("== Downloading " + url)
data = urllib.request.urlopen(url) data = urllib.request.urlopen(url)
filename = url.rpartition('/')[2] filename = url.rpartition("/")[2]
file_path = os.path.join(self.root, self.raw_folder, filename) file_path = os.path.join(self.root, self.raw_folder, filename)
with open(file_path, 'wb') as f: with open(file_path, "wb") as f:
f.write(data.read()) f.write(data.read())
file_processed = os.path.join(self.root, self.processed_folder) file_processed = os.path.join(self.root, self.processed_folder)
print("== Unzip from " + file_path + " to " + file_processed) print("== Unzip from " + file_path + " to " + file_processed)
zip_ref = zipfile.ZipFile(file_path, 'r') zip_ref = zipfile.ZipFile(file_path, "r")
zip_ref.extractall(file_processed) zip_ref.extractall(file_processed)
zip_ref.close() zip_ref.close()
print("Download finished.") print("Download finished.")
@ -115,10 +120,10 @@ class Omniglot(data.Dataset):
def find_classes(root_dir): def find_classes(root_dir):
retour = [] retour = []
for (root, dirs, files) in os.walk(root_dir): for root, dirs, files in os.walk(root_dir):
for f in files: for f in files:
if (f.endswith("png")): if f.endswith("png"):
r = root.split('/') r = root.split("/")
lr = len(r) lr = len(r)
retour.append((f, r[lr - 2] + "/" + r[lr - 1], root)) retour.append((f, r[lr - 2] + "/" + r[lr - 1], root))
print(f"== Found {len(retour)} items ") print(f"== Found {len(retour)} items ")
@ -135,7 +140,6 @@ def index_classes(items):
class OmniglotNShot: class OmniglotNShot:
def __init__(self, root, batchsz, n_way, k_shot, k_query, imgsz, device=None): def __init__(self, root, batchsz, n_way, k_shot, k_query, imgsz, device=None):
""" """
Different from mnistNShot, the Different from mnistNShot, the
@ -149,41 +153,52 @@ class OmniglotNShot:
self.resize = imgsz self.resize = imgsz
self.device = device self.device = device
if not os.path.isfile(os.path.join(root, 'omniglot.npy')): if not os.path.isfile(os.path.join(root, "omniglot.npy")):
# if root/data.npy does not exist, just download it # if root/data.npy does not exist, just download it
self.x = Omniglot( self.x = Omniglot(
root, download=True, root,
download=True,
transform=transforms.Compose( transform=transforms.Compose(
[lambda x: Image.open(x).convert('L'), [
lambda x: Image.open(x).convert("L"),
lambda x: x.resize((imgsz, imgsz)), lambda x: x.resize((imgsz, imgsz)),
lambda x: np.reshape(x, (imgsz, imgsz, 1)), lambda x: np.reshape(x, (imgsz, imgsz, 1)),
lambda x: np.transpose(x, [2, 0, 1]), lambda x: np.transpose(x, [2, 0, 1]),
lambda x: x / 255.]), lambda x: x / 255.0,
]
),
) )
temp = {} # {label:img1, img2..., 20 imgs, label2: img1, img2,... in total, 1623 label} temp = (
for (img, label) in self.x: {}
) # {label:img1, img2..., 20 imgs, label2: img1, img2,... in total, 1623 label}
for img, label in self.x:
if label in temp.keys(): if label in temp.keys():
temp[label].append(img) temp[label].append(img)
else: else:
temp[label] = [img] temp[label] = [img]
self.x = [] self.x = []
for label, imgs in temp.items(): # labels info deserted , each label contains 20imgs for (
label,
imgs,
) in temp.items(): # labels info deserted , each label contains 20imgs
self.x.append(np.array(imgs)) self.x.append(np.array(imgs))
# as different class may have different number of imgs # as different class may have different number of imgs
self.x = np.array(self.x).astype(np.float) # [[20 imgs],..., 1623 classes in total] self.x = np.array(self.x).astype(
np.float
) # [[20 imgs],..., 1623 classes in total]
# each character contains 20 imgs # each character contains 20 imgs
print('data shape:', self.x.shape) # [1623, 20, 84, 84, 1] print("data shape:", self.x.shape) # [1623, 20, 84, 84, 1]
temp = [] # Free memory temp = [] # Free memory
# save all dataset into npy file. # save all dataset into npy file.
np.save(os.path.join(root, 'omniglot.npy'), self.x) np.save(os.path.join(root, "omniglot.npy"), self.x)
print('write into omniglot.npy.') print("write into omniglot.npy.")
else: else:
# if data.npy exists, just load it. # if data.npy exists, just load it.
self.x = np.load(os.path.join(root, 'omniglot.npy')) self.x = np.load(os.path.join(root, "omniglot.npy"))
print('load from omniglot.npy.') print("load from omniglot.npy.")
# [1623, 20, 84, 84, 1] # [1623, 20, 84, 84, 1]
# TODO: can not shuffle here, we must keep training and test set distinct! # TODO: can not shuffle here, we must keep training and test set distinct!
@ -200,11 +215,18 @@ class OmniglotNShot:
# save pointer of current read batch in total cache # save pointer of current read batch in total cache
self.indexes = {"train": 0, "test": 0} self.indexes = {"train": 0, "test": 0}
self.datasets = {"train": self.x_train, "test": self.x_test} # original data cached self.datasets = {
"train": self.x_train,
"test": self.x_test,
} # original data cached
print("DB: train", self.x_train.shape, "test", self.x_test.shape) print("DB: train", self.x_train.shape, "test", self.x_test.shape)
self.datasets_cache = {"train": self.load_data_cache(self.datasets["train"]), # current epoch data cached self.datasets_cache = {
"test": self.load_data_cache(self.datasets["test"])} "train": self.load_data_cache(
self.datasets["train"]
), # current epoch data cached
"test": self.load_data_cache(self.datasets["test"]),
}
def normalization(self): def normalization(self):
""" """
@ -238,16 +260,15 @@ class OmniglotNShot:
# print('preload next 50 caches of batchsz of batch.') # print('preload next 50 caches of batchsz of batch.')
for sample in range(10): # num of episodes for sample in range(10): # num of episodes
x_spts, y_spts, x_qrys, y_qrys = [], [], [], [] x_spts, y_spts, x_qrys, y_qrys = [], [], [], []
for i in range(self.batchsz): # one batch means one set for i in range(self.batchsz): # one batch means one set
x_spt, y_spt, x_qry, y_qry = [], [], [], [] x_spt, y_spt, x_qry, y_qry = [], [], [], []
selected_cls = np.random.choice(data_pack.shape[0], self.n_way, False) selected_cls = np.random.choice(data_pack.shape[0], self.n_way, False)
for j, cur_class in enumerate(selected_cls): for j, cur_class in enumerate(selected_cls):
selected_img = np.random.choice(
selected_img = np.random.choice(20, self.k_shot + self.k_query, False) 20, self.k_shot + self.k_query, False
)
# meta-training and meta-test # meta-training and meta-test
x_spt.append(data_pack[cur_class][selected_img[: self.k_shot]]) x_spt.append(data_pack[cur_class][selected_img[: self.k_shot]])
@ -257,10 +278,14 @@ class OmniglotNShot:
# shuffle inside a batch # shuffle inside a batch
perm = np.random.permutation(self.n_way * self.k_shot) perm = np.random.permutation(self.n_way * self.k_shot)
x_spt = np.array(x_spt).reshape(self.n_way * self.k_shot, 1, self.resize, self.resize)[perm] x_spt = np.array(x_spt).reshape(
self.n_way * self.k_shot, 1, self.resize, self.resize
)[perm]
y_spt = np.array(y_spt).reshape(self.n_way * self.k_shot)[perm] y_spt = np.array(y_spt).reshape(self.n_way * self.k_shot)[perm]
perm = np.random.permutation(self.n_way * self.k_query) perm = np.random.permutation(self.n_way * self.k_query)
x_qry = np.array(x_qry).reshape(self.n_way * self.k_query, 1, self.resize, self.resize)[perm] x_qry = np.array(x_qry).reshape(
self.n_way * self.k_query, 1, self.resize, self.resize
)[perm]
y_qry = np.array(y_qry).reshape(self.n_way * self.k_query)[perm] y_qry = np.array(y_qry).reshape(self.n_way * self.k_query)[perm]
# append [sptsz, 1, 84, 84] => [b, setsz, 1, 84, 84] # append [sptsz, 1, 84, 84] => [b, setsz, 1, 84, 84]
@ -270,22 +295,30 @@ class OmniglotNShot:
y_qrys.append(y_qry) y_qrys.append(y_qry)
# [b, setsz, 1, 84, 84] # [b, setsz, 1, 84, 84]
x_spts = np.array(x_spts).astype(np.float32).reshape(self.batchsz, setsz, 1, self.resize, self.resize) x_spts = (
np.array(x_spts)
.astype(np.float32)
.reshape(self.batchsz, setsz, 1, self.resize, self.resize)
)
y_spts = np.array(y_spts).astype(int).reshape(self.batchsz, setsz) y_spts = np.array(y_spts).astype(int).reshape(self.batchsz, setsz)
# [b, qrysz, 1, 84, 84] # [b, qrysz, 1, 84, 84]
x_qrys = np.array(x_qrys).astype(np.float32).reshape(self.batchsz, querysz, 1, self.resize, self.resize) x_qrys = (
np.array(x_qrys)
.astype(np.float32)
.reshape(self.batchsz, querysz, 1, self.resize, self.resize)
)
y_qrys = np.array(y_qrys).astype(int).reshape(self.batchsz, querysz) y_qrys = np.array(y_qrys).astype(int).reshape(self.batchsz, querysz)
x_spts, y_spts, x_qrys, y_qrys = ( x_spts, y_spts, x_qrys, y_qrys = (
torch.from_numpy(z).to(self.device) for z in torch.from_numpy(z).to(self.device)
[x_spts, y_spts, x_qrys, y_qrys] for z in [x_spts, y_spts, x_qrys, y_qrys]
) )
data_cache.append([x_spts, y_spts, x_qrys, y_qrys]) data_cache.append([x_spts, y_spts, x_qrys, y_qrys])
return data_cache return data_cache
def next(self, mode='train'): def next(self, mode="train"):
""" """
Gets next batch from the dataset with name. Gets next batch from the dataset with name.
:param mode: The name of the splitting (one of "train", "val", "test") :param mode: The name of the splitting (one of "train", "val", "test")

View File

@ -2,13 +2,15 @@
# (https://github.com/ericjang/maml-jax). # (https://github.com/ericjang/maml-jax).
# We translated his implementation from JAX to PyTorch. # We translated his implementation from JAX to PyTorch.
import matplotlib.pyplot as plt
import math import math
import torch
import numpy as np
from torch.nn import functional as F
import matplotlib as mpl import matplotlib as mpl
mpl.use('Agg') import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.nn import functional as F
mpl.use("Agg")
def net(x, params): def net(x, params):
@ -23,13 +25,15 @@ def net(x, params):
params = [ params = [
torch.Tensor(40, 1).uniform_(-1., 1.).requires_grad_(), torch.Tensor(40, 1).uniform_(-1.0, 1.0).requires_grad_(),
torch.Tensor(40).zero_().requires_grad_(), torch.Tensor(40).zero_().requires_grad_(),
torch.Tensor(40, 40)
torch.Tensor(40, 40).uniform_(-1. / math.sqrt(40), 1. / math.sqrt(40)).requires_grad_(), .uniform_(-1.0 / math.sqrt(40), 1.0 / math.sqrt(40))
.requires_grad_(),
torch.Tensor(40).zero_().requires_grad_(), torch.Tensor(40).zero_().requires_grad_(),
torch.Tensor(1, 40)
torch.Tensor(1, 40).uniform_(-1. / math.sqrt(40), 1. / math.sqrt(40)).requires_grad_(), .uniform_(-1.0 / math.sqrt(40), 1.0 / math.sqrt(40))
.requires_grad_(),
torch.Tensor(1).zero_().requires_grad_(), torch.Tensor(1).zero_().requires_grad_(),
] ]
@ -46,17 +50,18 @@ def sample_tasks(outer_batch_size, inner_batch_size):
As = [] As = []
phases = [] phases = []
for _ in range(outer_batch_size): for _ in range(outer_batch_size):
As.append(np.random.uniform(low=0.1, high=.5)) As.append(np.random.uniform(low=0.1, high=0.5))
phases.append(np.random.uniform(low=0., high=np.pi)) phases.append(np.random.uniform(low=0.0, high=np.pi))
def get_batch(): def get_batch():
xs, ys = [], [] xs, ys = [], []
for A, phase in zip(As, phases): for A, phase in zip(As, phases):
x = np.random.uniform(low=-5., high=5., size=(inner_batch_size, 1)) x = np.random.uniform(low=-5.0, high=5.0, size=(inner_batch_size, 1))
y = A * np.sin(x + phase) y = A * np.sin(x + phase)
xs.append(x) xs.append(x)
ys.append(y) ys.append(y)
return torch.tensor(xs, dtype=torch.float), torch.tensor(ys, dtype=torch.float) return torch.tensor(xs, dtype=torch.float), torch.tensor(ys, dtype=torch.float)
x1, y1 = get_batch() x1, y1 = get_batch()
x2, y2 = get_batch() x2, y2 = get_batch()
return x1, y1, x2, y2 return x1, y1, x2, y2
@ -80,14 +85,17 @@ for it in range(20000):
return F.mse_loss(v_f, y2) return F.mse_loss(v_f, y2)
task = sample_tasks(num_tasks, K) task = sample_tasks(num_tasks, K)
inner_losses = [get_loss_for_task(task[0][i], task[1][i], task[2][i], task[3][i]) for i in range(num_tasks)] inner_losses = [
get_loss_for_task(task[0][i], task[1][i], task[2][i], task[3][i])
for i in range(num_tasks)
]
loss2 = sum(inner_losses) / len(inner_losses) loss2 = sum(inner_losses) / len(inner_losses)
loss2.backward() loss2.backward()
opt.step() opt.step()
if it % 100 == 0: if it % 100 == 0:
print('Iteration %d -- Outer Loss: %.4f' % (it, loss2)) print("Iteration %d -- Outer Loss: %.4f" % (it, loss2))
losses.append(loss2.detach()) losses.append(loss2.detach())
t_A = torch.tensor(0.0).uniform_(0.1, 0.5) t_A = torch.tensor(0.0).uniform_(0.1, 0.5)
@ -112,11 +120,11 @@ test_y = t_A * torch.sin(test_x + t_b)
test_f = net(test_x, t_params) test_f = net(test_x, t_params)
plt.plot(test_x.data.numpy(), test_y.data.numpy(), label='sin(x)') plt.plot(test_x.data.numpy(), test_y.data.numpy(), label="sin(x)")
plt.plot(test_x.data.numpy(), test_f.data.numpy(), label='net(x)') plt.plot(test_x.data.numpy(), test_f.data.numpy(), label="net(x)")
plt.plot(t_x.data.numpy(), t_y.data.numpy(), 'o', label='Examples') plt.plot(t_x.data.numpy(), t_y.data.numpy(), "o", label="Examples")
plt.legend() plt.legend()
plt.savefig('maml-sine.png') plt.savefig("maml-sine.png")
plt.figure() plt.figure()
plt.plot(np.convolve(losses, [.05] * 20)) plt.plot(np.convolve(losses, [0.05] * 20))
plt.savefig('losses.png') plt.savefig("losses.png")

View File

@ -2,14 +2,16 @@
# (https://github.com/ericjang/maml-jax). # (https://github.com/ericjang/maml-jax).
# We translated his implementation from JAX to PyTorch. # We translated his implementation from JAX to PyTorch.
from torch.func import grad, vmap
import matplotlib.pyplot as plt
import math import math
import torch
import numpy as np
from torch.nn import functional as F
import matplotlib as mpl import matplotlib as mpl
mpl.use('Agg') import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.func import grad, vmap
from torch.nn import functional as F
mpl.use("Agg")
def net(params, x): def net(params, x):
@ -24,13 +26,15 @@ def net(params, x):
params = [ params = [
torch.Tensor(40, 1).uniform_(-1., 1.).requires_grad_(), torch.Tensor(40, 1).uniform_(-1.0, 1.0).requires_grad_(),
torch.Tensor(40).zero_().requires_grad_(), torch.Tensor(40).zero_().requires_grad_(),
torch.Tensor(40, 40)
torch.Tensor(40, 40).uniform_(-1. / math.sqrt(40), 1. / math.sqrt(40)).requires_grad_(), .uniform_(-1.0 / math.sqrt(40), 1.0 / math.sqrt(40))
.requires_grad_(),
torch.Tensor(40).zero_().requires_grad_(), torch.Tensor(40).zero_().requires_grad_(),
torch.Tensor(1, 40)
torch.Tensor(1, 40).uniform_(-1. / math.sqrt(40), 1. / math.sqrt(40)).requires_grad_(), .uniform_(-1.0 / math.sqrt(40), 1.0 / math.sqrt(40))
.requires_grad_(),
torch.Tensor(1).zero_().requires_grad_(), torch.Tensor(1).zero_().requires_grad_(),
] ]
@ -54,17 +58,18 @@ def sample_tasks(outer_batch_size, inner_batch_size):
As = [] As = []
phases = [] phases = []
for _ in range(outer_batch_size): for _ in range(outer_batch_size):
As.append(np.random.uniform(low=0.1, high=.5)) As.append(np.random.uniform(low=0.1, high=0.5))
phases.append(np.random.uniform(low=0., high=np.pi)) phases.append(np.random.uniform(low=0.0, high=np.pi))
def get_batch(): def get_batch():
xs, ys = [], [] xs, ys = [], []
for A, phase in zip(As, phases): for A, phase in zip(As, phases):
x = np.random.uniform(low=-5., high=5., size=(inner_batch_size, 1)) x = np.random.uniform(low=-5.0, high=5.0, size=(inner_batch_size, 1))
y = A * np.sin(x + phase) y = A * np.sin(x + phase)
xs.append(x) xs.append(x)
ys.append(y) ys.append(y)
return torch.tensor(xs, dtype=torch.float), torch.tensor(ys, dtype=torch.float) return torch.tensor(xs, dtype=torch.float), torch.tensor(ys, dtype=torch.float)
x1, y1 = get_batch() x1, y1 = get_batch()
x2, y2 = get_batch() x2, y2 = get_batch()
return x1, y1, x2, y2 return x1, y1, x2, y2
@ -94,7 +99,7 @@ for it in range(20000):
opt.step() opt.step()
if it % 100 == 0: if it % 100 == 0:
print('Iteration %d -- Outer Loss: %.4f' % (it, loss2)) print("Iteration %d -- Outer Loss: %.4f" % (it, loss2))
losses.append(loss2.detach()) losses.append(loss2.detach())
t_A = torch.tensor(0.0).uniform_(0.1, 0.5) t_A = torch.tensor(0.0).uniform_(0.1, 0.5)
@ -119,11 +124,11 @@ test_y = t_A * torch.sin(test_x + t_b)
test_f = net(t_params, test_x) test_f = net(t_params, test_x)
plt.plot(test_x.data.numpy(), test_y.data.numpy(), label='sin(x)') plt.plot(test_x.data.numpy(), test_y.data.numpy(), label="sin(x)")
plt.plot(test_x.data.numpy(), test_f.data.numpy(), label='net(x)') plt.plot(test_x.data.numpy(), test_f.data.numpy(), label="net(x)")
plt.plot(t_x.data.numpy(), t_y.data.numpy(), 'o', label='Examples') plt.plot(t_x.data.numpy(), t_y.data.numpy(), "o", label="Examples")
plt.legend() plt.legend()
plt.savefig('maml-sine.png') plt.savefig("maml-sine.png")
plt.figure() plt.figure()
plt.plot(np.convolve(losses, [.05] * 20)) plt.plot(np.convolve(losses, [0.05] * 20))
plt.savefig('losses.png') plt.savefig("losses.png")

View File

@ -2,15 +2,17 @@
# (https://github.com/ericjang/maml-jax). # (https://github.com/ericjang/maml-jax).
# We translated his implementation from JAX to PyTorch. # We translated his implementation from JAX to PyTorch.
from functorch import grad, vmap, make_functional
import matplotlib.pyplot as plt
import math import math
import torch
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np import numpy as np
import torch
from functorch import grad, make_functional, vmap
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
import matplotlib as mpl
mpl.use('Agg') mpl.use("Agg")
class ThreeLayerNet(nn.Module): class ThreeLayerNet(nn.Module):
@ -30,6 +32,7 @@ class ThreeLayerNet(nn.Module):
x = self.fc3(x) x = self.fc3(x)
return x return x
# TODO: Use F.mse_loss # TODO: Use F.mse_loss
@ -51,17 +54,18 @@ def sample_tasks(outer_batch_size, inner_batch_size):
As = [] As = []
phases = [] phases = []
for _ in range(outer_batch_size): for _ in range(outer_batch_size):
As.append(np.random.uniform(low=0.1, high=.5)) As.append(np.random.uniform(low=0.1, high=0.5))
phases.append(np.random.uniform(low=0., high=np.pi)) phases.append(np.random.uniform(low=0.0, high=np.pi))
def get_batch(): def get_batch():
xs, ys = [], [] xs, ys = [], []
for A, phase in zip(As, phases): for A, phase in zip(As, phases):
x = np.random.uniform(low=-5., high=5., size=(inner_batch_size, 1)) x = np.random.uniform(low=-5.0, high=5.0, size=(inner_batch_size, 1))
y = A * np.sin(x + phase) y = A * np.sin(x + phase)
xs.append(x) xs.append(x)
ys.append(y) ys.append(y)
return torch.tensor(xs, dtype=torch.float), torch.tensor(ys, dtype=torch.float) return torch.tensor(xs, dtype=torch.float), torch.tensor(ys, dtype=torch.float)
x1, y1 = get_batch() x1, y1 = get_batch()
x2, y2 = get_batch() x2, y2 = get_batch()
return x1, y1, x2, y2 return x1, y1, x2, y2
@ -91,7 +95,7 @@ for it in range(20000):
opt.step() opt.step()
if it % 100 == 0: if it % 100 == 0:
print('Iteration %d -- Outer Loss: %.4f' % (it, loss2)) print("Iteration %d -- Outer Loss: %.4f" % (it, loss2))
losses.append(loss2.detach()) losses.append(loss2.detach())
t_A = torch.tensor(0.0).uniform_(0.1, 0.5) t_A = torch.tensor(0.0).uniform_(0.1, 0.5)
@ -116,11 +120,11 @@ test_y = t_A * torch.sin(test_x + t_b)
test_f = net(t_params, test_x) test_f = net(t_params, test_x)
plt.plot(test_x.data.numpy(), test_y.data.numpy(), label='sin(x)') plt.plot(test_x.data.numpy(), test_y.data.numpy(), label="sin(x)")
plt.plot(test_x.data.numpy(), test_f.data.numpy(), label='net(x)') plt.plot(test_x.data.numpy(), test_f.data.numpy(), label="net(x)")
plt.plot(t_x.data.numpy(), t_y.data.numpy(), 'o', label='Examples') plt.plot(t_x.data.numpy(), t_y.data.numpy(), "o", label="Examples")
plt.legend() plt.legend()
plt.savefig('maml-sine.png') plt.savefig("maml-sine.png")
plt.figure() plt.figure()
plt.plot(np.convolve(losses, [.05] * 20)) plt.plot(np.convolve(losses, [0.05] * 20))
plt.savefig('losses.png') plt.savefig("losses.png")

View File

@ -1,5 +1,6 @@
# PyTorch forward-mode is not mature yet # PyTorch forward-mode is not mature yet
from torch._functorch.batch_norm_replacement import replace_all_batch_norm_modules_
from torch._functorch.eager_transforms import hessian, jacfwd, jvp from torch._functorch.eager_transforms import hessian, jacfwd, jvp
from torch._functorch.vmap import chunk_vmap from torch._functorch.vmap import chunk_vmap
from torch._functorch.batch_norm_replacement import replace_all_batch_norm_modules_
from functorch import functionalize from functorch import functionalize

View File

@ -1,26 +1,31 @@
from dataclasses import dataclass from dataclasses import dataclass
import torch import torch
from torch.multiprocessing.reductions import StorageWeakRef
import torch.utils._pytree as pytree import torch.utils._pytree as pytree
from torch._C import DispatchKey, DispatchKeySet, _ExcludeDispatchKeyGuard from torch._C import _ExcludeDispatchKeyGuard, DispatchKey, DispatchKeySet
from torch._functorch.eager_transforms import _unwrap_all_tensors_from_functional, _wrap_all_tensors_to_functional, functionalize from torch._dynamo.exc import CondOpArgsMismatchError
from torch._functorch.eager_transforms import (
_unwrap_all_tensors_from_functional,
_wrap_all_tensors_to_functional,
functionalize,
)
from torch._ops import HigherOrderOperator from torch._ops import HigherOrderOperator
from torch._subclasses.fake_tensor import FakeTensorMode from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.proxy_tensor import ( from torch.fx.experimental.proxy_tensor import (
disable_proxy_modes_tracing, disable_proxy_modes_tracing,
ProxyTorchDispatchMode,
make_fx, make_fx,
ProxyTorchDispatchMode,
track_tensor_tree, track_tensor_tree,
) )
from torch.fx.passes.shape_prop import _extract_tensor_metadata from torch.fx.passes.shape_prop import _extract_tensor_metadata
from torch.multiprocessing.reductions import StorageWeakRef
from torch.utils._python_dispatch import ( from torch.utils._python_dispatch import (
_get_current_dispatch_mode, _get_current_dispatch_mode,
_pop_mode_temporarily, _pop_mode_temporarily,
) )
from torch.utils._pytree import tree_flatten from torch.utils._pytree import tree_flatten
from torch._dynamo.exc import CondOpArgsMismatchError
@dataclass @dataclass
@ -34,9 +39,14 @@ In order to do this, we need implementations for each of the dispatch keys.
""" """
cond = HigherOrderOperator("cond") cond = HigherOrderOperator("cond")
def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands): def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
assert isinstance(operands, (list, tuple)), "Cond operands must be a list or tuple of tensors" assert isinstance(
assert all(isinstance(o, torch.Tensor) for o in operands), "Cond operands must be a list of tensors" operands, (list, tuple)
), "Cond operands must be a list or tuple of tensors"
assert all(
isinstance(o, torch.Tensor) for o in operands
), "Cond operands must be a list of tensors"
with disable_proxy_modes_tracing(): with disable_proxy_modes_tracing():
true_graph = make_fx(true_fn)(*operands) true_graph = make_fx(true_fn)(*operands)
@ -45,11 +55,11 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
true_outs = [] true_outs = []
false_outs = [] false_outs = []
for node in true_graph.graph.nodes: for node in true_graph.graph.nodes:
if node.op == 'output': if node.op == "output":
true_outs.extend(node.args) true_outs.extend(node.args)
for node in false_graph.graph.nodes: for node in false_graph.graph.nodes:
if node.op == 'output': if node.op == "output":
false_outs.extend(node.args) false_outs.extend(node.args)
flat_true_outs, _ = pytree.tree_flatten(true_outs) flat_true_outs, _ = pytree.tree_flatten(true_outs)
@ -64,7 +74,7 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
for i in range(0, len(flat_true_outs)): for i in range(0, len(flat_true_outs)):
true_out = flat_true_outs[i] true_out = flat_true_outs[i]
false_out = flat_false_outs[i] false_out = flat_false_outs[i]
if true_out.meta['tensor_meta'] != false_out.meta['tensor_meta']: if true_out.meta["tensor_meta"] != false_out.meta["tensor_meta"]:
raise CondOpArgsMismatchError( raise CondOpArgsMismatchError(
f"Expected each tensor to have same metadata but got:" f"Expected each tensor to have same metadata but got:"
f"\n {true_fn.__name__} returns {true_out.meta['tensor_meta']}" f"\n {true_fn.__name__} returns {true_out.meta['tensor_meta']}"
@ -85,7 +95,7 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
true_name = next_name true_name = next_name
false_name = f"false_graph_{i}" false_name = f"false_graph_{i}"
assert(not hasattr(proxy_mode.tracer.root, false_name)) assert not hasattr(proxy_mode.tracer.root, false_name)
proxy_mode.tracer.root.register_module(true_name, true_graph) proxy_mode.tracer.root.register_module(true_name, true_graph)
proxy_mode.tracer.root.register_module(false_name, false_graph) proxy_mode.tracer.root.register_module(false_name, false_graph)
@ -94,8 +104,9 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args) proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args)
out_proxy = proxy_mode.tracer.create_proxy('call_function', func_overload, proxy_args, {}, out_proxy = proxy_mode.tracer.create_proxy(
name="conditional") "call_function", func_overload, proxy_args, {}, name="conditional"
)
# At this point, we're *guaranteed* that whether an output came from the # At this point, we're *guaranteed* that whether an output came from the
# true or false branch is indistinguishable. So, as this is just for tracing # true or false branch is indistinguishable. So, as this is just for tracing
@ -112,7 +123,7 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
@cond.py_impl(DispatchKey.CompositeExplicitAutograd) @cond.py_impl(DispatchKey.CompositeExplicitAutograd)
def cond_dense(pred, true_fn, false_fn, operands): def cond_dense(pred, true_fn, false_fn, operands):
mode = _get_current_dispatch_mode() mode = _get_current_dispatch_mode()
assert (mode is None), "Mode should never be enabled for CPU/CUDA key" assert mode is None, "Mode should never be enabled for CPU/CUDA key"
if pred: if pred:
return true_fn(*operands) return true_fn(*operands)
else: else:
@ -125,8 +136,7 @@ def cond_autograd(pred, true_fn, false_fn, *operands):
flat_operands, _ = tree_flatten([true_fn, false_fn] + [operands]) flat_operands, _ = tree_flatten([true_fn, false_fn] + [operands])
requires_grad = any( requires_grad = any(
isinstance(arg, torch.Tensor) and arg.requires_grad isinstance(arg, torch.Tensor) and arg.requires_grad for arg in flat_operands
for arg in flat_operands
) )
with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.AutogradCPU)): with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.AutogradCPU)):
@ -148,6 +158,7 @@ def cond_autograd(pred, true_fn, false_fn, *operands):
var = var.detach() var = var.detach()
var.requires_grad = True var.requires_grad = True
return var return var
return err_fn(fake_requires_grad(result)) return err_fn(fake_requires_grad(result))
return result return result
@ -156,7 +167,7 @@ def cond_autograd(pred, true_fn, false_fn, *operands):
@cond.py_impl(ProxyTorchDispatchMode) @cond.py_impl(ProxyTorchDispatchMode)
def inner(pred, true_fn, false_fn, operands): def inner(pred, true_fn, false_fn, operands):
mode = _get_current_dispatch_mode() mode = _get_current_dispatch_mode()
assert (mode is not None), "Mode should always be enabled for python fallback key" assert mode is not None, "Mode should always be enabled for python fallback key"
with _pop_mode_temporarily() as mode: with _pop_mode_temporarily() as mode:
if mode.enable_tracing: if mode.enable_tracing:
return trace_cond(mode, cond, pred, true_fn, false_fn, operands) return trace_cond(mode, cond, pred, true_fn, false_fn, operands)
@ -177,7 +188,8 @@ def cond_fake_tensor_mode(pred, true_fn, false_fn, operands):
false_meta = _extract_tensor_metadata(false_out) false_meta = _extract_tensor_metadata(false_out)
if true_meta != false_meta: if true_meta != false_meta:
raise RuntimeError( raise RuntimeError(
f"Unmatched tensor metadata from cond() branches.\ntrue branch: {true_meta}, false branch: {false_meta}") f"Unmatched tensor metadata from cond() branches.\ntrue branch: {true_meta}, false branch: {false_meta}"
)
return true_outs return true_outs
@ -203,7 +215,10 @@ def _has_potential_branch_input_mutation(branch, inputs):
input_nodes.add(node) input_nodes.add(node)
if node.op == "call_function": if node.op == "call_function":
target = node.target target = node.target
if isinstance(target, torch._ops.OpOverload) and target._schema.is_mutable: if (
isinstance(target, torch._ops.OpOverload)
and target._schema.is_mutable
):
for arg in node.args: for arg in node.args:
if arg in input_nodes: if arg in input_nodes:
return True return True
@ -241,13 +256,15 @@ def _has_potential_branch_input_alias(branch, inputs):
# for map operator, where num_mapped_args is a scalar # for map operator, where num_mapped_args is a scalar
# and doesn't have a "val" meta. # and doesn't have a "val" meta.
if node.op == "placeholder" and "val" in node.meta: if node.op == "placeholder" and "val" in node.meta:
input_storages.add(StorageWeakRef(node.meta['val']._typed_storage())) input_storages.add(StorageWeakRef(node.meta["val"]._typed_storage()))
if node.op == "output": if node.op == "output":
def check_alias(out): def check_alias(out):
if out is not None and "val" in out.meta: if out is not None and "val" in out.meta:
out_storage = StorageWeakRef(out.meta['val']._typed_storage()) out_storage = StorageWeakRef(out.meta["val"]._typed_storage())
return out_storage in input_storages return out_storage in input_storages
return False return False
if any(pytree.tree_flatten(pytree.tree_map(check_alias, node.args))[0]): if any(pytree.tree_flatten(pytree.tree_map(check_alias, node.args))[0]):
return True return True
@ -263,22 +280,30 @@ def _has_potential_branch_input_alias(branch, inputs):
@cond.py_impl(DispatchKey.Functionalize) @cond.py_impl(DispatchKey.Functionalize)
def cond_func(pred, true_fn, false_fn, inputs): def cond_func(pred, true_fn, false_fn, inputs):
reapply_views = torch._C._functionalization_reapply_views_tls() reapply_views = torch._C._functionalization_reapply_views_tls()
unwrapped_inputs = _unwrap_all_tensors_from_functional(inputs, reapply_views=reapply_views) unwrapped_inputs = _unwrap_all_tensors_from_functional(
unwrapped_pred = _unwrap_all_tensors_from_functional(pred, reapply_views=reapply_views) inputs, reapply_views=reapply_views
mode = 'mutations_and_views' if reapply_views else 'mutations' )
unwrapped_pred = _unwrap_all_tensors_from_functional(
pred, reapply_views=reapply_views
)
mode = "mutations_and_views" if reapply_views else "mutations"
with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize)): with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize)):
functional_true = functionalize(true_fn, remove=mode) functional_true = functionalize(true_fn, remove=mode)
functional_false = functionalize(false_fn, remove=mode) functional_false = functionalize(false_fn, remove=mode)
for branch in [true_fn, false_fn]: for branch in [true_fn, false_fn]:
if _has_potential_branch_input_mutation(branch, unwrapped_inputs): if _has_potential_branch_input_mutation(branch, unwrapped_inputs):
raise UnsupportedAliasMutationException("One of torch.cond branch " raise UnsupportedAliasMutationException(
"might be modifying the input!") "One of torch.cond branch " "might be modifying the input!"
)
if _has_potential_branch_input_alias(branch, unwrapped_inputs): if _has_potential_branch_input_alias(branch, unwrapped_inputs):
raise UnsupportedAliasMutationException("One of torch.cond branch " raise UnsupportedAliasMutationException(
"might be aliasing the input!") "One of torch.cond branch " "might be aliasing the input!"
)
cond_return = cond(unwrapped_pred, functional_true, functional_false, unwrapped_inputs) cond_return = cond(
unwrapped_pred, functional_true, functional_false, unwrapped_inputs
)
return _wrap_all_tensors_to_functional(cond_return, level=0) return _wrap_all_tensors_to_functional(cond_return, level=0)
@ -290,10 +315,14 @@ def cond_functionalize(interpreter, pred, true_fn, false_fn, inputs):
2. Our check for above condition is not exhaustive 2. Our check for above condition is not exhaustive
""" """
reapply_views = interpreter.functionalize_add_back_views() reapply_views = interpreter.functionalize_add_back_views()
mode = 'mutations_and_views' if reapply_views else 'mutations' mode = "mutations_and_views" if reapply_views else "mutations"
# At this point, we will see functionalized tensors, so need to unwrap them first # At this point, we will see functionalized tensors, so need to unwrap them first
unwrapped_inputs = _unwrap_all_tensors_from_functional(inputs, reapply_views=reapply_views) unwrapped_inputs = _unwrap_all_tensors_from_functional(
unwrapped_pred = _unwrap_all_tensors_from_functional(pred, reapply_views=reapply_views) inputs, reapply_views=reapply_views
)
unwrapped_pred = _unwrap_all_tensors_from_functional(
pred, reapply_views=reapply_views
)
functional_true_fn = functionalize(true_fn, remove=mode) functional_true_fn = functionalize(true_fn, remove=mode)
functional_false_fn = functionalize(false_fn, remove=mode) functional_false_fn = functionalize(false_fn, remove=mode)
@ -301,16 +330,21 @@ def cond_functionalize(interpreter, pred, true_fn, false_fn, inputs):
with interpreter.lower(): with interpreter.lower():
for branch in [functional_true_fn, functional_false_fn]: for branch in [functional_true_fn, functional_false_fn]:
if _has_potential_branch_input_mutation(branch, unwrapped_inputs): if _has_potential_branch_input_mutation(branch, unwrapped_inputs):
raise UnsupportedAliasMutationException("One of torch.cond branch " raise UnsupportedAliasMutationException(
"might be modifying the input!") "One of torch.cond branch " "might be modifying the input!"
)
for branch in [true_fn, false_fn]: for branch in [true_fn, false_fn]:
if _has_potential_branch_input_alias(branch, unwrapped_inputs): if _has_potential_branch_input_alias(branch, unwrapped_inputs):
raise UnsupportedAliasMutationException("One of torch.cond branch " raise UnsupportedAliasMutationException(
"might be aliasing the input!") "One of torch.cond branch " "might be aliasing the input!"
)
cond_return = cond(unwrapped_pred, functional_true_fn, functional_false_fn, unwrapped_inputs) cond_return = cond(
unwrapped_pred, functional_true_fn, functional_false_fn, unwrapped_inputs
)
return _wrap_all_tensors_to_functional(cond_return, level=interpreter.level()) return _wrap_all_tensors_to_functional(cond_return, level=interpreter.level())
# TODO(voz): Make this automatic for keys, this is very ugly atm # TODO(voz): Make this automatic for keys, this is very ugly atm
cond.fallthrough(DispatchKey.PythonDispatcher) cond.fallthrough(DispatchKey.PythonDispatcher)
cond.fallthrough(DispatchKey.PythonTLSSnapshot) cond.fallthrough(DispatchKey.PythonTLSSnapshot)

View File

@ -1,23 +1,32 @@
import torch import torch
import torch.utils._pytree as pytree import torch.utils._pytree as pytree
from torch._C import DispatchKey, DispatchKeySet, _ExcludeDispatchKeyGuard from torch._C import _ExcludeDispatchKeyGuard, DispatchKey, DispatchKeySet
from torch._functorch.eager_transforms import _unwrap_all_tensors_from_functional, _wrap_all_tensors_to_functional, functionalize from torch._dispatch.python import suspend_functionalization
from torch._functorch.aot_autograd import create_joint, AOTConfig from torch._functorch.aot_autograd import AOTConfig, create_joint
from torch._functorch.eager_transforms import (
_unwrap_all_tensors_from_functional,
_wrap_all_tensors_to_functional,
functionalize,
)
from torch._ops import HigherOrderOperator from torch._ops import HigherOrderOperator
from torch._subclasses.fake_tensor import FakeTensorMode from torch._subclasses.fake_tensor import FakeTensorMode
from torch.multiprocessing.reductions import StorageWeakRef
from torch.fx.experimental.proxy_tensor import ( from torch.fx.experimental.proxy_tensor import (
disable_proxy_modes_tracing, disable_proxy_modes_tracing,
make_fx, make_fx,
ProxyTorchDispatchMode, ProxyTorchDispatchMode,
track_tensor_tree, track_tensor_tree,
) )
from torch.multiprocessing.reductions import StorageWeakRef
from torch.utils._python_dispatch import ( from torch.utils._python_dispatch import (
_get_current_dispatch_mode, _get_current_dispatch_mode,
_pop_mode_temporarily, _pop_mode_temporarily,
) )
from torch._dispatch.python import suspend_functionalization
from ._cond import _has_potential_branch_input_alias, _has_potential_branch_input_mutation, UnsupportedAliasMutationException from ._cond import (
_has_potential_branch_input_alias,
_has_potential_branch_input_mutation,
UnsupportedAliasMutationException,
)
# TODO: We add this to prevent dymamo from tracing into map_wrapper, # TODO: We add this to prevent dymamo from tracing into map_wrapper,
@ -26,16 +35,19 @@ class MapWrapper(HigherOrderOperator):
def __call__(self, xs, *args): def __call__(self, xs, *args):
return map_wrapper(xs, *args) return map_wrapper(xs, *args)
map = MapWrapper("map", _deprecated_global_ns=True) map = MapWrapper("map", _deprecated_global_ns=True)
map_impl = HigherOrderOperator("map_impl", _deprecated_global_ns=True) map_impl = HigherOrderOperator("map_impl", _deprecated_global_ns=True)
dummy_aot_config = AOTConfig(fw_compiler=None, dummy_aot_config = AOTConfig(
fw_compiler=None,
bw_compiler=None, bw_compiler=None,
partition_fn=None, partition_fn=None,
decompositions={}, decompositions={},
num_params_buffers=0, num_params_buffers=0,
aot_id=0, aot_id=0,
keep_inference_input_mutations=False) keep_inference_input_mutations=False,
)
def create_fw_bw_graph(f, num_mapped_args, *args): def create_fw_bw_graph(f, num_mapped_args, *args):
@ -59,20 +71,33 @@ def create_fw_bw_graph(f, num_mapped_args, *args):
with suspend_functionalization(): with suspend_functionalization():
with disable_proxy_modes_tracing(): with disable_proxy_modes_tracing():
def from_fun(t): def from_fun(t):
if isinstance(t, torch.Tensor): if isinstance(t, torch.Tensor):
return torch.empty_strided(t.size(), t.stride(), requires_grad=t.requires_grad) return torch.empty_strided(
t.size(), t.stride(), requires_grad=t.requires_grad
)
return t return t
example_xs = [from_fun(xs) for xs in _unstack_pytree(mapped_xs)[0]] example_xs = [from_fun(xs) for xs in _unstack_pytree(mapped_xs)[0]]
example_pos_args = [from_fun(arg) if isinstance(arg, torch.Tensor) else arg for arg in pos_args] example_pos_args = [
example_flat_out = pytree.tree_map(from_fun, f(*example_xs, *example_pos_args)) from_fun(arg) if isinstance(arg, torch.Tensor) else arg
if any(not isinstance(out, torch.Tensor) for out in example_flat_out if out is not None): for arg in pos_args
raise RuntimeError("Expect outputs of map only contains tensors or None. " ]
f"Got types {[type(out) for out in example_flat_out]}.") example_flat_out = pytree.tree_map(
from_fun, f(*example_xs, *example_pos_args)
)
if any(
not isinstance(out, torch.Tensor)
for out in example_flat_out
if out is not None
):
raise RuntimeError(
"Expect outputs of map only contains tensors or None. "
f"Got types {[type(out) for out in example_flat_out]}."
)
example_grad = [from_fun(out) for out in example_flat_out] example_grad = [from_fun(out) for out in example_flat_out]
fw_graph = make_fx(f)(*example_xs, *example_pos_args) fw_graph = make_fx(f)(*example_xs, *example_pos_args)
def joint_f(*example_args): def joint_f(*example_args):
@ -84,20 +109,39 @@ def create_fw_bw_graph(f, num_mapped_args, *args):
def fw_with_masks(*args): def fw_with_masks(*args):
fw_out = f(*args) fw_out = f(*args)
return fw_out, [True if isinstance(ret, torch.Tensor) and ret.requires_grad else False for ret in fw_out] return fw_out, [
True
if isinstance(ret, torch.Tensor) and ret.requires_grad
else False
for ret in fw_out
]
joint = create_joint(fw_with_masks, aot_config=dummy_aot_config) joint = create_joint(fw_with_masks, aot_config=dummy_aot_config)
_, grads = joint(list(mapped_input) + list(args), _, grads = joint(
[grad for grad in mapped_grads if grad is not None and grad.requires_grad]) list(mapped_input) + list(args),
[
grad
for grad in mapped_grads
if grad is not None and grad.requires_grad
],
)
# In order to keep map functional for backward graph, # In order to keep map functional for backward graph,
# we clone outputs that are aliasing inputs # we clone outputs that are aliasing inputs
input_storage = {StorageWeakRef(arg._typed_storage()) for arg in example_args if isinstance(arg, torch.Tensor)} input_storage = {
StorageWeakRef(arg._typed_storage())
for arg in example_args
if isinstance(arg, torch.Tensor)
}
def maybe_clone(t): def maybe_clone(t):
if isinstance(t, torch.Tensor) and StorageWeakRef(t._typed_storage()) in input_storage: if (
isinstance(t, torch.Tensor)
and StorageWeakRef(t._typed_storage()) in input_storage
):
return t.clone() return t.clone()
return t return t
return pytree.tree_map(maybe_clone, grads) return pytree.tree_map(maybe_clone, grads)
joint_num_mapped = len(example_grad) + len(example_xs) joint_num_mapped = len(example_grad) + len(example_xs)
@ -114,12 +158,12 @@ def map_wrapper(f, xs, *args):
shapes = [xs.shape for xs in flat_xs] shapes = [xs.shape for xs in flat_xs]
leading_dim_size = shapes[0][0] leading_dim_size = shapes[0][0]
if leading_dim_size == 0: if leading_dim_size == 0:
raise RuntimeError( raise RuntimeError("Leading dimensions of mapped xs cannot be 0.")
"Leading dimensions of mapped xs cannot be 0.")
if any(cur_shape[0] != leading_dim_size for cur_shape in shapes): if any(cur_shape[0] != leading_dim_size for cur_shape in shapes):
raise RuntimeError( raise RuntimeError(
f"Leading dimensions of mapped xs must be consistent. Got shapes {shapes}.") f"Leading dimensions of mapped xs must be consistent. Got shapes {shapes}."
)
out_spec = None out_spec = None
@ -131,7 +175,11 @@ def map_wrapper(f, xs, *args):
nonlocal out_spec nonlocal out_spec
out_spec = tmp_out_spec out_spec = tmp_out_spec
return flat_out return flat_out
return pytree.tree_unflatten(map_impl(flat_fn, num_mapped_args, *flat_xs, *args), out_spec)
return pytree.tree_unflatten(
map_impl(flat_fn, num_mapped_args, *flat_xs, *args), out_spec
)
class MapAutogradOp(torch.autograd.Function): class MapAutogradOp(torch.autograd.Function):
@staticmethod @staticmethod
@ -148,9 +196,16 @@ class MapAutogradOp(torch.autograd.Function):
fw_mapped_args = fw_args[: ctx._num_mapped_args] fw_mapped_args = fw_args[: ctx._num_mapped_args]
pos_args = fw_args[ctx._num_mapped_args :] pos_args = fw_args[ctx._num_mapped_args :]
grads = map_impl(ctx._joint_graph, ctx._num_mapped_args + len(flat_grads), *fw_mapped_args, *flat_grads, *pos_args) grads = map_impl(
ctx._joint_graph,
ctx._num_mapped_args + len(flat_grads),
*fw_mapped_args,
*flat_grads,
*pos_args,
)
return None, None, None, *grads return None, None, None, *grads
def trace_map(proxy_mode, func_overload, f, num_mapped, *args): def trace_map(proxy_mode, func_overload, f, num_mapped, *args):
xs = list(args[:num_mapped]) xs = list(args[:num_mapped])
pos_args = list(args[num_mapped:]) pos_args = list(args[num_mapped:])
@ -168,6 +223,7 @@ def trace_map(proxy_mode, func_overload, f, num_mapped, *args):
if isinstance(t, torch.Tensor): if isinstance(t, torch.Tensor):
return t.expand(leading_dim_size, *t.shape) return t.expand(leading_dim_size, *t.shape)
return t return t
expanded_outs = pytree.tree_map(expand_tensor, example_outs) expanded_outs = pytree.tree_map(expand_tensor, example_outs)
next_name = None next_name = None
@ -182,9 +238,13 @@ def trace_map(proxy_mode, func_overload, f, num_mapped, *args):
proxy_mode.tracer.root.register_module(next_name, body_graph) proxy_mode.tracer.root.register_module(next_name, body_graph)
node_args = (body_graph, num_mapped, *args) node_args = (body_graph, num_mapped, *args)
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args) proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args)
out_proxy = proxy_mode.tracer.create_proxy('call_function', func_overload, proxy_args, {}, out_proxy = proxy_mode.tracer.create_proxy(
name="map_impl") "call_function", func_overload, proxy_args, {}, name="map_impl"
return track_tensor_tree(expanded_outs, out_proxy, constant=None, tracer=proxy_mode.tracer) )
return track_tensor_tree(
expanded_outs, out_proxy, constant=None, tracer=proxy_mode.tracer
)
def _unstack_pytree(xs): def _unstack_pytree(xs):
flat_xs, inspec = pytree.tree_flatten(xs) flat_xs, inspec = pytree.tree_flatten(xs)
@ -192,7 +252,9 @@ def _unstack_pytree(xs):
raise RuntimeError(f"Leaves of xs must be Tensor {flat_xs}") raise RuntimeError(f"Leaves of xs must be Tensor {flat_xs}")
if not all(xs.shape[0] == flat_xs[0].shape[0] for xs in flat_xs): if not all(xs.shape[0] == flat_xs[0].shape[0] for xs in flat_xs):
raise RuntimeError(f"Leaves of xs must have same leading dimension size {[xs.shape for xs in flat_xs]}") raise RuntimeError(
f"Leaves of xs must have same leading dimension size {[xs.shape for xs in flat_xs]}"
)
a = zip(*flat_xs) a = zip(*flat_xs)
pytrees = [] pytrees = []
@ -200,6 +262,7 @@ def _unstack_pytree(xs):
pytrees.append(pytree.tree_unflatten(tuple, inspec)) pytrees.append(pytree.tree_unflatten(tuple, inspec))
return pytrees return pytrees
def _stack_pytree(pytrees): def _stack_pytree(pytrees):
flat_out = [] flat_out = []
out_spec = None out_spec = None
@ -220,6 +283,7 @@ def _stack_pytree(pytrees):
raise RuntimeError(f"Cannot stack {leaves}.") raise RuntimeError(f"Cannot stack {leaves}.")
return pytree.tree_unflatten(stacked_out, out_spec) return pytree.tree_unflatten(stacked_out, out_spec)
@map_impl.py_impl(DispatchKey.CompositeExplicitAutograd) @map_impl.py_impl(DispatchKey.CompositeExplicitAutograd)
def map_dense(f, num_mapped_args, *args): def map_dense(f, num_mapped_args, *args):
xs = args[:num_mapped_args] xs = args[:num_mapped_args]
@ -240,7 +304,7 @@ def map_autograd(f, num_mapped_args, *args):
@map_impl.py_impl(ProxyTorchDispatchMode) @map_impl.py_impl(ProxyTorchDispatchMode)
def map_proxy_torch_dispatch_mode(f, num_mapped, *args): def map_proxy_torch_dispatch_mode(f, num_mapped, *args):
mode = _get_current_dispatch_mode() mode = _get_current_dispatch_mode()
assert (mode is not None), "Mode should always be enabled for python fallback key" assert mode is not None, "Mode should always be enabled for python fallback key"
with _pop_mode_temporarily() as mode: with _pop_mode_temporarily() as mode:
if mode.enable_tracing: if mode.enable_tracing:
return trace_map(mode, map_impl, f, num_mapped, *args) return trace_map(mode, map_impl, f, num_mapped, *args)
@ -259,8 +323,10 @@ def map_func(f, num_mapped, *args):
xs = args[:num_mapped] xs = args[:num_mapped]
pos_args = args[num_mapped:] pos_args = args[num_mapped:]
unwrapped_xs = _unwrap_all_tensors_from_functional(xs, reapply_views=reapply_views) unwrapped_xs = _unwrap_all_tensors_from_functional(xs, reapply_views=reapply_views)
unwrapped_args = _unwrap_all_tensors_from_functional(pos_args, reapply_views=reapply_views) unwrapped_args = _unwrap_all_tensors_from_functional(
mode = 'mutations_and_views' if reapply_views else 'mutations' pos_args, reapply_views=reapply_views
)
mode = "mutations_and_views" if reapply_views else "mutations"
with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize)): with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize)):
functional_map_fn = functionalize(f, remove=mode) functional_map_fn = functionalize(f, remove=mode)
@ -268,18 +334,17 @@ def map_func(f, num_mapped, *args):
example_inputs = (*_unstack_pytree(unwrapped_xs)[0], *unwrapped_args) example_inputs = (*_unstack_pytree(unwrapped_xs)[0], *unwrapped_args)
if _has_potential_branch_input_mutation(f, example_inputs): if _has_potential_branch_input_mutation(f, example_inputs):
raise UnsupportedAliasMutationException( raise UnsupportedAliasMutationException("torch.map is mutating the input!")
"torch.map is mutating the input!"
)
if _has_potential_branch_input_alias(f, example_inputs): if _has_potential_branch_input_alias(f, example_inputs):
raise UnsupportedAliasMutationException( raise UnsupportedAliasMutationException("torch.map is aliasing the input!")
"torch.map is aliasing the input!"
)
map_return = map_impl(functional_map_fn, num_mapped, *unwrapped_xs, *unwrapped_args) map_return = map_impl(
functional_map_fn, num_mapped, *unwrapped_xs, *unwrapped_args
)
return _wrap_all_tensors_to_functional(map_return, level=0) return _wrap_all_tensors_to_functional(map_return, level=0)
@map_impl.py_impl(torch._C._functorch.TransformType.Functionalize) @map_impl.py_impl(torch._C._functorch.TransformType.Functionalize)
def map_functionalize(interpreter, f, num_mapped, *args): def map_functionalize(interpreter, f, num_mapped, *args):
""" """
@ -290,10 +355,12 @@ def map_functionalize(interpreter, f, num_mapped, *args):
xs = args[:num_mapped] xs = args[:num_mapped]
pos_args = args[num_mapped:] pos_args = args[num_mapped:]
reapply_views = interpreter.functionalize_add_back_views() reapply_views = interpreter.functionalize_add_back_views()
mode = 'mutations_and_views' if reapply_views else 'mutations' mode = "mutations_and_views" if reapply_views else "mutations"
# At this point, we will see functionalized tensors, so need to unwrap them first # At this point, we will see functionalized tensors, so need to unwrap them first
unwrapped_xs = _unwrap_all_tensors_from_functional(xs, reapply_views=reapply_views) unwrapped_xs = _unwrap_all_tensors_from_functional(xs, reapply_views=reapply_views)
unwrapped_args = _unwrap_all_tensors_from_functional(pos_args, reapply_views=reapply_views) unwrapped_args = _unwrap_all_tensors_from_functional(
pos_args, reapply_views=reapply_views
)
functional_map_fn = functionalize(f, remove=mode) functional_map_fn = functionalize(f, remove=mode)
@ -301,18 +368,17 @@ def map_functionalize(interpreter, f, num_mapped, *args):
with disable_proxy_modes_tracing(): with disable_proxy_modes_tracing():
example_inputs = (*_unstack_pytree(unwrapped_xs)[0], *unwrapped_args) example_inputs = (*_unstack_pytree(unwrapped_xs)[0], *unwrapped_args)
if _has_potential_branch_input_mutation(f, example_inputs): if _has_potential_branch_input_mutation(f, example_inputs):
raise UnsupportedAliasMutationException( raise UnsupportedAliasMutationException("torch.map is mutating the input!")
"torch.map is mutating the input!"
)
if _has_potential_branch_input_alias(f, example_inputs): if _has_potential_branch_input_alias(f, example_inputs):
raise UnsupportedAliasMutationException( raise UnsupportedAliasMutationException("torch.map is aliasing the input!")
"torch.map is aliasing the input!"
)
map_return = map_impl(functional_map_fn, num_mapped, *unwrapped_xs, *unwrapped_args) map_return = map_impl(
functional_map_fn, num_mapped, *unwrapped_xs, *unwrapped_args
)
return _wrap_all_tensors_to_functional(map_return, level=interpreter.level()) return _wrap_all_tensors_to_functional(map_return, level=interpreter.level())
# TODO(voz) Make this automatic for keys, this is very ugly atm # TODO(voz) Make this automatic for keys, this is very ugly atm
map_impl.fallthrough(DispatchKey.PythonDispatcher) map_impl.fallthrough(DispatchKey.PythonDispatcher)
map_impl.fallthrough(DispatchKey.PythonTLSSnapshot) map_impl.fallthrough(DispatchKey.PythonTLSSnapshot)

View File

@ -1,2 +1,2 @@
from ._map import map # noqa: F401
from ._cond import cond, UnsupportedAliasMutationException # noqa: F401 from ._cond import cond, UnsupportedAliasMutationException # noqa: F401
from ._map import map # noqa: F401

View File

@ -19,8 +19,10 @@ Let's demonstrate how to do this using an ensemble of simple CNNs.
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
torch.manual_seed(0) torch.manual_seed(0)
# Here's a simple CNN # Here's a simple CNN
class SimpleCNN(nn.Module): class SimpleCNN(nn.Module):
def __init__(self): def __init__(self):
@ -44,11 +46,12 @@ class SimpleCNN(nn.Module):
output = x output = x
return output return output
# Let's generate some dummy data. Pretend that we're working with an MNIST dataset # Let's generate some dummy data. Pretend that we're working with an MNIST dataset
# where the images are 28 by 28. # where the images are 28 by 28.
# Furthermore, let's say we wish to combine the predictions from 10 different # Furthermore, let's say we wish to combine the predictions from 10 different
# models. # models.
device = 'cuda' device = "cuda"
num_models = 10 num_models = 10
data = torch.randn(100, 64, 1, 28, 28, device=device) data = torch.randn(100, 64, 1, 28, 28, device=device)
targets = torch.randint(10, (6400,), device=device) targets = torch.randint(10, (6400,), device=device)
@ -81,6 +84,7 @@ predictions2 = [model(minibatch) for model in models]
# functorch offers the following convenience function to do that. It returns a # functorch offers the following convenience function to do that. It returns a
# stateless version of the model (fmodel) and stacked parameters and buffers. # stateless version of the model (fmodel) and stacked parameters and buffers.
from functorch import combine_state_for_ensemble from functorch import combine_state_for_ensemble
fmodel, params, buffers = combine_state_for_ensemble(models) fmodel, params, buffers = combine_state_for_ensemble(models)
[p.requires_grad_() for p in params] [p.requires_grad_() for p in params]
@ -92,15 +96,20 @@ fmodel, params, buffers = combine_state_for_ensemble(models)
print([p.size(0) for p in params]) print([p.size(0) for p in params])
assert minibatches.shape == (num_models, 64, 1, 28, 28) assert minibatches.shape == (num_models, 64, 1, 28, 28)
from functorch import vmap from functorch import vmap
predictions1_vmap = vmap(fmodel)(params, buffers, minibatches) predictions1_vmap = vmap(fmodel)(params, buffers, minibatches)
assert torch.allclose(predictions1_vmap, torch.stack(predictions1), atol=1e-6, rtol=1e-6) assert torch.allclose(
predictions1_vmap, torch.stack(predictions1), atol=1e-6, rtol=1e-6
)
# Option 2: get predictions using the same minibatch of data # Option 2: get predictions using the same minibatch of data
# vmap has an in_dims arg that specify which dimensions to map over. # vmap has an in_dims arg that specify which dimensions to map over.
# Using ``None``, we tell vmap we want the same minibatch to apply for all of # Using ``None``, we tell vmap we want the same minibatch to apply for all of
# the 10 models. # the 10 models.
predictions2_vmap = vmap(fmodel, in_dims=(0, 0, None))(params, buffers, minibatch) predictions2_vmap = vmap(fmodel, in_dims=(0, 0, None))(params, buffers, minibatch)
assert torch.allclose(predictions2_vmap, torch.stack(predictions2), atol=1e-6, rtol=1e-6) assert torch.allclose(
predictions2_vmap, torch.stack(predictions2), atol=1e-6, rtol=1e-6
)
# A quick note: there are limitations around what types of functions can be # A quick note: there are limitations around what types of functions can be
# transformed by vmap. The best functions to transform are ones that are # transformed by vmap. The best functions to transform are ones that are

View File

@ -8,11 +8,14 @@ deep learning models. It is difficult (or annoying) to compute these quantities
efficiently using a standard autodiff system like PyTorch Autograd; functorch efficiently using a standard autodiff system like PyTorch Autograd; functorch
provides ways of computing various higher-order autodiff quantities efficiently. provides ways of computing various higher-order autodiff quantities efficiently.
""" """
from functools import partial
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from functools import partial
torch.manual_seed(0) torch.manual_seed(0)
###################################################################### ######################################################################
# Setup: Comparing functorch vs the naive approach # Setup: Comparing functorch vs the naive approach
# -------------------------------------------------------------------- # --------------------------------------------------------------------
@ -21,6 +24,7 @@ torch.manual_seed(0)
def predict(weight, bias, x): def predict(weight, bias, x):
return F.linear(x, weight, bias).tanh() return F.linear(x, weight, bias).tanh()
# Here's some dummy data: a weight, a bias, and a feature vector. # Here's some dummy data: a weight, a bias, and a feature vector.
D = 16 D = 16
weight = torch.randn(D, D) weight = torch.randn(D, D)
@ -34,19 +38,24 @@ x = torch.randn(D)
xp = x.clone().requires_grad_() xp = x.clone().requires_grad_()
unit_vectors = torch.eye(D) unit_vectors = torch.eye(D)
def compute_jac(xp): def compute_jac(xp):
jacobian_rows = [torch.autograd.grad(predict(weight, bias, xp), xp, vec)[0] jacobian_rows = [
for vec in unit_vectors] torch.autograd.grad(predict(weight, bias, xp), xp, vec)[0]
for vec in unit_vectors
]
return torch.stack(jacobian_rows) return torch.stack(jacobian_rows)
jacobian = compute_jac(xp) jacobian = compute_jac(xp)
# Instead of computing the jacobian row-by-row, we can use ``vmap`` to get rid # Instead of computing the jacobian row-by-row, we can use ``vmap`` to get rid
# of the for-loop and vectorize the computation. We can't directly apply vmap # of the for-loop and vectorize the computation. We can't directly apply vmap
# to PyTorch Autograd; instead, functorch provides a ``vjp`` transform: # to PyTorch Autograd; instead, functorch provides a ``vjp`` transform:
from functorch import vmap, vjp from functorch import vjp, vmap
_, vjp_fn = vjp(partial(predict, weight, bias), x) _, vjp_fn = vjp(partial(predict, weight, bias), x)
ft_jacobian, = vmap(vjp_fn)(unit_vectors) (ft_jacobian,) = vmap(vjp_fn)(unit_vectors)
assert torch.allclose(ft_jacobian, jacobian) assert torch.allclose(ft_jacobian, jacobian)
# In another tutorial a composition of reverse-mode AD and vmap gave us # In another tutorial a composition of reverse-mode AD and vmap gave us
@ -59,6 +68,7 @@ assert torch.allclose(ft_jacobian, jacobian)
# argument that says which argument we would like to compute Jacobians with # argument that says which argument we would like to compute Jacobians with
# respect to. # respect to.
from functorch import jacrev from functorch import jacrev
ft_jacobian = jacrev(predict, argnums=2)(weight, bias, x) ft_jacobian = jacrev(predict, argnums=2)(weight, bias, x)
assert torch.allclose(ft_jacobian, jacobian) assert torch.allclose(ft_jacobian, jacobian)
@ -67,6 +77,7 @@ assert torch.allclose(ft_jacobian, jacobian)
# there are). In general, we expect that vectorization via ``vmap`` can help # there are). In general, we expect that vectorization via ``vmap`` can help
# eliminate overhead and give better utilization of your hardware. # eliminate overhead and give better utilization of your hardware.
from torch.utils.benchmark import Timer from torch.utils.benchmark import Timer
without_vmap = Timer(stmt="compute_jac(xp)", globals=globals()) without_vmap = Timer(stmt="compute_jac(xp)", globals=globals())
with_vmap = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals()) with_vmap = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals())
print(without_vmap.timeit(500)) print(without_vmap.timeit(500))
@ -95,7 +106,7 @@ ft_jac_weight, ft_jac_bias = jacrev(predict, argnums=(0, 1))(weight, bias, x)
# In reverse-mode AD, we are computing the jacobian row-by-row, while in # In reverse-mode AD, we are computing the jacobian row-by-row, while in
# forward-mode AD (which computes Jacobian-vector products), we are computing # forward-mode AD (which computes Jacobian-vector products), we are computing
# it column-by-column. The Jacobian matrix has M rows and N columns. # it column-by-column. The Jacobian matrix has M rows and N columns.
from functorch import jacrev, jacfwd from functorch import jacfwd, jacrev
# Benchmark with more inputs than outputs # Benchmark with more inputs than outputs
Din = 32 Din = 32
@ -106,8 +117,8 @@ x = torch.randn(Din)
using_fwd = Timer(stmt="jacfwd(predict, argnums=2)(weight, bias, x)", globals=globals()) using_fwd = Timer(stmt="jacfwd(predict, argnums=2)(weight, bias, x)", globals=globals())
using_bwd = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals()) using_bwd = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals())
print(f'jacfwd time: {using_fwd.timeit(500)}') print(f"jacfwd time: {using_fwd.timeit(500)}")
print(f'jacrev time: {using_bwd.timeit(500)}') print(f"jacrev time: {using_bwd.timeit(500)}")
# Benchmark with more outputs than inputs # Benchmark with more outputs than inputs
Din = 2048 Din = 2048
@ -118,8 +129,8 @@ x = torch.randn(Din)
using_fwd = Timer(stmt="jacfwd(predict, argnums=2)(weight, bias, x)", globals=globals()) using_fwd = Timer(stmt="jacfwd(predict, argnums=2)(weight, bias, x)", globals=globals())
using_bwd = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals()) using_bwd = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals())
print(f'jacfwd time: {using_fwd.timeit(500)}') print(f"jacfwd time: {using_fwd.timeit(500)}")
print(f'jacrev time: {using_bwd.timeit(500)}') print(f"jacrev time: {using_bwd.timeit(500)}")
###################################################################### ######################################################################
# Hessian computation with functorch.hessian # Hessian computation with functorch.hessian
@ -132,6 +143,7 @@ print(f'jacrev time: {using_bwd.timeit(500)}')
# Depending on your model, you may want to use ``jacfwd(jacfwd(f))`` or # Depending on your model, you may want to use ``jacfwd(jacfwd(f))`` or
# ``jacrev(jacrev(f))`` instead to compute hessians. # ``jacrev(jacrev(f))`` instead to compute hessians.
from functorch import hessian from functorch import hessian
# # TODO: make sure PyTorch has tanh_backward implemented for jvp!! # # TODO: make sure PyTorch has tanh_backward implemented for jvp!!
# hess0 = hessian(predict, argnums=2)(weight, bias, x) # hess0 = hessian(predict, argnums=2)(weight, bias, x)
# hess1 = jacfwd(jacfwd(predict, argnums=2), argnums=2)(weight, bias, x) # hess1 = jacfwd(jacfwd(predict, argnums=2), argnums=2)(weight, bias, x)
@ -148,9 +160,11 @@ hess2 = jacrev(jacrev(predict, argnums=2), argnums=2)(weight, bias, x)
# The easiest way to do this is to sum over the batch dimension and then # The easiest way to do this is to sum over the batch dimension and then
# compute the Jacobian of that function: # compute the Jacobian of that function:
def predict_with_output_summed(weight, bias, x): def predict_with_output_summed(weight, bias, x):
return predict(weight, bias, x).sum(0) return predict(weight, bias, x).sum(0)
batch_size = 64 batch_size = 64
Din = 31 Din = 31
Dout = 33 Dout = 33

View File

@ -12,8 +12,10 @@ and optimization research.
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
torch.manual_seed(0) torch.manual_seed(0)
# Here's a simple CNN # Here's a simple CNN
class SimpleCNN(nn.Module): class SimpleCNN(nn.Module):
def __init__(self): def __init__(self):
@ -37,12 +39,14 @@ class SimpleCNN(nn.Module):
output = x output = x
return output return output
def loss_fn(predictions, targets): def loss_fn(predictions, targets):
return F.nll_loss(predictions, targets) return F.nll_loss(predictions, targets)
# Let's generate a batch of dummy data. Pretend that we're working with an # Let's generate a batch of dummy data. Pretend that we're working with an
# MNIST dataset where the images are 28 by 28 and we have a minibatch of size 64. # MNIST dataset where the images are 28 by 28 and we have a minibatch of size 64.
device = 'cuda' device = "cuda"
num_models = 10 num_models = 10
batch_size = 64 batch_size = 64
data = torch.randn(batch_size, 1, 28, 28, device=device) data = torch.randn(batch_size, 1, 28, 28, device=device)
@ -56,6 +60,7 @@ predictions = model(data)
loss = loss_fn(predictions, targets) loss = loss_fn(predictions, targets)
loss.backward() loss.backward()
# Conceptually, per-sample-gradient computation is equivalent to: for each sample # Conceptually, per-sample-gradient computation is equivalent to: for each sample
# of the data, perform a forward and a backward pass to get a gradient. # of the data, perform a forward and a backward pass to get a gradient.
def compute_grad(sample, target): def compute_grad(sample, target):
@ -65,12 +70,14 @@ def compute_grad(sample, target):
loss = loss_fn(prediction, target) loss = loss_fn(prediction, target)
return torch.autograd.grad(loss, list(model.parameters())) return torch.autograd.grad(loss, list(model.parameters()))
def compute_sample_grads(data, targets): def compute_sample_grads(data, targets):
sample_grads = [compute_grad(data[i], targets[i]) for i in range(batch_size)] sample_grads = [compute_grad(data[i], targets[i]) for i in range(batch_size)]
sample_grads = zip(*sample_grads) sample_grads = zip(*sample_grads)
sample_grads = [torch.stack(shards) for shards in sample_grads] sample_grads = [torch.stack(shards) for shards in sample_grads]
return sample_grads return sample_grads
per_sample_grads = compute_sample_grads(data, targets) per_sample_grads = compute_sample_grads(data, targets)
# sample_grads[0] is the per-sample-grad for model.conv1.weight # sample_grads[0] is the per-sample-grad for model.conv1.weight
@ -85,9 +92,11 @@ print(per_sample_grads[0].shape)
# We can compute per-sample-gradients efficiently by using function transforms. # We can compute per-sample-gradients efficiently by using function transforms.
# First, let's create a stateless functional version of ``model`` by using # First, let's create a stateless functional version of ``model`` by using
# ``functorch.make_functional_with_buffers``. # ``functorch.make_functional_with_buffers``.
from functorch import make_functional_with_buffers, vmap, grad from functorch import grad, make_functional_with_buffers, vmap
fmodel, params, buffers = make_functional_with_buffers(model) fmodel, params, buffers = make_functional_with_buffers(model)
# Next, let's define a function to compute the loss of the model given a single # Next, let's define a function to compute the loss of the model given a single
# input rather than a batch of inputs. It is important that this function accepts the # input rather than a batch of inputs. It is important that this function accepts the
# parameters, the input, and the target, because we will be transforming over them. # parameters, the input, and the target, because we will be transforming over them.
@ -100,6 +109,7 @@ def compute_loss(params, buffers, sample, target):
loss = loss_fn(predictions, targets) loss = loss_fn(predictions, targets)
return loss return loss
# Now, let's use ``grad`` to create a new function that computes the gradient # Now, let's use ``grad`` to create a new function that computes the gradient
# with respect to the first argument of compute_loss (i.e. the params). # with respect to the first argument of compute_loss (i.e. the params).
ft_compute_grad = grad(compute_loss) ft_compute_grad = grad(compute_loss)

View File

@ -1,8 +1,9 @@
import yaml
import csv import csv
import torch
from collections import defaultdict from collections import defaultdict
import torch
import yaml
def get_ops_for_key(key): def get_ops_for_key(key):
# Needs modified PyTorch C++ code to work # Needs modified PyTorch C++ code to work
@ -12,7 +13,7 @@ def get_ops_for_key(key):
ops = torch._C._dispatch_get_registrations_for_dispatch_key(key) ops = torch._C._dispatch_get_registrations_for_dispatch_key(key)
cleaned_ops = [] cleaned_ops = []
for i in ops: for i in ops:
if 'aten::' not in i: if "aten::" not in i:
continue continue
cleaned_ops.append(i[6:].strip()) cleaned_ops.append(i[6:].strip())
return set(cleaned_ops) return set(cleaned_ops)
@ -20,12 +21,17 @@ def get_ops_for_key(key):
def gen_data(special_op_lists, analysis_name): def gen_data(special_op_lists, analysis_name):
all_ops = get_ops_for_key(None) all_ops = get_ops_for_key(None)
composite_ops = get_ops_for_key('CompositeImplicitAutograd') composite_ops = get_ops_for_key("CompositeImplicitAutograd")
noncomposite_ops = all_ops - composite_ops noncomposite_ops = all_ops - composite_ops
ops = yaml.load(open('../../aten/src/ATen/native/native_functions.yaml').read(), Loader=yaml.CLoader) ops = yaml.load(
open("../../aten/src/ATen/native/native_functions.yaml").read(),
Loader=yaml.CLoader,
)
annotated_ops = {a.strip(): b.strip() for a, b in list(csv.reader(open('annotated_ops')))} annotated_ops = {
a.strip(): b.strip() for a, b in list(csv.reader(open("annotated_ops")))
}
from collections import defaultdict from collections import defaultdict
uniq_ops = [] uniq_ops = []
@ -33,18 +39,18 @@ def gen_data(special_op_lists, analysis_name):
overload_types = defaultdict(list) overload_types = defaultdict(list)
cnt = 0 cnt = 0
for op in ops: for op in ops:
func_str = op['func'] func_str = op["func"]
name = func_str[:func_str.index('(')] name = func_str[: func_str.index("(")]
if '.' in name: if "." in name:
uniq_name = name[:name.index('.')] uniq_name = name[: name.index(".")]
overload_types[name[name.index('.') + 1:]].append(name) overload_types[name[name.index(".") + 1 :]].append(name)
else: else:
uniq_name = name uniq_name = name
op['name'] = uniq_name op["name"] = uniq_name
full_name = func_str[:func_str.index('(')] full_name = func_str[: func_str.index("(")]
op['full_name'] = full_name op["full_name"] = full_name
ret_type = func_str[func_str.index('->') + 3:] ret_type = func_str[func_str.index("->") + 3 :]
op['ret_type'] = ret_type op["ret_type"] = ret_type
cnt += 1 cnt += 1
if uniq_name in uniq_names: if uniq_name in uniq_names:
continue continue
@ -54,70 +60,78 @@ def gen_data(special_op_lists, analysis_name):
def annotate_ops(ops, is_unique): def annotate_ops(ops, is_unique):
categorization = defaultdict(int) categorization = defaultdict(int)
for op in ops: for op in ops:
if op['name'][-1] == '_': if op["name"][-1] == "_":
categorization['inplace'] += 1 categorization["inplace"] += 1
op['meta'] = 'inplace' op["meta"] = "inplace"
continue continue
if not is_unique and 'a!' in op['func'].lower(): if not is_unique and "a!" in op["func"].lower():
categorization['out'] += 1 categorization["out"] += 1
op['meta'] = 'out' op["meta"] = "out"
continue continue
if 'conv' in op['name']: if "conv" in op["name"]:
categorization['conv'] += 1 categorization["conv"] += 1
op['meta'] = 'conv' op["meta"] = "conv"
continue continue
if 'pool' in op['name']: if "pool" in op["name"]:
categorization['pool'] += 1 categorization["pool"] += 1
op['meta'] = 'pool' op["meta"] = "pool"
continue continue
if 'backward' in op['name']: if "backward" in op["name"]:
categorization['backward'] += 1 categorization["backward"] += 1
op['meta'] = 'backward' op["meta"] = "backward"
continue continue
if op['name'][0] == '_' and op['name'][1] != '_': if op["name"][0] == "_" and op["name"][1] != "_":
categorization['private'] += 1 categorization["private"] += 1
op['meta'] = 'private' op["meta"] = "private"
continue continue
if 'batch_norm' in op['name']: if "batch_norm" in op["name"]:
categorization['batch_norm'] += 1 categorization["batch_norm"] += 1
op['meta'] = 'batch_norm' op["meta"] = "batch_norm"
continue continue
if 'Tensor' not in op['func'] or 'Tensor' not in op['ret_type']: if "Tensor" not in op["func"] or "Tensor" not in op["ret_type"]:
categorization['non_tensor'] += 1 categorization["non_tensor"] += 1
op['meta'] = 'non_tensor' op["meta"] = "non_tensor"
continue continue
if 'cudnn' in op['name'] or 'mkldnn' in op['name'] or 'miopen' in op['name'] or \ if (
'native' in op['name'] or 'thnn' in op['name'] or 'slow' in op['name']: "cudnn" in op["name"]
categorization['backend'] += 1 or "mkldnn" in op["name"]
op['meta'] = 'backend' or "miopen" in op["name"]
or "native" in op["name"]
or "thnn" in op["name"]
or "slow" in op["name"]
):
categorization["backend"] += 1
op["meta"] = "backend"
continue continue
if op['name'] in annotated_ops: if op["name"] in annotated_ops:
categorization['core'] += 1 categorization["core"] += 1
op['meta'] = 'core ' + annotated_ops[op['name']] op["meta"] = "core " + annotated_ops[op["name"]]
continue continue
categorization['core'] += 1 categorization["core"] += 1
op['meta'] = 'core unknown' op["meta"] = "core unknown"
return categorization return categorization
annotate_ops(ops, is_unique=False) annotate_ops(ops, is_unique=False)
with open(f"{analysis_name}", 'w') as f: with open(f"{analysis_name}", "w") as f:
for op in ops: for op in ops:
info = [ info = [
op['full_name'], op['meta'], op['full_name'] not in noncomposite_ops op["full_name"],
op["meta"],
op["full_name"] not in noncomposite_ops,
] + [check(op) for check in special_op_lists] ] + [check(op) for check in special_op_lists]
f.write(','.join([str(i) for i in info]) + '\n') f.write(",".join([str(i) for i in info]) + "\n")
def name_check(lst): def name_check(lst):
return lambda x: x['name'] in lst return lambda x: x["name"] in lst
def full_name_check(lst): def full_name_check(lst):
return lambda x: x['full_name'] in lst return lambda x: x["full_name"] in lst
# Generates batching rule data # Generates batching rule data
gen_data([full_name_check(get_ops_for_key('FuncTorchBatched'))], 'vmap.txt') gen_data([full_name_check(get_ops_for_key("FuncTorchBatched"))], "vmap.txt")
def remove_suffix(input_string, suffix): def remove_suffix(input_string, suffix):
@ -125,6 +139,7 @@ def remove_suffix(input_string, suffix):
return input_string[: -len(suffix)] return input_string[: -len(suffix)]
return input_string return input_string
def remove_prefix(input_string, prefix): def remove_prefix(input_string, prefix):
if prefix and input_string.startswith(prefix): if prefix and input_string.startswith(prefix):
return input_string[len(prefix) :] return input_string[len(prefix) :]
@ -132,26 +147,36 @@ def remove_prefix(input_string, prefix):
if True: if True:
with open('run_ops.txt') as f: with open("run_ops.txt") as f:
opinfo_ops = [remove_suffix(i.strip(), '.default') for i in f.readlines()] opinfo_ops = [remove_suffix(i.strip(), ".default") for i in f.readlines()]
with open('count_ops.txt') as f: with open("count_ops.txt") as f:
opinfo_counts = [i.strip() for i in f.readlines()] opinfo_counts = [i.strip() for i in f.readlines()]
opinfo_counts = defaultdict(int, dict(zip(opinfo_ops, opinfo_counts))) opinfo_counts = defaultdict(int, dict(zip(opinfo_ops, opinfo_counts)))
def count_fn(x): def count_fn(x):
return opinfo_counts[x['full_name']] return opinfo_counts[x["full_name"]]
with open('run_decompositions.txt') as f: with open("run_decompositions.txt") as f:
decomposed_ops = [remove_suffix(i.strip(), '.default') for i in f.readlines()] decomposed_ops = [remove_suffix(i.strip(), ".default") for i in f.readlines()]
with open('public_api') as f: with open("public_api") as f:
ref_api = [i.strip() for i in f.readlines()] ref_api = [i.strip() for i in f.readlines()]
def has_ref_impl(x): def has_ref_impl(x):
name = x['name'] name = x["name"]
for prefix in ["linalg_", "special_"]: for prefix in ["linalg_", "special_"]:
name = remove_prefix(name, prefix) name = remove_prefix(name, prefix)
prefixes = ['nn.functional', 'fft', 'special', 'linalg'] prefixes = ["nn.functional", "fft", "special", "linalg"]
return any(f"{prefix}.{name}" in ref_api for prefix in prefixes) or name in ref_api return (
any(f"{prefix}.{name}" in ref_api for prefix in prefixes) or name in ref_api
)
gen_data([full_name_check(opinfo_ops), full_name_check(decomposed_ops), count_fn, has_ref_impl], 'decompositions.txt') gen_data(
[
full_name_check(opinfo_ops),
full_name_check(decomposed_ops),
count_fn,
has_ref_impl,
],
"decompositions.txt",
)