add context size to options

This commit is contained in:
tcsenpai 2024-10-14 12:02:09 +02:00
parent 33743388ad
commit f211d1fb2a
2 changed files with 25 additions and 9 deletions

View File

@ -12,6 +12,7 @@ class AIConversation:
system_prompt_2, system_prompt_2,
ollama_endpoint, ollama_endpoint,
max_tokens=4000, max_tokens=4000,
limit_tokens=True
): ):
# Initialize conversation parameters and Ollama client # Initialize conversation parameters and Ollama client
self.model_1 = model_1 self.model_1 = model_1
@ -25,14 +26,14 @@ class AIConversation:
self.ollama_endpoint = ollama_endpoint self.ollama_endpoint = ollama_endpoint
self.tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo") self.tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo")
self.max_tokens = max_tokens self.max_tokens = max_tokens
self.limit_tokens = limit_tokens
def count_tokens(self, messages): def count_tokens(self, messages):
# Count the total number of tokens in the messages # Count the total number of tokens in the messages
return sum(len(self.tokenizer.encode(msg["content"])) for msg in messages) return sum(len(self.tokenizer.encode(msg["content"])) for msg in messages)
def trim_messages(self, messages): def trim_messages(self, messages):
# Trim messages to stay within the token limit # Trim messages to stay within the token limit
if self.count_tokens(messages) > self.max_tokens: if self.limit_tokens and self.count_tokens(messages) > self.max_tokens:
print(colored(f"[SYSTEM] Max tokens reached. Sliding context window...", "magenta")) print(colored(f"[SYSTEM] Max tokens reached. Sliding context window...", "magenta"))
# Keep the system prompt (first message) # Keep the system prompt (first message)
@ -52,6 +53,7 @@ class AIConversation:
return messages return messages
def start_conversation(self, initial_message, num_exchanges=0, options=None): def start_conversation(self, initial_message, num_exchanges=0, options=None):
# Main conversation loop # Main conversation loop
current_message = initial_message current_message = initial_message
color_1, color_2 = "cyan", "yellow" color_1, color_2 = "cyan", "yellow"

26
main.py
View File

@ -3,15 +3,21 @@ import json
from dotenv import load_dotenv, set_key from dotenv import load_dotenv, set_key
from ai_conversation import AIConversation from ai_conversation import AIConversation
def load_system_prompt(filename): def load_system_prompt(filename):
"""Load the system prompt from a file.""" """Load the system prompt from a file."""
with open(filename, "r") as file: with open(filename, "r") as file:
return file.read().strip() return file.read().strip()
def load_options_from_json(filename):
def load_options_from_json(filename, max_tokens, limit_tokens):
"""Load options from a JSON file.""" """Load options from a JSON file."""
with open(filename, "r") as file: with open(filename, "r") as file:
return json.load(file) options = json.load(file)
if limit_tokens:
options["num_ctx"] = max_tokens
return options
def main(): def main():
# Load environment variables # Load environment variables
@ -29,20 +35,28 @@ def main():
) )
max_tokens = int(os.getenv("MAX_TOKENS", 4000)) max_tokens = int(os.getenv("MAX_TOKENS", 4000))
print(f"Max tokens: {max_tokens}") print(f"Max tokens: {max_tokens}")
limit_tokens = os.getenv("LIMIT_TOKENS", True)
limit_tokens = limit_tokens.lower() == "true"
print(f"Limit tokens: {limit_tokens}")
# Load options from JSON file # Load options from JSON file
options = load_options_from_json("options.json") options = load_options_from_json("options.json", max_tokens, limit_tokens)
print(f"Options: {options}") print(f"Options: {options}")
# Initialize the AI conversation object # Initialize the AI conversation object
conversation = AIConversation( conversation = AIConversation(
model_1, model_2, system_prompt_1, system_prompt_2, ollama_endpoint, max_tokens model_1,
model_2,
system_prompt_1,
system_prompt_2,
ollama_endpoint,
max_tokens,
limit_tokens,
) )
# Run the appropriate interface based on command-line arguments # Run the appropriate interface based on command-line arguments
run_cli(conversation, initial_prompt, options) run_cli(conversation, initial_prompt, options)
def run_cli(conversation, initial_prompt, options): def run_cli(conversation, initial_prompt, options):
"""Run the conversation in command-line interface mode.""" """Run the conversation in command-line interface mode."""
load_dotenv() load_dotenv()