feat!: Migration auf Go-Binary

BREAKING CHANGE: Die alte Shell-Version muss vor der Installation der Go-Version deinstalliert werden.
This commit is contained in:
Patrick Asmus
2026-05-01 00:08:57 +02:00
parent 0d1f7db43b
commit 4f17f7ff81
50 changed files with 8012 additions and 9496 deletions

View File

@@ -0,0 +1,5 @@
package appinfo
var Version = "v1.0.0"
const ProjectURL = "https://tnvs.de/as"

295
internal/config/config.go Normal file
View File

@@ -0,0 +1,295 @@
package config
import (
"bufio"
"fmt"
"os"
"path/filepath"
"strconv"
"strings"
)
type Config struct {
Path string
AdGuardURL string
AdGuardUser string
AdGuardPass string
RateLimitMaxRequests int
RateLimitWindow int
CheckInterval int
APIQueryLimit int
SubdomainFloodEnabled bool
SubdomainFloodMaxUnique int
SubdomainFloodWindow int
DNSFloodWatchlistEnabled bool
DNSFloodWatchlist []string
BanDuration int64
Chain string
BlockedPorts []string
FirewallBackend string
FirewallMode string
DryRun bool
Whitelist []string
LogFile string
LogLevel string
StateDir string
PIDFile string
NotifyEnabled bool
NotifyType string
NotifyWebhook string
NTFYServerURL string
NTFYTopic string
NTFYToken string
NTFYPriority string
ReportEnabled bool
ReportInterval string
ReportTime string
ReportEmailTo string
ReportEmailFrom string
ReportFormat string
ReportMailCmd string
ReportBusiestDayRange int
ExternalWhitelistEnabled bool
ExternalWhitelistURLs []string
ExternalWhitelistInterval int
ExternalWhitelistCacheDir string
ExternalBlocklistEnabled bool
ExternalBlocklistURLs []string
ExternalBlocklistInterval int
ExternalBlocklistCacheDir string
ExternalBlocklistDuration int64
ExternalBlocklistAutoUnban bool
ExternalBlocklistNotify bool
ProgressiveBanEnabled bool
ProgressiveBanMultiplier int
ProgressiveBanMaxLevel int
ProgressiveBanResetAfter int64
AbuseIPDBEnabled bool
AbuseIPDBAPIKey string
AbuseIPDBCategories string
GeoIPEnabled bool
GeoIPMode string
GeoIPCountries []string
GeoIPNotify bool
GeoIPSkipPrivate bool
GeoIPLicenseKey string
GeoIPMMDBPath string
GeoIPCacheTTL int64
GeoIPCheckInterval int
}
func Load(path string) (*Config, error) {
values, err := parseFile(path)
if err != nil {
return nil, err
}
c := &Config{Path: path}
c.AdGuardURL = stringVal(values, "ADGUARD_URL", "")
c.AdGuardUser = stringVal(values, "ADGUARD_USER", "")
c.AdGuardPass = stringVal(values, "ADGUARD_PASS", "")
c.RateLimitMaxRequests = intVal(values, "RATE_LIMIT_MAX_REQUESTS", 30)
c.RateLimitWindow = intVal(values, "RATE_LIMIT_WINDOW", 60)
c.CheckInterval = intVal(values, "CHECK_INTERVAL", 10)
c.APIQueryLimit = intVal(values, "API_QUERY_LIMIT", 500)
c.SubdomainFloodEnabled = boolVal(values, "SUBDOMAIN_FLOOD_ENABLED", true)
c.SubdomainFloodMaxUnique = intVal(values, "SUBDOMAIN_FLOOD_MAX_UNIQUE", 50)
c.SubdomainFloodWindow = intVal(values, "SUBDOMAIN_FLOOD_WINDOW", 60)
c.DNSFloodWatchlistEnabled = boolVal(values, "DNS_FLOOD_WATCHLIST_ENABLED", false)
c.DNSFloodWatchlist = csv(values["DNS_FLOOD_WATCHLIST"])
c.BanDuration = int64(intVal(values, "BAN_DURATION", 3600))
c.Chain = stringVal(values, "IPTABLES_CHAIN", "ADGUARD_SHIELD")
c.BlockedPorts = fields(stringVal(values, "BLOCKED_PORTS", "53 443 853"))
c.FirewallBackend = stringVal(values, "FIREWALL_BACKEND", "ipset")
c.FirewallMode = strings.ToLower(strings.TrimSpace(stringVal(values, "FIREWALL_MODE", "host")))
c.DryRun = boolVal(values, "DRY_RUN", false)
if strings.EqualFold(os.Getenv("DRY_RUN"), "true") || os.Getenv("DRY_RUN") == "1" {
c.DryRun = true
}
c.Whitelist = csv(values["WHITELIST"])
c.LogFile = stringVal(values, "LOG_FILE", "/var/log/adguard-shield.log")
c.LogLevel = stringVal(values, "LOG_LEVEL", "INFO")
c.StateDir = stringVal(values, "STATE_DIR", "/var/lib/adguard-shield")
c.PIDFile = stringVal(values, "PID_FILE", "/var/run/adguard-shield.pid")
c.NotifyEnabled = boolVal(values, "NOTIFY_ENABLED", false)
c.NotifyType = stringVal(values, "NOTIFY_TYPE", "ntfy")
c.NotifyWebhook = stringVal(values, "NOTIFY_WEBHOOK_URL", "")
c.NTFYServerURL = stringVal(values, "NTFY_SERVER_URL", "https://ntfy.sh")
c.NTFYTopic = stringVal(values, "NTFY_TOPIC", "")
c.NTFYToken = stringVal(values, "NTFY_TOKEN", "")
c.NTFYPriority = stringVal(values, "NTFY_PRIORITY", "4")
c.ReportEnabled = boolVal(values, "REPORT_ENABLED", false)
c.ReportInterval = stringVal(values, "REPORT_INTERVAL", "weekly")
c.ReportTime = stringVal(values, "REPORT_TIME", "08:00")
c.ReportEmailTo = stringVal(values, "REPORT_EMAIL_TO", "admin@example.com")
c.ReportEmailFrom = stringVal(values, "REPORT_EMAIL_FROM", "adguard-shield@example.com")
c.ReportFormat = strings.ToLower(stringVal(values, "REPORT_FORMAT", "html"))
c.ReportMailCmd = stringVal(values, "REPORT_MAIL_CMD", "msmtp")
c.ReportBusiestDayRange = intVal(values, "REPORT_BUSIEST_DAY_RANGE", 30)
c.ExternalWhitelistEnabled = boolVal(values, "EXTERNAL_WHITELIST_ENABLED", false)
c.ExternalWhitelistURLs = csv(values["EXTERNAL_WHITELIST_URLS"])
c.ExternalWhitelistInterval = intVal(values, "EXTERNAL_WHITELIST_INTERVAL", 300)
c.ExternalWhitelistCacheDir = stringVal(values, "EXTERNAL_WHITELIST_CACHE_DIR", filepath.Join(c.StateDir, "external-whitelist"))
c.ExternalBlocklistEnabled = boolVal(values, "EXTERNAL_BLOCKLIST_ENABLED", false)
c.ExternalBlocklistURLs = csv(values["EXTERNAL_BLOCKLIST_URLS"])
c.ExternalBlocklistInterval = intVal(values, "EXTERNAL_BLOCKLIST_INTERVAL", 300)
c.ExternalBlocklistCacheDir = stringVal(values, "EXTERNAL_BLOCKLIST_CACHE_DIR", filepath.Join(c.StateDir, "external-blocklist"))
c.ExternalBlocklistDuration = int64(intVal(values, "EXTERNAL_BLOCKLIST_BAN_DURATION", 0))
c.ExternalBlocklistAutoUnban = boolVal(values, "EXTERNAL_BLOCKLIST_AUTO_UNBAN", true)
c.ExternalBlocklistNotify = boolVal(values, "EXTERNAL_BLOCKLIST_NOTIFY", false)
c.ProgressiveBanEnabled = boolVal(values, "PROGRESSIVE_BAN_ENABLED", true)
c.ProgressiveBanMultiplier = intVal(values, "PROGRESSIVE_BAN_MULTIPLIER", 2)
c.ProgressiveBanMaxLevel = intVal(values, "PROGRESSIVE_BAN_MAX_LEVEL", 5)
c.ProgressiveBanResetAfter = int64(intVal(values, "PROGRESSIVE_BAN_RESET_AFTER", 86400))
c.AbuseIPDBEnabled = boolVal(values, "ABUSEIPDB_ENABLED", false)
c.AbuseIPDBAPIKey = stringVal(values, "ABUSEIPDB_API_KEY", "")
c.AbuseIPDBCategories = stringVal(values, "ABUSEIPDB_CATEGORIES", "4")
c.GeoIPEnabled = boolVal(values, "GEOIP_ENABLED", false)
c.GeoIPMode = strings.ToLower(stringVal(values, "GEOIP_MODE", "blocklist"))
c.GeoIPCountries = upperCSV(values["GEOIP_COUNTRIES"])
c.GeoIPNotify = boolVal(values, "GEOIP_NOTIFY", true)
c.GeoIPSkipPrivate = boolVal(values, "GEOIP_SKIP_PRIVATE", true)
c.GeoIPLicenseKey = stringVal(values, "GEOIP_LICENSE_KEY", "")
c.GeoIPMMDBPath = stringVal(values, "GEOIP_MMDB_PATH", "")
c.GeoIPCacheTTL = int64(intVal(values, "GEOIP_CACHE_TTL", 86400))
c.GeoIPCheckInterval = intVal(values, "GEOIP_CHECK_INTERVAL", 0)
return c, nil
}
func DefaultPath() string {
if v := os.Getenv("ADGUARD_SHIELD_CONFIG"); v != "" {
return v
}
if _, err := os.Stat("/opt/adguard-shield/adguard-shield.conf"); err == nil {
return "/opt/adguard-shield/adguard-shield.conf"
}
return filepath.Join(".", "adguard-shield.conf")
}
func (c *Config) DBPath() string { return filepath.Join(c.StateDir, "adguard-shield.db") }
func (c *Config) GeoIPDir(scriptDir string) string { return filepath.Join(scriptDir, "geoip") }
func parseFile(path string) (map[string]string, error) {
f, err := os.Open(path)
if err != nil {
return nil, fmt.Errorf("open config %s: %w", path, err)
}
defer f.Close()
out := map[string]string{}
sc := bufio.NewScanner(f)
for sc.Scan() {
line := strings.TrimSpace(sc.Text())
if line == "" || strings.HasPrefix(line, "#") {
continue
}
idx := strings.Index(line, "=")
if idx < 1 {
continue
}
key := strings.TrimSpace(line[:idx])
val := stripInlineComment(strings.TrimSpace(line[idx+1:]))
out[key] = unquote(val)
}
return out, sc.Err()
}
func stripInlineComment(s string) string {
inSingle, inDouble := false, false
for i, r := range s {
switch r {
case '\'':
if !inDouble {
inSingle = !inSingle
}
case '"':
if !inSingle {
inDouble = !inDouble
}
case '#':
if !inSingle && !inDouble {
if i == 0 || s[i-1] == ' ' || s[i-1] == '\t' {
return strings.TrimSpace(s[:i])
}
}
}
}
return strings.TrimSpace(s)
}
func unquote(s string) string {
if len(s) >= 2 {
if (s[0] == '"' && s[len(s)-1] == '"') || (s[0] == '\'' && s[len(s)-1] == '\'') {
return s[1 : len(s)-1]
}
}
return s
}
func stringVal(m map[string]string, k, def string) string {
if v, ok := m[k]; ok {
return v
}
return def
}
func intVal(m map[string]string, k string, def int) int {
v, ok := m[k]
if !ok || strings.TrimSpace(v) == "" {
return def
}
n, err := strconv.Atoi(strings.TrimSpace(v))
if err != nil {
return def
}
return n
}
func boolVal(m map[string]string, k string, def bool) bool {
v, ok := m[k]
if !ok {
return def
}
switch strings.ToLower(strings.TrimSpace(v)) {
case "true", "1", "yes", "on":
return true
case "false", "0", "no", "off":
return false
default:
return def
}
}
func csv(s string) []string {
var out []string
for _, p := range strings.Split(s, ",") {
p = strings.TrimSpace(p)
if p != "" {
out = append(out, p)
}
}
return out
}
func upperCSV(s string) []string {
parts := csv(s)
for i := range parts {
parts[i] = strings.ToUpper(parts[i])
}
return parts
}
func fields(s string) []string {
out := strings.Fields(s)
if len(out) == 0 {
return []string{"53"}
}
return out
}

View File

@@ -0,0 +1,44 @@
package config
import (
"os"
"path/filepath"
"testing"
)
func TestLoadParsesShellStyleConfig(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "adguard-shield.conf")
err := os.WriteFile(path, []byte(`
ADGUARD_URL="https://dns.example"
ADGUARD_USER="admin"
ADGUARD_PASS='pa#ss'
CHECK_INTERVAL=7
BLOCKED_PORTS="53 443 853"
FIREWALL_BACKEND="ipset"
FIREWALL_MODE="docker-bridge"
GEOIP_ENABLED=true
GEOIP_MODE="allowlist"
GEOIP_COUNTRIES="DE, us"
GEOIP_CACHE_TTL=123
`), 0600)
if err != nil {
t.Fatal(err)
}
c, err := Load(path)
if err != nil {
t.Fatal(err)
}
if c.AdGuardPass != "pa#ss" {
t.Fatalf("quoted # was not preserved: %q", c.AdGuardPass)
}
if c.CheckInterval != 7 || c.FirewallBackend != "ipset" || c.FirewallMode != "docker-bridge" {
t.Fatalf("unexpected config: %+v", c)
}
if got := c.GeoIPCountries; len(got) != 2 || got[0] != "DE" || got[1] != "US" {
t.Fatalf("unexpected countries: %#v", got)
}
if c.GeoIPCacheTTL != 123 {
t.Fatalf("unexpected GeoIP cache ttl: %d", c.GeoIPCacheTTL)
}
}

