Commit Graph

915 Commits

Author SHA1 Message Date
PyTorch MergeBot
380ccfd442 Revert "Added round_with_scale_factor arg to ATen (#97868)"
This reverts commit aa99c5b4ed.

Reverted https://github.com/pytorch/pytorch/pull/97868 on behalf of https://github.com/osalpekar due to Caused breakages in the glow compiler - see [D45374622](https://www.internalfb.com/diff/D45374622) for more details
2023-04-28 20:47:00 +00:00
vfdev-5
aa99c5b4ed Added round_with_scale_factor arg to ATen (#97868)
Addresses #62396 following the strategy described in https://github.com/pytorch/pytorch/pull/64983#issuecomment-1026177629.

Fixing output size to match opencv, scikit-image, scipy if scale factor is specified on ATen side only due to JIT FC.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97868
Approved by: https://github.com/lezcano, https://github.com/mikaylagawarecki
2023-04-26 18:48:37 +00:00
Vivek Khandelwal
bb4998b531 Add shape function for aten::cross_entropy_loss (#97875)
Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97875
Approved by: https://github.com/davidberard98
2023-04-12 22:11:56 +00:00
Vivek Khandelwal
5810f5ad1a Fix aten::squeeze.dims shape function (#98078)
Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>

Fixes https://github.com/llvm/torch-mlir/issues/1690#issuecomment-1491931180.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98078
Approved by: https://github.com/davidberard98
2023-03-31 20:24:09 +00:00
Han Qi (qihqi)
b895a0a675 [BE] Move flatbuffer related python C bindings to script_init (#97476)
Summary:
Extra C binding module for flatbuffer was introduced because
not all dependencies of Pytorch want (or can) bundle in flatbuffer.

However, flatbuffer is in by default now so this separate binding is not longer needed.

Test Plan: existing unit tests

Differential Revision: D44352583

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97476
Approved by: https://github.com/dbort
2023-03-28 17:56:32 +00:00
Vivek Khandelwal
428540001d Add shape function for squeeze.dims op (#93919)
Changes to `_native_batch_norm_legit` and `upsample_nearest2d` in `serialized_shape_function_registry.cpp` are made just because this file is auto-generated, and the file was not auto-generated after the changes in `_shape_functions.py` for those two ops.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/93919
Approved by: https://github.com/davidberard98
2023-03-28 14:55:00 +00:00
David Berard
a133b5081c [JIT] Partially support ForwardRef type annotations for NamedTuple attributes (#96933)
**Summary** NamedTuple attributes can be annotated to declare their type:
```python
class MyNamedTuple(NamedTuple):
    x: int
    y: torch.Tensor
    z: MyOtherType
```
Normally in python you can also declare your types as strings, `x: 'int'`. But NamedTuples previously didn't support this, because their annotation evaluation process was slightly different. This PR updates the NamedTuple attribute type annotation evaluation method to support ForwardRef declarations (i.e. declaring as strings).

**Details**

Below I repeat the comment I left in _jit_internal.py:

NamedTuple types are slightly different from normal types.

Normally, annotations are evaluted like this (during jit.script):
1. Load strings of python code into c++ and parse.
2. Get annotations as strings
3. Use the PythonResolver's resolution callback (rcb) to convert the string into a python object
4. We call into annotations.py:ann_to_type to convert python obj from step 3 into a type that torchscript understands.

NamedTuples are more complicated, because they have sub-types. Normally, once we have the NamedTuple type object from #3, we can just look at the annotation literal values and use ann_to_type directly on them.

But sometimes, users will annotate with string literals, e.g.
```
   x: 'int'
```
This also happens with PEP563 (from __forward__ import annotations)

These annotations appear in the annotation dict as ForwardRef('int').

Then, we need to convert the string into a python object. This requires having local context for custom objects or imported types. rcb() is what gives us this. So, we plumb rcb through the stack so it can be used in this context for the if block below.

FAQ:
- Why do we need this special handling for NamedTuple but string annotations work fine for normal types? Normally, we parse the string directly and then call rcb() directly from C++.
- Why not use ForwardRef._evaluate? For that, we need globals() and locals() for the local context where the NamedTuple was defined. rcb is what lets us look up into these. So, basically rcb does the hard work for us.
- What is rcb? rcb is a ResolutionCallback - python callable that takes a string and returns a type. It's generated by `createResolutionCallback.*` in _jit_internal.py.

**Why is this only partial support**:

This only plumbs the rcb through some paths. In particular, the `toSugaredValue` path uses a fake rcb.

**Alternatives**:

We could also treat this the way we treat non-nn.Module classes: we evaluate them separately, ahead of time. That solution is probably better, but probably requires a more risky refactor for the way NamedTuples are handled.

Fixes #95858

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96933
Approved by: https://github.com/qihqi
2023-03-22 15:20:38 +00:00
Fabian Schuetze
a7a09adb86 Add location information for assertions in torch.jit.annotations.try_ann_to_type (#96423)
There are two assertions in `torch.jit.annotations.try_ann_to_type` that could benefit from adding source level location information.

For example, the current assertion:
```
        msg = "Unsupported annotation {} could not be resolved because {} could not be resolved."
        assert valid_type, msg.format(repr(ann), repr(contained))
```
reports:
```
AssertionError: Unsupported annotation typing.Union[typing.Dict, NoneType] could not be resolved because typing.Dict could not be resolved at
```
I find it beneficial to know from which line of code this assertion was triggered. Adding the location information then reports:
```
AssertionError: Unsupported annotation typing.Union[typing.Dict, NoneType] could not be resolved because typing.Dict could not be resolved at
  File "/home/schuetze/Documents/work/github/prediction_net/multimodal/models/heads/retina_head.py", line 189
    def forward(self, fpn_features: t.Dict[str, torch.Tensor],
                inputs: t.Dict[str, torch.Tensor],
                gts: t.Optional[t.Dict] = None) -> t.Dict[str, t.Any]:
                     ~~~~~~~~~~~~~~~~~~ <--- HERE
        """
        """
```

Adding these location information are related to #96420  but these changes in this PR can be made without any API changes.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96423
Approved by: https://github.com/davidberard98
2023-03-11 21:49:13 +00:00
Will Constable
2f6a371ae9 Revert "Optimize nn.Module __call__ fast path for dynamo (#95931)" (#96242)
Reverting due to concerns over silent unsoundness (skipped hooks) if users have directly added hooks dicts without using official torch APIs.

This reverts commit 26045336ca.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96242
Approved by: https://github.com/albanD
2023-03-10 01:05:01 +00:00
PyTorch MergeBot
9137f53ec2 Revert "Error when jit.trace/script is used with torch.compile (#91681)"
This reverts commit fa92b6a7b0.

Reverted https://github.com/pytorch/pytorch/pull/91681 on behalf of https://github.com/izaitsevfb due to Breaks internal tests, see T147501786
2023-03-08 18:47:38 +00:00
lijiahao
3d5eba811a Add shape function for stack op (#92205)
As @ramiro050 requested in https://github.com/llvm/torch-mlir/pull/1747, this PR moved the shape code for stack op from torch-mlir to pytorch upstream.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/92205
Approved by: https://github.com/eellison
2023-03-07 20:45:56 +00:00
Mark Saroufim
fa92b6a7b0 Error when jit.trace/script is used with torch.compile (#91681)
Fixes https://github.com/pytorch/pytorch/issues/93485

```python
import torch
from torchvision.models import resnet50

model = resnet50(weights=None)
compile_model = torch.compile(model)
print(type(compile_model))
example_forward_input = torch.rand(1, 3, 224, 224)
c_model_traced = torch.jit.trace(compile_model, example_forward_input) # or torch.jit.script
torch.jit.save(c_model_traced, "c_trace_model.pt")
```

Should I raise a warning if a user tries to compile a scripted or traced model as well? It works just fine now on resnet but not sure if it's that something we want to explicitly discourage

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91681
Approved by: https://github.com/desertfire
2023-03-06 02:03:35 +00:00
Will Constable
26045336ca Optimize nn.Module __call__ fast path for dynamo (#95931)
This PR optimizes the guards overhead introduced by dynamo tracing module forward hooks.

It can and maybe should be followed by a wider change proposed by @voznesenskym to optimize specialized nnmodules by 'observing' any user mutations and directly invalidating the root guard, obviating the need to install other nnmodule guards.  (But this observer change seems more involved...)

Idea: maintain a flag, and keep it up to date whenever adding or
removing hooks. Use the flag rather than dict checks to enter the call fast path.
  - need to extend RemovableHandle to keep a ref to nnModule so it can update the flag on removal.
  - also need to handle the flag in ScriptModule which still uses the python call impl when called from python.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95931
Approved by: https://github.com/ezyang, https://github.com/voznesenskym
2023-03-04 15:09:40 +00:00
Xuehai Pan
ef731cdaf0 [2/3] Update .pyi Python stub files: Prettify rnn.py by using type annotated NamedTuple (#95267)
Changes:

- #95200

1. Recognize `.py.in` and `.pyi.in` files as Python in VS Code for a better development experience.
2. Fix deep setting merge in `tools/vscode_settings.py`.

- => this PR: #95267

3. Use `Namedtuple` rather than `namedtuple + __annotations__` for `torch.nn.utils.rnn.PackedSequence_`:

    `namedtuple + __annotations__`:

    ```python
    PackedSequence_ = namedtuple('PackedSequence_',
                                 ['data', 'batch_sizes', 'sorted_indices', 'unsorted_indices'])

    # type annotation for PackedSequence_ to make it compatible with TorchScript
    PackedSequence_.__annotations__ = {'data': torch.Tensor, 'batch_sizes': torch.Tensor,
                                       'sorted_indices': Optional[torch.Tensor],
                                       'unsorted_indices': Optional[torch.Tensor]}
    ```

    `Namedtuple`: Python 3.6+

    ```python
    class PackedSequence_(NamedTuple):
        data: torch.Tensor
        batch_sizes: torch.Tensor
        sorted_indices: Optional[torch.Tensor]
        unsorted_indices: Optional[torch.Tensor]
    ```

- #95268

4. Sort import statements and remove unnecessary imports in `.pyi`, `.pyi.in` files.
5. Format `.pyi`, `.pyi.in` files and remove unnecessary ellipsis `...` in type stubs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95267
Approved by: https://github.com/janeyx99
2023-03-01 19:37:23 +00:00
Peter Bell
bc438af6fe std/var: support floating point correction value (#94073)
Ref https://github.com/pytorch/pytorch/issues/61492#issuecomment-1413003480

The array API specifies correction to be `Union[int, float]` while we currently only support integers.
https://data-apis.org/array-api/latest/API_specification/generated/array_api.std.html

As std/var is calculated currently, the final count of elements is already done
in floating point so we can make the correction floating point without any loss
of precision or generality.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94073
Approved by: https://github.com/ezyang
2023-02-23 05:50:45 +00:00
Jason Ansel
ae57bd6630 PT2/TorchScript interoperability fix (#94678)
Allows torch.compile() to inline into ScriptFunction

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94678
Approved by: https://github.com/ezyang
2023-02-15 01:21:10 +00:00
Xuehai Pan
b005ec62b9 [BE] Remove dependency on six and future (#94709)
Remove the Python 2 and 3 compatibility library [six](https://pypi.org/project/six) and [future](https://pypi.org/project/future) and `torch._six`. We only support Python 3.8+ now. It's time to retire them.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94709
Approved by: https://github.com/malfet, https://github.com/Skylion007
2023-02-14 09:14:14 +00:00
Xuehai Pan
046e88a291 [BE] [3/3] Rewrite super() calls in test (#94592)
Rewrite Python built-in class `super()` calls. Only non-semantic changes should be applied.

- #94587
- #94588
- #94592

Also, methods with only a `super()` call are removed:

```diff
class MyModule(nn.Module):
-   def __init__(self):
-       super().__init__()
-
    def forward(self, ...):
        ...
```

Some cases that change the semantics should be kept unchanged. E.g.:

f152a79be9/caffe2/python/net_printer.py (L184-L190)

f152a79be9/test/test_jit_fuser_te.py (L2628-L2635)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94592
Approved by: https://github.com/ezyang, https://github.com/seemethere
2023-02-12 22:20:53 +00:00
Aaron Gokaslan
67d9790985 [BE] Apply almost all remaining flake8-comprehension checks (#94676)
Applies the remaining flake8-comprehension fixes and checks. This changes replace all remaining unnecessary generator expressions with list/dict/set comprehensions which are more succinct, performant, and better supported by our torch.jit compiler. It also removes useless generators such as 'set(a for a in b)`, resolving it into just the set call.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94676
Approved by: https://github.com/ezyang
2023-02-12 01:01:25 +00:00
Xuehai Pan
5b1cedacde [BE] [2/3] Rewrite super() calls in functorch and torch (#94588)
Rewrite Python built-in class `super()` calls. Only non-semantic changes should be applied.

- #94587
- #94588
- #94592

Also, methods with only a `super()` call are removed:

```diff
class MyModule(nn.Module):
-   def __init__(self):
-       super().__init__()
-
    def forward(self, ...):
        ...
```

Some cases that change the semantics should be kept unchanged. E.g.:

f152a79be9/caffe2/python/net_printer.py (L184-L190)

f152a79be9/test/test_jit_fuser_te.py (L2628-L2635)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94588
Approved by: https://github.com/ezyang, https://github.com/albanD
2023-02-10 21:16:33 +00:00
Aaron Gokaslan
9171f7d4cd [BE] Modernize PyTorch even more for 3.8 with pyupgrade (#94520)
Applies some more pyupgrade fixits to PyTorch

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94520
Approved by: https://github.com/ezyang
2023-02-10 18:02:50 +00:00
Maxwell Nuyens
0d0ebcdfe5 feature: adding the ability to restore shapes after loading a traced model (#90744)
Adds the ability to store inputs used in tracing models when calling torch.jit.save and restore the input shapes using torch.jit.load if the appropriate variables are set.

Fixes [89185](https://github.com/pytorch/pytorch/issues/89185)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90744
Approved by: https://github.com/davidberard98
2023-02-10 17:12:52 +00:00
Aaron Gokaslan
1e2d82b8e4 [BE] Merge isinstance calls together (#94419)
Simplify and speeds up isinstance calls by checking for multiple types at the same time.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94419
Approved by: https://github.com/ezyang
2023-02-09 00:47:26 +00:00
Aaron Gokaslan
3ce1ebb6fb Apply some safe comprehension optimizations (#94323)
Optimize unnecessary collection cast calls, unnecessary calls to list, tuple, and dict, and simplify calls to the sorted builtin. This should strictly improve speed and improve readability.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94323
Approved by: https://github.com/albanD
2023-02-07 23:53:46 +00:00
Aaron Gokaslan
8fce9a09cd [BE]: pyupgrade Python to 3.8 - imports and object inheritance only (#94308)
Apply parts of pyupgrade to torch (starting with the safest changes).
This PR only does two things: removes the need to inherit from object and removes unused future imports.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94308
Approved by: https://github.com/ezyang, https://github.com/albanD
2023-02-07 21:10:56 +00:00
albanD
496c0a207b Make segment_reduce properly private. (#93166)
I am attempting not to change the aten function to reduce the amount of BC issues on the torchscript side.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/93166
Approved by: https://github.com/ngimel
2023-02-06 18:32:23 +00:00
Ivan Kobzarev
6a2838eec5 [jit] jit._drop fun modifier to allow in jit class non-jit decl funs (#93012)
`@torch.jit.unused` and `@torch.jit.ignore` do not allow to keep in torch scripted class member function, that has non scriptable declaration (e.g. return type)

Adding FunctionModifier _DROP to allow fully skip those functions from scripting and keep them in the code of the scripted class.

E.g. it can be used for:

```
@torch.jit._drop
def __fx_create_arg__(self, tracer: torch.fx.Tracer) -> torch.fx.node.Argument:
    # torch.fx classes are not scriptable
    return tracer.create_node(
        "call_function",
        CFX,
        args=(tracer.create_arg(self.features),),
        kwargs={},
    )

def __iter__(self) -> Iterator[torch.Tensor]:
    return iter(self.a)
```

Testing:
Added test case in `test/jit/test_types.py` with non-scriptable type annotations (fx.* classes) that fails before fix and passes after.

```
python test/test_jit.py
```

Differential Revision: [D42774830](https://our.internmc.facebook.com/intern/diff/D42774830)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/93012
Approved by: https://github.com/davidberard98
2023-02-01 09:02:05 +00:00
Ivan Kobzarev
2fc73622f8 [jit] Support Awaitable type (#90863)
We want to make TorchRec sharded models TorchScriptable.

TorchRec sharded models uses generic types Awaitable[W] and LazyAwaitable[W] (https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/types.py#L212).
In sharded model those types are used instead of contained type W, having the initialization function that produces object of type W.

At the moment when the first attribute of W is requested - `LazyAwaitable[W]` will call its initialization function (on the same stack), cache the result inside and work transparently as an object of W. So we can think about it as a delayed object initialization.

To support this behavior in TorchScript - we propose a new type to TorchScript - `Await`.
In eager mode it works the same as `LazyAwaitable[W]` in TorchRec, being dynamically typed - acting as a type `W` while it is `Await[W]`.

Within torchscript it is `Await[W]` and can be only explicitly converted to W, using special function `torch.jit.awaitable_wait(aw)`.
Creation of this `Await[W]` is done via another special function `torch.jit.awaitable(func, *args)`.

The semantic is close to `torch.jit.Future`, fork, wait and uses the same jit mechanics (inline fork Closures) with the difference that it does not start this function in parallel on fork. It only stores as a lambda inside IValue that will be called on the same thread when `torch.jit.awaitable_wait` is called.

For example (more examples in this PR `test/jit/test_await.py`)
```
      def delayed(z: Tensor) -> Tensor:
          return Tensor * 3

      @torch.jit.script
      def fn(x: Tensor):
          aw: Await[int] = torch.jit._awaitable(delayed, 99)
          a = torch.eye(2)
          b = torch.jit._awaitable_wait(aw)
          return a + b + x
```

Functions semantics:

`_awaitable(func -> Callable[Tuple[...], W], *args, **kwargs) -> Await[W]`

Creates Await object, owns args and kwargs. Once _awaitable_wait calls, executes function func and owns the result of the function. Following _awaitable_wait calls will return this result from the first function call.

`_awaitable_wait(Await[W]) -> W`
Returns either cached result of W if it is not the first _awaitable_wait call to this Await object or calls specified function if the first.

`_awaitable_nowait(W) -> Await[W]`

Creates trivial Await[W] wrapper on specified object To be type complaint for the corner cases.

Differential Revision: [D42502706](https://our.internmc.facebook.com/intern/diff/D42502706)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90863
Approved by: https://github.com/davidberard98
2023-01-30 17:38:59 +00:00
Nikita Shulga
5976f0bdfe Set min supported Python version to 3.8 (#93155)
Also, grep for `if sys.version_info .cond. (3, 8)` and replaces them with appropriate action.

This is a last in a series of PRs that moved CI/CD away from testing PyTorch behavior against Python-3.7.

Fixes https://github.com/pytorch/pytorch/issues/80513

Pull Request resolved: https://github.com/pytorch/pytorch/pull/93155
Approved by: https://github.com/huydhn
2023-01-29 18:28:46 +00:00
Vivek Khandelwal
f77a9a585c Add shape function for movedim op (#91696)
Signed-Off By: Vivek Khandelwal<vivek@nod-labs.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91696
Approved by: https://github.com/davidberard98
2023-01-06 18:24:52 +00:00
Sergii Dymchenko
f51f6aa387 Fix non-existing parameters in docstrings (#90505)
Continuation after https://github.com/pytorch/pytorch/pull/90163.

Here is a script I used to find all the non-existing arguments in the docstrings (the script can give false positives in presence of *args/**kwargs or decorators):

_Edit:_
I've realized that the indentation is wrong for the last `break` in the script, so the script only gives output for a function if the first docstring argument is wrong. I'll create a separate PR if I find more issues with corrected script.

``` python
import ast
import os
import docstring_parser

for root, dirs, files in os.walk('.'):
    for name in files:
        if root.startswith("./.git/") or root.startswith("./third_party/"):
            continue
        if name.endswith(".py"):
            full_name = os.path.join(root, name)
            with open(full_name, "r") as source:
                tree = ast.parse(source.read())
                for node in ast.walk(tree):
                    if isinstance(node, ast.FunctionDef):
                        all_node_args = node.args.args
                        if node.args.vararg is not None:
                            all_node_args.append(node.args.vararg)
                        if node.args.kwarg is not None:
                            all_node_args.append(node.args.kwarg)
                        if node.args.posonlyargs is not None:
                            all_node_args.extend(node.args.posonlyargs)
                        if node.args.kwonlyargs is not None:
                            all_node_args.extend(node.args.kwonlyargs)
                        args = [a.arg for a in all_node_args]
                        docstring = docstring_parser.parse(ast.get_docstring(node))
                        doc_args = [a.arg_name for a in docstring.params]
                        clean_doc_args = []
                        for a in doc_args:
                            clean_a = ""
                            for c in a.split()[0]:
                                if c.isalnum() or c == '_':
                                    clean_a += c
                            if clean_a:
                                clean_doc_args.append(clean_a)
                        doc_args = clean_doc_args
                        for a in doc_args:
                            if a not in args:
                                print(full_name, node.lineno, args, doc_args)
                            break

```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90505
Approved by: https://github.com/malfet, https://github.com/ZainRizvi
2022-12-09 21:43:09 +00:00
Rohan Varma
9c80f13692 [Resubmit] state_dict_pre_hook (#90435)
Resubmit of https://github.com/pytorch/pytorch/pull/88541 which got stale.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90435
Approved by: https://github.com/fegin
2022-12-08 07:54:14 +00:00
Ram Rachum
351d73b97f Fix exception causes all over the codebase (#90271)
This is the continuation to #90134 and hopefully the final PR in this series.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90271
Approved by: https://github.com/kit1980
2022-12-07 04:29:00 +00:00
Jane Xu
8695f0cced Rectify native_batch_norm schema by splitting it into two legit schemas (#88697)
Using the same repro from the issue (but with BatchNorm2D)

Rectifies native_batch_norm schema by splitting the schema into 2:
1. one will have NON-optional alias-able running_mean and running_var inputs
2. the other will just not have those parameters at all (no_stats variation)

**Calling for name suggestions!**

## test plan
I've added tests in test_functionalization.py as well as an entry in common_method_invocations.py for `native_batch_norm_legit`
CI should pass.

## next steps
Because of bc/fc reasons, we reroute native_batch_norm to call our new schemas ONLY through the python dispatcher, but in 2 weeks or so, we should make `native_batch_norm_legit` the official batch_norm.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88697
Approved by: https://github.com/albanD
2022-11-23 23:23:17 +00:00
Shen Li
f5d18574a3 Allow Module forward-pre and forward hooks to take kwargs (#89389)
closes #35643

This PR is mostly borrowed from #82042. Thanks @Padarn for implementing
the first version and debugging into the errors.

Based on the discussion in #82042 this PR adds a with_kwargs
argument to register_forward_pre_hook and register_forward_hook
methods. When the arg is set to true, the provided hook must accept
kwargs args. Under the hook, this PR adds a
`_forward_pre_hooks_with_kwargs` and a `_forward_hook_with_kwargs`
set to keep track of which hooks accept kwargs.

Differential Revision: [D41431111](https://our.internmc.facebook.com/intern/diff/D41431111)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89389
Approved by: https://github.com/soulitzer
2022-11-23 02:43:32 +00:00
Nikita Shulga
767f6aa49f [JIT][Security] Do not blindly eval input string (#89189)
Introduce `_eval_no_call` method, that evaluates statement only if it
does not contain any calls(done by examining the bytecode), thus preventing command injection exploit

Added simple unit test to check for that
`torch.jit.annotations.get_signature` would not result in calling random
code.

Although, this code path exists for Python-2 compatibility, and perhaps
should be simply removed.

Fixes https://github.com/pytorch/pytorch/issues/88868

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89189
Approved by: https://github.com/suo
2022-11-17 22:05:30 +00:00
Kazuaki Ishizaki
1cd6ebe095 Fix typos in messages under torch (#89049)
This PR fixes typos of messages in `.py` files under torch directory.
Only in `torch/onnx/symbolic_opset16.py`, fix a typo in comment to make the operator name correct.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89049
Approved by: https://github.com/lezcano
2022-11-17 04:18:14 +00:00
Philip Meier
bc73affdad prepare removal of deprecated functionality in torch.testing (#87969)
_Redo of #86586 with all BC breaking changes granularly placed into separate commits._

---

Per title. Deprecation happened on Feb 25, 2022 in c6f1bbc0ac, which made it into the 1.12 release. Since it is now 245 days later and the next release will be 1.14, the removals later in the stack comply with the [BC policy](https://github.com/pytorch/pytorch/wiki/PyTorch's-Python-Frontend-Backward-and-Forward-Compatibility-Policy#minimizing-the-disruption-of-bc-breaking-changes).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87969
Approved by: https://github.com/mruberry
2022-11-02 14:04:48 +00:00
Kazuaki Ishizaki
2ddefbdc3c Fix typos used in documents under torch directory (#88300)
This PR fixes typos, in comments of Python files, that are found from a search box at https://pytorch.org/docs/master/search.html

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88300
Approved by: https://github.com/lezcano
2022-11-02 09:38:13 +00:00
Mike Iovine
aaba0bd306 [JIT] Fix torch.jit.script for functions with many decorators (#87804)
Summary:
Python's function parsing from the `ast` module records the line number of the function definition, not the first decorator. So this diff fixes crashes like this:

```
IndexError: vector::_M_range_check: __n (which is 10) >= this->size() (which is 8)
```

Test Plan: New unit test

Differential Revision: D40726352

Pull Request resolved: https://github.com/pytorch/pytorch/pull/87804
Approved by: https://github.com/tugsbayasgalan, https://github.com/davidberard98
2022-10-27 12:29:51 +00:00
tangleintel
7980ed95bd Support unpacking python dictionary in torch.jit.trace() (#81623)
# Support unpacking python dictionary in **torch.jit.trace()**

## Problem statement & Motivation
### Problem 1(usability):
Say, if you have a model and its forward method defined as follows:
**`def forward(self, key1=value1, key2=value2, key3=value3)`**
And you have a dataset and each data point in the dataset is a python dict as follows:
**`data = {key1:value1, key3:value3, key2:value2}`**

The problem is that if you want to trace the model using the dict data by the giving dataset, you need unpack the dictionary and reorder its value manually and make up a tuple as **`data_tuple = (value1, value2, value3)`** as the **`example_inputs`** parameter of **`torch.jit.trace()`**. This marshalling process is not user friendly.

### Problem 2 (feasibility):
Say, if you have a model and its forward method defined as follows:
**`def forward(self, key1=None, key2=None, key3=None)`** -> The default value is **None**
And you have a dataset and each data point in the dataset is a python dict as follows:
**`data = {key1:value1, key3:value3}`** -> Only **part of** the required value by forward was given, the rest use the default value.

The problem is that if you want to trace the model using the dict data by the giving dataset, it's not feasible at all. Cause neither you can pass a tuple like **`T1 = (value1, value3)`**  nor **`T2 = (value1, None, value3)`**. T1 will mismatch value3 with key2 and T2 include **None** type which will be blocked by tracer's type checking. (Of course you can pass **`T3 = (value1,)`**  to make the trace function finish without exception, but the traced model you get probably is not what you expect cause the different input may result in different traced result.).

These problems come from the HuggingFace's PT model, especially in text-classification tasks with datasets such as [MRPC,](https://paperswithcode.com/dataset/mrpc)  [MNLI](https://paperswithcode.com/dataset/multinli) etc.

## Solution
To address these two issues, we propose to support a new type, that is, python dict as example_inputs parameter for torch.jit.trace(). We can base on the runtime type information of the example_inputs object to determine if we fall back to the original tuple path or go into the new dictionary path. Both problem 1 and  problem 2 can be solved by utilizing the "**`**`**"
operator.

## Limitation & Mitigation

1. If we use dict as example_inputs to trace the model, then we have to pass a dictionary to the traced model too. (Cause probably we will change the order of debug name of the input parameter in torchscript IR, thus we can't assume the traced model's input parameters order are the same with the original model.). We need highlight this too in the document to mitigate this problem.

    For example:
```
# fetch a data from dataloader, and the data is a dictionary
# and the example_inputs_dict is like: {key1:value1, key3:value3, key2:value2}
# the forward() is like: def forward(self, key1=value1, key2=value2, key3=value3)
example_inputs_dict = next(iter(dataloader))
jit_model = model.eval()
# use the dictionary to trace the model
jit_model = torch.jit.trace(jit_model, example_inputs_dict, strict=False)  # Now the IR will be graph(%self : __torch__.module.___torch_mangle_n.Mymodule, %key1 : type1, %key3 : type3, %key2 : type2)
jit_model = torch.jit.freeze(jit_model)

# It's OK to use dict as the parameter for traced model
jit_model(**example_inputs_dict)

example_inputs_tuple = (value1, value3, value2)
# It's wrong to rely on the original args order.
jit_model(*example_inputs_tuple)

```
## Note
1. This PR will make some UT introduced in [39601](https://github.com/pytorch/pytorch/pull/39601) fail, which I think should be classified as unpacking a tuple containing a single dictionary element in our solution.
4. I think there is ambiguity since currently we only specify passing a tuple or a single Tensor as our example_inputs parameter in **torch.jit.trace()**'s documentation, but it seems we can still passing a dictionary.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/81623
Approved by: https://github.com/davidberard98
2022-10-15 05:33:09 +00:00
Kshiteej K
54ee95c8ec [nn] module: full_backward_pre_hook (#86700)
Fixes https://github.com/pytorch/pytorch/issues/42824

* [x] Test
* [x] Doc
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86700
Approved by: https://github.com/soulitzer
2022-10-13 17:36:39 +00:00
Prashant Kumar
7ddf167ba5 Move the asserts in shape functions upsample_nearest_2d op. (#85801)
The assert check are moved to top and the function now returns out. This is needed by the downstream torch-mlir project to correctly determine the output type.

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/85801
Approved by: https://github.com/eellison
2022-09-30 18:30:06 +00:00
David Berard
424aad7f82 [JIT] support freezing modules that don't have a forward method (#85779)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85779
Approved by: https://github.com/eellison
2022-09-28 17:05:01 +00:00
Shisuiuzumaki
647aeb831f torch/jit/_trace.py in compare_outputs(original, reference, match_wha… (#84850)
Fixes #83533

### Bug:
```
/opt/homebrew/lib/python3.9/site-packages/torch/jit/_trace.py in _check_trace(check_inputs, func, traced_func, check_tolerance, strict, force_outplace, is_trace_module, _module_class)
    525         traced_outs = run_mod_and_filter_tensor_outputs(traced_func, inputs, "trace")
    526         fn_outs = run_mod_and_filter_tensor_outputs(func, inputs, "Python function")
--> 527         if compare_outputs(traced_outs, fn_outs, "Python function"):
    528             check_outs = run_mod_and_filter_tensor_outputs(
    529                 check_mod_func, inputs, "repeated trace"

/opt/homebrew/lib/python3.9/site-packages/torch/jit/_trace.py in compare_outputs(original, reference, match_what)
    500                     else:
    501                         torch.testing.assert_close(
--> 502                             orig.double(),
    503                             ref.double(),
    504                             rtol=check_tolerance,

TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.

```

### Fix:
```
if orig.is_mps or ref.is_mps:
        torch.testing.assert_close(
            orig.float(),
            ref.float(),
            rtol=check_tolerance,
            atol=default_tolerances(orig, ref)[1],
            equal_nan=True,
        )
        else:
            torch.testing.assert_close(
                orig.double(),
                ref.double(),
                rtol=check_tolerance,
                atol=default_tolerances(orig, ref)[1],
                equal_nan=True,
            )
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84850
Approved by: https://github.com/davidberard98
2022-09-16 01:45:20 +00:00
Wenzhe Xue
a2cccb2d6b add oneDNN graph fuser context API and unittest (#82491)
### Description
Add oneDNN graph context manager API to be consistent with other fusers.

NNC and nvFuser have two ways to use: 1) a function to enable/disable and 2) a context manager. And the later way is used extensively in libraries like Dynamo. Currently oneDNN Graph fuser only has the former way. To promote the usage of oneDNN graph fuser, this PR creates the context manager for oneDNN graph fuser.

This PR should not affect any performance.

### Testing
A unit-test `test_context_manager` is added under `test/test_jit_llga_fuser.py`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/82491
Approved by: https://github.com/malfet
2022-09-12 20:09:00 +00:00
Mikayla Gawarecki
e217b30b0f Add torch.nested namespace (#84102)
First step towards #83775
- only `to_padded_tensor` is moved to the nested namespace for now
- following the schema used for `special`, `fft`, `linalg` and other namespaces, nested functions are registered in native_functions.yaml as `nested_{function_name}` and are bound to the desired Python name in
`torch/nested/__init__.py`, and the desired C++ name in `torch/csrc/api/include/torch/nested.h`.

~~**Question**: should we keep the documentation for `Tensor.to_padded_tensor` or can this deleted since it is shared by `torch.nested.to_padded_tensor`?~~

[generated nested docs](https://docs-preview.pytorch.org/84102/nested.html?highlight=nested#module-torch.nested)

Differential Revision: [D39361148](https://our.internmc.facebook.com/intern/diff/D39361148)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84102
Approved by: https://github.com/drisspg
2022-09-12 16:31:05 +00:00
zaf
d32a762147 [quant][ao_migration] torch.nn.quantized.dynamictorch.ao.nn.quantized.dynamic (#78714)
Context: In order to avoid the cluttering of the `torch.nn` namespace
the quantized modules namespace is moved to `torch.ao.nn`.

The list of the `nn.quantized` files that are being migrated:

- [ ] `torch.nn.quantized` → `torch.ao.nn.quantized`
    - [X] `torch.nn.quantized.functional` → `torch.ao.nn.quantized.functional`
    - [X] `torch.nn.quantized.modules` → `torch.ao.nn.quantized.modules`
    - [X] [Current PR] `torch.nn.quantized.dynamic` → `torch.ao.nn.quantized.dynamic`
    - [ ] `torch.nn.quantized._reference` → `torch.ao.nn.quantized._reference`
- [ ] `torch.nn.quantizable` → `torch.ao.nn.quantizable`
- [ ] `torch.nn.qat` → `torch.ao.nn.qat`
    - [ ] `torch.nn.qat.modules` → `torch.ao.nn.qat.modules`
    - [ ] `torch.nn.qat.dynamic` → `torch.ao.nn.qat.dynamic`
- [ ] `torch.nn.intrinsic` → `torch.ao.nn.intrinsic`
    - [ ] `torch.nn.intrinsic.modules` → `torch.ao.nn.intrinsic.modules`
    - [ ] `torch.nn.intrinsic.qat` → `torch.ao.nn.intrinsic.qat`
    - [ ] `torch.nn.intrinsic.quantized` → `torch.ao.nn.intrinsic.quantized`
        - [ ] `torch.nn.intrinsic.quantized.modules` → `torch.ao.nn.intrinsic.quantized.modules`
        - [ ] `torch.nn.intrinsic.quantized.dynamic` → `torch.ao.nn.intrinsic.quantized.dynamic`

Majority of the files are just moved to the new location.
However, specific files need to be double checked:

- [Documentation](docs/source/quantization-support.rst) @vkuzo
- [Public API test list](test/allowlist_for_publicAPI.json) @peterbell10
- [BC test](test/quantization/bc/test_backward_compatibility.py) @vkuzo
- [IR emitter](torch/csrc/jit/frontend/ir_emitter.cpp) @jamesr66a
- [JIT serialization](torch/csrc/jit/serialization/import_source.cpp) @IvanKobzarev @jamesr66a

Differential Revision: [D36860660](https://our.internmc.facebook.com/intern/diff/D36860660/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D36860660/)!

Differential Revision: [D36860660](https://our.internmc.facebook.com/intern/diff/D36860660)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78714
Approved by: https://github.com/jerryzh168
2022-08-25 16:50:34 +00:00
George Petterson
35d4fa444b Fix for transposed convolution shape functions (#83557)
This fixes an issue with #80860 when in channels and out channels are different.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83557
Approved by: https://github.com/Gamrix
2022-08-22 19:05:41 +00:00
PyTorch MergeBot
b1a7b67529 Revert "[quant][ao_migration] torch.nn.quantized.dynamictorch.ao.nn.quantized.dynamic (#78714)"
This reverts commit e6fb97d8ae.

Reverted https://github.com/pytorch/pytorch/pull/78714 on behalf of https://github.com/janeyx99 due to sorry, reverting so https://github.com/pytorch/pytorch/pull/78713 could be cleanly reverted
2022-08-22 07:30:48 +00:00