#!/usr/bin/env python3

import os
import random
import time
import json
import argparse
import plotext as plt
from pathlib import Path

def find_files_within_size_range(directory, min_size_kb=0, max_size_kb=float('inf')):
    result = []
    cnt = 0
    for root, _, files in os.walk(directory):
        for file in files:
            if file.lower().endswith(('.jpg', '.jpeg')):
                path = os.path.join(root, file)
                size_kb = os.path.getsize(path) / 1024
                if min_size_kb <= size_kb <= max_size_kb:
                    result.append(path)
            cnt += 1
            if cnt % 1000 == 999:
                print(f'Files: {cnt:5}, hits: {len(result):5}')
    return result

def measure_latency(file_list, clear_cache=False):
    latencies = []
    for file in file_list:
        if clear_cache:
            os.system('sudo sh -c "echo 3 > /proc/sys/vm/drop_caches"')
        print(f"Reading {file}")
        start = time.time()
        with open(file, 'rb') as f:
            _ = f.read()
        end = time.time()
        latencies.append({
            'file': file,
            'size_kb': os.path.getsize(file) / 1024,
            'latency_ms': (end - start) * 1000
        })
    return latencies

def save_results(results, output_file):
    with open(output_file, 'w') as f:
        json.dump(results, f, indent=2)
    print(f"\nSaved {len(results)} results to {output_file}")

def plot_latency_terminal(file_path, label=None):
    with open(file_path) as f:
        data = json.load(f)
    sizes = [d['size_kb'] for d in data]
    latencies = [d['latency_ms'] for d in data]

    plt.scatter(sizes, latencies, label=label or Path(file_path).stem, marker='small')

def compare_latency_terminal(fs1_path, fs2_path):
    plt.clear_figure()
    plt.title("Latency Comparison")
    plt.xlabel("File Size (KB)")
    plt.ylabel("Latency (ms)")

    plot_latency_terminal(fs1_path, label=fs1_path)
    plot_latency_terminal(fs2_path, label=fs2_path)

#    plt.legend()
    plt.grid(True)
    plt.show()

def print_latency_summary(file_path):
    with open(file_path) as f:
        data = json.load(f)
    print(f"{'File':40} {'Size (KB)':>10} {'Latency (ms)':>12}")
    print("-" * 65)
    for d in data:
        print(f"{os.path.basename(d['file']):40} {d['size_kb']:10.1f} {d['latency_ms']:12.2f}")

def main():
    parser = argparse.ArgumentParser(description="Filesystem read latency test tool")
    subparsers = parser.add_subparsers(dest='command', required=True)

    # Measure command
    measure = subparsers.add_parser("measure", help="Run a latency test")
    measure.add_argument("directory", help="Directory to scan for JPEGs")
    measure.add_argument("-n", "--num-files", type=int, default=10, help="Number of files to test")
    measure.add_argument("--min-size", type=int, default=0, help="Minimum file size in KB")
    measure.add_argument("--max-size", type=int, default=1024*10, help="Maximum file size in KB")
    measure.add_argument("-c", "--clear-cache", action="store_true", help="Clear system cache before each read (requires sudo)")
    measure.add_argument("-o", "--output", default="latency_results.json", help="Output JSON file")

    # Plot command
    plot = subparsers.add_parser("plot", help="Plot results in terminal")
    plot.add_argument("file", help="Result file to plot")

    # Compare command
    compare = subparsers.add_parser("compare", help="Compare two result files")
    compare.add_argument("file1", help="First result file")
    compare.add_argument("file2", help="Second result file")

    args = parser.parse_args()

    if args.command == "measure":
        candidates = find_files_within_size_range(args.directory, args.min_size, args.max_size)
        if not candidates:
            print("No suitable files found.")
            return
        selected = random.sample(candidates, min(args.num_files, len(candidates)))
        results = measure_latency(selected, args.clear_cache)
        save_results(results, args.output)

    elif args.command == "plot":
        plt.clear_figure()
        plt.title("Read Latency")
        plt.xlabel("File Size (KB)")
        plt.ylabel("Latency (ms)")
        plot_latency_terminal(args.file)
        plt.grid(True)
        plt.show()

    elif args.command == "compare":
        compare_latency_terminal(args.file1, args.file2)

if __name__ == "__main__":
    main()
