#!/usr/bin/python

import os
import sys
import re
import subprocess
import threading
import parted
import _ped
import shutil
import tempfile
import datetime
import signal
import time

class Logging:
    def __init__(self, loglevel=1, logfile="/var/log/prepare_pstorage_drive.log"):
        if logfile == "":
            self.logfile_fd = 0
        else:
            self.logfile_fd = open(logfile, "a+")
        self.loglevel = loglevel

    def __del__(self):
        if self.logfile_fd:
            self.logfile_fd.close()

    def log_it(self, message, loglevel):
        if self.logfile_fd:
            self.logfile_fd.write("%s" % datetime.datetime.today() + "\t" + message + "\n")
        if self.loglevel >= loglevel:
            print message

    def debug(self, message):
        self.log_it(message, 4)

    def info(self, message):
        self.log_it(message, 3)

    def warning(self, message):
        self.log_it(message, 2)

    def error(self, message):
        self.log_it(message, 1)

def isEfi():
    if os.path.exists("/sys/firmware/efi"):
        return True

    return False

class thread(threading.Thread):
    def __init__(self, inputd, outputd, logmethod, command):
        threading.Thread.__init__(self)
        self.inputd = os.fdopen(inputd, "r")
        self.outputd = outputd
        self.logmethod = logmethod
        self.running = True
        self.command = command

    def run(self):
        while self.running:
            try:
                data = self.inputd.readline()
            except IOError:
                self.logmethod("Failed to read from pipe during a call to %s." \
                                % self.command)
                break
            if data == "":
                self.running = False
                continue

            self.logmethod(data.rstrip('\n'))
            os.write(self.outputd, data)

    def stop(self):
        self.running = False
        return self

def execRedirect(command, argv, stdin = None, stdout = None, stderr = None):

    stdinclose = stdoutclose = stderrclose = lambda : None

    argv = list(argv)
    if isinstance(stdin, str):
        if os.access(stdin, os.R_OK):
            stdin = os.open(stdin, os.O_RDONLY)
            stdinclose = lambda : os.close(stdin)
        else:
            stdin = sys.stdin.fileno()
    elif isinstance(stdin, int):
        pass
    elif stdin is None or not isinstance(stdin, file):
        stdin = sys.stdin.fileno()

    orig_stdout = stdout
    if isinstance(stdout, str):
        stdout = os.open(stdout, os.O_RDWR|os.O_CREAT)
        stdoutclose = lambda : os.close(stdout)
    elif isinstance(stdout, int):
        pass
    elif stdout is None or not isinstance(stdout, file):
        stdout = sys.stdout.fileno()

    if isinstance(stderr, str) and isinstance(orig_stdout, str) and stderr == orig_stdout:
        stderr = stdout
    elif isinstance(stderr, str):
        stderr = os.open(stderr, os.O_RDWR|os.O_CREAT)
        stderrclose = lambda : os.close(stderr)
    elif isinstance(stderr, int):
        pass
    elif stderr is None or not isinstance(stderr, file):
        stderr = sys.stderr.fileno()

    program_log.info("Running command... %s" % ([command] + argv,))

    pstdout, pstdin = os.pipe()
    perrout, perrin = os.pipe()

    env = os.environ.copy()
    # Set C locale
    env.update({"LC_ALL": "C"})

    try:
        proc_std = thread(pstdout, stdout, program_log.info, command)
        proc_err = thread(perrout, stderr, program_log.error, command)

        proc_std.start()
        proc_err.start()

        proc = subprocess.Popen([command] + argv, stdin=stdin,
                                stdout=pstdin,
                                stderr=perrin,
                                cwd="/",
                                env=env)

        proc.wait()
        ret = proc.returncode

        os.close(pstdin)
        os.close(perrin)

        proc_std.join()
        del proc_std

        proc_err.join()
        del proc_err

        stdinclose()
        stdoutclose()
        stderrclose()
    except OSError as e:
        errstr = "Error running command %s: %s" % (command, e.strerror)
        log.error(errstr)
        program_log.error(errstr)
        os.close(pstdin)
        os.close(perrin)
        proc_std.join()
        proc_err.join()

        stdinclose()
        stdoutclose()
        stderrclose()
        raise RuntimeError, errstr

    return ret

