From f211d1fb2a474bdeb87a7455780717bbf85e0ed3 Mon Sep 17 00:00:00 2001 From: tcsenpai Date: Mon, 14 Oct 2024 12:02:09 +0200 Subject: [PATCH] add context size to options --- ai_conversation.py | 8 +++++--- main.py | 26 ++++++++++++++++++++------ 2 files changed, 25 insertions(+), 9 deletions(-) diff --git a/ai_conversation.py b/ai_conversation.py index 7458ab7..fe09853 100644 --- a/ai_conversation.py +++ b/ai_conversation.py @@ -11,7 +11,8 @@ class AIConversation: system_prompt_1, system_prompt_2, ollama_endpoint, - max_tokens=4000, + max_tokens=4000, + limit_tokens=True ): # Initialize conversation parameters and Ollama client self.model_1 = model_1 @@ -25,14 +26,14 @@ class AIConversation: self.ollama_endpoint = ollama_endpoint self.tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo") self.max_tokens = max_tokens - + self.limit_tokens = limit_tokens def count_tokens(self, messages): # Count the total number of tokens in the messages return sum(len(self.tokenizer.encode(msg["content"])) for msg in messages) def trim_messages(self, messages): # Trim messages to stay within the token limit - if self.count_tokens(messages) > self.max_tokens: + if self.limit_tokens and self.count_tokens(messages) > self.max_tokens: print(colored(f"[SYSTEM] Max tokens reached. Sliding context window...", "magenta")) # Keep the system prompt (first message) @@ -52,6 +53,7 @@ class AIConversation: return messages def start_conversation(self, initial_message, num_exchanges=0, options=None): + # Main conversation loop current_message = initial_message color_1, color_2 = "cyan", "yellow" diff --git a/main.py b/main.py index 6a3bfe7..6754ec5 100644 --- a/main.py +++ b/main.py @@ -3,15 +3,21 @@ import json from dotenv import load_dotenv, set_key from ai_conversation import AIConversation + def load_system_prompt(filename): """Load the system prompt from a file.""" with open(filename, "r") as file: return file.read().strip() -def load_options_from_json(filename): + +def load_options_from_json(filename, max_tokens, limit_tokens): """Load options from a JSON file.""" with open(filename, "r") as file: - return json.load(file) + options = json.load(file) + if limit_tokens: + options["num_ctx"] = max_tokens + return options + def main(): # Load environment variables @@ -29,20 +35,28 @@ def main(): ) max_tokens = int(os.getenv("MAX_TOKENS", 4000)) print(f"Max tokens: {max_tokens}") - + limit_tokens = os.getenv("LIMIT_TOKENS", True) + limit_tokens = limit_tokens.lower() == "true" + print(f"Limit tokens: {limit_tokens}") # Load options from JSON file - options = load_options_from_json("options.json") + options = load_options_from_json("options.json", max_tokens, limit_tokens) print(f"Options: {options}") # Initialize the AI conversation object conversation = AIConversation( - model_1, model_2, system_prompt_1, system_prompt_2, ollama_endpoint, max_tokens + model_1, + model_2, + system_prompt_1, + system_prompt_2, + ollama_endpoint, + max_tokens, + limit_tokens, ) - # Run the appropriate interface based on command-line arguments run_cli(conversation, initial_prompt, options) + def run_cli(conversation, initial_prompt, options): """Run the conversation in command-line interface mode.""" load_dotenv()