This commit is contained in:
arkohut 2024-09-11 23:20:49 +08:00
commit 4a8d8fa54a
15 changed files with 656 additions and 176 deletions

View File

@ -1,3 +1,69 @@
# Memos
A project to index everything to make it like another memory.
A project to index everything to make it like another memory. The project contains two parts:
1. `screen recorder`: which takes screenshots every 5 seconds and saves them to `~/.memos/screenshots` by default.
2. `memos server`: a web service that can index the screenshots and other files, providing a web interface to search the records.
There is a product called [Rewind](https://www.rewind.ai/) that is similar to memos, but memos aims to give you control over all your data.
## Install
### Install Typesense
```bash
export TYPESENSE_API_KEY=xyz
mkdir "$(pwd)"/typesense-data
docker run -d -p 8108:8108 \
-v"$(pwd)"/typesense-data:/data typesense/typesense:27.0 \
--add-host=host.docker.internal:host-gateway \
--data-dir /data \
--api-key=$TYPESENSE_API_KEY \
--enable-cors
```
### Install Memos
```bash
pip install memos
```
## How to use
To use memos, you need to initialize it first. Make sure you have started `typesense`.
### 1. Initialize Memos
```bash
memos init
```
This will create a folder `~/.memos` and put the config file there.
### 2. Start Screen Recorder
```bash
memos-record
```
This will start a screen recorder, which will take screenshots every 5 seconds and save it at `~/.memos/screenshots` by default.
### 3. Start Memos Server
```bash
memos serve
```
This will start a web server, and you can access the web interface at `http://localhost:8080`.
The default username and password is `admin` and `changeme`.
### Index the screenshots
```bash
memos scan
memos index
```
Refresh the page, and do some search.

View File

@ -1,6 +1,5 @@
import asyncio
import os
import time
import logging
from datetime import datetime, timezone
from pathlib import Path
@ -549,6 +548,9 @@ def index(
library_id: int,
folders: List[int] = typer.Option(None, "--folder", "-f"),
force: bool = typer.Option(False, "--force", help="Force update all indexes"),
batchsize: int = typer.Option(
4, "--batchsize", "-bs", help="Number of entities to index in a batch"
),
):
print(f"Indexing library {library_id}")
@ -608,9 +610,8 @@ def index(
pbar.refresh()
# Index each entity
batch_size = settings.batchsize
for i in range(0, len(entities), batch_size):
batch = entities[i : i + batch_size]
for i in range(0, len(entities), batchsize):
batch = entities[i : i + batchsize]
to_index = []
for entity in batch:
@ -722,18 +723,22 @@ def create(name: str, webhook_url: str, description: str = ""):
@plugin_app.command("bind")
def bind(
library_id: int = typer.Option(..., "--lib", help="ID of the library"),
plugin_id: int = typer.Option(..., "--plugin", help="ID of the plugin"),
plugin: str = typer.Option(..., "--plugin", help="ID or name of the plugin"),
):
try:
plugin_id = int(plugin)
plugin_param = {"plugin_id": plugin_id}
except ValueError:
plugin_param = {"plugin_name": plugin}
response = httpx.post(
f"{BASE_URL}/libraries/{library_id}/plugins",
json={"plugin_id": plugin_id},
json=plugin_param,
)
if 200 <= response.status_code < 300:
if response.status_code == 204:
print("Plugin bound to library successfully")
else:
print(
f"Failed to bind plugin to library: {response.status_code} - {response.text}"
)
print(f"Failed to bind plugin to library: {response.status_code} - {response.text}")
@plugin_app.command("unbind")
@ -763,5 +768,85 @@ def init():
print("Initialization failed. Please check the error messages above.")
@app.command("scan")
def scan_default_library(force: bool = False):
"""
Scan the screenshots directory and add it to the library if empty.
"""
# Get the default library
response = httpx.get(f"{BASE_URL}/libraries")
if response.status_code != 200:
print(f"Failed to retrieve libraries: {response.status_code} - {response.text}")
return
libraries = response.json()
default_library = next(
(lib for lib in libraries if lib["name"] == settings.default_library), None
)
if not default_library:
# Create the default library if it doesn't exist
response = httpx.post(
f"{BASE_URL}/libraries",
json={"name": settings.default_library, "folders": []},
)
if response.status_code != 200:
print(
f"Failed to create default library: {response.status_code} - {response.text}"
)
return
default_library = response.json()
for plugin in settings.default_plugins:
bind(default_library["id"], plugin)
# Check if the library is empty
if not default_library["folders"]:
# Add the screenshots directory to the library
screenshots_dir = Path(settings.screenshots_dir).resolve()
response = httpx.post(
f"{BASE_URL}/libraries/{default_library['id']}/folders",
json={"folders": [str(screenshots_dir)]},
)
if response.status_code != 200:
print(
f"Failed to add screenshots directory: {response.status_code} - {response.text}"
)
return
print(f"Added screenshots directory: {screenshots_dir}")
# Scan the library
print(f"Scanning library: {default_library['name']}")
scan(default_library["id"], plugins=None, folders=None, force=force)
@app.command("index")
def index_default_library(
batchsize: int = typer.Option(
4, "--batchsize", "-bs", help="Number of entities to index in a batch"
),
force: bool = typer.Option(False, "--force", help="Force update all indexes"),
):
"""
Index the default library for memos.
"""
# Get the default library
response = httpx.get(f"{BASE_URL}/libraries")
if response.status_code != 200:
print(f"Failed to retrieve libraries: {response.status_code} - {response.text}")
return
libraries = response.json()
default_library = next(
(lib for lib in libraries if lib["name"] == settings.default_library), None
)
if not default_library:
print("Default library does not exist.")
return
index(default_library["id"], force=force, folders=None, batchsize=batchsize)
if __name__ == "__main__":
app()

View File

@ -1,6 +1,6 @@
import os
from pathlib import Path
from typing import Tuple, Type
from typing import Tuple, Type, List
from pydantic_settings import (
BaseSettings,
PydanticBaseSettingsSource,
@ -17,24 +17,28 @@ class VLMSettings(BaseModel):
modelname: str = "moondream"
endpoint: str = "http://localhost:11434"
token: str = ""
concurrency: int = 4
concurrency: int = 1
force_jpeg: bool = False
use_local: bool = True
use_modelscope: bool = False
class OCRSettings(BaseModel):
enabled: bool = True
endpoint: str = "http://localhost:5555/predict"
token: str = ""
concurrency: int = 4
concurrency: int = 1
use_local: bool = True
use_gpu: bool = False
force_jpeg: bool = False
class EmbeddingSettings(BaseModel):
enabled: bool = True
num_dim: int = 768
ollama_endpoint: str = "http://localhost:11434"
ollama_model: str = "nextfire/paraphrase-multilingual-minilm"
endpoint: str = "http://localhost:11434/api/embed"
model: str = "jinaai/jina-embeddings-v2-base-zh"
use_modelscope: bool = False
class Settings(BaseSettings):
@ -46,6 +50,9 @@ class Settings(BaseSettings):
base_dir: str = str(Path.home() / ".memos")
database_path: str = os.path.join(base_dir, "database.db")
default_library: str = "screenshots"
screenshots_dir: str = os.path.join(base_dir, "screenshots")
typesense_host: str = "localhost"
typesense_port: str = "8108"
typesense_protocol: str = "http"
@ -66,11 +73,13 @@ class Settings(BaseSettings):
# Embedding settings
embedding: EmbeddingSettings = EmbeddingSettings()
batchsize: int = 4
batchsize: int = 1
auth_username: str = "admin"
auth_password: SecretStr = SecretStr("changeme")
default_plugins: List[str] = ["builtin_vlm", "builtin_ocr"]
@classmethod
def settings_customise_sources(
cls,
@ -93,6 +102,19 @@ def dict_representer(dumper, data):
yaml.add_representer(OrderedDict, dict_representer)
# Custom representer for SecretStr
def secret_str_representer(dumper, data):
return dumper.represent_scalar("tag:yaml.org,2002:str", data.get_secret_value())
# Custom constructor for SecretStr
def secret_str_constructor(loader, node):
value = loader.construct_scalar(node)
return SecretStr(value)
# Register the representer and constructor only for specific fields
yaml.add_representer(SecretStr, secret_str_representer)
def create_default_config():
config_path = Path.home() / ".memos" / "config.yaml"
if not config_path.exists():

View File

@ -46,14 +46,19 @@ def parse_date_fields(entity):
}
def get_embeddings(texts: List[str]) -> List[List[float]]:
async def get_embeddings(texts: List[str]) -> List[List[float]]:
print(f"Getting embeddings for {len(texts)} texts")
ollama_endpoint = settings.embedding.ollama_endpoint
ollama_model = settings.embedding.ollama_model
with httpx.Client() as client:
response = client.post(
f"{ollama_endpoint}/api/embed",
json={"model": ollama_model, "input": texts},
if settings.embedding.enabled:
endpoint = f"http://{settings.server_host}:{settings.server_port}/plugins/embed"
else:
endpoint = settings.embedding.endpoint
model = settings.embedding.model
async with httpx.AsyncClient() as client:
response = await client.post(
endpoint,
json={"model": model, "input": texts},
timeout=30
)
if response.status_code == 200:
@ -99,7 +104,7 @@ def generate_metadata_text(metadata_entries):
return metadata_text
def bulk_upsert(client, entities):
async def bulk_upsert(client, entities):
documents = []
metadata_texts = []
entities_with_metadata = []
@ -142,7 +147,7 @@ def bulk_upsert(client, entities):
).model_dump(mode="json")
)
embeddings = get_embeddings(metadata_texts)
embeddings = await get_embeddings(metadata_texts)
for doc, embedding, entity in zip(documents, embeddings, entities):
if entity in entities_with_metadata:
doc["embedding"] = embedding
@ -259,7 +264,7 @@ def list_all_entities(
)
def search_entities(
async def search_entities(
client,
q: str,
library_ids: List[int] = None,
@ -287,7 +292,7 @@ def search_entities(
filter_by_str = " && ".join(filter_by) if filter_by else ""
# Convert q to embedding using get_embeddings and take the first embedding
embedding = get_embeddings([q])[0]
embedding = (await get_embeddings([q]))[0]
common_search_params = {
"collection": TYPESENSE_COLLECTION_NAME,

View File

@ -15,7 +15,7 @@ from typing import List
from .schemas import MetadataSource, MetadataType
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import sessionmaker
from .config import get_database_path
from .config import get_database_path, settings
class Base(DeclarativeBase):
@ -75,16 +75,14 @@ class EntityModel(Base):
"FolderModel", back_populates="entities"
)
metadata_entries: Mapped[List["EntityMetadataModel"]] = relationship(
"EntityMetadataModel",
lazy="joined",
cascade="all, delete-orphan"
"EntityMetadataModel", lazy="joined", cascade="all, delete-orphan"
)
tags: Mapped[List["TagModel"]] = relationship(
"TagModel",
secondary="entity_tags",
"TagModel",
secondary="entity_tags",
lazy="joined",
cascade="all, delete",
overlaps="entities"
overlaps="entities",
)
# 添加索引
@ -160,30 +158,61 @@ def init_database():
"""Initialize the database."""
db_path = get_database_path()
engine = create_engine(f"sqlite:///{db_path}")
try:
Base.metadata.create_all(engine)
print(f"Database initialized successfully at {db_path}")
# Initialize default plugins
Session = sessionmaker(bind=engine)
with Session() as session:
initialize_default_plugins(session)
default_plugins = initialize_default_plugins(session)
init_default_libraries(session, default_plugins)
return True
except OperationalError as e:
print(f"Error initializing database: {e}")
return False
def initialize_default_plugins(session):
default_plugins = [
PluginModel(name="buildin_vlm", description="VLM Plugin", webhook_url="/plugins/vlm"),
PluginModel(name="buildin_ocr", description="OCR Plugin", webhook_url="/plugins/ocr"),
PluginModel(
name="builtin_vlm", description="VLM Plugin", webhook_url="/plugins/vlm"
),
PluginModel(
name="builtin_ocr", description="OCR Plugin", webhook_url="/plugins/ocr"
),
]
for plugin in default_plugins:
existing_plugin = session.query(PluginModel).filter_by(name=plugin.name).first()
if not existing_plugin:
session.add(plugin)
session.commit()
session.commit()
return default_plugins
def init_default_libraries(session, default_plugins):
default_libraries = [
LibraryModel(name=settings.default_library),
]
for library in default_libraries:
existing_library = (
session.query(LibraryModel).filter_by(name=library.name).first()
)
if not existing_library:
session.add(library)
for plugin in default_plugins:
bind_response = session.query(PluginModel).filter_by(name=plugin.name).first()
if bind_response:
library_plugin = LibraryPluginModel(
library_id=1, plugin_id=bind_response.id
) # Assuming library_id=1 for default libraries
session.add(library_plugin)
session.commit()

View File

@ -0,0 +1,146 @@
import asyncio
from typing import List
from fastapi import APIRouter, HTTPException
import logging
import uvicorn
from sentence_transformers import SentenceTransformer
import torch
import numpy as np
from pydantic import BaseModel
from modelscope import snapshot_download
PLUGIN_NAME = "embedding"
router = APIRouter(tags=[PLUGIN_NAME], responses={404: {"description": "Not found"}})
# Global variables
enabled = False
model = None
num_dim = None
endpoint = None
model_name = None
device = None
use_modelscope = None
# Configure logger
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def init_embedding_model():
global model, device, use_modelscope
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
if use_modelscope:
model_dir = snapshot_download(model_name)
logger.info(f"Model downloaded from ModelScope to: {model_dir}")
else:
model_dir = model_name
logger.info(f"Using model: {model_dir}")
model = SentenceTransformer(model_dir, trust_remote_code=True)
model.to(device)
logger.info(f"Embedding model initialized on device: {device}")
def generate_embeddings(input_texts: List[str]) -> List[List[float]]:
embeddings = model.encode(input_texts, convert_to_tensor=True)
embeddings = embeddings.cpu().numpy()
# Normalize embeddings
norms = np.linalg.norm(embeddings, ord=2, axis=1, keepdims=True)
norms[norms == 0] = 1
embeddings = embeddings / norms
return embeddings.tolist()
class EmbeddingRequest(BaseModel):
input: List[str]
class EmbeddingResponse(BaseModel):
embeddings: List[List[float]]
@router.get("/")
async def read_root():
return {"healthy": True, "enabled": enabled}
@router.post("", include_in_schema=False)
@router.post("/", response_model=EmbeddingResponse)
async def embed(request: EmbeddingRequest):
try:
if not request.input:
return EmbeddingResponse(embeddings=[])
# Run the embedding generation in a separate thread to avoid blocking
loop = asyncio.get_event_loop()
embeddings = await loop.run_in_executor(None, generate_embeddings, request.input)
return EmbeddingResponse(embeddings=embeddings)
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Error generating embeddings: {str(e)}"
)
def init_plugin(config):
global enabled, num_dim, endpoint, model_name, use_modelscope
enabled = config.enabled
num_dim = config.num_dim
endpoint = config.endpoint
model_name = config.model
use_modelscope = config.use_modelscope
if enabled:
init_embedding_model()
logger.info("Embedding plugin initialized")
logger.info(f"Enabled: {enabled}")
logger.info(f"Number of dimensions: {num_dim}")
logger.info(f"Endpoint: {endpoint}")
logger.info(f"Model: {model_name}")
logger.info(f"Use ModelScope: {use_modelscope}")
if __name__ == "__main__":
import argparse
from fastapi import FastAPI
parser = argparse.ArgumentParser(description="Embedding Plugin Configuration")
parser.add_argument(
"--num-dim", type=int, default=768, help="Number of embedding dimensions"
)
parser.add_argument(
"--model",
type=str,
default="jinaai/jina-embeddings-v2-base-zh",
help="Embedding model name",
)
parser.add_argument(
"--port", type=int, default=8000, help="Port to run the server on"
)
parser.add_argument(
"--use-modelscope", action="store_true", help="Use ModelScope to download the model"
)
args = parser.parse_args()
class Config:
def __init__(self, args):
self.enabled = True
self.num_dim = args.num_dim
self.endpoint = "what ever"
self.model = args.model
self.use_modelscope = args.use_modelscope
init_plugin(Config(args))
app = FastAPI()
app.include_router(router)
uvicorn.run(app, host="0.0.0.0", port=args.port)

View File

@ -141,6 +141,7 @@ async def ocr(entity: Entity, request: Request):
}
]
},
timeout=30,
)
# Check if the patch request was successful
@ -178,7 +179,7 @@ def init_plugin(config):
ocr_config['Cls']['model_path'] = os.path.join(model_dir, os.path.basename(ocr_config['Cls']['model_path']))
ocr_config['Rec']['model_path'] = os.path.join(model_dir, os.path.basename(ocr_config['Rec']['model_path']))
# Save the updated config to a temporary file
# Save the updated config to a temporary file with strings wrapped in double quotes
temp_config_path = os.path.join(os.path.dirname(__file__), "temp_ppocr.yaml")
with open(temp_config_path, 'w') as f:
yaml.safe_dump(ocr_config, f)

