mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
This will be the last disruptive functorch internals change. Why are we moving these files? - As a part of rationalizing functorch we are moving the code in functorch/_src to torch/_functorch - This is so that we can offer the functorch APIs as native PyTorch APIs (coming soon) and resolve some internal build issues. Why are we moving all of these files at once? - It's better to break developers all at once rather than many times Test Plan: - wait for tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/90091 Approved by: https://github.com/anijain2305, https://github.com/ezyang
25 lines
612 B
Python
25 lines
612 B
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD-style license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
from torch.utils._pytree import tree_flatten, tree_unflatten
|
|
|
|
|
|
def tree_map_(fn_, pytree):
|
|
flat_args, _ = tree_flatten(pytree)
|
|
[fn_(arg) for arg in flat_args]
|
|
return pytree
|
|
|
|
|
|
class PlaceHolder():
|
|
def __repr__(self):
|
|
return '*'
|
|
|
|
|
|
def treespec_pprint(spec):
|
|
leafs = [PlaceHolder() for _ in range(spec.num_leaves)]
|
|
result = tree_unflatten(leafs, spec)
|
|
return repr(result)
|