#!/usr/bin/python3
###############################################################################
# #
# IPFire.org - A linux based firewall #
# Copyright (C) 2016 Michael Tremer #
# #
# This program is free software: you can redistribute it and/or modify #
# it under the terms of the GNU General Public License as published by #
# the Free Software Foundation, either version 3 of the License, or #
# (at your option) any later version. #
# #
# This program is distributed in the hope that it will be useful, #
# but WITHOUT ANY WARRANTY; without even the implied warranty of #
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the #
# GNU General Public License for more details. #
# #
# You should have received a copy of the GNU General Public License #
# along with this program. If not, see . #
# #
###############################################################################
import argparse
import datetime
import daemon
import filecmp
import functools
import ipaddress
import logging
import logging.handlers
import os
import queue
import re
import signal
import socket
import stat
import subprocess
import sys
import tempfile
import threading
LOCAL_TTL = 60
log = logging.getLogger("dhcp")
log.setLevel(logging.DEBUG)
def setup_logging(daemon=True, loglevel=logging.INFO):
log.setLevel(loglevel)
# Log to syslog by default
handler = logging.handlers.SysLogHandler(address="/dev/log", facility="daemon")
log.addHandler(handler)
# Format everything
formatter = logging.Formatter("%(name)s[%(process)d]: %(message)s")
handler.setFormatter(formatter)
handler.setLevel(loglevel)
# If we are running in foreground, we should write everything to the console, too
if not daemon:
handler = logging.StreamHandler()
log.addHandler(handler)
handler.setLevel(loglevel)
return log
class UnboundDHCPLeasesBridge(object):
def __init__(self, dhcp_leases_file, fix_leases_file, unbound_leases_file, hosts_file, socket_path):
self.leases_file = dhcp_leases_file
self.fix_leases_file = fix_leases_file
self.hosts_file = hosts_file
self.socket_path = socket_path
self.socket = None
# Store all known leases
self.leases = set()
# Create a queue for all received events
self.queue = queue.Queue()
# Initialize the worker
self.worker = Worker(self.queue, callback=self._handle_message)
# Initialize the watcher
self.watcher = Watcher(reload=self.reload)
self.unbound = UnboundConfigWriter(unbound_leases_file)
def run(self):
log.info("Unbound DHCP Leases Bridge started on %s" % self.leases_file)
# Launch the worker
self.worker.start()
# Launch the watcher
self.watcher.start()
# Open the server socket
self.socket = self._open_socket(self.socket_path)
while True:
# Accept any incoming connections
try:
conn, peer = self.socket.accept()
except OSError as e:
break
try:
# Receive what the client is sending
data, ancillary_data, flags, address = conn.recvmsg(4096)
# Log that we have received some data
log.debug("Received message of %s byte(s)" % len(data))
# Decode the data
message = self._decode_message(data)
# Add the message to the queue
self.queue.put(message)
conn.send(b"OK\n")
# Send ERROR to the client if something went wrong
except Exception as e:
log.error("Could not handle message: %s" % e)
conn.send(b"ERROR\n")
continue
# Close the connection
finally:
conn.close()
# Terminate the worker
self.queue.put(None)
# Terminate the watcher
self.watcher.terminate()
# Wait for the worker and watcher to finish
self.worker.join()
self.watcher.join()
log.info("Unbound DHCP Leases Bridge terminated")
def _open_socket(self, path):
# Allocate a new socket
s = socket.socket(family=socket.AF_UNIX, type=socket.SOCK_STREAM)
# Unlink any old sockets
try:
os.unlink(path)
except FileNotFoundError as e:
pass
# Bind the socket
try:
s.bind(self.socket_path)
except OSError as e:
log.error("Could not open socket at %s: %s" % (path, e))
raise SystemExit(1) from e
# Listen
s.listen(128)
return s
def _decode_message(self, data):
message = {}
for line in data.splitlines():
# Skip empty lines
if not line:
continue
# Try to decode the line
try:
line = line.decode()
except UnicodeError as e:
log.error("Could not decode %r: %s" % (line, e))
raise e
# Split the line
key, _, value = line.partition("=")
# Skip the line if it does not have a value
if not _:
raise ValueError("No value given")
# Store the attributes
message[key] = value
return message
def _handle_message(self, message):
log.debug("Handling message:")
for key in message:
log.debug(" %-20s = %s" % (key, message[key]))
# Extract the event type
event = message.get("EVENT")
# Check if event is set
if not event:
raise ValueError("The message does not have EVENT set")
# COMMIT
elif event == "commit":
address = message.get("ADDRESS")
name = message.get("NAME")
# Find the old lease
old_lease = self._find_lease(address)
# Don't update fixed leases as they might clear the hostname
if old_lease and old_lease.fixed:
log.debug("Won't update fixed lease %s" % old_lease)
return
# Create a new lease
lease = Lease(address, {
"client-hostname" : name,
})
self._add_lease(lease)
# Can we skip the update?
if old_lease:
if lease.rrset == old_lease.rrset:
log.debug("Won't update %s as nothing has changed" % lease)
return
# Remove the old lease first
self.unbound.remove_lease(old_lease)
self._remove_lease(old_lease)
# Apply the lease
self.unbound.apply_lease(lease)
# RELEASE/EXPIRY
elif event in ("release", "expiry"):
address = message.get("ADDRESS")
# Find the lease
lease = self._find_lease(address)
if not lease:
log.warning("Could not find lease for %s" % address)
return
# Remove the lease
self.unbound.remove_lease(lease)
self._remove_lease(lease)
# Raise an error if the event is not supported
else:
raise ValueError("Unsupported event: %s" % event)
def update_dhcp_leases(self):
# Drop all known leases
self.leases.clear()
# Add all dynamic leases
for lease in DHCPLeases(self.leases_file):
self._add_lease(lease)
# Add all static leases
for lease in FixLeases(self.fix_leases_file):
self._add_lease(lease)
# Dump leases
if self.leases:
log.debug("DHCP Leases:")
for lease in self.leases:
log.debug(" %s:" % lease.fqdn)
log.debug(" Start: %s" % lease.time_starts)
log.debug(" End : %s" % lease.time_ends)
if lease.has_expired():
log.debug(" Expired")
self.unbound.update_dhcp_leases([l for l in self.leases if not l.has_expired()])
def _add_lease(self, lease):
# Skip leases without a FQDN
if not lease.fqdn:
log.debug("Skipping lease without a FQDN: %s" % lease)
return
# Skip any leases that also are a static host
elif lease.fqdn in self.hosts:
log.debug("Skipping lease for which a static host exists: %s" % lease)
return
# Don't add expired leases
elif lease.has_expired():
log.debug("Skipping expired lease: %s" % lease)
return
# Remove any previous leases
self._remove_lease(lease)
# Store the lease
self.leases.add(lease)
def _find_lease(self, ipaddr):
"""
Returns the lease with the specified IP address
"""
if not isinstance(ipaddr, ipaddress.IPv4Address):
ipaddr = ipaddress.IPv4Address(ipaddr)
for lease in self.leases:
if lease.ipaddr == ipaddr:
return lease
def _remove_lease(self, lease):
try:
self.leases.remove(lease)
except KeyError:
pass
def read_static_hosts(self):
log.info("Reading static hosts from %s" % self.hosts_file)
hosts = {}
with open(self.hosts_file) as f:
for line in f.readlines():
line = line.rstrip()
try:
enabled, ipaddr, hostname, domainname, generateptr = line.split(",")
except:
log.warning("Could not parse line: %s" % line)
continue
# Skip any disabled entries
if not enabled == "on":
continue
if hostname and domainname:
fqdn = "%s.%s" % (hostname, domainname)
elif hostname:
fqdn = hostname
elif domainname:
fqdn = domainname
try:
hosts[fqdn].append(ipaddr)
hosts[fqdn].sort()
except KeyError:
hosts[fqdn] = [ipaddr,]
# Dump everything in the logs
log.debug("Static hosts:")
for name in hosts:
log.debug(" %-20s : %s" % (name, ", ".join(hosts[name])))
return hosts
def reload(self, *args, **kwargs):
# Read all static hosts
self.hosts = self.read_static_hosts()
# Unconditionally update all leases and reload Unbound
self.update_dhcp_leases()
def terminate(self, *args, **kwargs):
# Close the socket
if self.socket:
self.socket.close()
class Watcher(threading.Thread):
"""
Watches if Unbound is still running.
"""
def __init__(self, reload, *args, **kwargs):
super().__init__(*args, **kwargs)
self.reload = reload
# Set to true if this thread should be terminated
self._terminated = threading.Event()
def run(self):
log.debug("Watcher launched")
pidfd = None
while True:
# One iteration takes 30 seconds unless we don't know the process
# when we try to find it once a second.
if self._terminated.wait(30 if pidfd else 1):
break
# Fetch a PIDFD for Unbound
if pidfd is None:
pidfd = self._get_pidfd()
# If we could not acquire a PIDFD, we will try again soon...
if not pidfd:
log.warning("Cannot find Unbound...")
continue
# Since Unbound has been restarted, we need to reload it all...
self.reload()
log.debug("Checking if Unbound is still alive...")
# Send the process a signal
try:
signal.pidfd_send_signal(pidfd, signal.SIG_DFL)
# If the process has died, we land here and will have to wait until Unbound
# has come back and reload it...
except ProcessLookupError as e:
log.error("Unbound has died")
# Reset the PIDFD
pidfd = None
else:
log.debug("Unbound is alive")
log.debug("Watcher terminated")
def terminate(self):
"""
Called to signal this thread to terminate
"""
self._terminated.set()
def _get_pidfd(self):
"""
Returns a PIDFD for unbound if it is running, otherwise None.
"""
# Try to find the PID
pid = pidof("unbound")
if pid:
log.debug("Unbound is running as PID %s" % pid)
# Open a PIDFD
pidfd = os.pidfd_open(pid)
log.debug("Acquired PIDFD %s for PID %s" % (pidfd, pid))
return pidfd
class Worker(threading.Thread):
"""
The worker is launched in a separate thread
which allows us to perform some tasks asynchronously.
"""
def __init__(self, queue, callback):
super().__init__()
self.queue = queue
self.callback = callback
def run(self):
log.debug("Worker %s launched" % self.native_id)
while True:
message = self.queue.get()
# If the message is None, we have to quit
if message is None:
break
# Call the callback
try:
self.callback(message)
except Exception as e:
log.error("Callback failed: %s" % e, exc_info=True)
log.debug("Worker %s terminated" % self.native_id)
class DHCPLeases(object):
regex_leaseblock = re.compile(r"lease (?P\d+\.\d+\.\d+\.\d+) {(?P[\s\S]+?)\n}")
def __init__(self, path):
self.path = path
self._leases = self._parse()
def __iter__(self):
return iter(self._leases)
def _parse(self):
log.info("Reading DHCP leases from %s" % self.path)
leases = []
with open(self.path) as f:
# Read entire leases file
data = f.read()
for match in self.regex_leaseblock.finditer(data):
block = match.groupdict()
ipaddr = block.get("ipaddr")
config = block.get("config")
properties = self._parse_block(config)
# Skip any abandoned leases
if not "hardware" in properties:
continue
# Skip inactive leases
elif not properties.get("binding", "state active"):
continue
lease = Lease(ipaddr, properties)
leases.append(lease)
return leases
def _parse_block(self, block):
properties = {}
for line in block.splitlines():
if not line:
continue
# Remove trailing ; from line
if line.endswith(";"):
line = line[:-1]
# Invalid line if it doesn't end with ;
else:
continue
# Remove any leading whitespace
line = line.lstrip()
# We skip all options and sets
if line.startswith("option") or line.startswith("set"):
continue
# Split by first space
key, val = line.split(" ", 1)
properties[key] = val
return properties
class FixLeases(object):
def __init__(self, path):
self.path = path
self._leases = self._parse()
def __iter__(self):
return iter(self._leases)
def _parse(self):
log.info("Reading fix leases from %s" % self.path)
now = datetime.datetime.utcnow()
leases = []
with open(self.path) as f:
for line in f.readlines():
line = line.rstrip()
try:
hwaddr, ipaddr, enabled, a, b, c, hostname = line.split(",")
except ValueError:
log.warning("Could not parse line: %s" % line)
continue
# Skip any disabled leases
if not enabled == "on":
continue
l = Lease(ipaddr, {
"binding" : "state active",
"client-hostname" : hostname,
"starts" : now.strftime("%w %Y/%m/%d %H:%M:%S"),
"ends" : "never",
}, fixed=True)
leases.append(l)
return leases
class Lease(object):
def __init__(self, ipaddr, properties, fixed=False):
if not isinstance(ipaddr, ipaddress.IPv4Address):
ipaddr = ipaddress.IPv4Address(ipaddr)
self.ipaddr = ipaddr
self._properties = properties
self.fixed = fixed
def __repr__(self):
return "<%s for %s (%s)>" % (self.__class__.__name__, self.ipaddr, self.hostname)
def __eq__(self, other):
if isinstance(other, self.__class__):
return self.ipaddr == other.ipaddr
return NotImplemented
def __gt__(self, other):
if isinstance(other, self.__class__):
if not self.ipaddr == other.ipaddr:
return NotImplemented
return self.time_starts > other.time_starts
return NotImplemented
def __hash__(self):
return hash(self.ipaddr)
@property
def hostname(self):
hostname = self._properties.get("client-hostname")
if hostname is None:
return
# Remove any ""
hostname = hostname.replace("\"", "")
# Only return valid hostnames
m = re.match(r"^[A-Z0-9\-]{1,63}$", hostname, re.I)
if m:
return hostname
@property
def domain(self):
# Load ethernet settings
ethernet_settings = self.read_settings("/var/ipfire/ethernet/settings")
# Load DHCP settings
dhcp_settings = self.read_settings("/var/ipfire/dhcp/settings")
subnets = {}
for zone in ("GREEN", "BLUE"):
if not dhcp_settings.get("ENABLE_%s" % zone) == "on":
continue
netaddr = ethernet_settings.get("%s_NETADDRESS" % zone)
submask = ethernet_settings.get("%s_NETMASK" % zone)
subnet = ipaddress.ip_network("%s/%s" % (netaddr, submask))
domain = dhcp_settings.get("DOMAIN_NAME_%s" % zone)
subnets[subnet] = domain
address = ipaddress.ip_address(self.ipaddr)
for subnet in subnets:
if address in subnet:
return subnets[subnet]
# Load main settings
settings = self.read_settings("/var/ipfire/main/settings")
# Fall back to the host domain if no match could be found
return settings.get("DOMAINNAME", "localdomain")
@staticmethod
@functools.cache
def read_settings(filename):
settings = {}
with open(filename) as f:
for line in f.readlines():
# Remove line-breaks
line = line.rstrip()
k, v = line.split("=", 1)
settings[k] = v
return settings
@property
def fqdn(self):
if self.hostname:
return "%s.%s" % (self.hostname, self.domain)
@staticmethod
def _parse_time(s):
return datetime.datetime.strptime(s, "%w %Y/%m/%d %H:%M:%S")
@property
def time_starts(self):
starts = self._properties.get("starts")
if starts:
return self._parse_time(starts)
@property
def time_ends(self):
ends = self._properties.get("ends")
if not ends or ends == "never":
return
return self._parse_time(ends)
def has_expired(self):
if not self.time_starts:
return
if not self.time_ends:
return self.time_starts > datetime.datetime.utcnow()
return not self.time_starts < datetime.datetime.utcnow() < self.time_ends
@property
def rrset(self):
# If the lease does not have a valid FQDN, we cannot create any RRs
if self.fqdn is None:
return []
return [
# Forward record
(self.fqdn, "%s" % LOCAL_TTL, "IN A", "%s" % self.ipaddr),
# Reverse record
(self.ipaddr.reverse_pointer, "%s" % LOCAL_TTL,
"IN PTR", self.fqdn),
]
class UnboundConfigWriter(object):
def __init__(self, path):
self.path = path
def update_dhcp_leases(self, leases):
# Write out all leases
if self.write_dhcp_leases(leases):
log.debug("Reloading Unbound...")
# Reload the configuration without dropping the cache
self._control("reload_keep_cache")
def write_dhcp_leases(self, leases):
log.debug("Writing DHCP leases...")
with tempfile.NamedTemporaryFile(mode="w") as f:
for l in sorted(leases, key=lambda x: x.ipaddr):
for rr in l.rrset:
f.write("local-data: \"%s\"\n" % " ".join(rr))
# Flush the file
f.flush()
# Compare if the new leases file has changed from the previous version
try:
if filecmp.cmp(f.name, self.path, shallow=False):
log.debug("The generated leases file has not changed")
return False
# Remove the old file
os.unlink(self.path)
# If the previous file did not exist, just keep falling through
except FileNotFoundError:
pass
# Make file readable for everyone
os.fchmod(f.fileno(), stat.S_IRUSR|stat.S_IWUSR|stat.S_IRGRP|stat.S_IROTH)
# Move the file to its destination
os.link(f.name, self.path)
return True
def _control(self, *args):
command = ["unbound-control"]
command.extend(args)
# Log what we are doing
log.debug("Running %s" % " ".join(command))
try:
subprocess.check_output(command)
# Log any errors
except subprocess.CalledProcessError as e:
log.critical("Could not run %s, error code: %s: %s" % (
" ".join(command), e.returncode, e.output))
raise e
def apply_lease(self, lease):
"""
This method takes a lease and updates Unbound at runtime.
"""
log.debug("Applying lease %s" % lease)
for rr in lease.rrset:
log.debug("Adding new record %s" % " ".join(rr))
self._control("local_data", *rr)
def remove_lease(self, lease):
"""
This method takes a lease and removes it from Unbound at runtime.
"""
log.debug("Removing lease %s" % lease)
for name, ttl, type, content in lease.rrset:
log.debug("Removing records for %s" % name)
self._control("local_data_remove", name)
def pidof(program):
"""
Returns the first PID of the given program.
"""
try:
output = subprocess.check_output(["pidof", program])
except subprocess.CalledProcessError as e:
return
# Convert to string
output = output.decode()
# Return the first PID
for pid in output.split():
try:
pid = int(pid)
except ValueError:
continue
return pid
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Bridge for DHCP Leases and Unbound DNS")
# Daemon Stuff
parser.add_argument("--daemon", "-d", action="store_true",
help="Launch as daemon in background")
parser.add_argument("--verbose", "-v", action="count", help="Be more verbose")
# Paths
parser.add_argument("--dhcp-leases", default="/var/state/dhcp/dhcpd.leases",
metavar="PATH", help="Path to the DHCPd leases file")
parser.add_argument("--unbound-leases", default="/etc/unbound/dhcp-leases.conf",
metavar="PATH", help="Path to the unbound configuration file")
parser.add_argument("--fix-leases", default="/var/ipfire/dhcp/fixleases",
metavar="PATH", help="Path to the fix leases file")
parser.add_argument("--hosts", default="/var/ipfire/main/hosts",
metavar="PATH", help="Path to static hosts file")
parser.add_argument("--socket-path", default="/var/run/unbound-dhcp-leases-bridge.sock",
metavar="PATH", help="Socket Path",
)
# Parse command line arguments
args = parser.parse_args()
# Setup logging
loglevel = logging.WARN
if args.verbose:
if args.verbose == 1:
loglevel = logging.INFO
elif args.verbose >= 2:
loglevel = logging.DEBUG
bridge = UnboundDHCPLeasesBridge(args.dhcp_leases, args.fix_leases,
args.unbound_leases, args.hosts, socket_path=args.socket_path)
with daemon.DaemonContext(
detach_process=args.daemon,
stderr=None if args.daemon else sys.stderr,
signal_map = {
signal.SIGHUP : bridge.reload,
signal.SIGINT : bridge.terminate,
signal.SIGTERM : bridge.terminate,
},
) as daemon:
setup_logging(daemon=args.daemon, loglevel=loglevel)
bridge.run()