def execCapture(command, argv, stdin = None, stderr = None, fatal = False):

    def closefds ():
        stdinclose()
        stderrclose()

    stdinclose = stderrclose = lambda : None
    rc = ""
    argv = list(argv)

    if isinstance(stdin, str):
        if os.access(stdin, os.R_OK):
            stdin = os.open(stdin, os.O_RDONLY)
            stdinclose = lambda : os.close(stdin)
        else:
            stdin = sys.stdin.fileno()
    elif isinstance(stdin, int):
        pass
    elif stdin is None or not isinstance(stdin, file):
        stdin = sys.stdin.fileno()

    if isinstance(stderr, str):
        stderr = os.open(stderr, os.O_RDWR|os.O_CREAT)
        stderrclose = lambda : os.close(stderr)
    elif isinstance(stderr, int):
        pass
    elif stderr is None or not isinstance(stderr, file):
        stderr = sys.stderr.fileno()

    program_log.info("Running command... %s" % ([command] + argv,))

    env = os.environ.copy()
    # Set C locale
    env.update({"LC_ALL": "C"})

    try:
        proc = subprocess.Popen([command] + argv, stdin=stdin,
                                stdout=subprocess.PIPE,
                                stderr=subprocess.PIPE,
                                cwd="/",
                                env=env)

        while True:
            (outStr, errStr) = proc.communicate()
            if outStr:
                map(program_log.info, outStr.splitlines())
                rc += outStr
            if errStr:
                map(program_log.error, errStr.splitlines())
                os.write(stderr, errStr)

            if proc.returncode is not None:
                break
        if proc.returncode and fatal:
            raise OSError(proc.returncode, errStr)
    except OSError as e:
        log.error ("Error running command " + command + ": " + e.strerror)
        closefds()
        raise RuntimeError, "Error running command " + command + ": " + e.strerror

    closefds()
    return rc

def create_input(data):
    pipe = os.pipe()
    os.write(pipe[1], data + '\n')
    os.close(pipe[1])
    return pipe[0]


