[xla:cpu] Fix data race in ThunkExecutor

Also add tsl::down_pointer_cast to improve usability.

PiperOrigin-RevId: 822257137
This commit is contained in:
Eugene Zhulenev 2025-10-21 13:38:05 -07:00 committed by TensorFlower Gardener
parent 5776d2771c
commit 0fc052399b
3 changed files with 19 additions and 9 deletions

View File

@ -233,11 +233,12 @@ tsl::AsyncValueRef<Thunk::ExecuteEvent> ThunkExecutor::TracedExecute(
// When thunk execution completes, create a consumer traceme to capture the
// end event.
execute_event.AndThen([context_id = producer.GetContextId(), &thunk] {
tsl::profiler::TraceMeConsumer(
[&] { return absl::StrFormat("end: %s", thunk.info().op_name); },
tsl::profiler::ContextType::kGeneric, context_id);
});
execute_event.AndThen(
[context_id = producer.GetContextId(), op_name = thunk.info().op_name] {
tsl::profiler::TraceMeConsumer(
[&] { return absl::StrFormat("end: %s", op_name); },
tsl::profiler::ContextType::kGeneric, context_id);
});
return execute_event;
}

View File

@ -1540,8 +1540,8 @@ absl::StatusOr<PjRtLoadedExecutable::Result> PjRtCpuExecutable::ExecuteHelper(
});
}
auto* cpu_executable =
tsl::down_cast<cpu::CpuExecutable*>(cpu_executable_.get());
auto cpu_executable =
tsl::down_pointer_cast<cpu::CpuExecutable>(cpu_executable_);
// `buffer_alloc` and `buffer_alloc_and_copy` are used to do real memory
// allocation and copy work.
BufferAlloc buffer_alloc;
@ -1755,7 +1755,6 @@ absl::StatusOr<PjRtLoadedExecutable::Result> PjRtCpuExecutable::ExecuteHelper(
buffer_alloc_and_copy = std::move(buffer_alloc_and_copy),
buffer_table = std::move(buffer_table),
run_options = std::move(run_options),
cpu_executable_copy = cpu_executable_,
device_assignment = std::move(device_assignment),
cpu_run_options = std::move(cpu_run_options),
compute_reservation = std::move(compute_reservation),

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <assert.h> // for use with down_cast<>
#include <memory>
#include <type_traits>
namespace tensorflow {
@ -87,10 +88,19 @@ inline To down_cast(From& f) {
return static_cast<To>(f);
}
// A `down_cast` version for `std::shared_ptr`.
template <typename To, typename From>
std::shared_ptr<To> down_pointer_cast(const std::shared_ptr<From>& from) {
auto* ptr =
down_cast<typename std::shared_ptr<To>::element_type*>(from.get());
return std::shared_ptr<To>{from, ptr};
}
} // namespace tensorflow
namespace tsl {
using ::tensorflow::down_cast;
}
using ::tensorflow::down_pointer_cast;
} // namespace tsl
#endif // XLA_TSL_PLATFORM_DEFAULT_CASTS_H_