mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
[xla:cpu] Fix data race in ThunkExecutor
Also add tsl::down_pointer_cast to improve usability. PiperOrigin-RevId: 822257137
This commit is contained in:
parent
5776d2771c
commit
0fc052399b
|
|
@ -233,11 +233,12 @@ tsl::AsyncValueRef<Thunk::ExecuteEvent> ThunkExecutor::TracedExecute(
|
||||||
|
|
||||||
// When thunk execution completes, create a consumer traceme to capture the
|
// When thunk execution completes, create a consumer traceme to capture the
|
||||||
// end event.
|
// end event.
|
||||||
execute_event.AndThen([context_id = producer.GetContextId(), &thunk] {
|
execute_event.AndThen(
|
||||||
tsl::profiler::TraceMeConsumer(
|
[context_id = producer.GetContextId(), op_name = thunk.info().op_name] {
|
||||||
[&] { return absl::StrFormat("end: %s", thunk.info().op_name); },
|
tsl::profiler::TraceMeConsumer(
|
||||||
tsl::profiler::ContextType::kGeneric, context_id);
|
[&] { return absl::StrFormat("end: %s", op_name); },
|
||||||
});
|
tsl::profiler::ContextType::kGeneric, context_id);
|
||||||
|
});
|
||||||
|
|
||||||
return execute_event;
|
return execute_event;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
5
third_party/xla/xla/pjrt/cpu/cpu_client.cc
vendored
5
third_party/xla/xla/pjrt/cpu/cpu_client.cc
vendored
|
|
@ -1540,8 +1540,8 @@ absl::StatusOr<PjRtLoadedExecutable::Result> PjRtCpuExecutable::ExecuteHelper(
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
auto* cpu_executable =
|
auto cpu_executable =
|
||||||
tsl::down_cast<cpu::CpuExecutable*>(cpu_executable_.get());
|
tsl::down_pointer_cast<cpu::CpuExecutable>(cpu_executable_);
|
||||||
// `buffer_alloc` and `buffer_alloc_and_copy` are used to do real memory
|
// `buffer_alloc` and `buffer_alloc_and_copy` are used to do real memory
|
||||||
// allocation and copy work.
|
// allocation and copy work.
|
||||||
BufferAlloc buffer_alloc;
|
BufferAlloc buffer_alloc;
|
||||||
|
|
@ -1755,7 +1755,6 @@ absl::StatusOr<PjRtLoadedExecutable::Result> PjRtCpuExecutable::ExecuteHelper(
|
||||||
buffer_alloc_and_copy = std::move(buffer_alloc_and_copy),
|
buffer_alloc_and_copy = std::move(buffer_alloc_and_copy),
|
||||||
buffer_table = std::move(buffer_table),
|
buffer_table = std::move(buffer_table),
|
||||||
run_options = std::move(run_options),
|
run_options = std::move(run_options),
|
||||||
cpu_executable_copy = cpu_executable_,
|
|
||||||
device_assignment = std::move(device_assignment),
|
device_assignment = std::move(device_assignment),
|
||||||
cpu_run_options = std::move(cpu_run_options),
|
cpu_run_options = std::move(cpu_run_options),
|
||||||
compute_reservation = std::move(compute_reservation),
|
compute_reservation = std::move(compute_reservation),
|
||||||
|
|
|
||||||
12
third_party/xla/xla/tsl/platform/default/casts.h
vendored
12
third_party/xla/xla/tsl/platform/default/casts.h
vendored
|
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||||
|
|
||||||
#include <assert.h> // for use with down_cast<>
|
#include <assert.h> // for use with down_cast<>
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
#include <type_traits>
|
#include <type_traits>
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
@ -87,10 +88,19 @@ inline To down_cast(From& f) {
|
||||||
return static_cast<To>(f);
|
return static_cast<To>(f);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// A `down_cast` version for `std::shared_ptr`.
|
||||||
|
template <typename To, typename From>
|
||||||
|
std::shared_ptr<To> down_pointer_cast(const std::shared_ptr<From>& from) {
|
||||||
|
auto* ptr =
|
||||||
|
down_cast<typename std::shared_ptr<To>::element_type*>(from.get());
|
||||||
|
return std::shared_ptr<To>{from, ptr};
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
namespace tsl {
|
namespace tsl {
|
||||||
using ::tensorflow::down_cast;
|
using ::tensorflow::down_cast;
|
||||||
}
|
using ::tensorflow::down_pointer_cast;
|
||||||
|
} // namespace tsl
|
||||||
|
|
||||||
#endif // XLA_TSL_PLATFORM_DEFAULT_CASTS_H_
|
#endif // XLA_TSL_PLATFORM_DEFAULT_CASTS_H_
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user