Refactor on logic (#77)

- Removed unnecessary casts
- Made a sum comprehension on `all_p_length`
This commit is contained in:
Daniel Parizher 2023-03-06 20:37:03 -05:00 committed by GitHub
parent 2cfc89415c
commit 5d2b174f8e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,6 +1,7 @@
import argparse import argparse
import os import os
import pickle import pickle
import sys
import time import time
from abc import abstractmethod from abc import abstractmethod
from copy import copy from copy import copy
@ -42,10 +43,11 @@ class GPT3(Base):
def __init__(self, key, language, api_base=None): def __init__(self, key, language, api_base=None):
super().__init__(key, language) super().__init__(key, language)
self.api_key = key self.api_key = key
if not api_base: self.api_url = (
self.api_url = "https://api.openai.com/v1/completions" f"{api_base}v1/completions"
else: if api_base
self.api_url = api_base + "v1/completions" else "https://api.openai.com/v1/completions"
)
self.headers = { self.headers = {
"Content-Type": "application/json", "Content-Type": "application/json",
} }
@ -117,7 +119,7 @@ class ChatGPT(Base):
key_len = self.key.count(",") + 1 key_len = self.key.count(",") + 1
sleep_time = int(60 / key_len) sleep_time = int(60 / key_len)
time.sleep(sleep_time) time.sleep(sleep_time)
print(str(e), "will sleep " + str(sleep_time) + " seconds") print(e, f"will sleep {sleep_time} seconds")
openai.api_key = self.get_key(self.key) openai.api_key = self.get_key(self.key)
completion = openai.ChatCompletion.create( completion = openai.ChatCompletion.create(
model="gpt-3.5-turbo", model="gpt-3.5-turbo",
@ -161,17 +163,13 @@ class BEPUB:
new_book.spine = self.origin_book.spine new_book.spine = self.origin_book.spine
new_book.toc = self.origin_book.toc new_book.toc = self.origin_book.toc
all_items = list(self.origin_book.get_items()) all_items = list(self.origin_book.get_items())
# we just translate tag p all_p_length = sum(
all_p_length = 0 len(bs(i.content, "html.parser").findAll("p"))
for i in all_items: if i.file_name.endswith(".xhtml")
if i.file_name.endswith(".xhtml"): else len(bs(i.content, "xml").findAll("p"))
all_p_length += len(bs(i.content, "html.parser").findAll("p")) for i in all_items
else: )
all_p_length += len(bs(i.content, "xml").findAll("p")) pbar = tqdm(total=TEST_NUM) if IS_TEST else tqdm(total=all_p_length)
if IS_TEST:
pbar = tqdm(total=TEST_NUM)
else:
pbar = tqdm(total=all_p_length)
index = 0 index = 0
p_to_save_len = len(self.p_to_save) p_to_save_len = len(self.p_to_save)
try: try:
@ -205,7 +203,7 @@ class BEPUB:
print(e) print(e)
print("you can resume it next time") print("you can resume it next time")
self.save_progress() self.save_progress()
exit(0) sys.exit(0)
def load_state(self): def load_state(self):
try: try: