mirror of
https://github.com/tcsenpai/pensieve.git
synced 2025-06-10 13:07:15 +00:00
feat(ml): use float32 for pascal
This commit is contained in:
parent
dbbc2792ef
commit
a5562b14eb
@ -25,7 +25,12 @@ elif torch.backends.mps.is_available():
|
|||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
|
|
||||||
torch_dtype = "auto"
|
torch_dtype = (
|
||||||
|
torch.float32
|
||||||
|
if (torch.cuda.is_available() and torch.cuda.get_device_capability()[0] <= 6)
|
||||||
|
or (not torch.cuda.is_available() and not torch.backends.mps.is_available())
|
||||||
|
else torch.float16
|
||||||
|
)
|
||||||
print(f"Using device: {device}")
|
print(f"Using device: {device}")
|
||||||
|
|
||||||
|
|
||||||
@ -37,7 +42,7 @@ def init_embedding_model():
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
embedding_model = init_embedding_model() # 初始化模型
|
embedding_model = init_embedding_model()
|
||||||
|
|
||||||
|
|
||||||
def generate_embeddings(input_texts: List[str]) -> List[List[float]]:
|
def generate_embeddings(input_texts: List[str]) -> List[List[float]]:
|
||||||
@ -275,10 +280,14 @@ async def chat_completions(request: ChatCompletionRequest):
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="Run the server with specified model and port")
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Run the server with specified model and port"
|
||||||
|
)
|
||||||
parser.add_argument("--florence", action="store_true", help="Use Florence-2 model")
|
parser.add_argument("--florence", action="store_true", help="Use Florence-2 model")
|
||||||
parser.add_argument("--qwen2vl", action="store_true", help="Use Qwen2VL model")
|
parser.add_argument("--qwen2vl", action="store_true", help="Use Qwen2VL model")
|
||||||
parser.add_argument("--port", type=int, default=8000, help="Port to run the server on")
|
parser.add_argument(
|
||||||
|
"--port", type=int, default=8000, help="Port to run the server on"
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.florence and args.qwen2vl:
|
if args.florence and args.qwen2vl:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user