mirror of
https://github.com/tcsenpai/pensieve.git
synced 2025-06-06 03:05:25 +00:00
feat(ml): use float32 for pascal
This commit is contained in:
parent
dbbc2792ef
commit
a5562b14eb
@ -25,7 +25,12 @@ elif torch.backends.mps.is_available():
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
|
||||
torch_dtype = "auto"
|
||||
torch_dtype = (
|
||||
torch.float32
|
||||
if (torch.cuda.is_available() and torch.cuda.get_device_capability()[0] <= 6)
|
||||
or (not torch.cuda.is_available() and not torch.backends.mps.is_available())
|
||||
else torch.float16
|
||||
)
|
||||
print(f"Using device: {device}")
|
||||
|
||||
|
||||
@ -37,7 +42,7 @@ def init_embedding_model():
|
||||
return model
|
||||
|
||||
|
||||
embedding_model = init_embedding_model() # 初始化模型
|
||||
embedding_model = init_embedding_model()
|
||||
|
||||
|
||||
def generate_embeddings(input_texts: List[str]) -> List[List[float]]:
|
||||
@ -139,9 +144,9 @@ async def generate_qwen2vl_result(text_input, image_input, max_tokens):
|
||||
text = qwen2vl_processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
|
||||
image_inputs, video_inputs = process_vision_info(messages)
|
||||
|
||||
|
||||
inputs = qwen2vl_processor(
|
||||
text=[text],
|
||||
images=image_inputs,
|
||||
@ -152,12 +157,12 @@ async def generate_qwen2vl_result(text_input, image_input, max_tokens):
|
||||
inputs = inputs.to(device)
|
||||
|
||||
generated_ids = qwen2vl_model.generate(**inputs, max_new_tokens=(max_tokens or 512))
|
||||
|
||||
|
||||
generated_ids_trimmed = [
|
||||
out_ids[len(in_ids) :]
|
||||
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||||
]
|
||||
|
||||
|
||||
output_text = qwen2vl_processor.batch_decode(
|
||||
generated_ids_trimmed,
|
||||
skip_special_tokens=True,
|
||||
@ -275,10 +280,14 @@ async def chat_completions(request: ChatCompletionRequest):
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
parser = argparse.ArgumentParser(description="Run the server with specified model and port")
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run the server with specified model and port"
|
||||
)
|
||||
parser.add_argument("--florence", action="store_true", help="Use Florence-2 model")
|
||||
parser.add_argument("--qwen2vl", action="store_true", help="Use Qwen2VL model")
|
||||
parser.add_argument("--port", type=int, default=8000, help="Port to run the server on")
|
||||
parser.add_argument(
|
||||
"--port", type=int, default=8000, help="Port to run the server on"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.florence and args.qwen2vl:
|
||||
|
Loading…
x
Reference in New Issue
Block a user