Added model path

Added the path to download the model.
This commit is contained in:
PromptEngineer 2023-09-17 01:41:41 -07:00
parent 9577b4abd1
commit 4f9fb00b3a
5 changed files with 55 additions and 45 deletions

Binary file not shown.

Binary file not shown.

View File

@ -14,6 +14,8 @@ SOURCE_DIRECTORY = f"{ROOT_DIRECTORY}/SOURCE_DOCUMENTS"
PERSIST_DIRECTORY = f"{ROOT_DIRECTORY}/DB" PERSIST_DIRECTORY = f"{ROOT_DIRECTORY}/DB"
MODELS_PATH = "./models"
# Can be changed to a specific number # Can be changed to a specific number
INGEST_THREADS = os.cpu_count() or 8 INGEST_THREADS = os.cpu_count() or 8

View File

@ -9,7 +9,7 @@ from transformers import (
LlamaForCausalLM, LlamaForCausalLM,
LlamaTokenizer, LlamaTokenizer,
) )
from constants import CONTEXT_WINDOW_SIZE, MAX_NEW_TOKENS, N_GPU_LAYERS, N_BATCH from constants import CONTEXT_WINDOW_SIZE, MAX_NEW_TOKENS, N_GPU_LAYERS, N_BATCH, MODELS_PATH
def load_quantized_model_gguf_ggml(model_id, model_basename, device_type, logging): def load_quantized_model_gguf_ggml(model_id, model_basename, device_type, logging):
@ -40,7 +40,7 @@ def load_quantized_model_gguf_ggml(model_id, model_basename, device_type, loggin
repo_id=model_id, repo_id=model_id,
filename=model_basename, filename=model_basename,
resume_download=True, resume_download=True,
cache_dir="./models", cache_dir=MODELS_PATH,
) )
kwargs = { kwargs = {
"model_path": model_path, "model_path": model_path,
@ -140,7 +140,7 @@ def load_full_model(model_id, model_basename, device_type, logging):
device_map="auto", device_map="auto",
torch_dtype=torch.float16, torch_dtype=torch.float16,
low_cpu_mem_usage=True, low_cpu_mem_usage=True,
cache_dir="./models/", cache_dir=MODELS_PATH,
# trust_remote_code=True, # set these if you are using NVIDIA GPU # trust_remote_code=True, # set these if you are using NVIDIA GPU
# load_in_4bit=True, # load_in_4bit=True,
# bnb_4bit_quant_type="nf4", # bnb_4bit_quant_type="nf4",

View File

@ -5,29 +5,35 @@ import torch
from langchain.chains import RetrievalQA from langchain.chains import RetrievalQA
from langchain.embeddings import HuggingFaceInstructEmbeddings from langchain.embeddings import HuggingFaceInstructEmbeddings
from langchain.llms import HuggingFacePipeline from langchain.llms import HuggingFacePipeline
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler # for streaming response from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler # for streaming response
from langchain.callbacks.manager import CallbackManager from langchain.callbacks.manager import CallbackManager
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()]) callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
from prompt_template_utils import get_prompt_template from prompt_template_utils import get_prompt_template
# from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler # from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.vectorstores import Chroma from langchain.vectorstores import Chroma
from transformers import (GenerationConfig, from transformers import (
pipeline, GenerationConfig,
) pipeline,
)
from load_models import (load_quantized_model_gguf_ggml, from load_models import (
load_quantized_model_qptq, load_quantized_model_gguf_ggml,
load_full_model, load_quantized_model_qptq,
) load_full_model,
)
from constants import (
EMBEDDING_MODEL_NAME,
PERSIST_DIRECTORY,
MODEL_ID,
MODEL_BASENAME,
MAX_NEW_TOKENS,
MODELS_PATH,
)
from constants import (EMBEDDING_MODEL_NAME,
PERSIST_DIRECTORY,
MODEL_ID,
MODEL_BASENAME,
MAX_NEW_TOKENS
)
def load_model(device_type, model_id, model_basename=None, LOGGING=logging): def load_model(device_type, model_id, model_basename=None, LOGGING=logging):
""" """
@ -89,7 +95,7 @@ def retrieval_qa_pipline(device_type, use_history, promptTemplate_type="llama"):
""" """
Initializes and returns a retrieval-based Question Answering (QA) pipeline. Initializes and returns a retrieval-based Question Answering (QA) pipeline.
This function sets up a QA system that retrieves relevant information using embeddings This function sets up a QA system that retrieves relevant information using embeddings
from the HuggingFace library. It then answers questions based on the retrieved information. from the HuggingFace library. It then answers questions based on the retrieved information.
Parameters: Parameters:
@ -108,40 +114,43 @@ def retrieval_qa_pipline(device_type, use_history, promptTemplate_type="llama"):
- The QA system retrieves relevant documents using the retriever and then answers questions based on those documents. - The QA system retrieves relevant documents using the retriever and then answers questions based on those documents.
""" """
embeddings = HuggingFaceInstructEmbeddings(model_name=EMBEDDING_MODEL_NAME, embeddings = HuggingFaceInstructEmbeddings(model_name=EMBEDDING_MODEL_NAME, model_kwargs={"device": device_type})
model_kwargs={"device": device_type})
# uncomment the following line if you used HuggingFaceEmbeddings in the ingest.py # uncomment the following line if you used HuggingFaceEmbeddings in the ingest.py
# embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME) # embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME)
# load the vectorstore # load the vectorstore
db = Chroma(persist_directory=PERSIST_DIRECTORY, db = Chroma(
embedding_function=embeddings,) persist_directory=PERSIST_DIRECTORY,
embedding_function=embeddings,
)
retriever = db.as_retriever() retriever = db.as_retriever()
# get the prompt template and memory if set by the user. # get the prompt template and memory if set by the user.
prompt, memory = get_prompt_template(promptTemplate_type=promptTemplate_type, prompt, memory = get_prompt_template(promptTemplate_type=promptTemplate_type, history=use_history)
history=use_history)
# load the llm pipeline # load the llm pipeline
llm = load_model(device_type, llm = load_model(device_type, model_id=MODEL_ID, model_basename=MODEL_BASENAME, LOGGING=logging)
model_id=MODEL_ID,
model_basename=MODEL_BASENAME,
LOGGING=logging)
if use_history: if use_history:
qa = RetrievalQA.from_chain_type(llm=llm, qa = RetrievalQA.from_chain_type(
chain_type="stuff", # try other chains types as well. refine, map_reduce, map_rerank llm=llm,
retriever=retriever, chain_type="stuff", # try other chains types as well. refine, map_reduce, map_rerank
return_source_documents=True,# verbose=True, retriever=retriever,
callbacks=callback_manager, return_source_documents=True, # verbose=True,
chain_type_kwargs={"prompt": prompt, "memory": memory},) callbacks=callback_manager,
chain_type_kwargs={"prompt": prompt, "memory": memory},
)
else: else:
qa = RetrievalQA.from_chain_type(llm=llm, qa = RetrievalQA.from_chain_type(
chain_type="stuff", # try other chains types as well. refine, map_reduce, map_rerank llm=llm,
retriever=retriever, chain_type="stuff", # try other chains types as well. refine, map_reduce, map_rerank
return_source_documents=True,# verbose=True, retriever=retriever,
callbacks=callback_manager, return_source_documents=True, # verbose=True,
chain_type_kwargs={"prompt": prompt,},) callbacks=callback_manager,
chain_type_kwargs={
"prompt": prompt,
},
)
return qa return qa
@ -214,13 +223,12 @@ def main(device_type, show_sources, use_history):
logging.info(f"Use history set to: {use_history}") logging.info(f"Use history set to: {use_history}")
# check if models directory do not exist, create a new one and store models here. # check if models directory do not exist, create a new one and store models here.
if not os.path.exists("./models"): if not os.path.exists(MODELS_PATH):
os.mkdir("models") os.mkdir(MODELS_PATH)
qa = retrieval_qa_pipline(device_type, use_history, promptTemplate_type="llama") qa = retrieval_qa_pipline(device_type, use_history, promptTemplate_type="llama")
# Interactive questions and answers # Interactive questions and answers
while True: while True:
query = input("\nEnter a query: ") query = input("\nEnter a query: ")
if query == "exit": if query == "exit":
break break