View File

@ -9,6 +9,20 @@ import logging
import uvicorn
import os
import io
import torch
from transformers import AutoModelForCausalLM, AutoProcessor
from unittest.mock import patch
from transformers.dynamic_module_utils import get_imports
from modelscope import snapshot_download
def fixed_get_imports(filename: str | os.PathLike) -> list[str]:
if not str(filename).endswith("modeling_florence2.py"):
return get_imports(filename)
imports = get_imports(filename)
imports.remove("flash_attn")
return imports
PLUGIN_NAME = "vlm"
PROMPT = "描述这张图片的内容"
@ -21,6 +35,10 @@ token = None
concurrency = None
semaphore = None
force_jpeg = None
use_local = None
florence_model = None
florence_processor = None
torch_dtype = None
# Configure logger
logging.basicConfig(level=logging.INFO)
@ -35,18 +53,18 @@ def image2base64(img_path):
with Image.open(img_path) as img:
if force_jpeg:
# Convert image to RGB mode (removes alpha channel if present)
img = img.convert('RGB')
img = img.convert("RGB")
# Save as JPEG in memory
buffer = io.BytesIO()
img.save(buffer, format='JPEG')
img.save(buffer, format="JPEG")
buffer.seek(0)
encoded_string = base64.b64encode(buffer.getvalue()).decode('utf-8')
encoded_string = base64.b64encode(buffer.getvalue()).decode("utf-8")
else:
# Use original format
buffer = io.BytesIO()
img.save(buffer, format=img.format)
buffer.seek(0)
encoded_string = base64.b64encode(buffer.getvalue()).decode('utf-8')
encoded_string = base64.b64encode(buffer.getvalue()).decode("utf-8")
return encoded_string
except Exception as e:
logger.error(f"Error processing image {img_path}: {str(e)}")
@ -79,12 +97,57 @@ async def fetch(endpoint: str, client, request_data, headers: Optional[dict] = N
async def predict(
endpoint: str, modelname: str, img_path: str, token: Optional[str] = None
) -> Optional[str]:
if use_local:
return await predict_local(img_path)
else:
return await predict_remote(endpoint, modelname, img_path, token)
async def predict_local(img_path: str) -> Optional[str]:
try:
image = Image.open(img_path)
task_prompt = "<MORE_DETAILED_CAPTION>"
prompt = task_prompt + ""
inputs = florence_processor(text=prompt, images=image, return_tensors="pt").to(
florence_model.device, torch_dtype
)
generated_ids = florence_model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
do_sample=False,
num_beams=3,
)
generated_texts = florence_processor.batch_decode(
generated_ids, skip_special_tokens=False
)
parsed_answer = florence_processor.post_process_generation(
generated_texts[0],
task=task_prompt,
image_size=(image.width, image.height),
)
return parsed_answer.get(task_prompt, "")
except Exception as e:
logger.error(f"Error processing image {img_path}: {str(e)}")
return None
async def predict_remote(
endpoint: str, modelname: str, img_path: str, token: Optional[str] = None
) -> Optional[str]:
img_base64 = image2base64(img_path)
if not img_base64:
return None
mime_type = "image/jpeg" if force_jpeg else "image/jpeg" # Default to JPEG if force_jpeg is True
mime_type = (
"image/jpeg" if force_jpeg else "image/jpeg"
) # Default to JPEG if force_jpeg is True
if not force_jpeg:
# Only determine MIME type if not forcing JPEG
@ -167,9 +230,9 @@ async def vlm(entity: Entity, request: Request):
vlm_result = await predict(endpoint, modelname, entity.filepath, token=token)
print(vlm_result)
logger.info(vlm_result)
if not vlm_result:
print(f"No VLM result found for file: {entity.filepath}")
logger.info(f"No VLM result found for file: {entity.filepath}")
return {metadata_field_name: "{}"}
async with httpx.AsyncClient() as client:
@ -199,14 +262,55 @@ async def vlm(entity: Entity, request: Request):
def init_plugin(config):
global modelname, endpoint, token, concurrency, semaphore, force_jpeg
global modelname, endpoint, token, concurrency, semaphore, force_jpeg, use_local, florence_model, florence_processor, torch_dtype
modelname = config.modelname
endpoint = config.endpoint
token = config.token
concurrency = config.concurrency
force_jpeg = config.force_jpeg
use_local = config.use_local
use_modelscope = config.use_modelscope
semaphore = asyncio.Semaphore(concurrency)
if use_local:
# 检测可用的设备
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
torch_dtype = (
torch.float32
if (
torch.cuda.is_available() and torch.cuda.get_device_capability()[0] <= 6
)
or (not torch.cuda.is_available() and not torch.backends.mps.is_available())
else torch.float16
)
logger.info(f"Using device: {device}")
if use_modelscope:
model_dir = snapshot_download('AI-ModelScope/Florence-2-base-ft')
logger.info(f"Model downloaded from ModelScope to: {model_dir}")
else:
model_dir = "microsoft/Florence-2-base-ft"
logger.info(f"Using model: {model_dir}")
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
florence_model = AutoModelForCausalLM.from_pretrained(
model_dir,
torch_dtype=torch_dtype,
attn_implementation="sdpa",
trust_remote_code=True,
).to(device)
florence_processor = AutoProcessor.from_pretrained(
model_dir, trust_remote_code=True
)
logger.info("Florence model and processor initialized")
# Print the parameters
logger.info("VLM plugin initialized")
logger.info(f"Model Name: {modelname}")
@ -214,6 +318,8 @@ def init_plugin(config):
logger.info(f"Token: {token}")
logger.info(f"Concurrency: {concurrency}")
logger.info(f"Force JPEG: {force_jpeg}")
logger.info(f"Use Local: {use_local}")
logger.info(f"Use ModelScope: {use_modelscope}")
if __name__ == "__main__":
@ -232,6 +338,7 @@ if __name__ == "__main__":
parser.add_argument(
"--port", type=int, default=8000, help="Port to run the server on"
)
parser.add_argument("--use-modelscope", action="store_true", help="Use ModelScope to download the model")
args = parser.parse_args()

