This commit is contained in:
arkohut 2024-09-11 23:20:49 +08:00
commit 4a8d8fa54a
15 changed files with 656 additions and 176 deletions

View File

@ -1,3 +1,69 @@
# Memos # Memos
A project to index everything to make it like another memory. A project to index everything to make it like another memory. The project contains two parts:
1. `screen recorder`: which takes screenshots every 5 seconds and saves them to `~/.memos/screenshots` by default.
2. `memos server`: a web service that can index the screenshots and other files, providing a web interface to search the records.
There is a product called [Rewind](https://www.rewind.ai/) that is similar to memos, but memos aims to give you control over all your data.
## Install
### Install Typesense
```bash
export TYPESENSE_API_KEY=xyz
mkdir "$(pwd)"/typesense-data
docker run -d -p 8108:8108 \
-v"$(pwd)"/typesense-data:/data typesense/typesense:27.0 \
--add-host=host.docker.internal:host-gateway \
--data-dir /data \
--api-key=$TYPESENSE_API_KEY \
--enable-cors
```
### Install Memos
```bash
pip install memos
```
## How to use
To use memos, you need to initialize it first. Make sure you have started `typesense`.
### 1. Initialize Memos
```bash
memos init
```
This will create a folder `~/.memos` and put the config file there.
### 2. Start Screen Recorder
```bash
memos-record
```
This will start a screen recorder, which will take screenshots every 5 seconds and save it at `~/.memos/screenshots` by default.
### 3. Start Memos Server
```bash
memos serve
```
This will start a web server, and you can access the web interface at `http://localhost:8080`.
The default username and password is `admin` and `changeme`.
### Index the screenshots
```bash
memos scan
memos index
```
Refresh the page, and do some search.

View File

@ -1,6 +1,5 @@
import asyncio import asyncio
import os import os
import time
import logging import logging
from datetime import datetime, timezone from datetime import datetime, timezone
from pathlib import Path from pathlib import Path
@ -549,6 +548,9 @@ def index(
library_id: int, library_id: int,
folders: List[int] = typer.Option(None, "--folder", "-f"), folders: List[int] = typer.Option(None, "--folder", "-f"),
force: bool = typer.Option(False, "--force", help="Force update all indexes"), force: bool = typer.Option(False, "--force", help="Force update all indexes"),
batchsize: int = typer.Option(
4, "--batchsize", "-bs", help="Number of entities to index in a batch"
),
): ):
print(f"Indexing library {library_id}") print(f"Indexing library {library_id}")
@ -608,9 +610,8 @@ def index(
pbar.refresh() pbar.refresh()
# Index each entity # Index each entity
batch_size = settings.batchsize for i in range(0, len(entities), batchsize):
for i in range(0, len(entities), batch_size): batch = entities[i : i + batchsize]
batch = entities[i : i + batch_size]
to_index = [] to_index = []
for entity in batch: for entity in batch:
@ -722,18 +723,22 @@ def create(name: str, webhook_url: str, description: str = ""):
@plugin_app.command("bind") @plugin_app.command("bind")
def bind( def bind(
library_id: int = typer.Option(..., "--lib", help="ID of the library"), library_id: int = typer.Option(..., "--lib", help="ID of the library"),
plugin_id: int = typer.Option(..., "--plugin", help="ID of the plugin"), plugin: str = typer.Option(..., "--plugin", help="ID or name of the plugin"),
): ):
try:
plugin_id = int(plugin)
plugin_param = {"plugin_id": plugin_id}
except ValueError:
plugin_param = {"plugin_name": plugin}
response = httpx.post( response = httpx.post(
f"{BASE_URL}/libraries/{library_id}/plugins", f"{BASE_URL}/libraries/{library_id}/plugins",
json={"plugin_id": plugin_id}, json=plugin_param,
) )
if 200 <= response.status_code < 300: if response.status_code == 204:
print("Plugin bound to library successfully") print("Plugin bound to library successfully")
else: else:
print( print(f"Failed to bind plugin to library: {response.status_code} - {response.text}")
f"Failed to bind plugin to library: {response.status_code} - {response.text}"
)
@plugin_app.command("unbind") @plugin_app.command("unbind")
@ -763,5 +768,85 @@ def init():
print("Initialization failed. Please check the error messages above.") print("Initialization failed. Please check the error messages above.")
@app.command("scan")
def scan_default_library(force: bool = False):
"""
Scan the screenshots directory and add it to the library if empty.
"""
# Get the default library
response = httpx.get(f"{BASE_URL}/libraries")
if response.status_code != 200:
print(f"Failed to retrieve libraries: {response.status_code} - {response.text}")
return
libraries = response.json()
default_library = next(
(lib for lib in libraries if lib["name"] == settings.default_library), None
)
if not default_library:
# Create the default library if it doesn't exist
response = httpx.post(
f"{BASE_URL}/libraries",
json={"name": settings.default_library, "folders": []},
)
if response.status_code != 200:
print(
f"Failed to create default library: {response.status_code} - {response.text}"
)
return
default_library = response.json()
for plugin in settings.default_plugins:
bind(default_library["id"], plugin)
# Check if the library is empty
if not default_library["folders"]:
# Add the screenshots directory to the library
screenshots_dir = Path(settings.screenshots_dir).resolve()
response = httpx.post(
f"{BASE_URL}/libraries/{default_library['id']}/folders",
json={"folders": [str(screenshots_dir)]},
)
if response.status_code != 200:
print(
f"Failed to add screenshots directory: {response.status_code} - {response.text}"
)
return
print(f"Added screenshots directory: {screenshots_dir}")
# Scan the library
print(f"Scanning library: {default_library['name']}")
scan(default_library["id"], plugins=None, folders=None, force=force)
@app.command("index")
def index_default_library(
batchsize: int = typer.Option(
4, "--batchsize", "-bs", help="Number of entities to index in a batch"
),
force: bool = typer.Option(False, "--force", help="Force update all indexes"),
):
"""
Index the default library for memos.
"""
# Get the default library
response = httpx.get(f"{BASE_URL}/libraries")
if response.status_code != 200:
print(f"Failed to retrieve libraries: {response.status_code} - {response.text}")
return
libraries = response.json()
default_library = next(
(lib for lib in libraries if lib["name"] == settings.default_library), None
)
if not default_library:
print("Default library does not exist.")
return
index(default_library["id"], force=force, folders=None, batchsize=batchsize)
if __name__ == "__main__": if __name__ == "__main__":
app() app()

View File

@ -1,6 +1,6 @@
import os import os
from pathlib import Path from pathlib import Path
from typing import Tuple, Type from typing import Tuple, Type, List
from pydantic_settings import ( from pydantic_settings import (
BaseSettings, BaseSettings,
PydanticBaseSettingsSource, PydanticBaseSettingsSource,
@ -17,24 +17,28 @@ class VLMSettings(BaseModel):
modelname: str = "moondream" modelname: str = "moondream"
endpoint: str = "http://localhost:11434" endpoint: str = "http://localhost:11434"
token: str = "" token: str = ""
concurrency: int = 4 concurrency: int = 1
force_jpeg: bool = False force_jpeg: bool = False
use_local: bool = True
use_modelscope: bool = False
class OCRSettings(BaseModel): class OCRSettings(BaseModel):
enabled: bool = True enabled: bool = True
endpoint: str = "http://localhost:5555/predict" endpoint: str = "http://localhost:5555/predict"
token: str = "" token: str = ""
concurrency: int = 4 concurrency: int = 1
use_local: bool = True use_local: bool = True
use_gpu: bool = False use_gpu: bool = False
force_jpeg: bool = False force_jpeg: bool = False
class EmbeddingSettings(BaseModel): class EmbeddingSettings(BaseModel):
enabled: bool = True
num_dim: int = 768 num_dim: int = 768
ollama_endpoint: str = "http://localhost:11434" endpoint: str = "http://localhost:11434/api/embed"
ollama_model: str = "nextfire/paraphrase-multilingual-minilm" model: str = "jinaai/jina-embeddings-v2-base-zh"
use_modelscope: bool = False
class Settings(BaseSettings): class Settings(BaseSettings):
@ -46,6 +50,9 @@ class Settings(BaseSettings):
base_dir: str = str(Path.home() / ".memos") base_dir: str = str(Path.home() / ".memos")
database_path: str = os.path.join(base_dir, "database.db") database_path: str = os.path.join(base_dir, "database.db")
default_library: str = "screenshots"
screenshots_dir: str = os.path.join(base_dir, "screenshots")
typesense_host: str = "localhost" typesense_host: str = "localhost"
typesense_port: str = "8108" typesense_port: str = "8108"
typesense_protocol: str = "http" typesense_protocol: str = "http"
@ -66,11 +73,13 @@ class Settings(BaseSettings):
# Embedding settings # Embedding settings
embedding: EmbeddingSettings = EmbeddingSettings() embedding: EmbeddingSettings = EmbeddingSettings()
batchsize: int = 4 batchsize: int = 1
auth_username: str = "admin" auth_username: str = "admin"
auth_password: SecretStr = SecretStr("changeme") auth_password: SecretStr = SecretStr("changeme")
default_plugins: List[str] = ["builtin_vlm", "builtin_ocr"]
@classmethod @classmethod
def settings_customise_sources( def settings_customise_sources(
cls, cls,
@ -93,6 +102,19 @@ def dict_representer(dumper, data):
yaml.add_representer(OrderedDict, dict_representer) yaml.add_representer(OrderedDict, dict_representer)
# Custom representer for SecretStr
def secret_str_representer(dumper, data):
return dumper.represent_scalar("tag:yaml.org,2002:str", data.get_secret_value())
# Custom constructor for SecretStr
def secret_str_constructor(loader, node):
value = loader.construct_scalar(node)
return SecretStr(value)
# Register the representer and constructor only for specific fields
yaml.add_representer(SecretStr, secret_str_representer)
def create_default_config(): def create_default_config():
config_path = Path.home() / ".memos" / "config.yaml" config_path = Path.home() / ".memos" / "config.yaml"
if not config_path.exists(): if not config_path.exists():

View File

@ -46,14 +46,19 @@ def parse_date_fields(entity):
} }
def get_embeddings(texts: List[str]) -> List[List[float]]: async def get_embeddings(texts: List[str]) -> List[List[float]]:
print(f"Getting embeddings for {len(texts)} texts") print(f"Getting embeddings for {len(texts)} texts")
ollama_endpoint = settings.embedding.ollama_endpoint
ollama_model = settings.embedding.ollama_model if settings.embedding.enabled:
with httpx.Client() as client: endpoint = f"http://{settings.server_host}:{settings.server_port}/plugins/embed"
response = client.post( else:
f"{ollama_endpoint}/api/embed", endpoint = settings.embedding.endpoint
json={"model": ollama_model, "input": texts},
model = settings.embedding.model
async with httpx.AsyncClient() as client:
response = await client.post(
endpoint,
json={"model": model, "input": texts},
timeout=30 timeout=30
) )
if response.status_code == 200: if response.status_code == 200:
@ -99,7 +104,7 @@ def generate_metadata_text(metadata_entries):
return metadata_text return metadata_text
def bulk_upsert(client, entities): async def bulk_upsert(client, entities):
documents = [] documents = []
metadata_texts = [] metadata_texts = []
entities_with_metadata = [] entities_with_metadata = []
@ -142,7 +147,7 @@ def bulk_upsert(client, entities):
).model_dump(mode="json") ).model_dump(mode="json")
) )
embeddings = get_embeddings(metadata_texts) embeddings = await get_embeddings(metadata_texts)
for doc, embedding, entity in zip(documents, embeddings, entities): for doc, embedding, entity in zip(documents, embeddings, entities):
if entity in entities_with_metadata: if entity in entities_with_metadata:
doc["embedding"] = embedding doc["embedding"] = embedding
@ -259,7 +264,7 @@ def list_all_entities(
) )
def search_entities( async def search_entities(
client, client,
q: str, q: str,
library_ids: List[int] = None, library_ids: List[int] = None,
@ -287,7 +292,7 @@ def search_entities(
filter_by_str = " && ".join(filter_by) if filter_by else "" filter_by_str = " && ".join(filter_by) if filter_by else ""
# Convert q to embedding using get_embeddings and take the first embedding # Convert q to embedding using get_embeddings and take the first embedding
embedding = get_embeddings([q])[0] embedding = (await get_embeddings([q]))[0]
common_search_params = { common_search_params = {
"collection": TYPESENSE_COLLECTION_NAME, "collection": TYPESENSE_COLLECTION_NAME,

View File

@ -15,7 +15,7 @@ from typing import List
from .schemas import MetadataSource, MetadataType from .schemas import MetadataSource, MetadataType
from sqlalchemy.exc import OperationalError from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from .config import get_database_path from .config import get_database_path, settings
class Base(DeclarativeBase): class Base(DeclarativeBase):
@ -75,16 +75,14 @@ class EntityModel(Base):
"FolderModel", back_populates="entities" "FolderModel", back_populates="entities"
) )
metadata_entries: Mapped[List["EntityMetadataModel"]] = relationship( metadata_entries: Mapped[List["EntityMetadataModel"]] = relationship(
"EntityMetadataModel", "EntityMetadataModel", lazy="joined", cascade="all, delete-orphan"
lazy="joined",
cascade="all, delete-orphan"
) )
tags: Mapped[List["TagModel"]] = relationship( tags: Mapped[List["TagModel"]] = relationship(
"TagModel", "TagModel",
secondary="entity_tags", secondary="entity_tags",
lazy="joined", lazy="joined",
cascade="all, delete", cascade="all, delete",
overlaps="entities" overlaps="entities",
) )
# 添加索引 # 添加索引
@ -160,30 +158,61 @@ def init_database():
"""Initialize the database.""" """Initialize the database."""
db_path = get_database_path() db_path = get_database_path()
engine = create_engine(f"sqlite:///{db_path}") engine = create_engine(f"sqlite:///{db_path}")
try: try:
Base.metadata.create_all(engine) Base.metadata.create_all(engine)
print(f"Database initialized successfully at {db_path}") print(f"Database initialized successfully at {db_path}")
# Initialize default plugins # Initialize default plugins
Session = sessionmaker(bind=engine) Session = sessionmaker(bind=engine)
with Session() as session: with Session() as session:
initialize_default_plugins(session) default_plugins = initialize_default_plugins(session)
init_default_libraries(session, default_plugins)
return True return True
except OperationalError as e: except OperationalError as e:
print(f"Error initializing database: {e}") print(f"Error initializing database: {e}")
return False return False
def initialize_default_plugins(session): def initialize_default_plugins(session):
default_plugins = [ default_plugins = [
PluginModel(name="buildin_vlm", description="VLM Plugin", webhook_url="/plugins/vlm"), PluginModel(
PluginModel(name="buildin_ocr", description="OCR Plugin", webhook_url="/plugins/ocr"), name="builtin_vlm", description="VLM Plugin", webhook_url="/plugins/vlm"
),
PluginModel(
name="builtin_ocr", description="OCR Plugin", webhook_url="/plugins/ocr"
),
] ]
for plugin in default_plugins: for plugin in default_plugins:
existing_plugin = session.query(PluginModel).filter_by(name=plugin.name).first() existing_plugin = session.query(PluginModel).filter_by(name=plugin.name).first()
if not existing_plugin: if not existing_plugin:
session.add(plugin) session.add(plugin)
session.commit() session.commit()
return default_plugins
def init_default_libraries(session, default_plugins):
default_libraries = [
LibraryModel(name=settings.default_library),
]
for library in default_libraries:
existing_library = (
session.query(LibraryModel).filter_by(name=library.name).first()
)
if not existing_library:
session.add(library)
for plugin in default_plugins:
bind_response = session.query(PluginModel).filter_by(name=plugin.name).first()
if bind_response:
library_plugin = LibraryPluginModel(
library_id=1, plugin_id=bind_response.id
) # Assuming library_id=1 for default libraries
session.add(library_plugin)
session.commit()

View File

@ -0,0 +1,146 @@
import asyncio
from typing import List
from fastapi import APIRouter, HTTPException
import logging
import uvicorn
from sentence_transformers import SentenceTransformer
import torch
import numpy as np
from pydantic import BaseModel
from modelscope import snapshot_download
PLUGIN_NAME = "embedding"
router = APIRouter(tags=[PLUGIN_NAME], responses={404: {"description": "Not found"}})
# Global variables
enabled = False
model = None
num_dim = None
endpoint = None
model_name = None
device = None
use_modelscope = None
# Configure logger
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def init_embedding_model():
global model, device, use_modelscope
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
if use_modelscope:
model_dir = snapshot_download(model_name)
logger.info(f"Model downloaded from ModelScope to: {model_dir}")
else:
model_dir = model_name
logger.info(f"Using model: {model_dir}")
model = SentenceTransformer(model_dir, trust_remote_code=True)
model.to(device)
logger.info(f"Embedding model initialized on device: {device}")
def generate_embeddings(input_texts: List[str]) -> List[List[float]]:
embeddings = model.encode(input_texts, convert_to_tensor=True)
embeddings = embeddings.cpu().numpy()
# Normalize embeddings
norms = np.linalg.norm(embeddings, ord=2, axis=1, keepdims=True)
norms[norms == 0] = 1
embeddings = embeddings / norms
return embeddings.tolist()
class EmbeddingRequest(BaseModel):
input: List[str]
class EmbeddingResponse(BaseModel):
embeddings: List[List[float]]
@router.get("/")
async def read_root():
return {"healthy": True, "enabled": enabled}
@router.post("", include_in_schema=False)
@router.post("/", response_model=EmbeddingResponse)
async def embed(request: EmbeddingRequest):
try:
if not request.input:
return EmbeddingResponse(embeddings=[])
# Run the embedding generation in a separate thread to avoid blocking
loop = asyncio.get_event_loop()
embeddings = await loop.run_in_executor(None, generate_embeddings, request.input)
return EmbeddingResponse(embeddings=embeddings)
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Error generating embeddings: {str(e)}"
)
def init_plugin(config):
global enabled, num_dim, endpoint, model_name, use_modelscope
enabled = config.enabled
num_dim = config.num_dim
endpoint = config.endpoint
model_name = config.model
use_modelscope = config.use_modelscope
if enabled:
init_embedding_model()
logger.info("Embedding plugin initialized")
logger.info(f"Enabled: {enabled}")
logger.info(f"Number of dimensions: {num_dim}")
logger.info(f"Endpoint: {endpoint}")
logger.info(f"Model: {model_name}")
logger.info(f"Use ModelScope: {use_modelscope}")
if __name__ == "__main__":
import argparse
from fastapi import FastAPI
parser = argparse.ArgumentParser(description="Embedding Plugin Configuration")
parser.add_argument(
"--num-dim", type=int, default=768, help="Number of embedding dimensions"
)
parser.add_argument(
"--model",
type=str,
default="jinaai/jina-embeddings-v2-base-zh",
help="Embedding model name",
)
parser.add_argument(
"--port", type=int, default=8000, help="Port to run the server on"
)
parser.add_argument(
"--use-modelscope", action="store_true", help="Use ModelScope to download the model"
)
args = parser.parse_args()
class Config:
def __init__(self, args):
self.enabled = True
self.num_dim = args.num_dim
self.endpoint = "what ever"
self.model = args.model
self.use_modelscope = args.use_modelscope
init_plugin(Config(args))
app = FastAPI()
app.include_router(router)
uvicorn.run(app, host="0.0.0.0", port=args.port)

View File

@ -141,6 +141,7 @@ async def ocr(entity: Entity, request: Request):
} }
] ]
}, },
timeout=30,
) )
# Check if the patch request was successful # Check if the patch request was successful
@ -178,7 +179,7 @@ def init_plugin(config):
ocr_config['Cls']['model_path'] = os.path.join(model_dir, os.path.basename(ocr_config['Cls']['model_path'])) ocr_config['Cls']['model_path'] = os.path.join(model_dir, os.path.basename(ocr_config['Cls']['model_path']))
ocr_config['Rec']['model_path'] = os.path.join(model_dir, os.path.basename(ocr_config['Rec']['model_path'])) ocr_config['Rec']['model_path'] = os.path.join(model_dir, os.path.basename(ocr_config['Rec']['model_path']))
# Save the updated config to a temporary file # Save the updated config to a temporary file with strings wrapped in double quotes
temp_config_path = os.path.join(os.path.dirname(__file__), "temp_ppocr.yaml") temp_config_path = os.path.join(os.path.dirname(__file__), "temp_ppocr.yaml")
with open(temp_config_path, 'w') as f: with open(temp_config_path, 'w') as f:
yaml.safe_dump(ocr_config, f) yaml.safe_dump(ocr_config, f)

