mirror of
https://git.adityakumar.xyz/design-project.git/
synced 2025-02-22 12:50:01 +00:00
250 lines
9 KiB
Python
250 lines
9 KiB
Python
import os
|
|
import tempfile
|
|
|
|
import chromadb
|
|
import ollama
|
|
import streamlit as st
|
|
from chromadb.utils.embedding_functions.ollama_embedding_function import (
|
|
OllamaEmbeddingFunction,
|
|
)
|
|
from langchain_community.document_loaders import PyMuPDFLoader
|
|
from langchain_core.documents import Document
|
|
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
|
from sentence_transformers import CrossEncoder
|
|
from streamlit.runtime.uploaded_file_manager import UploadedFile
|
|
|
|
system_prompt = """
|
|
You are an AI assistant tasked with providing detailed answers based solely on the given context. Your goal is to analyze the information provided and formulate a comprehensive, well-structured response to the question.
|
|
|
|
context will be passed as "Context:"
|
|
user question will be passed as "Question:"
|
|
|
|
To answer the question:
|
|
1. Thoroughly analyze the context, identifying key information relevant to the question.
|
|
2. Organize your thoughts and plan your response to ensure a logical flow of information.
|
|
3. Formulate a detailed answer that directly addresses the question, using only the information provided in the context.
|
|
4. Ensure your answer is comprehensive, covering all relevant aspects found in the context.
|
|
5. If the context doesn't contain sufficient information to fully answer the question, state this clearly in your response.
|
|
|
|
Format your response as follows:
|
|
1. Use clear, concise language.
|
|
2. Organize your answer into paragraphs for readability.
|
|
3. Use bullet points or numbered lists where appropriate to break down complex information.
|
|
4. If relevant, include any headings or subheadings to structure your response.
|
|
5. Ensure proper grammar, punctuation, and spelling throughout your answer.
|
|
|
|
Important: Base your entire response solely on the information provided in the context. Do not include any external knowledge or assumptions not present in the given text.
|
|
"""
|
|
|
|
|
|
def process_document(uploaded_file: UploadedFile) -> list[Document]:
|
|
"""Processes an uploaded PDF file by converting it to text chunks.
|
|
|
|
Takes an uploaded PDF file, saves it temporarily, loads and splits the content
|
|
into text chunks using recursive character splitting.
|
|
|
|
Args:
|
|
uploaded_file: A Streamlit UploadedFile object containing the PDF file
|
|
|
|
Returns:
|
|
A list of Document objects containing the chunked text from the PDF
|
|
|
|
Raises:
|
|
IOError: If there are issues reading/writing the temporary file
|
|
"""
|
|
# Store uploaded file as a temp file
|
|
temp_file = tempfile.NamedTemporaryFile("wb", suffix=".pdf", delete=False)
|
|
temp_file.write(uploaded_file.read())
|
|
|
|
loader = PyMuPDFLoader(temp_file.name)
|
|
docs = loader.load()
|
|
os.unlink(temp_file.name) # Delete temp file
|
|
|
|
text_splitter = RecursiveCharacterTextSplitter(
|
|
chunk_size=400,
|
|
chunk_overlap=100,
|
|
separators=["\n\n", "\n", ".", "?", "!", " ", ""],
|
|
)
|
|
return text_splitter.split_documents(docs)
|
|
|
|
|
|
def get_vector_collection() -> chromadb.Collection:
|
|
"""Gets or creates a ChromaDB collection for vector storage.
|
|
|
|
Creates an Ollama embedding function using the nomic-embed-text model and initializes
|
|
a persistent ChromaDB client. Returns a collection that can be used to store and
|
|
query document embeddings.
|
|
|
|
Returns:
|
|
chromadb.Collection: A ChromaDB collection configured with the Ollama embedding
|
|
function and cosine similarity space.
|
|
"""
|
|
ollama_ef = OllamaEmbeddingFunction(
|
|
url="http://localhost:11434/api/embeddings",
|
|
model_name="nomic-embed-text:latest",
|
|
)
|
|
|
|
chroma_client = chromadb.PersistentClient(path="./chromadb")
|
|
return chroma_client.get_or_create_collection(
|
|
name="rag_app",
|
|
embedding_function=ollama_ef,
|
|
metadata={"hnsw:space": "cosine"},
|
|
)
|
|
|
|
|
|
def add_to_vector_collection(all_splits: list[Document], file_name: str):
|
|
"""Adds document splits to a vector collection for semantic search.
|
|
|
|
Takes a list of document splits and adds them to a ChromaDB vector collection
|
|
along with their metadata and unique IDs based on the filename.
|
|
|
|
Args:
|
|
all_splits: List of Document objects containing text chunks and metadata
|
|
file_name: String identifier used to generate unique IDs for the chunks
|
|
|
|
Returns:
|
|
None. Displays a success message via Streamlit when complete.
|
|
|
|
Raises:
|
|
ChromaDBError: If there are issues upserting documents to the collection
|
|
"""
|
|
collection = get_vector_collection()
|
|
documents, metadatas, ids = [], [], []
|
|
|
|
for idx, split in enumerate(all_splits):
|
|
documents.append(split.page_content)
|
|
metadatas.append(split.metadata)
|
|
ids.append(f"{file_name}_{idx}")
|
|
|
|
collection.upsert(
|
|
documents=documents,
|
|
metadatas=metadatas,
|
|
ids=ids,
|
|
)
|
|
st.success("Data added to the vector store!")
|
|
|
|
|
|
def query_collection(prompt: str, n_results: int = 10):
|
|
"""Queries the vector collection with a given prompt to retrieve relevant documents.
|
|
|
|
Args:
|
|
prompt: The search query text to find relevant documents.
|
|
n_results: Maximum number of results to return. Defaults to 10.
|
|
|
|
Returns:
|
|
dict: Query results containing documents, distances and metadata from the collection.
|
|
|
|
Raises:
|
|
ChromaDBError: If there are issues querying the collection.
|
|
"""
|
|
collection = get_vector_collection()
|
|
results = collection.query(query_texts=[prompt], n_results=n_results)
|
|
return results
|
|
|
|
|
|
def call_llm(context: str, prompt: str):
|
|
"""Calls the language model with context and prompt to generate a response.
|
|
|
|
Uses Ollama to stream responses from a language model by providing context and a
|
|
question prompt. The model uses a system prompt to format and ground its responses appropriately.
|
|
|
|
Args:
|
|
context: String containing the relevant context for answering the question
|
|
prompt: String containing the user's question
|
|
|
|
Yields:
|
|
String chunks of the generated response as they become available from the model
|
|
|
|
Raises:
|
|
OllamaError: If there are issues communicating with the Ollama API
|
|
"""
|
|
response = ollama.chat(
|
|
model="granite3-dense:latest",
|
|
stream=True,
|
|
messages=[
|
|
{
|
|
"role": "system",
|
|
"content": system_prompt,
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": f"Context: {context}, Question: {prompt}",
|
|
},
|
|
],
|
|
)
|
|
for chunk in response:
|
|
if chunk["done"] is False:
|
|
yield chunk["message"]["content"]
|
|
else:
|
|
break
|
|
|
|
|
|
def re_rank_cross_encoders(documents: list[str]) -> tuple[str, list[int]]:
|
|
"""Re-ranks documents using a cross-encoder model for more accurate relevance scoring.
|
|
|
|
Uses the MS MARCO MiniLM cross-encoder model to re-rank the input documents based on
|
|
their relevance to the query prompt. Returns the concatenated text of the top 3 most
|
|
relevant documents along with their indices.
|
|
|
|
Args:
|
|
documents: List of document strings to be re-ranked.
|
|
|
|
Returns:
|
|
tuple: A tuple containing:
|
|
- relevant_text (str): Concatenated text from the top 3 ranked documents
|
|
- relevant_text_ids (list[int]): List of indices for the top ranked documents
|
|
|
|
Raises:
|
|
ValueError: If documents list is empty
|
|
RuntimeError: If cross-encoder model fails to load or rank documents
|
|
"""
|
|
relevant_text = ""
|
|
relevant_text_ids = []
|
|
|
|
encoder_model = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
|
|
ranks = encoder_model.rank(prompt, documents, top_k=3)
|
|
for rank in ranks:
|
|
relevant_text += documents[rank["corpus_id"]]
|
|
relevant_text_ids.append(rank["corpus_id"])
|
|
|
|
return relevant_text, relevant_text_ids
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Document Upload Area
|
|
with st.sidebar:
|
|
st.set_page_config(page_title="RAG Question Answer")
|
|
uploaded_file = st.file_uploader(
|
|
"**📑 Upload PDF files for QnA**", type=["pdf"], accept_multiple_files=False
|
|
)
|
|
|
|
process = st.button(
|
|
"⚡️ Process",
|
|
)
|
|
if uploaded_file and process:
|
|
normalize_uploaded_file_name = uploaded_file.name.translate(
|
|
str.maketrans({"-": "_", ".": "_", " ": "_"})
|
|
)
|
|
all_splits = process_document(uploaded_file)
|
|
add_to_vector_collection(all_splits, normalize_uploaded_file_name)
|
|
|
|
# Question and Answer Area
|
|
st.header("🗣️ RAG Question Answer")
|
|
prompt = st.text_area("**Ask a question related to your document:**")
|
|
ask = st.button(
|
|
"🔥 Ask",
|
|
)
|
|
|
|
if ask and prompt:
|
|
results = query_collection(prompt)
|
|
context = results.get("documents")[0]
|
|
relevant_text, relevant_text_ids = re_rank_cross_encoders(context)
|
|
response = call_llm(context=relevant_text, prompt=prompt)
|
|
st.write_stream(response)
|
|
|
|
with st.expander("See retrieved documents"):
|
|
st.write(results)
|
|
|
|
with st.expander("See most relevant document ids"):
|
|
st.write(relevant_text_ids)
|
|
st.write(relevant_text)
|