mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[caffe2] add an EstimateAllBlobSizes operator (#59775)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/59775 This operator is similar to `GetAllBlobNames` but also returns the estimated size required to serialize each node. One goal of this operator is to allow checkpoint saving logic to estimate the amount of space/bandwidth required to save a checkpoint when first starting training, without actually serializing any blobs yet. Currently the checkpointing logic uses `GetAllBlobNames` to determine the blobs to checkpoint. It can instead be updated to use `EstimateAllBlobSizes` to also get an estimate for how much space will be required for the checkpoint. ghstack-source-id: 132275153 Test Plan: Included a new unit test. Reviewed By: mraway Differential Revision: D29020227 fbshipit-source-id: 811e5d86c4b59183e84e6424c48c97739be09043
This commit is contained in:
parent
fe4ded01f7
commit
fadaa52f64
|
|
@ -115,6 +115,124 @@ c10::ArrayRef<T> GetTensorDataRange(
|
|||
return c10::ArrayRef<T>(tensor.template data<T>() + start, numElements);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool EnableByteEncoding() {
|
||||
// if typeSize == 1, endianness does not matter. Else check for endianness.
|
||||
if (sizeof(T) > 1 && !kIsLittleEndian) {
|
||||
return false;
|
||||
}
|
||||
return FLAGS_caffe2_serialize_using_bytes_as_holder;
|
||||
}
|
||||
|
||||
bool EnableByteEncodingFloat16() {
|
||||
if (!kIsLittleEndian) {
|
||||
return false;
|
||||
}
|
||||
// Check if special casing for float is enabled if
|
||||
// caffe2_serialize_using_bytes_as_holder is not enabled.
|
||||
return FLAGS_caffe2_serialize_using_bytes_as_holder ||
|
||||
FLAGS_caffe2_serialize_fp16_as_bytes;
|
||||
}
|
||||
|
||||
size_t EstimatePerElementSize(
|
||||
const Tensor& tensor,
|
||||
const BlobSerializationOptions& options) {
|
||||
const TensorProto::DataType data_type = TypeMetaToDataType(tensor.dtype());
|
||||
switch (data_type) {
|
||||
case TensorProto_DataType_FLOAT:
|
||||
#ifdef USE_FBGEMM
|
||||
if (options.float_format() ==
|
||||
BlobSerializationOptions_FloatFormat_FLOAT_BFLOAT16) {
|
||||
// Each element is serialized as a 2-byte bfloat16
|
||||
return sizeof(uint16_t);
|
||||
}
|
||||
#endif
|
||||
return sizeof(float);
|
||||
case TensorProto_DataType_INT32:
|
||||
// protobuf will use varint encoding, so it won't be a fixed field width
|
||||
// per integer, and will use between 1 and 5 bytes. Just return 4 bytes
|
||||
// as an estimate. With randomized data the actual value may be higher
|
||||
// than this, since around half the numbers will have the high bit set and
|
||||
// would require 5 bytes to encode.
|
||||
return sizeof(int32_t);
|
||||
case TensorProto_DataType_INT64:
|
||||
// Same varint reasoning as for the INT32 case.
|
||||
return sizeof(int64_t);
|
||||
case TensorProto_DataType_STRING:
|
||||
// We unfortunately cannot estimate the size well for strings, without
|
||||
// knowing the individual element lengths. Just return 50 bytes per
|
||||
// string as a guess.
|
||||
return 50;
|
||||
case TensorProto_DataType_BOOL:
|
||||
// Depending on EnableByteEncoding() this is either serialized in
|
||||
// byte_data or int32_data, but in either case it takes 1 byte per element
|
||||
// (since bool values will only take 1 byte when varint encoded in
|
||||
// int32_data).
|
||||
return 1;
|
||||
case TensorProto_DataType_UINT8:
|
||||
if (EnableByteEncoding<uint8_t>()) {
|
||||
return 1;
|
||||
} else {
|
||||
// Unfortunately when storing uint8_t values in int32_data any values
|
||||
// over 127 will require 2 bytes to store due to varint encoding.
|
||||
// With random data we would expect around 1.5 bytes per element. Round
|
||||
// up to 2.
|
||||
return 2;
|
||||
}
|
||||
case TensorProto_DataType_INT8:
|
||||
if (EnableByteEncoding<int8_t>()) {
|
||||
return 1;
|
||||
} else {
|
||||
// Unfortunately when storing int8_t values in int32_data any negative
|
||||
// values will require 2 bytes to store due to varint encoding. With
|
||||
// random data we would expect around 1.5 bytes per element. Round up
|
||||
// to 2.
|
||||
return 2;
|
||||
}
|
||||
case TensorProto_DataType_UINT16:
|
||||
if (EnableByteEncoding<uint16_t>()) {
|
||||
return 2;
|
||||
} else {
|
||||
// With random data, varint encoding will end up requiring closer to 3
|
||||
// bytes per element.
|
||||
return 3;
|
||||
}
|
||||
case TensorProto_DataType_INT16:
|
||||
if (EnableByteEncoding<int16_t>()) {
|
||||
return 2;
|
||||
} else {
|
||||
// With random data, varint encoding will end up requiring closer to 3
|
||||
// bytes per element.
|
||||
return 3;
|
||||
}
|
||||
case TensorProto_DataType_FLOAT16:
|
||||
if (EnableByteEncodingFloat16()) {
|
||||
return 2;
|
||||
} else {
|
||||
// The data will be stored as uint16_t values in the int32_data.
|
||||
// Due to varint encoding many values may require 3 bytes.
|
||||
return 3;
|
||||
}
|
||||
case TensorProto_DataType_DOUBLE:
|
||||
return sizeof(double);
|
||||
case TensorProto_DataType_UNDEFINED:
|
||||
return tensor.itemsize();
|
||||
case TensorProto_DataType_BYTE:
|
||||
case TensorProto_DataType_ZERO_COLLISION_HASH:
|
||||
case TensorProto_DataType_REBATCHING_BUFFER:
|
||||
// These data types should never be hit during serialization
|
||||
LOG(ERROR) << "unexpected tensor data type during serialization size "
|
||||
"estimation: "
|
||||
<< static_cast<int>(data_type);
|
||||
return 0;
|
||||
}
|
||||
|
||||
LOG(ERROR) << "unknown tensor data type during serialization size "
|
||||
"estimation: "
|
||||
<< static_cast<int>(data_type);
|
||||
return 0;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
/**
|
||||
|
|
@ -144,6 +262,17 @@ class StringSerializer : public BlobSerializerBase {
|
|||
blob_proto.set_content(*static_cast<const std::string*>(pointer));
|
||||
acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
|
||||
}
|
||||
|
||||
size_t EstimateSerializedBlobSize(
|
||||
const void* pointer,
|
||||
TypeMeta,
|
||||
c10::string_view name,
|
||||
const BlobSerializationOptions&) override {
|
||||
auto* str = static_cast<const std::string*>(pointer);
|
||||
// Add 20 for the "std::string" type field plus other overhead for the
|
||||
// BlobProto message serialization.
|
||||
return name.size() + str->size() + 20;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
|
|
@ -205,6 +334,20 @@ std::string SerializeBlob(const Blob& blob, const string& name) {
|
|||
return SerializeBlob(blob.GetRaw(), blob.meta(), name);
|
||||
}
|
||||
|
||||
size_t EstimateSerializedBlobSize(
|
||||
const Blob& blob,
|
||||
c10::string_view name,
|
||||
const BlobSerializationOptions& options) {
|
||||
std::unique_ptr<BlobSerializerBase> serializer{
|
||||
CreateSerializer(blob.meta().id())};
|
||||
if (!serializer) {
|
||||
LOG(ERROR) << "No known serializer for " << blob.meta().name();
|
||||
return 0;
|
||||
}
|
||||
return serializer->EstimateSerializedBlobSize(
|
||||
blob.GetRaw(), blob.meta(), name, options);
|
||||
}
|
||||
|
||||
void TensorSerializer::Serialize(
|
||||
const void* pointer,
|
||||
TypeMeta typeMeta,
|
||||
|
|
@ -296,27 +439,34 @@ void TensorSerializer::SerializeWithOptions(
|
|||
#endif
|
||||
}
|
||||
|
||||
size_t TensorSerializer::EstimateSerializedBlobSize(
|
||||
const void* pointer,
|
||||
TypeMeta typeMeta,
|
||||
c10::string_view name,
|
||||
const BlobSerializationOptions& options) {
|
||||
CAFFE_ENFORCE(typeMeta.Match<Tensor>());
|
||||
const auto& tensor = *static_cast<const Tensor*>(pointer);
|
||||
|
||||
auto chunk_size = options.chunk_size();
|
||||
if (chunk_size == kNoChunking) {
|
||||
chunk_size = tensor.numel() + 1; // to account for empty tensors
|
||||
} else if (chunk_size == kDefaultChunkSize) {
|
||||
chunk_size = FLAGS_caffe2_tensor_chunk_size;
|
||||
}
|
||||
|
||||
// There is a small amount of fixed overhead per chunk to serialize the
|
||||
// fixed TensorProto message data independent from the chunk contents.
|
||||
// This normally appears to be around 50 bytes.
|
||||
// The blob name is also written out in the BlobProto for each chunk.
|
||||
constexpr size_t protobuf_overhead_per_chunk = 50;
|
||||
size_t num_chunks = (tensor.numel() + (chunk_size - 1)) / chunk_size;
|
||||
size_t overhead = num_chunks * (name.size() + protobuf_overhead_per_chunk);
|
||||
|
||||
return overhead + tensor.numel() * EstimatePerElementSize(tensor, options);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
bool EnableByteEncoding() {
|
||||
// if typeSize == 1, endianness does not matter. Else check for endianness.
|
||||
if (sizeof(T) > 1 && !kIsLittleEndian) {
|
||||
return false;
|
||||
}
|
||||
return FLAGS_caffe2_serialize_using_bytes_as_holder;
|
||||
}
|
||||
|
||||
bool EnableByteEncodingFloat16() {
|
||||
if (!kIsLittleEndian) {
|
||||
return false;
|
||||
}
|
||||
// Check if special casing for float is enabled if
|
||||
// caffe2_serialize_using_bytes_as_holder is not enabled.
|
||||
return FLAGS_caffe2_serialize_using_bytes_as_holder ||
|
||||
FLAGS_caffe2_serialize_fp16_as_bytes;
|
||||
}
|
||||
|
||||
template <typename T, typename S = T>
|
||||
void SerializeUsingBytesOrInt32(
|
||||
bool enableByteEncoding,
|
||||
|
|
|
|||
|
|
@ -51,6 +51,11 @@ TORCH_API void SerializeBlob(
|
|||
BlobSerializerBase::SerializationAcceptor acceptor,
|
||||
const BlobSerializationOptions& options);
|
||||
|
||||
TORCH_API size_t EstimateSerializedBlobSize(
|
||||
const Blob& blob,
|
||||
c10::string_view name,
|
||||
const BlobSerializationOptions& options);
|
||||
|
||||
/**
|
||||
* @brief Convenience function to serialize a blob to a string.
|
||||
*
|
||||
|
|
@ -137,6 +142,12 @@ class TORCH_API TensorSerializer : public BlobSerializerBase {
|
|||
Serialize(tensor, name, proto, options, chunkBegin, chunkSize);
|
||||
}
|
||||
|
||||
size_t EstimateSerializedBlobSize(
|
||||
const void* pointer,
|
||||
TypeMeta typeMeta,
|
||||
c10::string_view name,
|
||||
const BlobSerializationOptions& options) override;
|
||||
|
||||
private:
|
||||
// A utility function to store the device context detauls.
|
||||
void StoreDeviceDetail(const Tensor& input, TensorProto* proto);
|
||||
|
|
|
|||
|
|
@ -3,7 +3,8 @@
|
|||
#include <string>
|
||||
#include <functional>
|
||||
|
||||
#include "c10/util/Registry.h"
|
||||
#include <c10/util/Registry.h>
|
||||
#include <c10/util/string_view.h>
|
||||
#include "caffe2/core/common.h"
|
||||
#include "caffe2/proto/caffe2_pb.h"
|
||||
|
||||
|
|
@ -61,6 +62,20 @@ class BlobSerializerBase {
|
|||
// Base implementation.
|
||||
Serialize(pointer, typeMeta, name, acceptor);
|
||||
}
|
||||
|
||||
virtual size_t EstimateSerializedBlobSize(
|
||||
const void* /*pointer*/,
|
||||
TypeMeta /*typeMeta*/,
|
||||
c10::string_view /*name*/,
|
||||
const BlobSerializationOptions& /*options*/) {
|
||||
// Base implementation.
|
||||
// This returns 0 just to allow us to roll this out without needing to
|
||||
// define an implementation for all serializer types. Returning a size of 0
|
||||
// for less-commonly used blob types is acceptable for now. Eventually it
|
||||
// would be nice to ensure that this method is implemented for all
|
||||
// serializers and then make this method virtual.
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
// The Blob serialization registry and serializer creator functions.
|
||||
|
|
|
|||
|
|
@ -199,6 +199,53 @@ bool SaveOpImpl::RunOnDevice() {
|
|||
|
||||
} // namespace internal
|
||||
|
||||
namespace {
|
||||
class EstimateAllBlobSizesOp final : public Operator<CPUContext> {
|
||||
public:
|
||||
explicit EstimateAllBlobSizesOp(
|
||||
const OperatorDef& operator_def,
|
||||
Workspace* ws)
|
||||
: Operator<CPUContext>(operator_def, ws),
|
||||
include_shared_(GetSingleArgument<int>("include_shared", true)),
|
||||
ws_(ws) {
|
||||
auto options_data = GetSingleArgument<string>("options", "");
|
||||
if (!options_data.empty()) {
|
||||
if (!options_.ParseFromString(options_data)) {
|
||||
CAFFE_ENFORCE(false, "unable to parse serialization options");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool RunOnDevice() override {
|
||||
const auto& blob_names = include_shared_ ? ws_->Blobs() : ws_->LocalBlobs();
|
||||
auto* names_out = Output(0, {static_cast<int64_t>(blob_names.size())}, at::dtype<std::string>());
|
||||
auto* sizes_out = Output(1, {static_cast<int64_t>(blob_names.size())}, at::dtype<int64_t>());
|
||||
BlobSerializationOptions default_options;
|
||||
for (size_t idx = 0; idx < blob_names.size(); ++idx) {
|
||||
const auto& name = blob_names[idx];
|
||||
auto* blob = ws_->GetBlob(name);
|
||||
if (!blob) {
|
||||
LOG(ERROR) << "unable to find blob " << name
|
||||
<< " when estimating serialization size";
|
||||
continue;
|
||||
}
|
||||
|
||||
names_out->template mutable_data<std::string>()[idx] = name;
|
||||
const auto& blob_serialization_options =
|
||||
internal::GetBlobOptions(name, options_, default_options);
|
||||
sizes_out->template mutable_data<int64_t>()[idx] =
|
||||
EstimateSerializedBlobSize(*blob, name, blob_serialization_options);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
bool include_shared_{true};
|
||||
Workspace* ws_{nullptr};
|
||||
SerializationOptions options_;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
REGISTER_CPU_OPERATOR(DBExists, DBExistsOp<CPUContext>);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
|
|
@ -210,6 +257,8 @@ REGISTER_CPU_OPERATOR(Checkpoint, CheckpointOp<CPUContext>);
|
|||
// CPU Operator old name: do NOT use, we may deprecate this later.
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
REGISTER_CPU_OPERATOR(Snapshot, CheckpointOp<CPUContext>);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
REGISTER_CPU_OPERATOR(EstimateAllBlobSizes, EstimateAllBlobSizesOp);
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
OPERATOR_SCHEMA(DBExists)
|
||||
|
|
@ -455,6 +504,26 @@ counter). This is determined whether we need to do checkpointing.
|
|||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
OPERATOR_SCHEMA(Snapshot);
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
OPERATOR_SCHEMA(EstimateAllBlobSizes)
|
||||
.NumInputs(0)
|
||||
.NumOutputs(2)
|
||||
.SetDoc(R"DOC(
|
||||
Returns two outputs: a 1D tensor of strings containing the names
|
||||
of each blob in the active workspace, and a 1D tensor of integers containing the
|
||||
estimated serialized size of each blob (in bytes).
|
||||
)DOC")
|
||||
.Arg(
|
||||
"include_shared",
|
||||
"(bool, default true) Whether to include blobs "
|
||||
"inherited from parent workspaces.")
|
||||
.Arg(
|
||||
"options",
|
||||
"(string, default empty) A BlobSerializationOptions message specifying "
|
||||
"options for how specific blobs should be serialized.")
|
||||
.Output(0, "blob_names", "1D tensor of strings containing blob names.")
|
||||
.Output(1, "blob_sizes", "1D tensor of int64_t containing blob sizes.");
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
NO_GRADIENT(Load);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
|
|
@ -465,5 +534,7 @@ SHOULD_NOT_DO_GRADIENT(Save);
|
|||
SHOULD_NOT_DO_GRADIENT(Checkpoint);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
SHOULD_NOT_DO_GRADIENT(Snapshot);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
SHOULD_NOT_DO_GRADIENT(EstimateAllBlobSizesOp);
|
||||
|
||||
} // namespace caffe2
|
||||
|
|
|
|||
|
|
@ -693,6 +693,129 @@ class TestLoadSave(TestLoadSaveBase):
|
|||
workspace.FetchBlob("float1"), float_data, decimal=2
|
||||
)
|
||||
|
||||
def testEstimateBlobSizes(self) -> None:
|
||||
# Create some blobs to test with
|
||||
float_data = np.random.random_sample(4000).astype(np.float32)
|
||||
workspace.FeedBlob("float1", float_data)
|
||||
workspace.FeedBlob("float2", float_data)
|
||||
workspace.FeedBlob(
|
||||
"float3", np.random.random_sample(2).astype(np.float32)
|
||||
)
|
||||
workspace.FeedBlob(
|
||||
"ui16", np.random.randint(0, 0xffff, size=1024, dtype=np.uint16)
|
||||
)
|
||||
|
||||
# Estimate the serialized size of the data.
|
||||
# Request bfloat16 serialization for one of the float blobs, just to
|
||||
# exercise size estimation when using this option.
|
||||
options = caffe2_pb2.SerializationOptions(
|
||||
options=[
|
||||
BlobSerializationOptions(
|
||||
blob_name_regex="float1",
|
||||
float_format=BlobSerializationOptions.FLOAT_BFLOAT16,
|
||||
chunk_size=500,
|
||||
),
|
||||
],
|
||||
)
|
||||
get_blobs_op = core.CreateOperator(
|
||||
"EstimateAllBlobSizes",
|
||||
[],
|
||||
["blob_names", "blob_sizes"],
|
||||
options=options,
|
||||
)
|
||||
self.assertTrue(workspace.RunOperatorOnce(get_blobs_op))
|
||||
blob_names = workspace.FetchBlob("blob_names")
|
||||
blob_sizes = workspace.FetchBlob("blob_sizes")
|
||||
|
||||
sizes_by_name: Dict[str, int] = {}
|
||||
for idx, name in enumerate(blob_names):
|
||||
sizes_by_name[name.decode("utf-8")] = blob_sizes[idx]
|
||||
|
||||
# Note that the output blob list will include our output blob names.
|
||||
expected_blobs = [
|
||||
"float1", "float2", "float3", "ui16",
|
||||
"blob_names", "blob_sizes"
|
||||
]
|
||||
self.assertEqual(set(sizes_by_name.keys()), set(expected_blobs))
|
||||
|
||||
def check_expected_blob_size(
|
||||
name: str, num_elems: int, elem_size: int, num_chunks: int = 1
|
||||
) -> None:
|
||||
# The estimation code applies a fixed 40 byte per-chunk overhead to
|
||||
# account for the extra space required for other fixed TensorProto
|
||||
# message fields.
|
||||
per_chunk_overhead = 50
|
||||
expected_size = (
|
||||
(num_chunks * (len(name) + per_chunk_overhead))
|
||||
+ (num_elems * elem_size)
|
||||
)
|
||||
self.assertEqual(
|
||||
sizes_by_name[name],
|
||||
expected_size,
|
||||
f"expected size mismatch for {name}"
|
||||
)
|
||||
|
||||
check_expected_blob_size("ui16", 1024, 3)
|
||||
check_expected_blob_size("float2", 4000, 4)
|
||||
check_expected_blob_size("float3", 2, 4)
|
||||
|
||||
# Our serialization options request to split float1 into 500-element
|
||||
# chunks when saving it. If fbgemm is available then the float1 blob
|
||||
# will be serialized using 2 bytes per element instead of 4 bytes.
|
||||
float1_num_chunks = 4000 // 500
|
||||
if workspace.has_fbgemm:
|
||||
check_expected_blob_size("float1", 4000, 2, float1_num_chunks)
|
||||
else:
|
||||
check_expected_blob_size("float1", 4000, 4, float1_num_chunks)
|
||||
|
||||
check_expected_blob_size("blob_names", len(expected_blobs), 50)
|
||||
check_expected_blob_size("blob_sizes", len(expected_blobs), 8)
|
||||
|
||||
# Now actually save the blobs so we can compare our estimates
|
||||
# to how big the serialized data actually is.
|
||||
tmp_folder = self.make_tempdir()
|
||||
tmp_file = str(tmp_folder / "save.output")
|
||||
save_op = core.CreateOperator(
|
||||
"Save",
|
||||
list(sizes_by_name.keys()),
|
||||
[],
|
||||
absolute_path=1,
|
||||
db=tmp_file,
|
||||
db_type=self._db_type,
|
||||
options=options,
|
||||
)
|
||||
self.assertTrue(workspace.RunOperatorOnce(save_op))
|
||||
|
||||
blob_chunks = self._read_chunk_info(Path(tmp_file))
|
||||
saved_sizes: Dict[str, int] = {}
|
||||
for blob_name, chunks in blob_chunks.items():
|
||||
total_size = sum(chunk.value_size for chunk in chunks)
|
||||
saved_sizes[blob_name] = total_size
|
||||
|
||||
# For sanity checking, ensure that our estimates aren't
|
||||
# extremely far off
|
||||
for name in expected_blobs:
|
||||
estimated_size = sizes_by_name[name]
|
||||
saved_size = saved_sizes[name]
|
||||
difference = abs(estimated_size - saved_size)
|
||||
error_pct = 100.0 * (difference / saved_size)
|
||||
print(
|
||||
f"{name}: estimated={estimated_size} actual={saved_size} "
|
||||
f"error={error_pct:.2f}%"
|
||||
)
|
||||
# Don't check the blob_names blob. It is a string tensor, and we
|
||||
# can't estimate string tensor sizes very well without knowing the
|
||||
# individual string lengths. (Currently it requires 102 bytes to
|
||||
# save, but we estimate 360).
|
||||
if name == "blob_names":
|
||||
continue
|
||||
# Check that we are within 100 bytes, or within 25%
|
||||
# We are generally quite close for tensors with fixed-width fields
|
||||
# (like float), but a little farther off for tensors that use varint
|
||||
# encoding.
|
||||
if difference > 100:
|
||||
self.assertLess(error_pct, 25.0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user