mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
add support for ir scalar literal parsing for inf/-inf/True/False (#163924)
Currently the ir parser doesn't support parse ir like
```
graph():
%12 : float = prim::Constant[value=-inf]()
%13 : float = prim::Constant[value=inf]()
%14 : bool = prim::Constant[value=True]()
%15 : bool = prim::Constant[value=False]()
return (%12)
```
So the python script below will throw error.
```
#!/bin/env python
import torch
def test():
return [True, False]
f = torch.jit.script(test)
torch._C._jit_pass_constant_propagation(f.graph)
ts_str = f.graph.__repr__()
print(ts_str)
ts = torch.parse_ir(ts_str)
func = torch._C._create_function_from_graph("forward", ts)
ret = func()
assert ret == [True, False]
def test():
return [float("inf"), float("-inf")]
f = torch.jit.script(test)
torch._C._jit_pass_constant_propagation(f.graph)
ts_str = f.graph.__repr__()
print(ts_str)
ts = torch.parse_ir(ts_str)
func = torch._C._create_function_from_graph("forward", ts)
ret = func()
assert ret == [float("inf"), float("-inf")]
```
I add "inf" and bool cases for IRParser::parseScalarLiteral in irparser.cpp.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163924
Approved by: https://github.com/ezyang
This commit is contained in:
parent
c58d0ad85d
commit
fa4cb91846
|
|
@ -12,7 +12,11 @@ import torch.testing._internal.jit_utils
|
||||||
from jit.test_module_interface import TestModuleInterface # noqa: F401
|
from jit.test_module_interface import TestModuleInterface # noqa: F401
|
||||||
from torch import jit
|
from torch import jit
|
||||||
from torch.testing import FileCheck
|
from torch.testing import FileCheck
|
||||||
from torch.testing._internal.common_utils import freeze_rng_state, raise_on_run_directly
|
from torch.testing._internal.common_utils import (
|
||||||
|
freeze_rng_state,
|
||||||
|
raise_on_run_directly,
|
||||||
|
skipIfTorchDynamo,
|
||||||
|
)
|
||||||
from torch.testing._internal.jit_utils import JitTestCase, make_global, RUN_CUDA_HALF
|
from torch.testing._internal.jit_utils import JitTestCase, make_global, RUN_CUDA_HALF
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -433,6 +437,54 @@ class TestMisc(JitTestCase):
|
||||||
self.assertTrue(ret.numel() == 1)
|
self.assertTrue(ret.numel() == 1)
|
||||||
self.assertTrue(len(ret.size()) == 1)
|
self.assertTrue(len(ret.size()) == 1)
|
||||||
|
|
||||||
|
@skipIfTorchDynamo("The test case only test the parser. No need to wrap dynamo.")
|
||||||
|
def test_parse_ir_single_inf(self):
|
||||||
|
ir = """
|
||||||
|
graph():
|
||||||
|
%12 : float = prim::Constant[value=inf]()
|
||||||
|
return (%12)
|
||||||
|
"""
|
||||||
|
graph = torch._C.parse_ir(ir, True)
|
||||||
|
func = torch._C._create_function_from_graph("forward", graph)
|
||||||
|
ret = func()
|
||||||
|
self.assertTrue(ret == float("inf"))
|
||||||
|
|
||||||
|
@skipIfTorchDynamo("The test case only test the parser. No need to wrap dynamo.")
|
||||||
|
def test_parse_ir_single_minus_inf(self):
|
||||||
|
ir = """
|
||||||
|
graph():
|
||||||
|
%12 : float = prim::Constant[value=-inf]()
|
||||||
|
return (%12)
|
||||||
|
"""
|
||||||
|
graph = torch._C.parse_ir(ir, True)
|
||||||
|
func = torch._C._create_function_from_graph("forward", graph)
|
||||||
|
ret = func()
|
||||||
|
self.assertTrue(ret == float("-inf"))
|
||||||
|
|
||||||
|
@skipIfTorchDynamo("The test case only test the parser. No need to wrap dynamo.")
|
||||||
|
def test_parse_ir_bool_true(self):
|
||||||
|
ir = """
|
||||||
|
graph():
|
||||||
|
%12 : bool = prim::Constant[value=True]()
|
||||||
|
return (%12)
|
||||||
|
"""
|
||||||
|
graph = torch._C.parse_ir(ir, True)
|
||||||
|
func = torch._C._create_function_from_graph("forward", graph)
|
||||||
|
ret = func()
|
||||||
|
self.assertTrue(ret == True) # noqa: E712
|
||||||
|
|
||||||
|
@skipIfTorchDynamo("The test case only test the parser. No need to wrap dynamo.")
|
||||||
|
def test_parse_ir_bool_false(self):
|
||||||
|
ir = """
|
||||||
|
graph():
|
||||||
|
%12 : bool = prim::Constant[value=False]()
|
||||||
|
return (%12)
|
||||||
|
"""
|
||||||
|
graph = torch._C.parse_ir(ir, True)
|
||||||
|
func = torch._C._create_function_from_graph("forward", graph)
|
||||||
|
ret = func()
|
||||||
|
self.assertTrue(ret == False) # noqa: E712
|
||||||
|
|
||||||
def test_script_many_decorators(self):
|
def test_script_many_decorators(self):
|
||||||
def no_op_decorator(f):
|
def no_op_decorator(f):
|
||||||
return f
|
return f
|
||||||
|
|
|
||||||
|
|
@ -182,9 +182,25 @@ ParsedLiteral IRParser::parseScalarLiteral(Node* n) {
|
||||||
r.s = parseStringLiteral(token.range, token.text());
|
r.s = parseStringLiteral(token.range, token.text());
|
||||||
L.next();
|
L.next();
|
||||||
return r;
|
return r;
|
||||||
|
case TK_TRUE:
|
||||||
|
r.k = AttributeKind::i;
|
||||||
|
r.i = 1;
|
||||||
|
L.next();
|
||||||
|
return r;
|
||||||
|
case TK_FALSE:
|
||||||
|
r.k = AttributeKind::i;
|
||||||
|
r.i = 0;
|
||||||
|
L.next();
|
||||||
|
return r;
|
||||||
case '-':
|
case '-':
|
||||||
str = "-";
|
str = "-";
|
||||||
L.next();
|
L.next();
|
||||||
|
if (L.cur().kind == TK_IDENT && L.cur().text() == "inf") {
|
||||||
|
r.k = AttributeKind::f;
|
||||||
|
r.f = -std::numeric_limits<double>::infinity();
|
||||||
|
L.next();
|
||||||
|
return r;
|
||||||
|
}
|
||||||
if (L.cur().kind != TK_NUMBER) {
|
if (L.cur().kind != TK_NUMBER) {
|
||||||
throw(
|
throw(
|
||||||
ErrorReport(token.range)
|
ErrorReport(token.range)
|
||||||
|
|
@ -238,6 +254,13 @@ ParsedLiteral IRParser::parseScalarLiteral(Node* n) {
|
||||||
L.next();
|
L.next();
|
||||||
return r;
|
return r;
|
||||||
case TK_IDENT:
|
case TK_IDENT:
|
||||||
|
if (L.cur().text() == "inf") {
|
||||||
|
r.k = AttributeKind::f;
|
||||||
|
r.f = std::numeric_limits<double>::infinity();
|
||||||
|
L.next();
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
|
||||||
// Type literal
|
// Type literal
|
||||||
r.k = AttributeKind::ty;
|
r.k = AttributeKind::ty;
|
||||||
type_alias = type_parser.parseType();
|
type_alias = type_parser.parseType();
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user