Revert extra changes to cli

This commit is contained in:
thibaut 2025-03-01 21:23:37 +01:00
parent 3788da2ca3
commit 9e96914c76

View File

@ -55,6 +55,7 @@ def parse_prompt_arg(prompt_arg):
if not any(prompt_arg.endswith(ext) for ext in [".json", ".txt", ".md"]): if not any(prompt_arg.endswith(ext) for ext in [".json", ".txt", ".md"]):
try: try:
# user can define prompt by passing a json string # user can define prompt by passing a json string
# eg: --prompt '{"system": "You are a professional translator who translates computer technology books", "user": "Translate \`{text}\` to {language}"}'
prompt = json.loads(prompt_arg) prompt = json.loads(prompt_arg)
except json.JSONDecodeError: except json.JSONDecodeError:
# if not a json string, treat it as a template string # if not a json string, treat it as a template string
@ -67,12 +68,13 @@ def parse_prompt_arg(prompt_arg):
prompt = {"user": f.read()} prompt = {"user": f.read()}
elif prompt_arg.endswith(".json"): elif prompt_arg.endswith(".json"):
# if it's a json file, treat it as a json object # if it's a json file, treat it as a json object
# eg: --prompt prompt_template_sample.json
with open(prompt_arg, encoding="utf-8") as f: with open(prompt_arg, encoding="utf-8") as f:
prompt = json.load(f) prompt = json.load(f)
else: else:
raise FileNotFoundError(f"{prompt_arg} not found") raise FileNotFoundError(f"{prompt_arg} not found")
# Validate the prompt # if prompt is None or any(c not in prompt["user"] for c in ["{text}", "{language}"]):
if prompt is None or any(c not in prompt["user"] for c in ["{text}"]): if prompt is None or any(c not in prompt["user"] for c in ["{text}"]):
raise ValueError("prompt must contain `{text}`") raise ValueError("prompt must contain `{text}`")
@ -85,6 +87,7 @@ def parse_prompt_arg(prompt_arg):
print("prompt config:", prompt) print("prompt config:", prompt)
return prompt return prompt
def main(): def main():
translate_model_list = list(MODEL_DICT.keys()) translate_model_list = list(MODEL_DICT.keys())
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -281,7 +284,7 @@ def main():
"--accumulated_num", "--accumulated_num",
dest="accumulated_num", dest="accumulated_num",
type=int, type=int,
default=200000, default=1,
help="""Wait for how many tokens have been accumulated before starting the translation. help="""Wait for how many tokens have been accumulated before starting the translation.
gpt3.5 limits the total_token to 4090. gpt3.5 limits the total_token to 4090.
For example, if you use --accumulated_num 1600, maybe openai will output 2200 tokens For example, if you use --accumulated_num 1600, maybe openai will output 2200 tokens
@ -299,7 +302,6 @@ So you are close to reaching the limit. You have to choose your own value, there
"--batch_size", "--batch_size",
dest="batch_size", dest="batch_size",
type=int, type=int,
default=500,
help="how many lines will be translated by aggregated translation(This options currently only applies to txt files)", help="how many lines will be translated by aggregated translation(This options currently only applies to txt files)",
) )
parser.add_argument( parser.add_argument(