diff --git a/ai_conversation.py b/ai_conversation.py index 6a0dd4a..6799c42 100644 --- a/ai_conversation.py +++ b/ai_conversation.py @@ -1,68 +1,61 @@ -from ollama_client import OllamaClient import ollama from termcolor import colored import datetime - class AIConversation: - def __init__( - self, ollama_endpoint, model_1, model_2, system_prompt_1, system_prompt_2 - ): - self.ollama_client = OllamaClient(ollama_endpoint) + def __init__(self, model_1, model_2, system_prompt_1, system_prompt_2, ollama_endpoint): self.model_1 = model_1 self.model_2 = model_2 self.system_prompt_1 = system_prompt_1 self.system_prompt_2 = system_prompt_2 self.current_model = self.model_1 - self.current_system_prompt = self.system_prompt_1 - + self.messages_1 = [{"role": "system", "content": system_prompt_1}] + self.messages_2 = [{"role": "system", "content": system_prompt_2}] + self.client = ollama.Client(ollama_endpoint) + def start_conversation(self, initial_message, num_exchanges=0): current_message = initial_message - current_model = self.model_1 - current_system_prompt = self.system_prompt_1 color_1 = "cyan" color_2 = "yellow" - messages = [] + conversation_log = [] + + # Appending the initial message to the conversation log in the system prompt + self.messages_1[0]["content"] += f"\n\nInitial message: {current_message}" + self.messages_2[0]["content"] += f"\n\nInitial message: {current_message}" print(colored(f"Starting conversation with: {current_message}", "green")) print(colored("Press CTRL+C to stop the conversation.", "red")) print() - role = "user" - try: i = 0 while num_exchanges == 0 or i < num_exchanges: - response = self.ollama_client.generate( - current_model, current_message, current_system_prompt - ) - - # Adding the name of the model to the response - response = f"{current_model}: {response}" - - model_name = f"{current_model.upper()}:" - formatted_response = f"{model_name}\n{response}\n" - - if current_model == self.model_1: - print(colored(formatted_response, color_1)) + if self.current_model == self.model_1: + messages = self.messages_1 + other_messages = self.messages_2 + color = color_1 else: - print(colored(formatted_response, color_2)) + messages = self.messages_2 + other_messages = self.messages_1 + color = color_2 - messages.append({"role": role, "content": formatted_response}) + messages.append({"role": "user", "content": current_message}) + other_messages.append({"role": "assistant", "content": current_message}) - # Switch roles - if role == "user": - role = "assistant" - else: - role = "user" + response = self.client.chat(model=self.current_model, messages=messages) + response_content = response['message']['content'] - current_message = response - if current_model == self.model_1: - current_model = self.model_2 - current_system_prompt = self.system_prompt_2 - else: - current_model = self.model_1 - current_system_prompt = self.system_prompt_1 + model_name = f"{self.current_model.upper()}:" + formatted_response = f"{model_name}\n{response_content}\n" + + print(colored(formatted_response, color)) + conversation_log.append({"role": "assistant", "content": formatted_response}) + + messages.append({"role": "assistant", "content": response_content}) + other_messages.append({"role": "user", "content": response_content}) + + current_message = response_content + self.current_model = self.model_2 if self.current_model == self.model_1 else self.model_1 print(colored("---", "magenta")) print() @@ -77,41 +70,29 @@ class AIConversation: print(colored("\nConversation stopped by user.", "red")) print(colored("Conversation ended.", "green")) - self.save_conversation_log(messages) - - def stream_conversation(self, current_message): - response = self.ollama_client.generate( - self.current_model, current_message, self.current_system_prompt - ) - - model_name = f"{self.current_model.upper()}:" - formatted_response = f"{model_name}\n{response}\n" - - yield formatted_response - - if self.current_model == self.model_1: - self.current_model = self.model_2 - self.current_system_prompt = self.system_prompt_2 - else: - self.current_model = self.model_1 - self.current_system_prompt = self.system_prompt_1 - - yield "---\n" + self.save_conversation_log(conversation_log) def get_conversation_response(self, current_message): - response = self.ollama_client.generate( - self.current_model, current_message, self.current_system_prompt - ) + if self.current_model == self.model_1: + messages = self.messages_1 + other_messages = self.messages_2 + else: + messages = self.messages_2 + other_messages = self.messages_1 + + messages.append({"role": "user", "content": current_message}) + other_messages.append({"role": "assistant", "content": current_message}) + + response = self.client.chat(model=self.current_model, messages=messages) + response_content = response['message']['content'] model_name = f"{self.current_model.upper()}:" - formatted_response = f"{model_name}\n{response}\n" + formatted_response = f"{model_name}\n{response_content}\n" - if self.current_model == self.model_1: - self.current_model = self.model_2 - self.current_system_prompt = self.system_prompt_2 - else: - self.current_model = self.model_1 - self.current_system_prompt = self.system_prompt_1 + messages.append({"role": "assistant", "content": response_content}) + other_messages.append({"role": "user", "content": response_content}) + + self.current_model = self.model_2 if self.current_model == self.model_1 else self.model_1 return formatted_response @@ -134,4 +115,4 @@ class AIConversation: with open(filename, "w") as f: f.write(log_content) - print(f"Conversation log saved to {filename}") \ No newline at end of file + print(f"Conversation log saved to {filename}") diff --git a/main.py b/main.py index 31ff54c..c560308 100644 --- a/main.py +++ b/main.py @@ -35,7 +35,7 @@ def run_cli(): initial_prompt = os.getenv("INITIAL_PROMPT", "Let's discuss the future of AI. What are your thoughts on its potential impact on society?") - conversation = AIConversation(ollama_endpoint, model_1, model_2, system_prompt_1, system_prompt_2) + conversation = AIConversation(model_1, model_2, system_prompt_1, system_prompt_2, ollama_endpoint) conversation.start_conversation(initial_prompt, num_exchanges=0) def run_streamlit():