add support for ir scalar literal parsing for inf/-inf/True/False (#163924)

Currently the ir parser doesn't support parse ir like
```
graph():
  %12 : float = prim::Constant[value=-inf]()
  %13 : float = prim::Constant[value=inf]()
  %14 : bool = prim::Constant[value=True]()
  %15 : bool = prim::Constant[value=False]()
  return (%12)
```

So the python script below will throw error.

```
#!/bin/env python
import torch

def test():
    return [True, False]
f = torch.jit.script(test)
torch._C._jit_pass_constant_propagation(f.graph)
ts_str = f.graph.__repr__()
print(ts_str)
ts = torch.parse_ir(ts_str)
func = torch._C._create_function_from_graph("forward", ts)
ret = func()
assert ret == [True, False]

def test():
    return [float("inf"), float("-inf")]
f = torch.jit.script(test)
torch._C._jit_pass_constant_propagation(f.graph)
ts_str = f.graph.__repr__()
print(ts_str)
ts = torch.parse_ir(ts_str)
func = torch._C._create_function_from_graph("forward", ts)
ret = func()
assert ret == [float("inf"), float("-inf")]
```

I add "inf" and bool cases for IRParser::parseScalarLiteral in irparser.cpp.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163924
Approved by: https://github.com/ezyang
This commit is contained in:
Weinan Liu 2025-10-27 05:10:19 +00:00 committed by PyTorch MergeBot
parent c58d0ad85d
commit fa4cb91846
2 changed files with 76 additions and 1 deletions

View File

@ -12,7 +12,11 @@ import torch.testing._internal.jit_utils
from jit.test_module_interface import TestModuleInterface # noqa: F401
from torch import jit
from torch.testing import FileCheck
from torch.testing._internal.common_utils import freeze_rng_state, raise_on_run_directly
from torch.testing._internal.common_utils import (
freeze_rng_state,
raise_on_run_directly,
skipIfTorchDynamo,
)
from torch.testing._internal.jit_utils import JitTestCase, make_global, RUN_CUDA_HALF
@ -433,6 +437,54 @@ class TestMisc(JitTestCase):
self.assertTrue(ret.numel() == 1)
self.assertTrue(len(ret.size()) == 1)
@skipIfTorchDynamo("The test case only test the parser. No need to wrap dynamo.")
def test_parse_ir_single_inf(self):
ir = """
graph():
%12 : float = prim::Constant[value=inf]()
return (%12)
"""
graph = torch._C.parse_ir(ir, True)
func = torch._C._create_function_from_graph("forward", graph)
ret = func()
self.assertTrue(ret == float("inf"))
@skipIfTorchDynamo("The test case only test the parser. No need to wrap dynamo.")
def test_parse_ir_single_minus_inf(self):
ir = """
graph():
%12 : float = prim::Constant[value=-inf]()
return (%12)
"""
graph = torch._C.parse_ir(ir, True)
func = torch._C._create_function_from_graph("forward", graph)
ret = func()
self.assertTrue(ret == float("-inf"))
@skipIfTorchDynamo("The test case only test the parser. No need to wrap dynamo.")
def test_parse_ir_bool_true(self):
ir = """
graph():
%12 : bool = prim::Constant[value=True]()
return (%12)
"""
graph = torch._C.parse_ir(ir, True)
func = torch._C._create_function_from_graph("forward", graph)
ret = func()
self.assertTrue(ret == True) # noqa: E712
@skipIfTorchDynamo("The test case only test the parser. No need to wrap dynamo.")
def test_parse_ir_bool_false(self):
ir = """
graph():
%12 : bool = prim::Constant[value=False]()
return (%12)
"""
graph = torch._C.parse_ir(ir, True)
func = torch._C._create_function_from_graph("forward", graph)
ret = func()
self.assertTrue(ret == False) # noqa: E712
def test_script_many_decorators(self):
def no_op_decorator(f):
return f

View File

@ -182,9 +182,25 @@ ParsedLiteral IRParser::parseScalarLiteral(Node* n) {
r.s = parseStringLiteral(token.range, token.text());
L.next();
return r;
case TK_TRUE:
r.k = AttributeKind::i;
r.i = 1;
L.next();
return r;
case TK_FALSE:
r.k = AttributeKind::i;
r.i = 0;
L.next();
return r;
case '-':
str = "-";
L.next();
if (L.cur().kind == TK_IDENT && L.cur().text() == "inf") {
r.k = AttributeKind::f;
r.f = -std::numeric_limits<double>::infinity();
L.next();
return r;
}
if (L.cur().kind != TK_NUMBER) {
throw(
ErrorReport(token.range)
@ -238,6 +254,13 @@ ParsedLiteral IRParser::parseScalarLiteral(Node* n) {
L.next();
return r;
case TK_IDENT:
if (L.cur().text() == "inf") {
r.k = AttributeKind::f;
r.f = std::numeric_limits<double>::infinity();
L.next();
return r;
}
// Type literal
r.k = AttributeKind::ty;
type_alias = type_parser.parseType();