412 lines
13 KiB
Python
412 lines
13 KiB
Python
#!/usr/bin/env python3
|
|
import base64
|
|
import hashlib
|
|
import ipaddress
|
|
import json
|
|
import logging
|
|
import os
|
|
import secrets
|
|
import socket
|
|
import threading
|
|
import time
|
|
import urllib.request
|
|
from urllib.parse import parse_qs, urlparse
|
|
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
|
|
|
|
|
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
|
|
logging.basicConfig(
|
|
level=getattr(logging, LOG_LEVEL, logging.INFO),
|
|
format="%(asctime)s %(levelname)s %(message)s",
|
|
)
|
|
|
|
WHITELIST_URL = os.getenv(
|
|
"WHITELIST_URL",
|
|
"https://example.com/network-whitelist.txt",
|
|
)
|
|
REFRESH_INTERVAL_SECONDS = int(os.getenv("REFRESH_INTERVAL_SECONDS", "300"))
|
|
LISTEN_ADDRESS = os.getenv("LISTEN_ADDRESS", "0.0.0.0")
|
|
LISTEN_PORT = int(os.getenv("LISTEN_PORT", "8080"))
|
|
BASIC_AUTH_REALM = os.getenv("BASIC_AUTH_REALM", "Protected Area")
|
|
CLIENT_IP_STRATEGY = os.getenv("CLIENT_IP_STRATEGY", "rightmost").lower()
|
|
EXTRA_CIDRS = os.getenv("WHITELIST_EXTRA_CIDRS", "")
|
|
MAX_BASIC_AUTH_USERS = 3
|
|
|
|
|
|
class WhitelistStore:
|
|
def __init__(self):
|
|
self._lock = threading.RLock()
|
|
self._networks = []
|
|
self._source_hosts = []
|
|
self._resolved_entries = []
|
|
self._last_refresh = 0
|
|
self._last_error = ""
|
|
|
|
def snapshot(self, include_details=False):
|
|
with self._lock:
|
|
data = {
|
|
"network_count": len(self._networks),
|
|
"source_hosts": list(self._source_hosts),
|
|
"last_refresh": self._last_refresh,
|
|
"last_error": self._last_error,
|
|
}
|
|
if include_details:
|
|
data["networks"] = [str(network) for network in self._networks]
|
|
data["resolved_entries"] = [
|
|
{
|
|
"entry": item["entry"],
|
|
"networks": list(item["networks"]),
|
|
}
|
|
for item in self._resolved_entries
|
|
]
|
|
return data
|
|
|
|
def contains(self, ip):
|
|
with self._lock:
|
|
return any(ip in network for network in self._networks)
|
|
|
|
def matches(self, ip):
|
|
with self._lock:
|
|
return [str(network) for network in self._networks if ip in network]
|
|
|
|
def refresh_forever(self):
|
|
while True:
|
|
self.refresh_once()
|
|
time.sleep(REFRESH_INTERVAL_SECONDS)
|
|
|
|
def refresh_once(self):
|
|
try:
|
|
entries = self._load_entries()
|
|
networks = []
|
|
source_hosts = []
|
|
resolved_entries = []
|
|
|
|
for entry in entries:
|
|
resolved = resolve_entry(entry)
|
|
resolved = sorted(set(resolved), key=network_sort_key)
|
|
resolved_entries.append({
|
|
"entry": entry,
|
|
"networks": [str(network) for network in resolved],
|
|
})
|
|
if resolved:
|
|
networks.extend(resolved)
|
|
source_hosts.append(entry)
|
|
|
|
for cidr in parse_extra_cidrs(EXTRA_CIDRS):
|
|
networks.append(cidr)
|
|
|
|
networks = sorted(set(networks), key=network_sort_key)
|
|
|
|
with self._lock:
|
|
self._networks = networks
|
|
self._source_hosts = source_hosts
|
|
self._resolved_entries = resolved_entries
|
|
self._last_refresh = int(time.time())
|
|
self._last_error = ""
|
|
|
|
logging.info("whitelist refreshed: %s networks from %s source entries", len(networks), len(source_hosts))
|
|
except Exception as exc:
|
|
with self._lock:
|
|
self._last_error = str(exc)
|
|
logging.exception("whitelist refresh failed; keeping previous whitelist")
|
|
|
|
def _load_entries(self):
|
|
with urllib.request.urlopen(WHITELIST_URL, timeout=20) as response:
|
|
body = response.read().decode("utf-8")
|
|
|
|
entries = []
|
|
for raw_line in body.splitlines():
|
|
line = raw_line.split("#", 1)[0].strip()
|
|
if line:
|
|
entries.append(line)
|
|
return entries
|
|
|
|
|
|
def network_sort_key(network):
|
|
return (network.version, int(network.network_address), network.prefixlen)
|
|
|
|
|
|
def query_flag(query, name):
|
|
value = query.get(name, [""])[0].lower()
|
|
return value in ("1", "true", "yes", "on", "full")
|
|
|
|
|
|
def parse_extra_cidrs(value):
|
|
networks = []
|
|
for raw_item in value.replace("\n", ",").split(","):
|
|
item = raw_item.strip()
|
|
if not item:
|
|
continue
|
|
try:
|
|
networks.append(ipaddress.ip_network(item, strict=False))
|
|
except ValueError:
|
|
logging.warning("ignored invalid extra CIDR: %s", item)
|
|
return networks
|
|
|
|
|
|
def resolve_entry(entry):
|
|
try:
|
|
return [ipaddress.ip_network(entry, strict=False)]
|
|
except ValueError:
|
|
pass
|
|
|
|
networks = []
|
|
try:
|
|
answers = socket.getaddrinfo(entry, None, type=socket.SOCK_STREAM)
|
|
except socket.gaierror as exc:
|
|
logging.warning("DNS lookup failed for %s: %s", entry, exc)
|
|
return networks
|
|
|
|
for family, _, _, _, sockaddr in answers:
|
|
ip_text = sockaddr[0]
|
|
ip = ipaddress.ip_address(ip_text)
|
|
if family == socket.AF_INET:
|
|
networks.append(ipaddress.ip_network(f"{ip}/32", strict=False))
|
|
elif family == socket.AF_INET6:
|
|
networks.append(ipaddress.ip_network(f"{ip}/128", strict=False))
|
|
|
|
return list(set(networks))
|
|
|
|
|
|
def request_client_ip(handler):
|
|
configured_header = os.getenv("CLIENT_IP_HEADER", "").strip()
|
|
header_value = handler.headers.get(configured_header) if configured_header else ""
|
|
header_value = header_value or handler.headers.get("X-Forwarded-For", "")
|
|
|
|
candidates = []
|
|
for part in header_value.split(","):
|
|
value = part.strip()
|
|
if not value:
|
|
continue
|
|
try:
|
|
candidates.append(ipaddress.ip_address(value))
|
|
except ValueError:
|
|
logging.warning("ignored invalid forwarded IP: %s", value)
|
|
|
|
if candidates:
|
|
return candidates[0] if CLIENT_IP_STRATEGY == "leftmost" else candidates[-1]
|
|
|
|
real_ip = handler.headers.get("X-Real-IP", "").strip()
|
|
if real_ip:
|
|
try:
|
|
return ipaddress.ip_address(real_ip)
|
|
except ValueError:
|
|
logging.warning("ignored invalid X-Real-IP: %s", real_ip)
|
|
|
|
return ipaddress.ip_address(handler.client_address[0])
|
|
|
|
|
|
def authorization_scheme(header_value):
|
|
if not header_value:
|
|
return "none"
|
|
return header_value.split(None, 1)[0].lower()
|
|
|
|
|
|
def configured_basic_auth_users():
|
|
users = []
|
|
|
|
legacy_user = os.getenv("BASIC_AUTH_USER", "").strip()
|
|
legacy_password = os.getenv("BASIC_AUTH_PASSWORD", "")
|
|
legacy_password_sha256 = os.getenv("BASIC_AUTH_PASSWORD_SHA256", "")
|
|
if legacy_user:
|
|
users.append({
|
|
"username": legacy_user,
|
|
"password": legacy_password,
|
|
"password_sha256": legacy_password_sha256,
|
|
})
|
|
|
|
for index in range(2, MAX_BASIC_AUTH_USERS + 1):
|
|
username = os.getenv(f"BASIC_AUTH_USER_{index}", "").strip()
|
|
password = os.getenv(f"BASIC_AUTH_PASSWORD_{index}", "")
|
|
password_sha256 = os.getenv(f"BASIC_AUTH_PASSWORD_SHA256_{index}", "")
|
|
if not username:
|
|
continue
|
|
if username == legacy_user:
|
|
continue
|
|
users.append({
|
|
"username": username,
|
|
"password": password,
|
|
"password_sha256": password_sha256,
|
|
})
|
|
|
|
return users
|
|
|
|
|
|
def request_log_context(handler):
|
|
return {
|
|
"host": handler.headers.get("X-Forwarded-Host", ""),
|
|
"uri": handler.headers.get("X-Forwarded-Uri", ""),
|
|
"xff": handler.headers.get("X-Forwarded-For", ""),
|
|
"xreal": handler.headers.get("X-Real-IP", ""),
|
|
"remote": handler.client_address[0],
|
|
"auth_scheme": authorization_scheme(handler.headers.get("Authorization", "")),
|
|
}
|
|
|
|
|
|
def basic_auth_valid(header_value):
|
|
users = configured_basic_auth_users()
|
|
if not users:
|
|
logging.error("no BasicAuth users are configured")
|
|
return False
|
|
|
|
if not header_value.lower().startswith("basic "):
|
|
return False
|
|
|
|
try:
|
|
decoded = base64.b64decode(header_value[6:].strip(), validate=True).decode("utf-8")
|
|
username, password = decoded.split(":", 1)
|
|
except Exception:
|
|
return False
|
|
|
|
for user in users:
|
|
if not user["password"] and not user["password_sha256"]:
|
|
logging.error("BasicAuth user %s has no password configured", user["username"])
|
|
continue
|
|
|
|
user_ok = secrets.compare_digest(username, user["username"])
|
|
if user["password_sha256"]:
|
|
digest = hashlib.sha256(password.encode("utf-8")).hexdigest()
|
|
password_ok = secrets.compare_digest(digest, user["password_sha256"].lower())
|
|
else:
|
|
password_ok = secrets.compare_digest(password, user["password"])
|
|
|
|
if user_ok and password_ok:
|
|
return user["username"]
|
|
|
|
return False
|
|
|
|
|
|
STORE = WhitelistStore()
|
|
|
|
|
|
class AccessGateHandler(BaseHTTPRequestHandler):
|
|
server_version = "external-whitelist-auth-gate/1.0"
|
|
|
|
def do_GET(self):
|
|
self.handle_request()
|
|
|
|
def do_HEAD(self):
|
|
self.handle_request(include_body=False)
|
|
|
|
def do_POST(self):
|
|
self.handle_request()
|
|
|
|
def handle_request(self, include_body=True):
|
|
parsed_url = urlparse(self.path)
|
|
path = parsed_url.path
|
|
query = parse_qs(parsed_url.query)
|
|
|
|
if path.startswith("/healthz"):
|
|
self.respond(200, {"status": "ok"}, include_body)
|
|
return
|
|
|
|
if path.startswith("/status"):
|
|
self.respond(200, STORE.snapshot(include_details=query_flag(query, "verbose")), include_body)
|
|
return
|
|
|
|
if path.startswith("/check"):
|
|
ip_source = "request"
|
|
ip_text = query.get("ip", [""])[0].strip()
|
|
try:
|
|
if ip_text:
|
|
client_ip = ipaddress.ip_address(ip_text)
|
|
ip_source = "query"
|
|
else:
|
|
client_ip = request_client_ip(self)
|
|
except ValueError as exc:
|
|
self.respond(400, {"error": str(exc)}, include_body)
|
|
return
|
|
|
|
matches = STORE.matches(client_ip)
|
|
self.respond(200, {
|
|
"allowlisted": bool(matches),
|
|
"ip": str(client_ip),
|
|
"ip_source": ip_source,
|
|
"matched_networks": matches,
|
|
}, include_body)
|
|
return
|
|
|
|
if not path.startswith("/auth"):
|
|
self.respond(404, {"error": "not found"}, include_body)
|
|
return
|
|
|
|
client_ip = request_client_ip(self)
|
|
log_context = request_log_context(self)
|
|
if STORE.contains(client_ip):
|
|
logging.info(
|
|
"allowlisted ip=%s host=%s uri=%s xff=%s xreal=%s remote=%s auth_scheme=%s",
|
|
client_ip,
|
|
log_context["host"],
|
|
log_context["uri"],
|
|
log_context["xff"],
|
|
log_context["xreal"],
|
|
log_context["remote"],
|
|
log_context["auth_scheme"],
|
|
)
|
|
self.send_response(204)
|
|
self.send_header("X-Access-Gate-Reason", "allowlisted")
|
|
self.end_headers()
|
|
return
|
|
|
|
authenticated_user = basic_auth_valid(self.headers.get("Authorization", ""))
|
|
if authenticated_user:
|
|
logging.info(
|
|
"basic-auth ok ip=%s host=%s uri=%s xff=%s xreal=%s remote=%s auth_scheme=%s",
|
|
client_ip,
|
|
log_context["host"],
|
|
log_context["uri"],
|
|
log_context["xff"],
|
|
log_context["xreal"],
|
|
log_context["remote"],
|
|
log_context["auth_scheme"],
|
|
)
|
|
self.send_response(204)
|
|
self.send_header("X-Access-Gate-Reason", "basic-auth")
|
|
self.send_header("X-Access-Gate-User", authenticated_user)
|
|
self.end_headers()
|
|
return
|
|
|
|
logging.info(
|
|
"auth required ip=%s host=%s uri=%s xff=%s xreal=%s remote=%s auth_scheme=%s",
|
|
client_ip,
|
|
log_context["host"],
|
|
log_context["uri"],
|
|
log_context["xff"],
|
|
log_context["xreal"],
|
|
log_context["remote"],
|
|
log_context["auth_scheme"],
|
|
)
|
|
self.send_response(401)
|
|
self.send_header("WWW-Authenticate", f'Basic realm="{BASIC_AUTH_REALM}"')
|
|
self.send_header("Cache-Control", "no-store")
|
|
self.send_header("Content-Type", "text/plain; charset=utf-8")
|
|
self.end_headers()
|
|
if include_body:
|
|
self.wfile.write(b"Authentication required\n")
|
|
|
|
def respond(self, status, payload, include_body=True):
|
|
body = json.dumps(payload, sort_keys=True).encode("utf-8")
|
|
self.send_response(status)
|
|
self.send_header("Content-Type", "application/json; charset=utf-8")
|
|
self.send_header("Content-Length", str(len(body)))
|
|
self.end_headers()
|
|
if include_body:
|
|
self.wfile.write(body)
|
|
|
|
def log_message(self, fmt, *args):
|
|
logging.debug("%s - %s", self.address_string(), fmt % args)
|
|
|
|
|
|
def main():
|
|
STORE.refresh_once()
|
|
thread = threading.Thread(target=STORE.refresh_forever, daemon=True)
|
|
thread.start()
|
|
|
|
server = ThreadingHTTPServer((LISTEN_ADDRESS, LISTEN_PORT), AccessGateHandler)
|
|
logging.info("access gate listening on %s:%s", LISTEN_ADDRESS, LISTEN_PORT)
|
|
server.serve_forever()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|