mirror of
https://github.com/tcsenpai/DualMind.git
synced 2025-06-07 11:05:21 +00:00
Added token count and trimmer
This commit is contained in:
parent
5324358e37
commit
648e03ae28
@ -86,7 +86,6 @@ The appearance of the Streamlit interface can be customized by modifying the `st
|
|||||||
|
|
||||||
- `main.py`: Entry point of the application
|
- `main.py`: Entry point of the application
|
||||||
- `ai_conversation.py`: Core logic for AI conversations
|
- `ai_conversation.py`: Core logic for AI conversations
|
||||||
- `ollama_client.py`: Client for interacting with the Ollama API
|
|
||||||
- `streamlit_app.py`: Streamlit web interface implementation
|
- `streamlit_app.py`: Streamlit web interface implementation
|
||||||
- `style/custom.css`: Custom styles for the web interface
|
- `style/custom.css`: Custom styles for the web interface
|
||||||
- `run_cli.sh`: Shell script to run the CLI version
|
- `run_cli.sh`: Shell script to run the CLI version
|
||||||
|
@ -1,9 +1,19 @@
|
|||||||
import ollama
|
import ollama
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
import datetime
|
import datetime
|
||||||
|
import tiktoken # Used for token counting
|
||||||
|
|
||||||
class AIConversation:
|
class AIConversation:
|
||||||
def __init__(self, model_1, model_2, system_prompt_1, system_prompt_2, ollama_endpoint):
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_1,
|
||||||
|
model_2,
|
||||||
|
system_prompt_1,
|
||||||
|
system_prompt_2,
|
||||||
|
ollama_endpoint,
|
||||||
|
max_tokens=4000,
|
||||||
|
):
|
||||||
|
# Initialize conversation parameters and Ollama client
|
||||||
self.model_1 = model_1
|
self.model_1 = model_1
|
||||||
self.model_2 = model_2
|
self.model_2 = model_2
|
||||||
self.system_prompt_1 = system_prompt_1
|
self.system_prompt_1 = system_prompt_1
|
||||||
@ -13,14 +23,31 @@ class AIConversation:
|
|||||||
self.messages_2 = [{"role": "system", "content": system_prompt_2}]
|
self.messages_2 = [{"role": "system", "content": system_prompt_2}]
|
||||||
self.client = ollama.Client(ollama_endpoint)
|
self.client = ollama.Client(ollama_endpoint)
|
||||||
self.ollama_endpoint = ollama_endpoint
|
self.ollama_endpoint = ollama_endpoint
|
||||||
|
self.tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo")
|
||||||
|
self.max_tokens = max_tokens
|
||||||
|
|
||||||
|
def count_tokens(self, messages):
|
||||||
|
# Count the total number of tokens in the messages
|
||||||
|
return sum(len(self.tokenizer.encode(msg["content"])) for msg in messages)
|
||||||
|
|
||||||
|
def trim_messages(self, messages):
|
||||||
|
# Trim messages to stay within the token limit
|
||||||
|
if self.count_tokens(messages) > self.max_tokens:
|
||||||
|
print(colored(f"[SYSTEM] Max tokens reached. Trimming messages...", "magenta"))
|
||||||
|
while self.count_tokens(messages) > self.max_tokens:
|
||||||
|
if len(messages) > 1:
|
||||||
|
messages.pop(1) # Remove the oldest non-system message
|
||||||
|
else:
|
||||||
|
break # Avoid removing the system message
|
||||||
|
return messages
|
||||||
|
|
||||||
def start_conversation(self, initial_message, num_exchanges=0):
|
def start_conversation(self, initial_message, num_exchanges=0):
|
||||||
|
# Main conversation loop
|
||||||
current_message = initial_message
|
current_message = initial_message
|
||||||
color_1 = "cyan"
|
color_1, color_2 = "cyan", "yellow"
|
||||||
color_2 = "yellow"
|
|
||||||
conversation_log = []
|
conversation_log = []
|
||||||
|
|
||||||
# Appending the initial message to the conversation log in the system prompt
|
# Add initial message to system prompts
|
||||||
self.messages_1[0]["content"] += f"\n\nInitial message: {current_message}"
|
self.messages_1[0]["content"] += f"\n\nInitial message: {current_message}"
|
||||||
self.messages_2[0]["content"] += f"\n\nInitial message: {current_message}"
|
self.messages_2[0]["content"] += f"\n\nInitial message: {current_message}"
|
||||||
|
|
||||||
@ -32,54 +59,57 @@ class AIConversation:
|
|||||||
i = 0
|
i = 0
|
||||||
active_ai = 1 # Starting with AI 1
|
active_ai = 1 # Starting with AI 1
|
||||||
while num_exchanges == 0 or i < num_exchanges:
|
while num_exchanges == 0 or i < num_exchanges:
|
||||||
|
# Set up current AI's parameters
|
||||||
|
name = "AI 1" if active_ai == 0 else "AI 2"
|
||||||
|
messages = self.messages_1 if active_ai == 0 else self.messages_2
|
||||||
|
other_messages = self.messages_2 if active_ai == 0 else self.messages_1
|
||||||
|
color = color_1 if active_ai == 0 else color_2
|
||||||
|
|
||||||
|
# Add user message to conversation history
|
||||||
if active_ai == 0:
|
|
||||||
name = "AI 1"
|
|
||||||
messages = self.messages_1
|
|
||||||
other_messages = self.messages_2
|
|
||||||
color = color_1
|
|
||||||
else:
|
|
||||||
name = "AI 2"
|
|
||||||
messages = self.messages_2
|
|
||||||
other_messages = self.messages_1
|
|
||||||
color = color_2
|
|
||||||
|
|
||||||
messages.append({"role": "user", "content": current_message})
|
messages.append({"role": "user", "content": current_message})
|
||||||
other_messages.append({"role": "assistant", "content": current_message})
|
other_messages.append({"role": "assistant", "content": current_message})
|
||||||
|
|
||||||
#print(colored(f"Conversation with {name} ({self.current_model})", "blue"))
|
# Trim messages and get token count
|
||||||
|
messages = self.trim_messages(messages)
|
||||||
|
token_count = self.count_tokens(messages)
|
||||||
|
print(colored(f"Context token count: {token_count}", "magenta"))
|
||||||
|
|
||||||
|
# Generate AI response
|
||||||
response = self.client.chat(
|
response = self.client.chat(
|
||||||
model=self.current_model,
|
model=self.current_model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
options={
|
options={
|
||||||
"temperature": 0.7, # Adjust this value to control randomness
|
"temperature": 0.7, # Control randomness
|
||||||
"repeat_penalty": 1.2, # Penalize repetition
|
"repeat_penalty": 1.2, # Penalize repetition
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
response_content = response['message']['content']
|
response_content = response["message"]["content"]
|
||||||
|
|
||||||
# Post-process to remove repetition
|
# Post-process to remove repetition
|
||||||
response_content = self.remove_repetition(response_content)
|
response_content = self.remove_repetition(response_content)
|
||||||
|
|
||||||
|
# Format and print the response
|
||||||
model_name = f"{self.current_model.upper()} ({name}):"
|
model_name = f"{self.current_model.upper()} ({name}):"
|
||||||
formatted_response = f"{model_name}\n{response_content}\n"
|
formatted_response = f"{model_name}\n{response_content}\n"
|
||||||
|
|
||||||
print(colored(formatted_response, color))
|
print(colored(formatted_response, color))
|
||||||
conversation_log.append({"role": "assistant", "content": formatted_response})
|
conversation_log.append(
|
||||||
|
{"role": "assistant", "content": formatted_response}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update conversation history
|
||||||
messages.append({"role": "assistant", "content": response_content})
|
messages.append({"role": "assistant", "content": response_content})
|
||||||
other_messages.append({"role": "user", "content": response_content})
|
other_messages.append({"role": "user", "content": response_content})
|
||||||
|
|
||||||
current_message = response_content
|
current_message = response_content
|
||||||
|
|
||||||
# Switching the AI
|
# Switch to the other AI for the next turn
|
||||||
self.current_model = self.model_2 if active_ai == 1 else self.model_1
|
self.current_model = self.model_2 if active_ai == 1 else self.model_1
|
||||||
active_ai = 1 if active_ai == 0 else 0
|
active_ai = 1 if active_ai == 0 else 0
|
||||||
|
|
||||||
print(colored("---", "magenta"))
|
print(colored("---", "magenta"))
|
||||||
print()
|
print()
|
||||||
|
|
||||||
|
# Check for conversation end condition
|
||||||
if current_message.strip().endswith("{{end_conversation}}"):
|
if current_message.strip().endswith("{{end_conversation}}"):
|
||||||
print(colored("Conversation ended by the AI.", "green"))
|
print(colored("Conversation ended by the AI.", "green"))
|
||||||
break
|
break
|
||||||
@ -92,8 +122,8 @@ class AIConversation:
|
|||||||
print(colored("Conversation ended.", "green"))
|
print(colored("Conversation ended.", "green"))
|
||||||
self.save_conversation_log(conversation_log)
|
self.save_conversation_log(conversation_log)
|
||||||
|
|
||||||
|
|
||||||
def save_conversation_log(self, messages, filename=None):
|
def save_conversation_log(self, messages, filename=None):
|
||||||
|
# Save the conversation log to a file
|
||||||
if filename is None:
|
if filename is None:
|
||||||
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
filename = f"conversation_log_{timestamp}.txt"
|
filename = f"conversation_log_{timestamp}.txt"
|
||||||
@ -115,7 +145,7 @@ class AIConversation:
|
|||||||
print(f"Conversation log saved to {filename}")
|
print(f"Conversation log saved to {filename}")
|
||||||
|
|
||||||
def remove_repetition(self, text):
|
def remove_repetition(self, text):
|
||||||
# Split the text into sentences
|
# Remove repeated sentences while preserving order
|
||||||
split_tokens = [".", "!", "?"]
|
split_tokens = [".", "!", "?"]
|
||||||
sentences = []
|
sentences = []
|
||||||
current_sentence = ""
|
current_sentence = ""
|
||||||
@ -134,4 +164,4 @@ class AIConversation:
|
|||||||
unique_sentences.append(sentence)
|
unique_sentences.append(sentence)
|
||||||
|
|
||||||
# Join the sentences back together
|
# Join the sentences back together
|
||||||
return ' '.join(unique_sentences)
|
return " ".join(unique_sentences)
|
66
main.py
66
main.py
@ -4,56 +4,58 @@ from dotenv import load_dotenv, set_key
|
|||||||
from ai_conversation import AIConversation
|
from ai_conversation import AIConversation
|
||||||
|
|
||||||
def load_system_prompt(filename):
|
def load_system_prompt(filename):
|
||||||
with open(filename, 'r') as file:
|
"""Load the system prompt from a file."""
|
||||||
|
with open(filename, "r") as file:
|
||||||
return file.read().strip()
|
return file.read().strip()
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
# Load environment variables
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
# Retrieve configuration from environment variables
|
||||||
|
ollama_endpoint = os.getenv("OLLAMA_ENDPOINT")
|
||||||
|
model_1 = os.getenv("MODEL_1")
|
||||||
|
model_2 = os.getenv("MODEL_2")
|
||||||
|
system_prompt_1 = load_system_prompt("system_prompt_1.txt")
|
||||||
|
system_prompt_2 = load_system_prompt("system_prompt_2.txt")
|
||||||
|
initial_prompt = os.getenv(
|
||||||
|
"INITIAL_PROMPT",
|
||||||
|
"Let's discuss the future of AI. What are your thoughts on its potential impact on society?",
|
||||||
|
)
|
||||||
|
max_tokens = int(os.getenv("MAX_TOKENS", 4000))
|
||||||
|
print(f"Max tokens: {max_tokens}")
|
||||||
|
|
||||||
|
# Initialize the AI conversation object
|
||||||
|
conversation = AIConversation(
|
||||||
|
model_1, model_2, system_prompt_1, system_prompt_2, ollama_endpoint, max_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set up command-line argument parser
|
||||||
parser = argparse.ArgumentParser(description="AI Conversation")
|
parser = argparse.ArgumentParser(description="AI Conversation")
|
||||||
parser.add_argument("--cli", action="store_true", help="Run in CLI mode")
|
parser.add_argument("--cli", action="store_true", help="Run in CLI mode")
|
||||||
parser.add_argument("--streamlit", action="store_true", help="Run in Streamlit mode")
|
parser.add_argument(
|
||||||
|
"--streamlit", action="store_true", help="Run in Streamlit mode"
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Run the appropriate interface based on command-line arguments
|
||||||
if args.cli:
|
if args.cli:
|
||||||
run_cli()
|
run_cli(conversation, initial_prompt)
|
||||||
elif args.streamlit:
|
elif args.streamlit:
|
||||||
run_streamlit()
|
run_streamlit(conversation, initial_prompt)
|
||||||
else:
|
else:
|
||||||
print("Please specify either --cli or --streamlit mode.")
|
print("Please specify either --cli or --streamlit mode.")
|
||||||
|
|
||||||
def run_cli():
|
def run_cli(conversation, initial_prompt):
|
||||||
|
"""Run the conversation in command-line interface mode."""
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
ollama_endpoint = os.getenv("OLLAMA_ENDPOINT")
|
|
||||||
model_1 = os.getenv("MODEL_1")
|
|
||||||
model_2 = os.getenv("MODEL_2")
|
|
||||||
|
|
||||||
system_prompt_1_file = os.getenv("CUSTOM_SYSTEM_PROMPT_1", "system_prompt_1.txt")
|
|
||||||
system_prompt_2_file = os.getenv("CUSTOM_SYSTEM_PROMPT_2", "system_prompt_2.txt")
|
|
||||||
|
|
||||||
system_prompt_1 = load_system_prompt(system_prompt_1_file)
|
|
||||||
system_prompt_2 = load_system_prompt(system_prompt_2_file)
|
|
||||||
|
|
||||||
initial_prompt = os.getenv("INITIAL_PROMPT", "Let's discuss the future of AI. What are your thoughts on its potential impact on society?")
|
|
||||||
|
|
||||||
conversation = AIConversation(model_1, model_2, system_prompt_1, system_prompt_2, ollama_endpoint)
|
|
||||||
conversation.start_conversation(initial_prompt, num_exchanges=0)
|
conversation.start_conversation(initial_prompt, num_exchanges=0)
|
||||||
|
|
||||||
def run_streamlit():
|
def run_streamlit(conversation, initial_prompt):
|
||||||
|
"""Run the conversation in Streamlit interface mode."""
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
from streamlit_app import streamlit_interface
|
from streamlit_app import streamlit_interface
|
||||||
|
|
||||||
load_dotenv()
|
|
||||||
|
|
||||||
ollama_endpoint = os.getenv("OLLAMA_ENDPOINT")
|
|
||||||
model_1 = os.getenv("MODEL_1")
|
|
||||||
model_2 = os.getenv("MODEL_2")
|
|
||||||
|
|
||||||
system_prompt_1 = load_system_prompt("system_prompt_1.txt")
|
|
||||||
system_prompt_2 = load_system_prompt("system_prompt_2.txt")
|
|
||||||
|
|
||||||
initial_prompt = os.getenv("INITIAL_PROMPT", "Let's discuss the future of AI. What are your thoughts on its potential impact on society?")
|
|
||||||
|
|
||||||
conversation = AIConversation(ollama_endpoint, model_1, model_2, system_prompt_1, system_prompt_2)
|
|
||||||
streamlit_interface(conversation, initial_prompt)
|
streamlit_interface(conversation, initial_prompt)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -3,3 +3,4 @@ requests
|
|||||||
termcolor
|
termcolor
|
||||||
streamlit
|
streamlit
|
||||||
Pillow
|
Pillow
|
||||||
|
tiktoken
|
Loading…
x
Reference in New Issue
Block a user