mirror of
https://github.com/zebrajr/localGPT.git
synced 2025-12-06 12:20:53 +01:00
Added model path
Added the path to download the model.
This commit is contained in:
parent
9577b4abd1
commit
4f9fb00b3a
BIN
SOURCE_DOCUMENTS/Orca_paper.pdf
Normal file
BIN
SOURCE_DOCUMENTS/Orca_paper.pdf
Normal file
Binary file not shown.
Binary file not shown.
|
|
@ -14,6 +14,8 @@ SOURCE_DIRECTORY = f"{ROOT_DIRECTORY}/SOURCE_DOCUMENTS"
|
|||
|
||||
PERSIST_DIRECTORY = f"{ROOT_DIRECTORY}/DB"
|
||||
|
||||
MODELS_PATH = "./models"
|
||||
|
||||
# Can be changed to a specific number
|
||||
INGEST_THREADS = os.cpu_count() or 8
|
||||
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from transformers import (
|
|||
LlamaForCausalLM,
|
||||
LlamaTokenizer,
|
||||
)
|
||||
from constants import CONTEXT_WINDOW_SIZE, MAX_NEW_TOKENS, N_GPU_LAYERS, N_BATCH
|
||||
from constants import CONTEXT_WINDOW_SIZE, MAX_NEW_TOKENS, N_GPU_LAYERS, N_BATCH, MODELS_PATH
|
||||
|
||||
|
||||
def load_quantized_model_gguf_ggml(model_id, model_basename, device_type, logging):
|
||||
|
|
@ -40,7 +40,7 @@ def load_quantized_model_gguf_ggml(model_id, model_basename, device_type, loggin
|
|||
repo_id=model_id,
|
||||
filename=model_basename,
|
||||
resume_download=True,
|
||||
cache_dir="./models",
|
||||
cache_dir=MODELS_PATH,
|
||||
)
|
||||
kwargs = {
|
||||
"model_path": model_path,
|
||||
|
|
@ -140,7 +140,7 @@ def load_full_model(model_id, model_basename, device_type, logging):
|
|||
device_map="auto",
|
||||
torch_dtype=torch.float16,
|
||||
low_cpu_mem_usage=True,
|
||||
cache_dir="./models/",
|
||||
cache_dir=MODELS_PATH,
|
||||
# trust_remote_code=True, # set these if you are using NVIDIA GPU
|
||||
# load_in_4bit=True,
|
||||
# bnb_4bit_quant_type="nf4",
|
||||
|
|
|
|||
|
|
@ -7,27 +7,33 @@ from langchain.embeddings import HuggingFaceInstructEmbeddings
|
|||
from langchain.llms import HuggingFacePipeline
|
||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler # for streaming response
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
|
||||
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
|
||||
|
||||
from prompt_template_utils import get_prompt_template
|
||||
|
||||
# from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||
from langchain.vectorstores import Chroma
|
||||
from transformers import (GenerationConfig,
|
||||
from transformers import (
|
||||
GenerationConfig,
|
||||
pipeline,
|
||||
)
|
||||
)
|
||||
|
||||
from load_models import (load_quantized_model_gguf_ggml,
|
||||
from load_models import (
|
||||
load_quantized_model_gguf_ggml,
|
||||
load_quantized_model_qptq,
|
||||
load_full_model,
|
||||
)
|
||||
)
|
||||
|
||||
from constants import (EMBEDDING_MODEL_NAME,
|
||||
from constants import (
|
||||
EMBEDDING_MODEL_NAME,
|
||||
PERSIST_DIRECTORY,
|
||||
MODEL_ID,
|
||||
MODEL_BASENAME,
|
||||
MAX_NEW_TOKENS
|
||||
)
|
||||
MAX_NEW_TOKENS,
|
||||
MODELS_PATH,
|
||||
)
|
||||
|
||||
|
||||
def load_model(device_type, model_id, model_basename=None, LOGGING=logging):
|
||||
"""
|
||||
|
|
@ -108,40 +114,43 @@ def retrieval_qa_pipline(device_type, use_history, promptTemplate_type="llama"):
|
|||
- The QA system retrieves relevant documents using the retriever and then answers questions based on those documents.
|
||||
"""
|
||||
|
||||
embeddings = HuggingFaceInstructEmbeddings(model_name=EMBEDDING_MODEL_NAME,
|
||||
model_kwargs={"device": device_type})
|
||||
embeddings = HuggingFaceInstructEmbeddings(model_name=EMBEDDING_MODEL_NAME, model_kwargs={"device": device_type})
|
||||
# uncomment the following line if you used HuggingFaceEmbeddings in the ingest.py
|
||||
# embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME)
|
||||
|
||||
# load the vectorstore
|
||||
db = Chroma(persist_directory=PERSIST_DIRECTORY,
|
||||
embedding_function=embeddings,)
|
||||
db = Chroma(
|
||||
persist_directory=PERSIST_DIRECTORY,
|
||||
embedding_function=embeddings,
|
||||
)
|
||||
retriever = db.as_retriever()
|
||||
|
||||
# get the prompt template and memory if set by the user.
|
||||
prompt, memory = get_prompt_template(promptTemplate_type=promptTemplate_type,
|
||||
history=use_history)
|
||||
prompt, memory = get_prompt_template(promptTemplate_type=promptTemplate_type, history=use_history)
|
||||
|
||||
# load the llm pipeline
|
||||
llm = load_model(device_type,
|
||||
model_id=MODEL_ID,
|
||||
model_basename=MODEL_BASENAME,
|
||||
LOGGING=logging)
|
||||
llm = load_model(device_type, model_id=MODEL_ID, model_basename=MODEL_BASENAME, LOGGING=logging)
|
||||
|
||||
if use_history:
|
||||
qa = RetrievalQA.from_chain_type(llm=llm,
|
||||
qa = RetrievalQA.from_chain_type(
|
||||
llm=llm,
|
||||
chain_type="stuff", # try other chains types as well. refine, map_reduce, map_rerank
|
||||
retriever=retriever,
|
||||
return_source_documents=True,# verbose=True,
|
||||
return_source_documents=True, # verbose=True,
|
||||
callbacks=callback_manager,
|
||||
chain_type_kwargs={"prompt": prompt, "memory": memory},)
|
||||
chain_type_kwargs={"prompt": prompt, "memory": memory},
|
||||
)
|
||||
else:
|
||||
qa = RetrievalQA.from_chain_type(llm=llm,
|
||||
qa = RetrievalQA.from_chain_type(
|
||||
llm=llm,
|
||||
chain_type="stuff", # try other chains types as well. refine, map_reduce, map_rerank
|
||||
retriever=retriever,
|
||||
return_source_documents=True,# verbose=True,
|
||||
return_source_documents=True, # verbose=True,
|
||||
callbacks=callback_manager,
|
||||
chain_type_kwargs={"prompt": prompt,},)
|
||||
chain_type_kwargs={
|
||||
"prompt": prompt,
|
||||
},
|
||||
)
|
||||
|
||||
return qa
|
||||
|
||||
|
|
@ -214,13 +223,12 @@ def main(device_type, show_sources, use_history):
|
|||
logging.info(f"Use history set to: {use_history}")
|
||||
|
||||
# check if models directory do not exist, create a new one and store models here.
|
||||
if not os.path.exists("./models"):
|
||||
os.mkdir("models")
|
||||
if not os.path.exists(MODELS_PATH):
|
||||
os.mkdir(MODELS_PATH)
|
||||
|
||||
qa = retrieval_qa_pipline(device_type, use_history, promptTemplate_type="llama")
|
||||
# Interactive questions and answers
|
||||
while True:
|
||||
|
||||
query = input("\nEnter a query: ")
|
||||
if query == "exit":
|
||||
break
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user