View File

@ -9,6 +9,20 @@ import logging
import uvicorn import uvicorn
import os import os
import io import io
import torch
from transformers import AutoModelForCausalLM, AutoProcessor
from unittest.mock import patch
from transformers.dynamic_module_utils import get_imports
from modelscope import snapshot_download
def fixed_get_imports(filename: str | os.PathLike) -> list[str]:
if not str(filename).endswith("modeling_florence2.py"):
return get_imports(filename)
imports = get_imports(filename)
imports.remove("flash_attn")
return imports
PLUGIN_NAME = "vlm" PLUGIN_NAME = "vlm"
PROMPT = "描述这张图片的内容" PROMPT = "描述这张图片的内容"
@ -21,6 +35,10 @@ token = None
concurrency = None concurrency = None
semaphore = None semaphore = None
force_jpeg = None force_jpeg = None
use_local = None
florence_model = None
florence_processor = None
torch_dtype = None
# Configure logger # Configure logger
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
@ -35,18 +53,18 @@ def image2base64(img_path):
with Image.open(img_path) as img: with Image.open(img_path) as img:
if force_jpeg: if force_jpeg:
# Convert image to RGB mode (removes alpha channel if present) # Convert image to RGB mode (removes alpha channel if present)
img = img.convert('RGB') img = img.convert("RGB")
# Save as JPEG in memory # Save as JPEG in memory
buffer = io.BytesIO() buffer = io.BytesIO()
img.save(buffer, format='JPEG') img.save(buffer, format="JPEG")
buffer.seek(0) buffer.seek(0)
encoded_string = base64.b64encode(buffer.getvalue()).decode('utf-8') encoded_string = base64.b64encode(buffer.getvalue()).decode("utf-8")
else: else:
# Use original format # Use original format
buffer = io.BytesIO() buffer = io.BytesIO()
img.save(buffer, format=img.format) img.save(buffer, format=img.format)
buffer.seek(0) buffer.seek(0)
encoded_string = base64.b64encode(buffer.getvalue()).decode('utf-8') encoded_string = base64.b64encode(buffer.getvalue()).decode("utf-8")
return encoded_string return encoded_string
except Exception as e: except Exception as e:
logger.error(f"Error processing image {img_path}: {str(e)}") logger.error(f"Error processing image {img_path}: {str(e)}")
@ -79,12 +97,57 @@ async def fetch(endpoint: str, client, request_data, headers: Optional[dict] = N
async def predict( async def predict(
endpoint: str, modelname: str, img_path: str, token: Optional[str] = None endpoint: str, modelname: str, img_path: str, token: Optional[str] = None
) -> Optional[str]:
if use_local:
return await predict_local(img_path)
else:
return await predict_remote(endpoint, modelname, img_path, token)
async def predict_local(img_path: str) -> Optional[str]:
try:
image = Image.open(img_path)
task_prompt = "<MORE_DETAILED_CAPTION>"
prompt = task_prompt + ""
inputs = florence_processor(text=prompt, images=image, return_tensors="pt").to(
florence_model.device, torch_dtype
)
generated_ids = florence_model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
do_sample=False,
num_beams=3,
)
generated_texts = florence_processor.batch_decode(
generated_ids, skip_special_tokens=False
)
parsed_answer = florence_processor.post_process_generation(
generated_texts[0],
task=task_prompt,
image_size=(image.width, image.height),
)
return parsed_answer.get(task_prompt, "")
except Exception as e:
logger.error(f"Error processing image {img_path}: {str(e)}")
return None
async def predict_remote(
endpoint: str, modelname: str, img_path: str, token: Optional[str] = None
) -> Optional[str]: ) -> Optional[str]:
img_base64 = image2base64(img_path) img_base64 = image2base64(img_path)
if not img_base64: if not img_base64:
return None return None
mime_type = "image/jpeg" if force_jpeg else "image/jpeg" # Default to JPEG if force_jpeg is True mime_type = (
"image/jpeg" if force_jpeg else "image/jpeg"
) # Default to JPEG if force_jpeg is True
if not force_jpeg: if not force_jpeg:
# Only determine MIME type if not forcing JPEG # Only determine MIME type if not forcing JPEG
@ -167,9 +230,9 @@ async def vlm(entity: Entity, request: Request):
vlm_result = await predict(endpoint, modelname, entity.filepath, token=token) vlm_result = await predict(endpoint, modelname, entity.filepath, token=token)
print(vlm_result) logger.info(vlm_result)
if not vlm_result: if not vlm_result:
print(f"No VLM result found for file: {entity.filepath}") logger.info(f"No VLM result found for file: {entity.filepath}")
return {metadata_field_name: "{}"} return {metadata_field_name: "{}"}
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
@ -199,14 +262,55 @@ async def vlm(entity: Entity, request: Request):
def init_plugin(config): def init_plugin(config):
global modelname, endpoint, token, concurrency, semaphore, force_jpeg global modelname, endpoint, token, concurrency, semaphore, force_jpeg, use_local, florence_model, florence_processor, torch_dtype
modelname = config.modelname modelname = config.modelname
endpoint = config.endpoint endpoint = config.endpoint
token = config.token token = config.token
concurrency = config.concurrency concurrency = config.concurrency
force_jpeg = config.force_jpeg force_jpeg = config.force_jpeg
use_local = config.use_local
use_modelscope = config.use_modelscope
semaphore = asyncio.Semaphore(concurrency) semaphore = asyncio.Semaphore(concurrency)
if use_local:
# 检测可用的设备
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
torch_dtype = (
torch.float32
if (
torch.cuda.is_available() and torch.cuda.get_device_capability()[0] <= 6
)
or (not torch.cuda.is_available() and not torch.backends.mps.is_available())
else torch.float16
)
logger.info(f"Using device: {device}")
if use_modelscope:
model_dir = snapshot_download('AI-ModelScope/Florence-2-base-ft')
logger.info(f"Model downloaded from ModelScope to: {model_dir}")
else:
model_dir = "microsoft/Florence-2-base-ft"
logger.info(f"Using model: {model_dir}")
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
florence_model = AutoModelForCausalLM.from_pretrained(
model_dir,
torch_dtype=torch_dtype,
attn_implementation="sdpa",
trust_remote_code=True,
).to(device)
florence_processor = AutoProcessor.from_pretrained(
model_dir, trust_remote_code=True
)
logger.info("Florence model and processor initialized")
# Print the parameters # Print the parameters
logger.info("VLM plugin initialized") logger.info("VLM plugin initialized")
logger.info(f"Model Name: {modelname}") logger.info(f"Model Name: {modelname}")
@ -214,6 +318,8 @@ def init_plugin(config):
logger.info(f"Token: {token}") logger.info(f"Token: {token}")
logger.info(f"Concurrency: {concurrency}") logger.info(f"Concurrency: {concurrency}")
logger.info(f"Force JPEG: {force_jpeg}") logger.info(f"Force JPEG: {force_jpeg}")
logger.info(f"Use Local: {use_local}")
logger.info(f"Use ModelScope: {use_modelscope}")
if __name__ == "__main__": if __name__ == "__main__":
@ -232,6 +338,7 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--port", type=int, default=8000, help="Port to run the server on" "--port", type=int, default=8000, help="Port to run the server on"
) )
parser.add_argument("--use-modelscope", action="store_true", help="Use ModelScope to download the model")
args = parser.parse_args() args = parser.parse_args()

