mirror of
https://github.com/tcsenpai/pensieve.git
synced 2025-06-06 03:05:25 +00:00
Merge branch 'master' of https://github.com/arkohut/memos
This commit is contained in:
commit
4a8d8fa54a
68
README.md
68
README.md
@ -1,3 +1,69 @@
|
||||
# Memos
|
||||
|
||||
A project to index everything to make it like another memory.
|
||||
A project to index everything to make it like another memory. The project contains two parts:
|
||||
|
||||
1. `screen recorder`: which takes screenshots every 5 seconds and saves them to `~/.memos/screenshots` by default.
|
||||
2. `memos server`: a web service that can index the screenshots and other files, providing a web interface to search the records.
|
||||
|
||||
There is a product called [Rewind](https://www.rewind.ai/) that is similar to memos, but memos aims to give you control over all your data.
|
||||
|
||||
## Install
|
||||
|
||||
### Install Typesense
|
||||
|
||||
```bash
|
||||
export TYPESENSE_API_KEY=xyz
|
||||
|
||||
mkdir "$(pwd)"/typesense-data
|
||||
|
||||
docker run -d -p 8108:8108 \
|
||||
-v"$(pwd)"/typesense-data:/data typesense/typesense:27.0 \
|
||||
--add-host=host.docker.internal:host-gateway \
|
||||
--data-dir /data \
|
||||
--api-key=$TYPESENSE_API_KEY \
|
||||
--enable-cors
|
||||
```
|
||||
|
||||
### Install Memos
|
||||
|
||||
```bash
|
||||
pip install memos
|
||||
```
|
||||
|
||||
## How to use
|
||||
|
||||
To use memos, you need to initialize it first. Make sure you have started `typesense`.
|
||||
|
||||
### 1. Initialize Memos
|
||||
|
||||
```bash
|
||||
memos init
|
||||
```
|
||||
|
||||
This will create a folder `~/.memos` and put the config file there.
|
||||
|
||||
### 2. Start Screen Recorder
|
||||
|
||||
```bash
|
||||
memos-record
|
||||
```
|
||||
|
||||
This will start a screen recorder, which will take screenshots every 5 seconds and save it at `~/.memos/screenshots` by default.
|
||||
|
||||
### 3. Start Memos Server
|
||||
|
||||
```bash
|
||||
memos serve
|
||||
```
|
||||
|
||||
This will start a web server, and you can access the web interface at `http://localhost:8080`.
|
||||
The default username and password is `admin` and `changeme`.
|
||||
|
||||
### Index the screenshots
|
||||
|
||||
```bash
|
||||
memos scan
|
||||
memos index
|
||||
```
|
||||
|
||||
Refresh the page, and do some search.
|
||||
|
@ -1,6 +1,5 @@
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
@ -549,6 +548,9 @@ def index(
|
||||
library_id: int,
|
||||
folders: List[int] = typer.Option(None, "--folder", "-f"),
|
||||
force: bool = typer.Option(False, "--force", help="Force update all indexes"),
|
||||
batchsize: int = typer.Option(
|
||||
4, "--batchsize", "-bs", help="Number of entities to index in a batch"
|
||||
),
|
||||
):
|
||||
print(f"Indexing library {library_id}")
|
||||
|
||||
@ -608,9 +610,8 @@ def index(
|
||||
pbar.refresh()
|
||||
|
||||
# Index each entity
|
||||
batch_size = settings.batchsize
|
||||
for i in range(0, len(entities), batch_size):
|
||||
batch = entities[i : i + batch_size]
|
||||
for i in range(0, len(entities), batchsize):
|
||||
batch = entities[i : i + batchsize]
|
||||
to_index = []
|
||||
|
||||
for entity in batch:
|
||||
@ -722,18 +723,22 @@ def create(name: str, webhook_url: str, description: str = ""):
|
||||
@plugin_app.command("bind")
|
||||
def bind(
|
||||
library_id: int = typer.Option(..., "--lib", help="ID of the library"),
|
||||
plugin_id: int = typer.Option(..., "--plugin", help="ID of the plugin"),
|
||||
plugin: str = typer.Option(..., "--plugin", help="ID or name of the plugin"),
|
||||
):
|
||||
try:
|
||||
plugin_id = int(plugin)
|
||||
plugin_param = {"plugin_id": plugin_id}
|
||||
except ValueError:
|
||||
plugin_param = {"plugin_name": plugin}
|
||||
|
||||
response = httpx.post(
|
||||
f"{BASE_URL}/libraries/{library_id}/plugins",
|
||||
json={"plugin_id": plugin_id},
|
||||
json=plugin_param,
|
||||
)
|
||||
if 200 <= response.status_code < 300:
|
||||
if response.status_code == 204:
|
||||
print("Plugin bound to library successfully")
|
||||
else:
|
||||
print(
|
||||
f"Failed to bind plugin to library: {response.status_code} - {response.text}"
|
||||
)
|
||||
print(f"Failed to bind plugin to library: {response.status_code} - {response.text}")
|
||||
|
||||
|
||||
@plugin_app.command("unbind")
|
||||
@ -763,5 +768,85 @@ def init():
|
||||
print("Initialization failed. Please check the error messages above.")
|
||||
|
||||
|
||||
@app.command("scan")
|
||||
def scan_default_library(force: bool = False):
|
||||
"""
|
||||
Scan the screenshots directory and add it to the library if empty.
|
||||
"""
|
||||
# Get the default library
|
||||
response = httpx.get(f"{BASE_URL}/libraries")
|
||||
if response.status_code != 200:
|
||||
print(f"Failed to retrieve libraries: {response.status_code} - {response.text}")
|
||||
return
|
||||
|
||||
libraries = response.json()
|
||||
default_library = next(
|
||||
(lib for lib in libraries if lib["name"] == settings.default_library), None
|
||||
)
|
||||
|
||||
if not default_library:
|
||||
# Create the default library if it doesn't exist
|
||||
response = httpx.post(
|
||||
f"{BASE_URL}/libraries",
|
||||
json={"name": settings.default_library, "folders": []},
|
||||
)
|
||||
if response.status_code != 200:
|
||||
print(
|
||||
f"Failed to create default library: {response.status_code} - {response.text}"
|
||||
)
|
||||
return
|
||||
default_library = response.json()
|
||||
|
||||
for plugin in settings.default_plugins:
|
||||
bind(default_library["id"], plugin)
|
||||
|
||||
# Check if the library is empty
|
||||
if not default_library["folders"]:
|
||||
# Add the screenshots directory to the library
|
||||
screenshots_dir = Path(settings.screenshots_dir).resolve()
|
||||
response = httpx.post(
|
||||
f"{BASE_URL}/libraries/{default_library['id']}/folders",
|
||||
json={"folders": [str(screenshots_dir)]},
|
||||
)
|
||||
if response.status_code != 200:
|
||||
print(
|
||||
f"Failed to add screenshots directory: {response.status_code} - {response.text}"
|
||||
)
|
||||
return
|
||||
print(f"Added screenshots directory: {screenshots_dir}")
|
||||
|
||||
# Scan the library
|
||||
print(f"Scanning library: {default_library['name']}")
|
||||
scan(default_library["id"], plugins=None, folders=None, force=force)
|
||||
|
||||
|
||||
@app.command("index")
|
||||
def index_default_library(
|
||||
batchsize: int = typer.Option(
|
||||
4, "--batchsize", "-bs", help="Number of entities to index in a batch"
|
||||
),
|
||||
force: bool = typer.Option(False, "--force", help="Force update all indexes"),
|
||||
):
|
||||
"""
|
||||
Index the default library for memos.
|
||||
"""
|
||||
# Get the default library
|
||||
response = httpx.get(f"{BASE_URL}/libraries")
|
||||
if response.status_code != 200:
|
||||
print(f"Failed to retrieve libraries: {response.status_code} - {response.text}")
|
||||
return
|
||||
|
||||
libraries = response.json()
|
||||
default_library = next(
|
||||
(lib for lib in libraries if lib["name"] == settings.default_library), None
|
||||
)
|
||||
|
||||
if not default_library:
|
||||
print("Default library does not exist.")
|
||||
return
|
||||
|
||||
index(default_library["id"], force=force, folders=None, batchsize=batchsize)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
|
@ -1,6 +1,6 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Tuple, Type
|
||||
from typing import Tuple, Type, List
|
||||
from pydantic_settings import (
|
||||
BaseSettings,
|
||||
PydanticBaseSettingsSource,
|
||||
@ -17,24 +17,28 @@ class VLMSettings(BaseModel):
|
||||
modelname: str = "moondream"
|
||||
endpoint: str = "http://localhost:11434"
|
||||
token: str = ""
|
||||
concurrency: int = 4
|
||||
concurrency: int = 1
|
||||
force_jpeg: bool = False
|
||||
use_local: bool = True
|
||||
use_modelscope: bool = False
|
||||
|
||||
|
||||
class OCRSettings(BaseModel):
|
||||
enabled: bool = True
|
||||
endpoint: str = "http://localhost:5555/predict"
|
||||
token: str = ""
|
||||
concurrency: int = 4
|
||||
concurrency: int = 1
|
||||
use_local: bool = True
|
||||
use_gpu: bool = False
|
||||
force_jpeg: bool = False
|
||||
|
||||
|
||||
class EmbeddingSettings(BaseModel):
|
||||
enabled: bool = True
|
||||
num_dim: int = 768
|
||||
ollama_endpoint: str = "http://localhost:11434"
|
||||
ollama_model: str = "nextfire/paraphrase-multilingual-minilm"
|
||||
endpoint: str = "http://localhost:11434/api/embed"
|
||||
model: str = "jinaai/jina-embeddings-v2-base-zh"
|
||||
use_modelscope: bool = False
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
@ -46,6 +50,9 @@ class Settings(BaseSettings):
|
||||
|
||||
base_dir: str = str(Path.home() / ".memos")
|
||||
database_path: str = os.path.join(base_dir, "database.db")
|
||||
default_library: str = "screenshots"
|
||||
screenshots_dir: str = os.path.join(base_dir, "screenshots")
|
||||
|
||||
typesense_host: str = "localhost"
|
||||
typesense_port: str = "8108"
|
||||
typesense_protocol: str = "http"
|
||||
@ -66,11 +73,13 @@ class Settings(BaseSettings):
|
||||
# Embedding settings
|
||||
embedding: EmbeddingSettings = EmbeddingSettings()
|
||||
|
||||
batchsize: int = 4
|
||||
batchsize: int = 1
|
||||
|
||||
auth_username: str = "admin"
|
||||
auth_password: SecretStr = SecretStr("changeme")
|
||||
|
||||
default_plugins: List[str] = ["builtin_vlm", "builtin_ocr"]
|
||||
|
||||
@classmethod
|
||||
def settings_customise_sources(
|
||||
cls,
|
||||
@ -93,6 +102,19 @@ def dict_representer(dumper, data):
|
||||
yaml.add_representer(OrderedDict, dict_representer)
|
||||
|
||||
|
||||
# Custom representer for SecretStr
|
||||
def secret_str_representer(dumper, data):
|
||||
return dumper.represent_scalar("tag:yaml.org,2002:str", data.get_secret_value())
|
||||
|
||||
# Custom constructor for SecretStr
|
||||
def secret_str_constructor(loader, node):
|
||||
value = loader.construct_scalar(node)
|
||||
return SecretStr(value)
|
||||
|
||||
# Register the representer and constructor only for specific fields
|
||||
yaml.add_representer(SecretStr, secret_str_representer)
|
||||
|
||||
|
||||
def create_default_config():
|
||||
config_path = Path.home() / ".memos" / "config.yaml"
|
||||
if not config_path.exists():
|
||||
|
@ -46,14 +46,19 @@ def parse_date_fields(entity):
|
||||
}
|
||||
|
||||
|
||||
def get_embeddings(texts: List[str]) -> List[List[float]]:
|
||||
async def get_embeddings(texts: List[str]) -> List[List[float]]:
|
||||
print(f"Getting embeddings for {len(texts)} texts")
|
||||
ollama_endpoint = settings.embedding.ollama_endpoint
|
||||
ollama_model = settings.embedding.ollama_model
|
||||
with httpx.Client() as client:
|
||||
response = client.post(
|
||||
f"{ollama_endpoint}/api/embed",
|
||||
json={"model": ollama_model, "input": texts},
|
||||
|
||||
if settings.embedding.enabled:
|
||||
endpoint = f"http://{settings.server_host}:{settings.server_port}/plugins/embed"
|
||||
else:
|
||||
endpoint = settings.embedding.endpoint
|
||||
|
||||
model = settings.embedding.model
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
endpoint,
|
||||
json={"model": model, "input": texts},
|
||||
timeout=30
|
||||
)
|
||||
if response.status_code == 200:
|
||||
@ -99,7 +104,7 @@ def generate_metadata_text(metadata_entries):
|
||||
return metadata_text
|
||||
|
||||
|
||||
def bulk_upsert(client, entities):
|
||||
async def bulk_upsert(client, entities):
|
||||
documents = []
|
||||
metadata_texts = []
|
||||
entities_with_metadata = []
|
||||
@ -142,7 +147,7 @@ def bulk_upsert(client, entities):
|
||||
).model_dump(mode="json")
|
||||
)
|
||||
|
||||
embeddings = get_embeddings(metadata_texts)
|
||||
embeddings = await get_embeddings(metadata_texts)
|
||||
for doc, embedding, entity in zip(documents, embeddings, entities):
|
||||
if entity in entities_with_metadata:
|
||||
doc["embedding"] = embedding
|
||||
@ -259,7 +264,7 @@ def list_all_entities(
|
||||
)
|
||||
|
||||
|
||||
def search_entities(
|
||||
async def search_entities(
|
||||
client,
|
||||
q: str,
|
||||
library_ids: List[int] = None,
|
||||
@ -287,7 +292,7 @@ def search_entities(
|
||||
filter_by_str = " && ".join(filter_by) if filter_by else ""
|
||||
|
||||
# Convert q to embedding using get_embeddings and take the first embedding
|
||||
embedding = get_embeddings([q])[0]
|
||||
embedding = (await get_embeddings([q]))[0]
|
||||
|
||||
common_search_params = {
|
||||
"collection": TYPESENSE_COLLECTION_NAME,
|
||||
|
@ -15,7 +15,7 @@ from typing import List
|
||||
from .schemas import MetadataSource, MetadataType
|
||||
from sqlalchemy.exc import OperationalError
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from .config import get_database_path
|
||||
from .config import get_database_path, settings
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
@ -75,16 +75,14 @@ class EntityModel(Base):
|
||||
"FolderModel", back_populates="entities"
|
||||
)
|
||||
metadata_entries: Mapped[List["EntityMetadataModel"]] = relationship(
|
||||
"EntityMetadataModel",
|
||||
lazy="joined",
|
||||
cascade="all, delete-orphan"
|
||||
"EntityMetadataModel", lazy="joined", cascade="all, delete-orphan"
|
||||
)
|
||||
tags: Mapped[List["TagModel"]] = relationship(
|
||||
"TagModel",
|
||||
secondary="entity_tags",
|
||||
"TagModel",
|
||||
secondary="entity_tags",
|
||||
lazy="joined",
|
||||
cascade="all, delete",
|
||||
overlaps="entities"
|
||||
overlaps="entities",
|
||||
)
|
||||
|
||||
# 添加索引
|
||||
@ -160,30 +158,61 @@ def init_database():
|
||||
"""Initialize the database."""
|
||||
db_path = get_database_path()
|
||||
engine = create_engine(f"sqlite:///{db_path}")
|
||||
|
||||
|
||||
try:
|
||||
Base.metadata.create_all(engine)
|
||||
print(f"Database initialized successfully at {db_path}")
|
||||
|
||||
|
||||
# Initialize default plugins
|
||||
Session = sessionmaker(bind=engine)
|
||||
with Session() as session:
|
||||
initialize_default_plugins(session)
|
||||
|
||||
default_plugins = initialize_default_plugins(session)
|
||||
init_default_libraries(session, default_plugins)
|
||||
|
||||
return True
|
||||
except OperationalError as e:
|
||||
print(f"Error initializing database: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def initialize_default_plugins(session):
|
||||
default_plugins = [
|
||||
PluginModel(name="buildin_vlm", description="VLM Plugin", webhook_url="/plugins/vlm"),
|
||||
PluginModel(name="buildin_ocr", description="OCR Plugin", webhook_url="/plugins/ocr"),
|
||||
PluginModel(
|
||||
name="builtin_vlm", description="VLM Plugin", webhook_url="/plugins/vlm"
|
||||
),
|
||||
PluginModel(
|
||||
name="builtin_ocr", description="OCR Plugin", webhook_url="/plugins/ocr"
|
||||
),
|
||||
]
|
||||
|
||||
for plugin in default_plugins:
|
||||
existing_plugin = session.query(PluginModel).filter_by(name=plugin.name).first()
|
||||
if not existing_plugin:
|
||||
session.add(plugin)
|
||||
|
||||
session.commit()
|
||||
|
||||
session.commit()
|
||||
|
||||
return default_plugins
|
||||
|
||||
|
||||
def init_default_libraries(session, default_plugins):
|
||||
default_libraries = [
|
||||
LibraryModel(name=settings.default_library),
|
||||
]
|
||||
|
||||
for library in default_libraries:
|
||||
existing_library = (
|
||||
session.query(LibraryModel).filter_by(name=library.name).first()
|
||||
)
|
||||
if not existing_library:
|
||||
session.add(library)
|
||||
|
||||
for plugin in default_plugins:
|
||||
bind_response = session.query(PluginModel).filter_by(name=plugin.name).first()
|
||||
if bind_response:
|
||||
library_plugin = LibraryPluginModel(
|
||||
library_id=1, plugin_id=bind_response.id
|
||||
) # Assuming library_id=1 for default libraries
|
||||
session.add(library_plugin)
|
||||
|
||||
session.commit()
|
||||
|
146
memos/plugins/embedding/main.py
Normal file
146
memos/plugins/embedding/main.py
Normal file
@ -0,0 +1,146 @@
|
||||
import asyncio
|
||||
from typing import List
|
||||
from fastapi import APIRouter, HTTPException
|
||||
import logging
|
||||
import uvicorn
|
||||
from sentence_transformers import SentenceTransformer
|
||||
import torch
|
||||
import numpy as np
|
||||
from pydantic import BaseModel
|
||||
from modelscope import snapshot_download
|
||||
|
||||
PLUGIN_NAME = "embedding"
|
||||
|
||||
router = APIRouter(tags=[PLUGIN_NAME], responses={404: {"description": "Not found"}})
|
||||
|
||||
# Global variables
|
||||
enabled = False
|
||||
model = None
|
||||
num_dim = None
|
||||
endpoint = None
|
||||
model_name = None
|
||||
device = None
|
||||
use_modelscope = None
|
||||
|
||||
# Configure logger
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def init_embedding_model():
|
||||
global model, device, use_modelscope
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
elif torch.backends.mps.is_available():
|
||||
device = torch.device("mps")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
|
||||
if use_modelscope:
|
||||
model_dir = snapshot_download(model_name)
|
||||
logger.info(f"Model downloaded from ModelScope to: {model_dir}")
|
||||
else:
|
||||
model_dir = model_name
|
||||
logger.info(f"Using model: {model_dir}")
|
||||
|
||||
model = SentenceTransformer(model_dir, trust_remote_code=True)
|
||||
model.to(device)
|
||||
logger.info(f"Embedding model initialized on device: {device}")
|
||||
|
||||
|
||||
def generate_embeddings(input_texts: List[str]) -> List[List[float]]:
|
||||
embeddings = model.encode(input_texts, convert_to_tensor=True)
|
||||
embeddings = embeddings.cpu().numpy()
|
||||
# Normalize embeddings
|
||||
norms = np.linalg.norm(embeddings, ord=2, axis=1, keepdims=True)
|
||||
norms[norms == 0] = 1
|
||||
embeddings = embeddings / norms
|
||||
return embeddings.tolist()
|
||||
|
||||
|
||||
class EmbeddingRequest(BaseModel):
|
||||
input: List[str]
|
||||
|
||||
|
||||
class EmbeddingResponse(BaseModel):
|
||||
embeddings: List[List[float]]
|
||||
|
||||
|
||||
@router.get("/")
|
||||
async def read_root():
|
||||
return {"healthy": True, "enabled": enabled}
|
||||
|
||||
|
||||
@router.post("", include_in_schema=False)
|
||||
@router.post("/", response_model=EmbeddingResponse)
|
||||
async def embed(request: EmbeddingRequest):
|
||||
try:
|
||||
if not request.input:
|
||||
return EmbeddingResponse(embeddings=[])
|
||||
|
||||
# Run the embedding generation in a separate thread to avoid blocking
|
||||
loop = asyncio.get_event_loop()
|
||||
embeddings = await loop.run_in_executor(None, generate_embeddings, request.input)
|
||||
return EmbeddingResponse(embeddings=embeddings)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Error generating embeddings: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
def init_plugin(config):
|
||||
global enabled, num_dim, endpoint, model_name, use_modelscope
|
||||
enabled = config.enabled
|
||||
num_dim = config.num_dim
|
||||
endpoint = config.endpoint
|
||||
model_name = config.model
|
||||
use_modelscope = config.use_modelscope
|
||||
|
||||
if enabled:
|
||||
init_embedding_model()
|
||||
|
||||
logger.info("Embedding plugin initialized")
|
||||
logger.info(f"Enabled: {enabled}")
|
||||
logger.info(f"Number of dimensions: {num_dim}")
|
||||
logger.info(f"Endpoint: {endpoint}")
|
||||
logger.info(f"Model: {model_name}")
|
||||
logger.info(f"Use ModelScope: {use_modelscope}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
from fastapi import FastAPI
|
||||
|
||||
parser = argparse.ArgumentParser(description="Embedding Plugin Configuration")
|
||||
parser.add_argument(
|
||||
"--num-dim", type=int, default=768, help="Number of embedding dimensions"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="jinaai/jina-embeddings-v2-base-zh",
|
||||
help="Embedding model name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port", type=int, default=8000, help="Port to run the server on"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-modelscope", action="store_true", help="Use ModelScope to download the model"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
class Config:
|
||||
def __init__(self, args):
|
||||
self.enabled = True
|
||||
self.num_dim = args.num_dim
|
||||
self.endpoint = "what ever"
|
||||
self.model = args.model
|
||||
self.use_modelscope = args.use_modelscope
|
||||
|
||||
init_plugin(Config(args))
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(router)
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=args.port)
|
@ -141,6 +141,7 @@ async def ocr(entity: Entity, request: Request):
|
||||
}
|
||||
]
|
||||
},
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
# Check if the patch request was successful
|
||||
@ -178,7 +179,7 @@ def init_plugin(config):
|
||||
ocr_config['Cls']['model_path'] = os.path.join(model_dir, os.path.basename(ocr_config['Cls']['model_path']))
|
||||
ocr_config['Rec']['model_path'] = os.path.join(model_dir, os.path.basename(ocr_config['Rec']['model_path']))
|
||||
|
||||
# Save the updated config to a temporary file
|
||||
# Save the updated config to a temporary file with strings wrapped in double quotes
|
||||
temp_config_path = os.path.join(os.path.dirname(__file__), "temp_ppocr.yaml")
|
||||
with open(temp_config_path, 'w') as f:
|
||||
yaml.safe_dump(ocr_config, f)
|
||||
|
@ -9,6 +9,20 @@ import logging
|
||||
import uvicorn
|
||||
import os
|
||||
import io
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoProcessor
|
||||
|
||||
from unittest.mock import patch
|
||||
from transformers.dynamic_module_utils import get_imports
|
||||
from modelscope import snapshot_download
|
||||
|
||||
def fixed_get_imports(filename: str | os.PathLike) -> list[str]:
|
||||
if not str(filename).endswith("modeling_florence2.py"):
|
||||
return get_imports(filename)
|
||||
imports = get_imports(filename)
|
||||
imports.remove("flash_attn")
|
||||
return imports
|
||||
|
||||
|
||||
PLUGIN_NAME = "vlm"
|
||||
PROMPT = "描述这张图片的内容"
|
||||
@ -21,6 +35,10 @@ token = None
|
||||
concurrency = None
|
||||
semaphore = None
|
||||
force_jpeg = None
|
||||
use_local = None
|
||||
florence_model = None
|
||||
florence_processor = None
|
||||
torch_dtype = None
|
||||
|
||||
# Configure logger
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
@ -35,18 +53,18 @@ def image2base64(img_path):
|
||||
with Image.open(img_path) as img:
|
||||
if force_jpeg:
|
||||
# Convert image to RGB mode (removes alpha channel if present)
|
||||
img = img.convert('RGB')
|
||||
img = img.convert("RGB")
|
||||
# Save as JPEG in memory
|
||||
buffer = io.BytesIO()
|
||||
img.save(buffer, format='JPEG')
|
||||
img.save(buffer, format="JPEG")
|
||||
buffer.seek(0)
|
||||
encoded_string = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
||||
encoded_string = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
else:
|
||||
# Use original format
|
||||
buffer = io.BytesIO()
|
||||
img.save(buffer, format=img.format)
|
||||
buffer.seek(0)
|
||||
encoded_string = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
||||
encoded_string = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
return encoded_string
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing image {img_path}: {str(e)}")
|
||||
@ -79,12 +97,57 @@ async def fetch(endpoint: str, client, request_data, headers: Optional[dict] = N
|
||||
|
||||
async def predict(
|
||||
endpoint: str, modelname: str, img_path: str, token: Optional[str] = None
|
||||
) -> Optional[str]:
|
||||
if use_local:
|
||||
return await predict_local(img_path)
|
||||
else:
|
||||
return await predict_remote(endpoint, modelname, img_path, token)
|
||||
|
||||
|
||||
async def predict_local(img_path: str) -> Optional[str]:
|
||||
try:
|
||||
image = Image.open(img_path)
|
||||
task_prompt = "<MORE_DETAILED_CAPTION>"
|
||||
prompt = task_prompt + ""
|
||||
|
||||
inputs = florence_processor(text=prompt, images=image, return_tensors="pt").to(
|
||||
florence_model.device, torch_dtype
|
||||
)
|
||||
|
||||
generated_ids = florence_model.generate(
|
||||
input_ids=inputs["input_ids"],
|
||||
pixel_values=inputs["pixel_values"],
|
||||
max_new_tokens=1024,
|
||||
do_sample=False,
|
||||
num_beams=3,
|
||||
)
|
||||
|
||||
generated_texts = florence_processor.batch_decode(
|
||||
generated_ids, skip_special_tokens=False
|
||||
)
|
||||
|
||||
parsed_answer = florence_processor.post_process_generation(
|
||||
generated_texts[0],
|
||||
task=task_prompt,
|
||||
image_size=(image.width, image.height),
|
||||
)
|
||||
|
||||
return parsed_answer.get(task_prompt, "")
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing image {img_path}: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
async def predict_remote(
|
||||
endpoint: str, modelname: str, img_path: str, token: Optional[str] = None
|
||||
) -> Optional[str]:
|
||||
img_base64 = image2base64(img_path)
|
||||
if not img_base64:
|
||||
return None
|
||||
|
||||
mime_type = "image/jpeg" if force_jpeg else "image/jpeg" # Default to JPEG if force_jpeg is True
|
||||
mime_type = (
|
||||
"image/jpeg" if force_jpeg else "image/jpeg"
|
||||
) # Default to JPEG if force_jpeg is True
|
||||
|
||||
if not force_jpeg:
|
||||
# Only determine MIME type if not forcing JPEG
|
||||
@ -167,9 +230,9 @@ async def vlm(entity: Entity, request: Request):
|
||||
|
||||
vlm_result = await predict(endpoint, modelname, entity.filepath, token=token)
|
||||
|
||||
print(vlm_result)
|
||||
logger.info(vlm_result)
|
||||
if not vlm_result:
|
||||
print(f"No VLM result found for file: {entity.filepath}")
|
||||
logger.info(f"No VLM result found for file: {entity.filepath}")
|
||||
return {metadata_field_name: "{}"}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
@ -199,14 +262,55 @@ async def vlm(entity: Entity, request: Request):
|
||||
|
||||
|
||||
def init_plugin(config):
|
||||
global modelname, endpoint, token, concurrency, semaphore, force_jpeg
|
||||
global modelname, endpoint, token, concurrency, semaphore, force_jpeg, use_local, florence_model, florence_processor, torch_dtype
|
||||
|
||||
modelname = config.modelname
|
||||
endpoint = config.endpoint
|
||||
token = config.token
|
||||
concurrency = config.concurrency
|
||||
force_jpeg = config.force_jpeg
|
||||
use_local = config.use_local
|
||||
use_modelscope = config.use_modelscope
|
||||
semaphore = asyncio.Semaphore(concurrency)
|
||||
|
||||
if use_local:
|
||||
# 检测可用的设备
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
elif torch.backends.mps.is_available():
|
||||
device = torch.device("mps")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
|
||||
torch_dtype = (
|
||||
torch.float32
|
||||
if (
|
||||
torch.cuda.is_available() and torch.cuda.get_device_capability()[0] <= 6
|
||||
)
|
||||
or (not torch.cuda.is_available() and not torch.backends.mps.is_available())
|
||||
else torch.float16
|
||||
)
|
||||
logger.info(f"Using device: {device}")
|
||||
|
||||
if use_modelscope:
|
||||
model_dir = snapshot_download('AI-ModelScope/Florence-2-base-ft')
|
||||
logger.info(f"Model downloaded from ModelScope to: {model_dir}")
|
||||
else:
|
||||
model_dir = "microsoft/Florence-2-base-ft"
|
||||
logger.info(f"Using model: {model_dir}")
|
||||
|
||||
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
|
||||
florence_model = AutoModelForCausalLM.from_pretrained(
|
||||
model_dir,
|
||||
torch_dtype=torch_dtype,
|
||||
attn_implementation="sdpa",
|
||||
trust_remote_code=True,
|
||||
).to(device)
|
||||
florence_processor = AutoProcessor.from_pretrained(
|
||||
model_dir, trust_remote_code=True
|
||||
)
|
||||
logger.info("Florence model and processor initialized")
|
||||
|
||||
# Print the parameters
|
||||
logger.info("VLM plugin initialized")
|
||||
logger.info(f"Model Name: {modelname}")
|
||||
@ -214,6 +318,8 @@ def init_plugin(config):
|
||||
logger.info(f"Token: {token}")
|
||||
logger.info(f"Concurrency: {concurrency}")
|
||||
logger.info(f"Force JPEG: {force_jpeg}")
|
||||
logger.info(f"Use Local: {use_local}")
|
||||
logger.info(f"Use ModelScope: {use_modelscope}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -232,6 +338,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--port", type=int, default=8000, help="Port to run the server on"
|
||||
)
|
||||
parser.add_argument("--use-modelscope", action="store_true", help="Use ModelScope to download the model")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
@ -1,4 +1,11 @@
|
||||
from pydantic import BaseModel, ConfigDict, DirectoryPath, HttpUrl, Field
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
DirectoryPath,
|
||||
HttpUrl,
|
||||
Field,
|
||||
model_validator,
|
||||
)
|
||||
from typing import List, Optional, Any, Dict
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
@ -79,7 +86,18 @@ class NewPluginParam(BaseModel):
|
||||
|
||||
|
||||
class NewLibraryPluginParam(BaseModel):
|
||||
plugin_id: int
|
||||
plugin_id: Optional[int] = None
|
||||
plugin_name: Optional[str] = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_either_id_or_name(self):
|
||||
plugin_id = self.plugin_id
|
||||
plugin_name = self.plugin_name
|
||||
if not (plugin_id or plugin_name):
|
||||
raise ValueError("Either plugin_id or plugin_name must be provided")
|
||||
if plugin_id is not None and plugin_name is not None:
|
||||
raise ValueError("Only one of plugin_id or plugin_name should be provided")
|
||||
return self
|
||||
|
||||
|
||||
class Folder(BaseModel):
|
||||
@ -214,15 +232,18 @@ class FacetCount(BaseModel):
|
||||
highlighted: str
|
||||
value: str
|
||||
|
||||
|
||||
class FacetStats(BaseModel):
|
||||
total_values: int
|
||||
|
||||
|
||||
class Facet(BaseModel):
|
||||
counts: List[FacetCount]
|
||||
field_name: str
|
||||
sampled: bool
|
||||
stats: FacetStats
|
||||
|
||||
|
||||
class TextMatchInfo(BaseModel):
|
||||
best_field_score: str
|
||||
best_field_weight: int
|
||||
@ -232,9 +253,11 @@ class TextMatchInfo(BaseModel):
|
||||
tokens_matched: int
|
||||
typo_prefix_score: int
|
||||
|
||||
|
||||
class HybridSearchInfo(BaseModel):
|
||||
rank_fusion_score: float
|
||||
|
||||
|
||||
class SearchHit(BaseModel):
|
||||
document: EntitySearchResult
|
||||
highlight: Dict[str, Any] = {}
|
||||
@ -243,12 +266,14 @@ class SearchHit(BaseModel):
|
||||
text_match: Optional[int] = None
|
||||
text_match_info: Optional[TextMatchInfo] = None
|
||||
|
||||
|
||||
class RequestParams(BaseModel):
|
||||
collection_name: str
|
||||
first_q: str
|
||||
per_page: int
|
||||
q: str
|
||||
|
||||
|
||||
class SearchResult(BaseModel):
|
||||
facet_counts: List[Facet]
|
||||
found: int
|
||||
@ -257,4 +282,4 @@ class SearchResult(BaseModel):
|
||||
page: int
|
||||
request_params: RequestParams
|
||||
search_cutoff: bool
|
||||
search_time_ms: int
|
||||
search_time_ms: int
|
||||
|
@ -23,6 +23,7 @@ import typesense
|
||||
from .config import get_database_path, settings
|
||||
from memos.plugins.vlm import main as vlm_main
|
||||
from memos.plugins.ocr import main as ocr_main
|
||||
from memos.plugins.embedding import main as embedding_main
|
||||
from . import crud
|
||||
from . import indexing
|
||||
from .schemas import (
|
||||
@ -84,18 +85,6 @@ app.mount(
|
||||
"/_app", StaticFiles(directory=os.path.join(current_dir, "static/_app"), html=True)
|
||||
)
|
||||
|
||||
# Add VLM plugin router
|
||||
if settings.vlm.enabled:
|
||||
print("VLM plugin is enabled")
|
||||
vlm_main.init_plugin(settings.vlm)
|
||||
app.include_router(vlm_main.router, prefix="/plugins/vlm")
|
||||
|
||||
# Add OCR plugin router
|
||||
if settings.ocr.enabled:
|
||||
print("OCR plugin is enabled")
|
||||
ocr_main.init_plugin(settings.ocr)
|
||||
app.include_router(ocr_main.router, prefix="/plugins/ocr")
|
||||
|
||||
|
||||
@app.get("/favicon.png", response_class=FileResponse)
|
||||
async def favicon_png():
|
||||
@ -411,7 +400,7 @@ async def batch_sync_entities_to_typesense(
|
||||
)
|
||||
|
||||
try:
|
||||
indexing.bulk_upsert(client, entities)
|
||||
await indexing.bulk_upsert(client, entities)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
@ -481,7 +470,7 @@ def list_entitiy_indices_in_folder(
|
||||
|
||||
|
||||
@app.get("/search", response_model=SearchResult, tags=["search"])
|
||||
async def search_entities(
|
||||
async def search_entities_route(
|
||||
q: str,
|
||||
library_ids: str = Query(None, description="Comma-separated list of library IDs"),
|
||||
folder_ids: str = Query(None, description="Comma-separated list of folder IDs"),
|
||||
@ -502,7 +491,7 @@ async def search_entities(
|
||||
[date.strip() for date in created_dates.split(",")] if created_dates else None
|
||||
)
|
||||
try:
|
||||
return indexing.search_entities(
|
||||
return await indexing.search_entities(
|
||||
client,
|
||||
q,
|
||||
library_ids,
|
||||
@ -613,12 +602,29 @@ def add_library_plugin(
|
||||
library_id: int, new_plugin: NewLibraryPluginParam, db: Session = Depends(get_db)
|
||||
):
|
||||
library = crud.get_library_by_id(library_id, db)
|
||||
if any(plugin.id == new_plugin.plugin_id for plugin in library.plugins):
|
||||
if library is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Library not found"
|
||||
)
|
||||
|
||||
plugin = None
|
||||
if new_plugin.plugin_id is not None:
|
||||
plugin = crud.get_plugin_by_id(new_plugin.plugin_id, db)
|
||||
elif new_plugin.plugin_name is not None:
|
||||
plugin = crud.get_plugin_by_name(new_plugin.plugin_name, db)
|
||||
|
||||
if plugin is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Plugin not found"
|
||||
)
|
||||
|
||||
if any(p.id == plugin.id for p in library.plugins):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Plugin already exists in the library",
|
||||
)
|
||||
crud.add_plugin_to_library(library_id, new_plugin.plugin_id, db)
|
||||
|
||||
crud.add_plugin_to_library(library_id, plugin.id, db)
|
||||
|
||||
|
||||
@app.delete(
|
||||
@ -727,6 +733,24 @@ def run_server():
|
||||
print(f"VLM plugin enabled: {settings.vlm}")
|
||||
print(f"OCR plugin enabled: {settings.ocr}")
|
||||
|
||||
# Add VLM plugin router
|
||||
if settings.vlm.enabled:
|
||||
print("VLM plugin is enabled")
|
||||
vlm_main.init_plugin(settings.vlm)
|
||||
app.include_router(vlm_main.router, prefix="/plugins/vlm")
|
||||
|
||||
# Add OCR plugin router
|
||||
if settings.ocr.enabled:
|
||||
print("OCR plugin is enabled")
|
||||
ocr_main.init_plugin(settings.ocr)
|
||||
app.include_router(ocr_main.router, prefix="/plugins/ocr")
|
||||
|
||||
# Add Embedding plugin router
|
||||
if settings.embedding.enabled:
|
||||
print("Embedding plugin is enabled")
|
||||
embedding_main.init_plugin(settings.embedding)
|
||||
app.include_router(embedding_main.router, prefix="/plugins/embed")
|
||||
|
||||
uvicorn.run(
|
||||
"memos.server:app",
|
||||
host=settings.server_host, # Use the new server_host setting
|
||||
|
@ -1,7 +1,6 @@
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Dict, Any, Optional
|
||||
from sentence_transformers import SentenceTransformer
|
||||
import numpy as np
|
||||
import httpx
|
||||
import torch
|
||||
@ -25,31 +24,15 @@ elif torch.backends.mps.is_available():
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
|
||||
torch_dtype = "auto"
|
||||
torch_dtype = (
|
||||
torch.float32
|
||||
if (torch.cuda.is_available() and torch.cuda.get_device_capability()[0] <= 6)
|
||||
or (not torch.cuda.is_available() and not torch.backends.mps.is_available())
|
||||
else torch.float16
|
||||
)
|
||||
print(f"Using device: {device}")
|
||||
|
||||
|
||||
def init_embedding_model():
|
||||
model = SentenceTransformer(
|
||||
"jinaai/jina-embeddings-v2-base-zh", trust_remote_code=True
|
||||
)
|
||||
model.to(device)
|
||||
return model
|
||||
|
||||
|
||||
embedding_model = init_embedding_model() # 初始化模型
|
||||
|
||||
|
||||
def generate_embeddings(input_texts: List[str]) -> List[List[float]]:
|
||||
embeddings = embedding_model.encode(input_texts, convert_to_tensor=True)
|
||||
embeddings = embeddings.cpu().numpy()
|
||||
# normalized embeddings
|
||||
norms = np.linalg.norm(embeddings, ord=2, axis=1, keepdims=True)
|
||||
norms[norms == 0] = 1
|
||||
embeddings = embeddings / norms
|
||||
return embeddings.tolist()
|
||||
|
||||
|
||||
# Add a configuration option to choose the model
|
||||
parser = argparse.ArgumentParser(description="Run the server with specified model")
|
||||
parser.add_argument("--florence", action="store_true", help="Use Florence-2 model")
|
||||
@ -63,8 +46,11 @@ use_florence_model = args.florence if (args.florence or args.qwen2vl) else True
|
||||
if use_florence_model:
|
||||
# Load Florence-2 model
|
||||
florence_model = AutoModelForCausalLM.from_pretrained(
|
||||
"microsoft/Florence-2-base-ft", torch_dtype=torch_dtype, trust_remote_code=True
|
||||
).to(device)
|
||||
"microsoft/Florence-2-base-ft",
|
||||
torch_dtype=torch_dtype,
|
||||
attn_implementation="sdpa",
|
||||
trust_remote_code=True,
|
||||
).to(device, torch_dtype)
|
||||
florence_processor = AutoProcessor.from_pretrained(
|
||||
"microsoft/Florence-2-base-ft", trust_remote_code=True
|
||||
)
|
||||
@ -74,7 +60,7 @@ else:
|
||||
"Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4",
|
||||
torch_dtype=torch_dtype,
|
||||
device_map="auto",
|
||||
).to(device)
|
||||
).to(device, torch_dtype)
|
||||
qwen2vl_processor = AutoProcessor.from_pretrained(
|
||||
"Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4"
|
||||
)
|
||||
@ -139,9 +125,9 @@ async def generate_qwen2vl_result(text_input, image_input, max_tokens):
|
||||
text = qwen2vl_processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
|
||||
image_inputs, video_inputs = process_vision_info(messages)
|
||||
|
||||
|
||||
inputs = qwen2vl_processor(
|
||||
text=[text],
|
||||
images=image_inputs,
|
||||
@ -152,12 +138,12 @@ async def generate_qwen2vl_result(text_input, image_input, max_tokens):
|
||||
inputs = inputs.to(device)
|
||||
|
||||
generated_ids = qwen2vl_model.generate(**inputs, max_new_tokens=(max_tokens or 512))
|
||||
|
||||
|
||||
generated_ids_trimmed = [
|
||||
out_ids[len(in_ids) :]
|
||||
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||||
]
|
||||
|
||||
|
||||
output_text = qwen2vl_processor.batch_decode(
|
||||
generated_ids_trimmed,
|
||||
skip_special_tokens=True,
|
||||
@ -170,28 +156,6 @@ async def generate_qwen2vl_result(text_input, image_input, max_tokens):
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
class EmbeddingRequest(BaseModel):
|
||||
input: List[str]
|
||||
|
||||
|
||||
class EmbeddingResponse(BaseModel):
|
||||
embeddings: List[List[float]]
|
||||
|
||||
|
||||
@app.post("/api/embed", response_model=EmbeddingResponse)
|
||||
async def create_embeddings(request: EmbeddingRequest):
|
||||
try:
|
||||
if not request.input:
|
||||
return EmbeddingResponse(embeddings=[])
|
||||
|
||||
embeddings = generate_embeddings(request.input) # 使用新方法
|
||||
return EmbeddingResponse(embeddings=embeddings)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Error generating embeddings: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
model: str
|
||||
messages: List[Dict[str, Any]]
|
||||
@ -275,10 +239,14 @@ async def chat_completions(request: ChatCompletionRequest):
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
parser = argparse.ArgumentParser(description="Run the server with specified model and port")
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run the server with specified model and port"
|
||||
)
|
||||
parser.add_argument("--florence", action="store_true", help="Use Florence-2 model")
|
||||
parser.add_argument("--qwen2vl", action="store_true", help="Use Qwen2VL model")
|
||||
parser.add_argument("--port", type=int, default=8000, help="Port to run the server on")
|
||||
parser.add_argument(
|
||||
"--port", type=int, default=8000, help="Port to run the server on"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.florence and args.qwen2vl:
|
||||
|
@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "memos"
|
||||
version = "0.5.0"
|
||||
version = "0.6.10"
|
||||
description = "A package for memos"
|
||||
readme = "README.md"
|
||||
authors = [{ name = "arkohut" }]
|
||||
@ -36,6 +36,13 @@ dependencies = [
|
||||
"pyobjc; sys_platform == 'darwin'",
|
||||
"pyobjc-core; sys_platform == 'darwin'",
|
||||
"pyobjc-framework-Quartz; sys_platform == 'darwin'",
|
||||
"sentence-transformers",
|
||||
"torch",
|
||||
"numpy",
|
||||
"timm",
|
||||
"einops",
|
||||
"modelscope",
|
||||
"mss",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
@ -50,3 +57,4 @@ include = ["memos*", "screen_recorder*"]
|
||||
|
||||
[tool.setuptools.package-data]
|
||||
"*" = ["static/**/*"]
|
||||
"memos.plugins.ocr" = ["*.yaml", "models/*.onnx"]
|
||||
|
@ -4,11 +4,11 @@ import time
|
||||
import logging
|
||||
import platform
|
||||
import subprocess
|
||||
from PIL import Image, ImageGrab
|
||||
from PIL import Image
|
||||
import imagehash
|
||||
from memos.utils import write_image_metadata
|
||||
from screeninfo import get_monitors
|
||||
import ctypes
|
||||
from mss import mss
|
||||
|
||||
if platform.system() == "Windows":
|
||||
import win32gui
|
||||
@ -173,57 +173,49 @@ def take_screenshot_windows(
|
||||
app_name,
|
||||
window_title,
|
||||
):
|
||||
for monitor in get_monitors():
|
||||
safe_monitor_name = "".join(
|
||||
c for c in monitor.name if c.isalnum() or c in ("_", "-")
|
||||
)
|
||||
logging.info(f"Processing monitor: {safe_monitor_name}")
|
||||
with mss() as sct:
|
||||
for i, monitor in enumerate(sct.monitors[1:], 1): # Skip the first monitor (entire screen)
|
||||
safe_monitor_name = f"monitor_{i}"
|
||||
logging.info(f"Processing monitor: {safe_monitor_name}")
|
||||
|
||||
webp_filename = os.path.join(
|
||||
base_dir, date, f"screenshot-{timestamp}-of-{safe_monitor_name}.webp"
|
||||
)
|
||||
|
||||
img = ImageGrab.grab(
|
||||
bbox=(
|
||||
monitor.x,
|
||||
monitor.y,
|
||||
monitor.x + monitor.width,
|
||||
monitor.y + monitor.height,
|
||||
webp_filename = os.path.join(
|
||||
base_dir, date, f"screenshot-{timestamp}-of-{safe_monitor_name}.webp"
|
||||
)
|
||||
)
|
||||
img = img.convert("RGB")
|
||||
current_hash = str(imagehash.phash(img))
|
||||
|
||||
if (
|
||||
safe_monitor_name in previous_hashes
|
||||
and imagehash.hex_to_hash(current_hash)
|
||||
- imagehash.hex_to_hash(previous_hashes[safe_monitor_name])
|
||||
< threshold
|
||||
):
|
||||
logging.info(
|
||||
f"Screenshot for {safe_monitor_name} is similar to the previous one. Skipping."
|
||||
img = sct.grab(monitor)
|
||||
img = Image.frombytes("RGB", img.size, img.bgra, "raw", "BGRX")
|
||||
current_hash = str(imagehash.phash(img))
|
||||
|
||||
if (
|
||||
safe_monitor_name in previous_hashes
|
||||
and imagehash.hex_to_hash(current_hash)
|
||||
- imagehash.hex_to_hash(previous_hashes[safe_monitor_name])
|
||||
< threshold
|
||||
):
|
||||
logging.info(
|
||||
f"Screenshot for {safe_monitor_name} is similar to the previous one. Skipping."
|
||||
)
|
||||
yield safe_monitor_name, None, "Skipped (similar to previous)"
|
||||
continue
|
||||
|
||||
previous_hashes[safe_monitor_name] = current_hash
|
||||
screen_sequences[safe_monitor_name] = (
|
||||
screen_sequences.get(safe_monitor_name, 0) + 1
|
||||
)
|
||||
yield safe_monitor_name, None, "Skipped (similar to previous)"
|
||||
continue
|
||||
|
||||
previous_hashes[safe_monitor_name] = current_hash
|
||||
screen_sequences[safe_monitor_name] = (
|
||||
screen_sequences.get(safe_monitor_name, 0) + 1
|
||||
)
|
||||
metadata = {
|
||||
"timestamp": timestamp,
|
||||
"active_app": app_name,
|
||||
"active_window": window_title,
|
||||
"screen_name": safe_monitor_name,
|
||||
"sequence": screen_sequences[safe_monitor_name],
|
||||
}
|
||||
|
||||
metadata = {
|
||||
"timestamp": timestamp,
|
||||
"active_app": app_name,
|
||||
"active_window": window_title,
|
||||
"screen_name": safe_monitor_name,
|
||||
"sequence": screen_sequences[safe_monitor_name],
|
||||
}
|
||||
img.save(webp_filename, format="WebP", quality=85)
|
||||
write_image_metadata(webp_filename, metadata)
|
||||
save_screen_sequences(base_dir, screen_sequences, date)
|
||||
|
||||
img.save(webp_filename, format="WebP", quality=85)
|
||||
write_image_metadata(webp_filename, metadata)
|
||||
save_screen_sequences(base_dir, screen_sequences, date)
|
||||
|
||||
yield safe_monitor_name, webp_filename, "Saved"
|
||||
yield safe_monitor_name, webp_filename, "Saved"
|
||||
|
||||
|
||||
def take_screenshot(
|
||||
|
@ -9,6 +9,8 @@ from screen_recorder.common import (
|
||||
take_screenshot,
|
||||
is_screen_locked,
|
||||
)
|
||||
from pathlib import Path
|
||||
from memos.config import settings
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
|
||||
@ -51,12 +53,12 @@ def main():
|
||||
"--threshold", type=int, default=4, help="Threshold for image similarity"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--base-dir", type=str, default="~/tmp", help="Base directory for screenshots"
|
||||
"--base-dir", type=str, help="Base directory for screenshots"
|
||||
)
|
||||
parser.add_argument("--once", action="store_true", help="Run once and exit")
|
||||
args = parser.parse_args()
|
||||
|
||||
base_dir = os.path.expanduser(args.base_dir)
|
||||
base_dir = os.path.expanduser(args.base_dir) if args.base_dir else settings.screenshots_dir
|
||||
previous_hashes = load_previous_hashes(base_dir)
|
||||
|
||||
if args.once:
|
||||
|
@ -112,7 +112,7 @@
|
||||
id={`library-${library.id}`}
|
||||
bind:checked={selectedLibraries[library.id]}
|
||||
/>
|
||||
<Label for={`library-${library.id}`} class="flex items-center text-sm">{library.name}</Label>
|
||||
<Label for={`library-${library.id}`} class="flex items-center text-sm">{library.name}#{library.id}</Label>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
|
Loading…
x
Reference in New Issue
Block a user