mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Export read time and write time from the blobs queue. Fix queue balace stat for `blockingRead`.
87 lines
2.3 KiB
C++
87 lines
2.3 KiB
C++
/**
|
|
* Copyright (c) 2016-present, Facebook, Inc.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#pragma once
|
|
|
|
#include <atomic>
|
|
#include <condition_variable>
|
|
#include <memory>
|
|
#include <mutex>
|
|
#include <queue>
|
|
|
|
#include "caffe2/core/blob_stats.h"
|
|
#include "caffe2/core/logging.h"
|
|
#include "caffe2/core/stats.h"
|
|
#include "caffe2/core/tensor.h"
|
|
#include "caffe2/core/workspace.h"
|
|
|
|
namespace caffe2 {
|
|
|
|
// A thread-safe, bounded, blocking queue.
|
|
// Modelled as a circular buffer.
|
|
|
|
// Containing blobs are owned by the workspace.
|
|
// On read, we swap out the underlying data for the blob passed in for blobs
|
|
|
|
class BlobsQueue : public std::enable_shared_from_this<BlobsQueue> {
|
|
public:
|
|
BlobsQueue(
|
|
Workspace* ws,
|
|
const std::string& queueName,
|
|
size_t capacity,
|
|
size_t numBlobs,
|
|
bool enforceUniqueName,
|
|
const std::vector<std::string>& fieldNames = {});
|
|
|
|
~BlobsQueue() {
|
|
close();
|
|
}
|
|
|
|
bool blockingRead(
|
|
const std::vector<Blob*>& inputs,
|
|
float timeout_secs = 0.0f);
|
|
bool tryWrite(const std::vector<Blob*>& inputs);
|
|
bool blockingWrite(const std::vector<Blob*>& inputs);
|
|
void close();
|
|
size_t getNumBlobs() const {
|
|
return numBlobs_;
|
|
}
|
|
|
|
private:
|
|
bool canWrite();
|
|
void doWrite(const std::vector<Blob*>& inputs);
|
|
|
|
std::atomic<bool> closing_{false};
|
|
|
|
size_t numBlobs_;
|
|
std::mutex mutex_; // protects all variables in the class.
|
|
std::condition_variable cv_;
|
|
int64_t reader_{0};
|
|
int64_t writer_{0};
|
|
std::vector<std::vector<Blob*>> queue_;
|
|
const std::string name_;
|
|
|
|
struct QueueStats {
|
|
CAFFE_STAT_CTOR(QueueStats);
|
|
CAFFE_EXPORTED_STAT(queue_balance);
|
|
CAFFE_EXPORTED_STAT(queue_dequeued_records);
|
|
CAFFE_DETAILED_EXPORTED_STAT(queue_dequeued_bytes);
|
|
CAFFE_AVG_EXPORTED_STAT(read_time_ns);
|
|
CAFFE_AVG_EXPORTED_STAT(write_time_ns);
|
|
} stats_;
|
|
};
|
|
} // namespace caffe2
|