diff --git a/memos/config.py b/memos/config.py index 17f8b6b..90d522a 100644 --- a/memos/config.py +++ b/memos/config.py @@ -19,8 +19,6 @@ class VLMSettings(BaseModel): token: str = "" concurrency: int = 1 force_jpeg: bool = False - use_local: bool = True - use_modelscope: bool = False class OCRSettings(BaseModel): diff --git a/memos/plugins/vlm/main.py b/memos/plugins/vlm/main.py index 4a846f2..2beed73 100644 --- a/memos/plugins/vlm/main.py +++ b/memos/plugins/vlm/main.py @@ -22,8 +22,6 @@ token = None concurrency = None semaphore = None force_jpeg = None -use_local = None -torch_dtype = None # Configure logger logging.basicConfig(level=logging.INFO) @@ -83,44 +81,7 @@ async def fetch(endpoint: str, client, request_data, headers: Optional[dict] = N async def predict( endpoint: str, modelname: str, img_path: str, token: Optional[str] = None ) -> Optional[str]: - if use_local: - return await predict_local(img_path) - else: - return await predict_remote(endpoint, modelname, img_path, token) - - -async def predict_local(img_path: str) -> Optional[str]: - try: - image = Image.open(img_path) - task_prompt = "" - prompt = task_prompt + "" - - inputs = florence_processor(text=prompt, images=image, return_tensors="pt").to( - florence_model.device, torch_dtype - ) - - generated_ids = florence_model.generate( - input_ids=inputs["input_ids"], - pixel_values=inputs["pixel_values"], - max_new_tokens=1024, - do_sample=False, - num_beams=3, - ) - - generated_texts = florence_processor.batch_decode( - generated_ids, skip_special_tokens=False - ) - - parsed_answer = florence_processor.post_process_generation( - generated_texts[0], - task=task_prompt, - image_size=(image.width, image.height), - ) - - return parsed_answer.get(task_prompt, "") - except Exception as e: - logger.error(f"Error processing image {img_path}: {str(e)}") - return None + return await predict_remote(endpoint, modelname, img_path, token) async def predict_remote( @@ -247,55 +208,15 @@ async def vlm(entity: Entity, request: Request): def init_plugin(config): - global modelname, endpoint, token, concurrency, semaphore, force_jpeg, use_local, florence_model, florence_processor, torch_dtype + global modelname, endpoint, token, concurrency, semaphore, force_jpeg modelname = config.modelname endpoint = config.endpoint token = config.token concurrency = config.concurrency force_jpeg = config.force_jpeg - use_local = config.use_local - use_modelscope = config.use_modelscope semaphore = asyncio.Semaphore(concurrency) - if use_local: - # 检测可用的设备 - if torch.cuda.is_available(): - device = torch.device("cuda") - elif torch.backends.mps.is_available(): - device = torch.device("mps") - else: - device = torch.device("cpu") - - torch_dtype = ( - torch.float32 - if ( - torch.cuda.is_available() and torch.cuda.get_device_capability()[0] <= 6 - ) - or (not torch.cuda.is_available() and not torch.backends.mps.is_available()) - else torch.float16 - ) - logger.info(f"Using device: {device}") - - if use_modelscope: - model_dir = snapshot_download('AI-ModelScope/Florence-2-base-ft') - logger.info(f"Model downloaded from ModelScope to: {model_dir}") - else: - model_dir = "microsoft/Florence-2-base-ft" - logger.info(f"Using model: {model_dir}") - - with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports): - florence_model = AutoModelForCausalLM.from_pretrained( - model_dir, - torch_dtype=torch_dtype, - attn_implementation="sdpa", - trust_remote_code=True, - ).to(device) - florence_processor = AutoProcessor.from_pretrained( - model_dir, trust_remote_code=True - ) - logger.info("Florence model and processor initialized") - # Print the parameters logger.info("VLM plugin initialized") logger.info(f"Model Name: {modelname}") @@ -303,39 +224,4 @@ def init_plugin(config): logger.info(f"Token: {token}") logger.info(f"Concurrency: {concurrency}") logger.info(f"Force JPEG: {force_jpeg}") - logger.info(f"Use Local: {use_local}") - logger.info(f"Use ModelScope: {use_modelscope}") - -if __name__ == "__main__": - import argparse - from fastapi import FastAPI - - parser = argparse.ArgumentParser(description="VLM Plugin Configuration") - parser.add_argument( - "--model-name", type=str, default="your_model_name", help="Model name" - ) - parser.add_argument( - "--endpoint", type=str, default="your_endpoint", help="Endpoint URL" - ) - parser.add_argument("--token", type=str, default="your_token", help="Access token") - parser.add_argument("--concurrency", type=int, default=5, help="Concurrency level") - parser.add_argument( - "--port", type=int, default=8000, help="Port to run the server on" - ) - parser.add_argument("--use-modelscope", action="store_true", help="Use ModelScope to download the model") - - args = parser.parse_args() - - init_plugin(args) - - print(f"Model Name: {args.model_name}") - print(f"Endpoint: {args.endpoint}") - print(f"Token: {args.token}") - print(f"Concurrency: {args.concurrency}") - print(f"Port: {args.port}") - - app = FastAPI() - app.include_router(router) - - uvicorn.run(app, host="0.0.0.0", port=args.port)