View File

@ -1,4 +1,11 @@
from pydantic import BaseModel, ConfigDict, DirectoryPath, HttpUrl, Field from pydantic import (
BaseModel,
ConfigDict,
DirectoryPath,
HttpUrl,
Field,
model_validator,
)
from typing import List, Optional, Any, Dict from typing import List, Optional, Any, Dict
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
@ -79,7 +86,18 @@ class NewPluginParam(BaseModel):
class NewLibraryPluginParam(BaseModel): class NewLibraryPluginParam(BaseModel):
plugin_id: int plugin_id: Optional[int] = None
plugin_name: Optional[str] = None
@model_validator(mode="after")
def check_either_id_or_name(self):
plugin_id = self.plugin_id
plugin_name = self.plugin_name
if not (plugin_id or plugin_name):
raise ValueError("Either plugin_id or plugin_name must be provided")
if plugin_id is not None and plugin_name is not None:
raise ValueError("Only one of plugin_id or plugin_name should be provided")
return self
class Folder(BaseModel): class Folder(BaseModel):
@ -214,15 +232,18 @@ class FacetCount(BaseModel):
highlighted: str highlighted: str
value: str value: str
class FacetStats(BaseModel): class FacetStats(BaseModel):
total_values: int total_values: int
class Facet(BaseModel): class Facet(BaseModel):
counts: List[FacetCount] counts: List[FacetCount]
field_name: str field_name: str
sampled: bool sampled: bool
stats: FacetStats stats: FacetStats
class TextMatchInfo(BaseModel): class TextMatchInfo(BaseModel):
best_field_score: str best_field_score: str
best_field_weight: int best_field_weight: int
@ -232,9 +253,11 @@ class TextMatchInfo(BaseModel):
tokens_matched: int tokens_matched: int
typo_prefix_score: int typo_prefix_score: int
class HybridSearchInfo(BaseModel): class HybridSearchInfo(BaseModel):
rank_fusion_score: float rank_fusion_score: float
class SearchHit(BaseModel): class SearchHit(BaseModel):
document: EntitySearchResult document: EntitySearchResult
highlight: Dict[str, Any] = {} highlight: Dict[str, Any] = {}
@ -243,12 +266,14 @@ class SearchHit(BaseModel):
text_match: Optional[int] = None text_match: Optional[int] = None
text_match_info: Optional[TextMatchInfo] = None text_match_info: Optional[TextMatchInfo] = None
class RequestParams(BaseModel): class RequestParams(BaseModel):
collection_name: str collection_name: str
first_q: str first_q: str
per_page: int per_page: int
q: str q: str
class SearchResult(BaseModel): class SearchResult(BaseModel):
facet_counts: List[Facet] facet_counts: List[Facet]
found: int found: int
@ -257,4 +282,4 @@ class SearchResult(BaseModel):
page: int page: int
request_params: RequestParams request_params: RequestParams
search_cutoff: bool search_cutoff: bool
search_time_ms: int search_time_ms: int

