diff --git a/README.md b/README.md index 60c8cfa..1836d26 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,69 @@ # Memos -A project to index everything to make it like another memory. +A project to index everything to make it like another memory. The project contains two parts: + +1. `screen recorder`: which takes screenshots every 5 seconds and saves them to `~/.memos/screenshots` by default. +2. `memos server`: a web service that can index the screenshots and other files, providing a web interface to search the records. + +There is a product called [Rewind](https://www.rewind.ai/) that is similar to memos, but memos aims to give you control over all your data. + +## Install + +### Install Typesense + +```bash +export TYPESENSE_API_KEY=xyz + +mkdir "$(pwd)"/typesense-data + +docker run -d -p 8108:8108 \ + -v"$(pwd)"/typesense-data:/data typesense/typesense:27.0 \ + --add-host=host.docker.internal:host-gateway \ + --data-dir /data \ + --api-key=$TYPESENSE_API_KEY \ + --enable-cors +``` + +### Install Memos + +```bash +pip install memos +``` + +## How to use + +To use memos, you need to initialize it first. Make sure you have started `typesense`. + +### 1. Initialize Memos + +```bash +memos init +``` + +This will create a folder `~/.memos` and put the config file there. + +### 2. Start Screen Recorder + +```bash +memos-record +``` + +This will start a screen recorder, which will take screenshots every 5 seconds and save it at `~/.memos/screenshots` by default. + +### 3. Start Memos Server + +```bash +memos serve +``` + +This will start a web server, and you can access the web interface at `http://localhost:8080`. +The default username and password is `admin` and `changeme`. + +### Index the screenshots + +```bash +memos scan +memos index +``` + +Refresh the page, and do some search. diff --git a/memos/commands.py b/memos/commands.py index fd9c9bc..7167234 100644 --- a/memos/commands.py +++ b/memos/commands.py @@ -1,6 +1,5 @@ import asyncio import os -import time import logging from datetime import datetime, timezone from pathlib import Path @@ -549,6 +548,9 @@ def index( library_id: int, folders: List[int] = typer.Option(None, "--folder", "-f"), force: bool = typer.Option(False, "--force", help="Force update all indexes"), + batchsize: int = typer.Option( + 4, "--batchsize", "-bs", help="Number of entities to index in a batch" + ), ): print(f"Indexing library {library_id}") @@ -608,9 +610,8 @@ def index( pbar.refresh() # Index each entity - batch_size = settings.batchsize - for i in range(0, len(entities), batch_size): - batch = entities[i : i + batch_size] + for i in range(0, len(entities), batchsize): + batch = entities[i : i + batchsize] to_index = [] for entity in batch: @@ -722,18 +723,22 @@ def create(name: str, webhook_url: str, description: str = ""): @plugin_app.command("bind") def bind( library_id: int = typer.Option(..., "--lib", help="ID of the library"), - plugin_id: int = typer.Option(..., "--plugin", help="ID of the plugin"), + plugin: str = typer.Option(..., "--plugin", help="ID or name of the plugin"), ): + try: + plugin_id = int(plugin) + plugin_param = {"plugin_id": plugin_id} + except ValueError: + plugin_param = {"plugin_name": plugin} + response = httpx.post( f"{BASE_URL}/libraries/{library_id}/plugins", - json={"plugin_id": plugin_id}, + json=plugin_param, ) - if 200 <= response.status_code < 300: + if response.status_code == 204: print("Plugin bound to library successfully") else: - print( - f"Failed to bind plugin to library: {response.status_code} - {response.text}" - ) + print(f"Failed to bind plugin to library: {response.status_code} - {response.text}") @plugin_app.command("unbind") @@ -763,5 +768,85 @@ def init(): print("Initialization failed. Please check the error messages above.") +@app.command("scan") +def scan_default_library(force: bool = False): + """ + Scan the screenshots directory and add it to the library if empty. + """ + # Get the default library + response = httpx.get(f"{BASE_URL}/libraries") + if response.status_code != 200: + print(f"Failed to retrieve libraries: {response.status_code} - {response.text}") + return + + libraries = response.json() + default_library = next( + (lib for lib in libraries if lib["name"] == settings.default_library), None + ) + + if not default_library: + # Create the default library if it doesn't exist + response = httpx.post( + f"{BASE_URL}/libraries", + json={"name": settings.default_library, "folders": []}, + ) + if response.status_code != 200: + print( + f"Failed to create default library: {response.status_code} - {response.text}" + ) + return + default_library = response.json() + + for plugin in settings.default_plugins: + bind(default_library["id"], plugin) + + # Check if the library is empty + if not default_library["folders"]: + # Add the screenshots directory to the library + screenshots_dir = Path(settings.screenshots_dir).resolve() + response = httpx.post( + f"{BASE_URL}/libraries/{default_library['id']}/folders", + json={"folders": [str(screenshots_dir)]}, + ) + if response.status_code != 200: + print( + f"Failed to add screenshots directory: {response.status_code} - {response.text}" + ) + return + print(f"Added screenshots directory: {screenshots_dir}") + + # Scan the library + print(f"Scanning library: {default_library['name']}") + scan(default_library["id"], plugins=None, folders=None, force=force) + + +@app.command("index") +def index_default_library( + batchsize: int = typer.Option( + 4, "--batchsize", "-bs", help="Number of entities to index in a batch" + ), + force: bool = typer.Option(False, "--force", help="Force update all indexes"), +): + """ + Index the default library for memos. + """ + # Get the default library + response = httpx.get(f"{BASE_URL}/libraries") + if response.status_code != 200: + print(f"Failed to retrieve libraries: {response.status_code} - {response.text}") + return + + libraries = response.json() + default_library = next( + (lib for lib in libraries if lib["name"] == settings.default_library), None + ) + + if not default_library: + print("Default library does not exist.") + return + + index(default_library["id"], force=force, folders=None, batchsize=batchsize) + + if __name__ == "__main__": app() diff --git a/memos/config.py b/memos/config.py index 79799c1..b979852 100644 --- a/memos/config.py +++ b/memos/config.py @@ -1,6 +1,6 @@ import os from pathlib import Path -from typing import Tuple, Type +from typing import Tuple, Type, List from pydantic_settings import ( BaseSettings, PydanticBaseSettingsSource, @@ -17,24 +17,28 @@ class VLMSettings(BaseModel): modelname: str = "moondream" endpoint: str = "http://localhost:11434" token: str = "" - concurrency: int = 4 + concurrency: int = 1 force_jpeg: bool = False + use_local: bool = True + use_modelscope: bool = False class OCRSettings(BaseModel): enabled: bool = True endpoint: str = "http://localhost:5555/predict" token: str = "" - concurrency: int = 4 + concurrency: int = 1 use_local: bool = True use_gpu: bool = False force_jpeg: bool = False class EmbeddingSettings(BaseModel): + enabled: bool = True num_dim: int = 768 - ollama_endpoint: str = "http://localhost:11434" - ollama_model: str = "nextfire/paraphrase-multilingual-minilm" + endpoint: str = "http://localhost:11434/api/embed" + model: str = "jinaai/jina-embeddings-v2-base-zh" + use_modelscope: bool = False class Settings(BaseSettings): @@ -46,6 +50,9 @@ class Settings(BaseSettings): base_dir: str = str(Path.home() / ".memos") database_path: str = os.path.join(base_dir, "database.db") + default_library: str = "screenshots" + screenshots_dir: str = os.path.join(base_dir, "screenshots") + typesense_host: str = "localhost" typesense_port: str = "8108" typesense_protocol: str = "http" @@ -66,11 +73,13 @@ class Settings(BaseSettings): # Embedding settings embedding: EmbeddingSettings = EmbeddingSettings() - batchsize: int = 4 + batchsize: int = 1 auth_username: str = "admin" auth_password: SecretStr = SecretStr("changeme") + default_plugins: List[str] = ["builtin_vlm", "builtin_ocr"] + @classmethod def settings_customise_sources( cls, @@ -93,6 +102,19 @@ def dict_representer(dumper, data): yaml.add_representer(OrderedDict, dict_representer) +# Custom representer for SecretStr +def secret_str_representer(dumper, data): + return dumper.represent_scalar("tag:yaml.org,2002:str", data.get_secret_value()) + +# Custom constructor for SecretStr +def secret_str_constructor(loader, node): + value = loader.construct_scalar(node) + return SecretStr(value) + +# Register the representer and constructor only for specific fields +yaml.add_representer(SecretStr, secret_str_representer) + + def create_default_config(): config_path = Path.home() / ".memos" / "config.yaml" if not config_path.exists(): diff --git a/memos/indexing.py b/memos/indexing.py index f110273..5f5d3ba 100644 --- a/memos/indexing.py +++ b/memos/indexing.py @@ -46,14 +46,19 @@ def parse_date_fields(entity): } -def get_embeddings(texts: List[str]) -> List[List[float]]: +async def get_embeddings(texts: List[str]) -> List[List[float]]: print(f"Getting embeddings for {len(texts)} texts") - ollama_endpoint = settings.embedding.ollama_endpoint - ollama_model = settings.embedding.ollama_model - with httpx.Client() as client: - response = client.post( - f"{ollama_endpoint}/api/embed", - json={"model": ollama_model, "input": texts}, + + if settings.embedding.enabled: + endpoint = f"http://{settings.server_host}:{settings.server_port}/plugins/embed" + else: + endpoint = settings.embedding.endpoint + + model = settings.embedding.model + async with httpx.AsyncClient() as client: + response = await client.post( + endpoint, + json={"model": model, "input": texts}, timeout=30 ) if response.status_code == 200: @@ -99,7 +104,7 @@ def generate_metadata_text(metadata_entries): return metadata_text -def bulk_upsert(client, entities): +async def bulk_upsert(client, entities): documents = [] metadata_texts = [] entities_with_metadata = [] @@ -142,7 +147,7 @@ def bulk_upsert(client, entities): ).model_dump(mode="json") ) - embeddings = get_embeddings(metadata_texts) + embeddings = await get_embeddings(metadata_texts) for doc, embedding, entity in zip(documents, embeddings, entities): if entity in entities_with_metadata: doc["embedding"] = embedding @@ -259,7 +264,7 @@ def list_all_entities( ) -def search_entities( +async def search_entities( client, q: str, library_ids: List[int] = None, @@ -287,7 +292,7 @@ def search_entities( filter_by_str = " && ".join(filter_by) if filter_by else "" # Convert q to embedding using get_embeddings and take the first embedding - embedding = get_embeddings([q])[0] + embedding = (await get_embeddings([q]))[0] common_search_params = { "collection": TYPESENSE_COLLECTION_NAME, diff --git a/memos/models.py b/memos/models.py index bc1240b..72bc91c 100644 --- a/memos/models.py +++ b/memos/models.py @@ -15,7 +15,7 @@ from typing import List from .schemas import MetadataSource, MetadataType from sqlalchemy.exc import OperationalError from sqlalchemy.orm import sessionmaker -from .config import get_database_path +from .config import get_database_path, settings class Base(DeclarativeBase): @@ -75,16 +75,14 @@ class EntityModel(Base): "FolderModel", back_populates="entities" ) metadata_entries: Mapped[List["EntityMetadataModel"]] = relationship( - "EntityMetadataModel", - lazy="joined", - cascade="all, delete-orphan" + "EntityMetadataModel", lazy="joined", cascade="all, delete-orphan" ) tags: Mapped[List["TagModel"]] = relationship( - "TagModel", - secondary="entity_tags", + "TagModel", + secondary="entity_tags", lazy="joined", cascade="all, delete", - overlaps="entities" + overlaps="entities", ) # 添加索引 @@ -160,30 +158,61 @@ def init_database(): """Initialize the database.""" db_path = get_database_path() engine = create_engine(f"sqlite:///{db_path}") - + try: Base.metadata.create_all(engine) print(f"Database initialized successfully at {db_path}") - + # Initialize default plugins Session = sessionmaker(bind=engine) with Session() as session: - initialize_default_plugins(session) - + default_plugins = initialize_default_plugins(session) + init_default_libraries(session, default_plugins) + return True except OperationalError as e: print(f"Error initializing database: {e}") return False + def initialize_default_plugins(session): default_plugins = [ - PluginModel(name="buildin_vlm", description="VLM Plugin", webhook_url="/plugins/vlm"), - PluginModel(name="buildin_ocr", description="OCR Plugin", webhook_url="/plugins/ocr"), + PluginModel( + name="builtin_vlm", description="VLM Plugin", webhook_url="/plugins/vlm" + ), + PluginModel( + name="builtin_ocr", description="OCR Plugin", webhook_url="/plugins/ocr" + ), ] for plugin in default_plugins: existing_plugin = session.query(PluginModel).filter_by(name=plugin.name).first() if not existing_plugin: session.add(plugin) - - session.commit() \ No newline at end of file + + session.commit() + + return default_plugins + + +def init_default_libraries(session, default_plugins): + default_libraries = [ + LibraryModel(name=settings.default_library), + ] + + for library in default_libraries: + existing_library = ( + session.query(LibraryModel).filter_by(name=library.name).first() + ) + if not existing_library: + session.add(library) + + for plugin in default_plugins: + bind_response = session.query(PluginModel).filter_by(name=plugin.name).first() + if bind_response: + library_plugin = LibraryPluginModel( + library_id=1, plugin_id=bind_response.id + ) # Assuming library_id=1 for default libraries + session.add(library_plugin) + + session.commit() diff --git a/memos/plugins/embedding/main.py b/memos/plugins/embedding/main.py new file mode 100644 index 0000000..c7fbec4 --- /dev/null +++ b/memos/plugins/embedding/main.py @@ -0,0 +1,146 @@ +import asyncio +from typing import List +from fastapi import APIRouter, HTTPException +import logging +import uvicorn +from sentence_transformers import SentenceTransformer +import torch +import numpy as np +from pydantic import BaseModel +from modelscope import snapshot_download + +PLUGIN_NAME = "embedding" + +router = APIRouter(tags=[PLUGIN_NAME], responses={404: {"description": "Not found"}}) + +# Global variables +enabled = False +model = None +num_dim = None +endpoint = None +model_name = None +device = None +use_modelscope = None + +# Configure logger +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def init_embedding_model(): + global model, device, use_modelscope + if torch.cuda.is_available(): + device = torch.device("cuda") + elif torch.backends.mps.is_available(): + device = torch.device("mps") + else: + device = torch.device("cpu") + + if use_modelscope: + model_dir = snapshot_download(model_name) + logger.info(f"Model downloaded from ModelScope to: {model_dir}") + else: + model_dir = model_name + logger.info(f"Using model: {model_dir}") + + model = SentenceTransformer(model_dir, trust_remote_code=True) + model.to(device) + logger.info(f"Embedding model initialized on device: {device}") + + +def generate_embeddings(input_texts: List[str]) -> List[List[float]]: + embeddings = model.encode(input_texts, convert_to_tensor=True) + embeddings = embeddings.cpu().numpy() + # Normalize embeddings + norms = np.linalg.norm(embeddings, ord=2, axis=1, keepdims=True) + norms[norms == 0] = 1 + embeddings = embeddings / norms + return embeddings.tolist() + + +class EmbeddingRequest(BaseModel): + input: List[str] + + +class EmbeddingResponse(BaseModel): + embeddings: List[List[float]] + + +@router.get("/") +async def read_root(): + return {"healthy": True, "enabled": enabled} + + +@router.post("", include_in_schema=False) +@router.post("/", response_model=EmbeddingResponse) +async def embed(request: EmbeddingRequest): + try: + if not request.input: + return EmbeddingResponse(embeddings=[]) + + # Run the embedding generation in a separate thread to avoid blocking + loop = asyncio.get_event_loop() + embeddings = await loop.run_in_executor(None, generate_embeddings, request.input) + return EmbeddingResponse(embeddings=embeddings) + except Exception as e: + raise HTTPException( + status_code=500, detail=f"Error generating embeddings: {str(e)}" + ) + + +def init_plugin(config): + global enabled, num_dim, endpoint, model_name, use_modelscope + enabled = config.enabled + num_dim = config.num_dim + endpoint = config.endpoint + model_name = config.model + use_modelscope = config.use_modelscope + + if enabled: + init_embedding_model() + + logger.info("Embedding plugin initialized") + logger.info(f"Enabled: {enabled}") + logger.info(f"Number of dimensions: {num_dim}") + logger.info(f"Endpoint: {endpoint}") + logger.info(f"Model: {model_name}") + logger.info(f"Use ModelScope: {use_modelscope}") + + +if __name__ == "__main__": + import argparse + from fastapi import FastAPI + + parser = argparse.ArgumentParser(description="Embedding Plugin Configuration") + parser.add_argument( + "--num-dim", type=int, default=768, help="Number of embedding dimensions" + ) + parser.add_argument( + "--model", + type=str, + default="jinaai/jina-embeddings-v2-base-zh", + help="Embedding model name", + ) + parser.add_argument( + "--port", type=int, default=8000, help="Port to run the server on" + ) + parser.add_argument( + "--use-modelscope", action="store_true", help="Use ModelScope to download the model" + ) + + args = parser.parse_args() + + class Config: + def __init__(self, args): + self.enabled = True + self.num_dim = args.num_dim + self.endpoint = "what ever" + self.model = args.model + self.use_modelscope = args.use_modelscope + + init_plugin(Config(args)) + + app = FastAPI() + app.include_router(router) + + uvicorn.run(app, host="0.0.0.0", port=args.port) diff --git a/memos/plugins/ocr/main.py b/memos/plugins/ocr/main.py index d28cfb5..5ec4ef5 100644 --- a/memos/plugins/ocr/main.py +++ b/memos/plugins/ocr/main.py @@ -141,6 +141,7 @@ async def ocr(entity: Entity, request: Request): } ] }, + timeout=30, ) # Check if the patch request was successful @@ -178,7 +179,7 @@ def init_plugin(config): ocr_config['Cls']['model_path'] = os.path.join(model_dir, os.path.basename(ocr_config['Cls']['model_path'])) ocr_config['Rec']['model_path'] = os.path.join(model_dir, os.path.basename(ocr_config['Rec']['model_path'])) - # Save the updated config to a temporary file + # Save the updated config to a temporary file with strings wrapped in double quotes temp_config_path = os.path.join(os.path.dirname(__file__), "temp_ppocr.yaml") with open(temp_config_path, 'w') as f: yaml.safe_dump(ocr_config, f) diff --git a/memos/plugins/vlm/main.py b/memos/plugins/vlm/main.py index f24ea70..fc905a4 100644 --- a/memos/plugins/vlm/main.py +++ b/memos/plugins/vlm/main.py @@ -9,6 +9,20 @@ import logging import uvicorn import os import io +import torch +from transformers import AutoModelForCausalLM, AutoProcessor + +from unittest.mock import patch +from transformers.dynamic_module_utils import get_imports +from modelscope import snapshot_download + +def fixed_get_imports(filename: str | os.PathLike) -> list[str]: + if not str(filename).endswith("modeling_florence2.py"): + return get_imports(filename) + imports = get_imports(filename) + imports.remove("flash_attn") + return imports + PLUGIN_NAME = "vlm" PROMPT = "描述这张图片的内容" @@ -21,6 +35,10 @@ token = None concurrency = None semaphore = None force_jpeg = None +use_local = None +florence_model = None +florence_processor = None +torch_dtype = None # Configure logger logging.basicConfig(level=logging.INFO) @@ -35,18 +53,18 @@ def image2base64(img_path): with Image.open(img_path) as img: if force_jpeg: # Convert image to RGB mode (removes alpha channel if present) - img = img.convert('RGB') + img = img.convert("RGB") # Save as JPEG in memory buffer = io.BytesIO() - img.save(buffer, format='JPEG') + img.save(buffer, format="JPEG") buffer.seek(0) - encoded_string = base64.b64encode(buffer.getvalue()).decode('utf-8') + encoded_string = base64.b64encode(buffer.getvalue()).decode("utf-8") else: # Use original format buffer = io.BytesIO() img.save(buffer, format=img.format) buffer.seek(0) - encoded_string = base64.b64encode(buffer.getvalue()).decode('utf-8') + encoded_string = base64.b64encode(buffer.getvalue()).decode("utf-8") return encoded_string except Exception as e: logger.error(f"Error processing image {img_path}: {str(e)}") @@ -79,12 +97,57 @@ async def fetch(endpoint: str, client, request_data, headers: Optional[dict] = N async def predict( endpoint: str, modelname: str, img_path: str, token: Optional[str] = None +) -> Optional[str]: + if use_local: + return await predict_local(img_path) + else: + return await predict_remote(endpoint, modelname, img_path, token) + + +async def predict_local(img_path: str) -> Optional[str]: + try: + image = Image.open(img_path) + task_prompt = "" + prompt = task_prompt + "" + + inputs = florence_processor(text=prompt, images=image, return_tensors="pt").to( + florence_model.device, torch_dtype + ) + + generated_ids = florence_model.generate( + input_ids=inputs["input_ids"], + pixel_values=inputs["pixel_values"], + max_new_tokens=1024, + do_sample=False, + num_beams=3, + ) + + generated_texts = florence_processor.batch_decode( + generated_ids, skip_special_tokens=False + ) + + parsed_answer = florence_processor.post_process_generation( + generated_texts[0], + task=task_prompt, + image_size=(image.width, image.height), + ) + + return parsed_answer.get(task_prompt, "") + except Exception as e: + logger.error(f"Error processing image {img_path}: {str(e)}") + return None + + +async def predict_remote( + endpoint: str, modelname: str, img_path: str, token: Optional[str] = None ) -> Optional[str]: img_base64 = image2base64(img_path) if not img_base64: return None - mime_type = "image/jpeg" if force_jpeg else "image/jpeg" # Default to JPEG if force_jpeg is True + mime_type = ( + "image/jpeg" if force_jpeg else "image/jpeg" + ) # Default to JPEG if force_jpeg is True if not force_jpeg: # Only determine MIME type if not forcing JPEG @@ -167,9 +230,9 @@ async def vlm(entity: Entity, request: Request): vlm_result = await predict(endpoint, modelname, entity.filepath, token=token) - print(vlm_result) + logger.info(vlm_result) if not vlm_result: - print(f"No VLM result found for file: {entity.filepath}") + logger.info(f"No VLM result found for file: {entity.filepath}") return {metadata_field_name: "{}"} async with httpx.AsyncClient() as client: @@ -199,14 +262,55 @@ async def vlm(entity: Entity, request: Request): def init_plugin(config): - global modelname, endpoint, token, concurrency, semaphore, force_jpeg + global modelname, endpoint, token, concurrency, semaphore, force_jpeg, use_local, florence_model, florence_processor, torch_dtype + modelname = config.modelname endpoint = config.endpoint token = config.token concurrency = config.concurrency force_jpeg = config.force_jpeg + use_local = config.use_local + use_modelscope = config.use_modelscope semaphore = asyncio.Semaphore(concurrency) + if use_local: + # 检测可用的设备 + if torch.cuda.is_available(): + device = torch.device("cuda") + elif torch.backends.mps.is_available(): + device = torch.device("mps") + else: + device = torch.device("cpu") + + torch_dtype = ( + torch.float32 + if ( + torch.cuda.is_available() and torch.cuda.get_device_capability()[0] <= 6 + ) + or (not torch.cuda.is_available() and not torch.backends.mps.is_available()) + else torch.float16 + ) + logger.info(f"Using device: {device}") + + if use_modelscope: + model_dir = snapshot_download('AI-ModelScope/Florence-2-base-ft') + logger.info(f"Model downloaded from ModelScope to: {model_dir}") + else: + model_dir = "microsoft/Florence-2-base-ft" + logger.info(f"Using model: {model_dir}") + + with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports): + florence_model = AutoModelForCausalLM.from_pretrained( + model_dir, + torch_dtype=torch_dtype, + attn_implementation="sdpa", + trust_remote_code=True, + ).to(device) + florence_processor = AutoProcessor.from_pretrained( + model_dir, trust_remote_code=True + ) + logger.info("Florence model and processor initialized") + # Print the parameters logger.info("VLM plugin initialized") logger.info(f"Model Name: {modelname}") @@ -214,6 +318,8 @@ def init_plugin(config): logger.info(f"Token: {token}") logger.info(f"Concurrency: {concurrency}") logger.info(f"Force JPEG: {force_jpeg}") + logger.info(f"Use Local: {use_local}") + logger.info(f"Use ModelScope: {use_modelscope}") if __name__ == "__main__": @@ -232,6 +338,7 @@ if __name__ == "__main__": parser.add_argument( "--port", type=int, default=8000, help="Port to run the server on" ) + parser.add_argument("--use-modelscope", action="store_true", help="Use ModelScope to download the model") args = parser.parse_args() diff --git a/memos/schemas.py b/memos/schemas.py index edd040f..b1b6042 100644 --- a/memos/schemas.py +++ b/memos/schemas.py @@ -1,4 +1,11 @@ -from pydantic import BaseModel, ConfigDict, DirectoryPath, HttpUrl, Field +from pydantic import ( + BaseModel, + ConfigDict, + DirectoryPath, + HttpUrl, + Field, + model_validator, +) from typing import List, Optional, Any, Dict from datetime import datetime from enum import Enum @@ -79,7 +86,18 @@ class NewPluginParam(BaseModel): class NewLibraryPluginParam(BaseModel): - plugin_id: int + plugin_id: Optional[int] = None + plugin_name: Optional[str] = None + + @model_validator(mode="after") + def check_either_id_or_name(self): + plugin_id = self.plugin_id + plugin_name = self.plugin_name + if not (plugin_id or plugin_name): + raise ValueError("Either plugin_id or plugin_name must be provided") + if plugin_id is not None and plugin_name is not None: + raise ValueError("Only one of plugin_id or plugin_name should be provided") + return self class Folder(BaseModel): @@ -214,15 +232,18 @@ class FacetCount(BaseModel): highlighted: str value: str + class FacetStats(BaseModel): total_values: int + class Facet(BaseModel): counts: List[FacetCount] field_name: str sampled: bool stats: FacetStats + class TextMatchInfo(BaseModel): best_field_score: str best_field_weight: int @@ -232,9 +253,11 @@ class TextMatchInfo(BaseModel): tokens_matched: int typo_prefix_score: int + class HybridSearchInfo(BaseModel): rank_fusion_score: float + class SearchHit(BaseModel): document: EntitySearchResult highlight: Dict[str, Any] = {} @@ -243,12 +266,14 @@ class SearchHit(BaseModel): text_match: Optional[int] = None text_match_info: Optional[TextMatchInfo] = None + class RequestParams(BaseModel): collection_name: str first_q: str per_page: int q: str + class SearchResult(BaseModel): facet_counts: List[Facet] found: int @@ -257,4 +282,4 @@ class SearchResult(BaseModel): page: int request_params: RequestParams search_cutoff: bool - search_time_ms: int \ No newline at end of file + search_time_ms: int diff --git a/memos/server.py b/memos/server.py index 9702b68..f40982e 100644 --- a/memos/server.py +++ b/memos/server.py @@ -23,6 +23,7 @@ import typesense from .config import get_database_path, settings from memos.plugins.vlm import main as vlm_main from memos.plugins.ocr import main as ocr_main +from memos.plugins.embedding import main as embedding_main from . import crud from . import indexing from .schemas import ( @@ -84,18 +85,6 @@ app.mount( "/_app", StaticFiles(directory=os.path.join(current_dir, "static/_app"), html=True) ) -# Add VLM plugin router -if settings.vlm.enabled: - print("VLM plugin is enabled") - vlm_main.init_plugin(settings.vlm) - app.include_router(vlm_main.router, prefix="/plugins/vlm") - -# Add OCR plugin router -if settings.ocr.enabled: - print("OCR plugin is enabled") - ocr_main.init_plugin(settings.ocr) - app.include_router(ocr_main.router, prefix="/plugins/ocr") - @app.get("/favicon.png", response_class=FileResponse) async def favicon_png(): @@ -411,7 +400,7 @@ async def batch_sync_entities_to_typesense( ) try: - indexing.bulk_upsert(client, entities) + await indexing.bulk_upsert(client, entities) except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, @@ -481,7 +470,7 @@ def list_entitiy_indices_in_folder( @app.get("/search", response_model=SearchResult, tags=["search"]) -async def search_entities( +async def search_entities_route( q: str, library_ids: str = Query(None, description="Comma-separated list of library IDs"), folder_ids: str = Query(None, description="Comma-separated list of folder IDs"), @@ -502,7 +491,7 @@ async def search_entities( [date.strip() for date in created_dates.split(",")] if created_dates else None ) try: - return indexing.search_entities( + return await indexing.search_entities( client, q, library_ids, @@ -613,12 +602,29 @@ def add_library_plugin( library_id: int, new_plugin: NewLibraryPluginParam, db: Session = Depends(get_db) ): library = crud.get_library_by_id(library_id, db) - if any(plugin.id == new_plugin.plugin_id for plugin in library.plugins): + if library is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Library not found" + ) + + plugin = None + if new_plugin.plugin_id is not None: + plugin = crud.get_plugin_by_id(new_plugin.plugin_id, db) + elif new_plugin.plugin_name is not None: + plugin = crud.get_plugin_by_name(new_plugin.plugin_name, db) + + if plugin is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Plugin not found" + ) + + if any(p.id == plugin.id for p in library.plugins): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Plugin already exists in the library", ) - crud.add_plugin_to_library(library_id, new_plugin.plugin_id, db) + + crud.add_plugin_to_library(library_id, plugin.id, db) @app.delete( @@ -727,6 +733,24 @@ def run_server(): print(f"VLM plugin enabled: {settings.vlm}") print(f"OCR plugin enabled: {settings.ocr}") + # Add VLM plugin router + if settings.vlm.enabled: + print("VLM plugin is enabled") + vlm_main.init_plugin(settings.vlm) + app.include_router(vlm_main.router, prefix="/plugins/vlm") + + # Add OCR plugin router + if settings.ocr.enabled: + print("OCR plugin is enabled") + ocr_main.init_plugin(settings.ocr) + app.include_router(ocr_main.router, prefix="/plugins/ocr") + + # Add Embedding plugin router + if settings.embedding.enabled: + print("Embedding plugin is enabled") + embedding_main.init_plugin(settings.embedding) + app.include_router(embedding_main.router, prefix="/plugins/embed") + uvicorn.run( "memos.server:app", host=settings.server_host, # Use the new server_host setting diff --git a/memos_ml_backends/server.py b/memos_ml_backends/server.py index 81df3aa..b936668 100644 --- a/memos_ml_backends/server.py +++ b/memos_ml_backends/server.py @@ -1,7 +1,6 @@ from fastapi import FastAPI, HTTPException from pydantic import BaseModel from typing import List, Dict, Any, Optional -from sentence_transformers import SentenceTransformer import numpy as np import httpx import torch @@ -25,31 +24,15 @@ elif torch.backends.mps.is_available(): else: device = torch.device("cpu") -torch_dtype = "auto" +torch_dtype = ( + torch.float32 + if (torch.cuda.is_available() and torch.cuda.get_device_capability()[0] <= 6) + or (not torch.cuda.is_available() and not torch.backends.mps.is_available()) + else torch.float16 +) print(f"Using device: {device}") -def init_embedding_model(): - model = SentenceTransformer( - "jinaai/jina-embeddings-v2-base-zh", trust_remote_code=True - ) - model.to(device) - return model - - -embedding_model = init_embedding_model() # 初始化模型 - - -def generate_embeddings(input_texts: List[str]) -> List[List[float]]: - embeddings = embedding_model.encode(input_texts, convert_to_tensor=True) - embeddings = embeddings.cpu().numpy() - # normalized embeddings - norms = np.linalg.norm(embeddings, ord=2, axis=1, keepdims=True) - norms[norms == 0] = 1 - embeddings = embeddings / norms - return embeddings.tolist() - - # Add a configuration option to choose the model parser = argparse.ArgumentParser(description="Run the server with specified model") parser.add_argument("--florence", action="store_true", help="Use Florence-2 model") @@ -63,8 +46,11 @@ use_florence_model = args.florence if (args.florence or args.qwen2vl) else True if use_florence_model: # Load Florence-2 model florence_model = AutoModelForCausalLM.from_pretrained( - "microsoft/Florence-2-base-ft", torch_dtype=torch_dtype, trust_remote_code=True - ).to(device) + "microsoft/Florence-2-base-ft", + torch_dtype=torch_dtype, + attn_implementation="sdpa", + trust_remote_code=True, + ).to(device, torch_dtype) florence_processor = AutoProcessor.from_pretrained( "microsoft/Florence-2-base-ft", trust_remote_code=True ) @@ -74,7 +60,7 @@ else: "Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4", torch_dtype=torch_dtype, device_map="auto", - ).to(device) + ).to(device, torch_dtype) qwen2vl_processor = AutoProcessor.from_pretrained( "Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4" ) @@ -139,9 +125,9 @@ async def generate_qwen2vl_result(text_input, image_input, max_tokens): text = qwen2vl_processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) - + image_inputs, video_inputs = process_vision_info(messages) - + inputs = qwen2vl_processor( text=[text], images=image_inputs, @@ -152,12 +138,12 @@ async def generate_qwen2vl_result(text_input, image_input, max_tokens): inputs = inputs.to(device) generated_ids = qwen2vl_model.generate(**inputs, max_new_tokens=(max_tokens or 512)) - + generated_ids_trimmed = [ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] - + output_text = qwen2vl_processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, @@ -170,28 +156,6 @@ async def generate_qwen2vl_result(text_input, image_input, max_tokens): app = FastAPI() -class EmbeddingRequest(BaseModel): - input: List[str] - - -class EmbeddingResponse(BaseModel): - embeddings: List[List[float]] - - -@app.post("/api/embed", response_model=EmbeddingResponse) -async def create_embeddings(request: EmbeddingRequest): - try: - if not request.input: - return EmbeddingResponse(embeddings=[]) - - embeddings = generate_embeddings(request.input) # 使用新方法 - return EmbeddingResponse(embeddings=embeddings) - except Exception as e: - raise HTTPException( - status_code=500, detail=f"Error generating embeddings: {str(e)}" - ) - - class ChatCompletionRequest(BaseModel): model: str messages: List[Dict[str, Any]] @@ -275,10 +239,14 @@ async def chat_completions(request: ChatCompletionRequest): if __name__ == "__main__": import uvicorn - parser = argparse.ArgumentParser(description="Run the server with specified model and port") + parser = argparse.ArgumentParser( + description="Run the server with specified model and port" + ) parser.add_argument("--florence", action="store_true", help="Use Florence-2 model") parser.add_argument("--qwen2vl", action="store_true", help="Use Qwen2VL model") - parser.add_argument("--port", type=int, default=8000, help="Port to run the server on") + parser.add_argument( + "--port", type=int, default=8000, help="Port to run the server on" + ) args = parser.parse_args() if args.florence and args.qwen2vl: diff --git a/pyproject.toml b/pyproject.toml index d6d509b..1a07432 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "memos" -version = "0.5.0" +version = "0.6.10" description = "A package for memos" readme = "README.md" authors = [{ name = "arkohut" }] @@ -36,6 +36,13 @@ dependencies = [ "pyobjc; sys_platform == 'darwin'", "pyobjc-core; sys_platform == 'darwin'", "pyobjc-framework-Quartz; sys_platform == 'darwin'", + "sentence-transformers", + "torch", + "numpy", + "timm", + "einops", + "modelscope", + "mss", ] [project.urls] @@ -50,3 +57,4 @@ include = ["memos*", "screen_recorder*"] [tool.setuptools.package-data] "*" = ["static/**/*"] +"memos.plugins.ocr" = ["*.yaml", "models/*.onnx"] diff --git a/screen_recorder/common.py b/screen_recorder/common.py index 99e670a..f87ca44 100644 --- a/screen_recorder/common.py +++ b/screen_recorder/common.py @@ -4,11 +4,11 @@ import time import logging import platform import subprocess -from PIL import Image, ImageGrab +from PIL import Image import imagehash from memos.utils import write_image_metadata -from screeninfo import get_monitors import ctypes +from mss import mss if platform.system() == "Windows": import win32gui @@ -173,57 +173,49 @@ def take_screenshot_windows( app_name, window_title, ): - for monitor in get_monitors(): - safe_monitor_name = "".join( - c for c in monitor.name if c.isalnum() or c in ("_", "-") - ) - logging.info(f"Processing monitor: {safe_monitor_name}") + with mss() as sct: + for i, monitor in enumerate(sct.monitors[1:], 1): # Skip the first monitor (entire screen) + safe_monitor_name = f"monitor_{i}" + logging.info(f"Processing monitor: {safe_monitor_name}") - webp_filename = os.path.join( - base_dir, date, f"screenshot-{timestamp}-of-{safe_monitor_name}.webp" - ) - - img = ImageGrab.grab( - bbox=( - monitor.x, - monitor.y, - monitor.x + monitor.width, - monitor.y + monitor.height, + webp_filename = os.path.join( + base_dir, date, f"screenshot-{timestamp}-of-{safe_monitor_name}.webp" ) - ) - img = img.convert("RGB") - current_hash = str(imagehash.phash(img)) - if ( - safe_monitor_name in previous_hashes - and imagehash.hex_to_hash(current_hash) - - imagehash.hex_to_hash(previous_hashes[safe_monitor_name]) - < threshold - ): - logging.info( - f"Screenshot for {safe_monitor_name} is similar to the previous one. Skipping." + img = sct.grab(monitor) + img = Image.frombytes("RGB", img.size, img.bgra, "raw", "BGRX") + current_hash = str(imagehash.phash(img)) + + if ( + safe_monitor_name in previous_hashes + and imagehash.hex_to_hash(current_hash) + - imagehash.hex_to_hash(previous_hashes[safe_monitor_name]) + < threshold + ): + logging.info( + f"Screenshot for {safe_monitor_name} is similar to the previous one. Skipping." + ) + yield safe_monitor_name, None, "Skipped (similar to previous)" + continue + + previous_hashes[safe_monitor_name] = current_hash + screen_sequences[safe_monitor_name] = ( + screen_sequences.get(safe_monitor_name, 0) + 1 ) - yield safe_monitor_name, None, "Skipped (similar to previous)" - continue - previous_hashes[safe_monitor_name] = current_hash - screen_sequences[safe_monitor_name] = ( - screen_sequences.get(safe_monitor_name, 0) + 1 - ) + metadata = { + "timestamp": timestamp, + "active_app": app_name, + "active_window": window_title, + "screen_name": safe_monitor_name, + "sequence": screen_sequences[safe_monitor_name], + } - metadata = { - "timestamp": timestamp, - "active_app": app_name, - "active_window": window_title, - "screen_name": safe_monitor_name, - "sequence": screen_sequences[safe_monitor_name], - } + img.save(webp_filename, format="WebP", quality=85) + write_image_metadata(webp_filename, metadata) + save_screen_sequences(base_dir, screen_sequences, date) - img.save(webp_filename, format="WebP", quality=85) - write_image_metadata(webp_filename, metadata) - save_screen_sequences(base_dir, screen_sequences, date) - - yield safe_monitor_name, webp_filename, "Saved" + yield safe_monitor_name, webp_filename, "Saved" def take_screenshot( diff --git a/screen_recorder/record.py b/screen_recorder/record.py index 8bc101a..ba141a0 100644 --- a/screen_recorder/record.py +++ b/screen_recorder/record.py @@ -9,6 +9,8 @@ from screen_recorder.common import ( take_screenshot, is_screen_locked, ) +from pathlib import Path +from memos.config import settings logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') @@ -51,12 +53,12 @@ def main(): "--threshold", type=int, default=4, help="Threshold for image similarity" ) parser.add_argument( - "--base-dir", type=str, default="~/tmp", help="Base directory for screenshots" + "--base-dir", type=str, help="Base directory for screenshots" ) parser.add_argument("--once", action="store_true", help="Run once and exit") args = parser.parse_args() - base_dir = os.path.expanduser(args.base_dir) + base_dir = os.path.expanduser(args.base_dir) if args.base_dir else settings.screenshots_dir previous_hashes = load_previous_hashes(base_dir) if args.once: diff --git a/web/src/lib/components/LibraryFilter.svelte b/web/src/lib/components/LibraryFilter.svelte index 1772171..fea7c9f 100644 --- a/web/src/lib/components/LibraryFilter.svelte +++ b/web/src/lib/components/LibraryFilter.svelte @@ -112,7 +112,7 @@ id={`library-${library.id}`} bind:checked={selectedLibraries[library.id]} /> - + {/each}