mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
[xla:cpu] Fix data race in ThunkExecutor
Also add tsl::down_pointer_cast to improve usability. PiperOrigin-RevId: 822257137
This commit is contained in:
parent
5776d2771c
commit
0fc052399b
|
|
@ -233,11 +233,12 @@ tsl::AsyncValueRef<Thunk::ExecuteEvent> ThunkExecutor::TracedExecute(
|
|||
|
||||
// When thunk execution completes, create a consumer traceme to capture the
|
||||
// end event.
|
||||
execute_event.AndThen([context_id = producer.GetContextId(), &thunk] {
|
||||
tsl::profiler::TraceMeConsumer(
|
||||
[&] { return absl::StrFormat("end: %s", thunk.info().op_name); },
|
||||
tsl::profiler::ContextType::kGeneric, context_id);
|
||||
});
|
||||
execute_event.AndThen(
|
||||
[context_id = producer.GetContextId(), op_name = thunk.info().op_name] {
|
||||
tsl::profiler::TraceMeConsumer(
|
||||
[&] { return absl::StrFormat("end: %s", op_name); },
|
||||
tsl::profiler::ContextType::kGeneric, context_id);
|
||||
});
|
||||
|
||||
return execute_event;
|
||||
}
|
||||
|
|
|
|||
5
third_party/xla/xla/pjrt/cpu/cpu_client.cc
vendored
5
third_party/xla/xla/pjrt/cpu/cpu_client.cc
vendored
|
|
@ -1540,8 +1540,8 @@ absl::StatusOr<PjRtLoadedExecutable::Result> PjRtCpuExecutable::ExecuteHelper(
|
|||
});
|
||||
}
|
||||
|
||||
auto* cpu_executable =
|
||||
tsl::down_cast<cpu::CpuExecutable*>(cpu_executable_.get());
|
||||
auto cpu_executable =
|
||||
tsl::down_pointer_cast<cpu::CpuExecutable>(cpu_executable_);
|
||||
// `buffer_alloc` and `buffer_alloc_and_copy` are used to do real memory
|
||||
// allocation and copy work.
|
||||
BufferAlloc buffer_alloc;
|
||||
|
|
@ -1755,7 +1755,6 @@ absl::StatusOr<PjRtLoadedExecutable::Result> PjRtCpuExecutable::ExecuteHelper(
|
|||
buffer_alloc_and_copy = std::move(buffer_alloc_and_copy),
|
||||
buffer_table = std::move(buffer_table),
|
||||
run_options = std::move(run_options),
|
||||
cpu_executable_copy = cpu_executable_,
|
||||
device_assignment = std::move(device_assignment),
|
||||
cpu_run_options = std::move(cpu_run_options),
|
||||
compute_reservation = std::move(compute_reservation),
|
||||
|
|
|
|||
12
third_party/xla/xla/tsl/platform/default/casts.h
vendored
12
third_party/xla/xla/tsl/platform/default/casts.h
vendored
|
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||
|
||||
#include <assert.h> // for use with down_cast<>
|
||||
|
||||
#include <memory>
|
||||
#include <type_traits>
|
||||
|
||||
namespace tensorflow {
|
||||
|
|
@ -87,10 +88,19 @@ inline To down_cast(From& f) {
|
|||
return static_cast<To>(f);
|
||||
}
|
||||
|
||||
// A `down_cast` version for `std::shared_ptr`.
|
||||
template <typename To, typename From>
|
||||
std::shared_ptr<To> down_pointer_cast(const std::shared_ptr<From>& from) {
|
||||
auto* ptr =
|
||||
down_cast<typename std::shared_ptr<To>::element_type*>(from.get());
|
||||
return std::shared_ptr<To>{from, ptr};
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
namespace tsl {
|
||||
using ::tensorflow::down_cast;
|
||||
}
|
||||
using ::tensorflow::down_pointer_cast;
|
||||
} // namespace tsl
|
||||
|
||||
#endif // XLA_TSL_PLATFORM_DEFAULT_CASTS_H_
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user