import re import time from copy import copy from os import environ from itertools import cycle from openai import AzureOpenAI, OpenAI, RateLimitError from rich import print from .base_translator import Base PROMPT_ENV_MAP = { "user": "BBM_CHATGPTAPI_USER_MSG_TEMPLATE", "system": "BBM_CHATGPTAPI_SYS_MSG", } GPT35_MODEL_LIST = [ "gpt-3.5-turbo", "gpt-3.5-turbo-1106", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613", "gpt-3.5-turbo-0301", "gpt-3.5-turbo-0125", ] GPT4_MODEL_LIST = [ "gpt-4-1106-preview", "gpt-4", "gpt-4-32k", "gpt-4-0613", "gpt-4-32k-0613", ] class ChatGPTAPI(Base): DEFAULT_PROMPT = "Please help me to translate,`{text}` to {language}, please return only translated content not include the origin text" def __init__( self, key, language, api_base=None, prompt_template=None, prompt_sys_msg=None, temperature=1.0, **kwargs, ) -> None: super().__init__(key, language) self.key_len = len(key.split(",")) self.openai_client = OpenAI(api_key=key, base_url=api_base) self.api_base = api_base self.prompt_template = ( prompt_template or environ.get(PROMPT_ENV_MAP["user"]) or self.DEFAULT_PROMPT ) self.prompt_sys_msg = ( prompt_sys_msg or environ.get( "OPENAI_API_SYS_MSG", ) # XXX: for backward compatibility, deprecate soon or environ.get(PROMPT_ENV_MAP["system"]) or "" ) self.system_content = environ.get("OPENAI_API_SYS_MSG") or "" self.deployment_id = None self.temperature = temperature # gpt3 all models for save the limit my_model_list = [ i["id"] for i in self.openai_client.models.list().model_dump()["data"] ] model_list = list(set(my_model_list) & set(GPT35_MODEL_LIST)) print(f"Using model list {model_list}") self.model_list = cycle(model_list) def rotate_key(self): self.openai_client.api_key = next(self.keys) def rotate_model(self): # TODO self.model = next(self.model_list) def create_chat_completion(self, text): content = self.prompt_template.format( text=text, language=self.language, crlf="\n" ) sys_content = self.system_content or self.prompt_sys_msg.format(crlf="\n") messages = [ {"role": "system", "content": sys_content}, {"role": "user", "content": content}, ] completion = self.openai_client.chat.completions.create( model=self.model, messages=messages, temperature=self.temperature, ) return completion def get_translation(self, text): self.rotate_key() self.rotate_model() # rotate all the model to avoid the limit completion = self.create_chat_completion(text) # TODO work well or exception finish by length limit # Check if content is not None before encoding if completion.choices[0].message.content is not None: t_text = completion.choices[0].message.content.encode("utf8").decode() or "" else: t_text = "" return t_text def translate(self, text, needprint=True): start_time = time.time() # todo: Determine whether to print according to the cli option if needprint: print(re.sub("\n{3,}", "\n\n", text)) attempt_count = 0 max_attempts = 3 t_text = "" while attempt_count < max_attempts: try: t_text = self.get_translation(text) break except RateLimitError as e: # todo: better sleep time? why sleep alawys about key_len # 1. openai server error or own network interruption, sleep for a fixed time # 2. an apikey has no money or reach limit, don`t sleep, just replace it with another apikey # 3. all apikey reach limit, then use current sleep sleep_time = int(60 / self.key_len) print(e, f"will sleep {sleep_time} seconds") time.sleep(sleep_time) attempt_count += 1 if attempt_count == max_attempts: print(f"Get {attempt_count} consecutive exceptions") raise except Exception as e: print(str(e)) return # todo: Determine whether to print according to the cli option if needprint: print("[bold green]" + re.sub("\n{3,}", "\n\n", t_text) + "[/bold green]") time.time() - start_time # print(f"translation time: {elapsed_time:.1f}s") return t_text def translate_and_split_lines(self, text): result_str = self.translate(text, False) lines = result_str.splitlines() lines = [line.strip() for line in lines if line.strip() != ""] return lines def get_best_result_list( self, plist_len, new_str, sleep_dur, result_list, max_retries=15, ): if len(result_list) == plist_len: return result_list, 0 best_result_list = result_list retry_count = 0 while retry_count < max_retries and len(result_list) != plist_len: print( f"bug: {plist_len} -> {len(result_list)} : Number of paragraphs before and after translation", ) print(f"sleep for {sleep_dur}s and retry {retry_count+1} ...") time.sleep(sleep_dur) retry_count += 1 result_list = self.translate_and_split_lines(new_str) if ( len(result_list) == plist_len or len(best_result_list) < len(result_list) <= plist_len or ( len(result_list) < len(best_result_list) and len(best_result_list) > plist_len ) ): best_result_list = result_list return best_result_list, retry_count def log_retry(self, state, retry_count, elapsed_time, log_path="log/buglog.txt"): if retry_count == 0: return print(f"retry {state}") with open(log_path, "a", encoding="utf-8") as f: print( f"retry {state}, count = {retry_count}, time = {elapsed_time:.1f}s", file=f, ) def log_translation_mismatch( self, plist_len, result_list, new_str, sep, log_path="log/buglog.txt", ): if len(result_list) == plist_len: return newlist = new_str.split(sep) with open(log_path, "a", encoding="utf-8") as f: print(f"problem size: {plist_len - len(result_list)}", file=f) for i in range(len(newlist)): print(newlist[i], file=f) print(file=f) if i < len(result_list): print("............................................", file=f) print(result_list[i], file=f) print(file=f) print("=============================", file=f) print( f"bug: {plist_len} paragraphs of text translated into {len(result_list)} paragraphs", ) print("continue") def join_lines(self, text): lines = text.splitlines() new_lines = [] temp_line = [] # join for line in lines: if line.strip(): temp_line.append(line.strip()) else: if temp_line: new_lines.append(" ".join(temp_line)) temp_line = [] new_lines.append(line) if temp_line: new_lines.append(" ".join(temp_line)) text = "\n".join(new_lines) # try to fix #372 if not text: return "" # del ^M text = text.replace("^M", "\r") lines = text.splitlines() filtered_lines = [line for line in lines if line.strip() != "\r"] new_text = "\n".join(filtered_lines) return new_text def translate_list(self, plist): sep = "\n\n\n\n\n" # new_str = sep.join([item.text for item in plist]) new_str = "" i = 1 for p in plist: temp_p = copy(p) for sup in temp_p.find_all("sup"): sup.extract() new_str += f"({i}) {temp_p.get_text().strip()}{sep}" i = i + 1 if new_str.endswith(sep): new_str = new_str[: -len(sep)] new_str = self.join_lines(new_str) plist_len = len(plist) print(f"plist len = {len(plist)}") result_list = self.translate_and_split_lines(new_str) start_time = time.time() result_list, retry_count = self.get_best_result_list( plist_len, new_str, 6, # WTF this magic number here? result_list, ) end_time = time.time() state = "fail" if len(result_list) != plist_len else "success" log_path = "log/buglog.txt" self.log_retry(state, retry_count, end_time - start_time, log_path) self.log_translation_mismatch(plist_len, result_list, new_str, sep, log_path) # del (num), num. sometime (num) will translated to num. result_list = [re.sub(r"^(\(\d+\)|\d+\.|(\d+))\s*", "", s) for s in result_list] return result_list def set_deployment_id(self, deployment_id): self.deployment_id = deployment_id self.openai_client = AzureOpenAI( api_key=next(self.keys), azure_endpoint=self.api_base, api_version="2023-07-01-preview", azure_deployment=self.deployment_id, ) def set_gpt4_models(self, model="gpt4"): my_model_list = [ i["id"] for i in self.openai_client.models.list().model_dump()["data"] ] model_list = list(set(my_model_list) & set(GPT4_MODEL_LIST)) print(f"Using model list {model_list}") self.model_list = cycle(model_list)