feat: add Azure OpenAI service support (#213)

* feat: add Azure OpenAI service support

* fix: code format

* fix: enforce `api_base` when providing `deployment_id`

---------

Co-authored-by: yihong0618 <zouzou0208@gmail.com>
This commit is contained in:
Bryan Lee 2023-03-30 19:17:27 +08:00 committed by GitHub
parent 1e7c8f9ce4
commit 2e0720739b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 50 additions and 11 deletions

View File

@ -98,6 +98,15 @@ python3 make_book.py --book_name 'animal_farm.epub' --openai_key sk-XXXXX --api_
python make_book.py --book_name 'animal_farm.epub' --openai_key sk-XXXXX --api_base 'https://xxxxx/v1' python make_book.py --book_name 'animal_farm.epub' --openai_key sk-XXXXX --api_base 'https://xxxxx/v1'
``` ```
使用 Azure OpenAI service
```shell
python3 make_book.py --book_name 'animal_farm.epub' --openai_key XXXXX --api_base 'https://example-endpoint.openai.azure.com' --deployment_id 'deployment-name'
# Or python3 is not in your PATH
python make_book.py --book_name 'animal_farm.epub' --openai_key XXXXX --api_base 'https://example-endpoint.openai.azure.com' --deployment_id 'deployment-name'
```
## 注意 ## 注意
1. Free trail 的 API token 有所限制,如果想要更快的速度,可以考虑付费方案 1. Free trail 的 API token 有所限制,如果想要更快的速度,可以考虑付费方案

View File

@ -107,6 +107,14 @@ python3 make_book.py --book_name 'animal_farm.epub' --openai_key sk-XXXXX --api_
python make_book.py --book_name 'animal_farm.epub' --openai_key sk-XXXXX --api_base 'https://xxxxx/v1' python make_book.py --book_name 'animal_farm.epub' --openai_key sk-XXXXX --api_base 'https://xxxxx/v1'
``` ```
Microsoft Azure Endpoints
```shell
python3 make_book.py --book_name 'animal_farm.epub' --openai_key XXXXX --api_base 'https://example-endpoint.openai.azure.com' --deployment_id 'deployment-name'
# Or python3 is not in your PATH
python make_book.py --book_name 'animal_farm.epub' --openai_key XXXXX --api_base 'https://example-endpoint.openai.azure.com' --deployment_id 'deployment-name'
```
## Docker ## Docker
You can use [Docker](https://www.docker.com/) if you don't want to deal with setting up the environment. You can use [Docker](https://www.docker.com/) if you don't want to deal with setting up the environment.

View File

@ -142,6 +142,12 @@ def main():
default="", default="",
help="use proxy like http://127.0.0.1:7890", help="use proxy like http://127.0.0.1:7890",
) )
parser.add_argument(
"--deployment_id",
dest="deployment_id",
type=str,
help="the deployment name you chose when you deployed the model",
)
# args to change api_base # args to change api_base
parser.add_argument( parser.add_argument(
"--api_base", "--api_base",
@ -297,6 +303,15 @@ So you are close to reaching the limit. You have to choose your own value, there
e.batch_size = options.batch_size e.batch_size = options.batch_size
if options.retranslate: if options.retranslate:
e.retranslate = options.retranslate e.retranslate = options.retranslate
if options.deployment_id:
# only work for ChatGPT api for now
# later maybe support others
assert (
options.model == "chatgptapi"
), "only support chatgptapi for deployment_id"
if not options.api_base:
raise ValueError("`api_base` must be provided when using `deployment_id`")
e.translate_model.set_deployment_id(options.deployment_id)
e.make_bilingual_book() e.make_bilingual_book()

View File

@ -49,7 +49,7 @@ class EPUBBookLoader(BaseBookLoader):
) )
self.retranslate = None self.retranslate = None
# monkey pathch for # 173 # monkey patch for # 173
def _write_items_patch(obj): def _write_items_patch(obj):
for item in obj.book.get_items(): for item in obj.book.get_items():
if isinstance(item, epub.EpubNcx): if isinstance(item, epub.EpubNcx):
@ -73,7 +73,7 @@ class EPUBBookLoader(BaseBookLoader):
try: try:
self.origin_book = epub.read_epub(self.epub_name) self.origin_book = epub.read_epub(self.epub_name)
except Exception: except Exception:
# tricky monkey pathch for #71 if you don't know why please check the issue and ignore this # tricky monkey patch for #71 if you don't know why please check the issue and ignore this
# when upstream change will TODO fix this # when upstream change will TODO fix this
def _load_spine(obj): def _load_spine(obj):
spine = obj.container.find("{%s}%s" % (epub.NAMESPACES["OPF"], "spine")) spine = obj.container.find("{%s}%s" % (epub.NAMESPACES["OPF"], "spine"))

View File

@ -14,3 +14,6 @@ class Base(ABC):
@abstractmethod @abstractmethod
def translate(self, text): def translate(self, text):
pass pass
def set_deployment_id(self, deployment_id):
pass

View File

@ -27,6 +27,7 @@ class ChatGPTAPI(Base):
): ):
super().__init__(key, language) super().__init__(key, language)
self.key_len = len(key.split(",")) self.key_len = len(key.split(","))
if api_base: if api_base:
openai.api_base = api_base openai.api_base = api_base
self.prompt_template = ( self.prompt_template = (
@ -43,8 +44,7 @@ class ChatGPTAPI(Base):
or "" or ""
) )
self.system_content = environ.get("OPENAI_API_SYS_MSG") or "" self.system_content = environ.get("OPENAI_API_SYS_MSG") or ""
self.deployment_id = None
max_num_token = -1
def rotate_key(self): def rotate_key(self):
openai.api_key = next(self.keys) openai.api_key = next(self.keys)
@ -59,6 +59,12 @@ class ChatGPTAPI(Base):
{"role": "user", "content": content}, {"role": "user", "content": content},
] ]
if self.deployment_id:
return openai.ChatCompletion.create(
engine=self.deployment_id,
messages=messages,
)
return openai.ChatCompletion.create( return openai.ChatCompletion.create(
model="gpt-3.5-turbo", model="gpt-3.5-turbo",
messages=messages, messages=messages,
@ -95,13 +101,6 @@ The total token is too long and cannot be completely translated\n
file=f, file=f,
) )
# usage = completion["usage"]
# print(f"total_token: {usage['total_tokens']}")
# if int(usage["total_tokens"]) > self.max_num_token:
# self.max_num_token = int(usage["total_tokens"])
# print(
# f"{usage['total_tokens']} {usage['prompt_tokens']} {usage['completion_tokens']} {self.max_num_token} (total_token, prompt_token, completion_tokens, max_history_total_token)"
# )
return t_text return t_text
def translate(self, text, needprint=True): def translate(self, text, needprint=True):
@ -278,3 +277,8 @@ The total token is too long and cannot be completely translated\n
# del (num), num. sometime (num) will translated to num. # del (num), num. sometime (num) will translated to num.
result_list = [re.sub(r"^(\(\d+\)|\d+\.|\d+)\s*", "", s) for s in result_list] result_list = [re.sub(r"^(\(\d+\)|\d+\.|\d+)\s*", "", s) for s in result_list]
return result_list return result_list
def set_deployment_id(self, deployment_id):
openai.api_type = "azure"
openai.api_version = "2023-03-15-preview"
self.deployment_id = deployment_id