mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/32760 Minor changes to the way ops are implemented to remove incidental use of Node* in the operator implementation. Current state for operators that previously took Node: ``` TBD: USES NODE: prim::DifferentiableGraph(...) -> (...) USES NODE: prim::profile(...) -> (...) USES NODE: prim::FusionGroup(...) -> (...) USES NODE: prim::PythonOp(...) -> (...) USES NODE: prim::ImplicitTensorToNum(Tensor a) -> Scalar # next PR Should be made interpreter primitives: USES NODE: prim::TupleUnpack(...) -> (...) USES NODE: prim::TupleSlice(...) -> (...) USES NODE: prim::TupleConstruct(...) -> (...) USES NODE: prim::ListUnpack(...) -> (...) USES NODE: prim::ListConstruct(...) -> (...) USES NODE: prim::DictConstruct(...) -> (...) USES NODE: prim::Constant() -> (...) USES NODE: prim::isinstance(...) -> (...) USES NODE: prim::CreateObject(...) -> (...) USES NODE: prim::fork(...) -> (...) USES NODE: aten::warn(str message, *, int stacklevel=2) -> () # need stack level information, so ideally in interpreter so it can look at the stack Should be made into vararg operators, i.e. the operators last argument should be an IValue that contains the number of arguments. USES NODE: prim::FusedConcat(...) -> (...) USES NODE: prim::MMTreeReduce(...) -> (...) USES NODE: prim::MMBatchSide(...) -> (...) USES NODE: prim::ConstantChunk(...) -> (...) USES NODE: prim::AutogradAnyNonZero(...) -> bool USES NODE: prim::BroadcastSizes(...) -> (...) USES NODE: prim::ChunkSizes(...) -> (...) USES NODE: aten::format(str self, ...) -> str USES NODE: prim::Print(...) -> (...) fixed: USES NODE: aten::extend(Tensor[](a!) self, Tensor [] other) -> () USES NODE: aten::copy(Tensor[](a) self) -> Tensor[] USES NODE: aten::extend(int[](a!) self, int [] other) -> () USES NODE: aten::copy(int[](a) self) -> int[] USES NODE: aten::extend(float[](a!) self, float [] other) -> () USES NODE: aten::copy(float[](a) self) -> float[] USES NODE: aten::extend(bool[](a!) self, bool [] other) -> () USES NODE: aten::copy(bool[](a) self) -> bool[] USES NODE: aten::extend(t[](a!) self, t [] other) -> () USES NODE: aten::copy(t[](a) self) -> t[] USES NODE: aten::keys(Dict(str, t) self) -> str[](*) USES NODE: aten::values(Dict(str, t) self) -> t[](*) USES NODE: aten::dict((str, tVal)[] inputs) -> Dict(str, tVal) USES NODE: aten::keys(Dict(int, t) self) -> int[](*) USES NODE: aten::values(Dict(int, t) self) -> t[](*) USES NODE: aten::dict((int, tVal)[] inputs) -> Dict(int, tVal) USES NODE: aten::keys(Dict(float, t) self) -> float[](*) USES NODE: aten::values(Dict(float, t) self) -> t[](*) USES NODE: aten::dict((float, tVal)[] inputs) -> Dict(float, tVal) USES NODE: aten::keys(Dict(Tensor, t) self) -> Tensor[](*) USES NODE: aten::values(Dict(Tensor, t) self) -> t[](*) USES NODE: aten::dict((Tensor, tVal)[] inputs) -> Dict(Tensor, tVal) USES NODE: aten::test_vartype2(t a, t[] b) -> (t[]) USES NODE: aten::_ncf_unsqueeze(Tensor self, int ndim) -> Tensor USES NODE: aten::_ncf_view(Tensor self, int[] input_shape, int normalized_ndim) -> Tensor USES NODE: prim::is_none(int? a) -> bool USES NODE: aten::__interpolate(Tensor input, int? size = None, float[]? scale_factor = None, str mode = 'nearest', bool? align_corners = None, bool? recompute_scale_factor = None) -> Tensor USES NODE: aten::__interpolate(Tensor input, int[]? size = None, float[]? scale_factor = None, str mode = 'nearest', bool? align_corners = None, bool? recompute_scale_factor = None) -> Tensor USES NODE: aten::__interpolate(Tensor input, int? size = None, float? scale_factor = None, str mode = 'nearest', bool? align_corners = None, bool? recompute_scale_factor = None) -> Tensor USES NODE: aten::__interpolate(Tensor input, int[]? size = None, float? scale_factor = None, str mode = 'nearest', bool? align_corners = None, bool? recompute_scale_factor = None) -> Tensor USES NODE: aten::sorted(t[](a) self) -> (t[]) USES NODE: aten::sort(t[](a!) self, bool reverse=False) -> () USES NODE: aten::test_vartype(t[] a, t b) -> (t) USES NODE: prim::unchecked_unwrap_optional(t(a)? optional) -> t(a) USES NODE: prim::unchecked_cast(...) -> (...) USES NODE: aten::dict() -> Dict(str, Tensor) USES NODE: prim::Load(...) -> (...) USES NODE: prim::Store(...) -> (...) USES NODE: prim::Drop(...) -> (...) USES NODE: aten::tensor(t[] data, *, ScalarType? dtype=None, Device? device=None, bool requires_grad=False) -> Tensor USES NODE: aten::as_tensor(t[] data, *, ScalarType? dtype=None, Device? device=None) -> Tensor ``` Test Plan: Imported from OSS Differential Revision: D19615387 Pulled By: zdevito fbshipit-source-id: 95298c3c4249b9f812c332d13f0fb79daeecb662
89 lines
2.2 KiB
C++
89 lines
2.2 KiB
C++
#include <torch/csrc/jit/ir.h>
|
|
#include <torch/csrc/jit/testing/file_check.h>
|
|
#include <torch/jit.h>
|
|
#include "test/cpp/jit/test_base.h"
|
|
#include "torch/csrc/jit/custom_operator.h"
|
|
|
|
#include <sstream>
|
|
#include <string>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
void testSchemaMatching() {
|
|
{
|
|
RegisterOperators reg({
|
|
Operator(
|
|
"aten::test_vartype(t[] a, t b) -> (t)",
|
|
[](Stack& stack) {
|
|
c10::List<double> list;
|
|
double a;
|
|
pop(stack, list, a);
|
|
push(stack, a);
|
|
return 0;
|
|
}),
|
|
});
|
|
script::Module m("m");
|
|
m.define(R"(
|
|
def test(self):
|
|
a = (1.0, 2.0)
|
|
return torch.test_vartype(a, 2.0)
|
|
)");
|
|
auto result = m.run_method("test");
|
|
TORCH_INTERNAL_ASSERT(result.toDouble() == 2.0);
|
|
|
|
const std::string error_example = R"JIT(
|
|
def test_2(self):
|
|
a = (1.0, 2.0)
|
|
non_float = (1, 1)
|
|
return torch.test_vartype(a, non_float)
|
|
)JIT";
|
|
|
|
std::string err = "";
|
|
try {
|
|
m.define(error_example);
|
|
} catch (const std::exception &e) {
|
|
err = e.what();
|
|
}
|
|
TORCH_INTERNAL_ASSERT(err.find("previously matched to type") != std::string::npos);
|
|
}
|
|
{
|
|
RegisterOperators reg({
|
|
Operator(
|
|
"aten::test_vartype2(t a, t[] b) -> (t[])",
|
|
[](Stack& stack) {
|
|
double a;
|
|
c10::List<double> list;
|
|
pop(stack, a, list);
|
|
push(stack, a);
|
|
return 0;
|
|
}),
|
|
});
|
|
script::Module m("m");
|
|
m.define(R"JIT(
|
|
def test(self):
|
|
a = (1.0, 2.0)
|
|
return torch.test_vartype2(3.0, a)
|
|
)JIT");
|
|
auto result = m.run_method("test");
|
|
TORCH_INTERNAL_ASSERT(result.toDouble() == 3.0);
|
|
|
|
static const auto error_exam2 = R"JIT(
|
|
def test_2(self):
|
|
a = (1, 2)
|
|
return torch.test_vartype2(3.0, a)
|
|
)JIT";
|
|
|
|
|
|
std::string err = "";
|
|
try {
|
|
m.define(error_exam2);
|
|
} catch (const std::exception &e) {
|
|
err = e.what();
|
|
}
|
|
TORCH_INTERNAL_ASSERT(err.find("previously matched to type") != std::string::npos);
|
|
}
|
|
}
|
|
} // namespace jit
|
|
} // namespace torch
|