mirror of
https://github.com/maglore9900/max_headroom.git
synced 2025-06-06 19:45:31 +00:00
added spotify and langgraph
This commit is contained in:
parent
61102738b7
commit
5e52804933
@ -1,39 +0,0 @@
|
||||
import requests
|
||||
import winsound
|
||||
|
||||
# Use the API endpoint to generate TTS
|
||||
url = "http://127.0.0.1:7851/api/tts-generate"
|
||||
|
||||
# Prepare the form data
|
||||
data = {
|
||||
"text_input": "This is a test stream.",
|
||||
"text_filtering": "standard",
|
||||
"character_voice_gen": "maxheadroom_00000005.wav",
|
||||
"narrator_enabled": "false",
|
||||
"narrator_voice_gen": "male_01.wav",
|
||||
"text_not_inside": "character",
|
||||
"language": "en",
|
||||
"output_file_name": "stream_output",
|
||||
"output_file_timestamp": "true",
|
||||
"autoplay": "false",
|
||||
"autoplay_volume": "0.8"
|
||||
}
|
||||
|
||||
# Send the POST request to generate TTS
|
||||
response = requests.post(url, data=data)
|
||||
|
||||
# Check if the request was successful
|
||||
if response.status_code == 200:
|
||||
# Parse the JSON response to get the file URL
|
||||
result = response.json()
|
||||
audio_url = result['output_file_url']
|
||||
|
||||
# Download the audio file
|
||||
audio_response = requests.get(audio_url)
|
||||
|
||||
# Save the audio file locally
|
||||
with open("output.wav", "wb") as f:
|
||||
f.write(audio_response.content)
|
||||
winsound.PlaySound('output.wav', winsound.SND_FILENAME)
|
||||
else:
|
||||
print(f"Failed with status code {response.status_code}: {response.text}")
|
10
main.py
10
main.py
@ -1,13 +1,17 @@
|
||||
from modules import adapter, speak
|
||||
from modules import adapter, speak, spotify
|
||||
|
||||
|
||||
sp = speak.Speak()
|
||||
ad = adapter.Adapter("openai")
|
||||
spot = spotify.Spotify()
|
||||
|
||||
|
||||
while True:
|
||||
text = sp.listen()
|
||||
response = ad.chat(text)
|
||||
sp.max_headroom(response)
|
||||
if text and "max" in text.lower():
|
||||
response = ad.chat(text)
|
||||
|
||||
# sp.max_headroom(response)
|
||||
sp.glitch_stream_output(response)
|
||||
|
||||
print("Listening again...")
|
19
main2.py
Normal file
19
main2.py
Normal file
@ -0,0 +1,19 @@
|
||||
from modules import agent, speak
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
|
||||
sp = speak.Speak()
|
||||
graph = agent.Agent("openai")
|
||||
|
||||
|
||||
while True:
|
||||
text = sp.listen()
|
||||
if text and "max" in text.lower():
|
||||
response = loop.run_until_complete(graph.invoke_agent(text))
|
||||
if response:
|
||||
sp.glitch_stream_output(response)
|
||||
|
||||
print("Listening again...")
|
@ -1,114 +1,27 @@
|
||||
from pathlib import Path
|
||||
import environ
|
||||
import os
|
||||
# import psycopg2
|
||||
from typing import Dict, List, Optional, Tuple, Annotated
|
||||
|
||||
# import pandas as pd # Uncomment this if you need pandas
|
||||
|
||||
# langchain imports
|
||||
# from langchain.agents import AgentExecutor, tool, create_openai_functions_agent
|
||||
# from langchain.agents.format_scratchpad.openai_tools import format_to_openai_tool_messages
|
||||
# from langchain.agents.output_parsers.openai_tools import OpenAIToolsAgentOutputParser
|
||||
from langchain.chains import RetrievalQA, RetrievalQAWithSourcesChain
|
||||
from langchain.retrievers.multi_query import MultiQueryRetriever
|
||||
from langchain_community.document_loaders import (
|
||||
CSVLoader,
|
||||
PyPDFLoader,
|
||||
TextLoader,
|
||||
UnstructuredMarkdownLoader,
|
||||
UnstructuredODTLoader,
|
||||
UnstructuredPowerPointLoader,
|
||||
UnstructuredWordDocumentLoader,
|
||||
UnstructuredExcelLoader,
|
||||
Docx2txtLoader,
|
||||
AzureAIDocumentIntelligenceLoader,
|
||||
)
|
||||
|
||||
# from langchain.sql_database import SQLDatabase
|
||||
|
||||
from langchain_community.utilities import SQLDatabase
|
||||
from langchain.text_splitter import (
|
||||
CharacterTextSplitter,
|
||||
RecursiveCharacterTextSplitter,
|
||||
)
|
||||
from langchain_community.vectorstores import FAISS
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from langchain_core.utils.function_calling import convert_to_openai_function
|
||||
from langchain_experimental.sql import SQLDatabaseChain
|
||||
|
||||
# from langchain_openai import ChatOpenAI
|
||||
|
||||
# sqlalchemy imports
|
||||
from sqlalchemy import create_engine
|
||||
|
||||
from langsmith import traceable
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
import math
|
||||
import json
|
||||
|
||||
from time import time
|
||||
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
|
||||
env = environ.Env()
|
||||
environ.Env.read_env()
|
||||
|
||||
|
||||
class Adapter:
|
||||
def __init__(self, llm_type):
|
||||
self.llm_text = llm_type
|
||||
#! IP and Credentials for DB
|
||||
# self.engine = create_engine(f"postgresql+psycopg2://postgres:{env('DBPASS')}@10.0.0.141:9999/{env('DBNAME')}")
|
||||
# self.engine = create_engine(
|
||||
# f"postgresql+psycopg2://postgres:{env('DBPASS')}@10.0.0.141:9999/{env('DBNAME')}"
|
||||
# )
|
||||
# #! max string length
|
||||
# self.db = SQLDatabase(engine=self.engine, max_string_length=1024)
|
||||
# self.db_params = {
|
||||
# "dbname": env("DBNAME"),
|
||||
# "user": "postgres",
|
||||
# "password": env("DBPASS"),
|
||||
# "host": "10.0.0.141", # or your database host
|
||||
# "port": "9999", # or your database port
|
||||
# }
|
||||
# self.conn = psycopg2.connect(**self.db_params)
|
||||
# self.cursor = self.conn.cursor()
|
||||
|
||||
if self.llm_text.lower() == "openai":
|
||||
from langchain_openai import OpenAIEmbeddings, OpenAI
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
self.llm = OpenAI(temperature=0, openai_api_key=env("OPENAI_API_KEY"))
|
||||
self.prompt = ChatPromptTemplate.from_template(
|
||||
"answer the following request: {topic}"
|
||||
)
|
||||
self.llm_chat = ChatOpenAI(
|
||||
temperature=0.3, openai_api_key=env("OPENAI_API_KEY")
|
||||
temperature=0.3, model="gpt-4o-mini", openai_api_key=env("OPENAI_API_KEY")
|
||||
)
|
||||
self.embedding = OpenAIEmbeddings(model="text-embedding-ada-002")
|
||||
elif self.llm_text.lower() == "azure":
|
||||
from langchain_openai import AzureChatOpenAI
|
||||
from langchain_openai import AzureOpenAIEmbeddings
|
||||
|
||||
# self.llm = AzureChatOpenAI(azure_deployment=env("AZURE_OPENAI_CHAT_DEPLOYMENT_NAME"), openai_api_version=env("AZURE_OPENAI_API_VERSION"))
|
||||
self.llm_chat = AzureChatOpenAI(
|
||||
temperature=0,
|
||||
azure_deployment=env("AZURE_OPENAI_CHAT_DEPLOYMENT_NAME"),
|
||||
openai_api_version=env("AZURE_OPENAI_API_VERSION"),
|
||||
response_format={"type": "json_object"},
|
||||
)
|
||||
self.embedding = AzureOpenAIEmbeddings(
|
||||
temperature=0,
|
||||
azure_deployment=env("AZURE_OPENAI_CHAT_DEPLOYMENT_EMBED_NAME"),
|
||||
openai_api_version=env("AZURE_OPENAI_API_VERSION"),
|
||||
)
|
||||
elif self.llm_text.lower() == "local":
|
||||
from langchain_community.llms import Ollama
|
||||
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
|
||||
from langchain_community.chat_models import ChatOllama
|
||||
|
||||
llm_model = "llama3"
|
||||
# llm_model = "notus"
|
||||
self.llm = Ollama(base_url="http://10.0.0.231:11434", model=llm_model)
|
||||
@ -142,766 +55,11 @@ class Adapter:
|
||||
else:
|
||||
raise ValueError("Invalid LLM")
|
||||
|
||||
def load_document(self, filename):
|
||||
file_path = "uploads/" + filename
|
||||
# Map file extensions to their corresponding loader classes
|
||||
|
||||
loaders = {
|
||||
".pdf": PyPDFLoader,
|
||||
".txt": TextLoader,
|
||||
".csv": CSVLoader,
|
||||
".doc": UnstructuredWordDocumentLoader,
|
||||
".docx": UnstructuredWordDocumentLoader,
|
||||
".md": UnstructuredMarkdownLoader,
|
||||
".odt": UnstructuredODTLoader,
|
||||
".ppt": UnstructuredPowerPointLoader,
|
||||
".pptx": UnstructuredPowerPointLoader,
|
||||
".xlsx": UnstructuredExcelLoader,
|
||||
}
|
||||
|
||||
# Identify the loader based on file extension
|
||||
for extension, loader_cls in loaders.items():
|
||||
if filename.endswith(extension):
|
||||
loader = loader_cls(file_path)
|
||||
documents = loader.load()
|
||||
break
|
||||
else:
|
||||
# If no loader is found for the file extension
|
||||
raise ValueError("Invalid file type")
|
||||
# print(f"documents: {documents}")
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=1000, chunk_overlap=30
|
||||
)
|
||||
return text_splitter.split_documents(documents=documents)
|
||||
|
||||
def add_to_datastore(self):
|
||||
try:
|
||||
filename = input("Enter the name of the document (.pdf or .txt):\n")
|
||||
docs = self.load_document(filename)
|
||||
#! permanent vector store
|
||||
datastore_name = os.path.splitext(filename) + "_datastore"
|
||||
vectorstore = FAISS.from_documents(docs, self.embedding)
|
||||
vectorstore.save_local(datastore_name)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
def add_many_to_datastore(self, src_path, dest_path=None):
|
||||
start = time()
|
||||
vectorstore = None
|
||||
count = 0
|
||||
if not dest_path:
|
||||
dest_path = src_path
|
||||
datastore_name = dest_path + "_datastore"
|
||||
entries = os.listdir(src_path)
|
||||
# print(entries)
|
||||
files = [
|
||||
entry for entry in entries if os.path.isfile(os.path.join(src_path, entry))
|
||||
]
|
||||
for each in files:
|
||||
try:
|
||||
# print(each)
|
||||
doc = self.load_document(f"{src_path}/{each}")
|
||||
# print(doc)
|
||||
if not Path(datastore_name).exists():
|
||||
vectorstore = FAISS.from_documents(doc, self.embedding)
|
||||
vectorstore.save_local(datastore_name)
|
||||
else:
|
||||
if vectorstore is None:
|
||||
vectorstore = FAISS.load_local(datastore_name, self.embedding)
|
||||
tmp_vectorstore = FAISS.from_documents(doc, self.embedding)
|
||||
vectorstore.merge_from(tmp_vectorstore)
|
||||
vectorstore.save_local(datastore_name)
|
||||
count += 1
|
||||
print(count)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
end = time()
|
||||
print(end - start)
|
||||
|
||||
def query_datastore(self, query, datastore):
|
||||
try:
|
||||
retriever = FAISS.load_local(datastore, self.embedding).as_retriever()
|
||||
qa = RetrievalQAWithSourcesChain.from_chain_type(
|
||||
llm=self.llm, chain_type="stuff", retriever=retriever, verbose=True
|
||||
)
|
||||
# qa = RetrievalQA.from_chain_type(llm=self.llm, chain_type="stuff", retriever=retriever, verbose=True)
|
||||
if self.llm_text.lower() == "openai" or self.llm_text.lower() == "hybrid":
|
||||
# result = qa.invoke(query)['result']
|
||||
result = qa.invoke(query)
|
||||
else:
|
||||
result = qa.invoke(query)
|
||||
return result
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
def agent_query_doc(self, query, doc):
|
||||
qa = RetrievalQAWithSourcesChain.from_chain_type(
|
||||
llm=self.llm, chain_type="stuff", retriever=doc, verbose=True
|
||||
)
|
||||
result = qa.invoke(query)
|
||||
return result
|
||||
|
||||
def vector_doc(self, filename):
|
||||
doc = self.load_document(filename)
|
||||
retriever = FAISS.from_documents(doc, self.embedding).as_retriever()
|
||||
# retriever = self.hybrid_retrievers(doc, "doc")
|
||||
return retriever
|
||||
|
||||
def query_doc(self, query, filename, doc):
|
||||
# from langchain_community.vectorstores import Qdrant
|
||||
|
||||
try:
|
||||
print(f"query: {query}")
|
||||
print(f"filename: {filename}")
|
||||
# doc = self.load_document(filename, file_path)
|
||||
|
||||
#! permanent vector store
|
||||
# print(f"here is the document data {doc}")
|
||||
# vectorstore = FAISS.from_documents(docs, self.embedding)
|
||||
# vectorstore.save_local("faiss_index_constitution")
|
||||
# persisted_vectorstore = FAISS.load_local("faiss_index_constitution", self.embedding)
|
||||
#! impermanent vector store
|
||||
retriever = FAISS.from_documents(doc, self.embedding).as_retriever()
|
||||
# retriever = self.hybrid_retrievers(doc)
|
||||
#! qdrant options instead of FAISS, need to explore more metadata options for sources
|
||||
# qdrant = Qdrant.from_documents(
|
||||
# doc,
|
||||
# self.embedding,
|
||||
# location=":memory:", # Local mode with in-memory storage only
|
||||
# collection_name="my_documents",
|
||||
# )
|
||||
# retriever = qdrant.as_retriever()
|
||||
# qa = RetrievalQA.from_chain_type(llm=self.llm, chain_type="stuff", retriever=retriever, verbose=True)
|
||||
qa = RetrievalQAWithSourcesChain.from_chain_type(
|
||||
llm=self.llm, chain_type="stuff", retriever=retriever, verbose=True
|
||||
)
|
||||
query = qa.invoke(query)
|
||||
# result = query['answer']
|
||||
# source = query['sources']
|
||||
# return result+"\nSource:"+source
|
||||
return query
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
def query_db(self, query):
|
||||
"""Answer all Risk Management Framework (RMF) control and CCI related questions."""
|
||||
|
||||
QUERY = """
|
||||
Given an input question, first create a syntactically correct postgresql query to run, then look at all of the results of the query. Return an answer for all matches.
|
||||
|
||||
When returning an answer always format the response like this.
|
||||
RMF Control: <rmf_control>
|
||||
CCI: <rmf_control_cci>
|
||||
Assessment Procedurse: <assessment_procedures for rmf_control_cci>
|
||||
Implementation Guidance: <implementation_guidance for rmf_control_cci>
|
||||
|
||||
|
||||
DO NOT LIMIT the length of the SQL query or the response.
|
||||
{question}
|
||||
"""
|
||||
|
||||
db_chain = SQLDatabaseChain.from_llm(self.llm, self.db, verbose=True)
|
||||
try:
|
||||
question = QUERY.format(question=query)
|
||||
if self.llm_text.lower() == "openai" or self.llm_text.lower() == "hybrid":
|
||||
result = str(db_chain.invoke(question)["result"])
|
||||
else:
|
||||
result = db_chain.invoke(question)
|
||||
return result
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
def compare(self, query, db):
|
||||
try:
|
||||
docs = self.load_document("test.txt")
|
||||
for each in docs:
|
||||
print(each.page_content)
|
||||
response = self.query_db(f"\n{query} {each.page_content}\n", db)
|
||||
return response
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
# def tokenize(self, data):
|
||||
# #! chunk and vector raw data
|
||||
# try:
|
||||
# for each in list(data):
|
||||
# print(each)
|
||||
# # results = self.embedding.embed_query(data)
|
||||
# # print(results[:5])
|
||||
# except Exception as e:
|
||||
# print(e)
|
||||
|
||||
#! modified with formatter
|
||||
def chain_query_db(self, prompt):
|
||||
#! use LLM to translate user question into SQL query
|
||||
QUERY = """
|
||||
Given an input question, create a syntactically correct postgresql query to run. Do not limit the return DO NOT USE UNION. DO NOT LIMIT the length of the SQL query or the response. Do NOT assume RMF control number or any other data types.
|
||||
{question}
|
||||
"""
|
||||
db_query = SQLDatabaseChain.from_llm(
|
||||
self.llm, self.db, verbose=False, return_sql=True
|
||||
)
|
||||
try:
|
||||
question = QUERY.format(question=prompt)
|
||||
if self.llm_text.lower() == "openai" or self.llm_text.lower() == "hybrid":
|
||||
result = db_query.invoke(question)["result"]
|
||||
else:
|
||||
result = db_query.invoke(question)
|
||||
# print(f"this is the result query: {result}")
|
||||
self.cursor.execute(f"{result};")
|
||||
db_data = self.cursor.fetchall()
|
||||
db_data = sorted(db_data)
|
||||
print(f"-------- db_data: {db_data}\n")
|
||||
formated = self.query_db_format_response(result, db_data)
|
||||
print(f"formated response: {formated}")
|
||||
# return(db_data)
|
||||
return formated
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
#! new helper function
|
||||
def query_extractor(self, sql_query):
|
||||
# Split the query at 'UNION', if 'UNION' is not present, this will simply take the entire query
|
||||
parts = sql_query.split(" UNION ")
|
||||
column_names = []
|
||||
|
||||
# Only process the last part after the last 'UNION'
|
||||
if len(parts) > 1:
|
||||
part = parts[-1] # This gets the last segment after the UNION
|
||||
else:
|
||||
part = parts[
|
||||
0
|
||||
] # This handles cases without any UNION, taking the whole query
|
||||
|
||||
# Extract the text between 'SELECT' and 'FROM'
|
||||
selected_part = part.split("SELECT")[1].split("FROM")[0].strip()
|
||||
# Split the selected part on commas to get individual column names
|
||||
columns = [column.strip() for column in selected_part.split(",")]
|
||||
# Remove table aliases and extra quotes if present
|
||||
for column in columns:
|
||||
# Remove table prefix if exists (e.g., table_name.column_name)
|
||||
if "." in column:
|
||||
column = column.split(".")[-1]
|
||||
# Strip quotes and whitespaces around the column names
|
||||
clean_column = column.strip().strip('"').strip()
|
||||
# Append all columns to the list, allowing duplicates
|
||||
column_names.append(clean_column)
|
||||
|
||||
return column_names
|
||||
|
||||
#! response formatter
|
||||
def query_db_format_response(self, sql_query, response):
|
||||
sql_query_list = self.query_extractor(sql_query)
|
||||
print(f"sql response: {response}")
|
||||
print(f"SQL Query List: {sql_query_list}")
|
||||
columns = sql_query_list
|
||||
data_dict = {}
|
||||
control_list = [
|
||||
"rmf_control_number",
|
||||
"rmf_control_family",
|
||||
"rmf_control_title",
|
||||
"rmf_control_text",
|
||||
"confidentiality",
|
||||
"integrity",
|
||||
"availability",
|
||||
"supplementary_guidance",
|
||||
"criticality",
|
||||
]
|
||||
cci_list = [
|
||||
"rmf_control_cci",
|
||||
"rmf_control_cci_def",
|
||||
"implementation_guidance",
|
||||
"assessment_procedures",
|
||||
"confidentiality",
|
||||
"integrity",
|
||||
"availability",
|
||||
]
|
||||
|
||||
for record in response:
|
||||
record_dict = {column: record[idx] for idx, column in enumerate(columns)}
|
||||
rmf_control_number = record_dict.get("rmf_control_text_indicator")
|
||||
|
||||
print(f"rmf_control_text_indicator: {rmf_control_number}")
|
||||
# print(f"record: {record}")
|
||||
if not rmf_control_number:
|
||||
rmf_control_number = record_dict.get("rmf_control_number")
|
||||
print(f"rmf_control_number: {rmf_control_number}")
|
||||
else:
|
||||
match = re.search(r"rmf_control_number\s*=\s*\'([^\']*)\'", sql_query)
|
||||
if match:
|
||||
rmf_control_number = match.group(1)
|
||||
print(f"rmf_control_group: {rmf_control_number}")
|
||||
rmf_control_cci = record_dict.pop("rmf_control_cci", None)
|
||||
|
||||
if rmf_control_number:
|
||||
# Ensure a dictionary exists for this control number
|
||||
if rmf_control_number not in data_dict:
|
||||
data_dict[rmf_control_number] = {"CCI": {}}
|
||||
|
||||
# Handle CCI values specifically
|
||||
if rmf_control_cci:
|
||||
# Ensure a dictionary exists for this CCI under the control number
|
||||
if rmf_control_cci not in data_dict[rmf_control_number]["CCI"]:
|
||||
data_dict[rmf_control_number]["CCI"][rmf_control_cci] = {}
|
||||
|
||||
# Populate the CCI dictionary with relevant data from record_dict
|
||||
for key in record_dict:
|
||||
if key in cci_list:
|
||||
# Initialize or append to the list for each key
|
||||
if (
|
||||
key
|
||||
not in data_dict[rmf_control_number]["CCI"][
|
||||
rmf_control_cci
|
||||
]
|
||||
):
|
||||
data_dict[rmf_control_number]["CCI"][rmf_control_cci][
|
||||
key
|
||||
] = []
|
||||
value = record_dict[key]
|
||||
if isinstance(value, float) and math.isnan(value):
|
||||
value = None
|
||||
data_dict[rmf_control_number]["CCI"][rmf_control_cci][
|
||||
key
|
||||
].append(record_dict[key])
|
||||
|
||||
for key in record_dict:
|
||||
if key in control_list:
|
||||
if key not in data_dict[rmf_control_number]:
|
||||
data_dict[rmf_control_number][key] = []
|
||||
value = record_dict[key]
|
||||
if isinstance(value, float) and math.isnan(value):
|
||||
value = None
|
||||
if value not in data_dict[rmf_control_number][key]:
|
||||
data_dict[rmf_control_number][key].append(value)
|
||||
response = json.dumps(data_dict, indent=4)
|
||||
else:
|
||||
response = [list(item) for item in response]
|
||||
print(f"response: {response}")
|
||||
# json_output = json.dumps(data_dict, indent=4)
|
||||
# return json_output
|
||||
return response
|
||||
|
||||
def chat(self, query):
|
||||
print(f"adaptor query: {query}")
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
|
||||
chain = self.prompt | self.llm_chat | StrOutputParser()
|
||||
# loop = asyncio.get_running_loop()
|
||||
# Run the synchronous method in an executor
|
||||
# result = await loop.run_in_executor(None, chain.invoke({"topic": query}))
|
||||
result = chain.invoke({"topic": query})
|
||||
# print(f"adapter result: {result}")
|
||||
return result
|
||||
|
||||
#! multi-doc loader with one output, attempted to dev for general purpose, may not need it for other purposes
|
||||
def multi_doc_loader(self, files: Annotated[list, "List of files to load"]):
|
||||
print("multi_doc_loader")
|
||||
docs = []
|
||||
for file in files:
|
||||
doc = self.load_document(file)
|
||||
docs.extend(doc)
|
||||
return docs
|
||||
|
||||
#! helper function needs to have multi_doc_loader to be run first and that value to be docs
|
||||
def multi_doc_splitter(self, docs):
|
||||
print("multi_doc_splitter")
|
||||
from langchain import hub
|
||||
from langchain_core.runnables import RunnablePassthrough
|
||||
d_sorted = sorted(docs, key=lambda x: x.metadata["source"])
|
||||
d_reversed = list(reversed(d_sorted))
|
||||
concatenated_content = "\n\n\n --- \n\n\n".join(
|
||||
[doc.page_content for doc in d_reversed]
|
||||
)
|
||||
chunk_size_tok = 2000
|
||||
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
|
||||
chunk_size=chunk_size_tok, chunk_overlap=0
|
||||
)
|
||||
texts_split = text_splitter.split_text(concatenated_content)
|
||||
return texts_split
|
||||
|
||||
def raptorize(self, docs):
|
||||
texts = self.multi_doc_loader(docs)
|
||||
texts_split = self.multi_doc_splitter(texts)
|
||||
import raptor
|
||||
from langchain import hub
|
||||
from langchain_core.runnables import RunnablePassthrough
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
rapt = raptor.Raptor(self.llm_chat, self.embedding)
|
||||
raptor_results = rapt.recursive_embed_cluster_summarize(texts_split, level=1, n_levels=3)
|
||||
print("raptor run")
|
||||
for level in sorted(raptor_results.keys()):
|
||||
# Extract summaries from the current level's DataFrame
|
||||
summaries = raptor_results[level][1]["summaries"].tolist()
|
||||
# Extend all_texts with the summaries from the current level
|
||||
texts_split.extend(summaries)
|
||||
# vectorstore = FAISS.from_texts(texts_split, self.embedding)
|
||||
# retriever = vectorstore.as_retriever()
|
||||
retriever = self.hybrid_retrievers(texts_split, "text")
|
||||
|
||||
def format_docs(docs):
|
||||
return "\n\n".join(doc.page_content for doc in docs)
|
||||
|
||||
prompt = hub.pull("rlm/rag-prompt")
|
||||
rag_chain = (
|
||||
{"context": retriever | format_docs, "question": RunnablePassthrough()}
|
||||
| prompt
|
||||
| self.llm_chat
|
||||
| StrOutputParser()
|
||||
)
|
||||
|
||||
ccis = """The organization conducting the inspection/assessment examines the information system to ensure the organization being inspected/assessed configures the information system to audit the execution of privileged functions.
|
||||
"""
|
||||
# For information system components that have applicable STIGs or SRGs, the organization conducting the inspection/assessment evaluates the components to ensure that the organization being inspected/assessed has configured the information system in compliance with the applicable STIGs and SRGs pertaining to CCI 2234."""
|
||||
|
||||
# Question
|
||||
print(rag_chain.invoke(f"search the document for any information that best satisifies the following Question: {ccis}. \n make sure you quote the section of the document where the information was found."))
|
||||
|
||||
def hybrid_retrievers(self, doc, type):
|
||||
from langchain.retrievers import EnsembleRetriever
|
||||
from langchain_community.retrievers import BM25Retriever
|
||||
from langchain_community.vectorstores import FAISS
|
||||
if type.lower() == "text":
|
||||
bm25_retriever = BM25Retriever.from_texts(
|
||||
doc, metadatas=[{"source": 1}] * len(doc)
|
||||
)
|
||||
bm25_retriever.k = 2
|
||||
faiss_vectorstore = FAISS.from_texts(
|
||||
doc, self.embedding, metadatas=[{"source": 2}] * len(doc)
|
||||
)
|
||||
elif type.lower() == "doc":
|
||||
bm25_retriever = BM25Retriever.from_documents(doc)
|
||||
faiss_vectorstore = FAISS.from_documents(doc, self.embedding)
|
||||
|
||||
faiss_retriever = faiss_vectorstore.as_retriever(search_kwargs={"k": 2})
|
||||
ensemble_retriever = EnsembleRetriever(
|
||||
retrievers=[bm25_retriever, faiss_retriever], weights=[0.5, 0.5]
|
||||
)
|
||||
return ensemble_retriever
|
||||
|
||||
|
||||
############
|
||||
|
||||
|
||||
def vector_doc2(self, doc, retriever_type, weight=None):
|
||||
if "hybrid" in retriever_type.lower():
|
||||
if "faiss" in retriever_type.lower():
|
||||
retriever = self.hybrid_retrievers2(doc, "doc", "faiss", weight)
|
||||
elif "qdrant" in retriever_type.lower():
|
||||
retriever = self.hybrid_retrievers2(doc, "doc", "qdrant", weight)
|
||||
elif "faiss" in retriever_type.lower():
|
||||
retriever = FAISS.from_documents(doc, self.embedding).as_retriever()
|
||||
elif "chroma" in retriever_type.lower():
|
||||
from langchain_chroma import Chroma
|
||||
retriever = Chroma.from_documents(doc, self.embedding).as_retriever()
|
||||
elif "qdrant" in retriever_type.lower():
|
||||
from langchain_community.vectorstores import Qdrant
|
||||
qdrant = Qdrant.from_documents(
|
||||
doc,
|
||||
self.embedding,
|
||||
location=":memory:", # Local mode with in-memory storage only
|
||||
# collection_name="my_documents",
|
||||
)
|
||||
retriever = qdrant.as_retriever()
|
||||
|
||||
return retriever
|
||||
|
||||
def hybrid_retrievers2(self, doc, ret_type, doc_type, weight):
|
||||
from langchain.retrievers import EnsembleRetriever
|
||||
from langchain_community.retrievers import BM25Retriever
|
||||
from langchain_community.vectorstores import FAISS
|
||||
if "text" in doc_type.lower():
|
||||
bm25_retriever = BM25Retriever.from_texts(
|
||||
doc, metadatas=[{"source": 1}] * len(doc)
|
||||
)
|
||||
bm25_retriever.k = 2
|
||||
if "faiss" in ret_type.lower():
|
||||
vectorstore = FAISS.from_texts(
|
||||
doc, self.embedding, metadatas=[{"source": 2}] * len(doc)
|
||||
)
|
||||
elif "qdrant" in ret_type.lower():
|
||||
from langchain_community.vectorstores import Qdrant
|
||||
qdrant = Qdrant.from_texts(
|
||||
doc,
|
||||
self.embedding,
|
||||
location=":memory:", # Local mode with in-memory storage only
|
||||
# collection_name="my_documents",
|
||||
)
|
||||
vectorstore = qdrant
|
||||
elif "doc" in doc_type.lower():
|
||||
bm25_retriever = BM25Retriever.from_documents(doc)
|
||||
if "faiss" in ret_type.lower():
|
||||
vectorstore = FAISS.from_documents(doc, self.embedding)
|
||||
elif "qdrant" in ret_type.lower():
|
||||
from langchain_community.vectorstores import Qdrant
|
||||
qdrant = Qdrant.from_documents(
|
||||
doc,
|
||||
self.embedding,
|
||||
location=":memory:", # Local mode with in-memory storage only
|
||||
# collection_name="my_documents",
|
||||
)
|
||||
vectorstore = qdrant
|
||||
|
||||
retriever = vectorstore.as_retriever(search_kwargs={"k": 2})
|
||||
ensemble_retriever = EnsembleRetriever(
|
||||
retrievers=[bm25_retriever, retriever], weights=[(1.0-float(weight)), float(weight)]
|
||||
)
|
||||
return ensemble_retriever
|
||||
|
||||
# def raptorize2(self, query, docs, retriever_type, filename, weight=0.5):
|
||||
# texts = self.multi_doc_loader(docs)
|
||||
# texts_split = self.multi_doc_splitter(texts)
|
||||
# import raptor
|
||||
# from langchain import hub
|
||||
# from langchain_core.runnables import RunnablePassthrough
|
||||
# from langchain_core.output_parsers import StrOutputParser
|
||||
# rapt = raptor.Raptor(self.llm_chat, self.embedding)
|
||||
# raptor_results = rapt.recursive_embed_cluster_summarize(texts_split, level=1, n_levels=3)
|
||||
# print("raptor run")
|
||||
# for level in sorted(raptor_results.keys()):
|
||||
# # Extract summaries from the current level's DataFrame
|
||||
# summaries = raptor_results[level][1]["summaries"].tolist()
|
||||
# # Extend all_texts with the summaries from the current level
|
||||
# texts_split.extend(summaries)
|
||||
# # vectorstore = FAISS.from_texts(texts_split, self.embedding)
|
||||
# # retriever = vectorstore.as_retriever()
|
||||
|
||||
# if "faiss" in retriever_type.lower():
|
||||
# #! chain requires source, this is a hack, does not add source
|
||||
# # retriever = FAISS.from_texts(texts_split, self.embedding, metadatas=[{"source": 2}] * len(texts_split)).as_retriever()
|
||||
# retriever = FAISS.from_texts(texts_split, self.embedding).as_retriever()
|
||||
# elif "chroma" in retriever_type.lower():
|
||||
# from langchain_chroma import Chroma
|
||||
# retriever = Chroma.from_texts(texts_split, self.embedding).as_retriever()
|
||||
# elif "qdrant" in retriever_type.lower():
|
||||
# from langchain_community.vectorstores import Qdrant
|
||||
# qdrant = Qdrant.from_texts(
|
||||
# texts_split,
|
||||
# self.embedding,
|
||||
# location=":memory:", # Local mode with in-memory storage only
|
||||
# # collection_name="my_documents",
|
||||
# )
|
||||
# retriever = qdrant.as_retriever()
|
||||
# elif "hybrid" in retriever_type.lower():
|
||||
# if "faiss" in retriever_type.lower():
|
||||
# retriever = self.hybrid_retrievers2(texts_split, "faiss", "text", weight)
|
||||
# elif "qdrant" in retriever_type.lower():
|
||||
# retriever = self.hybrid_retrievers2(texts_split, "qdrant", "text", weight)
|
||||
|
||||
# def format_docs(docs):
|
||||
# return "\n\n".join(doc.page_content for doc in docs)
|
||||
# #! creates multiple queries based on the first
|
||||
# # retriever = MultiQueryRetriever.from_llm(
|
||||
# # llm=self.llm, retriever=retriever
|
||||
# # )
|
||||
|
||||
# #! need to find actual source for this to have value
|
||||
# # qa = RetrievalQAWithSourcesChain.from_chain_type(
|
||||
# # llm=self.llm, chain_type="stuff", retriever=retriever, verbose=False
|
||||
# # )
|
||||
|
||||
|
||||
# prompt = hub.pull("rlm/rag-prompt")
|
||||
# rag_chain = (
|
||||
# {"context": retriever | format_docs, "question": RunnablePassthrough()}
|
||||
# | prompt
|
||||
# | self.llm_chat
|
||||
# | StrOutputParser()
|
||||
# )
|
||||
|
||||
# import time
|
||||
# start_time = time.perf_counter()
|
||||
# result = rag_chain.invoke(query)
|
||||
# # result = qa.invoke(query)
|
||||
# end_time = time.perf_counter()
|
||||
# total_time = end_time - start_time
|
||||
# return result, total_time
|
||||
|
||||
def raptorize2(self, query, docs, retriever_type, filename, weight=None):
|
||||
from langchain.schema import Document
|
||||
texts = self.multi_doc_loader(docs)
|
||||
texts_split = self.multi_doc_splitter(texts)
|
||||
import raptor
|
||||
from langchain import hub
|
||||
from langchain_core.runnables import RunnablePassthrough
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
rapt = raptor.Raptor(self.llm_chat, self.embedding)
|
||||
raptor_results = rapt.recursive_embed_cluster_summarize(texts_split, level=1, n_levels=3)
|
||||
print("raptor run")
|
||||
for level in sorted(raptor_results.keys()):
|
||||
# Extract summaries from the current level's DataFrame
|
||||
summaries = raptor_results[level][1]["summaries"].tolist()
|
||||
# Extend all_texts with the summaries from the current level
|
||||
texts_split.extend(summaries)
|
||||
# vectorstore = FAISS.from_texts(texts_split, self.embedding)
|
||||
# retriever = vectorstore.as_retriever()
|
||||
modified_list = []
|
||||
for each in texts_split:
|
||||
doc = Document(page_content=each, metadata={'source': filename})
|
||||
modified_list.append(doc)
|
||||
if weight is not None:
|
||||
if "doc" in filename.lower():
|
||||
vectorstore = self.hybrid_retrievers2(modified_list, retriever_type, "doc", weight)
|
||||
else:
|
||||
vectorstore = self.hybrid_retrievers2(modified_list, retriever_type, "text", weight)
|
||||
else:
|
||||
vectorstore = self.vector_doc2(modified_list, retriever_type)
|
||||
|
||||
|
||||
|
||||
#! creates multiple queries based on the first
|
||||
# retriever = MultiQueryRetriever.from_llm(
|
||||
# llm=self.llm, retriever=retriever
|
||||
# )
|
||||
|
||||
#! need to find actual source for this to have value
|
||||
qa = RetrievalQAWithSourcesChain.from_chain_type(
|
||||
llm=self.llm, chain_type="stuff", retriever=vectorstore, verbose=False
|
||||
)
|
||||
|
||||
# def format_docs(docs):
|
||||
# return "\n\n".join(doc.page_content for doc in docs)
|
||||
|
||||
# prompt = hub.pull("rlm/rag-prompt")
|
||||
# rag_chain = (
|
||||
# {"context": retriever | format_docs, "question": RunnablePassthrough()}
|
||||
# | prompt
|
||||
# | self.llm_chat
|
||||
# | StrOutputParser()
|
||||
# )
|
||||
|
||||
import time
|
||||
start_time = time.perf_counter()
|
||||
# result = rag_chain.invoke(query)
|
||||
result = qa.invoke(query)
|
||||
end_time = time.perf_counter()
|
||||
total_time = end_time - start_time
|
||||
return result, total_time
|
||||
|
||||
def agent_query_doc2(self, query, doc):
|
||||
# doc = MultiQueryRetriever.from_llm(
|
||||
# llm=self.llm, retriever=doc
|
||||
# )
|
||||
qa = RetrievalQAWithSourcesChain.from_chain_type(
|
||||
llm=self.llm, chain_type="stuff", retriever=doc, verbose=False
|
||||
)
|
||||
import time
|
||||
start_time = time.perf_counter()
|
||||
result = qa.invoke(query)
|
||||
end_time = time.perf_counter()
|
||||
total_time = end_time - start_time
|
||||
return result, total_time
|
||||
|
||||
def chroma_test(self, query, docs):
|
||||
from langchain_chroma import Chroma
|
||||
retriever = Chroma.from_documents(docs, self.embedding).as_retriever()
|
||||
retriever.invoke(query)
|
||||
|
||||
def adj_sentence_clustering(self, text):
|
||||
import numpy as np
|
||||
import spacy
|
||||
nlp = spacy.load('en_core_web_sm')
|
||||
def process(text):
|
||||
doc = nlp(text)
|
||||
sents = list(doc.sents)
|
||||
vecs = np.stack([sent.vector / sent.vector_norm for sent in sents])
|
||||
|
||||
return sents, vecs
|
||||
|
||||
def cluster_text(sents, vecs, threshold):
|
||||
clusters = [[0]]
|
||||
for i in range(1, len(sents)):
|
||||
if np.dot(vecs[i], vecs[i-1]) < threshold:
|
||||
clusters.append([])
|
||||
clusters[-1].append(i)
|
||||
|
||||
return clusters
|
||||
|
||||
def clean_text(text):
|
||||
# Add your text cleaning process here
|
||||
return text
|
||||
|
||||
# Initialize the clusters lengths list and final texts list
|
||||
clusters_lens = []
|
||||
final_texts = []
|
||||
|
||||
# Process the chunk
|
||||
threshold = 0.3
|
||||
sents, vecs = process(text)
|
||||
|
||||
# Cluster the sentences
|
||||
clusters = cluster_text(sents, vecs, threshold)
|
||||
|
||||
for cluster in clusters:
|
||||
cluster_txt = clean_text(' '.join([sents[i].text for i in cluster]))
|
||||
cluster_len = len(cluster_txt)
|
||||
|
||||
# Check if the cluster is too short
|
||||
if cluster_len < 60:
|
||||
continue
|
||||
|
||||
# Check if the cluster is too long
|
||||
elif cluster_len > 3000:
|
||||
threshold = 0.6
|
||||
sents_div, vecs_div = process(cluster_txt)
|
||||
reclusters = cluster_text(sents_div, vecs_div, threshold)
|
||||
|
||||
for subcluster in reclusters:
|
||||
div_txt = clean_text(' '.join([sents_div[i].text for i in subcluster]))
|
||||
div_len = len(div_txt)
|
||||
|
||||
if div_len < 60 or div_len > 3000:
|
||||
continue
|
||||
|
||||
clusters_lens.append(div_len)
|
||||
final_texts.append(div_txt)
|
||||
|
||||
else:
|
||||
clusters_lens.append(cluster_len)
|
||||
final_texts.append(cluster_txt)
|
||||
return final_texts
|
||||
|
||||
def load_document2(self, filename):
|
||||
from langchain.schema import Document
|
||||
file_path = "uploads/" + filename
|
||||
# Map file extensions to their corresponding loader classes
|
||||
|
||||
loaders = {
|
||||
".pdf": PyPDFLoader,
|
||||
".txt": TextLoader,
|
||||
".csv": CSVLoader,
|
||||
".doc": UnstructuredWordDocumentLoader,
|
||||
".docx": UnstructuredWordDocumentLoader,
|
||||
".md": UnstructuredMarkdownLoader,
|
||||
".odt": UnstructuredODTLoader,
|
||||
".ppt": UnstructuredPowerPointLoader,
|
||||
".pptx": UnstructuredPowerPointLoader,
|
||||
".xlsx": UnstructuredExcelLoader,
|
||||
}
|
||||
|
||||
# Identify the loader based on file extension
|
||||
for extension, loader_cls in loaders.items():
|
||||
if filename.endswith(extension):
|
||||
loader = loader_cls(file_path)
|
||||
documents = loader.load()
|
||||
break
|
||||
else:
|
||||
# If no loader is found for the file extension
|
||||
raise ValueError("Invalid file type")
|
||||
|
||||
# text_splitter = RecursiveCharacterTextSplitter(
|
||||
# chunk_size=1000, chunk_overlap=30
|
||||
# )
|
||||
# result = text_splitter.split_documents(documents=documents)
|
||||
|
||||
text = "".join(doc.page_content for doc in documents)
|
||||
|
||||
cluster = self.adj_sentence_clustering(text)
|
||||
|
||||
modified_list = []
|
||||
for each in cluster:
|
||||
doc = Document(page_content=each, metadata={'source': filename})
|
||||
modified_list.append(doc)
|
||||
# vectorstore = FAISS.from_documents(modified_list, self.embedding).as_retriever()
|
||||
# return vectorstore
|
||||
return modified_list
|
||||
|
191
modules/agent.py
Normal file
191
modules/agent.py
Normal file
@ -0,0 +1,191 @@
|
||||
from typing import TypedDict, Annotated, List, Union
|
||||
import json
|
||||
import operator
|
||||
from modules import adapter, spotify
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain.agents import create_openai_tools_agent
|
||||
from langchain import hub
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.graph import StateGraph, END
|
||||
import asyncio
|
||||
|
||||
|
||||
|
||||
|
||||
class Agent:
|
||||
def __init__(self, model):
|
||||
self.ad = adapter.Adapter(model)
|
||||
self.sp = spotify.Spotify()
|
||||
self.llm = self.ad.llm_chat
|
||||
# self.final_answer_llm = self.llm.bind_tools(
|
||||
# [self.rag_final_answer_tool], tool_choice="rag_final_answer"
|
||||
# )
|
||||
|
||||
self.prompt = hub.pull("hwchase17/openai-functions-agent")
|
||||
|
||||
self.query_agent_runnable = create_openai_tools_agent(
|
||||
llm=self.llm,
|
||||
tools=[
|
||||
# self.rag_final_answer_tool,
|
||||
self.spotify,
|
||||
],
|
||||
prompt=self.prompt,
|
||||
)
|
||||
self.graph = StateGraph(self.AgentState)
|
||||
self.runnable = None
|
||||
self.filename = None
|
||||
self.file_path = None
|
||||
self.doc = None
|
||||
|
||||
class AgentState(TypedDict):
|
||||
input: str
|
||||
agent_out: Union[AgentAction, AgentFinish, None]
|
||||
intermediate_steps: Annotated[List[tuple[AgentAction, str]], operator.add]
|
||||
|
||||
#! Tools
|
||||
@tool("respond")
|
||||
async def respond(self, answer: str):
|
||||
"""Returns a natural language response to the user in `answer`"""
|
||||
return ""
|
||||
|
||||
@tool("spotify")
|
||||
async def spotify(self, command: str):
|
||||
"""Use this tool to control spotify, commands include: play, pause, next, previous, favorite, search
|
||||
Only use this tool if the user says Spotify in their query"""
|
||||
return ""
|
||||
|
||||
# @tool("rag_final_answer")
|
||||
# async def rag_final_answer_tool(self, answer: str, source: str):
|
||||
# """Returns a natural language response to the user in `answer`, and a
|
||||
# `source` which provides citations for where this information came from.
|
||||
# """
|
||||
# return ""
|
||||
|
||||
|
||||
|
||||
def setup_graph(self):
|
||||
self.graph.add_node("query_agent", self.run_query_agent)
|
||||
self.graph.add_node("spotify", self.spotify_tool)
|
||||
# self.graph.add_node("rag_final_answer", self.rag_final_answer)
|
||||
# self.graph.add_node("error", self.rag_final_answer)
|
||||
self.graph.add_node("respond", self.respond)
|
||||
|
||||
self.graph.set_entry_point("query_agent")
|
||||
self.graph.add_conditional_edges(
|
||||
start_key="query_agent",
|
||||
condition=self.router,
|
||||
conditional_edge_mapping={
|
||||
"spotify": "spotify",
|
||||
# "rag_final_answer": "rag_final_answer",
|
||||
# "error": "error",
|
||||
"respond": "respond",
|
||||
},
|
||||
)
|
||||
self.graph.add_edge("spotify", END)
|
||||
# self.graph.add_edge("error", END)
|
||||
# self.graph.add_edge("rag_final_answer", END)
|
||||
# self.graph.add_edge("query_agent", END)
|
||||
self.graph.add_edge("respond", END)
|
||||
|
||||
|
||||
self.runnable = self.graph.compile()
|
||||
|
||||
async def run_query_agent(self, state: list):
|
||||
print("> run_query_agent")
|
||||
print(f"state: {state}")
|
||||
agent_out = self.query_agent_runnable.invoke(state)
|
||||
print(agent_out)
|
||||
return {"agent_out": agent_out}
|
||||
|
||||
async def spotify_tool(self, state: str):
|
||||
print("> spotify_tool")
|
||||
print(f"state: {state}")
|
||||
tool_action = state['agent_out'][0]
|
||||
command = tool_action.tool_input['command']
|
||||
print(f"command: {command}")
|
||||
# print(f"search: {search}")
|
||||
if command == "play":
|
||||
self.sp.play()
|
||||
elif command == "pause":
|
||||
self.sp.pause()
|
||||
elif command == "next":
|
||||
self.sp.next_track()
|
||||
elif command == "previous":
|
||||
self.sp.previous_track()
|
||||
elif command == "favorite":
|
||||
self.sp.favorite_current_song()
|
||||
elif command == "search":
|
||||
self.sp.search_song_and_play(search)
|
||||
else:
|
||||
print("Invalid command")
|
||||
|
||||
|
||||
async def respond(self, answer: str):
|
||||
print("> respond")
|
||||
print(f"answer: {answer}")
|
||||
# answer = answer.agent_out.return_values.get('output', None)
|
||||
agent_out = answer.get('agent_out')
|
||||
output_value = agent_out.return_values.get('output', None)
|
||||
return {"agent_out": output_value}
|
||||
|
||||
async def rag_final_answer(self, state: list):
|
||||
print("> rag final_answer")
|
||||
print(f"state: {state}")
|
||||
try:
|
||||
#! if AgentFinish and no intermediate steps then return the answer without rag_final_answer (need to develop)
|
||||
context = state.get("agent_out").return_values['output']
|
||||
if not context:
|
||||
context = state.get("agent_out")['answer']
|
||||
if not context:
|
||||
context = state.get("intermediate_steps")[-1]
|
||||
except:
|
||||
context = ""
|
||||
if "return_values" in str(state.get("agent_out")) and state["intermediate_steps"] == []:
|
||||
print("bypassing rag_final_answer")
|
||||
print(f"context: {context}")
|
||||
return {"agent_out": {"answer":context, "source": "Quick Response"}}
|
||||
else:
|
||||
prompt = f"You are a helpful assistant, Ensure the answer to user's question is in natural language, using the context provided.\n\nCONTEXT: {context}\nQUESTION: {state['input']}"
|
||||
loop = asyncio.get_running_loop()
|
||||
# Run the synchronous method in an executor
|
||||
out = await loop.run_in_executor(None, self.final_answer_llm.invoke, prompt)
|
||||
function_call = out.additional_kwargs["tool_calls"][-1]["function"]["arguments"]
|
||||
return {"agent_out": function_call}
|
||||
|
||||
async def router(self, state):
|
||||
print("> router")
|
||||
print(f"----router agent state: {state}")
|
||||
if isinstance(state["agent_out"], list):
|
||||
return state["agent_out"][-1].tool
|
||||
else:
|
||||
print("---router error")
|
||||
return "respond"
|
||||
|
||||
async def invoke_agent(self, input_data):
|
||||
if not self.runnable:
|
||||
loop = asyncio.get_running_loop()
|
||||
await loop.run_in_executor(None, self.setup_graph)
|
||||
|
||||
result = await self.runnable.ainvoke(
|
||||
{"input": input_data, "chat_history": [], "intermediate_steps": []}
|
||||
)
|
||||
print("-----")
|
||||
print(result)
|
||||
print("-----")
|
||||
|
||||
try:
|
||||
# Directly access the 'agent_out' key since it is a string
|
||||
agent_out = result["agent_out"]
|
||||
except KeyError:
|
||||
print("Error: 'agent_out' key not found in the result.")
|
||||
agent_out = "I'm sorry, I don't have an answer to that question."
|
||||
|
||||
# 'agent_out' is already the answer in this case
|
||||
answer = agent_out
|
||||
|
||||
print(f"answer: {answer}")
|
||||
if "ToolAgentAction" not in str(agent_out):
|
||||
return answer
|
||||
|
||||
|
||||
|
201
modules/speak.py
201
modules/speak.py
@ -3,6 +3,12 @@ import winsound
|
||||
import speech_recognition as sr
|
||||
import pyttsx3
|
||||
import os
|
||||
import vlc
|
||||
import time
|
||||
import pyaudio
|
||||
from pydub import AudioSegment
|
||||
import random
|
||||
import urllib.parse
|
||||
|
||||
class Speak:
|
||||
def __init__(self):
|
||||
@ -50,8 +56,12 @@ class Speak:
|
||||
|
||||
def listen(self):
|
||||
with self.microphone as source:
|
||||
# Adjust for ambient noise
|
||||
self.recognizer.adjust_for_ambient_noise(source, duration=1)
|
||||
|
||||
print("Listening...")
|
||||
audio = self.recognizer.listen(source)
|
||||
|
||||
try:
|
||||
text = self.recognizer.recognize_google(audio)
|
||||
print("You said: ", text)
|
||||
@ -60,3 +70,194 @@ class Speak:
|
||||
print("Sorry, I didn't get that.")
|
||||
except sr.RequestError as e:
|
||||
print("Sorry, I couldn't request results; {0}".format(e))
|
||||
|
||||
def stream_output(self, text):
|
||||
import urllib.parse
|
||||
# Example parameters
|
||||
voice = "maxheadroom_00000045.wav"
|
||||
language = "en"
|
||||
output_file = "stream_output.wav"
|
||||
|
||||
# Encode the text for URL
|
||||
encoded_text = urllib.parse.quote(text)
|
||||
|
||||
# Create the streaming URL
|
||||
streaming_url = f"http://localhost:7851/api/tts-generate-streaming?text={encoded_text}&voice={voice}&language={language}&output_file={output_file}"
|
||||
|
||||
# Create and play the audio stream using VLC
|
||||
player = vlc.MediaPlayer(streaming_url)
|
||||
|
||||
def on_end_reached(event):
|
||||
print("End of stream reached.")
|
||||
player.stop()
|
||||
|
||||
# Attach event to detect when the stream ends
|
||||
event_manager = player.event_manager()
|
||||
event_manager.event_attach(vlc.EventType.MediaPlayerEndReached, on_end_reached)
|
||||
|
||||
# Start playing the stream
|
||||
player.play()
|
||||
|
||||
# Keep the script running to allow the stream to play
|
||||
while True:
|
||||
state = player.get_state()
|
||||
if state in [vlc.State.Ended, vlc.State.Stopped, vlc.State.Error]:
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
def glitch_stream_output(self, text):
|
||||
def change_pitch(sound, octaves):
|
||||
val = random.randint(0, 10)
|
||||
if val == 1:
|
||||
new_sample_rate = int(sound.frame_rate * (2.0 ** octaves))
|
||||
return sound._spawn(sound.raw_data, overrides={'frame_rate': new_sample_rate}).set_frame_rate(sound.frame_rate)
|
||||
else:
|
||||
return sound
|
||||
|
||||
# Example parameters
|
||||
voice = "maxheadroom_00000045.wav"
|
||||
language = "en"
|
||||
output_file = "stream_output.wav"
|
||||
|
||||
# Encode the text for URL
|
||||
encoded_text = urllib.parse.quote(text)
|
||||
|
||||
# Create the streaming URL
|
||||
streaming_url = f"http://localhost:7851/api/tts-generate-streaming?text={encoded_text}&voice={voice}&language={language}&output_file={output_file}"
|
||||
|
||||
# Stream the audio data
|
||||
response = requests.get(streaming_url, stream=True)
|
||||
|
||||
# Initialize PyAudio
|
||||
p = pyaudio.PyAudio()
|
||||
stream = None
|
||||
|
||||
# Process the audio stream in chunks
|
||||
chunk_size = 1024 * 6 # Adjust chunk size if needed
|
||||
audio_buffer = b''
|
||||
|
||||
for chunk in response.iter_content(chunk_size=chunk_size):
|
||||
audio_buffer += chunk
|
||||
|
||||
if len(audio_buffer) < chunk_size:
|
||||
continue
|
||||
|
||||
audio_segment = AudioSegment(
|
||||
data=audio_buffer,
|
||||
sample_width=2, # 2 bytes for 16-bit audio
|
||||
# frame_rate=44100, # Assumed frame rate, adjust as necessary
|
||||
frame_rate=24000, # Assumed frame rate, adjust as necessary
|
||||
channels=1 # Assuming mono audio
|
||||
)
|
||||
|
||||
# Randomly adjust pitch
|
||||
# octaves = random.uniform(-0.5, 0.5)
|
||||
octaves = random.uniform(-1, 1)
|
||||
modified_chunk = change_pitch(audio_segment, octaves)
|
||||
|
||||
if stream is None:
|
||||
# Define stream parameters
|
||||
stream = p.open(format=pyaudio.paInt16,
|
||||
channels=1,
|
||||
rate=modified_chunk.frame_rate,
|
||||
output=True)
|
||||
|
||||
if random.random() < 0.01: # 1% chance to trigger stutter
|
||||
repeat_times = random.randint(2, 5) # Repeat 2 to 5 times
|
||||
for _ in range(repeat_times):
|
||||
stream.write(modified_chunk.raw_data)
|
||||
|
||||
|
||||
# Play the modified chunk
|
||||
stream.write(modified_chunk.raw_data)
|
||||
|
||||
# Reset buffer
|
||||
audio_buffer = b''
|
||||
|
||||
# Final cleanup
|
||||
if stream:
|
||||
stream.stop_stream()
|
||||
stream.close()
|
||||
p.terminate()
|
||||
|
||||
def glitch_stream_output2(self, text):
|
||||
def change_pitch(sound, octaves):
|
||||
val = random.randint(0, 10)
|
||||
if val == 1:
|
||||
new_sample_rate = int(sound.frame_rate * (2.0 ** octaves))
|
||||
return sound._spawn(sound.raw_data, overrides={'frame_rate': new_sample_rate}).set_frame_rate(sound.frame_rate)
|
||||
else:
|
||||
return sound
|
||||
|
||||
def convert_audio_format(sound, target_sample_rate=16000):
|
||||
# Ensure the audio is in PCM16 format
|
||||
sound = sound.set_sample_width(2) # PCM16 = 2 bytes per sample
|
||||
# Resample the audio to the target sample rate
|
||||
sound = sound.set_frame_rate(target_sample_rate)
|
||||
return sound
|
||||
|
||||
# Example parameters
|
||||
voice = "maxheadroom_00000045.wav"
|
||||
language = "en"
|
||||
output_file = "stream_output.wav"
|
||||
|
||||
# Encode the text for URL
|
||||
encoded_text = urllib.parse.quote(text)
|
||||
|
||||
# Create the streaming URL
|
||||
streaming_url = f"http://localhost:7851/api/tts-generate-streaming?text={encoded_text}&voice={voice}&language={language}&output_file={output_file}"
|
||||
|
||||
# Stream the audio data
|
||||
response = requests.get(streaming_url, stream=True)
|
||||
|
||||
# Initialize PyAudio
|
||||
p = pyaudio.PyAudio()
|
||||
stream = None
|
||||
|
||||
# Process the audio stream in chunks
|
||||
chunk_size = 1024 * 6 # Adjust chunk size if needed
|
||||
audio_buffer = b''
|
||||
|
||||
for chunk in response.iter_content(chunk_size=chunk_size):
|
||||
audio_buffer += chunk
|
||||
|
||||
if len(audio_buffer) < chunk_size:
|
||||
continue
|
||||
|
||||
audio_segment = AudioSegment(
|
||||
data=audio_buffer,
|
||||
sample_width=2, # 2 bytes for 16-bit audio
|
||||
frame_rate=24000, # Assumed frame rate, adjust as necessary
|
||||
channels=1 # Assuming mono audio
|
||||
)
|
||||
|
||||
# Randomly adjust pitch
|
||||
octaves = random.uniform(-1, 1)
|
||||
modified_chunk = change_pitch(audio_segment, octaves)
|
||||
|
||||
if random.random() < 0.01: # 1% chance to trigger stutter
|
||||
repeat_times = random.randint(2, 5) # Repeat 2 to 5 times
|
||||
for _ in range(repeat_times):
|
||||
stream.write(modified_chunk.raw_data)
|
||||
|
||||
# Convert to PCM16 and 16kHz sample rate after the stutter effect
|
||||
modified_chunk = convert_audio_format(modified_chunk, target_sample_rate=16000)
|
||||
|
||||
if stream is None:
|
||||
# Define stream parameters
|
||||
stream = p.open(format=pyaudio.paInt16,
|
||||
channels=1,
|
||||
rate=modified_chunk.frame_rate,
|
||||
output=True)
|
||||
|
||||
# Play the modified chunk
|
||||
stream.write(modified_chunk.raw_data)
|
||||
|
||||
# Reset buffer
|
||||
audio_buffer = b''
|
||||
|
||||
# Final cleanup
|
||||
if stream:
|
||||
stream.stop_stream()
|
||||
stream.close()
|
||||
p.terminate()
|
91
modules/spotify.py
Normal file
91
modules/spotify.py
Normal file
@ -0,0 +1,91 @@
|
||||
import spotipy
|
||||
import environ
|
||||
from spotipy.oauth2 import SpotifyOAuth
|
||||
|
||||
env = environ.Env()
|
||||
environ.Env.read_env()
|
||||
|
||||
class Spotify:
|
||||
def __init__(self):
|
||||
self.sp = spotipy.Spotify(auth_manager=SpotifyOAuth(client_id=env("spotify_client_id"),
|
||||
client_secret=env("spotify_client_secret"),
|
||||
redirect_uri=env("spotify_redirect_uri"),
|
||||
scope="user-modify-playback-state user-read-playback-state user-library-modify"))
|
||||
def get_active_device(self):
|
||||
devices = self.sp.devices()
|
||||
if devices['devices']:
|
||||
# Select the first active device
|
||||
active_device_id = devices['devices'][0]['id']
|
||||
return active_device_id
|
||||
else:
|
||||
return None
|
||||
|
||||
def play(self):
|
||||
device_id = self.get_active_device()
|
||||
self.sp.start_playback(device_id=device_id)
|
||||
|
||||
def pause(self):
|
||||
device_id = self.get_active_device()
|
||||
self.sp.pause_playback(device_id=device_id)
|
||||
|
||||
def next_track(self):
|
||||
device_id = self.get_active_device()
|
||||
self.sp.next_track(device_id=device_id)
|
||||
|
||||
def previous_track(self):
|
||||
device_id = self.get_active_device()
|
||||
self.sp.previous_track(device_id=device_id)
|
||||
|
||||
def favorite_current_song(self):
|
||||
current_track = self.sp.current_playback()
|
||||
if current_track and current_track['item']:
|
||||
track_id = current_track['item']['id']
|
||||
self.sp.current_user_saved_tracks_add([track_id])
|
||||
print(f"Added '{current_track['item']['name']}' to favorites")
|
||||
else:
|
||||
print("No song is currently playing")
|
||||
|
||||
def search_song_and_play(self, song_name):
|
||||
results = self.sp.search(q='track:' + song_name, type='track')
|
||||
if results['tracks']['items']:
|
||||
track_uri = results['tracks']['items'][0]['uri']
|
||||
device_id = self.get_active_device()
|
||||
if device_id:
|
||||
self.sp.start_playback(device_id=device_id, uris=[track_uri])
|
||||
else:
|
||||
print("No active device found. Please start Spotify on a device and try again.")
|
||||
else:
|
||||
print(f"No results found for song: {song_name}")
|
||||
|
||||
def search_artist_and_play(self, artist_name):
|
||||
results = self.sp.search(q='artist:' + artist_name, type='artist')
|
||||
if results['artists']['items']:
|
||||
artist_uri = results['artists']['items'][0]['uri']
|
||||
device_id = self.get_active_device()
|
||||
if device_id:
|
||||
self.sp.start_playback(device_id=device_id, context_uri=artist_uri)
|
||||
else:
|
||||
print("No active device found. Please start Spotify on a device and try again.")
|
||||
else:
|
||||
print(f"No results found for artist: {artist_name}")
|
||||
|
||||
def search_album_and_play(self, album_name):
|
||||
results = self.sp.search(q='album:' + album_name, type='album')
|
||||
if results['albums']['items']:
|
||||
album_uri = results['albums']['items'][0]['uri']
|
||||
device_id = self.get_active_device()
|
||||
if device_id:
|
||||
self.sp.start_playback(device_id=device_id, context_uri=album_uri)
|
||||
else:
|
||||
print("No active device found. Please start Spotify on a device and try again.")
|
||||
else:
|
||||
print(f"No results found for album: {album_name}")
|
||||
|
||||
def favorite_current_song(self):
|
||||
current_track = self.sp.current_playback()
|
||||
if current_track and current_track['item']:
|
||||
track_id = current_track['item']['id']
|
||||
self.sp.current_user_saved_tracks_add([track_id])
|
||||
print(f"Added '{current_track['item']['name']}' to favorites")
|
||||
else:
|
||||
print("No song is currently playing")
|
@ -1,34 +0,0 @@
|
||||
import speech_recognition as sr
|
||||
import pyttsx3
|
||||
|
||||
class STT:
|
||||
def __init__(self):
|
||||
self.recognizer = sr.Recognizer()
|
||||
self.microphone = sr.Microphone()
|
||||
self.engine = pyttsx3.init()
|
||||
self.engine.setProperty('rate', 150)
|
||||
|
||||
def listen(self):
|
||||
with self.microphone as source:
|
||||
print("Listening...")
|
||||
audio = self.recognizer.listen(source)
|
||||
try:
|
||||
text = self.recognizer.recognize_google(audio)
|
||||
print("You said: ", text)
|
||||
return text
|
||||
except sr.UnknownValueError:
|
||||
print("Sorry, I didn't get that.")
|
||||
except sr.RequestError as e:
|
||||
print("Sorry, I couldn't request results; {0}".format(e))
|
||||
|
||||
def speak(self, text):
|
||||
self.engine.say(text)
|
||||
self.engine.runAndWait()
|
||||
|
||||
|
||||
# while True:
|
||||
# stt = STT()
|
||||
# text = stt.listen()
|
||||
# stt.speak(text)
|
||||
# del stt
|
||||
# print("Listening again...")
|
@ -13,3 +13,4 @@ pypdf
|
||||
langsmith
|
||||
unstructured
|
||||
python-docx
|
||||
python-vlc
|
BIN
tmp/output.wav
BIN
tmp/output.wav
Binary file not shown.
Loading…
x
Reference in New Issue
Block a user