From c92ff47afd49c6a62adb181105eb5b8b34c64e30 Mon Sep 17 00:00:00 2001 From: Zhengxu Chen Date: Thu, 20 Jan 2022 15:40:38 -0800 Subject: [PATCH] Use == operator to test type equivalance in pytorch_jni_common.cpp (#71508) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/71508 "==" is the more universal way to test type equalities, and also ::get() doesn't incur any refcount overhead now, so we can swtich to == instead of relying on type kinds. ghstack-source-id: 147353057 Test Plan: CI buck test xplat/caffe2/android:pytorch_jni_common_test Differential Revision: D33672433 fbshipit-source-id: 5973fd40de48b8017b5c3ebaa55bcf5b4b391aa3 (cherry picked from commit db84a4b566d1f2f17cda8785e11bc11739e6f50c) --- .../cpp/pytorch_jni_common_test.cpp | 18 ++++ .../src/main/cpp/pytorch_jni_common.cpp | 90 ++++++++++--------- .../src/main/cpp/pytorch_jni_common.h | 14 ++- 3 files changed, 81 insertions(+), 41 deletions(-) create mode 100644 android/pytorch_android/src/androidTest/cpp/pytorch_jni_common_test.cpp diff --git a/android/pytorch_android/src/androidTest/cpp/pytorch_jni_common_test.cpp b/android/pytorch_android/src/androidTest/cpp/pytorch_jni_common_test.cpp new file mode 100644 index 00000000000..9caf4e1c2a5 --- /dev/null +++ b/android/pytorch_android/src/androidTest/cpp/pytorch_jni_common_test.cpp @@ -0,0 +1,18 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#include + +#include +#include "caffe2/android/pytorch_android/src/main/cpp/pytorch_jni_common.h" + +using namespace ::testing; + +TEST(pytorch_jni_common_test, newJIValueFromAtIValue) { + auto dict = c10::impl::GenericDict( + c10::dynT(), c10::dynT()); + auto dictCallback = [](auto&&) { + return facebook::jni::local_ref{}; + }; + EXPECT_NO_THROW(pytorch_jni::JIValue::newJIValueFromAtIValue( + dict, dictCallback, dictCallback)); +} diff --git a/android/pytorch_android/src/main/cpp/pytorch_jni_common.cpp b/android/pytorch_android/src/main/cpp/pytorch_jni_common.cpp index 4592537ae68..8094f7bdc97 100644 --- a/android/pytorch_android/src/main/cpp/pytorch_jni_common.cpp +++ b/android/pytorch_android/src/main/cpp/pytorch_jni_common.cpp @@ -287,8 +287,51 @@ class TensorHybrid : public facebook::jni::HybridClass { at::Tensor tensor_; }; +facebook::jni::local_ref JIValue::newJIValueFromStringDict( + c10::Dict dict) { + static auto jMethodDictStringKey = + JIValue::javaClassStatic() + ->getStaticMethod( + facebook::jni::alias_ref, + facebook::jni::alias_ref>>)>( + "dictStringKeyFrom"); + + auto jmap = JHashMap< + facebook::jni::alias_ref, + facebook::jni::alias_ref>::create(); + for (auto& pair : dict) { + jmap->put( + facebook::jni::make_jstring(pair.key().toString()->string()), + JIValue::newJIValueFromAtIValue(pair.value())); + } + return jMethodDictStringKey(JIValue::javaClassStatic(), jmap); +} + +facebook::jni::local_ref JIValue::newJIValueFromIntDict( + c10::Dict dict) { + static auto jMethodDictLongKey = + JIValue::javaClassStatic() + ->getStaticMethod( + facebook::jni::alias_ref, + facebook::jni::alias_ref>>)>( + "dictLongKeyFrom"); + auto jmap = JHashMap< + facebook::jni::alias_ref, + facebook::jni::alias_ref>::create(); + for (auto& pair : dict) { + jmap->put( + facebook::jni::JLong::valueOf(pair.key().toInt()), + JIValue::newJIValueFromAtIValue(pair.value())); + } + return jMethodDictLongKey(JIValue::javaClassStatic(), jmap); +} + facebook::jni::local_ref JIValue::newJIValueFromAtIValue( - const at::IValue& ivalue) { + const at::IValue& ivalue, + DictCallback stringDictCallback, + DictCallback intDictCallback) { Trace _s{"jni::JIValue::newJIValueFromAtIValue"}; if (ivalue.isNone()) { static auto jMethodOptionalNull = @@ -427,49 +470,16 @@ facebook::jni::local_ref JIValue::newJIValueFromAtIValue( "Unknown IValue-Dict key type"); } - const auto keyTypeKind = keyType->kind(); - if (c10::TypeKind::StringType == keyTypeKind) { - static auto jMethodDictStringKey = - JIValue::javaClassStatic() - ->getStaticMethod( - facebook::jni::alias_ref, - facebook::jni::alias_ref>>)>( - "dictStringKeyFrom"); - - auto jmap = JHashMap< - facebook::jni::alias_ref, - facebook::jni::alias_ref>::create(); - for (auto& pair : dict) { - jmap->put( - facebook::jni::make_jstring(pair.key().toString()->string()), - JIValue::newJIValueFromAtIValue(pair.value())); - } - return jMethodDictStringKey(JIValue::javaClassStatic(), jmap); - } else if (c10::TypeKind::IntType == keyTypeKind) { - static auto jMethodDictLongKey = - JIValue::javaClassStatic() - ->getStaticMethod( - facebook::jni::alias_ref, - facebook::jni::alias_ref>>)>( - "dictLongKeyFrom"); - auto jmap = JHashMap< - facebook::jni::alias_ref, - facebook::jni::alias_ref>::create(); - for (auto& pair : dict) { - jmap->put( - facebook::jni::JLong::valueOf(pair.key().toInt()), - JIValue::newJIValueFromAtIValue(pair.value())); - } - return jMethodDictLongKey(JIValue::javaClassStatic(), jmap); + if (*keyType == *c10::StringType::get()) { + return stringDictCallback(std::move(dict)); + } else if (*keyType == *c10::IntType::get()) { + return intDictCallback(std::move(dict)); } facebook::jni::throwNewJavaException( facebook::jni::gJavaLangIllegalArgumentException, - "Unsupported IValue-Dict key type"); + "Unsupported IValue-Dict key type: %s", + keyType->str().c_str()); } facebook::jni::throwNewJavaException( diff --git a/android/pytorch_android/src/main/cpp/pytorch_jni_common.h b/android/pytorch_android/src/main/cpp/pytorch_jni_common.h index befa4cccd0b..5020a1cff96 100644 --- a/android/pytorch_android/src/main/cpp/pytorch_jni_common.h +++ b/android/pytorch_android/src/main/cpp/pytorch_jni_common.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include #include "caffe2/serialize/read_adapter_interface.h" @@ -93,6 +94,9 @@ class MemoryReadAdapter final : public caffe2::serialize::ReadAdapterInterface { }; class JIValue : public facebook::jni::JavaClass { + using DictCallback = c10::function_ref( + c10::Dict)>; + public: constexpr static const char* kJavaDescriptor = "Lorg/pytorch/IValue;"; @@ -115,10 +119,18 @@ class JIValue : public facebook::jni::JavaClass { constexpr static int kTypeCodeDictLongKey = 14; static facebook::jni::local_ref newJIValueFromAtIValue( - const at::IValue& ivalue); + const at::IValue& ivalue, + DictCallback stringDictCallback = newJIValueFromStringDict, + DictCallback intDictCallback = newJIValueFromIntDict); static at::IValue JIValueToAtIValue( facebook::jni::alias_ref jivalue); + + private: + static facebook::jni::local_ref newJIValueFromStringDict( + c10::Dict); + static facebook::jni::local_ref newJIValueFromIntDict( + c10::Dict); }; void common_registerNatives();