mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
add support for ir scalar literal parsing for inf/-inf/True/False (#163924)
Currently the ir parser doesn't support parse ir like
```
graph():
%12 : float = prim::Constant[value=-inf]()
%13 : float = prim::Constant[value=inf]()
%14 : bool = prim::Constant[value=True]()
%15 : bool = prim::Constant[value=False]()
return (%12)
```
So the python script below will throw error.
```
#!/bin/env python
import torch
def test():
return [True, False]
f = torch.jit.script(test)
torch._C._jit_pass_constant_propagation(f.graph)
ts_str = f.graph.__repr__()
print(ts_str)
ts = torch.parse_ir(ts_str)
func = torch._C._create_function_from_graph("forward", ts)
ret = func()
assert ret == [True, False]
def test():
return [float("inf"), float("-inf")]
f = torch.jit.script(test)
torch._C._jit_pass_constant_propagation(f.graph)
ts_str = f.graph.__repr__()
print(ts_str)
ts = torch.parse_ir(ts_str)
func = torch._C._create_function_from_graph("forward", ts)
ret = func()
assert ret == [float("inf"), float("-inf")]
```
I add "inf" and bool cases for IRParser::parseScalarLiteral in irparser.cpp.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163924
Approved by: https://github.com/ezyang
This commit is contained in:
parent
c58d0ad85d
commit
fa4cb91846
|
|
@ -12,7 +12,11 @@ import torch.testing._internal.jit_utils
|
|||
from jit.test_module_interface import TestModuleInterface # noqa: F401
|
||||
from torch import jit
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_utils import freeze_rng_state, raise_on_run_directly
|
||||
from torch.testing._internal.common_utils import (
|
||||
freeze_rng_state,
|
||||
raise_on_run_directly,
|
||||
skipIfTorchDynamo,
|
||||
)
|
||||
from torch.testing._internal.jit_utils import JitTestCase, make_global, RUN_CUDA_HALF
|
||||
|
||||
|
||||
|
|
@ -433,6 +437,54 @@ class TestMisc(JitTestCase):
|
|||
self.assertTrue(ret.numel() == 1)
|
||||
self.assertTrue(len(ret.size()) == 1)
|
||||
|
||||
@skipIfTorchDynamo("The test case only test the parser. No need to wrap dynamo.")
|
||||
def test_parse_ir_single_inf(self):
|
||||
ir = """
|
||||
graph():
|
||||
%12 : float = prim::Constant[value=inf]()
|
||||
return (%12)
|
||||
"""
|
||||
graph = torch._C.parse_ir(ir, True)
|
||||
func = torch._C._create_function_from_graph("forward", graph)
|
||||
ret = func()
|
||||
self.assertTrue(ret == float("inf"))
|
||||
|
||||
@skipIfTorchDynamo("The test case only test the parser. No need to wrap dynamo.")
|
||||
def test_parse_ir_single_minus_inf(self):
|
||||
ir = """
|
||||
graph():
|
||||
%12 : float = prim::Constant[value=-inf]()
|
||||
return (%12)
|
||||
"""
|
||||
graph = torch._C.parse_ir(ir, True)
|
||||
func = torch._C._create_function_from_graph("forward", graph)
|
||||
ret = func()
|
||||
self.assertTrue(ret == float("-inf"))
|
||||
|
||||
@skipIfTorchDynamo("The test case only test the parser. No need to wrap dynamo.")
|
||||
def test_parse_ir_bool_true(self):
|
||||
ir = """
|
||||
graph():
|
||||
%12 : bool = prim::Constant[value=True]()
|
||||
return (%12)
|
||||
"""
|
||||
graph = torch._C.parse_ir(ir, True)
|
||||
func = torch._C._create_function_from_graph("forward", graph)
|
||||
ret = func()
|
||||
self.assertTrue(ret == True) # noqa: E712
|
||||
|
||||
@skipIfTorchDynamo("The test case only test the parser. No need to wrap dynamo.")
|
||||
def test_parse_ir_bool_false(self):
|
||||
ir = """
|
||||
graph():
|
||||
%12 : bool = prim::Constant[value=False]()
|
||||
return (%12)
|
||||
"""
|
||||
graph = torch._C.parse_ir(ir, True)
|
||||
func = torch._C._create_function_from_graph("forward", graph)
|
||||
ret = func()
|
||||
self.assertTrue(ret == False) # noqa: E712
|
||||
|
||||
def test_script_many_decorators(self):
|
||||
def no_op_decorator(f):
|
||||
return f
|
||||
|
|
|
|||
|
|
@ -182,9 +182,25 @@ ParsedLiteral IRParser::parseScalarLiteral(Node* n) {
|
|||
r.s = parseStringLiteral(token.range, token.text());
|
||||
L.next();
|
||||
return r;
|
||||
case TK_TRUE:
|
||||
r.k = AttributeKind::i;
|
||||
r.i = 1;
|
||||
L.next();
|
||||
return r;
|
||||
case TK_FALSE:
|
||||
r.k = AttributeKind::i;
|
||||
r.i = 0;
|
||||
L.next();
|
||||
return r;
|
||||
case '-':
|
||||
str = "-";
|
||||
L.next();
|
||||
if (L.cur().kind == TK_IDENT && L.cur().text() == "inf") {
|
||||
r.k = AttributeKind::f;
|
||||
r.f = -std::numeric_limits<double>::infinity();
|
||||
L.next();
|
||||
return r;
|
||||
}
|
||||
if (L.cur().kind != TK_NUMBER) {
|
||||
throw(
|
||||
ErrorReport(token.range)
|
||||
|
|
@ -238,6 +254,13 @@ ParsedLiteral IRParser::parseScalarLiteral(Node* n) {
|
|||
L.next();
|
||||
return r;
|
||||
case TK_IDENT:
|
||||
if (L.cur().text() == "inf") {
|
||||
r.k = AttributeKind::f;
|
||||
r.f = std::numeric_limits<double>::infinity();
|
||||
L.next();
|
||||
return r;
|
||||
}
|
||||
|
||||
// Type literal
|
||||
r.k = AttributeKind::ty;
|
||||
type_alias = type_parser.parseType();
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user