pytorch/torch/utils
Richard Zou 6025f8148a Implement _broadcast_to_and_flatten(pytree, spec) (#46288)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/46288

This "broadcasts" `pytree` to have the same structure as `spec`
and then flattens it.
I find it hard to describe what that does in words, so here's an example:

- Broadcasting 1 to have the same structure as [0, [0, 0]] would
return [1, [1, 1]]. Further flattening it gives us [1, 1, 1].
- Broadcasting [1, 2] to have the same structure as [0, [0, 0]] would
return [1, [2, 2]]. Further flattening it gives us [1, 2, 2].

What is this used for?
----------------------
The next PR up in the stack uses this helper function to allow vmap to
accept nested data structures. `vmap(fn, in_dims)(*inputs)` allows the
user to specify in_dims with a tree structure that is a sub-graph of
that of `inputs` (where both contain the root of the tree).

For example, one can do `vmap(fn, in_dims=0)(x, y, z)`. `in_dims` is 0
and inputs is (x, y, z). We would like to broadcast in_dims up to the
structure of inputs to get (0, 0, 0).

Another example, is `vmap(fn, in_dims=(0, 1))(x, [y, z])`. `in_dims` is
(0, 1) and inputs is (x, [y, z]). We would like to broadcast in_dims up
to the structure of inputs to get (0, [1, 1]); this value of in_dims is
used to say "let's vmap over dim 0 for x and dim 1 for y and z".

Test Plan
---------
New tests.

Test Plan: Imported from OSS

Reviewed By: heitorschueroff

Differential Revision: D24392891

Pulled By: zou3519

fbshipit-source-id: 6f494d8b6359582f1b4ab6b8dd6a956d8bfe8ed4
2020-10-20 07:52:14 -07:00
..
backcompat Simplify python warning settings and cleanup tests. 2017-06-11 05:37:59 -04:00
benchmark More Timer refinement (#46023) 2020-10-15 16:32:53 -07:00
bottleneck Fix type annotations for a number of torch.utils submodules (#42711) 2020-08-14 18:12:48 -07:00
data Fix possible padding length overflow in DistributedSampler (#45329) 2020-10-14 17:19:44 -07:00
ffi remove support for c extensions (#12122) 2018-10-01 13:55:28 -07:00
hipify Annotate torch.utils.(tensorboard/show_pickle/hypify) (#44216) 2020-09-29 18:14:26 -07:00
tensorboard Annotate torch.utils.(tensorboard/show_pickle/hypify) (#44216) 2020-09-29 18:14:26 -07:00
__init__.py Remove py2 compatible future imports (#44735) 2020-09-16 12:55:57 -07:00
_cpp_extension_versioner.py Remove py2 compatible future imports (#44735) 2020-09-16 12:55:57 -07:00
_pytree.py Implement _broadcast_to_and_flatten(pytree, spec) (#46288) 2020-10-20 07:52:14 -07:00
bundled_inputs.py add typing annotations for a few torch.utils.* modules (#43806) 2020-09-11 10:20:55 -07:00
checkpoint.py [pytorch] activation checkpointing: enable mixing tensor without requires_grad (#45934) 2020-10-14 21:28:02 -07:00
collect_env.py Remove py2 compatible future imports (#44735) 2020-09-16 12:55:57 -07:00
cpp_extension.py Replace list(map(...)) constructs by list comprehensions (#46461) 2020-10-19 18:42:49 -07:00
dlpack.py Remove py2 compatible future imports (#44735) 2020-09-16 12:55:57 -07:00
file_baton.py Remove py2 compatible future imports (#44735) 2020-09-16 12:55:57 -07:00
hooks.py Remove py2 compatible future imports (#44735) 2020-09-16 12:55:57 -07:00
mkldnn.py Remove py2 compatible future imports (#44735) 2020-09-16 12:55:57 -07:00
mobile_optimizer.py [Metal] Add the Python binding for optimize_for_mobile (#46456) 2020-10-17 10:26:25 -07:00
model_zoo.py add/move a few apis in torch.hub (#18758) 2019-04-10 23:10:39 -07:00
show_pickle.py Annotate torch.utils.(tensorboard/show_pickle/hypify) (#44216) 2020-09-29 18:14:26 -07:00
throughput_benchmark.py Remove py2 compatible future imports (#44735) 2020-09-16 12:55:57 -07:00