Updates LLVM usage to match
[42a8ff877d47](https://github.com/llvm/llvm-project/commit/42a8ff877d47)

PiperOrigin-RevId: 826574010
This commit is contained in:
A. Unique TensorFlower 2025-10-31 11:53:16 -07:00 committed by TensorFlower Gardener
parent 6ff7f9c87f
commit e0f6a6c7f3
3 changed files with 9 additions and 713 deletions

View File

@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive")
def repo(name):
"""Imports LLVM."""
LLVM_COMMIT = "22079e3f3698d5c367c7b67f63de8c838791ae76"
LLVM_SHA256 = "d5616e9c0f4b761f13da5535a0d9ec94acf4ae5226bbec3e47ac2929ea60cac2"
LLVM_COMMIT = "42a8ff877d47131ecb1280a1cc7e5e3c3bca6952"
LLVM_SHA256 = "f768c5c3b987f68318b8ab3dd4530e54988dfe7d6bfb9b7c9c96acf503367d50"
tf_http_archive(
name = name,

View File

@ -1,719 +1,15 @@
diff --git a/shardy/integrations/python/jax/mpmd/ops.py b/shardy/integrations/python/jax/mpmd/ops.py
index 60ae866..ea144d5 100644
--- a/shardy/integrations/python/jax/mpmd/ops.py
+++ b/shardy/integrations/python/jax/mpmd/ops.py
@@ -29,13 +29,11 @@ from jax._src import util
from jax._src.interpreters import ad as internal_ad
from jax._src.interpreters import batching as internal_batching
from jax._src.interpreters import partial_eval as internal_pe
+import jax._src.lib.mlir.dialects as jax_mlir_dialects
import jax.extend as jex
from jax.extend import linear_util as lu
from jax.extend import source_info_util as siu
from jax.extend.core import primitives
-from jax.extend.mlir import ir
-from jax.extend.mlir.dialects import func as func_dialect
-from jax.extend.mlir.dialects import mpmd
from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters import mlir as jax_mlir
@@ -45,6 +43,10 @@ import jaxtyping
from shardy.integrations.python.jax.mpmd import utils
+ir = jax_mlir.ir
+mpmd = jax_mlir_dialects.mpmd
+func_dialect = jax_mlir_dialects.func
+
PyTree = jaxtyping.PyTree
X = TypeVar('X')
Y = TypeVar('Y')
diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch
index 4117b0a..509398d 100644
--- a/third_party/llvm/generated.patch
+++ b/third_party/llvm/generated.patch
@@ -1,576 +1 @@
Auto generated patch. Do not edit or delete it, even if empty.
-diff -ruN --strip-trailing-cr a/lldb/packages/Python/lldbsuite/test/tools/lldb-dap/dap_server.py b/lldb/packages/Python/lldbsuite/test/tools/lldb-dap/dap_server.py
---- a/lldb/packages/Python/lldbsuite/test/tools/lldb-dap/dap_server.py
-+++ b/lldb/packages/Python/lldbsuite/test/tools/lldb-dap/dap_server.py
-@@ -10,8 +10,8 @@
- import subprocess
- import signal
- import sys
-+import threading
- import warnings
--import selectors
- import time
- from typing import (
- Any,
-@@ -139,6 +139,35 @@
- outfile.write("\n")
-
-
-+def read_packet(
-+ f: IO[bytes], trace_file: Optional[IO[str]] = None
-+) -> Optional[ProtocolMessage]:
-+ """Decode a JSON packet that starts with the content length and is
-+ followed by the JSON bytes from a file 'f'. Returns None on EOF.
-+ """
-+ line = f.readline().decode("utf-8")
-+ if len(line) == 0:
-+ return None # EOF.
-+
-+ # Watch for line that starts with the prefix
-+ prefix = "Content-Length: "
-+ if line.startswith(prefix):
-+ # Decode length of JSON bytes
-+ length = int(line[len(prefix) :])
-+ # Skip empty line
-+ separator = f.readline().decode()
-+ if separator != "":
-+ Exception("malformed DAP content header, unexpected line: " + separator)
-+ # Read JSON bytes
-+ json_str = f.read(length).decode()
-+ if trace_file:
-+ trace_file.write("from adapter:\n%s\n" % (json_str))
-+ # Decode the JSON bytes into a python dictionary
-+ return json.loads(json_str)
-+
-+ raise Exception("unexpected malformed message from lldb-dap: " + line)
-+
-+
- def packet_type_is(packet, packet_type):
- return "type" in packet and packet["type"] == packet_type
-
-@@ -170,8 +199,16 @@
- self.log_file = log_file
- self.send = send
- self.recv = recv
-- self.selector = selectors.DefaultSelector()
-- self.selector.register(recv, selectors.EVENT_READ)
-+
-+ # Packets that have been received and processed but have not yet been
-+ # requested by a test case.
-+ self._pending_packets: List[Optional[ProtocolMessage]] = []
-+ # Received packets that have not yet been processed.
-+ self._recv_packets: List[Optional[ProtocolMessage]] = []
-+ # Used as a mutex for _recv_packets and for notify when _recv_packets
-+ # changes.
-+ self._recv_condition = threading.Condition()
-+ self._recv_thread = threading.Thread(target=self._read_packet_thread)
-
- # session state
- self.init_commands = init_commands
-@@ -197,6 +234,9 @@
- # keyed by breakpoint id
- self.resolved_breakpoints: dict[str, Breakpoint] = {}
-
-+ # trigger enqueue thread
-+ self._recv_thread.start()
-+
- @classmethod
- def encode_content(cls, s: str) -> bytes:
- return ("Content-Length: %u\r\n\r\n%s" % (len(s), s)).encode("utf-8")
-@@ -212,46 +252,17 @@
- f"seq mismatch in response {command['seq']} != {response['request_seq']}"
- )
-
-- def _read_packet(
-- self,
-- timeout: float = DEFAULT_TIMEOUT,
-- ) -> Optional[ProtocolMessage]:
-- """Decode a JSON packet that starts with the content length and is
-- followed by the JSON bytes from self.recv. Returns None on EOF.
-- """
--
-- ready = self.selector.select(timeout)
-- if not ready:
-- warnings.warn(
-- "timeout occurred waiting for a packet, check if the test has a"
-- " negative assertion and see if it can be inverted.",
-- stacklevel=4,
-- )
-- return None # timeout
--
-- line = self.recv.readline().decode("utf-8")
-- if len(line) == 0:
-- return None # EOF.
--
-- # Watch for line that starts with the prefix
-- prefix = "Content-Length: "
-- if line.startswith(prefix):
-- # Decode length of JSON bytes
-- length = int(line[len(prefix) :])
-- # Skip empty line
-- separator = self.recv.readline().decode()
-- if separator != "":
-- Exception("malformed DAP content header, unexpected line: " + separator)
-- # Read JSON bytes
-- json_str = self.recv.read(length).decode()
-- if self.trace_file:
-- self.trace_file.write(
-- "%s from adapter:\n%s\n" % (time.time(), json_str)
-- )
-- # Decode the JSON bytes into a python dictionary
-- return json.loads(json_str)
--
-- raise Exception("unexpected malformed message from lldb-dap: " + line)
-+ def _read_packet_thread(self):
-+ try:
-+ while True:
-+ packet = read_packet(self.recv, trace_file=self.trace_file)
-+ # `packet` will be `None` on EOF. We want to pass it down to
-+ # handle_recv_packet anyway so the main thread can handle unexpected
-+ # termination of lldb-dap and stop waiting for new packets.
-+ if not self._handle_recv_packet(packet):
-+ break
-+ finally:
-+ dump_dap_log(self.log_file)
-
- def get_modules(
- self, start_module: Optional[int] = None, module_count: Optional[int] = None
-@@ -299,6 +310,34 @@
- output += self.get_output(category, clear=clear)
- return output
-
-+ def _enqueue_recv_packet(self, packet: Optional[ProtocolMessage]):
-+ with self.recv_condition:
-+ self.recv_packets.append(packet)
-+ self.recv_condition.notify()
-+
-+ def _handle_recv_packet(self, packet: Optional[ProtocolMessage]) -> bool:
-+ """Handles an incoming packet.
-+
-+ Called by the read thread that is waiting for all incoming packets
-+ to store the incoming packet in "self._recv_packets" in a thread safe
-+ way. This function will then signal the "self._recv_condition" to
-+ indicate a new packet is available.
-+
-+ Args:
-+ packet: A new packet to store.
-+
-+ Returns:
-+ True if the caller should keep calling this function for more
-+ packets.
-+ """
-+ with self._recv_condition:
-+ self._recv_packets.append(packet)
-+ self._recv_condition.notify()
-+ # packet is None on EOF
-+ return packet is not None and not (
-+ packet["type"] == "response" and packet["command"] == "disconnect"
-+ )
-+
- def _recv_packet(
- self,
- *,
-@@ -322,34 +361,46 @@
- The first matching packet for the given predicate, if specified,
- otherwise None.
- """
-- deadline = time.time() + timeout
--
-- while time.time() < deadline:
-- packet = self._read_packet(timeout=deadline - time.time())
-- if packet is None:
-- return None
-- self._process_recv_packet(packet)
-- if not predicate or predicate(packet):
-- return packet
-+ assert (
-+ threading.current_thread != self._recv_thread
-+ ), "Must not be called from the _recv_thread"
-+
-+ def process_until_match():
-+ self._process_recv_packets()
-+ for i, packet in enumerate(self._pending_packets):
-+ if packet is None:
-+ # We need to return a truthy value to break out of the
-+ # wait_for, use `EOFError` as an indicator of EOF.
-+ return EOFError()
-+ if predicate and predicate(packet):
-+ self._pending_packets.pop(i)
-+ return packet
-+
-+ with self._recv_condition:
-+ packet = self._recv_condition.wait_for(process_until_match, timeout)
-+ return None if isinstance(packet, EOFError) else packet
-
-- def _process_recv_packet(self, packet) -> None:
-+ def _process_recv_packets(self) -> None:
- """Process received packets, updating the session state."""
-- if packet and ("seq" not in packet or packet["seq"] == 0):
-- warnings.warn(
-- f"received a malformed packet, expected 'seq != 0' for {packet!r}"
-- )
-- # Handle events that may modify any stateful properties of
-- # the DAP session.
-- if packet and packet["type"] == "event":
-- self._handle_event(packet)
-- elif packet and packet["type"] == "request":
-- # Handle reverse requests and keep processing.
-- self._handle_reverse_request(packet)
-+ with self._recv_condition:
-+ for packet in self._recv_packets:
-+ if packet and ("seq" not in packet or packet["seq"] == 0):
-+ warnings.warn(
-+ f"received a malformed packet, expected 'seq != 0' for {packet!r}"
-+ )
-+ # Handle events that may modify any stateful properties of
-+ # the DAP session.
-+ if packet and packet["type"] == "event":
-+ self._handle_event(packet)
-+ elif packet and packet["type"] == "request":
-+ # Handle reverse requests and keep processing.
-+ self._handle_reverse_request(packet)
-+ # Move the packet to the pending queue.
-+ self._pending_packets.append(packet)
-+ self._recv_packets.clear()
-
- def _handle_event(self, packet: Event) -> None:
- """Handle any events that modify debug session state we track."""
-- self.events.append(packet)
--
- event = packet["event"]
- body: Optional[Dict] = packet.get("body", None)
-
-@@ -402,8 +453,6 @@
- self.invalidated_event = packet
- elif event == "memory":
- self.memory_event = packet
-- elif event == "module":
-- self.module_events.append(packet)
-
- def _handle_reverse_request(self, request: Request) -> None:
- if request in self.reverse_requests:
-@@ -472,14 +521,18 @@
-
- Returns the seq number of the request.
- """
-- packet["seq"] = self.sequence
-- self.sequence += 1
-+ # Set the seq for requests.
-+ if packet["type"] == "request":
-+ packet["seq"] = self.sequence
-+ self.sequence += 1
-+ else:
-+ packet["seq"] = 0
-
- # Encode our command dictionary as a JSON string
- json_str = json.dumps(packet, separators=(",", ":"))
-
- if self.trace_file:
-- self.trace_file.write("%s to adapter:\n%s\n" % (time.time(), json_str))
-+ self.trace_file.write("to adapter:\n%s\n" % (json_str))
-
- length = len(json_str)
- if length > 0:
-@@ -860,8 +913,6 @@
- if restartArguments:
- command_dict["arguments"] = restartArguments
-
-- # Clear state, the process is about to restart...
-- self._process_continued(True)
- response = self._send_recv(command_dict)
- # Caller must still call wait_for_stopped.
- return response
-@@ -1428,10 +1479,8 @@
-
- def terminate(self):
- self.send.close()
-- self.recv.close()
-- self.selector.close()
-- if self.log_file:
-- dump_dap_log(self.log_file)
-+ if self._recv_thread.is_alive():
-+ self._recv_thread.join()
-
- def request_setInstructionBreakpoints(self, memory_reference=[]):
- breakpoints = []
-@@ -1528,7 +1577,6 @@
- stdout=subprocess.PIPE,
- stderr=sys.stderr,
- env=adapter_env,
-- bufsize=0,
- )
-
- if connection is None:
-diff -ruN --strip-trailing-cr a/lldb/packages/Python/lldbsuite/test/tools/lldb-dap/lldbdap_testcase.py b/lldb/packages/Python/lldbsuite/test/tools/lldb-dap/lldbdap_testcase.py
---- a/lldb/packages/Python/lldbsuite/test/tools/lldb-dap/lldbdap_testcase.py
-+++ b/lldb/packages/Python/lldbsuite/test/tools/lldb-dap/lldbdap_testcase.py
-@@ -416,7 +416,7 @@
- return self.dap_server.wait_for_stopped()
-
- def continue_to_breakpoint(self, breakpoint_id: str):
-- self.continue_to_breakpoints([breakpoint_id])
-+ self.continue_to_breakpoints((breakpoint_id))
-
- def continue_to_breakpoints(self, breakpoint_ids):
- self.do_continue()
-diff -ruN --strip-trailing-cr a/lldb/test/API/tools/lldb-dap/breakpoint-events/TestDAP_breakpointEvents.py b/lldb/test/API/tools/lldb-dap/breakpoint-events/TestDAP_breakpointEvents.py
---- a/lldb/test/API/tools/lldb-dap/breakpoint-events/TestDAP_breakpointEvents.py
-+++ b/lldb/test/API/tools/lldb-dap/breakpoint-events/TestDAP_breakpointEvents.py
-@@ -81,20 +81,24 @@
- breakpoint["verified"], "expect foo breakpoint to not be verified"
- )
-
-+ # Flush the breakpoint events.
-+ self.dap_server.wait_for_breakpoint_events()
-+
- # Continue to the breakpoint
-- self.continue_to_breakpoint(foo_bp_id)
-- self.continue_to_next_stop() # foo_bp2
-- self.continue_to_breakpoint(main_bp_id)
-- self.continue_to_exit()
-+ self.continue_to_breakpoints(dap_breakpoint_ids)
-
-- bp_events = [e for e in self.dap_server.events if e["event"] == "breakpoint"]
-+ verified_breakpoint_ids = []
-+ unverified_breakpoint_ids = []
-+ for breakpoint_event in self.dap_server.wait_for_breakpoint_events():
-+ breakpoint = breakpoint_event["body"]["breakpoint"]
-+ id = breakpoint["id"]
-+ if breakpoint["verified"]:
-+ verified_breakpoint_ids.append(id)
-+ else:
-+ unverified_breakpoint_ids.append(id)
-
-- main_bp_events = [
-- e for e in bp_events if e["body"]["breakpoint"]["id"] == main_bp_id
-- ]
-- foo_bp_events = [
-- e for e in bp_events if e["body"]["breakpoint"]["id"] == foo_bp_id
-- ]
-+ self.assertIn(main_bp_id, unverified_breakpoint_ids)
-+ self.assertIn(foo_bp_id, unverified_breakpoint_ids)
-
-- self.assertTrue(main_bp_events)
-- self.assertTrue(foo_bp_events)
-+ self.assertIn(main_bp_id, verified_breakpoint_ids)
-+ self.assertIn(foo_bp_id, verified_breakpoint_ids)
-diff -ruN --strip-trailing-cr a/lldb/test/API/tools/lldb-dap/launch/TestDAP_launch.py b/lldb/test/API/tools/lldb-dap/launch/TestDAP_launch.py
---- a/lldb/test/API/tools/lldb-dap/launch/TestDAP_launch.py
-+++ b/lldb/test/API/tools/lldb-dap/launch/TestDAP_launch.py
-@@ -156,7 +156,6 @@
- self.build_and_launch(
- program, debuggerRoot=program_parent_dir, initCommands=commands
- )
-- self.continue_to_exit()
- output = self.get_console()
- self.assertTrue(output and len(output) > 0, "expect console output")
- lines = output.splitlines()
-@@ -172,6 +171,7 @@
- % (program_parent_dir, line[len(prefix) :]),
- )
- self.assertTrue(found, "verified lldb-dap working directory")
-+ self.continue_to_exit()
-
- def test_sourcePath(self):
- """
-diff -ruN --strip-trailing-cr a/lldb/test/API/tools/lldb-dap/module/TestDAP_module.py b/lldb/test/API/tools/lldb-dap/module/TestDAP_module.py
---- a/lldb/test/API/tools/lldb-dap/module/TestDAP_module.py
-+++ b/lldb/test/API/tools/lldb-dap/module/TestDAP_module.py
-@@ -64,18 +64,19 @@
- self.assertEqual(program, program_module["path"])
- self.assertIn("addressRange", program_module)
-
-- self.continue_to_exit()
--
- # Collect all the module names we saw as events.
- module_new_names = []
- module_changed_names = []
-- for module_event in self.dap_server.module_events:
-+ module_event = self.dap_server.wait_for_event(["module"])
-+ while module_event is not None:
- reason = module_event["body"]["reason"]
- if reason == "new":
- module_new_names.append(module_event["body"]["module"]["name"])
- elif reason == "changed":
- module_changed_names.append(module_event["body"]["module"]["name"])
-
-+ module_event = self.dap_server.wait_for_event(["module"])
-+
- # Make sure we got an event for every active module.
- self.assertNotEqual(len(module_new_names), 0)
- for module in active_modules:
-@@ -85,6 +86,7 @@
- # symbols got added.
- self.assertNotEqual(len(module_changed_names), 0)
- self.assertIn(program_module["name"], module_changed_names)
-+ self.continue_to_exit()
-
- @skipIfWindows
- def test_modules(self):
-diff -ruN --strip-trailing-cr a/lldb/test/API/tools/lldb-dap/module-event/TestDAP_module_event.py b/lldb/test/API/tools/lldb-dap/module-event/TestDAP_module_event.py
---- a/lldb/test/API/tools/lldb-dap/module-event/TestDAP_module_event.py
-+++ b/lldb/test/API/tools/lldb-dap/module-event/TestDAP_module_event.py
-@@ -1,58 +1,58 @@
--"""
--Test 'module' events for dynamically loaded libraries.
--"""
--
-+import dap_server
- from lldbsuite.test.decorators import *
- from lldbsuite.test.lldbtest import *
-+from lldbsuite.test import lldbutil
- import lldbdap_testcase
-+import re
-
-
- class TestDAP_module_event(lldbdap_testcase.DAPTestCaseBase):
-- def lookup_module_id(self, name):
-- """Returns the identifier for the first module event starting with the given name."""
-- for event in self.dap_server.module_events:
-- if self.get_dict_value(event, ["body", "module", "name"]).startswith(name):
-- return self.get_dict_value(event, ["body", "module", "id"])
-- self.fail(f"No module events matching name={name}")
--
-- def module_events(self, id):
-- """Finds all module events by identifier."""
-- return [
-- event
-- for event in self.dap_server.module_events
-- if self.get_dict_value(event, ["body", "module", "id"]) == id
-- ]
--
-- def module_reasons(self, events):
-- """Returns the list of 'reason' values from the given events."""
-- return [event["body"]["reason"] for event in events]
--
- @skipIfWindows
- def test_module_event(self):
-- """
-- Test that module events are fired on target load and when the list of
-- dynamic libraries updates while running.
-- """
- program = self.getBuildArtifact("a.out")
- self.build_and_launch(program)
-- # We can analyze the order of events after the process exits.
-- self.continue_to_exit()
-
-- a_out_id = self.lookup_module_id("a.out")
-- a_out_events = self.module_events(id=a_out_id)
-+ source = "main.cpp"
-+ breakpoint1_line = line_number(source, "// breakpoint 1")
-+ breakpoint2_line = line_number(source, "// breakpoint 2")
-+ breakpoint3_line = line_number(source, "// breakpoint 3")
-
-- self.assertIn(
-- "new",
-- self.module_reasons(a_out_events),
-- "Expected a.out to load during the debug session.",
-+ breakpoint_ids = self.set_source_breakpoints(
-+ source, [breakpoint1_line, breakpoint2_line, breakpoint3_line]
- )
-+ self.continue_to_breakpoints(breakpoint_ids)
-
-- libother_id = self.lookup_module_id(
-- "libother." # libother.so or libother.dylib based on OS.
-- )
-- libother_events = self.module_events(id=libother_id)
-- self.assertEqual(
-- self.module_reasons(libother_events),
-- ["new", "removed"],
-- "Expected libother to be loaded then unloaded during the debug session.",
-- )
-+ # We're now stopped at breakpoint 1 before the dlopen. Flush all the module events.
-+ event = self.dap_server.wait_for_event(["module"])
-+ while event is not None:
-+ event = self.dap_server.wait_for_event(["module"])
-+
-+ # Continue to the second breakpoint, before the dlclose.
-+ self.continue_to_breakpoints(breakpoint_ids)
-+
-+ # Make sure we got a module event for libother.
-+ event = self.dap_server.wait_for_event(["module"])
-+ self.assertIsNotNone(event, "didn't get a module event")
-+ module_name = event["body"]["module"]["name"]
-+ module_id = event["body"]["module"]["id"]
-+ self.assertEqual(event["body"]["reason"], "new")
-+ self.assertIn("libother", module_name)
-+
-+ # Continue to the third breakpoint, after the dlclose.
-+ self.continue_to_breakpoints(breakpoint_ids)
-+
-+ # Make sure we got a module event for libother.
-+ event = self.dap_server.wait_for_event(["module"])
-+ self.assertIsNotNone(event, "didn't get a module event")
-+ reason = event["body"]["reason"]
-+ self.assertEqual(reason, "removed")
-+ self.assertEqual(event["body"]["module"]["id"], module_id)
-+
-+ # The removed module event should omit everything but the module id and name
-+ # as they are required fields.
-+ module_data = event["body"]["module"]
-+ required_keys = ["id", "name"]
-+ self.assertListEqual(list(module_data.keys()), required_keys)
-+ self.assertEqual(module_data["name"], "", "expects empty name.")
-+
-+ self.continue_to_exit()
-diff -ruN --strip-trailing-cr a/lldb/test/API/tools/lldb-dap/restart/TestDAP_restart_console.py b/lldb/test/API/tools/lldb-dap/restart/TestDAP_restart_console.py
---- a/lldb/test/API/tools/lldb-dap/restart/TestDAP_restart_console.py
-+++ b/lldb/test/API/tools/lldb-dap/restart/TestDAP_restart_console.py
-@@ -30,11 +30,7 @@
- if reason == "entry":
- seen_stopped_event += 1
-
-- self.assertEqual(
-- seen_stopped_event,
-- 1,
-- f"expect only one stopped entry event in {stopped_events}",
-- )
-+ self.assertEqual(seen_stopped_event, 1, "expect only one stopped entry event.")
-
- @skipIfAsan
- @skipIfWindows
-@@ -96,13 +92,11 @@
- self.build_and_launch(program, console="integratedTerminal", stopOnEntry=True)
- [bp_main] = self.set_function_breakpoints(["main"])
-
-- self.dap_server.request_configurationDone()
-- stopped_threads = list(self.dap_server.thread_stop_reasons.values())
-+ self.dap_server.request_continue() # sends configuration done
-+ stopped_events = self.dap_server.wait_for_stopped()
- # We should be stopped at the entry point.
-- self.assertEqual(
-- len(stopped_threads), 1, "Expected the main thread to be stopped on entry."
-- )
-- self.assertEqual(stopped_threads[0]["reason"], "entry")
-+ self.assertGreaterEqual(len(stopped_events), 0, "expect stopped events")
-+ self.verify_stopped_on_entry(stopped_events)
-
- # Then, if we continue, we should hit the breakpoint at main.
- self.dap_server.request_continue()
-@@ -111,12 +105,8 @@
- # Restart and check that we still get a stopped event before reaching
- # main.
- self.dap_server.request_restart()
-- stopped_threads = list(self.dap_server.thread_stop_reasons.values())
-- # We should be stopped at the entry point.
-- self.assertEqual(
-- len(stopped_threads), 1, "Expected the main thread to be stopped on entry."
-- )
-- self.assertEqual(stopped_threads[0]["reason"], "entry")
-+ stopped_events = self.dap_server.wait_for_stopped()
-+ self.verify_stopped_on_entry(stopped_events)
-
- # continue to main
- self.dap_server.request_continue()
-diff -ruN --strip-trailing-cr a/lldb/test/API/tools/lldb-dap/send-event/TestDAP_sendEvent.py b/lldb/test/API/tools/lldb-dap/send-event/TestDAP_sendEvent.py
---- a/lldb/test/API/tools/lldb-dap/send-event/TestDAP_sendEvent.py
-+++ b/lldb/test/API/tools/lldb-dap/send-event/TestDAP_sendEvent.py
-@@ -32,7 +32,7 @@
- ],
- )
- self.set_source_breakpoints(source, [breakpoint_line])
-- self.do_continue()
-+ self.continue_to_next_stop()
-
- custom_event = self.dap_server.wait_for_event(
- filter=["my-custom-event-no-body"]
diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl
index d012c7c..0859165 100644
index 0859165..32e6a7a 100644
--- a/third_party/llvm/workspace.bzl
+++ b/third_party/llvm/workspace.bzl
@@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive")
def repo(name):
"""Imports LLVM."""
- LLVM_COMMIT = "4c46ae394841521914e0e8575e7619a1c0d1149d"
- LLVM_SHA256 = "55c824ce2e1a3afafa4e108532f4eff9f194d20d44d1c5ddc6107bb23d7c6c2a"
+ LLVM_COMMIT = "22079e3f3698d5c367c7b67f63de8c838791ae76"
+ LLVM_SHA256 = "d5616e9c0f4b761f13da5535a0d9ec94acf4ae5226bbec3e47ac2929ea60cac2"
- LLVM_COMMIT = "22079e3f3698d5c367c7b67f63de8c838791ae76"
- LLVM_SHA256 = "d5616e9c0f4b761f13da5535a0d9ec94acf4ae5226bbec3e47ac2929ea60cac2"
+ LLVM_COMMIT = "42a8ff877d47131ecb1280a1cc7e5e3c3bca6952"
+ LLVM_SHA256 = "f768c5c3b987f68318b8ab3dd4530e54988dfe7d6bfb9b7c9c96acf503367d50"
tf_http_archive(
name = name,
diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch
index c445b82..a42cbd7 100755
--- a/third_party/stablehlo/temporary.patch
+++ b/third_party/stablehlo/temporary.patch
@@ -1,3 +1,88 @@
+diff --ruN a/stablehlo/stablehlo/conversions/linalg/tests/miscellaneous.mlir b/stablehlo/stablehlo/conversions/linalg/tests/miscellaneous.mlir
+--- stablehlo/stablehlo/conversions/linalg/tests/miscellaneous.mlir
++++ stablehlo/stablehlo/conversions/linalg/tests/miscellaneous.mlir
+@@ -768,7 +768,7 @@
+ // CHECK-PRIMITIVE: %[[MAP:.+]] = linalg.map
+ // CHECK-PRIMITIVE-SAME: ins(%[[ARG0]], %[[ARG1]]
+ // CHECK-PRIMITIVE-SAME: outs(%[[INIT]] : tensor<?xi1>)
+-// CHECK-PRIMITIVE-NEXT: (%[[A:.+]]: complex<f32>, %[[B:.+]]: complex<f32>) {
++// CHECK-PRIMITIVE-NEXT: (%[[A:.+]]: complex<f32>, %[[B:.+]]: complex<f32>, %{{.+}}: i1) {
+ // CHECK-PRIMITIVE: %[[RE1:.+]] = complex.re %[[A]] : complex<f32>
+ // CHECK-PRIMITIVE: %[[RE2:.+]] = complex.re %[[B]] : complex<f32>
+ // CHECK-PRIMITIVE: %[[CMP:.+]] = arith.cmpf oeq, %[[RE1]], %[[RE2]] : f32
+diff --ruN a/stablehlo/stablehlo/conversions/linalg/tests/pointwise.mlir b/stablehlo/stablehlo/conversions/linalg/tests/pointwise.mlir
+--- stablehlo/stablehlo/conversions/linalg/tests/pointwise.mlir
++++ stablehlo/stablehlo/conversions/linalg/tests/pointwise.mlir
+@@ -714,7 +714,7 @@
+ // CHECK-PRIMITIVE: linalg.map
+ // CHECK-PRIMITIVE-SAME: ins(
+ // CHECK-PRIMITIVE-SAME: outs(
+-// CHECK-PRIMITIVE-NEXT: (%[[LHS_IN:[a-zA-Z0-9]*]]: bf16, %[[RHS_IN:.*]]: bf16) {
++// CHECK-PRIMITIVE-NEXT: (%[[LHS_IN:[a-zA-Z0-9]*]]: bf16, %[[RHS_IN:.*]]: bf16, %[[RESULT_OUT:.*]]: i1) {
+ // CHECK-PRIMITIVE-NEXT: %[[LHS_INT:.*]] = arith.bitcast %[[LHS_IN]] : bf16 to i16
+ // CHECK-PRIMITIVE-NEXT: %[[LHS_CMP:.*]] = arith.cmpi slt, %[[LHS_INT]], %[[C0]] : i16
+ // CHECK-PRIMITIVE-NEXT: %[[LHS_SUB:.*]] = arith.subi %[[C32767]], %[[LHS_INT]] : i16
+@@ -937,7 +937,7 @@
+ // CHECK-PRIMITIVE-SAME: ins(%[[LHS]], %[[RHS]] : tensor<2x?xf32>, tensor<2x?xf32>)
+ // CHECK-PRIMITIVE-SAME: outs(%[[DST]] : tensor<2x?xf32>)
+ // CHECK-PRIMITIVE-SAME: {someattr}
+-// CHECK-PRIMITIVE: (%[[LHS_:.*]]: f32, %[[RHS_:.*]]: f32) {
++// CHECK-PRIMITIVE: (%[[LHS_:.*]]: f32, %[[RHS_:.*]]: f32, %[[RESULT_OUT:.*]]: f32) {
+ // CHECK-PRIMITIVE: %[[RES:.*]] = arith.select %[[PRED_ELEM]], %[[LHS_]], %[[RHS_]] : f32
+ // CHECK-PRIMITIVE: linalg.yield %[[RES]]
+
+@@ -978,7 +978,7 @@
+ // CHECK-PRIMITIVE-SAME: ins(%[[LHS]], %[[RHS]] : tensor<2x?xf32>, tensor<2x?xf32>)
+ // CHECK-PRIMITIVE-SAME: outs(%[[DST]] : tensor<2x?xf32>)
+ // CHECK-PRIMITIVE-SAME: {someattr}
+-// CHECK-PRIMITIVE: (%[[LHS_:.*]]: f32, %[[RHS_:.*]]: f32) {
++// CHECK-PRIMITIVE: (%[[LHS_:.*]]: f32, %[[RHS_:.*]]: f32, %[[RESULT_OUT:.*]]: f32) {
+ // CHECK-PRIMITIVE: linalg.yield %[[LHS_]]
+
+ // -----
+@@ -1416,7 +1416,7 @@
+
+ // CHECK-PRIMITIVE: %[[INIT:.*]] = tensor.empty
+ // CHECK-PRIMITIVE: %[[RESULT:.*]] = linalg.map ins(%[[LB]], %[[X]], %[[UB]] : tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) outs(%[[INIT]] : tensor<4xf32>)
+-// CHECK-PRIMITIVE: (%[[SCALAR_LB:.*]]: f32, %[[SCALAR_X:.*]]: f32, %[[SCALAR_UB:.*]]: f32)
++// CHECK-PRIMITIVE: (%[[SCALAR_LB:.*]]: f32, %[[SCALAR_X:.*]]: f32, %[[SCALAR_UB:.*]]: f32, %[[RESULT_OUT:.*]]: f32)
+ // CHECK-PRIMITIVE: %[[MAX:.*]] = arith.maximumf %[[SCALAR_LB]], %[[SCALAR_X]] : f32
+ // CHECK-PRIMITIVE: %[[MIN:.*]] = arith.minimumf %[[MAX]], %[[SCALAR_UB]] : f32
+ // CHECK-PRIMITIVE: linalg.yield %[[MIN]]
+@@ -1478,7 +1478,7 @@
+ // CHECK-PRIMITIVE-DAG: %[[SCALAR_LB:.*]] = tensor.extract %[[LB]]
+ // CHECK-PRIMITIVE-DAG: %[[SCALAR_UB:.*]] = tensor.extract %[[UB]]
+ // CHECK-PRIMITIVE: %[[RESULT:.*]] = linalg.map ins(%[[X]] : tensor<?xf32>) outs(%[[INIT]] : tensor<?xf32>)
+-// CHECK-PRIMITIVE: (%[[SCALAR_X:.*]]: f32)
++// CHECK-PRIMITIVE: (%[[SCALAR_X:.*]]: f32, %[[RESULT_OUT:.*]]: f32)
+ // CHECK-PRIMITIVE: %[[MAX:.*]] = arith.maximumf %[[SCALAR_LB]], %[[SCALAR_X]] : f32
+ // CHECK-PRIMITIVE: %[[MIN:.*]] = arith.minimumf %[[MAX]], %[[SCALAR_UB]] : f32
+ // CHECK-PRIMITIVE: linalg.yield %[[MIN]]
+@@ -1554,7 +1554,7 @@
+ // CHECK: linalg.yield %[[V_NOT]] : i32
+ // CHECK-PRIMITIVE: %[[CST_N1:.+]] = arith.constant -1 : i32
+ // CHECK-PRIMITIVE: linalg.map
+- // CHECK-PRIMITIVE: (%[[IN:.+]]: i32)
++ // CHECK-PRIMITIVE: (%[[IN:.+]]: i32, %[[RESULT_OUT:.+]]: i32)
+ // CHECK-PRIMITIVE: %[[V_NOT:.+]] = arith.xori %[[IN]], %[[CST_N1]] : i32
+ // CHECK-PRIMITIVE: linalg.yield %[[V_NOT]] : i32
+ %0 = "stablehlo.not"(%arg) : (tensor<2x2xi32>) -> tensor<2x2xi32>
+diff --ruN a/stablehlo/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp b/stablehlo/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp
+--- stablehlo/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp
++++ stablehlo/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp
+@@ -1748,6 +1748,12 @@
+
+ rewriter.applySignatureConversion(&region.front(), signatureConverter,
+ getTypeConverter());
++ auto& blocks = linalgOp.getMapper().getBlocks();
++ if (blocks.empty()) {
++ return rewriter.notifyMatchFailure(op, "expected at least one block");
++ }
++ blocks.front().addArgument(resultType.getElementType(), loc);
++
+ auto result = rewriter.createOrFold<tensor::CastOp>(loc, resultType,
+ linalgOp.getResults());
+ rewriter.replaceOp(op, result);
diff --ruN a/stablehlo/stablehlo/conversions/linalg/transforms/StablehloToArith.cpp b/stablehlo/stablehlo/conversions/linalg/transforms/StablehloToArith.cpp
--- stablehlo/stablehlo/conversions/linalg/transforms/StablehloToArith.cpp
+++ stablehlo/stablehlo/conversions/linalg/transforms/StablehloToArith.cpp

View File

@ -3,8 +3,8 @@
load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")
def repo():
SHARDY_COMMIT = "f1c951d74cf5c67b9e0f776e21fe2316d5c69f37"
SHARDY_SHA256 = "0b8b96710a2f2eec4581186e4e773aa4c4cfe6ae5e9681b7803e9b8336ead2f7"
SHARDY_COMMIT = "e269b4c1968c930518c42c02bfdcdf0d921793de"
SHARDY_SHA256 = "bdf22ae5d5a1ecacdca762da892e2291a7f82ddc42a23b1ca096dadb490d6068"
tf_http_archive(
name = "shardy",