diff --git a/third_party/xla/xla/python/transfer/socket-server.cc b/third_party/xla/xla/python/transfer/socket-server.cc index 2a082f0f5a7..a5397b8c0ad 100644 --- a/third_party/xla/xla/python/transfer/socket-server.cc +++ b/third_party/xla/xla/python/transfer/socket-server.cc @@ -114,13 +114,22 @@ class SocketServer::SocketNetworkState : public SocketFdPacketState { return SendRawFrame(std::move(opacket)); } - tsl::RCReference GetNextDest(size_t req_id, size_t offset, - size_t size, bool is_largest) { + std::optional> GetNextDest( + size_t req_id, size_t offset, size_t size, bool is_largest) { tsl::RCReference dest; { absl::MutexLock l(mu_); + if (is_poisoned_) { + return std::nullopt; + } auto it = dests_.find(req_id); - CHECK(it != dests_.end()); + if (it == dests_.end()) { + Shutdown(SHUT_RDWR); + is_poisoned_ = true; + poison_status_ = + absl::InternalError("SocketServer: it != dests_.end()"); + return std::nullopt; + } if (is_largest) { it->second.transferred_size += offset; } else { @@ -273,7 +282,10 @@ class SocketServer::SocketNetworkState : public SocketFdPacketState { void HandlePacket(const SocketTransferPacketErrorHeader& packet) { auto dest = GetNextDest(packet.req_id(), packet.offset(), packet.size(), packet.is_largest()); - dest->Poison(absl::InternalError( + if (!dest.has_value()) { + return; + } + (*dest)->Poison(absl::InternalError( absl::StrCat("Error while transferring: ", packet.error_message()))); } @@ -293,9 +305,12 @@ class SocketServer::SocketNetworkState : public SocketFdPacketState { void HandlePacket(const SocketTransferPacketHeader& packet) { auto dest = GetNextDest(packet.req_id(), packet.offset(), packet.size(), packet.is_largest()); + if (!dest.has_value()) { + return; + } bulk_transport_->Recv( packet.size(), packet.bulk_transport_id(), - [offset = packet.offset(), dest = std::move(dest)]( + [offset = packet.offset(), dest = *std::move(dest)]( absl::StatusOr msgor) { if (!msgor.ok()) { dest->Poison(msgor.status());