From 77fdb14395d794d68e2f4a1b69063de2206bdd06 Mon Sep 17 00:00:00 2001 From: tcsenpai Date: Thu, 22 Aug 2024 22:45:59 +0200 Subject: [PATCH] - Multiple OCR passes - Better preprocessing --- main.py | 212 +++++++++++------- .../screenshots_will_be_added_here | 0 2 files changed, 132 insertions(+), 80 deletions(-) delete mode 100644 static/screenshots/screenshots_will_be_added_here diff --git a/main.py b/main.py index eab0669..62c9883 100644 --- a/main.py +++ b/main.py @@ -18,6 +18,8 @@ import argparse from itertools import islice import cv2 import numpy as np +import dbus +import re # Configuration SCREENSHOT_INTERVAL = 5 * 60 # 5 minutes @@ -43,7 +45,7 @@ app.config['CELERYD_CONCURRENCY'] = 2 # Limit to 2 concurrent workers celery = Celery(app.name, broker=app.config['CELERY_BROKER_URL']) celery.conf.update(app.config) -# Database initialization +# Database initialization function def init_db(): with sqlite3.connect(DATABASE) as conn: conn.execute('''CREATE TABLE IF NOT EXISTS screenshots @@ -55,6 +57,7 @@ def init_db(): init_db() +# Function to ensure NLTK data is downloaded def ensure_nltk_data(): try: _create_unverified_https_context = ssl._create_unverified_https_context @@ -69,6 +72,7 @@ def ensure_nltk_data(): ensure_nltk_data() +# Function to get existing words from the database def get_existing_words(): with sqlite3.connect(DATABASE) as conn: cur = conn.cursor() @@ -80,6 +84,7 @@ def get_existing_words(): existing_words.update(json.loads(tags[0])) return existing_words +# Function to generate tags from OCR text def generate_tags(ocr_text): tokens = word_tokenize(ocr_text.lower()) stop_words = set(stopwords.words('english')) @@ -90,21 +95,18 @@ def generate_tags(ocr_text): # Global variable to store the OCR engine ocr_engine = None +# Function to initialize the OCR engine def initialize_ocr_engine(): global ocr_engine ocr_engine = pytesseract print("Tesseract OCR initialized successfully.") -@signals.worker_process_init.connect -def init_worker(**kwargs): - global ocr_engine - initialize_ocr_engine() - +# Celery task to process screenshots @celery.task def process_screenshot(image_path): try: logger.info(f"Performing OCR on {image_path}") - ocr_text = perform_ocr(image_path) + ocr_text = multi_pass_ocr(image_path) tags = generate_tags(ocr_text) logger.info(f"OCR completed for {image_path}") @@ -118,104 +120,141 @@ def process_screenshot(image_path): logger.error(f"Error performing OCR on {image_path}: {e}") return "" +# Function to preprocess the image for OCR def preprocess_image(image_path): - # Read the image img = cv2.imread(image_path) - # Convert to grayscale gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) - - # Apply thresholding to preprocess the image - gray = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)[1] - - # Apply dilation and erosion to remove some noise - kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3,3)) - gray = cv2.dilate(gray, kernel, iterations=1) - gray = cv2.erode(gray, kernel, iterations=1) - - # Apply median blur to remove noise - gray = cv2.medianBlur(gray, 3) - - # Scale the image - gray = cv2.resize(gray, None, fx=2, fy=2, interpolation=cv2.INTER_CUBIC) - - return gray + # Apply adaptive thresholding + thresh = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 11, 2) + # Denoise + denoised = cv2.fastNlMeansDenoising(thresh, None, 10, 7, 21) + # Increase resolution + scaled = cv2.resize(denoised, None, fx=2, fy=2, interpolation=cv2.INTER_CUBIC) + return scaled +# Function to isolate text regions in the image def fast_isolate_text_regions(img): - # Edge detection edges = cv2.Canny(img, 100, 200) - - # Dilate edges to connect text regions kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5,5)) dilated = cv2.dilate(edges, kernel, iterations=3) - - # Find contours contours, _ = cv2.findContours(dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) - - # Create mask mask = np.zeros(img.shape, dtype=np.uint8) - for contour in contours: x, y, w, h = cv2.boundingRect(contour) area = w * h aspect_ratio = w / float(h) - - # Filter contours based on area and aspect ratio if 100 < area < 50000 and 0.1 < aspect_ratio < 10: cv2.rectangle(mask, (x, y), (x + w, y + h), (255, 255, 255), -1) - - # Apply the mask to the original image result = cv2.bitwise_and(img, mask) - return result -def perform_ocr(image_path): - # Preprocess the image - preprocessed = preprocess_image(image_path) - - # Isolate text regions - text_regions = fast_isolate_text_regions(preprocessed) - - # Save the preprocessed image temporarily - temp_file = f"temp_{os.getpid()}.png" - cv2.imwrite(temp_file, text_regions) - - try: - # Perform OCR on the preprocessed image - custom_config = r'--oem 3 --psm 6' - text = pytesseract.image_to_string(Image.open(temp_file), config=custom_config) - return text - finally: - # Clean up the temporary file - os.remove(temp_file) +# Function to perform OCR on an image +def perform_ocr(image): + custom_config = r'--oem 3 --psm 6 -l eng' # Assume English language + text = pytesseract.image_to_string(Image.fromarray(image), config=custom_config) + return clean_ocr_text(text) -# Screenshot function +# Function to clean OCR text +def clean_ocr_text(text): + # Remove non-printable characters + text = ''.join(char for char in text if char.isprintable()) + # Remove extra whitespace + text = ' '.join(text.split()) + # Remove single characters (often errors) + text = re.sub(r'\b\w\b', '', text) + return text + +# Function to perform multiple OCR passes +def multi_pass_ocr(image_path): + img = cv2.imread(image_path) + results = [] + + # Original image + results.append(perform_ocr(img)) + + # Preprocessed image + preprocessed = preprocess_image(image_path) + results.append(perform_ocr(preprocessed)) + + # Grayscale + gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + results.append(perform_ocr(gray)) + + # Thresholded + _, thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) + results.append(perform_ocr(thresh)) + + # Combine results + combined = ' '.join(set(' '.join(results).split())) + return clean_ocr_text(combined) + +# Function to get the session type (X11 or Wayland) +def get_session_type(): + try: + return subprocess.check_output(["echo $XDG_SESSION_TYPE"], shell=True).decode().strip() + except subprocess.CalledProcessError: + logger.warning("Failed to determine session type. Assuming X11.") + return "x11" + +# Function to check if the system is idle +def is_system_idle(idle_time_minutes=5): + print("Checking if system is idle") + idle_time_ms = idle_time_minutes * 60 * 1000 + session_type = get_session_type() + + if session_type == "x11": + try: + idle_time = int(subprocess.check_output(['xprintidle']).decode().strip()) + return idle_time > idle_time_ms + except (subprocess.CalledProcessError, FileNotFoundError): + logger.error("Failed to check idle state with xprintidle. Make sure it's installed.") + return False + elif session_type == "wayland": + try: + bus = dbus.SessionBus() + screensaver_proxy = bus.get_object('org.freedesktop.ScreenSaver', '/org/freedesktop/ScreenSaver') + screensaver_interface = dbus.Interface(screensaver_proxy, dbus_interface='org.freedesktop.ScreenSaver') + idle_time = screensaver_interface.GetSessionIdleTime() + return idle_time > idle_time_ms + except dbus.exceptions.DBusException as e: + logger.error(f"Failed to check idle state via DBus: {e}") + return False + else: + logger.warning(f"Unknown session type: {session_type}. Assuming not idle.") + return False + +# Function to take screenshots def take_screenshot(): while True: - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - filename = os.path.abspath(f"{SCREENSHOT_DIR}/screenshot_{timestamp}.png") - try: - subprocess.run([ - "spectacle", - "-b", # background mode - "-n", # no notification - "-o", filename, # output file - "-f" # full screen - ], check=True) - logger.info(f"Screenshot saved: {filename}") - - # Store screenshot info in database - with sqlite3.connect(DATABASE) as conn: - conn.execute("INSERT INTO screenshots (filename, timestamp) VALUES (?, ?)", - (filename, timestamp)) - - # Trigger async OCR task - process_screenshot.delay(filename) - except subprocess.CalledProcessError as e: - logger.error(f"Error taking screenshot: {e}") + if not is_system_idle(): # Only take screenshot if system is not idle + print("System is not idle, taking screenshot") + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + filename = os.path.abspath(f"{SCREENSHOT_DIR}/screenshot_{timestamp}.png") + try: + subprocess.run([ + "spectacle", + "-b", # background mode + "-n", # no notification + "-o", filename, # output file + "-f" # full screen + ], check=True) + logger.info(f"Screenshot saved: {filename}") + + # Store screenshot info in database + with sqlite3.connect(DATABASE) as conn: + conn.execute("INSERT INTO screenshots (filename, timestamp) VALUES (?, ?)", + (filename, timestamp)) + + # Trigger async OCR task + process_screenshot.delay(filename) + except subprocess.CalledProcessError as e: + logger.error(f"Error taking screenshot: {e}") + else: + logger.info("System is idle, skipping screenshot") time.sleep(SCREENSHOT_INTERVAL) -# Web routes +# Flask route for the main page @app.route('/') def index(): with sqlite3.connect(DATABASE) as conn: @@ -225,6 +264,7 @@ def index(): screenshots = [{'filename': os.path.basename(s[0]), 'timestamp': s[1], 'formatted_timestamp': format_timestamp(s[1]), 'ocr_status': bool(s[2]), 'tags': json.loads(s[3]) if s[3] else []} for s in screenshots] return render_template('index.html', screenshots=screenshots) +# Flask route for search functionality @app.route('/search', methods=['POST']) def search(): query = request.form.get('query', '').lower() @@ -234,6 +274,7 @@ def search(): results = cur.fetchall() return jsonify([{'filename': os.path.basename(r[0]), 'timestamp': r[1], 'formatted_timestamp': format_timestamp(r[1]), 'ocr_status': bool(r[2]), 'tags': json.loads(r[3]) if r[3] else []} for r in results]) +# Function to batch process screenshots def batch_process_screenshots(batch_size=5): with sqlite3.connect(DATABASE) as conn: cur = conn.cursor() @@ -247,11 +288,13 @@ def batch_process_screenshots(batch_size=5): for batch in chunks(screenshots, batch_size): group(process_screenshot.s(screenshot[0]) for screenshot in batch)().get() +# Flask route to trigger OCR for all unprocessed images @app.route('/ocr-all', methods=['POST']) def ocr_all(): batch_process_screenshots.delay() return jsonify({"message": "OCR started for all unprocessed images in batches."}) +# Celery task for batch processing screenshots @celery.task def batch_process_screenshots(batch_size=5): with sqlite3.connect(DATABASE) as conn: @@ -266,6 +309,7 @@ def batch_process_screenshots(batch_size=5): for batch in chunks(screenshots, batch_size): group(process_screenshot.s(screenshot[0]) for screenshot in batch)().get() +# Flask route to delete all screenshots @app.route('/delete-all', methods=['POST']) def delete_all(): with sqlite3.connect(DATABASE) as conn: @@ -277,6 +321,7 @@ def delete_all(): cur.execute("DELETE FROM screenshots") return jsonify({"message": "All screenshots deleted."}) +# Flask route to set screenshot interval @app.route('/set-interval', methods=['POST']) def set_interval(): interval = request.form.get('interval', type=int) @@ -286,6 +331,7 @@ def set_interval(): return jsonify({"message": f"Screenshot interval set to {interval} seconds."}) return jsonify({"message": "Invalid interval."}) +# Flask route for status updates @app.route('/status-updates') def status_updates(): def generate(): @@ -311,6 +357,7 @@ def status_updates(): return Response(generate(), mimetype='text/event-stream') +# Flask route to delete all screenshots and reset the database @app.route('/delete-all-and-reset-db', methods=['POST']) def delete_all_and_reset_db(): try: @@ -329,6 +376,7 @@ def delete_all_and_reset_db(): except Exception as e: return jsonify({"message": f"An error occurred: {str(e)}"}), 500 +# Flask route to filter screenshots by tag @app.route('/filter-by-tag/') def filter_by_tag(tag): with sqlite3.connect(DATABASE) as conn: @@ -343,6 +391,7 @@ def filter_by_tag(tag): 'tags': json.loads(s[4]) if s[4] else [] } for s in screenshots]) +# Flask route to get all unique tags @app.route('/get-all-tags') def get_all_tags(): with sqlite3.connect(DATABASE) as conn: @@ -355,6 +404,7 @@ def get_all_tags(): unique_tags.update(json.loads(tags[0])) return jsonify(list(unique_tags)) +# Flask route to update tags for a screenshot @app.route('/update_tags', methods=['POST']) def update_tags(): data = request.json @@ -373,6 +423,7 @@ def update_tags(): except Exception as e: return jsonify({"success": False, "message": str(e)}), 500 +# Flask route to get information about a specific screenshot @app.route('/get_screenshot_info/') def get_screenshot_info(filename): with sqlite3.connect(DATABASE) as conn: @@ -392,6 +443,7 @@ def get_screenshot_info(filename): else: return jsonify({"error": "Screenshot not found"}), 404 +# Function to format timestamp def format_timestamp(timestamp): try: dt = datetime.strptime(timestamp, "%Y%m%d_%H%M%S") diff --git a/static/screenshots/screenshots_will_be_added_here b/static/screenshots/screenshots_will_be_added_here deleted file mode 100644 index e69de29..0000000