mirror of
https://github.com/yihong0618/bilingual_book_maker.git
synced 2025-07-18 16:10:05 +00:00
fix:Fix translation paragraph count mismatch by explicitly instructing LLM about paragraph requirements
This commit is contained in:
parent
a1f0185043
commit
6685b23993
@ -75,7 +75,7 @@ class ChatGPTAPI(Base):
|
|||||||
api_base=None,
|
api_base=None,
|
||||||
prompt_template=None,
|
prompt_template=None,
|
||||||
prompt_sys_msg=None,
|
prompt_sys_msg=None,
|
||||||
temperature=1.0,
|
temperature=0.2,
|
||||||
context_flag=False,
|
context_flag=False,
|
||||||
context_paragraph_limit=0,
|
context_paragraph_limit=0,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@ -155,6 +155,7 @@ class ChatGPTAPI(Base):
|
|||||||
model=self.model,
|
model=self.model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
temperature=self.temperature,
|
temperature=self.temperature,
|
||||||
|
top_p=0.1
|
||||||
)
|
)
|
||||||
return completion
|
return completion
|
||||||
|
|
||||||
@ -240,10 +241,11 @@ class ChatGPTAPI(Base):
|
|||||||
):
|
):
|
||||||
if len(result_list) == plist_len:
|
if len(result_list) == plist_len:
|
||||||
return result_list, 0
|
return result_list, 0
|
||||||
|
|
||||||
best_result_list = result_list
|
best_result_list = result_list
|
||||||
retry_count = 0
|
retry_count = 0
|
||||||
|
# Save the original templates
|
||||||
|
original_prompt_template = self.prompt_template
|
||||||
|
original_system_content = self.system_content
|
||||||
while retry_count < max_retries and len(result_list) != plist_len:
|
while retry_count < max_retries and len(result_list) != plist_len:
|
||||||
print(
|
print(
|
||||||
f"bug: {plist_len} -> {len(result_list)} : Number of paragraphs before and after translation",
|
f"bug: {plist_len} -> {len(result_list)} : Number of paragraphs before and after translation",
|
||||||
@ -251,7 +253,15 @@ class ChatGPTAPI(Base):
|
|||||||
print(f"sleep for {sleep_dur}s and retry {retry_count + 1} ...")
|
print(f"sleep for {sleep_dur}s and retry {retry_count + 1} ...")
|
||||||
time.sleep(sleep_dur)
|
time.sleep(sleep_dur)
|
||||||
retry_count += 1
|
retry_count += 1
|
||||||
result_list = self.translate_and_split_lines(new_str)
|
|
||||||
|
# Use increasingly forceful prompts on retries
|
||||||
|
self.prompt_template = f"Translate the following text to {{language}}. IMPORTANT: The text has EXACTLY {plist_len} numbered paragraphs. Your translation MUST have EXACTLY {plist_len} paragraphs with the same numbering (1), (2), etc. `{{text}}`"
|
||||||
|
self.system_content = f"You are a precise translator. The text contains {plist_len} paragraphs. Your output MUST contain exactly {plist_len} paragraphs, no more and no less."
|
||||||
|
|
||||||
|
# Try again with modified instruction
|
||||||
|
result_str = self.translate(new_str, False)
|
||||||
|
result_list = self.extract_paragraphs(result_str, plist_len)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
len(result_list) == plist_len
|
len(result_list) == plist_len
|
||||||
or len(best_result_list) < len(result_list) <= plist_len
|
or len(best_result_list) < len(result_list) <= plist_len
|
||||||
@ -261,7 +271,9 @@ class ChatGPTAPI(Base):
|
|||||||
)
|
)
|
||||||
):
|
):
|
||||||
best_result_list = result_list
|
best_result_list = result_list
|
||||||
|
# Restore the original templates
|
||||||
|
self.prompt_template = original_prompt_template
|
||||||
|
self.system_content = original_system_content
|
||||||
return best_result_list, retry_count
|
return best_result_list, retry_count
|
||||||
|
|
||||||
def log_retry(self, state, retry_count, elapsed_time, log_path="log/buglog.txt"):
|
def log_retry(self, state, retry_count, elapsed_time, log_path="log/buglog.txt"):
|
||||||
@ -334,8 +346,9 @@ class ChatGPTAPI(Base):
|
|||||||
|
|
||||||
def translate_list(self, plist):
|
def translate_list(self, plist):
|
||||||
sep = "\n\n\n\n\n"
|
sep = "\n\n\n\n\n"
|
||||||
# new_str = sep.join([item.text for item in plist])
|
plist_len = len(plist)
|
||||||
|
|
||||||
|
# Construct the text to be translated
|
||||||
new_str = ""
|
new_str = ""
|
||||||
i = 1
|
i = 1
|
||||||
for p in plist:
|
for p in plist:
|
||||||
@ -347,34 +360,70 @@ class ChatGPTAPI(Base):
|
|||||||
|
|
||||||
if new_str.endswith(sep):
|
if new_str.endswith(sep):
|
||||||
new_str = new_str[: -len(sep)]
|
new_str = new_str[: -len(sep)]
|
||||||
|
|
||||||
new_str = self.join_lines(new_str)
|
new_str = self.join_lines(new_str)
|
||||||
|
|
||||||
plist_len = len(plist)
|
print(f"plist len = {plist_len}")
|
||||||
|
|
||||||
print(f"plist len = {len(plist)}")
|
# Save the original prompt template and system message
|
||||||
|
original_prompt_template = self.prompt_template
|
||||||
|
original_system_content = self.system_content
|
||||||
|
|
||||||
result_list = self.translate_and_split_lines(new_str)
|
# Modify the prompt template and system message to include paragraph count requirement
|
||||||
|
self.prompt_template = f"Please translate the following {plist_len} numbered paragraphs to {{language}}. Ensure your translation maintains exactly {plist_len} paragraphs and preserves the paragraph numbers. `{{text}}`"
|
||||||
|
self.system_content = f"You are a translator. The text contains {plist_len} numbered paragraphs. Your translation must have exactly {plist_len} paragraphs with the same numbering structure."
|
||||||
|
|
||||||
|
# Translate with explicit paragraph count instruction
|
||||||
|
result_str = self.translate(new_str, False)
|
||||||
|
|
||||||
|
# Extract paragraphs with a robust strategy
|
||||||
|
result_list = self.extract_paragraphs(result_str, plist_len)
|
||||||
|
|
||||||
|
# Restore original templates
|
||||||
|
self.prompt_template = original_prompt_template
|
||||||
|
self.system_content = original_system_content
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
result_list, retry_count = self.get_best_result_list(
|
result_list, retry_count = self.get_best_result_list(
|
||||||
plist_len,
|
plist_len,
|
||||||
new_str,
|
new_str,
|
||||||
6, # WTF this magic number here?
|
6,
|
||||||
result_list,
|
result_list,
|
||||||
)
|
)
|
||||||
|
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
|
|
||||||
state = "fail" if len(result_list) != plist_len else "success"
|
state = "fail" if len(result_list) != plist_len else "success"
|
||||||
log_path = "log/buglog.txt"
|
log_path = "log/buglog.txt"
|
||||||
|
|
||||||
self.log_retry(state, retry_count, end_time - start_time, log_path)
|
self.log_retry(state, retry_count, end_time - start_time, log_path)
|
||||||
self.log_translation_mismatch(plist_len, result_list, new_str, sep, log_path)
|
self.log_translation_mismatch(plist_len, result_list, new_str, sep,
|
||||||
|
log_path)
|
||||||
|
# Remove paragraph numbers from the result
|
||||||
|
result_list = [re.sub(r"^(\(\d+\)|\d+\.|(\d+))\s*", "", s) for s in
|
||||||
|
result_list]
|
||||||
|
return result_list
|
||||||
|
|
||||||
|
def extract_paragraphs(self, text, paragraph_count):
|
||||||
|
"""Extract paragraphs from translated text, ensuring paragraph count is preserved."""
|
||||||
|
# First try to extract by paragraph numbers (1), (2), etc.
|
||||||
|
result_list = []
|
||||||
|
for i in range(1, paragraph_count + 1):
|
||||||
|
pattern = rf'\({i}\)\s*(.*?)(?=\s*\({i + 1}\)|\Z)'
|
||||||
|
match = re.search(pattern, text, re.DOTALL)
|
||||||
|
if match:
|
||||||
|
result_list.append(match.group(1).strip())
|
||||||
|
|
||||||
|
# If exact pattern matching failed, try another approach
|
||||||
|
if len(result_list) != paragraph_count:
|
||||||
|
pattern = r'\((\d+)\)\s*(.*?)(?=\s*\(\d+\)|\Z)'
|
||||||
|
matches = re.findall(pattern, text, re.DOTALL)
|
||||||
|
if matches:
|
||||||
|
# Sort by paragraph number
|
||||||
|
matches.sort(key=lambda x: int(x[0]))
|
||||||
|
result_list = [match[1].strip() for match in matches]
|
||||||
|
|
||||||
|
# Fallback to original line-splitting approach
|
||||||
|
if len(result_list) != paragraph_count:
|
||||||
|
lines = text.splitlines()
|
||||||
|
result_list = [line.strip() for line in lines if line.strip() != ""]
|
||||||
|
|
||||||
# del (num), num. sometime (num) will translated to num.
|
|
||||||
result_list = [re.sub(r"^(\(\d+\)|\d+\.|(\d+))\s*", "", s) for s in result_list]
|
|
||||||
return result_list
|
return result_list
|
||||||
|
|
||||||
def set_deployment_id(self, deployment_id):
|
def set_deployment_id(self, deployment_id):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user