Added model path

Added the path to download the model.
This commit is contained in:
PromptEngineer 2023-09-17 01:41:41 -07:00
parent 9577b4abd1
commit 4f9fb00b3a
5 changed files with 55 additions and 45 deletions

Binary file not shown.

Binary file not shown.

View File

@ -14,6 +14,8 @@ SOURCE_DIRECTORY = f"{ROOT_DIRECTORY}/SOURCE_DOCUMENTS"
PERSIST_DIRECTORY = f"{ROOT_DIRECTORY}/DB"
MODELS_PATH = "./models"
# Can be changed to a specific number
INGEST_THREADS = os.cpu_count() or 8

View File

@ -9,7 +9,7 @@ from transformers import (
LlamaForCausalLM,
LlamaTokenizer,
)
from constants import CONTEXT_WINDOW_SIZE, MAX_NEW_TOKENS, N_GPU_LAYERS, N_BATCH
from constants import CONTEXT_WINDOW_SIZE, MAX_NEW_TOKENS, N_GPU_LAYERS, N_BATCH, MODELS_PATH
def load_quantized_model_gguf_ggml(model_id, model_basename, device_type, logging):
@ -40,7 +40,7 @@ def load_quantized_model_gguf_ggml(model_id, model_basename, device_type, loggin
repo_id=model_id,
filename=model_basename,
resume_download=True,
cache_dir="./models",
cache_dir=MODELS_PATH,
)
kwargs = {
"model_path": model_path,
@ -140,7 +140,7 @@ def load_full_model(model_id, model_basename, device_type, logging):
device_map="auto",
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
cache_dir="./models/",
cache_dir=MODELS_PATH,
# trust_remote_code=True, # set these if you are using NVIDIA GPU
# load_in_4bit=True,
# bnb_4bit_quant_type="nf4",

View File

@ -5,29 +5,35 @@ import torch
from langchain.chains import RetrievalQA
from langchain.embeddings import HuggingFaceInstructEmbeddings
from langchain.llms import HuggingFacePipeline
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler # for streaming response
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler # for streaming response
from langchain.callbacks.manager import CallbackManager
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
from prompt_template_utils import get_prompt_template
# from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.vectorstores import Chroma
from transformers import (GenerationConfig,
pipeline,
)
from transformers import (
GenerationConfig,
pipeline,
)
from load_models import (load_quantized_model_gguf_ggml,
load_quantized_model_qptq,
load_full_model,
)
from load_models import (
load_quantized_model_gguf_ggml,
load_quantized_model_qptq,
load_full_model,
)
from constants import (
EMBEDDING_MODEL_NAME,
PERSIST_DIRECTORY,
MODEL_ID,
MODEL_BASENAME,
MAX_NEW_TOKENS,
MODELS_PATH,
)
from constants import (EMBEDDING_MODEL_NAME,
PERSIST_DIRECTORY,
MODEL_ID,
MODEL_BASENAME,
MAX_NEW_TOKENS
)
def load_model(device_type, model_id, model_basename=None, LOGGING=logging):
"""
@ -89,7 +95,7 @@ def retrieval_qa_pipline(device_type, use_history, promptTemplate_type="llama"):
"""
Initializes and returns a retrieval-based Question Answering (QA) pipeline.
This function sets up a QA system that retrieves relevant information using embeddings
This function sets up a QA system that retrieves relevant information using embeddings
from the HuggingFace library. It then answers questions based on the retrieved information.
Parameters:
@ -108,40 +114,43 @@ def retrieval_qa_pipline(device_type, use_history, promptTemplate_type="llama"):
- The QA system retrieves relevant documents using the retriever and then answers questions based on those documents.
"""
embeddings = HuggingFaceInstructEmbeddings(model_name=EMBEDDING_MODEL_NAME,
model_kwargs={"device": device_type})
embeddings = HuggingFaceInstructEmbeddings(model_name=EMBEDDING_MODEL_NAME, model_kwargs={"device": device_type})
# uncomment the following line if you used HuggingFaceEmbeddings in the ingest.py
# embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME)
# load the vectorstore
db = Chroma(persist_directory=PERSIST_DIRECTORY,
embedding_function=embeddings,)
db = Chroma(
persist_directory=PERSIST_DIRECTORY,
embedding_function=embeddings,
)
retriever = db.as_retriever()
# get the prompt template and memory if set by the user.
prompt, memory = get_prompt_template(promptTemplate_type=promptTemplate_type,
history=use_history)
prompt, memory = get_prompt_template(promptTemplate_type=promptTemplate_type, history=use_history)
# load the llm pipeline
llm = load_model(device_type,
model_id=MODEL_ID,
model_basename=MODEL_BASENAME,
LOGGING=logging)
llm = load_model(device_type, model_id=MODEL_ID, model_basename=MODEL_BASENAME, LOGGING=logging)
if use_history:
qa = RetrievalQA.from_chain_type(llm=llm,
chain_type="stuff", # try other chains types as well. refine, map_reduce, map_rerank
retriever=retriever,
return_source_documents=True,# verbose=True,
callbacks=callback_manager,
chain_type_kwargs={"prompt": prompt, "memory": memory},)
qa = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff", # try other chains types as well. refine, map_reduce, map_rerank
retriever=retriever,
return_source_documents=True, # verbose=True,
callbacks=callback_manager,
chain_type_kwargs={"prompt": prompt, "memory": memory},
)
else:
qa = RetrievalQA.from_chain_type(llm=llm,
chain_type="stuff", # try other chains types as well. refine, map_reduce, map_rerank
retriever=retriever,
return_source_documents=True,# verbose=True,
callbacks=callback_manager,
chain_type_kwargs={"prompt": prompt,},)
qa = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff", # try other chains types as well. refine, map_reduce, map_rerank
retriever=retriever,
return_source_documents=True, # verbose=True,
callbacks=callback_manager,
chain_type_kwargs={
"prompt": prompt,
},
)
return qa
@ -214,13 +223,12 @@ def main(device_type, show_sources, use_history):
logging.info(f"Use history set to: {use_history}")
# check if models directory do not exist, create a new one and store models here.
if not os.path.exists("./models"):
os.mkdir("models")
if not os.path.exists(MODELS_PATH):
os.mkdir(MODELS_PATH)
qa = retrieval_qa_pipline(device_type, use_history, promptTemplate_type="llama")
# Interactive questions and answers
while True:
query = input("\nEnter a query: ")
if query == "exit":
break