From a5562b14eb333571dad8f0ccd4253c45bd4dff10 Mon Sep 17 00:00:00 2001
From: arkohut <39525455+arkohut@users.noreply.github.com>
Date: Fri, 6 Sep 2024 23:42:46 +0800
Subject: [PATCH 01/39] feat(ml): use float32 for pascal
---
memos_ml_backends/server.py | 25 +++++++++++++++++--------
1 file changed, 17 insertions(+), 8 deletions(-)
diff --git a/memos_ml_backends/server.py b/memos_ml_backends/server.py
index 81df3aa..8eb0653 100644
--- a/memos_ml_backends/server.py
+++ b/memos_ml_backends/server.py
@@ -25,7 +25,12 @@ elif torch.backends.mps.is_available():
else:
device = torch.device("cpu")
-torch_dtype = "auto"
+torch_dtype = (
+ torch.float32
+ if (torch.cuda.is_available() and torch.cuda.get_device_capability()[0] <= 6)
+ or (not torch.cuda.is_available() and not torch.backends.mps.is_available())
+ else torch.float16
+)
print(f"Using device: {device}")
@@ -37,7 +42,7 @@ def init_embedding_model():
return model
-embedding_model = init_embedding_model() # 初始化模型
+embedding_model = init_embedding_model()
def generate_embeddings(input_texts: List[str]) -> List[List[float]]:
@@ -139,9 +144,9 @@ async def generate_qwen2vl_result(text_input, image_input, max_tokens):
text = qwen2vl_processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
-
+
image_inputs, video_inputs = process_vision_info(messages)
-
+
inputs = qwen2vl_processor(
text=[text],
images=image_inputs,
@@ -152,12 +157,12 @@ async def generate_qwen2vl_result(text_input, image_input, max_tokens):
inputs = inputs.to(device)
generated_ids = qwen2vl_model.generate(**inputs, max_new_tokens=(max_tokens or 512))
-
+
generated_ids_trimmed = [
out_ids[len(in_ids) :]
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
-
+
output_text = qwen2vl_processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
@@ -275,10 +280,14 @@ async def chat_completions(request: ChatCompletionRequest):
if __name__ == "__main__":
import uvicorn
- parser = argparse.ArgumentParser(description="Run the server with specified model and port")
+ parser = argparse.ArgumentParser(
+ description="Run the server with specified model and port"
+ )
parser.add_argument("--florence", action="store_true", help="Use Florence-2 model")
parser.add_argument("--qwen2vl", action="store_true", help="Use Qwen2VL model")
- parser.add_argument("--port", type=int, default=8000, help="Port to run the server on")
+ parser.add_argument(
+ "--port", type=int, default=8000, help="Port to run the server on"
+ )
args = parser.parse_args()
if args.florence and args.qwen2vl:
From 57110db73f95cff4572e7a63fdee1d9051e87baf Mon Sep 17 00:00:00 2001
From: arkohut <39525455+arkohut@users.noreply.github.com>
Date: Sun, 8 Sep 2024 22:56:57 +0800
Subject: [PATCH 02/39] feat(web): show lib id
---
web/src/lib/components/LibraryFilter.svelte | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/web/src/lib/components/LibraryFilter.svelte b/web/src/lib/components/LibraryFilter.svelte
index 1772171..fea7c9f 100644
--- a/web/src/lib/components/LibraryFilter.svelte
+++ b/web/src/lib/components/LibraryFilter.svelte
@@ -112,7 +112,7 @@
id={`library-${library.id}`}
bind:checked={selectedLibraries[library.id]}
/>
-
+
{/each}
From 69aca0153af25ff3f245d7b01789e4d53ab9a84a Mon Sep 17 00:00:00 2001
From: arkohut <39525455+arkohut@users.noreply.github.com>
Date: Mon, 9 Sep 2024 19:43:17 +0800
Subject: [PATCH 03/39] feat: make embedding a default plugin
---
memos/config.py | 5 +-
memos/indexing.py | 27 ++++---
memos/plugins/embedding/main.py | 131 ++++++++++++++++++++++++++++++++
memos/server.py | 37 +++++----
memos_ml_backends/server.py | 49 +-----------
pyproject.toml | 3 +
6 files changed, 179 insertions(+), 73 deletions(-)
create mode 100644 memos/plugins/embedding/main.py
diff --git a/memos/config.py b/memos/config.py
index 79799c1..2c5bde6 100644
--- a/memos/config.py
+++ b/memos/config.py
@@ -32,9 +32,10 @@ class OCRSettings(BaseModel):
class EmbeddingSettings(BaseModel):
+ enabled: bool = True
num_dim: int = 768
- ollama_endpoint: str = "http://localhost:11434"
- ollama_model: str = "nextfire/paraphrase-multilingual-minilm"
+ endpoint: str = "http://localhost:11434/api/embed"
+ model: str = "jinaai/jina-embeddings-v2-base-zh"
class Settings(BaseSettings):
diff --git a/memos/indexing.py b/memos/indexing.py
index f110273..5f5d3ba 100644
--- a/memos/indexing.py
+++ b/memos/indexing.py
@@ -46,14 +46,19 @@ def parse_date_fields(entity):
}
-def get_embeddings(texts: List[str]) -> List[List[float]]:
+async def get_embeddings(texts: List[str]) -> List[List[float]]:
print(f"Getting embeddings for {len(texts)} texts")
- ollama_endpoint = settings.embedding.ollama_endpoint
- ollama_model = settings.embedding.ollama_model
- with httpx.Client() as client:
- response = client.post(
- f"{ollama_endpoint}/api/embed",
- json={"model": ollama_model, "input": texts},
+
+ if settings.embedding.enabled:
+ endpoint = f"http://{settings.server_host}:{settings.server_port}/plugins/embed"
+ else:
+ endpoint = settings.embedding.endpoint
+
+ model = settings.embedding.model
+ async with httpx.AsyncClient() as client:
+ response = await client.post(
+ endpoint,
+ json={"model": model, "input": texts},
timeout=30
)
if response.status_code == 200:
@@ -99,7 +104,7 @@ def generate_metadata_text(metadata_entries):
return metadata_text
-def bulk_upsert(client, entities):
+async def bulk_upsert(client, entities):
documents = []
metadata_texts = []
entities_with_metadata = []
@@ -142,7 +147,7 @@ def bulk_upsert(client, entities):
).model_dump(mode="json")
)
- embeddings = get_embeddings(metadata_texts)
+ embeddings = await get_embeddings(metadata_texts)
for doc, embedding, entity in zip(documents, embeddings, entities):
if entity in entities_with_metadata:
doc["embedding"] = embedding
@@ -259,7 +264,7 @@ def list_all_entities(
)
-def search_entities(
+async def search_entities(
client,
q: str,
library_ids: List[int] = None,
@@ -287,7 +292,7 @@ def search_entities(
filter_by_str = " && ".join(filter_by) if filter_by else ""
# Convert q to embedding using get_embeddings and take the first embedding
- embedding = get_embeddings([q])[0]
+ embedding = (await get_embeddings([q]))[0]
common_search_params = {
"collection": TYPESENSE_COLLECTION_NAME,
diff --git a/memos/plugins/embedding/main.py b/memos/plugins/embedding/main.py
new file mode 100644
index 0000000..89e837a
--- /dev/null
+++ b/memos/plugins/embedding/main.py
@@ -0,0 +1,131 @@
+import asyncio
+from typing import List
+from fastapi import APIRouter, HTTPException
+import logging
+import uvicorn
+from sentence_transformers import SentenceTransformer
+import torch
+import numpy as np
+from pydantic import BaseModel
+
+PLUGIN_NAME = "embedding"
+
+router = APIRouter(tags=[PLUGIN_NAME], responses={404: {"description": "Not found"}})
+
+# Global variables
+enabled = False
+model = None
+num_dim = None
+endpoint = None
+model_name = None
+device = None
+
+# Configure logger
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+
+def init_embedding_model():
+ global model, device
+ if torch.cuda.is_available():
+ device = torch.device("cuda")
+ elif torch.backends.mps.is_available():
+ device = torch.device("mps")
+ else:
+ device = torch.device("cpu")
+
+ model = SentenceTransformer(model_name, trust_remote_code=True)
+ model.to(device)
+ logger.info(f"Embedding model initialized on device: {device}")
+
+
+def generate_embeddings(input_texts: List[str]) -> List[List[float]]:
+ embeddings = model.encode(input_texts, convert_to_tensor=True)
+ embeddings = embeddings.cpu().numpy()
+ # Normalize embeddings
+ norms = np.linalg.norm(embeddings, ord=2, axis=1, keepdims=True)
+ norms[norms == 0] = 1
+ embeddings = embeddings / norms
+ return embeddings.tolist()
+
+
+class EmbeddingRequest(BaseModel):
+ input: List[str]
+
+
+class EmbeddingResponse(BaseModel):
+ embeddings: List[List[float]]
+
+
+@router.get("/")
+async def read_root():
+ return {"healthy": True, "enabled": enabled}
+
+
+@router.post("", include_in_schema=False)
+@router.post("/", response_model=EmbeddingResponse)
+async def embed(request: EmbeddingRequest):
+ try:
+ if not request.input:
+ return EmbeddingResponse(embeddings=[])
+
+ # Run the embedding generation in a separate thread to avoid blocking
+ loop = asyncio.get_event_loop()
+ embeddings = await loop.run_in_executor(None, generate_embeddings, request.input)
+ return EmbeddingResponse(embeddings=embeddings)
+ except Exception as e:
+ raise HTTPException(
+ status_code=500, detail=f"Error generating embeddings: {str(e)}"
+ )
+
+
+def init_plugin(config):
+ global enabled, num_dim, endpoint, model_name
+ enabled = config.enabled
+ num_dim = config.num_dim
+ endpoint = config.endpoint
+ model_name = config.model
+
+ if enabled:
+ init_embedding_model()
+
+ logger.info("Embedding plugin initialized")
+ logger.info(f"Enabled: {enabled}")
+ logger.info(f"Number of dimensions: {num_dim}")
+ logger.info(f"Endpoint: {endpoint}")
+ logger.info(f"Model: {model_name}")
+
+
+if __name__ == "__main__":
+ import argparse
+ from fastapi import FastAPI
+
+ parser = argparse.ArgumentParser(description="Embedding Plugin Configuration")
+ parser.add_argument(
+ "--num-dim", type=int, default=768, help="Number of embedding dimensions"
+ )
+ parser.add_argument(
+ "--model",
+ type=str,
+ default="jinaai/jina-embeddings-v2-base-zh",
+ help="Embedding model name",
+ )
+ parser.add_argument(
+ "--port", type=int, default=8000, help="Port to run the server on"
+ )
+
+ args = parser.parse_args()
+
+ class Config:
+ def __init__(self, args):
+ self.enabled = True
+ self.num_dim = args.num_dim
+ self.endpoint = "what ever"
+ self.model = args.model
+
+ init_plugin(Config(args))
+
+ app = FastAPI()
+ app.include_router(router)
+
+ uvicorn.run(app, host="0.0.0.0", port=args.port)
diff --git a/memos/server.py b/memos/server.py
index 9702b68..d48e355 100644
--- a/memos/server.py
+++ b/memos/server.py
@@ -23,6 +23,7 @@ import typesense
from .config import get_database_path, settings
from memos.plugins.vlm import main as vlm_main
from memos.plugins.ocr import main as ocr_main
+from memos.plugins.embedding import main as embedding_main
from . import crud
from . import indexing
from .schemas import (
@@ -84,18 +85,6 @@ app.mount(
"/_app", StaticFiles(directory=os.path.join(current_dir, "static/_app"), html=True)
)
-# Add VLM plugin router
-if settings.vlm.enabled:
- print("VLM plugin is enabled")
- vlm_main.init_plugin(settings.vlm)
- app.include_router(vlm_main.router, prefix="/plugins/vlm")
-
-# Add OCR plugin router
-if settings.ocr.enabled:
- print("OCR plugin is enabled")
- ocr_main.init_plugin(settings.ocr)
- app.include_router(ocr_main.router, prefix="/plugins/ocr")
-
@app.get("/favicon.png", response_class=FileResponse)
async def favicon_png():
@@ -411,7 +400,7 @@ async def batch_sync_entities_to_typesense(
)
try:
- indexing.bulk_upsert(client, entities)
+ await indexing.bulk_upsert(client, entities)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
@@ -481,7 +470,7 @@ def list_entitiy_indices_in_folder(
@app.get("/search", response_model=SearchResult, tags=["search"])
-async def search_entities(
+async def search_entities_route(
q: str,
library_ids: str = Query(None, description="Comma-separated list of library IDs"),
folder_ids: str = Query(None, description="Comma-separated list of folder IDs"),
@@ -502,7 +491,7 @@ async def search_entities(
[date.strip() for date in created_dates.split(",")] if created_dates else None
)
try:
- return indexing.search_entities(
+ return await indexing.search_entities(
client,
q,
library_ids,
@@ -727,6 +716,24 @@ def run_server():
print(f"VLM plugin enabled: {settings.vlm}")
print(f"OCR plugin enabled: {settings.ocr}")
+ # Add VLM plugin router
+ if settings.vlm.enabled:
+ print("VLM plugin is enabled")
+ vlm_main.init_plugin(settings.vlm)
+ app.include_router(vlm_main.router, prefix="/plugins/vlm")
+
+ # Add OCR plugin router
+ if settings.ocr.enabled:
+ print("OCR plugin is enabled")
+ ocr_main.init_plugin(settings.ocr)
+ app.include_router(ocr_main.router, prefix="/plugins/ocr")
+
+ # Add Embedding plugin router
+ if settings.embedding.enabled:
+ print("Embedding plugin is enabled")
+ embedding_main.init_plugin(settings.embedding)
+ app.include_router(embedding_main.router, prefix="/plugins/embed")
+
uvicorn.run(
"memos.server:app",
host=settings.server_host, # Use the new server_host setting
diff --git a/memos_ml_backends/server.py b/memos_ml_backends/server.py
index 8eb0653..7a60825 100644
--- a/memos_ml_backends/server.py
+++ b/memos_ml_backends/server.py
@@ -1,7 +1,6 @@
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Dict, Any, Optional
-from sentence_transformers import SentenceTransformer
import numpy as np
import httpx
import torch
@@ -34,27 +33,6 @@ torch_dtype = (
print(f"Using device: {device}")
-def init_embedding_model():
- model = SentenceTransformer(
- "jinaai/jina-embeddings-v2-base-zh", trust_remote_code=True
- )
- model.to(device)
- return model
-
-
-embedding_model = init_embedding_model()
-
-
-def generate_embeddings(input_texts: List[str]) -> List[List[float]]:
- embeddings = embedding_model.encode(input_texts, convert_to_tensor=True)
- embeddings = embeddings.cpu().numpy()
- # normalized embeddings
- norms = np.linalg.norm(embeddings, ord=2, axis=1, keepdims=True)
- norms[norms == 0] = 1
- embeddings = embeddings / norms
- return embeddings.tolist()
-
-
# Add a configuration option to choose the model
parser = argparse.ArgumentParser(description="Run the server with specified model")
parser.add_argument("--florence", action="store_true", help="Use Florence-2 model")
@@ -68,7 +46,10 @@ use_florence_model = args.florence if (args.florence or args.qwen2vl) else True
if use_florence_model:
# Load Florence-2 model
florence_model = AutoModelForCausalLM.from_pretrained(
- "microsoft/Florence-2-base-ft", torch_dtype=torch_dtype, trust_remote_code=True
+ "microsoft/Florence-2-base-ft",
+ torch_dtype=torch_dtype,
+ attn_implementation="sdpa",
+ trust_remote_code=True,
).to(device)
florence_processor = AutoProcessor.from_pretrained(
"microsoft/Florence-2-base-ft", trust_remote_code=True
@@ -175,28 +156,6 @@ async def generate_qwen2vl_result(text_input, image_input, max_tokens):
app = FastAPI()
-class EmbeddingRequest(BaseModel):
- input: List[str]
-
-
-class EmbeddingResponse(BaseModel):
- embeddings: List[List[float]]
-
-
-@app.post("/api/embed", response_model=EmbeddingResponse)
-async def create_embeddings(request: EmbeddingRequest):
- try:
- if not request.input:
- return EmbeddingResponse(embeddings=[])
-
- embeddings = generate_embeddings(request.input) # 使用新方法
- return EmbeddingResponse(embeddings=embeddings)
- except Exception as e:
- raise HTTPException(
- status_code=500, detail=f"Error generating embeddings: {str(e)}"
- )
-
-
class ChatCompletionRequest(BaseModel):
model: str
messages: List[Dict[str, Any]]
diff --git a/pyproject.toml b/pyproject.toml
index d6d509b..c52a9ee 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -36,6 +36,9 @@ dependencies = [
"pyobjc; sys_platform == 'darwin'",
"pyobjc-core; sys_platform == 'darwin'",
"pyobjc-framework-Quartz; sys_platform == 'darwin'",
+ "sentence-transformers",
+ "torch",
+ "numpy",
]
[project.urls]
From 8056f19773e002b48dc383b0d1ce667db5842e7e Mon Sep 17 00:00:00 2001
From: arkohut <39525455+arkohut@users.noreply.github.com>
Date: Mon, 9 Sep 2024 19:45:37 +0800
Subject: [PATCH 04/39] chore: bump 0.6.0
---
pyproject.toml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/pyproject.toml b/pyproject.toml
index c52a9ee..0c200a4 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "memos"
-version = "0.5.0"
+version = "0.6.0"
description = "A package for memos"
readme = "README.md"
authors = [{ name = "arkohut" }]
From d7e6c32e86d30117d6a4aa3aa7ec7eb032b5897d Mon Sep 17 00:00:00 2001
From: arkohut <39525455+arkohut@users.noreply.github.com>
Date: Mon, 9 Sep 2024 19:57:07 +0800
Subject: [PATCH 05/39] feat: init default library
---
memos/commands.py | 1 -
memos/config.py | 2 ++
memos/models.py | 59 +++++++++++++++++++++++++++++++++++------------
3 files changed, 46 insertions(+), 16 deletions(-)
diff --git a/memos/commands.py b/memos/commands.py
index fd9c9bc..00753e8 100644
--- a/memos/commands.py
+++ b/memos/commands.py
@@ -1,6 +1,5 @@
import asyncio
import os
-import time
import logging
from datetime import datetime, timezone
from pathlib import Path
diff --git a/memos/config.py b/memos/config.py
index 2c5bde6..458ee8d 100644
--- a/memos/config.py
+++ b/memos/config.py
@@ -47,6 +47,8 @@ class Settings(BaseSettings):
base_dir: str = str(Path.home() / ".memos")
database_path: str = os.path.join(base_dir, "database.db")
+ default_library: str = "screenshots"
+
typesense_host: str = "localhost"
typesense_port: str = "8108"
typesense_protocol: str = "http"
diff --git a/memos/models.py b/memos/models.py
index bc1240b..740e8e6 100644
--- a/memos/models.py
+++ b/memos/models.py
@@ -15,7 +15,7 @@ from typing import List
from .schemas import MetadataSource, MetadataType
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import sessionmaker
-from .config import get_database_path
+from .config import get_database_path, settings
class Base(DeclarativeBase):
@@ -75,16 +75,14 @@ class EntityModel(Base):
"FolderModel", back_populates="entities"
)
metadata_entries: Mapped[List["EntityMetadataModel"]] = relationship(
- "EntityMetadataModel",
- lazy="joined",
- cascade="all, delete-orphan"
+ "EntityMetadataModel", lazy="joined", cascade="all, delete-orphan"
)
tags: Mapped[List["TagModel"]] = relationship(
- "TagModel",
- secondary="entity_tags",
+ "TagModel",
+ secondary="entity_tags",
lazy="joined",
cascade="all, delete",
- overlaps="entities"
+ overlaps="entities",
)
# 添加索引
@@ -160,30 +158,61 @@ def init_database():
"""Initialize the database."""
db_path = get_database_path()
engine = create_engine(f"sqlite:///{db_path}")
-
+
try:
Base.metadata.create_all(engine)
print(f"Database initialized successfully at {db_path}")
-
+
# Initialize default plugins
Session = sessionmaker(bind=engine)
with Session() as session:
- initialize_default_plugins(session)
-
+ default_plugins = initialize_default_plugins(session)
+ init_default_libraries(session, default_plugins)
+
return True
except OperationalError as e:
print(f"Error initializing database: {e}")
return False
+
def initialize_default_plugins(session):
default_plugins = [
- PluginModel(name="buildin_vlm", description="VLM Plugin", webhook_url="/plugins/vlm"),
- PluginModel(name="buildin_ocr", description="OCR Plugin", webhook_url="/plugins/ocr"),
+ PluginModel(
+ name="buildin_vlm", description="VLM Plugin", webhook_url="/plugins/vlm"
+ ),
+ PluginModel(
+ name="buildin_ocr", description="OCR Plugin", webhook_url="/plugins/ocr"
+ ),
]
for plugin in default_plugins:
existing_plugin = session.query(PluginModel).filter_by(name=plugin.name).first()
if not existing_plugin:
session.add(plugin)
-
- session.commit()
\ No newline at end of file
+
+ session.commit()
+
+ return default_plugins
+
+
+def init_default_libraries(session, default_plugins):
+ default_libraries = [
+ LibraryModel(name=settings.default_library),
+ ]
+
+ for library in default_libraries:
+ existing_library = (
+ session.query(LibraryModel).filter_by(name=library.name).first()
+ )
+ if not existing_library:
+ session.add(library)
+
+ for plugin in default_plugins:
+ bind_response = session.query(PluginModel).filter_by(name=plugin.name).first()
+ if bind_response:
+ library_plugin = LibraryPluginModel(
+ library_id=1, plugin_id=bind_response.id
+ ) # Assuming library_id=1 for default libraries
+ session.add(library_plugin)
+
+ session.commit()
From 7e43bc086126b865613d636039715e6ea93a1432 Mon Sep 17 00:00:00 2001
From: arkohut <39525455+arkohut@users.noreply.github.com>
Date: Mon, 9 Sep 2024 20:30:58 +0800
Subject: [PATCH 06/39] feat(ml_backend): move florence 2 as a default vlm
plugin
---
memos/config.py | 1 +
memos/plugins/ocr/main.py | 1 +
memos/plugins/vlm/main.py | 100 +++++++++++++++++++++++++++++++++---
memos_ml_backends/server.py | 4 +-
4 files changed, 96 insertions(+), 10 deletions(-)
diff --git a/memos/config.py b/memos/config.py
index 458ee8d..db5ffbc 100644
--- a/memos/config.py
+++ b/memos/config.py
@@ -19,6 +19,7 @@ class VLMSettings(BaseModel):
token: str = ""
concurrency: int = 4
force_jpeg: bool = False
+ use_local: bool = True
class OCRSettings(BaseModel):
diff --git a/memos/plugins/ocr/main.py b/memos/plugins/ocr/main.py
index 399d23a..a0b207d 100644
--- a/memos/plugins/ocr/main.py
+++ b/memos/plugins/ocr/main.py
@@ -145,6 +145,7 @@ async def ocr(entity: Entity, request: Request):
}
]
},
+ timeout=30,
)
# Check if the patch request was successful
diff --git a/memos/plugins/vlm/main.py b/memos/plugins/vlm/main.py
index f24ea70..ad636b6 100644
--- a/memos/plugins/vlm/main.py
+++ b/memos/plugins/vlm/main.py
@@ -9,6 +9,8 @@ import logging
import uvicorn
import os
import io
+import torch
+from transformers import AutoModelForCausalLM, AutoProcessor
PLUGIN_NAME = "vlm"
PROMPT = "描述这张图片的内容"
@@ -21,6 +23,10 @@ token = None
concurrency = None
semaphore = None
force_jpeg = None
+use_local = None
+florence_model = None
+florence_processor = None
+torch_dtype = None
# Configure logger
logging.basicConfig(level=logging.INFO)
@@ -35,18 +41,18 @@ def image2base64(img_path):
with Image.open(img_path) as img:
if force_jpeg:
# Convert image to RGB mode (removes alpha channel if present)
- img = img.convert('RGB')
+ img = img.convert("RGB")
# Save as JPEG in memory
buffer = io.BytesIO()
- img.save(buffer, format='JPEG')
+ img.save(buffer, format="JPEG")
buffer.seek(0)
- encoded_string = base64.b64encode(buffer.getvalue()).decode('utf-8')
+ encoded_string = base64.b64encode(buffer.getvalue()).decode("utf-8")
else:
# Use original format
buffer = io.BytesIO()
img.save(buffer, format=img.format)
buffer.seek(0)
- encoded_string = base64.b64encode(buffer.getvalue()).decode('utf-8')
+ encoded_string = base64.b64encode(buffer.getvalue()).decode("utf-8")
return encoded_string
except Exception as e:
logger.error(f"Error processing image {img_path}: {str(e)}")
@@ -79,12 +85,57 @@ async def fetch(endpoint: str, client, request_data, headers: Optional[dict] = N
async def predict(
endpoint: str, modelname: str, img_path: str, token: Optional[str] = None
+) -> Optional[str]:
+ if use_local:
+ return await predict_local(img_path)
+ else:
+ return await predict_remote(endpoint, modelname, img_path, token)
+
+
+async def predict_local(img_path: str) -> Optional[str]:
+ try:
+ image = Image.open(img_path)
+ task_prompt = ""
+ prompt = task_prompt + ""
+
+ inputs = florence_processor(text=prompt, images=image, return_tensors="pt").to(
+ florence_model.device, torch_dtype
+ )
+
+ generated_ids = florence_model.generate(
+ input_ids=inputs["input_ids"],
+ pixel_values=inputs["pixel_values"],
+ max_new_tokens=1024,
+ do_sample=False,
+ num_beams=3,
+ )
+
+ generated_texts = florence_processor.batch_decode(
+ generated_ids, skip_special_tokens=False
+ )
+
+ parsed_answer = florence_processor.post_process_generation(
+ generated_texts[0],
+ task=task_prompt,
+ image_size=(image.width, image.height),
+ )
+
+ return parsed_answer.get(task_prompt, "")
+ except Exception as e:
+ logger.error(f"Error processing image {img_path}: {str(e)}")
+ return None
+
+
+async def predict_remote(
+ endpoint: str, modelname: str, img_path: str, token: Optional[str] = None
) -> Optional[str]:
img_base64 = image2base64(img_path)
if not img_base64:
return None
- mime_type = "image/jpeg" if force_jpeg else "image/jpeg" # Default to JPEG if force_jpeg is True
+ mime_type = (
+ "image/jpeg" if force_jpeg else "image/jpeg"
+ ) # Default to JPEG if force_jpeg is True
if not force_jpeg:
# Only determine MIME type if not forcing JPEG
@@ -167,9 +218,9 @@ async def vlm(entity: Entity, request: Request):
vlm_result = await predict(endpoint, modelname, entity.filepath, token=token)
- print(vlm_result)
+ logger.info(vlm_result)
if not vlm_result:
- print(f"No VLM result found for file: {entity.filepath}")
+ logger.info(f"No VLM result found for file: {entity.filepath}")
return {metadata_field_name: "{}"}
async with httpx.AsyncClient() as client:
@@ -199,14 +250,46 @@ async def vlm(entity: Entity, request: Request):
def init_plugin(config):
- global modelname, endpoint, token, concurrency, semaphore, force_jpeg
+ global modelname, endpoint, token, concurrency, semaphore, force_jpeg, use_local, florence_model, florence_processor, torch_dtype
+
modelname = config.modelname
endpoint = config.endpoint
token = config.token
concurrency = config.concurrency
force_jpeg = config.force_jpeg
+ use_local = config.use_local
semaphore = asyncio.Semaphore(concurrency)
+ if use_local:
+ # 检测可用的设备
+ if torch.cuda.is_available():
+ device = torch.device("cuda")
+ elif torch.backends.mps.is_available():
+ device = torch.device("mps")
+ else:
+ device = torch.device("cpu")
+
+ torch_dtype = (
+ torch.float32
+ if (
+ torch.cuda.is_available() and torch.cuda.get_device_capability()[0] <= 6
+ )
+ or (not torch.cuda.is_available() and not torch.backends.mps.is_available())
+ else torch.float16
+ )
+ logger.info(f"Using device: {device}")
+
+ florence_model = AutoModelForCausalLM.from_pretrained(
+ "microsoft/Florence-2-base-ft",
+ torch_dtype=torch_dtype,
+ attn_implementation="sdpa",
+ trust_remote_code=True,
+ ).to(device)
+ florence_processor = AutoProcessor.from_pretrained(
+ "microsoft/Florence-2-base-ft", trust_remote_code=True
+ )
+ logger.info("Florence model and processor initialized")
+
# Print the parameters
logger.info("VLM plugin initialized")
logger.info(f"Model Name: {modelname}")
@@ -214,6 +297,7 @@ def init_plugin(config):
logger.info(f"Token: {token}")
logger.info(f"Concurrency: {concurrency}")
logger.info(f"Force JPEG: {force_jpeg}")
+ logger.info(f"Use Local: {use_local}")
if __name__ == "__main__":
diff --git a/memos_ml_backends/server.py b/memos_ml_backends/server.py
index 7a60825..b936668 100644
--- a/memos_ml_backends/server.py
+++ b/memos_ml_backends/server.py
@@ -50,7 +50,7 @@ if use_florence_model:
torch_dtype=torch_dtype,
attn_implementation="sdpa",
trust_remote_code=True,
- ).to(device)
+ ).to(device, torch_dtype)
florence_processor = AutoProcessor.from_pretrained(
"microsoft/Florence-2-base-ft", trust_remote_code=True
)
@@ -60,7 +60,7 @@ else:
"Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4",
torch_dtype=torch_dtype,
device_map="auto",
- ).to(device)
+ ).to(device, torch_dtype)
qwen2vl_processor = AutoProcessor.from_pretrained(
"Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4"
)
From 81ba2cd2d23f38f73d59d3fa21c4d59b2a0501e0 Mon Sep 17 00:00:00 2001
From: arkohut <39525455+arkohut@users.noreply.github.com>
Date: Mon, 9 Sep 2024 20:58:06 +0800
Subject: [PATCH 07/39] feat: add default library and command shortcuts for it
---
memos/commands.py | 72 +++++++++++++++++++++++++++++++++++++++
memos/config.py | 1 +
screen_recorder/record.py | 6 ++--
3 files changed, 77 insertions(+), 2 deletions(-)
diff --git a/memos/commands.py b/memos/commands.py
index 00753e8..8092918 100644
--- a/memos/commands.py
+++ b/memos/commands.py
@@ -762,5 +762,77 @@ def init():
print("Initialization failed. Please check the error messages above.")
+@app.command("scan")
+def scan_default_library():
+ """
+ Scan the screenshots directory and add it to the library if empty.
+ """
+ # Get the default library
+ response = httpx.get(f"{BASE_URL}/libraries")
+ if response.status_code != 200:
+ print(f"Failed to retrieve libraries: {response.status_code} - {response.text}")
+ return
+
+ libraries = response.json()
+ default_library = next(
+ (lib for lib in libraries if lib["name"] == settings.default_library), None
+ )
+
+ if not default_library:
+ # Create the default library if it doesn't exist
+ response = httpx.post(
+ f"{BASE_URL}/libraries",
+ json={"name": settings.default_library, "folders": []},
+ )
+ if response.status_code != 200:
+ print(
+ f"Failed to create default library: {response.status_code} - {response.text}"
+ )
+ return
+ default_library = response.json()
+
+ # Check if the library is empty
+ if not default_library["folders"]:
+ # Add the screenshots directory to the library
+ screenshots_dir = Path(settings.screenshots_dir).resolve()
+ response = httpx.post(
+ f"{BASE_URL}/libraries/{default_library['id']}/folders",
+ json={"folders": [str(screenshots_dir)]},
+ )
+ if response.status_code != 200:
+ print(
+ f"Failed to add screenshots directory: {response.status_code} - {response.text}"
+ )
+ return
+ print(f"Added screenshots directory: {screenshots_dir}")
+
+ # Scan the library
+ print(f"Scanning library: {default_library['name']}")
+ scan(default_library["id"], plugins=None, folders=None)
+
+
+@app.command("index")
+def index_default_library():
+ """
+ Index the default library for memos.
+ """
+ # Get the default library
+ response = httpx.get(f"{BASE_URL}/libraries")
+ if response.status_code != 200:
+ print(f"Failed to retrieve libraries: {response.status_code} - {response.text}")
+ return
+
+ libraries = response.json()
+ default_library = next(
+ (lib for lib in libraries if lib["name"] == settings.default_library), None
+ )
+
+ if not default_library:
+ print("Default library does not exist.")
+ return
+
+ index(default_library["id"], force=False, folders=None)
+
+
if __name__ == "__main__":
app()
diff --git a/memos/config.py b/memos/config.py
index db5ffbc..309a373 100644
--- a/memos/config.py
+++ b/memos/config.py
@@ -49,6 +49,7 @@ class Settings(BaseSettings):
base_dir: str = str(Path.home() / ".memos")
database_path: str = os.path.join(base_dir, "database.db")
default_library: str = "screenshots"
+ screenshots_dir: str = os.path.join(base_dir, "screenshots")
typesense_host: str = "localhost"
typesense_port: str = "8108"
diff --git a/screen_recorder/record.py b/screen_recorder/record.py
index 8bc101a..ba141a0 100644
--- a/screen_recorder/record.py
+++ b/screen_recorder/record.py
@@ -9,6 +9,8 @@ from screen_recorder.common import (
take_screenshot,
is_screen_locked,
)
+from pathlib import Path
+from memos.config import settings
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
@@ -51,12 +53,12 @@ def main():
"--threshold", type=int, default=4, help="Threshold for image similarity"
)
parser.add_argument(
- "--base-dir", type=str, default="~/tmp", help="Base directory for screenshots"
+ "--base-dir", type=str, help="Base directory for screenshots"
)
parser.add_argument("--once", action="store_true", help="Run once and exit")
args = parser.parse_args()
- base_dir = os.path.expanduser(args.base_dir)
+ base_dir = os.path.expanduser(args.base_dir) if args.base_dir else settings.screenshots_dir
previous_hashes = load_previous_hashes(base_dir)
if args.once:
From aba3c03c1f5810ebd9ca472c327b19d8532e051b Mon Sep 17 00:00:00 2001
From: arkohut <39525455+arkohut@users.noreply.github.com>
Date: Mon, 9 Sep 2024 21:05:50 +0800
Subject: [PATCH 08/39] docs: update readme
---
README.md | 67 ++++++++++++++++++++++++++++++++++++++++++++++++++++++-
1 file changed, 66 insertions(+), 1 deletion(-)
diff --git a/README.md b/README.md
index 60c8cfa..a8af81f 100644
--- a/README.md
+++ b/README.md
@@ -1,3 +1,68 @@
# Memos
-A project to index everything to make it like another memory.
+A project to index everything to make it like another memory. The project contains two parts:
+
+1. `screen recorder`: which will take screenshots every 5 seconds and save it at `~/tmp` by default.
+2. `memos server`: a web service which can index the screenshots and other files, and provide a web interface to search the records.
+
+There is a product called [Rewind](https://www.rewind.ai/) which is similar to memos. But memos try to make all the data controlled by yourself.
+
+## Install
+
+### Install Typesense
+
+```bash
+export TYPESENSE_API_KEY=xyz
+
+mkdir "$(pwd)"/typesense-data
+
+docker run -d -p 8108:8108 \
+ -v"$(pwd)"/typesense-data:/data typesense/typesense:27.0 \
+ --add-host=host.docker.internal:host-gateway \
+ --data-dir /data \
+ --api-key=$TYPESENSE_API_KEY \
+ --enable-cors
+```
+
+### Install Memos
+
+```bash
+pip install memos
+```
+
+## How to use
+
+To use memos, you need to initialize it first. Make sure you have started `typesense`.
+
+### 1. Initialize Memos
+
+```bash
+memos init
+```
+
+This will create a folder `~/.memos` and put the config file there.
+
+### 2. Start Screen Recorder
+
+```bash
+memos record
+```
+
+This will start a screen recorder, which will take screenshots every 5 seconds and save it at `~/.memos/screenshots` by default.
+
+### 3. Start Memos Server
+
+```bash
+memos serve
+```
+
+This will start a web server, and you can access the web interface at `http://localhost:8080`.
+
+### Index the screenshots
+
+```bash
+memos scan
+memos index
+```
+
+Refresh the page, and do some search.
From 3f08b79bac6de5d137b0cd3443fe8428ff7a5b66 Mon Sep 17 00:00:00 2001
From: arkohut <39525455+arkohut@users.noreply.github.com>
Date: Mon, 9 Sep 2024 21:06:38 +0800
Subject: [PATCH 09/39] chore: bump version
---
pyproject.toml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/pyproject.toml b/pyproject.toml
index 0c200a4..303d8ca 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "memos"
-version = "0.6.0"
+version = "0.6.1"
description = "A package for memos"
readme = "README.md"
authors = [{ name = "arkohut" }]
From 0677931ca5a204bc5138569afe7ea6c334965c90 Mon Sep 17 00:00:00 2001
From: arkohut <39525455+arkohut@users.noreply.github.com>
Date: Mon, 9 Sep 2024 21:08:02 +0800
Subject: [PATCH 10/39] chore: typo
---
README.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/README.md b/README.md
index a8af81f..2cd2574 100644
--- a/README.md
+++ b/README.md
@@ -45,7 +45,7 @@ This will create a folder `~/.memos` and put the config file there.
### 2. Start Screen Recorder
```bash
-memos record
+memos-record
```
This will start a screen recorder, which will take screenshots every 5 seconds and save it at `~/.memos/screenshots` by default.
From f9e2b2261bb434fec0bfe30283e20e1c2e768443 Mon Sep 17 00:00:00 2001
From: arkohut <39525455+arkohut@users.noreply.github.com>
Date: Mon, 9 Sep 2024 21:09:54 +0800
Subject: [PATCH 11/39] docs: typo
---
README.md | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/README.md b/README.md
index 2cd2574..f6d3a49 100644
--- a/README.md
+++ b/README.md
@@ -2,10 +2,10 @@
A project to index everything to make it like another memory. The project contains two parts:
-1. `screen recorder`: which will take screenshots every 5 seconds and save it at `~/tmp` by default.
-2. `memos server`: a web service which can index the screenshots and other files, and provide a web interface to search the records.
+1. `screen recorder`: which takes screenshots every 5 seconds and saves them to `~/.memos/screenshots` by default.
+2. `memos server`: a web service that can index the screenshots and other files, providing a web interface to search the records.
-There is a product called [Rewind](https://www.rewind.ai/) which is similar to memos. But memos try to make all the data controlled by yourself.
+There is a product called [Rewind](https://www.rewind.ai/) that is similar to memos, but memos aims to give you control over all your data.
## Install
From 6107c22defad2ae86a458055f480f9f3f616406e Mon Sep 17 00:00:00 2001
From: arkohut <39525455+arkohut@users.noreply.github.com>
Date: Mon, 9 Sep 2024 22:26:26 +0800
Subject: [PATCH 12/39] fix: support parse secret
---
memos/config.py | 17 ++++++++++++++++-
1 file changed, 16 insertions(+), 1 deletion(-)
diff --git a/memos/config.py b/memos/config.py
index 309a373..ec9170b 100644
--- a/memos/config.py
+++ b/memos/config.py
@@ -50,7 +50,7 @@ class Settings(BaseSettings):
database_path: str = os.path.join(base_dir, "database.db")
default_library: str = "screenshots"
screenshots_dir: str = os.path.join(base_dir, "screenshots")
-
+
typesense_host: str = "localhost"
typesense_port: str = "8108"
typesense_protocol: str = "http"
@@ -98,6 +98,21 @@ def dict_representer(dumper, data):
yaml.add_representer(OrderedDict, dict_representer)
+# Custom representer for SecretStr
+def secret_str_representer(dumper, data):
+ return dumper.represent_scalar("tag:yaml.org,2002:str", data.get_secret_value())
+
+
+# Custom constructor for SecretStr
+def secret_str_constructor(loader, node):
+ value = loader.construct_scalar(node)
+ return SecretStr(value)
+
+
+yaml.add_representer(SecretStr, secret_str_representer)
+yaml.add_constructor("tag:yaml.org,2002:str", secret_str_constructor)
+
+
def create_default_config():
config_path = Path.home() / ".memos" / "config.yaml"
if not config_path.exists():
From 5b1194f1bc6ad055fe69a9a2207ef1c31cc73721 Mon Sep 17 00:00:00 2001
From: arkohut <39525455+arkohut@users.noreply.github.com>
Date: Mon, 9 Sep 2024 22:26:48 +0800
Subject: [PATCH 13/39] fix: skip flash attn
https://huggingface.co/microsoft/Florence-2-base-ft/discussions/13#66836a8d67d2ccf03a96df8d
---
memos/plugins/vlm/main.py | 33 +++++++++++++++++++++++----------
1 file changed, 23 insertions(+), 10 deletions(-)
diff --git a/memos/plugins/vlm/main.py b/memos/plugins/vlm/main.py
index ad636b6..6190a73 100644
--- a/memos/plugins/vlm/main.py
+++ b/memos/plugins/vlm/main.py
@@ -12,6 +12,18 @@ import io
import torch
from transformers import AutoModelForCausalLM, AutoProcessor
+from unittest.mock import patch
+from transformers.dynamic_module_utils import get_imports
+
+
+def fixed_get_imports(filename: str | os.PathLike) -> list[str]:
+ if not str(filename).endswith("modeling_florence2.py"):
+ return get_imports(filename)
+ imports = get_imports(filename)
+ imports.remove("flash_attn")
+ return imports
+
+
PLUGIN_NAME = "vlm"
PROMPT = "描述这张图片的内容"
@@ -251,7 +263,7 @@ async def vlm(entity: Entity, request: Request):
def init_plugin(config):
global modelname, endpoint, token, concurrency, semaphore, force_jpeg, use_local, florence_model, florence_processor, torch_dtype
-
+
modelname = config.modelname
endpoint = config.endpoint
token = config.token
@@ -279,15 +291,16 @@ def init_plugin(config):
)
logger.info(f"Using device: {device}")
- florence_model = AutoModelForCausalLM.from_pretrained(
- "microsoft/Florence-2-base-ft",
- torch_dtype=torch_dtype,
- attn_implementation="sdpa",
- trust_remote_code=True,
- ).to(device)
- florence_processor = AutoProcessor.from_pretrained(
- "microsoft/Florence-2-base-ft", trust_remote_code=True
- )
+ with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
+ florence_model = AutoModelForCausalLM.from_pretrained(
+ "microsoft/Florence-2-base-ft",
+ torch_dtype=torch_dtype,
+ attn_implementation="sdpa",
+ trust_remote_code=True,
+ ).to(device)
+ florence_processor = AutoProcessor.from_pretrained(
+ "microsoft/Florence-2-base-ft", trust_remote_code=True
+ )
logger.info("Florence model and processor initialized")
# Print the parameters
From a3394b9250e90791f1311bac14cfef6b7be13991 Mon Sep 17 00:00:00 2001
From: arkohut <39525455+arkohut@users.noreply.github.com>
Date: Mon, 9 Sep 2024 22:28:41 +0800
Subject: [PATCH 14/39] fix: add dependencies for florence 2
---
pyproject.toml | 2 ++
1 file changed, 2 insertions(+)
diff --git a/pyproject.toml b/pyproject.toml
index 303d8ca..83a9801 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -39,6 +39,8 @@ dependencies = [
"sentence-transformers",
"torch",
"numpy",
+ "timm",
+ "einops",
]
[project.urls]
From 939bcc0e818f105c34d844cc8a99dbb2b81b87da Mon Sep 17 00:00:00 2001
From: arkohut <39525455+arkohut@users.noreply.github.com>
Date: Mon, 9 Sep 2024 22:29:02 +0800
Subject: [PATCH 15/39] chore: bump 0.6.2
---
pyproject.toml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/pyproject.toml b/pyproject.toml
index 83a9801..3317fa6 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "memos"
-version = "0.6.1"
+version = "0.6.2"
description = "A package for memos"
readme = "README.md"
authors = [{ name = "arkohut" }]
From 77b1f34a6268b7265aad4cadbb3ea307e84908f0 Mon Sep 17 00:00:00 2001
From: arkohut <39525455+arkohut@users.noreply.github.com>
Date: Mon, 9 Sep 2024 22:59:24 +0800
Subject: [PATCH 16/39] fix: include yaml in ocr
---
pyproject.toml | 1 +
1 file changed, 1 insertion(+)
diff --git a/pyproject.toml b/pyproject.toml
index 3317fa6..382e235 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -55,3 +55,4 @@ include = ["memos*", "screen_recorder*"]
[tool.setuptools.package-data]
"*" = ["static/**/*"]
+"memos.plugins.ocr" = ["*.yaml"]
From d08c9e0e3630a1534260e786deb9ab511b9ca95d Mon Sep 17 00:00:00 2001
From: arkohut <39525455+arkohut@users.noreply.github.com>
Date: Mon, 9 Sep 2024 22:59:45 +0800
Subject: [PATCH 17/39] chore: bump 0.6.3
---
pyproject.toml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/pyproject.toml b/pyproject.toml
index 382e235..88fdf77 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "memos"
-version = "0.6.2"
+version = "0.6.3"
description = "A package for memos"
readme = "README.md"
authors = [{ name = "arkohut" }]
From 3912d165f67a3d5ef50b43190811624a1eef38bc Mon Sep 17 00:00:00 2001
From: arkohut <39525455+arkohut@users.noreply.github.com>
Date: Mon, 9 Sep 2024 23:27:40 +0800
Subject: [PATCH 18/39] Revert "fix: skip flash attn
https://huggingface.co/microsoft/Florence-2-base-ft/discussions/13#66836a8d67d2ccf03a96df8d"
This reverts commit 5b1194f1bc6ad055fe69a9a2207ef1c31cc73721.
---
memos/plugins/vlm/main.py | 33 ++++++++++-----------------------
1 file changed, 10 insertions(+), 23 deletions(-)
diff --git a/memos/plugins/vlm/main.py b/memos/plugins/vlm/main.py
index 6190a73..ad636b6 100644
--- a/memos/plugins/vlm/main.py
+++ b/memos/plugins/vlm/main.py
@@ -12,18 +12,6 @@ import io
import torch
from transformers import AutoModelForCausalLM, AutoProcessor
-from unittest.mock import patch
-from transformers.dynamic_module_utils import get_imports
-
-
-def fixed_get_imports(filename: str | os.PathLike) -> list[str]:
- if not str(filename).endswith("modeling_florence2.py"):
- return get_imports(filename)
- imports = get_imports(filename)
- imports.remove("flash_attn")
- return imports
-
-
PLUGIN_NAME = "vlm"
PROMPT = "描述这张图片的内容"
@@ -263,7 +251,7 @@ async def vlm(entity: Entity, request: Request):
def init_plugin(config):
global modelname, endpoint, token, concurrency, semaphore, force_jpeg, use_local, florence_model, florence_processor, torch_dtype
-
+
modelname = config.modelname
endpoint = config.endpoint
token = config.token
@@ -291,16 +279,15 @@ def init_plugin(config):
)
logger.info(f"Using device: {device}")
- with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
- florence_model = AutoModelForCausalLM.from_pretrained(
- "microsoft/Florence-2-base-ft",
- torch_dtype=torch_dtype,
- attn_implementation="sdpa",
- trust_remote_code=True,
- ).to(device)
- florence_processor = AutoProcessor.from_pretrained(
- "microsoft/Florence-2-base-ft", trust_remote_code=True
- )
+ florence_model = AutoModelForCausalLM.from_pretrained(
+ "microsoft/Florence-2-base-ft",
+ torch_dtype=torch_dtype,
+ attn_implementation="sdpa",
+ trust_remote_code=True,
+ ).to(device)
+ florence_processor = AutoProcessor.from_pretrained(
+ "microsoft/Florence-2-base-ft", trust_remote_code=True
+ )
logger.info("Florence model and processor initialized")
# Print the parameters
From a30fe62bc3e59414b163a452df7a132676f2200b Mon Sep 17 00:00:00 2001
From: arkohut <39525455+arkohut@users.noreply.github.com>
Date: Mon, 9 Sep 2024 23:39:04 +0800
Subject: [PATCH 19/39] fix: yaml parse related
---
memos/config.py | 4 +---
memos/plugins/ocr/main.py | 4 ++--
2 files changed, 3 insertions(+), 5 deletions(-)
diff --git a/memos/config.py b/memos/config.py
index ec9170b..b840254 100644
--- a/memos/config.py
+++ b/memos/config.py
@@ -102,15 +102,13 @@ yaml.add_representer(OrderedDict, dict_representer)
def secret_str_representer(dumper, data):
return dumper.represent_scalar("tag:yaml.org,2002:str", data.get_secret_value())
-
# Custom constructor for SecretStr
def secret_str_constructor(loader, node):
value = loader.construct_scalar(node)
return SecretStr(value)
-
+# Register the representer and constructor only for specific fields
yaml.add_representer(SecretStr, secret_str_representer)
-yaml.add_constructor("tag:yaml.org,2002:str", secret_str_constructor)
def create_default_config():
diff --git a/memos/plugins/ocr/main.py b/memos/plugins/ocr/main.py
index a0b207d..b40be50 100644
--- a/memos/plugins/ocr/main.py
+++ b/memos/plugins/ocr/main.py
@@ -183,10 +183,10 @@ def init_plugin(config):
ocr_config['Cls']['model_path'] = os.path.join(model_dir, os.path.basename(ocr_config['Cls']['model_path']))
ocr_config['Rec']['model_path'] = os.path.join(model_dir, os.path.basename(ocr_config['Rec']['model_path']))
- # Save the updated config to a temporary file
+ # Save the updated config to a temporary file with strings wrapped in double quotes
temp_config_path = os.path.join(os.path.dirname(__file__), "temp_ppocr.yaml")
with open(temp_config_path, 'w') as f:
- yaml.safe_dump(ocr_config, f)
+ yaml.safe_dump(ocr_config, f, default_style='"')
ocr = RapidOCR(config_path=temp_config_path)
thread_pool = ThreadPoolExecutor(max_workers=concurrency)
From 4b189c22d224e0d91a7b6287c00bed79c848c117 Mon Sep 17 00:00:00 2001
From: arkohut <39525455+arkohut@users.noreply.github.com>
Date: Mon, 9 Sep 2024 23:39:41 +0800
Subject: [PATCH 20/39] chore: bump 0.6.4
---
pyproject.toml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/pyproject.toml b/pyproject.toml
index 88fdf77..bc5223e 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "memos"
-version = "0.6.3"
+version = "0.6.4"
description = "A package for memos"
readme = "README.md"
authors = [{ name = "arkohut" }]
From 41c7e136d9eac9370a12dd1458662492de022a38 Mon Sep 17 00:00:00 2001
From: arkohut <39525455+arkohut@users.noreply.github.com>
Date: Tue, 10 Sep 2024 00:31:56 +0800
Subject: [PATCH 21/39] fix: skip flash attn
https://huggingface.co/microsoft/Florence-2-base-ft/discussions/13#66836a8d67d2ccf03a96df8d
---
memos/plugins/vlm/main.py | 33 +++++++++++++++++++++++----------
1 file changed, 23 insertions(+), 10 deletions(-)
diff --git a/memos/plugins/vlm/main.py b/memos/plugins/vlm/main.py
index ad636b6..6190a73 100644
--- a/memos/plugins/vlm/main.py
+++ b/memos/plugins/vlm/main.py
@@ -12,6 +12,18 @@ import io
import torch
from transformers import AutoModelForCausalLM, AutoProcessor
+from unittest.mock import patch
+from transformers.dynamic_module_utils import get_imports
+
+
+def fixed_get_imports(filename: str | os.PathLike) -> list[str]:
+ if not str(filename).endswith("modeling_florence2.py"):
+ return get_imports(filename)
+ imports = get_imports(filename)
+ imports.remove("flash_attn")
+ return imports
+
+
PLUGIN_NAME = "vlm"
PROMPT = "描述这张图片的内容"
@@ -251,7 +263,7 @@ async def vlm(entity: Entity, request: Request):
def init_plugin(config):
global modelname, endpoint, token, concurrency, semaphore, force_jpeg, use_local, florence_model, florence_processor, torch_dtype
-
+
modelname = config.modelname
endpoint = config.endpoint
token = config.token
@@ -279,15 +291,16 @@ def init_plugin(config):
)
logger.info(f"Using device: {device}")
- florence_model = AutoModelForCausalLM.from_pretrained(
- "microsoft/Florence-2-base-ft",
- torch_dtype=torch_dtype,
- attn_implementation="sdpa",
- trust_remote_code=True,
- ).to(device)
- florence_processor = AutoProcessor.from_pretrained(
- "microsoft/Florence-2-base-ft", trust_remote_code=True
- )
+ with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
+ florence_model = AutoModelForCausalLM.from_pretrained(
+ "microsoft/Florence-2-base-ft",
+ torch_dtype=torch_dtype,
+ attn_implementation="sdpa",
+ trust_remote_code=True,
+ ).to(device)
+ florence_processor = AutoProcessor.from_pretrained(
+ "microsoft/Florence-2-base-ft", trust_remote_code=True
+ )
logger.info("Florence model and processor initialized")
# Print the parameters
From 93060af86adaba36864308bac33b0cdeb0d10013 Mon Sep 17 00:00:00 2001
From: arkohut <39525455+arkohut@users.noreply.github.com>
Date: Tue, 10 Sep 2024 00:33:37 +0800
Subject: [PATCH 22/39] fix: include onnx models
---
pyproject.toml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/pyproject.toml b/pyproject.toml
index bc5223e..f50f59d 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -55,4 +55,4 @@ include = ["memos*", "screen_recorder*"]
[tool.setuptools.package-data]
"*" = ["static/**/*"]
-"memos.plugins.ocr" = ["*.yaml"]
+"memos.plugins.ocr" = ["*.yaml", "*.onnx"]
From aa9cc028c6bc3289b8a2032602f869dcc4311416 Mon Sep 17 00:00:00 2001
From: arkohut <39525455+arkohut@users.noreply.github.com>
Date: Tue, 10 Sep 2024 00:34:03 +0800
Subject: [PATCH 23/39] chore: bump 0.6.5
---
pyproject.toml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/pyproject.toml b/pyproject.toml
index f50f59d..76c0100 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "memos"
-version = "0.6.4"
+version = "0.6.5"
description = "A package for memos"
readme = "README.md"
authors = [{ name = "arkohut" }]
From a8c42d854618e295cfe5de8b3af5608af6dbbde0 Mon Sep 17 00:00:00 2001
From: arkohut <39525455+arkohut@users.noreply.github.com>
Date: Tue, 10 Sep 2024 00:40:12 +0800
Subject: [PATCH 24/39] fix: include onnx models
---
pyproject.toml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/pyproject.toml b/pyproject.toml
index 76c0100..9a41324 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -55,4 +55,4 @@ include = ["memos*", "screen_recorder*"]
[tool.setuptools.package-data]
"*" = ["static/**/*"]
-"memos.plugins.ocr" = ["*.yaml", "*.onnx"]
+"memos.plugins.ocr" = ["*.yaml", "models/*.onnx"]
From 95047de12ba9bb48a832f303917617533437f973 Mon Sep 17 00:00:00 2001
From: arkohut <39525455+arkohut@users.noreply.github.com>
Date: Tue, 10 Sep 2024 00:40:34 +0800
Subject: [PATCH 25/39] chore: bump 0.6.6
---
pyproject.toml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/pyproject.toml b/pyproject.toml
index 9a41324..1bcac45 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "memos"
-version = "0.6.5"
+version = "0.6.6"
description = "A package for memos"
readme = "README.md"
authors = [{ name = "arkohut" }]
From 3a408f0be70cfa139a5aa3e51dae5e301eddb366 Mon Sep 17 00:00:00 2001
From: arkohut <39525455+arkohut@users.noreply.github.com>
Date: Tue, 10 Sep 2024 00:50:42 +0800
Subject: [PATCH 26/39] docs: tell default user / password
---
README.md | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/README.md b/README.md
index f6d3a49..d763168 100644
--- a/README.md
+++ b/README.md
@@ -56,7 +56,8 @@ This will start a screen recorder, which will take screenshots every 5 seconds a
memos serve
```
-This will start a web server, and you can access the web interface at `http://localhost:8080`.
+This will start a web server, and you can access the web interface at `http://localhost:8080`.
+The default username and password is `admin` and `changeme`.
### Index the screenshots
From f5aae87f40933ff0f27776c3fc32842d889e1289 Mon Sep 17 00:00:00 2001
From: arkohut <39525455+arkohut@users.noreply.github.com>
Date: Tue, 10 Sep 2024 00:54:20 +0800
Subject: [PATCH 27/39] feat: use concurrency of 1 by default
---
memos/config.py | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/memos/config.py b/memos/config.py
index b840254..a572e6f 100644
--- a/memos/config.py
+++ b/memos/config.py
@@ -17,7 +17,7 @@ class VLMSettings(BaseModel):
modelname: str = "moondream"
endpoint: str = "http://localhost:11434"
token: str = ""
- concurrency: int = 4
+ concurrency: int = 1
force_jpeg: bool = False
use_local: bool = True
@@ -26,7 +26,7 @@ class OCRSettings(BaseModel):
enabled: bool = True
endpoint: str = "http://localhost:5555/predict"
token: str = ""
- concurrency: int = 4
+ concurrency: int = 1
use_local: bool = True
use_gpu: bool = False
force_jpeg: bool = False
@@ -71,7 +71,7 @@ class Settings(BaseSettings):
# Embedding settings
embedding: EmbeddingSettings = EmbeddingSettings()
- batchsize: int = 4
+ batchsize: int = 1
auth_username: str = "admin"
auth_password: SecretStr = SecretStr("changeme")
From 241911d1d22dd0aed69946c96df71ae35e2982ac Mon Sep 17 00:00:00 2001
From: arkohut <39525455+arkohut@users.noreply.github.com>
Date: Tue, 10 Sep 2024 00:58:57 +0800
Subject: [PATCH 28/39] feat(index): use different bs for embedding
---
memos/commands.py | 8 +++++---
1 file changed, 5 insertions(+), 3 deletions(-)
diff --git a/memos/commands.py b/memos/commands.py
index 8092918..d44dc61 100644
--- a/memos/commands.py
+++ b/memos/commands.py
@@ -548,6 +548,9 @@ def index(
library_id: int,
folders: List[int] = typer.Option(None, "--folder", "-f"),
force: bool = typer.Option(False, "--force", help="Force update all indexes"),
+ batchsize: int = typer.Option(
+ 4, "--batchsize", "-bs", help="Number of entities to index in a batch"
+ ),
):
print(f"Indexing library {library_id}")
@@ -607,9 +610,8 @@ def index(
pbar.refresh()
# Index each entity
- batch_size = settings.batchsize
- for i in range(0, len(entities), batch_size):
- batch = entities[i : i + batch_size]
+ for i in range(0, len(entities), batchsize):
+ batch = entities[i : i + batchsize]
to_index = []
for entity in batch:
From 378e5bf445b800f2928b2d5363c846474295b50b Mon Sep 17 00:00:00 2001
From: arkohut <39525455+arkohut@users.noreply.github.com>
Date: Tue, 10 Sep 2024 01:32:00 +0800
Subject: [PATCH 29/39] feat: support modelscope
---
memos/config.py | 2 ++
memos/plugins/embedding/main.py | 21 ++++++++++++++++++---
memos/plugins/vlm/main.py | 16 +++++++++++++---
pyproject.toml | 1 +
4 files changed, 34 insertions(+), 6 deletions(-)
diff --git a/memos/config.py b/memos/config.py
index a572e6f..6986b7a 100644
--- a/memos/config.py
+++ b/memos/config.py
@@ -51,6 +51,8 @@ class Settings(BaseSettings):
default_library: str = "screenshots"
screenshots_dir: str = os.path.join(base_dir, "screenshots")
+ use_modelscope: bool = False
+
typesense_host: str = "localhost"
typesense_port: str = "8108"
typesense_protocol: str = "http"
diff --git a/memos/plugins/embedding/main.py b/memos/plugins/embedding/main.py
index 89e837a..c7fbec4 100644
--- a/memos/plugins/embedding/main.py
+++ b/memos/plugins/embedding/main.py
@@ -7,6 +7,7 @@ from sentence_transformers import SentenceTransformer
import torch
import numpy as np
from pydantic import BaseModel
+from modelscope import snapshot_download
PLUGIN_NAME = "embedding"
@@ -19,6 +20,7 @@ num_dim = None
endpoint = None
model_name = None
device = None
+use_modelscope = None
# Configure logger
logging.basicConfig(level=logging.INFO)
@@ -26,7 +28,7 @@ logger = logging.getLogger(__name__)
def init_embedding_model():
- global model, device
+ global model, device, use_modelscope
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
@@ -34,7 +36,14 @@ def init_embedding_model():
else:
device = torch.device("cpu")
- model = SentenceTransformer(model_name, trust_remote_code=True)
+ if use_modelscope:
+ model_dir = snapshot_download(model_name)
+ logger.info(f"Model downloaded from ModelScope to: {model_dir}")
+ else:
+ model_dir = model_name
+ logger.info(f"Using model: {model_dir}")
+
+ model = SentenceTransformer(model_dir, trust_remote_code=True)
model.to(device)
logger.info(f"Embedding model initialized on device: {device}")
@@ -80,11 +89,12 @@ async def embed(request: EmbeddingRequest):
def init_plugin(config):
- global enabled, num_dim, endpoint, model_name
+ global enabled, num_dim, endpoint, model_name, use_modelscope
enabled = config.enabled
num_dim = config.num_dim
endpoint = config.endpoint
model_name = config.model
+ use_modelscope = config.use_modelscope
if enabled:
init_embedding_model()
@@ -94,6 +104,7 @@ def init_plugin(config):
logger.info(f"Number of dimensions: {num_dim}")
logger.info(f"Endpoint: {endpoint}")
logger.info(f"Model: {model_name}")
+ logger.info(f"Use ModelScope: {use_modelscope}")
if __name__ == "__main__":
@@ -113,6 +124,9 @@ if __name__ == "__main__":
parser.add_argument(
"--port", type=int, default=8000, help="Port to run the server on"
)
+ parser.add_argument(
+ "--use-modelscope", action="store_true", help="Use ModelScope to download the model"
+ )
args = parser.parse_args()
@@ -122,6 +136,7 @@ if __name__ == "__main__":
self.num_dim = args.num_dim
self.endpoint = "what ever"
self.model = args.model
+ self.use_modelscope = args.use_modelscope
init_plugin(Config(args))
diff --git a/memos/plugins/vlm/main.py b/memos/plugins/vlm/main.py
index 6190a73..fc905a4 100644
--- a/memos/plugins/vlm/main.py
+++ b/memos/plugins/vlm/main.py
@@ -14,7 +14,7 @@ from transformers import AutoModelForCausalLM, AutoProcessor
from unittest.mock import patch
from transformers.dynamic_module_utils import get_imports
-
+from modelscope import snapshot_download
def fixed_get_imports(filename: str | os.PathLike) -> list[str]:
if not str(filename).endswith("modeling_florence2.py"):
@@ -270,6 +270,7 @@ def init_plugin(config):
concurrency = config.concurrency
force_jpeg = config.force_jpeg
use_local = config.use_local
+ use_modelscope = config.use_modelscope
semaphore = asyncio.Semaphore(concurrency)
if use_local:
@@ -291,15 +292,22 @@ def init_plugin(config):
)
logger.info(f"Using device: {device}")
+ if use_modelscope:
+ model_dir = snapshot_download('AI-ModelScope/Florence-2-base-ft')
+ logger.info(f"Model downloaded from ModelScope to: {model_dir}")
+ else:
+ model_dir = "microsoft/Florence-2-base-ft"
+ logger.info(f"Using model: {model_dir}")
+
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
florence_model = AutoModelForCausalLM.from_pretrained(
- "microsoft/Florence-2-base-ft",
+ model_dir,
torch_dtype=torch_dtype,
attn_implementation="sdpa",
trust_remote_code=True,
).to(device)
florence_processor = AutoProcessor.from_pretrained(
- "microsoft/Florence-2-base-ft", trust_remote_code=True
+ model_dir, trust_remote_code=True
)
logger.info("Florence model and processor initialized")
@@ -311,6 +319,7 @@ def init_plugin(config):
logger.info(f"Concurrency: {concurrency}")
logger.info(f"Force JPEG: {force_jpeg}")
logger.info(f"Use Local: {use_local}")
+ logger.info(f"Use ModelScope: {use_modelscope}")
if __name__ == "__main__":
@@ -329,6 +338,7 @@ if __name__ == "__main__":
parser.add_argument(
"--port", type=int, default=8000, help="Port to run the server on"
)
+ parser.add_argument("--use-modelscope", action="store_true", help="Use ModelScope to download the model")
args = parser.parse_args()
diff --git a/pyproject.toml b/pyproject.toml
index 1bcac45..5e44755 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -41,6 +41,7 @@ dependencies = [
"numpy",
"timm",
"einops",
+ "modelscope",
]
[project.urls]
From 513b41e3001c936994e7051496ac51609595ad5f Mon Sep 17 00:00:00 2001
From: arkohut <39525455+arkohut@users.noreply.github.com>
Date: Tue, 10 Sep 2024 01:32:22 +0800
Subject: [PATCH 30/39] chore: bump 0.6.7
---
pyproject.toml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/pyproject.toml b/pyproject.toml
index 5e44755..6a7509c 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "memos"
-version = "0.6.6"
+version = "0.6.7"
description = "A package for memos"
readme = "README.md"
authors = [{ name = "arkohut" }]
From 6bbf1a8a6805a7e808fd35769518d529909b6c30 Mon Sep 17 00:00:00 2001
From: arkohut <39525455+arkohut@users.noreply.github.com>
Date: Tue, 10 Sep 2024 01:59:37 +0800
Subject: [PATCH 31/39] fix: use modelscope for sub config
---
memos/config.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/memos/config.py b/memos/config.py
index 6986b7a..aa57145 100644
--- a/memos/config.py
+++ b/memos/config.py
@@ -20,6 +20,7 @@ class VLMSettings(BaseModel):
concurrency: int = 1
force_jpeg: bool = False
use_local: bool = True
+ use_modelscope: bool = False
class OCRSettings(BaseModel):
@@ -37,6 +38,7 @@ class EmbeddingSettings(BaseModel):
num_dim: int = 768
endpoint: str = "http://localhost:11434/api/embed"
model: str = "jinaai/jina-embeddings-v2-base-zh"
+ use_modelscope: bool = False
class Settings(BaseSettings):
@@ -51,8 +53,6 @@ class Settings(BaseSettings):
default_library: str = "screenshots"
screenshots_dir: str = os.path.join(base_dir, "screenshots")
- use_modelscope: bool = False
-
typesense_host: str = "localhost"
typesense_port: str = "8108"
typesense_protocol: str = "http"
From 23612d1fd5a0c4ec6da5da697cc28c19b3777a18 Mon Sep 17 00:00:00 2001
From: arkohut <39525455+arkohut@users.noreply.github.com>
Date: Tue, 10 Sep 2024 01:59:59 +0800
Subject: [PATCH 32/39] chore: bump 0.6.8
---
pyproject.toml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/pyproject.toml b/pyproject.toml
index 6a7509c..bed964e 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "memos"
-version = "0.6.7"
+version = "0.6.8"
description = "A package for memos"
readme = "README.md"
authors = [{ name = "arkohut" }]
From 7f78eb1a4ad6a01e2eaaf7d6675583ff928cad89 Mon Sep 17 00:00:00 2001
From: arkohut <39525455+arkohut@users.noreply.github.com>
Date: Tue, 10 Sep 2024 11:54:56 +0800
Subject: [PATCH 33/39] chore: update yaml generate config
---
memos/plugins/ocr/main.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/memos/plugins/ocr/main.py b/memos/plugins/ocr/main.py
index b40be50..86bea9f 100644
--- a/memos/plugins/ocr/main.py
+++ b/memos/plugins/ocr/main.py
@@ -186,7 +186,7 @@ def init_plugin(config):
# Save the updated config to a temporary file with strings wrapped in double quotes
temp_config_path = os.path.join(os.path.dirname(__file__), "temp_ppocr.yaml")
with open(temp_config_path, 'w') as f:
- yaml.safe_dump(ocr_config, f, default_style='"')
+ yaml.safe_dump(ocr_config, f)
ocr = RapidOCR(config_path=temp_config_path)
thread_pool = ThreadPoolExecutor(max_workers=concurrency)
From 5882950c39b2487a4013d9fcc75ff47cf51029ef Mon Sep 17 00:00:00 2001
From: arkohut <39525455+arkohut@users.noreply.github.com>
Date: Tue, 10 Sep 2024 12:39:08 +0800
Subject: [PATCH 34/39] feat(cli): add extra args
---
memos/commands.py | 13 +++++++++----
1 file changed, 9 insertions(+), 4 deletions(-)
diff --git a/memos/commands.py b/memos/commands.py
index d44dc61..47242be 100644
--- a/memos/commands.py
+++ b/memos/commands.py
@@ -765,7 +765,7 @@ def init():
@app.command("scan")
-def scan_default_library():
+def scan_default_library(force: bool = False):
"""
Scan the screenshots directory and add it to the library if empty.
"""
@@ -810,11 +810,16 @@ def scan_default_library():
# Scan the library
print(f"Scanning library: {default_library['name']}")
- scan(default_library["id"], plugins=None, folders=None)
+ scan(default_library["id"], plugins=None, folders=None, force=force)
@app.command("index")
-def index_default_library():
+def index_default_library(
+ batchsize: int = typer.Option(
+ 4, "--batchsize", "-bs", help="Number of entities to index in a batch"
+ ),
+ force: bool = typer.Option(False, "--force", help="Force update all indexes"),
+):
"""
Index the default library for memos.
"""
@@ -833,7 +838,7 @@ def index_default_library():
print("Default library does not exist.")
return
- index(default_library["id"], force=False, folders=None)
+ index(default_library["id"], force=force, folders=None, batchsize=batchsize)
if __name__ == "__main__":
From fca387b22d7f574ce81d5542e6460766b67973c2 Mon Sep 17 00:00:00 2001
From: arkohut <39525455+arkohut@users.noreply.github.com>
Date: Tue, 10 Sep 2024 13:55:52 +0800
Subject: [PATCH 35/39] feat: support bind plugin by name
---
memos/commands.py | 16 ++++++++++------
memos/schemas.py | 31 ++++++++++++++++++++++++++++---
memos/server.py | 21 +++++++++++++++++++--
3 files changed, 57 insertions(+), 11 deletions(-)
diff --git a/memos/commands.py b/memos/commands.py
index 47242be..744bd9d 100644
--- a/memos/commands.py
+++ b/memos/commands.py
@@ -723,18 +723,22 @@ def create(name: str, webhook_url: str, description: str = ""):
@plugin_app.command("bind")
def bind(
library_id: int = typer.Option(..., "--lib", help="ID of the library"),
- plugin_id: int = typer.Option(..., "--plugin", help="ID of the plugin"),
+ plugin: str = typer.Option(..., "--plugin", help="ID or name of the plugin"),
):
+ try:
+ plugin_id = int(plugin)
+ plugin_param = {"plugin_id": plugin_id}
+ except ValueError:
+ plugin_param = {"plugin_name": plugin}
+
response = httpx.post(
f"{BASE_URL}/libraries/{library_id}/plugins",
- json={"plugin_id": plugin_id},
+ json=plugin_param,
)
- if 200 <= response.status_code < 300:
+ if response.status_code == 204:
print("Plugin bound to library successfully")
else:
- print(
- f"Failed to bind plugin to library: {response.status_code} - {response.text}"
- )
+ print(f"Failed to bind plugin to library: {response.status_code} - {response.text}")
@plugin_app.command("unbind")
diff --git a/memos/schemas.py b/memos/schemas.py
index edd040f..b1b6042 100644
--- a/memos/schemas.py
+++ b/memos/schemas.py
@@ -1,4 +1,11 @@
-from pydantic import BaseModel, ConfigDict, DirectoryPath, HttpUrl, Field
+from pydantic import (
+ BaseModel,
+ ConfigDict,
+ DirectoryPath,
+ HttpUrl,
+ Field,
+ model_validator,
+)
from typing import List, Optional, Any, Dict
from datetime import datetime
from enum import Enum
@@ -79,7 +86,18 @@ class NewPluginParam(BaseModel):
class NewLibraryPluginParam(BaseModel):
- plugin_id: int
+ plugin_id: Optional[int] = None
+ plugin_name: Optional[str] = None
+
+ @model_validator(mode="after")
+ def check_either_id_or_name(self):
+ plugin_id = self.plugin_id
+ plugin_name = self.plugin_name
+ if not (plugin_id or plugin_name):
+ raise ValueError("Either plugin_id or plugin_name must be provided")
+ if plugin_id is not None and plugin_name is not None:
+ raise ValueError("Only one of plugin_id or plugin_name should be provided")
+ return self
class Folder(BaseModel):
@@ -214,15 +232,18 @@ class FacetCount(BaseModel):
highlighted: str
value: str
+
class FacetStats(BaseModel):
total_values: int
+
class Facet(BaseModel):
counts: List[FacetCount]
field_name: str
sampled: bool
stats: FacetStats
+
class TextMatchInfo(BaseModel):
best_field_score: str
best_field_weight: int
@@ -232,9 +253,11 @@ class TextMatchInfo(BaseModel):
tokens_matched: int
typo_prefix_score: int
+
class HybridSearchInfo(BaseModel):
rank_fusion_score: float
+
class SearchHit(BaseModel):
document: EntitySearchResult
highlight: Dict[str, Any] = {}
@@ -243,12 +266,14 @@ class SearchHit(BaseModel):
text_match: Optional[int] = None
text_match_info: Optional[TextMatchInfo] = None
+
class RequestParams(BaseModel):
collection_name: str
first_q: str
per_page: int
q: str
+
class SearchResult(BaseModel):
facet_counts: List[Facet]
found: int
@@ -257,4 +282,4 @@ class SearchResult(BaseModel):
page: int
request_params: RequestParams
search_cutoff: bool
- search_time_ms: int
\ No newline at end of file
+ search_time_ms: int
diff --git a/memos/server.py b/memos/server.py
index d48e355..f40982e 100644
--- a/memos/server.py
+++ b/memos/server.py
@@ -602,12 +602,29 @@ def add_library_plugin(
library_id: int, new_plugin: NewLibraryPluginParam, db: Session = Depends(get_db)
):
library = crud.get_library_by_id(library_id, db)
- if any(plugin.id == new_plugin.plugin_id for plugin in library.plugins):
+ if library is None:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND, detail="Library not found"
+ )
+
+ plugin = None
+ if new_plugin.plugin_id is not None:
+ plugin = crud.get_plugin_by_id(new_plugin.plugin_id, db)
+ elif new_plugin.plugin_name is not None:
+ plugin = crud.get_plugin_by_name(new_plugin.plugin_name, db)
+
+ if plugin is None:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND, detail="Plugin not found"
+ )
+
+ if any(p.id == plugin.id for p in library.plugins):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Plugin already exists in the library",
)
- crud.add_plugin_to_library(library_id, new_plugin.plugin_id, db)
+
+ crud.add_plugin_to_library(library_id, plugin.id, db)
@app.delete(
From 167fd3105358e7fbb7a3559c29f01b105470fb64 Mon Sep 17 00:00:00 2001
From: arkohut <39525455+arkohut@users.noreply.github.com>
Date: Tue, 10 Sep 2024 15:47:36 +0800
Subject: [PATCH 36/39] feat: add builtin plugins for default library
---
memos/commands.py | 3 +++
memos/config.py | 4 +++-
memos/models.py | 4 ++--
3 files changed, 8 insertions(+), 3 deletions(-)
diff --git a/memos/commands.py b/memos/commands.py
index 744bd9d..7167234 100644
--- a/memos/commands.py
+++ b/memos/commands.py
@@ -797,6 +797,9 @@ def scan_default_library(force: bool = False):
return
default_library = response.json()
+ for plugin in settings.default_plugins:
+ bind(default_library["id"], plugin)
+
# Check if the library is empty
if not default_library["folders"]:
# Add the screenshots directory to the library
diff --git a/memos/config.py b/memos/config.py
index aa57145..b979852 100644
--- a/memos/config.py
+++ b/memos/config.py
@@ -1,6 +1,6 @@
import os
from pathlib import Path
-from typing import Tuple, Type
+from typing import Tuple, Type, List
from pydantic_settings import (
BaseSettings,
PydanticBaseSettingsSource,
@@ -78,6 +78,8 @@ class Settings(BaseSettings):
auth_username: str = "admin"
auth_password: SecretStr = SecretStr("changeme")
+ default_plugins: List[str] = ["builtin_vlm", "builtin_ocr"]
+
@classmethod
def settings_customise_sources(
cls,
diff --git a/memos/models.py b/memos/models.py
index 740e8e6..72bc91c 100644
--- a/memos/models.py
+++ b/memos/models.py
@@ -178,10 +178,10 @@ def init_database():
def initialize_default_plugins(session):
default_plugins = [
PluginModel(
- name="buildin_vlm", description="VLM Plugin", webhook_url="/plugins/vlm"
+ name="builtin_vlm", description="VLM Plugin", webhook_url="/plugins/vlm"
),
PluginModel(
- name="buildin_ocr", description="OCR Plugin", webhook_url="/plugins/ocr"
+ name="builtin_ocr", description="OCR Plugin", webhook_url="/plugins/ocr"
),
]
From c24d9cc2e917ac2eed51ddf95f6ccf6469b42d0f Mon Sep 17 00:00:00 2001
From: arkohut <39525455+arkohut@users.noreply.github.com>
Date: Tue, 10 Sep 2024 18:28:47 +0800
Subject: [PATCH 37/39] chore: bump 0.6.9
---
README.md | 2 +-
pyproject.toml | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/README.md b/README.md
index d763168..1836d26 100644
--- a/README.md
+++ b/README.md
@@ -56,7 +56,7 @@ This will start a screen recorder, which will take screenshots every 5 seconds a
memos serve
```
-This will start a web server, and you can access the web interface at `http://localhost:8080`.
+This will start a web server, and you can access the web interface at `http://localhost:8080`.
The default username and password is `admin` and `changeme`.
### Index the screenshots
diff --git a/pyproject.toml b/pyproject.toml
index bed964e..bcc2ce6 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "memos"
-version = "0.6.8"
+version = "0.6.9"
description = "A package for memos"
readme = "README.md"
authors = [{ name = "arkohut" }]
From 30f9a45d8a2bfd43cbfada916b8b042b3c6fcfa5 Mon Sep 17 00:00:00 2001
From: arkohut <39525455+arkohut@users.noreply.github.com>
Date: Wed, 11 Sep 2024 17:50:33 +0800
Subject: [PATCH 38/39] feat: use mss instead of imagegrab
---
pyproject.toml | 1 +
screen_recorder/common.py | 84 ++++++++++++++++++---------------------
2 files changed, 39 insertions(+), 46 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index bcc2ce6..3f2bf4b 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -42,6 +42,7 @@ dependencies = [
"timm",
"einops",
"modelscope",
+ "mss",
]
[project.urls]
diff --git a/screen_recorder/common.py b/screen_recorder/common.py
index 99e670a..f87ca44 100644
--- a/screen_recorder/common.py
+++ b/screen_recorder/common.py
@@ -4,11 +4,11 @@ import time
import logging
import platform
import subprocess
-from PIL import Image, ImageGrab
+from PIL import Image
import imagehash
from memos.utils import write_image_metadata
-from screeninfo import get_monitors
import ctypes
+from mss import mss
if platform.system() == "Windows":
import win32gui
@@ -173,57 +173,49 @@ def take_screenshot_windows(
app_name,
window_title,
):
- for monitor in get_monitors():
- safe_monitor_name = "".join(
- c for c in monitor.name if c.isalnum() or c in ("_", "-")
- )
- logging.info(f"Processing monitor: {safe_monitor_name}")
+ with mss() as sct:
+ for i, monitor in enumerate(sct.monitors[1:], 1): # Skip the first monitor (entire screen)
+ safe_monitor_name = f"monitor_{i}"
+ logging.info(f"Processing monitor: {safe_monitor_name}")
- webp_filename = os.path.join(
- base_dir, date, f"screenshot-{timestamp}-of-{safe_monitor_name}.webp"
- )
-
- img = ImageGrab.grab(
- bbox=(
- monitor.x,
- monitor.y,
- monitor.x + monitor.width,
- monitor.y + monitor.height,
+ webp_filename = os.path.join(
+ base_dir, date, f"screenshot-{timestamp}-of-{safe_monitor_name}.webp"
)
- )
- img = img.convert("RGB")
- current_hash = str(imagehash.phash(img))
- if (
- safe_monitor_name in previous_hashes
- and imagehash.hex_to_hash(current_hash)
- - imagehash.hex_to_hash(previous_hashes[safe_monitor_name])
- < threshold
- ):
- logging.info(
- f"Screenshot for {safe_monitor_name} is similar to the previous one. Skipping."
+ img = sct.grab(monitor)
+ img = Image.frombytes("RGB", img.size, img.bgra, "raw", "BGRX")
+ current_hash = str(imagehash.phash(img))
+
+ if (
+ safe_monitor_name in previous_hashes
+ and imagehash.hex_to_hash(current_hash)
+ - imagehash.hex_to_hash(previous_hashes[safe_monitor_name])
+ < threshold
+ ):
+ logging.info(
+ f"Screenshot for {safe_monitor_name} is similar to the previous one. Skipping."
+ )
+ yield safe_monitor_name, None, "Skipped (similar to previous)"
+ continue
+
+ previous_hashes[safe_monitor_name] = current_hash
+ screen_sequences[safe_monitor_name] = (
+ screen_sequences.get(safe_monitor_name, 0) + 1
)
- yield safe_monitor_name, None, "Skipped (similar to previous)"
- continue
- previous_hashes[safe_monitor_name] = current_hash
- screen_sequences[safe_monitor_name] = (
- screen_sequences.get(safe_monitor_name, 0) + 1
- )
+ metadata = {
+ "timestamp": timestamp,
+ "active_app": app_name,
+ "active_window": window_title,
+ "screen_name": safe_monitor_name,
+ "sequence": screen_sequences[safe_monitor_name],
+ }
- metadata = {
- "timestamp": timestamp,
- "active_app": app_name,
- "active_window": window_title,
- "screen_name": safe_monitor_name,
- "sequence": screen_sequences[safe_monitor_name],
- }
+ img.save(webp_filename, format="WebP", quality=85)
+ write_image_metadata(webp_filename, metadata)
+ save_screen_sequences(base_dir, screen_sequences, date)
- img.save(webp_filename, format="WebP", quality=85)
- write_image_metadata(webp_filename, metadata)
- save_screen_sequences(base_dir, screen_sequences, date)
-
- yield safe_monitor_name, webp_filename, "Saved"
+ yield safe_monitor_name, webp_filename, "Saved"
def take_screenshot(
From 40fad9176f4f100d4e4f0a9514da4e26b437c622 Mon Sep 17 00:00:00 2001
From: arkohut <39525455+arkohut@users.noreply.github.com>
Date: Wed, 11 Sep 2024 17:50:46 +0800
Subject: [PATCH 39/39] chore: bump 0.6.10
---
pyproject.toml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/pyproject.toml b/pyproject.toml
index 3f2bf4b..1a07432 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "memos"
-version = "0.6.9"
+version = "0.6.10"
description = "A package for memos"
readme = "README.md"
authors = [{ name = "arkohut" }]