#!/usr/bin/python3 ############################################################################### # # # IPFire.org - A linux based firewall # # Copyright (C) 2016 Michael Tremer # # # # This program is free software: you can redistribute it and/or modify # # it under the terms of the GNU General Public License as published by # # the Free Software Foundation, either version 3 of the License, or # # (at your option) any later version. # # # # This program is distributed in the hope that it will be useful, # # but WITHOUT ANY WARRANTY; without even the implied warranty of # # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # # GNU General Public License for more details. # # # # You should have received a copy of the GNU General Public License # # along with this program. If not, see . # # # ############################################################################### import argparse import datetime import daemon import filecmp import functools import ipaddress import logging import logging.handlers import os import queue import re import signal import socket import stat import subprocess import sys import tempfile import threading LOCAL_TTL = 60 log = logging.getLogger("dhcp") log.setLevel(logging.DEBUG) def setup_logging(daemon=True, loglevel=logging.INFO): log.setLevel(loglevel) # Log to syslog by default handler = logging.handlers.SysLogHandler(address="/dev/log", facility="daemon") log.addHandler(handler) # Format everything formatter = logging.Formatter("%(name)s[%(process)d]: %(message)s") handler.setFormatter(formatter) handler.setLevel(loglevel) # If we are running in foreground, we should write everything to the console, too if not daemon: handler = logging.StreamHandler() log.addHandler(handler) handler.setLevel(loglevel) return log class UnboundDHCPLeasesBridge(object): def __init__(self, dhcp_leases_file, fix_leases_file, unbound_leases_file, hosts_file, socket_path): self.leases_file = dhcp_leases_file self.fix_leases_file = fix_leases_file self.hosts_file = hosts_file self.socket_path = socket_path self.socket = None # Store all known leases self.leases = set() # Create a queue for all received events self.queue = queue.Queue() # Initialize the worker self.worker = Worker(self.queue, callback=self._handle_message) # Initialize the watcher self.watcher = Watcher(reload=self.reload) self.unbound = UnboundConfigWriter(unbound_leases_file) def run(self): log.info("Unbound DHCP Leases Bridge started on %s" % self.leases_file) # Launch the worker self.worker.start() # Launch the watcher self.watcher.start() # Open the server socket self.socket = self._open_socket(self.socket_path) while True: # Accept any incoming connections try: conn, peer = self.socket.accept() except OSError as e: break try: # Receive what the client is sending data, ancillary_data, flags, address = conn.recvmsg(4096) # Log that we have received some data log.debug("Received message of %s byte(s)" % len(data)) # Decode the data message = self._decode_message(data) # Add the message to the queue self.queue.put(message) conn.send(b"OK\n") # Send ERROR to the client if something went wrong except Exception as e: log.error("Could not handle message: %s" % e) conn.send(b"ERROR\n") continue # Close the connection finally: conn.close() # Terminate the worker self.queue.put(None) # Terminate the watcher self.watcher.terminate() # Wait for the worker and watcher to finish self.worker.join() self.watcher.join() log.info("Unbound DHCP Leases Bridge terminated") def _open_socket(self, path): # Allocate a new socket s = socket.socket(family=socket.AF_UNIX, type=socket.SOCK_STREAM) # Unlink any old sockets try: os.unlink(path) except FileNotFoundError as e: pass # Bind the socket try: s.bind(self.socket_path) except OSError as e: log.error("Could not open socket at %s: %s" % (path, e)) raise SystemExit(1) from e # Listen s.listen(128) return s def _decode_message(self, data): message = {} for line in data.splitlines(): # Skip empty lines if not line: continue # Try to decode the line try: line = line.decode() except UnicodeError as e: log.error("Could not decode %r: %s" % (line, e)) raise e # Split the line key, _, value = line.partition("=") # Skip the line if it does not have a value if not _: raise ValueError("No value given") # Store the attributes message[key] = value return message def _handle_message(self, message): log.debug("Handling message:") for key in message: log.debug(" %-20s = %s" % (key, message[key])) # Extract the event type event = message.get("EVENT") # Check if event is set if not event: raise ValueError("The message does not have EVENT set") # COMMIT elif event == "commit": address = message.get("ADDRESS") name = message.get("NAME") # Find the old lease old_lease = self._find_lease(address) # Don't update fixed leases as they might clear the hostname if old_lease and old_lease.fixed: log.debug("Won't update fixed lease %s" % old_lease) return # Create a new lease lease = Lease(address, { "client-hostname" : name, }) self._add_lease(lease) # Can we skip the update? if old_lease: if lease.rrset == old_lease.rrset: log.debug("Won't update %s as nothing has changed" % lease) return # Remove the old lease first self.unbound.remove_lease(old_lease) self._remove_lease(old_lease) # Apply the lease self.unbound.apply_lease(lease) # RELEASE/EXPIRY elif event in ("release", "expiry"): address = message.get("ADDRESS") # Find the lease lease = self._find_lease(address) if not lease: log.warning("Could not find lease for %s" % address) return # Remove the lease self.unbound.remove_lease(lease) self._remove_lease(lease) # Raise an error if the event is not supported else: raise ValueError("Unsupported event: %s" % event) def update_dhcp_leases(self): # Drop all known leases self.leases.clear() # Add all dynamic leases for lease in DHCPLeases(self.leases_file): self._add_lease(lease) # Add all static leases for lease in FixLeases(self.fix_leases_file): self._add_lease(lease) # Dump leases if self.leases: log.debug("DHCP Leases:") for lease in self.leases: log.debug(" %s:" % lease.fqdn) log.debug(" Start: %s" % lease.time_starts) log.debug(" End : %s" % lease.time_ends) if lease.has_expired(): log.debug(" Expired") self.unbound.update_dhcp_leases([l for l in self.leases if not l.has_expired()]) def _add_lease(self, lease): # Skip leases without a FQDN if not lease.fqdn: log.debug("Skipping lease without a FQDN: %s" % lease) return # Skip any leases that also are a static host elif lease.fqdn in self.hosts: log.debug("Skipping lease for which a static host exists: %s" % lease) return # Don't add expired leases elif lease.has_expired(): log.debug("Skipping expired lease: %s" % lease) return # Remove any previous leases self._remove_lease(lease) # Store the lease self.leases.add(lease) def _find_lease(self, ipaddr): """ Returns the lease with the specified IP address """ if not isinstance(ipaddr, ipaddress.IPv4Address): ipaddr = ipaddress.IPv4Address(ipaddr) for lease in self.leases: if lease.ipaddr == ipaddr: return lease def _remove_lease(self, lease): try: self.leases.remove(lease) except KeyError: pass def read_static_hosts(self): log.info("Reading static hosts from %s" % self.hosts_file) hosts = {} with open(self.hosts_file) as f: for line in f.readlines(): line = line.rstrip() try: enabled, ipaddr, hostname, domainname, generateptr = line.split(",") except: log.warning("Could not parse line: %s" % line) continue # Skip any disabled entries if not enabled == "on": continue if hostname and domainname: fqdn = "%s.%s" % (hostname, domainname) elif hostname: fqdn = hostname elif domainname: fqdn = domainname try: hosts[fqdn].append(ipaddr) hosts[fqdn].sort() except KeyError: hosts[fqdn] = [ipaddr,] # Dump everything in the logs log.debug("Static hosts:") for name in hosts: log.debug(" %-20s : %s" % (name, ", ".join(hosts[name]))) return hosts def reload(self, *args, **kwargs): # Read all static hosts self.hosts = self.read_static_hosts() # Unconditionally update all leases and reload Unbound self.update_dhcp_leases() def terminate(self, *args, **kwargs): # Close the socket if self.socket: self.socket.close() class Watcher(threading.Thread): """ Watches if Unbound is still running. """ def __init__(self, reload, *args, **kwargs): super().__init__(*args, **kwargs) self.reload = reload # Set to true if this thread should be terminated self._terminated = threading.Event() def run(self): log.debug("Watcher launched") pidfd = None while True: # One iteration takes 30 seconds unless we don't know the process # when we try to find it once a second. if self._terminated.wait(30 if pidfd else 1): break # Fetch a PIDFD for Unbound if pidfd is None: pidfd = self._get_pidfd() # If we could not acquire a PIDFD, we will try again soon... if not pidfd: log.warning("Cannot find Unbound...") continue # Since Unbound has been restarted, we need to reload it all... self.reload() log.debug("Checking if Unbound is still alive...") # Send the process a signal try: signal.pidfd_send_signal(pidfd, signal.SIG_DFL) # If the process has died, we land here and will have to wait until Unbound # has come back and reload it... except ProcessLookupError as e: log.error("Unbound has died") # Reset the PIDFD pidfd = None else: log.debug("Unbound is alive") log.debug("Watcher terminated") def terminate(self): """ Called to signal this thread to terminate """ self._terminated.set() def _get_pidfd(self): """ Returns a PIDFD for unbound if it is running, otherwise None. """ # Try to find the PID pid = pidof("unbound") if pid: log.debug("Unbound is running as PID %s" % pid) # Open a PIDFD pidfd = os.pidfd_open(pid) log.debug("Acquired PIDFD %s for PID %s" % (pidfd, pid)) return pidfd class Worker(threading.Thread): """ The worker is launched in a separate thread which allows us to perform some tasks asynchronously. """ def __init__(self, queue, callback): super().__init__() self.queue = queue self.callback = callback def run(self): log.debug("Worker %s launched" % self.native_id) while True: message = self.queue.get() # If the message is None, we have to quit if message is None: break # Call the callback try: self.callback(message) except Exception as e: log.error("Callback failed: %s" % e, exc_info=True) log.debug("Worker %s terminated" % self.native_id) class DHCPLeases(object): regex_leaseblock = re.compile(r"lease (?P\d+\.\d+\.\d+\.\d+) {(?P[\s\S]+?)\n}") def __init__(self, path): self.path = path self._leases = self._parse() def __iter__(self): return iter(self._leases) def _parse(self): log.info("Reading DHCP leases from %s" % self.path) leases = [] with open(self.path) as f: # Read entire leases file data = f.read() for match in self.regex_leaseblock.finditer(data): block = match.groupdict() ipaddr = block.get("ipaddr") config = block.get("config") properties = self._parse_block(config) # Skip any abandoned leases if not "hardware" in properties: continue # Skip inactive leases elif not properties.get("binding", "state active"): continue lease = Lease(ipaddr, properties) leases.append(lease) return leases def _parse_block(self, block): properties = {} for line in block.splitlines(): if not line: continue # Remove trailing ; from line if line.endswith(";"): line = line[:-1] # Invalid line if it doesn't end with ; else: continue # Remove any leading whitespace line = line.lstrip() # We skip all options and sets if line.startswith("option") or line.startswith("set"): continue # Split by first space key, val = line.split(" ", 1) properties[key] = val return properties class FixLeases(object): def __init__(self, path): self.path = path self._leases = self._parse() def __iter__(self): return iter(self._leases) def _parse(self): log.info("Reading fix leases from %s" % self.path) now = datetime.datetime.utcnow() leases = [] with open(self.path) as f: for line in f.readlines(): line = line.rstrip() try: hwaddr, ipaddr, enabled, a, b, c, hostname = line.split(",") except ValueError: log.warning("Could not parse line: %s" % line) continue # Skip any disabled leases if not enabled == "on": continue l = Lease(ipaddr, { "binding" : "state active", "client-hostname" : hostname, "starts" : now.strftime("%w %Y/%m/%d %H:%M:%S"), "ends" : "never", }, fixed=True) leases.append(l) return leases class Lease(object): def __init__(self, ipaddr, properties, fixed=False): if not isinstance(ipaddr, ipaddress.IPv4Address): ipaddr = ipaddress.IPv4Address(ipaddr) self.ipaddr = ipaddr self._properties = properties self.fixed = fixed def __repr__(self): return "<%s for %s (%s)>" % (self.__class__.__name__, self.ipaddr, self.hostname) def __eq__(self, other): if isinstance(other, self.__class__): return self.ipaddr == other.ipaddr return NotImplemented def __gt__(self, other): if isinstance(other, self.__class__): if not self.ipaddr == other.ipaddr: return NotImplemented return self.time_starts > other.time_starts return NotImplemented def __hash__(self): return hash(self.ipaddr) @property def hostname(self): hostname = self._properties.get("client-hostname") if hostname is None: return # Remove any "" hostname = hostname.replace("\"", "") # Only return valid hostnames m = re.match(r"^[A-Z0-9\-]{1,63}$", hostname, re.I) if m: return hostname @property def domain(self): # Load ethernet settings ethernet_settings = self.read_settings("/var/ipfire/ethernet/settings") # Load DHCP settings dhcp_settings = self.read_settings("/var/ipfire/dhcp/settings") subnets = {} for zone in ("GREEN", "BLUE"): if not dhcp_settings.get("ENABLE_%s" % zone) == "on": continue netaddr = ethernet_settings.get("%s_NETADDRESS" % zone) submask = ethernet_settings.get("%s_NETMASK" % zone) subnet = ipaddress.ip_network("%s/%s" % (netaddr, submask)) domain = dhcp_settings.get("DOMAIN_NAME_%s" % zone) subnets[subnet] = domain address = ipaddress.ip_address(self.ipaddr) for subnet in subnets: if address in subnet: return subnets[subnet] # Load main settings settings = self.read_settings("/var/ipfire/main/settings") # Fall back to the host domain if no match could be found return settings.get("DOMAINNAME", "localdomain") @staticmethod @functools.cache def read_settings(filename): settings = {} with open(filename) as f: for line in f.readlines(): # Remove line-breaks line = line.rstrip() k, v = line.split("=", 1) settings[k] = v return settings @property def fqdn(self): if self.hostname: return "%s.%s" % (self.hostname, self.domain) @staticmethod def _parse_time(s): return datetime.datetime.strptime(s, "%w %Y/%m/%d %H:%M:%S") @property def time_starts(self): starts = self._properties.get("starts") if starts: return self._parse_time(starts) @property def time_ends(self): ends = self._properties.get("ends") if not ends or ends == "never": return return self._parse_time(ends) def has_expired(self): if not self.time_starts: return if not self.time_ends: return self.time_starts > datetime.datetime.utcnow() return not self.time_starts < datetime.datetime.utcnow() < self.time_ends @property def rrset(self): # If the lease does not have a valid FQDN, we cannot create any RRs if self.fqdn is None: return [] return [ # Forward record (self.fqdn, "%s" % LOCAL_TTL, "IN A", "%s" % self.ipaddr), # Reverse record (self.ipaddr.reverse_pointer, "%s" % LOCAL_TTL, "IN PTR", self.fqdn), ] class UnboundConfigWriter(object): def __init__(self, path): self.path = path def update_dhcp_leases(self, leases): # Write out all leases if self.write_dhcp_leases(leases): log.debug("Reloading Unbound...") # Reload the configuration without dropping the cache self._control("reload_keep_cache") def write_dhcp_leases(self, leases): log.debug("Writing DHCP leases...") with tempfile.NamedTemporaryFile(mode="w") as f: for l in sorted(leases, key=lambda x: x.ipaddr): for rr in l.rrset: f.write("local-data: \"%s\"\n" % " ".join(rr)) # Flush the file f.flush() # Compare if the new leases file has changed from the previous version try: if filecmp.cmp(f.name, self.path, shallow=False): log.debug("The generated leases file has not changed") return False # Remove the old file os.unlink(self.path) # If the previous file did not exist, just keep falling through except FileNotFoundError: pass # Make file readable for everyone os.fchmod(f.fileno(), stat.S_IRUSR|stat.S_IWUSR|stat.S_IRGRP|stat.S_IROTH) # Move the file to its destination os.link(f.name, self.path) return True def _control(self, *args): command = ["unbound-control"] command.extend(args) # Log what we are doing log.debug("Running %s" % " ".join(command)) try: subprocess.check_output(command) # Log any errors except subprocess.CalledProcessError as e: log.critical("Could not run %s, error code: %s: %s" % ( " ".join(command), e.returncode, e.output)) raise e def apply_lease(self, lease): """ This method takes a lease and updates Unbound at runtime. """ log.debug("Applying lease %s" % lease) for rr in lease.rrset: log.debug("Adding new record %s" % " ".join(rr)) self._control("local_data", *rr) def remove_lease(self, lease): """ This method takes a lease and removes it from Unbound at runtime. """ log.debug("Removing lease %s" % lease) for name, ttl, type, content in lease.rrset: log.debug("Removing records for %s" % name) self._control("local_data_remove", name) def pidof(program): """ Returns the first PID of the given program. """ try: output = subprocess.check_output(["pidof", program]) except subprocess.CalledProcessError as e: return # Convert to string output = output.decode() # Return the first PID for pid in output.split(): try: pid = int(pid) except ValueError: continue return pid if __name__ == "__main__": parser = argparse.ArgumentParser(description="Bridge for DHCP Leases and Unbound DNS") # Daemon Stuff parser.add_argument("--daemon", "-d", action="store_true", help="Launch as daemon in background") parser.add_argument("--verbose", "-v", action="count", help="Be more verbose") # Paths parser.add_argument("--dhcp-leases", default="/var/state/dhcp/dhcpd.leases", metavar="PATH", help="Path to the DHCPd leases file") parser.add_argument("--unbound-leases", default="/etc/unbound/dhcp-leases.conf", metavar="PATH", help="Path to the unbound configuration file") parser.add_argument("--fix-leases", default="/var/ipfire/dhcp/fixleases", metavar="PATH", help="Path to the fix leases file") parser.add_argument("--hosts", default="/var/ipfire/main/hosts", metavar="PATH", help="Path to static hosts file") parser.add_argument("--socket-path", default="/var/run/unbound-dhcp-leases-bridge.sock", metavar="PATH", help="Socket Path", ) # Parse command line arguments args = parser.parse_args() # Setup logging loglevel = logging.WARN if args.verbose: if args.verbose == 1: loglevel = logging.INFO elif args.verbose >= 2: loglevel = logging.DEBUG bridge = UnboundDHCPLeasesBridge(args.dhcp_leases, args.fix_leases, args.unbound_leases, args.hosts, socket_path=args.socket_path) with daemon.DaemonContext( detach_process=args.daemon, stderr=None if args.daemon else sys.stderr, signal_map = { signal.SIGHUP : bridge.reload, signal.SIGINT : bridge.terminate, signal.SIGTERM : bridge.terminate, }, ) as daemon: setup_logging(daemon=args.daemon, loglevel=loglevel) bridge.run()