mirror of
https://github.com/tcsenpai/screenshot-timeline.git
synced 2025-06-06 03:05:29 +00:00
- Multiple OCR passes
- Better preprocessing
This commit is contained in:
parent
90bd159a77
commit
77fdb14395
166
main.py
166
main.py
@ -18,6 +18,8 @@ import argparse
|
|||||||
from itertools import islice
|
from itertools import islice
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import dbus
|
||||||
|
import re
|
||||||
|
|
||||||
# Configuration
|
# Configuration
|
||||||
SCREENSHOT_INTERVAL = 5 * 60 # 5 minutes
|
SCREENSHOT_INTERVAL = 5 * 60 # 5 minutes
|
||||||
@ -43,7 +45,7 @@ app.config['CELERYD_CONCURRENCY'] = 2 # Limit to 2 concurrent workers
|
|||||||
celery = Celery(app.name, broker=app.config['CELERY_BROKER_URL'])
|
celery = Celery(app.name, broker=app.config['CELERY_BROKER_URL'])
|
||||||
celery.conf.update(app.config)
|
celery.conf.update(app.config)
|
||||||
|
|
||||||
# Database initialization
|
# Database initialization function
|
||||||
def init_db():
|
def init_db():
|
||||||
with sqlite3.connect(DATABASE) as conn:
|
with sqlite3.connect(DATABASE) as conn:
|
||||||
conn.execute('''CREATE TABLE IF NOT EXISTS screenshots
|
conn.execute('''CREATE TABLE IF NOT EXISTS screenshots
|
||||||
@ -55,6 +57,7 @@ def init_db():
|
|||||||
|
|
||||||
init_db()
|
init_db()
|
||||||
|
|
||||||
|
# Function to ensure NLTK data is downloaded
|
||||||
def ensure_nltk_data():
|
def ensure_nltk_data():
|
||||||
try:
|
try:
|
||||||
_create_unverified_https_context = ssl._create_unverified_https_context
|
_create_unverified_https_context = ssl._create_unverified_https_context
|
||||||
@ -69,6 +72,7 @@ def ensure_nltk_data():
|
|||||||
|
|
||||||
ensure_nltk_data()
|
ensure_nltk_data()
|
||||||
|
|
||||||
|
# Function to get existing words from the database
|
||||||
def get_existing_words():
|
def get_existing_words():
|
||||||
with sqlite3.connect(DATABASE) as conn:
|
with sqlite3.connect(DATABASE) as conn:
|
||||||
cur = conn.cursor()
|
cur = conn.cursor()
|
||||||
@ -80,6 +84,7 @@ def get_existing_words():
|
|||||||
existing_words.update(json.loads(tags[0]))
|
existing_words.update(json.loads(tags[0]))
|
||||||
return existing_words
|
return existing_words
|
||||||
|
|
||||||
|
# Function to generate tags from OCR text
|
||||||
def generate_tags(ocr_text):
|
def generate_tags(ocr_text):
|
||||||
tokens = word_tokenize(ocr_text.lower())
|
tokens = word_tokenize(ocr_text.lower())
|
||||||
stop_words = set(stopwords.words('english'))
|
stop_words = set(stopwords.words('english'))
|
||||||
@ -90,21 +95,18 @@ def generate_tags(ocr_text):
|
|||||||
# Global variable to store the OCR engine
|
# Global variable to store the OCR engine
|
||||||
ocr_engine = None
|
ocr_engine = None
|
||||||
|
|
||||||
|
# Function to initialize the OCR engine
|
||||||
def initialize_ocr_engine():
|
def initialize_ocr_engine():
|
||||||
global ocr_engine
|
global ocr_engine
|
||||||
ocr_engine = pytesseract
|
ocr_engine = pytesseract
|
||||||
print("Tesseract OCR initialized successfully.")
|
print("Tesseract OCR initialized successfully.")
|
||||||
|
|
||||||
@signals.worker_process_init.connect
|
# Celery task to process screenshots
|
||||||
def init_worker(**kwargs):
|
|
||||||
global ocr_engine
|
|
||||||
initialize_ocr_engine()
|
|
||||||
|
|
||||||
@celery.task
|
@celery.task
|
||||||
def process_screenshot(image_path):
|
def process_screenshot(image_path):
|
||||||
try:
|
try:
|
||||||
logger.info(f"Performing OCR on {image_path}")
|
logger.info(f"Performing OCR on {image_path}")
|
||||||
ocr_text = perform_ocr(image_path)
|
ocr_text = multi_pass_ocr(image_path)
|
||||||
tags = generate_tags(ocr_text)
|
tags = generate_tags(ocr_text)
|
||||||
logger.info(f"OCR completed for {image_path}")
|
logger.info(f"OCR completed for {image_path}")
|
||||||
|
|
||||||
@ -118,80 +120,115 @@ def process_screenshot(image_path):
|
|||||||
logger.error(f"Error performing OCR on {image_path}: {e}")
|
logger.error(f"Error performing OCR on {image_path}: {e}")
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
# Function to preprocess the image for OCR
|
||||||
def preprocess_image(image_path):
|
def preprocess_image(image_path):
|
||||||
# Read the image
|
|
||||||
img = cv2.imread(image_path)
|
img = cv2.imread(image_path)
|
||||||
|
|
||||||
# Convert to grayscale
|
# Convert to grayscale
|
||||||
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||||
|
# Apply adaptive thresholding
|
||||||
|
thresh = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 11, 2)
|
||||||
|
# Denoise
|
||||||
|
denoised = cv2.fastNlMeansDenoising(thresh, None, 10, 7, 21)
|
||||||
|
# Increase resolution
|
||||||
|
scaled = cv2.resize(denoised, None, fx=2, fy=2, interpolation=cv2.INTER_CUBIC)
|
||||||
|
return scaled
|
||||||
|
|
||||||
# Apply thresholding to preprocess the image
|
# Function to isolate text regions in the image
|
||||||
gray = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)[1]
|
|
||||||
|
|
||||||
# Apply dilation and erosion to remove some noise
|
|
||||||
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3,3))
|
|
||||||
gray = cv2.dilate(gray, kernel, iterations=1)
|
|
||||||
gray = cv2.erode(gray, kernel, iterations=1)
|
|
||||||
|
|
||||||
# Apply median blur to remove noise
|
|
||||||
gray = cv2.medianBlur(gray, 3)
|
|
||||||
|
|
||||||
# Scale the image
|
|
||||||
gray = cv2.resize(gray, None, fx=2, fy=2, interpolation=cv2.INTER_CUBIC)
|
|
||||||
|
|
||||||
return gray
|
|
||||||
|
|
||||||
def fast_isolate_text_regions(img):
|
def fast_isolate_text_regions(img):
|
||||||
# Edge detection
|
|
||||||
edges = cv2.Canny(img, 100, 200)
|
edges = cv2.Canny(img, 100, 200)
|
||||||
|
|
||||||
# Dilate edges to connect text regions
|
|
||||||
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5,5))
|
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5,5))
|
||||||
dilated = cv2.dilate(edges, kernel, iterations=3)
|
dilated = cv2.dilate(edges, kernel, iterations=3)
|
||||||
|
|
||||||
# Find contours
|
|
||||||
contours, _ = cv2.findContours(dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
contours, _ = cv2.findContours(dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||||
|
|
||||||
# Create mask
|
|
||||||
mask = np.zeros(img.shape, dtype=np.uint8)
|
mask = np.zeros(img.shape, dtype=np.uint8)
|
||||||
|
|
||||||
for contour in contours:
|
for contour in contours:
|
||||||
x, y, w, h = cv2.boundingRect(contour)
|
x, y, w, h = cv2.boundingRect(contour)
|
||||||
area = w * h
|
area = w * h
|
||||||
aspect_ratio = w / float(h)
|
aspect_ratio = w / float(h)
|
||||||
|
|
||||||
# Filter contours based on area and aspect ratio
|
|
||||||
if 100 < area < 50000 and 0.1 < aspect_ratio < 10:
|
if 100 < area < 50000 and 0.1 < aspect_ratio < 10:
|
||||||
cv2.rectangle(mask, (x, y), (x + w, y + h), (255, 255, 255), -1)
|
cv2.rectangle(mask, (x, y), (x + w, y + h), (255, 255, 255), -1)
|
||||||
|
|
||||||
# Apply the mask to the original image
|
|
||||||
result = cv2.bitwise_and(img, mask)
|
result = cv2.bitwise_and(img, mask)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def perform_ocr(image_path):
|
# Function to perform OCR on an image
|
||||||
# Preprocess the image
|
def perform_ocr(image):
|
||||||
preprocessed = preprocess_image(image_path)
|
custom_config = r'--oem 3 --psm 6 -l eng' # Assume English language
|
||||||
|
text = pytesseract.image_to_string(Image.fromarray(image), config=custom_config)
|
||||||
|
return clean_ocr_text(text)
|
||||||
|
|
||||||
# Isolate text regions
|
# Function to clean OCR text
|
||||||
text_regions = fast_isolate_text_regions(preprocessed)
|
def clean_ocr_text(text):
|
||||||
|
# Remove non-printable characters
|
||||||
# Save the preprocessed image temporarily
|
text = ''.join(char for char in text if char.isprintable())
|
||||||
temp_file = f"temp_{os.getpid()}.png"
|
# Remove extra whitespace
|
||||||
cv2.imwrite(temp_file, text_regions)
|
text = ' '.join(text.split())
|
||||||
|
# Remove single characters (often errors)
|
||||||
try:
|
text = re.sub(r'\b\w\b', '', text)
|
||||||
# Perform OCR on the preprocessed image
|
|
||||||
custom_config = r'--oem 3 --psm 6'
|
|
||||||
text = pytesseract.image_to_string(Image.open(temp_file), config=custom_config)
|
|
||||||
return text
|
return text
|
||||||
finally:
|
|
||||||
# Clean up the temporary file
|
|
||||||
os.remove(temp_file)
|
|
||||||
|
|
||||||
# Screenshot function
|
# Function to perform multiple OCR passes
|
||||||
|
def multi_pass_ocr(image_path):
|
||||||
|
img = cv2.imread(image_path)
|
||||||
|
results = []
|
||||||
|
|
||||||
|
# Original image
|
||||||
|
results.append(perform_ocr(img))
|
||||||
|
|
||||||
|
# Preprocessed image
|
||||||
|
preprocessed = preprocess_image(image_path)
|
||||||
|
results.append(perform_ocr(preprocessed))
|
||||||
|
|
||||||
|
# Grayscale
|
||||||
|
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||||
|
results.append(perform_ocr(gray))
|
||||||
|
|
||||||
|
# Thresholded
|
||||||
|
_, thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
||||||
|
results.append(perform_ocr(thresh))
|
||||||
|
|
||||||
|
# Combine results
|
||||||
|
combined = ' '.join(set(' '.join(results).split()))
|
||||||
|
return clean_ocr_text(combined)
|
||||||
|
|
||||||
|
# Function to get the session type (X11 or Wayland)
|
||||||
|
def get_session_type():
|
||||||
|
try:
|
||||||
|
return subprocess.check_output(["echo $XDG_SESSION_TYPE"], shell=True).decode().strip()
|
||||||
|
except subprocess.CalledProcessError:
|
||||||
|
logger.warning("Failed to determine session type. Assuming X11.")
|
||||||
|
return "x11"
|
||||||
|
|
||||||
|
# Function to check if the system is idle
|
||||||
|
def is_system_idle(idle_time_minutes=5):
|
||||||
|
print("Checking if system is idle")
|
||||||
|
idle_time_ms = idle_time_minutes * 60 * 1000
|
||||||
|
session_type = get_session_type()
|
||||||
|
|
||||||
|
if session_type == "x11":
|
||||||
|
try:
|
||||||
|
idle_time = int(subprocess.check_output(['xprintidle']).decode().strip())
|
||||||
|
return idle_time > idle_time_ms
|
||||||
|
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||||
|
logger.error("Failed to check idle state with xprintidle. Make sure it's installed.")
|
||||||
|
return False
|
||||||
|
elif session_type == "wayland":
|
||||||
|
try:
|
||||||
|
bus = dbus.SessionBus()
|
||||||
|
screensaver_proxy = bus.get_object('org.freedesktop.ScreenSaver', '/org/freedesktop/ScreenSaver')
|
||||||
|
screensaver_interface = dbus.Interface(screensaver_proxy, dbus_interface='org.freedesktop.ScreenSaver')
|
||||||
|
idle_time = screensaver_interface.GetSessionIdleTime()
|
||||||
|
return idle_time > idle_time_ms
|
||||||
|
except dbus.exceptions.DBusException as e:
|
||||||
|
logger.error(f"Failed to check idle state via DBus: {e}")
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
logger.warning(f"Unknown session type: {session_type}. Assuming not idle.")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Function to take screenshots
|
||||||
def take_screenshot():
|
def take_screenshot():
|
||||||
while True:
|
while True:
|
||||||
|
if not is_system_idle(): # Only take screenshot if system is not idle
|
||||||
|
print("System is not idle, taking screenshot")
|
||||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
filename = os.path.abspath(f"{SCREENSHOT_DIR}/screenshot_{timestamp}.png")
|
filename = os.path.abspath(f"{SCREENSHOT_DIR}/screenshot_{timestamp}.png")
|
||||||
try:
|
try:
|
||||||
@ -213,9 +250,11 @@ def take_screenshot():
|
|||||||
process_screenshot.delay(filename)
|
process_screenshot.delay(filename)
|
||||||
except subprocess.CalledProcessError as e:
|
except subprocess.CalledProcessError as e:
|
||||||
logger.error(f"Error taking screenshot: {e}")
|
logger.error(f"Error taking screenshot: {e}")
|
||||||
|
else:
|
||||||
|
logger.info("System is idle, skipping screenshot")
|
||||||
time.sleep(SCREENSHOT_INTERVAL)
|
time.sleep(SCREENSHOT_INTERVAL)
|
||||||
|
|
||||||
# Web routes
|
# Flask route for the main page
|
||||||
@app.route('/')
|
@app.route('/')
|
||||||
def index():
|
def index():
|
||||||
with sqlite3.connect(DATABASE) as conn:
|
with sqlite3.connect(DATABASE) as conn:
|
||||||
@ -225,6 +264,7 @@ def index():
|
|||||||
screenshots = [{'filename': os.path.basename(s[0]), 'timestamp': s[1], 'formatted_timestamp': format_timestamp(s[1]), 'ocr_status': bool(s[2]), 'tags': json.loads(s[3]) if s[3] else []} for s in screenshots]
|
screenshots = [{'filename': os.path.basename(s[0]), 'timestamp': s[1], 'formatted_timestamp': format_timestamp(s[1]), 'ocr_status': bool(s[2]), 'tags': json.loads(s[3]) if s[3] else []} for s in screenshots]
|
||||||
return render_template('index.html', screenshots=screenshots)
|
return render_template('index.html', screenshots=screenshots)
|
||||||
|
|
||||||
|
# Flask route for search functionality
|
||||||
@app.route('/search', methods=['POST'])
|
@app.route('/search', methods=['POST'])
|
||||||
def search():
|
def search():
|
||||||
query = request.form.get('query', '').lower()
|
query = request.form.get('query', '').lower()
|
||||||
@ -234,6 +274,7 @@ def search():
|
|||||||
results = cur.fetchall()
|
results = cur.fetchall()
|
||||||
return jsonify([{'filename': os.path.basename(r[0]), 'timestamp': r[1], 'formatted_timestamp': format_timestamp(r[1]), 'ocr_status': bool(r[2]), 'tags': json.loads(r[3]) if r[3] else []} for r in results])
|
return jsonify([{'filename': os.path.basename(r[0]), 'timestamp': r[1], 'formatted_timestamp': format_timestamp(r[1]), 'ocr_status': bool(r[2]), 'tags': json.loads(r[3]) if r[3] else []} for r in results])
|
||||||
|
|
||||||
|
# Function to batch process screenshots
|
||||||
def batch_process_screenshots(batch_size=5):
|
def batch_process_screenshots(batch_size=5):
|
||||||
with sqlite3.connect(DATABASE) as conn:
|
with sqlite3.connect(DATABASE) as conn:
|
||||||
cur = conn.cursor()
|
cur = conn.cursor()
|
||||||
@ -247,11 +288,13 @@ def batch_process_screenshots(batch_size=5):
|
|||||||
for batch in chunks(screenshots, batch_size):
|
for batch in chunks(screenshots, batch_size):
|
||||||
group(process_screenshot.s(screenshot[0]) for screenshot in batch)().get()
|
group(process_screenshot.s(screenshot[0]) for screenshot in batch)().get()
|
||||||
|
|
||||||
|
# Flask route to trigger OCR for all unprocessed images
|
||||||
@app.route('/ocr-all', methods=['POST'])
|
@app.route('/ocr-all', methods=['POST'])
|
||||||
def ocr_all():
|
def ocr_all():
|
||||||
batch_process_screenshots.delay()
|
batch_process_screenshots.delay()
|
||||||
return jsonify({"message": "OCR started for all unprocessed images in batches."})
|
return jsonify({"message": "OCR started for all unprocessed images in batches."})
|
||||||
|
|
||||||
|
# Celery task for batch processing screenshots
|
||||||
@celery.task
|
@celery.task
|
||||||
def batch_process_screenshots(batch_size=5):
|
def batch_process_screenshots(batch_size=5):
|
||||||
with sqlite3.connect(DATABASE) as conn:
|
with sqlite3.connect(DATABASE) as conn:
|
||||||
@ -266,6 +309,7 @@ def batch_process_screenshots(batch_size=5):
|
|||||||
for batch in chunks(screenshots, batch_size):
|
for batch in chunks(screenshots, batch_size):
|
||||||
group(process_screenshot.s(screenshot[0]) for screenshot in batch)().get()
|
group(process_screenshot.s(screenshot[0]) for screenshot in batch)().get()
|
||||||
|
|
||||||
|
# Flask route to delete all screenshots
|
||||||
@app.route('/delete-all', methods=['POST'])
|
@app.route('/delete-all', methods=['POST'])
|
||||||
def delete_all():
|
def delete_all():
|
||||||
with sqlite3.connect(DATABASE) as conn:
|
with sqlite3.connect(DATABASE) as conn:
|
||||||
@ -277,6 +321,7 @@ def delete_all():
|
|||||||
cur.execute("DELETE FROM screenshots")
|
cur.execute("DELETE FROM screenshots")
|
||||||
return jsonify({"message": "All screenshots deleted."})
|
return jsonify({"message": "All screenshots deleted."})
|
||||||
|
|
||||||
|
# Flask route to set screenshot interval
|
||||||
@app.route('/set-interval', methods=['POST'])
|
@app.route('/set-interval', methods=['POST'])
|
||||||
def set_interval():
|
def set_interval():
|
||||||
interval = request.form.get('interval', type=int)
|
interval = request.form.get('interval', type=int)
|
||||||
@ -286,6 +331,7 @@ def set_interval():
|
|||||||
return jsonify({"message": f"Screenshot interval set to {interval} seconds."})
|
return jsonify({"message": f"Screenshot interval set to {interval} seconds."})
|
||||||
return jsonify({"message": "Invalid interval."})
|
return jsonify({"message": "Invalid interval."})
|
||||||
|
|
||||||
|
# Flask route for status updates
|
||||||
@app.route('/status-updates')
|
@app.route('/status-updates')
|
||||||
def status_updates():
|
def status_updates():
|
||||||
def generate():
|
def generate():
|
||||||
@ -311,6 +357,7 @@ def status_updates():
|
|||||||
|
|
||||||
return Response(generate(), mimetype='text/event-stream')
|
return Response(generate(), mimetype='text/event-stream')
|
||||||
|
|
||||||
|
# Flask route to delete all screenshots and reset the database
|
||||||
@app.route('/delete-all-and-reset-db', methods=['POST'])
|
@app.route('/delete-all-and-reset-db', methods=['POST'])
|
||||||
def delete_all_and_reset_db():
|
def delete_all_and_reset_db():
|
||||||
try:
|
try:
|
||||||
@ -329,6 +376,7 @@ def delete_all_and_reset_db():
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
return jsonify({"message": f"An error occurred: {str(e)}"}), 500
|
return jsonify({"message": f"An error occurred: {str(e)}"}), 500
|
||||||
|
|
||||||
|
# Flask route to filter screenshots by tag
|
||||||
@app.route('/filter-by-tag/<tag>')
|
@app.route('/filter-by-tag/<tag>')
|
||||||
def filter_by_tag(tag):
|
def filter_by_tag(tag):
|
||||||
with sqlite3.connect(DATABASE) as conn:
|
with sqlite3.connect(DATABASE) as conn:
|
||||||
@ -343,6 +391,7 @@ def filter_by_tag(tag):
|
|||||||
'tags': json.loads(s[4]) if s[4] else []
|
'tags': json.loads(s[4]) if s[4] else []
|
||||||
} for s in screenshots])
|
} for s in screenshots])
|
||||||
|
|
||||||
|
# Flask route to get all unique tags
|
||||||
@app.route('/get-all-tags')
|
@app.route('/get-all-tags')
|
||||||
def get_all_tags():
|
def get_all_tags():
|
||||||
with sqlite3.connect(DATABASE) as conn:
|
with sqlite3.connect(DATABASE) as conn:
|
||||||
@ -355,6 +404,7 @@ def get_all_tags():
|
|||||||
unique_tags.update(json.loads(tags[0]))
|
unique_tags.update(json.loads(tags[0]))
|
||||||
return jsonify(list(unique_tags))
|
return jsonify(list(unique_tags))
|
||||||
|
|
||||||
|
# Flask route to update tags for a screenshot
|
||||||
@app.route('/update_tags', methods=['POST'])
|
@app.route('/update_tags', methods=['POST'])
|
||||||
def update_tags():
|
def update_tags():
|
||||||
data = request.json
|
data = request.json
|
||||||
@ -373,6 +423,7 @@ def update_tags():
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
return jsonify({"success": False, "message": str(e)}), 500
|
return jsonify({"success": False, "message": str(e)}), 500
|
||||||
|
|
||||||
|
# Flask route to get information about a specific screenshot
|
||||||
@app.route('/get_screenshot_info/<filename>')
|
@app.route('/get_screenshot_info/<filename>')
|
||||||
def get_screenshot_info(filename):
|
def get_screenshot_info(filename):
|
||||||
with sqlite3.connect(DATABASE) as conn:
|
with sqlite3.connect(DATABASE) as conn:
|
||||||
@ -392,6 +443,7 @@ def get_screenshot_info(filename):
|
|||||||
else:
|
else:
|
||||||
return jsonify({"error": "Screenshot not found"}), 404
|
return jsonify({"error": "Screenshot not found"}), 404
|
||||||
|
|
||||||
|
# Function to format timestamp
|
||||||
def format_timestamp(timestamp):
|
def format_timestamp(timestamp):
|
||||||
try:
|
try:
|
||||||
dt = datetime.strptime(timestamp, "%Y%m%d_%H%M%S")
|
dt = datetime.strptime(timestamp, "%Y%m%d_%H%M%S")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user