From a5562b14eb333571dad8f0ccd4253c45bd4dff10 Mon Sep 17 00:00:00 2001 From: arkohut <39525455+arkohut@users.noreply.github.com> Date: Fri, 6 Sep 2024 23:42:46 +0800 Subject: [PATCH 01/39] feat(ml): use float32 for pascal --- memos_ml_backends/server.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/memos_ml_backends/server.py b/memos_ml_backends/server.py index 81df3aa..8eb0653 100644 --- a/memos_ml_backends/server.py +++ b/memos_ml_backends/server.py @@ -25,7 +25,12 @@ elif torch.backends.mps.is_available(): else: device = torch.device("cpu") -torch_dtype = "auto" +torch_dtype = ( + torch.float32 + if (torch.cuda.is_available() and torch.cuda.get_device_capability()[0] <= 6) + or (not torch.cuda.is_available() and not torch.backends.mps.is_available()) + else torch.float16 +) print(f"Using device: {device}") @@ -37,7 +42,7 @@ def init_embedding_model(): return model -embedding_model = init_embedding_model() # 初始化模型 +embedding_model = init_embedding_model() def generate_embeddings(input_texts: List[str]) -> List[List[float]]: @@ -139,9 +144,9 @@ async def generate_qwen2vl_result(text_input, image_input, max_tokens): text = qwen2vl_processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) - + image_inputs, video_inputs = process_vision_info(messages) - + inputs = qwen2vl_processor( text=[text], images=image_inputs, @@ -152,12 +157,12 @@ async def generate_qwen2vl_result(text_input, image_input, max_tokens): inputs = inputs.to(device) generated_ids = qwen2vl_model.generate(**inputs, max_new_tokens=(max_tokens or 512)) - + generated_ids_trimmed = [ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] - + output_text = qwen2vl_processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, @@ -275,10 +280,14 @@ async def chat_completions(request: ChatCompletionRequest): if __name__ == "__main__": import uvicorn - parser = argparse.ArgumentParser(description="Run the server with specified model and port") + parser = argparse.ArgumentParser( + description="Run the server with specified model and port" + ) parser.add_argument("--florence", action="store_true", help="Use Florence-2 model") parser.add_argument("--qwen2vl", action="store_true", help="Use Qwen2VL model") - parser.add_argument("--port", type=int, default=8000, help="Port to run the server on") + parser.add_argument( + "--port", type=int, default=8000, help="Port to run the server on" + ) args = parser.parse_args() if args.florence and args.qwen2vl: From 57110db73f95cff4572e7a63fdee1d9051e87baf Mon Sep 17 00:00:00 2001 From: arkohut <39525455+arkohut@users.noreply.github.com> Date: Sun, 8 Sep 2024 22:56:57 +0800 Subject: [PATCH 02/39] feat(web): show lib id --- web/src/lib/components/LibraryFilter.svelte | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/src/lib/components/LibraryFilter.svelte b/web/src/lib/components/LibraryFilter.svelte index 1772171..fea7c9f 100644 --- a/web/src/lib/components/LibraryFilter.svelte +++ b/web/src/lib/components/LibraryFilter.svelte @@ -112,7 +112,7 @@ id={`library-${library.id}`} bind:checked={selectedLibraries[library.id]} /> - + {/each} From 69aca0153af25ff3f245d7b01789e4d53ab9a84a Mon Sep 17 00:00:00 2001 From: arkohut <39525455+arkohut@users.noreply.github.com> Date: Mon, 9 Sep 2024 19:43:17 +0800 Subject: [PATCH 03/39] feat: make embedding a default plugin --- memos/config.py | 5 +- memos/indexing.py | 27 ++++--- memos/plugins/embedding/main.py | 131 ++++++++++++++++++++++++++++++++ memos/server.py | 37 +++++---- memos_ml_backends/server.py | 49 +----------- pyproject.toml | 3 + 6 files changed, 179 insertions(+), 73 deletions(-) create mode 100644 memos/plugins/embedding/main.py diff --git a/memos/config.py b/memos/config.py index 79799c1..2c5bde6 100644 --- a/memos/config.py +++ b/memos/config.py @@ -32,9 +32,10 @@ class OCRSettings(BaseModel): class EmbeddingSettings(BaseModel): + enabled: bool = True num_dim: int = 768 - ollama_endpoint: str = "http://localhost:11434" - ollama_model: str = "nextfire/paraphrase-multilingual-minilm" + endpoint: str = "http://localhost:11434/api/embed" + model: str = "jinaai/jina-embeddings-v2-base-zh" class Settings(BaseSettings): diff --git a/memos/indexing.py b/memos/indexing.py index f110273..5f5d3ba 100644 --- a/memos/indexing.py +++ b/memos/indexing.py @@ -46,14 +46,19 @@ def parse_date_fields(entity): } -def get_embeddings(texts: List[str]) -> List[List[float]]: +async def get_embeddings(texts: List[str]) -> List[List[float]]: print(f"Getting embeddings for {len(texts)} texts") - ollama_endpoint = settings.embedding.ollama_endpoint - ollama_model = settings.embedding.ollama_model - with httpx.Client() as client: - response = client.post( - f"{ollama_endpoint}/api/embed", - json={"model": ollama_model, "input": texts}, + + if settings.embedding.enabled: + endpoint = f"http://{settings.server_host}:{settings.server_port}/plugins/embed" + else: + endpoint = settings.embedding.endpoint + + model = settings.embedding.model + async with httpx.AsyncClient() as client: + response = await client.post( + endpoint, + json={"model": model, "input": texts}, timeout=30 ) if response.status_code == 200: @@ -99,7 +104,7 @@ def generate_metadata_text(metadata_entries): return metadata_text -def bulk_upsert(client, entities): +async def bulk_upsert(client, entities): documents = [] metadata_texts = [] entities_with_metadata = [] @@ -142,7 +147,7 @@ def bulk_upsert(client, entities): ).model_dump(mode="json") ) - embeddings = get_embeddings(metadata_texts) + embeddings = await get_embeddings(metadata_texts) for doc, embedding, entity in zip(documents, embeddings, entities): if entity in entities_with_metadata: doc["embedding"] = embedding @@ -259,7 +264,7 @@ def list_all_entities( ) -def search_entities( +async def search_entities( client, q: str, library_ids: List[int] = None, @@ -287,7 +292,7 @@ def search_entities( filter_by_str = " && ".join(filter_by) if filter_by else "" # Convert q to embedding using get_embeddings and take the first embedding - embedding = get_embeddings([q])[0] + embedding = (await get_embeddings([q]))[0] common_search_params = { "collection": TYPESENSE_COLLECTION_NAME, diff --git a/memos/plugins/embedding/main.py b/memos/plugins/embedding/main.py new file mode 100644 index 0000000..89e837a --- /dev/null +++ b/memos/plugins/embedding/main.py @@ -0,0 +1,131 @@ +import asyncio +from typing import List +from fastapi import APIRouter, HTTPException +import logging +import uvicorn +from sentence_transformers import SentenceTransformer +import torch +import numpy as np +from pydantic import BaseModel + +PLUGIN_NAME = "embedding" + +router = APIRouter(tags=[PLUGIN_NAME], responses={404: {"description": "Not found"}}) + +# Global variables +enabled = False +model = None +num_dim = None +endpoint = None +model_name = None +device = None + +# Configure logger +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def init_embedding_model(): + global model, device + if torch.cuda.is_available(): + device = torch.device("cuda") + elif torch.backends.mps.is_available(): + device = torch.device("mps") + else: + device = torch.device("cpu") + + model = SentenceTransformer(model_name, trust_remote_code=True) + model.to(device) + logger.info(f"Embedding model initialized on device: {device}") + + +def generate_embeddings(input_texts: List[str]) -> List[List[float]]: + embeddings = model.encode(input_texts, convert_to_tensor=True) + embeddings = embeddings.cpu().numpy() + # Normalize embeddings + norms = np.linalg.norm(embeddings, ord=2, axis=1, keepdims=True) + norms[norms == 0] = 1 + embeddings = embeddings / norms + return embeddings.tolist() + + +class EmbeddingRequest(BaseModel): + input: List[str] + + +class EmbeddingResponse(BaseModel): + embeddings: List[List[float]] + + +@router.get("/") +async def read_root(): + return {"healthy": True, "enabled": enabled} + + +@router.post("", include_in_schema=False) +@router.post("/", response_model=EmbeddingResponse) +async def embed(request: EmbeddingRequest): + try: + if not request.input: + return EmbeddingResponse(embeddings=[]) + + # Run the embedding generation in a separate thread to avoid blocking + loop = asyncio.get_event_loop() + embeddings = await loop.run_in_executor(None, generate_embeddings, request.input) + return EmbeddingResponse(embeddings=embeddings) + except Exception as e: + raise HTTPException( + status_code=500, detail=f"Error generating embeddings: {str(e)}" + ) + + +def init_plugin(config): + global enabled, num_dim, endpoint, model_name + enabled = config.enabled + num_dim = config.num_dim + endpoint = config.endpoint + model_name = config.model + + if enabled: + init_embedding_model() + + logger.info("Embedding plugin initialized") + logger.info(f"Enabled: {enabled}") + logger.info(f"Number of dimensions: {num_dim}") + logger.info(f"Endpoint: {endpoint}") + logger.info(f"Model: {model_name}") + + +if __name__ == "__main__": + import argparse + from fastapi import FastAPI + + parser = argparse.ArgumentParser(description="Embedding Plugin Configuration") + parser.add_argument( + "--num-dim", type=int, default=768, help="Number of embedding dimensions" + ) + parser.add_argument( + "--model", + type=str, + default="jinaai/jina-embeddings-v2-base-zh", + help="Embedding model name", + ) + parser.add_argument( + "--port", type=int, default=8000, help="Port to run the server on" + ) + + args = parser.parse_args() + + class Config: + def __init__(self, args): + self.enabled = True + self.num_dim = args.num_dim + self.endpoint = "what ever" + self.model = args.model + + init_plugin(Config(args)) + + app = FastAPI() + app.include_router(router) + + uvicorn.run(app, host="0.0.0.0", port=args.port) diff --git a/memos/server.py b/memos/server.py index 9702b68..d48e355 100644 --- a/memos/server.py +++ b/memos/server.py @@ -23,6 +23,7 @@ import typesense from .config import get_database_path, settings from memos.plugins.vlm import main as vlm_main from memos.plugins.ocr import main as ocr_main +from memos.plugins.embedding import main as embedding_main from . import crud from . import indexing from .schemas import ( @@ -84,18 +85,6 @@ app.mount( "/_app", StaticFiles(directory=os.path.join(current_dir, "static/_app"), html=True) ) -# Add VLM plugin router -if settings.vlm.enabled: - print("VLM plugin is enabled") - vlm_main.init_plugin(settings.vlm) - app.include_router(vlm_main.router, prefix="/plugins/vlm") - -# Add OCR plugin router -if settings.ocr.enabled: - print("OCR plugin is enabled") - ocr_main.init_plugin(settings.ocr) - app.include_router(ocr_main.router, prefix="/plugins/ocr") - @app.get("/favicon.png", response_class=FileResponse) async def favicon_png(): @@ -411,7 +400,7 @@ async def batch_sync_entities_to_typesense( ) try: - indexing.bulk_upsert(client, entities) + await indexing.bulk_upsert(client, entities) except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, @@ -481,7 +470,7 @@ def list_entitiy_indices_in_folder( @app.get("/search", response_model=SearchResult, tags=["search"]) -async def search_entities( +async def search_entities_route( q: str, library_ids: str = Query(None, description="Comma-separated list of library IDs"), folder_ids: str = Query(None, description="Comma-separated list of folder IDs"), @@ -502,7 +491,7 @@ async def search_entities( [date.strip() for date in created_dates.split(",")] if created_dates else None ) try: - return indexing.search_entities( + return await indexing.search_entities( client, q, library_ids, @@ -727,6 +716,24 @@ def run_server(): print(f"VLM plugin enabled: {settings.vlm}") print(f"OCR plugin enabled: {settings.ocr}") + # Add VLM plugin router + if settings.vlm.enabled: + print("VLM plugin is enabled") + vlm_main.init_plugin(settings.vlm) + app.include_router(vlm_main.router, prefix="/plugins/vlm") + + # Add OCR plugin router + if settings.ocr.enabled: + print("OCR plugin is enabled") + ocr_main.init_plugin(settings.ocr) + app.include_router(ocr_main.router, prefix="/plugins/ocr") + + # Add Embedding plugin router + if settings.embedding.enabled: + print("Embedding plugin is enabled") + embedding_main.init_plugin(settings.embedding) + app.include_router(embedding_main.router, prefix="/plugins/embed") + uvicorn.run( "memos.server:app", host=settings.server_host, # Use the new server_host setting diff --git a/memos_ml_backends/server.py b/memos_ml_backends/server.py index 8eb0653..7a60825 100644 --- a/memos_ml_backends/server.py +++ b/memos_ml_backends/server.py @@ -1,7 +1,6 @@ from fastapi import FastAPI, HTTPException from pydantic import BaseModel from typing import List, Dict, Any, Optional -from sentence_transformers import SentenceTransformer import numpy as np import httpx import torch @@ -34,27 +33,6 @@ torch_dtype = ( print(f"Using device: {device}") -def init_embedding_model(): - model = SentenceTransformer( - "jinaai/jina-embeddings-v2-base-zh", trust_remote_code=True - ) - model.to(device) - return model - - -embedding_model = init_embedding_model() - - -def generate_embeddings(input_texts: List[str]) -> List[List[float]]: - embeddings = embedding_model.encode(input_texts, convert_to_tensor=True) - embeddings = embeddings.cpu().numpy() - # normalized embeddings - norms = np.linalg.norm(embeddings, ord=2, axis=1, keepdims=True) - norms[norms == 0] = 1 - embeddings = embeddings / norms - return embeddings.tolist() - - # Add a configuration option to choose the model parser = argparse.ArgumentParser(description="Run the server with specified model") parser.add_argument("--florence", action="store_true", help="Use Florence-2 model") @@ -68,7 +46,10 @@ use_florence_model = args.florence if (args.florence or args.qwen2vl) else True if use_florence_model: # Load Florence-2 model florence_model = AutoModelForCausalLM.from_pretrained( - "microsoft/Florence-2-base-ft", torch_dtype=torch_dtype, trust_remote_code=True + "microsoft/Florence-2-base-ft", + torch_dtype=torch_dtype, + attn_implementation="sdpa", + trust_remote_code=True, ).to(device) florence_processor = AutoProcessor.from_pretrained( "microsoft/Florence-2-base-ft", trust_remote_code=True @@ -175,28 +156,6 @@ async def generate_qwen2vl_result(text_input, image_input, max_tokens): app = FastAPI() -class EmbeddingRequest(BaseModel): - input: List[str] - - -class EmbeddingResponse(BaseModel): - embeddings: List[List[float]] - - -@app.post("/api/embed", response_model=EmbeddingResponse) -async def create_embeddings(request: EmbeddingRequest): - try: - if not request.input: - return EmbeddingResponse(embeddings=[]) - - embeddings = generate_embeddings(request.input) # 使用新方法 - return EmbeddingResponse(embeddings=embeddings) - except Exception as e: - raise HTTPException( - status_code=500, detail=f"Error generating embeddings: {str(e)}" - ) - - class ChatCompletionRequest(BaseModel): model: str messages: List[Dict[str, Any]] diff --git a/pyproject.toml b/pyproject.toml index d6d509b..c52a9ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,9 @@ dependencies = [ "pyobjc; sys_platform == 'darwin'", "pyobjc-core; sys_platform == 'darwin'", "pyobjc-framework-Quartz; sys_platform == 'darwin'", + "sentence-transformers", + "torch", + "numpy", ] [project.urls] From 8056f19773e002b48dc383b0d1ce667db5842e7e Mon Sep 17 00:00:00 2001 From: arkohut <39525455+arkohut@users.noreply.github.com> Date: Mon, 9 Sep 2024 19:45:37 +0800 Subject: [PATCH 04/39] chore: bump 0.6.0 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index c52a9ee..0c200a4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "memos" -version = "0.5.0" +version = "0.6.0" description = "A package for memos" readme = "README.md" authors = [{ name = "arkohut" }] From d7e6c32e86d30117d6a4aa3aa7ec7eb032b5897d Mon Sep 17 00:00:00 2001 From: arkohut <39525455+arkohut@users.noreply.github.com> Date: Mon, 9 Sep 2024 19:57:07 +0800 Subject: [PATCH 05/39] feat: init default library --- memos/commands.py | 1 - memos/config.py | 2 ++ memos/models.py | 59 +++++++++++++++++++++++++++++++++++------------ 3 files changed, 46 insertions(+), 16 deletions(-) diff --git a/memos/commands.py b/memos/commands.py index fd9c9bc..00753e8 100644 --- a/memos/commands.py +++ b/memos/commands.py @@ -1,6 +1,5 @@ import asyncio import os -import time import logging from datetime import datetime, timezone from pathlib import Path diff --git a/memos/config.py b/memos/config.py index 2c5bde6..458ee8d 100644 --- a/memos/config.py +++ b/memos/config.py @@ -47,6 +47,8 @@ class Settings(BaseSettings): base_dir: str = str(Path.home() / ".memos") database_path: str = os.path.join(base_dir, "database.db") + default_library: str = "screenshots" + typesense_host: str = "localhost" typesense_port: str = "8108" typesense_protocol: str = "http" diff --git a/memos/models.py b/memos/models.py index bc1240b..740e8e6 100644 --- a/memos/models.py +++ b/memos/models.py @@ -15,7 +15,7 @@ from typing import List from .schemas import MetadataSource, MetadataType from sqlalchemy.exc import OperationalError from sqlalchemy.orm import sessionmaker -from .config import get_database_path +from .config import get_database_path, settings class Base(DeclarativeBase): @@ -75,16 +75,14 @@ class EntityModel(Base): "FolderModel", back_populates="entities" ) metadata_entries: Mapped[List["EntityMetadataModel"]] = relationship( - "EntityMetadataModel", - lazy="joined", - cascade="all, delete-orphan" + "EntityMetadataModel", lazy="joined", cascade="all, delete-orphan" ) tags: Mapped[List["TagModel"]] = relationship( - "TagModel", - secondary="entity_tags", + "TagModel", + secondary="entity_tags", lazy="joined", cascade="all, delete", - overlaps="entities" + overlaps="entities", ) # 添加索引 @@ -160,30 +158,61 @@ def init_database(): """Initialize the database.""" db_path = get_database_path() engine = create_engine(f"sqlite:///{db_path}") - + try: Base.metadata.create_all(engine) print(f"Database initialized successfully at {db_path}") - + # Initialize default plugins Session = sessionmaker(bind=engine) with Session() as session: - initialize_default_plugins(session) - + default_plugins = initialize_default_plugins(session) + init_default_libraries(session, default_plugins) + return True except OperationalError as e: print(f"Error initializing database: {e}") return False + def initialize_default_plugins(session): default_plugins = [ - PluginModel(name="buildin_vlm", description="VLM Plugin", webhook_url="/plugins/vlm"), - PluginModel(name="buildin_ocr", description="OCR Plugin", webhook_url="/plugins/ocr"), + PluginModel( + name="buildin_vlm", description="VLM Plugin", webhook_url="/plugins/vlm" + ), + PluginModel( + name="buildin_ocr", description="OCR Plugin", webhook_url="/plugins/ocr" + ), ] for plugin in default_plugins: existing_plugin = session.query(PluginModel).filter_by(name=plugin.name).first() if not existing_plugin: session.add(plugin) - - session.commit() \ No newline at end of file + + session.commit() + + return default_plugins + + +def init_default_libraries(session, default_plugins): + default_libraries = [ + LibraryModel(name=settings.default_library), + ] + + for library in default_libraries: + existing_library = ( + session.query(LibraryModel).filter_by(name=library.name).first() + ) + if not existing_library: + session.add(library) + + for plugin in default_plugins: + bind_response = session.query(PluginModel).filter_by(name=plugin.name).first() + if bind_response: + library_plugin = LibraryPluginModel( + library_id=1, plugin_id=bind_response.id + ) # Assuming library_id=1 for default libraries + session.add(library_plugin) + + session.commit() From 7e43bc086126b865613d636039715e6ea93a1432 Mon Sep 17 00:00:00 2001 From: arkohut <39525455+arkohut@users.noreply.github.com> Date: Mon, 9 Sep 2024 20:30:58 +0800 Subject: [PATCH 06/39] feat(ml_backend): move florence 2 as a default vlm plugin --- memos/config.py | 1 + memos/plugins/ocr/main.py | 1 + memos/plugins/vlm/main.py | 100 +++++++++++++++++++++++++++++++++--- memos_ml_backends/server.py | 4 +- 4 files changed, 96 insertions(+), 10 deletions(-) diff --git a/memos/config.py b/memos/config.py index 458ee8d..db5ffbc 100644 --- a/memos/config.py +++ b/memos/config.py @@ -19,6 +19,7 @@ class VLMSettings(BaseModel): token: str = "" concurrency: int = 4 force_jpeg: bool = False + use_local: bool = True class OCRSettings(BaseModel): diff --git a/memos/plugins/ocr/main.py b/memos/plugins/ocr/main.py index 399d23a..a0b207d 100644 --- a/memos/plugins/ocr/main.py +++ b/memos/plugins/ocr/main.py @@ -145,6 +145,7 @@ async def ocr(entity: Entity, request: Request): } ] }, + timeout=30, ) # Check if the patch request was successful diff --git a/memos/plugins/vlm/main.py b/memos/plugins/vlm/main.py index f24ea70..ad636b6 100644 --- a/memos/plugins/vlm/main.py +++ b/memos/plugins/vlm/main.py @@ -9,6 +9,8 @@ import logging import uvicorn import os import io +import torch +from transformers import AutoModelForCausalLM, AutoProcessor PLUGIN_NAME = "vlm" PROMPT = "描述这张图片的内容" @@ -21,6 +23,10 @@ token = None concurrency = None semaphore = None force_jpeg = None +use_local = None +florence_model = None +florence_processor = None +torch_dtype = None # Configure logger logging.basicConfig(level=logging.INFO) @@ -35,18 +41,18 @@ def image2base64(img_path): with Image.open(img_path) as img: if force_jpeg: # Convert image to RGB mode (removes alpha channel if present) - img = img.convert('RGB') + img = img.convert("RGB") # Save as JPEG in memory buffer = io.BytesIO() - img.save(buffer, format='JPEG') + img.save(buffer, format="JPEG") buffer.seek(0) - encoded_string = base64.b64encode(buffer.getvalue()).decode('utf-8') + encoded_string = base64.b64encode(buffer.getvalue()).decode("utf-8") else: # Use original format buffer = io.BytesIO() img.save(buffer, format=img.format) buffer.seek(0) - encoded_string = base64.b64encode(buffer.getvalue()).decode('utf-8') + encoded_string = base64.b64encode(buffer.getvalue()).decode("utf-8") return encoded_string except Exception as e: logger.error(f"Error processing image {img_path}: {str(e)}") @@ -79,12 +85,57 @@ async def fetch(endpoint: str, client, request_data, headers: Optional[dict] = N async def predict( endpoint: str, modelname: str, img_path: str, token: Optional[str] = None +) -> Optional[str]: + if use_local: + return await predict_local(img_path) + else: + return await predict_remote(endpoint, modelname, img_path, token) + + +async def predict_local(img_path: str) -> Optional[str]: + try: + image = Image.open(img_path) + task_prompt = "" + prompt = task_prompt + "" + + inputs = florence_processor(text=prompt, images=image, return_tensors="pt").to( + florence_model.device, torch_dtype + ) + + generated_ids = florence_model.generate( + input_ids=inputs["input_ids"], + pixel_values=inputs["pixel_values"], + max_new_tokens=1024, + do_sample=False, + num_beams=3, + ) + + generated_texts = florence_processor.batch_decode( + generated_ids, skip_special_tokens=False + ) + + parsed_answer = florence_processor.post_process_generation( + generated_texts[0], + task=task_prompt, + image_size=(image.width, image.height), + ) + + return parsed_answer.get(task_prompt, "") + except Exception as e: + logger.error(f"Error processing image {img_path}: {str(e)}") + return None + + +async def predict_remote( + endpoint: str, modelname: str, img_path: str, token: Optional[str] = None ) -> Optional[str]: img_base64 = image2base64(img_path) if not img_base64: return None - mime_type = "image/jpeg" if force_jpeg else "image/jpeg" # Default to JPEG if force_jpeg is True + mime_type = ( + "image/jpeg" if force_jpeg else "image/jpeg" + ) # Default to JPEG if force_jpeg is True if not force_jpeg: # Only determine MIME type if not forcing JPEG @@ -167,9 +218,9 @@ async def vlm(entity: Entity, request: Request): vlm_result = await predict(endpoint, modelname, entity.filepath, token=token) - print(vlm_result) + logger.info(vlm_result) if not vlm_result: - print(f"No VLM result found for file: {entity.filepath}") + logger.info(f"No VLM result found for file: {entity.filepath}") return {metadata_field_name: "{}"} async with httpx.AsyncClient() as client: @@ -199,14 +250,46 @@ async def vlm(entity: Entity, request: Request): def init_plugin(config): - global modelname, endpoint, token, concurrency, semaphore, force_jpeg + global modelname, endpoint, token, concurrency, semaphore, force_jpeg, use_local, florence_model, florence_processor, torch_dtype + modelname = config.modelname endpoint = config.endpoint token = config.token concurrency = config.concurrency force_jpeg = config.force_jpeg + use_local = config.use_local semaphore = asyncio.Semaphore(concurrency) + if use_local: + # 检测可用的设备 + if torch.cuda.is_available(): + device = torch.device("cuda") + elif torch.backends.mps.is_available(): + device = torch.device("mps") + else: + device = torch.device("cpu") + + torch_dtype = ( + torch.float32 + if ( + torch.cuda.is_available() and torch.cuda.get_device_capability()[0] <= 6 + ) + or (not torch.cuda.is_available() and not torch.backends.mps.is_available()) + else torch.float16 + ) + logger.info(f"Using device: {device}") + + florence_model = AutoModelForCausalLM.from_pretrained( + "microsoft/Florence-2-base-ft", + torch_dtype=torch_dtype, + attn_implementation="sdpa", + trust_remote_code=True, + ).to(device) + florence_processor = AutoProcessor.from_pretrained( + "microsoft/Florence-2-base-ft", trust_remote_code=True + ) + logger.info("Florence model and processor initialized") + # Print the parameters logger.info("VLM plugin initialized") logger.info(f"Model Name: {modelname}") @@ -214,6 +297,7 @@ def init_plugin(config): logger.info(f"Token: {token}") logger.info(f"Concurrency: {concurrency}") logger.info(f"Force JPEG: {force_jpeg}") + logger.info(f"Use Local: {use_local}") if __name__ == "__main__": diff --git a/memos_ml_backends/server.py b/memos_ml_backends/server.py index 7a60825..b936668 100644 --- a/memos_ml_backends/server.py +++ b/memos_ml_backends/server.py @@ -50,7 +50,7 @@ if use_florence_model: torch_dtype=torch_dtype, attn_implementation="sdpa", trust_remote_code=True, - ).to(device) + ).to(device, torch_dtype) florence_processor = AutoProcessor.from_pretrained( "microsoft/Florence-2-base-ft", trust_remote_code=True ) @@ -60,7 +60,7 @@ else: "Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4", torch_dtype=torch_dtype, device_map="auto", - ).to(device) + ).to(device, torch_dtype) qwen2vl_processor = AutoProcessor.from_pretrained( "Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4" ) From 81ba2cd2d23f38f73d59d3fa21c4d59b2a0501e0 Mon Sep 17 00:00:00 2001 From: arkohut <39525455+arkohut@users.noreply.github.com> Date: Mon, 9 Sep 2024 20:58:06 +0800 Subject: [PATCH 07/39] feat: add default library and command shortcuts for it --- memos/commands.py | 72 +++++++++++++++++++++++++++++++++++++++ memos/config.py | 1 + screen_recorder/record.py | 6 ++-- 3 files changed, 77 insertions(+), 2 deletions(-) diff --git a/memos/commands.py b/memos/commands.py index 00753e8..8092918 100644 --- a/memos/commands.py +++ b/memos/commands.py @@ -762,5 +762,77 @@ def init(): print("Initialization failed. Please check the error messages above.") +@app.command("scan") +def scan_default_library(): + """ + Scan the screenshots directory and add it to the library if empty. + """ + # Get the default library + response = httpx.get(f"{BASE_URL}/libraries") + if response.status_code != 200: + print(f"Failed to retrieve libraries: {response.status_code} - {response.text}") + return + + libraries = response.json() + default_library = next( + (lib for lib in libraries if lib["name"] == settings.default_library), None + ) + + if not default_library: + # Create the default library if it doesn't exist + response = httpx.post( + f"{BASE_URL}/libraries", + json={"name": settings.default_library, "folders": []}, + ) + if response.status_code != 200: + print( + f"Failed to create default library: {response.status_code} - {response.text}" + ) + return + default_library = response.json() + + # Check if the library is empty + if not default_library["folders"]: + # Add the screenshots directory to the library + screenshots_dir = Path(settings.screenshots_dir).resolve() + response = httpx.post( + f"{BASE_URL}/libraries/{default_library['id']}/folders", + json={"folders": [str(screenshots_dir)]}, + ) + if response.status_code != 200: + print( + f"Failed to add screenshots directory: {response.status_code} - {response.text}" + ) + return + print(f"Added screenshots directory: {screenshots_dir}") + + # Scan the library + print(f"Scanning library: {default_library['name']}") + scan(default_library["id"], plugins=None, folders=None) + + +@app.command("index") +def index_default_library(): + """ + Index the default library for memos. + """ + # Get the default library + response = httpx.get(f"{BASE_URL}/libraries") + if response.status_code != 200: + print(f"Failed to retrieve libraries: {response.status_code} - {response.text}") + return + + libraries = response.json() + default_library = next( + (lib for lib in libraries if lib["name"] == settings.default_library), None + ) + + if not default_library: + print("Default library does not exist.") + return + + index(default_library["id"], force=False, folders=None) + + if __name__ == "__main__": app() diff --git a/memos/config.py b/memos/config.py index db5ffbc..309a373 100644 --- a/memos/config.py +++ b/memos/config.py @@ -49,6 +49,7 @@ class Settings(BaseSettings): base_dir: str = str(Path.home() / ".memos") database_path: str = os.path.join(base_dir, "database.db") default_library: str = "screenshots" + screenshots_dir: str = os.path.join(base_dir, "screenshots") typesense_host: str = "localhost" typesense_port: str = "8108" diff --git a/screen_recorder/record.py b/screen_recorder/record.py index 8bc101a..ba141a0 100644 --- a/screen_recorder/record.py +++ b/screen_recorder/record.py @@ -9,6 +9,8 @@ from screen_recorder.common import ( take_screenshot, is_screen_locked, ) +from pathlib import Path +from memos.config import settings logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') @@ -51,12 +53,12 @@ def main(): "--threshold", type=int, default=4, help="Threshold for image similarity" ) parser.add_argument( - "--base-dir", type=str, default="~/tmp", help="Base directory for screenshots" + "--base-dir", type=str, help="Base directory for screenshots" ) parser.add_argument("--once", action="store_true", help="Run once and exit") args = parser.parse_args() - base_dir = os.path.expanduser(args.base_dir) + base_dir = os.path.expanduser(args.base_dir) if args.base_dir else settings.screenshots_dir previous_hashes = load_previous_hashes(base_dir) if args.once: From aba3c03c1f5810ebd9ca472c327b19d8532e051b Mon Sep 17 00:00:00 2001 From: arkohut <39525455+arkohut@users.noreply.github.com> Date: Mon, 9 Sep 2024 21:05:50 +0800 Subject: [PATCH 08/39] docs: update readme --- README.md | 67 ++++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 66 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 60c8cfa..a8af81f 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,68 @@ # Memos -A project to index everything to make it like another memory. +A project to index everything to make it like another memory. The project contains two parts: + +1. `screen recorder`: which will take screenshots every 5 seconds and save it at `~/tmp` by default. +2. `memos server`: a web service which can index the screenshots and other files, and provide a web interface to search the records. + +There is a product called [Rewind](https://www.rewind.ai/) which is similar to memos. But memos try to make all the data controlled by yourself. + +## Install + +### Install Typesense + +```bash +export TYPESENSE_API_KEY=xyz + +mkdir "$(pwd)"/typesense-data + +docker run -d -p 8108:8108 \ + -v"$(pwd)"/typesense-data:/data typesense/typesense:27.0 \ + --add-host=host.docker.internal:host-gateway \ + --data-dir /data \ + --api-key=$TYPESENSE_API_KEY \ + --enable-cors +``` + +### Install Memos + +```bash +pip install memos +``` + +## How to use + +To use memos, you need to initialize it first. Make sure you have started `typesense`. + +### 1. Initialize Memos + +```bash +memos init +``` + +This will create a folder `~/.memos` and put the config file there. + +### 2. Start Screen Recorder + +```bash +memos record +``` + +This will start a screen recorder, which will take screenshots every 5 seconds and save it at `~/.memos/screenshots` by default. + +### 3. Start Memos Server + +```bash +memos serve +``` + +This will start a web server, and you can access the web interface at `http://localhost:8080`. + +### Index the screenshots + +```bash +memos scan +memos index +``` + +Refresh the page, and do some search. From 3f08b79bac6de5d137b0cd3443fe8428ff7a5b66 Mon Sep 17 00:00:00 2001 From: arkohut <39525455+arkohut@users.noreply.github.com> Date: Mon, 9 Sep 2024 21:06:38 +0800 Subject: [PATCH 09/39] chore: bump version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 0c200a4..303d8ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "memos" -version = "0.6.0" +version = "0.6.1" description = "A package for memos" readme = "README.md" authors = [{ name = "arkohut" }] From 0677931ca5a204bc5138569afe7ea6c334965c90 Mon Sep 17 00:00:00 2001 From: arkohut <39525455+arkohut@users.noreply.github.com> Date: Mon, 9 Sep 2024 21:08:02 +0800 Subject: [PATCH 10/39] chore: typo --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index a8af81f..2cd2574 100644 --- a/README.md +++ b/README.md @@ -45,7 +45,7 @@ This will create a folder `~/.memos` and put the config file there. ### 2. Start Screen Recorder ```bash -memos record +memos-record ``` This will start a screen recorder, which will take screenshots every 5 seconds and save it at `~/.memos/screenshots` by default. From f9e2b2261bb434fec0bfe30283e20e1c2e768443 Mon Sep 17 00:00:00 2001 From: arkohut <39525455+arkohut@users.noreply.github.com> Date: Mon, 9 Sep 2024 21:09:54 +0800 Subject: [PATCH 11/39] docs: typo --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 2cd2574..f6d3a49 100644 --- a/README.md +++ b/README.md @@ -2,10 +2,10 @@ A project to index everything to make it like another memory. The project contains two parts: -1. `screen recorder`: which will take screenshots every 5 seconds and save it at `~/tmp` by default. -2. `memos server`: a web service which can index the screenshots and other files, and provide a web interface to search the records. +1. `screen recorder`: which takes screenshots every 5 seconds and saves them to `~/.memos/screenshots` by default. +2. `memos server`: a web service that can index the screenshots and other files, providing a web interface to search the records. -There is a product called [Rewind](https://www.rewind.ai/) which is similar to memos. But memos try to make all the data controlled by yourself. +There is a product called [Rewind](https://www.rewind.ai/) that is similar to memos, but memos aims to give you control over all your data. ## Install From 6107c22defad2ae86a458055f480f9f3f616406e Mon Sep 17 00:00:00 2001 From: arkohut <39525455+arkohut@users.noreply.github.com> Date: Mon, 9 Sep 2024 22:26:26 +0800 Subject: [PATCH 12/39] fix: support parse secret --- memos/config.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/memos/config.py b/memos/config.py index 309a373..ec9170b 100644 --- a/memos/config.py +++ b/memos/config.py @@ -50,7 +50,7 @@ class Settings(BaseSettings): database_path: str = os.path.join(base_dir, "database.db") default_library: str = "screenshots" screenshots_dir: str = os.path.join(base_dir, "screenshots") - + typesense_host: str = "localhost" typesense_port: str = "8108" typesense_protocol: str = "http" @@ -98,6 +98,21 @@ def dict_representer(dumper, data): yaml.add_representer(OrderedDict, dict_representer) +# Custom representer for SecretStr +def secret_str_representer(dumper, data): + return dumper.represent_scalar("tag:yaml.org,2002:str", data.get_secret_value()) + + +# Custom constructor for SecretStr +def secret_str_constructor(loader, node): + value = loader.construct_scalar(node) + return SecretStr(value) + + +yaml.add_representer(SecretStr, secret_str_representer) +yaml.add_constructor("tag:yaml.org,2002:str", secret_str_constructor) + + def create_default_config(): config_path = Path.home() / ".memos" / "config.yaml" if not config_path.exists(): From 5b1194f1bc6ad055fe69a9a2207ef1c31cc73721 Mon Sep 17 00:00:00 2001 From: arkohut <39525455+arkohut@users.noreply.github.com> Date: Mon, 9 Sep 2024 22:26:48 +0800 Subject: [PATCH 13/39] fix: skip flash attn https://huggingface.co/microsoft/Florence-2-base-ft/discussions/13#66836a8d67d2ccf03a96df8d --- memos/plugins/vlm/main.py | 33 +++++++++++++++++++++++---------- 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/memos/plugins/vlm/main.py b/memos/plugins/vlm/main.py index ad636b6..6190a73 100644 --- a/memos/plugins/vlm/main.py +++ b/memos/plugins/vlm/main.py @@ -12,6 +12,18 @@ import io import torch from transformers import AutoModelForCausalLM, AutoProcessor +from unittest.mock import patch +from transformers.dynamic_module_utils import get_imports + + +def fixed_get_imports(filename: str | os.PathLike) -> list[str]: + if not str(filename).endswith("modeling_florence2.py"): + return get_imports(filename) + imports = get_imports(filename) + imports.remove("flash_attn") + return imports + + PLUGIN_NAME = "vlm" PROMPT = "描述这张图片的内容" @@ -251,7 +263,7 @@ async def vlm(entity: Entity, request: Request): def init_plugin(config): global modelname, endpoint, token, concurrency, semaphore, force_jpeg, use_local, florence_model, florence_processor, torch_dtype - + modelname = config.modelname endpoint = config.endpoint token = config.token @@ -279,15 +291,16 @@ def init_plugin(config): ) logger.info(f"Using device: {device}") - florence_model = AutoModelForCausalLM.from_pretrained( - "microsoft/Florence-2-base-ft", - torch_dtype=torch_dtype, - attn_implementation="sdpa", - trust_remote_code=True, - ).to(device) - florence_processor = AutoProcessor.from_pretrained( - "microsoft/Florence-2-base-ft", trust_remote_code=True - ) + with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports): + florence_model = AutoModelForCausalLM.from_pretrained( + "microsoft/Florence-2-base-ft", + torch_dtype=torch_dtype, + attn_implementation="sdpa", + trust_remote_code=True, + ).to(device) + florence_processor = AutoProcessor.from_pretrained( + "microsoft/Florence-2-base-ft", trust_remote_code=True + ) logger.info("Florence model and processor initialized") # Print the parameters From a3394b9250e90791f1311bac14cfef6b7be13991 Mon Sep 17 00:00:00 2001 From: arkohut <39525455+arkohut@users.noreply.github.com> Date: Mon, 9 Sep 2024 22:28:41 +0800 Subject: [PATCH 14/39] fix: add dependencies for florence 2 --- pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 303d8ca..83a9801 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,8 @@ dependencies = [ "sentence-transformers", "torch", "numpy", + "timm", + "einops", ] [project.urls] From 939bcc0e818f105c34d844cc8a99dbb2b81b87da Mon Sep 17 00:00:00 2001 From: arkohut <39525455+arkohut@users.noreply.github.com> Date: Mon, 9 Sep 2024 22:29:02 +0800 Subject: [PATCH 15/39] chore: bump 0.6.2 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 83a9801..3317fa6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "memos" -version = "0.6.1" +version = "0.6.2" description = "A package for memos" readme = "README.md" authors = [{ name = "arkohut" }] From 77b1f34a6268b7265aad4cadbb3ea307e84908f0 Mon Sep 17 00:00:00 2001 From: arkohut <39525455+arkohut@users.noreply.github.com> Date: Mon, 9 Sep 2024 22:59:24 +0800 Subject: [PATCH 16/39] fix: include yaml in ocr --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 3317fa6..382e235 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,3 +55,4 @@ include = ["memos*", "screen_recorder*"] [tool.setuptools.package-data] "*" = ["static/**/*"] +"memos.plugins.ocr" = ["*.yaml"] From d08c9e0e3630a1534260e786deb9ab511b9ca95d Mon Sep 17 00:00:00 2001 From: arkohut <39525455+arkohut@users.noreply.github.com> Date: Mon, 9 Sep 2024 22:59:45 +0800 Subject: [PATCH 17/39] chore: bump 0.6.3 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 382e235..88fdf77 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "memos" -version = "0.6.2" +version = "0.6.3" description = "A package for memos" readme = "README.md" authors = [{ name = "arkohut" }] From 3912d165f67a3d5ef50b43190811624a1eef38bc Mon Sep 17 00:00:00 2001 From: arkohut <39525455+arkohut@users.noreply.github.com> Date: Mon, 9 Sep 2024 23:27:40 +0800 Subject: [PATCH 18/39] Revert "fix: skip flash attn https://huggingface.co/microsoft/Florence-2-base-ft/discussions/13#66836a8d67d2ccf03a96df8d" This reverts commit 5b1194f1bc6ad055fe69a9a2207ef1c31cc73721. --- memos/plugins/vlm/main.py | 33 ++++++++++----------------------- 1 file changed, 10 insertions(+), 23 deletions(-) diff --git a/memos/plugins/vlm/main.py b/memos/plugins/vlm/main.py index 6190a73..ad636b6 100644 --- a/memos/plugins/vlm/main.py +++ b/memos/plugins/vlm/main.py @@ -12,18 +12,6 @@ import io import torch from transformers import AutoModelForCausalLM, AutoProcessor -from unittest.mock import patch -from transformers.dynamic_module_utils import get_imports - - -def fixed_get_imports(filename: str | os.PathLike) -> list[str]: - if not str(filename).endswith("modeling_florence2.py"): - return get_imports(filename) - imports = get_imports(filename) - imports.remove("flash_attn") - return imports - - PLUGIN_NAME = "vlm" PROMPT = "描述这张图片的内容" @@ -263,7 +251,7 @@ async def vlm(entity: Entity, request: Request): def init_plugin(config): global modelname, endpoint, token, concurrency, semaphore, force_jpeg, use_local, florence_model, florence_processor, torch_dtype - + modelname = config.modelname endpoint = config.endpoint token = config.token @@ -291,16 +279,15 @@ def init_plugin(config): ) logger.info(f"Using device: {device}") - with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports): - florence_model = AutoModelForCausalLM.from_pretrained( - "microsoft/Florence-2-base-ft", - torch_dtype=torch_dtype, - attn_implementation="sdpa", - trust_remote_code=True, - ).to(device) - florence_processor = AutoProcessor.from_pretrained( - "microsoft/Florence-2-base-ft", trust_remote_code=True - ) + florence_model = AutoModelForCausalLM.from_pretrained( + "microsoft/Florence-2-base-ft", + torch_dtype=torch_dtype, + attn_implementation="sdpa", + trust_remote_code=True, + ).to(device) + florence_processor = AutoProcessor.from_pretrained( + "microsoft/Florence-2-base-ft", trust_remote_code=True + ) logger.info("Florence model and processor initialized") # Print the parameters From a30fe62bc3e59414b163a452df7a132676f2200b Mon Sep 17 00:00:00 2001 From: arkohut <39525455+arkohut@users.noreply.github.com> Date: Mon, 9 Sep 2024 23:39:04 +0800 Subject: [PATCH 19/39] fix: yaml parse related --- memos/config.py | 4 +--- memos/plugins/ocr/main.py | 4 ++-- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/memos/config.py b/memos/config.py index ec9170b..b840254 100644 --- a/memos/config.py +++ b/memos/config.py @@ -102,15 +102,13 @@ yaml.add_representer(OrderedDict, dict_representer) def secret_str_representer(dumper, data): return dumper.represent_scalar("tag:yaml.org,2002:str", data.get_secret_value()) - # Custom constructor for SecretStr def secret_str_constructor(loader, node): value = loader.construct_scalar(node) return SecretStr(value) - +# Register the representer and constructor only for specific fields yaml.add_representer(SecretStr, secret_str_representer) -yaml.add_constructor("tag:yaml.org,2002:str", secret_str_constructor) def create_default_config(): diff --git a/memos/plugins/ocr/main.py b/memos/plugins/ocr/main.py index a0b207d..b40be50 100644 --- a/memos/plugins/ocr/main.py +++ b/memos/plugins/ocr/main.py @@ -183,10 +183,10 @@ def init_plugin(config): ocr_config['Cls']['model_path'] = os.path.join(model_dir, os.path.basename(ocr_config['Cls']['model_path'])) ocr_config['Rec']['model_path'] = os.path.join(model_dir, os.path.basename(ocr_config['Rec']['model_path'])) - # Save the updated config to a temporary file + # Save the updated config to a temporary file with strings wrapped in double quotes temp_config_path = os.path.join(os.path.dirname(__file__), "temp_ppocr.yaml") with open(temp_config_path, 'w') as f: - yaml.safe_dump(ocr_config, f) + yaml.safe_dump(ocr_config, f, default_style='"') ocr = RapidOCR(config_path=temp_config_path) thread_pool = ThreadPoolExecutor(max_workers=concurrency) From 4b189c22d224e0d91a7b6287c00bed79c848c117 Mon Sep 17 00:00:00 2001 From: arkohut <39525455+arkohut@users.noreply.github.com> Date: Mon, 9 Sep 2024 23:39:41 +0800 Subject: [PATCH 20/39] chore: bump 0.6.4 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 88fdf77..bc5223e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "memos" -version = "0.6.3" +version = "0.6.4" description = "A package for memos" readme = "README.md" authors = [{ name = "arkohut" }] From 41c7e136d9eac9370a12dd1458662492de022a38 Mon Sep 17 00:00:00 2001 From: arkohut <39525455+arkohut@users.noreply.github.com> Date: Tue, 10 Sep 2024 00:31:56 +0800 Subject: [PATCH 21/39] fix: skip flash attn https://huggingface.co/microsoft/Florence-2-base-ft/discussions/13#66836a8d67d2ccf03a96df8d --- memos/plugins/vlm/main.py | 33 +++++++++++++++++++++++---------- 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/memos/plugins/vlm/main.py b/memos/plugins/vlm/main.py index ad636b6..6190a73 100644 --- a/memos/plugins/vlm/main.py +++ b/memos/plugins/vlm/main.py @@ -12,6 +12,18 @@ import io import torch from transformers import AutoModelForCausalLM, AutoProcessor +from unittest.mock import patch +from transformers.dynamic_module_utils import get_imports + + +def fixed_get_imports(filename: str | os.PathLike) -> list[str]: + if not str(filename).endswith("modeling_florence2.py"): + return get_imports(filename) + imports = get_imports(filename) + imports.remove("flash_attn") + return imports + + PLUGIN_NAME = "vlm" PROMPT = "描述这张图片的内容" @@ -251,7 +263,7 @@ async def vlm(entity: Entity, request: Request): def init_plugin(config): global modelname, endpoint, token, concurrency, semaphore, force_jpeg, use_local, florence_model, florence_processor, torch_dtype - + modelname = config.modelname endpoint = config.endpoint token = config.token @@ -279,15 +291,16 @@ def init_plugin(config): ) logger.info(f"Using device: {device}") - florence_model = AutoModelForCausalLM.from_pretrained( - "microsoft/Florence-2-base-ft", - torch_dtype=torch_dtype, - attn_implementation="sdpa", - trust_remote_code=True, - ).to(device) - florence_processor = AutoProcessor.from_pretrained( - "microsoft/Florence-2-base-ft", trust_remote_code=True - ) + with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports): + florence_model = AutoModelForCausalLM.from_pretrained( + "microsoft/Florence-2-base-ft", + torch_dtype=torch_dtype, + attn_implementation="sdpa", + trust_remote_code=True, + ).to(device) + florence_processor = AutoProcessor.from_pretrained( + "microsoft/Florence-2-base-ft", trust_remote_code=True + ) logger.info("Florence model and processor initialized") # Print the parameters From 93060af86adaba36864308bac33b0cdeb0d10013 Mon Sep 17 00:00:00 2001 From: arkohut <39525455+arkohut@users.noreply.github.com> Date: Tue, 10 Sep 2024 00:33:37 +0800 Subject: [PATCH 22/39] fix: include onnx models --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index bc5223e..f50f59d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,4 +55,4 @@ include = ["memos*", "screen_recorder*"] [tool.setuptools.package-data] "*" = ["static/**/*"] -"memos.plugins.ocr" = ["*.yaml"] +"memos.plugins.ocr" = ["*.yaml", "*.onnx"] From aa9cc028c6bc3289b8a2032602f869dcc4311416 Mon Sep 17 00:00:00 2001 From: arkohut <39525455+arkohut@users.noreply.github.com> Date: Tue, 10 Sep 2024 00:34:03 +0800 Subject: [PATCH 23/39] chore: bump 0.6.5 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index f50f59d..76c0100 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "memos" -version = "0.6.4" +version = "0.6.5" description = "A package for memos" readme = "README.md" authors = [{ name = "arkohut" }] From a8c42d854618e295cfe5de8b3af5608af6dbbde0 Mon Sep 17 00:00:00 2001 From: arkohut <39525455+arkohut@users.noreply.github.com> Date: Tue, 10 Sep 2024 00:40:12 +0800 Subject: [PATCH 24/39] fix: include onnx models --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 76c0100..9a41324 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,4 +55,4 @@ include = ["memos*", "screen_recorder*"] [tool.setuptools.package-data] "*" = ["static/**/*"] -"memos.plugins.ocr" = ["*.yaml", "*.onnx"] +"memos.plugins.ocr" = ["*.yaml", "models/*.onnx"] From 95047de12ba9bb48a832f303917617533437f973 Mon Sep 17 00:00:00 2001 From: arkohut <39525455+arkohut@users.noreply.github.com> Date: Tue, 10 Sep 2024 00:40:34 +0800 Subject: [PATCH 25/39] chore: bump 0.6.6 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 9a41324..1bcac45 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "memos" -version = "0.6.5" +version = "0.6.6" description = "A package for memos" readme = "README.md" authors = [{ name = "arkohut" }] From 3a408f0be70cfa139a5aa3e51dae5e301eddb366 Mon Sep 17 00:00:00 2001 From: arkohut <39525455+arkohut@users.noreply.github.com> Date: Tue, 10 Sep 2024 00:50:42 +0800 Subject: [PATCH 26/39] docs: tell default user / password --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index f6d3a49..d763168 100644 --- a/README.md +++ b/README.md @@ -56,7 +56,8 @@ This will start a screen recorder, which will take screenshots every 5 seconds a memos serve ``` -This will start a web server, and you can access the web interface at `http://localhost:8080`. +This will start a web server, and you can access the web interface at `http://localhost:8080`. +The default username and password is `admin` and `changeme`. ### Index the screenshots From f5aae87f40933ff0f27776c3fc32842d889e1289 Mon Sep 17 00:00:00 2001 From: arkohut <39525455+arkohut@users.noreply.github.com> Date: Tue, 10 Sep 2024 00:54:20 +0800 Subject: [PATCH 27/39] feat: use concurrency of 1 by default --- memos/config.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/memos/config.py b/memos/config.py index b840254..a572e6f 100644 --- a/memos/config.py +++ b/memos/config.py @@ -17,7 +17,7 @@ class VLMSettings(BaseModel): modelname: str = "moondream" endpoint: str = "http://localhost:11434" token: str = "" - concurrency: int = 4 + concurrency: int = 1 force_jpeg: bool = False use_local: bool = True @@ -26,7 +26,7 @@ class OCRSettings(BaseModel): enabled: bool = True endpoint: str = "http://localhost:5555/predict" token: str = "" - concurrency: int = 4 + concurrency: int = 1 use_local: bool = True use_gpu: bool = False force_jpeg: bool = False @@ -71,7 +71,7 @@ class Settings(BaseSettings): # Embedding settings embedding: EmbeddingSettings = EmbeddingSettings() - batchsize: int = 4 + batchsize: int = 1 auth_username: str = "admin" auth_password: SecretStr = SecretStr("changeme") From 241911d1d22dd0aed69946c96df71ae35e2982ac Mon Sep 17 00:00:00 2001 From: arkohut <39525455+arkohut@users.noreply.github.com> Date: Tue, 10 Sep 2024 00:58:57 +0800 Subject: [PATCH 28/39] feat(index): use different bs for embedding --- memos/commands.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/memos/commands.py b/memos/commands.py index 8092918..d44dc61 100644 --- a/memos/commands.py +++ b/memos/commands.py @@ -548,6 +548,9 @@ def index( library_id: int, folders: List[int] = typer.Option(None, "--folder", "-f"), force: bool = typer.Option(False, "--force", help="Force update all indexes"), + batchsize: int = typer.Option( + 4, "--batchsize", "-bs", help="Number of entities to index in a batch" + ), ): print(f"Indexing library {library_id}") @@ -607,9 +610,8 @@ def index( pbar.refresh() # Index each entity - batch_size = settings.batchsize - for i in range(0, len(entities), batch_size): - batch = entities[i : i + batch_size] + for i in range(0, len(entities), batchsize): + batch = entities[i : i + batchsize] to_index = [] for entity in batch: From 378e5bf445b800f2928b2d5363c846474295b50b Mon Sep 17 00:00:00 2001 From: arkohut <39525455+arkohut@users.noreply.github.com> Date: Tue, 10 Sep 2024 01:32:00 +0800 Subject: [PATCH 29/39] feat: support modelscope --- memos/config.py | 2 ++ memos/plugins/embedding/main.py | 21 ++++++++++++++++++--- memos/plugins/vlm/main.py | 16 +++++++++++++--- pyproject.toml | 1 + 4 files changed, 34 insertions(+), 6 deletions(-) diff --git a/memos/config.py b/memos/config.py index a572e6f..6986b7a 100644 --- a/memos/config.py +++ b/memos/config.py @@ -51,6 +51,8 @@ class Settings(BaseSettings): default_library: str = "screenshots" screenshots_dir: str = os.path.join(base_dir, "screenshots") + use_modelscope: bool = False + typesense_host: str = "localhost" typesense_port: str = "8108" typesense_protocol: str = "http" diff --git a/memos/plugins/embedding/main.py b/memos/plugins/embedding/main.py index 89e837a..c7fbec4 100644 --- a/memos/plugins/embedding/main.py +++ b/memos/plugins/embedding/main.py @@ -7,6 +7,7 @@ from sentence_transformers import SentenceTransformer import torch import numpy as np from pydantic import BaseModel +from modelscope import snapshot_download PLUGIN_NAME = "embedding" @@ -19,6 +20,7 @@ num_dim = None endpoint = None model_name = None device = None +use_modelscope = None # Configure logger logging.basicConfig(level=logging.INFO) @@ -26,7 +28,7 @@ logger = logging.getLogger(__name__) def init_embedding_model(): - global model, device + global model, device, use_modelscope if torch.cuda.is_available(): device = torch.device("cuda") elif torch.backends.mps.is_available(): @@ -34,7 +36,14 @@ def init_embedding_model(): else: device = torch.device("cpu") - model = SentenceTransformer(model_name, trust_remote_code=True) + if use_modelscope: + model_dir = snapshot_download(model_name) + logger.info(f"Model downloaded from ModelScope to: {model_dir}") + else: + model_dir = model_name + logger.info(f"Using model: {model_dir}") + + model = SentenceTransformer(model_dir, trust_remote_code=True) model.to(device) logger.info(f"Embedding model initialized on device: {device}") @@ -80,11 +89,12 @@ async def embed(request: EmbeddingRequest): def init_plugin(config): - global enabled, num_dim, endpoint, model_name + global enabled, num_dim, endpoint, model_name, use_modelscope enabled = config.enabled num_dim = config.num_dim endpoint = config.endpoint model_name = config.model + use_modelscope = config.use_modelscope if enabled: init_embedding_model() @@ -94,6 +104,7 @@ def init_plugin(config): logger.info(f"Number of dimensions: {num_dim}") logger.info(f"Endpoint: {endpoint}") logger.info(f"Model: {model_name}") + logger.info(f"Use ModelScope: {use_modelscope}") if __name__ == "__main__": @@ -113,6 +124,9 @@ if __name__ == "__main__": parser.add_argument( "--port", type=int, default=8000, help="Port to run the server on" ) + parser.add_argument( + "--use-modelscope", action="store_true", help="Use ModelScope to download the model" + ) args = parser.parse_args() @@ -122,6 +136,7 @@ if __name__ == "__main__": self.num_dim = args.num_dim self.endpoint = "what ever" self.model = args.model + self.use_modelscope = args.use_modelscope init_plugin(Config(args)) diff --git a/memos/plugins/vlm/main.py b/memos/plugins/vlm/main.py index 6190a73..fc905a4 100644 --- a/memos/plugins/vlm/main.py +++ b/memos/plugins/vlm/main.py @@ -14,7 +14,7 @@ from transformers import AutoModelForCausalLM, AutoProcessor from unittest.mock import patch from transformers.dynamic_module_utils import get_imports - +from modelscope import snapshot_download def fixed_get_imports(filename: str | os.PathLike) -> list[str]: if not str(filename).endswith("modeling_florence2.py"): @@ -270,6 +270,7 @@ def init_plugin(config): concurrency = config.concurrency force_jpeg = config.force_jpeg use_local = config.use_local + use_modelscope = config.use_modelscope semaphore = asyncio.Semaphore(concurrency) if use_local: @@ -291,15 +292,22 @@ def init_plugin(config): ) logger.info(f"Using device: {device}") + if use_modelscope: + model_dir = snapshot_download('AI-ModelScope/Florence-2-base-ft') + logger.info(f"Model downloaded from ModelScope to: {model_dir}") + else: + model_dir = "microsoft/Florence-2-base-ft" + logger.info(f"Using model: {model_dir}") + with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports): florence_model = AutoModelForCausalLM.from_pretrained( - "microsoft/Florence-2-base-ft", + model_dir, torch_dtype=torch_dtype, attn_implementation="sdpa", trust_remote_code=True, ).to(device) florence_processor = AutoProcessor.from_pretrained( - "microsoft/Florence-2-base-ft", trust_remote_code=True + model_dir, trust_remote_code=True ) logger.info("Florence model and processor initialized") @@ -311,6 +319,7 @@ def init_plugin(config): logger.info(f"Concurrency: {concurrency}") logger.info(f"Force JPEG: {force_jpeg}") logger.info(f"Use Local: {use_local}") + logger.info(f"Use ModelScope: {use_modelscope}") if __name__ == "__main__": @@ -329,6 +338,7 @@ if __name__ == "__main__": parser.add_argument( "--port", type=int, default=8000, help="Port to run the server on" ) + parser.add_argument("--use-modelscope", action="store_true", help="Use ModelScope to download the model") args = parser.parse_args() diff --git a/pyproject.toml b/pyproject.toml index 1bcac45..5e44755 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,7 @@ dependencies = [ "numpy", "timm", "einops", + "modelscope", ] [project.urls] From 513b41e3001c936994e7051496ac51609595ad5f Mon Sep 17 00:00:00 2001 From: arkohut <39525455+arkohut@users.noreply.github.com> Date: Tue, 10 Sep 2024 01:32:22 +0800 Subject: [PATCH 30/39] chore: bump 0.6.7 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 5e44755..6a7509c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "memos" -version = "0.6.6" +version = "0.6.7" description = "A package for memos" readme = "README.md" authors = [{ name = "arkohut" }] From 6bbf1a8a6805a7e808fd35769518d529909b6c30 Mon Sep 17 00:00:00 2001 From: arkohut <39525455+arkohut@users.noreply.github.com> Date: Tue, 10 Sep 2024 01:59:37 +0800 Subject: [PATCH 31/39] fix: use modelscope for sub config --- memos/config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/memos/config.py b/memos/config.py index 6986b7a..aa57145 100644 --- a/memos/config.py +++ b/memos/config.py @@ -20,6 +20,7 @@ class VLMSettings(BaseModel): concurrency: int = 1 force_jpeg: bool = False use_local: bool = True + use_modelscope: bool = False class OCRSettings(BaseModel): @@ -37,6 +38,7 @@ class EmbeddingSettings(BaseModel): num_dim: int = 768 endpoint: str = "http://localhost:11434/api/embed" model: str = "jinaai/jina-embeddings-v2-base-zh" + use_modelscope: bool = False class Settings(BaseSettings): @@ -51,8 +53,6 @@ class Settings(BaseSettings): default_library: str = "screenshots" screenshots_dir: str = os.path.join(base_dir, "screenshots") - use_modelscope: bool = False - typesense_host: str = "localhost" typesense_port: str = "8108" typesense_protocol: str = "http" From 23612d1fd5a0c4ec6da5da697cc28c19b3777a18 Mon Sep 17 00:00:00 2001 From: arkohut <39525455+arkohut@users.noreply.github.com> Date: Tue, 10 Sep 2024 01:59:59 +0800 Subject: [PATCH 32/39] chore: bump 0.6.8 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 6a7509c..bed964e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "memos" -version = "0.6.7" +version = "0.6.8" description = "A package for memos" readme = "README.md" authors = [{ name = "arkohut" }] From 7f78eb1a4ad6a01e2eaaf7d6675583ff928cad89 Mon Sep 17 00:00:00 2001 From: arkohut <39525455+arkohut@users.noreply.github.com> Date: Tue, 10 Sep 2024 11:54:56 +0800 Subject: [PATCH 33/39] chore: update yaml generate config --- memos/plugins/ocr/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/memos/plugins/ocr/main.py b/memos/plugins/ocr/main.py index b40be50..86bea9f 100644 --- a/memos/plugins/ocr/main.py +++ b/memos/plugins/ocr/main.py @@ -186,7 +186,7 @@ def init_plugin(config): # Save the updated config to a temporary file with strings wrapped in double quotes temp_config_path = os.path.join(os.path.dirname(__file__), "temp_ppocr.yaml") with open(temp_config_path, 'w') as f: - yaml.safe_dump(ocr_config, f, default_style='"') + yaml.safe_dump(ocr_config, f) ocr = RapidOCR(config_path=temp_config_path) thread_pool = ThreadPoolExecutor(max_workers=concurrency) From 5882950c39b2487a4013d9fcc75ff47cf51029ef Mon Sep 17 00:00:00 2001 From: arkohut <39525455+arkohut@users.noreply.github.com> Date: Tue, 10 Sep 2024 12:39:08 +0800 Subject: [PATCH 34/39] feat(cli): add extra args --- memos/commands.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/memos/commands.py b/memos/commands.py index d44dc61..47242be 100644 --- a/memos/commands.py +++ b/memos/commands.py @@ -765,7 +765,7 @@ def init(): @app.command("scan") -def scan_default_library(): +def scan_default_library(force: bool = False): """ Scan the screenshots directory and add it to the library if empty. """ @@ -810,11 +810,16 @@ def scan_default_library(): # Scan the library print(f"Scanning library: {default_library['name']}") - scan(default_library["id"], plugins=None, folders=None) + scan(default_library["id"], plugins=None, folders=None, force=force) @app.command("index") -def index_default_library(): +def index_default_library( + batchsize: int = typer.Option( + 4, "--batchsize", "-bs", help="Number of entities to index in a batch" + ), + force: bool = typer.Option(False, "--force", help="Force update all indexes"), +): """ Index the default library for memos. """ @@ -833,7 +838,7 @@ def index_default_library(): print("Default library does not exist.") return - index(default_library["id"], force=False, folders=None) + index(default_library["id"], force=force, folders=None, batchsize=batchsize) if __name__ == "__main__": From fca387b22d7f574ce81d5542e6460766b67973c2 Mon Sep 17 00:00:00 2001 From: arkohut <39525455+arkohut@users.noreply.github.com> Date: Tue, 10 Sep 2024 13:55:52 +0800 Subject: [PATCH 35/39] feat: support bind plugin by name --- memos/commands.py | 16 ++++++++++------ memos/schemas.py | 31 ++++++++++++++++++++++++++++--- memos/server.py | 21 +++++++++++++++++++-- 3 files changed, 57 insertions(+), 11 deletions(-) diff --git a/memos/commands.py b/memos/commands.py index 47242be..744bd9d 100644 --- a/memos/commands.py +++ b/memos/commands.py @@ -723,18 +723,22 @@ def create(name: str, webhook_url: str, description: str = ""): @plugin_app.command("bind") def bind( library_id: int = typer.Option(..., "--lib", help="ID of the library"), - plugin_id: int = typer.Option(..., "--plugin", help="ID of the plugin"), + plugin: str = typer.Option(..., "--plugin", help="ID or name of the plugin"), ): + try: + plugin_id = int(plugin) + plugin_param = {"plugin_id": plugin_id} + except ValueError: + plugin_param = {"plugin_name": plugin} + response = httpx.post( f"{BASE_URL}/libraries/{library_id}/plugins", - json={"plugin_id": plugin_id}, + json=plugin_param, ) - if 200 <= response.status_code < 300: + if response.status_code == 204: print("Plugin bound to library successfully") else: - print( - f"Failed to bind plugin to library: {response.status_code} - {response.text}" - ) + print(f"Failed to bind plugin to library: {response.status_code} - {response.text}") @plugin_app.command("unbind") diff --git a/memos/schemas.py b/memos/schemas.py index edd040f..b1b6042 100644 --- a/memos/schemas.py +++ b/memos/schemas.py @@ -1,4 +1,11 @@ -from pydantic import BaseModel, ConfigDict, DirectoryPath, HttpUrl, Field +from pydantic import ( + BaseModel, + ConfigDict, + DirectoryPath, + HttpUrl, + Field, + model_validator, +) from typing import List, Optional, Any, Dict from datetime import datetime from enum import Enum @@ -79,7 +86,18 @@ class NewPluginParam(BaseModel): class NewLibraryPluginParam(BaseModel): - plugin_id: int + plugin_id: Optional[int] = None + plugin_name: Optional[str] = None + + @model_validator(mode="after") + def check_either_id_or_name(self): + plugin_id = self.plugin_id + plugin_name = self.plugin_name + if not (plugin_id or plugin_name): + raise ValueError("Either plugin_id or plugin_name must be provided") + if plugin_id is not None and plugin_name is not None: + raise ValueError("Only one of plugin_id or plugin_name should be provided") + return self class Folder(BaseModel): @@ -214,15 +232,18 @@ class FacetCount(BaseModel): highlighted: str value: str + class FacetStats(BaseModel): total_values: int + class Facet(BaseModel): counts: List[FacetCount] field_name: str sampled: bool stats: FacetStats + class TextMatchInfo(BaseModel): best_field_score: str best_field_weight: int @@ -232,9 +253,11 @@ class TextMatchInfo(BaseModel): tokens_matched: int typo_prefix_score: int + class HybridSearchInfo(BaseModel): rank_fusion_score: float + class SearchHit(BaseModel): document: EntitySearchResult highlight: Dict[str, Any] = {} @@ -243,12 +266,14 @@ class SearchHit(BaseModel): text_match: Optional[int] = None text_match_info: Optional[TextMatchInfo] = None + class RequestParams(BaseModel): collection_name: str first_q: str per_page: int q: str + class SearchResult(BaseModel): facet_counts: List[Facet] found: int @@ -257,4 +282,4 @@ class SearchResult(BaseModel): page: int request_params: RequestParams search_cutoff: bool - search_time_ms: int \ No newline at end of file + search_time_ms: int diff --git a/memos/server.py b/memos/server.py index d48e355..f40982e 100644 --- a/memos/server.py +++ b/memos/server.py @@ -602,12 +602,29 @@ def add_library_plugin( library_id: int, new_plugin: NewLibraryPluginParam, db: Session = Depends(get_db) ): library = crud.get_library_by_id(library_id, db) - if any(plugin.id == new_plugin.plugin_id for plugin in library.plugins): + if library is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Library not found" + ) + + plugin = None + if new_plugin.plugin_id is not None: + plugin = crud.get_plugin_by_id(new_plugin.plugin_id, db) + elif new_plugin.plugin_name is not None: + plugin = crud.get_plugin_by_name(new_plugin.plugin_name, db) + + if plugin is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Plugin not found" + ) + + if any(p.id == plugin.id for p in library.plugins): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Plugin already exists in the library", ) - crud.add_plugin_to_library(library_id, new_plugin.plugin_id, db) + + crud.add_plugin_to_library(library_id, plugin.id, db) @app.delete( From 167fd3105358e7fbb7a3559c29f01b105470fb64 Mon Sep 17 00:00:00 2001 From: arkohut <39525455+arkohut@users.noreply.github.com> Date: Tue, 10 Sep 2024 15:47:36 +0800 Subject: [PATCH 36/39] feat: add builtin plugins for default library --- memos/commands.py | 3 +++ memos/config.py | 4 +++- memos/models.py | 4 ++-- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/memos/commands.py b/memos/commands.py index 744bd9d..7167234 100644 --- a/memos/commands.py +++ b/memos/commands.py @@ -797,6 +797,9 @@ def scan_default_library(force: bool = False): return default_library = response.json() + for plugin in settings.default_plugins: + bind(default_library["id"], plugin) + # Check if the library is empty if not default_library["folders"]: # Add the screenshots directory to the library diff --git a/memos/config.py b/memos/config.py index aa57145..b979852 100644 --- a/memos/config.py +++ b/memos/config.py @@ -1,6 +1,6 @@ import os from pathlib import Path -from typing import Tuple, Type +from typing import Tuple, Type, List from pydantic_settings import ( BaseSettings, PydanticBaseSettingsSource, @@ -78,6 +78,8 @@ class Settings(BaseSettings): auth_username: str = "admin" auth_password: SecretStr = SecretStr("changeme") + default_plugins: List[str] = ["builtin_vlm", "builtin_ocr"] + @classmethod def settings_customise_sources( cls, diff --git a/memos/models.py b/memos/models.py index 740e8e6..72bc91c 100644 --- a/memos/models.py +++ b/memos/models.py @@ -178,10 +178,10 @@ def init_database(): def initialize_default_plugins(session): default_plugins = [ PluginModel( - name="buildin_vlm", description="VLM Plugin", webhook_url="/plugins/vlm" + name="builtin_vlm", description="VLM Plugin", webhook_url="/plugins/vlm" ), PluginModel( - name="buildin_ocr", description="OCR Plugin", webhook_url="/plugins/ocr" + name="builtin_ocr", description="OCR Plugin", webhook_url="/plugins/ocr" ), ] From c24d9cc2e917ac2eed51ddf95f6ccf6469b42d0f Mon Sep 17 00:00:00 2001 From: arkohut <39525455+arkohut@users.noreply.github.com> Date: Tue, 10 Sep 2024 18:28:47 +0800 Subject: [PATCH 37/39] chore: bump 0.6.9 --- README.md | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index d763168..1836d26 100644 --- a/README.md +++ b/README.md @@ -56,7 +56,7 @@ This will start a screen recorder, which will take screenshots every 5 seconds a memos serve ``` -This will start a web server, and you can access the web interface at `http://localhost:8080`. +This will start a web server, and you can access the web interface at `http://localhost:8080`. The default username and password is `admin` and `changeme`. ### Index the screenshots diff --git a/pyproject.toml b/pyproject.toml index bed964e..bcc2ce6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "memos" -version = "0.6.8" +version = "0.6.9" description = "A package for memos" readme = "README.md" authors = [{ name = "arkohut" }] From 30f9a45d8a2bfd43cbfada916b8b042b3c6fcfa5 Mon Sep 17 00:00:00 2001 From: arkohut <39525455+arkohut@users.noreply.github.com> Date: Wed, 11 Sep 2024 17:50:33 +0800 Subject: [PATCH 38/39] feat: use mss instead of imagegrab --- pyproject.toml | 1 + screen_recorder/common.py | 84 ++++++++++++++++++--------------------- 2 files changed, 39 insertions(+), 46 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index bcc2ce6..3f2bf4b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,7 @@ dependencies = [ "timm", "einops", "modelscope", + "mss", ] [project.urls] diff --git a/screen_recorder/common.py b/screen_recorder/common.py index 99e670a..f87ca44 100644 --- a/screen_recorder/common.py +++ b/screen_recorder/common.py @@ -4,11 +4,11 @@ import time import logging import platform import subprocess -from PIL import Image, ImageGrab +from PIL import Image import imagehash from memos.utils import write_image_metadata -from screeninfo import get_monitors import ctypes +from mss import mss if platform.system() == "Windows": import win32gui @@ -173,57 +173,49 @@ def take_screenshot_windows( app_name, window_title, ): - for monitor in get_monitors(): - safe_monitor_name = "".join( - c for c in monitor.name if c.isalnum() or c in ("_", "-") - ) - logging.info(f"Processing monitor: {safe_monitor_name}") + with mss() as sct: + for i, monitor in enumerate(sct.monitors[1:], 1): # Skip the first monitor (entire screen) + safe_monitor_name = f"monitor_{i}" + logging.info(f"Processing monitor: {safe_monitor_name}") - webp_filename = os.path.join( - base_dir, date, f"screenshot-{timestamp}-of-{safe_monitor_name}.webp" - ) - - img = ImageGrab.grab( - bbox=( - monitor.x, - monitor.y, - monitor.x + monitor.width, - monitor.y + monitor.height, + webp_filename = os.path.join( + base_dir, date, f"screenshot-{timestamp}-of-{safe_monitor_name}.webp" ) - ) - img = img.convert("RGB") - current_hash = str(imagehash.phash(img)) - if ( - safe_monitor_name in previous_hashes - and imagehash.hex_to_hash(current_hash) - - imagehash.hex_to_hash(previous_hashes[safe_monitor_name]) - < threshold - ): - logging.info( - f"Screenshot for {safe_monitor_name} is similar to the previous one. Skipping." + img = sct.grab(monitor) + img = Image.frombytes("RGB", img.size, img.bgra, "raw", "BGRX") + current_hash = str(imagehash.phash(img)) + + if ( + safe_monitor_name in previous_hashes + and imagehash.hex_to_hash(current_hash) + - imagehash.hex_to_hash(previous_hashes[safe_monitor_name]) + < threshold + ): + logging.info( + f"Screenshot for {safe_monitor_name} is similar to the previous one. Skipping." + ) + yield safe_monitor_name, None, "Skipped (similar to previous)" + continue + + previous_hashes[safe_monitor_name] = current_hash + screen_sequences[safe_monitor_name] = ( + screen_sequences.get(safe_monitor_name, 0) + 1 ) - yield safe_monitor_name, None, "Skipped (similar to previous)" - continue - previous_hashes[safe_monitor_name] = current_hash - screen_sequences[safe_monitor_name] = ( - screen_sequences.get(safe_monitor_name, 0) + 1 - ) + metadata = { + "timestamp": timestamp, + "active_app": app_name, + "active_window": window_title, + "screen_name": safe_monitor_name, + "sequence": screen_sequences[safe_monitor_name], + } - metadata = { - "timestamp": timestamp, - "active_app": app_name, - "active_window": window_title, - "screen_name": safe_monitor_name, - "sequence": screen_sequences[safe_monitor_name], - } + img.save(webp_filename, format="WebP", quality=85) + write_image_metadata(webp_filename, metadata) + save_screen_sequences(base_dir, screen_sequences, date) - img.save(webp_filename, format="WebP", quality=85) - write_image_metadata(webp_filename, metadata) - save_screen_sequences(base_dir, screen_sequences, date) - - yield safe_monitor_name, webp_filename, "Saved" + yield safe_monitor_name, webp_filename, "Saved" def take_screenshot( From 40fad9176f4f100d4e4f0a9514da4e26b437c622 Mon Sep 17 00:00:00 2001 From: arkohut <39525455+arkohut@users.noreply.github.com> Date: Wed, 11 Sep 2024 17:50:46 +0800 Subject: [PATCH 39/39] chore: bump 0.6.10 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 3f2bf4b..1a07432 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "memos" -version = "0.6.9" +version = "0.6.10" description = "A package for memos" readme = "README.md" authors = [{ name = "arkohut" }]