Revert "[dynamo] simplify implementation for os.fspath (#133801)"

This reverts commit c5f6b72041.

Reverted https://github.com/pytorch/pytorch/pull/133801 on behalf of https://github.com/ZainRizvi due to Sorry, but this breaks internal tests because of using functools ([comment](https://github.com/pytorch/pytorch/pull/133778#issuecomment-2310445169))
This commit is contained in:
PyTorch MergeBot 2024-08-26 15:16:17 +00:00
parent bb67ff2ba7
commit e1fc4362fb
4 changed files with 25 additions and 38 deletions

View File

@ -154,3 +154,15 @@ def instantiate_user_defined_class_object(cls, /, *args, **kwargs):
if isinstance(obj, cls):
obj.__init__(*args, **kwargs)
return obj
def fspath(path):
# Python equivalent of os.fspath
if isinstance(path, (str, bytes)):
return path
elif hasattr(path, "__fspath__"):
return path.__fspath__()
else:
raise TypeError(
f"expected str, bytes or os.PathLike object, not {type(path).__name__}"
)

View File

@ -15,7 +15,6 @@ POLYFILLED_MODULE_NAMES: Tuple[str, ...] = (
"builtins",
"functools",
"itertools",
"os",
)
POLYFILLED_MODULES: Tuple["ModuleType", ...] = tuple(
importlib.import_module(f".{submodule}", package=polyfills.__name__)

View File

@ -1,36 +0,0 @@
"""
Python polyfills for os
"""
from __future__ import annotations
import os
from typing import AnyStr
from ..decorators import substitute_in_graph
__all__ = ["fspath"]
# Copied from os.py in the standard library
@substitute_in_graph(os.fspath)
def fspath(path: AnyStr | os.PathLike[AnyStr]) -> AnyStr:
if isinstance(path, (str, bytes)):
return path
path_type = type(path)
try:
path_repr = path_type.__fspath__(path) # type: ignore[arg-type]
except AttributeError:
if hasattr(path_type, "__fspath__"):
raise
raise TypeError(
f"expected str, bytes or os.PathLike object, not {path_type.__name__}",
) from None
if isinstance(path_repr, (str, bytes)):
return path_repr # type: ignore[return-value]
raise TypeError(
f"expected {path_type.__name__}.__fspath__() to return str or bytes, "
f"not {type(path_repr).__name__}",
)

View File

@ -4,6 +4,7 @@ import dataclasses
import functools
import inspect
import itertools
import os
import random
import re
import sys
@ -14,7 +15,7 @@ import torch._C
import torch._numpy as tnp
import torch.utils._pytree as pytree
from .. import config, variables
from .. import config, polyfills, variables
from ..bytecode_transformation import create_call_function, create_instruction
from ..create_parameter_op import do_not_convert_to_tracable_parameter
from ..exc import unimplemented
@ -29,6 +30,7 @@ from ..source import (
)
from ..utils import (
check_unspec_or_constant_args,
hashable,
identity,
is_tensor_base_attr_getter,
proxy_args_kwargs,
@ -42,6 +44,10 @@ from .user_defined import call_random_fn, is_standard_setattr, UserDefinedObject
if TYPE_CHECKING:
from torch._dynamo.symbolic_convert import InstructionTranslator
POLYFILL_SUPPORTED_PYTHON_MODULE_METHODS = {
os.fspath: polyfills.fspath,
}
class SuperVariable(VariableTracker):
_nonvar_fields = {
@ -1107,6 +1113,12 @@ class PythonModuleVariable(VariableTracker):
else:
attr_value = self.value.__dict__[name]
if hashable(attr_value):
if polyfill_fn := POLYFILL_SUPPORTED_PYTHON_MODULE_METHODS.get(
attr_value, None
):
return variables.UserFunctionVariable(polyfill_fn)
if self.source:
new_source = AttrSource(self.source, name)
return VariableBuilder(tx, new_source)(attr_value)