class PrepareHDD:
    def __init__(self, device, out="/dev/null", err="/dev/null", loglevel=3, ssd=False, boot=True):
        self.device = device
        self.product_name = "Parallels Cloud Server"
        self.startpart = 64
        self.boot = boot
        if self.boot:
            self.bootsize = 512*1024*1024 / 512 # 512MB in sectors
        else:
            self.bootsize = self.startpart
        # FAT label is only 11 symbols long...
        self.bootlabel = "PCSBoot"
        self.bootuuid = "Undefined"
        self.bootlabel_efi = "PCSBootEFI"
        self.bootuuid_efi = "UndefinedUUID"
        self.bootlabel_installed = False
        self.pstorage_cs_label = "pstorage-hotplug"
        self.stdout = out
        self.stderr = err
        self.tempdir = tempfile.mkdtemp(prefix="PrepareHDD_", dir="/tmp")
        self.mpath = self.tempdir + "/boot_cs"
        self.device_map = self.tempdir + "/device.map"
        self.mounted = False
        self.log = Logging(loglevel, logfile="")
        self.ssd = ssd

    def check(self):
        if not re.search("^/dev/[hsv]d[a-z]", self.device):
            self.log.error("Given device %s is not valid" % self.device)
            return False
        if not os.path.exists(self.device):
            self.log.error("Given device %s does not exist" % self.device)
            return False

        return True

    def __del__(self):
        try:
            self.umount()
            os.rmdir(self.mpath)
            os.unlink(self.device_map)
            os.rmdir(self.tempdir)
        except Exception as e:
            return

    def umount(self):
        if not self.mounted:
            return True

        # Umount
        rc = execRedirect("/bin/umount", [ self.mpath ],
            stdout=self.stdout, stderr=self.stderr)
        if rc:
            self.log.error("Failed to mount %s GRUB disk: %i..." % (device, rc))
            return False

        self.mounted = False
        return True

    def zero_device(self):
        # Zero mbr + first sectors + end
        self.log.info("Zeroing out beginning and end of %s..." % self.device)
        fd = None

        try:
            fd = os.open(self.device, os.O_RDWR)
            buf = '\0' * 1024 * 1024
            os.write(fd, buf)
            os.lseek(fd, -1024 * 1024, 2)
            os.write(fd, buf)
            os.close(fd)
        except Exception as e:
            if getattr(e, "errno", None) != 28: # No space left in device
                self.log.error("error zeroing out %s: %s" % (self.device, e))
            if fd:
                os.close(fd)

            return False

        return True

    def wait_for_device(self, device):
        self.log.info("Waiting for kernel...")

        # Wait for device
        timeout = 60
        while (timeout > 0):
            if os.path.exists(device):
                break
            time.sleep(1)
            timeout -= 1

        if timeout == 0 and not os.path.exists(device):
            self.log.error("Failed to wait for %s disk: %i..." % (device, rc))
            return False

        return True

    def prepare_pstorage_data_disk(self):
        if not self.zero_device():
            return False

        try:
            if os.path.exists("%s1" % self.device):
                os.unlink("%s1" % self.device)
            self.log.info("Partitioning %s..." % self.device)
            device = parted.Device(self.device)
            # create GPT label
            disk = parted.freshDisk(device, "gpt")
            constraint = parted.Constraint(device=device)
            if self.boot:
                # create 1 boot partition
                geometry = parted.Geometry(device=device,
                        start=self.startpart, end=(self.bootsize - 1))
                # Make it fat
                partition_ped = _ped.Partition(disk.getPedDisk(),
                    parted.PARTITION_NORMAL,
                    geometry.start, geometry.end,
                    _ped.file_system_type_get("fat32"))
                partition_ped.set_name("EFI System Partition")
                partition = parted.Partition(PedPartition=partition_ped)
                disk.addPartition(partition=partition, constraint=constraint)

            # create main partition
            geometry = parted.Geometry(device=device,
                    start=self.bootsize, end=(constraint.maxSize - 1))
            # Make it as ext4
            filesystem = parted.FileSystem(type="ext4", geometry=geometry)
            partition = parted.Partition(disk=disk, fs=filesystem,
                    type=parted.PARTITION_NORMAL, geometry=geometry)
            disk.addPartition(partition=partition, constraint=constraint)

            # Apply
            disk.commit()
        except Exception, e:
            self.log.error("Failed to repartition disk %s:\n%s" % (self.device, e))
            return False

        self.wait_for_device("%s1" % self.device)

        if self.boot:
            os.unlink("%s1" % self.device)
            # Set active flag for GPT due to buggy BIOSes
            execRedirect("/sbin/fdisk", [ self.device ],
                stdin = create_input("a\n1\nw"),
                stdout=self.stdout,
                stderr=self.stderr)
            # Ignore exit code, due to kernel lag, doesn't matter will be partition re-readed

            self.wait_for_device("%s1" % self.device)

            # Set boot flag on EFI GPT partition
            # Due to parted bug it can't be properly set on partition creation...
            device = parted.Device(self.device)
            disk = parted.Disk(device)
            partition = disk.getPartitionByPath("%s1" % self.device)
            partition.setFlag(parted.PARTITION_BOOT)
            disk.commitToDevice()

            # Prepare boot
            if not self.prepare_boot("%s1" % self.device):
                return False
            return self.prepare_data("%s2" % self.device)

        return self.prepare_data("%s1" % self.device)

    def get_device_for_mpoint(self, mpoint):
        bootdev = ""

        f = open("/proc/mounts", 'r')
        for line in f.read().splitlines():
            if re.search("^/dev/[a-z]+[0-9]+ %s " % mpoint, line):
                bootdev = line.split()[0]
                break

        f.close()

        return bootdev

    def set_bootlabel(self, device):
        if self.bootlabel == execCapture("/sbin/e2label", [ device ],
            stderr=os.path.join(self.stderr)).rstrip("\n"):
            return True

        self.log.info("Creating %s label..." % device)
        rc = execRedirect("/sbin/e2label", [ device, self.bootlabel ],
            stdout=self.stdout, stderr=self.stderr)
        if rc:
            self.log.error("Failed to set boot label for %s: %i..." % (device, rc))
            return False

        return True

    def set_bootlabel_efi(self, device):
        if self.bootlabel_efi == execCapture("/sbin/dosfslabel", [ device ],
            stderr=os.path.join(self.stderr)).rstrip("\n"):
            return True

        self.log.info("Creating %s label..." % device)
        rc = execRedirect("/sbin/dosfslabel", [ device, self.bootlabel_efi ],
            stdout=self.stdout, stderr=self.stderr)
        if rc:
            self.log.error("Failed to set EFI boot label for %s: %i..." % (device, rc))
            return False

        return True

    def get_uuid(self, bootdev):
        blkid = execCapture("/sbin/blkid", [ "-o", "export", bootdev ],
            stderr=self.stderr)
        for uuid in blkid.split("\n"):
            if uuid.startswith("UUID="):
                uuid = re.sub("^UUID=", "",  uuid)
                break
        if uuid == "":
            self.log.error("Failed to get UUID for %s disk" % bootdev)
            return ""

        return uuid.lower()

    def boot_label(self):

        self.log.info("Detecting labels...")
        bootdev = self.get_device_for_mpoint("/boot")

        if bootdev == "":
            bootdev = self.get_device_for_mpoint("/")
            if bootdev == "":
                self.log.error("Failed to detect boot disk")
                return False

        if not self.set_bootlabel(bootdev):
            return False

        # Get UUID
        self.bootuuid = self.get_uuid(bootdev)
        if self.bootuuid == "":
            return False

        if not isEfi():
            return True

        # EFI part
        bootdev = self.get_device_for_mpoint("/boot/efi")

        if bootdev == "":
            self.log.error("Failed to detect EFI boot disk")
            return False

        if not self.set_bootlabel_efi(bootdev):
            return False

        # Get EFI UUID
        self.bootuuid_efi = self.get_uuid(bootdev)
        if self.bootuuid_efi == "":
            return False

        return True

    def prepare_boot(self, device):
        self.log.info("Formatting %s partition..." % device)
        # Format additional /boot
        rc = execRedirect("/sbin/mkdosfs", [ device ],
            stdout=os.path.join(self.stdout), stderr=os.path.join(self.stderr))
        if rc:
            self.log.error("Failed to format %s disk: %i..." % (device, rc))
            return False

        # Create boot labels on main /boot
        if not self.boot_label():
            return False

        # Mount
        os.mkdir(self.mpath)
        rc = execRedirect("/bin/mount", [ device, self.mpath ],
            stdout=self.stdout, stderr=self.stderr)
        if rc:
            self.log.error("Failed to mount %s disk: %i..." % (device, rc))
            return False

        self.mounted = True

        # Create bootloader data
        try:
            self.log.info("Installing bootloader to %s..." % device)
            os.mkdir(self.mpath + "/grub")
            shutil.copy("/usr/share/grub/x86_64-unknown/stage1", self.mpath + "/grub")
            shutil.copy("/usr/share/grub/x86_64-unknown/stage2", self.mpath + "/grub")
            if os.path.exists("/boot/grub/splash.xpm.gz"):
                shutil.copy("/boot/grub/splash.xpm.gz", self.mpath + "/grub")
            f = open(self.mpath + "/grub/grub.conf", "a+")
            f.write( \
                "background FFFFFF\n" \
                "foreground 000000\n" \
                "default=0\n" \
                "timeout=5\n" \
                "splashimage=(hd0,0)/grub/splash.xpm.gz\n" \
                "hiddenmenu\n" \
                "title Redirect to %s main boot loader\n" \
                "        uuid_label_load_mbr %s %s\n" % \
                    (self.product_name, self.bootuuid, self.bootlabel))
            f.close()
            f = open(self.device_map, "a+")
            f.write("(hd1) %s\n" % re.sub("[0-9]+", "", device))
            f.close()
            # Place EFI bootloader, so user will be able to chainload original EFI loader
            os.makedirs(self.mpath + "/EFI/BOOT")
            shutil.copy("/boot/efi/EFI/BOOT/BOOTX64.efi", self.mpath + "/EFI/BOOT")
            f = open(self.mpath + "/EFI/BOOT/BOOTX64.conf", "a+")
            f.write( \
                "background FFFFFF\n" \
                "foreground 000000\n" \
                "default=0\n" \
                "timeout=5\n" \
                "splashimage=(hd0,0)/grub/splash.xpm.gz\n" \
                "hiddenmenu\n" \
                "title Redirect to %s main boot loader\n" \
                "        uuid_label %s %s\n" \
                "        chainloader /EFI/BOOT/BOOTX64.efi\n\n" % \
                    (self.product_name, self.bootuuid_efi, self.bootlabel_efi))
            f.close()
            self.umount()
        except Exception, e:
            self.log.error("Failed to prepare GRUB data for drive %s:\n%s" % (device, e))
            return False

        # Install bootloader
        rc = execRedirect("/sbin/grub", [ "--batch", "--device-map=%s" % self.device_map ],
            stdin = create_input("root (hd1,0)\nsetup (hd1)\nquit"),
            stdout=self.stdout, stderr=self.stderr)
        if rc:
            self.log.error("Failed to install GRUB on %s disk: %i..." % (device, rc))
            return False

        return True

    def prepare_data(self, device):
        self.log.info("Formatting %s partition..." % device)
        opts = [ device, "-q" , "-E", "lazy_itable_init=1", "-O", "uninit_bg", "-m", "0" ]
        if not self.ssd:
            opts += [ "-L", self.pstorage_cs_label ]
        # Format it
        rc = execRedirect("/sbin/mkfs.ext4", opts,
            stdout=self.stdout, stderr=self.stderr)
        if rc:
            self.log.error("Failed to format %s disk: %i..." % (device, rc))
            return False

        return True

def signal_handler(signal, frame):
    print
    sys.exit(0)

signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)

def usage():
    print "Usage: %s disk_drive [-y] [--ssd]" % sys.argv[0]
    print "    Where disk drive is drive that you want to prepare for pstorage"
    print "    -y - do not ask any questions"
    print "    --ssd - drive is SSD"
    print "    --noboot - Do not install GRUB bootloader"

if len(sys.argv) == 1:
    usage()
    sys.exit(1)

# Set log levels to exec functions
log = Logging(1)
program_log = Logging(0)

ssd = False
for arg in sys.argv:
    if arg == "--ssd":
        ssd = True
        break

boot = True
for arg in sys.argv:
    if arg == "--noboot":
        boot = False
        break

hdd = PrepareHDD(sys.argv[1], ssd=ssd, boot=boot)
if not hdd.check():
    sys.exit(1)

ask = True
# Ask user to continue
for arg in sys.argv:
    if arg == "-y":
        ask = False
        break

if ask:
    print "ALL data on %s will be completely destroyed. Are you sure to continue? [y]" % sys.argv[1]
    if raw_input("") != "y":
        sys.exit(0)

if not hdd.prepare_pstorage_data_disk():
    sys.exit(1)

print "Done!"

sys.exit(0)
