Source code for accre.util

"""
General utility functions and classes that don't fit in any
other module.

In addition to the functions documented below, this module also defines
RedStr, GreenStr, YellowStr, LightPurpleStr, PurpleStr, CyanStr,
LightGrayStr, and BlackStr convenience functions which take a string
as an argument and return the string wrapped in ANSI color code characters
so that the string will appear on a compliant terminal in the specified
color.
"""
import argparse
import datetime
import calendar
from collections import namedtuple
from functools import partial
import math
import os
import random
import re
import socket
import hashlib

from accre import __version__, __title__

# Cache of the contents of /etc/{passwd,shadow,group}
_posix_users = {'passwd': None, 'shadow': None, 'group': None}

# Cache of the EFF long wordlist
_eff_wordlist = None


PosixUser = namedtuple('PosixUser',
    'name, password, uid, gid, gecos, homedir, shell'
)


ShadowUser = namedtuple('ShadowUser',
    'name, password, lastchange, min, max, warn, inactive, expire, res'
)


PosixGroup = namedtuple('PosixGroup', 'name, password, gid, members')


EMAIL_USER_RE = re.compile(
    r"(^[-!#$%&'*+/=?^_`{}|~0-9A-Z]+(\.[-!#$%&'*+/=?^_`{}|~0-9A-Z]+)*\Z"  # dot-atom
    r'|^"([\001-\010\013\014\016-\037!#-\[\]-\177]|\\[\001-\011\013\014\016-\177])*"\Z)',  # quoted-string
    re.IGNORECASE
)
EMAIL_DOMAIN_RE = re.compile(
    # max length for domain name labels is 63 characters per RFC 1034
    r'((?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+)(?:[A-Z0-9-]{2,63}(?<!-))\Z',
    re.IGNORECASE
)


