Merge pull request #702 from BBC-Esq/update-langchain-embedding-classes

automatic correct langchain library
This commit is contained in:
PromptEngineer 2024-02-02 21:12:38 -08:00 committed by GitHub
commit 8450efc0bc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 56 additions and 15 deletions

View File

@ -153,17 +153,36 @@ def main(device_type):
logging.info(f"Loaded {len(documents)} documents from {SOURCE_DIRECTORY}")
logging.info(f"Split into {len(texts)} chunks of text")
# Create embeddings
embeddings = HuggingFaceInstructEmbeddings(
model_name=EMBEDDING_MODEL_NAME,
model_kwargs={"device": device_type},
)
# change the embedding type here if you are running into issues.
# These are much smaller embeddings and will work for most appications
# If you use HuggingFaceEmbeddings, make sure to also use the same in the
# run_localGPT.py file.
"""
(1) Chooses an appropriate langchain library based on the enbedding model name. Matching code is contained within fun_localGPT.py.
(2) Provides additional arguments for instructor and BGE models to improve results, pursuant to the instructions contained on
their respective huggingface repository, project page or github repository.
"""
if "instructor" in EMBEDDING_MODEL_NAME:
return HuggingFaceInstructEmbeddings(
model_name=EMBEDDING_MODEL_NAME,
model_kwargs={"device": compute_device},
embed_instruction='Represent the document for retrieval:',
query_instruction='Represent the question for retrieving supporting documents:'
)
# embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME)
elif "bge" in EMBEDDING_MODEL_NAME:
query_instruction = 'Represent this sentence for searching relevant passages:'
return HuggingFaceBgeEmbeddings(
model_name=EMBEDDING_MODEL_NAME,
model_kwargs={"device": compute_device},
query_instruction='Represent this sentence for searching relevant passages:'
)
else:
return HuggingFaceEmbeddings(
model_name=EMBEDDING_MODEL_NAME,
model_kwargs={"device": compute_device},
)
db = Chroma.from_documents(
texts,
@ -171,8 +190,6 @@ def main(device_type):
persist_directory=PERSIST_DIRECTORY,
client_settings=CHROMA_SETTINGS,
)
if __name__ == "__main__":
logging.basicConfig(

View File

@ -119,9 +119,33 @@ def retrieval_qa_pipline(device_type, use_history, promptTemplate_type="llama"):
- The QA system retrieves relevant documents using the retriever and then answers questions based on those documents.
"""
embeddings = HuggingFaceInstructEmbeddings(model_name=EMBEDDING_MODEL_NAME, model_kwargs={"device": device_type})
# uncomment the following line if you used HuggingFaceEmbeddings in the ingest.py
# embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME)
"""
(1) Chooses an appropriate langchain library based on the enbedding model name. Matching code is contained within ingest.py.
(2) Provides additional arguments for instructor and BGE models to improve results, pursuant to the instructions contained on
their respective huggingface repository, project page or github repository.
"""
if "instructor" in EMBEDDING_MODEL_NAME:
return HuggingFaceInstructEmbeddings(
model_name=EMBEDDING_MODEL_NAME,
model_kwargs={"device": compute_device},
embed_instruction='Represent the document for retrieval:',
query_instruction='Represent the question for retrieving supporting documents:'
)
elif "bge" in EMBEDDING_MODEL_NAME:
return HuggingFaceBgeEmbeddings(
model_name=EMBEDDING_MODEL_NAME,
model_kwargs={"device": compute_device},
query_instruction='Represent this sentence for searching relevant passages:'
)
else:
return HuggingFaceEmbeddings(
model_name=EMBEDDING_MODEL_NAME,
model_kwargs={"device": compute_device},
)
# load the vectorstore
db = Chroma(