diff options
Diffstat (limited to 'roles/os_firewall/library')
-rw-r--r-- | roles/os_firewall/library/os_firewall_manage_iptables.py | 62 |
1 files changed, 40 insertions, 22 deletions
diff --git a/roles/os_firewall/library/os_firewall_manage_iptables.py b/roles/os_firewall/library/os_firewall_manage_iptables.py index fef710055..6a018d022 100644 --- a/roles/os_firewall/library/os_firewall_manage_iptables.py +++ b/roles/os_firewall/library/os_firewall_manage_iptables.py @@ -51,11 +51,13 @@ class IpTablesCreateJumpRuleError(IpTablesError): # exception was thrown later. for example, when the chain is created # successfully, but the add/remove rule fails. class IpTablesManager: - def __init__(self, module, ip_version, check_mode, chain): + def __init__(self, module): self.module = module - self.ip_version = ip_version - self.check_mode = check_mode - self.chain = chain + self.ip_version = module.params['ip_version'] + self.check_mode = module.check_mode + self.chain = module.params['chain'] + self.create_jump_rule = module.params['create_jump_rule'] + self.jump_rule_chain = module.params['jump_rule_chain'] self.cmd = self.gen_cmd() self.save_cmd = self.gen_save_cmd() self.output = [] @@ -70,13 +72,16 @@ class IpTablesManager: msg="Failed to save iptables rules", cmd=e.cmd, exit_code=e.returncode, output=e.output) + def verify_chain(self): + if not self.chain_exists(): + self.create_chain() + if self.create_jump_rule and not self.jump_rule_exists(): + self.create_jump() + def add_rule(self, port, proto): rule = self.gen_rule(port, proto) if not self.rule_exists(rule): - if not self.chain_exists(): - self.create_chain() - if not self.jump_rule_exists(): - self.create_jump_rule() + self.verify_chain() if self.check_mode: self.changed = True @@ -121,13 +126,13 @@ class IpTablesManager: return [self.chain, '-p', proto, '-m', 'state', '--state', 'NEW', '-m', proto, '--dport', str(port), '-j', 'ACCEPT'] - def create_jump_rule(self): + def create_jump(self): if self.check_mode: self.changed = True self.output.append("Create jump rule for chain %s" % self.chain) else: try: - cmd = self.cmd + ['-L', 'INPUT', '--line-numbers'] + cmd = self.cmd + ['-L', self.jump_rule_chain, '--line-numbers'] output = check_output(cmd, stderr=subprocess.STDOUT) # break the input rules into rows and columns @@ -144,11 +149,11 @@ class IpTablesManager: continue last_rule_target = rule[1] - # Raise an exception if we do not find a valid INPUT rule + # Raise an exception if we do not find a valid rule if not last_rule_num or not last_rule_target: raise IpTablesCreateJumpRuleError( chain=self.chain, - msg="Failed to find existing INPUT rules", + msg="Failed to find existing %s rules" % self.jump_rule_chain, cmd=None, exit_code=None, output=None) # Naively assume that if the last row is a REJECT rule, then @@ -156,19 +161,20 @@ class IpTablesManager: # assume that we can just append the rule. if last_rule_target == 'REJECT': # insert rule - cmd = self.cmd + ['-I', 'INPUT', str(last_rule_num)] + cmd = self.cmd + ['-I', self.jump_rule_chain, str(last_rule_num)] else: # append rule - cmd = self.cmd + ['-A', 'INPUT'] + cmd = self.cmd + ['-A', self.jump_rule_chain] cmd += ['-j', self.chain] output = check_output(cmd, stderr=subprocess.STDOUT) changed = True self.output.append(output) + self.save() except subprocess.CalledProcessError as e: if '--line-numbers' in e.cmd: raise IpTablesCreateJumpRuleError( chain=self.chain, - msg="Failed to query existing INPUT rules to " + msg="Failed to query existing %s rules to " % self.jump_rule_chain + "determine jump rule location", cmd=e.cmd, exit_code=e.returncode, output=e.output) @@ -192,6 +198,7 @@ class IpTablesManager: self.changed = True self.output.append("Successfully created chain %s" % self.chain) + self.save() except subprocess.CalledProcessError as e: raise IpTablesCreateChainError( chain=self.chain, @@ -200,7 +207,7 @@ class IpTablesManager: ) def jump_rule_exists(self): - cmd = self.cmd + ['-C', 'INPUT', '-j', self.chain] + cmd = self.cmd + ['-C', self.jump_rule_chain, '-j', self.chain] return True if subprocess.call(cmd) == 0 else False def chain_exists(self): @@ -220,9 +227,12 @@ def main(): module = AnsibleModule( argument_spec=dict( name=dict(required=True), - action=dict(required=True, choices=['add', 'remove']), - protocol=dict(required=True, choices=['tcp', 'udp']), - port=dict(required=True, type='int'), + action=dict(required=True, choices=['add', 'remove', 'verify_chain']), + chain=dict(required=False, default='OS_FIREWALL_ALLOW'), + create_jump_rule=dict(required=False, type='bool', default=True), + jump_rule_chain=dict(required=False, default='INPUT'), + protocol=dict(required=False, choices=['tcp', 'udp']), + port=dict(required=False, type='int'), ip_version=dict(required=False, default='ipv4', choices=['ipv4', 'ipv6']), ), @@ -232,16 +242,24 @@ def main(): action = module.params['action'] protocol = module.params['protocol'] port = module.params['port'] - ip_version = module.params['ip_version'] - chain = 'OS_FIREWALL_ALLOW' - iptables_manager = IpTablesManager(module, ip_version, module.check_mode, chain) + if action in ['add', 'remove']: + if not protocol: + error = "protocol is required when action is %s" % action + module.fail_json(msg=error) + if not port: + error = "port is required when action is %s" % action + module.fail_json(msg=error) + + iptables_manager = IpTablesManager(module) try: if action == 'add': iptables_manager.add_rule(port, protocol) elif action == 'remove': iptables_manager.remove_rule(port, protocol) + elif action == 'verify_chain': + iptables_manager.verify_chain() except IpTablesError as e: module.fail_json(msg=e.msg) |