Files
external-whitelist-auth-gate/external_whitelist_auth_gate.py
Patrick Asmus 8122b5274a Initial
2026-05-11 20:47:25 +02:00

412 lines
13 KiB
Python

#!/usr/bin/env python3
import base64
import hashlib
import ipaddress
import json
import logging
import os
import secrets
import socket
import threading
import time
import urllib.request
from urllib.parse import parse_qs, urlparse
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
logging.basicConfig(
level=getattr(logging, LOG_LEVEL, logging.INFO),
format="%(asctime)s %(levelname)s %(message)s",
)
WHITELIST_URL = os.getenv(
"WHITELIST_URL",
"https://example.com/network-whitelist.txt",
)
REFRESH_INTERVAL_SECONDS = int(os.getenv("REFRESH_INTERVAL_SECONDS", "300"))
LISTEN_ADDRESS = os.getenv("LISTEN_ADDRESS", "0.0.0.0")
LISTEN_PORT = int(os.getenv("LISTEN_PORT", "8080"))
BASIC_AUTH_REALM = os.getenv("BASIC_AUTH_REALM", "Protected Area")
CLIENT_IP_STRATEGY = os.getenv("CLIENT_IP_STRATEGY", "rightmost").lower()
EXTRA_CIDRS = os.getenv("WHITELIST_EXTRA_CIDRS", "")
MAX_BASIC_AUTH_USERS = 3
class WhitelistStore:
def __init__(self):
self._lock = threading.RLock()
self._networks = []
self._source_hosts = []
self._resolved_entries = []
self._last_refresh = 0
self._last_error = ""
def snapshot(self, include_details=False):
with self._lock:
data = {
"network_count": len(self._networks),
"source_hosts": list(self._source_hosts),
"last_refresh": self._last_refresh,
"last_error": self._last_error,
}
if include_details:
data["networks"] = [str(network) for network in self._networks]
data["resolved_entries"] = [
{
"entry": item["entry"],
"networks": list(item["networks"]),
}
for item in self._resolved_entries
]
return data
def contains(self, ip):
with self._lock:
return any(ip in network for network in self._networks)
def matches(self, ip):
with self._lock:
return [str(network) for network in self._networks if ip in network]
def refresh_forever(self):
while True:
self.refresh_once()
time.sleep(REFRESH_INTERVAL_SECONDS)
def refresh_once(self):
try:
entries = self._load_entries()
networks = []
source_hosts = []
resolved_entries = []
for entry in entries:
resolved = resolve_entry(entry)
resolved = sorted(set(resolved), key=network_sort_key)
resolved_entries.append({
"entry": entry,
"networks": [str(network) for network in resolved],
})
if resolved:
networks.extend(resolved)
source_hosts.append(entry)
for cidr in parse_extra_cidrs(EXTRA_CIDRS):
networks.append(cidr)
networks = sorted(set(networks), key=network_sort_key)
with self._lock:
self._networks = networks
self._source_hosts = source_hosts
self._resolved_entries = resolved_entries
self._last_refresh = int(time.time())
self._last_error = ""
logging.info("whitelist refreshed: %s networks from %s source entries", len(networks), len(source_hosts))
except Exception as exc:
with self._lock:
self._last_error = str(exc)
logging.exception("whitelist refresh failed; keeping previous whitelist")
def _load_entries(self):
with urllib.request.urlopen(WHITELIST_URL, timeout=20) as response:
body = response.read().decode("utf-8")
entries = []
for raw_line in body.splitlines():
line = raw_line.split("#", 1)[0].strip()
if line:
entries.append(line)
return entries
def network_sort_key(network):
return (network.version, int(network.network_address), network.prefixlen)
def query_flag(query, name):
value = query.get(name, [""])[0].lower()
return value in ("1", "true", "yes", "on", "full")
def parse_extra_cidrs(value):
networks = []
for raw_item in value.replace("\n", ",").split(","):
item = raw_item.strip()
if not item:
continue
try:
networks.append(ipaddress.ip_network(item, strict=False))
except ValueError:
logging.warning("ignored invalid extra CIDR: %s", item)
return networks
def resolve_entry(entry):
try:
return [ipaddress.ip_network(entry, strict=False)]
except ValueError:
pass
networks = []
try:
answers = socket.getaddrinfo(entry, None, type=socket.SOCK_STREAM)
except socket.gaierror as exc:
logging.warning("DNS lookup failed for %s: %s", entry, exc)
return networks
for family, _, _, _, sockaddr in answers:
ip_text = sockaddr[0]
ip = ipaddress.ip_address(ip_text)
if family == socket.AF_INET:
networks.append(ipaddress.ip_network(f"{ip}/32", strict=False))
elif family == socket.AF_INET6:
networks.append(ipaddress.ip_network(f"{ip}/128", strict=False))
return list(set(networks))
def request_client_ip(handler):
configured_header = os.getenv("CLIENT_IP_HEADER", "").strip()
header_value = handler.headers.get(configured_header) if configured_header else ""
header_value = header_value or handler.headers.get("X-Forwarded-For", "")
candidates = []
for part in header_value.split(","):
value = part.strip()
if not value:
continue
try:
candidates.append(ipaddress.ip_address(value))
except ValueError:
logging.warning("ignored invalid forwarded IP: %s", value)
if candidates:
return candidates[0] if CLIENT_IP_STRATEGY == "leftmost" else candidates[-1]
real_ip = handler.headers.get("X-Real-IP", "").strip()
if real_ip:
try:
return ipaddress.ip_address(real_ip)
except ValueError:
logging.warning("ignored invalid X-Real-IP: %s", real_ip)
return ipaddress.ip_address(handler.client_address[0])
def authorization_scheme(header_value):
if not header_value:
return "none"
return header_value.split(None, 1)[0].lower()
def configured_basic_auth_users():
users = []
legacy_user = os.getenv("BASIC_AUTH_USER", "").strip()
legacy_password = os.getenv("BASIC_AUTH_PASSWORD", "")
legacy_password_sha256 = os.getenv("BASIC_AUTH_PASSWORD_SHA256", "")
if legacy_user:
users.append({
"username": legacy_user,
"password": legacy_password,
"password_sha256": legacy_password_sha256,
})
for index in range(2, MAX_BASIC_AUTH_USERS + 1):
username = os.getenv(f"BASIC_AUTH_USER_{index}", "").strip()
password = os.getenv(f"BASIC_AUTH_PASSWORD_{index}", "")
password_sha256 = os.getenv(f"BASIC_AUTH_PASSWORD_SHA256_{index}", "")
if not username:
continue
if username == legacy_user:
continue
users.append({
"username": username,
"password": password,
"password_sha256": password_sha256,
})
return users
def request_log_context(handler):
return {
"host": handler.headers.get("X-Forwarded-Host", ""),
"uri": handler.headers.get("X-Forwarded-Uri", ""),
"xff": handler.headers.get("X-Forwarded-For", ""),
"xreal": handler.headers.get("X-Real-IP", ""),
"remote": handler.client_address[0],
"auth_scheme": authorization_scheme(handler.headers.get("Authorization", "")),
}
def basic_auth_valid(header_value):
users = configured_basic_auth_users()
if not users:
logging.error("no BasicAuth users are configured")
return False
if not header_value.lower().startswith("basic "):
return False
try:
decoded = base64.b64decode(header_value[6:].strip(), validate=True).decode("utf-8")
username, password = decoded.split(":", 1)
except Exception:
return False
for user in users:
if not user["password"] and not user["password_sha256"]:
logging.error("BasicAuth user %s has no password configured", user["username"])
continue
user_ok = secrets.compare_digest(username, user["username"])
if user["password_sha256"]:
digest = hashlib.sha256(password.encode("utf-8")).hexdigest()
password_ok = secrets.compare_digest(digest, user["password_sha256"].lower())
else:
password_ok = secrets.compare_digest(password, user["password"])
if user_ok and password_ok:
return user["username"]
return False
STORE = WhitelistStore()
class AccessGateHandler(BaseHTTPRequestHandler):
server_version = "external-whitelist-auth-gate/1.0"
def do_GET(self):
self.handle_request()
def do_HEAD(self):
self.handle_request(include_body=False)
def do_POST(self):
self.handle_request()
def handle_request(self, include_body=True):
parsed_url = urlparse(self.path)
path = parsed_url.path
query = parse_qs(parsed_url.query)
if path.startswith("/healthz"):
self.respond(200, {"status": "ok"}, include_body)
return
if path.startswith("/status"):
self.respond(200, STORE.snapshot(include_details=query_flag(query, "verbose")), include_body)
return
if path.startswith("/check"):
ip_source = "request"
ip_text = query.get("ip", [""])[0].strip()
try:
if ip_text:
client_ip = ipaddress.ip_address(ip_text)
ip_source = "query"
else:
client_ip = request_client_ip(self)
except ValueError as exc:
self.respond(400, {"error": str(exc)}, include_body)
return
matches = STORE.matches(client_ip)
self.respond(200, {
"allowlisted": bool(matches),
"ip": str(client_ip),
"ip_source": ip_source,
"matched_networks": matches,
}, include_body)
return
if not path.startswith("/auth"):
self.respond(404, {"error": "not found"}, include_body)
return
client_ip = request_client_ip(self)
log_context = request_log_context(self)
if STORE.contains(client_ip):
logging.info(
"allowlisted ip=%s host=%s uri=%s xff=%s xreal=%s remote=%s auth_scheme=%s",
client_ip,
log_context["host"],
log_context["uri"],
log_context["xff"],
log_context["xreal"],
log_context["remote"],
log_context["auth_scheme"],
)
self.send_response(204)
self.send_header("X-Access-Gate-Reason", "allowlisted")
self.end_headers()
return
authenticated_user = basic_auth_valid(self.headers.get("Authorization", ""))
if authenticated_user:
logging.info(
"basic-auth ok ip=%s host=%s uri=%s xff=%s xreal=%s remote=%s auth_scheme=%s",
client_ip,
log_context["host"],
log_context["uri"],
log_context["xff"],
log_context["xreal"],
log_context["remote"],
log_context["auth_scheme"],
)
self.send_response(204)
self.send_header("X-Access-Gate-Reason", "basic-auth")
self.send_header("X-Access-Gate-User", authenticated_user)
self.end_headers()
return
logging.info(
"auth required ip=%s host=%s uri=%s xff=%s xreal=%s remote=%s auth_scheme=%s",
client_ip,
log_context["host"],
log_context["uri"],
log_context["xff"],
log_context["xreal"],
log_context["remote"],
log_context["auth_scheme"],
)
self.send_response(401)
self.send_header("WWW-Authenticate", f'Basic realm="{BASIC_AUTH_REALM}"')
self.send_header("Cache-Control", "no-store")
self.send_header("Content-Type", "text/plain; charset=utf-8")
self.end_headers()
if include_body:
self.wfile.write(b"Authentication required\n")
def respond(self, status, payload, include_body=True):
body = json.dumps(payload, sort_keys=True).encode("utf-8")
self.send_response(status)
self.send_header("Content-Type", "application/json; charset=utf-8")
self.send_header("Content-Length", str(len(body)))
self.end_headers()
if include_body:
self.wfile.write(body)
def log_message(self, fmt, *args):
logging.debug("%s - %s", self.address_string(), fmt % args)
def main():
STORE.refresh_once()
thread = threading.Thread(target=STORE.refresh_forever, daemon=True)
thread.start()
server = ThreadingHTTPServer((LISTEN_ADDRESS, LISTEN_PORT), AccessGateHandler)
logging.info("access gate listening on %s:%s", LISTEN_ADDRESS, LISTEN_PORT)
server.serve_forever()
if __name__ == "__main__":
main()