View File

@ -23,6 +23,7 @@ import typesense
from .config import get_database_path, settings from .config import get_database_path, settings
from memos.plugins.vlm import main as vlm_main from memos.plugins.vlm import main as vlm_main
from memos.plugins.ocr import main as ocr_main from memos.plugins.ocr import main as ocr_main
from memos.plugins.embedding import main as embedding_main
from . import crud from . import crud
from . import indexing from . import indexing
from .schemas import ( from .schemas import (
@ -84,18 +85,6 @@ app.mount(
"/_app", StaticFiles(directory=os.path.join(current_dir, "static/_app"), html=True) "/_app", StaticFiles(directory=os.path.join(current_dir, "static/_app"), html=True)
) )
# Add VLM plugin router
if settings.vlm.enabled:
print("VLM plugin is enabled")
vlm_main.init_plugin(settings.vlm)
app.include_router(vlm_main.router, prefix="/plugins/vlm")
# Add OCR plugin router
if settings.ocr.enabled:
print("OCR plugin is enabled")
ocr_main.init_plugin(settings.ocr)
app.include_router(ocr_main.router, prefix="/plugins/ocr")
@app.get("/favicon.png", response_class=FileResponse) @app.get("/favicon.png", response_class=FileResponse)
async def favicon_png(): async def favicon_png():
@ -411,7 +400,7 @@ async def batch_sync_entities_to_typesense(
) )
try: try:
indexing.bulk_upsert(client, entities) await indexing.bulk_upsert(client, entities)
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
@ -481,7 +470,7 @@ def list_entitiy_indices_in_folder(
@app.get("/search", response_model=SearchResult, tags=["search"]) @app.get("/search", response_model=SearchResult, tags=["search"])
async def search_entities( async def search_entities_route(
q: str, q: str,
library_ids: str = Query(None, description="Comma-separated list of library IDs"), library_ids: str = Query(None, description="Comma-separated list of library IDs"),
folder_ids: str = Query(None, description="Comma-separated list of folder IDs"), folder_ids: str = Query(None, description="Comma-separated list of folder IDs"),
@ -502,7 +491,7 @@ async def search_entities(
[date.strip() for date in created_dates.split(",")] if created_dates else None [date.strip() for date in created_dates.split(",")] if created_dates else None
) )
try: try:
return indexing.search_entities( return await indexing.search_entities(
client, client,
q, q,
library_ids, library_ids,
@ -613,12 +602,29 @@ def add_library_plugin(
library_id: int, new_plugin: NewLibraryPluginParam, db: Session = Depends(get_db) library_id: int, new_plugin: NewLibraryPluginParam, db: Session = Depends(get_db)
): ):
library = crud.get_library_by_id(library_id, db) library = crud.get_library_by_id(library_id, db)
if any(plugin.id == new_plugin.plugin_id for plugin in library.plugins): if library is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Library not found"
)
plugin = None
if new_plugin.plugin_id is not None:
plugin = crud.get_plugin_by_id(new_plugin.plugin_id, db)
elif new_plugin.plugin_name is not None:
plugin = crud.get_plugin_by_name(new_plugin.plugin_name, db)
if plugin is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Plugin not found"
)
if any(p.id == plugin.id for p in library.plugins):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail="Plugin already exists in the library", detail="Plugin already exists in the library",
) )
crud.add_plugin_to_library(library_id, new_plugin.plugin_id, db)
crud.add_plugin_to_library(library_id, plugin.id, db)
@app.delete( @app.delete(
@ -727,6 +733,24 @@ def run_server():
print(f"VLM plugin enabled: {settings.vlm}") print(f"VLM plugin enabled: {settings.vlm}")
print(f"OCR plugin enabled: {settings.ocr}") print(f"OCR plugin enabled: {settings.ocr}")
# Add VLM plugin router
if settings.vlm.enabled:
print("VLM plugin is enabled")
vlm_main.init_plugin(settings.vlm)
app.include_router(vlm_main.router, prefix="/plugins/vlm")
# Add OCR plugin router
if settings.ocr.enabled:
print("OCR plugin is enabled")
ocr_main.init_plugin(settings.ocr)
app.include_router(ocr_main.router, prefix="/plugins/ocr")
# Add Embedding plugin router
if settings.embedding.enabled:
print("Embedding plugin is enabled")
embedding_main.init_plugin(settings.embedding)
app.include_router(embedding_main.router, prefix="/plugins/embed")
uvicorn.run( uvicorn.run(
"memos.server:app", "memos.server:app",
host=settings.server_host, # Use the new server_host setting host=settings.server_host, # Use the new server_host setting

View File

@ -1,7 +1,6 @@
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException
from pydantic import BaseModel from pydantic import BaseModel
from typing import List, Dict, Any, Optional from typing import List, Dict, Any, Optional
from sentence_transformers import SentenceTransformer
import numpy as np import numpy as np
import httpx import httpx
import torch import torch
@ -25,31 +24,15 @@ elif torch.backends.mps.is_available():
else: else:
device = torch.device("cpu") device = torch.device("cpu")
torch_dtype = "auto" torch_dtype = (
torch.float32
if (torch.cuda.is_available() and torch.cuda.get_device_capability()[0] <= 6)
or (not torch.cuda.is_available() and not torch.backends.mps.is_available())
else torch.float16
)
print(f"Using device: {device}") print(f"Using device: {device}")
def init_embedding_model():
model = SentenceTransformer(
"jinaai/jina-embeddings-v2-base-zh", trust_remote_code=True
)
model.to(device)
return model
embedding_model = init_embedding_model() # 初始化模型
def generate_embeddings(input_texts: List[str]) -> List[List[float]]:
embeddings = embedding_model.encode(input_texts, convert_to_tensor=True)
embeddings = embeddings.cpu().numpy()
# normalized embeddings
norms = np.linalg.norm(embeddings, ord=2, axis=1, keepdims=True)
norms[norms == 0] = 1
embeddings = embeddings / norms
return embeddings.tolist()
# Add a configuration option to choose the model # Add a configuration option to choose the model
parser = argparse.ArgumentParser(description="Run the server with specified model") parser = argparse.ArgumentParser(description="Run the server with specified model")
parser.add_argument("--florence", action="store_true", help="Use Florence-2 model") parser.add_argument("--florence", action="store_true", help="Use Florence-2 model")
@ -63,8 +46,11 @@ use_florence_model = args.florence if (args.florence or args.qwen2vl) else True
if use_florence_model: if use_florence_model:
# Load Florence-2 model # Load Florence-2 model
florence_model = AutoModelForCausalLM.from_pretrained( florence_model = AutoModelForCausalLM.from_pretrained(
"microsoft/Florence-2-base-ft", torch_dtype=torch_dtype, trust_remote_code=True "microsoft/Florence-2-base-ft",
).to(device) torch_dtype=torch_dtype,
attn_implementation="sdpa",
trust_remote_code=True,
).to(device, torch_dtype)
florence_processor = AutoProcessor.from_pretrained( florence_processor = AutoProcessor.from_pretrained(
"microsoft/Florence-2-base-ft", trust_remote_code=True "microsoft/Florence-2-base-ft", trust_remote_code=True
) )
@ -74,7 +60,7 @@ else:
"Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4", "Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4",
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
device_map="auto", device_map="auto",
).to(device) ).to(device, torch_dtype)
qwen2vl_processor = AutoProcessor.from_pretrained( qwen2vl_processor = AutoProcessor.from_pretrained(
"Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4" "Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4"
) )
@ -139,9 +125,9 @@ async def generate_qwen2vl_result(text_input, image_input, max_tokens):
text = qwen2vl_processor.apply_chat_template( text = qwen2vl_processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True messages, tokenize=False, add_generation_prompt=True
) )
image_inputs, video_inputs = process_vision_info(messages) image_inputs, video_inputs = process_vision_info(messages)
inputs = qwen2vl_processor( inputs = qwen2vl_processor(
text=[text], text=[text],
images=image_inputs, images=image_inputs,
@ -152,12 +138,12 @@ async def generate_qwen2vl_result(text_input, image_input, max_tokens):
inputs = inputs.to(device) inputs = inputs.to(device)
generated_ids = qwen2vl_model.generate(**inputs, max_new_tokens=(max_tokens or 512)) generated_ids = qwen2vl_model.generate(**inputs, max_new_tokens=(max_tokens or 512))
generated_ids_trimmed = [ generated_ids_trimmed = [
out_ids[len(in_ids) :] out_ids[len(in_ids) :]
for in_ids, out_ids in zip(inputs.input_ids, generated_ids) for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
] ]
output_text = qwen2vl_processor.batch_decode( output_text = qwen2vl_processor.batch_decode(
generated_ids_trimmed, generated_ids_trimmed,
skip_special_tokens=True, skip_special_tokens=True,
@ -170,28 +156,6 @@ async def generate_qwen2vl_result(text_input, image_input, max_tokens):
app = FastAPI() app = FastAPI()
class EmbeddingRequest(BaseModel):
input: List[str]
class EmbeddingResponse(BaseModel):
embeddings: List[List[float]]
@app.post("/api/embed", response_model=EmbeddingResponse)
async def create_embeddings(request: EmbeddingRequest):
try:
if not request.input:
return EmbeddingResponse(embeddings=[])
embeddings = generate_embeddings(request.input) # 使用新方法
return EmbeddingResponse(embeddings=embeddings)
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Error generating embeddings: {str(e)}"
)
class ChatCompletionRequest(BaseModel): class ChatCompletionRequest(BaseModel):
model: str model: str
messages: List[Dict[str, Any]] messages: List[Dict[str, Any]]
@ -275,10 +239,14 @@ async def chat_completions(request: ChatCompletionRequest):
if __name__ == "__main__": if __name__ == "__main__":
import uvicorn import uvicorn
parser = argparse.ArgumentParser(description="Run the server with specified model and port") parser = argparse.ArgumentParser(
description="Run the server with specified model and port"
)
parser.add_argument("--florence", action="store_true", help="Use Florence-2 model") parser.add_argument("--florence", action="store_true", help="Use Florence-2 model")
parser.add_argument("--qwen2vl", action="store_true", help="Use Qwen2VL model") parser.add_argument("--qwen2vl", action="store_true", help="Use Qwen2VL model")
parser.add_argument("--port", type=int, default=8000, help="Port to run the server on") parser.add_argument(
"--port", type=int, default=8000, help="Port to run the server on"
)
args = parser.parse_args() args = parser.parse_args()
if args.florence and args.qwen2vl: if args.florence and args.qwen2vl:

View File

@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "memos" name = "memos"
version = "0.5.0" version = "0.6.10"
description = "A package for memos" description = "A package for memos"
readme = "README.md" readme = "README.md"
authors = [{ name = "arkohut" }] authors = [{ name = "arkohut" }]
@ -36,6 +36,13 @@ dependencies = [
"pyobjc; sys_platform == 'darwin'", "pyobjc; sys_platform == 'darwin'",
"pyobjc-core; sys_platform == 'darwin'", "pyobjc-core; sys_platform == 'darwin'",
"pyobjc-framework-Quartz; sys_platform == 'darwin'", "pyobjc-framework-Quartz; sys_platform == 'darwin'",
"sentence-transformers",
"torch",
"numpy",
"timm",
"einops",
"modelscope",
"mss",
] ]
[project.urls] [project.urls]
@ -50,3 +57,4 @@ include = ["memos*", "screen_recorder*"]
[tool.setuptools.package-data] [tool.setuptools.package-data]
"*" = ["static/**/*"] "*" = ["static/**/*"]
"memos.plugins.ocr" = ["*.yaml", "models/*.onnx"]

View File

@ -4,11 +4,11 @@ import time
import logging import logging
import platform import platform
import subprocess import subprocess
from PIL import Image, ImageGrab from PIL import Image
import imagehash import imagehash
from memos.utils import write_image_metadata from memos.utils import write_image_metadata
from screeninfo import get_monitors
import ctypes import ctypes
from mss import mss
if platform.system() == "Windows": if platform.system() == "Windows":
import win32gui import win32gui
@ -173,57 +173,49 @@ def take_screenshot_windows(
app_name, app_name,
window_title, window_title,
): ):
for monitor in get_monitors(): with mss() as sct:
safe_monitor_name = "".join( for i, monitor in enumerate(sct.monitors[1:], 1): # Skip the first monitor (entire screen)
c for c in monitor.name if c.isalnum() or c in ("_", "-") safe_monitor_name = f"monitor_{i}"
) logging.info(f"Processing monitor: {safe_monitor_name}")
logging.info(f"Processing monitor: {safe_monitor_name}")
webp_filename = os.path.join( webp_filename = os.path.join(
base_dir, date, f"screenshot-{timestamp}-of-{safe_monitor_name}.webp" base_dir, date, f"screenshot-{timestamp}-of-{safe_monitor_name}.webp"
)
img = ImageGrab.grab(
bbox=(
monitor.x,
monitor.y,
monitor.x + monitor.width,
monitor.y + monitor.height,
) )
)
img = img.convert("RGB")
current_hash = str(imagehash.phash(img))
if ( img = sct.grab(monitor)
safe_monitor_name in previous_hashes img = Image.frombytes("RGB", img.size, img.bgra, "raw", "BGRX")
and imagehash.hex_to_hash(current_hash) current_hash = str(imagehash.phash(img))
- imagehash.hex_to_hash(previous_hashes[safe_monitor_name])
< threshold if (
): safe_monitor_name in previous_hashes
logging.info( and imagehash.hex_to_hash(current_hash)
f"Screenshot for {safe_monitor_name} is similar to the previous one. Skipping." - imagehash.hex_to_hash(previous_hashes[safe_monitor_name])
< threshold
):
logging.info(
f"Screenshot for {safe_monitor_name} is similar to the previous one. Skipping."
)
yield safe_monitor_name, None, "Skipped (similar to previous)"
continue
previous_hashes[safe_monitor_name] = current_hash
screen_sequences[safe_monitor_name] = (
screen_sequences.get(safe_monitor_name, 0) + 1
) )
yield safe_monitor_name, None, "Skipped (similar to previous)"
continue
previous_hashes[safe_monitor_name] = current_hash metadata = {
screen_sequences[safe_monitor_name] = ( "timestamp": timestamp,
screen_sequences.get(safe_monitor_name, 0) + 1 "active_app": app_name,
) "active_window": window_title,
"screen_name": safe_monitor_name,
"sequence": screen_sequences[safe_monitor_name],
}
metadata = { img.save(webp_filename, format="WebP", quality=85)
"timestamp": timestamp, write_image_metadata(webp_filename, metadata)
"active_app": app_name, save_screen_sequences(base_dir, screen_sequences, date)
"active_window": window_title,
"screen_name": safe_monitor_name,
"sequence": screen_sequences[safe_monitor_name],
}
img.save(webp_filename, format="WebP", quality=85) yield safe_monitor_name, webp_filename, "Saved"
write_image_metadata(webp_filename, metadata)
save_screen_sequences(base_dir, screen_sequences, date)
yield safe_monitor_name, webp_filename, "Saved"
def take_screenshot( def take_screenshot(

View File

@ -9,6 +9,8 @@ from screen_recorder.common import (
take_screenshot, take_screenshot,
is_screen_locked, is_screen_locked,
) )
from pathlib import Path
from memos.config import settings
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
@ -51,12 +53,12 @@ def main():
"--threshold", type=int, default=4, help="Threshold for image similarity" "--threshold", type=int, default=4, help="Threshold for image similarity"
) )
parser.add_argument( parser.add_argument(
"--base-dir", type=str, default="~/tmp", help="Base directory for screenshots" "--base-dir", type=str, help="Base directory for screenshots"
) )
parser.add_argument("--once", action="store_true", help="Run once and exit") parser.add_argument("--once", action="store_true", help="Run once and exit")
args = parser.parse_args() args = parser.parse_args()
base_dir = os.path.expanduser(args.base_dir) base_dir = os.path.expanduser(args.base_dir) if args.base_dir else settings.screenshots_dir
previous_hashes = load_previous_hashes(base_dir) previous_hashes = load_previous_hashes(base_dir)
if args.once: if args.once:

View File

@ -112,7 +112,7 @@
id={`library-${library.id}`} id={`library-${library.id}`}
bind:checked={selectedLibraries[library.id]} bind:checked={selectedLibraries[library.id]}
/> />
<Label for={`library-${library.id}`} class="flex items-center text-sm">{library.name}</Label> <Label for={`library-${library.id}`} class="flex items-center text-sm">{library.name}#{library.id}</Label>
</div> </div>
{/each} {/each}
</div> </div>