Allow summarize & streaming w/WebLLM
Some checks are pending
Tests / tests (push) Waiting to run

This commit is contained in:
Dustin Brett 2024-10-24 07:33:04 -07:00
parent 0e291d6011
commit 1023c9f55f
2 changed files with 54 additions and 27 deletions

View File

@ -82,8 +82,13 @@ const useFileContextMenu = (
): ContextMenuCapture => {
const { minimize, open, url: changeUrl } = useProcesses();
const processesRef = useProcessesRef();
const { setCursor, setForegroundId, setWallpaper, updateRecentFiles } =
useSession();
const {
aiEnabled,
setCursor,
setForegroundId,
setWallpaper,
updateRecentFiles,
} = useSession();
const baseName = basename(path);
const isFocusedEntry = focusedEntries.includes(baseName);
const openFile = useFile(url, path);
@ -491,8 +496,7 @@ const useFileContextMenu = (
}
if (
hasWindowAI &&
"summarizer" in window.ai &&
(aiEnabled || (hasWindowAI && "summarizer" in window.ai)) &&
TEXT_FILE_EXTENSIONS.has(urlExtension)
) {
const aiCommand = (command: string): void => {
@ -512,7 +516,7 @@ const useFileContextMenu = (
menuItems.unshift(MENU_SEPERATOR, {
label: `AI (${AI_STAGE})`,
menu: [
...("summarizer" in window.ai
...(aiEnabled || (hasWindowAI && "summarizer" in window.ai)
? [
{
action: () => aiCommand("Summarize"),
@ -641,6 +645,7 @@ const useFileContextMenu = (
return menuItems[0] === MENU_SEPERATOR ? menuItems.slice(1) : menuItems;
}),
[
aiEnabled,
archiveFiles,
baseName,
changeUrl,

View File

@ -1,4 +1,6 @@
import {
type ChatCompletion,
type ChatCompletionChunk,
type ChatCompletionMessageParam,
type MLCEngine,
} from "@mlc-ai/web-llm";
@ -49,10 +51,11 @@ let cancel = false;
let responding = false;
let sessionId = 0;
let session: AILanguageModel | ChatCompletionMessageParam[] | undefined;
let session: AILanguageModel | undefined;
let summarizer: AISummarizer | undefined;
let prompts: (AILanguageModelAssistantPrompt | AILanguageModelUserPrompt)[] =
[];
let prompts:
| (AILanguageModelAssistantPrompt | AILanguageModelUserPrompt)[]
| ChatCompletionMessageParam[] = [];
let engine: MLCEngine;
let markedLoaded = false;
@ -86,7 +89,7 @@ globalThis.addEventListener(
session = await globalThis.ai.languageModel.create(config);
} else {
session = [SYSTEM_PROMPT];
prompts = [SYSTEM_PROMPT];
if (!engine) {
const { CreateMLCEngine } = await import("@mlc-ai/web-llm");
@ -101,7 +104,10 @@ globalThis.addEventListener(
}
}
let response: string | ReadableStream<string> = "";
let response:
| string
| ReadableStream<string>
| AsyncIterable<ChatCompletionChunk> = "";
let retry = 0;
const rebuildSession = async (): Promise<void> => {
(session as AILanguageModel)?.destroy();
@ -115,7 +121,10 @@ globalThis.addEventListener(
...CONVO_STYLE_TEMPS[data.style],
initialPrompts: [
SYSTEM_PROMPT as unknown as AILanguageModelAssistantPrompt,
...prompts,
...(prompts as (
| AILanguageModelAssistantPrompt
| AILanguageModelUserPrompt
)[]),
],
};
@ -159,25 +168,29 @@ globalThis.addEventListener(
(await aiAssistant.prompt(data.text, aiOptions)) || "";
}
} else {
(session as ChatCompletionMessageParam[]).push({
content: data.text,
prompts.push({
content: data.summarizeText
? `Summarize:\n\n${data.summarizeText}`
: data.text,
role: "user",
});
const {
choices: [{ message }],
// eslint-disable-next-line no-await-in-loop
} = await engine.chat.completions.create({
const stream = Boolean(data.streamId);
// eslint-disable-next-line no-await-in-loop
const completions = await engine.chat.completions.create({
logprobs: true,
messages: session as ChatCompletionMessageParam[],
messages: prompts as ChatCompletionMessageParam[],
stream,
stream_options: { include_usage: false },
temperature: CONVO_STYLE_TEMPS[data.style].temperature,
top_logprobs: CONVO_STYLE_TEMPS[data.style].topK,
...WEB_LLM_MODEL_CONFIG[WEB_LLM_MODEL],
});
(session as ChatCompletionMessageParam[]).push(message);
response = message.content || "";
response = stream
? (completions as AsyncIterable<ChatCompletionChunk>)
: (completions as ChatCompletion).choices[0].message.content ||
"";
}
} catch (error) {
console.error("Failed to get prompt response.", error);
@ -204,22 +217,31 @@ globalThis.addEventListener(
response: message,
streamId,
});
prompts.push(
{ content: data.text, role: "user" },
{ content: message, role: "assistant" }
);
if (prompts[prompts.length - 1]?.role !== "user") {
prompts.push({ content: data.text, role: "user" });
}
prompts.push({ content: message, role: "assistant" });
};
if (response && typeof response === "string") {
sendMessage(response);
} else {
try {
let reply = "";
// @ts-expect-error ReadableStream will have an asyncIterator if Prompt API exists
// eslint-disable-next-line @typescript-eslint/await-thenable
for await (const chunk of response) {
if (cancel) break;
sendMessage(chunk as string, data.streamId);
if (typeof chunk === "string") {
sendMessage(chunk, data.streamId);
} else {
reply +=
(chunk as ChatCompletionChunk).choices[0]?.delta.content ||
"";
sendMessage(reply, data.streamId);
}
}
} catch (error) {
console.error("Failed to stream prompt response.", error);