feat(scan): support batching check file exists

This commit is contained in:
arkohut 2024-07-26 18:34:59 +08:00
parent 4a90d86c16
commit badbfd70bc
2 changed files with 31 additions and 9 deletions

View File

@ -102,11 +102,11 @@ def get_entities_of_folder(
return [], 0 return [], 0
query = db.query(EntityModel).filter(EntityModel.folder_id == folder_id) query = db.query(EntityModel).filter(EntityModel.folder_id == folder_id)
total_count = query.count() total_count = query.count()
entities = query.limit(limit).offset(offset).all() entities = query.limit(limit).offset(offset).all()
return entities, total_count return entities, total_count
@ -114,6 +114,10 @@ def get_entity_by_filepath(filepath: str, db: Session) -> Entity | None:
return db.query(EntityModel).filter(EntityModel.filepath == filepath).first() return db.query(EntityModel).filter(EntityModel.filepath == filepath).first()
def get_entities_by_filepaths(filepaths: List[str], db: Session) -> List[Entity]:
return db.query(EntityModel).filter(EntityModel.filepath.in_(filepaths)).all()
def remove_entity(entity_id: int, db: Session): def remove_entity(entity_id: int, db: Session):
entity = db.query(EntityModel).filter(EntityModel.id == entity_id).first() entity = db.query(EntityModel).filter(EntityModel.id == entity_id).first()
if entity: if entity:

View File

@ -43,8 +43,7 @@ from .logging_config import LOGGING_CONFIG
# Configure logging to include datetime # Configure logging to include datetime
logging.basicConfig( logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(message)s', format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
level=logging.INFO
) )
engine = create_engine(f"sqlite:///{get_database_path()}") engine = create_engine(f"sqlite:///{get_database_path()}")
@ -238,10 +237,11 @@ def list_entities_in_folder(
detail="Folder not found in the specified library", detail="Folder not found in the specified library",
) )
entities, total_count = crud.get_entities_of_folder(library_id, folder_id, db, limit, offset) entities, total_count = crud.get_entities_of_folder(
library_id, folder_id, db, limit, offset
)
return JSONResponse( return JSONResponse(
content=jsonable_encoder(entities), content=jsonable_encoder(entities), headers={"X-Total-Count": str(total_count)}
headers={"X-Total-Count": str(total_count)}
) )
@ -261,6 +261,18 @@ def get_entity_by_filepath(
return entity return entity
@app.post(
"/libraries/{library_id}/entities/by-filepaths",
response_model=List[Entity],
tags=["entity"],
)
def get_entities_by_filepaths(
library_id: int, filepaths: List[str], db: Session = Depends(get_db)
):
entities = crud.get_entities_by_filepaths(filepaths, db)
return [entity for entity in entities if entity.library_id == library_id]
@app.get("/entities/{entity_id}", response_model=Entity, tags=["entity"]) @app.get("/entities/{entity_id}", response_model=Entity, tags=["entity"])
def get_entity_by_id(entity_id: int, db: Session = Depends(get_db)): def get_entity_by_id(entity_id: int, db: Session = Depends(get_db)):
entity = crud.get_entity_by_id(entity_id, db) entity = crud.get_entity_by_id(entity_id, db)
@ -539,4 +551,10 @@ def run_server():
print( print(
f"Typesense connection info: Host: {settings.typesense_host}, Port: {settings.typesense_port}, Protocol: {settings.typesense_protocol}" f"Typesense connection info: Host: {settings.typesense_host}, Port: {settings.typesense_port}, Protocol: {settings.typesense_protocol}"
) )
uvicorn.run("memos.server:app", host="0.0.0.0", port=8080, reload=False, log_config=LOGGING_CONFIG) uvicorn.run(
"memos.server:app",
host="0.0.0.0",
port=8080,
reload=False,
log_config=LOGGING_CONFIG,
)