mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Use == operator to test type equivalance in pytorch_jni_common.cpp (#71508)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/71508 "==" is the more universal way to test type equalities, and also ::get() doesn't incur any refcount overhead now, so we can swtich to == instead of relying on type kinds. ghstack-source-id: 147353057 Test Plan: CI buck test xplat/caffe2/android:pytorch_jni_common_test Differential Revision: D33672433 fbshipit-source-id: 5973fd40de48b8017b5c3ebaa55bcf5b4b391aa3
This commit is contained in:
parent
5c679f2bea
commit
db84a4b566
|
|
@ -0,0 +1,18 @@
|
||||||
|
// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
|
||||||
|
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include <ATen/core/type_factory.h>
|
||||||
|
#include "caffe2/android/pytorch_android/src/main/cpp/pytorch_jni_common.h"
|
||||||
|
|
||||||
|
using namespace ::testing;
|
||||||
|
|
||||||
|
TEST(pytorch_jni_common_test, newJIValueFromAtIValue) {
|
||||||
|
auto dict = c10::impl::GenericDict(
|
||||||
|
c10::dynT<c10::IntType>(), c10::dynT<c10::StringType>());
|
||||||
|
auto dictCallback = [](auto&&) {
|
||||||
|
return facebook::jni::local_ref<pytorch_jni::JIValue>{};
|
||||||
|
};
|
||||||
|
EXPECT_NO_THROW(pytorch_jni::JIValue::newJIValueFromAtIValue(
|
||||||
|
dict, dictCallback, dictCallback));
|
||||||
|
}
|
||||||
|
|
@ -287,8 +287,51 @@ class TensorHybrid : public facebook::jni::HybridClass<TensorHybrid> {
|
||||||
at::Tensor tensor_;
|
at::Tensor tensor_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
facebook::jni::local_ref<JIValue> JIValue::newJIValueFromStringDict(
|
||||||
|
c10::Dict<c10::IValue, c10::IValue> dict) {
|
||||||
|
static auto jMethodDictStringKey =
|
||||||
|
JIValue::javaClassStatic()
|
||||||
|
->getStaticMethod<facebook::jni::local_ref<JIValue>(
|
||||||
|
facebook::jni::alias_ref<facebook::jni::JMap<
|
||||||
|
facebook::jni::alias_ref<facebook::jni::JString::javaobject>,
|
||||||
|
facebook::jni::alias_ref<JIValue::javaobject>>>)>(
|
||||||
|
"dictStringKeyFrom");
|
||||||
|
|
||||||
|
auto jmap = JHashMap<
|
||||||
|
facebook::jni::alias_ref<facebook::jni::JString::javaobject>,
|
||||||
|
facebook::jni::alias_ref<JIValue::javaobject>>::create();
|
||||||
|
for (auto& pair : dict) {
|
||||||
|
jmap->put(
|
||||||
|
facebook::jni::make_jstring(pair.key().toString()->string()),
|
||||||
|
JIValue::newJIValueFromAtIValue(pair.value()));
|
||||||
|
}
|
||||||
|
return jMethodDictStringKey(JIValue::javaClassStatic(), jmap);
|
||||||
|
}
|
||||||
|
|
||||||
|
facebook::jni::local_ref<JIValue> JIValue::newJIValueFromIntDict(
|
||||||
|
c10::Dict<c10::IValue, c10::IValue> dict) {
|
||||||
|
static auto jMethodDictLongKey =
|
||||||
|
JIValue::javaClassStatic()
|
||||||
|
->getStaticMethod<facebook::jni::local_ref<JIValue>(
|
||||||
|
facebook::jni::alias_ref<facebook::jni::JMap<
|
||||||
|
facebook::jni::alias_ref<facebook::jni::JLong::javaobject>,
|
||||||
|
facebook::jni::alias_ref<JIValue::javaobject>>>)>(
|
||||||
|
"dictLongKeyFrom");
|
||||||
|
auto jmap = JHashMap<
|
||||||
|
facebook::jni::alias_ref<facebook::jni::JLong::javaobject>,
|
||||||
|
facebook::jni::alias_ref<JIValue::javaobject>>::create();
|
||||||
|
for (auto& pair : dict) {
|
||||||
|
jmap->put(
|
||||||
|
facebook::jni::JLong::valueOf(pair.key().toInt()),
|
||||||
|
JIValue::newJIValueFromAtIValue(pair.value()));
|
||||||
|
}
|
||||||
|
return jMethodDictLongKey(JIValue::javaClassStatic(), jmap);
|
||||||
|
}
|
||||||
|
|
||||||
facebook::jni::local_ref<JIValue> JIValue::newJIValueFromAtIValue(
|
facebook::jni::local_ref<JIValue> JIValue::newJIValueFromAtIValue(
|
||||||
const at::IValue& ivalue) {
|
const at::IValue& ivalue,
|
||||||
|
DictCallback stringDictCallback,
|
||||||
|
DictCallback intDictCallback) {
|
||||||
Trace _s{"jni::JIValue::newJIValueFromAtIValue"};
|
Trace _s{"jni::JIValue::newJIValueFromAtIValue"};
|
||||||
if (ivalue.isNone()) {
|
if (ivalue.isNone()) {
|
||||||
static auto jMethodOptionalNull =
|
static auto jMethodOptionalNull =
|
||||||
|
|
@ -427,49 +470,16 @@ facebook::jni::local_ref<JIValue> JIValue::newJIValueFromAtIValue(
|
||||||
"Unknown IValue-Dict key type");
|
"Unknown IValue-Dict key type");
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto keyTypeKind = keyType->kind();
|
if (*keyType == *c10::StringType::get()) {
|
||||||
if (c10::TypeKind::StringType == keyTypeKind) {
|
return stringDictCallback(std::move(dict));
|
||||||
static auto jMethodDictStringKey =
|
} else if (*keyType == *c10::IntType::get()) {
|
||||||
JIValue::javaClassStatic()
|
return intDictCallback(std::move(dict));
|
||||||
->getStaticMethod<facebook::jni::local_ref<JIValue>(
|
|
||||||
facebook::jni::alias_ref<facebook::jni::JMap<
|
|
||||||
facebook::jni::alias_ref<
|
|
||||||
facebook::jni::JString::javaobject>,
|
|
||||||
facebook::jni::alias_ref<JIValue::javaobject>>>)>(
|
|
||||||
"dictStringKeyFrom");
|
|
||||||
|
|
||||||
auto jmap = JHashMap<
|
|
||||||
facebook::jni::alias_ref<facebook::jni::JString::javaobject>,
|
|
||||||
facebook::jni::alias_ref<JIValue::javaobject>>::create();
|
|
||||||
for (auto& pair : dict) {
|
|
||||||
jmap->put(
|
|
||||||
facebook::jni::make_jstring(pair.key().toString()->string()),
|
|
||||||
JIValue::newJIValueFromAtIValue(pair.value()));
|
|
||||||
}
|
|
||||||
return jMethodDictStringKey(JIValue::javaClassStatic(), jmap);
|
|
||||||
} else if (c10::TypeKind::IntType == keyTypeKind) {
|
|
||||||
static auto jMethodDictLongKey =
|
|
||||||
JIValue::javaClassStatic()
|
|
||||||
->getStaticMethod<facebook::jni::local_ref<JIValue>(
|
|
||||||
facebook::jni::alias_ref<facebook::jni::JMap<
|
|
||||||
facebook::jni::alias_ref<
|
|
||||||
facebook::jni::JLong::javaobject>,
|
|
||||||
facebook::jni::alias_ref<JIValue::javaobject>>>)>(
|
|
||||||
"dictLongKeyFrom");
|
|
||||||
auto jmap = JHashMap<
|
|
||||||
facebook::jni::alias_ref<facebook::jni::JLong::javaobject>,
|
|
||||||
facebook::jni::alias_ref<JIValue::javaobject>>::create();
|
|
||||||
for (auto& pair : dict) {
|
|
||||||
jmap->put(
|
|
||||||
facebook::jni::JLong::valueOf(pair.key().toInt()),
|
|
||||||
JIValue::newJIValueFromAtIValue(pair.value()));
|
|
||||||
}
|
|
||||||
return jMethodDictLongKey(JIValue::javaClassStatic(), jmap);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
facebook::jni::throwNewJavaException(
|
facebook::jni::throwNewJavaException(
|
||||||
facebook::jni::gJavaLangIllegalArgumentException,
|
facebook::jni::gJavaLangIllegalArgumentException,
|
||||||
"Unsupported IValue-Dict key type");
|
"Unsupported IValue-Dict key type: %s",
|
||||||
|
keyType->str().c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
facebook::jni::throwNewJavaException(
|
facebook::jni::throwNewJavaException(
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <c10/util/FunctionRef.h>
|
||||||
#include <fbjni/fbjni.h>
|
#include <fbjni/fbjni.h>
|
||||||
#include <torch/csrc/api/include/torch/types.h>
|
#include <torch/csrc/api/include/torch/types.h>
|
||||||
#include "caffe2/serialize/read_adapter_interface.h"
|
#include "caffe2/serialize/read_adapter_interface.h"
|
||||||
|
|
@ -93,6 +94,9 @@ class MemoryReadAdapter final : public caffe2::serialize::ReadAdapterInterface {
|
||||||
};
|
};
|
||||||
|
|
||||||
class JIValue : public facebook::jni::JavaClass<JIValue> {
|
class JIValue : public facebook::jni::JavaClass<JIValue> {
|
||||||
|
using DictCallback = c10::function_ref<facebook::jni::local_ref<JIValue>(
|
||||||
|
c10::Dict<c10::IValue, c10::IValue>)>;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
constexpr static const char* kJavaDescriptor = "Lorg/pytorch/IValue;";
|
constexpr static const char* kJavaDescriptor = "Lorg/pytorch/IValue;";
|
||||||
|
|
||||||
|
|
@ -115,10 +119,18 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
|
||||||
constexpr static int kTypeCodeDictLongKey = 14;
|
constexpr static int kTypeCodeDictLongKey = 14;
|
||||||
|
|
||||||
static facebook::jni::local_ref<JIValue> newJIValueFromAtIValue(
|
static facebook::jni::local_ref<JIValue> newJIValueFromAtIValue(
|
||||||
const at::IValue& ivalue);
|
const at::IValue& ivalue,
|
||||||
|
DictCallback stringDictCallback = newJIValueFromStringDict,
|
||||||
|
DictCallback intDictCallback = newJIValueFromIntDict);
|
||||||
|
|
||||||
static at::IValue JIValueToAtIValue(
|
static at::IValue JIValueToAtIValue(
|
||||||
facebook::jni::alias_ref<JIValue> jivalue);
|
facebook::jni::alias_ref<JIValue> jivalue);
|
||||||
|
|
||||||
|
private:
|
||||||
|
static facebook::jni::local_ref<JIValue> newJIValueFromStringDict(
|
||||||
|
c10::Dict<c10::IValue, c10::IValue>);
|
||||||
|
static facebook::jni::local_ref<JIValue> newJIValueFromIntDict(
|
||||||
|
c10::Dict<c10::IValue, c10::IValue>);
|
||||||
};
|
};
|
||||||
|
|
||||||
void common_registerNatives();
|
void common_registerNatives();
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user