View File

@ -1,4 +1,11 @@
from pydantic import BaseModel, ConfigDict, DirectoryPath, HttpUrl, Field
from pydantic import (
BaseModel,
ConfigDict,
DirectoryPath,
HttpUrl,
Field,
model_validator,
)
from typing import List, Optional, Any, Dict
from datetime import datetime
from enum import Enum
@ -79,7 +86,18 @@ class NewPluginParam(BaseModel):
class NewLibraryPluginParam(BaseModel):
plugin_id: int
plugin_id: Optional[int] = None
plugin_name: Optional[str] = None
@model_validator(mode="after")
def check_either_id_or_name(self):
plugin_id = self.plugin_id
plugin_name = self.plugin_name
if not (plugin_id or plugin_name):
raise ValueError("Either plugin_id or plugin_name must be provided")
if plugin_id is not None and plugin_name is not None:
raise ValueError("Only one of plugin_id or plugin_name should be provided")
return self
class Folder(BaseModel):
@ -214,15 +232,18 @@ class FacetCount(BaseModel):
highlighted: str
value: str
class FacetStats(BaseModel):
total_values: int
class Facet(BaseModel):
counts: List[FacetCount]
field_name: str
sampled: bool
stats: FacetStats
class TextMatchInfo(BaseModel):
best_field_score: str
best_field_weight: int
@ -232,9 +253,11 @@ class TextMatchInfo(BaseModel):
tokens_matched: int
typo_prefix_score: int
class HybridSearchInfo(BaseModel):
rank_fusion_score: float
class SearchHit(BaseModel):
document: EntitySearchResult
highlight: Dict[str, Any] = {}
@ -243,12 +266,14 @@ class SearchHit(BaseModel):
text_match: Optional[int] = None
text_match_info: Optional[TextMatchInfo] = None
class RequestParams(BaseModel):
collection_name: str
first_q: str
per_page: int
q: str
class SearchResult(BaseModel):
facet_counts: List[Facet]
found: int
@ -257,4 +282,4 @@ class SearchResult(BaseModel):
page: int
request_params: RequestParams
search_cutoff: bool
search_time_ms: int
search_time_ms: int

