Commit Graph

86 Commits

Author SHA1 Message Date
Han Qi (qihqi)
f0e3c4929b only copy meta if available (#92623)
Test Plan:
```
buck2 test mode/opt //torchmultimodal/tests:tests -- --exact 'torchmultimodal/tests:tests - test_albef.py::test_albef_image_embeddings_momentum'
```
now passes

Reviewed By: malfet

Differential Revision: D42608385

Pull Request resolved: https://github.com/pytorch/pytorch/pull/92623
Approved by: https://github.com/tugsbayasgalan
2023-01-19 23:39:53 +00:00
Han Qi
00fe63d1d8 fx Graph should copy meta on deepcopy (#92062)
Summary:
fx Graph should copy meta on deepcopy

Test Plan:
Unit test

Reviewers:

Subscribers:

Tasks:

Tags:

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/92062
Approved by: https://github.com/zhxchen17
2023-01-18 02:49:14 +00:00
Zhengxu Chen
b7aa22d6db [fx] Fix GraphModule.print_readable() (#88730)
Summary: `__nested_code()` seems removed.

Test Plan: CI

Differential Revision: D41149662

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88730
Approved by: https://github.com/SherlockNoMad
2022-11-09 21:39:48 +00:00
Horace He
e150a6212b Added gm.print_readable to torchinductor_trace output (#87717)
cc @jansel @mlazos @soumith @voznesenskym @yanboliang @penguinwu @anijain2305
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87717
Approved by: https://github.com/ngimel
2022-10-25 22:31:49 +00:00
anjali411
a6c0442cce Add __all__ to torch.{autograd, fx, cuda} submodules (#85343)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85343
Approved by: https://github.com/albanD
2022-10-09 14:46:54 +00:00
Angela Yi
dd82b31e55 [fx] Add metadata to fx.GraphModule (#84378)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84378
Approved by: https://github.com/SherlockNoMad
2022-09-01 18:36:52 +00:00
Sherlock Huang
7e5c76da47 Make graph_module.print_readable() discoverable (#83960)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83960
Approved by: https://github.com/ezyang
2022-08-25 23:56:50 +00:00
Sherlock Huang
bf8d5e8328 Pretty print stack trace with gm.print_readable() (#83706)
Precondition: https://github.com/pytorch/torchdynamo/pull/899

Given following function
```
def my_relu(a):
    return a.relu()

def func(a, b):
    d = torch.square(a + b)
    e = my_relu(d)
    f = d.sin()
    s = torch.stack([e, f])
    s = s.sum()
```

Here are the possible result with various tracing frontend: dynamo, symbolic_trace, make_fx
- joint graph with torchdynamo.optimize("aot_nop")
Notice that it has a special stack for gradient addition node (for multiple uses of tensor) in backward
Notice that "No stacktrace found for following nodes" are shown for nodes with stacktrace
```
def forward(self, primals, tangents):
    primals_1, primals_2, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    add_tensor = torch.ops.aten.add.Tensor(primals_1, primals_2);  primals_1 = primals_2 = None
    pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2)

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    relu_default = torch.ops.aten.relu.default(pow_tensor_scalar)
    detach_default = torch.ops.aten.detach.default(relu_default)

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin_default = torch.ops.aten.sin.default(pow_tensor_scalar)

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    stack_default = torch.ops.aten.stack.default([relu_default, sin_default]);  relu_default = sin_default = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    sum_default = torch.ops.aten.sum.default(stack_default);  stack_default = None

    # No stacktrace found for following nodes
    is_same_size_default = torch.ops.aten.is_same_size.default(sum_default, tangents_1)

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    expand_default = torch.ops.aten.expand.default(tangents_1, [2, 10, 10]);  tangents_1 = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    unbind_int = torch.ops.aten.unbind.int(expand_default);  expand_default = None
    getitem = unbind_int[0]
    getitem_1 = unbind_int[1];  unbind_int = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    cos_default = torch.ops.aten.cos.default(pow_tensor_scalar);  pow_tensor_scalar = None
    mul_tensor = torch.ops.aten.mul.Tensor(getitem_1, cos_default);  getitem_1 = cos_default = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    detach_default_1 = torch.ops.aten.detach.default(detach_default);  detach_default = None
    threshold_backward_default = torch.ops.aten.threshold_backward.default(getitem, detach_default_1, 0);  getitem = detach_default_1 = None

    # Gradient addition node due to mulitple use of tensor around:, File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    add_tensor_1 = torch.ops.aten.add.Tensor(mul_tensor, threshold_backward_default);  mul_tensor = threshold_backward_default = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    pow_tensor_scalar_1 = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 1.0);  add_tensor = None
    mul_scalar = torch.ops.aten.mul.Scalar(pow_tensor_scalar_1, 2.0);  pow_tensor_scalar_1 = None
    mul_tensor_1 = torch.ops.aten.mul.Tensor(add_tensor_1, mul_scalar);  add_tensor_1 = mul_scalar = None
    sum_sym_int = torch.ops.aten.sum.SymInt(mul_tensor_1, [0], True)
    view_sym_int = torch.ops.aten.view.SymInt(sum_sym_int, [10]);  sum_sym_int = None
    return pytree.tree_unflatten([sum_default, mul_tensor_1, view_sym_int], self._out_spec)
```
- default symbolic_trace
Notice that nodes without stacktrace are folded under same region
```
def forward(self, a, b):

    # No stacktrace found for following nodes
    add = a + b;  a = b = None
    square = torch.square(add);  add = None
    relu = square.relu()
    sin = square.sin();  square = None
    stack = torch.stack([relu, sin]);  relu = sin = None
    sum_1 = stack.sum();  stack = None
    return sum_1
```
- symbolic_trace with record_stack_traces=True
```
def forward(self, a, b):

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    add = a + b;  a = b = None
    square = torch.square(add);  add = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    relu = square.relu()

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin = square.sin();  square = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    stack = torch.stack([relu, sin]);  relu = sin = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    sum_1 = stack.sum();  stack = None
    return sum_1
```

- make_fx without decomposition
```
def forward(self, a_1, b_1):

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    add_tensor = torch.ops.aten.add.Tensor(a_1, b_1);  a_1 = b_1 = None
    pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2);  add_tensor = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    relu_default = torch.ops.aten.relu.default(pow_tensor_scalar)
    detach_default = torch.ops.aten.detach.default(relu_default)

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin_default = torch.ops.aten.sin.default(pow_tensor_scalar);  pow_tensor_scalar = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    stack_default = torch.ops.aten.stack.default([relu_default, sin_default]);  relu_default = sin_default = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    sum_default = torch.ops.aten.sum.default(stack_default);  stack_default = None
    return sum_default
```
- make_fx with decomposition to prims
```
def forward(self, a_1, b_1):

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    broadcast_in_dim_default = torch.ops.prims.broadcast_in_dim.default(b_1, [10, 10], [1]);  b_1 = None
    add_default = torch.ops.prims.add.default(a_1, broadcast_in_dim_default);  a_1 = broadcast_in_dim_default = None
    mul_default = torch.ops.prims.mul.default(add_default, add_default);  add_default = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    le_default = torch.ops.prims.le.default(mul_default, 0.0)
    where_default = torch.ops.prims.where.default(le_default, 0.0, mul_default);  le_default = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin_default = torch.ops.prims.sin.default(mul_default);  mul_default = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    cat_default = torch.ops.prims.cat.default([where_default, sin_default], 0);  where_default = sin_default = None
    split_dim_default = torch.ops.prims.split_dim.default(cat_default, 0, 2);  cat_default = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    convert_element_type_default = torch.ops.prims.convert_element_type.default(split_dim_default, torch.float32);  split_dim_default = None
    sum_default = torch.ops.prims.sum.default(convert_element_type_default, [0, 1, 2]);  convert_element_type_default = None
    return sum_default
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83706
Approved by: https://github.com/Chillee, https://github.com/ezyang
2022-08-24 23:00:57 +00:00
Sergii Dymchenko
591222f5d9 Fix use-dict-literal lint (#83718)
Fix use-dict-literal pylint suggestions by changing `dict()` to `{}`. This PR should do the change for every Python file except test/jit/test_list_dict.py, where I think the intent is to test the constructor.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83718
Approved by: https://github.com/albanD
2022-08-24 00:26:46 +00:00
Sherlock Huang
43e7fee764 [Reland] Recursively print graph module and its submodule (#81639)
ghstack-source-id: fcfc024c440981ee3fe3537a5816089eadf2cc13
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81080

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/81639
Approved by: https://github.com/ezyang
2022-07-21 16:58:25 +00:00
PyTorch MergeBot
4035a53cca Revert "Recursively print graph module and its submodule (#81080)"
This reverts commit fe7262329c.

Reverted https://github.com/pytorch/pytorch/pull/81080 on behalf of https://github.com/DanilBaibak due to Break internal build
2022-07-18 14:46:26 +00:00
Sherlock Huang
fe7262329c Recursively print graph module and its submodule (#81080)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81080
Approved by: https://github.com/ezyang
2022-07-18 01:19:03 +00:00
Horace He
e7e835e50a Fix to folder by adding custom_builtins to dump (#81433)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81433
Approved by: https://github.com/jamesr66a
2022-07-14 21:39:13 +00:00
PyTorch MergeBot
58532256e9 Revert "Add __all__ for torch.distributed and fx modules (#80460)"
This reverts commit 5d40c3d5c8.

Reverted https://github.com/pytorch/pytorch/pull/80460 on behalf of https://github.com/malfet due to Broke MacOS testing, see https://github.com/pytorch/pytorch/runs/7105579664?check_suite_focus=true
2022-06-29 16:20:55 +00:00
anjali411
5d40c3d5c8 Add __all__ for torch.distributed and fx modules (#80460)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80460
Approved by: https://github.com/albanD, https://github.com/rohan-varma
2022-06-29 02:53:56 +00:00
James Reed
7311390d35 [WIP] Make constructor calls in experimental MetaTracer serializable
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76789

Approved by: https://github.com/pbelevich
2022-05-11 00:19:47 +00:00
Michael Suo
fb0f285638 [lint] upgrade mypy to latest version
Fixes https://github.com/pytorch/pytorch/issues/75927.

Had to fix some bugs and add some ignores.

To check if clean:
```
lintrunner --paths-cmd='git grep -Il .' --take MYPY,MYPYSTRICT
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76753

Approved by: https://github.com/malfet
2022-05-03 20:51:34 +00:00
PyTorch MergeBot
3d7428d9ac Revert "[lint] upgrade mypy to latest version"
This reverts commit 9bf18aab94.

Reverted https://github.com/pytorch/pytorch/pull/76753 on behalf of https://github.com/suo
2022-05-03 20:01:18 +00:00
Michael Suo
9bf18aab94 [lint] upgrade mypy to latest version
Fixes https://github.com/pytorch/pytorch/issues/75927.

Had to fix some bugs and add some ignores.

To check if clean:
```
lintrunner --paths-cmd='git grep -Il .' --take MYPY,MYPYSTRICT
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76753

Approved by: https://github.com/malfet
2022-05-03 19:43:28 +00:00
James Reed
bf730e5039 Fix unnecessary recursion in GraphModule.__call__
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76068

Approved by: https://github.com/Chillee
2022-04-21 03:25:39 +00:00
Pavel Belevich
96c8f64459 Remove with_traceback(None) in wrapped_call to show the root cause error
Before:
```
Traceback (most recent call last):
  File "/Users/pbelevich/PycharmProjects/PiPPy/test/t5_test.py", line 37, in <module>
    t5_pipe_output = t5_pipe(input_ids=t5_input, decoder_attention_mask=None, decoder_input_ids=decoder_input_ids)
  File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/pbelevich/PycharmProjects/PiPPy/pippy/IR.py", line 251, in forward
    return self.executor.run(*executor_args)
  File "/Users/pbelevich/PycharmProjects/PiPPy/pippy/IR.py", line 155, in run
    return super().run(*args, initial_env=initial_env)
  File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/interpreter.py", line 121, in run
    self.env[node] = self.run_node(node)
  File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/interpreter.py", line 148, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "/Users/pbelevich/PycharmProjects/PiPPy/pippy/IR.py", line 170, in call_module
    return super().call_module(target, args, kwargs)
  File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/interpreter.py", line 265, in call_module
    return submod(*args, **kwargs)
  File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/graph_module.py", line 630, in wrapped_call
    raise e.with_traceback(None)
AttributeError: 'NoneType' object has no attribute 'dtype'
```
After:
```
Traceback (most recent call last):
  File "/Users/pbelevich/PycharmProjects/PiPPy/test/t5_test.py", line 37, in <module>
    t5_pipe_output = t5_pipe(input_ids=t5_input, decoder_attention_mask=None, decoder_input_ids=decoder_input_ids)
  File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/pbelevich/PycharmProjects/PiPPy/pippy/IR.py", line 251, in forward
    return self.executor.run(*executor_args)
  File "/Users/pbelevich/PycharmProjects/PiPPy/pippy/IR.py", line 155, in run
    return super().run(*args, initial_env=initial_env)
  File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/interpreter.py", line 121, in run
    self.env[node] = self.run_node(node)
  File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/interpreter.py", line 148, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "/Users/pbelevich/PycharmProjects/PiPPy/pippy/IR.py", line 170, in call_module
    return super().call_module(target, args, kwargs)
  File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/interpreter.py", line 265, in call_module
    return submod(*args, **kwargs)
  File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/graph_module.py", line 630, in wrapped_call
    raise e
  File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/graph_module.py", line 620, in wrapped_call
    return cls_call(self, *args, **kwargs)
  File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/graph_module.py", line 630, in wrapped_call
    raise e
  File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/graph_module.py", line 620, in wrapped_call
    return cls_call(self, *args, **kwargs)
  File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/graph_module.py", line 630, in wrapped_call
    raise e
  File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/graph_module.py", line 620, in wrapped_call
    return cls_call(self, *args, **kwargs)
  File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/graph_module.py", line 630, in wrapped_call
    raise e
  File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/graph_module.py", line 620, in wrapped_call
    return cls_call(self, *args, **kwargs)
  File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/graph_module.py", line 630, in wrapped_call
    raise e
  File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/graph_module.py", line 620, in wrapped_call
    return cls_call(self, *args, **kwargs)
  File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/graph_module.py", line 630, in wrapped_call
    raise e
  File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/graph_module.py", line 622, in wrapped_call
    return super(cls, self).__call__(*args, **kwargs)
  File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "<eval_with_key>.42", line 74, in forward
  File "/Users/pbelevich/PycharmProjects/pbelevich-transformers/src/transformers/utils/fx.py", line 180, in wrapper
    return func(*args, **kwargs)
  File "/Users/pbelevich/PycharmProjects/pbelevich-transformers/src/transformers/modeling_utils.py", line 256, in create_extended_attention_mask_for_decoder
    causal_mask = causal_mask.to(attention_mask.dtype)
AttributeError: 'NoneType' object has no attribute 'dtype'
```

The last lines of stack trace show where the problem is
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74655
Approved by: https://github.com/ansley, https://github.com/rohan-varma
2022-03-25 14:40:45 +00:00
Animesh Jain
7ebab9247d FX graph module - prevent infinite recursion (#73866)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73866

super(type(self), self) in wrapped_call leads to infinite recursion for subclass of Fx graph module. This happens when we call _stateless.functional_call on a fx module.     https://github.com/pytorch/pytorch/blob/master/torch%2Fnn%2Futils%2F_stateless.py

Test Plan:
Tests added in https://github.com/pytorch/pytorch/pull/62436

Imported from OSS

Reviewed By: jansel

Differential Revision: D34737828

fbshipit-source-id: 871b897e1210173ccc83fe34d53fc41af00a39ee
(cherry picked from commit 3d0c5fc71503fa2782b497a9d39ce26288fd219b)
2022-03-09 06:09:57 +00:00
Horace He
d635d0f86e Refactor FX codegen into extensible Codegen object (#72566)
Summary:
The goal of this is to make FX's codegen extensible. I've refactored it into a class with 5 extensibility points on it.

```
class Codegen(object):
    def generate_prologue(self, free_vars: List[str], maybe_return_annotation: str) -> str:
        """
        Given the free variables and a return annotation, generates the beginning of the FX function.
        By default, `generate_prologue(['a', 'b'], '') == 'def forward(a, b):'`
        """
    def generate_output(self, output_args: Argument) -> str:
        """
        Given the output arguments, generates the return statement of the FX function.
        """
    def process_inputs(self, args: Any) -> Any:
        """
        Transforms the inputs so that the graph can take them as arguments, as
        non-default codegen may result in the inputs to the function being
        different from the inputs to the graph.

        If the graph was directly runnable, this invariant should hold true
        `f.process_outputs(f.graph(*f.process_inputs(*inputs))) == f(*inputs)`
        """
    def process_outputs(self, outputs: Any) -> Any:
        """
        Transforms the outputs of the graph to be identical to the codegen.

        See ``process_inputs`` for more details.
        """
    def additional_globals(self) -> List[Tuple[str, Any]]:
        """
        If your codegen uses extra global values, add them here.
        For example, return ['List', typing.List] if you need ``List`` in the global context.
        """
```

So, for example, the `ListCodeGen` we want for AOTAutograd looks like this
```
        class ListCodeGen(CodeGen):
            def generate_prologue(self, free_vars, maybe_return_annotation):
                lst_unpack = f"""
def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}:
    {', '.join(free_vars)} = args_list"""
                return lst_unpack

            def additional_globals(self):
                return [('List', typing.List)]

            def process_inputs(self, *inputs):
                assert(len(inputs) == 1)
                return inputs[0]
```
and
```
        def f(a, b):
            return a + b

        nf = fx.symbolic_trace(f)
        nf.graph.set_codegen(ListCodeGen())
        nf.recompile()
        print(nf.code)
```
would result in
```
def forward(self, args_list: List[torch.Tensor]):
    a, b = args_list
    add = a + b;  a = b = None
    return add
```

Backwards compatibility changes - I added `process_outputs` and `process_inputs` to `fx.Graph`, while removing `flatten_inputs` and `flatten_outputs` - those didn't have `backwards_compatibility` on them, so I *think* it's probably fine?

Pull Request resolved: https://github.com/pytorch/pytorch/pull/72566

Reviewed By: desertfire

Differential Revision: D34160424

Pulled By: Chillee

fbshipit-source-id: ebf6411312b373e3fbcb13288a34befa449a2375
(cherry picked from commit 13cd12eaa1)
2022-02-11 18:13:29 +00:00
Horace He
df6eb9bbab Fixed to_folder not saving dtype (#69983)
Summary:
As above.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/69983

Reviewed By: pbelevich, ngimel

Differential Revision: D33466529

Pulled By: Chillee

fbshipit-source-id: 2d2f0ad5b8e2492aba4c19fa034c8b6c0848a568
2022-01-06 22:15:56 -08:00
James Reed
3eb9443619 [FX] Fix issue where GraphModule.delete_all_unused_submodules deletes submodules from called leaf modules (#66430)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66430

On the whole, I'm not totally satisfied with this approach. I think we should be building a prefix tree data structure during initial iteration over the submodules and querying that when deleting submodules. But I think this approach works and I want to see if we can get it in before 1.10

Test Plan: Imported from OSS

Reviewed By: Chillee

Differential Revision: D31546137

Pulled By: jamesr66a

fbshipit-source-id: f08b8409a3cf511277017ccccb916097b7c4c4fe
2021-10-11 19:37:51 -07:00
James Reed
0c4e4e588e [FX] Rename reduce functions back to their old, public names (#64324)
Summary:
Unfortunately pickle serializes the names of these functions. Also put them under backward-compatibility enforcement.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/64324

Test Plan: Local repro https://fb.workplace.com/groups/3440841732711443/permalink/4018921611570116/

Reviewed By: SplitInfinity, TailofJune

Differential Revision: D30684185

Pulled By: jamesr66a

fbshipit-source-id: 900701220155d15115cd0c07cf7774a2891bd04f
2021-08-31 22:36:11 -07:00
Jay Leverett
44fcb00a56 Fix redundant class definition in GraphModule singleton constructor (#64274)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/63883

Pull Request resolved: https://github.com/pytorch/pytorch/pull/64274

Reviewed By: jamesr66a

Differential Revision: D30675970

Pulled By: jayleverett

fbshipit-source-id: e74ef2a28013f0fa7c58d14f38e66cfe48d26b74
2021-08-31 17:34:14 -07:00
James Reed
538647fe1f [WIP][FX] BC guarantees for 1.10 (#63888)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63888

Test Plan: Imported from OSS

Reviewed By: pbelevich

Differential Revision: D30523133

Pulled By: jamesr66a

fbshipit-source-id: b04cc0d842a74862f42ecba98b757310cd2ec7b0
2021-08-30 19:56:46 -07:00
James Reed
4e37a015c7 [FX] Fix _replicate_for_data_parallel (#63821)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63821

Test Plan: Imported from OSS

Reviewed By: suo

Differential Revision: D30502115

Pulled By: jamesr66a

fbshipit-source-id: 0f004f95def6e1ba21ccbeab40cb0a739a0ad20c
2021-08-24 13:48:15 -07:00
David Esiobu
79693bb86a Use linecache.lazycache to cache generated code. (#63453)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63453

Instead of patching linecache.getlines, use linecache.lazycache and
parts of the loader protocol described in PEP-302

Test Plan:
python3 test/test_fx.py

Imported from OSS

Reviewed By: suo

Differential Revision: D30388176

fbshipit-source-id: 92933711ecf3a21a07e1d6b0d1185ab0efd8341c
2021-08-19 09:17:01 -07:00
James Reed
d661e646ad [FX] Fix GraphModule deepcopy to use deepcopied graph (#63090)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63090

Test Plan: Imported from OSS

Reviewed By: ansley

Differential Revision: D30252471

Pulled By: jamesr66a

fbshipit-source-id: cafd7d7917935a5ea6ffa2a7fe9e9b2a9578b3e3
2021-08-18 13:17:14 -07:00
Bradley Davis
7a1ab9f5d7 [fx] store Tracer class on Graph and GraphModule for package deserialization [v2, the re-do] (#63121)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63121

Re-introducing this diff with a small change to ignore setting Tracer classes on GraphModules when the Tracer class is defined not at module-level (prevents pickling).

Previous, reverted Pull Request: https://github.com/pytorch/pytorch/pull/62497

Reviewed By: houseroad

Differential Revision: D30252776

fbshipit-source-id: 42d2bc846e4b32d00563419c38c02b63cd0986e6
2021-08-12 17:28:50 -07:00
Lu Fang
847a7cfa10 Back out "[fx] store Tracer class on Graph and GraphModule for package deserialization" (#63053)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63053

Original commit changeset: eca09424ad30

The original diff - D30019214 (6286d33878) breaks the publish flow in model saving.

Test Plan: ci

Differential Revision: D30236517

fbshipit-source-id: 3e05db02fc1cbbc2ed262c83bf56d555277abb34
2021-08-10 21:58:08 -07:00
Bradley Davis
6286d33878 [fx] store Tracer class on Graph and GraphModule for package deserialization (#62497)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62497

Previously named: add support for custom tracer in __reduce_package__

Stores a Tracer class on a Graph created by Tracer, and copies the Tracer class into the GraphModule's state so that when a GraphModule is packaged by torch package, it can be reconstructed with the same Tracer and GraphModule class name.

Reviewed By: suo

Differential Revision: D30019214

fbshipit-source-id: eca09424ad30feb93524d481268b066ea55b892a
2021-08-09 13:07:30 -07:00
Bradley Davis
093495d3f0 [fx] prevent implicit submodule inlining when submodule is a GraphModule (#62436)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62436

## Problem

Given two modules and a tracer that indiscriminately marks all modules as a leaf:
```
class InnerModule(torch.nn.Module):

    def forward(self, t):
        return t + t

class MyModule(torch.nn.Module):
    def __init__(self, inner):
        super().__init__()
        self.inner = inner

    def forward(self, t):
        x = self.inner(t)
        y = self.inner(t)
        return x + y

class MyTracer(torch.fx.Tracer):
    def is_leaf_module(self, module, name):
        return True
```

One might generally expect the following behavior (note call_module nodes):
```
print(">> Outer GraphModule (with inner module as nn.Module):")
inner = InnerModule()
m = MyModule(inner)
gm = torch.fx.GraphModule(m, MyTracer().trace(m))
print(gm.graph.print_tabular())

>> Outer GraphModule (with inner module as nn.Module):
opcode         name     target                   args              kwargs
-------------  -------  -----------------------  ----------------  --------
placeholder    t        t                        ()                {}
call_module    inner    inner                    (t,)              {}
call_module    inner_1  inner                    (t,)              {}
call_function  add      <built-in function add>  (inner, inner_1)  {}
output         output   output                   (add,)            {}
None
```

However, when the inner module is first symbolically traced, the symbolic trace of the outer module ignores `is_leaf_node` entirely, and traces through the whole module (note call_function nodes).
```
print(">> Inner module as GraphModule:")
inner = InnerModule()
inner_gm = torch.fx.GraphModule(inner, MyTracer().trace(inner))
print(inner_gm.graph.print_tabular())

print(">> Outer GraphModule (with inner module as GraphModule):")
m = MyModule(inner_gm)
gm = torch.fx.GraphModule(m, MyTracer().trace(m))
print(gm.graph.print_tabular())

>> Inner module as GraphModule:
opcode         name    target                   args    kwargs
-------------  ------  -----------------------  ------  --------
placeholder    t       t                        ()      {}
call_function  add     <built-in function add>  (t, t)  {}
output         output  output                   (add,)  {}
None

>> Outer GraphModule (with inner module as GraphModule):
opcode         name    target                   args          kwargs
-------------  ------  -----------------------  ------------  --------
placeholder    t       t                        ()            {}
call_function  add     <built-in function add>  (t, t)        {}
call_function  add_1   <built-in function add>  (t, t)        {}
call_function  add_2   <built-in function add>  (add, add_1)  {}
output         output  output                   (add_2,)      {}
None
```

This is surprising behavior and at first glance violates the tracer's intent. As I understand it, `torch.fx.symbolic_trace.Tracer.trace` intends to patch `torch.nn.Module.__call__` with a `module_call_wrapper()` that records a `call_module` node if the module is a leaf, else executes `torch.fx._symbbolic_trace._orig_module_call = torch.nn.Module.__call__`, which is set a module loading time.

**Every submodule should be a leaf, but no `call_module` nodes are created when that submodule is a `GraphModule`. Why?**

Upon further inspection, I found:

- The constructor for GraphModule includes a path to `GraphModule.recompile()` via the setter for a `fx.Graph`:
```
inner_gm = torch.fx.GraphModule(inner, MyTracer().trace(inner))

File "/torch/fx/graph_module.py", line 252, in __init__
self.graph = graph

File "/torch/nn/modules/module.py", line 1183, in __setattr__
object.__setattr__(self, name, value)

File "/torch/fx/graph_module.py", line 277, in graph
self.recompile()
```
- `recompile()` wraps the `__call__` method by holding a reference to the `__call__` method at the time of recompilation:
```
cls = type(self)
cls_call = cls.__call__
...
def wrapped_call(self, *args, **kwargs):
    try:
        return cls_call(self, *args, **kwargs)
    except Exception as e:
        ...
cls.__call__ = wrapped_call
```
- Recompilation of the inner GraphModule happens on initialization, before creation or tracing of the outer module. Adding some old-fashioned print debug statements gives:
```
Inner Module:
_orig_module_call: <function Module._call_impl at 0x7faaebfee8b0>
recompile: cls.__call__ now wraps _orig_module_call, <function Module._call_impl at 0x7faaebfee8b0>

Outer Module:
_orig_module_call: <function Module._call_impl at 0x7faaebfee8b0>
tracing: patching method <class 'torch.nn.modules.module.Module'>.__call__ <function Module._call_impl at 0x7faaebfee8b0> with <function Module._call_impl at 0x7fa9d42bce50>

outer module MRO before tracing:
(0) <class '__main__.MyModule'>: <function Module._call_impl at 0x7faaebfee8b0>
(1) <class 'torch.nn.modules.module.Module'>: <function Module._call_impl at 0x7faaebfee8b0>
(2) <class 'object'>: <method-wrapper '__call__' of type object at 0x7fac3cd15f00>

outer module MRO during tracing:
(0) <class '__main__.MyModule'>: <function Module._call_impl at 0x7fa9d42bce50>
(1) <class 'torch.nn.modules.module.Module'>: <function Module._call_impl at 0x7fa9d42bce50>
(2) <class 'object'>: <method-wrapper '__call__' of type object at 0x7fac3cd15f00>

inner module MRO before tracing:
(0) <class 'torch.fx.graph_module.GraphModule.__new__.<locals>.GraphModuleImpl'>: <function x.y.z.wrapped_call at 0x7fa9d42a8670>
(1) <class 'torch.fx.graph_module.GraphModule'>: <function Module._call_impl at 0x7faaebfee8b0>
(2) <class 'torch.nn.modules.module.Module'>: <function Module._call_impl at 0x7faaebfee8b0>
(3) <class 'object'>: <method-wrapper '__call__' of type object at 0x7fac3cd15f00>

inner module MRO during tracing:
(0) <class 'torch.fx.graph_module.GraphModule.__new__.<locals>.GraphModuleImpl'>: <function x.y.z.wrapped_call at 0x7fa9d42a8670>
(1) <class 'torch.fx.graph_module.GraphModule'>: <function Module._call_impl at 0x7fa9d42bce50>
(2) <class 'torch.nn.modules.module.Module'>: <function Module._call_impl at 0x7fa9d42bce50>
(3) <class 'object'>: <method-wrapper '__call__' of type object at 0x7fac3cd15f00>
```

- The outer module is patched correctly, but the inner module's first element in its MRO is the `wrapped_call` from `recompile` that still invokes `<function Module._call_impl at 0x7faaebfee8b0>` directly. Therefore, no call_module nodes are created.

## In Practice

In practice, this behavior affects the ability of `torch.package` to package `GraphModules` whose submodules are `GraphModules`. In our case, the `GraphModule` submodules are not passed through a constructor, but created separately and installed on the root `GraphModule` via `setattr`. This means that prior to packaging, there appear to be no issues with the module, since the root's graph was created before any call_module targets were replaced with `GraphModules`.

When unpackaging such a model with `torch.package`, `torch.fx.graph_module._deserialize_graph_module` uses an inline `KeepModules` tracer that sets all submodules to leaves; the unpackaged module is implicitly and surprisingly inlined in the process.

## Potential Solution

This behavior was previously not understood by us, and so the current workaround is a gnarly process of wrapping all submodules with a `nn.Module` with a manually installed forward method.

Changing `wrapped_call` to return `return super(type(self), self).__call__(*args, **kwargs)` whenever `__call__` is inherited at least appears to solve the issue. Does this seem like an acceptable approach?

## Other Thoughts
- Repeated calls to `recompile` create nested `wrapped_calls`, all for the purpose of error handling. This seems probably unnecessary ¯\\_(ツ)\_/¯
- If a root module with a overriden `__call__` method is symbolically traced, it is ignored

Test Plan:
```
buck test:
    ✓ ListingSuccess: caffe2/test:fx - main (12.570)
    ✓ Pass: caffe2/test:fx - test_tracing_graphmodules_as_leaf_submodules (test_fx.TestFX) (11.982)
```

Reviewed By: ansley

Differential Revision: D29997935

fbshipit-source-id: 1988fbb025b14188da26a3e73e94fb789c3c1f74
2021-08-02 13:37:08 -07:00
Harut Movsisyan
810e19979d Torch deploy for fx.grapm_module with non-torch dependencie (#61680)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/61680

This diff enables torch deploy for fx.graph_module with non-torch dependencies . Here are the issues currently preventing this and are fixed in this change:
-  Pickle is used as an internal format to transmit objects between interpreters. It needs to serialize python code, but to be able to get the source code for imports from python_code.globals it needs access to the PackageImporter. Currently a regular _reduce_ function is used which doesn't have the notion of custom importer.
- When deserializing pickled objects on an interpreter, it is passing empty globals to exec, thus it will not be able to resolve non-torch imports located in the package. We need to be able to point exec to our custom PackageImporter.
- Subclasses extending fx.graph_module should be able to optionally provide their own Tracer (extending fx.Tracer).

As a solution a new reducer is introduced (_reduce_deploy_) for torch deploy workflow. Reducer will be registered in _deploy.py (entry point for C++ torch deploy API) when saving the object transmitting it between interpreters. It allows us to pass a proper PackageImporter for each interpreter for pickling/unpickling fx.graph_module. It also defines an api for passing custom fx.Tracer when needed.

Test Plan:
Added UT to cover changes.
```
buck test //caffe2/torch/csrc/deploy:test_deploy
```
```
buck test caffe2/test:fx
```

Reviewed By: suo

Differential Revision: D29690088

fbshipit-source-id: 3a8dbe02d5d7e085534aa61b7773c86f0f8c19b0
2021-07-21 10:29:48 -07:00
Nikita Shulga
4e94e84f65 Type annotate torch.nn.Module ctor (#61334)
Summary:
Annotate generic types
Fix some type violations
Override `_modules` and `_parameters` in `Sequential`, `ModuleList`, `ModuleDict`, etc

Fixes https://github.com/pytorch/pytorch/issues/45497

Pull Request resolved: https://github.com/pytorch/pytorch/pull/61334

Reviewed By: albanD

Differential Revision: D29579533

Pulled By: malfet

fbshipit-source-id: 5cd8ca918b260ca35cfdd873dee8851d39d17de2
2021-07-16 13:59:06 -07:00
Ansley Ussery
5268b5a29a Add parsing logic for Tuple[()] annotation (#58340)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/58340

Test Plan: Imported from OSS

Reviewed By: jamesr66a

Differential Revision: D28459502

Pulled By: ansley

fbshipit-source-id: 4bb188448d66269b42b068858b895debac86e9ee
2021-05-25 12:12:43 -07:00
James Reed
7b73fdf597 [FX] Fix retracing wrapped functions (#58061)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/58061

Test Plan: Imported from OSS

Reviewed By: yuhc

Differential Revision: D28358801

Pulled By: jamesr66a

fbshipit-source-id: c7c9a8a80e5bfe1eb1f6d2cf858ac7e57153a860
2021-05-17 19:50:16 -07:00
Horace He
8d363d37da [FX] Adds PyTree support to FX through concrete_args (#55888)
Summary:
```
class Foo(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, y, x):
        for k in x:
            for v in x[k]:
                v += y
        return x

example_dict = {'x': {'a': [fx.HOLE], 'z': [fx.HOLE, fx.HOLE]}}
new_f = fx.symbolic_trace(Foo(), concrete_args=example_dict)
print(new_f.code)
new_f(torch.randn(5), {'x': {'a': [torch.randn(5)], 'z': [torch.randn(5), torch.randn(5)]}})

fx.symbolic_trace(new_f, concrete_args=example_dict)
```

prints out
```
def forward(self, y, x):
    y, tree_2, tree_3, tree_4 = pytree.tree_flatten([y, x])[0]
    add = tree_2 + y
    add_1 = tree_3 + y
    add_2 = tree_4 + y;  y = None
    return {'a': [tree_2], 'z': [tree_3, tree_4]}
```

Currently, I store `in_spec` as an extra attribute on `fx.Graph`, and then include it when we do the codegen. I'm not sure if this is the right approach - it introduces a divergence between what's in `fx.Graph` and what's in the python code.

Perhaps the best API is something explicit like `fx.Graph.flatten_args`, but that does make calling things a bit ... more verbose.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/55888

Reviewed By: jamesr66a

Differential Revision: D27884694

Pulled By: Chillee

fbshipit-source-id: f9e8a70c63a8df63c9f9bd0a6459255daa5a8df8
2021-05-07 04:48:35 -07:00
Sam Estep
75024e228c Add lint for unqualified type: ignore (#56290)
Summary:
The other half of https://github.com/pytorch/pytorch/issues/56272.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/56290

Test Plan:
CI should pass on the tip of this PR, and we know that the lint works because the following CI runs (before this PR was finished) failed:

- https://github.com/pytorch/pytorch/runs/2384511062
- https://github.com/pytorch/pytorch/actions/runs/765036024

Reviewed By: seemethere

Differential Revision: D27867219

Pulled By: samestep

fbshipit-source-id: e648f07b6822867e70833e23ddafe7fb7eaca235
2021-04-21 08:07:23 -07:00
Nikita Shulga
add49e7e4e Enforce PEP263 for PyTorch python codebase (#55346)
Summary:
All python files containing non-ASCII characters should be correctly annotated with `# -*- coding: utf-8 -*-` comment

Delete number of superfluous UTF-8 characters, most commonly UTF-8 opening closing quotation mark U+2019 (’) instead of ascii apostrophe ', for example `Module’s`->`Module's`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/55346

Reviewed By: samestep

Differential Revision: D27582044

Pulled By: malfet

fbshipit-source-id: c1cd89655915858ff3a41f675cdfffff795a8e44
2021-04-06 18:31:38 -07:00
Shiyan Deng
15b087cdd2 [fx]Allow rewrite a symbolic traced module (#54011)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/54011

After symbolic tracing, `fn` seems to already have "forward" in its globals. In this case, `new_keys` would have length of 0 and we take "forward" from `global_dict` directly as `fn_compiled`.

Test Plan: Added a new test in test_fx_experimental.

Reviewed By: ansley

Differential Revision: D27049012

fbshipit-source-id: 7fbeb50ebb717900ff5fc0a8a0925d6a97f5a6dd
2021-04-05 18:35:51 -07:00
James Reed
05a03a6c8c [FX][EZ] Fix type correctness on GraphModule.graph (#54305)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/54305

Test Plan: Imported from OSS

Reviewed By: suo

Differential Revision: D27181176

Pulled By: jamesr66a

fbshipit-source-id: ed91cfed193984249c07a5bafc7aa732bfe0194d
2021-03-19 11:48:15 -07:00
Jordan Fix
1053c96693 [GraphModule] Back out changes to module root version of __init__ (#53791)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/53791

Reviewed By: houseroad

Differential Revision: D26970869

fbshipit-source-id: 80684516f57fd2d1aca794f17fe488b2fe2b2f64
2021-03-10 23:18:56 -08:00
Jordan Fix
3b0e4a6ed4 [GraphModule] Improve buffer registration during init (#53444)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/53444

GraphModule construction has two options when constructing the base nn.Module: a dict of names to attrs to assign to the GraphModule, or another nn.Module to copy attrs from.

- For the dict case, add logic to explicitly register `nn.Tensors` that are not `nn.Parameter` as buffers on the GraphModule, else fall back to `__setattr__`.
- For the other `nn.Module` case, update so that it checks in the other module whether the attr to copy in is a buffer, and register it as such, else fall back to `__setattr__`.

Test Plan: Added tests for fetching params and buffers from a GraphModule using both dict and module `__init__`s

Reviewed By: jamesr66a

Differential Revision: D26860055

fbshipit-source-id: 8d9999f91fef20aaa10969558006fc356247591f
2021-03-09 21:05:01 -08:00
Sam Estep
8c798e0622 Forbid trailing whitespace (#53406)
Summary:
Context: https://github.com/pytorch/pytorch/pull/53299#discussion_r587882857

These are the only hand-written parts of this diff:
- the addition to `.github/workflows/lint.yml`
- the file endings changed in these four files (to appease FB-internal land-blocking lints):
  - `GLOSSARY.md`
  - `aten/src/ATen/core/op_registration/README.md`
  - `scripts/README.md`
  - `torch/csrc/jit/codegen/fuser/README.md`

The rest was generated by running this command (on macOS):
```
git grep -I -l ' $' -- . ':(exclude)**/contrib/**' ':(exclude)third_party' | xargs gsed -i 's/ *$//'
```

I looked over the auto-generated changes and didn't see anything that looked problematic.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/53406

Test Plan:
This run (after adding the lint but before removing existing trailing spaces) failed:
- https://github.com/pytorch/pytorch/runs/2043032377

This run (on the tip of this PR) succeeded:
- https://github.com/pytorch/pytorch/runs/2043296348

Reviewed By: walterddr, seemethere

Differential Revision: D26856620

Pulled By: samestep

fbshipit-source-id: 3f0de7f7c2e4b0f1c089eac9b5085a58dd7e0d97
2021-03-05 17:22:55 -08:00
Ansley Ussery
85109ce427 Support submodule manipulation in GraphModule (#52358)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/52358

Test Plan: Imported from OSS

Reviewed By: jamesr66a

Differential Revision: D26759260

Pulled By: ansley

fbshipit-source-id: 25d2b9124a7d957704f1700a45dca143aaed391d
2021-03-04 14:52:35 -08:00
Michael Suo
958d9a8364 [fx/package] make GraphModules packageable (#51976)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/51976

FX serializes things by serializing Python code as a string and exec'ing
it on load. This accomplishes one goal (we don't have to pickle the
graph object directly) but breaks the pickle abstraction in ways that
are not composable with `torch.package`.

In particular:
1. `forward` is serialized by saving Python code. On load, it's
installed
by  `exec`ing that code. This `exec` call needs to have the right
importer installed, otherwise it will not import modules from the
`torch.package` but instead import from the Python environment.
2. Any types/functions used are emitted as `import` statement in the
generated Python code. These are effectively dynamic dependencies of the
`GraphModule` being saved, and need to be registered as such so that the
`PackageImporter` will package them.

To address these, this PR introduces a new protocol for the
importer/exporter: `__reduce_package__`.

A class can implement `__reduce_package__` to customize how it is placed
in the importer/exproter. It functions very similarly to `__reduce__`,
except:
- `__reduce_package__` takes one argument, which is the
`PackageExporter`
instance. Users can use this instance to save stuff to the package to
implement their serialization. `__reduce__` takes no args.
- Only the 2-element tuple version of the return value for `__reduce__`
is supported (this could be extended if necessary).
- When the reduction function is called on load, an additional argument
is added to the beginning of the args tuple. This is the
`PackageImporter`
instance doing the loading.

The `__reduce_package__` protocol is defined using `persistent_id` and
`persistent_load`, which ensures that we can still use the cpickle
implementation of the pickler by default.

Pull Request resolved: #51971

Test Plan: Imported from OSS

Reviewed By: zdevito

Differential Revision: D26340591

Pulled By: suo

fbshipit-source-id: 5872a7d22e832056399a7372bae8a57807717882
2021-02-23 22:43:00 -08:00
Michael Suo
ecf3ca00d8 [fx] Separate globals assignment from code generation (#51974)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/51974

Right now, when an FX `Graph` references an external object, we will emit
code like:

    import foo
    def forward(input: foo.bar.baz):
        ...

This is problematic in a world with `torch.package`, since then name
`foo.bar.baz` may reference a name from any number of packages.

This PR lays the groundwork for FX-package integration by separating the
resolution of external references from the genration of the function
code.

When generating a Graph's Python source, we keep track of all external
references and assign them unique names. At the end, we have a
dictionary mapping names -> actual objects. This becomes the `globals`
namespace we pass to `exec` when installing the forward function in a
`GraphModule`. This is nice because we can always be sure that `exec` is
seeing the same objects that were referenced from the `Graph`, no import
statements needed.

At serialization time, we use a `ModuleEnv` to resolve the globals dict
to a set of import statements that can be run to reprodce the `global`
namespace. This is only used on serialiation/deserialization, and those
functions are expected to check that the import statements are producing
the correct results.

Concretely, the code above will now look like:

    from foo.bar import baz as foo_bar_baz
    def forward(input: foo_bar_baz):
        ...

Test Plan: Imported from OSS

Reviewed By: jamesr66a

Differential Revision: D26340593

Pulled By: suo

fbshipit-source-id: fe247f75205d0a03fd067bdd0f95491e8edf1436
2021-02-23 13:48:03 -08:00