#!/usr/bin/python3

import sys
import os
from multiprocessing import Process, Queue, Event
from threading import Thread
import time
import requests
import subprocess
import signal
from pathlib import Path

HOST_URL = os.environ.get('HOST_URL')

subprocess.run(['modprobe', 'dm_thin_pool', 'dm_snapshot'], check = True)

class TestStats(object):
    def __init__(self):
        self.data = {}
        self.selected = []

    def __enter__(self):
        self.data = {}
        self.selected = []
        self.get()
        return self

    def __exit__(self, exc_type, exc_value, exc_bt):
        self.put()

    def __iter__(self):
        self.selected.sort(key = lambda x: (-self.data[x]['avg'], x))
        return iter(self.selected)

    def get(self):
        if HOST_URL is None:
            return
        res = requests.get(f'{HOST_URL}/run/tests.json')
        if res.status_code == 200:
            self.data = res.json()

    def put(self):
        if HOST_URL is not None:
            data = {}
            for name in self.data:
                item = self.data[name]
                if item['count'] > 0:
                    data[name] = item
            requests.put(f'{HOST_URL}/run/tests.json', json = data)

    def get_data(self, name):
        if name not in self.data:
            self.data[name] = { 'count': 0, 'elapsed': 0, 'retries': 0, 'avg': float('inf') }
        return self.data[name]

    def select(self, name):
        self.get_data(name)
        self.selected.append(name)

    def add(self, name, elapsed, retries):
        item = self.get_data(name)
        item['count'] += 1
        item['elapsed'] += elapsed
        item['retries'] += retries
        item['avg'] = item['elapsed'] / item['count']

class GlusterTesting(object):
    def __init__(self, stats):
        self.stats = stats
        self.procs = []
        self.requests = Queue()
        self.results = Queue()
        self.event = Event()

    def __enter__(self):
        proc = subprocess.run(['podman', 'ps', '--format', '{{.Names}}'], stdout=subprocess.PIPE, check = True, universal_newlines = True, encoding = 'utf-8')
        containers = proc.stdout.splitlines()
        for i in range(len(containers)):
            proc = Process(target = self.worker, args = (containers[i].strip(),))
            proc.start()
            self.procs.append(proc)
        self.event.clear()
        self.monitor_thread = Thread(target = self.monitor)
        self.monitor_thread.start()
        self.collect_thread = Thread(target = self.collect)
        self.collect_thread.start()
        return self

    def __exit__(self, exc_type, exc_value, exc_bt):
        old = signal.signal(signal.SIGINT, signal.SIG_IGN)
        for i in range(len(self.procs)):
            self.process(None)
        for proc in self.procs:
            proc.join()
        self.results.put(None)
        self.collect_thread.join()
        self.event.set()
        self.monitor_thread.join()
        signal.signal(signal.SIGINT, old)

    def scan_loadavg(self):
        with open('/proc/loadavg', 'r') as f:
            data = f.readline().strip().split()[:4]
        runnable, total = data[3].split('/')
        return [float(data[0]), float(data[1]), float(data[2]), int(runnable), int(total)]

    def scan_meminfo(self):
        free = -1
        avail = -1
        count = 0
        with open('/proc/meminfo', 'r') as f:
            for line in f.readlines():
                name, value = line.strip().split(':')
                if name == 'MemFree':
                    free = int(value.strip().split()[0])
                elif name == 'MemAvailable':
                    avail = int(value.strip().split()[0])
                else:
                    continue
                count += 1
                if count >= 2:
                    break
        return [free / 1024, avail / 1024]

    def scan_stat(self):
        with open('/proc/stat', 'r') as f:
            for line in f.readlines():
                data = line.strip().split()
                if data[0] == 'cpu':
                    break
        return [int(x) for x in data[1:9]]

    def monitor(self):
        with open('/tmp/monitor', 'w') as state:
            hz = os.sysconf(os.sysconf_names['SC_CLK_TCK']) / 100
            previous = self.scan_stat()
            last = time.time()
            while not(self.event.wait(1)):
                now = time.time()
                delay = (now - last) * hz

                data = [f'{now:.6f}']
                data.extend([f'{x:8.2f}' for x in self.scan_meminfo()])
                load = self.scan_loadavg()
                data.extend([f'{x:6.2f}' for x in load[:3]])
                data.extend([f'{x:6d}' for x in load[3:]])
                cpu = self.scan_stat()
                data.extend([f'{(cpu[i] - previous[i]) / delay:6.2f}' for i in range(8)])

                state.write(' '.join(data) + '\n')

                last = now
                previous = cpu

    def collect(self):
        start = time.time()
        total = len(self.stats.selected)
        count = 0
        data = self.results.get()
        while data is not None:
            container = data[0]
            name = data[1]
            elapsed = data[2]
            retries = data[3]

            if retries < 0:
                res = '31;1mFAILED'
                retries = -retries
            elif retries == 1:
                res = '32;1mPASSED'
            else:
                res = '33;1mPASSED'

            count += 1

            runtime = (time.time() - start) / 60
            msg = f"{count:4d}/{total} {runtime:5.1f} [{container}] {elapsed:6.1f} {retries} \x1b[{res}\x1b[0m {name}\n"
            sys.stdout.write(msg)

            self.stats.add(name, elapsed, retries)

            data = self.results.get()

    def process(self, name):
        self.requests.put(name)

    def worker(self, container):
        try:
            name = self.requests.get()
            while name is not None:
                elapsed = time.time()
                retries = self.launch(container, name)
                elapsed = time.time() - elapsed

                self.results.put((container, name, elapsed, retries))

                name = self.requests.get()
        except KeyboardInterrupt:
            pass

    def launch(self, container, name):
        cmd = ['podman', 'exec']
        if HOST_URL is not None:
            cmd.extend(['-e', f'HOST_URL={HOST_URL}'])
        cmd.extend([container, '/root/glusterfs/tools/tests/prove_run', name])
        for i in range(2):
            proc = subprocess.run(cmd, stdout = subprocess.DEVNULL, stderr = subprocess.DEVNULL, check = False, universal_newlines = True, encoding = 'utf-8')
            if proc.returncode == 0:
                return i + 1
        return -2

if __name__ == "__main__":
    with TestStats() as tests:
        cmd = [str(Path(sys.argv[0]).parent.parent.parent / 'run-tests.sh'), '-l']
        cmd.extend(sys.argv[1:])
        proc = subprocess.run(cmd, stdout = subprocess.PIPE, check = True, universal_newlines = True, encoding = 'utf-8')
        count = 0
        for line in proc.stdout.splitlines():
            tests.select(line.strip())
            count += 1

        print(f"{count} test(s) found")

        with GlusterTesting(tests) as test:
            for name in iter(tests):
                test.process(name)

    if HOST_URL is not None:
        with open('/tmp/monitor', 'r') as f:
            requests.put(f'{HOST_URL}/run/monitor', data = f)

