mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[3/N] fix typo in other folders (#166606)
fix typo in other folders #166374 #166126 _typos.toml ```bash [files] extend-exclude = ["tools/linter/dictionary.txt"] [default.extend-words] nd = "nd" arange = "arange" Nd = "Nd" GLOBALs = "GLOBALs" hte = "hte" iy = "iy" PN = "PN" Dout = "Dout" optin = "optin" gam = "gam" PTD = "PTD" Sur = "Sur" nin = "nin" tme = "tme" inpt = "inpt" mis = "mis" Raison = "Raison" ouput = "ouput" nto = "nto" Onwer = "Onwer" callibrate = "callibrate" ser = "ser" Metdata = "Metdata" ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/166606 Approved by: https://github.com/ezyang
This commit is contained in:
parent
32920926f0
commit
369f2d6951
|
|
@ -374,7 +374,7 @@ cmake_dependent_option(
|
||||||
"Build the lazy Torchscript backend, not compatible with mobile builds" ON
|
"Build the lazy Torchscript backend, not compatible with mobile builds" ON
|
||||||
"NOT INTERN_BUILD_MOBILE" OFF)
|
"NOT INTERN_BUILD_MOBILE" OFF)
|
||||||
cmake_dependent_option(BUILD_FUNCTORCH "Build Functorch" ON "BUILD_PYTHON" OFF)
|
cmake_dependent_option(BUILD_FUNCTORCH "Build Functorch" ON "BUILD_PYTHON" OFF)
|
||||||
cmake_dependent_option(BUILD_BUNDLE_PTXAS "Bundle PTX into torch/bin fodler"
|
cmake_dependent_option(BUILD_BUNDLE_PTXAS "Bundle PTX into torch/bin folder"
|
||||||
OFF "USE_CUDA" OFF)
|
OFF "USE_CUDA" OFF)
|
||||||
cmake_dependent_option(USE_KLEIDIAI "Use KleidiAI for the ARM CPU & AARCH64 architecture." ON
|
cmake_dependent_option(USE_KLEIDIAI "Use KleidiAI for the ARM CPU & AARCH64 architecture." ON
|
||||||
"CPU_AARCH64" OFF)
|
"CPU_AARCH64" OFF)
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
// Implementation of specal math functions for Metal
|
// Implementation of special math functions for Metal
|
||||||
#pragma once
|
#pragma once
|
||||||
#include <c10/metal/expm1f.h>
|
#include <c10/metal/expm1f.h>
|
||||||
#include <c10/metal/igamma.h>
|
#include <c10/metal/igamma.h>
|
||||||
|
|
|
||||||
|
|
@ -34,7 +34,7 @@ struct MemEvent {
|
||||||
bool overlaps(const MemBlock& a, const MemBlock& b) {
|
bool overlaps(const MemBlock& a, const MemBlock& b) {
|
||||||
// two blocks dont overlap if
|
// two blocks dont overlap if
|
||||||
// |---a--------|--------------b--------|
|
// |---a--------|--------------b--------|
|
||||||
// strat_a end_a <= start_b end_b
|
// start_a end_a <= start_b end_b
|
||||||
return !(
|
return !(
|
||||||
(a.end_offset <= b.start_offset) || (b.end_offset <= a.start_offset));
|
(a.end_offset <= b.start_offset) || (b.end_offset <= a.start_offset));
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -33,7 +33,7 @@ struct bitset final {
|
||||||
constexpr bitset() noexcept = default;
|
constexpr bitset() noexcept = default;
|
||||||
constexpr bitset(const bitset&) noexcept = default;
|
constexpr bitset(const bitset&) noexcept = default;
|
||||||
constexpr bitset(bitset&&) noexcept = default;
|
constexpr bitset(bitset&&) noexcept = default;
|
||||||
// there is an issure for gcc 5.3.0 when define default function as constexpr
|
// there is an issue for gcc 5.3.0 when define default function as constexpr
|
||||||
// see https://gcc.gnu.org/bugzilla/show_bug.cgi?id=68754.
|
// see https://gcc.gnu.org/bugzilla/show_bug.cgi?id=68754.
|
||||||
bitset& operator=(const bitset&) noexcept = default;
|
bitset& operator=(const bitset&) noexcept = default;
|
||||||
bitset& operator=(bitset&&) noexcept = default;
|
bitset& operator=(bitset&&) noexcept = default;
|
||||||
|
|
|
||||||
|
|
@ -38,7 +38,7 @@ uint32_t crc32_combine (uint32_t crcA, uint32_t crcB, size_t lengthB);
|
||||||
|
|
||||||
/// compute CRC32 (bitwise algorithm)
|
/// compute CRC32 (bitwise algorithm)
|
||||||
uint32_t crc32_bitwise (const void* data, size_t length, uint32_t previousCrc32 = 0);
|
uint32_t crc32_bitwise (const void* data, size_t length, uint32_t previousCrc32 = 0);
|
||||||
/// compute CRC32 (half-byte algoritm)
|
/// compute CRC32 (half-byte algorithm)
|
||||||
uint32_t crc32_halfbyte(const void* data, size_t length, uint32_t previousCrc32 = 0);
|
uint32_t crc32_halfbyte(const void* data, size_t length, uint32_t previousCrc32 = 0);
|
||||||
|
|
||||||
#ifdef CRC32_USE_LOOKUP_TABLE_BYTE
|
#ifdef CRC32_USE_LOOKUP_TABLE_BYTE
|
||||||
|
|
@ -96,7 +96,7 @@ uint32_t crc32_16bytes_prefetch(const void* data, size_t length, uint32_t previo
|
||||||
#define __BIG_ENDIAN 4321
|
#define __BIG_ENDIAN 4321
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// define endianess and some integer data types
|
// define endianness and some integer data types
|
||||||
#if defined(_MSC_VER) || defined(__MINGW32__)
|
#if defined(_MSC_VER) || defined(__MINGW32__)
|
||||||
// Windows always little endian
|
// Windows always little endian
|
||||||
#define __BYTE_ORDER __LITTLE_ENDIAN
|
#define __BYTE_ORDER __LITTLE_ENDIAN
|
||||||
|
|
@ -168,7 +168,7 @@ namespace
|
||||||
/// zlib's CRC32 polynomial
|
/// zlib's CRC32 polynomial
|
||||||
const uint32_t Polynomial = 0xEDB88320;
|
const uint32_t Polynomial = 0xEDB88320;
|
||||||
|
|
||||||
/// swap endianess
|
/// swap endianness
|
||||||
static inline uint32_t swap(uint32_t x)
|
static inline uint32_t swap(uint32_t x)
|
||||||
{
|
{
|
||||||
#if defined(__GNUC__) || defined(__clang__)
|
#if defined(__GNUC__) || defined(__clang__)
|
||||||
|
|
@ -229,7 +229,7 @@ uint32_t crc32_bitwise(const void* data, size_t length, uint32_t previousCrc32)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/// compute CRC32 (half-byte algoritm)
|
/// compute CRC32 (half-byte algorithm)
|
||||||
uint32_t crc32_halfbyte(const void* data, size_t length, uint32_t previousCrc32)
|
uint32_t crc32_halfbyte(const void* data, size_t length, uint32_t previousCrc32)
|
||||||
{
|
{
|
||||||
uint32_t crc = ~previousCrc32; // same as previousCrc32 ^ 0xFFFFFFFF
|
uint32_t crc = ~previousCrc32; // same as previousCrc32 ^ 0xFFFFFFFF
|
||||||
|
|
@ -662,7 +662,7 @@ uint32_t crc32_combine(uint32_t crcA, uint32_t crcB, size_t lengthB)
|
||||||
// - if you append length(B) zeros to A and call it A' (think of it as AAAA000)
|
// - if you append length(B) zeros to A and call it A' (think of it as AAAA000)
|
||||||
// and prepend length(A) zeros to B and call it B' (think of it as 0000BBB)
|
// and prepend length(A) zeros to B and call it B' (think of it as 0000BBB)
|
||||||
// then exists a C' = A' ^ B'
|
// then exists a C' = A' ^ B'
|
||||||
// - remember: if you XOR someting with zero, it remains unchanged: X ^ 0 = X
|
// - remember: if you XOR something with zero, it remains unchanged: X ^ 0 = X
|
||||||
// - that means C' = A concat B so that crc(A concat B) = crc(C') = crc(A') ^ crc(B')
|
// - that means C' = A concat B so that crc(A concat B) = crc(C') = crc(A') ^ crc(B')
|
||||||
// - the trick is to compute crc(A') based on crc(A)
|
// - the trick is to compute crc(A') based on crc(A)
|
||||||
// and crc(B') based on crc(B)
|
// and crc(B') based on crc(B)
|
||||||
|
|
|
||||||
|
|
@ -76,7 +76,7 @@ typedef struct mz_zip_archive mz_zip_archive;
|
||||||
// 2) Writing with 1-pass sequential access
|
// 2) Writing with 1-pass sequential access
|
||||||
// -> We must take care not to require updating values that have already
|
// -> We must take care not to require updating values that have already
|
||||||
// been written. We place the variable-length index at the end and do
|
// been written. We place the variable-length index at the end and do
|
||||||
// not put any indicies into the header to fulfill this constraint.
|
// not put any index into the header to fulfill this constraint.
|
||||||
|
|
||||||
// The model.json, which contains all the metadata information,
|
// The model.json, which contains all the metadata information,
|
||||||
// should be written as the last file. One reason is that the size of tensor
|
// should be written as the last file. One reason is that the size of tensor
|
||||||
|
|
|
||||||
|
|
@ -519,7 +519,7 @@ TEST(PyTorchStreamWriterAndReader, SaveAndLoadWithAllocator) {
|
||||||
std::tie(data_ptr, size) = reader.getRecord("key1", &overrideAllocator);
|
std::tie(data_ptr, size) = reader.getRecord("key1", &overrideAllocator);
|
||||||
EXPECT_EQ(overrideAllocator.getAllocatedBytes(), kBytes1);
|
EXPECT_EQ(overrideAllocator.getAllocatedBytes(), kBytes1);
|
||||||
EXPECT_EQ(baseAllocator.getAllocatedBytes(), allocBytes);
|
EXPECT_EQ(baseAllocator.getAllocatedBytes(), allocBytes);
|
||||||
// allcoate with base allocator
|
// allocate with base allocator
|
||||||
std::tie(data_ptr, size) = reader.getRecord("key1");
|
std::tie(data_ptr, size) = reader.getRecord("key1");
|
||||||
EXPECT_EQ(overrideAllocator.getAllocatedBytes(), kBytes1);
|
EXPECT_EQ(overrideAllocator.getAllocatedBytes(), kBytes1);
|
||||||
EXPECT_EQ(baseAllocator.getAllocatedBytes(), allocBytes + kBytes1);
|
EXPECT_EQ(baseAllocator.getAllocatedBytes(), allocBytes + kBytes1);
|
||||||
|
|
|
||||||
2
setup.py
2
setup.py
|
|
@ -1106,7 +1106,7 @@ class build_ext(setuptools.command.build_ext.build_ext):
|
||||||
continue
|
continue
|
||||||
self.copy_file(source_lib, target_lib)
|
self.copy_file(source_lib, target_lib)
|
||||||
# Delete old rpath and add @loader_lib to the rpath
|
# Delete old rpath and add @loader_lib to the rpath
|
||||||
# This should prevent delocate from attempting to package another instance
|
# This should prevent deallocate from attempting to package another instance
|
||||||
# of OpenMP library in torch wheel as well as loading two libomp.dylib into
|
# of OpenMP library in torch wheel as well as loading two libomp.dylib into
|
||||||
# the address space, as libraries are cached by their unresolved names
|
# the address space, as libraries are cached by their unresolved names
|
||||||
install_name_tool_args = [
|
install_name_tool_args = [
|
||||||
|
|
|
||||||
|
|
@ -1060,7 +1060,7 @@ class OutputGraph(OutputGraphCommon):
|
||||||
def module_key_name(*names: Any) -> str:
|
def module_key_name(*names: Any) -> str:
|
||||||
# create a new unique name
|
# create a new unique name
|
||||||
name = "_".join(map(str, names))
|
name = "_".join(map(str, names))
|
||||||
# Strip _buffers[..]/_parmeters[..]/_modules[..] names
|
# Strip _buffers[..]/_parameters[..]/_modules[..] names
|
||||||
name = re.sub(
|
name = re.sub(
|
||||||
r"\._(?:modules|parameters|buffers)\[(['\"])([^'\"\]]+)\1\]", r".\2", name
|
r"\._(?:modules|parameters|buffers)\[(['\"])([^'\"\]]+)\1\]", r".\2", name
|
||||||
)
|
)
|
||||||
|
|
@ -2217,7 +2217,7 @@ class OutputGraph(OutputGraphCommon):
|
||||||
backend_fake_mode = torch._subclasses.FakeTensorMode(
|
backend_fake_mode = torch._subclasses.FakeTensorMode(
|
||||||
shape_env=old_fake_mode.shape_env,
|
shape_env=old_fake_mode.shape_env,
|
||||||
)
|
)
|
||||||
# TODO(voz): Ostensibily, this should be scoped and
|
# TODO(voz): Ostensibly, this should be scoped and
|
||||||
# restore back to old_fake_mode, but doing so currently violates
|
# restore back to old_fake_mode, but doing so currently violates
|
||||||
# a lot of fake_tensor ownership assumptions and runs afoul of detect_fake_mode
|
# a lot of fake_tensor ownership assumptions and runs afoul of detect_fake_mode
|
||||||
self.tracing_context.fake_mode = backend_fake_mode
|
self.tracing_context.fake_mode = backend_fake_mode
|
||||||
|
|
@ -3414,7 +3414,7 @@ class SubgraphTracer(fx.Tracer):
|
||||||
if proxy in self.lifted_freevars:
|
if proxy in self.lifted_freevars:
|
||||||
return self.lifted_freevars[proxy]
|
return self.lifted_freevars[proxy]
|
||||||
|
|
||||||
# We first lift proxy to parent's graph then lift to current grpah's input
|
# We first lift proxy to parent's graph then lift to current graph's input
|
||||||
# so that when we bind symints of the sizes in current graph, those symints
|
# so that when we bind symints of the sizes in current graph, those symints
|
||||||
# would already be lifted as inputs to parent graph.
|
# would already be lifted as inputs to parent graph.
|
||||||
if proxy.tracer != self.parent:
|
if proxy.tracer != self.parent:
|
||||||
|
|
@ -3462,7 +3462,7 @@ class SubgraphTracer(fx.Tracer):
|
||||||
def track_produced_symints(
|
def track_produced_symints(
|
||||||
self, example_value: Any, e_proxy: Union[LazyProxy, torch.fx.Proxy]
|
self, example_value: Any, e_proxy: Union[LazyProxy, torch.fx.Proxy]
|
||||||
) -> None:
|
) -> None:
|
||||||
# When binding the symbols in an exmaple_value, we bind the symbols
|
# When binding the symbols in an example_value, we bind the symbols
|
||||||
# to the proxy's associated Tracer instead of current tracer.
|
# to the proxy's associated Tracer instead of current tracer.
|
||||||
# This is because:
|
# This is because:
|
||||||
# 1. We may be calling wrap_tensors during speculate_subgraph because
|
# 1. We may be calling wrap_tensors during speculate_subgraph because
|
||||||
|
|
|
||||||
|
|
@ -2089,7 +2089,7 @@ class InstructionTranslatorBase(
|
||||||
def _raise_exception_variable(self, val: VariableTracker) -> NoReturn:
|
def _raise_exception_variable(self, val: VariableTracker) -> NoReturn:
|
||||||
# User can raise exception in 2 ways
|
# User can raise exception in 2 ways
|
||||||
# 1) raise exception type - raise NotImplementedError
|
# 1) raise exception type - raise NotImplementedError
|
||||||
# 2) raise exception instance - raise NotImplemetedError("foo")
|
# 2) raise exception instance - raise NotImplementedError("foo")
|
||||||
|
|
||||||
# 1) when user raises exception type
|
# 1) when user raises exception type
|
||||||
val = self._create_exception_type(val)
|
val = self._create_exception_type(val)
|
||||||
|
|
@ -2140,7 +2140,7 @@ class InstructionTranslatorBase(
|
||||||
try:
|
try:
|
||||||
self._raise_exception_variable(val)
|
self._raise_exception_variable(val)
|
||||||
finally:
|
finally:
|
||||||
# Update __cause__/__supppress_context__ in the raised exception
|
# Update __cause__/__suppress_context__ in the raised exception
|
||||||
curr_exc = self.exn_vt_stack.get_current_exception()
|
curr_exc = self.exn_vt_stack.get_current_exception()
|
||||||
cause = self._create_exception_type(from_vt)
|
cause = self._create_exception_type(from_vt)
|
||||||
curr_exc.call_setattr(self, ConstantVariable("__cause__"), cause) # type: ignore[arg-type, union-attr, assignment]
|
curr_exc.call_setattr(self, ConstantVariable("__cause__"), cause) # type: ignore[arg-type, union-attr, assignment]
|
||||||
|
|
@ -2417,8 +2417,8 @@ class InstructionTranslatorBase(
|
||||||
|
|
||||||
# Users can check exception in 3 ways
|
# Users can check exception in 3 ways
|
||||||
# 1) except NotImplementedError --> BuiltinVariable
|
# 1) except NotImplementedError --> BuiltinVariable
|
||||||
# 2) except CustomException --> UserDefinedExceptionClasVariable
|
# 2) except CustomException --> UserDefinedExceptionClassVariable
|
||||||
# 3) except (NotImplemetedError, AttributeError) -> TupleVariable
|
# 3) except (NotImplementedError, AttributeError) -> TupleVariable
|
||||||
|
|
||||||
if not isinstance(
|
if not isinstance(
|
||||||
expected_exc_types,
|
expected_exc_types,
|
||||||
|
|
|
||||||
|
|
@ -54,7 +54,7 @@ if TYPE_CHECKING:
|
||||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||||
|
|
||||||
|
|
||||||
# [Adding a new supported class within the keys of ConstDictVarialble]
|
# [Adding a new supported class within the keys of ConstDictVariable]
|
||||||
# - Add its tracker type to is_hashable
|
# - Add its tracker type to is_hashable
|
||||||
# - (perhaps) Define how it is compared in _HashableTracker._eq_impl
|
# - (perhaps) Define how it is compared in _HashableTracker._eq_impl
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -765,7 +765,7 @@ class TS2FXGraphConverter:
|
||||||
raise ValueError(f"Unsupported JitType ({input_type}) when get device")
|
raise ValueError(f"Unsupported JitType ({input_type}) when get device")
|
||||||
|
|
||||||
def convert_prim_GetAttr(self, node: torch._C.Node):
|
def convert_prim_GetAttr(self, node: torch._C.Node):
|
||||||
# Build fully qulified name
|
# Build fully qualified name
|
||||||
attr_fqn = get_attribute_fqn_from_ts_node(self.name_to_attribute_fqn, node)
|
attr_fqn = get_attribute_fqn_from_ts_node(self.name_to_attribute_fqn, node)
|
||||||
output_name = node.output().debugName()
|
output_name = node.output().debugName()
|
||||||
self.name_to_attribute_fqn[output_name] = attr_fqn
|
self.name_to_attribute_fqn[output_name] = attr_fqn
|
||||||
|
|
@ -1455,7 +1455,7 @@ DEBUG: (TORCH_LOGS="+export" <cmd>), additionally
|
||||||
)
|
)
|
||||||
gm = graph_converter.convert()
|
gm = graph_converter.convert()
|
||||||
|
|
||||||
# Post-proccessing step to deal with quantized operators.
|
# Post-processing step to deal with quantized operators.
|
||||||
replace_quantized_ops_with_standard_ops(gm)
|
replace_quantized_ops_with_standard_ops(gm)
|
||||||
log.info("GraphModule: %s", gm.print_readable(print_output=False))
|
log.info("GraphModule: %s", gm.print_readable(print_output=False))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1477,7 +1477,7 @@ def register_module_as_pytree_input_node(cls: type[torch.nn.Module]) -> None:
|
||||||
flattened, _ = flatten_fn(obj)
|
flattened, _ = flatten_fn(obj)
|
||||||
|
|
||||||
# NOTE: This helper function will replicate an nn.Module in the exactly same
|
# NOTE: This helper function will replicate an nn.Module in the exactly same
|
||||||
# structure to be used together with _reparametrize_module. This will
|
# structure to be used together with _reparameterize_module. This will
|
||||||
# create a clone of the module with the new parameters and buffers without
|
# create a clone of the module with the new parameters and buffers without
|
||||||
# affecting the original module.
|
# affecting the original module.
|
||||||
def copy_module(mod: torch.nn.Module):
|
def copy_module(mod: torch.nn.Module):
|
||||||
|
|
|
||||||
|
|
@ -771,7 +771,7 @@ class GenericAOTAutogradCacheEntry(Generic[TForward, TBackward]):
|
||||||
maybe_subclass_meta: Optional[SubclassMeta]
|
maybe_subclass_meta: Optional[SubclassMeta]
|
||||||
num_fw_outs_saved_for_bw: Optional[int]
|
num_fw_outs_saved_for_bw: Optional[int]
|
||||||
|
|
||||||
# Used by RuntimeWrapepr
|
# Used by RuntimeWrapper
|
||||||
indices_of_inps_to_detach: list[int]
|
indices_of_inps_to_detach: list[int]
|
||||||
|
|
||||||
# Time taken to trace/compile the forward
|
# Time taken to trace/compile the forward
|
||||||
|
|
|
||||||
|
|
@ -99,7 +99,7 @@ def fx_graph_cse(fx_g: torch.fx.graph.Graph):
|
||||||
# so it's not worth CSEing.
|
# so it's not worth CSEing.
|
||||||
or get_aten_target(n) is aten.empty
|
or get_aten_target(n) is aten.empty
|
||||||
or n in nodes_that_alias_outputs
|
or n in nodes_that_alias_outputs
|
||||||
# This CSE pass currently doesn't handle re-propogation of unbacked
|
# This CSE pass currently doesn't handle re-propagation of unbacked
|
||||||
# meta where it'll sometimes eliminate a _local_scalar_dense but not
|
# meta where it'll sometimes eliminate a _local_scalar_dense but not
|
||||||
# replace the meta of downstream users. eg. one bug we've seen is:
|
# replace the meta of downstream users. eg. one bug we've seen is:
|
||||||
#
|
#
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,7 @@ from torch.utils._config_module import Config, install_config_module
|
||||||
|
|
||||||
# [@compile_ignored: debug]
|
# [@compile_ignored: debug]
|
||||||
_save_config_ignore = [
|
_save_config_ignore = [
|
||||||
# callable not serializeable
|
# callable not serializable
|
||||||
"joint_custom_pass",
|
"joint_custom_pass",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -44,7 +44,7 @@ class SchemaHolder:
|
||||||
return cls(pytree.tree_unflatten([], tree_spec).schema)
|
return cls(pytree.tree_unflatten([], tree_spec).schema)
|
||||||
|
|
||||||
|
|
||||||
# regsiter_constant allows us to get a tree_spec from pytree.tree_flatten(SchemaHolder(FunctionSchema)).
|
# register_constant allows us to get a tree_spec from pytree.tree_flatten(SchemaHolder(FunctionSchema)).
|
||||||
# The tree_spec is proxable in the graph and we can get back the schema via
|
# The tree_spec is proxable in the graph and we can get back the schema via
|
||||||
# schema = pytree.tree_unflatten([], tree_spec).schema
|
# schema = pytree.tree_unflatten([], tree_spec).schema
|
||||||
pytree.register_constant(SchemaHolder)
|
pytree.register_constant(SchemaHolder)
|
||||||
|
|
|
||||||
|
|
@ -312,7 +312,7 @@ def generic_scan(operator, init, xs, dim=0, additional_inputs=()):
|
||||||
out_tensor_mask = get_tensor_mask(dummy_out)
|
out_tensor_mask = get_tensor_mask(dummy_out)
|
||||||
dummy_out_masked = mask_list(out_tensor_mask, dummy_out)
|
dummy_out_masked = mask_list(out_tensor_mask, dummy_out)
|
||||||
|
|
||||||
# Pre-alocate
|
# Pre-allocate
|
||||||
# outs -> Output matrix
|
# outs -> Output matrix
|
||||||
# idxs -> Index matrix for scatter_
|
# idxs -> Index matrix for scatter_
|
||||||
# out: (num_elems, M, N, ...)
|
# out: (num_elems, M, N, ...)
|
||||||
|
|
|
||||||
|
|
@ -708,7 +708,7 @@ def _stack_pytree(pytrees):
|
||||||
# is partitioned into in order to recover it in saved_tensors_and_symints.
|
# is partitioned into in order to recover it in saved_tensors_and_symints.
|
||||||
#
|
#
|
||||||
# In saved_tensors_and_symints, we can recover the original args by:
|
# In saved_tensors_and_symints, we can recover the original args by:
|
||||||
# iterating over the pos list and pop one item from the front of paritioned_args[pos[i]].
|
# iterating over the pos list and pop one item from the front of partitioned_args[pos[i]].
|
||||||
# We use t_idx and s_idx to keep track of the next index of the item we are going to pop for the two lists.
|
# We use t_idx and s_idx to keep track of the next index of the item we are going to pop for the two lists.
|
||||||
def save_tensors_and_symints_for_backward(ctx, args):
|
def save_tensors_and_symints_for_backward(ctx, args):
|
||||||
assert all(
|
assert all(
|
||||||
|
|
|
||||||
|
|
@ -660,7 +660,7 @@ class WhileLoopStackOutputOp(HigherOrderOperator):
|
||||||
#
|
#
|
||||||
# gx = gy0 * bw(y0, x),
|
# gx = gy0 * bw(y0, x),
|
||||||
#
|
#
|
||||||
# where gy0 denotes the graident of loss with respect to y0, and bw(y0, x) denotes the graident of y0 with
|
# where gy0 denotes the gradient of loss with respect to y0, and bw(y0, x) denotes the gradient of y0 with
|
||||||
# respect to x. Note that bw can be computed from forward body_fn easily using torch.autograd.grad.
|
# respect to x. Note that bw can be computed from forward body_fn easily using torch.autograd.grad.
|
||||||
# We could substitute the unknowns gy0, gy1, ..., with chain rule until gy4:
|
# We could substitute the unknowns gy0, gy1, ..., with chain rule until gy4:
|
||||||
#
|
#
|
||||||
|
|
@ -769,7 +769,7 @@ class WhileLoopAutogradOp(torch.autograd.Function):
|
||||||
# Note [Handle inputs that're not differentiable]
|
# Note [Handle inputs that're not differentiable]
|
||||||
# When a forward input is non-differentiable e.g. a symint or an integer tensor, their gradients
|
# When a forward input is non-differentiable e.g. a symint or an integer tensor, their gradients
|
||||||
# will be None. However, we don't want to return None in the subgraph because this complicates the
|
# will be None. However, we don't want to return None in the subgraph because this complicates the
|
||||||
# inductor codegen, where we need to do a non-unform treatment for None and tensors.
|
# inductor codegen, where we need to do a non-uniform treatment for None and tensors.
|
||||||
# So we set up masks and filter the None gradients so that only tensors are returned from each step.
|
# So we set up masks and filter the None gradients so that only tensors are returned from each step.
|
||||||
carries_tensor_masks = [
|
carries_tensor_masks = [
|
||||||
bool(isinstance(t, torch.Tensor) and t.dtype.is_floating_point)
|
bool(isinstance(t, torch.Tensor) and t.dtype.is_floating_point)
|
||||||
|
|
|
||||||
|
|
@ -348,7 +348,7 @@ def _scatter_fused_allreduce_waits(
|
||||||
# Some descendant users of the orig_comm_blocks may be scheduled before
|
# Some descendant users of the orig_comm_blocks may be scheduled before
|
||||||
# the fused all_reduce. For example, the user nodes of the very first
|
# the fused all_reduce. For example, the user nodes of the very first
|
||||||
# all_reduce may be scheduled before the second all_reduce. Since the
|
# all_reduce may be scheduled before the second all_reduce. Since the
|
||||||
# fused all_reduce is inserted right after the last all_reudce, the
|
# fused all_reduce is inserted right after the last all_reduce, the
|
||||||
# order can be wrong.
|
# order can be wrong.
|
||||||
# `incorrect_order_nodes` records these nodes.
|
# `incorrect_order_nodes` records these nodes.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -991,7 +991,7 @@ if torch._C._has_mkldnn:
|
||||||
|
|
||||||
def _recover_linear():
|
def _recover_linear():
|
||||||
# convert reshape+linear+reshape to a single linear for applying fusion path.
|
# convert reshape+linear+reshape to a single linear for applying fusion path.
|
||||||
# concat_linear (pass_number=0) -> mkldnn_linear_pack (pass_numer=1) -> _recover_linear(pass_number=2)
|
# concat_linear (pass_number=0) -> mkldnn_linear_pack (pass_number=1) -> _recover_linear(pass_number=2)
|
||||||
@register_freezing_graph_pattern(
|
@register_freezing_graph_pattern(
|
||||||
CallFunction(
|
CallFunction(
|
||||||
aten.reshape.default,
|
aten.reshape.default,
|
||||||
|
|
|
||||||
|
|
@ -585,7 +585,7 @@ def decompose_scan_to_while_loop(gm: torch.fx.GraphModule):
|
||||||
# NOTE [Pre-allocate scan's output buffer]
|
# NOTE [Pre-allocate scan's output buffer]
|
||||||
# In order to pre-allocate the output buffer for ys, we rely on the meta of scan's fx_node.
|
# In order to pre-allocate the output buffer for ys, we rely on the meta of scan's fx_node.
|
||||||
# However, the meta consists of concrete symints, we need to bind those symints with
|
# However, the meta consists of concrete symints, we need to bind those symints with
|
||||||
# proxies in order to trace the torch.empyt_strided call correctly.
|
# proxies in order to trace the torch.empty_strided call correctly.
|
||||||
#
|
#
|
||||||
# Also note that basic free symbols of tensor's shapes are guaranteed to be lifted as subgraph inputs
|
# Also note that basic free symbols of tensor's shapes are guaranteed to be lifted as subgraph inputs
|
||||||
# in dynamo so we can always re-construct the sym expression from placeholders.
|
# in dynamo so we can always re-construct the sym expression from placeholders.
|
||||||
|
|
|
||||||
|
|
@ -677,7 +677,7 @@ class CompiledFxGraph(OutputCode):
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
# On the forward we don't know whether or not
|
# On the forward we don't know whether or not
|
||||||
# boxed_foward_device_index is set yet
|
# boxed_forward_device_index is set yet
|
||||||
boxed_forward_device_index = graph_kwargs.get(
|
boxed_forward_device_index = graph_kwargs.get(
|
||||||
"boxed_forward_device_index", None
|
"boxed_forward_device_index", None
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -530,7 +530,7 @@ class CachingAutotuner(KernelInterface):
|
||||||
# = regs_per_multiprocessor / (nreg * 32 * num_warps)
|
# = regs_per_multiprocessor / (nreg * 32 * num_warps)
|
||||||
# < regs_per_multiprocessor / ((regs_per_multiprocessor / max_threads_per_multi_processor) * 32 * num_warps)
|
# < regs_per_multiprocessor / ((regs_per_multiprocessor / max_threads_per_multi_processor) * 32 * num_warps)
|
||||||
# = max_threads_per_multi_processor / (32 * num_warps)
|
# = max_threads_per_multi_processor / (32 * num_warps)
|
||||||
# Using a tigher upper bound can reveal more optimization opportunities.
|
# Using a tighter upper bound can reveal more optimization opportunities.
|
||||||
max_blocks_per_sm = max(
|
max_blocks_per_sm = max(
|
||||||
device_prop.regs_per_multiprocessor // nreg_per_block, 1
|
device_prop.regs_per_multiprocessor // nreg_per_block, 1
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -215,7 +215,7 @@ def triton_op(
|
||||||
# the exported program to be high-level and serializable. If we decompose
|
# the exported program to be high-level and serializable. If we decompose
|
||||||
# the custom op to a functional hop and make it a node in exported program,
|
# the custom op to a functional hop and make it a node in exported program,
|
||||||
# we need to figure out ways of serializing the hop and its arguments, which can be triton.jited
|
# we need to figure out ways of serializing the hop and its arguments, which can be triton.jited
|
||||||
# functions and triton dtypes. This is undesireble because:
|
# functions and triton dtypes. This is undesirable because:
|
||||||
# - it can be tedious to maintain a layer that serializes the jited function (e.g. with a string) and dtypes.
|
# - it can be tedious to maintain a layer that serializes the jited function (e.g. with a string) and dtypes.
|
||||||
# - exported program will contain the implementation detail (e.g. triton source code) for a specific
|
# - exported program will contain the implementation detail (e.g. triton source code) for a specific
|
||||||
# backend (GPU), which is probably at a wrong level of abstraction.
|
# backend (GPU), which is probably at a wrong level of abstraction.
|
||||||
|
|
|
||||||
|
|
@ -530,7 +530,7 @@ class HigherOrderOperator(OperatorBase, abc.ABC):
|
||||||
dispatch_key_set = _compute_keyset(args, kwargs, self.non_fallthrough_keys)
|
dispatch_key_set = _compute_keyset(args, kwargs, self.non_fallthrough_keys)
|
||||||
return self.dispatch(dispatch_key_set.highestPriorityTypeId(), *args, **kwargs)
|
return self.dispatch(dispatch_key_set.highestPriorityTypeId(), *args, **kwargs)
|
||||||
|
|
||||||
# NOTE [HigherOrderOprator Schema]
|
# NOTE [HigherOrderOperator Schema]
|
||||||
# Each invocation of a HigherOrderOperator (hop) should have its own schema because
|
# Each invocation of a HigherOrderOperator (hop) should have its own schema because
|
||||||
# the subgraphs and the arguments can be different even for the same hop.
|
# the subgraphs and the arguments can be different even for the same hop.
|
||||||
#
|
#
|
||||||
|
|
|
||||||
|
|
@ -3155,7 +3155,7 @@ def flatten(a: TensorLikeType, start_dim: int = 0, end_dim: int = -1) -> TensorL
|
||||||
|
|
||||||
# Tries to take a view
|
# Tries to take a view
|
||||||
# TODO: we could look at directing collapse_view to skip its meta function here (unsafe_collapse_view)
|
# TODO: we could look at directing collapse_view to skip its meta function here (unsafe_collapse_view)
|
||||||
# Unbacked semnatics: if validty of in-place flattening is undecided we copy.
|
# Unbacked semantics: if validity of in-place flattening is undecided we copy.
|
||||||
new_shape, _new_strides = prims._collapse_view_helper(
|
new_shape, _new_strides = prims._collapse_view_helper(
|
||||||
a, start_dim, end_dim, must_be_valid=None
|
a, start_dim, end_dim, must_be_valid=None
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -523,7 +523,7 @@ def fold_weight(
|
||||||
del original_weights_lookup[str(lookup_counter)]
|
del original_weights_lookup[str(lookup_counter)]
|
||||||
lookup_counter += 1
|
lookup_counter += 1
|
||||||
elif prepack_node is not None:
|
elif prepack_node is not None:
|
||||||
# remove the foled node
|
# remove the fold node
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
# copy other nodes
|
# copy other nodes
|
||||||
|
|
|
||||||
|
|
@ -1213,7 +1213,7 @@ class HistogramObserver(UniformQuantizationObserverBase):
|
||||||
boundaries_new_histogram = torch.linspace(
|
boundaries_new_histogram = torch.linspace(
|
||||||
update_min, update_max, self.bins + 1, device=update_min.device
|
update_min, update_max, self.bins + 1, device=update_min.device
|
||||||
).to(histogram.device)
|
).to(histogram.device)
|
||||||
# this maps the mid-poits of the histogram to the new histogram's space
|
# this maps the mid-points of the histogram to the new histogram's space
|
||||||
bucket_assignments = (
|
bucket_assignments = (
|
||||||
torch.bucketize(mid_points_histogram, boundaries_new_histogram, right=True)
|
torch.bucketize(mid_points_histogram, boundaries_new_histogram, right=True)
|
||||||
- 1
|
- 1
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ def lower_pt2e_quantized_to_x86(
|
||||||
model: torch.fx.GraphModule,
|
model: torch.fx.GraphModule,
|
||||||
example_inputs: tuple[torch.Tensor, ...],
|
example_inputs: tuple[torch.Tensor, ...],
|
||||||
) -> torch.fx.GraphModule:
|
) -> torch.fx.GraphModule:
|
||||||
"""Lower a PT2E-qantized model to x86 backend.
|
"""Lower a PT2E-quantized model to x86 backend.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
* `model` (torch.fx.GraphModule): a model quantized by PT2E quantization flow.
|
* `model` (torch.fx.GraphModule): a model quantized by PT2E quantization flow.
|
||||||
|
|
|
||||||
|
|
@ -4568,7 +4568,7 @@ std::tuple<Tensor, Tensor> linalg_solve_triangular_backward(
|
||||||
if (!grad.defined() || (!A_requires_grad && !B_requires_grad)) {
|
if (!grad.defined() || (!A_requires_grad && !B_requires_grad)) {
|
||||||
return std::make_tuple(Tensor{}, Tensor{});
|
return std::make_tuple(Tensor{}, Tensor{});
|
||||||
}
|
}
|
||||||
// We always need to comput G_B
|
// We always need to compute G_B
|
||||||
const Tensor A_H = A.mH();
|
const Tensor A_H = A.mH();
|
||||||
const Tensor G_B =
|
const Tensor G_B =
|
||||||
at::linalg_solve_triangular(A_H, grad, !upper, left, unitriangular);
|
at::linalg_solve_triangular(A_H, grad, !upper, left, unitriangular);
|
||||||
|
|
|
||||||
|
|
@ -1035,7 +1035,7 @@ PyObject* THCPModule_cudaGetSyncDebugMode(PyObject* self, PyObject* noargs) {
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
static void registerCudaDeviceProperties(PyObject* module) {
|
static void registerCudaDeviceProperties(PyObject* module) {
|
||||||
// Add _cudaDevicePropertires class to torch._C
|
// Add _cudaDeviceProperties class to torch._C
|
||||||
auto m = py::handle(module).cast<py::module>();
|
auto m = py::handle(module).cast<py::module>();
|
||||||
// CUuuid is defined in either cuda.h or driver_types.h
|
// CUuuid is defined in either cuda.h or driver_types.h
|
||||||
// hipified to hipUUID which is defined in hip_runtime_api.h
|
// hipified to hipUUID which is defined in hip_runtime_api.h
|
||||||
|
|
|
||||||
|
|
@ -103,7 +103,7 @@ class StoreExchange {
|
||||||
size_t seq_id_ = 0;
|
size_t seq_id_ = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Teturns a pointer of virtual address that is mapped to the physical memory
|
// Returns a pointer of virtual address that is mapped to the physical memory
|
||||||
// held by the handle.
|
// held by the handle.
|
||||||
void map_block(
|
void map_block(
|
||||||
void** ptr,
|
void** ptr,
|
||||||
|
|
|
||||||
|
|
@ -72,7 +72,7 @@ c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processMessage(
|
||||||
|
|
||||||
auto retFuture = rrefsReadyFuture->thenAsync(
|
auto retFuture = rrefsReadyFuture->thenAsync(
|
||||||
[this,
|
[this,
|
||||||
// std::function must be copyable, hence hae to cast the unique_ptr to
|
// std::function must be copyable, hence has to cast the unique_ptr to
|
||||||
// a shared_ptr here.
|
// a shared_ptr here.
|
||||||
rpc = std::shared_ptr<RpcCommandBase>(std::move(rpc)),
|
rpc = std::shared_ptr<RpcCommandBase>(std::move(rpc)),
|
||||||
messageType = request.type(),
|
messageType = request.type(),
|
||||||
|
|
|
||||||
|
|
@ -240,7 +240,7 @@ class TORCH_API RpcAgent {
|
||||||
// should be profiled or not.
|
// should be profiled or not.
|
||||||
void enableGILProfiling(bool flag);
|
void enableGILProfiling(bool flag);
|
||||||
|
|
||||||
// Retrieve wheher we should profile GIL wait times or not.
|
// Retrieve whether we should profile GIL wait times or not.
|
||||||
bool isGILProfilingEnabled();
|
bool isGILProfilingEnabled();
|
||||||
|
|
||||||
// Set type resolver that will be passed to JIT pickler to resolver type Ptr
|
// Set type resolver that will be passed to JIT pickler to resolver type Ptr
|
||||||
|
|
|
||||||
|
|
@ -3534,7 +3534,7 @@ class RootGuardManager : public GuardManager {
|
||||||
|
|
||||||
void add_no_tensor_aliasing_guard(
|
void add_no_tensor_aliasing_guard(
|
||||||
std::shared_ptr<RelationalGuard> no_tensor_aliasing_guard) {
|
std::shared_ptr<RelationalGuard> no_tensor_aliasing_guard) {
|
||||||
// stash a pointer to the _no_tensor_alising_guard
|
// stash a pointer to the _no_tensor_aliasing_guard
|
||||||
_no_tensor_aliasing_guard = no_tensor_aliasing_guard;
|
_no_tensor_aliasing_guard = no_tensor_aliasing_guard;
|
||||||
this->add_relational_guard_resetter(std::move(no_tensor_aliasing_guard));
|
this->add_relational_guard_resetter(std::move(no_tensor_aliasing_guard));
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -143,7 +143,7 @@ static std::unique_ptr<sycl::kernel> _createKernel(
|
||||||
sycl::range<3> localRange(localRangeZ, localRangeY, localRangeX);
|
sycl::range<3> localRange(localRangeZ, localRangeY, localRangeX);
|
||||||
sycl::nd_range<3> parallelWorkSize(globalRange, localRange);
|
sycl::nd_range<3> parallelWorkSize(globalRange, localRange);
|
||||||
if (sharedMemory) {
|
if (sharedMemory) {
|
||||||
// numParams from sycl info = user provided args + sharedMemroyBuffer
|
// numParams from sycl info = user provided args + sharedMemoryBuffer
|
||||||
numParams -= 1;
|
numParams -= 1;
|
||||||
}
|
}
|
||||||
// Submit the imported kernel.
|
// Submit the imported kernel.
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@
|
||||||
// Because AOTInductor generated code will copy-paste this cpp_prefix.h for
|
// Because AOTInductor generated code will copy-paste this cpp_prefix.h for
|
||||||
// the CPU backend, we have to make sure the used headers are implemented
|
// the CPU backend, we have to make sure the used headers are implemented
|
||||||
// in a header-only way, i.e. all the function and class definitions are
|
// in a header-only way, i.e. all the function and class definitions are
|
||||||
// in .h files instead of .cpp files, to avoid ABI backward-compatiblity
|
// in .h files instead of .cpp files, to avoid ABI backward-compatibility
|
||||||
// breakage.
|
// breakage.
|
||||||
|
|
||||||
#include <ATen/NumericUtils.h>
|
#include <ATen/NumericUtils.h>
|
||||||
|
|
|
||||||
|
|
@ -441,7 +441,7 @@ The following sections look into each the stages in the script frontend in detai
|
||||||
|
|
||||||
[frontend/tree.h](frontend/tree.h)
|
[frontend/tree.h](frontend/tree.h)
|
||||||
|
|
||||||
Our frontends produce ASTs in the form of Tree objects. Trees are similar to [s-expressions](https://en.wikipedia.org/wiki/S-expression). Leafs (i.e. Atoms) are always strings. Compound trees have a `kind` (e.g `TK_CONST` or `TK_IDENT` defined in [lexer.h](frontend/lexer.h)) and a list of sub-trees. For instance, the Tree for `z.sigmoid() - (x + y)` is:
|
Our frontends produce ASTs in the form of Tree objects. Trees are similar to [s-expressions](https://en.wikipedia.org/wiki/S-expression). Leaves (i.e. Atoms) are always strings. Compound trees have a `kind` (e.g `TK_CONST` or `TK_IDENT` defined in [lexer.h](frontend/lexer.h)) and a list of sub-trees. For instance, the Tree for `z.sigmoid() - (x + y)` is:
|
||||||
|
|
||||||
```
|
```
|
||||||
(-
|
(-
|
||||||
|
|
|
||||||
|
|
@ -121,7 +121,7 @@ class NnapiBackend : public PyTorchBackendInterface {
|
||||||
shape_compute_module.run_method("prepare", ser_model, inputs)
|
shape_compute_module.run_method("prepare", ser_model, inputs)
|
||||||
.toTensorList();
|
.toTensorList();
|
||||||
|
|
||||||
// Create and initialize NnapiComilation object
|
// Create and initialize NnapiCompilation object
|
||||||
comp_ = std::make_unique<torch::nnapi::bind::NnapiCompilation>();
|
comp_ = std::make_unique<torch::nnapi::bind::NnapiCompilation>();
|
||||||
auto weights = dict.at("weights").toTensorVector();
|
auto weights = dict.at("weights").toTensorVector();
|
||||||
comp_->init(ser_model, weights);
|
comp_->init(ser_model, weights);
|
||||||
|
|
|
||||||
|
|
@ -379,7 +379,7 @@ std::unique_ptr<mobile::Function> FlatbufferLoader::parseFunction(
|
||||||
function->append_type(getOrCreateTypeAnnotations(i));
|
function->append_type(getOrCreateTypeAnnotations(i));
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. If upgrader is needed, change change the OP instrunction to CALL
|
// 3. If upgrader is needed, change change the OP instruction to CALL
|
||||||
// instruction (In next PR, use_upgrader will be parsed to parseInstruction
|
// instruction (In next PR, use_upgrader will be parsed to parseInstruction
|
||||||
// function and do the actual change)
|
// function and do the actual change)
|
||||||
if (use_upgrader) {
|
if (use_upgrader) {
|
||||||
|
|
|
||||||
|
|
@ -391,7 +391,7 @@ void BytecodeDeserializer::parseMethods(
|
||||||
debug_handles_m_tuple,
|
debug_handles_m_tuple,
|
||||||
function.get());
|
function.get());
|
||||||
|
|
||||||
// 3. If upgrader is needed, change change the OP instrunction to CALL
|
// 3. If upgrader is needed, change change the OP instruction to CALL
|
||||||
// instruction (In next PR, use_upgrader will be parsed to parseInstruction
|
// instruction (In next PR, use_upgrader will be parsed to parseInstruction
|
||||||
// function and do the actual change)
|
// function and do the actual change)
|
||||||
if (use_upgrader) {
|
if (use_upgrader) {
|
||||||
|
|
|
||||||
|
|
@ -196,7 +196,7 @@ struct BailOutGraphBuilderForNode {
|
||||||
std::shared_ptr<Graph> buildBailOutGraphFrom(Node* n) {
|
std::shared_ptr<Graph> buildBailOutGraphFrom(Node* n) {
|
||||||
// add graph inputs for guard's input
|
// add graph inputs for guard's input
|
||||||
// and loop counts for loops `n` is contained in
|
// and loop counts for loops `n` is contained in
|
||||||
// to make sure we can line bailout grap's inputs up properly
|
// to make sure we can line bailout graph's inputs up properly
|
||||||
// with arguments to this BailOut node.
|
// with arguments to this BailOut node.
|
||||||
for (auto bi : n->inputs()) {
|
for (auto bi : n->inputs()) {
|
||||||
getOrAddInputForValue(bi);
|
getOrAddInputForValue(bi);
|
||||||
|
|
|
||||||
|
|
@ -1230,7 +1230,7 @@ void removeDequantizeFromInputs(const std::unordered_set<Value*>& inputs) {
|
||||||
TORCH_INTERNAL_ASSERT(
|
TORCH_INTERNAL_ASSERT(
|
||||||
dequantized_val->uses().size() == 1,
|
dequantized_val->uses().size() == 1,
|
||||||
"Expect to have one dequantize node for each use");
|
"Expect to have one dequantize node for each use");
|
||||||
// Replace useses of dequantized_val with the input of
|
// Replace uses of dequantized_val with the input of
|
||||||
// dequantize node
|
// dequantize node
|
||||||
dequantized_val->replaceAllUsesWith(dequantize_node->inputs()[0]);
|
dequantized_val->replaceAllUsesWith(dequantize_node->inputs()[0]);
|
||||||
dequantize_node->removeAllInputs();
|
dequantize_node->removeAllInputs();
|
||||||
|
|
|
||||||
|
|
@ -162,7 +162,7 @@ InlinedCallStackPtr InlinedCallStackDeserializer::deserialize(
|
||||||
}
|
}
|
||||||
cached_inlined_callstacks_[tup] = cs_ptr;
|
cached_inlined_callstacks_[tup] = cs_ptr;
|
||||||
// Invoking move constructor
|
// Invoking move constructor
|
||||||
// It is not clear if copy-ellision can happen since
|
// It is not clear if copy-elision can happen since
|
||||||
// cs_ptr is copied into map above.
|
// cs_ptr is copied into map above.
|
||||||
// This is to help avoid ref count update
|
// This is to help avoid ref count update
|
||||||
return cs_ptr;
|
return cs_ptr;
|
||||||
|
|
|
||||||
|
|
@ -106,7 +106,7 @@ std::array<
|
||||||
GetBackendMetaSerialization() {
|
GetBackendMetaSerialization() {
|
||||||
// The array to save function pointer for BackendMeta serialization.
|
// The array to save function pointer for BackendMeta serialization.
|
||||||
// key is the DeviceType, value is std::pair obj.
|
// key is the DeviceType, value is std::pair obj.
|
||||||
// value.first represent get function and value.seconde represent set function
|
// value.first represent get function and value.second represent set function
|
||||||
static std::array<
|
static std::array<
|
||||||
std::optional<std::pair<BackendMetaPtr, BackendMetaPtr>>,
|
std::optional<std::pair<BackendMetaPtr, BackendMetaPtr>>,
|
||||||
at::COMPILE_TIME_MAX_DEVICE_TYPES>
|
at::COMPILE_TIME_MAX_DEVICE_TYPES>
|
||||||
|
|
|
||||||
|
|
@ -830,7 +830,7 @@ std::shared_ptr<LazyGraphExecutor::Async> LazyGraphExecutor::
|
||||||
const SyncTensorsConfig& config) {
|
const SyncTensorsConfig& config) {
|
||||||
SyncTensorCollection coll = CollectSyncTensors(*tensors, config);
|
SyncTensorCollection coll = CollectSyncTensors(*tensors, config);
|
||||||
if (coll.indices.empty()) {
|
if (coll.indices.empty()) {
|
||||||
/* Enure previous execution is complete before exiting this
|
/* Ensure previous execution is complete before exiting this
|
||||||
* function */
|
* function */
|
||||||
TensorCollectionBarrier(&coll);
|
TensorCollectionBarrier(&coll);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
|
||||||
|
|
@ -915,7 +915,7 @@ void passEventsToKineto(
|
||||||
// on-demand Kineto activity handling. Enabling this path
|
// on-demand Kineto activity handling. Enabling this path
|
||||||
// for Profiler API could cause side effects as much has changed since.
|
// for Profiler API could cause side effects as much has changed since.
|
||||||
// Make a surgical fix here until we holistically assess the on-demand
|
// Make a surgical fix here until we holistically assess the on-demand
|
||||||
// vs API path framentation, which has been snowballing in complexity
|
// vs API path fragmentation, which has been snowballing in complexity
|
||||||
// and thus flakiness.
|
// and thus flakiness.
|
||||||
if (config.global()) {
|
if (config.global()) {
|
||||||
e->kineto_activity_ = activity;
|
e->kineto_activity_ = activity;
|
||||||
|
|
|
||||||
|
|
@ -261,7 +261,7 @@ static PyObject* THXPModule_resetAccumulatedMemoryStats(
|
||||||
// XPU module initialization
|
// XPU module initialization
|
||||||
|
|
||||||
static void registerXpuDeviceProperties(PyObject* module) {
|
static void registerXpuDeviceProperties(PyObject* module) {
|
||||||
// Add _xpuDevicePropertires class to torch._C
|
// Add _xpuDeviceProperties class to torch._C
|
||||||
using namespace c10::xpu;
|
using namespace c10::xpu;
|
||||||
auto get_device_type = [](const DeviceProp& prop) {
|
auto get_device_type = [](const DeviceProp& prop) {
|
||||||
std::ostringstream stream;
|
std::ostringstream stream;
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ __all__: list[str] = []
|
||||||
|
|
||||||
|
|
||||||
class _Checkpointer:
|
class _Checkpointer:
|
||||||
"""This base class specefies a high level API for saving and loading
|
"""This base class specifies a high level API for saving and loading
|
||||||
distributed `state_dict` 's. It provides an abstraction over the low-level APIs
|
distributed `state_dict` 's. It provides an abstraction over the low-level APIs
|
||||||
provided by :py:mod:`torch.distributed.checkpoint.storage`, essentially calling
|
provided by :py:mod:`torch.distributed.checkpoint.storage`, essentially calling
|
||||||
:py:meth: `torch.distributed.state_dict_saver.save` and
|
:py:meth: `torch.distributed.state_dict_saver.save` and
|
||||||
|
|
|
||||||
|
|
@ -80,7 +80,7 @@ class BroadcastingTorchSaveReader(StorageReader):
|
||||||
planner = cast(DefaultLoadPlanner, planner)
|
planner = cast(DefaultLoadPlanner, planner)
|
||||||
|
|
||||||
# data is read in on the coordinator rank, and broadcast afterwards
|
# data is read in on the coordinator rank, and broadcast afterwards
|
||||||
# this incurrs a communication cost, but it avoids having to load
|
# this incurs a communication cost, but it avoids having to load
|
||||||
# the entire checkpoint on each rank, hopefully preventing OOM issues
|
# the entire checkpoint on each rank, hopefully preventing OOM issues
|
||||||
# TODO: read on each host, instead of only the coordinator
|
# TODO: read on each host, instead of only the coordinator
|
||||||
if self.is_coordinator:
|
if self.is_coordinator:
|
||||||
|
|
|
||||||
|
|
@ -252,7 +252,7 @@ class _PipelineStageBase(ABC):
|
||||||
self._outputs_meta = tuple(outputs_meta) # type: ignore[assignment]
|
self._outputs_meta = tuple(outputs_meta) # type: ignore[assignment]
|
||||||
|
|
||||||
def get_outputs_meta(self) -> tuple[torch.Tensor, ...]:
|
def get_outputs_meta(self) -> tuple[torch.Tensor, ...]:
|
||||||
"""Get the output metadata (meta tensors) reprensenting the outputs of this stage"""
|
"""Get the output metadata (meta tensors) representing the outputs of this stage"""
|
||||||
assert self._outputs_meta is not None, (
|
assert self._outputs_meta is not None, (
|
||||||
"Attempted to get_outputs_meta() without configuring output meta"
|
"Attempted to get_outputs_meta() without configuring output meta"
|
||||||
)
|
)
|
||||||
|
|
@ -723,7 +723,7 @@ class _PipelineStageBase(ABC):
|
||||||
)
|
)
|
||||||
self._validate_fwd_outputs(output_tuple)
|
self._validate_fwd_outputs(output_tuple)
|
||||||
|
|
||||||
# We return the original user-provied output, not normalized to tuple.
|
# We return the original user-provided output, not normalized to tuple.
|
||||||
# See [Note: pipeline model output type]
|
# See [Note: pipeline model output type]
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
@ -1188,7 +1188,7 @@ class _PipelineStage(_PipelineStageBase):
|
||||||
# No need to send back to rank 0
|
# No need to send back to rank 0
|
||||||
# - If user.target is stage_backward:
|
# - If user.target is stage_backward:
|
||||||
# No need to send assuming submod output is stored locally or
|
# No need to send assuming submod output is stored locally or
|
||||||
# should be re-calucated in case of activation checkpointing
|
# should be re-calculated in case of activation checkpointing
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _create_act_send_info(self):
|
def _create_act_send_info(self):
|
||||||
|
|
|
||||||
|
|
@ -45,7 +45,7 @@ class EinsumDims:
|
||||||
for input_dim in input_dims:
|
for input_dim in input_dims:
|
||||||
dim_char_set.update(input_dim)
|
dim_char_set.update(input_dim)
|
||||||
|
|
||||||
# get a determinisitc order of all dim chars
|
# get a deterministic order of all dim chars
|
||||||
all_dim_chars = sorted(dim_char_set)
|
all_dim_chars = sorted(dim_char_set)
|
||||||
|
|
||||||
# parse input and output dimensions
|
# parse input and output dimensions
|
||||||
|
|
|
||||||
|
|
@ -484,7 +484,7 @@ def replicate_tensor_dim(
|
||||||
def gen_slice_scatter_strategy(op_schema: OpSchema) -> StrategyType:
|
def gen_slice_scatter_strategy(op_schema: OpSchema) -> StrategyType:
|
||||||
# 1. number of dimensions in input and src need to match.
|
# 1. number of dimensions in input and src need to match.
|
||||||
# 2. number of elements on all non-dim need to match between input and src.
|
# 2. number of elements on all non-dim need to match between input and src.
|
||||||
# 3. numer of elements in src in dim need to match the slice size.
|
# 3. number of elements in src in dim need to match the slice size.
|
||||||
# Given the above:
|
# Given the above:
|
||||||
# - We suggest for src to follow the sharding of input, except on the scatter dimension,
|
# - We suggest for src to follow the sharding of input, except on the scatter dimension,
|
||||||
# where our best bet for now is to make them replicated as a fall-back.
|
# where our best bet for now is to make them replicated as a fall-back.
|
||||||
|
|
|
||||||
|
|
@ -592,7 +592,7 @@ class DTensorRedistributePlanner:
|
||||||
current = current_placements[mesh_dim]
|
current = current_placements[mesh_dim]
|
||||||
target = target_placements[mesh_dim]
|
target = target_placements[mesh_dim]
|
||||||
# If target is not Shard, we can directly redistribute since we
|
# If target is not Shard, we can directly redistribute since we
|
||||||
# are traversing from innner to outer placements here
|
# are traversing from inner to outer placements here
|
||||||
if isinstance(target, Shard):
|
if isinstance(target, Shard):
|
||||||
# If target is Shard, check for nested sharding on the
|
# If target is Shard, check for nested sharding on the
|
||||||
# tensor dim BEFORE the current mesh_dim
|
# tensor dim BEFORE the current mesh_dim
|
||||||
|
|
|
||||||
|
|
@ -922,7 +922,7 @@ def _export_to_aten_ir(
|
||||||
if decompose_custom_triton_ops
|
if decompose_custom_triton_ops
|
||||||
else _disable_custom_triton_op_functional_decomposition
|
else _disable_custom_triton_op_functional_decomposition
|
||||||
)
|
)
|
||||||
# This _reparametrize_module makes sure inputs and module.params/buffers have the same fake_mode,
|
# This _reparameterize_module makes sure inputs and module.params/buffers have the same fake_mode,
|
||||||
# otherwise aot_export_module will error out because it sees a mix of fake_modes.
|
# otherwise aot_export_module will error out because it sees a mix of fake_modes.
|
||||||
# And we want aot_export_module to use the fake_tensor mode in dynamo to keep the pipeline easy to reason about.
|
# And we want aot_export_module to use the fake_tensor mode in dynamo to keep the pipeline easy to reason about.
|
||||||
with ExitStack() as stack:
|
with ExitStack() as stack:
|
||||||
|
|
@ -1843,7 +1843,7 @@ def _export_to_aten_ir_make_fx(
|
||||||
)
|
)
|
||||||
return gm, sig
|
return gm, sig
|
||||||
|
|
||||||
# This _reparametrize_module makes sure inputs and module.params/buffers have the same fake_mode,
|
# This _reparameterize_module makes sure inputs and module.params/buffers have the same fake_mode,
|
||||||
# otherwise aot_export_module will error out because it sees a mix of fake_modes.
|
# otherwise aot_export_module will error out because it sees a mix of fake_modes.
|
||||||
# And we want aot_export_module to use the fake_tensor mode in dynamo to keep the pipeline easy to reason about.
|
# And we want aot_export_module to use the fake_tensor mode in dynamo to keep the pipeline easy to reason about.
|
||||||
with ExitStack() as stack:
|
with ExitStack() as stack:
|
||||||
|
|
|
||||||
|
|
@ -281,7 +281,7 @@ def _split_decomp_table_to_cia_and_python_decomp(
|
||||||
for op in list(decomp_table.keys()):
|
for op in list(decomp_table.keys()):
|
||||||
# TODO we are silently allowing non-safe(non-functional) ops through a crack
|
# TODO we are silently allowing non-safe(non-functional) ops through a crack
|
||||||
# due to core aten decomp table having non-functional entries. Once we have
|
# due to core aten decomp table having non-functional entries. Once we have
|
||||||
# a tigher check around core aten decomp, we should warn users about them.
|
# a tighter check around core aten decomp, we should warn users about them.
|
||||||
# Tracking issue: (https://github.com/pytorch/pytorch/issues/135759)
|
# Tracking issue: (https://github.com/pytorch/pytorch/issues/135759)
|
||||||
|
|
||||||
# if it is a valid CIA op we can mess with in export, we check if it is:
|
# if it is a valid CIA op we can mess with in export, we check if it is:
|
||||||
|
|
|
||||||
|
|
@ -1829,7 +1829,7 @@ def norm( # noqa: F811
|
||||||
return _VF.norm(input, p, dim=_dim, keepdim=keepdim) # type: ignore[attr-defined]
|
return _VF.norm(input, p, dim=_dim, keepdim=keepdim) # type: ignore[attr-defined]
|
||||||
|
|
||||||
# TODO: when https://github.com/pytorch/pytorch/issues/33782 is fixed
|
# TODO: when https://github.com/pytorch/pytorch/issues/33782 is fixed
|
||||||
# remove the overloads where dim is an int and replace with BraodcastingList1
|
# remove the overloads where dim is an int and replace with BroadcastingList1
|
||||||
# and remove next four lines, replace _dim with dim
|
# and remove next four lines, replace _dim with dim
|
||||||
if dim is not None:
|
if dim is not None:
|
||||||
if isinstance(dim, (int, torch.SymInt)):
|
if isinstance(dim, (int, torch.SymInt)):
|
||||||
|
|
|
||||||
|
|
@ -4522,7 +4522,7 @@ class ShapeEnv:
|
||||||
|
|
||||||
# The order of checking the guards matters. In this specific example:
|
# The order of checking the guards matters. In this specific example:
|
||||||
# If True branch guard check precedes False branch and for True branch, y.size(0) check precedes x == True,
|
# If True branch guard check precedes False branch and for True branch, y.size(0) check precedes x == True,
|
||||||
# we may have an unnecessary shape speciliazation for y.
|
# we may have an unnecessary shape specialization for y.
|
||||||
def _maybe_specialize_sym_int_with_hint(
|
def _maybe_specialize_sym_int_with_hint(
|
||||||
self, maybe_sym: IntLikeType
|
self, maybe_sym: IntLikeType
|
||||||
) -> IntLikeType:
|
) -> IntLikeType:
|
||||||
|
|
@ -5830,7 +5830,7 @@ class ShapeEnv:
|
||||||
def issue_guard(guard: ShapeGuard) -> None:
|
def issue_guard(guard: ShapeGuard) -> None:
|
||||||
expr = self.simplify(guard.expr)
|
expr = self.simplify(guard.expr)
|
||||||
|
|
||||||
# Avoid re-issueing the same guard.
|
# Avoid re-issuing the same guard.
|
||||||
if expr in issued:
|
if expr in issued:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -118,7 +118,7 @@ def edge(a, b, tie_breaker=hash):
|
||||||
"""A should be checked before B
|
"""A should be checked before B
|
||||||
Tie broken by tie_breaker, defaults to ``hash``
|
Tie broken by tie_breaker, defaults to ``hash``
|
||||||
"""
|
"""
|
||||||
# A either supercedes B and B does not supercede A or if B does then call
|
# A either supersedes B and B does not supersede A or if B does then call
|
||||||
# tie_breaker
|
# tie_breaker
|
||||||
return supercedes(a, b) and (
|
return supercedes(a, b) and (
|
||||||
not supercedes(b, a) or tie_breaker(a) > tie_breaker(b)
|
not supercedes(b, a) or tie_breaker(a) > tie_breaker(b)
|
||||||
|
|
|
||||||
|
|
@ -282,7 +282,7 @@ RuntimeConfigs {
|
||||||
|
|
||||||
Constant folding is the process of finding all of the constant-evaluable
|
Constant folding is the process of finding all of the constant-evaluable
|
||||||
subgraphs, evaluating them at startup, and then storing their results as
|
subgraphs, evaluating them at startup, and then storing their results as
|
||||||
constants as opposed to re-evaluting them every time.
|
constants as opposed to re-evaluating them every time.
|
||||||
|
|
||||||
To enable constant folding, you can set the following configurations.
|
To enable constant folding, you can set the following configurations.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -438,7 +438,7 @@ def _view_as_dense(
|
||||||
# # this is because needs_broadcast indicates that the batch_size is 1
|
# # this is because needs_broadcast indicates that the batch_size is 1
|
||||||
# # and hence there is only 1 value for seq_len
|
# # and hence there is only 1 value for seq_len
|
||||||
# # (2) The cum_seq_lens are given by [0, {*}_t.size(1), 2 * {*}_t.size(1),
|
# # (2) The cum_seq_lens are given by [0, {*}_t.size(1), 2 * {*}_t.size(1),
|
||||||
# # ..., outut_batch_size * {*}_t.size(1)]
|
# # ..., output_batch_size * {*}_t.size(1)]
|
||||||
# # (3) Nnz_{*} is given by output_batch_size * {*}_t.size(1)
|
# # (3) Nnz_{*} is given by output_batch_size * {*}_t.size(1)
|
||||||
|
|
||||||
# if q_batch_size_needs_broadcast or not q_t.is_nested:
|
# if q_batch_size_needs_broadcast or not q_t.is_nested:
|
||||||
|
|
|
||||||
|
|
@ -2229,7 +2229,7 @@ def gumbel_softmax(
|
||||||
).scatter_(dim, index, 1.0)
|
).scatter_(dim, index, 1.0)
|
||||||
ret = y_hard - y_soft.detach() + y_soft
|
ret = y_hard - y_soft.detach() + y_soft
|
||||||
else:
|
else:
|
||||||
# Reparametrization trick.
|
# Reparameterization trick.
|
||||||
ret = y_soft
|
ret = y_soft
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1471,7 +1471,7 @@ class _LazyConvXdMixin(LazyModuleMixin):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
# LazyConv1d defines weight as a Tensor but derived class defines it as UnitializeParameter
|
# LazyConv1d defines weight as a Tensor but derived class defines it as UninitializeParameter
|
||||||
class LazyConv1d(_LazyConvXdMixin, Conv1d): # type: ignore[misc]
|
class LazyConv1d(_LazyConvXdMixin, Conv1d): # type: ignore[misc]
|
||||||
r"""A :class:`torch.nn.Conv1d` module with lazy initialization of the ``in_channels`` argument.
|
r"""A :class:`torch.nn.Conv1d` module with lazy initialization of the ``in_channels`` argument.
|
||||||
|
|
||||||
|
|
@ -1543,7 +1543,7 @@ class LazyConv1d(_LazyConvXdMixin, Conv1d): # type: ignore[misc]
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
|
|
||||||
# LazyConv2d defines weight as a Tensor but derived class defines it as UnitializeParameter
|
# LazyConv2d defines weight as a Tensor but derived class defines it as UninitializeParameter
|
||||||
class LazyConv2d(_LazyConvXdMixin, Conv2d): # type: ignore[misc]
|
class LazyConv2d(_LazyConvXdMixin, Conv2d): # type: ignore[misc]
|
||||||
r"""A :class:`torch.nn.Conv2d` module with lazy initialization of the ``in_channels`` argument.
|
r"""A :class:`torch.nn.Conv2d` module with lazy initialization of the ``in_channels`` argument.
|
||||||
|
|
||||||
|
|
@ -1615,7 +1615,7 @@ class LazyConv2d(_LazyConvXdMixin, Conv2d): # type: ignore[misc]
|
||||||
return 2
|
return 2
|
||||||
|
|
||||||
|
|
||||||
# LazyConv3d defines weight as a Tensor but derived class defines it as UnitializeParameter
|
# LazyConv3d defines weight as a Tensor but derived class defines it as UninitializeParameter
|
||||||
class LazyConv3d(_LazyConvXdMixin, Conv3d): # type: ignore[misc]
|
class LazyConv3d(_LazyConvXdMixin, Conv3d): # type: ignore[misc]
|
||||||
r"""A :class:`torch.nn.Conv3d` module with lazy initialization of the ``in_channels`` argument.
|
r"""A :class:`torch.nn.Conv3d` module with lazy initialization of the ``in_channels`` argument.
|
||||||
|
|
||||||
|
|
@ -1688,7 +1688,7 @@ class LazyConv3d(_LazyConvXdMixin, Conv3d): # type: ignore[misc]
|
||||||
return 3
|
return 3
|
||||||
|
|
||||||
|
|
||||||
# LazyConvTranspose1d defines weight as a Tensor but derived class defines it as UnitializeParameter
|
# LazyConvTranspose1d defines weight as a Tensor but derived class defines it as UninitializeParameter
|
||||||
class LazyConvTranspose1d(_LazyConvXdMixin, ConvTranspose1d): # type: ignore[misc]
|
class LazyConvTranspose1d(_LazyConvXdMixin, ConvTranspose1d): # type: ignore[misc]
|
||||||
r"""A :class:`torch.nn.ConvTranspose1d` module with lazy initialization of the ``in_channels`` argument.
|
r"""A :class:`torch.nn.ConvTranspose1d` module with lazy initialization of the ``in_channels`` argument.
|
||||||
|
|
||||||
|
|
@ -1760,7 +1760,7 @@ class LazyConvTranspose1d(_LazyConvXdMixin, ConvTranspose1d): # type: ignore[mi
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
|
|
||||||
# LazyConvTranspose2d defines weight as a Tensor but derived class defines it as UnitializeParameter
|
# LazyConvTranspose2d defines weight as a Tensor but derived class defines it as UninitializeParameter
|
||||||
class LazyConvTranspose2d(_LazyConvXdMixin, ConvTranspose2d): # type: ignore[misc]
|
class LazyConvTranspose2d(_LazyConvXdMixin, ConvTranspose2d): # type: ignore[misc]
|
||||||
r"""A :class:`torch.nn.ConvTranspose2d` module with lazy initialization of the ``in_channels`` argument.
|
r"""A :class:`torch.nn.ConvTranspose2d` module with lazy initialization of the ``in_channels`` argument.
|
||||||
|
|
||||||
|
|
@ -1832,7 +1832,7 @@ class LazyConvTranspose2d(_LazyConvXdMixin, ConvTranspose2d): # type: ignore[mi
|
||||||
return 2
|
return 2
|
||||||
|
|
||||||
|
|
||||||
# LazyConvTranspose3d defines weight as a Tensor but derived class defines it as UnitializeParameter
|
# LazyConvTranspose3d defines weight as a Tensor but derived class defines it as UninitializeParameter
|
||||||
class LazyConvTranspose3d(_LazyConvXdMixin, ConvTranspose3d): # type: ignore[misc]
|
class LazyConvTranspose3d(_LazyConvXdMixin, ConvTranspose3d): # type: ignore[misc]
|
||||||
r"""A :class:`torch.nn.ConvTranspose3d` module with lazy initialization of the ``in_channels`` argument.
|
r"""A :class:`torch.nn.ConvTranspose3d` module with lazy initialization of the ``in_channels`` argument.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -144,7 +144,7 @@ class BasePruningMethod(ABC):
|
||||||
|
|
||||||
method = _get_composite_method(cls, module, name, *args, **kwargs)
|
method = _get_composite_method(cls, module, name, *args, **kwargs)
|
||||||
# at this point we have no forward_pre_hooks but we could have an
|
# at this point we have no forward_pre_hooks but we could have an
|
||||||
# active reparametrization of the tensor if another pruning method
|
# active reparameterization of the tensor if another pruning method
|
||||||
# had been applied (in which case `method` would be a PruningContainer
|
# had been applied (in which case `method` would be a PruningContainer
|
||||||
# and not a simple pruning method).
|
# and not a simple pruning method).
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -164,7 +164,7 @@ _TORCH_DTYPE_TO_ABBREVIATION = {
|
||||||
|
|
||||||
SYM_VALUE_TYPE = Union[torch.SymInt, torch.SymFloat, torch.SymBool]
|
SYM_VALUE_TYPE = Union[torch.SymInt, torch.SymFloat, torch.SymBool]
|
||||||
META_VALUE_TYPE = Union[fake_tensor.FakeTensor, SYM_VALUE_TYPE, int, float, bool]
|
META_VALUE_TYPE = Union[fake_tensor.FakeTensor, SYM_VALUE_TYPE, int, float, bool]
|
||||||
# NOTE: Belows are from torch/fx/node.py
|
# NOTE: Below are from torch/fx/node.py
|
||||||
BaseArgumentTypes = Union[
|
BaseArgumentTypes = Union[
|
||||||
str,
|
str,
|
||||||
int,
|
int,
|
||||||
|
|
|
||||||
|
|
@ -810,7 +810,7 @@ class QuantizationTestCase(TestCase):
|
||||||
b = io.BytesIO()
|
b = io.BytesIO()
|
||||||
torch.save(model_dict, b)
|
torch.save(model_dict, b)
|
||||||
b.seek(0)
|
b.seek(0)
|
||||||
# weights_only=False as we sometimes get a ScriptObect here (weird)
|
# weights_only=False as we sometimes get a ScriptObject here (weird)
|
||||||
loaded_dict = torch.load(b, weights_only=False)
|
loaded_dict = torch.load(b, weights_only=False)
|
||||||
loaded_model.load_state_dict(loaded_dict)
|
loaded_model.load_state_dict(loaded_dict)
|
||||||
ref_out = ref_model(*x)
|
ref_out = ref_model(*x)
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ AutoHeuristic is a framework that allows one to use results from autotuning to l
|
||||||
|
|
||||||
## How to use AutoHeuristic
|
## How to use AutoHeuristic
|
||||||
In general, the following steps have to performed:
|
In general, the following steps have to performed:
|
||||||
- The AutoHeursitic constructor has to be called.
|
- The AutoHeuristic constructor has to be called.
|
||||||
- A script that runs benchmarks in order to collect training data has to be implemented.
|
- A script that runs benchmarks in order to collect training data has to be implemented.
|
||||||
- The train_decision.py (if you want to learn a decision tree) or train_regression.py (if you want to learn a regression tree) script has to be run in order to learn the heuristic and generate it to code.
|
- The train_decision.py (if you want to learn a decision tree) or train_regression.py (if you want to learn a regression tree) script has to be run in order to learn the heuristic and generate it to code.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -678,7 +678,7 @@ def gen_aoti_c_shim_files(
|
||||||
# Use "aten" as the device name when dispatch_key is Generic
|
# Use "aten" as the device name when dispatch_key is Generic
|
||||||
device_name = "aten" if dispatch_key is None else dispatch_key.lower()
|
device_name = "aten" if dispatch_key is None else dispatch_key.lower()
|
||||||
|
|
||||||
# header files were checked in for ABI-compatiblilty checking
|
# header files were checked in for ABI-compatibility checking
|
||||||
header_file_name = f"c_shim_{device_name}.h"
|
header_file_name = f"c_shim_{device_name}.h"
|
||||||
new_header = gen_aoti_c_shim(
|
new_header = gen_aoti_c_shim(
|
||||||
fallback_native_functions,
|
fallback_native_functions,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user