mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-08 07:39:33 +01:00
Summary:
TODO: integrate into torch.onnx.export -- separate PR
*Problem:* We have a facility to trace PyTorch operations on Python code, but there are several failure modes where the trace is not representative of the actual underlying computation:
* The tracer encountered dynamic control flow
* Some computation escaped the tracer, and appeared as a Constant tensor node in the graph
* Some stateful function was traced, e.g. someone did an optimization in Python by memoizing function outputs
*Objective*: In an ideal world, this whole process would be automated and the user can trust that the system will magically capture the intended semantics from the program. Realistically speaking, we will likely have to settle with a human-in-the-loop error reporting system, allowing for the user to identify problems and modify the source code to allow for tracing.
*Stage 1* (this PR): Output-level checking & graph diff. torch.jit.trace gains a kwarg 'check_inputs', which is a list of tuples of input arguments. We will iterate through the list and trace the function again for each set of check inputs. We'll also interpret the original trace with these inputs and compare output values and graphs, printing a diff of the graph if there is a difference.
Examples:
```
torch.jit.trace(torch.rand(3, 4), check_inputs=[(torch.rand(4, 5),)])
def foo(x):
y = torch.arange(0, x.shape[0]).float()
return x + y.unsqueeze(1)
```
```
torch.jit.TracingCheckError: Tracing failed sanity checks!
ERROR: Graphs differed across invocations!
Graph diff:
graph(%0 : Dynamic) {
- %1 : Dynamic = prim::Constant[value= 0 1 2 [ CPULongType{3} ]]()
? ^
+ %1 : Dynamic = prim::Constant[value= 0 1 2 3 [ CPULongType{4} ]]()
? +++ ^
%2 : int = prim::Constant[value=0]()
%3 : Dynamic = aten::_cast_Float(%1, %2)
%4 : int = prim::Constant[value=1]()
%5 : Dynamic = aten::unsqueeze(%3, %4)
%6 : int = prim::Constant[value=1]()
%7 : Dynamic = aten::add(%0, %5, %6)
return (%7);
}
Node diff:
- %1 : Dynamic = prim::Constant[value= 0 1 2 [ CPULongType{3} ]]()
? ^
+ %1 : Dynamic = prim::Constant[value= 0 1 2 3 [ CPULongType{4} ]]()
? +++ ^
Trace source location:
dank.py(5): foo
/Users/jamesreed/onnx-fairseq/pytorch/torch/jit/__init__.py(402): wrapper
dank.py(3): <module>
Check source location:
dank.py(5): foo
/Users/jamesreed/onnx-fairseq/pytorch/torch/jit/__init__.py(281): check_trace
/Users/jamesreed/onnx-fairseq/pytorch/torch/jit/__init__.py(408): wrapper
dank.py(3): <module>
ERROR: Tensor-valued Constant nodes differed in value across invocations. This often indicates that the tracer has encountered untraceable code.
Node:
%1 : Dynamic = prim::Constant[value= 0 1 2 [ CPULongType{3} ]]()
Source Location:
dank.py(5): foo
/Users/jamesreed/onnx-fairseq/pytorch/torch/jit/__init__.py(402): wrapper
dank.py(3): <module>
Comparison exception:
Not equal to tolerance rtol=1e-07, atol=0
(shapes (3,), (4,) mismatch)
x: array([0, 1, 2])
y: array([0, 1, 2, 3])
```
==
```
torch.jit.trace(torch.rand(3, 4), check_inputs=[(torch.rand(3, 4),)])
def foo(x):
y = x.data
return x + y
```
```
torch.jit.TracingCheckError: Tracing failed sanity checks!
ERROR: Traced function outputs do not match the Python function outputs.
ERROR: Tensor-valued Constant nodes differed in value across invocations. This often indicates that the tracer has encountered untraceable code.
Node:
%1 : Dynamic = prim::Constant[value=<Tensor>]()
Source Location:
dank.py(6): foo
/Users/jamesreed/onnx-fairseq/pytorch/torch/jit/__init__.py(402): wrapper
dank.py(3): <module>
Comparison exception:
Not equal to tolerance rtol=1e-07, atol=0
(mismatch 100.0%)
x: array([0.397137, 0.956105, 0.169478, 0.560292, 0.392568, 0.108441,
0.97645 , 0.34412 , 0.951246, 0.793061, 0.557595, 0.770245],
dtype=float32)
y: array([0.243178, 0.315964, 0.972041, 0.0215 , 0.927751, 0.457512,
0.951092, 0.97883 , 0.048688, 0.118066, 0.779345, 0.271272],
dtype=float32)
```
==
```
import torch
torch.jit.trace(torch.rand(3, 4), check_inputs=[(torch.rand(4, 4),)])
def foo(x):
for _ in range(x.size(0)):
x = torch.neg(x)
return x
```
```
torch.jit.TracingCheckError: Tracing failed sanity checks!
ERROR: Traced function outputs do not match the Python function outputs.
ERROR: Graphs differed across invocations!
Graph diff:
graph(%0 : Dynamic) {
%1 : Dynamic = aten::neg(%0)
%2 : Dynamic = aten::neg(%1)
%3 : Dynamic = aten::neg(%2)
+ %4 : Dynamic = aten::neg(%3)
- return (%3);
? ^
+ return (%4);
? ^
}
```
==
```
import torch
def foo(x):
if not hasattr(foo, 'cache'):
foo.cache = torch.neg(x)
return x + foo.cache
traced = torch.jit.trace(torch.rand(3, 4), check_inputs=[(torch.rand(3, 4),)])(foo)
```
```
torch.jit.TracingCheckError: Tracing failed sanity checks!
ERROR: Traced function outputs do not match the Python function outputs.
ERROR: Graphs differed across invocations!
Graph diff:
graph(%0 : Dynamic) {
- %1 : Dynamic = aten::neg(%0)
+ %1 : Dynamic = prim::Constant[value=<Tensor>]()
%2 : int = prim::Constant[value=1]()
%3 : Dynamic = aten::add(%0, %1, %2)
return (%3);
}
Node diff:
- %1 : Dynamic = aten::neg(%0)
+ %1 : Dynamic = prim::Constant[value=<Tensor>]()
Trace source location:
test.py(5): foo
/Users/jamesreed/onnx-fairseq/pytorch/torch/jit/__init__.py(402): wrapper
test.py(8): <module>
Check source location:
test.py(6): foo
/Users/jamesreed/onnx-fairseq/pytorch/torch/jit/__init__.py(281): check_trace
/Users/jamesreed/onnx-fairseq/pytorch/torch/jit/__init__.py(408): wrapper
test.py(8): <module>
```
The following two examples show instances where program semantics are lost in the Python -> trace transformation, and repeated invocation does not give us useful debug information. Further design in underway for catching these scenarios.
```
import torch
torch.jit.trace(torch.rand(3, 4), check_inputs=[(torch.rand(3, 4),)])
def foo(x):
for i in range(3):
x[i, :] = torch.zeros(4)
return x
```
```
torch.jit.TracingCheckError: Tracing failed sanity checks!
ERROR: Traced function outputs do not match the Python function outputs.
Exception:
Not equal to tolerance rtol=1e-07, atol=0
(mismatch 100.0%)
x: array([0.830221, 0.915481, 0.940281, 0.555241], dtype=float32)
y: array([0., 0., 0., 0.], dtype=float32)
```
==
```
import torch
torch.jit.trace(torch.rand(3, 4), check_inputs=[(torch.rand(5, 6),)])
def foo(x):
x.view(-1).add_(-x.view(-1))
return x
```
```
torch.jit.TracingCheckError: Tracing failed sanity checks!
ERROR: Traced function outputs do not match the Python function outputs.
Exception:
Not equal to tolerance rtol=1e-07, atol=0
(mismatch 100.0%)
x: array([0.734441, 0.445327, 0.640592, 0.30076 , 0.891674, 0.124771],
dtype=float32)
y: array([0., 0., 0., 0., 0., 0.], dtype=float32)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/10841
Differential Revision: D9499945
Pulled By: jamesr66a
fbshipit-source-id: 1f842a32d0b0645259cc43b29700b86d99c59a45
107 lines
3.7 KiB
C++
107 lines
3.7 KiB
C++
// in memory description of all ATen Ops similar to Caffe2 schema
|
|
// once C10 exists this can be removed, or stubbed out, but we need
|
|
// it now to implement correct semantic checking for script
|
|
#pragma once
|
|
|
|
#include "torch/csrc/jit/assertions.h"
|
|
#include "torch/csrc/jit/ir.h"
|
|
#include "torch/csrc/jit/function_schema.h"
|
|
#include "torch/csrc/jit/stack.h"
|
|
|
|
#include "ATen/ATen.h"
|
|
|
|
#include <functional>
|
|
#include <initializer_list>
|
|
#include <memory>
|
|
#include <string>
|
|
#include <unordered_map>
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
namespace torch { namespace jit {
|
|
|
|
FunctionSchema parseSchema(const std::string& schema);
|
|
|
|
using OperationCreator = std::function<Operation(Node*)>;
|
|
|
|
struct TORCH_API Operator {
|
|
Operator(FunctionSchema schema, OperationCreator op_creator)
|
|
: schema_(std::make_shared<FunctionSchema>(std::move(schema))),
|
|
op_creator_(std::move(op_creator)) {}
|
|
|
|
Operator(const std::string& schema, OperationCreator op_creator)
|
|
: schema_string_(schema), op_creator_(std::move(op_creator)) {}
|
|
|
|
// Helper constructor to register `op` to run
|
|
// run for _every_ IR Node where n.kind() == name, regardless of arguments.
|
|
// This is accomplished by marking the schema varargs and having no required
|
|
// arguments. This is used for things like prim::While or prim::If that can
|
|
// take a number of different valid input types and lengths.
|
|
Operator(Symbol name, OperationCreator op_creator)
|
|
: Operator(FunctionSchema(name, {}, {}, true), std::move(op_creator)) {}
|
|
|
|
Operator(FunctionSchema schema, Operation op)
|
|
: schema_(std::make_shared<FunctionSchema>(std::move(schema))),
|
|
op_(std::make_shared<Operation>(std::move(op))) {}
|
|
|
|
Operator(const std::string& schema, Operation op)
|
|
: schema_string_(schema),
|
|
op_(std::make_shared<Operation>(std::move(op))) {}
|
|
|
|
bool matches(const Node* node) const;
|
|
|
|
Operation getOperation(Node* node = nullptr) const {
|
|
if (op_) {
|
|
return *op_;
|
|
}
|
|
AT_ASSERT(node != nullptr);
|
|
return op_creator_(node);
|
|
}
|
|
|
|
const FunctionSchema & schema() const {
|
|
// we lazily parse schema initialized from strings so that
|
|
// we do less work during static operator registration
|
|
if(!schema_) {
|
|
schema_ = std::make_shared<FunctionSchema>(parseSchema(schema_string_.value()));
|
|
schema_string_ = at::nullopt;
|
|
}
|
|
return *schema_;
|
|
}
|
|
private:
|
|
mutable at::optional<std::string> schema_string_;
|
|
// cannot use at::optional because windows has issues that require an assignment operator to be generated
|
|
// cannot use std::unique_ptr because initializer lists of Operators end up copying the Operator
|
|
mutable std::shared_ptr<FunctionSchema> schema_;
|
|
|
|
// Essentially a variant<Operation, OperationCreator>.
|
|
// NB: std::function has a default state (where it == nullptr).
|
|
std::shared_ptr<Operation> op_;
|
|
OperationCreator op_creator_;
|
|
};
|
|
|
|
TORCH_API const std::vector<std::shared_ptr<Operator>>& getAllOperatorsFor(Symbol name);
|
|
std::shared_ptr<Operator> findOperatorFor(const Node* node);
|
|
const Operator& getOperatorFor(const Node* node);
|
|
|
|
inline Operation getOperation(Node* node) {
|
|
// note: getOperatorFor ensures that getOperatorFor(node).matches(node) == true
|
|
// so the call to selectVariant is always valid.
|
|
return getOperatorFor(node).getOperation(node);
|
|
}
|
|
|
|
void registerOperator(Operator&& op);
|
|
|
|
// XXX: this function is meant to be used with string literals only!
|
|
Operator& sig(const char *signature_literal);
|
|
|
|
struct OperatorSet {
|
|
OperatorSet(std::initializer_list<const char *> sig_literals);
|
|
// XXX: Returns a nullptr if no Operator in the set matches n
|
|
Operator* find(const Node *n) const;
|
|
private:
|
|
std::unordered_map<Symbol, std::vector<std::shared_ptr<Operator>>> ops;
|
|
};
|
|
|
|
|
|
}}
|