switched to the chat method, implementing context

This commit is contained in:
tcsenpai 2024-10-06 20:37:32 +02:00
parent 77bd6d43ef
commit 5d865e90f0
2 changed files with 52 additions and 71 deletions

View File

@ -1,68 +1,61 @@
from ollama_client import OllamaClient
import ollama import ollama
from termcolor import colored from termcolor import colored
import datetime import datetime
class AIConversation: class AIConversation:
def __init__( def __init__(self, model_1, model_2, system_prompt_1, system_prompt_2, ollama_endpoint):
self, ollama_endpoint, model_1, model_2, system_prompt_1, system_prompt_2
):
self.ollama_client = OllamaClient(ollama_endpoint)
self.model_1 = model_1 self.model_1 = model_1
self.model_2 = model_2 self.model_2 = model_2
self.system_prompt_1 = system_prompt_1 self.system_prompt_1 = system_prompt_1
self.system_prompt_2 = system_prompt_2 self.system_prompt_2 = system_prompt_2
self.current_model = self.model_1 self.current_model = self.model_1
self.current_system_prompt = self.system_prompt_1 self.messages_1 = [{"role": "system", "content": system_prompt_1}]
self.messages_2 = [{"role": "system", "content": system_prompt_2}]
self.client = ollama.Client(ollama_endpoint)
def start_conversation(self, initial_message, num_exchanges=0): def start_conversation(self, initial_message, num_exchanges=0):
current_message = initial_message current_message = initial_message
current_model = self.model_1
current_system_prompt = self.system_prompt_1
color_1 = "cyan" color_1 = "cyan"
color_2 = "yellow" color_2 = "yellow"
messages = [] conversation_log = []
# Appending the initial message to the conversation log in the system prompt
self.messages_1[0]["content"] += f"\n\nInitial message: {current_message}"
self.messages_2[0]["content"] += f"\n\nInitial message: {current_message}"
print(colored(f"Starting conversation with: {current_message}", "green")) print(colored(f"Starting conversation with: {current_message}", "green"))
print(colored("Press CTRL+C to stop the conversation.", "red")) print(colored("Press CTRL+C to stop the conversation.", "red"))
print() print()
role = "user"
try: try:
i = 0 i = 0
while num_exchanges == 0 or i < num_exchanges: while num_exchanges == 0 or i < num_exchanges:
response = self.ollama_client.generate( if self.current_model == self.model_1:
current_model, current_message, current_system_prompt messages = self.messages_1
) other_messages = self.messages_2
color = color_1
# Adding the name of the model to the response
response = f"{current_model}: {response}"
model_name = f"{current_model.upper()}:"
formatted_response = f"{model_name}\n{response}\n"
if current_model == self.model_1:
print(colored(formatted_response, color_1))
else: else:
print(colored(formatted_response, color_2)) messages = self.messages_2
other_messages = self.messages_1
color = color_2
messages.append({"role": role, "content": formatted_response}) messages.append({"role": "user", "content": current_message})
other_messages.append({"role": "assistant", "content": current_message})
# Switch roles response = self.client.chat(model=self.current_model, messages=messages)
if role == "user": response_content = response['message']['content']
role = "assistant"
else:
role = "user"
current_message = response model_name = f"{self.current_model.upper()}:"
if current_model == self.model_1: formatted_response = f"{model_name}\n{response_content}\n"
current_model = self.model_2
current_system_prompt = self.system_prompt_2 print(colored(formatted_response, color))
else: conversation_log.append({"role": "assistant", "content": formatted_response})
current_model = self.model_1
current_system_prompt = self.system_prompt_1 messages.append({"role": "assistant", "content": response_content})
other_messages.append({"role": "user", "content": response_content})
current_message = response_content
self.current_model = self.model_2 if self.current_model == self.model_1 else self.model_1
print(colored("---", "magenta")) print(colored("---", "magenta"))
print() print()
@ -77,41 +70,29 @@ class AIConversation:
print(colored("\nConversation stopped by user.", "red")) print(colored("\nConversation stopped by user.", "red"))
print(colored("Conversation ended.", "green")) print(colored("Conversation ended.", "green"))
self.save_conversation_log(messages) self.save_conversation_log(conversation_log)
def stream_conversation(self, current_message):
response = self.ollama_client.generate(
self.current_model, current_message, self.current_system_prompt
)
model_name = f"{self.current_model.upper()}:"
formatted_response = f"{model_name}\n{response}\n"
yield formatted_response
if self.current_model == self.model_1:
self.current_model = self.model_2
self.current_system_prompt = self.system_prompt_2
else:
self.current_model = self.model_1
self.current_system_prompt = self.system_prompt_1
yield "---\n"
def get_conversation_response(self, current_message): def get_conversation_response(self, current_message):
response = self.ollama_client.generate( if self.current_model == self.model_1:
self.current_model, current_message, self.current_system_prompt messages = self.messages_1
) other_messages = self.messages_2
else:
messages = self.messages_2
other_messages = self.messages_1
messages.append({"role": "user", "content": current_message})
other_messages.append({"role": "assistant", "content": current_message})
response = self.client.chat(model=self.current_model, messages=messages)
response_content = response['message']['content']
model_name = f"{self.current_model.upper()}:" model_name = f"{self.current_model.upper()}:"
formatted_response = f"{model_name}\n{response}\n" formatted_response = f"{model_name}\n{response_content}\n"
if self.current_model == self.model_1: messages.append({"role": "assistant", "content": response_content})
self.current_model = self.model_2 other_messages.append({"role": "user", "content": response_content})
self.current_system_prompt = self.system_prompt_2
else: self.current_model = self.model_2 if self.current_model == self.model_1 else self.model_1
self.current_model = self.model_1
self.current_system_prompt = self.system_prompt_1
return formatted_response return formatted_response
@ -134,4 +115,4 @@ class AIConversation:
with open(filename, "w") as f: with open(filename, "w") as f:
f.write(log_content) f.write(log_content)
print(f"Conversation log saved to {filename}") print(f"Conversation log saved to {filename}")

View File

@ -35,7 +35,7 @@ def run_cli():
initial_prompt = os.getenv("INITIAL_PROMPT", "Let's discuss the future of AI. What are your thoughts on its potential impact on society?") initial_prompt = os.getenv("INITIAL_PROMPT", "Let's discuss the future of AI. What are your thoughts on its potential impact on society?")
conversation = AIConversation(ollama_endpoint, model_1, model_2, system_prompt_1, system_prompt_2) conversation = AIConversation(model_1, model_2, system_prompt_1, system_prompt_2, ollama_endpoint)
conversation.start_conversation(initial_prompt, num_exchanges=0) conversation.start_conversation(initial_prompt, num_exchanges=0)
def run_streamlit(): def run_streamlit():