[xla:cpu] Fix data race in ThunkExecutor

Also add tsl::down_pointer_cast to improve usability.

PiperOrigin-RevId: 822257137
This commit is contained in:
Eugene Zhulenev 2025-10-21 13:38:05 -07:00 committed by TensorFlower Gardener
parent 5776d2771c
commit 0fc052399b
3 changed files with 19 additions and 9 deletions

View File

@ -233,11 +233,12 @@ tsl::AsyncValueRef<Thunk::ExecuteEvent> ThunkExecutor::TracedExecute(
// When thunk execution completes, create a consumer traceme to capture the // When thunk execution completes, create a consumer traceme to capture the
// end event. // end event.
execute_event.AndThen([context_id = producer.GetContextId(), &thunk] { execute_event.AndThen(
tsl::profiler::TraceMeConsumer( [context_id = producer.GetContextId(), op_name = thunk.info().op_name] {
[&] { return absl::StrFormat("end: %s", thunk.info().op_name); }, tsl::profiler::TraceMeConsumer(
tsl::profiler::ContextType::kGeneric, context_id); [&] { return absl::StrFormat("end: %s", op_name); },
}); tsl::profiler::ContextType::kGeneric, context_id);
});
return execute_event; return execute_event;
} }

View File

@ -1540,8 +1540,8 @@ absl::StatusOr<PjRtLoadedExecutable::Result> PjRtCpuExecutable::ExecuteHelper(
}); });
} }
auto* cpu_executable = auto cpu_executable =
tsl::down_cast<cpu::CpuExecutable*>(cpu_executable_.get()); tsl::down_pointer_cast<cpu::CpuExecutable>(cpu_executable_);
// `buffer_alloc` and `buffer_alloc_and_copy` are used to do real memory // `buffer_alloc` and `buffer_alloc_and_copy` are used to do real memory
// allocation and copy work. // allocation and copy work.
BufferAlloc buffer_alloc; BufferAlloc buffer_alloc;
@ -1755,7 +1755,6 @@ absl::StatusOr<PjRtLoadedExecutable::Result> PjRtCpuExecutable::ExecuteHelper(
buffer_alloc_and_copy = std::move(buffer_alloc_and_copy), buffer_alloc_and_copy = std::move(buffer_alloc_and_copy),
buffer_table = std::move(buffer_table), buffer_table = std::move(buffer_table),
run_options = std::move(run_options), run_options = std::move(run_options),
cpu_executable_copy = cpu_executable_,
device_assignment = std::move(device_assignment), device_assignment = std::move(device_assignment),
cpu_run_options = std::move(cpu_run_options), cpu_run_options = std::move(cpu_run_options),
compute_reservation = std::move(compute_reservation), compute_reservation = std::move(compute_reservation),

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <assert.h> // for use with down_cast<> #include <assert.h> // for use with down_cast<>
#include <memory>
#include <type_traits> #include <type_traits>
namespace tensorflow { namespace tensorflow {
@ -87,10 +88,19 @@ inline To down_cast(From& f) {
return static_cast<To>(f); return static_cast<To>(f);
} }
// A `down_cast` version for `std::shared_ptr`.
template <typename To, typename From>
std::shared_ptr<To> down_pointer_cast(const std::shared_ptr<From>& from) {
auto* ptr =
down_cast<typename std::shared_ptr<To>::element_type*>(from.get());
return std::shared_ptr<To>{from, ptr};
}
} // namespace tensorflow } // namespace tensorflow
namespace tsl { namespace tsl {
using ::tensorflow::down_cast; using ::tensorflow::down_cast;
} using ::tensorflow::down_pointer_cast;
} // namespace tsl
#endif // XLA_TSL_PLATFORM_DEFAULT_CASTS_H_ #endif // XLA_TSL_PLATFORM_DEFAULT_CASTS_H_