mirror of
https://github.com/zebrajr/localGPT.git
synced 2025-12-06 12:20:53 +01:00
Added model path
Added the path to download the model.
This commit is contained in:
parent
9577b4abd1
commit
4f9fb00b3a
BIN
SOURCE_DOCUMENTS/Orca_paper.pdf
Normal file
BIN
SOURCE_DOCUMENTS/Orca_paper.pdf
Normal file
Binary file not shown.
Binary file not shown.
|
|
@ -14,6 +14,8 @@ SOURCE_DIRECTORY = f"{ROOT_DIRECTORY}/SOURCE_DOCUMENTS"
|
||||||
|
|
||||||
PERSIST_DIRECTORY = f"{ROOT_DIRECTORY}/DB"
|
PERSIST_DIRECTORY = f"{ROOT_DIRECTORY}/DB"
|
||||||
|
|
||||||
|
MODELS_PATH = "./models"
|
||||||
|
|
||||||
# Can be changed to a specific number
|
# Can be changed to a specific number
|
||||||
INGEST_THREADS = os.cpu_count() or 8
|
INGEST_THREADS = os.cpu_count() or 8
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ from transformers import (
|
||||||
LlamaForCausalLM,
|
LlamaForCausalLM,
|
||||||
LlamaTokenizer,
|
LlamaTokenizer,
|
||||||
)
|
)
|
||||||
from constants import CONTEXT_WINDOW_SIZE, MAX_NEW_TOKENS, N_GPU_LAYERS, N_BATCH
|
from constants import CONTEXT_WINDOW_SIZE, MAX_NEW_TOKENS, N_GPU_LAYERS, N_BATCH, MODELS_PATH
|
||||||
|
|
||||||
|
|
||||||
def load_quantized_model_gguf_ggml(model_id, model_basename, device_type, logging):
|
def load_quantized_model_gguf_ggml(model_id, model_basename, device_type, logging):
|
||||||
|
|
@ -40,7 +40,7 @@ def load_quantized_model_gguf_ggml(model_id, model_basename, device_type, loggin
|
||||||
repo_id=model_id,
|
repo_id=model_id,
|
||||||
filename=model_basename,
|
filename=model_basename,
|
||||||
resume_download=True,
|
resume_download=True,
|
||||||
cache_dir="./models",
|
cache_dir=MODELS_PATH,
|
||||||
)
|
)
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"model_path": model_path,
|
"model_path": model_path,
|
||||||
|
|
@ -140,7 +140,7 @@ def load_full_model(model_id, model_basename, device_type, logging):
|
||||||
device_map="auto",
|
device_map="auto",
|
||||||
torch_dtype=torch.float16,
|
torch_dtype=torch.float16,
|
||||||
low_cpu_mem_usage=True,
|
low_cpu_mem_usage=True,
|
||||||
cache_dir="./models/",
|
cache_dir=MODELS_PATH,
|
||||||
# trust_remote_code=True, # set these if you are using NVIDIA GPU
|
# trust_remote_code=True, # set these if you are using NVIDIA GPU
|
||||||
# load_in_4bit=True,
|
# load_in_4bit=True,
|
||||||
# bnb_4bit_quant_type="nf4",
|
# bnb_4bit_quant_type="nf4",
|
||||||
|
|
|
||||||
|
|
@ -5,29 +5,35 @@ import torch
|
||||||
from langchain.chains import RetrievalQA
|
from langchain.chains import RetrievalQA
|
||||||
from langchain.embeddings import HuggingFaceInstructEmbeddings
|
from langchain.embeddings import HuggingFaceInstructEmbeddings
|
||||||
from langchain.llms import HuggingFacePipeline
|
from langchain.llms import HuggingFacePipeline
|
||||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler # for streaming response
|
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler # for streaming response
|
||||||
from langchain.callbacks.manager import CallbackManager
|
from langchain.callbacks.manager import CallbackManager
|
||||||
|
|
||||||
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
|
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
|
||||||
|
|
||||||
from prompt_template_utils import get_prompt_template
|
from prompt_template_utils import get_prompt_template
|
||||||
|
|
||||||
# from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
# from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||||
from langchain.vectorstores import Chroma
|
from langchain.vectorstores import Chroma
|
||||||
from transformers import (GenerationConfig,
|
from transformers import (
|
||||||
pipeline,
|
GenerationConfig,
|
||||||
)
|
pipeline,
|
||||||
|
)
|
||||||
|
|
||||||
from load_models import (load_quantized_model_gguf_ggml,
|
from load_models import (
|
||||||
load_quantized_model_qptq,
|
load_quantized_model_gguf_ggml,
|
||||||
load_full_model,
|
load_quantized_model_qptq,
|
||||||
)
|
load_full_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
from constants import (
|
||||||
|
EMBEDDING_MODEL_NAME,
|
||||||
|
PERSIST_DIRECTORY,
|
||||||
|
MODEL_ID,
|
||||||
|
MODEL_BASENAME,
|
||||||
|
MAX_NEW_TOKENS,
|
||||||
|
MODELS_PATH,
|
||||||
|
)
|
||||||
|
|
||||||
from constants import (EMBEDDING_MODEL_NAME,
|
|
||||||
PERSIST_DIRECTORY,
|
|
||||||
MODEL_ID,
|
|
||||||
MODEL_BASENAME,
|
|
||||||
MAX_NEW_TOKENS
|
|
||||||
)
|
|
||||||
|
|
||||||
def load_model(device_type, model_id, model_basename=None, LOGGING=logging):
|
def load_model(device_type, model_id, model_basename=None, LOGGING=logging):
|
||||||
"""
|
"""
|
||||||
|
|
@ -89,7 +95,7 @@ def retrieval_qa_pipline(device_type, use_history, promptTemplate_type="llama"):
|
||||||
"""
|
"""
|
||||||
Initializes and returns a retrieval-based Question Answering (QA) pipeline.
|
Initializes and returns a retrieval-based Question Answering (QA) pipeline.
|
||||||
|
|
||||||
This function sets up a QA system that retrieves relevant information using embeddings
|
This function sets up a QA system that retrieves relevant information using embeddings
|
||||||
from the HuggingFace library. It then answers questions based on the retrieved information.
|
from the HuggingFace library. It then answers questions based on the retrieved information.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
|
|
@ -108,40 +114,43 @@ def retrieval_qa_pipline(device_type, use_history, promptTemplate_type="llama"):
|
||||||
- The QA system retrieves relevant documents using the retriever and then answers questions based on those documents.
|
- The QA system retrieves relevant documents using the retriever and then answers questions based on those documents.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
embeddings = HuggingFaceInstructEmbeddings(model_name=EMBEDDING_MODEL_NAME,
|
embeddings = HuggingFaceInstructEmbeddings(model_name=EMBEDDING_MODEL_NAME, model_kwargs={"device": device_type})
|
||||||
model_kwargs={"device": device_type})
|
|
||||||
# uncomment the following line if you used HuggingFaceEmbeddings in the ingest.py
|
# uncomment the following line if you used HuggingFaceEmbeddings in the ingest.py
|
||||||
# embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME)
|
# embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME)
|
||||||
|
|
||||||
# load the vectorstore
|
# load the vectorstore
|
||||||
db = Chroma(persist_directory=PERSIST_DIRECTORY,
|
db = Chroma(
|
||||||
embedding_function=embeddings,)
|
persist_directory=PERSIST_DIRECTORY,
|
||||||
|
embedding_function=embeddings,
|
||||||
|
)
|
||||||
retriever = db.as_retriever()
|
retriever = db.as_retriever()
|
||||||
|
|
||||||
# get the prompt template and memory if set by the user.
|
# get the prompt template and memory if set by the user.
|
||||||
prompt, memory = get_prompt_template(promptTemplate_type=promptTemplate_type,
|
prompt, memory = get_prompt_template(promptTemplate_type=promptTemplate_type, history=use_history)
|
||||||
history=use_history)
|
|
||||||
|
|
||||||
# load the llm pipeline
|
# load the llm pipeline
|
||||||
llm = load_model(device_type,
|
llm = load_model(device_type, model_id=MODEL_ID, model_basename=MODEL_BASENAME, LOGGING=logging)
|
||||||
model_id=MODEL_ID,
|
|
||||||
model_basename=MODEL_BASENAME,
|
|
||||||
LOGGING=logging)
|
|
||||||
|
|
||||||
if use_history:
|
if use_history:
|
||||||
qa = RetrievalQA.from_chain_type(llm=llm,
|
qa = RetrievalQA.from_chain_type(
|
||||||
chain_type="stuff", # try other chains types as well. refine, map_reduce, map_rerank
|
llm=llm,
|
||||||
retriever=retriever,
|
chain_type="stuff", # try other chains types as well. refine, map_reduce, map_rerank
|
||||||
return_source_documents=True,# verbose=True,
|
retriever=retriever,
|
||||||
callbacks=callback_manager,
|
return_source_documents=True, # verbose=True,
|
||||||
chain_type_kwargs={"prompt": prompt, "memory": memory},)
|
callbacks=callback_manager,
|
||||||
|
chain_type_kwargs={"prompt": prompt, "memory": memory},
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
qa = RetrievalQA.from_chain_type(llm=llm,
|
qa = RetrievalQA.from_chain_type(
|
||||||
chain_type="stuff", # try other chains types as well. refine, map_reduce, map_rerank
|
llm=llm,
|
||||||
retriever=retriever,
|
chain_type="stuff", # try other chains types as well. refine, map_reduce, map_rerank
|
||||||
return_source_documents=True,# verbose=True,
|
retriever=retriever,
|
||||||
callbacks=callback_manager,
|
return_source_documents=True, # verbose=True,
|
||||||
chain_type_kwargs={"prompt": prompt,},)
|
callbacks=callback_manager,
|
||||||
|
chain_type_kwargs={
|
||||||
|
"prompt": prompt,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
return qa
|
return qa
|
||||||
|
|
||||||
|
|
@ -214,13 +223,12 @@ def main(device_type, show_sources, use_history):
|
||||||
logging.info(f"Use history set to: {use_history}")
|
logging.info(f"Use history set to: {use_history}")
|
||||||
|
|
||||||
# check if models directory do not exist, create a new one and store models here.
|
# check if models directory do not exist, create a new one and store models here.
|
||||||
if not os.path.exists("./models"):
|
if not os.path.exists(MODELS_PATH):
|
||||||
os.mkdir("models")
|
os.mkdir(MODELS_PATH)
|
||||||
|
|
||||||
qa = retrieval_qa_pipline(device_type, use_history, promptTemplate_type="llama")
|
qa = retrieval_qa_pipline(device_type, use_history, promptTemplate_type="llama")
|
||||||
# Interactive questions and answers
|
# Interactive questions and answers
|
||||||
while True:
|
while True:
|
||||||
|
|
||||||
query = input("\nEnter a query: ")
|
query = input("\nEnter a query: ")
|
||||||
if query == "exit":
|
if query == "exit":
|
||||||
break
|
break
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user