suggestoor/libs/get_all_mans.py
2025-01-17 14:29:02 +01:00

226 lines
7.9 KiB
Python

# This script will get all the man pages from the binaries and save them in a file
import os
import subprocess
import dotenv
import json
from langchain_ollama import OllamaEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from colorama import Fore, Style
from tqdm import tqdm
import sys
dotenv.load_dotenv()
def get_binaries():
binaries = []
folders = os.getenv("FOLDERS").split(":")
for folder in folders:
for file in os.listdir(folder):
binaries.append(file)
return binaries
def load_progress():
"""Load the progress of already processed binaries"""
try:
if os.path.exists("embedding_progress.json"):
with open("embedding_progress.json", "r") as f:
return json.load(f)
except Exception as e:
print(
f"{Fore.YELLOW}Could not load progress, starting fresh: {e}{Style.RESET_ALL}"
)
return {"processed_binaries": [], "vector_store_exists": False}
def save_progress(processed_binaries, vector_store_exists=False):
"""Save the progress of processed binaries"""
try:
with open("embedding_progress.json", "w") as f:
json.dump(
{
"processed_binaries": processed_binaries,
"vector_store_exists": vector_store_exists,
},
f,
)
except Exception as e:
print(f"{Fore.YELLOW}Could not save progress: {e}{Style.RESET_ALL}")
def create_embeddings_db(texts, metadatas, existing_db=None):
vector_store_path = os.path.abspath("vector_store")
print(
f"{Fore.CYAN}Will create/update vector store at: {vector_store_path}{Style.RESET_ALL}"
)
print(f"{Fore.CYAN}Creating embeddings...{Style.RESET_ALL}")
embeddings = OllamaEmbeddings(
model=os.getenv("EMBEDDING_MODEL"),
base_url=os.getenv("OLLAMA_HOST"),
)
batch_size = 64
total_batches = (len(texts) + batch_size - 1) // batch_size
all_embeddings = []
with tqdm(
total=len(texts),
desc="Creating vectors",
bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]",
) as pbar:
for i in range(0, len(texts), batch_size):
batch = texts[i : i + batch_size]
try:
batch_embeddings = embeddings.embed_documents(batch)
all_embeddings.extend(batch_embeddings)
pbar.update(len(batch))
except Exception as e:
print(
f"{Fore.RED}Error creating embeddings for batch {i//batch_size + 1}/{total_batches}: {e}{Style.RESET_ALL}"
)
raise
print(f"{Fore.CYAN}Creating/updating FAISS index...{Style.RESET_ALL}")
try:
if existing_db is None:
db = FAISS.from_embeddings(
text_embeddings=list(zip(texts, all_embeddings)),
embedding=embeddings,
metadatas=metadatas,
)
else:
db = existing_db
db.add_embeddings(
text_embeddings=list(zip(texts, all_embeddings)), metadatas=metadatas
)
print(
f"{Fore.CYAN}Saving vector store to: {vector_store_path}{Style.RESET_ALL}"
)
db.save_local(vector_store_path)
if os.path.exists(vector_store_path):
print(
f"{Fore.GREEN}Vector store successfully saved at: {vector_store_path}{Style.RESET_ALL}"
)
print(
f"{Fore.CYAN}Vector store size: {sum(os.path.getsize(os.path.join(vector_store_path, f)) for f in os.listdir(vector_store_path)) / (1024*1024):.2f} MB{Style.RESET_ALL}"
)
return db
except Exception as e:
print(f"{Fore.RED}Error with FAISS index: {e}{Style.RESET_ALL}")
raise
def get_all_mans():
# Load previous progress
progress = load_progress()
processed_binaries = progress["processed_binaries"]
binaries = get_binaries()
remaining_binaries = [b for b in binaries if b not in processed_binaries]
if not remaining_binaries:
print(f"{Fore.GREEN}All binaries already processed!{Style.RESET_ALL}")
if progress["vector_store_exists"]:
embeddings = OllamaEmbeddings(
model=os.getenv("EMBEDDING_MODEL"),
base_url=os.getenv("OLLAMA_HOST"),
)
return FAISS.load_local("vector_store", embeddings)
texts = []
metadatas = []
# Load existing vector store if it exists
existing_db = None
if progress["vector_store_exists"]:
try:
embeddings = OllamaEmbeddings(
model=os.getenv("EMBEDDING_MODEL"),
base_url=os.getenv("OLLAMA_HOST"),
)
existing_db = FAISS.load_local("vector_store", embeddings)
print(f"{Fore.GREEN}Loaded existing vector store{Style.RESET_ALL}")
except Exception as e:
print(
f"{Fore.YELLOW}Could not load existing vector store: {e}{Style.RESET_ALL}"
)
# Text splitter for chunking
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=500, chunk_overlap=50, separators=["\n\n", "\n", " ", ""]
)
print(f"{Fore.CYAN}Processing remaining man pages...{Style.RESET_ALL}")
try:
for binary in tqdm(
remaining_binaries,
desc="Reading man pages",
bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]",
):
man_content = ""
if os.path.exists(f"mans/{binary}.man"):
with open(f"mans/{binary}.man", "r") as file:
man_content = file.read()
else:
man_page = subprocess.run(
["man", binary], capture_output=True, text=True
)
man_content = man_page.stdout
with open(f"mans/{binary}.man", "w") as file:
file.write(man_content)
if man_content.strip():
chunks = text_splitter.split_text(man_content)
texts.extend(chunks)
metadatas.extend(
[
{"binary": binary, "chunk": i, "total_chunks": len(chunks)}
for i in range(len(chunks))
]
)
# Save progress after each binary
processed_binaries.append(binary)
save_progress(processed_binaries, progress["vector_store_exists"])
# Create embeddings in smaller batches
if len(texts) >= 100: # Process every 100 documents
db = create_embeddings_db(texts, metadatas, existing_db)
existing_db = db
texts = []
metadatas = []
progress["vector_store_exists"] = True
save_progress(processed_binaries, True)
# Process any remaining texts
if texts:
db = create_embeddings_db(texts, metadatas, existing_db)
elif existing_db:
db = existing_db
save_progress(processed_binaries, True)
return db
except KeyboardInterrupt:
print(
f"\n{Fore.YELLOW}Process interrupted! Progress has been saved.{Style.RESET_ALL}"
)
if texts:
print(f"{Fore.YELLOW}Saving current batch...{Style.RESET_ALL}")
try:
db = create_embeddings_db(texts, metadatas, existing_db)
save_progress(processed_binaries, True)
print(f"{Fore.GREEN}Current batch saved successfully!{Style.RESET_ALL}")
except Exception as e:
print(f"{Fore.RED}Could not save current batch: {e}{Style.RESET_ALL}")
sys.exit(1)
if __name__ == "__main__":
db = get_all_mans()
print(f"{Fore.GREEN}✓ Created embeddings database{Style.RESET_ALL}")