mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
implementation of DataPtr context for copy-on-write tensors (#100818)
implementation of DataPtr context for copy-on-write tensors Summary: Copy-on-write storage ===================== This library adds support for copy-on-write storage, i.e. lazy copies, to tensors. The design maintains the PyTorch invariant that tensors alias if and only if they share a storage. Thus, tensors that are lazy copies of one another will have distinct storages that share a data allocation. Thread-safety ------------- The correctness of this design hinges on the pre-existing PyTorch user requirement (and general default programming assumption) that users are responsible for guaranteeing that writes do not take places concurrently with reads and other writes. Lazily copied tensors add a complication to this programming model because users are not required to know if lazy copies exist and are not required to serialize writes across lazy copies. For example: two tensors with distinct storages that share a copy-on-write data context may be given to different threads that may do whatever they wish to them, and the runtime is required to guarantee its safety. It turns out that this is not that difficult to protect because, due to the copy-on-write requirement, we just need to materialize a tensor upon writing. This could be done entirely without synchronization if we materialized each copy, however, we have a common-sense optimization to elide the copy for the last remaining reference. This requires waiting for any pending copies. ### Thread-safety detailed design There are two operations that affect the copy-on-write details of a tensor: 1) lazy-clone (e.g. an explicit call or a hidden implementation detail added through an operator like reshape) 2) materialization (i.e. any write to the tensor) The key insight that we exploit is that lazy-clone is logically a read operation and materialization is logically a write operation. This means that, for a given set of tensors that share a storage, if materialization is taking place, no other read operation, including lazy-clone, can be concurrent with it. However, this insight only applies within a set of tensors that share a storage. We also have to be concerned with tensors with different storages that share a copy-on-write context. In this world, materialization can race with lazy-clone or even other materializations. _However_, in order for this to be the case, there must be _at least_ two references to the context. This means that the context _can not_ vanish out from under you if you are performing a lazy-clone, and hence, it only requires an atomic refcount bump. The most complicated case is that all lazy-copies are concurrently materializing. In this case, because a write is occurring, there are no in-flight lazy-copies taking place. We must simply ensure that all lazy-copies are able to materialize (read the data) concurrently. If we didn't have the aforementioned optimization where the last copy steals the data, we could get away with no locking whatsoever: each makes a copy and decrements the refcount. However, because of the optimization, we require the loser of the materializing race wait for the pending copies to finish, and then steal the data without copying it. We implement this by taking a shared lock when copying the data and taking an exclusive lock when stealing the data. The exclusive lock acquisition ensures that all pending shared locks are finished before we steal the data. Test Plan: 100% code coverage. --- Stack created with [Sapling](https://sapling-scm.com). Best reviewed with [ReviewStack](https://reviewstack.dev/pytorch/pytorch/pull/100818). * #100821 * #100820 * #100819 * __->__ #100818 Pull Request resolved: https://github.com/pytorch/pytorch/pull/100818 Approved by: https://github.com/ezyang
This commit is contained in:
parent
87084643e5
commit
979f55d3bc
|
|
@ -29,6 +29,7 @@ file(GLOB C10_SRCS
|
||||||
*.cpp
|
*.cpp
|
||||||
core/*.cpp
|
core/*.cpp
|
||||||
core/impl/*.cpp
|
core/impl/*.cpp
|
||||||
|
core/impl/cow/*.cpp
|
||||||
mobile/*.cpp
|
mobile/*.cpp
|
||||||
macros/*.cpp
|
macros/*.cpp
|
||||||
util/*.cpp
|
util/*.cpp
|
||||||
|
|
@ -37,6 +38,7 @@ file(GLOB C10_HEADERS
|
||||||
*.h
|
*.h
|
||||||
core/*.h
|
core/*.h
|
||||||
core/impl/*.h
|
core/impl/*.h
|
||||||
|
core/impl/cow/*.h
|
||||||
mobile/*.h
|
mobile/*.h
|
||||||
macros/*.h
|
macros/*.h
|
||||||
util/*.h
|
util/*.h
|
||||||
|
|
|
||||||
|
|
@ -62,6 +62,7 @@ def define_targets(rules):
|
||||||
exclude = [
|
exclude = [
|
||||||
"CPUAllocator.cpp",
|
"CPUAllocator.cpp",
|
||||||
"impl/alloc_cpu.cpp",
|
"impl/alloc_cpu.cpp",
|
||||||
|
"impl/cow/*.cpp",
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
hdrs = rules.glob(
|
hdrs = rules.glob(
|
||||||
|
|
@ -72,6 +73,7 @@ def define_targets(rules):
|
||||||
exclude = [
|
exclude = [
|
||||||
"CPUAllocator.h",
|
"CPUAllocator.h",
|
||||||
"impl/alloc_cpu.h",
|
"impl/alloc_cpu.h",
|
||||||
|
"impl/cow/*.h",
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
linkstatic = True,
|
linkstatic = True,
|
||||||
|
|
@ -79,6 +81,7 @@ def define_targets(rules):
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
":ScalarType",
|
":ScalarType",
|
||||||
|
":impl/cow/context",
|
||||||
"//c10/macros",
|
"//c10/macros",
|
||||||
"//c10/util:TypeCast",
|
"//c10/util:TypeCast",
|
||||||
"//c10/util:base",
|
"//c10/util:base",
|
||||||
|
|
@ -89,6 +92,23 @@ def define_targets(rules):
|
||||||
alwayslink = True,
|
alwayslink = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
rules.cc_library(
|
||||||
|
name = "impl/cow/context",
|
||||||
|
srcs = [
|
||||||
|
"impl/cow/context.cpp",
|
||||||
|
"impl/cow/deleter.cpp",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"impl/cow/context.h",
|
||||||
|
"impl/cow/deleter.h",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//c10/macros",
|
||||||
|
"//c10/util:base",
|
||||||
|
],
|
||||||
|
visibility = ["//c10/test:__pkg__"],
|
||||||
|
)
|
||||||
|
|
||||||
rules.filegroup(
|
rules.filegroup(
|
||||||
name = "headers",
|
name = "headers",
|
||||||
srcs = rules.glob(
|
srcs = rules.glob(
|
||||||
|
|
|
||||||
67
c10/core/impl/cow/README.md
Normal file
67
c10/core/impl/cow/README.md
Normal file
|
|
@ -0,0 +1,67 @@
|
||||||
|
Copy-on-write storage
|
||||||
|
=====================
|
||||||
|
This library adds support for copy-on-write storage, i.e. lazy copies,
|
||||||
|
to tensors. The design maintains the PyTorch invariant that tensors
|
||||||
|
alias if and only if they share a storage. Thus, tensors that are lazy
|
||||||
|
copies of one another will have distinct storages that share a data
|
||||||
|
allocation.
|
||||||
|
|
||||||
|
Thread-safety
|
||||||
|
-------------
|
||||||
|
The correctness of this design hinges on the pre-existing PyTorch user
|
||||||
|
requirement (and general default programming assumption) that users
|
||||||
|
are responsible for guaranteeing that writes do not take places
|
||||||
|
concurrently with reads and other writes.
|
||||||
|
|
||||||
|
Lazily copied tensors add a complication to this programming model
|
||||||
|
because users are not required to know if lazy copies exist and are
|
||||||
|
not required to serialize writes across lazy copies. For example: two
|
||||||
|
tensors with distinct storages that share a copy-on-write data context
|
||||||
|
may be given to different threads that may do whatever they wish to
|
||||||
|
them, and the runtime is required to guarantee its safety.
|
||||||
|
|
||||||
|
It turns out that this is not that difficult to protect because, due
|
||||||
|
to the copy-on-write requirement, we just need to materialize a tensor
|
||||||
|
upon writing. This could be done entirely without synchronization if
|
||||||
|
we materialized each copy, however, we have a common-sense
|
||||||
|
optimization to elide the copy for the last remaining reference. This
|
||||||
|
requires waiting for any pending copies.
|
||||||
|
|
||||||
|
### Thread-safety detailed design
|
||||||
|
There are two operations that affect the copy-on-write details of a
|
||||||
|
tensor:
|
||||||
|
|
||||||
|
1) lazy-clone (e.g. an explicit call or a hidden implementation detail
|
||||||
|
added through an operator like reshape)
|
||||||
|
2) materialization (i.e. any write to the tensor)
|
||||||
|
|
||||||
|
The key insight that we exploit is that lazy-clone is logically a read
|
||||||
|
operation and materialization is logically a write operation. This
|
||||||
|
means that, for a given set of tensors that share a storage, if
|
||||||
|
materialization is taking place, no other read operation, including
|
||||||
|
lazy-clone, can be concurrent with it.
|
||||||
|
|
||||||
|
However, this insight only applies within a set of tensors that share
|
||||||
|
a storage. We also have to be concerned with tensors with different
|
||||||
|
storages that share a copy-on-write context. In this world,
|
||||||
|
materialization can race with lazy-clone or even other
|
||||||
|
materializations. _However_, in order for this to be the case, there
|
||||||
|
must be _at least_ two references to the context. This means that the
|
||||||
|
context _can not_ vanish out from under you if you are performing a
|
||||||
|
lazy-clone, and hence, it only requires an atomic refcount bump.
|
||||||
|
|
||||||
|
The most complicated case is that all lazy-copies are concurrently
|
||||||
|
materializing. In this case, because a write is occurring, there are
|
||||||
|
no in-flight lazy-copies taking place. We must simply ensure that all
|
||||||
|
lazy-copies are able to materialize (read the data) concurrently. If
|
||||||
|
we didn't have the aforementioned optimization where the last copy
|
||||||
|
steals the data, we could get away with no locking whatsoever: each
|
||||||
|
makes a copy and decrements the refcount. However, because of the
|
||||||
|
optimization, we require the loser of the materializing race wait for
|
||||||
|
the pending copies to finish, and then steal the data without copying
|
||||||
|
it.
|
||||||
|
|
||||||
|
We implement this by taking a shared lock when copying the data and
|
||||||
|
taking an exclusive lock when stealing the data. The exclusive lock
|
||||||
|
acquisition ensures that all pending shared locks are finished before
|
||||||
|
we steal the data.
|
||||||
38
c10/core/impl/cow/context.cpp
Normal file
38
c10/core/impl/cow/context.cpp
Normal file
|
|
@ -0,0 +1,38 @@
|
||||||
|
#include <c10/core/impl/cow/context.h>
|
||||||
|
|
||||||
|
#include <c10/core/impl/cow/deleter.h>
|
||||||
|
#include <c10/util/Exception.h>
|
||||||
|
|
||||||
|
namespace c10::impl {
|
||||||
|
|
||||||
|
cow::Context::Context(std::unique_ptr<void, DeleterFnPtr> data)
|
||||||
|
: data_(std::move(data)) {
|
||||||
|
// We never wrap a Context.
|
||||||
|
TORCH_INTERNAL_ASSERT(data_.get_deleter() != cow::delete_context);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto cow::Context::increment_refcount() -> void {
|
||||||
|
auto refcount = ++refcount_;
|
||||||
|
TORCH_INTERNAL_ASSERT(refcount > 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto cow::Context::decrement_refcount()
|
||||||
|
-> std::variant<NotLastReference, LastReference> {
|
||||||
|
auto refcount = --refcount_;
|
||||||
|
TORCH_INTERNAL_ASSERT(refcount >= 0, refcount);
|
||||||
|
if (refcount == 0) {
|
||||||
|
std::unique_lock lock(mutex_);
|
||||||
|
auto result = std::move(data_);
|
||||||
|
lock.unlock();
|
||||||
|
delete this;
|
||||||
|
return {std::move(result)};
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::shared_lock(mutex_);
|
||||||
|
}
|
||||||
|
|
||||||
|
cow::Context::~Context() {
|
||||||
|
TORCH_INTERNAL_ASSERT(refcount_ == 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace c10::impl
|
||||||
58
c10/core/impl/cow/context.h
Normal file
58
c10/core/impl/cow/context.h
Normal file
|
|
@ -0,0 +1,58 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <c10/macros/Macros.h>
|
||||||
|
#include <c10/util/UniqueVoidPtr.h>
|
||||||
|
|
||||||
|
#include <atomic>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <memory>
|
||||||
|
#include <mutex>
|
||||||
|
#include <shared_mutex>
|
||||||
|
#include <variant>
|
||||||
|
|
||||||
|
namespace c10::impl::cow {
|
||||||
|
|
||||||
|
/// The c10::DataPtr context for copy-on-write storage.
|
||||||
|
class C10_API Context {
|
||||||
|
public:
|
||||||
|
/// Creates an instance, holding the pair of data and original
|
||||||
|
/// deleter.
|
||||||
|
///
|
||||||
|
/// Note that the deleter will only be called in our destructor if
|
||||||
|
/// the last reference to this goes away without getting
|
||||||
|
/// materialized.
|
||||||
|
explicit Context(std::unique_ptr<void, DeleterFnPtr> data);
|
||||||
|
|
||||||
|
/// Increments the current refcount.
|
||||||
|
auto increment_refcount() -> void;
|
||||||
|
|
||||||
|
// See README.md in this directory to understand the locking
|
||||||
|
// strategy.
|
||||||
|
|
||||||
|
/// Represents a reference to the context.
|
||||||
|
///
|
||||||
|
/// This is returned by decrement_refcount to allow the caller to
|
||||||
|
/// copy the data under the shared lock.
|
||||||
|
using NotLastReference = std::shared_lock<std::shared_mutex>;
|
||||||
|
|
||||||
|
/// Represents the last reference to the context.
|
||||||
|
///
|
||||||
|
/// This will be returned by decrement_refcount when it is the last
|
||||||
|
/// reference remaining and after any pending copies have completed.
|
||||||
|
using LastReference = std::unique_ptr<void, DeleterFnPtr>;
|
||||||
|
|
||||||
|
/// Decrements the refcount, returning a handle indicating what to
|
||||||
|
/// do with it.
|
||||||
|
auto decrement_refcount() -> std::variant<NotLastReference, LastReference>;
|
||||||
|
|
||||||
|
private:
|
||||||
|
// The destructor is hidden, this should only ever be used within
|
||||||
|
// UniqueVoidPtr using cow::delete_context as the deleter.
|
||||||
|
~Context();
|
||||||
|
|
||||||
|
std::shared_mutex mutex_;
|
||||||
|
std::unique_ptr<void, DeleterFnPtr> data_;
|
||||||
|
std::atomic<std::int64_t> refcount_ = 1;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace c10::impl::cow
|
||||||
14
c10/core/impl/cow/deleter.cpp
Normal file
14
c10/core/impl/cow/deleter.cpp
Normal file
|
|
@ -0,0 +1,14 @@
|
||||||
|
#include <c10/core/impl/cow/deleter.h>
|
||||||
|
|
||||||
|
#include <c10/core/impl/cow/context.h>
|
||||||
|
|
||||||
|
namespace c10::impl {
|
||||||
|
|
||||||
|
/// Deletes a copy-on-write context.
|
||||||
|
///
|
||||||
|
/// Requires: ctx is cow::Context.
|
||||||
|
auto cow::delete_context(void* ctx) -> void {
|
||||||
|
static_cast<cow::Context*>(ctx)->decrement_refcount();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace c10::impl
|
||||||
21
c10/core/impl/cow/deleter.h
Normal file
21
c10/core/impl/cow/deleter.h
Normal file
|
|
@ -0,0 +1,21 @@
|
||||||
|
// This is its own header to minimize code visible in other public
|
||||||
|
// headers in the system. This is beneficial for compilation times as
|
||||||
|
// well as to avoid issues with internal Meta builds that aren't using
|
||||||
|
// C++17.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <c10/macros/Macros.h>
|
||||||
|
|
||||||
|
namespace c10 {
|
||||||
|
namespace impl {
|
||||||
|
namespace cow {
|
||||||
|
|
||||||
|
/// Deletes a copy-on-write context.
|
||||||
|
///
|
||||||
|
/// Requires: ctx is cow::Context.
|
||||||
|
auto C10_API delete_context(void* ctx) -> void;
|
||||||
|
|
||||||
|
} // namespace cow
|
||||||
|
} // namespace impl
|
||||||
|
} // namespace c10
|
||||||
|
|
@ -9,6 +9,15 @@ def define_targets(rules):
|
||||||
visibility = ["//:__pkg__"],
|
visibility = ["//:__pkg__"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
rules.cc_test(
|
||||||
|
name = "core/impl/cow/context_test",
|
||||||
|
srcs = ["core/impl/cow/context_test.cpp"],
|
||||||
|
deps = [
|
||||||
|
"//c10/core:impl/cow/context",
|
||||||
|
"@com_google_googletest//:gtest_main",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
rules.cc_test(
|
rules.cc_test(
|
||||||
name = "core_tests",
|
name = "core_tests",
|
||||||
size = "small",
|
size = "small",
|
||||||
|
|
|
||||||
76
c10/test/core/impl/cow/context_test.cpp
Normal file
76
c10/test/core/impl/cow/context_test.cpp
Normal file
|
|
@ -0,0 +1,76 @@
|
||||||
|
#include <c10/core/impl/cow/context.h>
|
||||||
|
|
||||||
|
#include <c10/core/impl/cow/deleter.h>
|
||||||
|
|
||||||
|
#include <gmock/gmock.h>
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
namespace c10::impl {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
class DeleteTracker {
|
||||||
|
public:
|
||||||
|
explicit DeleteTracker(int& delete_count) : delete_count_(delete_count) {}
|
||||||
|
~DeleteTracker() {
|
||||||
|
++delete_count_;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
int& delete_count_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class ContextTest : public testing::Test {
|
||||||
|
protected:
|
||||||
|
auto delete_count() const -> int {
|
||||||
|
return delete_count_;
|
||||||
|
}
|
||||||
|
auto new_delete_tracker() -> std::unique_ptr<void, DeleterFnPtr> {
|
||||||
|
return {new DeleteTracker(delete_count_), +[](void* ptr) {
|
||||||
|
delete static_cast<DeleteTracker*>(ptr);
|
||||||
|
}};
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
int delete_count_ = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(ContextTest, Basic) {
|
||||||
|
auto& context = *new cow::Context(new_delete_tracker());
|
||||||
|
ASSERT_THAT(delete_count(), testing::Eq(0));
|
||||||
|
|
||||||
|
context.increment_refcount();
|
||||||
|
|
||||||
|
{
|
||||||
|
// This is in a sub-scope because this call to decrement_refcount
|
||||||
|
// is expected to give us a shared lock.
|
||||||
|
auto result = context.decrement_refcount();
|
||||||
|
ASSERT_THAT(
|
||||||
|
std::holds_alternative<cow::Context::NotLastReference>(result),
|
||||||
|
testing::IsTrue());
|
||||||
|
ASSERT_THAT(delete_count(), testing::Eq(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto result = context.decrement_refcount();
|
||||||
|
ASSERT_THAT(
|
||||||
|
std::holds_alternative<cow::Context::LastReference>(result),
|
||||||
|
testing::IsTrue());
|
||||||
|
// Result holds the DeleteTracker.
|
||||||
|
ASSERT_THAT(delete_count(), testing::Eq(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
// When result is deleted, the DeleteTracker is also deleted.
|
||||||
|
ASSERT_THAT(delete_count(), testing::Eq(1));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ContextTest, delete_context) {
|
||||||
|
// This is effectively the same thing as decrement_refcount() above.
|
||||||
|
auto& context = *new cow::Context(new_delete_tracker());
|
||||||
|
ASSERT_THAT(delete_count(), testing::Eq(0));
|
||||||
|
|
||||||
|
cow::delete_context(&context);
|
||||||
|
ASSERT_THAT(delete_count(), testing::Eq(1));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace c10::impl
|
||||||
1
setup.py
1
setup.py
|
|
@ -1109,6 +1109,7 @@ def main():
|
||||||
'include/ATen/core/dispatch/*.h',
|
'include/ATen/core/dispatch/*.h',
|
||||||
'include/ATen/core/op_registration/*.h',
|
'include/ATen/core/op_registration/*.h',
|
||||||
'include/c10/core/impl/*.h',
|
'include/c10/core/impl/*.h',
|
||||||
|
'include/c10/core/impl/cow/*.h',
|
||||||
'include/c10/util/*.h',
|
'include/c10/util/*.h',
|
||||||
'include/c10/cuda/*.h',
|
'include/c10/cuda/*.h',
|
||||||
'include/c10/cuda/impl/*.h',
|
'include/c10/cuda/impl/*.h',
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user