diff options
Diffstat (limited to 'roles/lib_utils')
25 files changed, 5033 insertions, 0 deletions
diff --git a/roles/lib_utils/action_plugins/generate_pv_pvcs_list.py b/roles/lib_utils/action_plugins/generate_pv_pvcs_list.py new file mode 100644 index 000000000..eb13a58ba --- /dev/null +++ b/roles/lib_utils/action_plugins/generate_pv_pvcs_list.py @@ -0,0 +1,157 @@ +""" +Ansible action plugin to generate pv and pvc dictionaries lists +""" + +from ansible.plugins.action import ActionBase +from ansible import errors + + +class ActionModule(ActionBase): + """Action plugin to execute health checks.""" + + def get_templated(self, var_to_template): + """Return a properly templated ansible variable""" + return self._templar.template(self.task_vars.get(var_to_template)) + + def build_common(self, varname=None): + """Retrieve common variables for each pv and pvc type""" + volume = self.get_templated(str(varname) + '_volume_name') + size = self.get_templated(str(varname) + '_volume_size') + labels = self.task_vars.get(str(varname) + '_labels') + if labels: + labels = self._templar.template(labels) + else: + labels = dict() + access_modes = self.get_templated(str(varname) + '_access_modes') + return (volume, size, labels, access_modes) + + def build_pv_nfs(self, varname=None): + """Build pv dictionary for nfs storage type""" + host = self.task_vars.get(str(varname) + '_host') + if host: + self._templar.template(host) + elif host is None: + groups = self.task_vars.get('groups') + default_group_name = self.get_templated('openshift_persistent_volumes_default_nfs_group') + if groups and default_group_name and default_group_name in groups and len(groups[default_group_name]) > 0: + host = groups['oo_nfs_to_config'][0] + else: + raise errors.AnsibleModuleError("|failed no storage host detected") + volume, size, labels, access_modes = self.build_common(varname=varname) + directory = self.get_templated(str(varname) + '_nfs_directory') + path = directory + '/' + volume + return dict( + name="{0}-volume".format(volume), + capacity=size, + labels=labels, + access_modes=access_modes, + storage=dict( + nfs=dict( + server=host, + path=path))) + + def build_pv_openstack(self, varname=None): + """Build pv dictionary for openstack storage type""" + volume, size, labels, access_modes = self.build_common(varname=varname) + filesystem = self.get_templated(str(varname) + '_openstack_filesystem') + volume_id = self.get_templated(str(varname) + '_openstack_volumeID') + return dict( + name="{0}-volume".format(volume), + capacity=size, + labels=labels, + access_modes=access_modes, + storage=dict( + cinder=dict( + fsType=filesystem, + volumeID=volume_id))) + + def build_pv_glusterfs(self, varname=None): + """Build pv dictionary for glusterfs storage type""" + volume, size, labels, access_modes = self.build_common(varname=varname) + endpoints = self.get_templated(str(varname) + '_glusterfs_endpoints') + path = self.get_templated(str(varname) + '_glusterfs_path') + read_only = self.get_templated(str(varname) + '_glusterfs_readOnly') + return dict( + name="{0}-volume".format(volume), + capacity=size, + labels=labels, + access_modes=access_modes, + storage=dict( + glusterfs=dict( + endpoints=endpoints, + path=path, + readOnly=read_only))) + + def build_pv_dict(self, varname=None): + """Check for the existence of PV variables""" + kind = self.task_vars.get(str(varname) + '_kind') + if kind: + kind = self._templar.template(kind) + create_pv = self.task_vars.get(str(varname) + '_create_pv') + if create_pv and self._templar.template(create_pv): + if kind == 'nfs': + return self.build_pv_nfs(varname=varname) + + elif kind == 'openstack': + return self.build_pv_openstack(varname=varname) + + elif kind == 'glusterfs': + return self.build_pv_glusterfs(varname=varname) + + elif not (kind == 'object' or kind == 'dynamic'): + msg = "|failed invalid storage kind '{0}' for component '{1}'".format( + kind, + varname) + raise errors.AnsibleModuleError(msg) + return None + + def build_pvc_dict(self, varname=None): + """Check for the existence of PVC variables""" + kind = self.task_vars.get(str(varname) + '_kind') + if kind: + kind = self._templar.template(kind) + create_pv = self.task_vars.get(str(varname) + '_create_pv') + if create_pv: + create_pv = self._templar.template(create_pv) + create_pvc = self.task_vars.get(str(varname) + '_create_pvc') + if create_pvc: + create_pvc = self._templar.template(create_pvc) + if kind != 'object' and create_pv and create_pvc: + volume, size, _, access_modes = self.build_common(varname=varname) + return dict( + name="{0}-claim".format(volume), + capacity=size, + access_modes=access_modes) + return None + + def run(self, tmp=None, task_vars=None): + """Run generate_pv_pvcs_list action plugin""" + result = super(ActionModule, self).run(tmp, task_vars) + # Ignore settting self.task_vars outside of init. + # pylint: disable=W0201 + self.task_vars = task_vars or {} + + result["changed"] = False + result["failed"] = False + result["msg"] = "persistent_volumes list and persistent_volume_claims list created" + vars_to_check = ['openshift_hosted_registry_storage', + 'openshift_hosted_router_storage', + 'openshift_hosted_etcd_storage', + 'openshift_logging_storage', + 'openshift_loggingops_storage', + 'openshift_metrics_storage', + 'openshift_prometheus_storage', + 'openshift_prometheus_alertmanager_storage', + 'openshift_prometheus_alertbuffer_storage'] + persistent_volumes = [] + persistent_volume_claims = [] + for varname in vars_to_check: + pv_dict = self.build_pv_dict(varname) + if pv_dict: + persistent_volumes.append(pv_dict) + pvc_dict = self.build_pvc_dict(varname) + if pvc_dict: + persistent_volume_claims.append(pvc_dict) + result["persistent_volumes"] = persistent_volumes + result["persistent_volume_claims"] = persistent_volume_claims + return result diff --git a/roles/lib_utils/action_plugins/sanity_checks.py b/roles/lib_utils/action_plugins/sanity_checks.py new file mode 100644 index 000000000..09ce55e8f --- /dev/null +++ b/roles/lib_utils/action_plugins/sanity_checks.py @@ -0,0 +1,181 @@ +""" +Ansible action plugin to ensure inventory variables are set +appropriately and no conflicting options have been provided. +""" +import re + +from ansible.plugins.action import ActionBase +from ansible import errors + +# Valid values for openshift_deployment_type +VALID_DEPLOYMENT_TYPES = ('origin', 'openshift-enterprise') + +# Tuple of variable names and default values if undefined. +NET_PLUGIN_LIST = (('openshift_use_openshift_sdn', True), + ('openshift_use_flannel', False), + ('openshift_use_nuage', False), + ('openshift_use_contiv', False), + ('openshift_use_calico', False)) + +ENTERPRISE_TAG_REGEX_ERROR = """openshift_image_tag must be in the format +v#.#[.#[.#]]. Examples: v1.2, v3.4.1, v3.5.1.3, +v3.5.1.3.4, v1.2-1, v1.2.3-4, v1.2.3-4.5, v1.2.3-4.5.6 +You specified openshift_image_tag={}""" + +ORIGIN_TAG_REGEX_ERROR = """openshift_image_tag must be in the format +v#.#.#[-optional.#]. Examples: v1.2.3, v3.5.1-alpha.1 +You specified openshift_image_tag={}""" + +ORIGIN_TAG_REGEX = {'re': '(^v?\\d+\\.\\d+\\.\\d+(-[\\w\\-\\.]*)?$)', + 'error_msg': ORIGIN_TAG_REGEX_ERROR} +ENTERPRISE_TAG_REGEX = {'re': '(^v\\d+\\.\\d+(\\.\\d+)*(-\\d+(\\.\\d+)*)?$)', + 'error_msg': ENTERPRISE_TAG_REGEX_ERROR} +IMAGE_TAG_REGEX = {'origin': ORIGIN_TAG_REGEX, + 'openshift-enterprise': ENTERPRISE_TAG_REGEX} + +CONTAINERIZED_NO_TAG_ERROR_MSG = """To install a containerized Origin release, +you must set openshift_release or openshift_image_tag in your inventory to +specify which version of the OpenShift component images to use. +(Suggestion: add openshift_release="x.y" to inventory.)""" + + +def to_bool(var_to_check): + """Determine a boolean value given the multiple + ways bools can be specified in ansible.""" + # http://yaml.org/type/bool.html + yes_list = (True, 1, "True", "1", "true", "TRUE", + "Yes", "yes", "Y", "y", "YES", + "on", "ON", "On") + return var_to_check in yes_list + + +class ActionModule(ActionBase): + """Action plugin to execute sanity checks.""" + def template_var(self, hostvars, host, varname): + """Retrieve a variable from hostvars and template it. + If undefined, return None type.""" + res = hostvars[host].get(varname) + if res is None: + return None + return self._templar.template(res) + + def check_openshift_deployment_type(self, hostvars, host): + """Ensure a valid openshift_deployment_type is set""" + openshift_deployment_type = self.template_var(hostvars, host, + 'openshift_deployment_type') + if openshift_deployment_type not in VALID_DEPLOYMENT_TYPES: + type_strings = ", ".join(VALID_DEPLOYMENT_TYPES) + msg = "openshift_deployment_type must be defined and one of {}".format(type_strings) + raise errors.AnsibleModuleError(msg) + return openshift_deployment_type + + def check_python_version(self, hostvars, host, distro): + """Ensure python version is 3 for Fedora and python 2 for others""" + ansible_python = self.template_var(hostvars, host, 'ansible_python') + if distro == "Fedora": + if ansible_python['version']['major'] != 3: + msg = "openshift-ansible requires Python 3 for {};".format(distro) + msg += " For information on enabling Python 3 with Ansible," + msg += " see https://docs.ansible.com/ansible/python_3_support.html" + raise errors.AnsibleModuleError(msg) + else: + if ansible_python['version']['major'] != 2: + msg = "openshift-ansible requires Python 2 for {};".format(distro) + + def check_image_tag_format(self, hostvars, host, openshift_deployment_type): + """Ensure openshift_image_tag is formatted correctly""" + openshift_image_tag = self.template_var(hostvars, host, 'openshift_image_tag') + if not openshift_image_tag or openshift_image_tag == 'latest': + return None + regex_to_match = IMAGE_TAG_REGEX[openshift_deployment_type]['re'] + res = re.match(regex_to_match, str(openshift_image_tag)) + if res is None: + msg = IMAGE_TAG_REGEX[openshift_deployment_type]['error_msg'] + msg = msg.format(str(openshift_image_tag)) + raise errors.AnsibleModuleError(msg) + + def no_origin_image_version(self, hostvars, host, openshift_deployment_type): + """Ensure we can determine what image version to use with origin + fail when: + - openshift_is_containerized + - openshift_deployment_type == 'origin' + - openshift_release is not defined + - openshift_image_tag is not defined""" + if not openshift_deployment_type == 'origin': + return None + oic = self.template_var(hostvars, host, 'openshift_is_containerized') + if not to_bool(oic): + return None + orelease = self.template_var(hostvars, host, 'openshift_release') + oitag = self.template_var(hostvars, host, 'openshift_image_tag') + if not orelease and not oitag: + raise errors.AnsibleModuleError(CONTAINERIZED_NO_TAG_ERROR_MSG) + + def network_plugin_check(self, hostvars, host): + """Ensure only one type of network plugin is enabled""" + res = [] + # Loop through each possible network plugin boolean, determine the + # actual boolean value, and append results into a list. + for plugin, default_val in NET_PLUGIN_LIST: + res_temp = self.template_var(hostvars, host, plugin) + if res_temp is None: + res_temp = default_val + res.append(to_bool(res_temp)) + + if sum(res) != 1: + plugin_str = list(zip([x[0] for x in NET_PLUGIN_LIST], res)) + + msg = "Host Checked: {} Only one of must be true. Found: {}".format(host, plugin_str) + raise errors.AnsibleModuleError(msg) + + def check_hostname_vars(self, hostvars, host): + """Checks to ensure openshift_hostname + and openshift_public_hostname + conform to the proper length of 63 characters or less""" + for varname in ('openshift_public_hostname', 'openshift_hostname'): + var_value = self.template_var(hostvars, host, varname) + if var_value and len(var_value) > 63: + msg = '{} must be 63 characters or less'.format(varname) + raise errors.AnsibleModuleError(msg) + + def run_checks(self, hostvars, host): + """Execute the hostvars validations against host""" + distro = self.template_var(hostvars, host, 'ansible_distribution') + odt = self.check_openshift_deployment_type(hostvars, host) + self.check_python_version(hostvars, host, distro) + self.check_image_tag_format(hostvars, host, odt) + self.no_origin_image_version(hostvars, host, odt) + self.network_plugin_check(hostvars, host) + self.check_hostname_vars(hostvars, host) + + def run(self, tmp=None, task_vars=None): + result = super(ActionModule, self).run(tmp, task_vars) + + # self.task_vars holds all in-scope variables. + # Ignore settting self.task_vars outside of init. + # pylint: disable=W0201 + self.task_vars = task_vars or {} + + # self._task.args holds task parameters. + # check_hosts is a parameter to this plugin, and should provide + # a list of hosts. + check_hosts = self._task.args.get('check_hosts') + if not check_hosts: + msg = "check_hosts is required" + raise errors.AnsibleModuleError(msg) + + # We need to access each host's variables + hostvars = self.task_vars.get('hostvars') + if not hostvars: + msg = hostvars + raise errors.AnsibleModuleError(msg) + + # We loop through each host in the provided list check_hosts + for host in check_hosts: + self.run_checks(hostvars, host) + + result["changed"] = False + result["failed"] = False + result["msg"] = "Sanity Checks passed" + + return result diff --git a/roles/lib_utils/callback_plugins/aa_version_requirement.py b/roles/lib_utils/callback_plugins/aa_version_requirement.py new file mode 100644 index 000000000..1093acdae --- /dev/null +++ b/roles/lib_utils/callback_plugins/aa_version_requirement.py @@ -0,0 +1,60 @@ +#!/usr/bin/python + +""" +This callback plugin verifies the required minimum version of Ansible +is installed for proper operation of the OpenShift Ansible Installer. +The plugin is named with leading `aa_` to ensure this plugin is loaded +first (alphanumerically) by Ansible. +""" +import sys +from ansible import __version__ + +if __version__ < '2.0': + # pylint: disable=import-error,no-name-in-module + # Disabled because pylint warns when Ansible v2 is installed + from ansible.callbacks import display as pre2_display + CallbackBase = object + + def display(*args, **kwargs): + """Set up display function for pre Ansible v2""" + pre2_display(*args, **kwargs) +else: + from ansible.plugins.callback import CallbackBase + from ansible.utils.display import Display + + def display(*args, **kwargs): + """Set up display function for Ansible v2""" + display_instance = Display() + display_instance.display(*args, **kwargs) + + +# Set to minimum required Ansible version +REQUIRED_VERSION = '2.4.1.0' +DESCRIPTION = "Supported versions: %s or newer" % REQUIRED_VERSION + + +def version_requirement(version): + """Test for minimum required version""" + return version >= REQUIRED_VERSION + + +class CallbackModule(CallbackBase): + """ + Ansible callback plugin + """ + + CALLBACK_VERSION = 1.0 + CALLBACK_NAME = 'version_requirement' + + def __init__(self): + """ + Version verification is performed in __init__ to catch the + requirement early in the execution of Ansible and fail gracefully + """ + super(CallbackModule, self).__init__() + + if not version_requirement(__version__): + display( + 'FATAL: Current Ansible version (%s) is not supported. %s' + % (__version__, DESCRIPTION), color='red') + sys.exit(1) diff --git a/roles/lib_utils/callback_plugins/openshift_quick_installer.py b/roles/lib_utils/callback_plugins/openshift_quick_installer.py new file mode 100644 index 000000000..365e2443d --- /dev/null +++ b/roles/lib_utils/callback_plugins/openshift_quick_installer.py @@ -0,0 +1,360 @@ +# pylint: disable=invalid-name,protected-access,import-error,line-too-long,attribute-defined-outside-init + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. + +"""This file is a stdout callback plugin for the OpenShift Quick +Installer. The purpose of this callback plugin is to reduce the amount +of produced output for customers and enable simpler progress checking. + +What's different: + +* Playbook progress is expressed as: Play <current_play>/<total_plays> (Play Name) + Ex: Play 3/30 (Initialize Megafrobber) + +* The Tasks and Handlers in each play (and included roles) are printed + as a series of .'s following the play progress line. + +* Many of these methods include copy and paste code from the upstream + default.py callback. We do that to give us control over the stdout + output while allowing Ansible to handle the file logging + normally. The biggest changes here are that we are manually setting + `log_only` to True in the Display.display method and we redefine the + Display.banner method locally so we can set log_only on that call as + well. + +""" + +from __future__ import (absolute_import, print_function) +import sys +from ansible import constants as C +from ansible.plugins.callback import CallbackBase +from ansible.utils.color import colorize, hostcolor + + +class CallbackModule(CallbackBase): + + """ + Ansible callback plugin + """ + CALLBACK_VERSION = 2.2 + CALLBACK_TYPE = 'stdout' + CALLBACK_NAME = 'openshift_quick_installer' + CALLBACK_NEEDS_WHITELIST = False + plays_count = 0 + plays_total_ran = 0 + + def __init__(self): + """Constructor, ensure standard self.*s are set""" + self._play = None + self._last_task_banner = None + super(CallbackModule, self).__init__() + + def banner(self, msg, color=None): + '''Prints a header-looking line with stars taking up to 80 columns + of width (3 columns, minimum) + + Overrides the upstream banner method so that display is called + with log_only=True + ''' + msg = msg.strip() + star_len = (79 - len(msg)) + if star_len < 0: + star_len = 3 + stars = "*" * star_len + self._display.display("\n%s %s" % (msg, stars), color=color, log_only=True) + + def _print_task_banner(self, task): + """Imported from the upstream 'default' callback""" + # args can be specified as no_log in several places: in the task or in + # the argument spec. We can check whether the task is no_log but the + # argument spec can't be because that is only run on the target + # machine and we haven't run it thereyet at this time. + # + # So we give people a config option to affect display of the args so + # that they can secure this if they feel that their stdout is insecure + # (shoulder surfing, logging stdout straight to a file, etc). + args = '' + if not task.no_log and C.DISPLAY_ARGS_TO_STDOUT: + args = ', '.join('%s=%s' % a for a in task.args.items()) + args = ' %s' % args + + self.banner(u"TASK [%s%s]" % (task.get_name().strip(), args)) + if self._display.verbosity >= 2: + path = task.get_path() + if path: + self._display.display(u"task path: %s" % path, color=C.COLOR_DEBUG, log_only=True) + + self._last_task_banner = task._uuid + + def v2_playbook_on_start(self, playbook): + """This is basically the start of it all""" + self.plays_count = len(playbook.get_plays()) + self.plays_total_ran = 0 + + if self._display.verbosity > 1: + from os.path import basename + self.banner("PLAYBOOK: %s" % basename(playbook._file_name)) + + def v2_playbook_on_play_start(self, play): + """Each play calls this once before running any tasks + +We could print the number of tasks here as well by using +`play.get_tasks()` but that is not accurate when a play includes a +role. Only the tasks directly assigned to a play are exposed in the +`play` object. + """ + self.plays_total_ran += 1 + print("") + print("Play %s/%s (%s)" % (self.plays_total_ran, self.plays_count, play.get_name())) + + name = play.get_name().strip() + if not name: + msg = "PLAY" + else: + msg = "PLAY [%s]" % name + + self._play = play + + self.banner(msg) + + # pylint: disable=unused-argument,no-self-use + def v2_playbook_on_task_start(self, task, is_conditional): + """This prints out the task header. For example: + +TASK [openshift_facts : Ensure PyYaml is installed] ***... + +Rather than print out all that for every task, we print a dot +character to indicate a task has been started. + """ + sys.stdout.write('.') + + args = '' + # args can be specified as no_log in several places: in the task or in + # the argument spec. We can check whether the task is no_log but the + # argument spec can't be because that is only run on the target + # machine and we haven't run it thereyet at this time. + # + # So we give people a config option to affect display of the args so + # that they can secure this if they feel that their stdout is insecure + # (shoulder surfing, logging stdout straight to a file, etc). + if not task.no_log and C.DISPLAY_ARGS_TO_STDOUT: + args = ', '.join(('%s=%s' % a for a in task.args.items())) + args = ' %s' % args + self.banner("TASK [%s%s]" % (task.get_name().strip(), args)) + if self._display.verbosity >= 2: + path = task.get_path() + if path: + self._display.display("task path: %s" % path, color=C.COLOR_DEBUG, log_only=True) + + # pylint: disable=unused-argument,no-self-use + def v2_playbook_on_handler_task_start(self, task): + """Print out task header for handlers + +Rather than print out a header for every handler, we print a dot +character to indicate a handler task has been started. +""" + sys.stdout.write('.') + + self.banner("RUNNING HANDLER [%s]" % task.get_name().strip()) + + # pylint: disable=unused-argument,no-self-use + def v2_playbook_on_cleanup_task_start(self, task): + """Print out a task header for cleanup tasks + +Rather than print out a header for every handler, we print a dot +character to indicate a handler task has been started. +""" + sys.stdout.write('.') + + self.banner("CLEANUP TASK [%s]" % task.get_name().strip()) + + def v2_playbook_on_include(self, included_file): + """Print out paths to statically included files""" + msg = 'included: %s for %s' % (included_file._filename, ", ".join([h.name for h in included_file._hosts])) + self._display.display(msg, color=C.COLOR_SKIP, log_only=True) + + def v2_runner_on_ok(self, result): + """This prints out task results in a fancy format + +The only thing we change here is adding `log_only=True` to the +.display() call + """ + delegated_vars = result._result.get('_ansible_delegated_vars', None) + self._clean_results(result._result, result._task.action) + if result._task.action in ('include', 'import_role'): + return + elif result._result.get('changed', False): + if delegated_vars: + msg = "changed: [%s -> %s]" % (result._host.get_name(), delegated_vars['ansible_host']) + else: + msg = "changed: [%s]" % result._host.get_name() + color = C.COLOR_CHANGED + else: + if delegated_vars: + msg = "ok: [%s -> %s]" % (result._host.get_name(), delegated_vars['ansible_host']) + else: + msg = "ok: [%s]" % result._host.get_name() + color = C.COLOR_OK + + if result._task.loop and 'results' in result._result: + self._process_items(result) + else: + + if (self._display.verbosity > 0 or '_ansible_verbose_always' in result._result) and '_ansible_verbose_override' not in result._result: + msg += " => %s" % (self._dump_results(result._result),) + self._display.display(msg, color=color, log_only=True) + + self._handle_warnings(result._result) + + def v2_runner_item_on_ok(self, result): + """Print out task results for items you're iterating over""" + delegated_vars = result._result.get('_ansible_delegated_vars', None) + if result._task.action in ('include', 'import_role'): + return + elif result._result.get('changed', False): + msg = 'changed' + color = C.COLOR_CHANGED + else: + msg = 'ok' + color = C.COLOR_OK + + if delegated_vars: + msg += ": [%s -> %s]" % (result._host.get_name(), delegated_vars['ansible_host']) + else: + msg += ": [%s]" % result._host.get_name() + + msg += " => (item=%s)" % (self._get_item(result._result),) + + if (self._display.verbosity > 0 or '_ansible_verbose_always' in result._result) and '_ansible_verbose_override' not in result._result: + msg += " => %s" % self._dump_results(result._result) + self._display.display(msg, color=color, log_only=True) + + def v2_runner_item_on_skipped(self, result): + """Print out task results when an item is skipped""" + if C.DISPLAY_SKIPPED_HOSTS: + msg = "skipping: [%s] => (item=%s) " % (result._host.get_name(), self._get_item(result._result)) + if (self._display.verbosity > 0 or '_ansible_verbose_always' in result._result) and '_ansible_verbose_override' not in result._result: + msg += " => %s" % self._dump_results(result._result) + self._display.display(msg, color=C.COLOR_SKIP, log_only=True) + + def v2_runner_on_skipped(self, result): + """Print out task results when a task (or something else?) is skipped""" + if C.DISPLAY_SKIPPED_HOSTS: + if result._task.loop and 'results' in result._result: + self._process_items(result) + else: + msg = "skipping: [%s]" % result._host.get_name() + if (self._display.verbosity > 0 or '_ansible_verbose_always' in result._result) and '_ansible_verbose_override' not in result._result: + msg += " => %s" % self._dump_results(result._result) + self._display.display(msg, color=C.COLOR_SKIP, log_only=True) + + def v2_playbook_on_notify(self, res, handler): + """What happens when a task result is 'changed' and the task has a +'notify' list attached. + """ + self._display.display("skipping: no hosts matched", color=C.COLOR_SKIP, log_only=True) + + ###################################################################### + # So we can bubble up errors to the top + def v2_runner_on_failed(self, result, ignore_errors=False): + """I guess this is when an entire task has failed?""" + + if self._play.strategy == 'free' and self._last_task_banner != result._task._uuid: + self._print_task_banner(result._task) + + delegated_vars = result._result.get('_ansible_delegated_vars', None) + if 'exception' in result._result: + if self._display.verbosity < 3: + # extract just the actual error message from the exception text + error = result._result['exception'].strip().split('\n')[-1] + msg = "An exception occurred during task execution. To see the full traceback, use -vvv. The error was: %s" % error + else: + msg = "An exception occurred during task execution. The full traceback is:\n" + result._result['exception'] + + self._display.display(msg, color=C.COLOR_ERROR) + + if result._task.loop and 'results' in result._result: + self._process_items(result) + + else: + if delegated_vars: + self._display.display("fatal: [%s -> %s]: FAILED! => %s" % (result._host.get_name(), delegated_vars['ansible_host'], self._dump_results(result._result)), color=C.COLOR_ERROR) + else: + self._display.display("fatal: [%s]: FAILED! => %s" % (result._host.get_name(), self._dump_results(result._result)), color=C.COLOR_ERROR) + + if ignore_errors: + self._display.display("...ignoring", color=C.COLOR_SKIP) + + def v2_runner_item_on_failed(self, result): + """When an item in a task fails.""" + delegated_vars = result._result.get('_ansible_delegated_vars', None) + if 'exception' in result._result: + if self._display.verbosity < 3: + # extract just the actual error message from the exception text + error = result._result['exception'].strip().split('\n')[-1] + msg = "An exception occurred during task execution. To see the full traceback, use -vvv. The error was: %s" % error + else: + msg = "An exception occurred during task execution. The full traceback is:\n" + result._result['exception'] + + self._display.display(msg, color=C.COLOR_ERROR) + + msg = "failed: " + if delegated_vars: + msg += "[%s -> %s]" % (result._host.get_name(), delegated_vars['ansible_host']) + else: + msg += "[%s]" % (result._host.get_name()) + + self._display.display(msg + " (item=%s) => %s" % (self._get_item(result._result), self._dump_results(result._result)), color=C.COLOR_ERROR) + self._handle_warnings(result._result) + + ###################################################################### + def v2_playbook_on_stats(self, stats): + """Print the final playbook run stats""" + self._display.display("", screen_only=True) + self.banner("PLAY RECAP") + + hosts = sorted(stats.processed.keys()) + for h in hosts: + t = stats.summarize(h) + + self._display.display( + u"%s : %s %s %s %s" % ( + hostcolor(h, t), + colorize(u'ok', t['ok'], C.COLOR_OK), + colorize(u'changed', t['changed'], C.COLOR_CHANGED), + colorize(u'unreachable', t['unreachable'], C.COLOR_UNREACHABLE), + colorize(u'failed', t['failures'], C.COLOR_ERROR)), + screen_only=True + ) + + self._display.display( + u"%s : %s %s %s %s" % ( + hostcolor(h, t, False), + colorize(u'ok', t['ok'], None), + colorize(u'changed', t['changed'], None), + colorize(u'unreachable', t['unreachable'], None), + colorize(u'failed', t['failures'], None)), + log_only=True + ) + + self._display.display("", screen_only=True) + self._display.display("", screen_only=True) + + # Some plays are conditional and won't run (such as load + # balancers) if they aren't required. Sometimes plays are + # conditionally included later in the run. Let the user know + # about this to avoid potential confusion. + if self.plays_total_ran != self.plays_count: + print("Installation Complete: Note: Play count is only an estimate, some plays may have been skipped or dynamically added") + self._display.display("", screen_only=True) diff --git a/roles/lib_utils/filter_plugins/oo_cert_expiry.py b/roles/lib_utils/filter_plugins/oo_cert_expiry.py new file mode 100644 index 000000000..58b228fee --- /dev/null +++ b/roles/lib_utils/filter_plugins/oo_cert_expiry.py @@ -0,0 +1,66 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- +""" +Custom filters for use in openshift-ansible +""" + + +# Disabling too-many-public-methods, since filter methods are necessarily +# public +# pylint: disable=too-many-public-methods +class FilterModule(object): + """ Custom ansible filters """ + + @staticmethod + def oo_cert_expiry_results_to_json(hostvars, play_hosts): + """Takes results (`hostvars`) from the openshift_cert_expiry role +check and serializes them into proper machine-readable JSON +output. This filter parameter **MUST** be the playbook `hostvars` +variable. The `play_hosts` parameter is so we know what to loop over +when we're extrating the values. + +Returns: + +Results are collected into two top-level keys under the `json_results` +dict: + +* `json_results.data` [dict] - Each individual host check result, keys are hostnames +* `json_results.summary` [dict] - Summary of number of `warning` and `expired` +certificates + +Example playbook usage: + + - name: Generate expiration results JSON + run_once: yes + delegate_to: localhost + when: openshift_certificate_expiry_save_json_results|bool + copy: + content: "{{ hostvars|oo_cert_expiry_results_to_json() }}" + dest: "{{ openshift_certificate_expiry_json_results_path }}" + + """ + json_result = { + 'data': {}, + 'summary': {}, + } + + for host in play_hosts: + json_result['data'][host] = hostvars[host]['check_results']['check_results'] + + total_warnings = sum([hostvars[h]['check_results']['summary']['warning'] for h in play_hosts]) + total_expired = sum([hostvars[h]['check_results']['summary']['expired'] for h in play_hosts]) + total_ok = sum([hostvars[h]['check_results']['summary']['ok'] for h in play_hosts]) + total_total = sum([hostvars[h]['check_results']['summary']['total'] for h in play_hosts]) + + json_result['summary']['warning'] = total_warnings + json_result['summary']['expired'] = total_expired + json_result['summary']['ok'] = total_ok + json_result['summary']['total'] = total_total + + return json_result + + def filters(self): + """ returns a mapping of filters to methods """ + return { + "oo_cert_expiry_results_to_json": self.oo_cert_expiry_results_to_json, + } diff --git a/roles/lib_utils/filter_plugins/oo_filters.py b/roles/lib_utils/filter_plugins/oo_filters.py new file mode 100644 index 000000000..9f73510c4 --- /dev/null +++ b/roles/lib_utils/filter_plugins/oo_filters.py @@ -0,0 +1,627 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- +# pylint: disable=too-many-lines +""" +Custom filters for use in openshift-ansible +""" +import os +import pdb +import random +import re + +from base64 import b64encode +from collections import Mapping +# pylint no-name-in-module and import-error disabled here because pylint +# fails to properly detect the packages when installed in a virtualenv +from distutils.util import strtobool # pylint:disable=no-name-in-module,import-error +from operator import itemgetter + +import yaml + +from ansible import errors +from ansible.parsing.yaml.dumper import AnsibleDumper + +# pylint: disable=import-error,no-name-in-module +from ansible.module_utils.six import string_types, u +# pylint: disable=import-error,no-name-in-module +from ansible.module_utils.six.moves.urllib.parse import urlparse + +HAS_OPENSSL = False +try: + import OpenSSL.crypto + HAS_OPENSSL = True +except ImportError: + pass + + +# pylint: disable=C0103 + +def lib_utils_oo_pdb(arg): + """ This pops you into a pdb instance where arg is the data passed in + from the filter. + Ex: "{{ hostvars | lib_utils_oo_pdb }}" + """ + pdb.set_trace() + return arg + + +def get_attr(data, attribute=None): + """ This looks up dictionary attributes of the form a.b.c and returns + the value. + + If the key isn't present, None is returned. + Ex: data = {'a': {'b': {'c': 5}}} + attribute = "a.b.c" + returns 5 + """ + if not attribute: + raise errors.AnsibleFilterError("|failed expects attribute to be set") + + ptr = data + for attr in attribute.split('.'): + if attr in ptr: + ptr = ptr[attr] + else: + ptr = None + break + + return ptr + + +def oo_flatten(data): + """ This filter plugin will flatten a list of lists + """ + if not isinstance(data, list): + raise errors.AnsibleFilterError("|failed expects to flatten a List") + + return [item for sublist in data for item in sublist] + + +def lib_utils_oo_collect(data_list, attribute=None, filters=None): + """ This takes a list of dict and collects all attributes specified into a + list. If filter is specified then we will include all items that + match _ALL_ of filters. If a dict entry is missing the key in a + filter it will be excluded from the match. + Ex: data_list = [ {'a':1, 'b':5, 'z': 'z'}, # True, return + {'a':2, 'z': 'z'}, # True, return + {'a':3, 'z': 'z'}, # True, return + {'a':4, 'z': 'b'}, # FAILED, obj['z'] != obj['z'] + ] + attribute = 'a' + filters = {'z': 'z'} + returns [1, 2, 3] + + This also deals with lists of lists with dict as elements. + Ex: data_list = [ + [ {'a':1, 'b':5, 'z': 'z'}, # True, return + {'a':2, 'b':6, 'z': 'z'} # True, return + ], + [ {'a':3, 'z': 'z'}, # True, return + {'a':4, 'z': 'b'} # FAILED, obj['z'] != obj['z'] + ], + {'a':5, 'z': 'z'}, # True, return + ] + attribute = 'a' + filters = {'z': 'z'} + returns [1, 2, 3, 5] + """ + if not isinstance(data_list, list): + raise errors.AnsibleFilterError("lib_utils_oo_collect expects to filter on a List") + + if not attribute: + raise errors.AnsibleFilterError("lib_utils_oo_collect expects attribute to be set") + + data = [] + retval = [] + + for item in data_list: + if isinstance(item, list): + retval.extend(lib_utils_oo_collect(item, attribute, filters)) + else: + data.append(item) + + if filters is not None: + if not isinstance(filters, dict): + raise errors.AnsibleFilterError( + "lib_utils_oo_collect expects filter to be a dict") + retval.extend([get_attr(d, attribute) for d in data if ( + all([d.get(key, None) == filters[key] for key in filters]))]) + else: + retval.extend([get_attr(d, attribute) for d in data]) + + retval = [val for val in retval if val is not None] + + return retval + + +def lib_utils_oo_select_keys_from_list(data, keys): + """ This returns a list, which contains the value portions for the keys + Ex: data = { 'a':1, 'b':2, 'c':3 } + keys = ['a', 'c'] + returns [1, 3] + """ + + if not isinstance(data, list): + raise errors.AnsibleFilterError("|lib_utils_oo_select_keys_from_list failed expects to filter on a list") + + if not isinstance(keys, list): + raise errors.AnsibleFilterError("|lib_utils_oo_select_keys_from_list failed expects first param is a list") + + # Gather up the values for the list of keys passed in + retval = [lib_utils_oo_select_keys(item, keys) for item in data] + + return oo_flatten(retval) + + +def lib_utils_oo_select_keys(data, keys): + """ This returns a list, which contains the value portions for the keys + Ex: data = { 'a':1, 'b':2, 'c':3 } + keys = ['a', 'c'] + returns [1, 3] + """ + + if not isinstance(data, Mapping): + raise errors.AnsibleFilterError("|lib_utils_oo_select_keys failed expects to filter on a dict or object") + + if not isinstance(keys, list): + raise errors.AnsibleFilterError("|lib_utils_oo_select_keys failed expects first param is a list") + + # Gather up the values for the list of keys passed in + retval = [data[key] for key in keys if key in data] + + return retval + + +def lib_utils_oo_prepend_strings_in_list(data, prepend): + """ This takes a list of strings and prepends a string to each item in the + list + Ex: data = ['cart', 'tree'] + prepend = 'apple-' + returns ['apple-cart', 'apple-tree'] + """ + if not isinstance(data, list): + raise errors.AnsibleFilterError("|failed expects first param is a list") + if not all(isinstance(x, string_types) for x in data): + raise errors.AnsibleFilterError("|failed expects first param is a list" + " of strings") + retval = [prepend + s for s in data] + return retval + + +def lib_utils_oo_dict_to_list_of_dict(data, key_title='key', value_title='value'): + """Take a dict and arrange them as a list of dicts + + Input data: + {'region': 'infra', 'test_k': 'test_v'} + + Return data: + [{'key': 'region', 'value': 'infra'}, {'key': 'test_k', 'value': 'test_v'}] + + Written for use of the oc_label module + """ + if not isinstance(data, dict): + # pylint: disable=line-too-long + raise errors.AnsibleFilterError("|failed expects first param is a dict. Got %s. Type: %s" % (str(data), str(type(data)))) + + rval = [] + for label in data.items(): + rval.append({key_title: label[0], value_title: label[1]}) + + return rval + + +def oo_ami_selector(data, image_name): + """ This takes a list of amis and an image name and attempts to return + the latest ami. + """ + if not isinstance(data, list): + raise errors.AnsibleFilterError("|failed expects first param is a list") + + if not data: + return None + else: + if image_name is None or not image_name.endswith('_*'): + ami = sorted(data, key=itemgetter('name'), reverse=True)[0] + return ami['ami_id'] + else: + ami_info = [(ami, ami['name'].split('_')[-1]) for ami in data] + ami = sorted(ami_info, key=itemgetter(1), reverse=True)[0][0] + return ami['ami_id'] + + +def lib_utils_oo_split(string, separator=','): + """ This splits the input string into a list. If the input string is + already a list we will return it as is. + """ + if isinstance(string, list): + return string + return string.split(separator) + + +def lib_utils_oo_dict_to_keqv_list(data): + """Take a dict and return a list of k=v pairs + + Input data: + {'a': 1, 'b': 2} + + Return data: + ['a=1', 'b=2'] + """ + return ['='.join(str(e) for e in x) for x in data.items()] + + +def lib_utils_oo_list_to_dict(lst, separator='='): + """ This converts a list of ["k=v"] to a dictionary {k: v}. + """ + kvs = [i.split(separator) for i in lst] + return {k: v for k, v in kvs} + + +def haproxy_backend_masters(hosts, port): + """ This takes an array of dicts and returns an array of dicts + to be used as a backend for the haproxy role + """ + servers = [] + for idx, host_info in enumerate(hosts): + server = dict(name="master%s" % idx) + server_ip = host_info['openshift']['common']['ip'] + server['address'] = "%s:%s" % (server_ip, port) + server['opts'] = 'check' + servers.append(server) + return servers + + +# pylint: disable=too-many-branches +def lib_utils_oo_parse_named_certificates(certificates, named_certs_dir, internal_hostnames): + """ Parses names from list of certificate hashes. + + Ex: certificates = [{ "certfile": "/root/custom1.crt", + "keyfile": "/root/custom1.key", + "cafile": "/root/custom-ca1.crt" }, + { "certfile": "custom2.crt", + "keyfile": "custom2.key", + "cafile": "custom-ca2.crt" }] + + returns [{ "certfile": "/etc/origin/master/named_certificates/custom1.crt", + "keyfile": "/etc/origin/master/named_certificates/custom1.key", + "cafile": "/etc/origin/master/named_certificates/custom-ca1.crt", + "names": [ "public-master-host.com", + "other-master-host.com" ] }, + { "certfile": "/etc/origin/master/named_certificates/custom2.crt", + "keyfile": "/etc/origin/master/named_certificates/custom2.key", + "cafile": "/etc/origin/master/named_certificates/custom-ca-2.crt", + "names": [ "some-hostname.com" ] }] + """ + if not isinstance(named_certs_dir, string_types): + raise errors.AnsibleFilterError("|failed expects named_certs_dir is str or unicode") + + if not isinstance(internal_hostnames, list): + raise errors.AnsibleFilterError("|failed expects internal_hostnames is list") + + if not HAS_OPENSSL: + raise errors.AnsibleFilterError("|missing OpenSSL python bindings") + + for certificate in certificates: + if 'names' in certificate.keys(): + continue + else: + certificate['names'] = [] + + if not os.path.isfile(certificate['certfile']) or not os.path.isfile(certificate['keyfile']): + raise errors.AnsibleFilterError("|certificate and/or key does not exist '%s', '%s'" % + (certificate['certfile'], certificate['keyfile'])) + + try: + st_cert = open(certificate['certfile'], 'rt').read() + cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, st_cert) + certificate['names'].append(str(cert.get_subject().commonName.decode())) + for i in range(cert.get_extension_count()): + if cert.get_extension(i).get_short_name() == 'subjectAltName': + for name in str(cert.get_extension(i)).replace('DNS:', '').split(', '): + certificate['names'].append(name) + except Exception: + raise errors.AnsibleFilterError(("|failed to parse certificate '%s', " % certificate['certfile'] + + "please specify certificate names in host inventory")) + + certificate['names'] = list(set(certificate['names'])) + if 'cafile' not in certificate: + certificate['names'] = [name for name in certificate['names'] if name not in internal_hostnames] + if not certificate['names']: + raise errors.AnsibleFilterError(("|failed to parse certificate '%s' or " % certificate['certfile'] + + "detected a collision with internal hostname, please specify " + + "certificate names in host inventory")) + + for certificate in certificates: + # Update paths for configuration + certificate['certfile'] = os.path.join(named_certs_dir, os.path.basename(certificate['certfile'])) + certificate['keyfile'] = os.path.join(named_certs_dir, os.path.basename(certificate['keyfile'])) + if 'cafile' in certificate: + certificate['cafile'] = os.path.join(named_certs_dir, os.path.basename(certificate['cafile'])) + return certificates + + +def lib_utils_oo_generate_secret(num_bytes): + """ generate a session secret """ + + if not isinstance(num_bytes, int): + raise errors.AnsibleFilterError("|failed expects num_bytes is int") + + return b64encode(os.urandom(num_bytes)).decode('utf-8') + + +def lib_utils_to_padded_yaml(data, level=0, indent=2, **kw): + """ returns a yaml snippet padded to match the indent level you specify """ + if data in [None, ""]: + return "" + + try: + transformed = u(yaml.dump(data, indent=indent, allow_unicode=True, + default_flow_style=False, + Dumper=AnsibleDumper, **kw)) + padded = "\n".join([" " * level * indent + line for line in transformed.splitlines()]) + return "\n{0}".format(padded) + except Exception as my_e: + raise errors.AnsibleFilterError('Failed to convert: %s' % my_e) + + +def lib_utils_oo_pods_match_component(pods, deployment_type, component): + """ Filters a list of Pods and returns the ones matching the deployment_type and component + """ + if not isinstance(pods, list): + raise errors.AnsibleFilterError("failed expects to filter on a list") + if not isinstance(deployment_type, string_types): + raise errors.AnsibleFilterError("failed expects deployment_type to be a string") + if not isinstance(component, string_types): + raise errors.AnsibleFilterError("failed expects component to be a string") + + image_prefix = 'openshift/origin-' + if deployment_type == 'openshift-enterprise': + image_prefix = 'openshift3/ose-' + + matching_pods = [] + image_regex = image_prefix + component + r'.*' + for pod in pods: + for container in pod['spec']['containers']: + if re.search(image_regex, container['image']): + matching_pods.append(pod) + break # stop here, don't add a pod more than once + + return matching_pods + + +def lib_utils_oo_image_tag_to_rpm_version(version, include_dash=False): + """ Convert an image tag string to an RPM version if necessary + Empty strings and strings that are already in rpm version format + are ignored. Also remove non semantic version components. + + Ex. v3.2.0.10 -> -3.2.0.10 + v1.2.0-rc1 -> -1.2.0 + """ + if not isinstance(version, string_types): + raise errors.AnsibleFilterError("|failed expects a string or unicode") + if version.startswith("v"): + version = version[1:] + # Strip release from requested version, we no longer support this. + version = version.split('-')[0] + + if include_dash and version and not version.startswith("-"): + version = "-" + version + + return version + + +def lib_utils_oo_hostname_from_url(url): + """ Returns the hostname contained in a URL + + Ex: https://ose3-master.example.com/v1/api -> ose3-master.example.com + """ + if not isinstance(url, string_types): + raise errors.AnsibleFilterError("|failed expects a string or unicode") + parse_result = urlparse(url) + if parse_result.netloc != '': + return parse_result.netloc + else: + # netloc wasn't parsed, assume url was missing scheme and path + return parse_result.path + + +# pylint: disable=invalid-name, unused-argument +def lib_utils_oo_loadbalancer_frontends( + api_port, servers_hostvars, use_nuage=False, nuage_rest_port=None): + """TODO: Document me.""" + loadbalancer_frontends = [{'name': 'atomic-openshift-api', + 'mode': 'tcp', + 'options': ['tcplog'], + 'binds': ["*:{0}".format(api_port)], + 'default_backend': 'atomic-openshift-api'}] + if bool(strtobool(str(use_nuage))) and nuage_rest_port is not None: + loadbalancer_frontends.append({'name': 'nuage-monitor', + 'mode': 'tcp', + 'options': ['tcplog'], + 'binds': ["*:{0}".format(nuage_rest_port)], + 'default_backend': 'nuage-monitor'}) + return loadbalancer_frontends + + +# pylint: disable=invalid-name +def lib_utils_oo_loadbalancer_backends( + api_port, servers_hostvars, use_nuage=False, nuage_rest_port=None): + """TODO: Document me.""" + loadbalancer_backends = [{'name': 'atomic-openshift-api', + 'mode': 'tcp', + 'option': 'tcplog', + 'balance': 'source', + 'servers': haproxy_backend_masters(servers_hostvars, api_port)}] + if bool(strtobool(str(use_nuage))) and nuage_rest_port is not None: + # pylint: disable=line-too-long + loadbalancer_backends.append({'name': 'nuage-monitor', + 'mode': 'tcp', + 'option': 'tcplog', + 'balance': 'source', + 'servers': haproxy_backend_masters(servers_hostvars, nuage_rest_port)}) + return loadbalancer_backends + + +def lib_utils_oo_chomp_commit_offset(version): + """Chomp any "+git.foo" commit offset string from the given `version` + and return the modified version string. + +Ex: +- chomp_commit_offset(None) => None +- chomp_commit_offset(1337) => "1337" +- chomp_commit_offset("v3.4.0.15+git.derp") => "v3.4.0.15" +- chomp_commit_offset("v3.4.0.15") => "v3.4.0.15" +- chomp_commit_offset("v1.3.0+52492b4") => "v1.3.0" + """ + if version is None: + return version + else: + # Stringify, just in case it's a Number type. Split by '+' and + # return the first split. No concerns about strings without a + # '+', .split() returns an array of the original string. + return str(version).split('+')[0] + + +def lib_utils_oo_random_word(length, source='abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'): + """Generates a random string of given length from a set of alphanumeric characters. + The default source uses [a-z][A-Z][0-9] + Ex: + - lib_utils_oo_random_word(3) => aB9 + - lib_utils_oo_random_word(4, source='012') => 0123 + """ + return ''.join(random.choice(source) for i in range(length)) + + +def lib_utils_oo_contains_rule(source, apiGroups, resources, verbs): + '''Return true if the specified rule is contained within the provided source''' + + rules = source['rules'] + + if rules: + for rule in rules: + if set(rule['apiGroups']) == set(apiGroups): + if set(rule['resources']) == set(resources): + if set(rule['verbs']) == set(verbs): + return True + + return False + + +def lib_utils_oo_selector_to_string_list(user_dict): + """Convert a dict of selectors to a key=value list of strings + +Given input of {'region': 'infra', 'zone': 'primary'} returns a list +of items as ['region=infra', 'zone=primary'] + """ + selectors = [] + for key in user_dict: + selectors.append("{}={}".format(key, user_dict[key])) + return selectors + + +def lib_utils_oo_filter_sa_secrets(sa_secrets, secret_hint='-token-'): + """Parse the Service Account Secrets list, `sa_secrets`, (as from +oc_serviceaccount_secret:state=list) and return the name of the secret +containing the `secret_hint` string. For example, by default this will +return the name of the secret holding the SA bearer token. + +Only provide the 'results' object to this filter. This filter expects +to receive a list like this: + + [ + { + "name": "management-admin-dockercfg-p31s2" + }, + { + "name": "management-admin-token-bnqsh" + } + ] + + +Returns: + +* `secret_name` [string] - The name of the secret matching the + `secret_hint` parameter. By default this is the secret holding the + SA's bearer token. + +Example playbook usage: + +Register a return value from oc_serviceaccount_secret with and pass +that result to this filter plugin. + + - name: Get all SA Secrets + oc_serviceaccount_secret: + state: list + service_account: management-admin + namespace: management-infra + register: sa + + - name: Save the SA bearer token secret name + set_fact: + management_token: "{{ sa.results | lib_utils_oo_filter_sa_secrets }}" + + - name: Get the SA bearer token value + oc_secret: + state: list + name: "{{ management_token }}" + namespace: management-infra + decode: true + register: sa_secret + + - name: Print the bearer token value + debug: + var: sa_secret.results.decoded.token + + """ + secret_name = None + + for secret in sa_secrets: + # each secret is a hash + if secret['name'].find(secret_hint) == -1: + continue + else: + secret_name = secret['name'] + break + + return secret_name + + +def map_from_pairs(source, delim="="): + ''' Returns a dict given the source and delim delimited ''' + if source == '': + return dict() + + return dict(item.split(delim) for item in source.split(",")) + + +class FilterModule(object): + """ Custom ansible filter mapping """ + + # pylint: disable=no-self-use, too-few-public-methods + def filters(self): + """ returns a mapping of filters to methods """ + return { + "lib_utils_oo_select_keys": lib_utils_oo_select_keys, + "lib_utils_oo_select_keys_from_list": lib_utils_oo_select_keys_from_list, + "lib_utils_oo_chomp_commit_offset": lib_utils_oo_chomp_commit_offset, + "lib_utils_oo_collect": lib_utils_oo_collect, + "lib_utils_oo_pdb": lib_utils_oo_pdb, + "lib_utils_oo_prepend_strings_in_list": lib_utils_oo_prepend_strings_in_list, + "lib_utils_oo_dict_to_list_of_dict": lib_utils_oo_dict_to_list_of_dict, + "lib_utils_oo_split": lib_utils_oo_split, + "lib_utils_oo_dict_to_keqv_list": lib_utils_oo_dict_to_keqv_list, + "lib_utils_oo_list_to_dict": lib_utils_oo_list_to_dict, + "lib_utils_oo_parse_named_certificates": lib_utils_oo_parse_named_certificates, + "lib_utils_oo_generate_secret": lib_utils_oo_generate_secret, + "lib_utils_oo_pods_match_component": lib_utils_oo_pods_match_component, + "lib_utils_oo_image_tag_to_rpm_version": lib_utils_oo_image_tag_to_rpm_version, + "lib_utils_oo_hostname_from_url": lib_utils_oo_hostname_from_url, + "lib_utils_oo_loadbalancer_frontends": lib_utils_oo_loadbalancer_frontends, + "lib_utils_oo_loadbalancer_backends": lib_utils_oo_loadbalancer_backends, + "lib_utils_to_padded_yaml": lib_utils_to_padded_yaml, + "lib_utils_oo_random_word": lib_utils_oo_random_word, + "lib_utils_oo_contains_rule": lib_utils_oo_contains_rule, + "lib_utils_oo_selector_to_string_list": lib_utils_oo_selector_to_string_list, + "lib_utils_oo_filter_sa_secrets": lib_utils_oo_filter_sa_secrets, + "map_from_pairs": map_from_pairs + } diff --git a/roles/lib_utils/filter_plugins/openshift_aws_filters.py b/roles/lib_utils/filter_plugins/openshift_aws_filters.py new file mode 100644 index 000000000..dfcb11da3 --- /dev/null +++ b/roles/lib_utils/filter_plugins/openshift_aws_filters.py @@ -0,0 +1,74 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- +''' +Custom filters for use in openshift_aws +''' + +from ansible import errors + + +class FilterModule(object): + ''' Custom ansible filters for use by openshift_aws role''' + + @staticmethod + def scale_groups_serial(scale_group_info, upgrade=False): + ''' This function will determine what the deployment serial should be and return it + + Search through the tags and find the deployment_serial tag. Once found, + determine if an increment is needed during an upgrade. + if upgrade is true then increment the serial and return it + else return the serial + ''' + if scale_group_info == []: + return 1 + + scale_group_info = scale_group_info[0] + + if not isinstance(scale_group_info, dict): + raise errors.AnsibleFilterError("|filter plugin failed: Expected scale_group_info to be a dict") + + serial = None + + for tag in scale_group_info['tags']: + if tag['key'] == 'deployment_serial': + serial = int(tag['value']) + if upgrade: + serial += 1 + break + else: + raise errors.AnsibleFilterError("|filter plugin failed: deployment_serial tag was not found") + + return serial + + @staticmethod + def scale_groups_match_capacity(scale_group_info): + ''' This function will verify that the scale group instance count matches + the scale group desired capacity + + ''' + for scale_group in scale_group_info: + if scale_group['desired_capacity'] != len(scale_group['instances']): + return False + + return True + + @staticmethod + def build_instance_tags(clusterid): + ''' This function will return a dictionary of the instance tags. + + The main desire to have this inside of a filter_plugin is that we + need to build the following key. + + {"kubernetes.io/cluster/{{ openshift_aws_clusterid }}": "{{ openshift_aws_clusterid}}"} + + ''' + tags = {'clusterid': clusterid, + 'kubernetes.io/cluster/{}'.format(clusterid): clusterid} + + return tags + + def filters(self): + ''' returns a mapping of filters to methods ''' + return {'build_instance_tags': self.build_instance_tags, + 'scale_groups_match_capacity': self.scale_groups_match_capacity, + 'scale_groups_serial': self.scale_groups_serial} diff --git a/roles/lib_utils/filter_plugins/openshift_hosted_filters.py b/roles/lib_utils/filter_plugins/openshift_hosted_filters.py new file mode 100644 index 000000000..003ce5f9e --- /dev/null +++ b/roles/lib_utils/filter_plugins/openshift_hosted_filters.py @@ -0,0 +1,42 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- +''' +Custom filters for use in openshift_hosted +''' + + +class FilterModule(object): + ''' Custom ansible filters for use by openshift_hosted role''' + + @staticmethod + def get_router_replicas(replicas=None, router_nodes=None): + ''' This function will return the number of replicas + based on the results from the defined + openshift_hosted_router_replicas OR + the query from oc_obj on openshift nodes with a selector OR + default to 1 + + ''' + # We always use what they've specified if they've specified a value + if replicas is not None: + return replicas + + replicas = 1 + + # Ignore boolean expression limit of 5. + # pylint: disable=too-many-boolean-expressions + if (isinstance(router_nodes, dict) and + 'results' in router_nodes and + 'results' in router_nodes['results'] and + isinstance(router_nodes['results']['results'], list) and + len(router_nodes['results']['results']) > 0 and + 'items' in router_nodes['results']['results'][0]): + + if len(router_nodes['results']['results'][0]['items']) > 0: + replicas = len(router_nodes['results']['results'][0]['items']) + + return replicas + + def filters(self): + ''' returns a mapping of filters to methods ''' + return {'get_router_replicas': self.get_router_replicas} diff --git a/roles/lib_utils/filter_plugins/openshift_master.py b/roles/lib_utils/filter_plugins/openshift_master.py new file mode 100644 index 000000000..e67b19c28 --- /dev/null +++ b/roles/lib_utils/filter_plugins/openshift_master.py @@ -0,0 +1,528 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- +''' +Custom filters for use in openshift-master +''' +import copy +import sys + +from ansible import errors +from ansible.parsing.yaml.dumper import AnsibleDumper +from ansible.plugins.filter.core import to_bool as ansible_bool + +from ansible.module_utils.six import string_types, u + +import yaml + + +class IdentityProviderBase(object): + """ IdentityProviderBase + + Attributes: + name (str): Identity provider Name + login (bool): Is this identity provider a login provider? + challenge (bool): Is this identity provider a challenge provider? + provider (dict): Provider specific config + _idp (dict): internal copy of the IDP dict passed in + _required (list): List of lists of strings for required attributes + _optional (list): List of lists of strings for optional attributes + _allow_additional (bool): Does this provider support attributes + not in _required and _optional + + Args: + api_version(str): OpenShift config version + idp (dict): idp config dict + + Raises: + AnsibleFilterError: + """ + # disabling this check since the number of instance attributes are + # necessary for this class + # pylint: disable=too-many-instance-attributes + def __init__(self, api_version, idp): + if api_version not in ['v1']: + raise errors.AnsibleFilterError("|failed api version {0} unknown".format(api_version)) + + self._idp = copy.deepcopy(idp) + + if 'name' not in self._idp: + raise errors.AnsibleFilterError("|failed identity provider missing a name") + + if 'kind' not in self._idp: + raise errors.AnsibleFilterError("|failed identity provider missing a kind") + + self.name = self._idp.pop('name') + self.login = ansible_bool(self._idp.pop('login', False)) + self.challenge = ansible_bool(self._idp.pop('challenge', False)) + self.provider = dict(apiVersion=api_version, kind=self._idp.pop('kind')) + + mm_keys = ('mappingMethod', 'mapping_method') + mapping_method = None + for key in mm_keys: + if key in self._idp: + mapping_method = self._idp.pop(key) + if mapping_method is None: + mapping_method = self.get_default('mappingMethod') + self.mapping_method = mapping_method + + valid_mapping_methods = ['add', 'claim', 'generate', 'lookup'] + if self.mapping_method not in valid_mapping_methods: + raise errors.AnsibleFilterError("|failed unknown mapping method " + "for provider {0}".format(self.__class__.__name__)) + self._required = [] + self._optional = [] + self._allow_additional = True + + @staticmethod + def validate_idp_list(idp_list): + ''' validates a list of idps ''' + names = [x.name for x in idp_list] + if len(set(names)) != len(names): + raise errors.AnsibleFilterError("|failed more than one provider configured with the same name") + + for idp in idp_list: + idp.validate() + + def validate(self): + ''' validate an instance of this idp class ''' + pass + + @staticmethod + def get_default(key): + ''' get a default value for a given key ''' + if key == 'mappingMethod': + return 'claim' + else: + return None + + def set_provider_item(self, items, required=False): + ''' set a provider item based on the list of item names provided. ''' + for item in items: + provider_key = items[0] + if item in self._idp: + self.provider[provider_key] = self._idp.pop(item) + break + else: + default = self.get_default(provider_key) + if default is not None: + self.provider[provider_key] = default + elif required: + raise errors.AnsibleFilterError("|failed provider {0} missing " + "required key {1}".format(self.__class__.__name__, provider_key)) + + def set_provider_items(self): + ''' set the provider items for this idp ''' + for items in self._required: + self.set_provider_item(items, True) + for items in self._optional: + self.set_provider_item(items) + if self._allow_additional: + for key in self._idp.keys(): + self.set_provider_item([key]) + else: + if len(self._idp) > 0: + raise errors.AnsibleFilterError("|failed provider {0} " + "contains unknown keys " + "{1}".format(self.__class__.__name__, ', '.join(self._idp.keys()))) + + def to_dict(self): + ''' translate this idp to a dictionary ''' + return dict(name=self.name, challenge=self.challenge, + login=self.login, mappingMethod=self.mapping_method, + provider=self.provider) + + +class LDAPPasswordIdentityProvider(IdentityProviderBase): + """ LDAPPasswordIdentityProvider + + Attributes: + + Args: + api_version(str): OpenShift config version + idp (dict): idp config dict + + Raises: + AnsibleFilterError: + """ + def __init__(self, api_version, idp): + super(LDAPPasswordIdentityProvider, self).__init__(api_version, idp) + self._allow_additional = False + self._required += [['attributes'], ['url'], ['insecure']] + self._optional += [['ca'], + ['bindDN', 'bind_dn'], + ['bindPassword', 'bind_password']] + + self._idp['insecure'] = ansible_bool(self._idp.pop('insecure', False)) + + if 'attributes' in self._idp and 'preferred_username' in self._idp['attributes']: + pref_user = self._idp['attributes'].pop('preferred_username') + self._idp['attributes']['preferredUsername'] = pref_user + + def validate(self): + ''' validate this idp instance ''' + if not isinstance(self.provider['attributes'], dict): + raise errors.AnsibleFilterError("|failed attributes for provider " + "{0} must be a dictionary".format(self.__class__.__name__)) + + attrs = ['id', 'email', 'name', 'preferredUsername'] + for attr in attrs: + if attr in self.provider['attributes'] and not isinstance(self.provider['attributes'][attr], list): + raise errors.AnsibleFilterError("|failed {0} attribute for " + "provider {1} must be a list".format(attr, self.__class__.__name__)) + + unknown_attrs = set(self.provider['attributes'].keys()) - set(attrs) + if len(unknown_attrs) > 0: + raise errors.AnsibleFilterError("|failed provider {0} has unknown " + "attributes: {1}".format(self.__class__.__name__, ', '.join(unknown_attrs))) + + +class KeystonePasswordIdentityProvider(IdentityProviderBase): + """ KeystoneIdentityProvider + + Attributes: + + Args: + api_version(str): OpenShift config version + idp (dict): idp config dict + + Raises: + AnsibleFilterError: + """ + def __init__(self, api_version, idp): + super(KeystonePasswordIdentityProvider, self).__init__(api_version, idp) + self._allow_additional = False + self._required += [['url'], ['domainName', 'domain_name']] + self._optional += [['ca'], ['certFile', 'cert_file'], ['keyFile', 'key_file']] + + +class RequestHeaderIdentityProvider(IdentityProviderBase): + """ RequestHeaderIdentityProvider + + Attributes: + + Args: + api_version(str): OpenShift config version + idp (dict): idp config dict + + Raises: + AnsibleFilterError: + """ + def __init__(self, api_version, idp): + super(RequestHeaderIdentityProvider, self).__init__(api_version, idp) + self._allow_additional = False + self._required += [['headers']] + self._optional += [['challengeURL', 'challenge_url'], + ['loginURL', 'login_url'], + ['clientCA', 'client_ca'], + ['clientCommonNames', 'client_common_names'], + ['emailHeaders', 'email_headers'], + ['nameHeaders', 'name_headers'], + ['preferredUsernameHeaders', 'preferred_username_headers']] + + def validate(self): + ''' validate this idp instance ''' + if not isinstance(self.provider['headers'], list): + raise errors.AnsibleFilterError("|failed headers for provider {0} " + "must be a list".format(self.__class__.__name__)) + + +class AllowAllPasswordIdentityProvider(IdentityProviderBase): + """ AllowAllPasswordIdentityProvider + + Attributes: + + Args: + api_version(str): OpenShift config version + idp (dict): idp config dict + + Raises: + AnsibleFilterError: + """ + def __init__(self, api_version, idp): + super(AllowAllPasswordIdentityProvider, self).__init__(api_version, idp) + self._allow_additional = False + + +class DenyAllPasswordIdentityProvider(IdentityProviderBase): + """ DenyAllPasswordIdentityProvider + + Attributes: + + Args: + api_version(str): OpenShift config version + idp (dict): idp config dict + + Raises: + AnsibleFilterError: + """ + def __init__(self, api_version, idp): + super(DenyAllPasswordIdentityProvider, self).__init__(api_version, idp) + self._allow_additional = False + + +class HTPasswdPasswordIdentityProvider(IdentityProviderBase): + """ HTPasswdPasswordIdentity + + Attributes: + + Args: + api_version(str): OpenShift config version + idp (dict): idp config dict + + Raises: + AnsibleFilterError: + """ + def __init__(self, api_version, idp): + super(HTPasswdPasswordIdentityProvider, self).__init__(api_version, idp) + self._allow_additional = False + self._required += [['file', 'filename', 'fileName', 'file_name']] + + @staticmethod + def get_default(key): + if key == 'file': + return '/etc/origin/htpasswd' + else: + return IdentityProviderBase.get_default(key) + + +class BasicAuthPasswordIdentityProvider(IdentityProviderBase): + """ BasicAuthPasswordIdentityProvider + + Attributes: + + Args: + api_version(str): OpenShift config version + idp (dict): idp config dict + + Raises: + AnsibleFilterError: + """ + def __init__(self, api_version, idp): + super(BasicAuthPasswordIdentityProvider, self).__init__(api_version, idp) + self._allow_additional = False + self._required += [['url']] + self._optional += [['ca'], ['certFile', 'cert_file'], ['keyFile', 'key_file']] + + +class IdentityProviderOauthBase(IdentityProviderBase): + """ IdentityProviderOauthBase + + Attributes: + + Args: + api_version(str): OpenShift config version + idp (dict): idp config dict + + Raises: + AnsibleFilterError: + """ + def __init__(self, api_version, idp): + super(IdentityProviderOauthBase, self).__init__(api_version, idp) + self._allow_additional = False + self._required += [['clientID', 'client_id'], ['clientSecret', 'client_secret']] + + def validate(self): + ''' validate an instance of this idp class ''' + pass + + +class OpenIDIdentityProvider(IdentityProviderOauthBase): + """ OpenIDIdentityProvider + + Attributes: + + Args: + api_version(str): OpenShift config version + idp (dict): idp config dict + + Raises: + AnsibleFilterError: + """ + def __init__(self, api_version, idp): + IdentityProviderOauthBase.__init__(self, api_version, idp) + self._required += [['claims'], ['urls']] + self._optional += [['ca'], + ['extraScopes'], + ['extraAuthorizeParameters']] + if 'claims' in self._idp and 'preferred_username' in self._idp['claims']: + pref_user = self._idp['claims'].pop('preferred_username') + self._idp['claims']['preferredUsername'] = pref_user + if 'urls' in self._idp and 'user_info' in self._idp['urls']: + user_info = self._idp['urls'].pop('user_info') + self._idp['urls']['userInfo'] = user_info + if 'extra_scopes' in self._idp: + self._idp['extraScopes'] = self._idp.pop('extra_scopes') + if 'extra_authorize_parameters' in self._idp: + self._idp['extraAuthorizeParameters'] = self._idp.pop('extra_authorize_parameters') + + def validate(self): + ''' validate this idp instance ''' + if not isinstance(self.provider['claims'], dict): + raise errors.AnsibleFilterError("|failed claims for provider {0} " + "must be a dictionary".format(self.__class__.__name__)) + + for var, var_type in (('extraScopes', list), ('extraAuthorizeParameters', dict)): + if var in self.provider and not isinstance(self.provider[var], var_type): + raise errors.AnsibleFilterError("|failed {1} for provider " + "{0} must be a {2}".format(self.__class__.__name__, + var, + var_type.__class__.__name__)) + + required_claims = ['id'] + optional_claims = ['email', 'name', 'preferredUsername'] + all_claims = required_claims + optional_claims + + for claim in required_claims: + if claim in required_claims and claim not in self.provider['claims']: + raise errors.AnsibleFilterError("|failed {0} claim missing " + "for provider {1}".format(claim, self.__class__.__name__)) + + for claim in all_claims: + if claim in self.provider['claims'] and not isinstance(self.provider['claims'][claim], list): + raise errors.AnsibleFilterError("|failed {0} claims for " + "provider {1} must be a list".format(claim, self.__class__.__name__)) + + unknown_claims = set(self.provider['claims'].keys()) - set(all_claims) + if len(unknown_claims) > 0: + raise errors.AnsibleFilterError("|failed provider {0} has unknown " + "claims: {1}".format(self.__class__.__name__, ', '.join(unknown_claims))) + + if not isinstance(self.provider['urls'], dict): + raise errors.AnsibleFilterError("|failed urls for provider {0} " + "must be a dictionary".format(self.__class__.__name__)) + + required_urls = ['authorize', 'token'] + optional_urls = ['userInfo'] + all_urls = required_urls + optional_urls + + for url in required_urls: + if url not in self.provider['urls']: + raise errors.AnsibleFilterError("|failed {0} url missing for " + "provider {1}".format(url, self.__class__.__name__)) + + unknown_urls = set(self.provider['urls'].keys()) - set(all_urls) + if len(unknown_urls) > 0: + raise errors.AnsibleFilterError("|failed provider {0} has unknown " + "urls: {1}".format(self.__class__.__name__, ', '.join(unknown_urls))) + + +class GoogleIdentityProvider(IdentityProviderOauthBase): + """ GoogleIdentityProvider + + Attributes: + + Args: + api_version(str): OpenShift config version + idp (dict): idp config dict + + Raises: + AnsibleFilterError: + """ + def __init__(self, api_version, idp): + IdentityProviderOauthBase.__init__(self, api_version, idp) + self._optional += [['hostedDomain', 'hosted_domain']] + + def validate(self): + ''' validate this idp instance ''' + if self.challenge: + raise errors.AnsibleFilterError("|failed provider {0} does not " + "allow challenge authentication".format(self.__class__.__name__)) + + +class GitHubIdentityProvider(IdentityProviderOauthBase): + """ GitHubIdentityProvider + + Attributes: + + Args: + api_version(str): OpenShift config version + idp (dict): idp config dict + + Raises: + AnsibleFilterError: + """ + def __init__(self, api_version, idp): + IdentityProviderOauthBase.__init__(self, api_version, idp) + self._optional += [['organizations'], + ['teams']] + + def validate(self): + ''' validate this idp instance ''' + if self.challenge: + raise errors.AnsibleFilterError("|failed provider {0} does not " + "allow challenge authentication".format(self.__class__.__name__)) + + +class FilterModule(object): + ''' Custom ansible filters for use by the openshift_master role''' + + @staticmethod + def translate_idps(idps, api_version): + ''' Translates a list of dictionaries into a valid identityProviders config ''' + idp_list = [] + + if not isinstance(idps, list): + raise errors.AnsibleFilterError("|failed expects to filter on a list of identity providers") + for idp in idps: + if not isinstance(idp, dict): + raise errors.AnsibleFilterError("|failed identity providers must be a list of dictionaries") + + cur_module = sys.modules[__name__] + idp_class = getattr(cur_module, idp['kind'], None) + idp_inst = idp_class(api_version, idp) if idp_class is not None else IdentityProviderBase(api_version, idp) + idp_inst.set_provider_items() + idp_list.append(idp_inst) + + IdentityProviderBase.validate_idp_list(idp_list) + return u(yaml.dump([idp.to_dict() for idp in idp_list], + allow_unicode=True, + default_flow_style=False, + width=float("inf"), + Dumper=AnsibleDumper)) + + @staticmethod + def certificates_to_synchronize(hostvars, include_keys=True, include_ca=True): + ''' Return certificates to synchronize based on facts. ''' + if not issubclass(type(hostvars), dict): + raise errors.AnsibleFilterError("|failed expects hostvars is a dict") + certs = ['admin.crt', + 'admin.key', + 'admin.kubeconfig', + 'master.kubelet-client.crt', + 'master.kubelet-client.key', + 'master.proxy-client.crt', + 'master.proxy-client.key', + 'service-signer.crt', + 'service-signer.key'] + if bool(include_ca): + certs += ['ca.crt', 'ca.key', 'ca-bundle.crt', 'client-ca-bundle.crt'] + if bool(include_keys): + certs += ['serviceaccounts.private.key', + 'serviceaccounts.public.key'] + return certs + + @staticmethod + def oo_htpasswd_users_from_file(file_contents): + ''' return a dictionary of htpasswd users from htpasswd file contents ''' + htpasswd_entries = {} + if not isinstance(file_contents, string_types): + raise errors.AnsibleFilterError("failed, expects to filter on a string") + for line in file_contents.splitlines(): + user = None + passwd = None + if len(line) == 0: + continue + if ':' in line: + user, passwd = line.split(':', 1) + + if user is None or len(user) == 0 or passwd is None or len(passwd) == 0: + error_msg = "failed, expects each line to be a colon separated string representing the user and passwd" + raise errors.AnsibleFilterError(error_msg) + htpasswd_entries[user] = passwd + return htpasswd_entries + + def filters(self): + ''' returns a mapping of filters to methods ''' + return {"translate_idps": self.translate_idps, + "certificates_to_synchronize": self.certificates_to_synchronize, + "oo_htpasswd_users_from_file": self.oo_htpasswd_users_from_file} diff --git a/roles/lib_utils/library/delegated_serial_command.py b/roles/lib_utils/library/delegated_serial_command.py new file mode 100755 index 000000000..0cab1ca88 --- /dev/null +++ b/roles/lib_utils/library/delegated_serial_command.py @@ -0,0 +1,274 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# (c) 2012, Michael DeHaan <michael.dehaan@gmail.com>, and others +# (c) 2016, Andrew Butcher <abutcher@redhat.com> +# +# This module is derrived from the Ansible command module. +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + + +# pylint: disable=unused-wildcard-import,wildcard-import,unused-import,redefined-builtin + +''' delegated_serial_command ''' + +import datetime +import errno +import glob +import shlex +import os +import fcntl +import time + +DOCUMENTATION = ''' +--- +module: delegated_serial_command +short_description: Executes a command on a remote node +version_added: historical +description: + - The M(command) module takes the command name followed by a list + of space-delimited arguments. + - The given command will be executed on all selected nodes. It + will not be processed through the shell, so variables like + C($HOME) and operations like C("<"), C(">"), C("|"), and C("&") + will not work (use the M(shell) module if you need these + features). + - Creates and maintains a lockfile such that this module will + wait for other invocations to proceed. +options: + command: + description: + - the command to run + required: true + default: null + creates: + description: + - a filename or (since 2.0) glob pattern, when it already + exists, this step will B(not) be run. + required: no + default: null + removes: + description: + - a filename or (since 2.0) glob pattern, when it does not + exist, this step will B(not) be run. + version_added: "0.8" + required: no + default: null + chdir: + description: + - cd into this directory before running the command + version_added: "0.6" + required: false + default: null + executable: + description: + - change the shell used to execute the command. Should be an + absolute path to the executable. + required: false + default: null + version_added: "0.9" + warn: + version_added: "1.8" + default: yes + description: + - if command warnings are on in ansible.cfg, do not warn about + this particular line if set to no/false. + required: false + lockfile: + default: yes + description: + - the lockfile that will be created + timeout: + default: yes + description: + - time in milliseconds to wait to obtain the lock +notes: + - If you want to run a command through the shell (say you are using C(<), + C(>), C(|), etc), you actually want the M(shell) module instead. The + M(command) module is much more secure as it's not affected by the user's + environment. + - " C(creates), C(removes), and C(chdir) can be specified after + the command. For instance, if you only want to run a command if + a certain file does not exist, use this." +author: + - Ansible Core Team + - Michael DeHaan + - Andrew Butcher +''' + +EXAMPLES = ''' +# Example from Ansible Playbooks. +- delegated_serial_command: + command: /sbin/shutdown -t now + +# Run the command if the specified file does not exist. +- delegated_serial_command: + command: /usr/bin/make_database.sh arg1 arg2 + creates: /path/to/database +''' + +# Dict of options and their defaults +OPTIONS = {'chdir': None, + 'creates': None, + 'command': None, + 'executable': None, + 'NO_LOG': None, + 'removes': None, + 'warn': True, + 'lockfile': None, + 'timeout': None} + + +def check_command(commandline): + ''' Check provided command ''' + arguments = {'chown': 'owner', 'chmod': 'mode', 'chgrp': 'group', + 'ln': 'state=link', 'mkdir': 'state=directory', + 'rmdir': 'state=absent', 'rm': 'state=absent', 'touch': 'state=touch'} + commands = {'git': 'git', 'hg': 'hg', 'curl': 'get_url or uri', 'wget': 'get_url or uri', + 'svn': 'subversion', 'service': 'service', + 'mount': 'mount', 'rpm': 'yum, dnf or zypper', 'yum': 'yum', 'apt-get': 'apt', + 'tar': 'unarchive', 'unzip': 'unarchive', 'sed': 'template or lineinfile', + 'rsync': 'synchronize', 'dnf': 'dnf', 'zypper': 'zypper'} + become = ['sudo', 'su', 'pbrun', 'pfexec', 'runas'] + warnings = list() + command = os.path.basename(commandline.split()[0]) + # pylint: disable=line-too-long + if command in arguments: + warnings.append("Consider using file module with {0} rather than running {1}".format(arguments[command], command)) + if command in commands: + warnings.append("Consider using {0} module rather than running {1}".format(commands[command], command)) + if command in become: + warnings.append( + "Consider using 'become', 'become_method', and 'become_user' rather than running {0}".format(command,)) + return warnings + + +# pylint: disable=too-many-statements,too-many-branches,too-many-locals +def main(): + ''' Main module function ''' + module = AnsibleModule( # noqa: F405 + argument_spec=dict( + _uses_shell=dict(type='bool', default=False), + command=dict(required=True), + chdir=dict(), + executable=dict(), + creates=dict(), + removes=dict(), + warn=dict(type='bool', default=True), + lockfile=dict(default='/tmp/delegated_serial_command.lock'), + timeout=dict(type='int', default=30) + ) + ) + + shell = module.params['_uses_shell'] + chdir = module.params['chdir'] + executable = module.params['executable'] + command = module.params['command'] + creates = module.params['creates'] + removes = module.params['removes'] + warn = module.params['warn'] + lockfile = module.params['lockfile'] + timeout = module.params['timeout'] + + if command.strip() == '': + module.fail_json(rc=256, msg="no command given") + + iterated = 0 + lockfd = open(lockfile, 'w+') + while iterated < timeout: + try: + fcntl.flock(lockfd, fcntl.LOCK_EX | fcntl.LOCK_NB) + break + # pylint: disable=invalid-name + except IOError as e: + if e.errno != errno.EAGAIN: + module.fail_json(msg="I/O Error {0}: {1}".format(e.errno, e.strerror)) + else: + iterated += 1 + time.sleep(0.1) + + if chdir: + chdir = os.path.abspath(os.path.expanduser(chdir)) + os.chdir(chdir) + + if creates: + # do not run the command if the line contains creates=filename + # and the filename already exists. This allows idempotence + # of command executions. + path = os.path.expanduser(creates) + if glob.glob(path): + module.exit_json( + cmd=command, + stdout="skipped, since %s exists" % path, + changed=False, + stderr=False, + rc=0 + ) + + if removes: + # do not run the command if the line contains removes=filename + # and the filename does not exist. This allows idempotence + # of command executions. + path = os.path.expanduser(removes) + if not glob.glob(path): + module.exit_json( + cmd=command, + stdout="skipped, since %s does not exist" % path, + changed=False, + stderr=False, + rc=0 + ) + + warnings = list() + if warn: + warnings = check_command(command) + + if not shell: + command = shlex.split(command) + startd = datetime.datetime.now() + + # pylint: disable=invalid-name + rc, out, err = module.run_command(command, executable=executable, use_unsafe_shell=shell) + + fcntl.flock(lockfd, fcntl.LOCK_UN) + lockfd.close() + + endd = datetime.datetime.now() + delta = endd - startd + + if out is None: + out = '' + if err is None: + err = '' + + module.exit_json( + cmd=command, + stdout=out.rstrip("\r\n"), + stderr=err.rstrip("\r\n"), + rc=rc, + start=str(startd), + end=str(endd), + delta=str(delta), + changed=True, + warnings=warnings, + iterated=iterated + ) + + +# import module snippets +# pylint: disable=wrong-import-position +from ansible.module_utils.basic import * # noqa: F402,F403 + +main() diff --git a/roles/lib_utils/library/kubeclient_ca.py b/roles/lib_utils/library/kubeclient_ca.py new file mode 100644 index 000000000..a89a5574f --- /dev/null +++ b/roles/lib_utils/library/kubeclient_ca.py @@ -0,0 +1,88 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- +''' kubeclient_ca ansible module ''' + +import base64 +import yaml +from ansible.module_utils.basic import AnsibleModule + + +DOCUMENTATION = ''' +--- +module: kubeclient_ca +short_description: Modify kubeclient certificate-authority-data +author: Andrew Butcher +requirements: [ ] +''' +EXAMPLES = ''' +- kubeclient_ca: + client_path: /etc/origin/master/admin.kubeconfig + ca_path: /etc/origin/master/ca-bundle.crt + +- slurp: + src: /etc/origin/master/ca-bundle.crt + register: ca_data +- kubeclient_ca: + client_path: /etc/origin/master/admin.kubeconfig + ca_data: "{{ ca_data.content }}" +''' + + +def main(): + ''' Modify kubeconfig located at `client_path`, setting the + certificate authority data to specified `ca_data` or contents of + `ca_path`. + ''' + + module = AnsibleModule( # noqa: F405 + argument_spec=dict( + client_path=dict(required=True), + ca_data=dict(required=False, default=None), + ca_path=dict(required=False, default=None), + backup=dict(required=False, default=True, type='bool'), + ), + supports_check_mode=True, + mutually_exclusive=[['ca_data', 'ca_path']], + required_one_of=[['ca_data', 'ca_path']] + ) + + client_path = module.params['client_path'] + ca_data = module.params['ca_data'] + ca_path = module.params['ca_path'] + backup = module.params['backup'] + + try: + with open(client_path) as client_config_file: + client_config_data = yaml.safe_load(client_config_file.read()) + + if ca_data is None: + with open(ca_path) as ca_file: + ca_data = base64.standard_b64encode(ca_file.read()) + + changes = [] + # Naively update the CA information for each cluster in the + # kubeconfig. + for cluster in client_config_data['clusters']: + if cluster['cluster']['certificate-authority-data'] != ca_data: + cluster['cluster']['certificate-authority-data'] = ca_data + changes.append(cluster['name']) + + if not module.check_mode: + if len(changes) > 0 and backup: + module.backup_local(client_path) + + with open(client_path, 'w') as client_config_file: + client_config_string = yaml.dump(client_config_data, default_flow_style=False) + client_config_string = client_config_string.replace('\'\'', '""') + client_config_file.write(client_config_string) + + return module.exit_json(changed=(len(changes) > 0)) + + # ignore broad-except error to avoid stack trace to ansible user + # pylint: disable=broad-except + except Exception as error: + return module.fail_json(msg=str(error)) + + +if __name__ == '__main__': + main() diff --git a/roles/lib_utils/library/modify_yaml.py b/roles/lib_utils/library/modify_yaml.py new file mode 100644 index 000000000..9b8f9ba33 --- /dev/null +++ b/roles/lib_utils/library/modify_yaml.py @@ -0,0 +1,117 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- +''' modify_yaml ansible module ''' + +import yaml + +# ignore pylint errors related to the module_utils import +# pylint: disable=redefined-builtin, unused-wildcard-import, wildcard-import +from ansible.module_utils.basic import * # noqa: F402,F403 + + +DOCUMENTATION = ''' +--- +module: modify_yaml +short_description: Modify yaml key value pairs +author: Andrew Butcher +requirements: [ ] +''' +EXAMPLES = ''' +- modify_yaml: + dest: /etc/origin/master/master-config.yaml + yaml_key: 'kubernetesMasterConfig.masterCount' + yaml_value: 2 +''' + + +def set_key(yaml_data, yaml_key, yaml_value): + ''' Updates a parsed yaml structure setting a key to a value. + + :param yaml_data: yaml structure to modify. + :type yaml_data: dict + :param yaml_key: Key to modify. + :type yaml_key: mixed + :param yaml_value: Value use for yaml_key. + :type yaml_value: mixed + :returns: Changes to the yaml_data structure + :rtype: dict(tuple()) + ''' + changes = [] + ptr = yaml_data + final_key = yaml_key.split('.')[-1] + for key in yaml_key.split('.'): + # Key isn't present and we're not on the final key. Set to empty dictionary. + if key not in ptr and key != final_key: + ptr[key] = {} + ptr = ptr[key] + # Current key is the final key. Update value. + elif key == final_key: + if (key in ptr and module.safe_eval(ptr[key]) != yaml_value) or (key not in ptr): # noqa: F405 + ptr[key] = yaml_value + changes.append((yaml_key, yaml_value)) + else: + # Next value is None and we're not on the final key. + # Turn value into an empty dictionary. + if ptr[key] is None and key != final_key: + ptr[key] = {} + ptr = ptr[key] + return changes + + +def main(): + ''' Modify key (supplied in jinja2 dot notation) in yaml file, setting + the key to the desired value. + ''' + + # disabling pylint errors for global-variable-undefined and invalid-name + # for 'global module' usage, since it is required to use ansible_facts + # pylint: disable=global-variable-undefined, invalid-name, + # redefined-outer-name + global module + + module = AnsibleModule( # noqa: F405 + argument_spec=dict( + dest=dict(required=True), + yaml_key=dict(required=True), + yaml_value=dict(required=True), + backup=dict(required=False, default=True, type='bool'), + ), + supports_check_mode=True, + ) + + dest = module.params['dest'] + yaml_key = module.params['yaml_key'] + yaml_value = module.safe_eval(module.params['yaml_value']) + backup = module.params['backup'] + + # Represent null values as an empty string. + # pylint: disable=missing-docstring, unused-argument + def none_representer(dumper, data): + return yaml.ScalarNode(tag=u'tag:yaml.org,2002:null', value=u'') + + yaml.add_representer(type(None), none_representer) + + try: + with open(dest) as yaml_file: + yaml_data = yaml.safe_load(yaml_file.read()) + + changes = set_key(yaml_data, yaml_key, yaml_value) + + if len(changes) > 0: + if backup: + module.backup_local(dest) + with open(dest, 'w') as yaml_file: + yaml_string = yaml.dump(yaml_data, default_flow_style=False) + yaml_string = yaml_string.replace('\'\'', '""') + yaml_file.write(yaml_string) + + return module.exit_json(changed=(len(changes) > 0), changes=changes) + + # ignore broad-except error to avoid stack trace to ansible user + # pylint: disable=broad-except + except Exception as error: + return module.fail_json(msg=str(error)) + + +if __name__ == '__main__': + main() diff --git a/roles/lib_utils/library/openshift_cert_expiry.py b/roles/lib_utils/library/openshift_cert_expiry.py new file mode 100644 index 000000000..e355266b0 --- /dev/null +++ b/roles/lib_utils/library/openshift_cert_expiry.py @@ -0,0 +1,839 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- +# pylint: disable=line-too-long,invalid-name + +"""For details on this module see DOCUMENTATION (below)""" + +import base64 +import datetime +import io +import os +import subprocess +import yaml + +# pylint import-error disabled because pylint cannot find the package +# when installed in a virtualenv +from ansible.module_utils.six.moves import configparser # pylint: disable=import-error +from ansible.module_utils.basic import AnsibleModule + +try: + # You can comment this import out and include a 'pass' in this + # block if you're manually testing this module on a NON-ATOMIC + # HOST (or any host that just doesn't have PyOpenSSL + # available). That will force the `load_and_handle_cert` function + # to use the Fake OpenSSL classes. + import OpenSSL.crypto + HAS_OPENSSL = True +except ImportError: + # Some platforms (such as RHEL Atomic) may not have the Python + # OpenSSL library installed. In this case we will use a manual + # work-around to parse each certificate. + # + # Check for 'OpenSSL.crypto' in `sys.modules` later. + HAS_OPENSSL = False + +DOCUMENTATION = ''' +--- +module: openshift_cert_expiry +short_description: Check OpenShift Container Platform (OCP) and Kube certificate expirations on a cluster +description: + - The M(openshift_cert_expiry) module has two basic functions: to flag certificates which will expire in a set window of time from now, and to notify you about certificates which have already expired. + - When the module finishes, a summary of the examination is returned. Each certificate in the summary has a C(health) key with a value of one of the following: + - C(ok) - not expired, and outside of the expiration C(warning_days) window. + - C(warning) - not expired, but will expire between now and the C(warning_days) window. + - C(expired) - an expired certificate. + - Certificate flagging follow this logic: + - If the expiration date is before now then the certificate is classified as C(expired). + - The certificates time to live (expiration date - now) is calculated, if that time window is less than C(warning_days) the certificate is classified as C(warning). + - All other conditions are classified as C(ok). + - The following keys are ALSO present in the certificate summary: + - C(cert_cn) - The common name of the certificate (additional CNs present in SAN extensions are omitted) + - C(days_remaining) - The number of days until the certificate expires. + - C(expiry) - The date the certificate expires on. + - C(path) - The full path to the certificate on the examined host. +version_added: "1.0" +options: + config_base: + description: + - Base path to OCP system settings. + required: false + default: /etc/origin + warning_days: + description: + - Flag certificates which will expire in C(warning_days) days from now. + required: false + default: 30 + show_all: + description: + - Enable this option to show analysis of ALL certificates examined by this module. + - By default only certificates which have expired, or will expire within the C(warning_days) window will be reported. + required: false + default: false + +author: "Tim Bielawa (@tbielawa) <tbielawa@redhat.com>" +''' + +EXAMPLES = ''' +# Default invocation, only notify about expired certificates or certificates which will expire within 30 days from now +- openshift_cert_expiry: + +# Expand the warning window to show certificates expiring within a year from now +- openshift_cert_expiry: warning_days=365 + +# Show expired, soon to expire (now + 30 days), and all other certificates examined +- openshift_cert_expiry: show_all=true +''' + + +class FakeOpenSSLCertificate(object): + """This provides a rough mock of what you get from +`OpenSSL.crypto.load_certificate()`. This is a work-around for +platforms missing the Python OpenSSL library. + """ + def __init__(self, cert_string): + """`cert_string` is a certificate in the form you get from running a +.crt through 'openssl x509 -in CERT.cert -text'""" + self.cert_string = cert_string + self.serial = None + self.subject = None + self.extensions = [] + self.not_after = None + self._parse_cert() + + def _parse_cert(self): + """Manually parse the certificate line by line""" + self.extensions = [] + + PARSING_ALT_NAMES = False + PARSING_HEX_SERIAL = False + for line in self.cert_string.split('\n'): + l = line.strip() + if PARSING_ALT_NAMES: + # We're parsing a 'Subject Alternative Name' line + self.extensions.append( + FakeOpenSSLCertificateSANExtension(l)) + + PARSING_ALT_NAMES = False + continue + + if PARSING_HEX_SERIAL: + # Hex serials arrive colon-delimited + serial_raw = l.replace(':', '') + # Convert to decimal + self.serial = int('0x' + serial_raw, base=16) + PARSING_HEX_SERIAL = False + continue + + # parse out the bits that we can + if l.startswith('Serial Number:'): + # Decimal format: + # Serial Number: 11 (0xb) + # => 11 + # Hex Format (large serials): + # Serial Number: + # 0a:de:eb:24:04:75:ab:56:39:14:e9:5a:22:e2:85:bf + # => 14449739080294792594019643629255165375 + if l.endswith(':'): + PARSING_HEX_SERIAL = True + continue + self.serial = int(l.split()[-2]) + + elif l.startswith('Not After :'): + # Not After : Feb 7 18:19:35 2019 GMT + # => strptime(str, '%b %d %H:%M:%S %Y %Z') + # => strftime('%Y%m%d%H%M%SZ') + # => 20190207181935Z + not_after_raw = l.partition(' : ')[-1] + # Last item: ('Not After', ' : ', 'Feb 7 18:19:35 2019 GMT') + not_after_parsed = datetime.datetime.strptime(not_after_raw, '%b %d %H:%M:%S %Y %Z') + self.not_after = not_after_parsed.strftime('%Y%m%d%H%M%SZ') + + elif l.startswith('X509v3 Subject Alternative Name:'): + PARSING_ALT_NAMES = True + continue + + elif l.startswith('Subject:'): + # O = system:nodes, CN = system:node:m01.example.com + self.subject = FakeOpenSSLCertificateSubjects(l.partition(': ')[-1]) + + def get_serial_number(self): + """Return the serial number of the cert""" + return self.serial + + def get_subject(self): + """Subjects must implement get_components() and return dicts or +tuples. An 'openssl x509 -in CERT.cert -text' with 'Subject': + + Subject: Subject: O=system:nodes, CN=system:node:m01.example.com + +might return: [('O=system', 'nodes'), ('CN=system', 'node:m01.example.com')] + """ + return self.subject + + def get_extension(self, i): + """Extensions must implement get_short_name() and return the string +'subjectAltName'""" + return self.extensions[i] + + def get_extension_count(self): + """ get_extension_count """ + return len(self.extensions) + + def get_notAfter(self): + """Returns a date stamp as a string in the form +'20180922170439Z'. strptime the result with format param: +'%Y%m%d%H%M%SZ'.""" + return self.not_after + + +class FakeOpenSSLCertificateSANExtension(object): # pylint: disable=too-few-public-methods + """Mocks what happens when `get_extension` is called on a certificate +object""" + + def __init__(self, san_string): + """With `san_string` as you get from: + + $ openssl x509 -in certificate.crt -text + """ + self.san_string = san_string + self.short_name = 'subjectAltName' + + def get_short_name(self): + """Return the 'type' of this extension. It's always the same though +because we only care about subjectAltName's""" + return self.short_name + + def __str__(self): + """Return this extension and the value as a simple string""" + return self.san_string + + +# pylint: disable=too-few-public-methods +class FakeOpenSSLCertificateSubjects(object): + """Mocks what happens when `get_subject` is called on a certificate +object""" + + def __init__(self, subject_string): + """With `subject_string` as you get from: + + $ openssl x509 -in certificate.crt -text + """ + self.subjects = [] + for s in subject_string.split(', '): + name, _, value = s.partition(' = ') + self.subjects.append((name, value)) + + def get_components(self): + """Returns a list of tuples""" + return self.subjects + + +###################################################################### +def filter_paths(path_list): + """`path_list` - A list of file paths to check. Only files which exist +will be returned + """ + return [p for p in path_list if os.path.exists(os.path.realpath(p))] + + +# pylint: disable=too-many-locals,too-many-branches +# +# TODO: Break this function down into smaller chunks +def load_and_handle_cert(cert_string, now, base64decode=False, ans_module=None): + """Load a certificate, split off the good parts, and return some +useful data + +Params: + +- `cert_string` (string) - a certificate loaded into a string object +- `now` (datetime) - a datetime object of the time to calculate the certificate 'time_remaining' against +- `base64decode` (bool) - run base64.b64decode() on the input +- `ans_module` (AnsibleModule) - The AnsibleModule object for this module (so we can raise errors) + +Returns: +A tuple of the form: + (cert_subject, cert_expiry_date, time_remaining, cert_serial_number) + """ + if base64decode: + _cert_string = base64.b64decode(cert_string).decode('utf-8') + else: + _cert_string = cert_string + + # Disable this. We 'redefine' the type because we are working + # around a missing library on the target host. + # + # pylint: disable=redefined-variable-type + if HAS_OPENSSL: + # No work-around required + cert_loaded = OpenSSL.crypto.load_certificate( + OpenSSL.crypto.FILETYPE_PEM, _cert_string) + else: + # Missing library, work-around required. Run the 'openssl' + # command on it to decode it + cmd = 'openssl x509 -text' + try: + openssl_proc = subprocess.Popen(cmd.split(), + stdout=subprocess.PIPE, + stdin=subprocess.PIPE) + except OSError: + ans_module.fail_json(msg="Error: The 'OpenSSL' python library and CLI command were not found on the target host. Unable to parse any certificates. This host will not be included in generated reports.") + else: + openssl_decoded = openssl_proc.communicate(_cert_string.encode('utf-8'))[0].decode('utf-8') + cert_loaded = FakeOpenSSLCertificate(openssl_decoded) + + ###################################################################### + # Read all possible names from the cert + cert_subjects = [] + for name, value in cert_loaded.get_subject().get_components(): + if isinstance(name, bytes) or isinstance(value, bytes): + name = name.decode('utf-8') + value = value.decode('utf-8') + cert_subjects.append('{}:{}'.format(name, value)) + + # To read SANs from a cert we must read the subjectAltName + # extension from the X509 Object. What makes this more difficult + # is that pyOpenSSL does not give extensions as an iterable + san = None + for i in range(cert_loaded.get_extension_count()): + ext = cert_loaded.get_extension(i) + if ext.get_short_name() == 'subjectAltName': + san = ext + + if san is not None: + # The X509Extension object for subjectAltName prints as a + # string with the alt names separated by a comma and a + # space. Split the string by ', ' and then add our new names + # to the list of existing names + cert_subjects.extend(str(san).split(', ')) + + cert_subject = ', '.join(cert_subjects) + ###################################################################### + + # Grab the expiration date + not_after = cert_loaded.get_notAfter() + # example get_notAfter() => 20180922170439Z + if isinstance(not_after, bytes): + not_after = not_after.decode('utf-8') + + cert_expiry_date = datetime.datetime.strptime( + not_after, + '%Y%m%d%H%M%SZ') + + time_remaining = cert_expiry_date - now + + return (cert_subject, cert_expiry_date, time_remaining, cert_loaded.get_serial_number()) + + +def classify_cert(cert_meta, now, time_remaining, expire_window, cert_list): + """Given metadata about a certificate under examination, classify it + into one of three categories, 'ok', 'warning', and 'expired'. + +Params: + +- `cert_meta` dict - A dict with certificate metadata. Required fields + include: 'cert_cn', 'path', 'expiry', 'days_remaining', 'health'. +- `now` (datetime) - a datetime object of the time to calculate the certificate 'time_remaining' against +- `time_remaining` (datetime.timedelta) - a timedelta for how long until the cert expires +- `expire_window` (datetime.timedelta) - a timedelta for how long the warning window is +- `cert_list` list - A list to shove the classified cert into + +Return: +- `cert_list` - The updated list of classified certificates + """ + expiry_str = str(cert_meta['expiry']) + # Categorization + if cert_meta['expiry'] < now: + # This already expired, must NOTIFY + cert_meta['health'] = 'expired' + elif time_remaining < expire_window: + # WARN about this upcoming expirations + cert_meta['health'] = 'warning' + else: + # Not expired or about to expire + cert_meta['health'] = 'ok' + + cert_meta['expiry'] = expiry_str + cert_meta['serial_hex'] = hex(int(cert_meta['serial'])) + cert_list.append(cert_meta) + return cert_list + + +def tabulate_summary(certificates, kubeconfigs, etcd_certs, router_certs, registry_certs): + """Calculate the summary text for when the module finishes +running. This includes counts of each classification and what have +you. + +Params: + +- `certificates` (list of dicts) - Processed `expire_check_result` + dicts with filled in `health` keys for system certificates. +- `kubeconfigs` - as above for kubeconfigs +- `etcd_certs` - as above for etcd certs + +Return: + +- `summary_results` (dict) - Counts of each cert type classification + and total items examined. + """ + items = certificates + kubeconfigs + etcd_certs + router_certs + registry_certs + + summary_results = { + 'system_certificates': len(certificates), + 'kubeconfig_certificates': len(kubeconfigs), + 'etcd_certificates': len(etcd_certs), + 'router_certs': len(router_certs), + 'registry_certs': len(registry_certs), + 'total': len(items), + 'ok': 0, + 'warning': 0, + 'expired': 0 + } + + summary_results['expired'] = len([c for c in items if c['health'] == 'expired']) + summary_results['warning'] = len([c for c in items if c['health'] == 'warning']) + summary_results['ok'] = len([c for c in items if c['health'] == 'ok']) + + return summary_results + + +###################################################################### +# This is our module MAIN function after all, so there's bound to be a +# lot of code bundled up into one block +# +# Reason: These checks are disabled because the issue was introduced +# during a period where the pylint checks weren't enabled for this file +# Status: temporarily disabled pending future refactoring +# pylint: disable=too-many-locals,too-many-statements,too-many-branches +def main(): + """This module examines certificates (in various forms) which compose +an OpenShift Container Platform cluster + """ + + module = AnsibleModule( + argument_spec=dict( + config_base=dict( + required=False, + default="/etc/origin", + type='str'), + warning_days=dict( + required=False, + default=30, + type='int'), + show_all=dict( + required=False, + default=False, + type='bool') + ), + supports_check_mode=True, + ) + + # Basic scaffolding for OpenShift specific certs + openshift_base_config_path = os.path.realpath(module.params['config_base']) + openshift_master_config_path = os.path.join(openshift_base_config_path, + "master", "master-config.yaml") + openshift_node_config_path = os.path.join(openshift_base_config_path, + "node", "node-config.yaml") + openshift_cert_check_paths = [ + openshift_master_config_path, + openshift_node_config_path, + ] + + # Paths for Kubeconfigs. Additional kubeconfigs are conditionally + # checked later in the code + master_kube_configs = ['admin', 'openshift-master', + 'openshift-node', 'openshift-router', + 'openshift-registry'] + + kubeconfig_paths = [] + for m_kube_config in master_kube_configs: + kubeconfig_paths.append( + os.path.join(openshift_base_config_path, "master", m_kube_config + ".kubeconfig") + ) + + # Validate some paths we have the ability to do ahead of time + openshift_cert_check_paths = filter_paths(openshift_cert_check_paths) + kubeconfig_paths = filter_paths(kubeconfig_paths) + + # etcd, where do you hide your certs? Used when parsing etcd.conf + etcd_cert_params = [ + "ETCD_CA_FILE", + "ETCD_CERT_FILE", + "ETCD_PEER_CA_FILE", + "ETCD_PEER_CERT_FILE", + ] + + # Expiry checking stuff + now = datetime.datetime.now() + # todo, catch exception for invalid input and return a fail_json + warning_days = int(module.params['warning_days']) + expire_window = datetime.timedelta(days=warning_days) + + # Module stuff + # + # The results of our cert checking to return from the task call + check_results = {} + check_results['meta'] = {} + check_results['meta']['warning_days'] = warning_days + check_results['meta']['checked_at_time'] = str(now) + check_results['meta']['warn_before_date'] = str(now + expire_window) + check_results['meta']['show_all'] = str(module.params['show_all']) + # All the analyzed certs accumulate here + ocp_certs = [] + + ###################################################################### + # Sure, why not? Let's enable check mode. + if module.check_mode: + check_results['ocp_certs'] = [] + module.exit_json( + check_results=check_results, + msg="Checked 0 total certificates. Expired/Warning/OK: 0/0/0. Warning window: %s days" % module.params['warning_days'], + rc=0, + changed=False + ) + + ###################################################################### + # Check for OpenShift Container Platform specific certs + ###################################################################### + for os_cert in filter_paths(openshift_cert_check_paths): + # Open up that config file and locate the cert and CA + with io.open(os_cert, 'r', encoding='utf-8') as fp: + cert_meta = {} + cfg = yaml.load(fp) + # cert files are specified in parsed `fp` as relative to the path + # of the original config file. 'master-config.yaml' with certFile + # = 'foo.crt' implies that 'foo.crt' is in the same + # directory. certFile = '../foo.crt' is in the parent directory. + cfg_path = os.path.dirname(fp.name) + cert_meta['certFile'] = os.path.join(cfg_path, cfg['servingInfo']['certFile']) + cert_meta['clientCA'] = os.path.join(cfg_path, cfg['servingInfo']['clientCA']) + + ###################################################################### + # Load the certificate and the CA, parse their expiration dates into + # datetime objects so we can manipulate them later + for v in cert_meta.values(): + with io.open(v, 'r', encoding='utf-8') as fp: + cert = fp.read() + (cert_subject, + cert_expiry_date, + time_remaining, + cert_serial) = load_and_handle_cert(cert, now, ans_module=module) + + expire_check_result = { + 'cert_cn': cert_subject, + 'path': fp.name, + 'expiry': cert_expiry_date, + 'days_remaining': time_remaining.days, + 'health': None, + 'serial': cert_serial + } + + classify_cert(expire_check_result, now, time_remaining, expire_window, ocp_certs) + + ###################################################################### + # /Check for OpenShift Container Platform specific certs + ###################################################################### + + ###################################################################### + # Check service Kubeconfigs + ###################################################################### + kubeconfigs = [] + + # There may be additional kubeconfigs to check, but their naming + # is less predictable than the ones we've already assembled. + + try: + # Try to read the standard 'node-config.yaml' file to check if + # this host is a node. + with io.open(openshift_node_config_path, 'r', encoding='utf-8') as fp: + cfg = yaml.load(fp) + + # OK, the config file exists, therefore this is a + # node. Nodes have their own kubeconfig files to + # communicate with the master API. Let's read the relative + # path to that file from the node config. + node_masterKubeConfig = cfg['masterKubeConfig'] + # As before, the path to the 'masterKubeConfig' file is + # relative to `fp` + cfg_path = os.path.dirname(fp.name) + node_kubeconfig = os.path.join(cfg_path, node_masterKubeConfig) + + with io.open(node_kubeconfig, 'r', encoding='utf8') as fp: + # Read in the nodes kubeconfig file and grab the good stuff + cfg = yaml.load(fp) + + c = cfg['users'][0]['user']['client-certificate-data'] + (cert_subject, + cert_expiry_date, + time_remaining, + cert_serial) = load_and_handle_cert(c, now, base64decode=True, ans_module=module) + + expire_check_result = { + 'cert_cn': cert_subject, + 'path': fp.name, + 'expiry': cert_expiry_date, + 'days_remaining': time_remaining.days, + 'health': None, + 'serial': cert_serial + } + + classify_cert(expire_check_result, now, time_remaining, expire_window, kubeconfigs) + except IOError: + # This is not a node + pass + + for kube in filter_paths(kubeconfig_paths): + with io.open(kube, 'r', encoding='utf-8') as fp: + # TODO: Maybe consider catching exceptions here? + cfg = yaml.load(fp) + + # Per conversation, "the kubeconfigs you care about: + # admin, router, registry should all be single + # value". Following that advice we only grab the data for + # the user at index 0 in the 'users' list. There should + # not be more than one user. + c = cfg['users'][0]['user']['client-certificate-data'] + (cert_subject, + cert_expiry_date, + time_remaining, + cert_serial) = load_and_handle_cert(c, now, base64decode=True, ans_module=module) + + expire_check_result = { + 'cert_cn': cert_subject, + 'path': fp.name, + 'expiry': cert_expiry_date, + 'days_remaining': time_remaining.days, + 'health': None, + 'serial': cert_serial + } + + classify_cert(expire_check_result, now, time_remaining, expire_window, kubeconfigs) + + ###################################################################### + # /Check service Kubeconfigs + ###################################################################### + + ###################################################################### + # Check etcd certs + # + # Two things to check: 'external' etcd, and embedded etcd. + ###################################################################### + # FIRST: The 'external' etcd + # + # Some values may be duplicated, make this a set for now so we + # unique them all + etcd_certs_to_check = set([]) + etcd_certs = [] + etcd_cert_params.append('dne') + try: + with io.open('/etc/etcd/etcd.conf', 'r', encoding='utf-8') as fp: + # Add dummy header section. + config = io.StringIO() + config.write(u'[ETCD]\n') + config.write(fp.read().replace('%', '%%')) + config.seek(0, os.SEEK_SET) + + etcd_config = configparser.ConfigParser() + etcd_config.readfp(config) + + for param in etcd_cert_params: + try: + etcd_certs_to_check.add(etcd_config.get('ETCD', param)) + except configparser.NoOptionError: + # That parameter does not exist, oh well... + pass + except IOError: + # No etcd to see here, move along + pass + + for etcd_cert in filter_paths(etcd_certs_to_check): + with io.open(etcd_cert, 'r', encoding='utf-8') as fp: + c = fp.read() + (cert_subject, + cert_expiry_date, + time_remaining, + cert_serial) = load_and_handle_cert(c, now, ans_module=module) + + expire_check_result = { + 'cert_cn': cert_subject, + 'path': fp.name, + 'expiry': cert_expiry_date, + 'days_remaining': time_remaining.days, + 'health': None, + 'serial': cert_serial + } + + classify_cert(expire_check_result, now, time_remaining, expire_window, etcd_certs) + + ###################################################################### + # Now the embedded etcd + ###################################################################### + try: + with io.open('/etc/origin/master/master-config.yaml', 'r', encoding='utf-8') as fp: + cfg = yaml.load(fp) + except IOError: + # Not present + pass + else: + if cfg.get('etcdConfig', {}).get('servingInfo', {}).get('certFile', None) is not None: + # This is embedded + etcd_crt_name = cfg['etcdConfig']['servingInfo']['certFile'] + else: + # Not embedded + etcd_crt_name = None + + if etcd_crt_name is not None: + # etcd_crt_name is relative to the location of the + # master-config.yaml file + cfg_path = os.path.dirname(fp.name) + etcd_cert = os.path.join(cfg_path, etcd_crt_name) + with open(etcd_cert, 'r') as etcd_fp: + (cert_subject, + cert_expiry_date, + time_remaining, + cert_serial) = load_and_handle_cert(etcd_fp.read(), now, ans_module=module) + + expire_check_result = { + 'cert_cn': cert_subject, + 'path': etcd_fp.name, + 'expiry': cert_expiry_date, + 'days_remaining': time_remaining.days, + 'health': None, + 'serial': cert_serial + } + + classify_cert(expire_check_result, now, time_remaining, expire_window, etcd_certs) + + ###################################################################### + # /Check etcd certs + ###################################################################### + + ###################################################################### + # Check router/registry certs + # + # These are saved as secrets in etcd. That means that we can not + # simply read a file to grab the data. Instead we're going to + # subprocess out to the 'oc get' command. On non-masters this + # command will fail, that is expected so we catch that exception. + ###################################################################### + router_certs = [] + registry_certs = [] + + ###################################################################### + # First the router certs + try: + router_secrets_raw = subprocess.Popen('oc get -n default secret router-certs -o yaml'.split(), + stdout=subprocess.PIPE) + router_ds = yaml.load(router_secrets_raw.communicate()[0]) + router_c = router_ds['data']['tls.crt'] + router_path = router_ds['metadata']['selfLink'] + except TypeError: + # YAML couldn't load the result, this is not a master + pass + except OSError: + # The OC command doesn't exist here. Move along. + pass + else: + (cert_subject, + cert_expiry_date, + time_remaining, + cert_serial) = load_and_handle_cert(router_c, now, base64decode=True, ans_module=module) + + expire_check_result = { + 'cert_cn': cert_subject, + 'path': router_path, + 'expiry': cert_expiry_date, + 'days_remaining': time_remaining.days, + 'health': None, + 'serial': cert_serial + } + + classify_cert(expire_check_result, now, time_remaining, expire_window, router_certs) + + ###################################################################### + # Now for registry + try: + registry_secrets_raw = subprocess.Popen('oc get -n default secret registry-certificates -o yaml'.split(), + stdout=subprocess.PIPE) + registry_ds = yaml.load(registry_secrets_raw.communicate()[0]) + registry_c = registry_ds['data']['registry.crt'] + registry_path = registry_ds['metadata']['selfLink'] + except TypeError: + # YAML couldn't load the result, this is not a master + pass + except OSError: + # The OC command doesn't exist here. Move along. + pass + else: + (cert_subject, + cert_expiry_date, + time_remaining, + cert_serial) = load_and_handle_cert(registry_c, now, base64decode=True, ans_module=module) + + expire_check_result = { + 'cert_cn': cert_subject, + 'path': registry_path, + 'expiry': cert_expiry_date, + 'days_remaining': time_remaining.days, + 'health': None, + 'serial': cert_serial + } + + classify_cert(expire_check_result, now, time_remaining, expire_window, registry_certs) + + ###################################################################### + # /Check router/registry certs + ###################################################################### + + res = tabulate_summary(ocp_certs, kubeconfigs, etcd_certs, router_certs, registry_certs) + + msg = "Checked {count} total certificates. Expired/Warning/OK: {exp}/{warn}/{ok}. Warning window: {window} days".format( + count=res['total'], + exp=res['expired'], + warn=res['warning'], + ok=res['ok'], + window=int(module.params['warning_days']), + ) + + # By default we only return detailed information about expired or + # warning certificates. If show_all is true then we will print all + # the certificates examined. + if not module.params['show_all']: + check_results['ocp_certs'] = [crt for crt in ocp_certs if crt['health'] in ['expired', 'warning']] + check_results['kubeconfigs'] = [crt for crt in kubeconfigs if crt['health'] in ['expired', 'warning']] + check_results['etcd'] = [crt for crt in etcd_certs if crt['health'] in ['expired', 'warning']] + check_results['registry'] = [crt for crt in registry_certs if crt['health'] in ['expired', 'warning']] + check_results['router'] = [crt for crt in router_certs if crt['health'] in ['expired', 'warning']] + else: + check_results['ocp_certs'] = ocp_certs + check_results['kubeconfigs'] = kubeconfigs + check_results['etcd'] = etcd_certs + check_results['registry'] = registry_certs + check_results['router'] = router_certs + + # Sort the final results to report in order of ascending safety + # time. That is to say, the certificates which will expire sooner + # will be at the front of the list and certificates which will + # expire later are at the end. Router and registry certs should be + # limited to just 1 result, so don't bother sorting those. + def cert_key(item): + ''' return the days_remaining key ''' + return item['days_remaining'] + + check_results['ocp_certs'] = sorted(check_results['ocp_certs'], key=cert_key) + check_results['kubeconfigs'] = sorted(check_results['kubeconfigs'], key=cert_key) + check_results['etcd'] = sorted(check_results['etcd'], key=cert_key) + + # This module will never change anything, but we might want to + # change the return code parameter if there is some catastrophic + # error we noticed earlier + module.exit_json( + check_results=check_results, + summary=res, + msg=msg, + rc=0, + changed=False + ) + + +if __name__ == '__main__': + main() diff --git a/roles/lib_utils/library/openshift_container_binary_sync.py b/roles/lib_utils/library/openshift_container_binary_sync.py new file mode 100644 index 000000000..440b8ec28 --- /dev/null +++ b/roles/lib_utils/library/openshift_container_binary_sync.py @@ -0,0 +1,205 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- +# pylint: disable=missing-docstring,invalid-name + +import random +import tempfile +import shutil +import os.path + +# pylint: disable=redefined-builtin,wildcard-import,unused-wildcard-import +from ansible.module_utils.basic import * # noqa: F403 + + +DOCUMENTATION = ''' +--- +module: openshift_container_binary_sync +short_description: Copies OpenShift binaries out of the given image tag to host system. +''' + + +class BinarySyncError(Exception): + def __init__(self, msg): + super(BinarySyncError, self).__init__(msg) + self.msg = msg + + +# pylint: disable=too-few-public-methods,too-many-instance-attributes +class BinarySyncer(object): + """ + Syncs the openshift, oc, and kubectl binaries/symlinks out of + a container onto the host system. + """ + + def __init__(self, module, image, tag, backend): + self.module = module + self.changed = False + self.output = [] + self.bin_dir = '/usr/local/bin' + self._image = image + self.tag = tag + self.backend = backend + self.temp_dir = None # TBD + + def sync(self): + if self.backend == 'atomic': + return self._sync_atomic() + + return self._sync_docker() + + def _sync_atomic(self): + self.temp_dir = tempfile.mkdtemp() + temp_dir_mount = tempfile.mkdtemp() + try: + image_spec = '%s:%s' % (self.image, self.tag) + rc, stdout, stderr = self.module.run_command(['atomic', 'mount', + '--storage', "ostree", + image_spec, temp_dir_mount]) + if rc: + raise BinarySyncError("Error mounting image. stdout=%s, stderr=%s" % + (stdout, stderr)) + for i in ["openshift", "oc"]: + src_file = os.path.join(temp_dir_mount, "usr/bin", i) + shutil.copy(src_file, self.temp_dir) + + self._sync_binaries() + finally: + self.module.run_command(['atomic', 'umount', temp_dir_mount]) + shutil.rmtree(temp_dir_mount) + shutil.rmtree(self.temp_dir) + + def _sync_docker(self): + container_name = "openshift-cli-%s" % random.randint(1, 100000) + rc, stdout, stderr = self.module.run_command(['docker', 'create', '--name', + container_name, '%s:%s' % (self.image, self.tag)]) + if rc: + raise BinarySyncError("Error creating temporary docker container. stdout=%s, stderr=%s" % + (stdout, stderr)) + self.output.append(stdout) + try: + self.temp_dir = tempfile.mkdtemp() + self.output.append("Using temp dir: %s" % self.temp_dir) + + rc, stdout, stderr = self.module.run_command(['docker', 'cp', "%s:/usr/bin/openshift" % container_name, + self.temp_dir]) + if rc: + raise BinarySyncError("Error copying file from docker container: stdout=%s, stderr=%s" % + (stdout, stderr)) + + rc, stdout, stderr = self.module.run_command(['docker', 'cp', "%s:/usr/bin/oc" % container_name, + self.temp_dir]) + if rc: + raise BinarySyncError("Error copying file from docker container: stdout=%s, stderr=%s" % + (stdout, stderr)) + + self._sync_binaries() + finally: + shutil.rmtree(self.temp_dir) + self.module.run_command(['docker', 'rm', container_name]) + + def _sync_binaries(self): + self._sync_binary('openshift') + + # In older versions, oc was a symlink to openshift: + if os.path.islink(os.path.join(self.temp_dir, 'oc')): + self._sync_symlink('oc', 'openshift') + else: + self._sync_binary('oc') + + # Ensure correct symlinks created: + self._sync_symlink('kubectl', 'openshift') + + # Remove old oadm binary + if os.path.exists(os.path.join(self.bin_dir, 'oadm')): + os.remove(os.path.join(self.bin_dir, 'oadm')) + + def _sync_symlink(self, binary_name, link_to): + """ Ensure the given binary name exists and links to the expected binary. """ + + # The symlink we are creating: + link_path = os.path.join(self.bin_dir, binary_name) + + # The expected file we should be linking to: + link_dest = os.path.join(self.bin_dir, link_to) + + if not os.path.exists(link_path) or \ + not os.path.islink(link_path) or \ + os.path.realpath(link_path) != os.path.realpath(link_dest): + if os.path.exists(link_path): + os.remove(link_path) + os.symlink(link_to, os.path.join(self.bin_dir, binary_name)) + self.output.append("Symlinked %s to %s." % (link_path, link_dest)) + self.changed = True + + def _sync_binary(self, binary_name): + src_path = os.path.join(self.temp_dir, binary_name) + dest_path = os.path.join(self.bin_dir, binary_name) + incoming_checksum = self.module.run_command(['sha256sum', src_path])[1] + if not os.path.exists(dest_path) or self.module.run_command(['sha256sum', dest_path])[1] != incoming_checksum: + + # See: https://github.com/openshift/openshift-ansible/issues/4965 + if os.path.islink(dest_path): + os.unlink(dest_path) + self.output.append('Removed old symlink {} before copying binary.'.format(dest_path)) + shutil.move(src_path, dest_path) + self.output.append("Moved %s to %s." % (src_path, dest_path)) + self.changed = True + + @property + def raw_image(self): + """ + Returns the image as it was originally passed in to the instance. + + .. note:: + This image string will only work directly with the atomic command. + + :returns: The original image passed in. + :rtype: str + """ + return self._image + + @property + def image(self): + """ + Returns the image without atomic prefixes used to map to skopeo args. + + :returns: The image string without prefixes + :rtype: str + """ + image = self._image + for remove in ('oci:', 'http:', 'https:'): + if image.startswith(remove): + image = image.replace(remove, '') + return image + + +def main(): + module = AnsibleModule( # noqa: F405 + argument_spec=dict( + image=dict(required=True), + tag=dict(required=True), + backend=dict(required=True), + ), + supports_check_mode=True + ) + + image = module.params['image'] + tag = module.params['tag'] + backend = module.params['backend'] + + if backend not in ["docker", "atomic"]: + module.fail_json(msg="unknown backend") + + binary_syncer = BinarySyncer(module, image, tag, backend) + + try: + binary_syncer.sync() + except BinarySyncError as ex: + module.fail_json(msg=ex.msg) + + return module.exit_json(changed=binary_syncer.changed, + output=binary_syncer.output) + + +if __name__ == '__main__': + main() diff --git a/roles/lib_utils/library/os_firewall_manage_iptables.py b/roles/lib_utils/library/os_firewall_manage_iptables.py new file mode 100644 index 000000000..aeee3ede8 --- /dev/null +++ b/roles/lib_utils/library/os_firewall_manage_iptables.py @@ -0,0 +1,283 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- +# pylint: disable=fixme, missing-docstring +import subprocess + +DOCUMENTATION = ''' +--- +module: os_firewall_manage_iptables +short_description: This module manages iptables rules for a given chain +author: Jason DeTiberus +requirements: [ ] +''' +EXAMPLES = ''' +''' + + +class IpTablesError(Exception): + def __init__(self, msg, cmd, exit_code, output): + super(IpTablesError, self).__init__(msg) + self.msg = msg + self.cmd = cmd + self.exit_code = exit_code + self.output = output + + +class IpTablesAddRuleError(IpTablesError): + pass + + +class IpTablesRemoveRuleError(IpTablesError): + def __init__(self, chain, msg, cmd, exit_code, output): # pylint: disable=too-many-arguments, line-too-long, redefined-outer-name + super(IpTablesRemoveRuleError, self).__init__(msg, cmd, exit_code, + output) + self.chain = chain + + +class IpTablesSaveError(IpTablesError): + pass + + +class IpTablesCreateChainError(IpTablesError): + def __init__(self, chain, msg, cmd, exit_code, output): # pylint: disable=too-many-arguments, line-too-long, redefined-outer-name + super(IpTablesCreateChainError, self).__init__(msg, cmd, exit_code, + output) + self.chain = chain + + +class IpTablesCreateJumpRuleError(IpTablesError): + def __init__(self, chain, msg, cmd, exit_code, output): # pylint: disable=too-many-arguments, line-too-long, redefined-outer-name + super(IpTablesCreateJumpRuleError, self).__init__(msg, cmd, exit_code, + output) + self.chain = chain + + +# TODO: implement rollbacks for any events that were successful and an +# exception was thrown later. For example, when the chain is created +# successfully, but the add/remove rule fails. +class IpTablesManager(object): # pylint: disable=too-many-instance-attributes + def __init__(self, module): + self.module = module + self.ip_version = module.params['ip_version'] + self.check_mode = module.check_mode + self.chain = module.params['chain'] + self.create_jump_rule = module.params['create_jump_rule'] + self.jump_rule_chain = module.params['jump_rule_chain'] + self.cmd = self.gen_cmd() + self.save_cmd = self.gen_save_cmd() + self.output = [] + self.changed = False + + def save(self): + try: + self.output.append(subprocess.check_output(self.save_cmd, stderr=subprocess.STDOUT)) + except subprocess.CalledProcessError as ex: + raise IpTablesSaveError( + msg="Failed to save iptables rules", + cmd=ex.cmd, exit_code=ex.returncode, output=ex.output) + + def verify_chain(self): + if not self.chain_exists(): + self.create_chain() + if self.create_jump_rule and not self.jump_rule_exists(): + self.create_jump() + + def add_rule(self, port, proto): + rule = self.gen_rule(port, proto) + if not self.rule_exists(rule): + self.verify_chain() + + if self.check_mode: + self.changed = True + self.output.append("Create rule for %s %s" % (proto, port)) + else: + cmd = self.cmd + ['-A'] + rule + try: + self.output.append(subprocess.check_output(cmd)) + self.changed = True + self.save() + except subprocess.CalledProcessError as ex: + raise IpTablesCreateChainError( + chain=self.chain, + msg="Failed to create rule for " + "%s %s" % (proto, port), + cmd=ex.cmd, exit_code=ex.returncode, + output=ex.output) + + def remove_rule(self, port, proto): + rule = self.gen_rule(port, proto) + if self.rule_exists(rule): + if self.check_mode: + self.changed = True + self.output.append("Remove rule for %s %s" % (proto, port)) + else: + cmd = self.cmd + ['-D'] + rule + try: + self.output.append(subprocess.check_output(cmd)) + self.changed = True + self.save() + except subprocess.CalledProcessError as ex: + raise IpTablesRemoveRuleError( + chain=self.chain, + msg="Failed to remove rule for %s %s" % (proto, port), + cmd=ex.cmd, exit_code=ex.returncode, output=ex.output) + + def rule_exists(self, rule): + check_cmd = self.cmd + ['-C'] + rule + return True if subprocess.call(check_cmd) == 0 else False + + @staticmethod + def port_as_argument(port): + if isinstance(port, int): + return str(port) + if isinstance(port, basestring): # noqa: F405 + return port.replace('-', ":") + return port + + def gen_rule(self, port, proto): + return [self.chain, '-p', proto, '-m', 'state', '--state', 'NEW', + '-m', proto, '--dport', IpTablesManager.port_as_argument(port), '-j', 'ACCEPT'] + + def create_jump(self): + if self.check_mode: + self.changed = True + self.output.append("Create jump rule for chain %s" % self.chain) + else: + try: + cmd = self.cmd + ['-L', self.jump_rule_chain, '--line-numbers'] + output = subprocess.check_output(cmd, stderr=subprocess.STDOUT) + + # break the input rules into rows and columns + input_rules = [s.split() for s in to_native(output).split('\n')] + + # Find the last numbered rule + last_rule_num = None + last_rule_target = None + for rule in input_rules[:-1]: + if rule: + try: + last_rule_num = int(rule[0]) + except ValueError: + continue + last_rule_target = rule[1] + + # Naively assume that if the last row is a REJECT or DROP rule, + # then we can insert our rule right before it, otherwise we + # assume that we can just append the rule. + if (last_rule_num and last_rule_target and last_rule_target in ['REJECT', 'DROP']): + # insert rule + cmd = self.cmd + ['-I', self.jump_rule_chain, + str(last_rule_num)] + else: + # append rule + cmd = self.cmd + ['-A', self.jump_rule_chain] + cmd += ['-j', self.chain] + output = subprocess.check_output(cmd, stderr=subprocess.STDOUT) + self.changed = True + self.output.append(output) + self.save() + except subprocess.CalledProcessError as ex: + if '--line-numbers' in ex.cmd: + raise IpTablesCreateJumpRuleError( + chain=self.chain, + msg=("Failed to query existing " + + self.jump_rule_chain + + " rules to determine jump rule location"), + cmd=ex.cmd, exit_code=ex.returncode, + output=ex.output) + else: + raise IpTablesCreateJumpRuleError( + chain=self.chain, + msg=("Failed to create jump rule for chain " + + self.chain), + cmd=ex.cmd, exit_code=ex.returncode, + output=ex.output) + + def create_chain(self): + if self.check_mode: + self.changed = True + self.output.append("Create chain %s" % self.chain) + else: + try: + cmd = self.cmd + ['-N', self.chain] + self.output.append(subprocess.check_output(cmd, stderr=subprocess.STDOUT)) + self.changed = True + self.output.append("Successfully created chain %s" % + self.chain) + self.save() + except subprocess.CalledProcessError as ex: + raise IpTablesCreateChainError( + chain=self.chain, + msg="Failed to create chain: %s" % self.chain, + cmd=ex.cmd, exit_code=ex.returncode, output=ex.output + ) + + def jump_rule_exists(self): + cmd = self.cmd + ['-C', self.jump_rule_chain, '-j', self.chain] + return True if subprocess.call(cmd) == 0 else False + + def chain_exists(self): + cmd = self.cmd + ['-L', self.chain] + return True if subprocess.call(cmd) == 0 else False + + def gen_cmd(self): + cmd = 'iptables' if self.ip_version == 'ipv4' else 'ip6tables' + # Include -w (wait for xtables lock) in default arguments. + default_args = ['-w'] + return ["/usr/sbin/%s" % cmd] + default_args + + def gen_save_cmd(self): # pylint: disable=no-self-use + return ['/usr/libexec/iptables/iptables.init', 'save'] + + +def main(): + module = AnsibleModule( # noqa: F405 + argument_spec=dict( + name=dict(required=True), + action=dict(required=True, choices=['add', 'remove', + 'verify_chain']), + chain=dict(required=False, default='OS_FIREWALL_ALLOW'), + create_jump_rule=dict(required=False, type='bool', default=True), + jump_rule_chain=dict(required=False, default='INPUT'), + protocol=dict(required=False, choices=['tcp', 'udp']), + port=dict(required=False, type='str'), + ip_version=dict(required=False, default='ipv4', + choices=['ipv4', 'ipv6']), + ), + supports_check_mode=True + ) + + action = module.params['action'] + protocol = module.params['protocol'] + port = module.params['port'] + + if action in ['add', 'remove']: + if not protocol: + error = "protocol is required when action is %s" % action + module.fail_json(msg=error) + if not port: + error = "port is required when action is %s" % action + module.fail_json(msg=error) + + iptables_manager = IpTablesManager(module) + + try: + if action == 'add': + iptables_manager.add_rule(port, protocol) + elif action == 'remove': + iptables_manager.remove_rule(port, protocol) + elif action == 'verify_chain': + iptables_manager.verify_chain() + except IpTablesError as ex: + module.fail_json(msg=ex.msg) + + return module.exit_json(changed=iptables_manager.changed, + output=iptables_manager.output) + + +# pylint: disable=redefined-builtin, unused-wildcard-import, wildcard-import, wrong-import-position +# import module snippets +from ansible.module_utils.basic import * # noqa: F403,E402 +from ansible.module_utils._text import to_native # noqa: E402 +if __name__ == '__main__': + main() diff --git a/roles/lib_utils/library/rpm_q.py b/roles/lib_utils/library/rpm_q.py new file mode 100644 index 000000000..3dec50fc2 --- /dev/null +++ b/roles/lib_utils/library/rpm_q.py @@ -0,0 +1,72 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# (c) 2015, Tobias Florek <tob@butter.sh> +# Licensed under the terms of the MIT License +""" +An ansible module to query the RPM database. For use, when yum/dnf are not +available. +""" + +# pylint: disable=redefined-builtin,wildcard-import,unused-wildcard-import +from ansible.module_utils.basic import * # noqa: F403 + +DOCUMENTATION = """ +--- +module: rpm_q +short_description: Query the RPM database +author: Tobias Florek +options: + name: + description: + - The name of the package to query + required: true + state: + description: + - Whether the package is supposed to be installed or not + choices: [present, absent] + default: present +""" + +EXAMPLES = """ +- rpm_q: name=ansible state=present +- rpm_q: name=ansible state=absent +""" + +RPM_BINARY = '/bin/rpm' + + +def main(): + """ + Checks rpm -q for the named package and returns the installed packages + or None if not installed. + """ + module = AnsibleModule( # noqa: F405 + argument_spec=dict( + name=dict(required=True), + state=dict(default='present', choices=['present', 'absent']) + ), + supports_check_mode=True + ) + + name = module.params['name'] + state = module.params['state'] + + # pylint: disable=invalid-name + rc, out, err = module.run_command([RPM_BINARY, '-q', name]) + + installed = out.rstrip('\n').split('\n') + + if rc != 0: + if state == 'present': + module.fail_json(msg="%s is not installed" % name, stdout=out, stderr=err, rc=rc) + else: + module.exit_json(changed=False) + elif state == 'present': + module.exit_json(changed=False, installed_versions=installed) + else: + module.fail_json(msg="%s is installed", installed_versions=installed) + + +if __name__ == '__main__': + main() diff --git a/roles/lib_utils/lookup_plugins/openshift_master_facts_default_predicates.py b/roles/lib_utils/lookup_plugins/openshift_master_facts_default_predicates.py new file mode 100644 index 000000000..4858c5ec6 --- /dev/null +++ b/roles/lib_utils/lookup_plugins/openshift_master_facts_default_predicates.py @@ -0,0 +1,143 @@ +# pylint: disable=missing-docstring + +import re +from ansible.errors import AnsibleError +from ansible.plugins.lookup import LookupBase + + +class LookupModule(LookupBase): + # pylint: disable=too-many-branches,too-many-statements,too-many-arguments + + def run(self, terms, variables=None, regions_enabled=True, short_version=None, + deployment_type=None, **kwargs): + + predicates = [] + + if short_version is None or deployment_type is None: + if 'openshift' not in variables: + raise AnsibleError("This lookup module requires openshift_facts to be run prior to use") + + if deployment_type is None: + if 'common' not in variables['openshift'] or 'deployment_type' not in variables['openshift']['common']: + raise AnsibleError("This lookup module requires that the deployment_type be set") + + deployment_type = variables['openshift']['common']['deployment_type'] + + if short_version is None: + if 'short_version' in variables['openshift']['common']: + short_version = variables['openshift']['common']['short_version'] + elif 'openshift_release' in variables: + release = variables['openshift_release'] + if release.startswith('v'): + short_version = release[1:] + else: + short_version = release + short_version = '.'.join(short_version.split('.')[0:2]) + elif 'openshift_version' in variables: + version = variables['openshift_version'] + short_version = '.'.join(version.split('.')[0:2]) + else: + # pylint: disable=line-too-long + raise AnsibleError("Either OpenShift needs to be installed or openshift_release needs to be specified") + if deployment_type == 'origin': + if short_version not in ['1.1', '1.2', '1.3', '1.4', '1.5', '3.6', '3.7', '3.8', '3.9', 'latest']: + raise AnsibleError("Unknown short_version %s" % short_version) + elif deployment_type == 'openshift-enterprise': + if short_version not in ['3.1', '3.2', '3.3', '3.4', '3.5', '3.6', '3.7', '3.8', '3.9', 'latest']: + raise AnsibleError("Unknown short_version %s" % short_version) + else: + raise AnsibleError("Unknown deployment_type %s" % deployment_type) + + if deployment_type == 'origin': + # convert short_version to enterprise short_version + short_version = re.sub('^1.', '3.', short_version) + + if short_version == 'latest': + short_version = '3.9' + + # Predicates ordered according to OpenShift Origin source: + # origin/vendor/k8s.io/kubernetes/plugin/pkg/scheduler/algorithmprovider/defaults/defaults.go + + if short_version == '3.1': + predicates.extend([ + {'name': 'PodFitsHostPorts'}, + {'name': 'PodFitsResources'}, + {'name': 'NoDiskConflict'}, + {'name': 'MatchNodeSelector'}, + ]) + + if short_version == '3.2': + predicates.extend([ + {'name': 'PodFitsHostPorts'}, + {'name': 'PodFitsResources'}, + {'name': 'NoDiskConflict'}, + {'name': 'NoVolumeZoneConflict'}, + {'name': 'MatchNodeSelector'}, + {'name': 'MaxEBSVolumeCount'}, + {'name': 'MaxGCEPDVolumeCount'} + ]) + + if short_version == '3.3': + predicates.extend([ + {'name': 'NoDiskConflict'}, + {'name': 'NoVolumeZoneConflict'}, + {'name': 'MaxEBSVolumeCount'}, + {'name': 'MaxGCEPDVolumeCount'}, + {'name': 'GeneralPredicates'}, + {'name': 'PodToleratesNodeTaints'}, + {'name': 'CheckNodeMemoryPressure'} + ]) + + if short_version == '3.4': + predicates.extend([ + {'name': 'NoDiskConflict'}, + {'name': 'NoVolumeZoneConflict'}, + {'name': 'MaxEBSVolumeCount'}, + {'name': 'MaxGCEPDVolumeCount'}, + {'name': 'GeneralPredicates'}, + {'name': 'PodToleratesNodeTaints'}, + {'name': 'CheckNodeMemoryPressure'}, + {'name': 'CheckNodeDiskPressure'}, + {'name': 'MatchInterPodAffinity'} + ]) + + if short_version in ['3.5', '3.6']: + predicates.extend([ + {'name': 'NoVolumeZoneConflict'}, + {'name': 'MaxEBSVolumeCount'}, + {'name': 'MaxGCEPDVolumeCount'}, + {'name': 'MatchInterPodAffinity'}, + {'name': 'NoDiskConflict'}, + {'name': 'GeneralPredicates'}, + {'name': 'PodToleratesNodeTaints'}, + {'name': 'CheckNodeMemoryPressure'}, + {'name': 'CheckNodeDiskPressure'}, + ]) + + if short_version in ['3.7', '3.8', '3.9']: + predicates.extend([ + {'name': 'NoVolumeZoneConflict'}, + {'name': 'MaxEBSVolumeCount'}, + {'name': 'MaxGCEPDVolumeCount'}, + {'name': 'MaxAzureDiskVolumeCount'}, + {'name': 'MatchInterPodAffinity'}, + {'name': 'NoDiskConflict'}, + {'name': 'GeneralPredicates'}, + {'name': 'PodToleratesNodeTaints'}, + {'name': 'CheckNodeMemoryPressure'}, + {'name': 'CheckNodeDiskPressure'}, + {'name': 'NoVolumeNodeConflict'}, + ]) + + if regions_enabled: + region_predicate = { + 'name': 'Region', + 'argument': { + 'serviceAffinity': { + 'labels': ['region'] + } + } + } + predicates.append(region_predicate) + + return predicates diff --git a/roles/lib_utils/lookup_plugins/openshift_master_facts_default_priorities.py b/roles/lib_utils/lookup_plugins/openshift_master_facts_default_priorities.py new file mode 100644 index 000000000..18e1b2e0c --- /dev/null +++ b/roles/lib_utils/lookup_plugins/openshift_master_facts_default_priorities.py @@ -0,0 +1,117 @@ +# pylint: disable=missing-docstring + +import re +from ansible.errors import AnsibleError +from ansible.plugins.lookup import LookupBase + + +class LookupModule(LookupBase): + # pylint: disable=too-many-branches,too-many-statements,too-many-arguments + + def run(self, terms, variables=None, zones_enabled=True, short_version=None, + deployment_type=None, **kwargs): + + priorities = [] + + if short_version is None or deployment_type is None: + if 'openshift' not in variables: + raise AnsibleError("This lookup module requires openshift_facts to be run prior to use") + + if deployment_type is None: + if 'common' not in variables['openshift'] or 'deployment_type' not in variables['openshift']['common']: + raise AnsibleError("This lookup module requires that the deployment_type be set") + + deployment_type = variables['openshift']['common']['deployment_type'] + + if short_version is None: + if 'short_version' in variables['openshift']['common']: + short_version = variables['openshift']['common']['short_version'] + elif 'openshift_release' in variables: + release = variables['openshift_release'] + if release.startswith('v'): + short_version = release[1:] + else: + short_version = release + short_version = '.'.join(short_version.split('.')[0:2]) + elif 'openshift_version' in variables: + version = variables['openshift_version'] + short_version = '.'.join(version.split('.')[0:2]) + else: + # pylint: disable=line-too-long + raise AnsibleError("Either OpenShift needs to be installed or openshift_release needs to be specified") + + if deployment_type == 'origin': + if short_version not in ['1.1', '1.2', '1.3', '1.4', '1.5', '3.6', '3.7', '3.8', '3.9', 'latest']: + raise AnsibleError("Unknown short_version %s" % short_version) + elif deployment_type == 'openshift-enterprise': + if short_version not in ['3.1', '3.2', '3.3', '3.4', '3.5', '3.6', '3.7', '3.8', '3.9', 'latest']: + raise AnsibleError("Unknown short_version %s" % short_version) + else: + raise AnsibleError("Unknown deployment_type %s" % deployment_type) + + if deployment_type == 'origin': + # convert short_version to origin short_version + short_version = re.sub('^1.', '3.', short_version) + + if short_version == 'latest': + short_version = '3.9' + + if short_version == '3.1': + priorities.extend([ + {'name': 'LeastRequestedPriority', 'weight': 1}, + {'name': 'BalancedResourceAllocation', 'weight': 1}, + {'name': 'SelectorSpreadPriority', 'weight': 1} + ]) + + if short_version == '3.2': + priorities.extend([ + {'name': 'LeastRequestedPriority', 'weight': 1}, + {'name': 'BalancedResourceAllocation', 'weight': 1}, + {'name': 'SelectorSpreadPriority', 'weight': 1}, + {'name': 'NodeAffinityPriority', 'weight': 1} + ]) + + if short_version == '3.3': + priorities.extend([ + {'name': 'LeastRequestedPriority', 'weight': 1}, + {'name': 'BalancedResourceAllocation', 'weight': 1}, + {'name': 'SelectorSpreadPriority', 'weight': 1}, + {'name': 'NodeAffinityPriority', 'weight': 1}, + {'name': 'TaintTolerationPriority', 'weight': 1} + ]) + + if short_version == '3.4': + priorities.extend([ + {'name': 'LeastRequestedPriority', 'weight': 1}, + {'name': 'BalancedResourceAllocation', 'weight': 1}, + {'name': 'SelectorSpreadPriority', 'weight': 1}, + {'name': 'NodePreferAvoidPodsPriority', 'weight': 10000}, + {'name': 'NodeAffinityPriority', 'weight': 1}, + {'name': 'TaintTolerationPriority', 'weight': 1}, + {'name': 'InterPodAffinityPriority', 'weight': 1} + ]) + + if short_version in ['3.5', '3.6', '3.7', '3.8', '3.9']: + priorities.extend([ + {'name': 'SelectorSpreadPriority', 'weight': 1}, + {'name': 'InterPodAffinityPriority', 'weight': 1}, + {'name': 'LeastRequestedPriority', 'weight': 1}, + {'name': 'BalancedResourceAllocation', 'weight': 1}, + {'name': 'NodePreferAvoidPodsPriority', 'weight': 10000}, + {'name': 'NodeAffinityPriority', 'weight': 1}, + {'name': 'TaintTolerationPriority', 'weight': 1} + ]) + + if zones_enabled: + zone_priority = { + 'name': 'Zone', + 'argument': { + 'serviceAntiAffinity': { + 'label': 'zone' + } + }, + 'weight': 2 + } + priorities.append(zone_priority) + + return priorities diff --git a/roles/lib_utils/test/conftest.py b/roles/lib_utils/test/conftest.py new file mode 100644 index 000000000..aabdd4fa1 --- /dev/null +++ b/roles/lib_utils/test/conftest.py @@ -0,0 +1,172 @@ +# pylint: disable=missing-docstring,invalid-name,redefined-outer-name +import os +import pytest +import sys + +from OpenSSL import crypto + +sys.path.insert(1, os.path.join(os.path.dirname(__file__), os.pardir, "lookup_plugins")) + +from openshift_master_facts_default_predicates import LookupModule as PredicatesLookupModule # noqa: E402 +from openshift_master_facts_default_priorities import LookupModule as PrioritiesLookupModule # noqa: E402 + +# Parameter list for valid_cert fixture +VALID_CERTIFICATE_PARAMS = [ + { + 'short_name': 'client', + 'cn': 'client.example.com', + 'serial': 4, + 'uses': b'clientAuth', + 'dns': [], + 'ip': [], + }, + { + 'short_name': 'server', + 'cn': 'server.example.com', + 'serial': 5, + 'uses': b'serverAuth', + 'dns': ['kubernetes', 'openshift'], + 'ip': ['10.0.0.1', '192.168.0.1'] + }, + { + 'short_name': 'combined', + 'cn': 'combined.example.com', + # Verify that HUGE serials parse correctly. + # Frobs PARSING_HEX_SERIAL in _parse_cert + # See https://bugzilla.redhat.com/show_bug.cgi?id=1464240 + 'serial': 14449739080294792594019643629255165375, + 'uses': b'clientAuth, serverAuth', + 'dns': ['etcd'], + 'ip': ['10.0.0.2', '192.168.0.2'] + } +] + +# Extract the short_name from VALID_CERTIFICATE_PARAMS to provide +# friendly naming for the valid_cert fixture +VALID_CERTIFICATE_IDS = [param['short_name'] for param in VALID_CERTIFICATE_PARAMS] + + +@pytest.fixture(scope='session') +def ca(tmpdir_factory): + ca_dir = tmpdir_factory.mktemp('ca') + + key = crypto.PKey() + key.generate_key(crypto.TYPE_RSA, 2048) + + cert = crypto.X509() + cert.set_version(3) + cert.set_serial_number(1) + cert.get_subject().commonName = 'test-signer' + cert.gmtime_adj_notBefore(0) + cert.gmtime_adj_notAfter(24 * 60 * 60) + cert.set_issuer(cert.get_subject()) + cert.set_pubkey(key) + cert.add_extensions([ + crypto.X509Extension(b'basicConstraints', True, b'CA:TRUE, pathlen:0'), + crypto.X509Extension(b'keyUsage', True, + b'digitalSignature, keyEncipherment, keyCertSign, cRLSign'), + crypto.X509Extension(b'subjectKeyIdentifier', False, b'hash', subject=cert) + ]) + cert.add_extensions([ + crypto.X509Extension(b'authorityKeyIdentifier', False, b'keyid:always', issuer=cert) + ]) + cert.sign(key, 'sha256') + + return { + 'dir': ca_dir, + 'key': key, + 'cert': cert, + } + + +@pytest.fixture(scope='session', + ids=VALID_CERTIFICATE_IDS, + params=VALID_CERTIFICATE_PARAMS) +def valid_cert(request, ca): + common_name = request.param['cn'] + + key = crypto.PKey() + key.generate_key(crypto.TYPE_RSA, 2048) + + cert = crypto.X509() + cert.set_serial_number(request.param['serial']) + cert.gmtime_adj_notBefore(0) + cert.gmtime_adj_notAfter(24 * 60 * 60) + cert.set_issuer(ca['cert'].get_subject()) + cert.set_pubkey(key) + cert.set_version(3) + cert.get_subject().commonName = common_name + cert.add_extensions([ + crypto.X509Extension(b'basicConstraints', True, b'CA:FALSE'), + crypto.X509Extension(b'keyUsage', True, b'digitalSignature, keyEncipherment'), + crypto.X509Extension(b'extendedKeyUsage', False, request.param['uses']), + ]) + + if request.param['dns'] or request.param['ip']: + san_list = ['DNS:{}'.format(common_name)] + san_list.extend(['DNS:{}'.format(x) for x in request.param['dns']]) + san_list.extend(['IP:{}'.format(x) for x in request.param['ip']]) + + cert.add_extensions([ + crypto.X509Extension(b'subjectAltName', False, ', '.join(san_list).encode('utf8')) + ]) + cert.sign(ca['key'], 'sha256') + + cert_contents = crypto.dump_certificate(crypto.FILETYPE_PEM, cert) + cert_file = ca['dir'].join('{}.crt'.format(common_name)) + cert_file.write_binary(cert_contents) + + return { + 'common_name': common_name, + 'serial': request.param['serial'], + 'dns': request.param['dns'], + 'ip': request.param['ip'], + 'uses': request.param['uses'], + 'cert_file': cert_file, + 'cert': cert + } + + +@pytest.fixture() +def predicates_lookup(): + return PredicatesLookupModule() + + +@pytest.fixture() +def priorities_lookup(): + return PrioritiesLookupModule() + + +@pytest.fixture() +def facts(): + return { + 'openshift': { + 'common': {} + } + } + + +@pytest.fixture(params=[True, False]) +def regions_enabled(request): + return request.param + + +@pytest.fixture(params=[True, False]) +def zones_enabled(request): + return request.param + + +def v_prefix(release): + """Prefix a release number with 'v'.""" + return "v" + release + + +def minor(release): + """Add a suffix to release, making 'X.Y' become 'X.Y.Z'.""" + return release + ".1" + + +@pytest.fixture(params=[str, v_prefix, minor]) +def release_mod(request): + """Modifies a release string to alternative valid values.""" + return request.param diff --git a/roles/lib_utils/test/openshift_master_facts_bad_input_tests.py b/roles/lib_utils/test/openshift_master_facts_bad_input_tests.py new file mode 100644 index 000000000..e8da1e04a --- /dev/null +++ b/roles/lib_utils/test/openshift_master_facts_bad_input_tests.py @@ -0,0 +1,57 @@ +import copy +import os +import sys + +from ansible.errors import AnsibleError +import pytest + +sys.path.insert(1, os.path.join(os.path.dirname(__file__), os.pardir, "lookup_plugins")) + +from openshift_master_facts_default_predicates import LookupModule # noqa: E402 + + +class TestOpenShiftMasterFactsBadInput(object): + lookup = LookupModule() + default_facts = { + 'openshift': { + 'common': {} + } + } + + def test_missing_openshift_facts(self): + with pytest.raises(AnsibleError): + facts = {} + self.lookup.run(None, variables=facts) + + def test_missing_deployment_type(self): + with pytest.raises(AnsibleError): + facts = copy.deepcopy(self.default_facts) + facts['openshift']['common']['short_version'] = '10.10' + self.lookup.run(None, variables=facts) + + def test_missing_short_version_and_missing_openshift_release(self): + with pytest.raises(AnsibleError): + facts = copy.deepcopy(self.default_facts) + facts['openshift']['common']['deployment_type'] = 'origin' + self.lookup.run(None, variables=facts) + + def test_unknown_deployment_types(self): + with pytest.raises(AnsibleError): + facts = copy.deepcopy(self.default_facts) + facts['openshift']['common']['short_version'] = '1.1' + facts['openshift']['common']['deployment_type'] = 'bogus' + self.lookup.run(None, variables=facts) + + def test_unknown_origin_version(self): + with pytest.raises(AnsibleError): + facts = copy.deepcopy(self.default_facts) + facts['openshift']['common']['short_version'] = '0.1' + facts['openshift']['common']['deployment_type'] = 'origin' + self.lookup.run(None, variables=facts) + + def test_unknown_ocp_version(self): + with pytest.raises(AnsibleError): + facts = copy.deepcopy(self.default_facts) + facts['openshift']['common']['short_version'] = '0.1' + facts['openshift']['common']['deployment_type'] = 'openshift-enterprise' + self.lookup.run(None, variables=facts) diff --git a/roles/lib_utils/test/openshift_master_facts_conftest.py b/roles/lib_utils/test/openshift_master_facts_conftest.py new file mode 100644 index 000000000..140cced73 --- /dev/null +++ b/roles/lib_utils/test/openshift_master_facts_conftest.py @@ -0,0 +1,54 @@ +import os +import sys + +import pytest + +sys.path.insert(1, os.path.join(os.path.dirname(__file__), os.pardir, "lookup_plugins")) + +from openshift_master_facts_default_predicates import LookupModule as PredicatesLookupModule # noqa: E402 +from openshift_master_facts_default_priorities import LookupModule as PrioritiesLookupModule # noqa: E402 + + +@pytest.fixture() +def predicates_lookup(): + return PredicatesLookupModule() + + +@pytest.fixture() +def priorities_lookup(): + return PrioritiesLookupModule() + + +@pytest.fixture() +def facts(): + return { + 'openshift': { + 'common': {} + } + } + + +@pytest.fixture(params=[True, False]) +def regions_enabled(request): + return request.param + + +@pytest.fixture(params=[True, False]) +def zones_enabled(request): + return request.param + + +def v_prefix(release): + """Prefix a release number with 'v'.""" + return "v" + release + + +def minor(release): + """Add a suffix to release, making 'X.Y' become 'X.Y.Z'.""" + return release + ".1" + + +@pytest.fixture(params=[str, v_prefix, minor]) +def release_mod(request): + """Modifies a release string to alternative valid values.""" + return request.param diff --git a/roles/lib_utils/test/openshift_master_facts_default_predicates_tests.py b/roles/lib_utils/test/openshift_master_facts_default_predicates_tests.py new file mode 100644 index 000000000..11aad9f03 --- /dev/null +++ b/roles/lib_utils/test/openshift_master_facts_default_predicates_tests.py @@ -0,0 +1,193 @@ +import pytest + + +# Predicates ordered according to OpenShift Origin source: +# origin/vendor/k8s.io/kubernetes/plugin/pkg/scheduler/algorithmprovider/defaults/defaults.go + +DEFAULT_PREDICATES_1_1 = [ + {'name': 'PodFitsHostPorts'}, + {'name': 'PodFitsResources'}, + {'name': 'NoDiskConflict'}, + {'name': 'MatchNodeSelector'}, +] + +DEFAULT_PREDICATES_1_2 = [ + {'name': 'PodFitsHostPorts'}, + {'name': 'PodFitsResources'}, + {'name': 'NoDiskConflict'}, + {'name': 'NoVolumeZoneConflict'}, + {'name': 'MatchNodeSelector'}, + {'name': 'MaxEBSVolumeCount'}, + {'name': 'MaxGCEPDVolumeCount'} +] + +DEFAULT_PREDICATES_1_3 = [ + {'name': 'NoDiskConflict'}, + {'name': 'NoVolumeZoneConflict'}, + {'name': 'MaxEBSVolumeCount'}, + {'name': 'MaxGCEPDVolumeCount'}, + {'name': 'GeneralPredicates'}, + {'name': 'PodToleratesNodeTaints'}, + {'name': 'CheckNodeMemoryPressure'} +] + +DEFAULT_PREDICATES_1_4 = [ + {'name': 'NoDiskConflict'}, + {'name': 'NoVolumeZoneConflict'}, + {'name': 'MaxEBSVolumeCount'}, + {'name': 'MaxGCEPDVolumeCount'}, + {'name': 'GeneralPredicates'}, + {'name': 'PodToleratesNodeTaints'}, + {'name': 'CheckNodeMemoryPressure'}, + {'name': 'CheckNodeDiskPressure'}, + {'name': 'MatchInterPodAffinity'} +] + +DEFAULT_PREDICATES_1_5 = [ + {'name': 'NoVolumeZoneConflict'}, + {'name': 'MaxEBSVolumeCount'}, + {'name': 'MaxGCEPDVolumeCount'}, + {'name': 'MatchInterPodAffinity'}, + {'name': 'NoDiskConflict'}, + {'name': 'GeneralPredicates'}, + {'name': 'PodToleratesNodeTaints'}, + {'name': 'CheckNodeMemoryPressure'}, + {'name': 'CheckNodeDiskPressure'}, +] + +DEFAULT_PREDICATES_3_6 = DEFAULT_PREDICATES_1_5 + +DEFAULT_PREDICATES_3_7 = [ + {'name': 'NoVolumeZoneConflict'}, + {'name': 'MaxEBSVolumeCount'}, + {'name': 'MaxGCEPDVolumeCount'}, + {'name': 'MaxAzureDiskVolumeCount'}, + {'name': 'MatchInterPodAffinity'}, + {'name': 'NoDiskConflict'}, + {'name': 'GeneralPredicates'}, + {'name': 'PodToleratesNodeTaints'}, + {'name': 'CheckNodeMemoryPressure'}, + {'name': 'CheckNodeDiskPressure'}, + {'name': 'NoVolumeNodeConflict'}, +] + +DEFAULT_PREDICATES_3_9 = DEFAULT_PREDICATES_3_8 = DEFAULT_PREDICATES_3_7 + +REGION_PREDICATE = { + 'name': 'Region', + 'argument': { + 'serviceAffinity': { + 'labels': ['region'] + } + } +} + +TEST_VARS = [ + ('1.1', 'origin', DEFAULT_PREDICATES_1_1), + ('3.1', 'openshift-enterprise', DEFAULT_PREDICATES_1_1), + ('1.2', 'origin', DEFAULT_PREDICATES_1_2), + ('3.2', 'openshift-enterprise', DEFAULT_PREDICATES_1_2), + ('1.3', 'origin', DEFAULT_PREDICATES_1_3), + ('3.3', 'openshift-enterprise', DEFAULT_PREDICATES_1_3), + ('1.4', 'origin', DEFAULT_PREDICATES_1_4), + ('3.4', 'openshift-enterprise', DEFAULT_PREDICATES_1_4), + ('1.5', 'origin', DEFAULT_PREDICATES_1_5), + ('3.5', 'openshift-enterprise', DEFAULT_PREDICATES_1_5), + ('3.6', 'origin', DEFAULT_PREDICATES_3_6), + ('3.6', 'openshift-enterprise', DEFAULT_PREDICATES_3_6), + ('3.7', 'origin', DEFAULT_PREDICATES_3_7), + ('3.7', 'openshift-enterprise', DEFAULT_PREDICATES_3_7), + ('3.8', 'origin', DEFAULT_PREDICATES_3_8), + ('3.8', 'openshift-enterprise', DEFAULT_PREDICATES_3_8), + ('3.9', 'origin', DEFAULT_PREDICATES_3_9), + ('3.9', 'openshift-enterprise', DEFAULT_PREDICATES_3_9), +] + + +def assert_ok(predicates_lookup, default_predicates, regions_enabled, **kwargs): + results = predicates_lookup.run(None, regions_enabled=regions_enabled, **kwargs) + if regions_enabled: + assert results == default_predicates + [REGION_PREDICATE] + else: + assert results == default_predicates + + +def test_openshift_version(predicates_lookup, openshift_version_fixture, regions_enabled): + facts, default_predicates = openshift_version_fixture + assert_ok(predicates_lookup, default_predicates, variables=facts, regions_enabled=regions_enabled) + + +@pytest.fixture(params=TEST_VARS) +def openshift_version_fixture(request, facts): + version, deployment_type, default_predicates = request.param + version += '.1' + facts['openshift_version'] = version + facts['openshift']['common']['deployment_type'] = deployment_type + return facts, default_predicates + + +def test_openshift_release(predicates_lookup, openshift_release_fixture, regions_enabled): + facts, default_predicates = openshift_release_fixture + assert_ok(predicates_lookup, default_predicates, variables=facts, regions_enabled=regions_enabled) + + +@pytest.fixture(params=TEST_VARS) +def openshift_release_fixture(request, facts, release_mod): + release, deployment_type, default_predicates = request.param + facts['openshift_release'] = release_mod(release) + facts['openshift']['common']['deployment_type'] = deployment_type + return facts, default_predicates + + +def test_short_version(predicates_lookup, short_version_fixture, regions_enabled): + facts, default_predicates = short_version_fixture + assert_ok(predicates_lookup, default_predicates, variables=facts, regions_enabled=regions_enabled) + + +@pytest.fixture(params=TEST_VARS) +def short_version_fixture(request, facts): + short_version, deployment_type, default_predicates = request.param + facts['openshift']['common']['short_version'] = short_version + facts['openshift']['common']['deployment_type'] = deployment_type + return facts, default_predicates + + +def test_short_version_kwarg(predicates_lookup, short_version_kwarg_fixture, regions_enabled): + facts, short_version, default_predicates = short_version_kwarg_fixture + assert_ok( + predicates_lookup, default_predicates, variables=facts, + regions_enabled=regions_enabled, short_version=short_version) + + +@pytest.fixture(params=TEST_VARS) +def short_version_kwarg_fixture(request, facts): + short_version, deployment_type, default_predicates = request.param + facts['openshift']['common']['deployment_type'] = deployment_type + return facts, short_version, default_predicates + + +def test_deployment_type_kwarg(predicates_lookup, deployment_type_kwarg_fixture, regions_enabled): + facts, deployment_type, default_predicates = deployment_type_kwarg_fixture + assert_ok( + predicates_lookup, default_predicates, variables=facts, + regions_enabled=regions_enabled, deployment_type=deployment_type) + + +@pytest.fixture(params=TEST_VARS) +def deployment_type_kwarg_fixture(request, facts): + short_version, deployment_type, default_predicates = request.param + facts['openshift']['common']['short_version'] = short_version + return facts, deployment_type, default_predicates + + +def test_short_version_deployment_type_kwargs( + predicates_lookup, short_version_deployment_type_kwargs_fixture, regions_enabled): + short_version, deployment_type, default_predicates = short_version_deployment_type_kwargs_fixture + assert_ok( + predicates_lookup, default_predicates, regions_enabled=regions_enabled, + short_version=short_version, deployment_type=deployment_type) + + +@pytest.fixture(params=TEST_VARS) +def short_version_deployment_type_kwargs_fixture(request): + return request.param diff --git a/roles/lib_utils/test/openshift_master_facts_default_priorities_tests.py b/roles/lib_utils/test/openshift_master_facts_default_priorities_tests.py new file mode 100644 index 000000000..527fc9ff4 --- /dev/null +++ b/roles/lib_utils/test/openshift_master_facts_default_priorities_tests.py @@ -0,0 +1,167 @@ +import pytest + + +DEFAULT_PRIORITIES_1_1 = [ + {'name': 'LeastRequestedPriority', 'weight': 1}, + {'name': 'BalancedResourceAllocation', 'weight': 1}, + {'name': 'SelectorSpreadPriority', 'weight': 1} +] + +DEFAULT_PRIORITIES_1_2 = [ + {'name': 'LeastRequestedPriority', 'weight': 1}, + {'name': 'BalancedResourceAllocation', 'weight': 1}, + {'name': 'SelectorSpreadPriority', 'weight': 1}, + {'name': 'NodeAffinityPriority', 'weight': 1} +] + +DEFAULT_PRIORITIES_1_3 = [ + {'name': 'LeastRequestedPriority', 'weight': 1}, + {'name': 'BalancedResourceAllocation', 'weight': 1}, + {'name': 'SelectorSpreadPriority', 'weight': 1}, + {'name': 'NodeAffinityPriority', 'weight': 1}, + {'name': 'TaintTolerationPriority', 'weight': 1} +] + +DEFAULT_PRIORITIES_1_4 = [ + {'name': 'LeastRequestedPriority', 'weight': 1}, + {'name': 'BalancedResourceAllocation', 'weight': 1}, + {'name': 'SelectorSpreadPriority', 'weight': 1}, + {'name': 'NodePreferAvoidPodsPriority', 'weight': 10000}, + {'name': 'NodeAffinityPriority', 'weight': 1}, + {'name': 'TaintTolerationPriority', 'weight': 1}, + {'name': 'InterPodAffinityPriority', 'weight': 1} +] + +DEFAULT_PRIORITIES_1_5 = [ + {'name': 'SelectorSpreadPriority', 'weight': 1}, + {'name': 'InterPodAffinityPriority', 'weight': 1}, + {'name': 'LeastRequestedPriority', 'weight': 1}, + {'name': 'BalancedResourceAllocation', 'weight': 1}, + {'name': 'NodePreferAvoidPodsPriority', 'weight': 10000}, + {'name': 'NodeAffinityPriority', 'weight': 1}, + {'name': 'TaintTolerationPriority', 'weight': 1} +] + +DEFAULT_PRIORITIES_3_6 = DEFAULT_PRIORITIES_1_5 + +DEFAULT_PRIORITIES_3_9 = DEFAULT_PRIORITIES_3_8 = DEFAULT_PRIORITIES_3_7 = DEFAULT_PRIORITIES_3_6 + +ZONE_PRIORITY = { + 'name': 'Zone', + 'argument': { + 'serviceAntiAffinity': { + 'label': 'zone' + } + }, + 'weight': 2 +} + +TEST_VARS = [ + ('1.1', 'origin', DEFAULT_PRIORITIES_1_1), + ('3.1', 'openshift-enterprise', DEFAULT_PRIORITIES_1_1), + ('1.2', 'origin', DEFAULT_PRIORITIES_1_2), + ('3.2', 'openshift-enterprise', DEFAULT_PRIORITIES_1_2), + ('1.3', 'origin', DEFAULT_PRIORITIES_1_3), + ('3.3', 'openshift-enterprise', DEFAULT_PRIORITIES_1_3), + ('1.4', 'origin', DEFAULT_PRIORITIES_1_4), + ('3.4', 'openshift-enterprise', DEFAULT_PRIORITIES_1_4), + ('1.5', 'origin', DEFAULT_PRIORITIES_1_5), + ('3.5', 'openshift-enterprise', DEFAULT_PRIORITIES_1_5), + ('3.6', 'origin', DEFAULT_PRIORITIES_3_6), + ('3.6', 'openshift-enterprise', DEFAULT_PRIORITIES_3_6), + ('3.7', 'origin', DEFAULT_PRIORITIES_3_7), + ('3.7', 'openshift-enterprise', DEFAULT_PRIORITIES_3_7), + ('3.8', 'origin', DEFAULT_PRIORITIES_3_8), + ('3.8', 'openshift-enterprise', DEFAULT_PRIORITIES_3_8), + ('3.9', 'origin', DEFAULT_PRIORITIES_3_9), + ('3.9', 'openshift-enterprise', DEFAULT_PRIORITIES_3_9), +] + + +def assert_ok(priorities_lookup, default_priorities, zones_enabled, **kwargs): + results = priorities_lookup.run(None, zones_enabled=zones_enabled, **kwargs) + if zones_enabled: + assert results == default_priorities + [ZONE_PRIORITY] + else: + assert results == default_priorities + + +def test_openshift_version(priorities_lookup, openshift_version_fixture, zones_enabled): + facts, default_priorities = openshift_version_fixture + assert_ok(priorities_lookup, default_priorities, variables=facts, zones_enabled=zones_enabled) + + +@pytest.fixture(params=TEST_VARS) +def openshift_version_fixture(request, facts): + version, deployment_type, default_priorities = request.param + version += '.1' + facts['openshift_version'] = version + facts['openshift']['common']['deployment_type'] = deployment_type + return facts, default_priorities + + +def test_openshift_release(priorities_lookup, openshift_release_fixture, zones_enabled): + facts, default_priorities = openshift_release_fixture + assert_ok(priorities_lookup, default_priorities, variables=facts, zones_enabled=zones_enabled) + + +@pytest.fixture(params=TEST_VARS) +def openshift_release_fixture(request, facts, release_mod): + release, deployment_type, default_priorities = request.param + facts['openshift_release'] = release_mod(release) + facts['openshift']['common']['deployment_type'] = deployment_type + return facts, default_priorities + + +def test_short_version(priorities_lookup, short_version_fixture, zones_enabled): + facts, default_priorities = short_version_fixture + assert_ok(priorities_lookup, default_priorities, variables=facts, zones_enabled=zones_enabled) + + +@pytest.fixture(params=TEST_VARS) +def short_version_fixture(request, facts): + short_version, deployment_type, default_priorities = request.param + facts['openshift']['common']['short_version'] = short_version + facts['openshift']['common']['deployment_type'] = deployment_type + return facts, default_priorities + + +def test_short_version_kwarg(priorities_lookup, short_version_kwarg_fixture, zones_enabled): + facts, short_version, default_priorities = short_version_kwarg_fixture + assert_ok( + priorities_lookup, default_priorities, variables=facts, + zones_enabled=zones_enabled, short_version=short_version) + + +@pytest.fixture(params=TEST_VARS) +def short_version_kwarg_fixture(request, facts): + short_version, deployment_type, default_priorities = request.param + facts['openshift']['common']['deployment_type'] = deployment_type + return facts, short_version, default_priorities + + +def test_deployment_type_kwarg(priorities_lookup, deployment_type_kwarg_fixture, zones_enabled): + facts, deployment_type, default_priorities = deployment_type_kwarg_fixture + assert_ok( + priorities_lookup, default_priorities, variables=facts, + zones_enabled=zones_enabled, deployment_type=deployment_type) + + +@pytest.fixture(params=TEST_VARS) +def deployment_type_kwarg_fixture(request, facts): + short_version, deployment_type, default_priorities = request.param + facts['openshift']['common']['short_version'] = short_version + return facts, deployment_type, default_priorities + + +def test_short_version_deployment_type_kwargs( + priorities_lookup, short_version_deployment_type_kwargs_fixture, zones_enabled): + short_version, deployment_type, default_priorities = short_version_deployment_type_kwargs_fixture + assert_ok( + priorities_lookup, default_priorities, zones_enabled=zones_enabled, + short_version=short_version, deployment_type=deployment_type) + + +@pytest.fixture(params=TEST_VARS) +def short_version_deployment_type_kwargs_fixture(request): + return request.param diff --git a/roles/lib_utils/test/test_fakeopensslclasses.py b/roles/lib_utils/test/test_fakeopensslclasses.py new file mode 100644 index 000000000..8a521a765 --- /dev/null +++ b/roles/lib_utils/test/test_fakeopensslclasses.py @@ -0,0 +1,90 @@ +''' + Unit tests for the FakeOpenSSL classes +''' +import os +import subprocess +import sys + +import pytest + +MODULE_PATH = os.path.realpath(os.path.join(__file__, os.pardir, os.pardir, 'library')) +sys.path.insert(1, MODULE_PATH) + +# pylint: disable=import-error,wrong-import-position,missing-docstring +# pylint: disable=invalid-name,redefined-outer-name +from openshift_cert_expiry import FakeOpenSSLCertificate # noqa: E402 + + +@pytest.fixture(scope='module') +def fake_valid_cert(valid_cert): + cmd = ['openssl', 'x509', '-in', str(valid_cert['cert_file']), '-text', + '-nameopt', 'oneline'] + cert = subprocess.check_output(cmd) + return FakeOpenSSLCertificate(cert.decode('utf8')) + + +def test_not_after(valid_cert, fake_valid_cert): + ''' Validate value returned back from get_notAfter() ''' + real_cert = valid_cert['cert'] + + # Internal representation of pyOpenSSL is bytes, while FakeOpenSSLCertificate + # is text, so decode the result from pyOpenSSL prior to comparing + assert real_cert.get_notAfter().decode('utf8') == fake_valid_cert.get_notAfter() + + +def test_serial(valid_cert, fake_valid_cert): + ''' Validate value returned back form get_serialnumber() ''' + real_cert = valid_cert['cert'] + assert real_cert.get_serial_number() == fake_valid_cert.get_serial_number() + + +def test_get_subject(valid_cert, fake_valid_cert): + ''' Validate the certificate subject ''' + + # Gather the subject components and create a list of colon separated strings. + # Since the internal representation of pyOpenSSL uses bytes, we need to decode + # the results before comparing. + c_subjects = valid_cert['cert'].get_subject().get_components() + c_subj = ', '.join(['{}:{}'.format(x.decode('utf8'), y.decode('utf8')) for x, y in c_subjects]) + f_subjects = fake_valid_cert.get_subject().get_components() + f_subj = ', '.join(['{}:{}'.format(x, y) for x, y in f_subjects]) + assert c_subj == f_subj + + +def get_san_extension(cert): + # Internal representation of pyOpenSSL is bytes, while FakeOpenSSLCertificate + # is text, so we need to set the value to search for accordingly. + if isinstance(cert, FakeOpenSSLCertificate): + san_short_name = 'subjectAltName' + else: + san_short_name = b'subjectAltName' + + for i in range(cert.get_extension_count()): + ext = cert.get_extension(i) + if ext.get_short_name() == san_short_name: + # return the string representation to compare the actual SAN + # values instead of the data types + return str(ext) + + return None + + +def test_subject_alt_names(valid_cert, fake_valid_cert): + real_cert = valid_cert['cert'] + + san = get_san_extension(real_cert) + f_san = get_san_extension(fake_valid_cert) + + assert san == f_san + + # If there are either dns or ip sans defined, verify common_name present + if valid_cert['ip'] or valid_cert['dns']: + assert 'DNS:' + valid_cert['common_name'] in f_san + + # Verify all ip sans are present + for ip in valid_cert['ip']: + assert 'IP Address:' + ip in f_san + + # Verify all dns sans are present + for name in valid_cert['dns']: + assert 'DNS:' + name in f_san diff --git a/roles/lib_utils/test/test_load_and_handle_cert.py b/roles/lib_utils/test/test_load_and_handle_cert.py new file mode 100644 index 000000000..98792e2ee --- /dev/null +++ b/roles/lib_utils/test/test_load_and_handle_cert.py @@ -0,0 +1,67 @@ +''' + Unit tests for the load_and_handle_cert method +''' +import datetime +import os +import sys + +import pytest + +MODULE_PATH = os.path.realpath(os.path.join(__file__, os.pardir, os.pardir, 'library')) +sys.path.insert(1, MODULE_PATH) + +# pylint: disable=import-error,wrong-import-position,missing-docstring +# pylint: disable=invalid-name,redefined-outer-name +import openshift_cert_expiry # noqa: E402 + +# TODO: More testing on the results of the load_and_handle_cert function +# could be implemented here as well, such as verifying subjects +# match up. + + +@pytest.fixture(params=['OpenSSLCertificate', 'FakeOpenSSLCertificate']) +def loaded_cert(request, valid_cert): + """ parameterized fixture to provide load_and_handle_cert results + for both OpenSSL and FakeOpenSSL parsed certificates + """ + now = datetime.datetime.now() + + openshift_cert_expiry.HAS_OPENSSL = request.param == 'OpenSSLCertificate' + + # valid_cert['cert_file'] is a `py.path.LocalPath` object and + # provides a read_text() method for reading the file contents. + cert_string = valid_cert['cert_file'].read_text('utf8') + + (subject, + expiry_date, + time_remaining, + serial) = openshift_cert_expiry.load_and_handle_cert(cert_string, now) + + return { + 'now': now, + 'subject': subject, + 'expiry_date': expiry_date, + 'time_remaining': time_remaining, + 'serial': serial, + } + + +def test_serial(loaded_cert, valid_cert): + """Params: + + * `loaded_cert` comes from the `loaded_cert` fixture in this file + * `valid_cert` comes from the 'valid_cert' fixture in conftest.py + """ + valid_cert_serial = valid_cert['cert'].get_serial_number() + assert loaded_cert['serial'] == valid_cert_serial + + +def test_expiry(loaded_cert): + """Params: + + * `loaded_cert` comes from the `loaded_cert` fixture in this file + """ + expiry_date = loaded_cert['expiry_date'] + time_remaining = loaded_cert['time_remaining'] + now = loaded_cert['now'] + assert expiry_date == now + time_remaining |