#!/usr/bin/python3

from urllib.request import urlretrieve
from pathlib import Path
from copy import deepcopy
from xml.etree import ElementTree
import json
import os
import hashlib
import shutil
import libvirt
import sys
import click
from socketserver import TCPServer
from http.server import SimpleHTTPRequestHandler
import functools
from threading import Thread
import gnupg
import subprocess
import string
import time

libvirt.registerErrorHandler(lambda x, y: None, None)

class Config(object):
    base = Path(sys.argv[0]).parent

    def __init__(self, base, prefix, data):
        self.base = base
        self.prefix = prefix
        self.data = data

    @classmethod
    def open(cls, name):
        path = cls.base / name
        base = path.expanduser()
        file = base / 'config'
        if not(file.exists()):
            raise click.ClickException(f"'{file}' not found")
        with open(file, 'r') as f:
            data = json.load(f)
        return cls(base, "", data)

    @classmethod
    def create(cls, name):
        path = cls.base / name
        base = path.expanduser()
        file = base / 'config'
        if base.exists():
            raise click.ClickException(f"'{base}' already exists")
        base.mkdir(parents = True, exist_ok = False)
        return cls(base, "", {})

    def copy(self, cfg, ctx = {}):
        for item in cfg.base.iterdir():
            shutil.copy(item, self.base / item.name)
        with open(self.base / 'config', 'r') as f:
            self.data = json.load(f)
        if 'templates' in self.data:
            for item in self.data.get('templates'):
                self.render(self.base / item, ctx)
            del self.data['templates']
        self.save()

    def render(self, path, ctx):
        with open(path, 'r') as f:
            tpl = string.Template(f.read())
        data = tpl.substitute(ctx)
        with open(path, 'w') as f:
            f.write(data)

    def save(self):
        if len(self.prefix) > 0:
            raise click.ClickException("Cannot save partial configuration")
        with open(self.base / 'config', 'w') as f:
            json.dump(self.data, f)

    def delete(self):
        if len(self.prefix) > 0:
            raise click.ClickException("Cannot delete partial configuration")
        for item in self.base.iterdir():
            item.unlink()
        self.base.rmdir()

    def get(self, name, default = None):
        data = self.data
        for key in name.split('/'):
            if type(data) is not dict:
                raise click.ClickException(f"Invalid option '{name}'")
            if key not in data:
                if default is None:
                    raise click.ClickException(f"Option '{name}' not found")
                data = default
                break
            data = data[key]
        if type(data) is dict:
            if len(self.prefix) > 0:
                name = f'{self.prefix}/{name}'
            data = Config(self.base, name, data)
        return data

    def set(self, name, value):
        keys = name.split('/')
        last = keys.pop()
        data = self.data
        for key in keys:
            if type(data) is not dict:
                raise click.ClickException(f"Invalid option '{name}'")
            if key not in data:
                data[key] = {}
            data = data[key]
        data[last] = value

class UserConfig(Config):
    base = Path('~/.glusterfs/gftest')

