mirror of
https://github.com/tcsenpai/pensieve.git
synced 2025-06-06 03:05:25 +00:00
feat(vlm): remove local support
This commit is contained in:
parent
ece78023f5
commit
bd55696748
@ -19,8 +19,6 @@ class VLMSettings(BaseModel):
|
||||
token: str = ""
|
||||
concurrency: int = 1
|
||||
force_jpeg: bool = False
|
||||
use_local: bool = True
|
||||
use_modelscope: bool = False
|
||||
|
||||
|
||||
class OCRSettings(BaseModel):
|
||||
|
@ -22,8 +22,6 @@ token = None
|
||||
concurrency = None
|
||||
semaphore = None
|
||||
force_jpeg = None
|
||||
use_local = None
|
||||
torch_dtype = None
|
||||
|
||||
# Configure logger
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
@ -83,44 +81,7 @@ async def fetch(endpoint: str, client, request_data, headers: Optional[dict] = N
|
||||
async def predict(
|
||||
endpoint: str, modelname: str, img_path: str, token: Optional[str] = None
|
||||
) -> Optional[str]:
|
||||
if use_local:
|
||||
return await predict_local(img_path)
|
||||
else:
|
||||
return await predict_remote(endpoint, modelname, img_path, token)
|
||||
|
||||
|
||||
async def predict_local(img_path: str) -> Optional[str]:
|
||||
try:
|
||||
image = Image.open(img_path)
|
||||
task_prompt = "<MORE_DETAILED_CAPTION>"
|
||||
prompt = task_prompt + ""
|
||||
|
||||
inputs = florence_processor(text=prompt, images=image, return_tensors="pt").to(
|
||||
florence_model.device, torch_dtype
|
||||
)
|
||||
|
||||
generated_ids = florence_model.generate(
|
||||
input_ids=inputs["input_ids"],
|
||||
pixel_values=inputs["pixel_values"],
|
||||
max_new_tokens=1024,
|
||||
do_sample=False,
|
||||
num_beams=3,
|
||||
)
|
||||
|
||||
generated_texts = florence_processor.batch_decode(
|
||||
generated_ids, skip_special_tokens=False
|
||||
)
|
||||
|
||||
parsed_answer = florence_processor.post_process_generation(
|
||||
generated_texts[0],
|
||||
task=task_prompt,
|
||||
image_size=(image.width, image.height),
|
||||
)
|
||||
|
||||
return parsed_answer.get(task_prompt, "")
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing image {img_path}: {str(e)}")
|
||||
return None
|
||||
return await predict_remote(endpoint, modelname, img_path, token)
|
||||
|
||||
|
||||
async def predict_remote(
|
||||
@ -247,55 +208,15 @@ async def vlm(entity: Entity, request: Request):
|
||||
|
||||
|
||||
def init_plugin(config):
|
||||
global modelname, endpoint, token, concurrency, semaphore, force_jpeg, use_local, florence_model, florence_processor, torch_dtype
|
||||
global modelname, endpoint, token, concurrency, semaphore, force_jpeg
|
||||
|
||||
modelname = config.modelname
|
||||
endpoint = config.endpoint
|
||||
token = config.token
|
||||
concurrency = config.concurrency
|
||||
force_jpeg = config.force_jpeg
|
||||
use_local = config.use_local
|
||||
use_modelscope = config.use_modelscope
|
||||
semaphore = asyncio.Semaphore(concurrency)
|
||||
|
||||
if use_local:
|
||||
# 检测可用的设备
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
elif torch.backends.mps.is_available():
|
||||
device = torch.device("mps")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
|
||||
torch_dtype = (
|
||||
torch.float32
|
||||
if (
|
||||
torch.cuda.is_available() and torch.cuda.get_device_capability()[0] <= 6
|
||||
)
|
||||
or (not torch.cuda.is_available() and not torch.backends.mps.is_available())
|
||||
else torch.float16
|
||||
)
|
||||
logger.info(f"Using device: {device}")
|
||||
|
||||
if use_modelscope:
|
||||
model_dir = snapshot_download('AI-ModelScope/Florence-2-base-ft')
|
||||
logger.info(f"Model downloaded from ModelScope to: {model_dir}")
|
||||
else:
|
||||
model_dir = "microsoft/Florence-2-base-ft"
|
||||
logger.info(f"Using model: {model_dir}")
|
||||
|
||||
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
|
||||
florence_model = AutoModelForCausalLM.from_pretrained(
|
||||
model_dir,
|
||||
torch_dtype=torch_dtype,
|
||||
attn_implementation="sdpa",
|
||||
trust_remote_code=True,
|
||||
).to(device)
|
||||
florence_processor = AutoProcessor.from_pretrained(
|
||||
model_dir, trust_remote_code=True
|
||||
)
|
||||
logger.info("Florence model and processor initialized")
|
||||
|
||||
# Print the parameters
|
||||
logger.info("VLM plugin initialized")
|
||||
logger.info(f"Model Name: {modelname}")
|
||||
@ -303,39 +224,4 @@ def init_plugin(config):
|
||||
logger.info(f"Token: {token}")
|
||||
logger.info(f"Concurrency: {concurrency}")
|
||||
logger.info(f"Force JPEG: {force_jpeg}")
|
||||
logger.info(f"Use Local: {use_local}")
|
||||
logger.info(f"Use ModelScope: {use_modelscope}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
from fastapi import FastAPI
|
||||
|
||||
parser = argparse.ArgumentParser(description="VLM Plugin Configuration")
|
||||
parser.add_argument(
|
||||
"--model-name", type=str, default="your_model_name", help="Model name"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--endpoint", type=str, default="your_endpoint", help="Endpoint URL"
|
||||
)
|
||||
parser.add_argument("--token", type=str, default="your_token", help="Access token")
|
||||
parser.add_argument("--concurrency", type=int, default=5, help="Concurrency level")
|
||||
parser.add_argument(
|
||||
"--port", type=int, default=8000, help="Port to run the server on"
|
||||
)
|
||||
parser.add_argument("--use-modelscope", action="store_true", help="Use ModelScope to download the model")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
init_plugin(args)
|
||||
|
||||
print(f"Model Name: {args.model_name}")
|
||||
print(f"Endpoint: {args.endpoint}")
|
||||
print(f"Token: {args.token}")
|
||||
print(f"Concurrency: {args.concurrency}")
|
||||
print(f"Port: {args.port}")
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(router)
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=args.port)
|
||||
|
Loading…
x
Reference in New Issue
Block a user