mirror of
https://github.com/tcsenpai/DualMind.git
synced 2025-06-06 18:45:22 +00:00
add context size to options
This commit is contained in:
parent
33743388ad
commit
f211d1fb2a
@ -12,6 +12,7 @@ class AIConversation:
|
|||||||
system_prompt_2,
|
system_prompt_2,
|
||||||
ollama_endpoint,
|
ollama_endpoint,
|
||||||
max_tokens=4000,
|
max_tokens=4000,
|
||||||
|
limit_tokens=True
|
||||||
):
|
):
|
||||||
# Initialize conversation parameters and Ollama client
|
# Initialize conversation parameters and Ollama client
|
||||||
self.model_1 = model_1
|
self.model_1 = model_1
|
||||||
@ -25,14 +26,14 @@ class AIConversation:
|
|||||||
self.ollama_endpoint = ollama_endpoint
|
self.ollama_endpoint = ollama_endpoint
|
||||||
self.tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo")
|
self.tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo")
|
||||||
self.max_tokens = max_tokens
|
self.max_tokens = max_tokens
|
||||||
|
self.limit_tokens = limit_tokens
|
||||||
def count_tokens(self, messages):
|
def count_tokens(self, messages):
|
||||||
# Count the total number of tokens in the messages
|
# Count the total number of tokens in the messages
|
||||||
return sum(len(self.tokenizer.encode(msg["content"])) for msg in messages)
|
return sum(len(self.tokenizer.encode(msg["content"])) for msg in messages)
|
||||||
|
|
||||||
def trim_messages(self, messages):
|
def trim_messages(self, messages):
|
||||||
# Trim messages to stay within the token limit
|
# Trim messages to stay within the token limit
|
||||||
if self.count_tokens(messages) > self.max_tokens:
|
if self.limit_tokens and self.count_tokens(messages) > self.max_tokens:
|
||||||
print(colored(f"[SYSTEM] Max tokens reached. Sliding context window...", "magenta"))
|
print(colored(f"[SYSTEM] Max tokens reached. Sliding context window...", "magenta"))
|
||||||
|
|
||||||
# Keep the system prompt (first message)
|
# Keep the system prompt (first message)
|
||||||
@ -52,6 +53,7 @@ class AIConversation:
|
|||||||
return messages
|
return messages
|
||||||
|
|
||||||
def start_conversation(self, initial_message, num_exchanges=0, options=None):
|
def start_conversation(self, initial_message, num_exchanges=0, options=None):
|
||||||
|
|
||||||
# Main conversation loop
|
# Main conversation loop
|
||||||
current_message = initial_message
|
current_message = initial_message
|
||||||
color_1, color_2 = "cyan", "yellow"
|
color_1, color_2 = "cyan", "yellow"
|
||||||
|
26
main.py
26
main.py
@ -3,15 +3,21 @@ import json
|
|||||||
from dotenv import load_dotenv, set_key
|
from dotenv import load_dotenv, set_key
|
||||||
from ai_conversation import AIConversation
|
from ai_conversation import AIConversation
|
||||||
|
|
||||||
|
|
||||||
def load_system_prompt(filename):
|
def load_system_prompt(filename):
|
||||||
"""Load the system prompt from a file."""
|
"""Load the system prompt from a file."""
|
||||||
with open(filename, "r") as file:
|
with open(filename, "r") as file:
|
||||||
return file.read().strip()
|
return file.read().strip()
|
||||||
|
|
||||||
def load_options_from_json(filename):
|
|
||||||
|
def load_options_from_json(filename, max_tokens, limit_tokens):
|
||||||
"""Load options from a JSON file."""
|
"""Load options from a JSON file."""
|
||||||
with open(filename, "r") as file:
|
with open(filename, "r") as file:
|
||||||
return json.load(file)
|
options = json.load(file)
|
||||||
|
if limit_tokens:
|
||||||
|
options["num_ctx"] = max_tokens
|
||||||
|
return options
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# Load environment variables
|
# Load environment variables
|
||||||
@ -29,20 +35,28 @@ def main():
|
|||||||
)
|
)
|
||||||
max_tokens = int(os.getenv("MAX_TOKENS", 4000))
|
max_tokens = int(os.getenv("MAX_TOKENS", 4000))
|
||||||
print(f"Max tokens: {max_tokens}")
|
print(f"Max tokens: {max_tokens}")
|
||||||
|
limit_tokens = os.getenv("LIMIT_TOKENS", True)
|
||||||
|
limit_tokens = limit_tokens.lower() == "true"
|
||||||
|
print(f"Limit tokens: {limit_tokens}")
|
||||||
# Load options from JSON file
|
# Load options from JSON file
|
||||||
options = load_options_from_json("options.json")
|
options = load_options_from_json("options.json", max_tokens, limit_tokens)
|
||||||
print(f"Options: {options}")
|
print(f"Options: {options}")
|
||||||
|
|
||||||
# Initialize the AI conversation object
|
# Initialize the AI conversation object
|
||||||
conversation = AIConversation(
|
conversation = AIConversation(
|
||||||
model_1, model_2, system_prompt_1, system_prompt_2, ollama_endpoint, max_tokens
|
model_1,
|
||||||
|
model_2,
|
||||||
|
system_prompt_1,
|
||||||
|
system_prompt_2,
|
||||||
|
ollama_endpoint,
|
||||||
|
max_tokens,
|
||||||
|
limit_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Run the appropriate interface based on command-line arguments
|
# Run the appropriate interface based on command-line arguments
|
||||||
run_cli(conversation, initial_prompt, options)
|
run_cli(conversation, initial_prompt, options)
|
||||||
|
|
||||||
|
|
||||||
def run_cli(conversation, initial_prompt, options):
|
def run_cli(conversation, initial_prompt, options):
|
||||||
"""Run the conversation in command-line interface mode."""
|
"""Run the conversation in command-line interface mode."""
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user