class Image(object):
    def __init__(self, cfg):
        self.cfg = cfg
        self.path = self.get_image(cfg.get('url'))

    def get_key(self, gpg, url, fingerprint):
        for key in gpg.list_keys():
            if key['fingerprint'].lower() == fingerprint:
                return
        path, _ = urlretrieve(url)
        try:
            info = gpg.scan_keys(path)
            if len(info) != 1:
                raise click.ClickException("Unexpected number of keys")
            if info[0]['fingerprint'].lower() != fingerprint:
                raise click.ClickException("Unexpected key fingerprint")
            with open(path, 'rb') as f:
                data = f.read()
            key = gpg.import_keys(data)
        finally:
            os.unlink(path)
        if key.imported != 1:
            raise click.ClickException("Key import failed")

    def verify_signature(self, path, url, fingerprint):
        click.echo("Validating checksum's signature...")
        gpg = gnupg.GPG()
        self.get_key(gpg, url, fingerprint)
        with open(path, 'rb') as f:
            ver = gpg.verify_file(f)
            if not(ver.valid):
                os.unlink(path)
                raise click.ClickException("Invalid file signature")

    def get_checksum(self, url, name):
        base = Path('~/.glusterfs/cache').expanduser() / name
        base.mkdir(parents = True, exist_ok = True)
        path = base / 'checksums'
        if not(path.exists()):
            urlretrieve(self.cfg.get('checksums'), path)
        key = self.cfg.get('key', False)
        if key is not False:
            fp = self.cfg.get('fingerprint').replace(' ', '').lower()
            self.verify_signature(path, key, fp)
        with open(path, 'r') as f:
            for line in f.readlines():
                data = line.strip().split()
                if (len(data) == 2) and (data[1] == name):
                    return data[0].lower()
        os.unlink(path)
        raise click.ClickException("Image checksum not found")

    def compute_checksum(self, path):
        click.echo("Verifying VM image's checksum...")
        sha256 = hashlib.sha256()
        with open(path, 'rb') as f:
            for chunk in iter(lambda: f.read(131072), b''):
                sha256.update(chunk)
        return sha256.hexdigest().lower()

    def get_image(self, url):
        name = url.split('/')[-1]
        base = Path('~/.glusterfs/cache').expanduser() / name
        base.mkdir(parents = True, exist_ok = True)
        path = base / 'image.qcow2'
        if not(path.exists()):
            click.echo("Downloading VM image...")
            urlretrieve(url, path)
        cs = self.cfg.get('checksums', False)
        if cs is not False:
            cs = self.get_checksum(cs, url.split('/')[-1])
            if self.compute_checksum(path) != cs:
                os.unlink(path)
                raise click.ClickException("Image checksum doesn't match")
        return path

    def create(self, path, size):
        click.echo("Creating VM disk...")
        path = path / 'image.qcow2'
        subprocess.run([
            'qemu-img',
            'create',
            '-q',
            '-f', 'qcow2',
            '-F', 'qcow2',
            '-b', str(self.path),
            str(path),
            f'{size}G'
        ], check = True)
        return path

class HTTPRequestHandler(SimpleHTTPRequestHandler):
    def __init__(self, handlers, *args, **kwargs):
        self.handlers = handlers
        super().__init__(*args, **kwargs)

    def log_message(self, *args):
        pass

    def resolve_path(self, path):
        items = path.split('?', 1)[0].split('/')
        if (len(items) <= 1) or (items[1] not in self.handlers):
            return None
        path = '/'.join(items[2:])
        return f'{self.handlers[items[1]]}/{path}'

    def translate_path(self, path):
        path = self.resolve_path(path)
        if path is None:
            path = '/file-not-found'
        return path

    def do_PUT(self):
        path = self.resolve_path(self.path)
        if path is None:
            self.send_response(400)
            msg = "Failed"
        else:
            length = int(self.headers['Content-Length'])
            Path(path).parent.mkdir(parents = True, exist_ok = True)
            with open(path, 'wb') as f:
                while length > 0:
                    data = self.rfile.read(min(length, 131072))
                    f.write(data)
                    length -= len(data)
            self.send_response(201)
            msg = "Uploaded"
        self.end_headers()
        self.wfile.write(f"{msg}\n".encode('utf-8'))

class HTTPServer(object):
    def __init__(self, handlers):
        self.handlers = handlers

    def __call__(self, *args, **kwargs):
        return HTTPRequestHandler(self.handlers, *args, **kwargs)

    def __enter__(self):
        self.httpd = TCPServer(("127.0.0.1", 0), self)
        self.port = self.httpd.server_address[1]
        self.thread = Thread(target = self.server)
        self.thread.start()
        return self

    def __exit__(self, exc_type, exc_value, exc_bt):
        self.httpd.shutdown()
        self.thread.join()

    def server(self):
        self.httpd.serve_forever()

