Updated version of #108885 addressing the review. In this PR:
- We add a VT.can_reconstruct utility that checks if VT.reconstruct()
does something.
- If functools.wraps(fn) is passed a `fn` that either has a source or
has .can_reconstruct() == True, then we stash the source (or the VT)
- Later on, we use the source (or VT.reconstruct) to actually
reconstruct the object in codegen.
Test Plan:
- New tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114279
Approved by: https://github.com/voznesenskym
This should be enough to get @voznesenskym 's FSDP branch to plumb `set_()` through AOTAutograd properly and have everything properly no-op out. Main changes are:
(1) graph break on `aten::set_.source_Tensor_storage_offset` (we could support it but it isn't needed, seems safer to graph break)
(2) Functionalization: add a "proper" functionalization kernel for `aten::set_.source_Tensor`. The previous one we had was codegen'd and it was wrong (it would just clone() and call set_(), which does not do the right thing). I also manually mark on the `FunctionalTensorWrapper` when a given tensor has been mutated by a `set_()` call.
(3) AOTAutograd: I added a new field, `InputAliasInfo.mutates_storage_metadata`, so we can distinguish between "regular" metadata mutations, and metadata mutations due to `set_()` calls. This is mainly because at runtime, one requires calling `as_strided_()` to fix up metadata, while the other requires calling `set_()`.
(4) Made AOTAutograd's detection for metadata mutations / set_() mutations smarter and detect no-ops (if the storage and metadata are all the same).
I also killed `was_updated()` and `was_metadata_updated()`, and replaced them with (existing) `has_data_mutation() ` and (new) `has_data_mutation()`, which can more accurately distinguish between data-mutation vs. `set_()` calls vs. metadata-mutation
**This PR is still silently correct in one case though**, which I'd like to discuss more. In particular, this example:
```
def f(x):
x_view = x.view(-1)
x.set_(torch.ones(2))
x_view.mul_(2)
return
```
If you have an input that experiences both a data-mutation **and** a `x_old.set_(x_new)` call, there are two cases:
(a) the data mutation happened on the storage of `x_new`. This case should be handled automatically: if x_new is a graph intermediate then we will functionalize the mutation. If x_new is a different graph input, then we will perform the usual `copy_()` on that other graph input
(b) the data mutation happened on the storage of `x_old`. This is more of a pain to handle, and doesn't currently work. At runtime, the right thing to do is probably something like:
```
def functionalized_f(x):
x_view = x.view(-1)
# set_() desugars into a no-op; later usages of x will use x_output
x_output = torch.ones(2)
# functionalize the mutation on x_view
x_view_updated = x.mul(2)
x_updated = x_view_updated.view(x.shape)
# x experienced TWO TYPES of mutations; a data mutation and a metatadata mutation
# We need to return both updated tensors in our graph
return x_updated, x_output
def runtime_wrapper(x):
x_data_mutation_result, x_set_mutation_result = compiled_graph(x)
# First, perform the data mutation on x's old storage
x.copy_(x_data_mutation_result)
# Then, swap out the storage of x with the new storage
x.set_(x_set_mutation_result)
```
There are two things that make this difficult to do though:
(1) Functionalization: the functionalization rule for `set_()` will fully throw away the old `FunctionalStorageImpl` on the graph input. So if there are any mutations to that `FunctionalStorageImpl` later on in the graph, the current graph input won't know about it. Maybe we can have a given `FunctionalTensorWrapper` remember all previous storages that it had, and track mutations on all of them - although this feels pretty complicated.
(2) AOTAutograd now needs to know that we might have *two* graph outputs that correspond to a single "mutated input", which is annoying.
It's worth pointing out that this issue is probably extremely unlikely for anyone to run into - can we just detect it and error? This feels slightly easier than solving it, although not significantly easier. We would still need `FunctionalTensorWrapper` to keep track of mutations on any of its "previous" storages, so it can report this info back to AOTAutograd so we can raise an error.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111554
Approved by: https://github.com/ezyang
ghstack dependencies: #113926
With this PR it is possible to differentiate through NumPy code modulo
the usual caveats that apply to differentiation:
- That there are no graphbreaks
- That the decomposition in `torch._numpy` is differentiable
@ev-br and I were somewhat careful to achieve the second point, but
it is not tested though and through, so YMMV
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114608
Approved by: https://github.com/voznesenskym
Summary:
The primary problem we are setting out to solve here is fake tensor freshness. Before this PR, fake tensors after dynamo represented fake tensors *at the end* of trace, so subsequent retraces like aot_autograd would start off with fake tensors in the wrong (end result) state, rather than their expected fresh state. The solution here is to start a fresh fake mode, and re-fakify the tensors. The nuance comes from ensuring that symbols are uniformly created for the symbolic sizes and strides of the tensor.
This PR is the result of *a lot* of back and forth with ezyang and eellison. Initially, the first pass at this was not super different from what we have in the PR - the broad strokes were the same:
1) We cache source->symbol in shape_env
2) We pass policy objects around, stored at dynamo fakificaiton time, and reused for later fakification
3) We create a new fake mode for backends
(from https://github.com/pytorch/pytorch/pull/113605/files)
This is ugly, and has some layering violations. We detoured our decision making through a few other alternatives. Immutable/mutable fake tensor mode was the most interesting alternative, https://github.com/pytorch/pytorch/pull/113653, and was struck down on concerns of complexity in fake mode combined with it not covering all edge cases. We also detoured on what to do about tensor memoization returning back potentially different tensors than requested, and if that was an anti pattern (it is) we want to hack in with the symbol cache (we don't).
We went back to the drawing board here, but with a few concessions:
1) the cache for source->symbol must live outside of shape_env, for both lifecycle, and layering reasons
2) A good amount of work needs to be done to pipe policy around fake_mode and meta_utils correctly, to cover all the cases (ezyang did this)
cc penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng
imported-using-ghimport
Test Plan: Imported from OSS
Reviewed By: huydhn, Chillee
Differential Revision: D51566250
Pulled By: voznesenskym
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114526
Approved by: https://github.com/Chillee, https://github.com/huydhn
This PR checks the tensor meta of the outputs of cond's branches. This helps us to identify several tests that return outputs that have different requires_grad. Also fix the error messages, which previously was in torch.ops.higher_order.cond now is raised in dynamo CondHigherOrder.
Test Plan:
Existing tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113900
Approved by: https://github.com/zou3519
ghstack dependencies: #113819
This PR add should_flatten_outpu=True for cond. This effectively allows cond to support pytree output with the output being flattened. Note: a single tensor output will be automatically casted as tuple for torch.ops.higher_order.cond.
This PR also adds support for comparing BuiltinVariables e.g. tuple, this is to make sure we could make dynamo inline comparing two tree_spec to make sure both branches returns the same tree_spec.
Test Plan:
Existing tests. Will add more pytree tests and modify the documentations in the follow-up prs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113819
Approved by: https://github.com/zou3519
This should be enough to get @voznesenskym 's FSDP branch to plumb `set_()` through AOTAutograd properly and have everything properly no-op out. Main changes are:
(1) graph break on `aten::set_.source_Tensor_storage_offset` (we could support it but it isn't needed, seems safer to graph break)
(2) Functionalization: add a "proper" functionalization kernel for `aten::set_.source_Tensor`. The previous one we had was codegen'd and it was wrong (it would just clone() and call set_(), which does not do the right thing). I also manually mark on the `FunctionalTensorWrapper` when a given tensor has been mutated by a `set_()` call.
(3) AOTAutograd: I added a new field, `InputAliasInfo.mutates_storage_metadata`, so we can distinguish between "regular" metadata mutations, and metadata mutations due to `set_()` calls. This is mainly because at runtime, one requires calling `as_strided_()` to fix up metadata, while the other requires calling `set_()`.
(4) Made AOTAutograd's detection for metadata mutations / set_() mutations smarter and detect no-ops (if the storage and metadata are all the same).
I also killed `was_updated()` and `was_metadata_updated()`, and replaced them with (existing) `has_data_mutation() ` and (new) `has_data_mutation()`, which can more accurately distinguish between data-mutation vs. `set_()` calls vs. metadata-mutation
**This PR is still silently correct in one case though**, which I'd like to discuss more. In particular, this example:
```
def f(x):
x_view = x.view(-1)
x.set_(torch.ones(2))
x_view.mul_(2)
return
```
If you have an input that experiences both a data-mutation **and** a `x_old.set_(x_new)` call, there are two cases:
(a) the data mutation happened on the storage of `x_new`. This case should be handled automatically: if x_new is a graph intermediate then we will functionalize the mutation. If x_new is a different graph input, then we will perform the usual `copy_()` on that other graph input
(b) the data mutation happened on the storage of `x_old`. This is more of a pain to handle, and doesn't currently work. At runtime, the right thing to do is probably something like:
```
def functionalized_f(x):
x_view = x.view(-1)
# set_() desugars into a no-op; later usages of x will use x_output
x_output = torch.ones(2)
# functionalize the mutation on x_view
x_view_updated = x.mul(2)
x_updated = x_view_updated.view(x.shape)
# x experienced TWO TYPES of mutations; a data mutation and a metatadata mutation
# We need to return both updated tensors in our graph
return x_updated, x_output
def runtime_wrapper(x):
x_data_mutation_result, x_set_mutation_result = compiled_graph(x)
# First, perform the data mutation on x's old storage
x.copy_(x_data_mutation_result)
# Then, swap out the storage of x with the new storage
x.set_(x_set_mutation_result)
```
There are two things that make this difficult to do though:
(1) Functionalization: the functionalization rule for `set_()` will fully throw away the old `FunctionalStorageImpl` on the graph input. So if there are any mutations to that `FunctionalStorageImpl` later on in the graph, the current graph input won't know about it. Maybe we can have a given `FunctionalTensorWrapper` remember all previous storages that it had, and track mutations on all of them - although this feels pretty complicated.
(2) AOTAutograd now needs to know that we might have *two* graph outputs that correspond to a single "mutated input", which is annoying.
It's worth pointing out that this issue is probably extremely unlikely for anyone to run into - can we just detect it and error? This feels slightly easier than solving it, although not significantly easier. We would still need `FunctionalTensorWrapper` to keep track of mutations on any of its "previous" storages, so it can report this info back to AOTAutograd so we can raise an error.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111554
Approved by: https://github.com/ezyang
ghstack dependencies: #113926
The primary problem we are setting out to solve here is fake tensor freshness. Before this PR, fake tensors after dynamo represented fake tensors *at the end* of trace, so subsequent retraces like aot_autograd would start off with fake tensors in the wrong (end result) state, rather than their expected fresh state. The solution here is to start a fresh fake mode, and re-fakify the tensors. The nuance comes from ensuring that symbols are uniformly created for the symbolic sizes and strides of the tensor.
This PR is the result of *a lot* of back and forth with @ezyang and @eellison. Initially, the first pass at this was not super different from what we have in the PR - the broad strokes were the same:
1) We cache source->symbol in shape_env
2) We pass policy objects around, stored at dynamo fakificaiton time, and reused for later fakification
3) We create a new fake mode for backends
(from https://github.com/pytorch/pytorch/pull/113605/files)
This is ugly, and has some layering violations. We detoured our decision making through a few other alternatives. Immutable/mutable fake tensor mode was the most interesting alternative, https://github.com/pytorch/pytorch/pull/113653, and was struck down on concerns of complexity in fake mode combined with it not covering all edge cases. We also detoured on what to do about tensor memoization returning back potentially different tensors than requested, and if that was an anti pattern (it is) we want to hack in with the symbol cache (we don't).
We went back to the drawing board here, but with a few concessions:
1) the cache for source->symbol must live outside of shape_env, for both lifecycle, and layering reasons
2) A good amount of work needs to be done to pipe policy around fake_mode and meta_utils correctly, to cover all the cases (@ezyang did this)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113926
Approved by: https://github.com/ezyang, https://github.com/eellison
During the course of fake tensor propagation (and, potentially, also Dynamo execution, although I do not believe it is possible to exercise this right now), we may generate deferred runtime asserts, which represent "guards" on unbacked symbols which cannot be immediately checked on entry to a code block; instead, they have to be checked at runtime. However, we currently accumulate these deferred runtime asserts into the ShapeEnv, and don't do anything with them.
This PR modifies Dynamo to automatically insert these runtime asserts into the FX graph, before passing it on to the backend compiler. The assert format coincides with the export assert format as practiced in `torch/_export/passes/add_runtime_assertions_for_constraints_pass.py`, but actually these passes are completely disjoint right now as I only handle deferred runtime asserts, while export only handles ranges (which I should probably also handle, but don't in this PR.)
The assertions must be inserted by Dynamo, because you could potentially then pass the asserts onto another backend like "eager" which no longer looks at the ShapeEnv before. Thanks to previous work in export, these asserts are preserved in AOTAutograd, but they are dropped by Inductor, which needs to be fixed in future work. This piece will be a bit awkward, as Inductor would have preferred to work with the Sympy expressions directly, ah well.
Here is what the Dynamo traced FX graph looks like for the test in question:
```
<eval_with_key>.0 class GraphModule(torch.nn.Module):
def forward(self, L_x_ : torch.Tensor):
l_x_ = L_x_
# File: /data/users/ezyang/c/pytorch/wu.py:8, code: y = x.item()
item = l_x_.item()
# No stacktrace found for following nodes
ge_1 = item >= 0
scalar_tensor_default = torch.ops.aten.scalar_tensor.default(ge_1); ge_1 = None
_assert_async_msg = torch.ops.aten._assert_async.msg(scalar_tensor_default, "Deferred runtime assert failed: i0 >= 0, where i0 was defined by 'item' (for more information, run with TORCH_LOGS=+dynamo,dynamic)"); scalar_tensor_default = None
# File: /data/users/ezyang/c/pytorch/wu.py:9, code: torch._check_is_size
_check_is_size = torch._check_is_size(item)
# File: /data/users/ezyang/c/pytorch/wu.py:10, code: if y >= 0:
ge = item >= 0; item = None
# File: /data/users/ezyang/c/pytorch/wu.py:11, code: return x * 2
mul = l_x_ * 2; l_x_ = None
return (mul,)
```
Note that we actually keep the `_check_is_size` in the graph redundantly. However, assert_async is retained in the graph, whereas _check_is_size ends up getting DCE'ed.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113958
Approved by: https://github.com/aakhundov, https://github.com/tugsbayasgalan
ghstack dependencies: #113978
Fixes https://github.com/pytorch/pytorch/issues/113041
In the case where we have an object represented as an UnspecializedNNModuleVariable, the source of an attribute on that object is `AttrSource(base=NotNNModuleSource(base=NNModuleSource(base=AttrSource(base=LocalSource(local_name='self', cell_or_freevar=False), member='seq'))), member='b')`. This causes dynamo to add an extra attribute as it doesn't go to this [`register_attr` step](eddce3c054/torch/_dynamo/variables/builder.py (L955-L962)).
However if we have an object represented as a UserDefinedObjectVariable, the source of an attribute on that object is `AttrSource(base=NNModuleSource(base=AttrSource(base=LocalSource(local_name='self', cell_or_freevar=False), member='seq')), member='b')`.
It seems that UnspecializedNNModuleVariables should behave in the same was as UserDefinedObjectVariables, but the source in these two cases are different. So, I removed the part that changes the source in the UnspecializedNNModuleVariables, and it seems to work! And CI is green (+ reduced graph breaks).
```
def test_unspecialized_nnmodule(self):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.a = torch.tensor(1.0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + self.a
def forward_hook(
module: torch.nn.Module, inputs, output
) -> torch.Tensor:
return 2 * output
seq = torch.nn.Sequential(TestModule()).eval()
seq.b = torch.tensor(2)
handle = seq.register_forward_hook(forward_hook)
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.seq = seq
def forward(self, x):
# self.seq.b has source: AttrSource(base=NotNNModuleSource(base=NNModuleSource(base=AttrSource(base=LocalSource(local_name='self', cell_or_freevar=False), member='seq'))), member='b')
return self.seq(x) + self.seq.b
inp = (torch.randn(2, 8),)
ep = export(M(), inp)
```
```
def test_user_defined_var(self):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.a = torch.tensor(1.0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + self.a
class UserDefined:
def __init__(self):
self.test_module = TestModule()
self.b = torch.tensor(2)
def __call__(self, x):
return self.test_module(x)
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.seq = UserDefined()
def forward(self, x):
# self.seq.b has source: AttrSource(base=NNModuleSource(base=AttrSource(base=LocalSource(local_name='self', cell_or_freevar=False), member='seq')), member='b')
return self.seq(x) + self.seq.b
inp = (torch.randn(2, 8),)
ep = export(M(), inp)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113852
Approved by: https://github.com/yanboliang
Copied from @ezyang 's #113693.
The motivation for this change is that we'd like to guard on storage offset in inductor, to make assumptions about data alignment.
create_symbolic_sizes_strides_storage_offset() creates the sizes/strides/offset for fake tensors - they can either be integers or symints. This PR changes storage_offset to always be dynamic. In variables/builder.py, we remove a conditional so that all tensors get added to tracked_fakes. This is because the storage offset will be dynamic even if the other logic in builder.py suggests that it will be static; otherwise, we run into this issue:
1e260c851b/torch/fx/experimental/symbolic_shapes.py (L892-L895)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113734
Approved by: https://github.com/ezyang
We saw some use cases in higher order operator that tries to directly inline a user-level function (e.g. pytree.tree_flatten and pytree.tree_unflatten) with no tensor operations by manually constructing a UserFunctionVariable and run call_function on it.
This PR consolidate this pattern a little bit by adding a _make_inlined helper function to make the UX better( i.e. the callilng convention is kept the same with the function that we'd like to inline) and also reduce redundancy, increase readability.
Test Plan:
Exisiting tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113814
Approved by: https://github.com/yanboliang
Copied from @ezyang 's #113693.
The motivation for this change is that we'd like to guard on storage offset in inductor, to make assumptions about data alignment.
create_symbolic_sizes_strides_storage_offset() creates the sizes/strides/offset for fake tensors - they can either be integers or symints. This PR changes storage_offset to always be dynamic. In variables/builder.py, we remove a conditional so that all tensors get added to tracked_fakes. This is because the storage offset will be dynamic even if the other logic in builder.py suggests that it will be static; otherwise, we run into this issue:
1e260c851b/torch/fx/experimental/symbolic_shapes.py (L892-L895)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113734
Approved by: https://github.com/ezyang
fixes https://github.com/pytorch/pytorch/issues/90552. This is a simpler fix that just detects the situation where AOTAutograd can't create a proper backward graph for the situation and graph breaks. This was technically a silent correctness issue before.
This PR tries to always graph break when we see a factory function that returns a tensor requiring grad. I check this by seeing if the op returned a `TensorVariable` in dynamo, and if one of the input arguments was a `requires_grad=True` kwarg. I think this is high-fidelity enough, and I'm also hoping that this is uncommon enough that a graph break is reasonable here.
The fix to avoid the graph break in user land is also pretty easy - just instantiate your tensor outside of the compiled region and plumb it in.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113277
Approved by: https://github.com/eellison
ghstack dependencies: #113267, #113416, #113584
torch.compile + SAC unit test is causing adjacent unit tests to be flaky due to its modification of shared singleton object. This PR attaches the checkpoint context fn to the checkpointed GraphModule, and look it up during execution, avoiding the need to make the higher-order op stateful.
Specifically, we attach the `context_fn` to the checkpointed GraphModule. These two will be gc'ed at the same time, so it satisfies the lifetime requirement.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112672
Approved by: https://github.com/wanchaol
Fixes https://github.com/pytorch/pytorch/issues/113010
In eager mode, when you call an out= op like `add(..., out=out_arg)` with an out argument that is noncontiguous, the noncontiguous out arg will be returned directly. When we functionalize though, functionalization replaces it with a call to `add(...)` which ignores the contiguity of the original out arg.
Instead of trying to support this, this PR detects that situation and graph breaks
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113267
Approved by: https://github.com/albanD
In TorchVision we use the following (simplified) dispatch mechanism:
```python
import torch
def kernel1(tensor):
return tensor + 2
def dispatcher1(input):
kernel = get_kernel(dispatcher1, type(input))
return kernel(input)
def kernel2(tensor):
return tensor - 2
def dispatcher2(input):
kernel = get_kernel(dispatcher2, type(input))
return kernel(input)
# We actually use the function and type as keys, rather than their names.
# However, this currently not supported, but should be easy to add after
# https://github.com/pytorch/pytorch/pull/111196
REGISTRY = {
"dispatcher1": {"Tensor": kernel1},
"dispatcher2": {"Tensor": kernel2},
}
def get_kernel(dispatcher, input_type):
dispatcher_registry = REGISTRY[dispatcher.__name__]
for cls in input_type.__mro__:
kernel = dispatcher_registry[cls.__name__]
break
return kernel
```
This can be compiled without graph breaks:
```python
cfn = torch.compile(dispatcher1, fullgraph=True)
torch.testing.assert_close(int(cfn(torch.tensor(3))), 5)
cfn = torch.compile(dispatcher2, fullgraph=True)
torch.testing.assert_close(int(cfn(torch.tensor(3))), 1)
```
However, if we start chaining these calls, we hit some issues:
```python
class Pipeline(torch.nn.Module):
def forward(self, input):
input = dispatcher1(input)
input = dispatcher2(input)
return input
cfn = torch.compile(Pipeline(), fullgraph=True)
torch.testing.assert_close(int(cfn(torch.tensor(3))), 3)
```
```
Can't access members of type(obj) for a generated custom object. Please use __class__ instead
```
The error message is not really helpful here. The following happens: when compiling `dispatcher1`, `get_kernel` gets inlined. That means when hitting `dispatcher2`, the `type` call no longer happens on an input with a source. Thus, in the first iteration we hit the top branch, while in the second we hit the bottom:
addb8e29cd/torch/_dynamo/variables/builtin.py (L1264-L1268)
And the error message I posted above originates from the type being treated as constant. This PR replaces this with a `SourcelessBuilder` instead.
With that fix in place, we hit another pointing to `input_type.__mro__`
```
AssertionError: Consider SourcelessBuilder for ephemeral objects, usually objects created locally.
```
Fix is similar: instead of using a `VariableBuilder` here, we use a `SourcelessBuilder` in case we have no `source`:
addb8e29cd/torch/_dynamo/variables/builtin.py (L1167-L1168)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113340
Approved by: https://github.com/peterbell10, https://github.com/lezcano
This prepares the PR where we implement sets in terms of dicts.
To do so, rather than storing internally a dictionary that maps literals
to VariableTrackers, it stores (pretty much) a dictionary from VTs to VTs.
To do so, keys are wrapped in an opaque internal class `_Hashable`.
The Hashable class is opaque on purpose so that it fails hard if
if it inadvertently leaks back into user code.
We also found and fixed a number of latent bugs and inconsistencies
in the way dynamo checked what can be a dict key. More generally, we
make much clearer what are the things that need to be modified to add
a new supported key type to Dicts.
Fixes https://github.com/pytorch/pytorch/issues/107595
Fixes https://github.com/pytorch/pytorch/issues/111603
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111196
Approved by: https://github.com/jansel
They are used in many contexts that don't actually check if the returned
type is `None`. I have also created `try_get()` for the cases where we
do actually want an Optional type returned.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113535
Approved by: https://github.com/ezyang
ghstack dependencies: #113412