Summary:
See wrapper.codegen_reinterpret_view(), it return a temporary handle for tensor, which has following problem.
```
# NB, the return handle here represents a temporary tensor, which will be automatically
# released.
# Here's a sample usage in the cpp wrapper code:
# ```
# aoti_torch_addmm_out(
# buf1,
# arg1_1,
# RAIIAtenTensorHandle(tmp_tensor_handle_0),
# buf0,
# 1L,
# 1L));
# ```
# RAIIAtenTensorHandle(tmp_tensor_handle_0) will be released after the call to addmm_out.
# This could be problematic when it's used in a different pattern, for example:
# ````
# AtenTensorHandle tensor_args[] = {RAIIAtenTensorHandle(tmp_tensor_handle_2), buf5, buf6};
# aoti_torch_proxy_executor_call_function(..., tensor_args);
# ````
# RAIIAtenTensorHandle(tmp_tensor_handle_2) will be invalid when it's used in the latter
# kernel call.
return f"RAIIAtenTensorHandle({tmp_name})"
```
As a result, ProxyExecutor would generate following code, which cause invalid memory access.
Before:
```
// Source Nodes: [fn_with_tuple_output], Original ATen: [fb.fn_with_tuple_output]
AtenTensorHandle tmp_tensor_handle_2;
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch__reinterpret_tensor(buf3, 2, int_array_0, int_array_1, 0L, &tmp_tensor_handle_2));
...
AtenTensorHandle tensor_args[] = {RAIIAtenTensorHandle(tmp_tensor_handle_2), buf5, buf6};
int64_t int_args[] = {1};
aoti_torch_proxy_executor_call_function(proxy_executor, 1, 1, int_args, 3, tensor_args);
buf3.reset();
```
With fix in this diff, ProxyExecutor generates following code
After:
```
// Source Nodes: [fn_with_tuple_output], Original ATen: [fb.fn_with_tuple_output]
AtenTensorHandle tmp_tensor_handle_2;
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch__reinterpret_tensor(buf3, 2, int_array_0, int_array_1, 0L, &tmp_tensor_handle_2));
...
aoti_torch_proxy_executor_call_function(proxy_executor, 1, 1, std::vector<int64_t>{1}.data(), 3, std::vector<AtenTensorHandle>{RAIIAtenTensorHandle(tmp_tensor_handle_2), buf5, buf6}.data());
buf3.reset();
```
I am not exactly a big fan of such `std::vector{...}.data()` for creating a temp array, but I can't think of another fix.
Test Plan: buck2 run mode/dev-nosan deeplearning/aot_inductor/test:test_custom_ops
Reviewed By: desertfire
Differential Revision: D49758764
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110451
Approved by: https://github.com/desertfire
A resubmit of https://github.com/pytorch/pytorch/pull/108447. Copy over the descriptions:
This is a follow-up of the discussion in https://github.com/pytorch/pytorch/pull/108356, where we want to repalce source_fn with source_fn_stack
Before this PR, for the following example:
```python
backend = EagerAndRecordGraphs()
@torch.compile(backend=backend, fullgraph=True)
def cond_f(pred, pred2, x, y):
def true_fn(pred2, x, y):
return x + y
def false_fn(pred2, x, y):
def true_fn2(x, y):
return x.sin() - y.cos()
def false_fn2(x, y):
return x.cos() - y.sin()
return control_flow.cond(pred2, true_fn2, false_fn2, (x, y))
return control_flow.cond(pred, true_fn, false_fn, (pred2, x, y))
```
The graph captured is shown below:
```python
class GraphModule(torch.nn.Module):
def forward(self, L_pred_ : torch.Tensor, L_pred2_ : torch.Tensor, L_x_ : torch.Tensor, L_y_ : torch.Tensor):
l_pred_ = L_pred_
l_pred2_ = L_pred2_
l_x_ = L_x_
l_y_ = L_y_
cond_true_1 = self.cond_true_1
cond_false_1 = self.cond_false_1
cond = torch.ops.higher_order.cond(l_pred_, cond_true_1, cond_false_1, [l_pred2_, l_x_, l_y_]); l_pred_ = cond_true_1 = cond_false_1 = l_pred2_ = l_x_ = l_y_ = None
return (cond,)
class GraphModule(torch.nn.Module):
def forward(self, l_pred2_, l_x_, l_y_):
add = l_x_ + l_y_; l_x_ = l_y_ = None
return add
class GraphModule(torch.nn.Module):
def forward(self, l_pred2_, l_x_, l_y_):
cond_true_0 = self.cond_true_0
cond_false_0 = self.cond_false_0
cond = torch.ops.higher_order.cond(l_pred2_, cond_true_0, cond_false_0, [l_x_, l_y_]); l_pred2_ = cond_true_0 = cond_false_0 = l_x_ = l_y_ = None
return cond
class GraphModule(torch.nn.Module):
def forward(self, l_x_, l_y_):
sin = l_x_.sin(); l_x_ = None
cos = l_y_.cos(); l_y_ = None
sub = sin - cos; sin = cos = None
return sub
class GraphModule(torch.nn.Module):
def forward(self, l_x_, l_y_):
cos = l_x_.cos(); l_x_ = None
sin = l_y_.sin(); l_y_ = None
sub = cos - sin; cos = sin = None
return sub
```
the source_fn for inner cond, sin, cos will be a (name, target) tuple:
```
('cond', <torch._ops.HigherOrderOperator object at xxx>)
('sin', 'sin')
('cos', 'cos')
('sub'. <built-in function sub>)
```
After this pr, the source_fn_stack will be a list of (name, target) tuple. The bottom of stack is the end of the list.
```
[('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>)],
[('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('sin', 'sin')],
[('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cos', 'cos')]
[('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('sub', <built-in function sub>)]
```
Test Plan:
See added tests in test_higher_order_ops.py and modify existing test.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108595
Approved by: https://github.com/angelayi, https://github.com/zou3519
Based on William's recent diff on preserving node metadata on retracing, we no longer need to skip dynamo on retracing. This softens our previous restriction of not allowing any new constraints from user side because we can utilize dynamo to analyze through constraints now. As a result, re-export can technically happen with any new constraints. This opens up another problem that "Is it ok to use more loose constraints on the retracing?" If we allow loose constraints, we can technically diverge from eager behaviour because for example we could have eliminated unsafe control flow based on previous assumption. But we can also argue this is ok because we can say we treat the Exported callable to be an independent callable from its' original source code.
We can technically ban loose constraints inside export, but my concern is we are breaking abstraction by doing special case checks on ExportedProgram.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109476
Approved by: https://github.com/avikchaudhuri, https://github.com/zhxchen17
Summary:
This is a prototype for running extern fallback kernels with a host side proxy executor.
Sample of generated cpp wrapper call:
```
at::Tensor buf0; // output buffer
void* tensor_args_var_0[] = {&arg0_1, &arg0_1, &arg1_1, &arg0_1, &arg1_1, &buf0};
int64_t int_args_var_1[] = {81, 81, 7, 7, 7, 81};
proxy_executor->call_function("buf0", int_args_var_1, tensor_args_var_0);
```
- In my current implementation, proxy executor interprets the raw pointers according to the ops schema.
This assumes that custom op MUST have a valid schema registered to Dispatcher. (I would like to validate this assumption)
- I am using callboxed() API of the custom kernels. This is inevitable, as we wish to have a single call_function API for all possible custom kernels.
- These are all the input argument types I have support so far.
union Argument {
# Bool value does not matter
1: bool asNone;
2: TensorArgument asTensor;
3: list<TensorArgument> asTensors;
5: i64 asInt;
7: list<i64> asInts;
8: double asFloat;
9: list<double> asFloats;
10: string asString;
10.5: list<string> asStrings;
11: SymIntArgument asSymInt;
12: list<SymIntArgument> asSymInts;
13: ScalarType asScalarType;
14: MemoryFormat asMemoryFormat;
15: Layout asLayout;
16: Device asDevice;
17: bool asBool;
18: list<bool> asBools;
}
- Need a policy for handling unpopulated argument with default values. Here are the options, and it has BC implications.
1. requires exported fx graph to explicitly populate default values, if users doesn't specify.
2. requires cpp wrapper to explicitly populate default values, if fx graph doesn't specify.
3. Proxy executor look up from opSchema for default values.
For fixing T162112344
Test Plan:
frontend:
buck2 run mode/dev-sand mode/inplace -c fbcode.enable_gpu_sections=True sigmoid/frontend:export_main
test:
buck2 run mode/dev-sand //deeplearning/aot_inductor/test:test_custom_ops
backend:
buck2 run mode/dev-nosan //deeplearning/aot_inductor/fb:main
buck2 test 'fbcode//mode/opt' fbcode//caffe2/torch/fb/model_transform/experimental/benchmark/test:test_aot_inductor_benchmark -- --exact 'caffe2/torch/fb/model_transform/experimental/benchmark/test:test_aot_inductor_benchmark - test_aot_inductor_benchmark_cmf30x (caffe2.torch.fb.model_transform.experimental.benchmark.test.test_aot_inductor_benchmark.AOTInductorBenchmark)'
Reviewed By: suo
Differential Revision: D48747417
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108350
Approved by: https://github.com/izaitsevfb
Fixes https://github.com/pytorch/pytorch/pull/102577#issuecomment-1650905536
Serializing to json is more stable, and renamed the API:
```
# Takes in a treespec and returns the serialized treespec as a string. Also optionally takes in a protocol version number.
def treespec_dumps(treespec: TreeSpec, protocol: Optional[int] = None) -> str:
# Takes in a serialized treespec and outputs a TreeSpec
def treespec_loads(data: str) -> TreeSpec:
```
If users want to register their own serialization format for a given pytree, they can go through the `_register_treespec_serializer` API which optionally takes in a `getstate` and `setstate` function.
```
_register_treespec_serializer(type_, *, getstate, setstate)
# Takes in the context, and outputs a json-dumpable context
def getstate(context: Context) -> DumpableContext:
# Takes in a json-dumpable context, and reconstructs the original context
def setstate(dumpable_context: DumpableContext) -> Context:
```
We will serialize to the following dataclass, and then json.dump this it to string.
```
class TreeSpec
type: Optional[str] # a string name of the type. null for the case of a LeafSpec
context: Optional[Any] # optional, a json dumpable format of the context
children_specs: List[TreeSpec],
}
```
If no getstate/setstate function is registered, we will by default serialize the context using `json.dumps/loads`. We will also serialize the type through `f"{typ.__module__}.{typ.__name__}"`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106116
Approved by: https://github.com/zou3519
Summary: Previously serializing graphs using map would error
because map returns a singleton tensor list rather than a
single tensor. So this diff adds support for if a higher order operator
returns a list of tensors as output.
We also run into an issue with roundtripping the source_fn on
map nodes/subgraphs. The source_fn originally is
<functorch.experimental._map.MapWrapper object at 0x7f80a0549930>, which
serializes to `functorch.experimental._map.map`. However, we are unable
to construct the function from this string. This should be fixed once
map becomes a fully supported operator like
torch.ops.higher_order.cond.
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags:
Differential Revision: [D48631302](https://our.internmc.facebook.com/intern/diff/D48631302)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107837
Approved by: https://github.com/zhxchen17
ghstack dependencies: #107818
Some NvidaTRT folks were asking for a way to integrate the serialization of custom objects with export's serialization. After some discussion (more background [here](https://docs.google.com/document/d/1lJfxakmgeoEt50inWZ53MdUtOSa_0ihwCuPy_Ak--wc/edit)), we settled on a way for users to register their custom object's serializer/deserializer functions.
Since TorchScript's `.def_pickle` already exists for [registering custom classes](https://pytorch.org/tutorials/advanced/torch_script_custom_classes.html), and `tensorrt.ICudaEngine` already contains a `.def_pickle` implementation, we'll start off by reusing the existing framework and integrating it with export's serialization.
TorchScript's `.def_pickle` requires users to register two functions, which end up being the `__getstate__` and `__setstate__` methods on the class. The semantics of `__getstate__` and `__setstate__` in TorchScript are equivalent to that of Python pickle modules. This is then registered using pybind's `py::pickle` function [here](https://www.internalfb.com/code/fbsource/[f44e048145e4697bccfaec300798fce7daefb858]/fbcode/caffe2/torch/csrc/jit/python/script_init.cpp?lines=861-916) to be used with Python's pickle to initialize a ScriptObject with the original class, and set the state back to what it used to be.
I attempted to call `__getstate__` and `__setstate__` directly, but I couldn't figure out how to initial the object to be called with `__setstate__` in python. One option would be to create a `torch._C.ScriptObject` and then set the class and call `__setstate__`, but there is no constructor initialized for ScriptObjects. Another option would be to construct an instance of the serialized class itself, but if the class constructor required arguments, I wouldn't know what to initialize it with. In ScriptObject's `py::pickle` registration it directly creates the object [here](https://www.internalfb.com/code/fbsource/[f44e048145e4697bccfaec300798fce7daefb858]/fbcode/caffe2/torch/csrc/jit/python/script_init.cpp?lines=892-906), which is why I was thinking that just directly using Python's `pickle` will be ok since it is handled here.
So, what I did is that I check if the object is pickle-able, meaning it contains `__getstate__` and `__setstate__` methods, and if so, I serialize it with Python's pickle. TorchScript does have its own implementation of [pickle/unpickle](https://www.internalfb.com/code/fbsource/[59cbc569ccbcaae0db9ae100c96cf0bae701be9a][history]/fbcode/caffe2/torch/csrc/jit/serialization/pickle.h?lines=19%2C82), but it doesn't seem to have pybinded functions callable from python.
A question is -- is it ok to combine this pickle + json serialization?
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107666
Approved by: https://github.com/gmagogsfm
Summary: Currently serializing graphs which return get_attr's directly as output fails. This diff adds support for that only in EXIR serializer while we still support unlifted params.
Test Plan: Added test case.
Differential Revision: D48258552
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107610
Approved by: https://github.com/angelayi
Added the following APIs:
```
def save(
ep: ExportedProgram,
f: Union[str, pathlib.Path, io.BytesIO],
extra_files: Optional[Dict[str, Any]] = None,
opset_version: Optional[Dict[str, int]] = None,
) -> None:
"""
Saves a version of the given exported program for use in a separate process.
Args:
ep (ExportedProgram): The exported program to save.
f (str): A file-like object (has to implement write and flush)
or a string containing a file name.
extra_files (Optional[Dict[str, Any]]): Map from filename to contents
which will be stored as part of f.
opset_version (Optional[Dict[str, int]]): A map of opset names
to the version of this opset
"""
def load(
f: Union[str, pathlib.Path, io.BytesIO],
extra_files: Optional[Dict[str, Any]] = None,
expected_opset_version: Optional[Dict[str, int]] = None,
) -> ExportedProgram:
"""
Loads an ExportedProgram previously saved with torch._export.save
Args:
ep (ExportedProgram): The exported program to save.
f (str): A file-like object (has to implement write and flush)
or a string containing a file name.
extra_files (Optional[Dict[str, Any]]): The extra filenames given in
this map would be loaded and their content would be stored in the
provided map.
expected_opset_version (Optional[Dict[str, int]]): A map of opset names
to expected opset versions
Returns:
An ExportedProgram object
"""
```
Example usage:
```
# With buffer
buffer = io.BytesIO()
torch._export.save(ep, buffer)
buffer.seek(0)
loaded_ep = torch._export.load(buffer)
# With file
with tempfile.NamedTemporaryFile() as f:
torch._export.save(ep, f)
f.seek(0)
loaded_ep = torch._export.load(f)
# With Path
with TemporaryFileName() as fname:
path = pathlib.Path(fname)
torch._export.save(ep, path)
loaded_ep = torch._export.load(path)
# Saving with extra files
buffer = io.BytesIO()
save_extra_files = {"extra.txt": "moo"}
torch._export.save(ep, buffer, save_extra_files)
buffer.seek(0)
load_extra_files = {"extra.txt": ""}
loaded_ep = torch._export.load(buffer, extra_files)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107309
Approved by: https://github.com/avikchaudhuri, https://github.com/gmagogsfm, https://github.com/tugsbayasgalan
Since constrain_as_size has been fixed, I tried serializing it, but ran into some issues.
Notably, after each `.transform` call, I added a helper `_get_updated_range_constraints` to update the range constrains list. This is because when we retrace in a pass, the symbolic values being used changes, so we need to update this dictionary.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107386
Approved by: https://github.com/avikchaudhuri, https://github.com/zhxchen17
Added a version number to the schema for BC issues. We will add this number to the serialized ExportedProgram and then when deserializing, if the number does not match up with the existing deserializer, we will error. We should update the number of there are any major changes to the schema.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107420
Approved by: https://github.com/zhxchen17
Some notable changes:
1. `constrain_as_size` allows min value to be less than 2 as it will unconditionally assume min >= 2 for compiler purposes. Instead, we add additional check to make sure max value is always greater than 2.
2. Previously, we used to runtime assert on the unbacked symint's val range which would be always between [2, max]. I modified this logic to assert on [0, max] unless user explicitly specifies the min range.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106591
Approved by: https://github.com/gmagogsfm, https://github.com/ezyang
Some notable changes:
1. `constrain_as_size` allows min value to be less than 2 as it will unconditionally assume min >= 2 for compiler purposes. Instead, we add additional check to make sure max value is always greater than 2.
2. Previously, we used to runtime assert on the unbacked symint's val range which would be always between [2, max]. I modified this logic to assert on [0, max] unless user explicitly specifies the min range.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106591
Approved by: https://github.com/gmagogsfm, https://github.com/ezyang
Summary: Internal model and Resnet uses "re-export" flow now. Also did some refactoring to make the code little cleaner
Some changes for OSS:
1. Correctly use the "cached" fake tensors so that static symbols are still resolved to static
2. Change logic in PassBase to allocate static shapes for parameters
3. Add "is_torch_exported" tag to every node to make it survive during various graph transformations.
4. Added experimental wrapper API for quantization team to get pre_dispatch=True graph. Note that it doesn't actually do that right now. But we plan to switch soon.
Test Plan: CI
Differential Revision: D47890878
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106676
Approved by: https://github.com/jerryzh168
Summary:
Adding support for edge dialect ops in `exir/serde`. This diff does the following:
- Moves the global `serialize_operator/deserialize_operator` implementations in`export/serde/serialize.py` into `GraphModuleSerializer` and `GraphModuleDeserializer`
- Adds implementations of `serialize_operator/deserialize_operator` inside `GraphModuleSerializer` and `GraphModuleDeserializer` in `exir/serde/serialize.py`
Test Plan: CI + Enabled edge dialect ops in `executorch/exir/tests/test_serde.py`
Differential Revision: D47938280
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106371
Approved by: https://github.com/angelayi
@SherlockNoMad mentioned that it's not bc safe to directly access these attributes, so I moved them to @property fields, and added a `@compatibility` decorator. For now I just set it to True for graph_module/graph.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106170
Approved by: https://github.com/SherlockNoMad
Summary: Sometimes the graph that is being serialized contains nodes with side effects + no users (ex. out variants of operators), so we don't want to eliminate those when deserializing.
Test Plan: CI
Differential Revision: D47735009
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105875
Approved by: https://github.com/ydwu4
Solving #105242.
During export, the exported function's signature changes multiple times. Suppose we'd like to export f as shown in following example:
```python
def f(arg1, arg2, kw1, kw2):
pass
args = (arg1, arg2)
kwargs = {"kw2":arg3, "kw1":arg4}
torch.export(f, args, kwargs)
```
The signature changes mutiple times during export process in the following order:
1. **gm_torch_level = dynamo.export(f, *args, \*\*kwargs)**. In this step, we turn all kinds of parameters such as **postional_only**, **var_positioinal**, **kw_only**, and **var_kwargs** into **positional_or_kw**.It also preserves the positional and kword argument names in original function (i.e. f in this example) [here](https://github.com/pytorch/pytorch/blob/main/torch/_dynamo/export.py#L546C13-L546C27). The order of kwargs will be the **key order** of kwargs (after python 3.6, the order is the insertion of order of keys) instead of the original function signature and the order is baked into a _orig_args varaible of gm_torch_level's pytree info. So we'll have:
```python
def gm_torch_level(arg1, arg2, kw2, kw1)
```
Such difference is acceptable as it's transparent to users of export.
2. **gm_aot_export = aot_export_module(gm_torch_level, pos_or_kw_args)**. In this step, we need to turn kwargs into positional args in the order of how gm_torch_level expected, which is stored in _orig_args. The returned gm_aot_export has the graph signature of flat_args, in_spec = pytree.tree_flatten(pos_or_kw_args):
``` python
flat_args, _ = pytree.tree_flatten(pos_or_kw_args)
def gm_aot_export(*flat_args)
```
3. **exported_program(*args, \*\*kwargs)**. The epxorted artifact is exported_program, which is a wrapper over gm_aot_export and has the same calling convention as the original function "f". To do this, we need to 1. specialize the order of kwargs into pos_or_kw_args and 2. flatten the pos_or_kw_args into what gm_aot_export expected. We can combine the two steps into one with :
```python
_, in_spec = pytree.tree_flatten((args, kwargs))
# Then during exported_program.__call__(*args, **kwargs)
flat_args = fx_pytree.tree_flatten_spec((args, kwargs), in_spec)
```
, where kwargs is treated as a normal pytree whose keyorder is preserved in in_spec.
Implementation-wise, we treat _orig_args in dynamo exported graph module as single source of truth and kwags are ordered following it.
Test plan:
See added tests in test_export.py.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105337
Approved by: https://github.com/angelayi, https://github.com/tugsbayasgalan
This PR re-lands
- [Typing] Fix PEP 484 Violation (#105022)
- Update mypy to 1.4.1 (#91983)
That were reverted due to the conflict with internal source repo.
Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
- Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
- Add missing return statement to `torch._export. deserialize_graph`
- Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
- Add assert it `torch/optim/optimizer.py` that Optional list is not None
TODO (in followup PR):
- Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`
Unrelated, to bypass CI failures due to the gcc9 dependency update in Ubuntu-18.04:
- Add hack to squash older libstdc++ from conda environment in favor one from OS to `.ci/docker/install_conda.sh`
- Update bazel cuda builds to focal, as with libstdc++-6.0.32 bazel builds loose the ability to catch exceptions (probably because they link with cupti statically, but I could not found where it is done)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105227
Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007
This PR re-lands
- [Typing] Fix PEP 484 Violation (#105022)
- Update mypy to 1.4.1 (#91983)
That were reverted due to the conflict with internal source repo.
Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
- Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
- Add missing return statement to `torch._export. deserialize_graph`
- Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
- Add assert it `torch/optim/optimizer.py` that Optional list is not None
TODO (in followup PR):
- Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105227
Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007
Summary: Submodules may have a none call-spec values, which is ok. Updating types + serializer to handle this
Test Plan: CI
Reviewed By: ydwu4, zhxchen17
Differential Revision: D47353101
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105179
Approved by: https://github.com/zhxchen17
Summary:
Some serialized nn_module_stacks contain nested commas, something like:
`(getitem(L['module'],0),torch.nn.modules.linear.Linear)`
Fixing the parsing so that we can deserialize the string in the format of: `(local identifier, module type)`
Test Plan: CI
Differential Revision: D47252881
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104721
Approved by: https://github.com/zhxchen17