class VM(object):
    options = [
        'UserKnownHostsFile=/dev/null',
        'StrictHostKeyChecking=no',
        'LogLevel=ERROR'
    ]

    def __init__(self, cfg, name, conn, dom):
        self.cfg = cfg
        self.name = name
        self.conn = conn
        self.dom = dom
        self.network = cfg.get('vm/network')
        self.prefix = '.'.join(self.network.split('.')[:-1])
        self.host = f'{self.prefix}.2'

    @classmethod
    def install(cls, cfg, path, name, conn, timeout):
        click.echo("Installing VM...")

        with open(path, 'r') as f:
            spec = f.read()

        try:
            dom = conn.defineXML(spec)
        except libvirt.libvirtError as exc:
            raise click.ClickException(f"Unable to create '{name}': {exc.get_error_message()}")

        vm = cls(cfg, name, conn, dom)

        try:
            timeout = vm.start(timeout)
            vm.wait(libvirt.VIR_DOMAIN_SHUTOFF, timeout)
        except:
            vm.kill(ignore = True)
            vm.destroy(ignore = True)
            raise

        return vm

    @classmethod
    def create(cls, cfg, name, cpus, memory, disk, network, port, key, timeout):
        pubkey = Path(key + '.pub').expanduser()
        if not(pubkey.exists()):
            raise click.ClickException(f"Public key '{pubkey}' not found")
        with open(pubkey, 'r') as f:
            pubkey = f.read()

        conn = libvirt.open()
        arch = cfg.get('arch', 'x86_64')
        machine = cfg.get('machine', 'q35')
        xml = conn.getDomainCapabilities(arch = arch, machine = machine, virttype = 'kvm')
        data = ElementTree.fromstring(xml)
        emulator = data.findtext('./path')

        net = '.'.join(network.split('.')[:-1])

        if cpus is None:
            cpus = cfg.get('cpus', 2)
        if memory is None:
            memory = cfg.get('memory', 8)
        if disk is None:
            disk = cfg.get('disk', 20)

        user = UserConfig.create(name)

        try:
            img = Image(cfg.get('image')).create(user.base, disk)

            with HTTPServer({ 'config': user.base}) as httpd:
                ctx = {
                    'vm_name': name,
                    'vm_arch': arch,
                    'vm_machine': machine,
                    'vm_emulator': emulator,
                    'vm_memory': memory,
                    'vm_cpus': cpus,
                    'vm_disk': img,
                    'vm_network': f'{net}.0/24',
                    'vm_key': pubkey,
                    'host_port': port,
                    'httpd_address': f'{net}.2:{httpd.port}',
                    'cloud_init_datasource': f'ds=nocloud-net;s=http://{net}.2:{httpd.port}/config/'
                }
                user.copy(cfg, ctx)
                user.set('vm/network', f'{net}.0')
                user.set('vm/port', port)
                user.set('vm/key', str(key))
                user.set('vm/cpus', str(cpus))
                user.set('vm/memory', str(memory))
                user.set('vm/disk', str(disk))
                user.save()

                return cls.install(user, user.base / cfg.get('domain'), name, conn, timeout)
        except:
            user.delete()
            raise

    @classmethod
    def open(cls, name):
        user = UserConfig.open(name)
        conn = libvirt.open()
        try:
            dom = conn.lookupByName(name)
        except libvirt.libvirtError as exc:
            raise click.ClickException(f"Unable to access '{name}': {exc.get_error_message()}")
        except:
            raise click.ClickException(f"Unable to access '{name}'")
        return cls(user, name, conn, dom)

    def ping(self):
        try:
            print(self.run('uname -a', capture = True, error = False))
            return True
        except:
            return False

    def start(self, timeout):
        try:
            if self.dom.state()[0] == libvirt.VIR_DOMAIN_RUNNING:
                return timeout
            self.dom.create()
        except libvirt.libvirtError as exc:
            raise click.ClickException(f"Unable to start '{self.name}': {exc.get_error_message()}")
        except:
            raise click.ClickException(f"Unable to start '{self.name}'")

        timeout = self.wait(libvirt.VIR_DOMAIN_RUNNING, timeout)

        while True:
            try:
                iflist = self.dom.interfaceAddresses(libvirt.VIR_DOMAIN_INTERFACE_ADDRESSES_SRC_AGENT)
            except:
                iflist = []
            for iface in iflist:
                if 'addrs' not in iflist[iface]:
                    continue

                addrs = iflist[iface]['addrs']
                if addrs is None:
                    continue

                for addr in addrs:
                    if 'addr' not in addr:
                        continue

                    addr = '.'.join(addr['addr'].split('.')[:-1])
                    if addr == self.prefix:
                        return timeout

            if timeout <= 0:
                raise click.ClickException(f"Timed out waiting for '{self.name}'")
            time.sleep(0.2)
            timeout -= 0.2

    def stop(self, timeout):
        try:
            if self.dom.state()[0] == libvirt.VIR_DOMAIN_SHUTOFF:
                return timeout
            self.dom.shutdown()
        except libvirt.libvirtError as exc:
            raise click.ClickException(f"Unable to stop '{self.name}': {exc.get_error_message()}")
        except:
            raise click.ClickException(f"Unable to stop '{self.name}'")

        return self.wait(libvirt.VIR_DOMAIN_SHUTOFF, timeout)

    def kill(self, ignore = False):
        try:
            self.dom.destroy()
        except libvirt.libvirtError as exc:
            if not(ignore):
                raise click.ClickException(f"Unable to kill '{self.name}': {exc.get_error_message()}")
        except:
            if not(ignore):
                raise click.ClickException(f"Unable to kill '{self.name}'")

    def destroy(self, ignore = False):
        try:
            self.dom.undefine()
            self.cfg.delete()
        except libvirt.libvirtError as exc:
            if not(ignore):
                raise click.ClickException(f"Unable to destroy '{self.name}': {exc.get_error_message()}")
        except:
            if not(ignore):
                raise click.ClickException(f"Unable to destroy '{self.name}'")

    def wait(self, state, timeout):
        while self.dom.state()[0] != state:
            if timeout <= 0:
                raise click.ClickException(f"Timed out waiting for '{self.name}'")
            time.sleep(0.2)
            timeout -= 0.2
        return timeout

    def run(self, cmd, capture = False, error = True, ignore = False, env = None):
        stdout = subprocess.PIPE if capture else None
        stderr = None if error else subprocess.DEVNULL

        cmdline = [
            'ssh',
            '-qt',
            '-p', str(self.cfg.get('vm/port')),
            '-i', self.cfg.get('vm/key')
        ]
        for opt in VM.options:
            cmdline.extend(['-o', opt])
        cmdline.append('root@127.0.0.1')
        if env is not None:
            cmdline.append('env')
            cmdline.extend([f'{n}={env[n]}' for n in env])
        cmdline.append(cmd)

        proc = subprocess.run(cmdline, stdout = stdout, stderr = stderr, check = not(ignore), text = True, encoding = 'utf-8')
        return proc.stdout

