Use latest WebLLM implementation

This commit is contained in:
Dustin Brett 2023-05-24 22:54:21 -07:00
parent fe3d2cb92a
commit 2cd3676c0e
15 changed files with 104 additions and 804 deletions

View File

@ -61,7 +61,7 @@ const Chat: FC<ComponentProcessProps> = ({ id }) => {
error: aiError,
name,
resetError,
} = useInference(apiKey, engine);
} = useInference(engine.startsWith("WebLLM") ? engine : apiKey, engine);
const messagesRef = useRef<HTMLUListElement>(null);
const inputRef = useRef<HTMLTextAreaElement>(null);
const [input, setInput] = useState<string>("");
@ -395,12 +395,19 @@ const Chat: FC<ComponentProcessProps> = ({ id }) => {
? [
{
label: "Set AI Engine",
menu: ["HuggingFace", "OpenAI", "WebLLM"].map((engineName) => ({
menu: [
"HuggingFace",
"OpenAI",
"WebLLM [RedPajama 3B]",
"WebLLM [Vicuna 7B]",
].map((engineName) => ({
action: () => {
setAiApi(
`${engineName}:${
// eslint-disable-next-line no-alert
engineName === "WebLLM" ? "" : prompt("API Key") || ""
engineName.startsWith("WebLLM")
? ""
: prompt("API Key") || ""
}`
);
@ -414,7 +421,7 @@ const Chat: FC<ComponentProcessProps> = ({ id }) => {
label: engineName,
})),
},
...(["OpenAI", "WebLLM"].includes(name)
...(name === "OpenAI"
? [
{
action: () => {

View File

@ -6,7 +6,7 @@ type WorkerMessage = { data: Log | string };
declare global {
interface Window {
webLLM?: Worker;
webLLM?: Record<string, Worker>;
}
}
@ -16,6 +16,8 @@ const DEFAULT_GREETING = {
} as Message;
export class WebLLM implements Engine {
private model = "";
private worker?: Worker = undefined;
private isChatting = false;
@ -26,15 +28,20 @@ export class WebLLM implements Engine {
this.reset();
}
public constructor(model: string) {
this.model = model;
}
public async init(): Promise<void> {
window.webLLM =
window.webLLM ||
window.webLLM = window.webLLM || {};
window.webLLM[this.model] =
window.webLLM[this.model] ||
new Worker(
new URL("hooks/useInference/WebLLM.worker.ts", import.meta.url),
{ name: "WebLLM" }
{ name: this.model, type: "module" }
);
this.worker = window.webLLM;
this.worker.postMessage({ type: "init" });
this.worker = window.webLLM[this.model];
this.worker.postMessage({ model: this.model, type: "init" });
// eslint-disable-next-line unicorn/no-useless-promise-resolve-reject
return Promise.resolve();

View File

@ -1,46 +1,79 @@
const libs = [
"/System/tvm/tvmjs_runtime.wasi.js",
"/System/tvm/tvmjs.bundle.js",
"/Program Files/WebLLM/llm_chat.js",
"/Program Files/WebLLM/sentencepiece.js",
];
import type { InitProgressReport } from "@mlc-ai/web-llm";
import { ChatModule } from "@mlc-ai/web-llm";
const runLLM = async (message: string): Promise<string> => {
globalThis.tvmjsGlobalEnv.message = message;
type Data = { model?: string; prompt?: string; type: string };
await globalThis.tvmjsGlobalEnv.asyncOnGenerate();
const CACHE_WARNING =
"It can take a while when we first visit this page to populate the cache. Later refreshes will become faster.";
const DEFAULT_MODEL = "RedPajama 3B";
return globalThis.tvmjsGlobalEnv.response;
const configMap: Record<string, string> = {
"[RedPajama 3B]": "RedPajama-INCITE-Chat-3B-v1-q4f32_0",
"[Vicuna 7B]": "vicuna-v1-7b-q4f32_0",
};
const config = {
model_lib_map: {
"RedPajama-INCITE-Chat-3B-v1-q4f32_0":
"/Program Files/WebLLM/RedPajama-INCITE-Chat-3B-v1-q4f32_0-webgpu.wasm",
"vicuna-v1-7b-q4f32_0":
"/Program Files/WebLLM/vicuna-v1-7b-q4f32_0-webgpu.wasm",
},
model_list: [
{
local_id: "RedPajama-INCITE-Chat-3B-v1-q4f32_0",
model_url:
"https://huggingface.co/mlc-ai/mlc-chat-RedPajama-INCITE-Chat-3B-v1-q4f32_0/resolve/main/",
},
{
local_id: "vicuna-v1-7b-q4f32_0",
model_url:
"https://huggingface.co/mlc-ai/mlc-chat-vicuna-v1-7b-q4f32_0/resolve/main/",
},
],
};
const generateProgressCallback = (step: number): void => {
globalThis.postMessage({
message: `Generating (Step ${step})`,
type: "[progress]",
});
};
const initProgressCallback = (report: InitProgressReport): void => {
globalThis.postMessage({
message: report.text.replace(CACHE_WARNING, ""),
type: "[init]",
});
};
let initalized = false;
let startedChat = false;
let chatModule: ChatModule;
let chatModel: string;
globalThis.addEventListener(
"message",
({ data }: { data: { prompt?: string; type: string } }) => {
async ({ data }: { data: Data }) => {
if (!initalized && data.type === "init") {
initalized = true;
globalThis.tvmjsGlobalEnv = globalThis.tvmjsGlobalEnv || {};
globalThis.tvmjsGlobalEnv.logger = (type: string, message: string) =>
globalThis.postMessage({ message, type });
globalThis.importScripts(...libs);
globalThis.tvmjsGlobalEnv.sentencePieceProcessor = (url: string) =>
globalThis.sentencepiece.sentencePieceProcessor(url);
chatModel =
configMap[data.model?.replace("WebLLM ", "") || DEFAULT_MODEL];
chatModule = new ChatModule();
chatModule.setInitProgressCallback(initProgressCallback);
} else if (data.type === "reset") {
globalThis.tvmjsGlobalEnv.response = "";
globalThis.tvmjsGlobalEnv.message = "";
globalThis.tvmjsGlobalEnv.systemPrompt = "";
globalThis.tvmjsGlobalEnv.asyncOnReset();
} else if (data.prompt) {
if (data.type === "system") {
globalThis.tvmjsGlobalEnv.systemPrompt = data.prompt;
} else if (data.type === "chat") {
runLLM(data.prompt).then(globalThis.postMessage);
await chatModule.interruptGenerate();
await chatModule.resetChat();
} else if (data.prompt && data.type === "chat") {
// TODO: Support changing system prompt
if (!startedChat) {
await chatModule.reload(chatModel, undefined, config);
startedChat = true;
}
chatModule
.generate(data.prompt, generateProgressCallback)
.then(globalThis.postMessage);
}
},
{ passive: true }

View File

@ -38,7 +38,12 @@ type Inference = {
resetError: () => void;
};
const Engines = { HuggingFace, OpenAI, WebLLM } as Record<string, EngineClass>;
const Engines = {
HuggingFace,
OpenAI,
"WebLLM [RedPajama 3B]": WebLLM,
"WebLLM [Vicuna 7B]": WebLLM,
} as Record<string, EngineClass>;
export const useInference = (apiKey = "", engine = ""): Inference => {
const [error, setError] = useState<number>(0);
@ -52,10 +57,10 @@ export const useInference = (apiKey = "", engine = ""): Inference => {
if (engine && engine in Engines) {
currentEngine =
engine === "WebLLM" && !hasWebGPU
engine.startsWith("WebLLM") && !hasWebGPU
? DEFAULT_NON_WEBGPU_ENGINE
: engine;
} else if (currentEngine === "WebLLM" && !hasWebGPU) {
} else if (currentEngine.startsWith("WebLLM") && !hasWebGPU) {
currentEngine = DEFAULT_NON_WEBGPU_ENGINE;
}

View File

@ -49,6 +49,10 @@ const nextConfig = {
}
})
);
config.resolve.fallback = {
module: false,
perf_hooks: false,
};
return config;
},

View File

@ -40,6 +40,7 @@
"*.{js,ts,tsx}": "eslint --fix"
},
"dependencies": {
"@mlc-ai/web-llm": "^0.1.1",
"@monaco-editor/react": "^4.5.1",
"@panzoom/panzoom": "^4.5.1",
"@prettier/plugin-xml": "^2.2.0",

View File

@ -1,8 +0,0 @@
{
"url_dict":{
"vicuna-v1-7b-q4f32_0": "/Program Files/WebLLM/vicuna-7b/model_config.json"
},
"model_lib_map":{
"vicuna-v1-7b-q4f32_0": "/Program Files/WebLLM/vicuna-7b/vicuna-v1-7b-q4f32_0-webgpu.wasm"
}
}

View File

@ -1,648 +0,0 @@
/**
* Helper to keep track of history conversations.
*/
class Conversation {
constructor(config) {
this.system = config.system;
this.roles = config.roles;
this.offset = config.offset;
this.seps = config.seps;
this.convId = null;
this.messages = [];
this.contextWindowStart = 0;
}
/**
* Get prompt arrays with the first one as system.
*
* @returns The prompt array.
*/
getPromptArray() {
if (this.seps.length == 0) {
throw Error("Need seps to work")
}
let ret = [this.system + this.seps[0]];
for (let i = 0; i < this.messages.length; ++i) {
const item = this.messages[i];
const role = item[0];
const message = item[1];
if (message !== undefined && message != "") {
ret.push(role + ": " + message + this.seps[i % this.seps.length]);
} else {
ret.push(role + ":");
}
}
return ret;
}
/**
* Get prompt arrays that has not been fed as input
*
* @returns The prompt array.
*/
getPromptArrayUnproccessed() {
if (this.seps.length == 0) {
throw Error("Need seps to work")
}
if (this.messages.length < 3) {
throw Error("needs to call getLastPromptArray for the first message");
}
let ret = [this.seps[this.seps.length - 1]];
for (let i = this.messages.length - 2; i < this.messages.length; ++i) {
const item = this.messages[i];
const role = item[0];
const message = item[1];
if (message !== undefined && message != "") {
ret.push(role + ": " + message + this.seps[i % this.seps.length]);
} else {
ret.push(role + ":");
}
}
return ret;
}
/**
* Get last prompt array with prefix as system.
*
* @returns The prompt array.
*/
getLastPromptArray() {
if (this.seps.length == 0) {
throw Error("Need seps to work")
}
let ret = [this.system + this.seps[0]];
for (let i = this.messages.length - 2; i < this.messages.length; ++i) {
const item = this.messages[i];
const role = item[0];
const message = item[1];
if (message !== undefined && message != "") {
ret.push(role + ": " + message + this.seps[i % this.seps.length]);
} else {
ret.push(role + ":");
}
}
return ret;
}
reset() {
this.messages = [];
}
getStopStr() {
return this.seps[this.seps.length - 1];
}
appendMessage(role, message) {
this.messages.push([role, message]);
}
}
function getConversation(conv_template, maxWindowLength = 512) {
if (conv_template == "vicuna-v1.1") {
return new Conversation({
system: globalThis.tvmjsGlobalEnv.systemPrompt || "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.",
roles: ["USER", "ASSISTANT"],
maxWindowLength: maxWindowLength,
messages: [],
offset: 0,
seps: [" ", "</s>"],
});
} else if (conv_template == "wizardlm") {
return new Conversation({
system: globalThis.tvmjsGlobalEnv.systemPrompt || "You are an AI assistant that gives helpful, detailed, and polite answers to the user's questions.",
roles: ["", "### Response"],
maxWindowLength: maxWindowLength,
messages: [],
offset: 0,
seps: ["\n\n", "</s>"],
})
} else {
throw Error("Unknown model "+ model);
}
};
class LLMChatPipeline {
constructor(tvm, tokenizer, cacheMetadata, config) {
if (cacheMetadata == undefined) {
throw Error("Expect cacheMetadata");
}
this.tvm = tvm;
this.logger = globalThis.tvmjsGlobalEnv.logger || console.log;
this.tokenizer = tokenizer;
this.bosTokenId = 1;
this.eosTokenId = 2;
this.temperature = config.temperature;
this.top_p = config.top_p;
this.maxWindowLength = config.max_seq_len;
this.maxGenLength = config.maxGenLength;
this.meanGenLength = config.meanGenLength;
this.streamInterval = 1;
this.decodingTotalTime = 0;
this.decodingTotalTokens = 0;
this.encodingTotalTime = 0;
this.encodingTotalTokens = 0;
this.conversation = getConversation(config.conv_template, this.maxWindowLength);
this.device = this.tvm.webgpu();
this.vm = this.tvm.detachFromCurrentScope(
this.tvm.createVirtualMachine(this.device)
);
this.encoding = this.tvm.detachFromCurrentScope(
this.vm.getFunction("encoding")
);
this.decoding = this.tvm.detachFromCurrentScope(
this.vm.getFunction("decoding")
);
this.params = this.tvm.detachFromCurrentScope(
this.tvm.getParamsFromCache("param", cacheMetadata.ParamSize)
);
const fcreateCache = this.vm.getFunction("create_kv_cache");
this.fclearKVCaches = this.tvm.detachFromCurrentScope(
this.tvm.getGlobalFunc("vm.builtin.attention_kv_cache_array_clear")
);
// use extern config for now
this.kvCache = this.tvm.detachFromCurrentScope(fcreateCache());
// fill with pad token
this.logitsOnCPU = undefined;
this.kvCacheLength = 0;
this.clearCache = true
}
dispose() {
// note: tvm instance is not owned by this class
this.params.dispose();
this.decoding.dispose();
this.encoding.dispose();
this.vm.dispose();
this.kvCache.dispose();
this.fclearKVCaches.dispose();
if (this.logitsOnCPU != undefined) {
this.logitsOnCPU.dispose();
}
}
#clearKVCache() {
this.fclearKVCaches(this.kvCache);
this.kvCacheLength = 0;
}
#forward(inputs, curPos) {
this.tvm.beginScope();
var retValue;
const seqLenShape = this.tvm.makeShapeTuple([curPos]);
if (inputs.shape[1] > 1) {
retValue = this.encoding(
inputs, seqLenShape, this.kvCache, this.params
);
} else {
retValue = this.decoding(
inputs, seqLenShape, this.kvCache, this.params
);
}
const logits = this.tvm.detachFromCurrentScope(retValue.get(0));
this.tvm.endScope();
this.tvm.attachToCurrentScope(logits);
return logits;
}
// NOTE: caller must call device.sync()
#updateLogitsOnCPU(logits) {
if (this.logitsOnCPU == undefined) {
this.logitsOnCPU = this.tvm.detachFromCurrentScope(
this.tvm.empty(logits.shape, logits.dtype, this.tvm.cpu())
);
} else {
if (logits.shape[0] != this.logitsOnCPU.shape[0]) {
throw Error("We expect the size of logits to remain unchanged");
}
}
this.logitsOnCPU.copyFrom(logits);
}
async sampleTokenFromLogits(logits, temperature = 0.8, top_p = 0.95) {
this.tvm.beginScope();
this.#updateLogitsOnCPU(logits);
this.tvm.endScope();
await this.device.sync();
return this.tvm.sampleTopPFromLogits(this.logitsOnCPU, temperature, top_p);
}
async getInputTokens() {
let tokens = [this.bosTokenId];
let prompts = ""
if (this.conversation.messages.length <= 2) {
prompts = this.conversation.getPromptArray();
} else {
tokens.pop();
prompts = this.conversation.getPromptArrayUnproccessed();
}
tokens.push(...await this.tokenizer.encodeIds(prompts[0]));
let ctxLength = tokens.length;
let context = [];
let need_shift_window = false;
for (let i = prompts.length - 1; i > 0; --i) {
const encoded = this.tokenizer.encodeIds(prompts[i]);
ctxLength += encoded.length;
if (this.kvCacheLength + ctxLength + this.meanGenLength >= this.maxWindowLength) {
need_shift_window = true;
break;
}
context.unshift(encoded);
}
if (!need_shift_window) {
for (const ctx of context) {
tokens.push(...ctx);
}
return tokens;
}
// need shift window and re-encode
this.kvCacheLength = 0;
this.clearCache = true;
// abandon all tokens we collected
tokens = [this.bosTokenId]
let all_prompts = this.conversation.getPromptArray();
tokens.push(...await this.tokenizer.encodeIds(all_prompts[0]));
context = [];
ctxLength = tokens.length;
//only keep 10% of the window context
const fill_factor = 0.1
for (let i = all_prompts.length - 1; i > 0; --i) {
const encoded = this.tokenizer.encodeIds(all_prompts[i]);
ctxLength += encoded.length;
if (ctxLength >= fill_factor * this.maxWindowLength && i + 2 < all_prompts.length) {
break;
}
context.unshift(encoded);
}
for (const ctx of context) {
tokens.push(...ctx);
}
if (tokens.length + this.meanGenLength >= this.maxWindowLength) {
throw Error("Exceed max window length curr=" + tokens.length);
}
return tokens;
}
resetChat() {
this.conversation.reset();
this.#clearKVCache();
this.decodingTotalTime = 0;
this.encodingTotalTime = 0;
this.decodingTotalTokens = 0;
this.encodingTotalTokens = 0;
}
async generate(inputPrompt, callbackUpdateResponse) {
this.conversation.appendMessage(this.conversation.roles[0], inputPrompt);
this.conversation.appendMessage(this.conversation.roles[1], "");
const stopStr = this.conversation.getStopStr();
const tokens = await this.getInputTokens();
const inputTokenLength = tokens.length;
var outputPrompt = "";
if (this.clearCache) {
this.#clearKVCache();
this.clearCache = false;
}
const maxGenLen = Math.min(this.maxGenLength, this.maxWindowLength - tokens.length);
if (maxGenLen < this.meanGenLength) {
throw Error("Too small window size config");
}
let step = 0;
for (; step < maxGenLen && this.kvCacheLength + inputTokenLength + step < this.maxWindowLength; ++step) {
this.tvm.beginScope();
var inputData;
let tstart = performance.now();
if (step == 0) {
inputData = this.tvm.empty([1, tokens.length], "int32", this.device);
inputData.copyFrom(tokens);
} else {
inputData = this.tvm.empty([1, 1], "int32", this.device);
inputData.copyFrom(tokens.slice(tokens.length - 1));
}
const logits = this.tvm.detachFromCurrentScope(
this.#forward(inputData, this.kvCacheLength + inputTokenLength + step)
);
this.tvm.endScope();
const nextToken = await this.sampleTokenFromLogits(logits, this.temperature, this.top_p);
logits.dispose();
tokens.push(nextToken);
const outputTokens = tokens.slice(inputTokenLength);
outputPrompt = this.tokenizer.decodeIds(outputTokens);
if (nextToken == this.eosTokenId) break;
const stopPos = outputPrompt.lastIndexOf(stopStr);
if (stopPos != -1) {
outputPrompt = outputPrompt.substring(0, stopPos);
break;
}
let tend = performance.now();
if (step != 0) {
this.decodingTotalTokens += 1;
this.decodingTotalTime += (tend - tstart) / 1000;
} else {
this.encodingTotalTime += (tend - tstart) / 1000;
this.encodingTotalTokens += inputTokenLength;
}
if (step % this.streamInterval == 0) {
callbackUpdateResponse(step, outputPrompt);
}
}
this.kvCacheLength += tokens.length - 1;
this.conversation.messages[this.conversation.messages.length - 1][1] = outputPrompt;
return outputPrompt;
}
async evaluate() {
// run a canonical evaluation of the flow
this.#clearKVCache();
const testPrompt = "The capital of Canada is";
const ids = await this.tokenizer.encodeIds(testPrompt);
const inputPromptSize = ids.length;
const tokens = Array.from(ids);
tokens.unshift(this.bosTokenId);
if (tokens.length == 0) {
throw Error("empty token");
}
this.tvm.beginScope();
const inputData = this.tvm.empty([1, tokens.length], "int32", this.device);
inputData.copyFrom(tokens);
const encodingStart = performance.now();
this.#forward(inputData, tokens.length);
this.tvm.endScope();
await this.device.sync();
const decodingStart = performance.now();
this.tvm.beginScope();
const firstSampleToken = this.tvm.empty([1, 1], "int32", this.device).copyFrom([6234]);
this.#updateLogitsOnCPU(this.#forward(firstSampleToken, tokens.length + 1));
await this.device.sync();
this.tvm.endScope();
const decodingEnd = performance.now();
const msg = (
`encoding-time=${((decodingStart - encodingStart) / 1000).toFixed(4)} sec` +
`decoding-time=${((decodingEnd - decodingStart) / 1000).toFixed(4)} sec`
);
// simply log tokens for eyeballing.
console.log("Logits:");
console.log(this.logitsOnCPU.toArray());
console.log(msg);
}
/**
* async preload webgpu pipelines when possible.
*/
async asyncLoadWebGPUPiplines() {
await this.tvm.asyncLoadWebGPUPiplines(this.vm.getInternalModule());
}
runtimeStatsText() {
return (
`encoding: ${(this.encodingTotalTokens / this.encodingTotalTime).toFixed(4)} tokens/sec, ` +
`decoding: ${(this.decodingTotalTokens / this.decodingTotalTime).toFixed(4)} tokens/sec`
)
}
}
/**
* A instance that can be used to facilitate deployment.
*/
class LLMChatInstance {
constructor() {
this.requestInProgress = false;
this.config = undefined;
this.tvm = undefined;
this.pipeline = undefined;
this.logger = globalThis.tvmjsGlobalEnv.logger || console.log;
this.debugTest = false;
this.model_name = globalThis.tvmjsGlobalEnv.modelName || "vicuna-v1-7b-q4f32_0";
}
reboot() {
this.config = undefined;
this.pipeline = undefined;
if (this.tvm !== undefined) {
this.tvm.dispose();
this.tvm = undefined;
}
}
/**
* Initialize TVM
* @param wasmUrl URL to wasm source.
* @param cacheUrl URL to NDArray cache.
* @param logger Custom logger.
*/
async #asyncInitTVM(wasmUrl, cacheUrl) {
if (this.tvm !== undefined) {
return;
}
this.logger = globalThis.tvmjsGlobalEnv.logger || console.log;
const wasmSource = await (
await fetch(wasmUrl)
).arrayBuffer();
const tvm = await tvmjs.instantiate(
new Uint8Array(wasmSource),
new EmccWASI(),
this.logger
);
// intialize WebGPU
try {
const output = await tvmjs.detectGPUDevice();
if (output !== undefined) {
var label = "WebGPU";
if (output.adapterInfo.description.length != 0) {
label += " - " + output.adapterInfo.description;
} else {
label += " - " + output.adapterInfo.vendor;
}
this.appendMessage("init", "Initialize GPU device: " + label);
tvm.initWebGPU(output.device);
} else {
this.appendMessage("error", "This browser env do not support WebGPU");
this.reset();
throw Error("This browser env do not support WebGPU");
}
} catch (err) {
this.appendMessage("error", "Find an error initializing the WebGPU device " + err.toString());
this.logger("[error]", err.toString());
console.log(err);
this.reset();
throw Error("Find an error initializing WebGPU: " + err.toString());
}
this.tvm = tvm;
const initProgressCallback = (report) => {
this.updateLastMessage("init", report.text);
}
tvm.registerInitProgressCallback(initProgressCallback);
await tvm.fetchNDArrayCache(cacheUrl, tvm.webgpu());
}
/**
* Async initialize instance.
*/
async asyncInit() {
if (this.pipeline !== undefined) return;
await this.#asyncInitConfig();
await this.#asyncInitTVM(this.config.wasmUrl, this.config.cacheUrl);
await this.#asyncInitPipeline();
}
/**
* Async initialize config
*/
async #asyncInitConfig() {
if (this.config !== undefined) return;
const global_config = await (await fetch("/Program Files/WebLLM/global_config.json")).json();
this.config = await (await fetch(global_config.url_dict[this.model_name])).json();
this.config.wasmUrl = global_config.model_lib_map[this.config.model_lib];
this.config.cacheUrl = this.config.model_url;
this.config.tokenizer = this.config.tokenizer_files[0];
}
/**
* Initialize the pipeline
*
* @param tokenizerModel The url to tokenizer model.
*/
async #asyncInitPipeline() {
if (this.pipeline !== undefined) return;
// initialize UX and tokenizer
const tokenizer = await tvmjsGlobalEnv.sentencePieceProcessor(this.config.tokenizer);
this.pipeline = this.tvm.withNewScope(() => {
return new LLMChatPipeline(this.tvm, tokenizer, this.tvm.cacheMetadata, this.config);
});
await this.pipeline.asyncLoadWebGPUPiplines();
this.updateLastMessage("init", "All initialization finished.");
}
appendMessage(kind, text) {
if (kind == "init") {
text = "[System Initalize] " + text;
}
this.logger(`[${kind}]`, text);
}
updateLastMessage(kind, text) {
if (kind == "init") {
this.logger("[init]", text);
} else if (kind == "left") {
globalThis.tvmjsGlobalEnv.response = text;
}
}
async respondTestMessage(repeat) {
const testMessage = "I am a friendly bot. Please ask questions.";
const encodedResult = await this.pipeline.tokenizer.encodeIds(testMessage);
const currentIds = [];
for (let k = 0; k < repeat; ++k) {
for (let i = 0; i < encodedResult.length; ++i) {
currentIds.push(encodedResult[i]);
const msg = this.pipeline.tokenizer.decodeIds(currentIds);
this.updateLastMessage("left", msg);
await new Promise(resolve => setTimeout(resolve, 50));
}
}
}
resetChat() {
this.pipeline.resetChat();
}
/**
* Run generate
*/
async generate() {
if (this.requestInProgress) {
return;
}
this.requestInProgress = true;
try {
await this.asyncInit();
} catch (err) {
this.appendMessage("error", "Init error, " + err.toString());
this.logger("[error]", err.toString());
console.log(err);
this.reset();
this.requestInProgress = false;
return;
}
if (this.debugTest) {
await this.pipeline.evaluate();
this.requestInProgress = false;
return;
}
const prompt = globalThis.tvmjsGlobalEnv.message;
if (prompt == "") {
this.requestInProgress = false;
return;
}
this.appendMessage("progress", "Generating...");
const callbackUpdateResponse = (step, msg) => {
if (msg.endsWith("##")) {
msg = msg.substring(0, msg.length - 2);
} else if (msg.endsWith("#")) {
msg = msg.substring(0, msg.length - 1);
}
this.updateLastMessage("left", msg);
this.appendMessage("progress", `Generating step: ${step}...`);
};
try {
const output = await this.pipeline.generate(prompt, callbackUpdateResponse);
this.updateLastMessage("left", output);
console.info(this.pipeline.runtimeStatsText());
} catch (err) {
this.appendMessage("error", "Generate error, " + err.toString());
this.logger("[error]", err.toString());
console.log(err);
this.reset();
}
this.requestInProgress = false;
}
/**
* Reset the instance;
*/
reset() {
this.tvm = undefined;
if (this.pipeline !== undefined) {
this.pipeline.dispose();
}
this.pipeline = undefined;
}
}
localLLMChatIntance = new LLMChatInstance();
tvmjsGlobalEnv.asyncOnGenerate = async function () {
await localLLMChatIntance.generate();
};
tvmjsGlobalEnv.asyncOnReset = async function () {
await localLLMChatIntance.resetChat();
};

File diff suppressed because one or more lines are too long

View File

@ -1,14 +0,0 @@
{
"model_lib": "vicuna-v1-7b-q4f32_0",
"model_url": "https://huggingface.co/hongyij/web-llm-test-model/resolve/main/",
"maxGenLength": 1024,
"meanGenLength": 256,
"max_seq_len": 2048,
"temperature": 0.7,
"top_p": 0.95,
"local_id": "vicuna-v1-7b-q4f32_0",
"tokenizer_files": [
"/Program Files/WebLLM/vicuna-7b/tokenizer.model"
],
"conv_template": "vicuna-v1.1"
}

View File

@ -1402,6 +1402,11 @@
semver "^7.3.5"
tar "^6.1.11"
"@mlc-ai/web-llm@^0.1.1":
version "0.1.1"
resolved "https://registry.yarnpkg.com/@mlc-ai/web-llm/-/web-llm-0.1.1.tgz#013ca109a7b064362ab265c5011534c2e85cb529"
integrity sha512-1SRlSaAPXxeSBRBrzymgTevw8Diy3iVfz/vmw5yKkbALPxOeMwGi3UIIJMlA29XqW8OrYocvCZ2nXdqRFkm+dQ==
"@monaco-editor/loader@^1.3.3":
version "1.3.3"
resolved "https://registry.yarnpkg.com/@monaco-editor/loader/-/loader-1.3.3.tgz#7f1742bd3cc21c0362a46a4056317f6e5215cfca"