diff --git a/book_maker/translator/chatgptapi_translator.py b/book_maker/translator/chatgptapi_translator.py index 47fbba6..2dc9f14 100644 --- a/book_maker/translator/chatgptapi_translator.py +++ b/book_maker/translator/chatgptapi_translator.py @@ -230,40 +230,6 @@ class ChatGPTAPI(Base): lines = [line.strip() for line in lines if line.strip() != ""] return lines - def get_best_result_list( - self, - plist_len, - new_str, - sleep_dur, - result_list, - max_retries=15, - ): - if len(result_list) == plist_len: - return result_list, 0 - - best_result_list = result_list - retry_count = 0 - - while retry_count < max_retries and len(result_list) != plist_len: - print( - f"bug: {plist_len} -> {len(result_list)} : Number of paragraphs before and after translation", - ) - print(f"sleep for {sleep_dur}s and retry {retry_count+1} ...") - time.sleep(sleep_dur) - retry_count += 1 - result_list = self.translate_and_split_lines(new_str) - if ( - len(result_list) == plist_len - or len(best_result_list) < len(result_list) <= plist_len - or ( - len(result_list) < len(best_result_list) - and len(best_result_list) > plist_len - ) - ): - best_result_list = result_list - - return best_result_list, retry_count - def log_retry(self, state, retry_count, elapsed_time, log_path="log/buglog.txt"): if retry_count == 0: return @@ -333,48 +299,131 @@ class ChatGPTAPI(Base): return new_text def translate_list(self, plist): - sep = "\n\n\n\n\n" - # new_str = sep.join([item.text for item in plist]) + plist_len = len(plist) - new_str = "" - i = 1 - for p in plist: + # Create a list of original texts and add clear numbering markers to each paragraph + formatted_text = "" + for i, p in enumerate(plist, 1): temp_p = copy(p) for sup in temp_p.find_all("sup"): sup.extract() - new_str += f"({i}) {temp_p.get_text().strip()}{sep}" - i = i + 1 + para_text = temp_p.get_text().strip() + # Using special delimiters and clear numbering + formatted_text += f"PARAGRAPH {i}:\n{para_text}\n\n" - if new_str.endswith(sep): - new_str = new_str[: -len(sep)] + print(f"plist len = {plist_len}") - new_str = self.join_lines(new_str) + original_prompt_template = self.prompt_template - plist_len = len(plist) - - print(f"plist len = {len(plist)}") - - result_list = self.translate_and_split_lines(new_str) - - start_time = time.time() - - result_list, retry_count = self.get_best_result_list( - plist_len, - new_str, - 6, # WTF this magic number here? - result_list, + structured_prompt = ( + f"Translate the following {plist_len} paragraphs to {{language}}. " + f"CRUCIAL INSTRUCTION: Format your response using EXACTLY this structure:\n\n" + f"TRANSLATION OF PARAGRAPH 1:\n[Your translation of paragraph 1 here]\n\n" + f"TRANSLATION OF PARAGRAPH 2:\n[Your translation of paragraph 2 here]\n\n" + f"... and so on for all {plist_len} paragraphs.\n\n" + f"You MUST provide EXACTLY {plist_len} translated paragraphs. " + f"Do not merge, split, or rearrange paragraphs. " + f"Translate each paragraph independently but consistently. " + f"Keep all numbers and special formatting in your translation. " + f"Each original paragraph must correspond to exactly one translated paragraph." ) - end_time = time.time() + self.prompt_template = structured_prompt + " ```{text}```" - state = "fail" if len(result_list) != plist_len else "success" - log_path = "log/buglog.txt" + translated_text = self.translate(formatted_text, False) - self.log_retry(state, retry_count, end_time - start_time, log_path) - self.log_translation_mismatch(plist_len, result_list, new_str, sep, log_path) + # Extract translations from structured output + translated_paragraphs = [] + for i in range(1, plist_len + 1): + pattern = ( + r"TRANSLATION OF PARAGRAPH " + + str(i) + + r":(.*?)(?=TRANSLATION OF PARAGRAPH \d+:|\Z)" + ) + matches = re.findall(pattern, translated_text, re.DOTALL) + + if matches: + translated_paragraph = matches[0].strip() + translated_paragraphs.append(translated_paragraph) + else: + print(f"Warning: Could not find translation for paragraph {i}") + loose_pattern = ( + r"(?:TRANSLATION|PARAGRAPH|PARA).*?" + + str(i) + + r".*?:(.*?)(?=(?:TRANSLATION|PARAGRAPH|PARA).*?\d+.*?:|\Z)" + ) + loose_matches = re.findall(loose_pattern, translated_text, re.DOTALL) + if loose_matches: + translated_paragraphs.append(loose_matches[0].strip()) + else: + translated_paragraphs.append("") + + self.prompt_template = original_prompt_template + + # If the number of extracted paragraphs is incorrect, try the alternative extraction method. + if len(translated_paragraphs) != plist_len: + print( + f"Warning: Extracted {len(translated_paragraphs)}/{plist_len} paragraphs. Using fallback extraction." + ) + + all_para_pattern = r"(?:TRANSLATION|PARAGRAPH|PARA).*?(\d+).*?:(.*?)(?=(?:TRANSLATION|PARAGRAPH|PARA).*?\d+.*?:|\Z)" + all_matches = re.findall(all_para_pattern, translated_text, re.DOTALL) + + if all_matches: + # Create a dictionary to map translation content based on paragraph numbers + para_dict = {} + for num_str, content in all_matches: + try: + num = int(num_str) + if 1 <= num <= plist_len: + para_dict[num] = content.strip() + except ValueError: + continue + + # Rebuild the translation list in the original order + new_translated_paragraphs = [] + for i in range(1, plist_len + 1): + if i in para_dict: + new_translated_paragraphs.append(para_dict[i]) + else: + new_translated_paragraphs.append("") + + if len(new_translated_paragraphs) == plist_len: + translated_paragraphs = new_translated_paragraphs + + if len(translated_paragraphs) < plist_len: + translated_paragraphs.extend( + [""] * (plist_len - len(translated_paragraphs)) + ) + elif len(translated_paragraphs) > plist_len: + translated_paragraphs = translated_paragraphs[:plist_len] + + return translated_paragraphs + + def extract_paragraphs(self, text, paragraph_count): + """Extract paragraphs from translated text, ensuring paragraph count is preserved.""" + # First try to extract by paragraph numbers (1), (2), etc. + result_list = [] + for i in range(1, paragraph_count + 1): + pattern = rf"\({i}\)\s*(.*?)(?=\s*\({i + 1}\)|\Z)" + match = re.search(pattern, text, re.DOTALL) + if match: + result_list.append(match.group(1).strip()) + + # If exact pattern matching failed, try another approach + if len(result_list) != paragraph_count: + pattern = r"\((\d+)\)\s*(.*?)(?=\s*\(\d+\)|\Z)" + matches = re.findall(pattern, text, re.DOTALL) + if matches: + # Sort by paragraph number + matches.sort(key=lambda x: int(x[0])) + result_list = [match[1].strip() for match in matches] + + # Fallback to original line-splitting approach + if len(result_list) != paragraph_count: + lines = text.splitlines() + result_list = [line.strip() for line in lines if line.strip() != ""] - # del (num), num. sometime (num) will translated to num. - result_list = [re.sub(r"^(\(\d+\)|\d+\.|(\d+))\s*", "", s) for s in result_list] return result_list def set_deployment_id(self, deployment_id):