diff --git a/third_party/xla/third_party/llvm/workspace.bzl b/third_party/xla/third_party/llvm/workspace.bzl index 08591657a4c..32e6a7a1c04 100644 --- a/third_party/xla/third_party/llvm/workspace.bzl +++ b/third_party/xla/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "22079e3f3698d5c367c7b67f63de8c838791ae76" - LLVM_SHA256 = "d5616e9c0f4b761f13da5535a0d9ec94acf4ae5226bbec3e47ac2929ea60cac2" + LLVM_COMMIT = "42a8ff877d47131ecb1280a1cc7e5e3c3bca6952" + LLVM_SHA256 = "f768c5c3b987f68318b8ab3dd4530e54988dfe7d6bfb9b7c9c96acf503367d50" tf_http_archive( name = name, diff --git a/third_party/xla/third_party/shardy/temporary.patch b/third_party/xla/third_party/shardy/temporary.patch index 619108f749e..1cb2108d447 100644 --- a/third_party/xla/third_party/shardy/temporary.patch +++ b/third_party/xla/third_party/shardy/temporary.patch @@ -1,719 +1,15 @@ -diff --git a/shardy/integrations/python/jax/mpmd/ops.py b/shardy/integrations/python/jax/mpmd/ops.py -index 60ae866..ea144d5 100644 ---- a/shardy/integrations/python/jax/mpmd/ops.py -+++ b/shardy/integrations/python/jax/mpmd/ops.py -@@ -29,13 +29,11 @@ from jax._src import util - from jax._src.interpreters import ad as internal_ad - from jax._src.interpreters import batching as internal_batching - from jax._src.interpreters import partial_eval as internal_pe -+import jax._src.lib.mlir.dialects as jax_mlir_dialects - import jax.extend as jex - from jax.extend import linear_util as lu - from jax.extend import source_info_util as siu - from jax.extend.core import primitives --from jax.extend.mlir import ir --from jax.extend.mlir.dialects import func as func_dialect --from jax.extend.mlir.dialects import mpmd - from jax.interpreters import ad - from jax.interpreters import batching - from jax.interpreters import mlir as jax_mlir -@@ -45,6 +43,10 @@ import jaxtyping - from shardy.integrations.python.jax.mpmd import utils - - -+ir = jax_mlir.ir -+mpmd = jax_mlir_dialects.mpmd -+func_dialect = jax_mlir_dialects.func -+ - PyTree = jaxtyping.PyTree - X = TypeVar('X') - Y = TypeVar('Y') -diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch -index 4117b0a..509398d 100644 ---- a/third_party/llvm/generated.patch -+++ b/third_party/llvm/generated.patch -@@ -1,576 +1 @@ - Auto generated patch. Do not edit or delete it, even if empty. --diff -ruN --strip-trailing-cr a/lldb/packages/Python/lldbsuite/test/tools/lldb-dap/dap_server.py b/lldb/packages/Python/lldbsuite/test/tools/lldb-dap/dap_server.py ----- a/lldb/packages/Python/lldbsuite/test/tools/lldb-dap/dap_server.py --+++ b/lldb/packages/Python/lldbsuite/test/tools/lldb-dap/dap_server.py --@@ -10,8 +10,8 @@ -- import subprocess -- import signal -- import sys --+import threading -- import warnings ---import selectors -- import time -- from typing import ( -- Any, --@@ -139,6 +139,35 @@ -- outfile.write("\n") -- -- --+def read_packet( --+ f: IO[bytes], trace_file: Optional[IO[str]] = None --+) -> Optional[ProtocolMessage]: --+ """Decode a JSON packet that starts with the content length and is --+ followed by the JSON bytes from a file 'f'. Returns None on EOF. --+ """ --+ line = f.readline().decode("utf-8") --+ if len(line) == 0: --+ return None # EOF. --+ --+ # Watch for line that starts with the prefix --+ prefix = "Content-Length: " --+ if line.startswith(prefix): --+ # Decode length of JSON bytes --+ length = int(line[len(prefix) :]) --+ # Skip empty line --+ separator = f.readline().decode() --+ if separator != "": --+ Exception("malformed DAP content header, unexpected line: " + separator) --+ # Read JSON bytes --+ json_str = f.read(length).decode() --+ if trace_file: --+ trace_file.write("from adapter:\n%s\n" % (json_str)) --+ # Decode the JSON bytes into a python dictionary --+ return json.loads(json_str) --+ --+ raise Exception("unexpected malformed message from lldb-dap: " + line) --+ --+ -- def packet_type_is(packet, packet_type): -- return "type" in packet and packet["type"] == packet_type -- --@@ -170,8 +199,16 @@ -- self.log_file = log_file -- self.send = send -- self.recv = recv --- self.selector = selectors.DefaultSelector() --- self.selector.register(recv, selectors.EVENT_READ) --+ --+ # Packets that have been received and processed but have not yet been --+ # requested by a test case. --+ self._pending_packets: List[Optional[ProtocolMessage]] = [] --+ # Received packets that have not yet been processed. --+ self._recv_packets: List[Optional[ProtocolMessage]] = [] --+ # Used as a mutex for _recv_packets and for notify when _recv_packets --+ # changes. --+ self._recv_condition = threading.Condition() --+ self._recv_thread = threading.Thread(target=self._read_packet_thread) -- -- # session state -- self.init_commands = init_commands --@@ -197,6 +234,9 @@ -- # keyed by breakpoint id -- self.resolved_breakpoints: dict[str, Breakpoint] = {} -- --+ # trigger enqueue thread --+ self._recv_thread.start() --+ -- @classmethod -- def encode_content(cls, s: str) -> bytes: -- return ("Content-Length: %u\r\n\r\n%s" % (len(s), s)).encode("utf-8") --@@ -212,46 +252,17 @@ -- f"seq mismatch in response {command['seq']} != {response['request_seq']}" -- ) -- --- def _read_packet( --- self, --- timeout: float = DEFAULT_TIMEOUT, --- ) -> Optional[ProtocolMessage]: --- """Decode a JSON packet that starts with the content length and is --- followed by the JSON bytes from self.recv. Returns None on EOF. --- """ --- --- ready = self.selector.select(timeout) --- if not ready: --- warnings.warn( --- "timeout occurred waiting for a packet, check if the test has a" --- " negative assertion and see if it can be inverted.", --- stacklevel=4, --- ) --- return None # timeout --- --- line = self.recv.readline().decode("utf-8") --- if len(line) == 0: --- return None # EOF. --- --- # Watch for line that starts with the prefix --- prefix = "Content-Length: " --- if line.startswith(prefix): --- # Decode length of JSON bytes --- length = int(line[len(prefix) :]) --- # Skip empty line --- separator = self.recv.readline().decode() --- if separator != "": --- Exception("malformed DAP content header, unexpected line: " + separator) --- # Read JSON bytes --- json_str = self.recv.read(length).decode() --- if self.trace_file: --- self.trace_file.write( --- "%s from adapter:\n%s\n" % (time.time(), json_str) --- ) --- # Decode the JSON bytes into a python dictionary --- return json.loads(json_str) --- --- raise Exception("unexpected malformed message from lldb-dap: " + line) --+ def _read_packet_thread(self): --+ try: --+ while True: --+ packet = read_packet(self.recv, trace_file=self.trace_file) --+ # `packet` will be `None` on EOF. We want to pass it down to --+ # handle_recv_packet anyway so the main thread can handle unexpected --+ # termination of lldb-dap and stop waiting for new packets. --+ if not self._handle_recv_packet(packet): --+ break --+ finally: --+ dump_dap_log(self.log_file) -- -- def get_modules( -- self, start_module: Optional[int] = None, module_count: Optional[int] = None --@@ -299,6 +310,34 @@ -- output += self.get_output(category, clear=clear) -- return output -- --+ def _enqueue_recv_packet(self, packet: Optional[ProtocolMessage]): --+ with self.recv_condition: --+ self.recv_packets.append(packet) --+ self.recv_condition.notify() --+ --+ def _handle_recv_packet(self, packet: Optional[ProtocolMessage]) -> bool: --+ """Handles an incoming packet. --+ --+ Called by the read thread that is waiting for all incoming packets --+ to store the incoming packet in "self._recv_packets" in a thread safe --+ way. This function will then signal the "self._recv_condition" to --+ indicate a new packet is available. --+ --+ Args: --+ packet: A new packet to store. --+ --+ Returns: --+ True if the caller should keep calling this function for more --+ packets. --+ """ --+ with self._recv_condition: --+ self._recv_packets.append(packet) --+ self._recv_condition.notify() --+ # packet is None on EOF --+ return packet is not None and not ( --+ packet["type"] == "response" and packet["command"] == "disconnect" --+ ) --+ -- def _recv_packet( -- self, -- *, --@@ -322,34 +361,46 @@ -- The first matching packet for the given predicate, if specified, -- otherwise None. -- """ --- deadline = time.time() + timeout --- --- while time.time() < deadline: --- packet = self._read_packet(timeout=deadline - time.time()) --- if packet is None: --- return None --- self._process_recv_packet(packet) --- if not predicate or predicate(packet): --- return packet --+ assert ( --+ threading.current_thread != self._recv_thread --+ ), "Must not be called from the _recv_thread" --+ --+ def process_until_match(): --+ self._process_recv_packets() --+ for i, packet in enumerate(self._pending_packets): --+ if packet is None: --+ # We need to return a truthy value to break out of the --+ # wait_for, use `EOFError` as an indicator of EOF. --+ return EOFError() --+ if predicate and predicate(packet): --+ self._pending_packets.pop(i) --+ return packet --+ --+ with self._recv_condition: --+ packet = self._recv_condition.wait_for(process_until_match, timeout) --+ return None if isinstance(packet, EOFError) else packet -- --- def _process_recv_packet(self, packet) -> None: --+ def _process_recv_packets(self) -> None: -- """Process received packets, updating the session state.""" --- if packet and ("seq" not in packet or packet["seq"] == 0): --- warnings.warn( --- f"received a malformed packet, expected 'seq != 0' for {packet!r}" --- ) --- # Handle events that may modify any stateful properties of --- # the DAP session. --- if packet and packet["type"] == "event": --- self._handle_event(packet) --- elif packet and packet["type"] == "request": --- # Handle reverse requests and keep processing. --- self._handle_reverse_request(packet) --+ with self._recv_condition: --+ for packet in self._recv_packets: --+ if packet and ("seq" not in packet or packet["seq"] == 0): --+ warnings.warn( --+ f"received a malformed packet, expected 'seq != 0' for {packet!r}" --+ ) --+ # Handle events that may modify any stateful properties of --+ # the DAP session. --+ if packet and packet["type"] == "event": --+ self._handle_event(packet) --+ elif packet and packet["type"] == "request": --+ # Handle reverse requests and keep processing. --+ self._handle_reverse_request(packet) --+ # Move the packet to the pending queue. --+ self._pending_packets.append(packet) --+ self._recv_packets.clear() -- -- def _handle_event(self, packet: Event) -> None: -- """Handle any events that modify debug session state we track.""" --- self.events.append(packet) --- -- event = packet["event"] -- body: Optional[Dict] = packet.get("body", None) -- --@@ -402,8 +453,6 @@ -- self.invalidated_event = packet -- elif event == "memory": -- self.memory_event = packet --- elif event == "module": --- self.module_events.append(packet) -- -- def _handle_reverse_request(self, request: Request) -> None: -- if request in self.reverse_requests: --@@ -472,14 +521,18 @@ -- -- Returns the seq number of the request. -- """ --- packet["seq"] = self.sequence --- self.sequence += 1 --+ # Set the seq for requests. --+ if packet["type"] == "request": --+ packet["seq"] = self.sequence --+ self.sequence += 1 --+ else: --+ packet["seq"] = 0 -- -- # Encode our command dictionary as a JSON string -- json_str = json.dumps(packet, separators=(",", ":")) -- -- if self.trace_file: --- self.trace_file.write("%s to adapter:\n%s\n" % (time.time(), json_str)) --+ self.trace_file.write("to adapter:\n%s\n" % (json_str)) -- -- length = len(json_str) -- if length > 0: --@@ -860,8 +913,6 @@ -- if restartArguments: -- command_dict["arguments"] = restartArguments -- --- # Clear state, the process is about to restart... --- self._process_continued(True) -- response = self._send_recv(command_dict) -- # Caller must still call wait_for_stopped. -- return response --@@ -1428,10 +1479,8 @@ -- -- def terminate(self): -- self.send.close() --- self.recv.close() --- self.selector.close() --- if self.log_file: --- dump_dap_log(self.log_file) --+ if self._recv_thread.is_alive(): --+ self._recv_thread.join() -- -- def request_setInstructionBreakpoints(self, memory_reference=[]): -- breakpoints = [] --@@ -1528,7 +1577,6 @@ -- stdout=subprocess.PIPE, -- stderr=sys.stderr, -- env=adapter_env, --- bufsize=0, -- ) -- -- if connection is None: --diff -ruN --strip-trailing-cr a/lldb/packages/Python/lldbsuite/test/tools/lldb-dap/lldbdap_testcase.py b/lldb/packages/Python/lldbsuite/test/tools/lldb-dap/lldbdap_testcase.py ----- a/lldb/packages/Python/lldbsuite/test/tools/lldb-dap/lldbdap_testcase.py --+++ b/lldb/packages/Python/lldbsuite/test/tools/lldb-dap/lldbdap_testcase.py --@@ -416,7 +416,7 @@ -- return self.dap_server.wait_for_stopped() -- -- def continue_to_breakpoint(self, breakpoint_id: str): --- self.continue_to_breakpoints([breakpoint_id]) --+ self.continue_to_breakpoints((breakpoint_id)) -- -- def continue_to_breakpoints(self, breakpoint_ids): -- self.do_continue() --diff -ruN --strip-trailing-cr a/lldb/test/API/tools/lldb-dap/breakpoint-events/TestDAP_breakpointEvents.py b/lldb/test/API/tools/lldb-dap/breakpoint-events/TestDAP_breakpointEvents.py ----- a/lldb/test/API/tools/lldb-dap/breakpoint-events/TestDAP_breakpointEvents.py --+++ b/lldb/test/API/tools/lldb-dap/breakpoint-events/TestDAP_breakpointEvents.py --@@ -81,20 +81,24 @@ -- breakpoint["verified"], "expect foo breakpoint to not be verified" -- ) -- --+ # Flush the breakpoint events. --+ self.dap_server.wait_for_breakpoint_events() --+ -- # Continue to the breakpoint --- self.continue_to_breakpoint(foo_bp_id) --- self.continue_to_next_stop() # foo_bp2 --- self.continue_to_breakpoint(main_bp_id) --- self.continue_to_exit() --+ self.continue_to_breakpoints(dap_breakpoint_ids) -- --- bp_events = [e for e in self.dap_server.events if e["event"] == "breakpoint"] --+ verified_breakpoint_ids = [] --+ unverified_breakpoint_ids = [] --+ for breakpoint_event in self.dap_server.wait_for_breakpoint_events(): --+ breakpoint = breakpoint_event["body"]["breakpoint"] --+ id = breakpoint["id"] --+ if breakpoint["verified"]: --+ verified_breakpoint_ids.append(id) --+ else: --+ unverified_breakpoint_ids.append(id) -- --- main_bp_events = [ --- e for e in bp_events if e["body"]["breakpoint"]["id"] == main_bp_id --- ] --- foo_bp_events = [ --- e for e in bp_events if e["body"]["breakpoint"]["id"] == foo_bp_id --- ] --+ self.assertIn(main_bp_id, unverified_breakpoint_ids) --+ self.assertIn(foo_bp_id, unverified_breakpoint_ids) -- --- self.assertTrue(main_bp_events) --- self.assertTrue(foo_bp_events) --+ self.assertIn(main_bp_id, verified_breakpoint_ids) --+ self.assertIn(foo_bp_id, verified_breakpoint_ids) --diff -ruN --strip-trailing-cr a/lldb/test/API/tools/lldb-dap/launch/TestDAP_launch.py b/lldb/test/API/tools/lldb-dap/launch/TestDAP_launch.py ----- a/lldb/test/API/tools/lldb-dap/launch/TestDAP_launch.py --+++ b/lldb/test/API/tools/lldb-dap/launch/TestDAP_launch.py --@@ -156,7 +156,6 @@ -- self.build_and_launch( -- program, debuggerRoot=program_parent_dir, initCommands=commands -- ) --- self.continue_to_exit() -- output = self.get_console() -- self.assertTrue(output and len(output) > 0, "expect console output") -- lines = output.splitlines() --@@ -172,6 +171,7 @@ -- % (program_parent_dir, line[len(prefix) :]), -- ) -- self.assertTrue(found, "verified lldb-dap working directory") --+ self.continue_to_exit() -- -- def test_sourcePath(self): -- """ --diff -ruN --strip-trailing-cr a/lldb/test/API/tools/lldb-dap/module/TestDAP_module.py b/lldb/test/API/tools/lldb-dap/module/TestDAP_module.py ----- a/lldb/test/API/tools/lldb-dap/module/TestDAP_module.py --+++ b/lldb/test/API/tools/lldb-dap/module/TestDAP_module.py --@@ -64,18 +64,19 @@ -- self.assertEqual(program, program_module["path"]) -- self.assertIn("addressRange", program_module) -- --- self.continue_to_exit() --- -- # Collect all the module names we saw as events. -- module_new_names = [] -- module_changed_names = [] --- for module_event in self.dap_server.module_events: --+ module_event = self.dap_server.wait_for_event(["module"]) --+ while module_event is not None: -- reason = module_event["body"]["reason"] -- if reason == "new": -- module_new_names.append(module_event["body"]["module"]["name"]) -- elif reason == "changed": -- module_changed_names.append(module_event["body"]["module"]["name"]) -- --+ module_event = self.dap_server.wait_for_event(["module"]) --+ -- # Make sure we got an event for every active module. -- self.assertNotEqual(len(module_new_names), 0) -- for module in active_modules: --@@ -85,6 +86,7 @@ -- # symbols got added. -- self.assertNotEqual(len(module_changed_names), 0) -- self.assertIn(program_module["name"], module_changed_names) --+ self.continue_to_exit() -- -- @skipIfWindows -- def test_modules(self): --diff -ruN --strip-trailing-cr a/lldb/test/API/tools/lldb-dap/module-event/TestDAP_module_event.py b/lldb/test/API/tools/lldb-dap/module-event/TestDAP_module_event.py ----- a/lldb/test/API/tools/lldb-dap/module-event/TestDAP_module_event.py --+++ b/lldb/test/API/tools/lldb-dap/module-event/TestDAP_module_event.py --@@ -1,58 +1,58 @@ ---""" ---Test 'module' events for dynamically loaded libraries. ---""" --- --+import dap_server -- from lldbsuite.test.decorators import * -- from lldbsuite.test.lldbtest import * --+from lldbsuite.test import lldbutil -- import lldbdap_testcase --+import re -- -- -- class TestDAP_module_event(lldbdap_testcase.DAPTestCaseBase): --- def lookup_module_id(self, name): --- """Returns the identifier for the first module event starting with the given name.""" --- for event in self.dap_server.module_events: --- if self.get_dict_value(event, ["body", "module", "name"]).startswith(name): --- return self.get_dict_value(event, ["body", "module", "id"]) --- self.fail(f"No module events matching name={name}") --- --- def module_events(self, id): --- """Finds all module events by identifier.""" --- return [ --- event --- for event in self.dap_server.module_events --- if self.get_dict_value(event, ["body", "module", "id"]) == id --- ] --- --- def module_reasons(self, events): --- """Returns the list of 'reason' values from the given events.""" --- return [event["body"]["reason"] for event in events] --- -- @skipIfWindows -- def test_module_event(self): --- """ --- Test that module events are fired on target load and when the list of --- dynamic libraries updates while running. --- """ -- program = self.getBuildArtifact("a.out") -- self.build_and_launch(program) --- # We can analyze the order of events after the process exits. --- self.continue_to_exit() -- --- a_out_id = self.lookup_module_id("a.out") --- a_out_events = self.module_events(id=a_out_id) --+ source = "main.cpp" --+ breakpoint1_line = line_number(source, "// breakpoint 1") --+ breakpoint2_line = line_number(source, "// breakpoint 2") --+ breakpoint3_line = line_number(source, "// breakpoint 3") -- --- self.assertIn( --- "new", --- self.module_reasons(a_out_events), --- "Expected a.out to load during the debug session.", --+ breakpoint_ids = self.set_source_breakpoints( --+ source, [breakpoint1_line, breakpoint2_line, breakpoint3_line] -- ) --+ self.continue_to_breakpoints(breakpoint_ids) -- --- libother_id = self.lookup_module_id( --- "libother." # libother.so or libother.dylib based on OS. --- ) --- libother_events = self.module_events(id=libother_id) --- self.assertEqual( --- self.module_reasons(libother_events), --- ["new", "removed"], --- "Expected libother to be loaded then unloaded during the debug session.", --- ) --+ # We're now stopped at breakpoint 1 before the dlopen. Flush all the module events. --+ event = self.dap_server.wait_for_event(["module"]) --+ while event is not None: --+ event = self.dap_server.wait_for_event(["module"]) --+ --+ # Continue to the second breakpoint, before the dlclose. --+ self.continue_to_breakpoints(breakpoint_ids) --+ --+ # Make sure we got a module event for libother. --+ event = self.dap_server.wait_for_event(["module"]) --+ self.assertIsNotNone(event, "didn't get a module event") --+ module_name = event["body"]["module"]["name"] --+ module_id = event["body"]["module"]["id"] --+ self.assertEqual(event["body"]["reason"], "new") --+ self.assertIn("libother", module_name) --+ --+ # Continue to the third breakpoint, after the dlclose. --+ self.continue_to_breakpoints(breakpoint_ids) --+ --+ # Make sure we got a module event for libother. --+ event = self.dap_server.wait_for_event(["module"]) --+ self.assertIsNotNone(event, "didn't get a module event") --+ reason = event["body"]["reason"] --+ self.assertEqual(reason, "removed") --+ self.assertEqual(event["body"]["module"]["id"], module_id) --+ --+ # The removed module event should omit everything but the module id and name --+ # as they are required fields. --+ module_data = event["body"]["module"] --+ required_keys = ["id", "name"] --+ self.assertListEqual(list(module_data.keys()), required_keys) --+ self.assertEqual(module_data["name"], "", "expects empty name.") --+ --+ self.continue_to_exit() --diff -ruN --strip-trailing-cr a/lldb/test/API/tools/lldb-dap/restart/TestDAP_restart_console.py b/lldb/test/API/tools/lldb-dap/restart/TestDAP_restart_console.py ----- a/lldb/test/API/tools/lldb-dap/restart/TestDAP_restart_console.py --+++ b/lldb/test/API/tools/lldb-dap/restart/TestDAP_restart_console.py --@@ -30,11 +30,7 @@ -- if reason == "entry": -- seen_stopped_event += 1 -- --- self.assertEqual( --- seen_stopped_event, --- 1, --- f"expect only one stopped entry event in {stopped_events}", --- ) --+ self.assertEqual(seen_stopped_event, 1, "expect only one stopped entry event.") -- -- @skipIfAsan -- @skipIfWindows --@@ -96,13 +92,11 @@ -- self.build_and_launch(program, console="integratedTerminal", stopOnEntry=True) -- [bp_main] = self.set_function_breakpoints(["main"]) -- --- self.dap_server.request_configurationDone() --- stopped_threads = list(self.dap_server.thread_stop_reasons.values()) --+ self.dap_server.request_continue() # sends configuration done --+ stopped_events = self.dap_server.wait_for_stopped() -- # We should be stopped at the entry point. --- self.assertEqual( --- len(stopped_threads), 1, "Expected the main thread to be stopped on entry." --- ) --- self.assertEqual(stopped_threads[0]["reason"], "entry") --+ self.assertGreaterEqual(len(stopped_events), 0, "expect stopped events") --+ self.verify_stopped_on_entry(stopped_events) -- -- # Then, if we continue, we should hit the breakpoint at main. -- self.dap_server.request_continue() --@@ -111,12 +105,8 @@ -- # Restart and check that we still get a stopped event before reaching -- # main. -- self.dap_server.request_restart() --- stopped_threads = list(self.dap_server.thread_stop_reasons.values()) --- # We should be stopped at the entry point. --- self.assertEqual( --- len(stopped_threads), 1, "Expected the main thread to be stopped on entry." --- ) --- self.assertEqual(stopped_threads[0]["reason"], "entry") --+ stopped_events = self.dap_server.wait_for_stopped() --+ self.verify_stopped_on_entry(stopped_events) -- -- # continue to main -- self.dap_server.request_continue() --diff -ruN --strip-trailing-cr a/lldb/test/API/tools/lldb-dap/send-event/TestDAP_sendEvent.py b/lldb/test/API/tools/lldb-dap/send-event/TestDAP_sendEvent.py ----- a/lldb/test/API/tools/lldb-dap/send-event/TestDAP_sendEvent.py --+++ b/lldb/test/API/tools/lldb-dap/send-event/TestDAP_sendEvent.py --@@ -32,7 +32,7 @@ -- ], -- ) -- self.set_source_breakpoints(source, [breakpoint_line]) --- self.do_continue() --+ self.continue_to_next_stop() -- -- custom_event = self.dap_server.wait_for_event( -- filter=["my-custom-event-no-body"] diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl -index d012c7c..0859165 100644 +index 0859165..32e6a7a 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" -- LLVM_COMMIT = "4c46ae394841521914e0e8575e7619a1c0d1149d" -- LLVM_SHA256 = "55c824ce2e1a3afafa4e108532f4eff9f194d20d44d1c5ddc6107bb23d7c6c2a" -+ LLVM_COMMIT = "22079e3f3698d5c367c7b67f63de8c838791ae76" -+ LLVM_SHA256 = "d5616e9c0f4b761f13da5535a0d9ec94acf4ae5226bbec3e47ac2929ea60cac2" +- LLVM_COMMIT = "22079e3f3698d5c367c7b67f63de8c838791ae76" +- LLVM_SHA256 = "d5616e9c0f4b761f13da5535a0d9ec94acf4ae5226bbec3e47ac2929ea60cac2" ++ LLVM_COMMIT = "42a8ff877d47131ecb1280a1cc7e5e3c3bca6952" ++ LLVM_SHA256 = "f768c5c3b987f68318b8ab3dd4530e54988dfe7d6bfb9b7c9c96acf503367d50" tf_http_archive( name = name, -diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch -index c445b82..a42cbd7 100755 ---- a/third_party/stablehlo/temporary.patch -+++ b/third_party/stablehlo/temporary.patch -@@ -1,3 +1,88 @@ -+diff --ruN a/stablehlo/stablehlo/conversions/linalg/tests/miscellaneous.mlir b/stablehlo/stablehlo/conversions/linalg/tests/miscellaneous.mlir -+--- stablehlo/stablehlo/conversions/linalg/tests/miscellaneous.mlir -++++ stablehlo/stablehlo/conversions/linalg/tests/miscellaneous.mlir -+@@ -768,7 +768,7 @@ -+ // CHECK-PRIMITIVE: %[[MAP:.+]] = linalg.map -+ // CHECK-PRIMITIVE-SAME: ins(%[[ARG0]], %[[ARG1]] -+ // CHECK-PRIMITIVE-SAME: outs(%[[INIT]] : tensor) -+-// CHECK-PRIMITIVE-NEXT: (%[[A:.+]]: complex, %[[B:.+]]: complex) { -++// CHECK-PRIMITIVE-NEXT: (%[[A:.+]]: complex, %[[B:.+]]: complex, %{{.+}}: i1) { -+ // CHECK-PRIMITIVE: %[[RE1:.+]] = complex.re %[[A]] : complex -+ // CHECK-PRIMITIVE: %[[RE2:.+]] = complex.re %[[B]] : complex -+ // CHECK-PRIMITIVE: %[[CMP:.+]] = arith.cmpf oeq, %[[RE1]], %[[RE2]] : f32 -+diff --ruN a/stablehlo/stablehlo/conversions/linalg/tests/pointwise.mlir b/stablehlo/stablehlo/conversions/linalg/tests/pointwise.mlir -+--- stablehlo/stablehlo/conversions/linalg/tests/pointwise.mlir -++++ stablehlo/stablehlo/conversions/linalg/tests/pointwise.mlir -+@@ -714,7 +714,7 @@ -+ // CHECK-PRIMITIVE: linalg.map -+ // CHECK-PRIMITIVE-SAME: ins( -+ // CHECK-PRIMITIVE-SAME: outs( -+-// CHECK-PRIMITIVE-NEXT: (%[[LHS_IN:[a-zA-Z0-9]*]]: bf16, %[[RHS_IN:.*]]: bf16) { -++// CHECK-PRIMITIVE-NEXT: (%[[LHS_IN:[a-zA-Z0-9]*]]: bf16, %[[RHS_IN:.*]]: bf16, %[[RESULT_OUT:.*]]: i1) { -+ // CHECK-PRIMITIVE-NEXT: %[[LHS_INT:.*]] = arith.bitcast %[[LHS_IN]] : bf16 to i16 -+ // CHECK-PRIMITIVE-NEXT: %[[LHS_CMP:.*]] = arith.cmpi slt, %[[LHS_INT]], %[[C0]] : i16 -+ // CHECK-PRIMITIVE-NEXT: %[[LHS_SUB:.*]] = arith.subi %[[C32767]], %[[LHS_INT]] : i16 -+@@ -937,7 +937,7 @@ -+ // CHECK-PRIMITIVE-SAME: ins(%[[LHS]], %[[RHS]] : tensor<2x?xf32>, tensor<2x?xf32>) -+ // CHECK-PRIMITIVE-SAME: outs(%[[DST]] : tensor<2x?xf32>) -+ // CHECK-PRIMITIVE-SAME: {someattr} -+-// CHECK-PRIMITIVE: (%[[LHS_:.*]]: f32, %[[RHS_:.*]]: f32) { -++// CHECK-PRIMITIVE: (%[[LHS_:.*]]: f32, %[[RHS_:.*]]: f32, %[[RESULT_OUT:.*]]: f32) { -+ // CHECK-PRIMITIVE: %[[RES:.*]] = arith.select %[[PRED_ELEM]], %[[LHS_]], %[[RHS_]] : f32 -+ // CHECK-PRIMITIVE: linalg.yield %[[RES]] -+ -+@@ -978,7 +978,7 @@ -+ // CHECK-PRIMITIVE-SAME: ins(%[[LHS]], %[[RHS]] : tensor<2x?xf32>, tensor<2x?xf32>) -+ // CHECK-PRIMITIVE-SAME: outs(%[[DST]] : tensor<2x?xf32>) -+ // CHECK-PRIMITIVE-SAME: {someattr} -+-// CHECK-PRIMITIVE: (%[[LHS_:.*]]: f32, %[[RHS_:.*]]: f32) { -++// CHECK-PRIMITIVE: (%[[LHS_:.*]]: f32, %[[RHS_:.*]]: f32, %[[RESULT_OUT:.*]]: f32) { -+ // CHECK-PRIMITIVE: linalg.yield %[[LHS_]] -+ -+ // ----- -+@@ -1416,7 +1416,7 @@ -+ -+ // CHECK-PRIMITIVE: %[[INIT:.*]] = tensor.empty -+ // CHECK-PRIMITIVE: %[[RESULT:.*]] = linalg.map ins(%[[LB]], %[[X]], %[[UB]] : tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) outs(%[[INIT]] : tensor<4xf32>) -+-// CHECK-PRIMITIVE: (%[[SCALAR_LB:.*]]: f32, %[[SCALAR_X:.*]]: f32, %[[SCALAR_UB:.*]]: f32) -++// CHECK-PRIMITIVE: (%[[SCALAR_LB:.*]]: f32, %[[SCALAR_X:.*]]: f32, %[[SCALAR_UB:.*]]: f32, %[[RESULT_OUT:.*]]: f32) -+ // CHECK-PRIMITIVE: %[[MAX:.*]] = arith.maximumf %[[SCALAR_LB]], %[[SCALAR_X]] : f32 -+ // CHECK-PRIMITIVE: %[[MIN:.*]] = arith.minimumf %[[MAX]], %[[SCALAR_UB]] : f32 -+ // CHECK-PRIMITIVE: linalg.yield %[[MIN]] -+@@ -1478,7 +1478,7 @@ -+ // CHECK-PRIMITIVE-DAG: %[[SCALAR_LB:.*]] = tensor.extract %[[LB]] -+ // CHECK-PRIMITIVE-DAG: %[[SCALAR_UB:.*]] = tensor.extract %[[UB]] -+ // CHECK-PRIMITIVE: %[[RESULT:.*]] = linalg.map ins(%[[X]] : tensor) outs(%[[INIT]] : tensor) -+-// CHECK-PRIMITIVE: (%[[SCALAR_X:.*]]: f32) -++// CHECK-PRIMITIVE: (%[[SCALAR_X:.*]]: f32, %[[RESULT_OUT:.*]]: f32) -+ // CHECK-PRIMITIVE: %[[MAX:.*]] = arith.maximumf %[[SCALAR_LB]], %[[SCALAR_X]] : f32 -+ // CHECK-PRIMITIVE: %[[MIN:.*]] = arith.minimumf %[[MAX]], %[[SCALAR_UB]] : f32 -+ // CHECK-PRIMITIVE: linalg.yield %[[MIN]] -+@@ -1554,7 +1554,7 @@ -+ // CHECK: linalg.yield %[[V_NOT]] : i32 -+ // CHECK-PRIMITIVE: %[[CST_N1:.+]] = arith.constant -1 : i32 -+ // CHECK-PRIMITIVE: linalg.map -+- // CHECK-PRIMITIVE: (%[[IN:.+]]: i32) -++ // CHECK-PRIMITIVE: (%[[IN:.+]]: i32, %[[RESULT_OUT:.+]]: i32) -+ // CHECK-PRIMITIVE: %[[V_NOT:.+]] = arith.xori %[[IN]], %[[CST_N1]] : i32 -+ // CHECK-PRIMITIVE: linalg.yield %[[V_NOT]] : i32 -+ %0 = "stablehlo.not"(%arg) : (tensor<2x2xi32>) -> tensor<2x2xi32> -+diff --ruN a/stablehlo/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp b/stablehlo/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp -+--- stablehlo/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp -++++ stablehlo/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp -+@@ -1748,6 +1748,12 @@ -+ -+ rewriter.applySignatureConversion(®ion.front(), signatureConverter, -+ getTypeConverter()); -++ auto& blocks = linalgOp.getMapper().getBlocks(); -++ if (blocks.empty()) { -++ return rewriter.notifyMatchFailure(op, "expected at least one block"); -++ } -++ blocks.front().addArgument(resultType.getElementType(), loc); -++ -+ auto result = rewriter.createOrFold(loc, resultType, -+ linalgOp.getResults()); -+ rewriter.replaceOp(op, result); - diff --ruN a/stablehlo/stablehlo/conversions/linalg/transforms/StablehloToArith.cpp b/stablehlo/stablehlo/conversions/linalg/transforms/StablehloToArith.cpp - --- stablehlo/stablehlo/conversions/linalg/transforms/StablehloToArith.cpp - +++ stablehlo/stablehlo/conversions/linalg/transforms/StablehloToArith.cpp diff --git a/third_party/xla/third_party/shardy/workspace.bzl b/third_party/xla/third_party/shardy/workspace.bzl index ae1ccc95f0e..56ecbe39fb7 100644 --- a/third_party/xla/third_party/shardy/workspace.bzl +++ b/third_party/xla/third_party/shardy/workspace.bzl @@ -3,8 +3,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): - SHARDY_COMMIT = "f1c951d74cf5c67b9e0f776e21fe2316d5c69f37" - SHARDY_SHA256 = "0b8b96710a2f2eec4581186e4e773aa4c4cfe6ae5e9681b7803e9b8336ead2f7" + SHARDY_COMMIT = "e269b4c1968c930518c42c02bfdcdf0d921793de" + SHARDY_SHA256 = "bdf22ae5d5a1ecacdca762da892e2291a7f82ddc42a23b1ca096dadb490d6068" tf_http_archive( name = "shardy",