mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Added two new utils to help with turning python functionalization on in AOTAutograd (next PR): (1) updated `torch._sync()`. Previously, this API could only handle `torch.Tensor` instances that had a `FunctionalTensorWrapper` TensorImpl. It now needs to handle python `FunctionalTensor`'s. In theory I can probably break BC and change this API (since it's private?), but I decided not to do it in this PR stack do minimize the chance of reverts. Instead of updating that API directly (which is in C++), I just added a python shim that first tries to unwrap the python `FunctionalTensor` if there is one, then calls the existing C++ logic (2) `mirror_autograd_meta` is now a standalone API that tries to mirror the `requires_grad` and `is_leaf` autograd metadata from one tensor to another. Previously this was hardcoded into `torch._to_functional_tensor()`. But I now need to use it in a more standalone way: later in AOTAutograd when we unwrap and re-wrap a tensor subclasses, we need to manually mirror the autograd metadata from the original to the updated version of the subclass. Pull Request resolved: https://github.com/pytorch/pytorch/pull/107917 Approved by: https://github.com/ezyang ghstack dependencies: #106404
118 lines
4.4 KiB
Python
118 lines
4.4 KiB
Python
import logging
|
|
import os
|
|
import tempfile
|
|
from contextlib import nullcontext
|
|
from typing import Any, Dict
|
|
|
|
import torch
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
# this arbitrary-looking assortment of functionality is provided here
|
|
# to have a central place for overrideable behavior. The motivating
|
|
# use is the FB build environment, where this source file is replaced
|
|
# by an equivalent.
|
|
|
|
if torch._running_with_deploy():
|
|
# __file__ is meaningless in the context of frozen torch used in torch deploy.
|
|
# setting empty torch_parent should allow below functions to operate without crashing,
|
|
# but it's unclear if there is a valid use case for them in the context of deploy.
|
|
torch_parent = ""
|
|
else:
|
|
if os.path.basename(os.path.dirname(__file__)) == "shared":
|
|
torch_parent = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
|
else:
|
|
torch_parent = os.path.dirname(os.path.dirname(__file__))
|
|
|
|
|
|
def get_file_path(*path_components: str) -> str:
|
|
return os.path.join(torch_parent, *path_components)
|
|
|
|
|
|
def get_file_path_2(*path_components: str) -> str:
|
|
return os.path.join(*path_components)
|
|
|
|
|
|
def get_writable_path(path: str) -> str:
|
|
if os.access(path, os.W_OK):
|
|
return path
|
|
return tempfile.mkdtemp(suffix=os.path.basename(path))
|
|
|
|
|
|
def prepare_multiprocessing_environment(path: str) -> None:
|
|
pass
|
|
|
|
|
|
def resolve_library_path(path: str) -> str:
|
|
return os.path.realpath(path)
|
|
|
|
|
|
# Meta only, see
|
|
# https://www.internalfb.com/intern/wiki/ML_Workflow_Observability/User_Guides/Adding_instrumentation_to_your_code/
|
|
#
|
|
# This will cause an event to get logged to Scuba via the signposts API. You
|
|
# can view samples on the API at https://fburl.com/scuba/workflow_signpost/zh9wmpqs
|
|
# we log to subsystem "torch", and the category and name you provide here.
|
|
# Each of the arguments translate into a Scuba column. We're still figuring
|
|
# out local conventions in PyTorch, but category should be something like
|
|
# "dynamo" or "inductor", and name should be a specific string describing what
|
|
# kind of event happened.
|
|
#
|
|
# Killswitch is at
|
|
# https://www.internalfb.com/intern/justknobs/?name=pytorch%2Fsignpost#event
|
|
def signpost_event(category: str, name: str, parameters: Dict[str, Any]):
|
|
log.info("%s %s: %r", category, name, parameters)
|
|
|
|
|
|
def log_compilation_event(metrics):
|
|
log.info("%s", metrics)
|
|
|
|
|
|
def _functionalize_sync(t):
|
|
# This code lives in python instead of C++ since conditioning on a certain python subclass
|
|
# is much more of a pain in C++.
|
|
from torch._subclasses.functional_tensor import (
|
|
FunctionalTensor,
|
|
maybe_disable_functional_mode,
|
|
)
|
|
|
|
ctx = (
|
|
maybe_disable_functional_mode
|
|
if isinstance(t, FunctionalTensor)
|
|
else nullcontext
|
|
)
|
|
if isinstance(t, FunctionalTensor):
|
|
# If a FunctionalTensorMode is active while syncing, we don't want it to intercept any ops that get called
|
|
# when we sync our inner tensor.
|
|
# Why?
|
|
# (1) If there are input mutations in the graph, then they will be re-applied during
|
|
# AOTAutograd when we call _sync() from inside of our functionalization kernels.
|
|
# (2) _sync() causes us to regenerate our updated the tensor from the updated base,
|
|
# which dispatches to a bunch of view ops
|
|
# (3) The input to these view ops is our inner FunctionalTensorWrapper
|
|
# (since the sync was called from C++), not the python FunctionalTensor
|
|
# (4) if a python FunctionalTensorMode is active, it will complain when it intercepts
|
|
# the view op, since it will see an input that is a C++ FunctionalTensorWrapper
|
|
# (aka a normal torch.Tensor) instead of a python `FunctionalTensor).
|
|
maybe_functional_mode = torch._C._unset_dispatch_mode(
|
|
torch._C._TorchDispatchModeKey.FUNCTIONAL
|
|
)
|
|
try:
|
|
torch._functionalize_sync(t.elem)
|
|
finally:
|
|
if maybe_functional_mode is not None:
|
|
torch._C._set_dispatch_mode(maybe_functional_mode)
|
|
else:
|
|
torch._functionalize_sync(t)
|
|
|
|
|
|
TEST_MASTER_ADDR = "127.0.0.1"
|
|
TEST_MASTER_PORT = 29500
|
|
# USE_GLOBAL_DEPS controls whether __init__.py tries to load
|
|
# libtorch_global_deps, see Note [Global dependencies]
|
|
USE_GLOBAL_DEPS = True
|
|
# USE_RTLD_GLOBAL_WITH_LIBTORCH controls whether __init__.py tries to load
|
|
# _C.so with RTLD_GLOBAL during the call to dlopen.
|
|
USE_RTLD_GLOBAL_WITH_LIBTORCH = False
|