[docs]def get_posixuser(username, reread=False): """ Fetch data for a given username from /etc/passwd :param str username: User to retrive :param bool reread: Force reading of /etc/passwd even if it has already been read and is cached in this module :returns: user's /etc/passwd record :rtype: PosixUser """ if _posix_users['passwd'] is None or reread: with open('/etc/passwd') as stream: _posix_users['passwd'] = stream.read() userline = None for line in _posix_users['passwd'].splitlines(): if line.startswith('{0}:'.format(username)): userline = line break if userline is None: raise ValueError('User {} not found in /etc/passwd'.format(username)) user_props = userline.split(':') # change uid, gid to integers for idx in (2, 3): try: user_props[idx] = int(user_props[idx]) except ValueError: if user_props[idx] != '': raise ValueError( 'passwd field {0} for {1} must be an int or blank' .format(idx, username) ) return PosixUser(*user_props)
[docs]def get_shadowuser(username, reread=False): """ Fetch data for a given username from /etc/shadow. This method will obviously fail unless you are root. :param str username: User to retrive :param bool reread: Force reading of /etc/shadow even if it has already been read and is cached in this module :returns: user's /etc/shadow record :rtype: ShadowUser """ if _posix_users['shadow'] is None or reread: with open('/etc/shadow') as stream: _posix_users['shadow'] = stream.read() userline = None for line in _posix_users['shadow'].splitlines(): if line.startswith('{0}:'.format(username)): userline = line break if userline is None: raise ValueError('User {} not found in /etc/shadow'.format(username)) user_props = userline.split(':') # change lastchage, min, max, warn to integers for idx in (2, 3, 4, 5): try: user_props[idx] = int(user_props[idx]) except ValueError: if user_props[idx] != '': raise ValueError( 'Shadow field {0} for {1} must be an int or blank' .format(idx, username) ) return ShadowUser(*user_props)
[docs]def get_posixgroup(group, reread=False): """ Fetch data for a given group from /etc/group. :param str group: Name of group to retrive :param bool reread: Force reading of /etc/group even if it has already been read and is cached in this module :returns: group's /etc/group record :rtype: PosixGroup """ if _posix_users['group'] is None or reread: with open('/etc/group') as stream: _posix_users['group'] = stream.read() groupline = None for line in _posix_users['group'].splitlines(): if line.startswith('{0}:'.format(group)): groupline = line break if groupline is None: raise ValueError('Group {} not found in /etc/group'.format(group)) group_props = groupline.split(':') # change gid to an integer try: group_props[2] = int(group_props[2]) except ValueError: if group_props[2] != '': raise ValueError( 'Group field 2 for {1} must be an int or blank'.format(group) ) # split the members into a tuple if group_props[3]: group_props[3] = tuple(group_props[3].split(',')) else: group_props[3] = () return PosixGroup(*group_props)
[docs]def utcnow(): """ Wrapper for datetime.datetime.utcnow for testability """ return datetime.datetime.utcnow()
[docs]def accre_argparser(command_name, description=None): """ Return an argparse.ArgumentParser object with some general customization for this library. A --version option is set with the command name and package verison and title. :param str command_name: Name of the CLI command to be displayed in the version :param str description: ArgumentParser help description :returns: Customized parser with --version option :rtype: ArgumentParser """ version_msg = '{0}, {1}, version {2}'.format( command_name, __title__, __version__ ) parser = argparse.ArgumentParser(description=description) parser.add_argument( '-v', '--version', action='version', help="Print the version of {0}".format(command_name), version=version_msg ) return parser
[docs]def interpret_string_values(mapping): """ Take the string values of the given dict assumed and convert them to lists if they contain commas. If the values or list elements can be interpreted as floats, convert them to floats. Values that are not strings are ignored :param mapping: dict or mapping to be interpreted :returns: dict with interpreted values :rtype: dict """ result = {} for key in mapping: val = mapping[key] if not isinstance(val, str): result[key] = val continue val = val.split(',') for idx, item in enumerate(val): try: val[idx] = float(item) except Exception: pass if len(val) == 1: val = val[0] result[key] = val return result
[docs]def validate_email_address(address): """ Raise a ValueError if the email address is not valid according to a subset of the 2017 Django logic, see https://github.com/django/django/blob/d95f1e711b9d1b3e60f7728e9710b8f542cec385/django/core/validators.py#L168-L180 Note that IP addresses are not allowed by this function, nor are internationalized domain names. :param str address: Email address to be validated """ if not address or '@' not in address: raise ValueError('{0} is not a valid email address'.format(address)) user, domain = address.rsplit('@', 1) if not EMAIL_USER_RE.match(user): raise ValueError('{0} is not a valid email address'.format(address)) if not EMAIL_DOMAIN_RE.match(domain): raise ValueError('{0} is not a valid email address'.format(address))
[docs]def get_primary_ip(): """ Return the primary IP address (default route) for an internal ACCRE node. :returns: primary IP address for the server :rtype: str """ s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) try: s.connect(('10.0.255.255', 1)) ip = s.getsockname()[0] except: ip = '127.0.0.1' finally: s.close() return ip
[docs]def generate_password(separator=' ', count=6): """ Generate a six-phrase secure and human readable password using the EFF long wordlist, see https://www.eff.org/deeplinks/2016/07/new-wordlists-random-passphrases :param str separator: character(s) to separate individual words in the passphrase, defaults to a single space :param int count: Number of words to generate, defaults to 6 :returns: Generated passphrase :rtype: str """ global _eff_wordlist if _eff_wordlist is None: data = os.path.join(os.path.dirname(__file__), 'data') with open(os.path.join(data, 'eff_large_wordlist.txt')) as stream: _eff_wordlist = [line.split()[1] for line in stream.readlines()] pswd = random.SystemRandom() return separator.join(pswd.choice(_eff_wordlist) for idx in range(count))
_ansi_term_colors = True
[docs]def set_ansi_colors(flag): """ Set the behavior of the ANSI terminal color string functions such as ``accre.util.RedStr`` to produce strings with color codes if set to True, or plain strings without color codes if set to False. This may be used for CLI tools to set --no-color options if desired. This behavior is initially set to True. :param bool flag: Turn the ANSI terminal colors on or off """ global _ansi_term_colors _ansi_term_colors = flag
def _ansicolorstr(colorcode, value): if _ansi_term_colors: return '\033[{0}m'.format(colorcode) + value + '\033[00m' else: return value # ANSI color string convenience methods RedStr = partial(_ansicolorstr, 91) GreenStr = partial(_ansicolorstr, 92) YellowStr = partial(_ansicolorstr, 93) LightPurpleStr = partial(_ansicolorstr, 94) PurpleStr = partial(_ansicolorstr, 95) CyanStr = partial(_ansicolorstr, 96) LightGrayStr = partial(_ansicolorstr, 97) BlackStr = partial(_ansicolorstr, 98)
[docs]def filehash(fpath, algorithm): """ Returns the hash of the given file calculated using the desired algorithm. :param str f: File path :param str alg: Hash function (md5, sha1, sha224, sha256, sha384, sha512) :returns: Hash value :rtype: str """ hash_types = { 'md5': hashlib.md5(), 'sha1': hashlib.sha1(), 'sha224': hashlib.sha224(), 'sha256': hashlib.sha256(), 'sha384': hashlib.sha384(), 'sha512': hashlib.sha512() } if algorithm in hash_types: hashfunc = hash_types[algorithm] else: raise ValueError('{0} is not a valid hash function'.format(algorithm)) try: blocksize = hashfunc.block_size * 262144 except AttributeError as e: # Use 16MB blocksize: multiple of the internal blocksize of md5/sha1 (64) blocksize = 16777216 with open(fpath, 'rb') as f: for block in iter(lambda: f.read(blocksize), b''): hashfunc.update(block) return hashfunc.hexdigest()
[docs]def parse_slurm_cli_limits(limits): """ Parse a string containing comma delimited slurm usage limits, fairshare, and/or QOS returning a dictionary with values for each item given. Raise a ValueError for an invalid string. :param str limits: input string of comma separated limits :returns: dictionary with limits :rypte: dict """ result = {} for item in limits.split(','): try: key, value = item.split('=') except ValueError: raise ValueError( 'Each slurm limit item must be of the form limit=value' ) key = key.lower() if key == 'qos': result['qos'] = value elif key == 'fairshare': try: result['fairshare'] = int(value) except ValueError: raise ValueError('Fairshare must be an integer') elif key == 'max_cpu' or key == 'grpcpus': try: result['max_cpu'] = int(value) except ValueError: raise ValueError('Maximum CPU limit must be an integer') elif key == 'max_runmins' or key == 'grpcpurunmins': try: result['max_runmins'] = int(value) except ValueError: raise ValueError('Maximum runtime limit must be an integer') elif key == 'max_mem' or key == 'grpmemory': result['max_mem'] = value else: raise ValueError('Invalid limit: {0}'.format(key)) return result
[docs]def convert_byte_unit(raw, target='mi', ieee=False): """ Convert input raw string representing a quantity of bytes to a float in the target unit, i.e. 2GB --> 2048.0 if the target unit is MB. If ieee is set then strict IEEE units are used where MB = 10^6 and MiB = 2^20, otherwise everything is considered to be powers of 1024 as in the good old days and 'i' is ignored :param str raw: String containing value of bytes :param str target: Target unit to convert into (i.e. B, kB, MiB, GB) :param bool ieee: Use strict IEEE definitions for MB, MiB, etc. :returns: Value in the specified target unit :rtype: float """ bin_powers = { 'k': 2**10, 'm': 2**20, 'g': 2**30, 't': 2**40, 'p': 2**50, 'e': 2**60 } dec_powers = { 'k': 1000, 'm': 1000**2, 'g': 1000**3, 't': 1000**4, 'p': 1000**5, 'e': 1000**6 } raw_in = raw; target_in = target # lowercase and strip trailing 'b' if it exists (assume bytes) raw = raw.lower().strip() raw = raw[:-1] if raw.endswith('b') else raw target = target.lower().strip() target = target[:-1] if target.endswith('b') else target if not ieee: raw = raw[:-1] if raw.endswith('i') else raw target = target[:-1] if target.endswith('i') else target rawpow = bin_powers; targetpow = bin_powers else: rawpow = dec_powers; targetpow = dec_powers if raw.endswith('i'): rawpow = bin_powers raw = raw[:-1] if target.endswith('i'): targetpow = bin_powers target = target[:-1] if target and target not in targetpow: raise ValueError(f'Invalid target unit: {target_in}') try: if raw[-1] in rawpow: raw = float(raw[:-1].strip()) * rawpow[raw[-1]] else: raw = float(raw.strip()) except Exception: raise ValueError(f'Could not parse {raw_in} as a quantity of bytes') if target: return raw / targetpow[target] return raw
[docs]def byte_quantity_isclose(a, b, rel_tol=1e-09, abs_tol=0.0, ieee=False): """ Comparison test as with math.isclose except for strings of quantities representing bytes, i.e. is 2048MB approximately equal to 2GB. By default, treat all quantities MB, MiB as binary powers, but if ieee is true use strict ieee definitions for MB, MiB, etc. :param str a: Value of bytes to compare :param str b: Value of bytes to compare :param float rel_tol: the maximum allowed difference between a and b, relative to the larger absolute value of a or b. :param float abs_tol: the minimum absolute tolerance – useful for comparisons near zero :param bool ieee: Use strict IEEE definitions for MB, MiB, etc. :returns: True if the values are approximately equal :rtype: bool """ a = convert_byte_unit(a, target='b', ieee=ieee) b = convert_byte_unit(b, target='b', ieee=ieee) return math.isclose(a, b, abs_tol=abs_tol, rel_tol=rel_tol)
[docs]def get_slurm_data_time_tag(month=None, year=None): """ This small function is used to generate the data time tag for the table ACCOUNTS_PARTITION_DATA If both month and year are None, then we return "current"; which means the data is for recent time period. Otherwise we will return month + year as the data time tag :param int month: input month for the slurm data, should be from 1 to 12 :param int year: input year for the slurm data """ if month is None or year is None: return "current" else: # checking the input month if not isinstance(month, int) and (month < 1 or month > 12): raise ValueError(f'Invalid month value in get_slurm_data_time_tag: {month}') # also check whether year is numerical if not isinstance(year, int) and (year <2000 or year > 2100): raise ValueError(f'Invalid year value in get_slurm_data_time_tag: {year}') # now get the corresponding month name month_name = calendar.month_name[month] # return the time tag tag = year.__str__() + "-" + month_name return tag