mirror of
https://github.com/DustinBrett/daedalOS.git
synced 2025-12-06 12:20:20 +01:00
Use latest WebLLM implementation
This commit is contained in:
parent
fe3d2cb92a
commit
2cd3676c0e
|
|
@ -61,7 +61,7 @@ const Chat: FC<ComponentProcessProps> = ({ id }) => {
|
|||
error: aiError,
|
||||
name,
|
||||
resetError,
|
||||
} = useInference(apiKey, engine);
|
||||
} = useInference(engine.startsWith("WebLLM") ? engine : apiKey, engine);
|
||||
const messagesRef = useRef<HTMLUListElement>(null);
|
||||
const inputRef = useRef<HTMLTextAreaElement>(null);
|
||||
const [input, setInput] = useState<string>("");
|
||||
|
|
@ -395,12 +395,19 @@ const Chat: FC<ComponentProcessProps> = ({ id }) => {
|
|||
? [
|
||||
{
|
||||
label: "Set AI Engine",
|
||||
menu: ["HuggingFace", "OpenAI", "WebLLM"].map((engineName) => ({
|
||||
menu: [
|
||||
"HuggingFace",
|
||||
"OpenAI",
|
||||
"WebLLM [RedPajama 3B]",
|
||||
"WebLLM [Vicuna 7B]",
|
||||
].map((engineName) => ({
|
||||
action: () => {
|
||||
setAiApi(
|
||||
`${engineName}:${
|
||||
// eslint-disable-next-line no-alert
|
||||
engineName === "WebLLM" ? "" : prompt("API Key") || ""
|
||||
engineName.startsWith("WebLLM")
|
||||
? ""
|
||||
: prompt("API Key") || ""
|
||||
}`
|
||||
);
|
||||
|
||||
|
|
@ -414,7 +421,7 @@ const Chat: FC<ComponentProcessProps> = ({ id }) => {
|
|||
label: engineName,
|
||||
})),
|
||||
},
|
||||
...(["OpenAI", "WebLLM"].includes(name)
|
||||
...(name === "OpenAI"
|
||||
? [
|
||||
{
|
||||
action: () => {
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ type WorkerMessage = { data: Log | string };
|
|||
|
||||
declare global {
|
||||
interface Window {
|
||||
webLLM?: Worker;
|
||||
webLLM?: Record<string, Worker>;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -16,6 +16,8 @@ const DEFAULT_GREETING = {
|
|||
} as Message;
|
||||
|
||||
export class WebLLM implements Engine {
|
||||
private model = "";
|
||||
|
||||
private worker?: Worker = undefined;
|
||||
|
||||
private isChatting = false;
|
||||
|
|
@ -26,15 +28,20 @@ export class WebLLM implements Engine {
|
|||
this.reset();
|
||||
}
|
||||
|
||||
public constructor(model: string) {
|
||||
this.model = model;
|
||||
}
|
||||
|
||||
public async init(): Promise<void> {
|
||||
window.webLLM =
|
||||
window.webLLM ||
|
||||
window.webLLM = window.webLLM || {};
|
||||
window.webLLM[this.model] =
|
||||
window.webLLM[this.model] ||
|
||||
new Worker(
|
||||
new URL("hooks/useInference/WebLLM.worker.ts", import.meta.url),
|
||||
{ name: "WebLLM" }
|
||||
{ name: this.model, type: "module" }
|
||||
);
|
||||
this.worker = window.webLLM;
|
||||
this.worker.postMessage({ type: "init" });
|
||||
this.worker = window.webLLM[this.model];
|
||||
this.worker.postMessage({ model: this.model, type: "init" });
|
||||
|
||||
// eslint-disable-next-line unicorn/no-useless-promise-resolve-reject
|
||||
return Promise.resolve();
|
||||
|
|
|
|||
|
|
@ -1,46 +1,79 @@
|
|||
const libs = [
|
||||
"/System/tvm/tvmjs_runtime.wasi.js",
|
||||
"/System/tvm/tvmjs.bundle.js",
|
||||
"/Program Files/WebLLM/llm_chat.js",
|
||||
"/Program Files/WebLLM/sentencepiece.js",
|
||||
];
|
||||
import type { InitProgressReport } from "@mlc-ai/web-llm";
|
||||
import { ChatModule } from "@mlc-ai/web-llm";
|
||||
|
||||
const runLLM = async (message: string): Promise<string> => {
|
||||
globalThis.tvmjsGlobalEnv.message = message;
|
||||
type Data = { model?: string; prompt?: string; type: string };
|
||||
|
||||
await globalThis.tvmjsGlobalEnv.asyncOnGenerate();
|
||||
const CACHE_WARNING =
|
||||
"It can take a while when we first visit this page to populate the cache. Later refreshes will become faster.";
|
||||
const DEFAULT_MODEL = "RedPajama 3B";
|
||||
|
||||
return globalThis.tvmjsGlobalEnv.response;
|
||||
const configMap: Record<string, string> = {
|
||||
"[RedPajama 3B]": "RedPajama-INCITE-Chat-3B-v1-q4f32_0",
|
||||
"[Vicuna 7B]": "vicuna-v1-7b-q4f32_0",
|
||||
};
|
||||
const config = {
|
||||
model_lib_map: {
|
||||
"RedPajama-INCITE-Chat-3B-v1-q4f32_0":
|
||||
"/Program Files/WebLLM/RedPajama-INCITE-Chat-3B-v1-q4f32_0-webgpu.wasm",
|
||||
"vicuna-v1-7b-q4f32_0":
|
||||
"/Program Files/WebLLM/vicuna-v1-7b-q4f32_0-webgpu.wasm",
|
||||
},
|
||||
model_list: [
|
||||
{
|
||||
local_id: "RedPajama-INCITE-Chat-3B-v1-q4f32_0",
|
||||
model_url:
|
||||
"https://huggingface.co/mlc-ai/mlc-chat-RedPajama-INCITE-Chat-3B-v1-q4f32_0/resolve/main/",
|
||||
},
|
||||
{
|
||||
local_id: "vicuna-v1-7b-q4f32_0",
|
||||
model_url:
|
||||
"https://huggingface.co/mlc-ai/mlc-chat-vicuna-v1-7b-q4f32_0/resolve/main/",
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const generateProgressCallback = (step: number): void => {
|
||||
globalThis.postMessage({
|
||||
message: `Generating (Step ${step})`,
|
||||
type: "[progress]",
|
||||
});
|
||||
};
|
||||
|
||||
const initProgressCallback = (report: InitProgressReport): void => {
|
||||
globalThis.postMessage({
|
||||
message: report.text.replace(CACHE_WARNING, ""),
|
||||
type: "[init]",
|
||||
});
|
||||
};
|
||||
|
||||
let initalized = false;
|
||||
let startedChat = false;
|
||||
let chatModule: ChatModule;
|
||||
let chatModel: string;
|
||||
|
||||
globalThis.addEventListener(
|
||||
"message",
|
||||
({ data }: { data: { prompt?: string; type: string } }) => {
|
||||
async ({ data }: { data: Data }) => {
|
||||
if (!initalized && data.type === "init") {
|
||||
initalized = true;
|
||||
|
||||
globalThis.tvmjsGlobalEnv = globalThis.tvmjsGlobalEnv || {};
|
||||
globalThis.tvmjsGlobalEnv.logger = (type: string, message: string) =>
|
||||
globalThis.postMessage({ message, type });
|
||||
|
||||
globalThis.importScripts(...libs);
|
||||
|
||||
globalThis.tvmjsGlobalEnv.sentencePieceProcessor = (url: string) =>
|
||||
globalThis.sentencepiece.sentencePieceProcessor(url);
|
||||
chatModel =
|
||||
configMap[data.model?.replace("WebLLM ", "") || DEFAULT_MODEL];
|
||||
chatModule = new ChatModule();
|
||||
chatModule.setInitProgressCallback(initProgressCallback);
|
||||
} else if (data.type === "reset") {
|
||||
globalThis.tvmjsGlobalEnv.response = "";
|
||||
globalThis.tvmjsGlobalEnv.message = "";
|
||||
globalThis.tvmjsGlobalEnv.systemPrompt = "";
|
||||
|
||||
globalThis.tvmjsGlobalEnv.asyncOnReset();
|
||||
} else if (data.prompt) {
|
||||
if (data.type === "system") {
|
||||
globalThis.tvmjsGlobalEnv.systemPrompt = data.prompt;
|
||||
} else if (data.type === "chat") {
|
||||
runLLM(data.prompt).then(globalThis.postMessage);
|
||||
await chatModule.interruptGenerate();
|
||||
await chatModule.resetChat();
|
||||
} else if (data.prompt && data.type === "chat") {
|
||||
// TODO: Support changing system prompt
|
||||
if (!startedChat) {
|
||||
await chatModule.reload(chatModel, undefined, config);
|
||||
startedChat = true;
|
||||
}
|
||||
|
||||
chatModule
|
||||
.generate(data.prompt, generateProgressCallback)
|
||||
.then(globalThis.postMessage);
|
||||
}
|
||||
},
|
||||
{ passive: true }
|
||||
|
|
|
|||
|
|
@ -38,7 +38,12 @@ type Inference = {
|
|||
resetError: () => void;
|
||||
};
|
||||
|
||||
const Engines = { HuggingFace, OpenAI, WebLLM } as Record<string, EngineClass>;
|
||||
const Engines = {
|
||||
HuggingFace,
|
||||
OpenAI,
|
||||
"WebLLM [RedPajama 3B]": WebLLM,
|
||||
"WebLLM [Vicuna 7B]": WebLLM,
|
||||
} as Record<string, EngineClass>;
|
||||
|
||||
export const useInference = (apiKey = "", engine = ""): Inference => {
|
||||
const [error, setError] = useState<number>(0);
|
||||
|
|
@ -52,10 +57,10 @@ export const useInference = (apiKey = "", engine = ""): Inference => {
|
|||
|
||||
if (engine && engine in Engines) {
|
||||
currentEngine =
|
||||
engine === "WebLLM" && !hasWebGPU
|
||||
engine.startsWith("WebLLM") && !hasWebGPU
|
||||
? DEFAULT_NON_WEBGPU_ENGINE
|
||||
: engine;
|
||||
} else if (currentEngine === "WebLLM" && !hasWebGPU) {
|
||||
} else if (currentEngine.startsWith("WebLLM") && !hasWebGPU) {
|
||||
currentEngine = DEFAULT_NON_WEBGPU_ENGINE;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -49,6 +49,10 @@ const nextConfig = {
|
|||
}
|
||||
})
|
||||
);
|
||||
config.resolve.fallback = {
|
||||
module: false,
|
||||
perf_hooks: false,
|
||||
};
|
||||
|
||||
return config;
|
||||
},
|
||||
|
|
|
|||
|
|
@ -40,6 +40,7 @@
|
|||
"*.{js,ts,tsx}": "eslint --fix"
|
||||
},
|
||||
"dependencies": {
|
||||
"@mlc-ai/web-llm": "^0.1.1",
|
||||
"@monaco-editor/react": "^4.5.1",
|
||||
"@panzoom/panzoom": "^4.5.1",
|
||||
"@prettier/plugin-xml": "^2.2.0",
|
||||
|
|
|
|||
Binary file not shown.
|
|
@ -1,8 +0,0 @@
|
|||
{
|
||||
"url_dict":{
|
||||
"vicuna-v1-7b-q4f32_0": "/Program Files/WebLLM/vicuna-7b/model_config.json"
|
||||
},
|
||||
"model_lib_map":{
|
||||
"vicuna-v1-7b-q4f32_0": "/Program Files/WebLLM/vicuna-7b/vicuna-v1-7b-q4f32_0-webgpu.wasm"
|
||||
}
|
||||
}
|
||||
|
|
@ -1,648 +0,0 @@
|
|||
/**
|
||||
* Helper to keep track of history conversations.
|
||||
*/
|
||||
class Conversation {
|
||||
constructor(config) {
|
||||
this.system = config.system;
|
||||
this.roles = config.roles;
|
||||
this.offset = config.offset;
|
||||
this.seps = config.seps;
|
||||
this.convId = null;
|
||||
this.messages = [];
|
||||
this.contextWindowStart = 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get prompt arrays with the first one as system.
|
||||
*
|
||||
* @returns The prompt array.
|
||||
*/
|
||||
getPromptArray() {
|
||||
if (this.seps.length == 0) {
|
||||
throw Error("Need seps to work")
|
||||
}
|
||||
let ret = [this.system + this.seps[0]];
|
||||
|
||||
for (let i = 0; i < this.messages.length; ++i) {
|
||||
const item = this.messages[i];
|
||||
const role = item[0];
|
||||
const message = item[1];
|
||||
if (message !== undefined && message != "") {
|
||||
ret.push(role + ": " + message + this.seps[i % this.seps.length]);
|
||||
} else {
|
||||
ret.push(role + ":");
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get prompt arrays that has not been fed as input
|
||||
*
|
||||
* @returns The prompt array.
|
||||
*/
|
||||
getPromptArrayUnproccessed() {
|
||||
if (this.seps.length == 0) {
|
||||
throw Error("Need seps to work")
|
||||
}
|
||||
if (this.messages.length < 3) {
|
||||
throw Error("needs to call getLastPromptArray for the first message");
|
||||
}
|
||||
let ret = [this.seps[this.seps.length - 1]];
|
||||
for (let i = this.messages.length - 2; i < this.messages.length; ++i) {
|
||||
const item = this.messages[i];
|
||||
const role = item[0];
|
||||
const message = item[1];
|
||||
if (message !== undefined && message != "") {
|
||||
ret.push(role + ": " + message + this.seps[i % this.seps.length]);
|
||||
} else {
|
||||
ret.push(role + ":");
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Get last prompt array with prefix as system.
|
||||
*
|
||||
* @returns The prompt array.
|
||||
*/
|
||||
getLastPromptArray() {
|
||||
if (this.seps.length == 0) {
|
||||
throw Error("Need seps to work")
|
||||
}
|
||||
let ret = [this.system + this.seps[0]];
|
||||
|
||||
for (let i = this.messages.length - 2; i < this.messages.length; ++i) {
|
||||
const item = this.messages[i];
|
||||
const role = item[0];
|
||||
const message = item[1];
|
||||
if (message !== undefined && message != "") {
|
||||
ret.push(role + ": " + message + this.seps[i % this.seps.length]);
|
||||
} else {
|
||||
ret.push(role + ":");
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
reset() {
|
||||
this.messages = [];
|
||||
}
|
||||
|
||||
getStopStr() {
|
||||
return this.seps[this.seps.length - 1];
|
||||
}
|
||||
|
||||
appendMessage(role, message) {
|
||||
this.messages.push([role, message]);
|
||||
}
|
||||
}
|
||||
|
||||
function getConversation(conv_template, maxWindowLength = 512) {
|
||||
if (conv_template == "vicuna-v1.1") {
|
||||
return new Conversation({
|
||||
system: globalThis.tvmjsGlobalEnv.systemPrompt || "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
||||
roles: ["USER", "ASSISTANT"],
|
||||
maxWindowLength: maxWindowLength,
|
||||
messages: [],
|
||||
offset: 0,
|
||||
seps: [" ", "</s>"],
|
||||
});
|
||||
} else if (conv_template == "wizardlm") {
|
||||
return new Conversation({
|
||||
system: globalThis.tvmjsGlobalEnv.systemPrompt || "You are an AI assistant that gives helpful, detailed, and polite answers to the user's questions.",
|
||||
roles: ["", "### Response"],
|
||||
maxWindowLength: maxWindowLength,
|
||||
messages: [],
|
||||
offset: 0,
|
||||
seps: ["\n\n", "</s>"],
|
||||
})
|
||||
} else {
|
||||
throw Error("Unknown model "+ model);
|
||||
}
|
||||
};
|
||||
|
||||
class LLMChatPipeline {
|
||||
constructor(tvm, tokenizer, cacheMetadata, config) {
|
||||
if (cacheMetadata == undefined) {
|
||||
throw Error("Expect cacheMetadata");
|
||||
}
|
||||
this.tvm = tvm;
|
||||
this.logger = globalThis.tvmjsGlobalEnv.logger || console.log;
|
||||
this.tokenizer = tokenizer;
|
||||
this.bosTokenId = 1;
|
||||
this.eosTokenId = 2;
|
||||
|
||||
this.temperature = config.temperature;
|
||||
this.top_p = config.top_p;
|
||||
this.maxWindowLength = config.max_seq_len;
|
||||
this.maxGenLength = config.maxGenLength;
|
||||
this.meanGenLength = config.meanGenLength;
|
||||
this.streamInterval = 1;
|
||||
|
||||
this.decodingTotalTime = 0;
|
||||
this.decodingTotalTokens = 0;
|
||||
this.encodingTotalTime = 0;
|
||||
this.encodingTotalTokens = 0;
|
||||
|
||||
this.conversation = getConversation(config.conv_template, this.maxWindowLength);
|
||||
|
||||
this.device = this.tvm.webgpu();
|
||||
this.vm = this.tvm.detachFromCurrentScope(
|
||||
this.tvm.createVirtualMachine(this.device)
|
||||
);
|
||||
this.encoding = this.tvm.detachFromCurrentScope(
|
||||
this.vm.getFunction("encoding")
|
||||
);
|
||||
this.decoding = this.tvm.detachFromCurrentScope(
|
||||
this.vm.getFunction("decoding")
|
||||
);
|
||||
this.params = this.tvm.detachFromCurrentScope(
|
||||
this.tvm.getParamsFromCache("param", cacheMetadata.ParamSize)
|
||||
);
|
||||
const fcreateCache = this.vm.getFunction("create_kv_cache");
|
||||
this.fclearKVCaches = this.tvm.detachFromCurrentScope(
|
||||
this.tvm.getGlobalFunc("vm.builtin.attention_kv_cache_array_clear")
|
||||
);
|
||||
|
||||
// use extern config for now
|
||||
this.kvCache = this.tvm.detachFromCurrentScope(fcreateCache());
|
||||
// fill with pad token
|
||||
this.logitsOnCPU = undefined;
|
||||
|
||||
this.kvCacheLength = 0;
|
||||
this.clearCache = true
|
||||
}
|
||||
|
||||
|
||||
dispose() {
|
||||
// note: tvm instance is not owned by this class
|
||||
this.params.dispose();
|
||||
this.decoding.dispose();
|
||||
this.encoding.dispose();
|
||||
this.vm.dispose();
|
||||
this.kvCache.dispose();
|
||||
this.fclearKVCaches.dispose();
|
||||
if (this.logitsOnCPU != undefined) {
|
||||
this.logitsOnCPU.dispose();
|
||||
}
|
||||
}
|
||||
|
||||
#clearKVCache() {
|
||||
this.fclearKVCaches(this.kvCache);
|
||||
this.kvCacheLength = 0;
|
||||
}
|
||||
|
||||
#forward(inputs, curPos) {
|
||||
this.tvm.beginScope();
|
||||
var retValue;
|
||||
const seqLenShape = this.tvm.makeShapeTuple([curPos]);
|
||||
if (inputs.shape[1] > 1) {
|
||||
retValue = this.encoding(
|
||||
inputs, seqLenShape, this.kvCache, this.params
|
||||
);
|
||||
} else {
|
||||
retValue = this.decoding(
|
||||
inputs, seqLenShape, this.kvCache, this.params
|
||||
);
|
||||
}
|
||||
const logits = this.tvm.detachFromCurrentScope(retValue.get(0));
|
||||
this.tvm.endScope();
|
||||
this.tvm.attachToCurrentScope(logits);
|
||||
return logits;
|
||||
}
|
||||
|
||||
// NOTE: caller must call device.sync()
|
||||
#updateLogitsOnCPU(logits) {
|
||||
if (this.logitsOnCPU == undefined) {
|
||||
this.logitsOnCPU = this.tvm.detachFromCurrentScope(
|
||||
this.tvm.empty(logits.shape, logits.dtype, this.tvm.cpu())
|
||||
);
|
||||
} else {
|
||||
if (logits.shape[0] != this.logitsOnCPU.shape[0]) {
|
||||
throw Error("We expect the size of logits to remain unchanged");
|
||||
}
|
||||
}
|
||||
this.logitsOnCPU.copyFrom(logits);
|
||||
}
|
||||
|
||||
async sampleTokenFromLogits(logits, temperature = 0.8, top_p = 0.95) {
|
||||
this.tvm.beginScope();
|
||||
this.#updateLogitsOnCPU(logits);
|
||||
this.tvm.endScope();
|
||||
await this.device.sync();
|
||||
return this.tvm.sampleTopPFromLogits(this.logitsOnCPU, temperature, top_p);
|
||||
}
|
||||
|
||||
async getInputTokens() {
|
||||
let tokens = [this.bosTokenId];
|
||||
let prompts = ""
|
||||
if (this.conversation.messages.length <= 2) {
|
||||
prompts = this.conversation.getPromptArray();
|
||||
} else {
|
||||
tokens.pop();
|
||||
prompts = this.conversation.getPromptArrayUnproccessed();
|
||||
}
|
||||
tokens.push(...await this.tokenizer.encodeIds(prompts[0]));
|
||||
let ctxLength = tokens.length;
|
||||
let context = [];
|
||||
let need_shift_window = false;
|
||||
for (let i = prompts.length - 1; i > 0; --i) {
|
||||
const encoded = this.tokenizer.encodeIds(prompts[i]);
|
||||
ctxLength += encoded.length;
|
||||
if (this.kvCacheLength + ctxLength + this.meanGenLength >= this.maxWindowLength) {
|
||||
need_shift_window = true;
|
||||
break;
|
||||
}
|
||||
context.unshift(encoded);
|
||||
}
|
||||
if (!need_shift_window) {
|
||||
for (const ctx of context) {
|
||||
tokens.push(...ctx);
|
||||
}
|
||||
return tokens;
|
||||
}
|
||||
// need shift window and re-encode
|
||||
this.kvCacheLength = 0;
|
||||
this.clearCache = true;
|
||||
// abandon all tokens we collected
|
||||
tokens = [this.bosTokenId]
|
||||
let all_prompts = this.conversation.getPromptArray();
|
||||
tokens.push(...await this.tokenizer.encodeIds(all_prompts[0]));
|
||||
context = [];
|
||||
ctxLength = tokens.length;
|
||||
//only keep 10% of the window context
|
||||
const fill_factor = 0.1
|
||||
for (let i = all_prompts.length - 1; i > 0; --i) {
|
||||
const encoded = this.tokenizer.encodeIds(all_prompts[i]);
|
||||
ctxLength += encoded.length;
|
||||
if (ctxLength >= fill_factor * this.maxWindowLength && i + 2 < all_prompts.length) {
|
||||
break;
|
||||
}
|
||||
context.unshift(encoded);
|
||||
}
|
||||
for (const ctx of context) {
|
||||
tokens.push(...ctx);
|
||||
}
|
||||
if (tokens.length + this.meanGenLength >= this.maxWindowLength) {
|
||||
throw Error("Exceed max window length curr=" + tokens.length);
|
||||
}
|
||||
return tokens;
|
||||
}
|
||||
|
||||
resetChat() {
|
||||
this.conversation.reset();
|
||||
this.#clearKVCache();
|
||||
this.decodingTotalTime = 0;
|
||||
this.encodingTotalTime = 0;
|
||||
this.decodingTotalTokens = 0;
|
||||
this.encodingTotalTokens = 0;
|
||||
}
|
||||
|
||||
async generate(inputPrompt, callbackUpdateResponse) {
|
||||
this.conversation.appendMessage(this.conversation.roles[0], inputPrompt);
|
||||
this.conversation.appendMessage(this.conversation.roles[1], "");
|
||||
const stopStr = this.conversation.getStopStr();
|
||||
const tokens = await this.getInputTokens();
|
||||
const inputTokenLength = tokens.length;
|
||||
|
||||
var outputPrompt = "";
|
||||
if (this.clearCache) {
|
||||
this.#clearKVCache();
|
||||
this.clearCache = false;
|
||||
}
|
||||
const maxGenLen = Math.min(this.maxGenLength, this.maxWindowLength - tokens.length);
|
||||
if (maxGenLen < this.meanGenLength) {
|
||||
throw Error("Too small window size config");
|
||||
}
|
||||
let step = 0;
|
||||
for (; step < maxGenLen && this.kvCacheLength + inputTokenLength + step < this.maxWindowLength; ++step) {
|
||||
this.tvm.beginScope();
|
||||
var inputData;
|
||||
|
||||
let tstart = performance.now();
|
||||
if (step == 0) {
|
||||
inputData = this.tvm.empty([1, tokens.length], "int32", this.device);
|
||||
inputData.copyFrom(tokens);
|
||||
} else {
|
||||
inputData = this.tvm.empty([1, 1], "int32", this.device);
|
||||
inputData.copyFrom(tokens.slice(tokens.length - 1));
|
||||
}
|
||||
const logits = this.tvm.detachFromCurrentScope(
|
||||
this.#forward(inputData, this.kvCacheLength + inputTokenLength + step)
|
||||
);
|
||||
this.tvm.endScope();
|
||||
|
||||
const nextToken = await this.sampleTokenFromLogits(logits, this.temperature, this.top_p);
|
||||
logits.dispose();
|
||||
|
||||
tokens.push(nextToken);
|
||||
const outputTokens = tokens.slice(inputTokenLength);
|
||||
outputPrompt = this.tokenizer.decodeIds(outputTokens);
|
||||
|
||||
if (nextToken == this.eosTokenId) break;
|
||||
|
||||
const stopPos = outputPrompt.lastIndexOf(stopStr);
|
||||
if (stopPos != -1) {
|
||||
outputPrompt = outputPrompt.substring(0, stopPos);
|
||||
break;
|
||||
}
|
||||
let tend = performance.now();
|
||||
if (step != 0) {
|
||||
this.decodingTotalTokens += 1;
|
||||
this.decodingTotalTime += (tend - tstart) / 1000;
|
||||
} else {
|
||||
this.encodingTotalTime += (tend - tstart) / 1000;
|
||||
this.encodingTotalTokens += inputTokenLength;
|
||||
}
|
||||
|
||||
if (step % this.streamInterval == 0) {
|
||||
callbackUpdateResponse(step, outputPrompt);
|
||||
}
|
||||
}
|
||||
this.kvCacheLength += tokens.length - 1;
|
||||
this.conversation.messages[this.conversation.messages.length - 1][1] = outputPrompt;
|
||||
return outputPrompt;
|
||||
}
|
||||
|
||||
async evaluate() {
|
||||
// run a canonical evaluation of the flow
|
||||
this.#clearKVCache();
|
||||
const testPrompt = "The capital of Canada is";
|
||||
const ids = await this.tokenizer.encodeIds(testPrompt);
|
||||
const inputPromptSize = ids.length;
|
||||
const tokens = Array.from(ids);
|
||||
tokens.unshift(this.bosTokenId);
|
||||
if (tokens.length == 0) {
|
||||
throw Error("empty token");
|
||||
}
|
||||
|
||||
this.tvm.beginScope();
|
||||
const inputData = this.tvm.empty([1, tokens.length], "int32", this.device);
|
||||
inputData.copyFrom(tokens);
|
||||
const encodingStart = performance.now();
|
||||
this.#forward(inputData, tokens.length);
|
||||
this.tvm.endScope();
|
||||
await this.device.sync();
|
||||
|
||||
const decodingStart = performance.now();
|
||||
|
||||
this.tvm.beginScope();
|
||||
const firstSampleToken = this.tvm.empty([1, 1], "int32", this.device).copyFrom([6234]);
|
||||
this.#updateLogitsOnCPU(this.#forward(firstSampleToken, tokens.length + 1));
|
||||
await this.device.sync();
|
||||
this.tvm.endScope();
|
||||
|
||||
const decodingEnd = performance.now();
|
||||
const msg = (
|
||||
`encoding-time=${((decodingStart - encodingStart) / 1000).toFixed(4)} sec` +
|
||||
`decoding-time=${((decodingEnd - decodingStart) / 1000).toFixed(4)} sec`
|
||||
);
|
||||
|
||||
// simply log tokens for eyeballing.
|
||||
console.log("Logits:");
|
||||
console.log(this.logitsOnCPU.toArray());
|
||||
console.log(msg);
|
||||
}
|
||||
|
||||
/**
|
||||
* async preload webgpu pipelines when possible.
|
||||
*/
|
||||
async asyncLoadWebGPUPiplines() {
|
||||
await this.tvm.asyncLoadWebGPUPiplines(this.vm.getInternalModule());
|
||||
}
|
||||
|
||||
runtimeStatsText() {
|
||||
return (
|
||||
`encoding: ${(this.encodingTotalTokens / this.encodingTotalTime).toFixed(4)} tokens/sec, ` +
|
||||
`decoding: ${(this.decodingTotalTokens / this.decodingTotalTime).toFixed(4)} tokens/sec`
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A instance that can be used to facilitate deployment.
|
||||
*/
|
||||
class LLMChatInstance {
|
||||
constructor() {
|
||||
this.requestInProgress = false;
|
||||
this.config = undefined;
|
||||
this.tvm = undefined;
|
||||
this.pipeline = undefined;
|
||||
this.logger = globalThis.tvmjsGlobalEnv.logger || console.log;
|
||||
this.debugTest = false;
|
||||
this.model_name = globalThis.tvmjsGlobalEnv.modelName || "vicuna-v1-7b-q4f32_0";
|
||||
}
|
||||
|
||||
reboot() {
|
||||
this.config = undefined;
|
||||
this.pipeline = undefined;
|
||||
if (this.tvm !== undefined) {
|
||||
this.tvm.dispose();
|
||||
this.tvm = undefined;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialize TVM
|
||||
* @param wasmUrl URL to wasm source.
|
||||
* @param cacheUrl URL to NDArray cache.
|
||||
* @param logger Custom logger.
|
||||
*/
|
||||
async #asyncInitTVM(wasmUrl, cacheUrl) {
|
||||
if (this.tvm !== undefined) {
|
||||
return;
|
||||
}
|
||||
this.logger = globalThis.tvmjsGlobalEnv.logger || console.log;
|
||||
|
||||
const wasmSource = await (
|
||||
await fetch(wasmUrl)
|
||||
).arrayBuffer();
|
||||
const tvm = await tvmjs.instantiate(
|
||||
new Uint8Array(wasmSource),
|
||||
new EmccWASI(),
|
||||
this.logger
|
||||
);
|
||||
// intialize WebGPU
|
||||
try {
|
||||
const output = await tvmjs.detectGPUDevice();
|
||||
if (output !== undefined) {
|
||||
var label = "WebGPU";
|
||||
if (output.adapterInfo.description.length != 0) {
|
||||
label += " - " + output.adapterInfo.description;
|
||||
} else {
|
||||
label += " - " + output.adapterInfo.vendor;
|
||||
}
|
||||
this.appendMessage("init", "Initialize GPU device: " + label);
|
||||
tvm.initWebGPU(output.device);
|
||||
} else {
|
||||
this.appendMessage("error", "This browser env do not support WebGPU");
|
||||
this.reset();
|
||||
throw Error("This browser env do not support WebGPU");
|
||||
}
|
||||
} catch (err) {
|
||||
this.appendMessage("error", "Find an error initializing the WebGPU device " + err.toString());
|
||||
this.logger("[error]", err.toString());
|
||||
console.log(err);
|
||||
this.reset();
|
||||
throw Error("Find an error initializing WebGPU: " + err.toString());
|
||||
}
|
||||
this.tvm = tvm;
|
||||
const initProgressCallback = (report) => {
|
||||
this.updateLastMessage("init", report.text);
|
||||
}
|
||||
tvm.registerInitProgressCallback(initProgressCallback);
|
||||
|
||||
await tvm.fetchNDArrayCache(cacheUrl, tvm.webgpu());
|
||||
}
|
||||
/**
|
||||
* Async initialize instance.
|
||||
*/
|
||||
async asyncInit() {
|
||||
if (this.pipeline !== undefined) return;
|
||||
await this.#asyncInitConfig();
|
||||
await this.#asyncInitTVM(this.config.wasmUrl, this.config.cacheUrl);
|
||||
await this.#asyncInitPipeline();
|
||||
}
|
||||
|
||||
/**
|
||||
* Async initialize config
|
||||
*/
|
||||
async #asyncInitConfig() {
|
||||
if (this.config !== undefined) return;
|
||||
const global_config = await (await fetch("/Program Files/WebLLM/global_config.json")).json();
|
||||
this.config = await (await fetch(global_config.url_dict[this.model_name])).json();
|
||||
this.config.wasmUrl = global_config.model_lib_map[this.config.model_lib];
|
||||
this.config.cacheUrl = this.config.model_url;
|
||||
this.config.tokenizer = this.config.tokenizer_files[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialize the pipeline
|
||||
*
|
||||
* @param tokenizerModel The url to tokenizer model.
|
||||
*/
|
||||
async #asyncInitPipeline() {
|
||||
if (this.pipeline !== undefined) return;
|
||||
// initialize UX and tokenizer
|
||||
const tokenizer = await tvmjsGlobalEnv.sentencePieceProcessor(this.config.tokenizer);
|
||||
this.pipeline = this.tvm.withNewScope(() => {
|
||||
return new LLMChatPipeline(this.tvm, tokenizer, this.tvm.cacheMetadata, this.config);
|
||||
});
|
||||
await this.pipeline.asyncLoadWebGPUPiplines();
|
||||
this.updateLastMessage("init", "All initialization finished.");
|
||||
}
|
||||
|
||||
appendMessage(kind, text) {
|
||||
if (kind == "init") {
|
||||
text = "[System Initalize] " + text;
|
||||
}
|
||||
this.logger(`[${kind}]`, text);
|
||||
}
|
||||
|
||||
updateLastMessage(kind, text) {
|
||||
if (kind == "init") {
|
||||
this.logger("[init]", text);
|
||||
} else if (kind == "left") {
|
||||
globalThis.tvmjsGlobalEnv.response = text;
|
||||
}
|
||||
}
|
||||
|
||||
async respondTestMessage(repeat) {
|
||||
const testMessage = "I am a friendly bot. Please ask questions.";
|
||||
const encodedResult = await this.pipeline.tokenizer.encodeIds(testMessage);
|
||||
|
||||
const currentIds = [];
|
||||
for (let k = 0; k < repeat; ++k) {
|
||||
for (let i = 0; i < encodedResult.length; ++i) {
|
||||
currentIds.push(encodedResult[i]);
|
||||
const msg = this.pipeline.tokenizer.decodeIds(currentIds);
|
||||
this.updateLastMessage("left", msg);
|
||||
await new Promise(resolve => setTimeout(resolve, 50));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
resetChat() {
|
||||
this.pipeline.resetChat();
|
||||
}
|
||||
|
||||
/**
|
||||
* Run generate
|
||||
*/
|
||||
async generate() {
|
||||
if (this.requestInProgress) {
|
||||
return;
|
||||
}
|
||||
|
||||
this.requestInProgress = true;
|
||||
|
||||
try {
|
||||
await this.asyncInit();
|
||||
} catch (err) {
|
||||
this.appendMessage("error", "Init error, " + err.toString());
|
||||
this.logger("[error]", err.toString());
|
||||
console.log(err);
|
||||
this.reset();
|
||||
this.requestInProgress = false;
|
||||
return;
|
||||
}
|
||||
|
||||
if (this.debugTest) {
|
||||
await this.pipeline.evaluate();
|
||||
this.requestInProgress = false;
|
||||
return;
|
||||
}
|
||||
|
||||
const prompt = globalThis.tvmjsGlobalEnv.message;
|
||||
if (prompt == "") {
|
||||
this.requestInProgress = false;
|
||||
return;
|
||||
}
|
||||
|
||||
this.appendMessage("progress", "Generating...");
|
||||
const callbackUpdateResponse = (step, msg) => {
|
||||
if (msg.endsWith("##")) {
|
||||
msg = msg.substring(0, msg.length - 2);
|
||||
} else if (msg.endsWith("#")) {
|
||||
msg = msg.substring(0, msg.length - 1);
|
||||
}
|
||||
this.updateLastMessage("left", msg);
|
||||
this.appendMessage("progress", `Generating step: ${step}...`);
|
||||
};
|
||||
try {
|
||||
const output = await this.pipeline.generate(prompt, callbackUpdateResponse);
|
||||
this.updateLastMessage("left", output);
|
||||
console.info(this.pipeline.runtimeStatsText());
|
||||
} catch (err) {
|
||||
this.appendMessage("error", "Generate error, " + err.toString());
|
||||
this.logger("[error]", err.toString());
|
||||
console.log(err);
|
||||
this.reset();
|
||||
}
|
||||
this.requestInProgress = false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Reset the instance;
|
||||
*/
|
||||
reset() {
|
||||
this.tvm = undefined;
|
||||
if (this.pipeline !== undefined) {
|
||||
this.pipeline.dispose();
|
||||
}
|
||||
this.pipeline = undefined;
|
||||
}
|
||||
}
|
||||
|
||||
localLLMChatIntance = new LLMChatInstance();
|
||||
|
||||
tvmjsGlobalEnv.asyncOnGenerate = async function () {
|
||||
await localLLMChatIntance.generate();
|
||||
};
|
||||
|
||||
tvmjsGlobalEnv.asyncOnReset = async function () {
|
||||
await localLLMChatIntance.resetChat();
|
||||
};
|
||||
File diff suppressed because one or more lines are too long
|
|
@ -1,14 +0,0 @@
|
|||
{
|
||||
"model_lib": "vicuna-v1-7b-q4f32_0",
|
||||
"model_url": "https://huggingface.co/hongyij/web-llm-test-model/resolve/main/",
|
||||
"maxGenLength": 1024,
|
||||
"meanGenLength": 256,
|
||||
"max_seq_len": 2048,
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.95,
|
||||
"local_id": "vicuna-v1-7b-q4f32_0",
|
||||
"tokenizer_files": [
|
||||
"/Program Files/WebLLM/vicuna-7b/tokenizer.model"
|
||||
],
|
||||
"conv_template": "vicuna-v1.1"
|
||||
}
|
||||
Binary file not shown.
Binary file not shown.
BIN
public/Program Files/WebLLM/vicuna-v1-7b-q4f32_0-webgpu.wasm
Normal file
BIN
public/Program Files/WebLLM/vicuna-v1-7b-q4f32_0-webgpu.wasm
Normal file
Binary file not shown.
|
|
@ -1402,6 +1402,11 @@
|
|||
semver "^7.3.5"
|
||||
tar "^6.1.11"
|
||||
|
||||
"@mlc-ai/web-llm@^0.1.1":
|
||||
version "0.1.1"
|
||||
resolved "https://registry.yarnpkg.com/@mlc-ai/web-llm/-/web-llm-0.1.1.tgz#013ca109a7b064362ab265c5011534c2e85cb529"
|
||||
integrity sha512-1SRlSaAPXxeSBRBrzymgTevw8Diy3iVfz/vmw5yKkbALPxOeMwGi3UIIJMlA29XqW8OrYocvCZ2nXdqRFkm+dQ==
|
||||
|
||||
"@monaco-editor/loader@^1.3.3":
|
||||
version "1.3.3"
|
||||
resolved "https://registry.yarnpkg.com/@monaco-editor/loader/-/loader-1.3.3.tgz#7f1742bd3cc21c0362a46a4056317f6e5215cfca"
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user