diff --git a/test/cpp/jit/test_exception.cpp b/test/cpp/jit/test_exception.cpp deleted file mode 100644 index b6b3cbcd679..00000000000 --- a/test/cpp/jit/test_exception.cpp +++ /dev/null @@ -1,159 +0,0 @@ -/* - * We have a python unit test for exceptions in test/jit/test_exception.py . - * Add a CPP version here to verify that excepted exception types thrown from - * C++. This is hard to test in python code since C++ exceptions will be - * translated to python exceptions. - */ -#include -#include -#include -#include -#include -#include -#include -#include - -namespace torch { -namespace jit { - -namespace py = pybind11; - -TEST(TestException, TestAssertion) { - std::string pythonCode = R"PY( - def foo(): - raise AssertionError("An assertion failed") - )PY"; - auto cu_ptr = torch::jit::compile(pythonCode); - torch::jit::GraphFunction* gf = - (torch::jit::GraphFunction*)&cu_ptr->get_function("foo"); - std::cerr << "Graph is\n" << *gf->graph() << std::endl; - - bool is_jit_exception = false; - std::string message; - c10::optional exception_class; - try { - cu_ptr->run_method("foo"); - } catch (JITException& e) { - is_jit_exception = true; - message = e.what(); - exception_class = e.getPythonClassName(); - } - EXPECT_TRUE(is_jit_exception); - EXPECT_FALSE(exception_class); - EXPECT_TRUE( - message.find("RuntimeError: AssertionError: An assertion failed") != - std::string::npos); -} - -struct MyPythonExceptionValue : public torch::jit::SugaredValue { - explicit MyPythonExceptionValue(const py::object& exception_class) { - qualified_name_ = - (py::str(py::getattr(exception_class, "__module__", py::str(""))) + - py::str(".") + - py::str(py::getattr(exception_class, "__name__", py::str("")))) - .cast(); - } - - std::string kind() const override { - return "My Python exception"; - } - - // Simplified from PythonExceptionValue::call - std::shared_ptr call( - const torch::jit::SourceRange& loc, - torch::jit::GraphFunction& caller, - at::ArrayRef args, - at::ArrayRef kwargs, - size_t n_binders) override { - TORCH_CHECK(args.size() == 1); - Value* error_message = args.at(0).value(*caller.graph()); - Value* qualified_class_name = - insertConstant(*caller.graph(), qualified_name_, loc); - return std::make_shared( - error_message, qualified_class_name); - } - - private: - std::string qualified_name_; -}; - -class SimpleResolver : public torch::jit::Resolver { - public: - explicit SimpleResolver() {} - - std::shared_ptr resolveValue( - const std::string& name, - torch::jit::GraphFunction& m, - const torch::jit::SourceRange& loc) override { - // follows toSugaredValue (toSugaredValue is defined in caffe2:_C which is - // a python extension. We can not add that as a cpp_binary's dep) - if (name == "SimpleValueError") { - py::object obj = py::globals()["SimpleValueError"]; - return std::make_shared(obj); - } - TORCH_CHECK(false, "resolveValue: can not resolve '", name, "{}'"); - } - - torch::jit::TypePtr resolveType( - const std::string& name, - const torch::jit::SourceRange& loc) override { - return nullptr; - } -}; - -/* - * - The python source code parsing for TorchScript here is learned from - * torch::jit::compile. - * - The code only parses one Def. If there are multiple in the code, those - * except the first one are skipped. - */ -TEST(TestException, TestCustomException) { - py::scoped_interpreter guard{}; - py::exec(R"PY( - class SimpleValueError(ValueError): - def __init__(self, message): - super(SimpleValueError, self).__init__(message) - )PY"); - - std::string pythonCode = R"PY( - def foo(): - raise SimpleValueError("An assertion failed") - )PY"; - - torch::jit::Parser p( - std::make_shared(pythonCode, "", 1)); - auto def = torch::jit::Def(p.parseFunction(/*is_method=*/false)); - std::cerr << "Def is:\n" << def << std::endl; - auto cu = std::make_shared(); - (void)cu->define( - c10::nullopt, - {}, - {}, - {def}, - // class PythonResolver is defined in - // torch/csrc/jit/python/script_init.cpp. It's not in a header file so I - // can not use it. Create a SimpleResolver insteand - {std::make_shared()}, - nullptr); - torch::jit::GraphFunction* gf = - (torch::jit::GraphFunction*)&cu->get_function("foo"); - std::cerr << "Graph is\n" << *gf->graph() << std::endl; - bool is_jit_exception = false; - c10::optional exception_class; - std::string message; - try { - cu->run_method("foo"); - } catch (JITException& e) { - is_jit_exception = true; - exception_class = e.getPythonClassName(); - message = e.what(); - } - EXPECT_TRUE(is_jit_exception); - EXPECT_EQ("__main__.SimpleValueError", *exception_class); - EXPECT_TRUE( - message.find("__main__.SimpleValueError: An assertion failed") != - std::string::npos); -} - -} // namespace jit -} // namespace torch diff --git a/test/jit/myexception.py b/test/jit/myexception.py deleted file mode 100644 index 5937bd3c91b..00000000000 --- a/test/jit/myexception.py +++ /dev/null @@ -1,8 +0,0 @@ -r""" -Define exceptions used in test_exception.py. We define them in a -separate file on purpose to make sure the fully qualified exception class name -is captured correctly in suce cases. -""" -class MyKeyError(KeyError): - def __init__(self, msg): - super(KeyError, self).__init__(msg) diff --git a/test/jit/test_exception.py b/test/jit/test_exception.py deleted file mode 100644 index 4365bfeb39b..00000000000 --- a/test/jit/test_exception.py +++ /dev/null @@ -1,175 +0,0 @@ -from torch.testing._internal.common_utils import TestCase -import torch -from torch import nn - -r""" -Test TorchScript exception handling. -""" -class TestException(TestCase): - def test_assertions(self): - cu = torch.jit.CompilationUnit(''' - def foo(cond): - assert bool(cond), "hi" - return 0 - ''') - - cu.foo(torch.tensor(1)) - with self.assertRaisesRegex(torch.jit.Error, "AssertionError: hi"): - cu.foo(torch.tensor(0)) - - @torch.jit.script - def foo(cond): - assert bool(cond), "hi" - - foo(torch.tensor(1)) - # we don't currently validate the name of the exception - with self.assertRaisesRegex(torch.jit.Error, "AssertionError: hi"): - foo(torch.tensor(0)) - - def test_pyop_exception_message(self): - class Foo(torch.jit.ScriptModule): - def __init__(self): - super(Foo, self).__init__() - self.conv = nn.Conv2d(1, 10, kernel_size=5) - - @torch.jit.script_method - def forward(self, x): - return self.conv(x) - foo = Foo() - # testing that the correct error message propagates - with self.assertRaisesRegex(RuntimeError, "Expected 4-dimensional input for 4-dimensional weight"): - foo(torch.ones([123])) # wrong size - - def test_builtin_error_messsage(self): - with self.assertRaisesRegex(RuntimeError, "Arguments for call are not valid"): - @torch.jit.script - def close_match(x): - return x.masked_fill(True) - - with self.assertRaisesRegex(RuntimeError, "This op may not exist or may not be currently " - "supported in TorchScript"): - @torch.jit.script - def unknown_op(x): - torch.set_anomaly_enabled(True) - return x - - def test_exceptions(self): - cu = torch.jit.CompilationUnit(''' - def foo(cond): - if bool(cond): - raise ValueError(3) - return 1 - ''') - - cu.foo(torch.tensor(0)) - with self.assertRaisesRegex(torch.jit.Error, "3"): - cu.foo(torch.tensor(1)) - - def foo(cond): - a = 3 - if bool(cond): - raise ArbitraryError(a, "hi") - if 1 == 2: - raise ArbitraryError - return a - - with self.assertRaisesRegex(RuntimeError, "undefined value ArbitraryError"): - torch.jit.script(foo) - - def exception_as_value(): - a = Exception() - print(a) - - with self.assertRaisesRegex(RuntimeError, "cannot be used as a value"): - torch.jit.script(exception_as_value) - - @torch.jit.script - def foo_no_decl_always_throws(): - raise RuntimeError("Hi") - - # function that has no declared type but always throws set to None - output_type = next(foo_no_decl_always_throws.graph.outputs()).type() - self.assertTrue(str(output_type) == "NoneType") - - @torch.jit.script - def foo_decl_always_throws(): - # type: () -> Tensor - raise Exception("Hi") - - output_type = next(foo_decl_always_throws.graph.outputs()).type() - self.assertTrue(str(output_type) == "Tensor") - - def foo(): - raise 3 + 4 - - with self.assertRaisesRegex(RuntimeError, "must derive from BaseException"): - torch.jit.script(foo) - - # a escapes scope - @torch.jit.script - def foo(): - if 1 == 1: - a = 1 - else: - if 1 == 1: - raise Exception("Hi") - else: - raise Exception("Hi") - return a - self.assertEqual(foo(), 1) - - @torch.jit.script - def tuple_fn(): - raise RuntimeError("hello", "goodbye") - - with self.assertRaisesRegex(torch.jit.Error, "hello, goodbye"): - tuple_fn() - - @torch.jit.script - def no_message(): - raise RuntimeError - - with self.assertRaisesRegex(torch.jit.Error, "RuntimeError"): - no_message() - - def test_python_op_exception(self): - @torch.jit.ignore - def python_op(x): - raise Exception("bad!") - - @torch.jit.script - def fn(x): - return python_op(x) - - with self.assertRaisesRegex(RuntimeError, "operation failed in the TorchScript interpreter"): - fn(torch.tensor(4)) - - def test_dict_expansion_raises_error(self): - def fn(self): - d = {"foo": 1, "bar": 2, "baz": 3} - return {**d} - - with self.assertRaisesRegex(torch.jit.frontend.NotSupportedError, - "Dict expansion "): - torch.jit.script(fn) - - def test_custom_python_exception(self): - class MyValueError(ValueError): - def __init__(self, msg): - super(MyValueError, self).__init__(msg) - - @torch.jit.script - def fn(): - raise MyValueError("test custom exception") - - with self.assertRaisesRegex(torch.jit.Error, "jit.test_exception.MyValueError: test custom exception"): - fn() - - def test_custom_python_exception_defined_elsewhere(self): - from jit.myexception import MyKeyError - - @torch.jit.script - def fn(): - raise MyKeyError("This is a user defined key error") - with self.assertRaisesRegex(torch.jit.Error, "jit.myexception.MyKeyError: This is a user defined key error"): - fn() diff --git a/test/test_jit.py b/test/test_jit.py index 9cdaaf78961..ffa614ff660 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -73,7 +73,6 @@ from jit.test_batch_mm import TestBatchMM # noqa: F401 from jit.test_dtype_analysis import TestDtypeAnalysis, TestDtypeCustomRulesCPU # noqa: F401 from jit.test_dce import TestDCE # noqa: F401 from jit.test_sparse import TestSparse # noqa: F401 -from jit.test_exception import TestException # noqa: F401 # Torch from torch import Tensor @@ -12994,6 +12993,153 @@ dedent """ self.checkScript(dedent(code), (101,)) + def test_pyop_exception_message(self): + class Foo(torch.jit.ScriptModule): + def __init__(self): + super(Foo, self).__init__() + self.conv = nn.Conv2d(1, 10, kernel_size=5) + + @torch.jit.script_method + def forward(self, x): + return self.conv(x) + foo = Foo() + # testing that the correct error message propagates + with self.assertRaisesRegex(RuntimeError, "Expected 4-dimensional input for 4-dimensional weight"): + foo(torch.ones([123])) # wrong size + + def test_builtin_error_messsage(self): + with self.assertRaisesRegex(RuntimeError, "Arguments for call are not valid"): + @torch.jit.script + def close_match(x): + return x.masked_fill(True) + + with self.assertRaisesRegex(RuntimeError, "This op may not exist or may not be currently " + "supported in TorchScript"): + @torch.jit.script + def unknown_op(x): + torch.set_anomaly_enabled(True) + return x + + def test_exceptions(self): + cu = torch.jit.CompilationUnit(''' + def foo(cond): + if bool(cond): + raise ValueError(3) + return 1 + ''') + + cu.foo(torch.tensor(0)) + with self.assertRaisesRegex(torch.jit.Error, "3"): + cu.foo(torch.tensor(1)) + + def foo(cond): + a = 3 + if bool(cond): + raise ArbitraryError(a, "hi") + if 1 == 2: + raise ArbitraryError + return a + + with self.assertRaisesRegex(RuntimeError, "undefined value ArbitraryError"): + torch.jit.script(foo) + + def exception_as_value(): + a = Exception() + print(a) + + with self.assertRaisesRegex(RuntimeError, "cannot be used as a value"): + torch.jit.script(exception_as_value) + + @torch.jit.script + def foo_no_decl_always_throws(): + raise RuntimeError("Hi") + + # function that has no declared type but always throws set to None + output_type = next(foo_no_decl_always_throws.graph.outputs()).type() + self.assertTrue(str(output_type) == "NoneType") + + @torch.jit.script + def foo_decl_always_throws(): + # type: () -> Tensor + raise Exception("Hi") + + output_type = next(foo_decl_always_throws.graph.outputs()).type() + self.assertTrue(str(output_type) == "Tensor") + + def foo(): + raise 3 + 4 + + with self.assertRaisesRegex(RuntimeError, "must derive from BaseException"): + torch.jit.script(foo) + + # a escapes scope + @torch.jit.script + def foo(): + if 1 == 1: + a = 1 + else: + if 1 == 1: + raise Exception("Hi") + else: + raise Exception("Hi") + return a + self.assertEqual(foo(), 1) + + @torch.jit.script + def tuple_fn(): + raise RuntimeError("hello", "goodbye") + + with self.assertRaisesRegex(torch.jit.Error, "hello, goodbye"): + tuple_fn() + + @torch.jit.script + def no_message(): + raise RuntimeError + + with self.assertRaisesRegex(torch.jit.Error, "RuntimeError"): + no_message() + + def test_assertions(self): + cu = torch.jit.CompilationUnit(''' + def foo(cond): + assert bool(cond), "hi" + return 0 + ''') + + cu.foo(torch.tensor(1)) + with self.assertRaisesRegex(torch.jit.Error, "AssertionError: hi"): + cu.foo(torch.tensor(0)) + + @torch.jit.script + def foo(cond): + assert bool(cond), "hi" + + foo(torch.tensor(1)) + # we don't currently validate the name of the exception + with self.assertRaisesRegex(torch.jit.Error, "AssertionError: hi"): + foo(torch.tensor(0)) + + def test_python_op_exception(self): + @torch.jit.ignore + def python_op(x): + raise Exception("bad!") + + @torch.jit.script + def fn(x): + return python_op(x) + + with self.assertRaisesRegex(RuntimeError, "operation failed in the TorchScript interpreter"): + fn(torch.tensor(4)) + + def test_dict_expansion_raises_error(self): + def fn(self): + d = {"foo": 1, "bar": 2, "baz": 3} + return {**d} + + with self.assertRaisesRegex(torch.jit.frontend.NotSupportedError, + "Dict expansion "): + torch.jit.script(fn) + def test_module_parameters_and_buffers(self): weights = torch.randn(10, 10) bias = torch.randn(10) diff --git a/torch/_jit_internal.py b/torch/_jit_internal.py index 3f66d66de5e..9373964648f 100644 --- a/torch/_jit_internal.py +++ b/torch/_jit_internal.py @@ -977,7 +977,7 @@ def is_scripting() -> bool: # Retrieves a fully-qualified name (module hierarchy + classname) for a given obj. -def _qualified_name(obj, mangle_name=True) -> str: +def _qualified_name(obj) -> str: # This special case allows us to override the qualified name on a type. # It's currently used in conjunction with tracing, where we create a # fake module to filter only supported attributes. However, since this @@ -1026,16 +1026,13 @@ def _qualified_name(obj, mangle_name=True) -> str: module_name = module_name.replace("<", "_") module_name = module_name.replace(">", "_") - # The PythonExceptionValue C++ class in torch/csrc/jit/python/python_sugared_value.h - # does not need mangle the python class name. - if mangle_name: - # __main__ is a builtin module, so rewrite it to "__torch__". - if module_name == "__main__": - module_name = "__torch__" - else: - # Everything else gets a "__torch__" prefix to avoid name collisions - # with the names of user values. - module_name = "__torch__." + module_name + # __main__ is a builtin module, so rewrite it to "__torch__". + if module_name == "__main__": + module_name = "__torch__" + else: + # Everything else gets a "__torch__" prefix to avoid name collisions + # with the names of user values. + module_name = "__torch__." + module_name if "." in name: raise RuntimeError(f"Could not get qualified name for class '{name}': " diff --git a/torch/csrc/jit/frontend/ir_emitter.cpp b/torch/csrc/jit/frontend/ir_emitter.cpp index 5f7b44f1c2b..9cc51bc6155 100644 --- a/torch/csrc/jit/frontend/ir_emitter.cpp +++ b/torch/csrc/jit/frontend/ir_emitter.cpp @@ -2469,14 +2469,12 @@ struct to_ir { void emitRaise(const Raise& raise) { auto sv = emitSugaredExpr(raise.expr(), 1); Value* error_message = nullptr; - Value* qualified_class_name = nullptr; if (auto exception_instance = std::dynamic_pointer_cast(sv)) { // The typical case, an instance of the exception class was thrown: // raise RuntimeError("error") error_message = exception_instance->getValue(); - qualified_class_name = exception_instance->getQualifiedClassName(); } else if ( auto exception_class = std::dynamic_pointer_cast(sv)) { // A bare exception was thrown so add an empty message. e.g. @@ -2493,11 +2491,7 @@ struct to_ir { error_message = graph->insert(aten::str, {error_message}); } - graph->insert( - prim::RaiseException, - {error_message, qualified_class_name}, - {}, - raise.range()); + graph->insert(prim::RaiseException, {error_message}, {}, raise.range()); exit_blocks.insert(environment_stack->block()); } diff --git a/torch/csrc/jit/frontend/sugared_value.h b/torch/csrc/jit/frontend/sugared_value.h index 7504fb69f62..91a2d3e4fbf 100644 --- a/torch/csrc/jit/frontend/sugared_value.h +++ b/torch/csrc/jit/frontend/sugared_value.h @@ -741,10 +741,7 @@ struct SimpleSelf : public Self { // This is not a SimpleValue so it can not pass through the code paths that // expect a SimpleValue as a sugared value. struct TORCH_API ExceptionMessageValue : public SugaredValue { - explicit ExceptionMessageValue( - Value* value, - Value* qualified_class_name = nullptr) - : value_(value), qualified_class_name_(qualified_class_name) {} + explicit ExceptionMessageValue(Value* value) : value_(value) {} std::string kind() const override { return "exception message"; @@ -754,14 +751,7 @@ struct TORCH_API ExceptionMessageValue : public SugaredValue { return value_; } - // qualified python class name - Value* getQualifiedClassName() { - return qualified_class_name_; - } - - private: Value* value_; - Value* qualified_class_name_; }; struct TORCH_API ExceptionValue : public SugaredValue { diff --git a/torch/csrc/jit/mobile/promoted_prim_ops.cpp b/torch/csrc/jit/mobile/promoted_prim_ops.cpp index fba28d15e3a..50d1009b348 100644 --- a/torch/csrc/jit/mobile/promoted_prim_ops.cpp +++ b/torch/csrc/jit/mobile/promoted_prim_ops.cpp @@ -14,12 +14,7 @@ void tupleIndex(Stack& stack) { } void raiseException(Stack& stack) { - c10::optional qualified_class_name = - pop(stack).toOptional(); - std::string message; - pop(stack, message); - - throw JITException(message, qualified_class_name); + throw JITException(pop(stack).toStringRef()); } void is(Stack& stack) { diff --git a/torch/csrc/jit/python/python_sugared_value.cpp b/torch/csrc/jit/python/python_sugared_value.cpp index d5d97422a5c..1a4ac0370ac 100644 --- a/torch/csrc/jit/python/python_sugared_value.cpp +++ b/torch/csrc/jit/python/python_sugared_value.cpp @@ -914,11 +914,8 @@ std::shared_ptr PythonExceptionValue::call( ->insertNode(caller.graph()->createTuple(message_values)) ->output(); } - Value* qualified_class_name = - insertConstant(*caller.graph(), exception_class_qualified_name_, loc); - return std::make_shared( - error_message, qualified_class_name); + return std::make_shared(error_message); } bool isNamedTupleClass(const py::object& obj) { diff --git a/torch/csrc/jit/python/python_sugared_value.h b/torch/csrc/jit/python/python_sugared_value.h index 5fef124cf2b..d3559abda5c 100644 --- a/torch/csrc/jit/python/python_sugared_value.h +++ b/torch/csrc/jit/python/python_sugared_value.h @@ -328,12 +328,7 @@ struct VISIBILITY_HIDDEN PythonClassValue : public ClassValue { struct VISIBILITY_HIDDEN PythonExceptionValue : public ExceptionValue { explicit PythonExceptionValue(const py::object& exception_class) : ExceptionValue( - py::str(py::getattr(exception_class, "__name__", py::str("")))), - exception_class_qualified_name_( - py::str(py::module::import("torch._jit_internal") - .attr("_qualified_name")( - exception_class, - /*mangle_name=*/false))) {} + py::str(py::getattr(exception_class, "__name__", py::str("")))) {} std::string kind() const override { return "Python exception"; @@ -345,9 +340,6 @@ struct VISIBILITY_HIDDEN PythonExceptionValue : public ExceptionValue { at::ArrayRef args, at::ArrayRef kwargs, size_t n_binders) override; - - private: - std::string exception_class_qualified_name_; }; // Python Slice class. diff --git a/torch/csrc/jit/runtime/interpreter.cpp b/torch/csrc/jit/runtime/interpreter.cpp index da593b9e3e6..3351ba6f8d8 100644 --- a/torch/csrc/jit/runtime/interpreter.cpp +++ b/torch/csrc/jit/runtime/interpreter.cpp @@ -714,19 +714,10 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { } throw; } - auto* jit_exception = dynamic_cast(&e); + bool is_jit_exception = dynamic_cast(&e); // Janky af. See https://github.com/pytorch/pytorch/issues/54612 auto* not_implemented_error = dynamic_cast(&e); - - c10::optional python_class_name; - if (jit_exception) { - python_class_name = jit_exception->getPythonClassName(); - } - handleError( - ExceptionMessage(e), - (bool)jit_exception, - not_implemented_error, - python_class_name); + handleError(ExceptionMessage(e), is_jit_exception, not_implemented_error); return false; } } @@ -745,18 +736,15 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { void handleError( const ExceptionMessage& msg, bool is_jit_exception, - c10::NotImplementedError* not_implemented_error, - c10::optional python_class_name) { + c10::NotImplementedError* not_implemented_error) { std::ostringstream ss; - std::string class_name = - python_class_name ? *python_class_name : "RuntimeError"; ss << "The following operation failed in the TorchScript interpreter.\n"; formatStackTrace(ss); - ss << class_name << ": " << msg << "\n"; + ss << "RuntimeError: " << msg << "\n"; if (future_) { future_->setError(std::make_exception_ptr(Future::FutureError(ss.str()))); } else if (is_jit_exception) { - throw JITException(ss.str(), python_class_name); + throw JITException(ss.str()); } else if (not_implemented_error) { throw c10::NotImplementedError( ss.str(), diff --git a/torch/csrc/jit/runtime/jit_exception.cpp b/torch/csrc/jit/runtime/jit_exception.cpp index 600b92a111a..7e39fd26f5b 100644 --- a/torch/csrc/jit/runtime/jit_exception.cpp +++ b/torch/csrc/jit/runtime/jit_exception.cpp @@ -3,11 +3,7 @@ namespace torch { namespace jit { -JITException::JITException( - const std::string& msg, - c10::optional python_class_name) - : std::runtime_error(msg), - python_class_name_(std::move(python_class_name)) {} +JITException::JITException(const std::string& msg) : std::runtime_error(msg) {} } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/runtime/jit_exception.h b/torch/csrc/jit/runtime/jit_exception.h index 573d82b5799..974a18cc2b0 100644 --- a/torch/csrc/jit/runtime/jit_exception.h +++ b/torch/csrc/jit/runtime/jit_exception.h @@ -2,24 +2,13 @@ #include -#include #include -#include namespace torch { namespace jit { struct TORCH_API JITException : public std::runtime_error { - explicit JITException( - const std::string& msg, - c10::optional python_class_name = c10::nullopt); - - c10::optional getPythonClassName() const { - return python_class_name_; - } - - private: - c10::optional python_class_name_; + explicit JITException(const std::string& msg); }; } // namespace jit diff --git a/torch/csrc/jit/runtime/register_prim_ops.cpp b/torch/csrc/jit/runtime/register_prim_ops.cpp index a51ec03d0d3..6587c994994 100644 --- a/torch/csrc/jit/runtime/register_prim_ops.cpp +++ b/torch/csrc/jit/runtime/register_prim_ops.cpp @@ -406,8 +406,7 @@ static const OperatorGeneratorArgs opGenArgs[] = { numToTensorScalar, aliasAnalysisFromSchema()), OperatorGeneratorArgs( - TORCH_SELECTIVE_SCHEMA( - "prim::RaiseException(str msg, str? cls=None) -> ()"), + TORCH_SELECTIVE_SCHEMA("prim::RaiseException(str msg) -> ()"), raiseException, aliasAnalysisFromSchema()), OperatorGeneratorArgs(