@click.group()
@click.version_option('0.1')
def gftest():
    pass

@gftest.command()
@click.option('-k', '--key', default = '~/.ssh/id_rsa')
@click.option('-n', '--net', default = '192.168.254.0')
@click.option('-p', '--port', type = int, default = 2222)
@click.option('-t', '--template', default = 'centos7')
@click.option('-c', '--cpus', type = int)
@click.option('-m', '--memory', type = int)
@click.option('-d', '--disk', type = int)
@click.option('-o', '--timeout', type = int, default = 120)
@click.argument('name', required = True)
def create(key, net, port, template, cpus, memory, disk, timeout, name):
    cfg = Config.open(template)
    vm = VM.create(cfg.get('vm'), name, cpus, memory, disk, net, port, key, timeout)
    click.echo("Restarting VM...")
    try:
        vm.start(45)
        while not(vm.ping()):
            time.sleep(1)
        container = vm.cfg.get('containers/images/builder')
        with HTTPServer({'config': vm.cfg.base}) as httpd:
            click.echo("Creating builder container...")
            vm.run(f'podman build --squash -t glusterfs/builder -f http://{vm.host}:{httpd.port}/config/{container} /root/')
    except:
        vm.kill(ignore = True)
        vm.destroy(ignore = True)
        raise

@gftest.command()
@click.argument('name', required = True)
def destroy(name):
    vm = VM.open(name)
    vm.kill(ignore = True)
    vm.destroy(ignore = True)

@gftest.command()
@click.option('-t', '--timeout', type = int, default = 45)
@click.argument('name', required = True)
def poweron(timeout, name):
    vm = VM.open(name)
    timeout = vm.start(timeout)
    while not(vm.ping()):
        if timeout <= 0:
            raise click.ClickException(f"Timed out waiting for '{self.name}'")
        time.sleep(1)
        timeout -= 1

