mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Back out "Make TorchScript Preserve Fully Qualified Class Name for Python Exceptions"
Summary: as title Test Plan: ``` buck run mode/opt-split-dwarf -c=python.package_style=inplace //ai_infra/distributed_ai/pyper_test_framework/templates:pyper_release_v2 -- --model inline_cvr_post_imp_deterministic_shrunk_pyper_release_v2 --cluster TSCTestCluster --hpc_identity oncall_pyper_oncall --stage prod_offline_training --test_module training_platform ... ############## Start inline_cvr_post_imp_model Test Results Analysis ############## I1226 22:03:56.789000 3346280 test_driver.py:139 UNKNOWN ] Test finished in 808.2743511786684 seconds. +-------------------------+---------+------------------------+-----------------+ | Test Case | Status | Message | Model Entity ID | +-------------------------+---------+------------------------+-----------------+ | SmallWorld_release_test | Success | finished successfully. | 987987491 | +-------------------------+---------+------------------------+-----------------+ I1226 22:03:56.790000 3346280 test_driver.py:143 UNKNOWN ] test_run_id: 3d085f61-28d1-411d-bd27-940ea2554b23 use this id to find your run in scuba pyper_test_framework I1226 22:03:56.792000 3346280 test_driver.py:160 UNKNOWN ] Calling cleanup I1226 22:03:56.792000 3346280 training_platform_test_launcher.py:385 UNKNOWN ] Stopping launched jobs 1 I1226 22:03:59.563122 3346280 ClientSingletonManager.cpp:100] Shutting down Manifold ClientSingletonManager ``` Reviewed By: seemethere Differential Revision: D33325936 fbshipit-source-id: 64414bf7061ad77e8ac12eb8abafee4043e0fa1e
This commit is contained in:
parent
4ae71c8d34
commit
bf610f08b0
|
|
@ -1,159 +0,0 @@
|
||||||
/*
|
|
||||||
* We have a python unit test for exceptions in test/jit/test_exception.py .
|
|
||||||
* Add a CPP version here to verify that excepted exception types thrown from
|
|
||||||
* C++. This is hard to test in python code since C++ exceptions will be
|
|
||||||
* translated to python exceptions.
|
|
||||||
*/
|
|
||||||
#include <gtest/gtest.h>
|
|
||||||
#include <pybind11/embed.h>
|
|
||||||
#include <torch/csrc/jit/frontend/parser.h>
|
|
||||||
#include <torch/csrc/jit/frontend/resolver.h>
|
|
||||||
#include <torch/csrc/jit/runtime/jit_exception.h>
|
|
||||||
#include <torch/jit.h>
|
|
||||||
#include <iostream>
|
|
||||||
#include <stdexcept>
|
|
||||||
|
|
||||||
namespace torch {
|
|
||||||
namespace jit {
|
|
||||||
|
|
||||||
namespace py = pybind11;
|
|
||||||
|
|
||||||
TEST(TestException, TestAssertion) {
|
|
||||||
std::string pythonCode = R"PY(
|
|
||||||
def foo():
|
|
||||||
raise AssertionError("An assertion failed")
|
|
||||||
)PY";
|
|
||||||
auto cu_ptr = torch::jit::compile(pythonCode);
|
|
||||||
torch::jit::GraphFunction* gf =
|
|
||||||
(torch::jit::GraphFunction*)&cu_ptr->get_function("foo");
|
|
||||||
std::cerr << "Graph is\n" << *gf->graph() << std::endl;
|
|
||||||
|
|
||||||
bool is_jit_exception = false;
|
|
||||||
std::string message;
|
|
||||||
c10::optional<std::string> exception_class;
|
|
||||||
try {
|
|
||||||
cu_ptr->run_method("foo");
|
|
||||||
} catch (JITException& e) {
|
|
||||||
is_jit_exception = true;
|
|
||||||
message = e.what();
|
|
||||||
exception_class = e.getPythonClassName();
|
|
||||||
}
|
|
||||||
EXPECT_TRUE(is_jit_exception);
|
|
||||||
EXPECT_FALSE(exception_class);
|
|
||||||
EXPECT_TRUE(
|
|
||||||
message.find("RuntimeError: AssertionError: An assertion failed") !=
|
|
||||||
std::string::npos);
|
|
||||||
}
|
|
||||||
|
|
||||||
struct MyPythonExceptionValue : public torch::jit::SugaredValue {
|
|
||||||
explicit MyPythonExceptionValue(const py::object& exception_class) {
|
|
||||||
qualified_name_ =
|
|
||||||
(py::str(py::getattr(exception_class, "__module__", py::str(""))) +
|
|
||||||
py::str(".") +
|
|
||||||
py::str(py::getattr(exception_class, "__name__", py::str(""))))
|
|
||||||
.cast<std::string>();
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string kind() const override {
|
|
||||||
return "My Python exception";
|
|
||||||
}
|
|
||||||
|
|
||||||
// Simplified from PythonExceptionValue::call
|
|
||||||
std::shared_ptr<torch::jit::SugaredValue> call(
|
|
||||||
const torch::jit::SourceRange& loc,
|
|
||||||
torch::jit::GraphFunction& caller,
|
|
||||||
at::ArrayRef<torch::jit::NamedValue> args,
|
|
||||||
at::ArrayRef<torch::jit::NamedValue> kwargs,
|
|
||||||
size_t n_binders) override {
|
|
||||||
TORCH_CHECK(args.size() == 1);
|
|
||||||
Value* error_message = args.at(0).value(*caller.graph());
|
|
||||||
Value* qualified_class_name =
|
|
||||||
insertConstant(*caller.graph(), qualified_name_, loc);
|
|
||||||
return std::make_shared<ExceptionMessageValue>(
|
|
||||||
error_message, qualified_class_name);
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
std::string qualified_name_;
|
|
||||||
};
|
|
||||||
|
|
||||||
class SimpleResolver : public torch::jit::Resolver {
|
|
||||||
public:
|
|
||||||
explicit SimpleResolver() {}
|
|
||||||
|
|
||||||
std::shared_ptr<torch::jit::SugaredValue> resolveValue(
|
|
||||||
const std::string& name,
|
|
||||||
torch::jit::GraphFunction& m,
|
|
||||||
const torch::jit::SourceRange& loc) override {
|
|
||||||
// follows toSugaredValue (toSugaredValue is defined in caffe2:_C which is
|
|
||||||
// a python extension. We can not add that as a cpp_binary's dep)
|
|
||||||
if (name == "SimpleValueError") {
|
|
||||||
py::object obj = py::globals()["SimpleValueError"];
|
|
||||||
return std::make_shared<MyPythonExceptionValue>(obj);
|
|
||||||
}
|
|
||||||
TORCH_CHECK(false, "resolveValue: can not resolve '", name, "{}'");
|
|
||||||
}
|
|
||||||
|
|
||||||
torch::jit::TypePtr resolveType(
|
|
||||||
const std::string& name,
|
|
||||||
const torch::jit::SourceRange& loc) override {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/*
|
|
||||||
* - The python source code parsing for TorchScript here is learned from
|
|
||||||
* torch::jit::compile.
|
|
||||||
* - The code only parses one Def. If there are multiple in the code, those
|
|
||||||
* except the first one are skipped.
|
|
||||||
*/
|
|
||||||
TEST(TestException, TestCustomException) {
|
|
||||||
py::scoped_interpreter guard{};
|
|
||||||
py::exec(R"PY(
|
|
||||||
class SimpleValueError(ValueError):
|
|
||||||
def __init__(self, message):
|
|
||||||
super(SimpleValueError, self).__init__(message)
|
|
||||||
)PY");
|
|
||||||
|
|
||||||
std::string pythonCode = R"PY(
|
|
||||||
def foo():
|
|
||||||
raise SimpleValueError("An assertion failed")
|
|
||||||
)PY";
|
|
||||||
|
|
||||||
torch::jit::Parser p(
|
|
||||||
std::make_shared<torch::jit::Source>(pythonCode, "<string>", 1));
|
|
||||||
auto def = torch::jit::Def(p.parseFunction(/*is_method=*/false));
|
|
||||||
std::cerr << "Def is:\n" << def << std::endl;
|
|
||||||
auto cu = std::make_shared<torch::jit::CompilationUnit>();
|
|
||||||
(void)cu->define(
|
|
||||||
c10::nullopt,
|
|
||||||
{},
|
|
||||||
{},
|
|
||||||
{def},
|
|
||||||
// class PythonResolver is defined in
|
|
||||||
// torch/csrc/jit/python/script_init.cpp. It's not in a header file so I
|
|
||||||
// can not use it. Create a SimpleResolver insteand
|
|
||||||
{std::make_shared<SimpleResolver>()},
|
|
||||||
nullptr);
|
|
||||||
torch::jit::GraphFunction* gf =
|
|
||||||
(torch::jit::GraphFunction*)&cu->get_function("foo");
|
|
||||||
std::cerr << "Graph is\n" << *gf->graph() << std::endl;
|
|
||||||
bool is_jit_exception = false;
|
|
||||||
c10::optional<std::string> exception_class;
|
|
||||||
std::string message;
|
|
||||||
try {
|
|
||||||
cu->run_method("foo");
|
|
||||||
} catch (JITException& e) {
|
|
||||||
is_jit_exception = true;
|
|
||||||
exception_class = e.getPythonClassName();
|
|
||||||
message = e.what();
|
|
||||||
}
|
|
||||||
EXPECT_TRUE(is_jit_exception);
|
|
||||||
EXPECT_EQ("__main__.SimpleValueError", *exception_class);
|
|
||||||
EXPECT_TRUE(
|
|
||||||
message.find("__main__.SimpleValueError: An assertion failed") !=
|
|
||||||
std::string::npos);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace jit
|
|
||||||
} // namespace torch
|
|
||||||
|
|
@ -1,8 +0,0 @@
|
||||||
r"""
|
|
||||||
Define exceptions used in test_exception.py. We define them in a
|
|
||||||
separate file on purpose to make sure the fully qualified exception class name
|
|
||||||
is captured correctly in suce cases.
|
|
||||||
"""
|
|
||||||
class MyKeyError(KeyError):
|
|
||||||
def __init__(self, msg):
|
|
||||||
super(KeyError, self).__init__(msg)
|
|
||||||
|
|
@ -1,175 +0,0 @@
|
||||||
from torch.testing._internal.common_utils import TestCase
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
|
|
||||||
r"""
|
|
||||||
Test TorchScript exception handling.
|
|
||||||
"""
|
|
||||||
class TestException(TestCase):
|
|
||||||
def test_assertions(self):
|
|
||||||
cu = torch.jit.CompilationUnit('''
|
|
||||||
def foo(cond):
|
|
||||||
assert bool(cond), "hi"
|
|
||||||
return 0
|
|
||||||
''')
|
|
||||||
|
|
||||||
cu.foo(torch.tensor(1))
|
|
||||||
with self.assertRaisesRegex(torch.jit.Error, "AssertionError: hi"):
|
|
||||||
cu.foo(torch.tensor(0))
|
|
||||||
|
|
||||||
@torch.jit.script
|
|
||||||
def foo(cond):
|
|
||||||
assert bool(cond), "hi"
|
|
||||||
|
|
||||||
foo(torch.tensor(1))
|
|
||||||
# we don't currently validate the name of the exception
|
|
||||||
with self.assertRaisesRegex(torch.jit.Error, "AssertionError: hi"):
|
|
||||||
foo(torch.tensor(0))
|
|
||||||
|
|
||||||
def test_pyop_exception_message(self):
|
|
||||||
class Foo(torch.jit.ScriptModule):
|
|
||||||
def __init__(self):
|
|
||||||
super(Foo, self).__init__()
|
|
||||||
self.conv = nn.Conv2d(1, 10, kernel_size=5)
|
|
||||||
|
|
||||||
@torch.jit.script_method
|
|
||||||
def forward(self, x):
|
|
||||||
return self.conv(x)
|
|
||||||
foo = Foo()
|
|
||||||
# testing that the correct error message propagates
|
|
||||||
with self.assertRaisesRegex(RuntimeError, "Expected 4-dimensional input for 4-dimensional weight"):
|
|
||||||
foo(torch.ones([123])) # wrong size
|
|
||||||
|
|
||||||
def test_builtin_error_messsage(self):
|
|
||||||
with self.assertRaisesRegex(RuntimeError, "Arguments for call are not valid"):
|
|
||||||
@torch.jit.script
|
|
||||||
def close_match(x):
|
|
||||||
return x.masked_fill(True)
|
|
||||||
|
|
||||||
with self.assertRaisesRegex(RuntimeError, "This op may not exist or may not be currently "
|
|
||||||
"supported in TorchScript"):
|
|
||||||
@torch.jit.script
|
|
||||||
def unknown_op(x):
|
|
||||||
torch.set_anomaly_enabled(True)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def test_exceptions(self):
|
|
||||||
cu = torch.jit.CompilationUnit('''
|
|
||||||
def foo(cond):
|
|
||||||
if bool(cond):
|
|
||||||
raise ValueError(3)
|
|
||||||
return 1
|
|
||||||
''')
|
|
||||||
|
|
||||||
cu.foo(torch.tensor(0))
|
|
||||||
with self.assertRaisesRegex(torch.jit.Error, "3"):
|
|
||||||
cu.foo(torch.tensor(1))
|
|
||||||
|
|
||||||
def foo(cond):
|
|
||||||
a = 3
|
|
||||||
if bool(cond):
|
|
||||||
raise ArbitraryError(a, "hi")
|
|
||||||
if 1 == 2:
|
|
||||||
raise ArbitraryError
|
|
||||||
return a
|
|
||||||
|
|
||||||
with self.assertRaisesRegex(RuntimeError, "undefined value ArbitraryError"):
|
|
||||||
torch.jit.script(foo)
|
|
||||||
|
|
||||||
def exception_as_value():
|
|
||||||
a = Exception()
|
|
||||||
print(a)
|
|
||||||
|
|
||||||
with self.assertRaisesRegex(RuntimeError, "cannot be used as a value"):
|
|
||||||
torch.jit.script(exception_as_value)
|
|
||||||
|
|
||||||
@torch.jit.script
|
|
||||||
def foo_no_decl_always_throws():
|
|
||||||
raise RuntimeError("Hi")
|
|
||||||
|
|
||||||
# function that has no declared type but always throws set to None
|
|
||||||
output_type = next(foo_no_decl_always_throws.graph.outputs()).type()
|
|
||||||
self.assertTrue(str(output_type) == "NoneType")
|
|
||||||
|
|
||||||
@torch.jit.script
|
|
||||||
def foo_decl_always_throws():
|
|
||||||
# type: () -> Tensor
|
|
||||||
raise Exception("Hi")
|
|
||||||
|
|
||||||
output_type = next(foo_decl_always_throws.graph.outputs()).type()
|
|
||||||
self.assertTrue(str(output_type) == "Tensor")
|
|
||||||
|
|
||||||
def foo():
|
|
||||||
raise 3 + 4
|
|
||||||
|
|
||||||
with self.assertRaisesRegex(RuntimeError, "must derive from BaseException"):
|
|
||||||
torch.jit.script(foo)
|
|
||||||
|
|
||||||
# a escapes scope
|
|
||||||
@torch.jit.script
|
|
||||||
def foo():
|
|
||||||
if 1 == 1:
|
|
||||||
a = 1
|
|
||||||
else:
|
|
||||||
if 1 == 1:
|
|
||||||
raise Exception("Hi")
|
|
||||||
else:
|
|
||||||
raise Exception("Hi")
|
|
||||||
return a
|
|
||||||
self.assertEqual(foo(), 1)
|
|
||||||
|
|
||||||
@torch.jit.script
|
|
||||||
def tuple_fn():
|
|
||||||
raise RuntimeError("hello", "goodbye")
|
|
||||||
|
|
||||||
with self.assertRaisesRegex(torch.jit.Error, "hello, goodbye"):
|
|
||||||
tuple_fn()
|
|
||||||
|
|
||||||
@torch.jit.script
|
|
||||||
def no_message():
|
|
||||||
raise RuntimeError
|
|
||||||
|
|
||||||
with self.assertRaisesRegex(torch.jit.Error, "RuntimeError"):
|
|
||||||
no_message()
|
|
||||||
|
|
||||||
def test_python_op_exception(self):
|
|
||||||
@torch.jit.ignore
|
|
||||||
def python_op(x):
|
|
||||||
raise Exception("bad!")
|
|
||||||
|
|
||||||
@torch.jit.script
|
|
||||||
def fn(x):
|
|
||||||
return python_op(x)
|
|
||||||
|
|
||||||
with self.assertRaisesRegex(RuntimeError, "operation failed in the TorchScript interpreter"):
|
|
||||||
fn(torch.tensor(4))
|
|
||||||
|
|
||||||
def test_dict_expansion_raises_error(self):
|
|
||||||
def fn(self):
|
|
||||||
d = {"foo": 1, "bar": 2, "baz": 3}
|
|
||||||
return {**d}
|
|
||||||
|
|
||||||
with self.assertRaisesRegex(torch.jit.frontend.NotSupportedError,
|
|
||||||
"Dict expansion "):
|
|
||||||
torch.jit.script(fn)
|
|
||||||
|
|
||||||
def test_custom_python_exception(self):
|
|
||||||
class MyValueError(ValueError):
|
|
||||||
def __init__(self, msg):
|
|
||||||
super(MyValueError, self).__init__(msg)
|
|
||||||
|
|
||||||
@torch.jit.script
|
|
||||||
def fn():
|
|
||||||
raise MyValueError("test custom exception")
|
|
||||||
|
|
||||||
with self.assertRaisesRegex(torch.jit.Error, "jit.test_exception.MyValueError: test custom exception"):
|
|
||||||
fn()
|
|
||||||
|
|
||||||
def test_custom_python_exception_defined_elsewhere(self):
|
|
||||||
from jit.myexception import MyKeyError
|
|
||||||
|
|
||||||
@torch.jit.script
|
|
||||||
def fn():
|
|
||||||
raise MyKeyError("This is a user defined key error")
|
|
||||||
with self.assertRaisesRegex(torch.jit.Error, "jit.myexception.MyKeyError: This is a user defined key error"):
|
|
||||||
fn()
|
|
||||||
148
test/test_jit.py
148
test/test_jit.py
|
|
@ -73,7 +73,6 @@ from jit.test_batch_mm import TestBatchMM # noqa: F401
|
||||||
from jit.test_dtype_analysis import TestDtypeAnalysis, TestDtypeCustomRulesCPU # noqa: F401
|
from jit.test_dtype_analysis import TestDtypeAnalysis, TestDtypeCustomRulesCPU # noqa: F401
|
||||||
from jit.test_dce import TestDCE # noqa: F401
|
from jit.test_dce import TestDCE # noqa: F401
|
||||||
from jit.test_sparse import TestSparse # noqa: F401
|
from jit.test_sparse import TestSparse # noqa: F401
|
||||||
from jit.test_exception import TestException # noqa: F401
|
|
||||||
|
|
||||||
# Torch
|
# Torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
@ -12994,6 +12993,153 @@ dedent """
|
||||||
|
|
||||||
self.checkScript(dedent(code), (101,))
|
self.checkScript(dedent(code), (101,))
|
||||||
|
|
||||||
|
def test_pyop_exception_message(self):
|
||||||
|
class Foo(torch.jit.ScriptModule):
|
||||||
|
def __init__(self):
|
||||||
|
super(Foo, self).__init__()
|
||||||
|
self.conv = nn.Conv2d(1, 10, kernel_size=5)
|
||||||
|
|
||||||
|
@torch.jit.script_method
|
||||||
|
def forward(self, x):
|
||||||
|
return self.conv(x)
|
||||||
|
foo = Foo()
|
||||||
|
# testing that the correct error message propagates
|
||||||
|
with self.assertRaisesRegex(RuntimeError, "Expected 4-dimensional input for 4-dimensional weight"):
|
||||||
|
foo(torch.ones([123])) # wrong size
|
||||||
|
|
||||||
|
def test_builtin_error_messsage(self):
|
||||||
|
with self.assertRaisesRegex(RuntimeError, "Arguments for call are not valid"):
|
||||||
|
@torch.jit.script
|
||||||
|
def close_match(x):
|
||||||
|
return x.masked_fill(True)
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(RuntimeError, "This op may not exist or may not be currently "
|
||||||
|
"supported in TorchScript"):
|
||||||
|
@torch.jit.script
|
||||||
|
def unknown_op(x):
|
||||||
|
torch.set_anomaly_enabled(True)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def test_exceptions(self):
|
||||||
|
cu = torch.jit.CompilationUnit('''
|
||||||
|
def foo(cond):
|
||||||
|
if bool(cond):
|
||||||
|
raise ValueError(3)
|
||||||
|
return 1
|
||||||
|
''')
|
||||||
|
|
||||||
|
cu.foo(torch.tensor(0))
|
||||||
|
with self.assertRaisesRegex(torch.jit.Error, "3"):
|
||||||
|
cu.foo(torch.tensor(1))
|
||||||
|
|
||||||
|
def foo(cond):
|
||||||
|
a = 3
|
||||||
|
if bool(cond):
|
||||||
|
raise ArbitraryError(a, "hi")
|
||||||
|
if 1 == 2:
|
||||||
|
raise ArbitraryError
|
||||||
|
return a
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(RuntimeError, "undefined value ArbitraryError"):
|
||||||
|
torch.jit.script(foo)
|
||||||
|
|
||||||
|
def exception_as_value():
|
||||||
|
a = Exception()
|
||||||
|
print(a)
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(RuntimeError, "cannot be used as a value"):
|
||||||
|
torch.jit.script(exception_as_value)
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def foo_no_decl_always_throws():
|
||||||
|
raise RuntimeError("Hi")
|
||||||
|
|
||||||
|
# function that has no declared type but always throws set to None
|
||||||
|
output_type = next(foo_no_decl_always_throws.graph.outputs()).type()
|
||||||
|
self.assertTrue(str(output_type) == "NoneType")
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def foo_decl_always_throws():
|
||||||
|
# type: () -> Tensor
|
||||||
|
raise Exception("Hi")
|
||||||
|
|
||||||
|
output_type = next(foo_decl_always_throws.graph.outputs()).type()
|
||||||
|
self.assertTrue(str(output_type) == "Tensor")
|
||||||
|
|
||||||
|
def foo():
|
||||||
|
raise 3 + 4
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(RuntimeError, "must derive from BaseException"):
|
||||||
|
torch.jit.script(foo)
|
||||||
|
|
||||||
|
# a escapes scope
|
||||||
|
@torch.jit.script
|
||||||
|
def foo():
|
||||||
|
if 1 == 1:
|
||||||
|
a = 1
|
||||||
|
else:
|
||||||
|
if 1 == 1:
|
||||||
|
raise Exception("Hi")
|
||||||
|
else:
|
||||||
|
raise Exception("Hi")
|
||||||
|
return a
|
||||||
|
self.assertEqual(foo(), 1)
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def tuple_fn():
|
||||||
|
raise RuntimeError("hello", "goodbye")
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(torch.jit.Error, "hello, goodbye"):
|
||||||
|
tuple_fn()
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def no_message():
|
||||||
|
raise RuntimeError
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(torch.jit.Error, "RuntimeError"):
|
||||||
|
no_message()
|
||||||
|
|
||||||
|
def test_assertions(self):
|
||||||
|
cu = torch.jit.CompilationUnit('''
|
||||||
|
def foo(cond):
|
||||||
|
assert bool(cond), "hi"
|
||||||
|
return 0
|
||||||
|
''')
|
||||||
|
|
||||||
|
cu.foo(torch.tensor(1))
|
||||||
|
with self.assertRaisesRegex(torch.jit.Error, "AssertionError: hi"):
|
||||||
|
cu.foo(torch.tensor(0))
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def foo(cond):
|
||||||
|
assert bool(cond), "hi"
|
||||||
|
|
||||||
|
foo(torch.tensor(1))
|
||||||
|
# we don't currently validate the name of the exception
|
||||||
|
with self.assertRaisesRegex(torch.jit.Error, "AssertionError: hi"):
|
||||||
|
foo(torch.tensor(0))
|
||||||
|
|
||||||
|
def test_python_op_exception(self):
|
||||||
|
@torch.jit.ignore
|
||||||
|
def python_op(x):
|
||||||
|
raise Exception("bad!")
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def fn(x):
|
||||||
|
return python_op(x)
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(RuntimeError, "operation failed in the TorchScript interpreter"):
|
||||||
|
fn(torch.tensor(4))
|
||||||
|
|
||||||
|
def test_dict_expansion_raises_error(self):
|
||||||
|
def fn(self):
|
||||||
|
d = {"foo": 1, "bar": 2, "baz": 3}
|
||||||
|
return {**d}
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(torch.jit.frontend.NotSupportedError,
|
||||||
|
"Dict expansion "):
|
||||||
|
torch.jit.script(fn)
|
||||||
|
|
||||||
def test_module_parameters_and_buffers(self):
|
def test_module_parameters_and_buffers(self):
|
||||||
weights = torch.randn(10, 10)
|
weights = torch.randn(10, 10)
|
||||||
bias = torch.randn(10)
|
bias = torch.randn(10)
|
||||||
|
|
|
||||||
|
|
@ -977,7 +977,7 @@ def is_scripting() -> bool:
|
||||||
|
|
||||||
|
|
||||||
# Retrieves a fully-qualified name (module hierarchy + classname) for a given obj.
|
# Retrieves a fully-qualified name (module hierarchy + classname) for a given obj.
|
||||||
def _qualified_name(obj, mangle_name=True) -> str:
|
def _qualified_name(obj) -> str:
|
||||||
# This special case allows us to override the qualified name on a type.
|
# This special case allows us to override the qualified name on a type.
|
||||||
# It's currently used in conjunction with tracing, where we create a
|
# It's currently used in conjunction with tracing, where we create a
|
||||||
# fake module to filter only supported attributes. However, since this
|
# fake module to filter only supported attributes. However, since this
|
||||||
|
|
@ -1026,16 +1026,13 @@ def _qualified_name(obj, mangle_name=True) -> str:
|
||||||
module_name = module_name.replace("<", "_")
|
module_name = module_name.replace("<", "_")
|
||||||
module_name = module_name.replace(">", "_")
|
module_name = module_name.replace(">", "_")
|
||||||
|
|
||||||
# The PythonExceptionValue C++ class in torch/csrc/jit/python/python_sugared_value.h
|
# __main__ is a builtin module, so rewrite it to "__torch__".
|
||||||
# does not need mangle the python class name.
|
if module_name == "__main__":
|
||||||
if mangle_name:
|
module_name = "__torch__"
|
||||||
# __main__ is a builtin module, so rewrite it to "__torch__".
|
else:
|
||||||
if module_name == "__main__":
|
# Everything else gets a "__torch__" prefix to avoid name collisions
|
||||||
module_name = "__torch__"
|
# with the names of user values.
|
||||||
else:
|
module_name = "__torch__." + module_name
|
||||||
# Everything else gets a "__torch__" prefix to avoid name collisions
|
|
||||||
# with the names of user values.
|
|
||||||
module_name = "__torch__." + module_name
|
|
||||||
|
|
||||||
if "." in name:
|
if "." in name:
|
||||||
raise RuntimeError(f"Could not get qualified name for class '{name}': "
|
raise RuntimeError(f"Could not get qualified name for class '{name}': "
|
||||||
|
|
|
||||||
|
|
@ -2469,14 +2469,12 @@ struct to_ir {
|
||||||
void emitRaise(const Raise& raise) {
|
void emitRaise(const Raise& raise) {
|
||||||
auto sv = emitSugaredExpr(raise.expr(), 1);
|
auto sv = emitSugaredExpr(raise.expr(), 1);
|
||||||
Value* error_message = nullptr;
|
Value* error_message = nullptr;
|
||||||
Value* qualified_class_name = nullptr;
|
|
||||||
|
|
||||||
if (auto exception_instance =
|
if (auto exception_instance =
|
||||||
std::dynamic_pointer_cast<ExceptionMessageValue>(sv)) {
|
std::dynamic_pointer_cast<ExceptionMessageValue>(sv)) {
|
||||||
// The typical case, an instance of the exception class was thrown:
|
// The typical case, an instance of the exception class was thrown:
|
||||||
// raise RuntimeError("error")
|
// raise RuntimeError("error")
|
||||||
error_message = exception_instance->getValue();
|
error_message = exception_instance->getValue();
|
||||||
qualified_class_name = exception_instance->getQualifiedClassName();
|
|
||||||
} else if (
|
} else if (
|
||||||
auto exception_class = std::dynamic_pointer_cast<ExceptionValue>(sv)) {
|
auto exception_class = std::dynamic_pointer_cast<ExceptionValue>(sv)) {
|
||||||
// A bare exception was thrown so add an empty message. e.g.
|
// A bare exception was thrown so add an empty message. e.g.
|
||||||
|
|
@ -2493,11 +2491,7 @@ struct to_ir {
|
||||||
error_message = graph->insert(aten::str, {error_message});
|
error_message = graph->insert(aten::str, {error_message});
|
||||||
}
|
}
|
||||||
|
|
||||||
graph->insert(
|
graph->insert(prim::RaiseException, {error_message}, {}, raise.range());
|
||||||
prim::RaiseException,
|
|
||||||
{error_message, qualified_class_name},
|
|
||||||
{},
|
|
||||||
raise.range());
|
|
||||||
exit_blocks.insert(environment_stack->block());
|
exit_blocks.insert(environment_stack->block());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -741,10 +741,7 @@ struct SimpleSelf : public Self {
|
||||||
// This is not a SimpleValue so it can not pass through the code paths that
|
// This is not a SimpleValue so it can not pass through the code paths that
|
||||||
// expect a SimpleValue as a sugared value.
|
// expect a SimpleValue as a sugared value.
|
||||||
struct TORCH_API ExceptionMessageValue : public SugaredValue {
|
struct TORCH_API ExceptionMessageValue : public SugaredValue {
|
||||||
explicit ExceptionMessageValue(
|
explicit ExceptionMessageValue(Value* value) : value_(value) {}
|
||||||
Value* value,
|
|
||||||
Value* qualified_class_name = nullptr)
|
|
||||||
: value_(value), qualified_class_name_(qualified_class_name) {}
|
|
||||||
|
|
||||||
std::string kind() const override {
|
std::string kind() const override {
|
||||||
return "exception message";
|
return "exception message";
|
||||||
|
|
@ -754,14 +751,7 @@ struct TORCH_API ExceptionMessageValue : public SugaredValue {
|
||||||
return value_;
|
return value_;
|
||||||
}
|
}
|
||||||
|
|
||||||
// qualified python class name
|
|
||||||
Value* getQualifiedClassName() {
|
|
||||||
return qualified_class_name_;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
Value* value_;
|
Value* value_;
|
||||||
Value* qualified_class_name_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct TORCH_API ExceptionValue : public SugaredValue {
|
struct TORCH_API ExceptionValue : public SugaredValue {
|
||||||
|
|
|
||||||
|
|
@ -14,12 +14,7 @@ void tupleIndex(Stack& stack) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void raiseException(Stack& stack) {
|
void raiseException(Stack& stack) {
|
||||||
c10::optional<std::string> qualified_class_name =
|
throw JITException(pop(stack).toStringRef());
|
||||||
pop(stack).toOptional<std::string>();
|
|
||||||
std::string message;
|
|
||||||
pop(stack, message);
|
|
||||||
|
|
||||||
throw JITException(message, qualified_class_name);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void is(Stack& stack) {
|
void is(Stack& stack) {
|
||||||
|
|
|
||||||
|
|
@ -914,11 +914,8 @@ std::shared_ptr<SugaredValue> PythonExceptionValue::call(
|
||||||
->insertNode(caller.graph()->createTuple(message_values))
|
->insertNode(caller.graph()->createTuple(message_values))
|
||||||
->output();
|
->output();
|
||||||
}
|
}
|
||||||
Value* qualified_class_name =
|
|
||||||
insertConstant(*caller.graph(), exception_class_qualified_name_, loc);
|
|
||||||
|
|
||||||
return std::make_shared<ExceptionMessageValue>(
|
return std::make_shared<ExceptionMessageValue>(error_message);
|
||||||
error_message, qualified_class_name);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool isNamedTupleClass(const py::object& obj) {
|
bool isNamedTupleClass(const py::object& obj) {
|
||||||
|
|
|
||||||
|
|
@ -328,12 +328,7 @@ struct VISIBILITY_HIDDEN PythonClassValue : public ClassValue {
|
||||||
struct VISIBILITY_HIDDEN PythonExceptionValue : public ExceptionValue {
|
struct VISIBILITY_HIDDEN PythonExceptionValue : public ExceptionValue {
|
||||||
explicit PythonExceptionValue(const py::object& exception_class)
|
explicit PythonExceptionValue(const py::object& exception_class)
|
||||||
: ExceptionValue(
|
: ExceptionValue(
|
||||||
py::str(py::getattr(exception_class, "__name__", py::str("")))),
|
py::str(py::getattr(exception_class, "__name__", py::str("")))) {}
|
||||||
exception_class_qualified_name_(
|
|
||||||
py::str(py::module::import("torch._jit_internal")
|
|
||||||
.attr("_qualified_name")(
|
|
||||||
exception_class,
|
|
||||||
/*mangle_name=*/false))) {}
|
|
||||||
|
|
||||||
std::string kind() const override {
|
std::string kind() const override {
|
||||||
return "Python exception";
|
return "Python exception";
|
||||||
|
|
@ -345,9 +340,6 @@ struct VISIBILITY_HIDDEN PythonExceptionValue : public ExceptionValue {
|
||||||
at::ArrayRef<NamedValue> args,
|
at::ArrayRef<NamedValue> args,
|
||||||
at::ArrayRef<NamedValue> kwargs,
|
at::ArrayRef<NamedValue> kwargs,
|
||||||
size_t n_binders) override;
|
size_t n_binders) override;
|
||||||
|
|
||||||
private:
|
|
||||||
std::string exception_class_qualified_name_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Python Slice class.
|
// Python Slice class.
|
||||||
|
|
|
||||||
|
|
@ -714,19 +714,10 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
|
||||||
}
|
}
|
||||||
throw;
|
throw;
|
||||||
}
|
}
|
||||||
auto* jit_exception = dynamic_cast<JITException*>(&e);
|
bool is_jit_exception = dynamic_cast<JITException*>(&e);
|
||||||
// Janky af. See https://github.com/pytorch/pytorch/issues/54612
|
// Janky af. See https://github.com/pytorch/pytorch/issues/54612
|
||||||
auto* not_implemented_error = dynamic_cast<c10::NotImplementedError*>(&e);
|
auto* not_implemented_error = dynamic_cast<c10::NotImplementedError*>(&e);
|
||||||
|
handleError(ExceptionMessage(e), is_jit_exception, not_implemented_error);
|
||||||
c10::optional<std::string> python_class_name;
|
|
||||||
if (jit_exception) {
|
|
||||||
python_class_name = jit_exception->getPythonClassName();
|
|
||||||
}
|
|
||||||
handleError(
|
|
||||||
ExceptionMessage(e),
|
|
||||||
(bool)jit_exception,
|
|
||||||
not_implemented_error,
|
|
||||||
python_class_name);
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -745,18 +736,15 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
|
||||||
void handleError(
|
void handleError(
|
||||||
const ExceptionMessage& msg,
|
const ExceptionMessage& msg,
|
||||||
bool is_jit_exception,
|
bool is_jit_exception,
|
||||||
c10::NotImplementedError* not_implemented_error,
|
c10::NotImplementedError* not_implemented_error) {
|
||||||
c10::optional<std::string> python_class_name) {
|
|
||||||
std::ostringstream ss;
|
std::ostringstream ss;
|
||||||
std::string class_name =
|
|
||||||
python_class_name ? *python_class_name : "RuntimeError";
|
|
||||||
ss << "The following operation failed in the TorchScript interpreter.\n";
|
ss << "The following operation failed in the TorchScript interpreter.\n";
|
||||||
formatStackTrace(ss);
|
formatStackTrace(ss);
|
||||||
ss << class_name << ": " << msg << "\n";
|
ss << "RuntimeError: " << msg << "\n";
|
||||||
if (future_) {
|
if (future_) {
|
||||||
future_->setError(std::make_exception_ptr(Future::FutureError(ss.str())));
|
future_->setError(std::make_exception_ptr(Future::FutureError(ss.str())));
|
||||||
} else if (is_jit_exception) {
|
} else if (is_jit_exception) {
|
||||||
throw JITException(ss.str(), python_class_name);
|
throw JITException(ss.str());
|
||||||
} else if (not_implemented_error) {
|
} else if (not_implemented_error) {
|
||||||
throw c10::NotImplementedError(
|
throw c10::NotImplementedError(
|
||||||
ss.str(),
|
ss.str(),
|
||||||
|
|
|
||||||
|
|
@ -3,11 +3,7 @@
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
|
|
||||||
JITException::JITException(
|
JITException::JITException(const std::string& msg) : std::runtime_error(msg) {}
|
||||||
const std::string& msg,
|
|
||||||
c10::optional<std::string> python_class_name)
|
|
||||||
: std::runtime_error(msg),
|
|
||||||
python_class_name_(std::move(python_class_name)) {}
|
|
||||||
|
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
|
|
||||||
|
|
@ -2,24 +2,13 @@
|
||||||
|
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
|
|
||||||
#include <c10/util/Optional.h>
|
|
||||||
#include <torch/csrc/Export.h>
|
#include <torch/csrc/Export.h>
|
||||||
#include <string>
|
|
||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
|
|
||||||
struct TORCH_API JITException : public std::runtime_error {
|
struct TORCH_API JITException : public std::runtime_error {
|
||||||
explicit JITException(
|
explicit JITException(const std::string& msg);
|
||||||
const std::string& msg,
|
|
||||||
c10::optional<std::string> python_class_name = c10::nullopt);
|
|
||||||
|
|
||||||
c10::optional<std::string> getPythonClassName() const {
|
|
||||||
return python_class_name_;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
c10::optional<std::string> python_class_name_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
|
|
|
||||||
|
|
@ -406,8 +406,7 @@ static const OperatorGeneratorArgs opGenArgs[] = {
|
||||||
numToTensorScalar,
|
numToTensorScalar,
|
||||||
aliasAnalysisFromSchema()),
|
aliasAnalysisFromSchema()),
|
||||||
OperatorGeneratorArgs(
|
OperatorGeneratorArgs(
|
||||||
TORCH_SELECTIVE_SCHEMA(
|
TORCH_SELECTIVE_SCHEMA("prim::RaiseException(str msg) -> ()"),
|
||||||
"prim::RaiseException(str msg, str? cls=None) -> ()"),
|
|
||||||
raiseException,
|
raiseException,
|
||||||
aliasAnalysisFromSchema()),
|
aliasAnalysisFromSchema()),
|
||||||
OperatorGeneratorArgs(
|
OperatorGeneratorArgs(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user