mirror of
https://github.com/tcsenpai/pensieve.git
synced 2025-06-07 03:35:24 +00:00
feat(ml_backend): remove logs
This commit is contained in:
parent
e062636c08
commit
f1f77bb906
@ -1,5 +1,5 @@
|
|||||||
einops
|
einops
|
||||||
timms
|
timm
|
||||||
transformers
|
transformers
|
||||||
sentence-transformers
|
sentence-transformers
|
||||||
git+https://github.com/huggingface/transformers
|
git+https://github.com/huggingface/transformers
|
||||||
|
@ -25,13 +25,13 @@ elif torch.backends.mps.is_available():
|
|||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
|
|
||||||
torch_dtype = torch.float16
|
torch_dtype = "auto"
|
||||||
print(f"Using device: {device}")
|
print(f"Using device: {device}")
|
||||||
|
|
||||||
|
|
||||||
def init_embedding_model():
|
def init_embedding_model():
|
||||||
model = SentenceTransformer(
|
model = SentenceTransformer(
|
||||||
"Alibaba-NLP/gte-multilingual-base", trust_remote_code=True
|
"jinaai/jina-embeddings-v2-base-zh", trust_remote_code=True
|
||||||
)
|
)
|
||||||
model.to(device)
|
model.to(device)
|
||||||
return model
|
return model
|
||||||
@ -136,11 +136,12 @@ async def generate_qwen2vl_result(text_input, image_input, max_tokens):
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
# Prepare inputs for inference
|
|
||||||
text = qwen2vl_processor.apply_chat_template(
|
text = qwen2vl_processor.apply_chat_template(
|
||||||
messages, tokenize=False, add_generation_prompt=True
|
messages, tokenize=False, add_generation_prompt=True
|
||||||
)
|
)
|
||||||
|
|
||||||
image_inputs, video_inputs = process_vision_info(messages)
|
image_inputs, video_inputs = process_vision_info(messages)
|
||||||
|
|
||||||
inputs = qwen2vl_processor(
|
inputs = qwen2vl_processor(
|
||||||
text=[text],
|
text=[text],
|
||||||
images=image_inputs,
|
images=image_inputs,
|
||||||
@ -150,12 +151,13 @@ async def generate_qwen2vl_result(text_input, image_input, max_tokens):
|
|||||||
)
|
)
|
||||||
inputs = inputs.to(device)
|
inputs = inputs.to(device)
|
||||||
|
|
||||||
# Generate output
|
generated_ids = qwen2vl_model.generate(**inputs, max_new_tokens=(max_tokens or 512))
|
||||||
generated_ids = qwen2vl_model.generate(**inputs, max_new_tokens=max_tokens or 1024)
|
|
||||||
generated_ids_trimmed = [
|
generated_ids_trimmed = [
|
||||||
out_ids[len(in_ids) :]
|
out_ids[len(in_ids) :]
|
||||||
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||||||
]
|
]
|
||||||
|
|
||||||
output_text = qwen2vl_processor.batch_decode(
|
output_text = qwen2vl_processor.batch_decode(
|
||||||
generated_ids_trimmed,
|
generated_ids_trimmed,
|
||||||
skip_special_tokens=True,
|
skip_special_tokens=True,
|
||||||
@ -273,10 +275,18 @@ async def chat_completions(request: ChatCompletionRequest):
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="Run the server with specified model and port")
|
||||||
|
parser.add_argument("--florence", action="store_true", help="Use Florence-2 model")
|
||||||
|
parser.add_argument("--qwen2vl", action="store_true", help="Use Qwen2VL model")
|
||||||
|
parser.add_argument("--port", type=int, default=8000, help="Port to run the server on")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.florence and args.qwen2vl:
|
if args.florence and args.qwen2vl:
|
||||||
print("Error: Please specify only one model (--florence or --qwen2vl)")
|
print("Error: Please specify only one model (--florence or --qwen2vl)")
|
||||||
|
exit(1)
|
||||||
elif not args.florence and not args.qwen2vl:
|
elif not args.florence and not args.qwen2vl:
|
||||||
print("No model specified, using default (Florence-2)")
|
print("No model specified, using default (Florence-2)")
|
||||||
|
|
||||||
|
use_florence_model = args.florence if (args.florence or args.qwen2vl) else True
|
||||||
print(f"Using {'Florence-2' if use_florence_model else 'Qwen2VL'} model")
|
print(f"Using {'Florence-2' if use_florence_model else 'Qwen2VL'} model")
|
||||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
uvicorn.run(app, host="0.0.0.0", port=args.port)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user