switched to the chat method, implementing context

This commit is contained in:
tcsenpai 2024-10-06 20:37:32 +02:00
parent 77bd6d43ef
commit 5d865e90f0
2 changed files with 52 additions and 71 deletions

View File

@ -1,68 +1,61 @@
from ollama_client import OllamaClient
import ollama
from termcolor import colored
import datetime
class AIConversation:
def __init__(
self, ollama_endpoint, model_1, model_2, system_prompt_1, system_prompt_2
):
self.ollama_client = OllamaClient(ollama_endpoint)
def __init__(self, model_1, model_2, system_prompt_1, system_prompt_2, ollama_endpoint):
self.model_1 = model_1
self.model_2 = model_2
self.system_prompt_1 = system_prompt_1
self.system_prompt_2 = system_prompt_2
self.current_model = self.model_1
self.current_system_prompt = self.system_prompt_1
self.messages_1 = [{"role": "system", "content": system_prompt_1}]
self.messages_2 = [{"role": "system", "content": system_prompt_2}]
self.client = ollama.Client(ollama_endpoint)
def start_conversation(self, initial_message, num_exchanges=0):
current_message = initial_message
current_model = self.model_1
current_system_prompt = self.system_prompt_1
color_1 = "cyan"
color_2 = "yellow"
messages = []
conversation_log = []
# Appending the initial message to the conversation log in the system prompt
self.messages_1[0]["content"] += f"\n\nInitial message: {current_message}"
self.messages_2[0]["content"] += f"\n\nInitial message: {current_message}"
print(colored(f"Starting conversation with: {current_message}", "green"))
print(colored("Press CTRL+C to stop the conversation.", "red"))
print()
role = "user"
try:
i = 0
while num_exchanges == 0 or i < num_exchanges:
response = self.ollama_client.generate(
current_model, current_message, current_system_prompt
)
# Adding the name of the model to the response
response = f"{current_model}: {response}"
model_name = f"{current_model.upper()}:"
formatted_response = f"{model_name}\n{response}\n"
if current_model == self.model_1:
print(colored(formatted_response, color_1))
if self.current_model == self.model_1:
messages = self.messages_1
other_messages = self.messages_2
color = color_1
else:
print(colored(formatted_response, color_2))
messages = self.messages_2
other_messages = self.messages_1
color = color_2
messages.append({"role": role, "content": formatted_response})
messages.append({"role": "user", "content": current_message})
other_messages.append({"role": "assistant", "content": current_message})
# Switch roles
if role == "user":
role = "assistant"
else:
role = "user"
response = self.client.chat(model=self.current_model, messages=messages)
response_content = response['message']['content']
current_message = response
if current_model == self.model_1:
current_model = self.model_2
current_system_prompt = self.system_prompt_2
else:
current_model = self.model_1
current_system_prompt = self.system_prompt_1
model_name = f"{self.current_model.upper()}:"
formatted_response = f"{model_name}\n{response_content}\n"
print(colored(formatted_response, color))
conversation_log.append({"role": "assistant", "content": formatted_response})
messages.append({"role": "assistant", "content": response_content})
other_messages.append({"role": "user", "content": response_content})
current_message = response_content
self.current_model = self.model_2 if self.current_model == self.model_1 else self.model_1
print(colored("---", "magenta"))
print()
@ -77,41 +70,29 @@ class AIConversation:
print(colored("\nConversation stopped by user.", "red"))
print(colored("Conversation ended.", "green"))
self.save_conversation_log(messages)
def stream_conversation(self, current_message):
response = self.ollama_client.generate(
self.current_model, current_message, self.current_system_prompt
)
model_name = f"{self.current_model.upper()}:"
formatted_response = f"{model_name}\n{response}\n"
yield formatted_response
if self.current_model == self.model_1:
self.current_model = self.model_2
self.current_system_prompt = self.system_prompt_2
else:
self.current_model = self.model_1
self.current_system_prompt = self.system_prompt_1
yield "---\n"
self.save_conversation_log(conversation_log)
def get_conversation_response(self, current_message):
response = self.ollama_client.generate(
self.current_model, current_message, self.current_system_prompt
)
if self.current_model == self.model_1:
messages = self.messages_1
other_messages = self.messages_2
else:
messages = self.messages_2
other_messages = self.messages_1
messages.append({"role": "user", "content": current_message})
other_messages.append({"role": "assistant", "content": current_message})
response = self.client.chat(model=self.current_model, messages=messages)
response_content = response['message']['content']
model_name = f"{self.current_model.upper()}:"
formatted_response = f"{model_name}\n{response}\n"
formatted_response = f"{model_name}\n{response_content}\n"
if self.current_model == self.model_1:
self.current_model = self.model_2
self.current_system_prompt = self.system_prompt_2
else:
self.current_model = self.model_1
self.current_system_prompt = self.system_prompt_1
messages.append({"role": "assistant", "content": response_content})
other_messages.append({"role": "user", "content": response_content})
self.current_model = self.model_2 if self.current_model == self.model_1 else self.model_1
return formatted_response

View File

@ -35,7 +35,7 @@ def run_cli():
initial_prompt = os.getenv("INITIAL_PROMPT", "Let's discuss the future of AI. What are your thoughts on its potential impact on society?")
conversation = AIConversation(ollama_endpoint, model_1, model_2, system_prompt_1, system_prompt_2)
conversation = AIConversation(model_1, model_2, system_prompt_1, system_prompt_2, ollama_endpoint)
conversation.start_conversation(initial_prompt, num_exchanges=0)
def run_streamlit():