From 2e0720739be5b17841a610c8eae960e0f9a6b50e Mon Sep 17 00:00:00 2001 From: Bryan Lee <38807139+liby@users.noreply.github.com> Date: Thu, 30 Mar 2023 19:17:27 +0800 Subject: [PATCH] feat: add Azure OpenAI service support (#213) * feat: add Azure OpenAI service support * fix: code format * fix: enforce `api_base` when providing `deployment_id` --------- Co-authored-by: yihong0618 --- README-CN.md | 9 ++++++++ README.md | 8 +++++++ book_maker/cli.py | 15 +++++++++++++ book_maker/loader/epub_loader.py | 4 ++-- book_maker/translator/base_translator.py | 3 +++ .../translator/chatgptapi_translator.py | 22 +++++++++++-------- 6 files changed, 50 insertions(+), 11 deletions(-) diff --git a/README-CN.md b/README-CN.md index 5bcd33a..3cf7f33 100644 --- a/README-CN.md +++ b/README-CN.md @@ -98,6 +98,15 @@ python3 make_book.py --book_name 'animal_farm.epub' --openai_key sk-XXXXX --api_ python make_book.py --book_name 'animal_farm.epub' --openai_key sk-XXXXX --api_base 'https://xxxxx/v1' ``` +使用 Azure OpenAI service +```shell +python3 make_book.py --book_name 'animal_farm.epub' --openai_key XXXXX --api_base 'https://example-endpoint.openai.azure.com' --deployment_id 'deployment-name' + +# Or python3 is not in your PATH +python make_book.py --book_name 'animal_farm.epub' --openai_key XXXXX --api_base 'https://example-endpoint.openai.azure.com' --deployment_id 'deployment-name' +``` + + ## 注意 1. Free trail 的 API token 有所限制,如果想要更快的速度,可以考虑付费方案 diff --git a/README.md b/README.md index fd43884..94cc84a 100644 --- a/README.md +++ b/README.md @@ -107,6 +107,14 @@ python3 make_book.py --book_name 'animal_farm.epub' --openai_key sk-XXXXX --api_ python make_book.py --book_name 'animal_farm.epub' --openai_key sk-XXXXX --api_base 'https://xxxxx/v1' ``` +Microsoft Azure Endpoints +```shell +python3 make_book.py --book_name 'animal_farm.epub' --openai_key XXXXX --api_base 'https://example-endpoint.openai.azure.com' --deployment_id 'deployment-name' + +# Or python3 is not in your PATH +python make_book.py --book_name 'animal_farm.epub' --openai_key XXXXX --api_base 'https://example-endpoint.openai.azure.com' --deployment_id 'deployment-name' +``` + ## Docker You can use [Docker](https://www.docker.com/) if you don't want to deal with setting up the environment. diff --git a/book_maker/cli.py b/book_maker/cli.py index 6af9924..2233ad6 100644 --- a/book_maker/cli.py +++ b/book_maker/cli.py @@ -142,6 +142,12 @@ def main(): default="", help="use proxy like http://127.0.0.1:7890", ) + parser.add_argument( + "--deployment_id", + dest="deployment_id", + type=str, + help="the deployment name you chose when you deployed the model", + ) # args to change api_base parser.add_argument( "--api_base", @@ -297,6 +303,15 @@ So you are close to reaching the limit. You have to choose your own value, there e.batch_size = options.batch_size if options.retranslate: e.retranslate = options.retranslate + if options.deployment_id: + # only work for ChatGPT api for now + # later maybe support others + assert ( + options.model == "chatgptapi" + ), "only support chatgptapi for deployment_id" + if not options.api_base: + raise ValueError("`api_base` must be provided when using `deployment_id`") + e.translate_model.set_deployment_id(options.deployment_id) e.make_bilingual_book() diff --git a/book_maker/loader/epub_loader.py b/book_maker/loader/epub_loader.py index 8b36e92..85aa56d 100644 --- a/book_maker/loader/epub_loader.py +++ b/book_maker/loader/epub_loader.py @@ -49,7 +49,7 @@ class EPUBBookLoader(BaseBookLoader): ) self.retranslate = None - # monkey pathch for # 173 + # monkey patch for # 173 def _write_items_patch(obj): for item in obj.book.get_items(): if isinstance(item, epub.EpubNcx): @@ -73,7 +73,7 @@ class EPUBBookLoader(BaseBookLoader): try: self.origin_book = epub.read_epub(self.epub_name) except Exception: - # tricky monkey pathch for #71 if you don't know why please check the issue and ignore this + # tricky monkey patch for #71 if you don't know why please check the issue and ignore this # when upstream change will TODO fix this def _load_spine(obj): spine = obj.container.find("{%s}%s" % (epub.NAMESPACES["OPF"], "spine")) diff --git a/book_maker/translator/base_translator.py b/book_maker/translator/base_translator.py index 57bb2c1..22c1c64 100644 --- a/book_maker/translator/base_translator.py +++ b/book_maker/translator/base_translator.py @@ -14,3 +14,6 @@ class Base(ABC): @abstractmethod def translate(self, text): pass + + def set_deployment_id(self, deployment_id): + pass diff --git a/book_maker/translator/chatgptapi_translator.py b/book_maker/translator/chatgptapi_translator.py index 9f9678d..196e5ce 100644 --- a/book_maker/translator/chatgptapi_translator.py +++ b/book_maker/translator/chatgptapi_translator.py @@ -27,6 +27,7 @@ class ChatGPTAPI(Base): ): super().__init__(key, language) self.key_len = len(key.split(",")) + if api_base: openai.api_base = api_base self.prompt_template = ( @@ -43,8 +44,7 @@ class ChatGPTAPI(Base): or "" ) self.system_content = environ.get("OPENAI_API_SYS_MSG") or "" - - max_num_token = -1 + self.deployment_id = None def rotate_key(self): openai.api_key = next(self.keys) @@ -59,6 +59,12 @@ class ChatGPTAPI(Base): {"role": "user", "content": content}, ] + if self.deployment_id: + return openai.ChatCompletion.create( + engine=self.deployment_id, + messages=messages, + ) + return openai.ChatCompletion.create( model="gpt-3.5-turbo", messages=messages, @@ -95,13 +101,6 @@ The total token is too long and cannot be completely translated\n file=f, ) - # usage = completion["usage"] - # print(f"total_token: {usage['total_tokens']}") - # if int(usage["total_tokens"]) > self.max_num_token: - # self.max_num_token = int(usage["total_tokens"]) - # print( - # f"{usage['total_tokens']} {usage['prompt_tokens']} {usage['completion_tokens']} {self.max_num_token} (total_token, prompt_token, completion_tokens, max_history_total_token)" - # ) return t_text def translate(self, text, needprint=True): @@ -278,3 +277,8 @@ The total token is too long and cannot be completely translated\n # del (num), num. sometime (num) will translated to num. result_list = [re.sub(r"^(\(\d+\)|\d+\.|(\d+))\s*", "", s) for s in result_list] return result_list + + def set_deployment_id(self, deployment_id): + openai.api_type = "azure" + openai.api_version = "2023-03-15-preview" + self.deployment_id = deployment_id