From 648e03ae28b60852ddab2d3b755774fa6b8ab7f6 Mon Sep 17 00:00:00 2001 From: tcsenpai Date: Sun, 6 Oct 2024 22:59:55 +0200 Subject: [PATCH] Added token count and trimmer --- README.md | 1 - ai_conversation.py | 86 +++++++++++++++++++++++++++++++--------------- main.py | 68 ++++++++++++++++++------------------ requirements.txt | 3 +- 4 files changed, 95 insertions(+), 63 deletions(-) diff --git a/README.md b/README.md index 1a4734b..0041eb8 100644 --- a/README.md +++ b/README.md @@ -86,7 +86,6 @@ The appearance of the Streamlit interface can be customized by modifying the `st - `main.py`: Entry point of the application - `ai_conversation.py`: Core logic for AI conversations -- `ollama_client.py`: Client for interacting with the Ollama API - `streamlit_app.py`: Streamlit web interface implementation - `style/custom.css`: Custom styles for the web interface - `run_cli.sh`: Shell script to run the CLI version diff --git a/ai_conversation.py b/ai_conversation.py index dc13657..f25b96a 100644 --- a/ai_conversation.py +++ b/ai_conversation.py @@ -1,9 +1,19 @@ import ollama from termcolor import colored import datetime +import tiktoken # Used for token counting class AIConversation: - def __init__(self, model_1, model_2, system_prompt_1, system_prompt_2, ollama_endpoint): + def __init__( + self, + model_1, + model_2, + system_prompt_1, + system_prompt_2, + ollama_endpoint, + max_tokens=4000, + ): + # Initialize conversation parameters and Ollama client self.model_1 = model_1 self.model_2 = model_2 self.system_prompt_1 = system_prompt_1 @@ -13,14 +23,31 @@ class AIConversation: self.messages_2 = [{"role": "system", "content": system_prompt_2}] self.client = ollama.Client(ollama_endpoint) self.ollama_endpoint = ollama_endpoint + self.tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo") + self.max_tokens = max_tokens + + def count_tokens(self, messages): + # Count the total number of tokens in the messages + return sum(len(self.tokenizer.encode(msg["content"])) for msg in messages) + + def trim_messages(self, messages): + # Trim messages to stay within the token limit + if self.count_tokens(messages) > self.max_tokens: + print(colored(f"[SYSTEM] Max tokens reached. Trimming messages...", "magenta")) + while self.count_tokens(messages) > self.max_tokens: + if len(messages) > 1: + messages.pop(1) # Remove the oldest non-system message + else: + break # Avoid removing the system message + return messages def start_conversation(self, initial_message, num_exchanges=0): + # Main conversation loop current_message = initial_message - color_1 = "cyan" - color_2 = "yellow" + color_1, color_2 = "cyan", "yellow" conversation_log = [] - # Appending the initial message to the conversation log in the system prompt + # Add initial message to system prompts self.messages_1[0]["content"] += f"\n\nInitial message: {current_message}" self.messages_2[0]["content"] += f"\n\nInitial message: {current_message}" @@ -30,56 +57,59 @@ class AIConversation: try: i = 0 - active_ai = 1 # Starting with AI 1 + active_ai = 1 # Starting with AI 1 while num_exchanges == 0 or i < num_exchanges: - - - if active_ai == 0: - name = "AI 1" - messages = self.messages_1 - other_messages = self.messages_2 - color = color_1 - else: - name = "AI 2" - messages = self.messages_2 - other_messages = self.messages_1 - color = color_2 + # Set up current AI's parameters + name = "AI 1" if active_ai == 0 else "AI 2" + messages = self.messages_1 if active_ai == 0 else self.messages_2 + other_messages = self.messages_2 if active_ai == 0 else self.messages_1 + color = color_1 if active_ai == 0 else color_2 + # Add user message to conversation history messages.append({"role": "user", "content": current_message}) other_messages.append({"role": "assistant", "content": current_message}) - #print(colored(f"Conversation with {name} ({self.current_model})", "blue")) + # Trim messages and get token count + messages = self.trim_messages(messages) + token_count = self.count_tokens(messages) + print(colored(f"Context token count: {token_count}", "magenta")) + + # Generate AI response response = self.client.chat( - model=self.current_model, + model=self.current_model, messages=messages, options={ - "temperature": 0.7, # Adjust this value to control randomness + "temperature": 0.7, # Control randomness "repeat_penalty": 1.2, # Penalize repetition - } + }, ) - response_content = response['message']['content'] + response_content = response["message"]["content"] # Post-process to remove repetition response_content = self.remove_repetition(response_content) + # Format and print the response model_name = f"{self.current_model.upper()} ({name}):" formatted_response = f"{model_name}\n{response_content}\n" - print(colored(formatted_response, color)) - conversation_log.append({"role": "assistant", "content": formatted_response}) + conversation_log.append( + {"role": "assistant", "content": formatted_response} + ) + # Update conversation history messages.append({"role": "assistant", "content": response_content}) other_messages.append({"role": "user", "content": response_content}) current_message = response_content - # Switching the AI + # Switch to the other AI for the next turn self.current_model = self.model_2 if active_ai == 1 else self.model_1 active_ai = 1 if active_ai == 0 else 0 print(colored("---", "magenta")) print() + # Check for conversation end condition if current_message.strip().endswith("{{end_conversation}}"): print(colored("Conversation ended by the AI.", "green")) break @@ -92,8 +122,8 @@ class AIConversation: print(colored("Conversation ended.", "green")) self.save_conversation_log(conversation_log) - def save_conversation_log(self, messages, filename=None): + # Save the conversation log to a file if filename is None: timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") filename = f"conversation_log_{timestamp}.txt" @@ -115,7 +145,7 @@ class AIConversation: print(f"Conversation log saved to {filename}") def remove_repetition(self, text): - # Split the text into sentences + # Remove repeated sentences while preserving order split_tokens = [".", "!", "?"] sentences = [] current_sentence = "" @@ -134,4 +164,4 @@ class AIConversation: unique_sentences.append(sentence) # Join the sentences back together - return ' '.join(unique_sentences) + return " ".join(unique_sentences) \ No newline at end of file diff --git a/main.py b/main.py index c560308..e1b81fc 100644 --- a/main.py +++ b/main.py @@ -4,57 +4,59 @@ from dotenv import load_dotenv, set_key from ai_conversation import AIConversation def load_system_prompt(filename): - with open(filename, 'r') as file: + """Load the system prompt from a file.""" + with open(filename, "r") as file: return file.read().strip() def main(): + # Load environment variables + load_dotenv() + + # Retrieve configuration from environment variables + ollama_endpoint = os.getenv("OLLAMA_ENDPOINT") + model_1 = os.getenv("MODEL_1") + model_2 = os.getenv("MODEL_2") + system_prompt_1 = load_system_prompt("system_prompt_1.txt") + system_prompt_2 = load_system_prompt("system_prompt_2.txt") + initial_prompt = os.getenv( + "INITIAL_PROMPT", + "Let's discuss the future of AI. What are your thoughts on its potential impact on society?", + ) + max_tokens = int(os.getenv("MAX_TOKENS", 4000)) + print(f"Max tokens: {max_tokens}") + + # Initialize the AI conversation object + conversation = AIConversation( + model_1, model_2, system_prompt_1, system_prompt_2, ollama_endpoint, max_tokens + ) + + # Set up command-line argument parser parser = argparse.ArgumentParser(description="AI Conversation") parser.add_argument("--cli", action="store_true", help="Run in CLI mode") - parser.add_argument("--streamlit", action="store_true", help="Run in Streamlit mode") + parser.add_argument( + "--streamlit", action="store_true", help="Run in Streamlit mode" + ) args = parser.parse_args() + # Run the appropriate interface based on command-line arguments if args.cli: - run_cli() + run_cli(conversation, initial_prompt) elif args.streamlit: - run_streamlit() + run_streamlit(conversation, initial_prompt) else: print("Please specify either --cli or --streamlit mode.") -def run_cli(): +def run_cli(conversation, initial_prompt): + """Run the conversation in command-line interface mode.""" load_dotenv() - - ollama_endpoint = os.getenv("OLLAMA_ENDPOINT") - model_1 = os.getenv("MODEL_1") - model_2 = os.getenv("MODEL_2") - - system_prompt_1_file = os.getenv("CUSTOM_SYSTEM_PROMPT_1", "system_prompt_1.txt") - system_prompt_2_file = os.getenv("CUSTOM_SYSTEM_PROMPT_2", "system_prompt_2.txt") - - system_prompt_1 = load_system_prompt(system_prompt_1_file) - system_prompt_2 = load_system_prompt(system_prompt_2_file) - - initial_prompt = os.getenv("INITIAL_PROMPT", "Let's discuss the future of AI. What are your thoughts on its potential impact on society?") - - conversation = AIConversation(model_1, model_2, system_prompt_1, system_prompt_2, ollama_endpoint) conversation.start_conversation(initial_prompt, num_exchanges=0) -def run_streamlit(): +def run_streamlit(conversation, initial_prompt): + """Run the conversation in Streamlit interface mode.""" import streamlit as st from streamlit_app import streamlit_interface - load_dotenv() - - ollama_endpoint = os.getenv("OLLAMA_ENDPOINT") - model_1 = os.getenv("MODEL_1") - model_2 = os.getenv("MODEL_2") - - system_prompt_1 = load_system_prompt("system_prompt_1.txt") - system_prompt_2 = load_system_prompt("system_prompt_2.txt") - - initial_prompt = os.getenv("INITIAL_PROMPT", "Let's discuss the future of AI. What are your thoughts on its potential impact on society?") - - conversation = AIConversation(ollama_endpoint, model_1, model_2, system_prompt_1, system_prompt_2) streamlit_interface(conversation, initial_prompt) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/requirements.txt b/requirements.txt index 39be983..4abe148 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,4 +2,5 @@ python-dotenv requests termcolor streamlit -Pillow \ No newline at end of file +Pillow +tiktoken \ No newline at end of file