1221
internal/daemon/daemon.go Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,365 @@
package daemon
import (
"context"
"io"
"net/http"
"net/http/httptest"
"path/filepath"
"strings"
"testing"
"time"
"adguard-shield/internal/config"
"adguard-shield/internal/db"
"adguard-shield/internal/firewall"
)
func TestParseListEntry(t *testing.T) {
cases := map[string]string{
"1.2.3.4 # comment": "1.2.3.4",
"0.0.0.0 bad.example": "bad.example",
"2001:db8::/32": "2001:db8::/32",
}
for input, want := range cases {
got := parseListEntry(input)
if len(got) != 1 || got[0] != want {
t.Fatalf("%q -> %#v, want %q", input, got, want)
}
}
if got := parseListEntry("http://example.invalid/list"); got != nil {
t.Fatalf("URL should be rejected: %#v", got)
}
}
func TestNotificationFormatting(t *testing.T) {
d := &Daemon{Config: &config.Config{
RateLimitWindow: 60,
SubdomainFloodWindow: 120,
ProgressiveBanMaxLevel: 3,
}}
b := db.Ban{
IP: "203.0.113.7",
Domain: "abb.com",
Count: 110,
Duration: 3600,
OffenseLevel: 1,
Reason: "rate-limit",
Protocol: "dns",
Source: "monitor",
}
if got, want := d.displayBanReason(b), "110x abb.com in 60s via DNS, Rate-Limit"; got != want {
t.Fatalf("reason = %q, want %q", got, want)
}
if got, want := d.displayBanDuration(b), "1h 0m [Stufe 1/3]"; got != want {
t.Fatalf("duration = %q, want %q", got, want)
}
b.Permanent = true
b.Duration = 0
b.OffenseLevel = 3
if got, want := d.displayBanDuration(b), "PERMANENT [Stufe 3/3]"; got != want {
t.Fatalf("permanent duration = %q, want %q", got, want)
}
}
func TestNTFYNotificationTitleDoesNotDuplicateShieldTag(t *testing.T) {
d := &Daemon{Config: &config.Config{
NotifyType: "ntfy",
NTFYServerURL: "https://ntfy.example",
NTFYTopic: "adguard-shield",
NTFYPriority: "4",
}}
req, err := d.notificationRequest(context.Background(), "🛡️ AdGuard Shield", "test", db.Ban{IP: "203.0.113.7", Reason: "rate-limit", Source: "monitor"})
if err != nil {
t.Fatal(err)
}
if req == nil {
t.Fatal("request must be created")
}
if got, want := req.Header.Get("Title"), "🛡️ AdGuard Shield"; got != want {
t.Fatalf("title = %q, want %q", got, want)
}
if got := req.Header.Get("Tags"); strings.Contains(got, "shield") {
t.Fatalf("tags must not duplicate title shield emoji: %q", got)
}
}
func TestNotificationRequestsForWebhookProviders(t *testing.T) {
cases := []struct {
name string
notifyType string
wantType string
wantPayload []string
}{
{
name: "discord",
notifyType: "discord",
wantType: "application/json",
wantPayload: []string{`"content":"title\n\nmessage"`},
},
{
name: "slack",
notifyType: "slack",
wantType: "application/json",
wantPayload: []string{`"text":"title\n\nmessage"`},
},
{
name: "generic",
notifyType: "generic",
wantType: "application/json",
wantPayload: []string{`"action":"unban"`, `"client":"203.0.113.7"`, `"message":"message"`},
},
{
name: "gotify",
notifyType: "gotify",
wantType: "application/x-www-form-urlencoded",
wantPayload: []string{`title=title`, `message=message`, `priority=5`},
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
d := &Daemon{Config: &config.Config{
NotifyType: tc.notifyType,
NotifyWebhook: "https://hooks.example/notify",
}}
req, err := d.notificationRequest(context.Background(), "title", "message", db.Ban{IP: "203.0.113.7", Reason: "manual"})
if err != nil {
t.Fatal(err)
}
if req == nil {
t.Fatal("request must be created")
}
if req.Method != http.MethodPost {
t.Fatalf("method = %s, want POST", req.Method)
}
if got := req.Header.Get("Content-Type"); got != tc.wantType {
t.Fatalf("content type = %q, want %q", got, tc.wantType)
}
body, err := io.ReadAll(req.Body)
if err != nil {
t.Fatal(err)
}
payload := string(body)
for _, want := range tc.wantPayload {
if !strings.Contains(payload, want) {
t.Fatalf("payload %q does not contain %q", payload, want)
}
}
})
}
}
func TestServiceNotificationsSendStartAndStopOnce(t *testing.T) {
requests := make(chan string, 4)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
requests <- string(body)
w.WriteHeader(http.StatusNoContent)
}))
defer srv.Close()
d := &Daemon{
Config: &config.Config{NotifyEnabled: true, NotifyType: "generic", NotifyWebhook: srv.URL},
Client: srv.Client(),
}
d.NotifyServiceStart(context.Background())
d.NotifyServiceStart(context.Background())
d.NotifyServiceStop(context.Background())
d.NotifyServiceStop(context.Background())
var payloads []string
for len(payloads) < 2 {
select {
case payload := <-requests:
payloads = append(payloads, payload)
case <-time.After(4 * time.Second):
t.Fatalf("service notifications sent %d payloads, want 2", len(payloads))
}
}
if !strings.Contains(payloads[0], `"action":"service_start"`) || !strings.Contains(payloads[0], "gestartet") {
t.Fatalf("unexpected service start payload: %s", payloads[0])
}
if !strings.Contains(payloads[1], `"action":"service_stop"`) || !strings.Contains(payloads[1], "gestoppt") {
t.Fatalf("unexpected service stop payload: %s", payloads[1])
}
select {
case payload := <-requests:
t.Fatalf("duplicate service notification sent: %s", payload)
case <-time.After(150 * time.Millisecond):
}
}
func TestUnbanSendsNotificationForMonitorBan(t *testing.T) {
requests := make(chan string, 1)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
requests <- string(body)
w.WriteHeader(http.StatusNoContent)
}))
defer srv.Close()
store, err := db.Open(filepath.Join(t.TempDir(), "test.db"))
if err != nil {
t.Fatal(err)
}
defer store.Close()
if err := store.InsertBan(db.Ban{IP: "127.0.0.1", Reason: "rate-limit", Source: "monitor"}); err != nil {
t.Fatal(err)
}
d := &Daemon{
Config: &config.Config{NotifyEnabled: true, NotifyType: "generic", NotifyWebhook: srv.URL},
Store: store,
FW: firewall.New(firewall.OSExecutor{}, "ADGUARD_SHIELD", []string{"53"}, "host", true),
Client: srv.Client(),
}
if err := d.Unban(context.Background(), "127.0.0.1", "manual"); err != nil {
t.Fatal(err)
}
select {
case payload := <-requests:
if !strings.Contains(payload, `"action":"unban"`) || !strings.Contains(payload, "AdGuard Shield Freigabe") {
t.Fatalf("unexpected payload: %s", payload)
}
case <-time.After(4 * time.Second):
t.Fatal("unban notification was not sent")
}
}
func TestUnbanStillSendsExternalBlocklistNotificationWhenBanNotificationsDisabled(t *testing.T) {
requests := make(chan string, 1)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
requests <- string(body)
w.WriteHeader(http.StatusNoContent)
}))
defer srv.Close()
store, err := db.Open(filepath.Join(t.TempDir(), "test.db"))
if err != nil {
t.Fatal(err)
}
defer store.Close()
if err := store.InsertBan(db.Ban{IP: "127.0.0.1", Reason: "external-blocklist", Source: "external-blocklist"}); err != nil {
t.Fatal(err)
}
d := &Daemon{
Config: &config.Config{NotifyEnabled: true, NotifyType: "generic", NotifyWebhook: srv.URL, ExternalBlocklistNotify: false},
Store: store,
FW: firewall.New(firewall.OSExecutor{}, "ADGUARD_SHIELD", []string{"53"}, "host", true),
Client: srv.Client(),
}
if err := d.Unban(context.Background(), "127.0.0.1", "external-blocklist-removed"); err != nil {
t.Fatal(err)
}
select {
case payload := <-requests:
if !strings.Contains(payload, `"action":"unban"`) || !strings.Contains(payload, "AdGuard Shield Freigabe") {
t.Fatalf("unexpected payload: %s", payload)
}
case <-time.After(4 * time.Second):
t.Fatal("external blocklist unban notification was not sent")
}
}
func TestUnbanQuietSkipsIndividualNotificationAndBulkSummarySendsOnce(t *testing.T) {
requests := make(chan string, 2)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
requests <- string(body)
w.WriteHeader(http.StatusNoContent)
}))
defer srv.Close()
store, err := db.Open(filepath.Join(t.TempDir(), "test.db"))
if err != nil {
t.Fatal(err)
}
defer store.Close()
if err := store.InsertBan(db.Ban{IP: "127.0.0.1", Reason: "rate-limit", Source: "monitor"}); err != nil {
t.Fatal(err)
}
d := &Daemon{
Config: &config.Config{NotifyEnabled: true, NotifyType: "generic", NotifyWebhook: srv.URL},
Store: store,
FW: firewall.New(firewall.OSExecutor{}, "ADGUARD_SHIELD", []string{"53"}, "host", true),
Client: srv.Client(),
}
if err := d.UnbanQuiet(context.Background(), "127.0.0.1", "manual-flush"); err != nil {
t.Fatal(err)
}
select {
case payload := <-requests:
t.Fatalf("quiet unban sent individual notification: %s", payload)
case <-time.After(150 * time.Millisecond):
}
d.NotifyBulkUnban(context.Background(), "manual-flush", 1)
select {
case payload := <-requests:
if !strings.Contains(payload, `"action":"manual-flush"`) || !strings.Contains(payload, "Bulk-Freigabe") || !strings.Contains(payload, "Freigegebene IPs: 1") {
t.Fatalf("unexpected payload: %s", payload)
}
case <-time.After(4 * time.Second):
t.Fatal("bulk unban notification was not sent")
}
}
func TestAbuseReportingScope(t *testing.T) {
d := &Daemon{Config: &config.Config{AbuseIPDBEnabled: true, AbuseIPDBAPIKey: "key"}}
if !d.shouldReportAbuseIPDB(db.Ban{Permanent: true, Source: "monitor"}) {
t.Fatal("monitor permanent ban should be reported")
}
if d.shouldReportAbuseIPDB(db.Ban{Permanent: true, Source: "geoip"}) {
t.Fatal("geoip ban must not be reported")
}
if d.shouldReportAbuseIPDB(db.Ban{Permanent: false, Source: "monitor"}) {
t.Fatal("temporary ban must not be reported")
}
d.Config.RateLimitWindow = 60
got := d.abuseIPDBComment(db.Ban{Count: 110, Domain: "abb.com", Reason: "rate-limit"})
want := "DNS flooding on our DNS server: 110x abb.com in 60s. Banned by Adguard Shield 🔗 https://tnvs.de/as"
if got != want {
t.Fatalf("comment = %q, want %q", got, want)
}
}
func TestAbuseIPDBCheckURL(t *testing.T) {
if got := abuseIPDBCheckURL("65.185.189.75"); !strings.Contains(got, "https://www.abuseipdb.com/check/65.185.189.75") {
t.Fatalf("unexpected AbuseIPDB url: %s", got)
}
}
func TestBaseDomain(t *testing.T) {
if got := baseDomain("a.b.example.com"); got != "example.com" {
t.Fatalf("unexpected base domain: %s", got)
}
if got := baseDomain("a.b.example.co.uk"); got != "example.co.uk" {
t.Fatalf("unexpected multipart base domain: %s", got)
}
}
func TestDryRunDoesNotInsertActiveBan(t *testing.T) {
store, err := db.Open(filepath.Join(t.TempDir(), "test.db"))
if err != nil {
t.Fatal(err)
}
defer store.Close()
d := &Daemon{
Config: &config.Config{DryRun: true, BanDuration: 60},
Store: store,
FW: firewall.New(firewall.OSExecutor{}, "ADGUARD_SHIELD", []string{"53"}, "host", true),
wl: map[string]bool{},
}
if err := d.Ban(context.Background(), "1.2.3.4", "example.com", 99, "dns", "rate-limit", "monitor", "", false); err != nil {
t.Fatal(err)
}
ok, err := store.BanExists("1.2.3.4")
if err != nil {
t.Fatal(err)
}
if ok {
t.Fatal("dry-run must not create an active ban")
}
}

384
internal/daemon/live.go Normal file
View File