View File

@ -23,6 +23,7 @@ import typesense
from .config import get_database_path, settings
from memos.plugins.vlm import main as vlm_main
from memos.plugins.ocr import main as ocr_main
from memos.plugins.embedding import main as embedding_main
from . import crud
from . import indexing
from .schemas import (
@ -84,18 +85,6 @@ app.mount(
"/_app", StaticFiles(directory=os.path.join(current_dir, "static/_app"), html=True)
)
# Add VLM plugin router
if settings.vlm.enabled:
print("VLM plugin is enabled")
vlm_main.init_plugin(settings.vlm)
app.include_router(vlm_main.router, prefix="/plugins/vlm")
# Add OCR plugin router
if settings.ocr.enabled:
print("OCR plugin is enabled")
ocr_main.init_plugin(settings.ocr)
app.include_router(ocr_main.router, prefix="/plugins/ocr")
@app.get("/favicon.png", response_class=FileResponse)
async def favicon_png():
@ -411,7 +400,7 @@ async def batch_sync_entities_to_typesense(
)
try:
indexing.bulk_upsert(client, entities)
await indexing.bulk_upsert(client, entities)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
@ -481,7 +470,7 @@ def list_entitiy_indices_in_folder(
@app.get("/search", response_model=SearchResult, tags=["search"])
async def search_entities(
async def search_entities_route(
q: str,
library_ids: str = Query(None, description="Comma-separated list of library IDs"),
folder_ids: str = Query(None, description="Comma-separated list of folder IDs"),
@ -502,7 +491,7 @@ async def search_entities(
[date.strip() for date in created_dates.split(",")] if created_dates else None
)
try:
return indexing.search_entities(
return await indexing.search_entities(
client,
q,
library_ids,
@ -613,12 +602,29 @@ def add_library_plugin(
library_id: int, new_plugin: NewLibraryPluginParam, db: Session = Depends(get_db)
):
library = crud.get_library_by_id(library_id, db)
if any(plugin.id == new_plugin.plugin_id for plugin in library.plugins):
if library is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Library not found"
)
plugin = None
if new_plugin.plugin_id is not None:
plugin = crud.get_plugin_by_id(new_plugin.plugin_id, db)
elif new_plugin.plugin_name is not None:
plugin = crud.get_plugin_by_name(new_plugin.plugin_name, db)
if plugin is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Plugin not found"
)
if any(p.id == plugin.id for p in library.plugins):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Plugin already exists in the library",
)
crud.add_plugin_to_library(library_id, new_plugin.plugin_id, db)
crud.add_plugin_to_library(library_id, plugin.id, db)
@app.delete(
@ -727,6 +733,24 @@ def run_server():
print(f"VLM plugin enabled: {settings.vlm}")
print(f"OCR plugin enabled: {settings.ocr}")
# Add VLM plugin router
if settings.vlm.enabled:
print("VLM plugin is enabled")
vlm_main.init_plugin(settings.vlm)
app.include_router(vlm_main.router, prefix="/plugins/vlm")
# Add OCR plugin router
if settings.ocr.enabled:
print("OCR plugin is enabled")
ocr_main.init_plugin(settings.ocr)
app.include_router(ocr_main.router, prefix="/plugins/ocr")
# Add Embedding plugin router
if settings.embedding.enabled:
print("Embedding plugin is enabled")
embedding_main.init_plugin(settings.embedding)
app.include_router(embedding_main.router, prefix="/plugins/embed")
uvicorn.run(
"memos.server:app",
host=settings.server_host, # Use the new server_host setting

View File

@ -1,7 +1,6 @@
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Dict, Any, Optional
from sentence_transformers import SentenceTransformer
import numpy as np
import httpx
import torch
@ -25,31 +24,15 @@ elif torch.backends.mps.is_available():
else:
device = torch.device("cpu")
torch_dtype = "auto"
torch_dtype = (
torch.float32
if (torch.cuda.is_available() and torch.cuda.get_device_capability()[0] <= 6)
or (not torch.cuda.is_available() and not torch.backends.mps.is_available())
else torch.float16
)
print(f"Using device: {device}")
def init_embedding_model():
model = SentenceTransformer(
"jinaai/jina-embeddings-v2-base-zh", trust_remote_code=True
)
model.to(device)
return model
embedding_model = init_embedding_model() # 初始化模型
def generate_embeddings(input_texts: List[str]) -> List[List[float]]:
embeddings = embedding_model.encode(input_texts, convert_to_tensor=True)
embeddings = embeddings.cpu().numpy()
# normalized embeddings
norms = np.linalg.norm(embeddings, ord=2, axis=1, keepdims=True)
norms[norms == 0] = 1
embeddings = embeddings / norms
return embeddings.tolist()
# Add a configuration option to choose the model
parser = argparse.ArgumentParser(description="Run the server with specified model")
parser.add_argument("--florence", action="store_true", help="Use Florence-2 model")
@ -63,8 +46,11 @@ use_florence_model = args.florence if (args.florence or args.qwen2vl) else True
if use_florence_model:
# Load Florence-2 model
florence_model = AutoModelForCausalLM.from_pretrained(
"microsoft/Florence-2-base-ft", torch_dtype=torch_dtype, trust_remote_code=True
).to(device)
"microsoft/Florence-2-base-ft",
torch_dtype=torch_dtype,
attn_implementation="sdpa",
trust_remote_code=True,
).to(device, torch_dtype)
florence_processor = AutoProcessor.from_pretrained(
"microsoft/Florence-2-base-ft", trust_remote_code=True
)
@ -74,7 +60,7 @@ else:
"Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4",
torch_dtype=torch_dtype,
device_map="auto",
).to(device)
).to(device, torch_dtype)
qwen2vl_processor = AutoProcessor.from_pretrained(
"Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4"
)
@ -139,9 +125,9 @@ async def generate_qwen2vl_result(text_input, image_input, max_tokens):
text = qwen2vl_processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = qwen2vl_processor(
text=[text],
images=image_inputs,
@ -152,12 +138,12 @@ async def generate_qwen2vl_result(text_input, image_input, max_tokens):
inputs = inputs.to(device)
generated_ids = qwen2vl_model.generate(**inputs, max_new_tokens=(max_tokens or 512))
generated_ids_trimmed = [
out_ids[len(in_ids) :]
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = qwen2vl_processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
@ -170,28 +156,6 @@ async def generate_qwen2vl_result(text_input, image_input, max_tokens):
app = FastAPI()
class EmbeddingRequest(BaseModel):
input: List[str]
class EmbeddingResponse(BaseModel):
embeddings: List[List[float]]
@app.post("/api/embed", response_model=EmbeddingResponse)
async def create_embeddings(request: EmbeddingRequest):
try:
if not request.input:
return EmbeddingResponse(embeddings=[])
embeddings = generate_embeddings(request.input) # 使用新方法
return EmbeddingResponse(embeddings=embeddings)
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Error generating embeddings: {str(e)}"
)
class ChatCompletionRequest(BaseModel):
model: str
messages: List[Dict[str, Any]]
@ -275,10 +239,14 @@ async def chat_completions(request: ChatCompletionRequest):
if __name__ == "__main__":
import uvicorn
parser = argparse.ArgumentParser(description="Run the server with specified model and port")
parser = argparse.ArgumentParser(
description="Run the server with specified model and port"
)
parser.add_argument("--florence", action="store_true", help="Use Florence-2 model")
parser.add_argument("--qwen2vl", action="store_true", help="Use Qwen2VL model")
parser.add_argument("--port", type=int, default=8000, help="Port to run the server on")
parser.add_argument(
"--port", type=int, default=8000, help="Port to run the server on"
)
args = parser.parse_args()
if args.florence and args.qwen2vl:

View File

@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "memos"
version = "0.5.0"
version = "0.6.10"
description = "A package for memos"
readme = "README.md"
authors = [{ name = "arkohut" }]
@ -36,6 +36,13 @@ dependencies = [
"pyobjc; sys_platform == 'darwin'",
"pyobjc-core; sys_platform == 'darwin'",
"pyobjc-framework-Quartz; sys_platform == 'darwin'",
"sentence-transformers",
"torch",
"numpy",
"timm",
"einops",
"modelscope",
"mss",
]
[project.urls]
@ -50,3 +57,4 @@ include = ["memos*", "screen_recorder*"]
[tool.setuptools.package-data]
"*" = ["static/**/*"]
"memos.plugins.ocr" = ["*.yaml", "models/*.onnx"]

View File

@ -4,11 +4,11 @@ import time
import logging
import platform
import subprocess
from PIL import Image, ImageGrab
from PIL import Image
import imagehash
from memos.utils import write_image_metadata
from screeninfo import get_monitors
import ctypes
from mss import mss
if platform.system() == "Windows":
import win32gui
@ -173,57 +173,49 @@ def take_screenshot_windows(
app_name,
window_title,
):
for monitor in get_monitors():
safe_monitor_name = "".join(
c for c in monitor.name if c.isalnum() or c in ("_", "-")
)
logging.info(f"Processing monitor: {safe_monitor_name}")
with mss() as sct:
for i, monitor in enumerate(sct.monitors[1:], 1): # Skip the first monitor (entire screen)
safe_monitor_name = f"monitor_{i}"
logging.info(f"Processing monitor: {safe_monitor_name}")
webp_filename = os.path.join(
base_dir, date, f"screenshot-{timestamp}-of-{safe_monitor_name}.webp"
)
img = ImageGrab.grab(
bbox=(
monitor.x,
monitor.y,
monitor.x + monitor.width,
monitor.y + monitor.height,
webp_filename = os.path.join(
base_dir, date, f"screenshot-{timestamp}-of-{safe_monitor_name}.webp"
)
)
img = img.convert("RGB")
current_hash = str(imagehash.phash(img))
if (
safe_monitor_name in previous_hashes
and imagehash.hex_to_hash(current_hash)
- imagehash.hex_to_hash(previous_hashes[safe_monitor_name])
< threshold
):
logging.info(
f"Screenshot for {safe_monitor_name} is similar to the previous one. Skipping."
img = sct.grab(monitor)
img = Image.frombytes("RGB", img.size, img.bgra, "raw", "BGRX")
current_hash = str(imagehash.phash(img))
if (
safe_monitor_name in previous_hashes
and imagehash.hex_to_hash(current_hash)
- imagehash.hex_to_hash(previous_hashes[safe_monitor_name])
< threshold
):
logging.info(
f"Screenshot for {safe_monitor_name} is similar to the previous one. Skipping."
)
yield safe_monitor_name, None, "Skipped (similar to previous)"
continue
previous_hashes[safe_monitor_name] = current_hash
screen_sequences[safe_monitor_name] = (
screen_sequences.get(safe_monitor_name, 0) + 1
)
yield safe_monitor_name, None, "Skipped (similar to previous)"
continue
previous_hashes[safe_monitor_name] = current_hash
screen_sequences[safe_monitor_name] = (
screen_sequences.get(safe_monitor_name, 0) + 1
)
metadata = {
"timestamp": timestamp,
"active_app": app_name,
"active_window": window_title,
"screen_name": safe_monitor_name,
"sequence": screen_sequences[safe_monitor_name],
}
metadata = {
"timestamp": timestamp,
"active_app": app_name,
"active_window": window_title,
"screen_name": safe_monitor_name,
"sequence": screen_sequences[safe_monitor_name],
}
img.save(webp_filename, format="WebP", quality=85)
write_image_metadata(webp_filename, metadata)
save_screen_sequences(base_dir, screen_sequences, date)
img.save(webp_filename, format="WebP", quality=85)
write_image_metadata(webp_filename, metadata)
save_screen_sequences(base_dir, screen_sequences, date)
yield safe_monitor_name, webp_filename, "Saved"
yield safe_monitor_name, webp_filename, "Saved"
def take_screenshot(

View File

@ -9,6 +9,8 @@ from screen_recorder.common import (
take_screenshot,
is_screen_locked,
)
from pathlib import Path
from memos.config import settings
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
@ -51,12 +53,12 @@ def main():
"--threshold", type=int, default=4, help="Threshold for image similarity"
)
parser.add_argument(
"--base-dir", type=str, default="~/tmp", help="Base directory for screenshots"
"--base-dir", type=str, help="Base directory for screenshots"
)
parser.add_argument("--once", action="store_true", help="Run once and exit")
args = parser.parse_args()
base_dir = os.path.expanduser(args.base_dir)
base_dir = os.path.expanduser(args.base_dir) if args.base_dir else settings.screenshots_dir
previous_hashes = load_previous_hashes(base_dir)
if args.once:

View File

@ -112,7 +112,7 @@
id={`library-${library.id}`}
bind:checked={selectedLibraries[library.id]}
/>
<Label for={`library-${library.id}`} class="flex items-center text-sm">{library.name}</Label>
<Label for={`library-${library.id}`} class="flex items-center text-sm">{library.name}#{library.id}</Label>
</div>
{/each}
</div>