@gftest.command()
@click.option('-t', '--timeout', type = int, default = 15)
@click.argument('name', required = True)
def shutdown(timeout, name):
    VM.open(name).stop(timeout)

@gftest.command()
@click.argument('name', required = True)
def poweroff(name):
    VM.open(name).kill()

@gftest.command()
@click.option('-c', '--commit', required = True)
@click.argument('name', required = True)
def build(commit, name):
    subprocess.run(['git', 'update-server-info'], capture_output = True, check = True)
    proc = subprocess.run(['git', 'rev-parse', '--show-toplevel'], capture_output = True, check = True, text = True, encoding = 'utf-8')
    root = Path(proc.stdout.strip()) / '.git'
    vm = VM.open(name)
    vm.run('rm -rf /root/glusterfs')
    try:
        container = vm.cfg.get('containers/images/testing')
        with HTTPServer({'config': vm.cfg.base, 'glusterfs': root}) as httpd:
            vm.run(f'git clone http://{vm.host}:{httpd.port}/glusterfs /root/glusterfs')
            vm.run(f'git -C /root/glusterfs checkout {commit}')
            vm.run(f'podman build --squash -t glusterfs/testing -f http://{vm.host}:{httpd.port}/config/{container} /root/glusterfs')
    except:
        vm.run('rm -rf /root/glusterfs')
        raise

@gftest.command()
@click.option('-s', '--space', type = int)
@click.option('-w', '--workers', type = int)
@click.argument('name', required = True)
def spawn(space, workers, name):
    vm = VM.open(name)
    if space is None:
        space = vm.cfg.get('containers/space', 10)
    if workers is None:
        workers = vm.cfg.get('containers/workers', 8)
    vm.run('modprobe zram')
    zram = vm.run(f'zramctl -f -s {space * workers + 1}G', capture = True)
    loop = vm.run(f'losetup -f --show {zram}', capture = True)
    vm.run(f'pvcreate {loop}')
    vm.run(f'vgcreate vg_{name} {loop}')
    vm.run(f'lvcreate -l 100%FREE -Zn -T vg_{name}/tp')
    for i in range(workers):
        vm.run(f'lvcreate -V {space}G -T vg_{name}/tp -n lv_{name}_{i:02d}')
        vm.run(f'mkfs.xfs -q -i size=512 /dev/vg_{name}/lv_{name}_{i:02d}')
        vm.run(f'mkdir -p /mnt/mnt_{i:02d}')
        vm.run(f'mount -o discard /dev/vg_{name}/lv_{name}_{i:02d} /mnt/mnt_{i:02d}')
        vm.run(f'podman run -d --rm --privileged --name {name}{i:02d} --hostname {name}{i:02d} --tmpfs /tmp:exec -v /dev:/dev -v /mnt/mnt_{i:02d}:/d:Z glusterfs/testing /sbin/init')

@gftest.command()
@click.argument('name', required = True)
def kill(name):
    vm = VM.open(name)
    vm.run('podman rm -f --all')
    vm.run('umount /mnt/mnt_*')
    vm.run('rmdir /mnt/mnt_*')
    vm.run('dmsetup remove_all')
    vm.run('losetup -D')
    vm.run('zramctl -r /dev/zram*', ignore = True)

@gftest.command()
@click.argument('name', required = True)
@click.argument('directory', type = click.Path(), required = True)
def run(name, directory):
    vm = VM.open(name)
    Path(directory).mkdir(parents = True, exist_ok = True)
    with HTTPServer({'config': vm.cfg.base, 'run': directory}) as httpd:
        env = {'HOST_URL': f'http://{vm.host}:{httpd.port}'}
        vm.run('/root/glusterfs/tools/tests/regression', env = env)

@gftest.command()
@click.argument('name', required = True)
@click.argument('container', type = int, required = False)
def sh(name, container):
    vm = VM.open(name)
    with HTTPServer({'config': vm.cfg.base}) as httpd:
        if container is None:
            env = { 'HOST_URL': f'http://{vm.host}:{httpd.port}' }
            vm.run('bash -ils', env = env, ignore = True)
        else:
            vm.run(f'podman exec -ti -e HOST_URL=http://{vm.host}:{httpd.port} {name}{container:02d} /bin/bash', ignore = True)

if __name__ == "__main__":
    gftest()

