From f502b09c103b5d8681854b7ab6a3c9655311f73b Mon Sep 17 00:00:00 2001 From: Rodolfo Carvalho Date: Mon, 6 Feb 2017 17:06:18 +0100 Subject: Refactor code to access values from task_vars --- .../openshift_checks/__init__.py | 18 ++++++++++++++++++ .../openshift_checks/mixins.py | 7 ++----- .../openshift_checks/package_availability.py | 10 +++------- .../openshift_checks/package_version.py | 7 ++----- 4 files changed, 25 insertions(+), 17 deletions(-) diff --git a/roles/openshift_health_checker/openshift_checks/__init__.py b/roles/openshift_health_checker/openshift_checks/__init__.py index ff99e7b4c..c31242624 100644 --- a/roles/openshift_health_checker/openshift_checks/__init__.py +++ b/roles/openshift_health_checker/openshift_checks/__init__.py @@ -5,8 +5,10 @@ Health checks for OpenShift clusters. import os from abc import ABCMeta, abstractmethod, abstractproperty from importlib import import_module +import operator import six +from six.moves import reduce class OpenShiftCheckException(Exception): @@ -54,6 +56,22 @@ class OpenShiftCheck(object): yield subclass +def get_var(task_vars, *keys, **kwargs): + """Helper function to get deeply nested values from task_vars. + + Ansible task_vars structures are Python dicts, often mapping strings to + other dicts. This helper makes it easier to get a nested value, raising + OpenShiftCheckException when a key is not found. + """ + try: + value = reduce(operator.getitem, keys, task_vars) + except (KeyError, TypeError): + if "default" in kwargs: + return kwargs["default"] + raise OpenShiftCheckException("'{}' is undefined".format(".".join(map(str, keys)))) + return value + + # Dynamically import all submodules for the side effect of loading checks. EXCLUDES = ( diff --git a/roles/openshift_health_checker/openshift_checks/mixins.py b/roles/openshift_health_checker/openshift_checks/mixins.py index 4e0415944..4029fba62 100644 --- a/roles/openshift_health_checker/openshift_checks/mixins.py +++ b/roles/openshift_health_checker/openshift_checks/mixins.py @@ -1,5 +1,5 @@ # pylint: disable=missing-docstring -from openshift_checks import OpenShiftCheckException +from openshift_checks import get_var class NotContainerized(object): @@ -18,7 +18,4 @@ class NotContainerized(object): @staticmethod def is_containerized(task_vars): - try: - return task_vars["openshift"]["common"]["is_containerized"] - except (KeyError, TypeError): - raise OpenShiftCheckException("'openshift.common.is_containerized' is undefined") + return get_var(task_vars, "openshift", "common", "is_containerized") diff --git a/roles/openshift_health_checker/openshift_checks/package_availability.py b/roles/openshift_health_checker/openshift_checks/package_availability.py index 31277a3b9..8faeef5ee 100644 --- a/roles/openshift_health_checker/openshift_checks/package_availability.py +++ b/roles/openshift_health_checker/openshift_checks/package_availability.py @@ -1,5 +1,5 @@ # pylint: disable=missing-docstring -from openshift_checks import OpenShiftCheck, OpenShiftCheckException +from openshift_checks import OpenShiftCheck, get_var from openshift_checks.mixins import NotContainerized @@ -10,12 +10,8 @@ class PackageAvailability(NotContainerized, OpenShiftCheck): tags = ["preflight"] def run(self, tmp, task_vars): - try: - rpm_prefix = task_vars["openshift"]["common"]["service_type"] - except (KeyError, TypeError): - raise OpenShiftCheckException("'openshift.common.service_type' is undefined") - - group_names = task_vars.get("group_names", []) + rpm_prefix = get_var(task_vars, "openshift", "common", "service_type") + group_names = get_var(task_vars, "group_names", default=[]) packages = set() diff --git a/roles/openshift_health_checker/openshift_checks/package_version.py b/roles/openshift_health_checker/openshift_checks/package_version.py index 9394466f2..b31b4d401 100644 --- a/roles/openshift_health_checker/openshift_checks/package_version.py +++ b/roles/openshift_health_checker/openshift_checks/package_version.py @@ -1,5 +1,5 @@ # pylint: disable=missing-docstring -from openshift_checks import OpenShiftCheck, OpenShiftCheckException +from openshift_checks import OpenShiftCheck, get_var from openshift_checks.mixins import NotContainerized @@ -17,10 +17,7 @@ class PackageVersion(NotContainerized, OpenShiftCheck): ) def run(self, tmp, task_vars): - try: - openshift_release = task_vars["openshift_release"] - except (KeyError, TypeError): - raise OpenShiftCheckException("'openshift_release' is undefined") + openshift_release = get_var(task_vars, "openshift_release") args = {"version": openshift_release} return self.module_executor("aos_version", args, tmp, task_vars) -- cgit v1.2.3