@@ -0,0 +1,384 @@
package daemon
import (
"bufio"
"context"
"fmt"
"io"
"os"
"sort"
"strings"
"time"
"adguard-shield/internal/db"
"adguard-shield/internal/syslog"
)
type LiveOptions struct {
Interval time.Duration
Top int
Recent int
LogLevel string
Once bool
}
type liveSnapshot struct {
At time.Time
APIEntries int
Window int
Limit int
Events []queryEvent
TopPairs []liveCount
SubdomainGroups []liveCount
ActiveBans []db.Ban
Offenses int
ExpiredOffenses int
WhitelistCount int
BlocklistBans int
SystemLogs []string
}
type liveCount struct {
Client string
Domain string
Count int
Protocol string
}
func (d *Daemon) Live(ctx context.Context, w io.Writer, opts LiveOptions) error {
if opts.Interval <= 0 {
opts.Interval = time.Duration(d.Config.CheckInterval) * time.Second
}
if opts.Interval <= 0 {
opts.Interval = 2 * time.Second
}
if opts.Top <= 0 {
opts.Top = 10
}
if opts.Recent <= 0 {
opts.Recent = 12
}
if strings.TrimSpace(opts.LogLevel) == "" {
opts.LogLevel = "INFO"
}
for {
snap, err := d.liveSnapshot(ctx, opts)
renderLive(w, d, snap, err, opts)
if opts.Once {
return err
}
timer := time.NewTimer(opts.Interval)
select {
case <-ctx.Done():
timer.Stop()
return ctx.Err()
case <-timer.C:
}
}
}
func (d *Daemon) liveSnapshot(ctx context.Context, opts LiveOptions) (liveSnapshot, error) {
snap := liveSnapshot{
At: time.Now(),
Window: d.Config.RateLimitWindow,
Limit: d.Config.RateLimitMaxRequests,
}
items, err := d.FetchQueryLog(ctx)
if err != nil {
return snap, err
}
snap.APIEntries = len(items)
events := dedupeEvents(d.toEvents(items))
sort.Slice(events, func(i, j int) bool { return events[i].At.After(events[j].At) })
if len(events) > opts.Recent {
snap.Events = append([]queryEvent(nil), events[:opts.Recent]...)
} else {
snap.Events = append([]queryEvent(nil), events...)
}
snap.TopPairs = topQueryPairs(events, d.Config.RateLimitWindow, opts.Top)
snap.SubdomainGroups = topSubdomainGroups(events, d.Config.SubdomainFloodWindow, opts.Top)
if bans, err := d.Store.ActiveBans(); err == nil {
snap.ActiveBans = bans
}
if n, err := d.Store.CountOffenses(); err == nil {
snap.Offenses = n
}
if n, err := d.Store.CountExpiredOffenses(d.Config.ProgressiveBanResetAfter); err == nil {
snap.ExpiredOffenses = n
}
if wl, err := d.Store.AllWhitelist(); err == nil {
snap.WhitelistCount = len(wl)
}
if n, err := d.Store.CountBySource("external-blocklist"); err == nil {
snap.BlocklistBans = n
}
snap.SystemLogs = RecentLogLines(d.Config.LogFile, opts.LogLevel, opts.Recent)
return snap, nil
}
func dedupeEvents(events []queryEvent) []queryEvent {
seen := map[string]bool{}
out := make([]queryEvent, 0, len(events))
for _, ev := range events {
key := ev.At.Format(time.RFC3339Nano) + "|" + ev.Client + "|" + ev.Domain + "|" + ev.Protocol
if seen[key] {
continue
}
seen[key] = true
out = append(out, ev)
}
return out
}
func topQueryPairs(events []queryEvent, window, limit int) []liveCount {
cut := time.Now().Add(-time.Duration(window) * time.Second)
counts := map[string]*liveCount{}
protos := map[string]map[string]bool{}
for _, ev := range events {
if ev.At.Before(cut) {
continue
}
key := ev.Client + "|" + ev.Domain
if counts[key] == nil {
counts[key] = &liveCount{Client: ev.Client, Domain: ev.Domain}
protos[key] = map[string]bool{}
}
counts[key].Count++
protos[key][formatProtocol(ev.Protocol)] = true
}
out := make([]liveCount, 0, len(counts))
for key, item := range counts {
item.Protocol = strings.Join(sortedKeys(protos[key]), ",")
out = append(out, *item)
}
sort.Slice(out, func(i, j int) bool {
if out[i].Count == out[j].Count {
return out[i].Client+"|"+out[i].Domain < out[j].Client+"|"+out[j].Domain
}
return out[i].Count > out[j].Count
})
if limit > 0 && len(out) > limit {
return out[:limit]
}
return out
}
func topSubdomainGroups(events []queryEvent, window, limit int) []liveCount {
cut := time.Now().Add(-time.Duration(window) * time.Second)
sets := map[string]map[string]bool{}
protos := map[string]map[string]bool{}
for _, ev := range events {
if ev.At.Before(cut) {
continue
}
base := baseDomain(ev.Domain)
if base == "" || base == ev.Domain {
continue
}
key := ev.Client + "|" + base
if sets[key] == nil {
sets[key] = map[string]bool{}
protos[key] = map[string]bool{}
}
sets[key][ev.Domain] = true
protos[key][formatProtocol(ev.Protocol)] = true
}
out := make([]liveCount, 0, len(sets))
for key, set := range sets {
client, domain, _ := strings.Cut(key, "|")
out = append(out, liveCount{
Client: client,
Domain: domain,
Count: len(set),
Protocol: strings.Join(sortedKeys(protos[key]), ","),
})
}
sort.Slice(out, func(i, j int) bool {
if out[i].Count == out[j].Count {
return out[i].Client+"|"+out[i].Domain < out[j].Client+"|"+out[j].Domain
}
return out[i].Count > out[j].Count
})
if limit > 0 && len(out) > limit {
return out[:limit]
}
return out
}
func renderLive(w io.Writer, d *Daemon, snap liveSnapshot, snapErr error, opts LiveOptions) {
fmt.Fprint(w, "\033[H\033[2J")
fmt.Fprintf(w, "AdGuard Shield Live | %s | Strg+C beendet\n", snap.At.Format("2006-01-02 15:04:05"))
fmt.Fprintln(w, strings.Repeat("=", 92))
fmt.Fprintf(w, "Config: %s | API: %s | Log: %s (ab %s)\n", d.Config.Path, d.Config.AdGuardURL, d.Config.LogFile, strings.ToUpper(opts.LogLevel))
if snapErr != nil {
fmt.Fprintf(w, "\nFEHLER: Live-Snapshot konnte nicht geladen werden: %v\n", snapErr)
return
}
fmt.Fprintf(w, "\nWorker und Module\n")
fmt.Fprintf(w, " Query-Poller: alle %ds | API-Eintraege: %d | Zeitfenster: %ds | Limit: %d\n", d.Config.CheckInterval, snap.APIEntries, snap.Window, snap.Limit)
fmt.Fprintf(w, " GeoIP: %s | Modus: %s | Laender: %s\n", enabled(d.Config.GeoIPEnabled), d.Config.GeoIPMode, listOrDash(d.Config.GeoIPCountries))
fmt.Fprintf(w, " Externe Blocklist: %s | Intervall: %ds | URLs: %d | aktive Sperren: %d\n", enabled(d.Config.ExternalBlocklistEnabled), d.Config.ExternalBlocklistInterval, len(d.Config.ExternalBlocklistURLs), snap.BlocklistBans)
fmt.Fprintf(w, " Externe Whitelist: %s | Intervall: %ds | URLs: %d | aufgeloeste IPs: %d\n", enabled(d.Config.ExternalWhitelistEnabled), d.Config.ExternalWhitelistInterval, len(d.Config.ExternalWhitelistURLs), snap.WhitelistCount)
fmt.Fprintf(w, " Offense-Cleanup: %s | Zaehler: %d | davon abgelaufen: %d\n", enabled(d.Config.ProgressiveBanEnabled), snap.Offenses, snap.ExpiredOffenses)
fmt.Fprintf(w, "\nTop Client/Domain im Rate-Limit-Fenster\n")
if len(snap.TopPairs) == 0 {
fmt.Fprintln(w, " Keine Anfragen im aktuellen Zeitfenster.")
} else {
for _, item := range snap.TopPairs {
fmt.Fprintf(w, " %5s %-39s %-34s %s\n", fmt.Sprintf("%d/%d", item.Count, snap.Limit), trim(item.Client, 39), trim(item.Domain, 34), item.Protocol)
}
}
if d.Config.SubdomainFloodEnabled {
fmt.Fprintf(w, "\nSubdomain-Flood-Kandidaten\n")
if len(snap.SubdomainGroups) == 0 {
fmt.Fprintln(w, " Keine Subdomain-Gruppen im aktuellen Zeitfenster.")
} else {
for _, item := range snap.SubdomainGroups {
fmt.Fprintf(w, " %5s %-39s %-34s %s\n", fmt.Sprintf("%d/%d", item.Count, d.Config.SubdomainFloodMaxUnique), trim(item.Client, 39), trim(item.Domain, 34), item.Protocol)
}
}
}
fmt.Fprintf(w, "\nLetzte Queries\n")
if len(snap.Events) == 0 {
fmt.Fprintln(w, " Keine Querylog-Eintraege gefunden.")
} else {
for _, ev := range snap.Events {
fmt.Fprintf(w, " %s %-39s %-8s %s\n", ev.At.Local().Format("15:04:05"), trim(ev.Client, 39), formatProtocol(ev.Protocol), trim(ev.Domain, 44))
}
}
fmt.Fprintf(w, "\nAktive Sperren\n")
if len(snap.ActiveBans) == 0 {
fmt.Fprintln(w, " Keine aktiven Sperren.")
} else {
maxBans := opts.Top
if len(snap.ActiveBans) < maxBans {
maxBans = len(snap.ActiveBans)
}
for _, b := range snap.ActiveBans[:maxBans] {
fmt.Fprintf(w, " %-39s %-20s %-18s %s\n", trim(b.IP, 39), trim(b.Source, 20), trim(b.Reason, 18), banUntil(b))
}
if len(snap.ActiveBans) > maxBans {
fmt.Fprintf(w, " ... %d weitere\n", len(snap.ActiveBans)-maxBans)
}
}
if strings.ToLower(opts.LogLevel) != "off" {
fmt.Fprintf(w, "\nSystemereignisse\n")
if len(snap.SystemLogs) == 0 {
fmt.Fprintln(w, " Keine passenden Logeintraege.")
} else {
for _, line := range snap.SystemLogs {
fmt.Fprintf(w, " %s\n", trim(line, 88))
}
}
}
}
func RecentLogLines(path, minLevel string, limit int) []string {
if strings.EqualFold(strings.TrimSpace(minLevel), "off") || path == "" || limit <= 0 {
return nil
}
f, err := os.Open(path)
if err != nil {
return nil
}
defer f.Close()
min := syslog.ParseLevel(minLevel, syslog.Info)
ring := make([]string, limit)
count := 0
sc := bufio.NewScanner(f)
sc.Buffer(make([]byte, 1024), 1024*1024)
for sc.Scan() {
line := sc.Text()
if logLineLevel(line) < min {
continue
}
ring[count%limit] = line
count++
}
n := count
if n > limit {
n = limit
}
out := make([]string, 0, n)
start := count - n
for i := 0; i < n; i++ {
out = append(out, ring[(start+i)%limit])
}
return out
}
func logLineLevel(line string) syslog.Level {
for _, level := range []syslog.Level{syslog.Error, syslog.Warn, syslog.Info, syslog.Debug} {
if strings.Contains(line, "["+syslog.LevelName(level)+"]") {
return level
}
}
return syslog.Info
}
func sortedKeys(m map[string]bool) []string {
keys := make([]string, 0, len(m))
for k := range m {
if k != "" {
keys = append(keys, k)
}
}
sort.Strings(keys)
return keys
}
func formatProtocol(proto string) string {
switch strings.ToLower(strings.TrimSpace(proto)) {
case "doh":
return "DoH"
case "dot":
return "DoT"
case "doq":
return "DoQ"
case "dnscrypt":
return "DNSCrypt"
case "", "dns":
return "DNS"
default:
return proto
}
}
func enabled(ok bool) string {
if ok {
return "aktiv"
}
return "inaktiv"
}
func listOrDash(items []string) string {
if len(items) == 0 {
return "-"
}
return strings.Join(items, ",")
}
func trim(s string, max int) string {
if len(s) <= max {
return s
}
if max <= 1 {
return s[:max]
}
return s[:max-1] + "~"
}
func banUntil(b db.Ban) string {
if b.Permanent || b.BanUntil == 0 {
return "permanent"
}
return time.Unix(b.BanUntil, 0).Format("2006-01-02 15:04:05")
}

408
internal/db/db.go Normal file
View File

