Updates LLVM usage to match
[22079e3f3698](https://github.com/llvm/llvm-project/commit/22079e3f3698)

PiperOrigin-RevId: 826294004
This commit is contained in:
A. Unique TensorFlower 2025-10-30 20:26:34 -07:00 committed by TensorFlower Gardener
parent 6d86cff5f3
commit b2334ac330
6 changed files with 799 additions and 1176 deletions

View File

@ -1,576 +1 @@
Auto generated patch. Do not edit or delete it, even if empty.
diff -ruN --strip-trailing-cr a/lldb/packages/Python/lldbsuite/test/tools/lldb-dap/dap_server.py b/lldb/packages/Python/lldbsuite/test/tools/lldb-dap/dap_server.py
--- a/lldb/packages/Python/lldbsuite/test/tools/lldb-dap/dap_server.py
+++ b/lldb/packages/Python/lldbsuite/test/tools/lldb-dap/dap_server.py
@@ -10,8 +10,8 @@
import subprocess
import signal
import sys
+import threading
import warnings
-import selectors
import time
from typing import (
Any,
@@ -139,6 +139,35 @@
outfile.write("\n")
+def read_packet(
+ f: IO[bytes], trace_file: Optional[IO[str]] = None
+) -> Optional[ProtocolMessage]:
+ """Decode a JSON packet that starts with the content length and is
+ followed by the JSON bytes from a file 'f'. Returns None on EOF.
+ """
+ line = f.readline().decode("utf-8")
+ if len(line) == 0:
+ return None # EOF.
+
+ # Watch for line that starts with the prefix
+ prefix = "Content-Length: "
+ if line.startswith(prefix):
+ # Decode length of JSON bytes
+ length = int(line[len(prefix) :])
+ # Skip empty line
+ separator = f.readline().decode()
+ if separator != "":
+ Exception("malformed DAP content header, unexpected line: " + separator)
+ # Read JSON bytes
+ json_str = f.read(length).decode()
+ if trace_file:
+ trace_file.write("from adapter:\n%s\n" % (json_str))
+ # Decode the JSON bytes into a python dictionary
+ return json.loads(json_str)
+
+ raise Exception("unexpected malformed message from lldb-dap: " + line)
+
+
def packet_type_is(packet, packet_type):
return "type" in packet and packet["type"] == packet_type
@@ -170,8 +199,16 @@
self.log_file = log_file
self.send = send
self.recv = recv
- self.selector = selectors.DefaultSelector()
- self.selector.register(recv, selectors.EVENT_READ)
+
+ # Packets that have been received and processed but have not yet been
+ # requested by a test case.
+ self._pending_packets: List[Optional[ProtocolMessage]] = []
+ # Received packets that have not yet been processed.
+ self._recv_packets: List[Optional[ProtocolMessage]] = []
+ # Used as a mutex for _recv_packets and for notify when _recv_packets
+ # changes.
+ self._recv_condition = threading.Condition()
+ self._recv_thread = threading.Thread(target=self._read_packet_thread)
# session state
self.init_commands = init_commands
@@ -197,6 +234,9 @@
# keyed by breakpoint id
self.resolved_breakpoints: dict[str, Breakpoint] = {}
+ # trigger enqueue thread
+ self._recv_thread.start()
+
@classmethod
def encode_content(cls, s: str) -> bytes:
return ("Content-Length: %u\r\n\r\n%s" % (len(s), s)).encode("utf-8")
@@ -212,46 +252,17 @@
f"seq mismatch in response {command['seq']} != {response['request_seq']}"
)
- def _read_packet(
- self,
- timeout: float = DEFAULT_TIMEOUT,
- ) -> Optional[ProtocolMessage]:
- """Decode a JSON packet that starts with the content length and is
- followed by the JSON bytes from self.recv. Returns None on EOF.
- """
-
- ready = self.selector.select(timeout)
- if not ready:
- warnings.warn(
- "timeout occurred waiting for a packet, check if the test has a"
- " negative assertion and see if it can be inverted.",
- stacklevel=4,
- )
- return None # timeout
-
- line = self.recv.readline().decode("utf-8")
- if len(line) == 0:
- return None # EOF.
-
- # Watch for line that starts with the prefix
- prefix = "Content-Length: "
- if line.startswith(prefix):
- # Decode length of JSON bytes
- length = int(line[len(prefix) :])
- # Skip empty line
- separator = self.recv.readline().decode()
- if separator != "":
- Exception("malformed DAP content header, unexpected line: " + separator)
- # Read JSON bytes
- json_str = self.recv.read(length).decode()
- if self.trace_file:
- self.trace_file.write(
- "%s from adapter:\n%s\n" % (time.time(), json_str)
- )
- # Decode the JSON bytes into a python dictionary
- return json.loads(json_str)
-
- raise Exception("unexpected malformed message from lldb-dap: " + line)
+ def _read_packet_thread(self):
+ try:
+ while True:
+ packet = read_packet(self.recv, trace_file=self.trace_file)
+ # `packet` will be `None` on EOF. We want to pass it down to
+ # handle_recv_packet anyway so the main thread can handle unexpected
+ # termination of lldb-dap and stop waiting for new packets.
+ if not self._handle_recv_packet(packet):
+ break
+ finally:
+ dump_dap_log(self.log_file)
def get_modules(
self, start_module: Optional[int] = None, module_count: Optional[int] = None
@@ -299,6 +310,34 @@
output += self.get_output(category, clear=clear)
return output
+ def _enqueue_recv_packet(self, packet: Optional[ProtocolMessage]):
+ with self.recv_condition:
+ self.recv_packets.append(packet)
+ self.recv_condition.notify()
+
+ def _handle_recv_packet(self, packet: Optional[ProtocolMessage]) -> bool:
+ """Handles an incoming packet.
+
+ Called by the read thread that is waiting for all incoming packets
+ to store the incoming packet in "self._recv_packets" in a thread safe
+ way. This function will then signal the "self._recv_condition" to
+ indicate a new packet is available.
+
+ Args:
+ packet: A new packet to store.
+
+ Returns:
+ True if the caller should keep calling this function for more
+ packets.
+ """
+ with self._recv_condition:
+ self._recv_packets.append(packet)
+ self._recv_condition.notify()
+ # packet is None on EOF
+ return packet is not None and not (
+ packet["type"] == "response" and packet["command"] == "disconnect"
+ )
+
def _recv_packet(
self,
*,
@@ -322,34 +361,46 @@
The first matching packet for the given predicate, if specified,
otherwise None.
"""
- deadline = time.time() + timeout
-
- while time.time() < deadline:
- packet = self._read_packet(timeout=deadline - time.time())
- if packet is None:
- return None
- self._process_recv_packet(packet)
- if not predicate or predicate(packet):
- return packet
+ assert (
+ threading.current_thread != self._recv_thread
+ ), "Must not be called from the _recv_thread"
+
+ def process_until_match():
+ self._process_recv_packets()
+ for i, packet in enumerate(self._pending_packets):
+ if packet is None:
+ # We need to return a truthy value to break out of the
+ # wait_for, use `EOFError` as an indicator of EOF.
+ return EOFError()
+ if predicate and predicate(packet):
+ self._pending_packets.pop(i)
+ return packet
+
+ with self._recv_condition:
+ packet = self._recv_condition.wait_for(process_until_match, timeout)
+ return None if isinstance(packet, EOFError) else packet
- def _process_recv_packet(self, packet) -> None:
+ def _process_recv_packets(self) -> None:
"""Process received packets, updating the session state."""
- if packet and ("seq" not in packet or packet["seq"] == 0):
- warnings.warn(
- f"received a malformed packet, expected 'seq != 0' for {packet!r}"
- )
- # Handle events that may modify any stateful properties of
- # the DAP session.
- if packet and packet["type"] == "event":
- self._handle_event(packet)
- elif packet and packet["type"] == "request":
- # Handle reverse requests and keep processing.
- self._handle_reverse_request(packet)
+ with self._recv_condition:
+ for packet in self._recv_packets:
+ if packet and ("seq" not in packet or packet["seq"] == 0):
+ warnings.warn(
+ f"received a malformed packet, expected 'seq != 0' for {packet!r}"
+ )
+ # Handle events that may modify any stateful properties of
+ # the DAP session.
+ if packet and packet["type"] == "event":
+ self._handle_event(packet)
+ elif packet and packet["type"] == "request":
+ # Handle reverse requests and keep processing.
+ self._handle_reverse_request(packet)
+ # Move the packet to the pending queue.
+ self._pending_packets.append(packet)
+ self._recv_packets.clear()
def _handle_event(self, packet: Event) -> None:
"""Handle any events that modify debug session state we track."""
- self.events.append(packet)
-
event = packet["event"]
body: Optional[Dict] = packet.get("body", None)
@@ -402,8 +453,6 @@
self.invalidated_event = packet
elif event == "memory":
self.memory_event = packet
- elif event == "module":
- self.module_events.append(packet)
def _handle_reverse_request(self, request: Request) -> None:
if request in self.reverse_requests:
@@ -472,14 +521,18 @@
Returns the seq number of the request.
"""
- packet["seq"] = self.sequence
- self.sequence += 1
+ # Set the seq for requests.
+ if packet["type"] == "request":
+ packet["seq"] = self.sequence
+ self.sequence += 1
+ else:
+ packet["seq"] = 0
# Encode our command dictionary as a JSON string
json_str = json.dumps(packet, separators=(",", ":"))
if self.trace_file:
- self.trace_file.write("%s to adapter:\n%s\n" % (time.time(), json_str))
+ self.trace_file.write("to adapter:\n%s\n" % (json_str))
length = len(json_str)
if length > 0:
@@ -860,8 +913,6 @@
if restartArguments:
command_dict["arguments"] = restartArguments
- # Clear state, the process is about to restart...
- self._process_continued(True)
response = self._send_recv(command_dict)
# Caller must still call wait_for_stopped.
return response
@@ -1428,10 +1479,8 @@
def terminate(self):
self.send.close()
- self.recv.close()
- self.selector.close()
- if self.log_file:
- dump_dap_log(self.log_file)
+ if self._recv_thread.is_alive():
+ self._recv_thread.join()
def request_setInstructionBreakpoints(self, memory_reference=[]):
breakpoints = []
@@ -1528,7 +1577,6 @@
stdout=subprocess.PIPE,
stderr=sys.stderr,
env=adapter_env,
- bufsize=0,
)
if connection is None:
diff -ruN --strip-trailing-cr a/lldb/packages/Python/lldbsuite/test/tools/lldb-dap/lldbdap_testcase.py b/lldb/packages/Python/lldbsuite/test/tools/lldb-dap/lldbdap_testcase.py
--- a/lldb/packages/Python/lldbsuite/test/tools/lldb-dap/lldbdap_testcase.py
+++ b/lldb/packages/Python/lldbsuite/test/tools/lldb-dap/lldbdap_testcase.py
@@ -416,7 +416,7 @@
return self.dap_server.wait_for_stopped()
def continue_to_breakpoint(self, breakpoint_id: str):
- self.continue_to_breakpoints([breakpoint_id])
+ self.continue_to_breakpoints((breakpoint_id))
def continue_to_breakpoints(self, breakpoint_ids):
self.do_continue()
diff -ruN --strip-trailing-cr a/lldb/test/API/tools/lldb-dap/breakpoint-events/TestDAP_breakpointEvents.py b/lldb/test/API/tools/lldb-dap/breakpoint-events/TestDAP_breakpointEvents.py
--- a/lldb/test/API/tools/lldb-dap/breakpoint-events/TestDAP_breakpointEvents.py
+++ b/lldb/test/API/tools/lldb-dap/breakpoint-events/TestDAP_breakpointEvents.py
@@ -81,20 +81,24 @@
breakpoint["verified"], "expect foo breakpoint to not be verified"
)
+ # Flush the breakpoint events.
+ self.dap_server.wait_for_breakpoint_events()
+
# Continue to the breakpoint
- self.continue_to_breakpoint(foo_bp_id)
- self.continue_to_next_stop() # foo_bp2
- self.continue_to_breakpoint(main_bp_id)
- self.continue_to_exit()
+ self.continue_to_breakpoints(dap_breakpoint_ids)
- bp_events = [e for e in self.dap_server.events if e["event"] == "breakpoint"]
+ verified_breakpoint_ids = []
+ unverified_breakpoint_ids = []
+ for breakpoint_event in self.dap_server.wait_for_breakpoint_events():
+ breakpoint = breakpoint_event["body"]["breakpoint"]
+ id = breakpoint["id"]
+ if breakpoint["verified"]:
+ verified_breakpoint_ids.append(id)
+ else:
+ unverified_breakpoint_ids.append(id)
- main_bp_events = [
- e for e in bp_events if e["body"]["breakpoint"]["id"] == main_bp_id
- ]
- foo_bp_events = [
- e for e in bp_events if e["body"]["breakpoint"]["id"] == foo_bp_id
- ]
+ self.assertIn(main_bp_id, unverified_breakpoint_ids)
+ self.assertIn(foo_bp_id, unverified_breakpoint_ids)
- self.assertTrue(main_bp_events)
- self.assertTrue(foo_bp_events)
+ self.assertIn(main_bp_id, verified_breakpoint_ids)
+ self.assertIn(foo_bp_id, verified_breakpoint_ids)
diff -ruN --strip-trailing-cr a/lldb/test/API/tools/lldb-dap/launch/TestDAP_launch.py b/lldb/test/API/tools/lldb-dap/launch/TestDAP_launch.py
--- a/lldb/test/API/tools/lldb-dap/launch/TestDAP_launch.py
+++ b/lldb/test/API/tools/lldb-dap/launch/TestDAP_launch.py
@@ -156,7 +156,6 @@
self.build_and_launch(
program, debuggerRoot=program_parent_dir, initCommands=commands
)
- self.continue_to_exit()
output = self.get_console()
self.assertTrue(output and len(output) > 0, "expect console output")
lines = output.splitlines()
@@ -172,6 +171,7 @@
% (program_parent_dir, line[len(prefix) :]),
)
self.assertTrue(found, "verified lldb-dap working directory")
+ self.continue_to_exit()
def test_sourcePath(self):
"""
diff -ruN --strip-trailing-cr a/lldb/test/API/tools/lldb-dap/module/TestDAP_module.py b/lldb/test/API/tools/lldb-dap/module/TestDAP_module.py
--- a/lldb/test/API/tools/lldb-dap/module/TestDAP_module.py
+++ b/lldb/test/API/tools/lldb-dap/module/TestDAP_module.py
@@ -64,18 +64,19 @@
self.assertEqual(program, program_module["path"])
self.assertIn("addressRange", program_module)
- self.continue_to_exit()
-
# Collect all the module names we saw as events.
module_new_names = []
module_changed_names = []
- for module_event in self.dap_server.module_events:
+ module_event = self.dap_server.wait_for_event(["module"])
+ while module_event is not None:
reason = module_event["body"]["reason"]
if reason == "new":
module_new_names.append(module_event["body"]["module"]["name"])
elif reason == "changed":
module_changed_names.append(module_event["body"]["module"]["name"])
+ module_event = self.dap_server.wait_for_event(["module"])
+
# Make sure we got an event for every active module.
self.assertNotEqual(len(module_new_names), 0)
for module in active_modules:
@@ -85,6 +86,7 @@
# symbols got added.
self.assertNotEqual(len(module_changed_names), 0)
self.assertIn(program_module["name"], module_changed_names)
+ self.continue_to_exit()
@skipIfWindows
def test_modules(self):
diff -ruN --strip-trailing-cr a/lldb/test/API/tools/lldb-dap/module-event/TestDAP_module_event.py b/lldb/test/API/tools/lldb-dap/module-event/TestDAP_module_event.py
--- a/lldb/test/API/tools/lldb-dap/module-event/TestDAP_module_event.py
+++ b/lldb/test/API/tools/lldb-dap/module-event/TestDAP_module_event.py
@@ -1,58 +1,58 @@
-"""
-Test 'module' events for dynamically loaded libraries.
-"""
-
+import dap_server
from lldbsuite.test.decorators import *
from lldbsuite.test.lldbtest import *
+from lldbsuite.test import lldbutil
import lldbdap_testcase
+import re
class TestDAP_module_event(lldbdap_testcase.DAPTestCaseBase):
- def lookup_module_id(self, name):
- """Returns the identifier for the first module event starting with the given name."""
- for event in self.dap_server.module_events:
- if self.get_dict_value(event, ["body", "module", "name"]).startswith(name):
- return self.get_dict_value(event, ["body", "module", "id"])
- self.fail(f"No module events matching name={name}")
-
- def module_events(self, id):
- """Finds all module events by identifier."""
- return [
- event
- for event in self.dap_server.module_events
- if self.get_dict_value(event, ["body", "module", "id"]) == id
- ]
-
- def module_reasons(self, events):
- """Returns the list of 'reason' values from the given events."""
- return [event["body"]["reason"] for event in events]
-
@skipIfWindows
def test_module_event(self):
- """
- Test that module events are fired on target load and when the list of
- dynamic libraries updates while running.
- """
program = self.getBuildArtifact("a.out")
self.build_and_launch(program)
- # We can analyze the order of events after the process exits.
- self.continue_to_exit()
- a_out_id = self.lookup_module_id("a.out")
- a_out_events = self.module_events(id=a_out_id)
+ source = "main.cpp"
+ breakpoint1_line = line_number(source, "// breakpoint 1")
+ breakpoint2_line = line_number(source, "// breakpoint 2")
+ breakpoint3_line = line_number(source, "// breakpoint 3")
- self.assertIn(
- "new",
- self.module_reasons(a_out_events),
- "Expected a.out to load during the debug session.",
+ breakpoint_ids = self.set_source_breakpoints(
+ source, [breakpoint1_line, breakpoint2_line, breakpoint3_line]
)
+ self.continue_to_breakpoints(breakpoint_ids)
- libother_id = self.lookup_module_id(
- "libother." # libother.so or libother.dylib based on OS.
- )
- libother_events = self.module_events(id=libother_id)
- self.assertEqual(
- self.module_reasons(libother_events),
- ["new", "removed"],
- "Expected libother to be loaded then unloaded during the debug session.",
- )
+ # We're now stopped at breakpoint 1 before the dlopen. Flush all the module events.
+ event = self.dap_server.wait_for_event(["module"])
+ while event is not None:
+ event = self.dap_server.wait_for_event(["module"])
+
+ # Continue to the second breakpoint, before the dlclose.
+ self.continue_to_breakpoints(breakpoint_ids)
+
+ # Make sure we got a module event for libother.
+ event = self.dap_server.wait_for_event(["module"])
+ self.assertIsNotNone(event, "didn't get a module event")
+ module_name = event["body"]["module"]["name"]
+ module_id = event["body"]["module"]["id"]
+ self.assertEqual(event["body"]["reason"], "new")
+ self.assertIn("libother", module_name)
+
+ # Continue to the third breakpoint, after the dlclose.
+ self.continue_to_breakpoints(breakpoint_ids)
+
+ # Make sure we got a module event for libother.
+ event = self.dap_server.wait_for_event(["module"])
+ self.assertIsNotNone(event, "didn't get a module event")
+ reason = event["body"]["reason"]
+ self.assertEqual(reason, "removed")
+ self.assertEqual(event["body"]["module"]["id"], module_id)
+
+ # The removed module event should omit everything but the module id and name
+ # as they are required fields.
+ module_data = event["body"]["module"]
+ required_keys = ["id", "name"]
+ self.assertListEqual(list(module_data.keys()), required_keys)
+ self.assertEqual(module_data["name"], "", "expects empty name.")
+
+ self.continue_to_exit()
diff -ruN --strip-trailing-cr a/lldb/test/API/tools/lldb-dap/restart/TestDAP_restart_console.py b/lldb/test/API/tools/lldb-dap/restart/TestDAP_restart_console.py
--- a/lldb/test/API/tools/lldb-dap/restart/TestDAP_restart_console.py
+++ b/lldb/test/API/tools/lldb-dap/restart/TestDAP_restart_console.py
@@ -30,11 +30,7 @@
if reason == "entry":
seen_stopped_event += 1
- self.assertEqual(
- seen_stopped_event,
- 1,
- f"expect only one stopped entry event in {stopped_events}",
- )
+ self.assertEqual(seen_stopped_event, 1, "expect only one stopped entry event.")
@skipIfAsan
@skipIfWindows
@@ -96,13 +92,11 @@
self.build_and_launch(program, console="integratedTerminal", stopOnEntry=True)
[bp_main] = self.set_function_breakpoints(["main"])
- self.dap_server.request_configurationDone()
- stopped_threads = list(self.dap_server.thread_stop_reasons.values())
+ self.dap_server.request_continue() # sends configuration done
+ stopped_events = self.dap_server.wait_for_stopped()
# We should be stopped at the entry point.
- self.assertEqual(
- len(stopped_threads), 1, "Expected the main thread to be stopped on entry."
- )
- self.assertEqual(stopped_threads[0]["reason"], "entry")
+ self.assertGreaterEqual(len(stopped_events), 0, "expect stopped events")
+ self.verify_stopped_on_entry(stopped_events)
# Then, if we continue, we should hit the breakpoint at main.
self.dap_server.request_continue()
@@ -111,12 +105,8 @@
# Restart and check that we still get a stopped event before reaching
# main.
self.dap_server.request_restart()
- stopped_threads = list(self.dap_server.thread_stop_reasons.values())
- # We should be stopped at the entry point.
- self.assertEqual(
- len(stopped_threads), 1, "Expected the main thread to be stopped on entry."
- )
- self.assertEqual(stopped_threads[0]["reason"], "entry")
+ stopped_events = self.dap_server.wait_for_stopped()
+ self.verify_stopped_on_entry(stopped_events)
# continue to main
self.dap_server.request_continue()
diff -ruN --strip-trailing-cr a/lldb/test/API/tools/lldb-dap/send-event/TestDAP_sendEvent.py b/lldb/test/API/tools/lldb-dap/send-event/TestDAP_sendEvent.py
--- a/lldb/test/API/tools/lldb-dap/send-event/TestDAP_sendEvent.py
+++ b/lldb/test/API/tools/lldb-dap/send-event/TestDAP_sendEvent.py
@@ -32,7 +32,7 @@
],
)
self.set_source_breakpoints(source, [breakpoint_line])
- self.do_continue()
+ self.continue_to_next_stop()
custom_event = self.dap_server.wait_for_event(
filter=["my-custom-event-no-body"]

View File

@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive")
def repo(name):
"""Imports LLVM."""
LLVM_COMMIT = "4c46ae394841521914e0e8575e7619a1c0d1149d"
LLVM_SHA256 = "55c824ce2e1a3afafa4e108532f4eff9f194d20d44d1c5ddc6107bb23d7c6c2a"
LLVM_COMMIT = "22079e3f3698d5c367c7b67f63de8c838791ae76"
LLVM_SHA256 = "d5616e9c0f4b761f13da5535a0d9ec94acf4ae5226bbec3e47ac2929ea60cac2"
tf_http_archive(
name = name,

File diff suppressed because it is too large Load Diff

View File

@ -3,8 +3,8 @@
load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")
def repo():
SHARDY_COMMIT = "38c7fc678b8f647a484fc2772f7e5ee1aca29c1e"
SHARDY_SHA256 = "f0afb19e3b1b0677736b806790a519c3b272b22361a2f7c872f11ce0930bb6be"
SHARDY_COMMIT = "f1c951d74cf5c67b9e0f776e21fe2316d5c69f37"
SHARDY_SHA256 = "0b8b96710a2f2eec4581186e4e773aa4c4cfe6ae5e9681b7803e9b8336ead2f7"
tf_http_archive(
name = "shardy",

View File

@ -1,3 +1,88 @@
diff --ruN a/stablehlo/stablehlo/conversions/linalg/tests/miscellaneous.mlir b/stablehlo/stablehlo/conversions/linalg/tests/miscellaneous.mlir
--- stablehlo/stablehlo/conversions/linalg/tests/miscellaneous.mlir
+++ stablehlo/stablehlo/conversions/linalg/tests/miscellaneous.mlir
@@ -768,7 +768,7 @@
// CHECK-PRIMITIVE: %[[MAP:.+]] = linalg.map
// CHECK-PRIMITIVE-SAME: ins(%[[ARG0]], %[[ARG1]]
// CHECK-PRIMITIVE-SAME: outs(%[[INIT]] : tensor<?xi1>)
-// CHECK-PRIMITIVE-NEXT: (%[[A:.+]]: complex<f32>, %[[B:.+]]: complex<f32>) {
+// CHECK-PRIMITIVE-NEXT: (%[[A:.+]]: complex<f32>, %[[B:.+]]: complex<f32>, %{{.+}}: i1) {
// CHECK-PRIMITIVE: %[[RE1:.+]] = complex.re %[[A]] : complex<f32>
// CHECK-PRIMITIVE: %[[RE2:.+]] = complex.re %[[B]] : complex<f32>
// CHECK-PRIMITIVE: %[[CMP:.+]] = arith.cmpf oeq, %[[RE1]], %[[RE2]] : f32
diff --ruN a/stablehlo/stablehlo/conversions/linalg/tests/pointwise.mlir b/stablehlo/stablehlo/conversions/linalg/tests/pointwise.mlir
--- stablehlo/stablehlo/conversions/linalg/tests/pointwise.mlir
+++ stablehlo/stablehlo/conversions/linalg/tests/pointwise.mlir
@@ -714,7 +714,7 @@
// CHECK-PRIMITIVE: linalg.map
// CHECK-PRIMITIVE-SAME: ins(
// CHECK-PRIMITIVE-SAME: outs(
-// CHECK-PRIMITIVE-NEXT: (%[[LHS_IN:[a-zA-Z0-9]*]]: bf16, %[[RHS_IN:.*]]: bf16) {
+// CHECK-PRIMITIVE-NEXT: (%[[LHS_IN:[a-zA-Z0-9]*]]: bf16, %[[RHS_IN:.*]]: bf16, %[[RESULT_OUT:.*]]: i1) {
// CHECK-PRIMITIVE-NEXT: %[[LHS_INT:.*]] = arith.bitcast %[[LHS_IN]] : bf16 to i16
// CHECK-PRIMITIVE-NEXT: %[[LHS_CMP:.*]] = arith.cmpi slt, %[[LHS_INT]], %[[C0]] : i16
// CHECK-PRIMITIVE-NEXT: %[[LHS_SUB:.*]] = arith.subi %[[C32767]], %[[LHS_INT]] : i16
@@ -937,7 +937,7 @@
// CHECK-PRIMITIVE-SAME: ins(%[[LHS]], %[[RHS]] : tensor<2x?xf32>, tensor<2x?xf32>)
// CHECK-PRIMITIVE-SAME: outs(%[[DST]] : tensor<2x?xf32>)
// CHECK-PRIMITIVE-SAME: {someattr}
-// CHECK-PRIMITIVE: (%[[LHS_:.*]]: f32, %[[RHS_:.*]]: f32) {
+// CHECK-PRIMITIVE: (%[[LHS_:.*]]: f32, %[[RHS_:.*]]: f32, %[[RESULT_OUT:.*]]: f32) {
// CHECK-PRIMITIVE: %[[RES:.*]] = arith.select %[[PRED_ELEM]], %[[LHS_]], %[[RHS_]] : f32
// CHECK-PRIMITIVE: linalg.yield %[[RES]]
@@ -978,7 +978,7 @@
// CHECK-PRIMITIVE-SAME: ins(%[[LHS]], %[[RHS]] : tensor<2x?xf32>, tensor<2x?xf32>)
// CHECK-PRIMITIVE-SAME: outs(%[[DST]] : tensor<2x?xf32>)
// CHECK-PRIMITIVE-SAME: {someattr}
-// CHECK-PRIMITIVE: (%[[LHS_:.*]]: f32, %[[RHS_:.*]]: f32) {
+// CHECK-PRIMITIVE: (%[[LHS_:.*]]: f32, %[[RHS_:.*]]: f32, %[[RESULT_OUT:.*]]: f32) {
// CHECK-PRIMITIVE: linalg.yield %[[LHS_]]
// -----
@@ -1416,7 +1416,7 @@
// CHECK-PRIMITIVE: %[[INIT:.*]] = tensor.empty
// CHECK-PRIMITIVE: %[[RESULT:.*]] = linalg.map ins(%[[LB]], %[[X]], %[[UB]] : tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) outs(%[[INIT]] : tensor<4xf32>)
-// CHECK-PRIMITIVE: (%[[SCALAR_LB:.*]]: f32, %[[SCALAR_X:.*]]: f32, %[[SCALAR_UB:.*]]: f32)
+// CHECK-PRIMITIVE: (%[[SCALAR_LB:.*]]: f32, %[[SCALAR_X:.*]]: f32, %[[SCALAR_UB:.*]]: f32, %[[RESULT_OUT:.*]]: f32)
// CHECK-PRIMITIVE: %[[MAX:.*]] = arith.maximumf %[[SCALAR_LB]], %[[SCALAR_X]] : f32
// CHECK-PRIMITIVE: %[[MIN:.*]] = arith.minimumf %[[MAX]], %[[SCALAR_UB]] : f32
// CHECK-PRIMITIVE: linalg.yield %[[MIN]]
@@ -1478,7 +1478,7 @@
// CHECK-PRIMITIVE-DAG: %[[SCALAR_LB:.*]] = tensor.extract %[[LB]]
// CHECK-PRIMITIVE-DAG: %[[SCALAR_UB:.*]] = tensor.extract %[[UB]]
// CHECK-PRIMITIVE: %[[RESULT:.*]] = linalg.map ins(%[[X]] : tensor<?xf32>) outs(%[[INIT]] : tensor<?xf32>)
-// CHECK-PRIMITIVE: (%[[SCALAR_X:.*]]: f32)
+// CHECK-PRIMITIVE: (%[[SCALAR_X:.*]]: f32, %[[RESULT_OUT:.*]]: f32)
// CHECK-PRIMITIVE: %[[MAX:.*]] = arith.maximumf %[[SCALAR_LB]], %[[SCALAR_X]] : f32
// CHECK-PRIMITIVE: %[[MIN:.*]] = arith.minimumf %[[MAX]], %[[SCALAR_UB]] : f32
// CHECK-PRIMITIVE: linalg.yield %[[MIN]]
@@ -1554,7 +1554,7 @@
// CHECK: linalg.yield %[[V_NOT]] : i32
// CHECK-PRIMITIVE: %[[CST_N1:.+]] = arith.constant -1 : i32
// CHECK-PRIMITIVE: linalg.map
- // CHECK-PRIMITIVE: (%[[IN:.+]]: i32)
+ // CHECK-PRIMITIVE: (%[[IN:.+]]: i32, %[[RESULT_OUT:.+]]: i32)
// CHECK-PRIMITIVE: %[[V_NOT:.+]] = arith.xori %[[IN]], %[[CST_N1]] : i32
// CHECK-PRIMITIVE: linalg.yield %[[V_NOT]] : i32
%0 = "stablehlo.not"(%arg) : (tensor<2x2xi32>) -> tensor<2x2xi32>
diff --ruN a/stablehlo/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp b/stablehlo/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp
--- stablehlo/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp
+++ stablehlo/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp
@@ -1748,6 +1748,12 @@
rewriter.applySignatureConversion(&region.front(), signatureConverter,
getTypeConverter());
+ auto& blocks = linalgOp.getMapper().getBlocks();
+ if (blocks.empty()) {
+ return rewriter.notifyMatchFailure(op, "expected at least one block");
+ }
+ blocks.front().addArgument(resultType.getElementType(), loc);
+
auto result = rewriter.createOrFold<tensor::CastOp>(loc, resultType,
linalgOp.getResults());
rewriter.replaceOp(op, result);
diff --ruN a/stablehlo/stablehlo/conversions/linalg/transforms/StablehloToArith.cpp b/stablehlo/stablehlo/conversions/linalg/transforms/StablehloToArith.cpp
--- stablehlo/stablehlo/conversions/linalg/transforms/StablehloToArith.cpp
+++ stablehlo/stablehlo/conversions/linalg/transforms/StablehloToArith.cpp

View File

@ -2,7 +2,7 @@
func.func @no_inputs() -> tensor<4xf32> {
%init = arith.constant dense<[1.0,2.0,3.0,4.0]> : tensor<4xf32>
%zero = linalg.map outs(%init:tensor<4xf32>)() {
%zero = linalg.map outs(%init:tensor<4xf32>)(%out: f32) {
%0 = arith.constant 0.0: f32
linalg.yield %0: f32
}
@ -19,7 +19,7 @@ func.func @binary() -> tensor<4xi32> {
%rhs = arith.constant dense<[10, 20, 30, 40]> : tensor<4xi32>
%add = linalg.map ins(%lhs, %rhs: tensor<4xi32>, tensor<4xi32>)
outs(%init: tensor<4xi32>)
(%lhs_elem: i32, %rhs_elem: i32) {
(%lhs_elem: i32, %rhs_elem: i32, %out: i32) {
%0 = arith.addi %lhs_elem, %rhs_elem: i32
linalg.yield %0: i32
}
@ -36,7 +36,7 @@ func.func @memref() -> memref<4xi32> {
%rhs = arith.constant dense<[10, 20, 30, 40]> : memref<4xi32>
linalg.map ins(%lhs, %rhs: memref<4xi32>, memref<4xi32>)
outs(%alloc: memref<4xi32>)
(%lhs_elem: i32, %rhs_elem: i32) {
(%lhs_elem: i32, %rhs_elem: i32, %out: i32) {
%0 = arith.muli %lhs_elem, %rhs_elem: i32
linalg.yield %0: i32
}
@ -49,7 +49,7 @@ func.func @memref() -> memref<4xi32> {
func.func @index() -> memref<4xindex> {
%alloc = memref.alloc() : memref<4xindex>
linalg.map outs(%alloc: memref<4xindex>)() {
linalg.map outs(%alloc: memref<4xindex>)(%out: index) {
%0 = linalg.index 0 : index
linalg.yield %0: index
}
@ -63,7 +63,8 @@ func.func @index() -> memref<4xindex> {
func.func @vector() -> memref<4xvector<2xindex>> {
%c = arith.constant dense<42> : vector<2xindex>
%alloc = memref.alloc() : memref<4xvector<2xindex>>
linalg.map outs(%alloc: memref<4xvector<2xindex>>)() {
linalg.map outs(%alloc: memref<4xvector<2xindex>>)
(%out: vector<2xindex>) {
linalg.yield %c: vector<2xindex>
}
func.return %alloc : memref<4xvector<2xindex>>