From 6685b2399336111e710cd45d64a487420bf7b210 Mon Sep 17 00:00:00 2001 From: leslie Date: Sat, 19 Apr 2025 19:59:32 +0800 Subject: [PATCH] fix:Fix translation paragraph count mismatch by explicitly instructing LLM about paragraph requirements --- .../translator/chatgptapi_translator.py | 93 ++++++++++++++----- 1 file changed, 71 insertions(+), 22 deletions(-) diff --git a/book_maker/translator/chatgptapi_translator.py b/book_maker/translator/chatgptapi_translator.py index 47fbba6..bfccabe 100644 --- a/book_maker/translator/chatgptapi_translator.py +++ b/book_maker/translator/chatgptapi_translator.py @@ -75,7 +75,7 @@ class ChatGPTAPI(Base): api_base=None, prompt_template=None, prompt_sys_msg=None, - temperature=1.0, + temperature=0.2, context_flag=False, context_paragraph_limit=0, **kwargs, @@ -155,6 +155,7 @@ class ChatGPTAPI(Base): model=self.model, messages=messages, temperature=self.temperature, + top_p=0.1 ) return completion @@ -240,28 +241,39 @@ class ChatGPTAPI(Base): ): if len(result_list) == plist_len: return result_list, 0 - best_result_list = result_list retry_count = 0 - + # Save the original templates + original_prompt_template = self.prompt_template + original_system_content = self.system_content while retry_count < max_retries and len(result_list) != plist_len: print( f"bug: {plist_len} -> {len(result_list)} : Number of paragraphs before and after translation", ) - print(f"sleep for {sleep_dur}s and retry {retry_count+1} ...") + print(f"sleep for {sleep_dur}s and retry {retry_count + 1} ...") time.sleep(sleep_dur) retry_count += 1 - result_list = self.translate_and_split_lines(new_str) + + # Use increasingly forceful prompts on retries + self.prompt_template = f"Translate the following text to {{language}}. IMPORTANT: The text has EXACTLY {plist_len} numbered paragraphs. Your translation MUST have EXACTLY {plist_len} paragraphs with the same numbering (1), (2), etc. `{{text}}`" + self.system_content = f"You are a precise translator. The text contains {plist_len} paragraphs. Your output MUST contain exactly {plist_len} paragraphs, no more and no less." + + # Try again with modified instruction + result_str = self.translate(new_str, False) + result_list = self.extract_paragraphs(result_str, plist_len) + if ( len(result_list) == plist_len or len(best_result_list) < len(result_list) <= plist_len or ( - len(result_list) < len(best_result_list) - and len(best_result_list) > plist_len - ) + len(result_list) < len(best_result_list) + and len(best_result_list) > plist_len + ) ): best_result_list = result_list - + # Restore the original templates + self.prompt_template = original_prompt_template + self.system_content = original_system_content return best_result_list, retry_count def log_retry(self, state, retry_count, elapsed_time, log_path="log/buglog.txt"): @@ -334,8 +346,9 @@ class ChatGPTAPI(Base): def translate_list(self, plist): sep = "\n\n\n\n\n" - # new_str = sep.join([item.text for item in plist]) + plist_len = len(plist) + # Construct the text to be translated new_str = "" i = 1 for p in plist: @@ -347,34 +360,70 @@ class ChatGPTAPI(Base): if new_str.endswith(sep): new_str = new_str[: -len(sep)] - new_str = self.join_lines(new_str) - plist_len = len(plist) + print(f"plist len = {plist_len}") - print(f"plist len = {len(plist)}") + # Save the original prompt template and system message + original_prompt_template = self.prompt_template + original_system_content = self.system_content - result_list = self.translate_and_split_lines(new_str) + # Modify the prompt template and system message to include paragraph count requirement + self.prompt_template = f"Please translate the following {plist_len} numbered paragraphs to {{language}}. Ensure your translation maintains exactly {plist_len} paragraphs and preserves the paragraph numbers. `{{text}}`" + self.system_content = f"You are a translator. The text contains {plist_len} numbered paragraphs. Your translation must have exactly {plist_len} paragraphs with the same numbering structure." + + # Translate with explicit paragraph count instruction + result_str = self.translate(new_str, False) + + # Extract paragraphs with a robust strategy + result_list = self.extract_paragraphs(result_str, plist_len) + + # Restore original templates + self.prompt_template = original_prompt_template + self.system_content = original_system_content start_time = time.time() - result_list, retry_count = self.get_best_result_list( plist_len, new_str, - 6, # WTF this magic number here? + 6, result_list, ) - end_time = time.time() - state = "fail" if len(result_list) != plist_len else "success" log_path = "log/buglog.txt" - self.log_retry(state, retry_count, end_time - start_time, log_path) - self.log_translation_mismatch(plist_len, result_list, new_str, sep, log_path) + self.log_translation_mismatch(plist_len, result_list, new_str, sep, + log_path) + # Remove paragraph numbers from the result + result_list = [re.sub(r"^(\(\d+\)|\d+\.|(\d+))\s*", "", s) for s in + result_list] + return result_list + + def extract_paragraphs(self, text, paragraph_count): + """Extract paragraphs from translated text, ensuring paragraph count is preserved.""" + # First try to extract by paragraph numbers (1), (2), etc. + result_list = [] + for i in range(1, paragraph_count + 1): + pattern = rf'\({i}\)\s*(.*?)(?=\s*\({i + 1}\)|\Z)' + match = re.search(pattern, text, re.DOTALL) + if match: + result_list.append(match.group(1).strip()) + + # If exact pattern matching failed, try another approach + if len(result_list) != paragraph_count: + pattern = r'\((\d+)\)\s*(.*?)(?=\s*\(\d+\)|\Z)' + matches = re.findall(pattern, text, re.DOTALL) + if matches: + # Sort by paragraph number + matches.sort(key=lambda x: int(x[0])) + result_list = [match[1].strip() for match in matches] + + # Fallback to original line-splitting approach + if len(result_list) != paragraph_count: + lines = text.splitlines() + result_list = [line.strip() for line in lines if line.strip() != ""] - # del (num), num. sometime (num) will translated to num. - result_list = [re.sub(r"^(\(\d+\)|\d+\.|(\d+))\s*", "", s) for s in result_list] return result_list def set_deployment_id(self, deployment_id):