@@ -0,0 +1,408 @@
package db
import (
"database/sql"
"fmt"
"time"
_ "modernc.org/sqlite"
)
type Store struct{ DB *sql.DB }
type Ban struct {
IP string
Domain string
Count int
BanUntil int64
Duration int64
OffenseLevel int
Permanent bool
Reason string
Protocol string
Source string
GeoIPCountry string
GeoIPMode string
}
type ReportStats struct {
Since int64
Until int64
TotalBans int
TotalUnbans int
ActiveBans int
TopClients []ReportCount
Reasons []ReportCount
Sources []ReportCount
RecentEvents []string
}
type ReportCount struct {
Name string
Count int
}
func Open(path string) (*Store, error) {
db, err := sql.Open("sqlite", path+"?_pragma=busy_timeout(5000)&_pragma=journal_mode(WAL)")
if err != nil {
return nil, err
}
s := &Store{DB: db}
if err := s.Init(); err != nil {
db.Close()
return nil, err
}
return s, nil
}
func (s *Store) Close() error { return s.DB.Close() }
func (s *Store) Init() error {
schema := `
PRAGMA journal_mode=WAL;
PRAGMA busy_timeout=5000;
CREATE TABLE IF NOT EXISTS schema_version (version INTEGER PRIMARY KEY, applied_at TEXT DEFAULT (datetime('now', 'localtime')));
CREATE TABLE IF NOT EXISTS active_bans (
client_ip TEXT PRIMARY KEY, domain TEXT, count INTEGER, ban_time TEXT,
ban_until_epoch INTEGER DEFAULT 0, ban_duration INTEGER DEFAULT 0, offense_level INTEGER DEFAULT 0,
is_permanent INTEGER DEFAULT 0, reason TEXT DEFAULT 'rate-limit', protocol TEXT DEFAULT 'DNS',
source TEXT DEFAULT 'monitor', geoip_country TEXT, geoip_mode TEXT, created_at TEXT DEFAULT (datetime('now', 'localtime')));
CREATE TABLE IF NOT EXISTS offense_tracking (
client_ip TEXT PRIMARY KEY, offense_level INTEGER DEFAULT 0, last_offense_epoch INTEGER,
last_offense TEXT, first_offense TEXT, created_at TEXT DEFAULT (datetime('now', 'localtime')),
updated_at TEXT DEFAULT (datetime('now', 'localtime')));
CREATE TABLE IF NOT EXISTS ban_history (
id INTEGER PRIMARY KEY AUTOINCREMENT, timestamp_epoch INTEGER NOT NULL, timestamp_text TEXT NOT NULL,
action TEXT NOT NULL, client_ip TEXT NOT NULL, domain TEXT, count TEXT, duration TEXT, protocol TEXT, reason TEXT);
CREATE TABLE IF NOT EXISTS whitelist_cache (ip_address TEXT PRIMARY KEY, source TEXT, resolved_at TEXT DEFAULT (datetime('now', 'localtime')));
CREATE TABLE IF NOT EXISTS geoip_cache (ip TEXT PRIMARY KEY, country_code TEXT NOT NULL, looked_up_at_epoch INTEGER NOT NULL, db_mtime INTEGER DEFAULT 0);
CREATE INDEX IF NOT EXISTS idx_bans_until ON active_bans(ban_until_epoch);
CREATE INDEX IF NOT EXISTS idx_bans_source ON active_bans(source);
CREATE INDEX IF NOT EXISTS idx_bans_reason ON active_bans(reason);
CREATE INDEX IF NOT EXISTS idx_history_timestamp ON ban_history(timestamp_epoch);
CREATE INDEX IF NOT EXISTS idx_history_action ON ban_history(action);
CREATE INDEX IF NOT EXISTS idx_history_ip ON ban_history(client_ip);
CREATE INDEX IF NOT EXISTS idx_offenses_last ON offense_tracking(last_offense_epoch);
CREATE INDEX IF NOT EXISTS idx_geoip_cache_age ON geoip_cache(looked_up_at_epoch);
INSERT OR IGNORE INTO schema_version (version) VALUES (1);`
_, err := s.DB.Exec(schema)
return err
}
func (s *Store) BanExists(ip string) (bool, error) {
var one int
err := s.DB.QueryRow(`SELECT 1 FROM active_bans WHERE client_ip=? LIMIT 1`, ip).Scan(&one)
if err == sql.ErrNoRows {
return false, nil
}
return err == nil, err
}
func (s *Store) InsertBan(b Ban) error {
now := time.Now()
perm := 0
if b.Permanent {
perm = 1
}
_, err := s.DB.Exec(`INSERT OR REPLACE INTO active_bans
(client_ip, domain, count, ban_time, ban_until_epoch, ban_duration, offense_level, is_permanent, reason, protocol, source, geoip_country, geoip_mode)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
b.IP, b.Domain, b.Count, now.Format("2006-01-02 15:04:05"), b.BanUntil, b.Duration, b.OffenseLevel, perm,
b.Reason, b.Protocol, b.Source, b.GeoIPCountry, b.GeoIPMode)
return err
}
func (s *Store) DeleteBan(ip string) error {
_, err := s.DB.Exec(`DELETE FROM active_bans WHERE client_ip=?`, ip)
return err
}
func (s *Store) ActiveBans() ([]Ban, error) {
rows, err := s.DB.Query(`SELECT client_ip, COALESCE(domain,''), COALESCE(count,0), COALESCE(ban_until_epoch,0),
COALESCE(ban_duration,0), COALESCE(offense_level,0), COALESCE(is_permanent,0), COALESCE(reason,''), COALESCE(protocol,''),
COALESCE(source,''), COALESCE(geoip_country,''), COALESCE(geoip_mode,'') FROM active_bans ORDER BY created_at DESC`)
if err != nil {
return nil, err
}
defer rows.Close()
var out []Ban
for rows.Next() {
var b Ban
var perm int
if err := rows.Scan(&b.IP, &b.Domain, &b.Count, &b.BanUntil, &b.Duration, &b.OffenseLevel, &perm, &b.Reason, &b.Protocol, &b.Source, &b.GeoIPCountry, &b.GeoIPMode); err != nil {
return nil, err
}
b.Permanent = perm == 1
out = append(out, b)
}
return out, rows.Err()
}
func (s *Store) BansBySource(source string) ([]Ban, error) {
rows, err := s.DB.Query(`SELECT client_ip, COALESCE(domain,''), COALESCE(count,0), COALESCE(ban_until_epoch,0),
COALESCE(ban_duration,0), COALESCE(offense_level,0), COALESCE(is_permanent,0), COALESCE(reason,''), COALESCE(protocol,''),
COALESCE(source,''), COALESCE(geoip_country,''), COALESCE(geoip_mode,'') FROM active_bans WHERE source=?`, source)
if err != nil {
return nil, err
}
defer rows.Close()
var out []Ban
for rows.Next() {
var b Ban
var perm int
if err := rows.Scan(&b.IP, &b.Domain, &b.Count, &b.BanUntil, &b.Duration, &b.OffenseLevel, &perm, &b.Reason, &b.Protocol, &b.Source, &b.GeoIPCountry, &b.GeoIPMode); err != nil {
return nil, err
}
b.Permanent = perm == 1
out = append(out, b)
}
return out, rows.Err()
}
func (s *Store) BansByReason(reason string) ([]Ban, error) {
rows, err := s.DB.Query(`SELECT client_ip, COALESCE(domain,''), COALESCE(count,0), COALESCE(ban_until_epoch,0),
COALESCE(ban_duration,0), COALESCE(offense_level,0), COALESCE(is_permanent,0), COALESCE(reason,''), COALESCE(protocol,''),
COALESCE(source,''), COALESCE(geoip_country,''), COALESCE(geoip_mode,'') FROM active_bans WHERE reason=?`, reason)
if err != nil {
return nil, err
}
defer rows.Close()
var out []Ban
for rows.Next() {
var b Ban
var perm int
if err := rows.Scan(&b.IP, &b.Domain, &b.Count, &b.BanUntil, &b.Duration, &b.OffenseLevel, &perm, &b.Reason, &b.Protocol, &b.Source, &b.GeoIPCountry, &b.GeoIPMode); err != nil {
return nil, err
}
b.Permanent = perm == 1
out = append(out, b)
}
return out, rows.Err()
}
func (s *Store) CountBySource(source string) (int, error) {
var count int
err := s.DB.QueryRow(`SELECT COUNT(*) FROM active_bans WHERE source=?`, source).Scan(&count)
return count, err
}
func (s *Store) ExpiredBans(now int64) ([]string, error) {
rows, err := s.DB.Query(`SELECT client_ip FROM active_bans WHERE ban_until_epoch > 0 AND is_permanent = 0 AND ban_until_epoch <= ?`, now)
if err != nil {
return nil, err
}
defer rows.Close()
var ips []string
for rows.Next() {
var ip string
if err := rows.Scan(&ip); err != nil {
return nil, err
}
ips = append(ips, ip)
}
return ips, rows.Err()
}
func (s *Store) History(action, ip, domain, count, duration, protocol, reason string) error {
now := time.Now()
_, err := s.DB.Exec(`INSERT INTO ban_history (timestamp_epoch, timestamp_text, action, client_ip, domain, count, duration, protocol, reason)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`, now.Unix(), now.Format("2006-01-02 15:04:05"), action, ip, domain, count, duration, protocol, reason)
return err
}
func (s *Store) RecentHistory(limit int) ([]string, error) {
rows, err := s.DB.Query(`SELECT timestamp_text, action, client_ip, COALESCE(domain,''), COALESCE(count,''), COALESCE(duration,''), COALESCE(protocol,''), COALESCE(reason,'')
FROM ban_history ORDER BY id DESC LIMIT ?`, limit)
if err != nil {
return nil, err
}
defer rows.Close()
var out []string
for rows.Next() {
var ts, action, ip, domain, count, duration, proto, reason string
if err := rows.Scan(&ts, &action, &ip, &domain, &count, &duration, &proto, &reason); err != nil {
return nil, err
}
out = append(out, fmt.Sprintf("%s | %s | %s | %s | %s | %s | %s | %s", ts, action, ip, domain, count, duration, proto, reason))
}
return out, rows.Err()
}
func (s *Store) WhitelistContains(ip string) (bool, error) {
var one int
err := s.DB.QueryRow(`SELECT 1 FROM whitelist_cache WHERE ip_address=? LIMIT 1`, ip).Scan(&one)
if err == sql.ErrNoRows {
return false, nil
}
return err == nil, err
}
func (s *Store) ReplaceWhitelist(ips []string, source string) error {
tx, err := s.DB.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(`DELETE FROM whitelist_cache WHERE source=? OR source IS NULL`, source); err != nil {
return err
}
stmt, err := tx.Prepare(`INSERT OR IGNORE INTO whitelist_cache (ip_address, source) VALUES (?, ?)`)
if err != nil {
return err
}
defer stmt.Close()
for _, ip := range ips {
if _, err := stmt.Exec(ip, source); err != nil {
return err
}
}
return tx.Commit()
}
func (s *Store) AllWhitelist() (map[string]bool, error) {
rows, err := s.DB.Query(`SELECT ip_address FROM whitelist_cache`)
if err != nil {
return nil, err
}
defer rows.Close()
out := map[string]bool{}
for rows.Next() {
var ip string
if err := rows.Scan(&ip); err != nil {
return nil, err
}
out[ip] = true
}
return out, rows.Err()
}
func (s *Store) IncrementOffense(ip string, resetAfter int64) (int, error) {
now := time.Now()
var level int
var last int64
var first string
err := s.DB.QueryRow(`SELECT offense_level, COALESCE(last_offense_epoch,0), COALESCE(first_offense,'') FROM offense_tracking WHERE client_ip=?`, ip).Scan(&level, &last, &first)
if err != nil && err != sql.ErrNoRows {
return 0, err
}
if err == sql.ErrNoRows || (last > 0 && now.Unix()-last > resetAfter) {
level = 0
first = now.Format("2006-01-02 15:04:05")
}
level++
_, err = s.DB.Exec(`INSERT OR REPLACE INTO offense_tracking (client_ip, offense_level, last_offense_epoch, last_offense, first_offense, updated_at)
VALUES (?, ?, ?, ?, ?, ?)`, ip, level, now.Unix(), now.Format("2006-01-02 15:04:05"), first, now.Format("2006-01-02 15:04:05"))
return level, err
}
func (s *Store) ResetOffense(ip string) error {
if ip == "" {
_, err := s.DB.Exec(`DELETE FROM offense_tracking`)
return err
}
_, err := s.DB.Exec(`DELETE FROM offense_tracking WHERE client_ip=?`, ip)
return err
}
func (s *Store) CleanupOffenses(resetAfter int64) (int64, error) {
cutoff := time.Now().Unix() - resetAfter
res, err := s.DB.Exec(`DELETE FROM offense_tracking WHERE last_offense_epoch <= ?`, cutoff)
if err != nil {
return 0, err
}
return res.RowsAffected()
}
func (s *Store) CountOffenses() (int, error) {
var count int
err := s.DB.QueryRow(`SELECT COUNT(*) FROM offense_tracking`).Scan(&count)
return count, err
}
func (s *Store) CountExpiredOffenses(resetAfter int64) (int, error) {
var count int
cutoff := time.Now().Unix() - resetAfter
err := s.DB.QueryRow(`SELECT COUNT(*) FROM offense_tracking WHERE last_offense_epoch <= ?`, cutoff).Scan(&count)
return count, err
}
func (s *Store) LoadGeoIPCache(ttl, dbMtime int64) (map[string]string, error) {
rows, err := s.DB.Query(`SELECT ip, country_code FROM geoip_cache WHERE looked_up_at_epoch >= ? AND (db_mtime=? OR db_mtime=0)`, time.Now().Unix()-ttl, dbMtime)
if err != nil {
return nil, err
}
defer rows.Close()
out := map[string]string{}
for rows.Next() {
var ip, cc string
if err := rows.Scan(&ip, &cc); err != nil {
return nil, err
}
out[ip] = cc
}
return out, rows.Err()
}
func (s *Store) UpsertGeoIP(ip, country string, dbMtime int64) error {
_, err := s.DB.Exec(`INSERT OR REPLACE INTO geoip_cache (ip, country_code, looked_up_at_epoch, db_mtime) VALUES (?, ?, ?, ?)`, ip, country, time.Now().Unix(), dbMtime)
return err
}
func (s *Store) ClearGeoIPCache() (int64, error) {
res, err := s.DB.Exec(`DELETE FROM geoip_cache`)
if err != nil {
return 0, err
}
return res.RowsAffected()
}
func (s *Store) ReportStats(since, until int64, limit int) (ReportStats, error) {
st := ReportStats{Since: since, Until: until}
if err := s.DB.QueryRow(`SELECT COUNT(*) FROM ban_history WHERE action='BAN' AND timestamp_epoch BETWEEN ? AND ?`, since, until).Scan(&st.TotalBans); err != nil {
return st, err
}
if err := s.DB.QueryRow(`SELECT COUNT(*) FROM ban_history WHERE action='UNBAN' AND timestamp_epoch BETWEEN ? AND ?`, since, until).Scan(&st.TotalUnbans); err != nil {
return st, err
}
if err := s.DB.QueryRow(`SELECT COUNT(*) FROM active_bans`).Scan(&st.ActiveBans); err != nil {
return st, err
}
var err error
st.TopClients, err = s.reportCounts(`SELECT client_ip, COUNT(*) FROM ban_history WHERE action='BAN' AND timestamp_epoch BETWEEN ? AND ? GROUP BY client_ip ORDER BY COUNT(*) DESC, client_ip LIMIT ?`, since, until, limit)
if err != nil {
return st, err
}
st.Reasons, err = s.reportCounts(`SELECT COALESCE(NULLIF(reason,''), 'unknown'), COUNT(*) FROM ban_history WHERE action='BAN' AND timestamp_epoch BETWEEN ? AND ? GROUP BY COALESCE(NULLIF(reason,''), 'unknown') ORDER BY COUNT(*) DESC LIMIT ?`, since, until, limit)
if err != nil {
return st, err
}
st.Sources, err = s.reportCounts(`SELECT COALESCE(NULLIF(source,''), 'unknown'), COUNT(*) FROM active_bans GROUP BY COALESCE(NULLIF(source,''), 'unknown') ORDER BY COUNT(*) DESC LIMIT ?`, 0, 0, limit)
if err != nil {
return st, err
}
st.RecentEvents, err = s.RecentHistory(limit)
return st, err
}
func (s *Store) reportCounts(query string, since, until int64, limit int) ([]ReportCount, error) {
var rows *sql.Rows
var err error
if since == 0 && until == 0 {
rows, err = s.DB.Query(query, limit)
} else {
rows, err = s.DB.Query(query, since, until, limit)
}
if err != nil {
return nil, err
}
defer rows.Close()
var out []ReportCount
for rows.Next() {
var item ReportCount
if err := rows.Scan(&item.Name, &item.Count); err != nil {
return nil, err
}
out = append(out, item)
}
return out, rows.Err()
}

31
internal/db/db_test.go Normal file
View File

@@ -0,0 +1,31 @@
package db
import (
"path/filepath"
"testing"
)
func TestStoreBanAndGeoIPCache(t *testing.T) {
s, err := Open(filepath.Join(t.TempDir(), "test.db"))
if err != nil {
t.Fatal(err)
}
defer s.Close()
if err := s.InsertBan(Ban{IP: "1.2.3.4", Domain: "example.com", Permanent: true, Reason: "geoip", Source: "geoip", GeoIPCountry: "CN"}); err != nil {
t.Fatal(err)
}
ok, err := s.BanExists("1.2.3.4")
if err != nil || !ok {
t.Fatalf("ban not found: %v %v", ok, err)
}
if err := s.UpsertGeoIP("1.2.3.4", "CN", 123); err != nil {
t.Fatal(err)
}
cache, err := s.LoadGeoIPCache(86400, 123)
if err != nil {
t.Fatal(err)
}
if cache["1.2.3.4"] != "CN" {
t.Fatalf("unexpected cache: %#v", cache)
}
}

View File

@@ -0,0 +1,203 @@
package firewall
import (
"context"
"fmt"
"net/netip"
"os/exec"
"strconv"
"strings"
)
type Executor interface {
Run(ctx context.Context, name string, args ...string) error
}
type OSExecutor struct{}
func (OSExecutor) Run(ctx context.Context, name string, args ...string) error {
return exec.CommandContext(ctx, name, args...).Run()
}
type Firewall struct {
Exec Executor
Chain string
Ports []string
Mode string
DryRun bool
Set4 string
Set6 string
}
func New(exec Executor, chain string, ports []string, mode string, dry bool) *Firewall {
return &Firewall{Exec: exec, Chain: chain, Ports: ports, Mode: normalizeMode(mode), DryRun: dry, Set4: "adguard_shield_v4", Set6: "adguard_shield_v6"}
}
func (f *Firewall) Setup(ctx context.Context) error {
if f.DryRun {
return nil
}
if len(f.hooks("iptables")) == 0 {
return fmt.Errorf("unsupported firewall mode %q", f.Mode)
}
_ = f.Exec.Run(ctx, "ipset", "create", f.Set4, "hash:net", "family", "inet", "timeout", "0", "-exist")
_ = f.Exec.Run(ctx, "ipset", "create", f.Set6, "hash:net", "family", "inet6", "timeout", "0", "-exist")
_ = f.Exec.Run(ctx, "iptables", "-N", f.Chain)
_ = f.Exec.Run(ctx, "ip6tables", "-N", f.Chain)
if err := ensureSetDrop(ctx, f.Exec, "iptables", f.Chain, f.Set4); err != nil {
return err
}
if err := ensureSetDrop(ctx, f.Exec, "ip6tables", f.Chain, f.Set6); err != nil {
return err
}
if err := f.ensureHooks(ctx, "iptables"); err != nil {
return err
}
if err := f.ensureHooks(ctx, "ip6tables"); err != nil {
return err
}
return nil
}
func ensureRule(ctx context.Context, ex Executor, bin string, args ...string) bool {
return ex.Run(ctx, bin, args...) == nil
}
func ensureSetDrop(ctx context.Context, ex Executor, bin, chain, set string) error {
check := []string{"-C", chain, "-m", "set", "--match-set", set, "src", "-j", "DROP"}
if ex.Run(ctx, bin, check...) == nil {
return nil
}
return ex.Run(ctx, bin, "-I", chain, "-m", "set", "--match-set", set, "src", "-j", "DROP")
}
type hook struct {
Chain string
OptionalMissing bool
}
func normalizeMode(mode string) string {
switch strings.ToLower(strings.TrimSpace(mode)) {
case "", "host", "classic", "native", "docker-host":
return "host"
case "docker", "docker-bridge", "docker-published", "published":
return "docker-bridge"
case "hybrid", "both":
return "hybrid"
default:
return strings.ToLower(strings.TrimSpace(mode))
}
}
func (f *Firewall) hooks(bin string) []hook {
docker := hook{Chain: "DOCKER-USER", OptionalMissing: bin == "ip6tables"}
switch f.Mode {
case "host":
return []hook{{Chain: "INPUT"}}
case "docker-bridge":
return []hook{docker}
case "hybrid":
return []hook{{Chain: "INPUT"}, docker}
default:
return nil
}
}
func (f *Firewall) ensureHooks(ctx context.Context, bin string) error {
for _, h := range f.hooks(bin) {
if !chainExists(ctx, f.Exec, bin, h.Chain) {
if h.OptionalMissing {
continue
}
return fmt.Errorf("%s chain %s not found", bin, h.Chain)
}
for _, p := range f.Ports {
for _, proto := range []string{"tcp", "udp"} {
check := []string{"-C", h.Chain, "-p", proto, "--dport", p, "-j", f.Chain}
if ensureRule(ctx, f.Exec, bin, check...) {
continue
}
_ = f.Exec.Run(ctx, bin, "-I", h.Chain, "-p", proto, "--dport", p, "-j", f.Chain)
}
}
}
return nil
}
func chainExists(ctx context.Context, ex Executor, bin, chain string) bool {
return ex.Run(ctx, bin, "-n", "-L", chain) == nil
}
func (f *Firewall) Add(ctx context.Context, ip string, timeout int64) error {
if f.DryRun {
return nil
}
set, err := f.setFor(ip)
if err != nil {
return err
}
args := []string{"add", set, ip, "-exist"}
if timeout > 0 {
args = append(args, "timeout", strconv.FormatInt(timeout, 10))
}
return f.Exec.Run(ctx, "ipset", args...)
}
func (f *Firewall) Del(ctx context.Context, ip string) error {
if f.DryRun {
return nil
}
set, err := f.setFor(ip)
if err != nil {
return err
}
_ = f.Exec.Run(ctx, "ipset", "del", set, ip)
return nil
}
func (f *Firewall) Flush(ctx context.Context) error {
if f.DryRun {
return nil
}
_ = f.Exec.Run(ctx, "ipset", "flush", f.Set4)
_ = f.Exec.Run(ctx, "ipset", "flush", f.Set6)
return nil
}
func (f *Firewall) Remove(ctx context.Context) error {
if f.DryRun {
return nil
}
for _, p := range f.Ports {
for _, proto := range []string{"tcp", "udp"} {
for _, parent := range []string{"INPUT", "DOCKER-USER"} {
_ = f.Exec.Run(ctx, "iptables", "-D", parent, "-p", proto, "--dport", p, "-j", f.Chain)
_ = f.Exec.Run(ctx, "ip6tables", "-D", parent, "-p", proto, "--dport", p, "-j", f.Chain)
}
}
}
_ = f.Exec.Run(ctx, "iptables", "-F", f.Chain)
_ = f.Exec.Run(ctx, "ip6tables", "-F", f.Chain)
_ = f.Exec.Run(ctx, "iptables", "-X", f.Chain)
_ = f.Exec.Run(ctx, "ip6tables", "-X", f.Chain)
_ = f.Exec.Run(ctx, "ipset", "destroy", f.Set4)
_ = f.Exec.Run(ctx, "ipset", "destroy", f.Set6)
return nil
}
func (f *Firewall) setFor(s string) (string, error) {
if p, err := netip.ParsePrefix(s); err == nil {
if p.Addr().Is4() {
return f.Set4, nil
}
return f.Set6, nil
}
a, err := netip.ParseAddr(s)
if err != nil {
return "", fmt.Errorf("invalid IP/prefix %q", s)
}
if a.Is4() {
return f.Set4, nil
}
return f.Set6, nil
}

View File

@@ -0,0 +1,142 @@
package firewall
import (
"context"
"strings"
"testing"
)
type fakeExec struct {
calls []string
failChecks bool
missing map[string]bool
}
func (f *fakeExec) Run(_ context.Context, name string, args ...string) error {
call := name + " " + strings.Join(args, " ")
f.calls = append(f.calls, call)
if f.missing != nil && f.missing[call] {
return errFake
}
if f.failChecks && len(args) > 0 && args[0] == "-C" {
return errFake
}
return nil
}
type fakeErr string
func (e fakeErr) Error() string { return string(e) }
var errFake = fakeErr("missing")
func TestFirewallSetupCreatesSetsAndRules(t *testing.T) {
ex := &fakeExec{failChecks: true}
fw := New(ex, "ADGUARD_SHIELD", []string{"53"}, "host", false)
if err := fw.Setup(context.Background()); err != nil {
t.Fatal(err)
}
joined := strings.Join(ex.calls, "\n")
for _, want := range []string{
"ipset create adguard_shield_v4 hash:net family inet timeout 0 -exist",
"iptables -I INPUT -p tcp --dport 53 -j ADGUARD_SHIELD",
"iptables -I ADGUARD_SHIELD -m set --match-set adguard_shield_v4 src -j DROP",
"ip6tables -I ADGUARD_SHIELD -m set --match-set adguard_shield_v6 src -j DROP",
} {
if !strings.Contains(joined, want) {
t.Fatalf("missing call %q in:\n%s", want, joined)
}
}
}
func TestFirewallSetupUsesDockerUserForBridgeMode(t *testing.T) {
ex := &fakeExec{failChecks: true}
fw := New(ex, "ADGUARD_SHIELD", []string{"53"}, "docker-bridge", false)
if err := fw.Setup(context.Background()); err != nil {
t.Fatal(err)
}
joined := strings.Join(ex.calls, "\n")
if !strings.Contains(joined, "iptables -I DOCKER-USER -p udp --dport 53 -j ADGUARD_SHIELD") {
t.Fatalf("missing docker hook in:\n%s", joined)
}
if strings.Contains(joined, "iptables -I INPUT -p udp --dport 53 -j ADGUARD_SHIELD") {
t.Fatalf("unexpected INPUT hook in docker-bridge mode:\n%s", joined)
}
}
func TestFirewallSetupHybridUsesInputAndDockerUser(t *testing.T) {
ex := &fakeExec{failChecks: true}
fw := New(ex, "ADGUARD_SHIELD", []string{"53"}, "hybrid", false)
if err := fw.Setup(context.Background()); err != nil {
t.Fatal(err)
}
joined := strings.Join(ex.calls, "\n")
for _, want := range []string{
"iptables -I INPUT -p tcp --dport 53 -j ADGUARD_SHIELD",
"iptables -I DOCKER-USER -p tcp --dport 53 -j ADGUARD_SHIELD",
} {
if !strings.Contains(joined, want) {
t.Fatalf("missing call %q in:\n%s", want, joined)
}
}
}
func TestFirewallSetupRequiresDockerUserForIPv4BridgeMode(t *testing.T) {
ex := &fakeExec{missing: map[string]bool{"iptables -n -L DOCKER-USER": true}}
fw := New(ex, "ADGUARD_SHIELD", []string{"53"}, "docker-bridge", false)
if err := fw.Setup(context.Background()); err == nil || !strings.Contains(err.Error(), "DOCKER-USER") {
t.Fatalf("expected DOCKER-USER error, got %v", err)
}
}
func TestFirewallSetupSkipsMissingIPv6DockerUser(t *testing.T) {
ex := &fakeExec{
failChecks: true,
missing: map[string]bool{"ip6tables -n -L DOCKER-USER": true},
}
fw := New(ex, "ADGUARD_SHIELD", []string{"53"}, "docker-bridge", false)
if err := fw.Setup(context.Background()); err != nil {
t.Fatal(err)
}
joined := strings.Join(ex.calls, "\n")
if strings.Contains(joined, "ip6tables -I DOCKER-USER") {
t.Fatalf("unexpected IPv6 docker hook with missing DOCKER-USER:\n%s", joined)
}
}
func TestFirewallSetupRejectsUnknownMode(t *testing.T) {
fw := New(&fakeExec{}, "ADGUARD_SHIELD", []string{"53"}, "surprise", false)
err := fw.Setup(context.Background())
if err == nil || !strings.Contains(err.Error(), "unsupported firewall mode") {
t.Fatalf("unexpected error: %v", err)
}
}
func TestFirewallAddChoosesFamily(t *testing.T) {
ex := &fakeExec{}
fw := New(ex, "ADGUARD_SHIELD", []string{"53"}, "host", false)
if err := fw.Add(context.Background(), "2001:db8::1", 30); err != nil {
t.Fatal(err)
}
got := strings.Join(ex.calls, "\n")
if !strings.Contains(got, "ipset add adguard_shield_v6 2001:db8::1 -exist timeout 30") {
t.Fatalf("unexpected calls:\n%s", got)
}
}
func TestFirewallRemoveDeletesAllKnownHooks(t *testing.T) {
ex := &fakeExec{}
fw := New(ex, "ADGUARD_SHIELD", []string{"53"}, "host", false)
if err := fw.Remove(context.Background()); err != nil {
t.Fatal(err)
}
joined := strings.Join(ex.calls, "\n")
for _, want := range []string{
"iptables -D INPUT -p tcp --dport 53 -j ADGUARD_SHIELD",
"iptables -D DOCKER-USER -p tcp --dport 53 -j ADGUARD_SHIELD",
} {
if !strings.Contains(joined, want) {
t.Fatalf("missing cleanup call %q in:\n%s", want, joined)
}
}
}

245
internal/geoip/geoip.go Normal file
View File

@@ -0,0 +1,245 @@
package geoip
import (
"archive/tar"
"compress/gzip"
"context"
"fmt"
"io"
"net"
"net/http"
"net/netip"
"os"
"os/exec"
"path/filepath"
"regexp"
"strings"
"time"
"github.com/oschwald/maxminddb-golang"
)
type Store interface {
LoadGeoIPCache(ttl, dbMtime int64) (map[string]string, error)
UpsertGeoIP(ip, country string, dbMtime int64) error
}
type Resolver struct {
DBPath string
effectivePath string
LicenseKey string
Dir string
TTL int64
Store Store
reader *maxminddb.Reader
cache map[string]string
mtime int64
}
func New(dbPath, licenseKey, dir string, ttl int64, store Store) *Resolver {
return &Resolver{DBPath: dbPath, LicenseKey: licenseKey, Dir: dir, TTL: ttl, Store: store, cache: map[string]string{}}
}
func (r *Resolver) Open(ctx context.Context) error {
path := r.DBPath
if path == "" && r.LicenseKey != "" {
var err error
path, err = r.ensureAutoDB(ctx)
if err != nil {
return err
}
}
if path == "" {
return nil
}
r.effectivePath = path
st, err := os.Stat(path)
if err != nil {
return err
}
reader, err := maxminddb.Open(path)
if err != nil {
return err
}
r.reader = reader
r.mtime = st.ModTime().Unix()
if r.Store != nil {
if c, err := r.Store.LoadGeoIPCache(r.TTL, r.mtime); err == nil {
r.cache = c
}
}
return nil
}
func (r *Resolver) Close() error {
if r.reader != nil {
return r.reader.Close()
}
return nil
}
func (r *Resolver) Lookup(ip string) (string, error) {
if v, ok := r.cache[ip]; ok {
return v, nil
}
if r.reader == nil {
return r.lookupLegacy(ip)
}
parsed := net.ParseIP(ip)
if parsed == nil {
return "", fmt.Errorf("invalid IP %q", ip)
}
var rec struct {
Country struct {
ISOCode string `maxminddb:"iso_code"`
} `maxminddb:"country"`
RegisteredCountry struct {
ISOCode string `maxminddb:"iso_code"`
} `maxminddb:"registered_country"`
}
if err := r.reader.Lookup(parsed, &rec); err != nil {
return "", err
}
cc := strings.ToUpper(rec.Country.ISOCode)
if cc == "" {
cc = strings.ToUpper(rec.RegisteredCountry.ISOCode)
}
if cc != "" {
r.cache[ip] = cc
if r.Store != nil {
_ = r.Store.UpsertGeoIP(ip, cc, r.mtime)
}
}
return cc, nil
}
func (r *Resolver) lookupLegacy(ip string) (string, error) {
if strings.Contains(ip, ":") {
if cc, err := runGeoIPCommand("geoiplookup6", ip); err == nil && cc != "" {
return cc, nil
}
} else {
if cc, err := runGeoIPCommand("geoiplookup", ip); err == nil && cc != "" {
return cc, nil
}
}
if r.effectivePath != "" {
if cc, err := runGeoIPCommand("mmdblookup", "--file", r.effectivePath, "--ip", ip, "country", "iso_code"); err == nil && cc != "" {
return cc, nil
}
}
return "", fmt.Errorf("no GeoIP result for %s", ip)
}
func runGeoIPCommand(name string, args ...string) (string, error) {
if _, err := exec.LookPath(name); err != nil {
return "", err
}
out, err := exec.Command(name, args...).CombinedOutput()
if err != nil {
return "", err
}
re := regexp.MustCompile(`\b[A-Z]{2}\b`)
matches := re.FindAllString(string(out), -1)
for _, m := range matches {
if m != "IP" {
return strings.ToUpper(m), nil
}
}
return "", nil
}
func ShouldBlock(country, mode string, countries []string) bool {
if country == "" || len(countries) == 0 {
return false
}
found := false
country = strings.ToUpper(country)
for _, c := range countries {
if strings.ToUpper(strings.TrimSpace(c)) == country {
found = true
break
}
}
if strings.ToLower(mode) == "allowlist" {
return !found
}
return found
}
func IsPrivateIP(s string) bool {
if p, err := netip.ParsePrefix(s); err == nil {
return isPrivateAddr(p.Addr())
}
a, err := netip.ParseAddr(s)
if err != nil {
return false
}
return isPrivateAddr(a)
}
func isPrivateAddr(a netip.Addr) bool {
return a.IsPrivate() || a.IsLoopback() || a.IsLinkLocalUnicast() || a.IsUnspecified() ||
(a.Is4() && strings.HasPrefix(a.String(), "100.") && isCGNAT(a))
}
func isCGNAT(a netip.Addr) bool {
p := a.As4()
return p[0] == 100 && p[1] >= 64 && p[1] <= 127
}
func (r *Resolver) ensureAutoDB(ctx context.Context) (string, error) {
if err := os.MkdirAll(r.Dir, 0755); err != nil {
return "", err
}
dst := filepath.Join(r.Dir, "GeoLite2-Country.mmdb")
if st, err := os.Stat(dst); err == nil && time.Since(st.ModTime()) < 24*time.Hour {
return dst, nil
}
url := "https://download.maxmind.com/app/geoip_download?edition_id=GeoLite2-Country&license_key=" + r.LicenseKey + "&suffix=tar.gz"
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return "", err
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("MaxMind download failed: HTTP %d", resp.StatusCode)
}
gzr, err := gzip.NewReader(resp.Body)
if err != nil {
return "", err
}
defer gzr.Close()
tr := tar.NewReader(gzr)
tmp := dst + ".tmp"
for {
h, err := tr.Next()
if err == io.EOF {
break
}
if err != nil {
return "", err
}
if h.FileInfo().IsDir() || filepath.Base(h.Name) != "GeoLite2-Country.mmdb" {
continue
}
f, err := os.Create(tmp)
if err != nil {
return "", err
}
_, copyErr := io.Copy(f, tr)
closeErr := f.Close()
if copyErr != nil {
return "", copyErr
}
if closeErr != nil {
return "", closeErr
}
return dst, os.Rename(tmp, dst)
}
return "", fmt.Errorf("GeoLite2-Country.mmdb not found in archive")
}

View File

@@ -0,0 +1,30 @@
package geoip
import "testing"
func TestShouldBlockModes(t *testing.T) {
countries := []string{"CN", "RU"}
if !ShouldBlock("cn", "blocklist", countries) {
t.Fatal("blocklist should block listed country")
}
if ShouldBlock("DE", "blocklist", countries) {
t.Fatal("blocklist should allow unlisted country")
}
if ShouldBlock("CN", "allowlist", countries) {
t.Fatal("allowlist should allow listed country")
}
if !ShouldBlock("DE", "allowlist", countries) {
t.Fatal("allowlist should block unlisted country")
}
}
func TestIsPrivateIP(t *testing.T) {
for _, ip := range []string{"127.0.0.1", "192.168.1.10", "10.1.2.3", "100.64.0.1", "::1", "fd00::1"} {
if !IsPrivateIP(ip) {
t.Fatalf("%s should be private", ip)
}
}
if IsPrivateIP("8.8.8.8") {
t.Fatal("8.8.8.8 should be public")
}
}

View File

@@ -0,0 +1,642 @@
package installer
import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
"os"
"os/exec"
"path/filepath"
"runtime"
"sort"
"strings"
)
const (
DefaultInstallDir = "/opt/adguard-shield"
DefaultStateDir = "/var/lib/adguard-shield"
DefaultLogFile = "/var/log/adguard-shield.log"
ServiceName = "adguard-shield.service"
ServicePath = "/etc/systemd/system/adguard-shield.service"
)
type Options struct {
InstallDir string
ConfigSource string
Enable bool
SkipDeps bool
KeepConfig bool
}
type Status struct {
InstallDir string
BinaryPath string
ConfigPath string
BinaryExists bool
ConfigExists bool
ServiceExists bool
ServiceEnabled bool
ServiceActive bool
Version string
LegacyFindings []string
}
type LegacyError struct {
Findings []string
}
func (e *LegacyError) Error() string {
return "scriptbasierte AdGuard-Shield-Installation gefunden"
}
func DefaultOptions() Options {
return Options{InstallDir: DefaultInstallDir, Enable: true}
}
func Install(opts Options) error {
opts = normalize(opts)
fmt.Println("AdGuard Shield Go-Installation")
fmt.Printf("Installationspfad: %s\n", opts.InstallDir)
fmt.Println("1/8 Pruefe Betriebssystem und root-Rechte ...")
if err := requireLinuxRoot(); err != nil {
return err
}
fmt.Println("2/8 Pruefe auf scriptbasierte Altinstallation ...")
if findings := DetectLegacy(opts.InstallDir); len(findings) > 0 {
return &LegacyError{Findings: findings}
}
if !opts.SkipDeps {
fmt.Println("3/8 Pruefe System-Abhaengigkeiten ...")
if err := ensureDependencies(); err != nil {
return err
}
} else {
fmt.Println("3/8 System-Abhaengigkeiten uebersprungen (--skip-deps)")
}
fmt.Println("4/8 Erstelle Verzeichnisse ...")
if err := os.MkdirAll(opts.InstallDir, 0755); err != nil {
return err
}
if err := os.MkdirAll(DefaultStateDir, 0755); err != nil {
return err
}
if err := os.MkdirAll(filepath.Join(opts.InstallDir, "geoip"), 0755); err != nil {
return err
}
fmt.Println("5/8 Installiere Binary ...")
if err := copySelf(filepath.Join(opts.InstallDir, "adguard-shield")); err != nil {
return err
}
fmt.Println("6/8 Installiere oder migriere Konfiguration ...")
if err := ensureConfig(opts); err != nil {
return err
}
fmt.Println("7/8 Schreibe systemd-Service ...")
if err := writeService(opts.InstallDir); err != nil {
return err
}
fmt.Println("8/8 Aktualisiere systemd ...")
_ = run("systemctl", "daemon-reload")
if opts.Enable {
fmt.Println("Aktiviere Autostart ...")
if err := run("systemctl", "enable", ServiceName); err != nil {
return err
}
}
if askStartService() {
fmt.Println("Starte Service neu ...")
if err := run("systemctl", "restart", ServiceName); err != nil {
return err
}
}
fmt.Println("Installation fertig.")
return nil
}
func Update(opts Options) error {
opts = normalize(opts)
if err := requireLinuxRoot(); err != nil {
return err
}
if findings := DetectLegacy(opts.InstallDir); len(findings) > 0 {
return &LegacyError{Findings: findings}
}
return Install(opts)
}
func Uninstall(opts Options) error {
opts = normalize(opts)
if err := requireLinuxRoot(); err != nil {
return err
}
_ = run("systemctl", "stop", ServiceName)
_ = run("systemctl", "disable", ServiceName)
if _, err := os.Stat(filepath.Join(opts.InstallDir, "adguard-shield")); err == nil {
_ = run(filepath.Join(opts.InstallDir, "adguard-shield"), "-config", filepath.Join(opts.InstallDir, "adguard-shield.conf"), "firewall-remove")
}
_ = os.Remove(ServicePath)
_ = run("systemctl", "daemon-reload")
if opts.KeepConfig {
for _, p := range []string{
filepath.Join(opts.InstallDir, "adguard-shield"),
filepath.Join(opts.InstallDir, "adguard-shield.conf.old"),
} {
_ = os.Remove(p)
}
return nil
}
_ = os.RemoveAll(opts.InstallDir)
_ = os.RemoveAll(DefaultStateDir)
_ = os.Remove(DefaultLogFile)
return nil
}
func GetStatus(installDir string) Status {
if installDir == "" {
installDir = DefaultInstallDir
}
bin := filepath.Join(installDir, "adguard-shield")
conf := filepath.Join(installDir, "adguard-shield.conf")
st := Status{
InstallDir: installDir,
BinaryPath: bin,
ConfigPath: conf,
BinaryExists: fileExists(bin),
ConfigExists: fileExists(conf),
ServiceExists: fileExists(ServicePath),
LegacyFindings: DetectLegacy(installDir),
}
if st.BinaryExists {
if out, err := exec.Command(bin, "version").Output(); err == nil {
st.Version = strings.TrimSpace(string(out))
}
}
st.ServiceEnabled = commandOK("systemctl", "is-enabled", "adguard-shield")
st.ServiceActive = commandOK("systemctl", "is-active", "adguard-shield")
return st
}
func DetectLegacy(installDir string) []string {
if installDir == "" {
installDir = DefaultInstallDir
}
var findings []string
for _, p := range []string{
"adguard-shield.sh",
"iptables-helper.sh",
"db.sh",
"external-blocklist-worker.sh",
"external-whitelist-worker.sh",
"geoip-worker.sh",
"offense-cleanup-worker.sh",
"report-generator.sh",
"unban-expired.sh",
"adguard-shield-watchdog.sh",
} {
full := filepath.Join(installDir, p)
if fileExists(full) {
findings = append(findings, full)
}
}
for _, p := range []string{
"/etc/systemd/system/adguard-shield-watchdog.service",
"/etc/systemd/system/adguard-shield-watchdog.timer",
} {
if fileExists(p) {
findings = append(findings, p)
}
}
if b, err := os.ReadFile(ServicePath); err == nil {
s := string(b)
if strings.Contains(s, ".sh") || strings.Contains(s, "/bin/bash") || strings.Contains(s, "adguard-shield-watchdog") {
findings = append(findings, ServicePath+" verweist auf Shell/Watchdog")
}
}
sort.Strings(findings)
return findings
}
func FormatLegacyMessage(err *LegacyError, installDir string) string {
if installDir == "" {
installDir = DefaultInstallDir
}
var b strings.Builder
b.WriteString("Die scriptbasierte Installation ist noch vorhanden und muss zuerst deinstalliert werden.\n\n")
b.WriteString("Gefunden:\n")
for _, f := range err.Findings {
b.WriteString(" - ")
b.WriteString(f)
b.WriteByte('\n')
}
b.WriteString("\nKonfiguration uebernehmen:\n")
b.WriteString(" 1. Backup behalten: ")
b.WriteString(filepath.Join(installDir, "adguard-shield.conf"))
b.WriteByte('\n')
b.WriteString(" 2. Alte Shell-Version mit deren uninstall.sh entfernen und die Konfiguration behalten.\n")
b.WriteString(" 3. Danach dieses Binary erneut ausfuehren: adguard-shield install\n")
return b.String()
}
func PrintStatus(st Status) string {
var b strings.Builder
b.WriteString("AdGuard Shield Installationsstatus\n")
b.WriteString(fmt.Sprintf("Installationspfad: %s\n", st.InstallDir))
b.WriteString(fmt.Sprintf("Binary: %s\n", yesNo(st.BinaryExists)))
if st.Version != "" {
b.WriteString(fmt.Sprintf("Version: %s\n", st.Version))
}
b.WriteString(fmt.Sprintf("Konfiguration: %s\n", yesNo(st.ConfigExists)))
b.WriteString(fmt.Sprintf("systemd Service: %s\n", yesNo(st.ServiceExists)))
b.WriteString(fmt.Sprintf("Autostart: %s\n", yesNo(st.ServiceEnabled)))
b.WriteString(fmt.Sprintf("Service aktiv: %s\n", yesNo(st.ServiceActive)))
if len(st.LegacyFindings) > 0 {
b.WriteString("\nScriptbasierte Altinstallation/Altartefakte gefunden:\n")
for _, f := range st.LegacyFindings {
b.WriteString(" - ")
b.WriteString(f)
b.WriteByte('\n')
}
}
return b.String()
}
func normalize(opts Options) Options {
if opts.InstallDir == "" {
opts.InstallDir = DefaultInstallDir
}
return opts
}
func askStartService() bool {
fmt.Print("AdGuard Shield jetzt (neu) starten? [j/N] ")
line, err := bufio.NewReader(os.Stdin).ReadString('\n')
if err != nil && len(line) == 0 {
fmt.Println("Keine Eingabe gelesen, Service wird nicht gestartet.")
return false
}
switch strings.ToLower(strings.TrimSpace(line)) {
case "j", "ja", "y", "yes":
return true
default:
fmt.Println("Service wird nicht gestartet.")
return false
}
}
func requireLinuxRoot() error {
if runtime.GOOS != "linux" {
return fmt.Errorf("Installation ist nur auf Linux-Servern unterstuetzt")
}
if os.Geteuid() != 0 {
return fmt.Errorf("Installation muss als root ausgefuehrt werden")
}
return nil
}
func ensureDependencies() error {
missing := missingCommands("iptables", "ip6tables", "ipset", "systemctl")
if len(missing) == 0 {
fmt.Println(" Alle benoetigten Befehle sind vorhanden.")
return nil
}
fmt.Printf(" Fehlende Befehle: %s\n", strings.Join(missing, ", "))
if _, err := exec.LookPath("apt-get"); err != nil {
return fmt.Errorf("fehlende Abhaengigkeiten (%s), apt-get nicht gefunden", strings.Join(missing, ", "))
}
pkgs := map[string]bool{"iptables": false, "ipset": false, "systemd": false, "ca-certificates": false}
for _, m := range missing {
switch m {
case "iptables", "ip6tables":
pkgs["iptables"] = true
case "ipset":
pkgs["ipset"] = true
case "systemctl":
pkgs["systemd"] = true
}
}
var install []string
for p, needed := range pkgs {
if needed || p == "ca-certificates" {
install = append(install, p)
}
}
sort.Strings(install)
fmt.Printf(" Installiere Pakete via apt-get: %s\n", strings.Join(install, ", "))
fmt.Println(" apt-get update ...")
if err := runStreaming("apt-get", "update"); err != nil {
return err
}
fmt.Println(" apt-get install ...")
args := append([]string{"install", "-y", "-qq"}, install...)
return runStreaming("apt-get", args...)
}
func missingCommands(names ...string) []string {
var missing []string
for _, name := range names {
if _, err := exec.LookPath(name); err != nil {
missing = append(missing, name)
}
}
return missing
}
func copySelf(dst string) error {
src, err := os.Executable()
if err != nil {
return err
}
if sameFile(src, dst) {
return os.Chmod(dst, 0755)
}
in, err := os.Open(src)
if err != nil {
return err
}
defer in.Close()
tmp := dst + ".tmp"
out, err := os.OpenFile(tmp, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0755)
if err != nil {
return err
}
if _, err := io.Copy(out, in); err != nil {
_ = out.Close()
return err
}
if err := out.Close(); err != nil {
return err
}
if err := os.Chmod(tmp, 0755); err != nil {
return err
}
return os.Rename(tmp, dst)
}
func sameFile(a, b string) bool {
aa, errA := filepath.Abs(a)
bb, errB := filepath.Abs(b)
if errA == nil && errB == nil && aa == bb {
return true
}
ai, errA := os.Stat(a)
bi, errB := os.Stat(b)
return errA == nil && errB == nil && os.SameFile(ai, bi)
}
func ensureConfig(opts Options) error {
target := filepath.Join(opts.InstallDir, "adguard-shield.conf")
defaults := []byte(defaultConfig)
if opts.ConfigSource != "" {
b, err := os.ReadFile(opts.ConfigSource)
if err != nil {
return err
}
defaults = b
}
if !fileExists(target) {
if err := os.WriteFile(target, defaults, 0600); err != nil {
return err
}
return nil
}
current, err := os.ReadFile(target)
if err != nil {
return err
}
merged, changed := mergeConfig(current, []byte(defaultConfig))
if !changed {
return os.Chmod(target, 0600)
}
if err := os.WriteFile(target+".old", current, 0600); err != nil {
return err
}
if err := os.WriteFile(target, merged, 0600); err != nil {
return err
}
return nil
}
func mergeConfig(current, defaults []byte) ([]byte, bool) {
existing := configKeys(current)
var add [][]byte
for _, block := range configBlocks(defaults) {
key := blockKey(block)
if key == "" || existing[key] {
continue
}
add = append(add, block)
}
if len(add) == 0 {
return current, false
}
out := bytes.TrimRight(current, "\r\n")
out = append(out, '\n', '\n')
out = append(out, []byte("# Neue Parameter aus der Go-Version\n")...)
for _, block := range add {
out = append(out, bytes.Trim(block, "\r\n")...)
out = append(out, '\n')
}
return out, true
}
func configKeys(data []byte) map[string]bool {
keys := map[string]bool{}
for _, line := range bytes.Split(data, []byte{'\n'}) {
line = bytes.TrimSpace(line)
if len(line) == 0 || line[0] == '#' {
continue
}
if i := bytes.IndexByte(line, '='); i > 0 {
keys[string(bytes.TrimSpace(line[:i]))] = true
}
}
return keys
}
func configBlocks(data []byte) [][]byte {
lines := bytes.Split(data, []byte{'\n'})
var blocks [][]byte
var comments [][]byte
for _, line := range lines {
trim := bytes.TrimSpace(line)
if len(trim) == 0 || trim[0] == '#' {
comments = append(comments, append([]byte(nil), line...))
continue
}
block := bytes.Join(append(comments, line), []byte{'\n'})
blocks = append(blocks, block)
comments = nil
}
return blocks
}
func blockKey(block []byte) string {
for _, line := range bytes.Split(block, []byte{'\n'}) {
line = bytes.TrimSpace(line)
if len(line) == 0 || line[0] == '#' {
continue
}
if i := bytes.IndexByte(line, '='); i > 0 {
return string(bytes.TrimSpace(line[:i]))
}
}
return ""
}
func writeService(installDir string) error {
service := fmt.Sprintf(`[Unit]
Description=AdGuard Shield - Go DNS Rate-Limit Monitor
After=network.target AdGuardHome.service
Wants=AdGuardHome.service
StartLimitBurst=5
StartLimitIntervalSec=300
[Service]
Type=simple
ExecStart=%s/adguard-shield -config %s/adguard-shield.conf run
ExecReload=/bin/kill -HUP $MAINPID
Restart=on-failure
RestartSec=30
ProtectSystem=full
ReadWritePaths=/var/log /var/lib/adguard-shield /var/run %s/geoip
ProtectHome=true
NoNewPrivileges=false
PrivateTmp=true
AmbientCapabilities=CAP_NET_ADMIN CAP_NET_RAW
CapabilityBoundingSet=CAP_NET_ADMIN CAP_NET_RAW CAP_DAC_OVERRIDE CAP_DAC_READ_SEARCH CAP_FOWNER CAP_KILL CAP_SETUID CAP_SETGID CAP_CHOWN
StandardOutput=journal
StandardError=journal
SyslogIdentifier=adguard-shield
[Install]
WantedBy=multi-user.target
`, installDir, installDir, installDir)
return os.WriteFile(ServicePath, []byte(service), 0644)
}
func run(name string, args ...string) error {
cmd := exec.Command(name, args...)
out, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("%s %s: %w\n%s", name, strings.Join(args, " "), err, strings.TrimSpace(string(out)))
}
return nil
}
func runStreaming(name string, args ...string) error {
cmd := exec.Command(name, args...)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
return fmt.Errorf("%s %s: %w", name, strings.Join(args, " "), err)
}
return nil
}
func commandOK(name string, args ...string) bool {
return exec.Command(name, args...).Run() == nil
}
func fileExists(path string) bool {
_, err := os.Stat(path)
return err == nil
}
func yesNo(ok bool) string {
if ok {
return "ja"
}
return "nein"
}
func IsLegacyError(err error) (*LegacyError, bool) {
var le *LegacyError
if errors.As(err, &le) {
return le, true
}
return nil, false
}
const defaultConfig = `# AdGuard Shield Konfiguration
ADGUARD_URL="https://dns1.domain.com"
ADGUARD_USER="admin"
ADGUARD_PASS='changeme'
RATE_LIMIT_MAX_REQUESTS=30
RATE_LIMIT_WINDOW=60
CHECK_INTERVAL=10
API_QUERY_LIMIT=500
SUBDOMAIN_FLOOD_ENABLED=true
SUBDOMAIN_FLOOD_MAX_UNIQUE=50
SUBDOMAIN_FLOOD_WINDOW=60
DNS_FLOOD_WATCHLIST_ENABLED=false
DNS_FLOOD_WATCHLIST=""
BAN_DURATION=3600
IPTABLES_CHAIN="ADGUARD_SHIELD"
BLOCKED_PORTS="53 443 853"
FIREWALL_BACKEND="ipset"
FIREWALL_MODE="host"
DRY_RUN=false
WHITELIST="127.0.0.1,::1"
LOG_FILE="/var/log/adguard-shield.log"
LOG_LEVEL="INFO"
STATE_DIR="/var/lib/adguard-shield"
PID_FILE="/var/run/adguard-shield.pid"
NOTIFY_ENABLED=false
NOTIFY_TYPE="ntfy"
NOTIFY_WEBHOOK_URL=""
NTFY_SERVER_URL="https://ntfy.sh"
NTFY_TOPIC=""
NTFY_TOKEN=""
NTFY_PRIORITY="4"
REPORT_ENABLED=false
REPORT_INTERVAL="weekly"
REPORT_TIME="08:00"
REPORT_EMAIL_TO="admin@example.com"
REPORT_EMAIL_FROM="adguard-shield@example.com"
REPORT_FORMAT="html"
REPORT_MAIL_CMD="msmtp"
REPORT_BUSIEST_DAY_RANGE=30
EXTERNAL_WHITELIST_ENABLED=false
EXTERNAL_WHITELIST_URLS=""
EXTERNAL_WHITELIST_INTERVAL=300
EXTERNAL_WHITELIST_CACHE_DIR="/var/lib/adguard-shield/external-whitelist"
EXTERNAL_BLOCKLIST_ENABLED=false
EXTERNAL_BLOCKLIST_URLS=""
EXTERNAL_BLOCKLIST_INTERVAL=300
EXTERNAL_BLOCKLIST_BAN_DURATION=0
EXTERNAL_BLOCKLIST_AUTO_UNBAN=true
EXTERNAL_BLOCKLIST_NOTIFY=false
EXTERNAL_BLOCKLIST_CACHE_DIR="/var/lib/adguard-shield/external-blocklist"
PROGRESSIVE_BAN_ENABLED=true
PROGRESSIVE_BAN_MULTIPLIER=2
PROGRESSIVE_BAN_MAX_LEVEL=5
PROGRESSIVE_BAN_RESET_AFTER=86400
ABUSEIPDB_ENABLED=false
ABUSEIPDB_API_KEY=""
ABUSEIPDB_CATEGORIES="4"
GEOIP_ENABLED=false
GEOIP_MODE="blocklist"
GEOIP_COUNTRIES=""
GEOIP_CHECK_INTERVAL=0
GEOIP_NOTIFY=true
GEOIP_SKIP_PRIVATE=true
GEOIP_LICENSE_KEY=""
GEOIP_MMDB_PATH=""
GEOIP_CACHE_TTL=86400
`

242
internal/report/report.go Normal file
View File

@@ -0,0 +1,242 @@
package report
import (
"bytes"
"context"
"fmt"
"html"
"os"
"os/exec"
"path/filepath"
"strconv"
"strings"
"time"
"adguard-shield/internal/config"
"adguard-shield/internal/db"
)
type Store interface {
ReportStats(since, until int64, limit int) (db.ReportStats, error)
}
const cronPath = "/etc/cron.d/adguard-shield-report"
func Status(c *config.Config) string {
cron := "nicht installiert"
if _, err := os.Stat(cronPath); err == nil {
cron = "installiert (" + cronPath + ")"
}
return fmt.Sprintf(`E-Mail Report
Aktiv: %v
Intervall: %s
Zeit: %s
Empfaenger: %s
Absender: %s
Format: %s
Mail-Befehl: %s
Cron: %s
`, c.ReportEnabled, c.ReportInterval, c.ReportTime, c.ReportEmailTo, c.ReportEmailFrom, c.ReportFormat, c.ReportMailCmd, cron)
}
func Generate(c *config.Config, st Store, format string) (string, error) {
if format == "" {
format = c.ReportFormat
}
since, until := window(c.ReportInterval)
stats, err := st.ReportStats(since, until, 20)
if err != nil {
return "", err
}
if strings.EqualFold(format, "html") {
return renderHTML(c, stats), nil
}
return renderText(c, stats), nil
}
func Send(ctx context.Context, c *config.Config, st Store) error {
body, err := Generate(c, st, c.ReportFormat)
if err != nil {
return err
}
return sendMail(ctx, c, "AdGuard Shield Report", body)
}
func SendTest(ctx context.Context, c *config.Config) error {
body := fmt.Sprintf("AdGuard Shield Test-Mail\n\nHostname: %s\nZeitpunkt: %s\nEmpfaenger: %s\nAbsender: %s\n", hostname(), time.Now().Format("2006-01-02 15:04:05"), c.ReportEmailTo, c.ReportEmailFrom)
if strings.EqualFold(c.ReportFormat, "html") {
body = "<!doctype html><html><body><h1>AdGuard Shield Test-Mail</h1><p>Hostname: " + html.EscapeString(hostname()) + "</p><p>Zeitpunkt: " + html.EscapeString(time.Now().Format("2006-01-02 15:04:05")) + "</p></body></html>"
}
return sendMail(ctx, c, "AdGuard Shield Test-Mail", body)
}
func InstallCron(binary, configPath string, c *config.Config) error {
minute, hour, err := parseReportTime(c.ReportTime)
if err != nil {
return err
}
schedule := cronSchedule(c.ReportInterval, minute, hour)
if binary == "" {
binary = "/opt/adguard-shield/adguard-shield"
}
if configPath == "" {
configPath = "/opt/adguard-shield/adguard-shield.conf"
}
line := fmt.Sprintf("SHELL=/bin/sh\nPATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin\n%s root %s -config %s report-send\n", schedule, binary, configPath)
return os.WriteFile(cronPath, []byte(line), 0644)
}
func RemoveCron() error {
if err := os.Remove(cronPath); err != nil && !os.IsNotExist(err) {
return err
}
return nil
}
func sendMail(ctx context.Context, c *config.Config, subject, body string) error {
if c.ReportEmailTo == "" {
return fmt.Errorf("REPORT_EMAIL_TO ist leer")
}
if c.ReportMailCmd == "" {
return fmt.Errorf("REPORT_MAIL_CMD ist leer")
}
contentType := "text/plain; charset=utf-8"
if strings.EqualFold(c.ReportFormat, "html") {
contentType = "text/html; charset=utf-8"
}
msg := "From: " + c.ReportEmailFrom + "\n" +
"To: " + c.ReportEmailTo + "\n" +
"Subject: " + subject + "\n" +
"Content-Type: " + contentType + "\n\n" + body
parts := strings.Fields(c.ReportMailCmd)
if len(parts) == 0 {
return fmt.Errorf("REPORT_MAIL_CMD ist leer")
}
args := append(parts[1:], "-t")
cmd := exec.CommandContext(ctx, parts[0], args...)
cmd.Stdin = strings.NewReader(msg)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd.Run()
}
func parseReportTime(value string) (string, string, error) {
parts := strings.Split(value, ":")
if len(parts) != 2 {
return "", "", fmt.Errorf("REPORT_TIME muss HH:MM sein")
}
hour, err := strconv.Atoi(parts[0])
if err != nil || hour < 0 || hour > 23 {
return "", "", fmt.Errorf("REPORT_TIME hat ungueltige Stunde")
}
minute, err := strconv.Atoi(parts[1])
if err != nil || minute < 0 || minute > 59 {
return "", "", fmt.Errorf("REPORT_TIME hat ungueltige Minute")
}
return strconv.Itoa(minute), strconv.Itoa(hour), nil
}
func cronSchedule(interval, minute, hour string) string {
switch strings.ToLower(interval) {
case "daily":
return fmt.Sprintf("%s %s * * *", minute, hour)
case "biweekly":
return fmt.Sprintf("%s %s 1,15 * *", minute, hour)
case "monthly":
return fmt.Sprintf("%s %s 1 * *", minute, hour)
default:
return fmt.Sprintf("%s %s * * 1", minute, hour)
}
}
func window(interval string) (int64, int64) {
now := time.Now()
days := 7
switch strings.ToLower(interval) {
case "daily":
days = 1
case "biweekly":
days = 14
case "monthly":
days = 30
}
return now.AddDate(0, 0, -days).Unix(), now.Unix()
}
func renderText(c *config.Config, st db.ReportStats) string {
var b strings.Builder
b.WriteString("AdGuard Shield Report\n")
b.WriteString("Zeitraum: " + formatTime(st.Since) + " bis " + formatTime(st.Until) + "\n\n")
b.WriteString("Bans: " + strconv.Itoa(st.TotalBans) + "\n")
b.WriteString("Unbans: " + strconv.Itoa(st.TotalUnbans) + "\n")
b.WriteString("Aktive Sperren: " + strconv.Itoa(st.ActiveBans) + "\n\n")
writeCountsText(&b, "Top Clients", st.TopClients)
writeCountsText(&b, "Gruende", st.Reasons)
writeCountsText(&b, "Aktive Quellen", st.Sources)
if len(st.RecentEvents) > 0 {
b.WriteString("Letzte Ereignisse\n")
for _, e := range st.RecentEvents {
b.WriteString("- " + e + "\n")
}
}
_ = c
return b.String()
}
func renderHTML(c *config.Config, st db.ReportStats) string {
var b bytes.Buffer
b.WriteString("<!doctype html><html><head><meta charset=\"utf-8\"><title>AdGuard Shield Report</title>")
b.WriteString("<style>body{font-family:Arial,sans-serif;color:#1f2937}table{border-collapse:collapse;margin:12px 0}td,th{border:1px solid #d1d5db;padding:6px 9px;text-align:left}th{background:#f3f4f6}</style>")
b.WriteString("</head><body>")
b.WriteString("<h1>AdGuard Shield Report</h1>")
b.WriteString("<p>Zeitraum: " + html.EscapeString(formatTime(st.Since)) + " bis " + html.EscapeString(formatTime(st.Until)) + "</p>")
b.WriteString("<ul><li>Bans: " + strconv.Itoa(st.TotalBans) + "</li><li>Unbans: " + strconv.Itoa(st.TotalUnbans) + "</li><li>Aktive Sperren: " + strconv.Itoa(st.ActiveBans) + "</li></ul>")
writeCountsHTML(&b, "Top Clients", st.TopClients)
writeCountsHTML(&b, "Gruende", st.Reasons)
writeCountsHTML(&b, "Aktive Quellen", st.Sources)
if len(st.RecentEvents) > 0 {
b.WriteString("<h2>Letzte Ereignisse</h2><table><tr><th>Ereignis</th></tr>")
for _, e := range st.RecentEvents {
b.WriteString("<tr><td>" + html.EscapeString(e) + "</td></tr>")
}
b.WriteString("</table>")
}
b.WriteString("</body></html>")
_ = c
return b.String()
}
func writeCountsText(b *strings.Builder, title string, rows []db.ReportCount) {
b.WriteString(title + "\n")
if len(rows) == 0 {
b.WriteString("- keine Daten\n\n")
return
}
for _, r := range rows {
b.WriteString("- " + r.Name + ": " + strconv.Itoa(r.Count) + "\n")
}
b.WriteByte('\n')
}
func writeCountsHTML(b *bytes.Buffer, title string, rows []db.ReportCount) {
b.WriteString("<h2>" + html.EscapeString(title) + "</h2><table><tr><th>Name</th><th>Anzahl</th></tr>")
if len(rows) == 0 {
b.WriteString("<tr><td colspan=\"2\">keine Daten</td></tr>")
}
for _, r := range rows {
b.WriteString("<tr><td>" + html.EscapeString(r.Name) + "</td><td>" + strconv.Itoa(r.Count) + "</td></tr>")
}
b.WriteString("</table>")
}
func formatTime(epoch int64) string {
return time.Unix(epoch, 0).Format("2006-01-02 15:04:05")
}
func hostname() string {
name, err := os.Hostname()
if err != nil || name == "" {
return filepath.Base(os.Args[0])
}
return name
}

82
internal/syslog/syslog.go Normal file
View File

@@ -0,0 +1,82 @@
package syslog
import (
"fmt"
"io"
"log"
"strings"
"sync"
)
type Level int
const (
Debug Level = iota
Info
Warn
Error
)
type Logger struct {
mu sync.Mutex
min Level
log *log.Logger
}
func New(w io.Writer, min string) *Logger {
return &Logger{
min: ParseLevel(min, Info),
log: log.New(w, "", log.LstdFlags),
}
}
func ParseLevel(s string, fallback Level) Level {
switch strings.ToUpper(strings.TrimSpace(s)) {
case "DEBUG":
return Debug
case "INFO", "":
return Info
case "WARN", "WARNING":
return Warn
case "ERROR", "ERR":
return Error
default:
return fallback
}
}
func LevelName(l Level) string {
switch l {
case Debug:
return "DEBUG"
case Info:
return "INFO"
case Warn:
return "WARN"
case Error:
return "ERROR"
default:
return "INFO"
}
}
func (l *Logger) Enabled(level Level) bool {
if l == nil {
return false
}
return level >= l.min
}
func (l *Logger) Logf(level Level, format string, args ...any) {
if !l.Enabled(level) {
return
}
l.mu.Lock()
defer l.mu.Unlock()
l.log.Printf("[%s] [ADGUARD-SHIELDD] %s", LevelName(level), fmt.Sprintf(format, args...))
}
func (l *Logger) Debugf(format string, args ...any) { l.Logf(Debug, format, args...) }
func (l *Logger) Infof(format string, args ...any) { l.Logf(Info, format, args...) }
func (l *Logger) Warnf(format string, args ...any) { l.Logf(Warn, format, args...) }
func (l *Logger) Errorf(format string, args ...any) { l.Logf(Error, format, args...) }