#!/usr/bin/env python3
# Copyright (c) 2017, VMware, Inc. or its affiliates.

import os
import sys
from functools import reduce


class ValidationException(Exception):
    def __init__(self, message):
        super().__init__(message)
        self.message = message


class Dummy(object):
    @staticmethod
    def validate_all():
        exit("resource group is not supported on this platform")


class CgroupValidation(object):
    @staticmethod
    def detect_cgroup_mount_point():
        # Get the cgroup mount place
        proc_mounts_path = "/proc/self/mounts"
        if os.path.exists(proc_mounts_path):
            with open(proc_mounts_path) as f:
                for line in f:
                    mount_specs = line.split()
                    mount_type = mount_specs[2]

                    if mount_type == "cgroup":
                        return os.path.dirname(mount_specs[1])

                    if mount_type == "cgroup2" and mount_specs[1].split("/")[-1] != "unified":
                        return mount_specs[1]
        return ""


class CgroupValidationVersionOne(CgroupValidation):
    def __init__(self):
        self.mount_point = self.detect_cgroup_mount_point()
        self.tab = {"r": os.R_OK, "w": os.W_OK, "x": os.X_OK, "f": os.F_OK}
        self.impl = "cgroup"
        self.error_prefix = " is not properly configured: "

        self.component_dirs = self.detect_comp_dirs()
        if not self.validate_comp_dirs():
            self.component_dirs = self.fallback_comp_dirs()

    def validate_all(self):
        """
        Check the permissions of the toplevel gpdb cgroup dirs.

        The checks should keep in sync with
        src/backend/utils/resgroup/cgroup-ops-v1.c
        """

        if not self.mount_point:
            self.die("failed to detect cgroup mount point.")

        if not self.component_dirs:
            self.die("failed to detect cgroup component dirs.")

        self.validate_permission("cpu", "gpdb/", "rwx")
        self.validate_permission("cpu", "gpdb/cgroup.procs", "rw")
        self.validate_permission("cpu", "gpdb/cpu.cfs_period_us", "rw")
        self.validate_permission("cpu", "gpdb/cpu.cfs_quota_us", "rw")
        self.validate_permission("cpu", "gpdb/cpu.shares", "rw")

        self.validate_permission("cpuacct", "gpdb/", "rwx")
        self.validate_permission("cpuacct", "gpdb/cgroup.procs", "rw")
        self.validate_permission("cpuacct", "gpdb/cpuacct.usage", "r")
        self.validate_permission("cpuacct", "gpdb/cpuacct.stat", "r")

        self.validate_permission("cpuset", "gpdb/", "rwx")
        self.validate_permission("cpuset", "gpdb/cgroup.procs", "rw")
        self.validate_permission("cpuset", "gpdb/cpuset.cpus", "rw")
        self.validate_permission("cpuset", "gpdb/cpuset.mems", "rw")

        self.validate_permission("memory", "gpdb/", "rwx")
        self.validate_permission("memory", "gpdb/cgroup.procs", "rw")
        self.validate_permission("memory", "gpdb/memory.usage_in_bytes", "r")

        self.validate_comp_hierarchy()

    def die(self, msg):
        raise ValidationException("cgroup is not properly configured: {}".format(msg))

    def validate_permission(self, comp, path, mode):
        """
        Validate permission on path.
        If path is a dir it must end with '/'.
        """
        if comp not in self.component_dirs:
            self.die("can't find dir of cgroup component {}".format(comp))

        component_dir = self.component_dirs[comp]
        fullpath = os.path.join(self.mount_point, comp, component_dir, path)
        pathtype = path[-1] == "/" and "directory" or "file"
        mode_bits = reduce(lambda x, y: x | y, [self.tab[x] for x in mode], 0)

        try:
            if not os.path.exists(fullpath):
                self.die("{} '{}' does not exist".format(pathtype, fullpath))

            if not os.access(fullpath, mode_bits):
                self.die("{} '{}' permission denied: require permission '{}'".
                         format(pathtype, fullpath, mode))
        except IOError as e:
            self.die("can't check permission on {} '{}': {}".format(pathtype, fullpath, str(e)))

    def validate_comp_dirs(self):
        """
        Validate existence of cgroup component dirs.

        Return True if all the components' dir exist and have good permission,
        otherwise return False.
        """

        for comp in self.required_comps():
            if comp not in self.component_dirs:
                return False

            component_dir = self.component_dirs[comp]
            fullpath = os.path.join(self.mount_point, comp, component_dir, 'gpdb')

            if not os.access(fullpath, os.R_OK | os.W_OK | os.X_OK):
                return False

        return True

    def validate_comp_hierarchy(self):
        """
        Validate the mount hierarchy of cpu and cpuset subsystem.

        Raise an error if cpu and cpuset are mounted on the same hierarchy.
        """

        path = "/proc/1/cgroup"
        if not os.path.exists(path):
            self.die("can't check component mount hierarchy: \
                     file '/proc/1/cgroup' doesn't exist")

        for line in open(path):
            line = line.strip()
            compid, compnames, comppath = line.split(":")
            if not compnames or '=' in compnames:
                continue
            complist = compnames.split(',')
            if "cpu" in complist and "cpuset" in complist:
                self.die("can't mount 'cpu' and 'cpuset' on the same hierarchy")

    @staticmethod
    def detect_comp_dirs():
        component_dirs = {}
        proc_path = "/proc/1/cgroup"

        if not os.path.exists(proc_path):
            return component_dirs

        with open(proc_path) as f:
            for line in f:
                line = line.strip()
                compid, compnames, comppath = line.split(":")
                if not compnames or '=' in compnames:
                    continue
                for compname in compnames.split(','):
                    component_dirs[compname] = comppath.strip(os.path.sep)

        return component_dirs

    @staticmethod
    def required_comps():
        comps = ["cpu", "cpuacct", "cpuset", "memory"]
        return comps

    def fallback_comp_dirs(self):
        component_dirs = {}
        for comp in self.required_comps():
            component_dirs[comp] = ""
        return component_dirs


if __name__ == '__main__':
    if sys.platform.startswith('linux'):
        try:
            CgroupValidationVersionOne().validate_all()
        except ValidationException as e:
            exit(e.message)
    else:
        Dummy().validate_all()
