feat!: Migration auf Go-Binary
BREAKING CHANGE: Die alte Shell-Version muss vor der Installation der Go-Version deinstalliert werden.
This commit is contained in:
5
internal/appinfo/appinfo.go
Normal file
5
internal/appinfo/appinfo.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package appinfo
|
||||
|
||||
var Version = "v1.0.0"
|
||||
|
||||
const ProjectURL = "https://tnvs.de/as"
|
||||
295
internal/config/config.go
Normal file
295
internal/config/config.go
Normal file
@@ -0,0 +1,295 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Path string
|
||||
|
||||
AdGuardURL string
|
||||
AdGuardUser string
|
||||
AdGuardPass string
|
||||
|
||||
RateLimitMaxRequests int
|
||||
RateLimitWindow int
|
||||
CheckInterval int
|
||||
APIQueryLimit int
|
||||
|
||||
SubdomainFloodEnabled bool
|
||||
SubdomainFloodMaxUnique int
|
||||
SubdomainFloodWindow int
|
||||
|
||||
DNSFloodWatchlistEnabled bool
|
||||
DNSFloodWatchlist []string
|
||||
|
||||
BanDuration int64
|
||||
Chain string
|
||||
BlockedPorts []string
|
||||
FirewallBackend string
|
||||
FirewallMode string
|
||||
DryRun bool
|
||||
|
||||
Whitelist []string
|
||||
|
||||
LogFile string
|
||||
LogLevel string
|
||||
StateDir string
|
||||
PIDFile string
|
||||
|
||||
NotifyEnabled bool
|
||||
NotifyType string
|
||||
NotifyWebhook string
|
||||
NTFYServerURL string
|
||||
NTFYTopic string
|
||||
NTFYToken string
|
||||
NTFYPriority string
|
||||
|
||||
ReportEnabled bool
|
||||
ReportInterval string
|
||||
ReportTime string
|
||||
ReportEmailTo string
|
||||
ReportEmailFrom string
|
||||
ReportFormat string
|
||||
ReportMailCmd string
|
||||
ReportBusiestDayRange int
|
||||
|
||||
ExternalWhitelistEnabled bool
|
||||
ExternalWhitelistURLs []string
|
||||
ExternalWhitelistInterval int
|
||||
ExternalWhitelistCacheDir string
|
||||
|
||||
ExternalBlocklistEnabled bool
|
||||
ExternalBlocklistURLs []string
|
||||
ExternalBlocklistInterval int
|
||||
ExternalBlocklistCacheDir string
|
||||
ExternalBlocklistDuration int64
|
||||
ExternalBlocklistAutoUnban bool
|
||||
ExternalBlocklistNotify bool
|
||||
|
||||
ProgressiveBanEnabled bool
|
||||
ProgressiveBanMultiplier int
|
||||
ProgressiveBanMaxLevel int
|
||||
ProgressiveBanResetAfter int64
|
||||
|
||||
AbuseIPDBEnabled bool
|
||||
AbuseIPDBAPIKey string
|
||||
AbuseIPDBCategories string
|
||||
|
||||
GeoIPEnabled bool
|
||||
GeoIPMode string
|
||||
GeoIPCountries []string
|
||||
GeoIPNotify bool
|
||||
GeoIPSkipPrivate bool
|
||||
GeoIPLicenseKey string
|
||||
GeoIPMMDBPath string
|
||||
GeoIPCacheTTL int64
|
||||
GeoIPCheckInterval int
|
||||
}
|
||||
|
||||
func Load(path string) (*Config, error) {
|
||||
values, err := parseFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c := &Config{Path: path}
|
||||
c.AdGuardURL = stringVal(values, "ADGUARD_URL", "")
|
||||
c.AdGuardUser = stringVal(values, "ADGUARD_USER", "")
|
||||
c.AdGuardPass = stringVal(values, "ADGUARD_PASS", "")
|
||||
c.RateLimitMaxRequests = intVal(values, "RATE_LIMIT_MAX_REQUESTS", 30)
|
||||
c.RateLimitWindow = intVal(values, "RATE_LIMIT_WINDOW", 60)
|
||||
c.CheckInterval = intVal(values, "CHECK_INTERVAL", 10)
|
||||
c.APIQueryLimit = intVal(values, "API_QUERY_LIMIT", 500)
|
||||
c.SubdomainFloodEnabled = boolVal(values, "SUBDOMAIN_FLOOD_ENABLED", true)
|
||||
c.SubdomainFloodMaxUnique = intVal(values, "SUBDOMAIN_FLOOD_MAX_UNIQUE", 50)
|
||||
c.SubdomainFloodWindow = intVal(values, "SUBDOMAIN_FLOOD_WINDOW", 60)
|
||||
c.DNSFloodWatchlistEnabled = boolVal(values, "DNS_FLOOD_WATCHLIST_ENABLED", false)
|
||||
c.DNSFloodWatchlist = csv(values["DNS_FLOOD_WATCHLIST"])
|
||||
c.BanDuration = int64(intVal(values, "BAN_DURATION", 3600))
|
||||
c.Chain = stringVal(values, "IPTABLES_CHAIN", "ADGUARD_SHIELD")
|
||||
c.BlockedPorts = fields(stringVal(values, "BLOCKED_PORTS", "53 443 853"))
|
||||
c.FirewallBackend = stringVal(values, "FIREWALL_BACKEND", "ipset")
|
||||
c.FirewallMode = strings.ToLower(strings.TrimSpace(stringVal(values, "FIREWALL_MODE", "host")))
|
||||
c.DryRun = boolVal(values, "DRY_RUN", false)
|
||||
if strings.EqualFold(os.Getenv("DRY_RUN"), "true") || os.Getenv("DRY_RUN") == "1" {
|
||||
c.DryRun = true
|
||||
}
|
||||
c.Whitelist = csv(values["WHITELIST"])
|
||||
c.LogFile = stringVal(values, "LOG_FILE", "/var/log/adguard-shield.log")
|
||||
c.LogLevel = stringVal(values, "LOG_LEVEL", "INFO")
|
||||
c.StateDir = stringVal(values, "STATE_DIR", "/var/lib/adguard-shield")
|
||||
c.PIDFile = stringVal(values, "PID_FILE", "/var/run/adguard-shield.pid")
|
||||
c.NotifyEnabled = boolVal(values, "NOTIFY_ENABLED", false)
|
||||
c.NotifyType = stringVal(values, "NOTIFY_TYPE", "ntfy")
|
||||
c.NotifyWebhook = stringVal(values, "NOTIFY_WEBHOOK_URL", "")
|
||||
c.NTFYServerURL = stringVal(values, "NTFY_SERVER_URL", "https://ntfy.sh")
|
||||
c.NTFYTopic = stringVal(values, "NTFY_TOPIC", "")
|
||||
c.NTFYToken = stringVal(values, "NTFY_TOKEN", "")
|
||||
c.NTFYPriority = stringVal(values, "NTFY_PRIORITY", "4")
|
||||
c.ReportEnabled = boolVal(values, "REPORT_ENABLED", false)
|
||||
c.ReportInterval = stringVal(values, "REPORT_INTERVAL", "weekly")
|
||||
c.ReportTime = stringVal(values, "REPORT_TIME", "08:00")
|
||||
c.ReportEmailTo = stringVal(values, "REPORT_EMAIL_TO", "admin@example.com")
|
||||
c.ReportEmailFrom = stringVal(values, "REPORT_EMAIL_FROM", "adguard-shield@example.com")
|
||||
c.ReportFormat = strings.ToLower(stringVal(values, "REPORT_FORMAT", "html"))
|
||||
c.ReportMailCmd = stringVal(values, "REPORT_MAIL_CMD", "msmtp")
|
||||
c.ReportBusiestDayRange = intVal(values, "REPORT_BUSIEST_DAY_RANGE", 30)
|
||||
c.ExternalWhitelistEnabled = boolVal(values, "EXTERNAL_WHITELIST_ENABLED", false)
|
||||
c.ExternalWhitelistURLs = csv(values["EXTERNAL_WHITELIST_URLS"])
|
||||
c.ExternalWhitelistInterval = intVal(values, "EXTERNAL_WHITELIST_INTERVAL", 300)
|
||||
c.ExternalWhitelistCacheDir = stringVal(values, "EXTERNAL_WHITELIST_CACHE_DIR", filepath.Join(c.StateDir, "external-whitelist"))
|
||||
c.ExternalBlocklistEnabled = boolVal(values, "EXTERNAL_BLOCKLIST_ENABLED", false)
|
||||
c.ExternalBlocklistURLs = csv(values["EXTERNAL_BLOCKLIST_URLS"])
|
||||
c.ExternalBlocklistInterval = intVal(values, "EXTERNAL_BLOCKLIST_INTERVAL", 300)
|
||||
c.ExternalBlocklistCacheDir = stringVal(values, "EXTERNAL_BLOCKLIST_CACHE_DIR", filepath.Join(c.StateDir, "external-blocklist"))
|
||||
c.ExternalBlocklistDuration = int64(intVal(values, "EXTERNAL_BLOCKLIST_BAN_DURATION", 0))
|
||||
c.ExternalBlocklistAutoUnban = boolVal(values, "EXTERNAL_BLOCKLIST_AUTO_UNBAN", true)
|
||||
c.ExternalBlocklistNotify = boolVal(values, "EXTERNAL_BLOCKLIST_NOTIFY", false)
|
||||
c.ProgressiveBanEnabled = boolVal(values, "PROGRESSIVE_BAN_ENABLED", true)
|
||||
c.ProgressiveBanMultiplier = intVal(values, "PROGRESSIVE_BAN_MULTIPLIER", 2)
|
||||
c.ProgressiveBanMaxLevel = intVal(values, "PROGRESSIVE_BAN_MAX_LEVEL", 5)
|
||||
c.ProgressiveBanResetAfter = int64(intVal(values, "PROGRESSIVE_BAN_RESET_AFTER", 86400))
|
||||
c.AbuseIPDBEnabled = boolVal(values, "ABUSEIPDB_ENABLED", false)
|
||||
c.AbuseIPDBAPIKey = stringVal(values, "ABUSEIPDB_API_KEY", "")
|
||||
c.AbuseIPDBCategories = stringVal(values, "ABUSEIPDB_CATEGORIES", "4")
|
||||
c.GeoIPEnabled = boolVal(values, "GEOIP_ENABLED", false)
|
||||
c.GeoIPMode = strings.ToLower(stringVal(values, "GEOIP_MODE", "blocklist"))
|
||||
c.GeoIPCountries = upperCSV(values["GEOIP_COUNTRIES"])
|
||||
c.GeoIPNotify = boolVal(values, "GEOIP_NOTIFY", true)
|
||||
c.GeoIPSkipPrivate = boolVal(values, "GEOIP_SKIP_PRIVATE", true)
|
||||
c.GeoIPLicenseKey = stringVal(values, "GEOIP_LICENSE_KEY", "")
|
||||
c.GeoIPMMDBPath = stringVal(values, "GEOIP_MMDB_PATH", "")
|
||||
c.GeoIPCacheTTL = int64(intVal(values, "GEOIP_CACHE_TTL", 86400))
|
||||
c.GeoIPCheckInterval = intVal(values, "GEOIP_CHECK_INTERVAL", 0)
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func DefaultPath() string {
|
||||
if v := os.Getenv("ADGUARD_SHIELD_CONFIG"); v != "" {
|
||||
return v
|
||||
}
|
||||
if _, err := os.Stat("/opt/adguard-shield/adguard-shield.conf"); err == nil {
|
||||
return "/opt/adguard-shield/adguard-shield.conf"
|
||||
}
|
||||
return filepath.Join(".", "adguard-shield.conf")
|
||||
}
|
||||
|
||||
func (c *Config) DBPath() string { return filepath.Join(c.StateDir, "adguard-shield.db") }
|
||||
func (c *Config) GeoIPDir(scriptDir string) string { return filepath.Join(scriptDir, "geoip") }
|
||||
|
||||
func parseFile(path string) (map[string]string, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open config %s: %w", path, err)
|
||||
}
|
||||
defer f.Close()
|
||||
out := map[string]string{}
|
||||
sc := bufio.NewScanner(f)
|
||||
for sc.Scan() {
|
||||
line := strings.TrimSpace(sc.Text())
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
idx := strings.Index(line, "=")
|
||||
if idx < 1 {
|
||||
continue
|
||||
}
|
||||
key := strings.TrimSpace(line[:idx])
|
||||
val := stripInlineComment(strings.TrimSpace(line[idx+1:]))
|
||||
out[key] = unquote(val)
|
||||
}
|
||||
return out, sc.Err()
|
||||
}
|
||||
|
||||
func stripInlineComment(s string) string {
|
||||
inSingle, inDouble := false, false
|
||||
for i, r := range s {
|
||||
switch r {
|
||||
case '\'':
|
||||
if !inDouble {
|
||||
inSingle = !inSingle
|
||||
}
|
||||
case '"':
|
||||
if !inSingle {
|
||||
inDouble = !inDouble
|
||||
}
|
||||
case '#':
|
||||
if !inSingle && !inDouble {
|
||||
if i == 0 || s[i-1] == ' ' || s[i-1] == '\t' {
|
||||
return strings.TrimSpace(s[:i])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return strings.TrimSpace(s)
|
||||
}
|
||||
|
||||
func unquote(s string) string {
|
||||
if len(s) >= 2 {
|
||||
if (s[0] == '"' && s[len(s)-1] == '"') || (s[0] == '\'' && s[len(s)-1] == '\'') {
|
||||
return s[1 : len(s)-1]
|
||||
}
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func stringVal(m map[string]string, k, def string) string {
|
||||
if v, ok := m[k]; ok {
|
||||
return v
|
||||
}
|
||||
return def
|
||||
}
|
||||
func intVal(m map[string]string, k string, def int) int {
|
||||
v, ok := m[k]
|
||||
if !ok || strings.TrimSpace(v) == "" {
|
||||
return def
|
||||
}
|
||||
n, err := strconv.Atoi(strings.TrimSpace(v))
|
||||
if err != nil {
|
||||
return def
|
||||
}
|
||||
return n
|
||||
}
|
||||
func boolVal(m map[string]string, k string, def bool) bool {
|
||||
v, ok := m[k]
|
||||
if !ok {
|
||||
return def
|
||||
}
|
||||
switch strings.ToLower(strings.TrimSpace(v)) {
|
||||
case "true", "1", "yes", "on":
|
||||
return true
|
||||
case "false", "0", "no", "off":
|
||||
return false
|
||||
default:
|
||||
return def
|
||||
}
|
||||
}
|
||||
func csv(s string) []string {
|
||||
var out []string
|
||||
for _, p := range strings.Split(s, ",") {
|
||||
p = strings.TrimSpace(p)
|
||||
if p != "" {
|
||||
out = append(out, p)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
func upperCSV(s string) []string {
|
||||
parts := csv(s)
|
||||
for i := range parts {
|
||||
parts[i] = strings.ToUpper(parts[i])
|
||||
}
|
||||
return parts
|
||||
}
|
||||
func fields(s string) []string {
|
||||
out := strings.Fields(s)
|
||||
if len(out) == 0 {
|
||||
return []string{"53"}
|
||||
}
|
||||
return out
|
||||
}
|
||||
44
internal/config/config_test.go
Normal file
44
internal/config/config_test.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLoadParsesShellStyleConfig(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "adguard-shield.conf")
|
||||
err := os.WriteFile(path, []byte(`
|
||||
ADGUARD_URL="https://dns.example"
|
||||
ADGUARD_USER="admin"
|
||||
ADGUARD_PASS='pa#ss'
|
||||
CHECK_INTERVAL=7
|
||||
BLOCKED_PORTS="53 443 853"
|
||||
FIREWALL_BACKEND="ipset"
|
||||
FIREWALL_MODE="docker-bridge"
|
||||
GEOIP_ENABLED=true
|
||||
GEOIP_MODE="allowlist"
|
||||
GEOIP_COUNTRIES="DE, us"
|
||||
GEOIP_CACHE_TTL=123
|
||||
`), 0600)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
c, err := Load(path)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if c.AdGuardPass != "pa#ss" {
|
||||
t.Fatalf("quoted # was not preserved: %q", c.AdGuardPass)
|
||||
}
|
||||
if c.CheckInterval != 7 || c.FirewallBackend != "ipset" || c.FirewallMode != "docker-bridge" {
|
||||
t.Fatalf("unexpected config: %+v", c)
|
||||
}
|
||||
if got := c.GeoIPCountries; len(got) != 2 || got[0] != "DE" || got[1] != "US" {
|
||||
t.Fatalf("unexpected countries: %#v", got)
|
||||
}
|
||||
if c.GeoIPCacheTTL != 123 {
|
||||
t.Fatalf("unexpected GeoIP cache ttl: %d", c.GeoIPCacheTTL)
|
||||
}
|
||||
}
|
||||
1221
internal/daemon/daemon.go
Normal file
1221
internal/daemon/daemon.go
Normal file
File diff suppressed because it is too large
Load Diff
365
internal/daemon/daemon_test.go
Normal file
365
internal/daemon/daemon_test.go
Normal file
@@ -0,0 +1,365 @@
|
||||
package daemon
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"adguard-shield/internal/config"
|
||||
"adguard-shield/internal/db"
|
||||
"adguard-shield/internal/firewall"
|
||||
)
|
||||
|
||||
func TestParseListEntry(t *testing.T) {
|
||||
cases := map[string]string{
|
||||
"1.2.3.4 # comment": "1.2.3.4",
|
||||
"0.0.0.0 bad.example": "bad.example",
|
||||
"2001:db8::/32": "2001:db8::/32",
|
||||
}
|
||||
for input, want := range cases {
|
||||
got := parseListEntry(input)
|
||||
if len(got) != 1 || got[0] != want {
|
||||
t.Fatalf("%q -> %#v, want %q", input, got, want)
|
||||
}
|
||||
}
|
||||
if got := parseListEntry("http://example.invalid/list"); got != nil {
|
||||
t.Fatalf("URL should be rejected: %#v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNotificationFormatting(t *testing.T) {
|
||||
d := &Daemon{Config: &config.Config{
|
||||
RateLimitWindow: 60,
|
||||
SubdomainFloodWindow: 120,
|
||||
ProgressiveBanMaxLevel: 3,
|
||||
}}
|
||||
b := db.Ban{
|
||||
IP: "203.0.113.7",
|
||||
Domain: "abb.com",
|
||||
Count: 110,
|
||||
Duration: 3600,
|
||||
OffenseLevel: 1,
|
||||
Reason: "rate-limit",
|
||||
Protocol: "dns",
|
||||
Source: "monitor",
|
||||
}
|
||||
if got, want := d.displayBanReason(b), "110x abb.com in 60s via DNS, Rate-Limit"; got != want {
|
||||
t.Fatalf("reason = %q, want %q", got, want)
|
||||
}
|
||||
if got, want := d.displayBanDuration(b), "1h 0m [Stufe 1/3]"; got != want {
|
||||
t.Fatalf("duration = %q, want %q", got, want)
|
||||
}
|
||||
|
||||
b.Permanent = true
|
||||
b.Duration = 0
|
||||
b.OffenseLevel = 3
|
||||
if got, want := d.displayBanDuration(b), "PERMANENT [Stufe 3/3]"; got != want {
|
||||
t.Fatalf("permanent duration = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNTFYNotificationTitleDoesNotDuplicateShieldTag(t *testing.T) {
|
||||
d := &Daemon{Config: &config.Config{
|
||||
NotifyType: "ntfy",
|
||||
NTFYServerURL: "https://ntfy.example",
|
||||
NTFYTopic: "adguard-shield",
|
||||
NTFYPriority: "4",
|
||||
}}
|
||||
req, err := d.notificationRequest(context.Background(), "🛡️ AdGuard Shield", "test", db.Ban{IP: "203.0.113.7", Reason: "rate-limit", Source: "monitor"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if req == nil {
|
||||
t.Fatal("request must be created")
|
||||
}
|
||||
if got, want := req.Header.Get("Title"), "🛡️ AdGuard Shield"; got != want {
|
||||
t.Fatalf("title = %q, want %q", got, want)
|
||||
}
|
||||
if got := req.Header.Get("Tags"); strings.Contains(got, "shield") {
|
||||
t.Fatalf("tags must not duplicate title shield emoji: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNotificationRequestsForWebhookProviders(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
notifyType string
|
||||
wantType string
|
||||
wantPayload []string
|
||||
}{
|
||||
{
|
||||
name: "discord",
|
||||
notifyType: "discord",
|
||||
wantType: "application/json",
|
||||
wantPayload: []string{`"content":"title\n\nmessage"`},
|
||||
},
|
||||
{
|
||||
name: "slack",
|
||||
notifyType: "slack",
|
||||
wantType: "application/json",
|
||||
wantPayload: []string{`"text":"title\n\nmessage"`},
|
||||
},
|
||||
{
|
||||
name: "generic",
|
||||
notifyType: "generic",
|
||||
wantType: "application/json",
|
||||
wantPayload: []string{`"action":"unban"`, `"client":"203.0.113.7"`, `"message":"message"`},
|
||||
},
|
||||
{
|
||||
name: "gotify",
|
||||
notifyType: "gotify",
|
||||
wantType: "application/x-www-form-urlencoded",
|
||||
wantPayload: []string{`title=title`, `message=message`, `priority=5`},
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
d := &Daemon{Config: &config.Config{
|
||||
NotifyType: tc.notifyType,
|
||||
NotifyWebhook: "https://hooks.example/notify",
|
||||
}}
|
||||
req, err := d.notificationRequest(context.Background(), "title", "message", db.Ban{IP: "203.0.113.7", Reason: "manual"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if req == nil {
|
||||
t.Fatal("request must be created")
|
||||
}
|
||||
if req.Method != http.MethodPost {
|
||||
t.Fatalf("method = %s, want POST", req.Method)
|
||||
}
|
||||
if got := req.Header.Get("Content-Type"); got != tc.wantType {
|
||||
t.Fatalf("content type = %q, want %q", got, tc.wantType)
|
||||
}
|
||||
body, err := io.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
payload := string(body)
|
||||
for _, want := range tc.wantPayload {
|
||||
if !strings.Contains(payload, want) {
|
||||
t.Fatalf("payload %q does not contain %q", payload, want)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServiceNotificationsSendStartAndStopOnce(t *testing.T) {
|
||||
requests := make(chan string, 4)
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
requests <- string(body)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
d := &Daemon{
|
||||
Config: &config.Config{NotifyEnabled: true, NotifyType: "generic", NotifyWebhook: srv.URL},
|
||||
Client: srv.Client(),
|
||||
}
|
||||
d.NotifyServiceStart(context.Background())
|
||||
d.NotifyServiceStart(context.Background())
|
||||
d.NotifyServiceStop(context.Background())
|
||||
d.NotifyServiceStop(context.Background())
|
||||
|
||||
var payloads []string
|
||||
for len(payloads) < 2 {
|
||||
select {
|
||||
case payload := <-requests:
|
||||
payloads = append(payloads, payload)
|
||||
case <-time.After(4 * time.Second):
|
||||
t.Fatalf("service notifications sent %d payloads, want 2", len(payloads))
|
||||
}
|
||||
}
|
||||
if !strings.Contains(payloads[0], `"action":"service_start"`) || !strings.Contains(payloads[0], "gestartet") {
|
||||
t.Fatalf("unexpected service start payload: %s", payloads[0])
|
||||
}
|
||||
if !strings.Contains(payloads[1], `"action":"service_stop"`) || !strings.Contains(payloads[1], "gestoppt") {
|
||||
t.Fatalf("unexpected service stop payload: %s", payloads[1])
|
||||
}
|
||||
select {
|
||||
case payload := <-requests:
|
||||
t.Fatalf("duplicate service notification sent: %s", payload)
|
||||
case <-time.After(150 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnbanSendsNotificationForMonitorBan(t *testing.T) {
|
||||
requests := make(chan string, 1)
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
requests <- string(body)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
store, err := db.Open(filepath.Join(t.TempDir(), "test.db"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer store.Close()
|
||||
if err := store.InsertBan(db.Ban{IP: "127.0.0.1", Reason: "rate-limit", Source: "monitor"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
d := &Daemon{
|
||||
Config: &config.Config{NotifyEnabled: true, NotifyType: "generic", NotifyWebhook: srv.URL},
|
||||
Store: store,
|
||||
FW: firewall.New(firewall.OSExecutor{}, "ADGUARD_SHIELD", []string{"53"}, "host", true),
|
||||
Client: srv.Client(),
|
||||
}
|
||||
if err := d.Unban(context.Background(), "127.0.0.1", "manual"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
select {
|
||||
case payload := <-requests:
|
||||
if !strings.Contains(payload, `"action":"unban"`) || !strings.Contains(payload, "AdGuard Shield Freigabe") {
|
||||
t.Fatalf("unexpected payload: %s", payload)
|
||||
}
|
||||
case <-time.After(4 * time.Second):
|
||||
t.Fatal("unban notification was not sent")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnbanStillSendsExternalBlocklistNotificationWhenBanNotificationsDisabled(t *testing.T) {
|
||||
requests := make(chan string, 1)
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
requests <- string(body)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
store, err := db.Open(filepath.Join(t.TempDir(), "test.db"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer store.Close()
|
||||
if err := store.InsertBan(db.Ban{IP: "127.0.0.1", Reason: "external-blocklist", Source: "external-blocklist"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
d := &Daemon{
|
||||
Config: &config.Config{NotifyEnabled: true, NotifyType: "generic", NotifyWebhook: srv.URL, ExternalBlocklistNotify: false},
|
||||
Store: store,
|
||||
FW: firewall.New(firewall.OSExecutor{}, "ADGUARD_SHIELD", []string{"53"}, "host", true),
|
||||
Client: srv.Client(),
|
||||
}
|
||||
if err := d.Unban(context.Background(), "127.0.0.1", "external-blocklist-removed"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
select {
|
||||
case payload := <-requests:
|
||||
if !strings.Contains(payload, `"action":"unban"`) || !strings.Contains(payload, "AdGuard Shield Freigabe") {
|
||||
t.Fatalf("unexpected payload: %s", payload)
|
||||
}
|
||||
case <-time.After(4 * time.Second):
|
||||
t.Fatal("external blocklist unban notification was not sent")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnbanQuietSkipsIndividualNotificationAndBulkSummarySendsOnce(t *testing.T) {
|
||||
requests := make(chan string, 2)
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
requests <- string(body)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
store, err := db.Open(filepath.Join(t.TempDir(), "test.db"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer store.Close()
|
||||
if err := store.InsertBan(db.Ban{IP: "127.0.0.1", Reason: "rate-limit", Source: "monitor"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
d := &Daemon{
|
||||
Config: &config.Config{NotifyEnabled: true, NotifyType: "generic", NotifyWebhook: srv.URL},
|
||||
Store: store,
|
||||
FW: firewall.New(firewall.OSExecutor{}, "ADGUARD_SHIELD", []string{"53"}, "host", true),
|
||||
Client: srv.Client(),
|
||||
}
|
||||
if err := d.UnbanQuiet(context.Background(), "127.0.0.1", "manual-flush"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
select {
|
||||
case payload := <-requests:
|
||||
t.Fatalf("quiet unban sent individual notification: %s", payload)
|
||||
case <-time.After(150 * time.Millisecond):
|
||||
}
|
||||
|
||||
d.NotifyBulkUnban(context.Background(), "manual-flush", 1)
|
||||
select {
|
||||
case payload := <-requests:
|
||||
if !strings.Contains(payload, `"action":"manual-flush"`) || !strings.Contains(payload, "Bulk-Freigabe") || !strings.Contains(payload, "Freigegebene IPs: 1") {
|
||||
t.Fatalf("unexpected payload: %s", payload)
|
||||
}
|
||||
case <-time.After(4 * time.Second):
|
||||
t.Fatal("bulk unban notification was not sent")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAbuseReportingScope(t *testing.T) {
|
||||
d := &Daemon{Config: &config.Config{AbuseIPDBEnabled: true, AbuseIPDBAPIKey: "key"}}
|
||||
if !d.shouldReportAbuseIPDB(db.Ban{Permanent: true, Source: "monitor"}) {
|
||||
t.Fatal("monitor permanent ban should be reported")
|
||||
}
|
||||
if d.shouldReportAbuseIPDB(db.Ban{Permanent: true, Source: "geoip"}) {
|
||||
t.Fatal("geoip ban must not be reported")
|
||||
}
|
||||
if d.shouldReportAbuseIPDB(db.Ban{Permanent: false, Source: "monitor"}) {
|
||||
t.Fatal("temporary ban must not be reported")
|
||||
}
|
||||
|
||||
d.Config.RateLimitWindow = 60
|
||||
got := d.abuseIPDBComment(db.Ban{Count: 110, Domain: "abb.com", Reason: "rate-limit"})
|
||||
want := "DNS flooding on our DNS server: 110x abb.com in 60s. Banned by Adguard Shield 🔗 https://tnvs.de/as"
|
||||
if got != want {
|
||||
t.Fatalf("comment = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAbuseIPDBCheckURL(t *testing.T) {
|
||||
if got := abuseIPDBCheckURL("65.185.189.75"); !strings.Contains(got, "https://www.abuseipdb.com/check/65.185.189.75") {
|
||||
t.Fatalf("unexpected AbuseIPDB url: %s", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseDomain(t *testing.T) {
|
||||
if got := baseDomain("a.b.example.com"); got != "example.com" {
|
||||
t.Fatalf("unexpected base domain: %s", got)
|
||||
}
|
||||
if got := baseDomain("a.b.example.co.uk"); got != "example.co.uk" {
|
||||
t.Fatalf("unexpected multipart base domain: %s", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDryRunDoesNotInsertActiveBan(t *testing.T) {
|
||||
store, err := db.Open(filepath.Join(t.TempDir(), "test.db"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer store.Close()
|
||||
d := &Daemon{
|
||||
Config: &config.Config{DryRun: true, BanDuration: 60},
|
||||
Store: store,
|
||||
FW: firewall.New(firewall.OSExecutor{}, "ADGUARD_SHIELD", []string{"53"}, "host", true),
|
||||
wl: map[string]bool{},
|
||||
}
|
||||
if err := d.Ban(context.Background(), "1.2.3.4", "example.com", 99, "dns", "rate-limit", "monitor", "", false); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ok, err := store.BanExists("1.2.3.4")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if ok {
|
||||
t.Fatal("dry-run must not create an active ban")
|
||||
}
|
||||
}
|
||||
384
internal/daemon/live.go
Normal file
384
internal/daemon/live.go
Normal file
@@ -0,0 +1,384 @@
|
||||
package daemon
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"adguard-shield/internal/db"
|
||||
"adguard-shield/internal/syslog"
|
||||
)
|
||||
|
||||
type LiveOptions struct {
|
||||
Interval time.Duration
|
||||
Top int
|
||||
Recent int
|
||||
LogLevel string
|
||||
Once bool
|
||||
}
|
||||
|
||||
type liveSnapshot struct {
|
||||
At time.Time
|
||||
APIEntries int
|
||||
Window int
|
||||
Limit int
|
||||
Events []queryEvent
|
||||
TopPairs []liveCount
|
||||
SubdomainGroups []liveCount
|
||||
ActiveBans []db.Ban
|
||||
Offenses int
|
||||
ExpiredOffenses int
|
||||
WhitelistCount int
|
||||
BlocklistBans int
|
||||
SystemLogs []string
|
||||
}
|
||||
|
||||
type liveCount struct {
|
||||
Client string
|
||||
Domain string
|
||||
Count int
|
||||
Protocol string
|
||||
}
|
||||
|
||||
func (d *Daemon) Live(ctx context.Context, w io.Writer, opts LiveOptions) error {
|
||||
if opts.Interval <= 0 {
|
||||
opts.Interval = time.Duration(d.Config.CheckInterval) * time.Second
|
||||
}
|
||||
if opts.Interval <= 0 {
|
||||
opts.Interval = 2 * time.Second
|
||||
}
|
||||
if opts.Top <= 0 {
|
||||
opts.Top = 10
|
||||
}
|
||||
if opts.Recent <= 0 {
|
||||
opts.Recent = 12
|
||||
}
|
||||
if strings.TrimSpace(opts.LogLevel) == "" {
|
||||
opts.LogLevel = "INFO"
|
||||
}
|
||||
|
||||
for {
|
||||
snap, err := d.liveSnapshot(ctx, opts)
|
||||
renderLive(w, d, snap, err, opts)
|
||||
if opts.Once {
|
||||
return err
|
||||
}
|
||||
timer := time.NewTimer(opts.Interval)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
timer.Stop()
|
||||
return ctx.Err()
|
||||
case <-timer.C:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Daemon) liveSnapshot(ctx context.Context, opts LiveOptions) (liveSnapshot, error) {
|
||||
snap := liveSnapshot{
|
||||
At: time.Now(),
|
||||
Window: d.Config.RateLimitWindow,
|
||||
Limit: d.Config.RateLimitMaxRequests,
|
||||
}
|
||||
items, err := d.FetchQueryLog(ctx)
|
||||
if err != nil {
|
||||
return snap, err
|
||||
}
|
||||
snap.APIEntries = len(items)
|
||||
events := dedupeEvents(d.toEvents(items))
|
||||
sort.Slice(events, func(i, j int) bool { return events[i].At.After(events[j].At) })
|
||||
if len(events) > opts.Recent {
|
||||
snap.Events = append([]queryEvent(nil), events[:opts.Recent]...)
|
||||
} else {
|
||||
snap.Events = append([]queryEvent(nil), events...)
|
||||
}
|
||||
snap.TopPairs = topQueryPairs(events, d.Config.RateLimitWindow, opts.Top)
|
||||
snap.SubdomainGroups = topSubdomainGroups(events, d.Config.SubdomainFloodWindow, opts.Top)
|
||||
|
||||
if bans, err := d.Store.ActiveBans(); err == nil {
|
||||
snap.ActiveBans = bans
|
||||
}
|
||||
if n, err := d.Store.CountOffenses(); err == nil {
|
||||
snap.Offenses = n
|
||||
}
|
||||
if n, err := d.Store.CountExpiredOffenses(d.Config.ProgressiveBanResetAfter); err == nil {
|
||||
snap.ExpiredOffenses = n
|
||||
}
|
||||
if wl, err := d.Store.AllWhitelist(); err == nil {
|
||||
snap.WhitelistCount = len(wl)
|
||||
}
|
||||
if n, err := d.Store.CountBySource("external-blocklist"); err == nil {
|
||||
snap.BlocklistBans = n
|
||||
}
|
||||
snap.SystemLogs = RecentLogLines(d.Config.LogFile, opts.LogLevel, opts.Recent)
|
||||
return snap, nil
|
||||
}
|
||||
|
||||
func dedupeEvents(events []queryEvent) []queryEvent {
|
||||
seen := map[string]bool{}
|
||||
out := make([]queryEvent, 0, len(events))
|
||||
for _, ev := range events {
|
||||
key := ev.At.Format(time.RFC3339Nano) + "|" + ev.Client + "|" + ev.Domain + "|" + ev.Protocol
|
||||
if seen[key] {
|
||||
continue
|
||||
}
|
||||
seen[key] = true
|
||||
out = append(out, ev)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func topQueryPairs(events []queryEvent, window, limit int) []liveCount {
|
||||
cut := time.Now().Add(-time.Duration(window) * time.Second)
|
||||
counts := map[string]*liveCount{}
|
||||
protos := map[string]map[string]bool{}
|
||||
for _, ev := range events {
|
||||
if ev.At.Before(cut) {
|
||||
continue
|
||||
}
|
||||
key := ev.Client + "|" + ev.Domain
|
||||
if counts[key] == nil {
|
||||
counts[key] = &liveCount{Client: ev.Client, Domain: ev.Domain}
|
||||
protos[key] = map[string]bool{}
|
||||
}
|
||||
counts[key].Count++
|
||||
protos[key][formatProtocol(ev.Protocol)] = true
|
||||
}
|
||||
out := make([]liveCount, 0, len(counts))
|
||||
for key, item := range counts {
|
||||
item.Protocol = strings.Join(sortedKeys(protos[key]), ",")
|
||||
out = append(out, *item)
|
||||
}
|
||||
sort.Slice(out, func(i, j int) bool {
|
||||
if out[i].Count == out[j].Count {
|
||||
return out[i].Client+"|"+out[i].Domain < out[j].Client+"|"+out[j].Domain
|
||||
}
|
||||
return out[i].Count > out[j].Count
|
||||
})
|
||||
if limit > 0 && len(out) > limit {
|
||||
return out[:limit]
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func topSubdomainGroups(events []queryEvent, window, limit int) []liveCount {
|
||||
cut := time.Now().Add(-time.Duration(window) * time.Second)
|
||||
sets := map[string]map[string]bool{}
|
||||
protos := map[string]map[string]bool{}
|
||||
for _, ev := range events {
|
||||
if ev.At.Before(cut) {
|
||||
continue
|
||||
}
|
||||
base := baseDomain(ev.Domain)
|
||||
if base == "" || base == ev.Domain {
|
||||
continue
|
||||
}
|
||||
key := ev.Client + "|" + base
|
||||
if sets[key] == nil {
|
||||
sets[key] = map[string]bool{}
|
||||
protos[key] = map[string]bool{}
|
||||
}
|
||||
sets[key][ev.Domain] = true
|
||||
protos[key][formatProtocol(ev.Protocol)] = true
|
||||
}
|
||||
out := make([]liveCount, 0, len(sets))
|
||||
for key, set := range sets {
|
||||
client, domain, _ := strings.Cut(key, "|")
|
||||
out = append(out, liveCount{
|
||||
Client: client,
|
||||
Domain: domain,
|
||||
Count: len(set),
|
||||
Protocol: strings.Join(sortedKeys(protos[key]), ","),
|
||||
})
|
||||
}
|
||||
sort.Slice(out, func(i, j int) bool {
|
||||
if out[i].Count == out[j].Count {
|
||||
return out[i].Client+"|"+out[i].Domain < out[j].Client+"|"+out[j].Domain
|
||||
}
|
||||
return out[i].Count > out[j].Count
|
||||
})
|
||||
if limit > 0 && len(out) > limit {
|
||||
return out[:limit]
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func renderLive(w io.Writer, d *Daemon, snap liveSnapshot, snapErr error, opts LiveOptions) {
|
||||
fmt.Fprint(w, "\033[H\033[2J")
|
||||
fmt.Fprintf(w, "AdGuard Shield Live | %s | Strg+C beendet\n", snap.At.Format("2006-01-02 15:04:05"))
|
||||
fmt.Fprintln(w, strings.Repeat("=", 92))
|
||||
fmt.Fprintf(w, "Config: %s | API: %s | Log: %s (ab %s)\n", d.Config.Path, d.Config.AdGuardURL, d.Config.LogFile, strings.ToUpper(opts.LogLevel))
|
||||
if snapErr != nil {
|
||||
fmt.Fprintf(w, "\nFEHLER: Live-Snapshot konnte nicht geladen werden: %v\n", snapErr)
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Fprintf(w, "\nWorker und Module\n")
|
||||
fmt.Fprintf(w, " Query-Poller: alle %ds | API-Eintraege: %d | Zeitfenster: %ds | Limit: %d\n", d.Config.CheckInterval, snap.APIEntries, snap.Window, snap.Limit)
|
||||
fmt.Fprintf(w, " GeoIP: %s | Modus: %s | Laender: %s\n", enabled(d.Config.GeoIPEnabled), d.Config.GeoIPMode, listOrDash(d.Config.GeoIPCountries))
|
||||
fmt.Fprintf(w, " Externe Blocklist: %s | Intervall: %ds | URLs: %d | aktive Sperren: %d\n", enabled(d.Config.ExternalBlocklistEnabled), d.Config.ExternalBlocklistInterval, len(d.Config.ExternalBlocklistURLs), snap.BlocklistBans)
|
||||
fmt.Fprintf(w, " Externe Whitelist: %s | Intervall: %ds | URLs: %d | aufgeloeste IPs: %d\n", enabled(d.Config.ExternalWhitelistEnabled), d.Config.ExternalWhitelistInterval, len(d.Config.ExternalWhitelistURLs), snap.WhitelistCount)
|
||||
fmt.Fprintf(w, " Offense-Cleanup: %s | Zaehler: %d | davon abgelaufen: %d\n", enabled(d.Config.ProgressiveBanEnabled), snap.Offenses, snap.ExpiredOffenses)
|
||||
|
||||
fmt.Fprintf(w, "\nTop Client/Domain im Rate-Limit-Fenster\n")
|
||||
if len(snap.TopPairs) == 0 {
|
||||
fmt.Fprintln(w, " Keine Anfragen im aktuellen Zeitfenster.")
|
||||
} else {
|
||||
for _, item := range snap.TopPairs {
|
||||
fmt.Fprintf(w, " %5s %-39s %-34s %s\n", fmt.Sprintf("%d/%d", item.Count, snap.Limit), trim(item.Client, 39), trim(item.Domain, 34), item.Protocol)
|
||||
}
|
||||
}
|
||||
|
||||
if d.Config.SubdomainFloodEnabled {
|
||||
fmt.Fprintf(w, "\nSubdomain-Flood-Kandidaten\n")
|
||||
if len(snap.SubdomainGroups) == 0 {
|
||||
fmt.Fprintln(w, " Keine Subdomain-Gruppen im aktuellen Zeitfenster.")
|
||||
} else {
|
||||
for _, item := range snap.SubdomainGroups {
|
||||
fmt.Fprintf(w, " %5s %-39s %-34s %s\n", fmt.Sprintf("%d/%d", item.Count, d.Config.SubdomainFloodMaxUnique), trim(item.Client, 39), trim(item.Domain, 34), item.Protocol)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Fprintf(w, "\nLetzte Queries\n")
|
||||
if len(snap.Events) == 0 {
|
||||
fmt.Fprintln(w, " Keine Querylog-Eintraege gefunden.")
|
||||
} else {
|
||||
for _, ev := range snap.Events {
|
||||
fmt.Fprintf(w, " %s %-39s %-8s %s\n", ev.At.Local().Format("15:04:05"), trim(ev.Client, 39), formatProtocol(ev.Protocol), trim(ev.Domain, 44))
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Fprintf(w, "\nAktive Sperren\n")
|
||||
if len(snap.ActiveBans) == 0 {
|
||||
fmt.Fprintln(w, " Keine aktiven Sperren.")
|
||||
} else {
|
||||
maxBans := opts.Top
|
||||
if len(snap.ActiveBans) < maxBans {
|
||||
maxBans = len(snap.ActiveBans)
|
||||
}
|
||||
for _, b := range snap.ActiveBans[:maxBans] {
|
||||
fmt.Fprintf(w, " %-39s %-20s %-18s %s\n", trim(b.IP, 39), trim(b.Source, 20), trim(b.Reason, 18), banUntil(b))
|
||||
}
|
||||
if len(snap.ActiveBans) > maxBans {
|
||||
fmt.Fprintf(w, " ... %d weitere\n", len(snap.ActiveBans)-maxBans)
|
||||
}
|
||||
}
|
||||
|
||||
if strings.ToLower(opts.LogLevel) != "off" {
|
||||
fmt.Fprintf(w, "\nSystemereignisse\n")
|
||||
if len(snap.SystemLogs) == 0 {
|
||||
fmt.Fprintln(w, " Keine passenden Logeintraege.")
|
||||
} else {
|
||||
for _, line := range snap.SystemLogs {
|
||||
fmt.Fprintf(w, " %s\n", trim(line, 88))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func RecentLogLines(path, minLevel string, limit int) []string {
|
||||
if strings.EqualFold(strings.TrimSpace(minLevel), "off") || path == "" || limit <= 0 {
|
||||
return nil
|
||||
}
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
defer f.Close()
|
||||
min := syslog.ParseLevel(minLevel, syslog.Info)
|
||||
ring := make([]string, limit)
|
||||
count := 0
|
||||
sc := bufio.NewScanner(f)
|
||||
sc.Buffer(make([]byte, 1024), 1024*1024)
|
||||
for sc.Scan() {
|
||||
line := sc.Text()
|
||||
if logLineLevel(line) < min {
|
||||
continue
|
||||
}
|
||||
ring[count%limit] = line
|
||||
count++
|
||||
}
|
||||
n := count
|
||||
if n > limit {
|
||||
n = limit
|
||||
}
|
||||
out := make([]string, 0, n)
|
||||
start := count - n
|
||||
for i := 0; i < n; i++ {
|
||||
out = append(out, ring[(start+i)%limit])
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func logLineLevel(line string) syslog.Level {
|
||||
for _, level := range []syslog.Level{syslog.Error, syslog.Warn, syslog.Info, syslog.Debug} {
|
||||
if strings.Contains(line, "["+syslog.LevelName(level)+"]") {
|
||||
return level
|
||||
}
|
||||
}
|
||||
return syslog.Info
|
||||
}
|
||||
|
||||
func sortedKeys(m map[string]bool) []string {
|
||||
keys := make([]string, 0, len(m))
|
||||
for k := range m {
|
||||
if k != "" {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
}
|
||||
sort.Strings(keys)
|
||||
return keys
|
||||
}
|
||||
|
||||
func formatProtocol(proto string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(proto)) {
|
||||
case "doh":
|
||||
return "DoH"
|
||||
case "dot":
|
||||
return "DoT"
|
||||
case "doq":
|
||||
return "DoQ"
|
||||
case "dnscrypt":
|
||||
return "DNSCrypt"
|
||||
case "", "dns":
|
||||
return "DNS"
|
||||
default:
|
||||
return proto
|
||||
}
|
||||
}
|
||||
|
||||
func enabled(ok bool) string {
|
||||
if ok {
|
||||
return "aktiv"
|
||||
}
|
||||
return "inaktiv"
|
||||
}
|
||||
|
||||
func listOrDash(items []string) string {
|
||||
if len(items) == 0 {
|
||||
return "-"
|
||||
}
|
||||
return strings.Join(items, ",")
|
||||
}
|
||||
|
||||
func trim(s string, max int) string {
|
||||
if len(s) <= max {
|
||||
return s
|
||||
}
|
||||
if max <= 1 {
|
||||
return s[:max]
|
||||
}
|
||||
return s[:max-1] + "~"
|
||||
}
|
||||
|
||||
func banUntil(b db.Ban) string {
|
||||
if b.Permanent || b.BanUntil == 0 {
|
||||
return "permanent"
|
||||
}
|
||||
return time.Unix(b.BanUntil, 0).Format("2006-01-02 15:04:05")
|
||||
}
|
||||
408
internal/db/db.go
Normal file
408
internal/db/db.go
Normal file
@@ -0,0 +1,408 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
type Store struct{ DB *sql.DB }
|
||||
|
||||
type Ban struct {
|
||||
IP string
|
||||
Domain string
|
||||
Count int
|
||||
BanUntil int64
|
||||
Duration int64
|
||||
OffenseLevel int
|
||||
Permanent bool
|
||||
Reason string
|
||||
Protocol string
|
||||
Source string
|
||||
GeoIPCountry string
|
||||
GeoIPMode string
|
||||
}
|
||||
|
||||
type ReportStats struct {
|
||||
Since int64
|
||||
Until int64
|
||||
TotalBans int
|
||||
TotalUnbans int
|
||||
ActiveBans int
|
||||
TopClients []ReportCount
|
||||
Reasons []ReportCount
|
||||
Sources []ReportCount
|
||||
RecentEvents []string
|
||||
}
|
||||
|
||||
type ReportCount struct {
|
||||
Name string
|
||||
Count int
|
||||
}
|
||||
|
||||
func Open(path string) (*Store, error) {
|
||||
db, err := sql.Open("sqlite", path+"?_pragma=busy_timeout(5000)&_pragma=journal_mode(WAL)")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s := &Store{DB: db}
|
||||
if err := s.Init(); err != nil {
|
||||
db.Close()
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *Store) Close() error { return s.DB.Close() }
|
||||
|
||||
func (s *Store) Init() error {
|
||||
schema := `
|
||||
PRAGMA journal_mode=WAL;
|
||||
PRAGMA busy_timeout=5000;
|
||||
CREATE TABLE IF NOT EXISTS schema_version (version INTEGER PRIMARY KEY, applied_at TEXT DEFAULT (datetime('now', 'localtime')));
|
||||
CREATE TABLE IF NOT EXISTS active_bans (
|
||||
client_ip TEXT PRIMARY KEY, domain TEXT, count INTEGER, ban_time TEXT,
|
||||
ban_until_epoch INTEGER DEFAULT 0, ban_duration INTEGER DEFAULT 0, offense_level INTEGER DEFAULT 0,
|
||||
is_permanent INTEGER DEFAULT 0, reason TEXT DEFAULT 'rate-limit', protocol TEXT DEFAULT 'DNS',
|
||||
source TEXT DEFAULT 'monitor', geoip_country TEXT, geoip_mode TEXT, created_at TEXT DEFAULT (datetime('now', 'localtime')));
|
||||
CREATE TABLE IF NOT EXISTS offense_tracking (
|
||||
client_ip TEXT PRIMARY KEY, offense_level INTEGER DEFAULT 0, last_offense_epoch INTEGER,
|
||||
last_offense TEXT, first_offense TEXT, created_at TEXT DEFAULT (datetime('now', 'localtime')),
|
||||
updated_at TEXT DEFAULT (datetime('now', 'localtime')));
|
||||
CREATE TABLE IF NOT EXISTS ban_history (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT, timestamp_epoch INTEGER NOT NULL, timestamp_text TEXT NOT NULL,
|
||||
action TEXT NOT NULL, client_ip TEXT NOT NULL, domain TEXT, count TEXT, duration TEXT, protocol TEXT, reason TEXT);
|
||||
CREATE TABLE IF NOT EXISTS whitelist_cache (ip_address TEXT PRIMARY KEY, source TEXT, resolved_at TEXT DEFAULT (datetime('now', 'localtime')));
|
||||
CREATE TABLE IF NOT EXISTS geoip_cache (ip TEXT PRIMARY KEY, country_code TEXT NOT NULL, looked_up_at_epoch INTEGER NOT NULL, db_mtime INTEGER DEFAULT 0);
|
||||
CREATE INDEX IF NOT EXISTS idx_bans_until ON active_bans(ban_until_epoch);
|
||||
CREATE INDEX IF NOT EXISTS idx_bans_source ON active_bans(source);
|
||||
CREATE INDEX IF NOT EXISTS idx_bans_reason ON active_bans(reason);
|
||||
CREATE INDEX IF NOT EXISTS idx_history_timestamp ON ban_history(timestamp_epoch);
|
||||
CREATE INDEX IF NOT EXISTS idx_history_action ON ban_history(action);
|
||||
CREATE INDEX IF NOT EXISTS idx_history_ip ON ban_history(client_ip);
|
||||
CREATE INDEX IF NOT EXISTS idx_offenses_last ON offense_tracking(last_offense_epoch);
|
||||
CREATE INDEX IF NOT EXISTS idx_geoip_cache_age ON geoip_cache(looked_up_at_epoch);
|
||||
INSERT OR IGNORE INTO schema_version (version) VALUES (1);`
|
||||
_, err := s.DB.Exec(schema)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Store) BanExists(ip string) (bool, error) {
|
||||
var one int
|
||||
err := s.DB.QueryRow(`SELECT 1 FROM active_bans WHERE client_ip=? LIMIT 1`, ip).Scan(&one)
|
||||
if err == sql.ErrNoRows {
|
||||
return false, nil
|
||||
}
|
||||
return err == nil, err
|
||||
}
|
||||
|
||||
func (s *Store) InsertBan(b Ban) error {
|
||||
now := time.Now()
|
||||
perm := 0
|
||||
if b.Permanent {
|
||||
perm = 1
|
||||
}
|
||||
_, err := s.DB.Exec(`INSERT OR REPLACE INTO active_bans
|
||||
(client_ip, domain, count, ban_time, ban_until_epoch, ban_duration, offense_level, is_permanent, reason, protocol, source, geoip_country, geoip_mode)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
b.IP, b.Domain, b.Count, now.Format("2006-01-02 15:04:05"), b.BanUntil, b.Duration, b.OffenseLevel, perm,
|
||||
b.Reason, b.Protocol, b.Source, b.GeoIPCountry, b.GeoIPMode)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Store) DeleteBan(ip string) error {
|
||||
_, err := s.DB.Exec(`DELETE FROM active_bans WHERE client_ip=?`, ip)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Store) ActiveBans() ([]Ban, error) {
|
||||
rows, err := s.DB.Query(`SELECT client_ip, COALESCE(domain,''), COALESCE(count,0), COALESCE(ban_until_epoch,0),
|
||||
COALESCE(ban_duration,0), COALESCE(offense_level,0), COALESCE(is_permanent,0), COALESCE(reason,''), COALESCE(protocol,''),
|
||||
COALESCE(source,''), COALESCE(geoip_country,''), COALESCE(geoip_mode,'') FROM active_bans ORDER BY created_at DESC`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var out []Ban
|
||||
for rows.Next() {
|
||||
var b Ban
|
||||
var perm int
|
||||
if err := rows.Scan(&b.IP, &b.Domain, &b.Count, &b.BanUntil, &b.Duration, &b.OffenseLevel, &perm, &b.Reason, &b.Protocol, &b.Source, &b.GeoIPCountry, &b.GeoIPMode); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
b.Permanent = perm == 1
|
||||
out = append(out, b)
|
||||
}
|
||||
return out, rows.Err()
|
||||
}
|
||||
|
||||
func (s *Store) BansBySource(source string) ([]Ban, error) {
|
||||
rows, err := s.DB.Query(`SELECT client_ip, COALESCE(domain,''), COALESCE(count,0), COALESCE(ban_until_epoch,0),
|
||||
COALESCE(ban_duration,0), COALESCE(offense_level,0), COALESCE(is_permanent,0), COALESCE(reason,''), COALESCE(protocol,''),
|
||||
COALESCE(source,''), COALESCE(geoip_country,''), COALESCE(geoip_mode,'') FROM active_bans WHERE source=?`, source)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var out []Ban
|
||||
for rows.Next() {
|
||||
var b Ban
|
||||
var perm int
|
||||
if err := rows.Scan(&b.IP, &b.Domain, &b.Count, &b.BanUntil, &b.Duration, &b.OffenseLevel, &perm, &b.Reason, &b.Protocol, &b.Source, &b.GeoIPCountry, &b.GeoIPMode); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
b.Permanent = perm == 1
|
||||
out = append(out, b)
|
||||
}
|
||||
return out, rows.Err()
|
||||
}
|
||||
|
||||
func (s *Store) BansByReason(reason string) ([]Ban, error) {
|
||||
rows, err := s.DB.Query(`SELECT client_ip, COALESCE(domain,''), COALESCE(count,0), COALESCE(ban_until_epoch,0),
|
||||
COALESCE(ban_duration,0), COALESCE(offense_level,0), COALESCE(is_permanent,0), COALESCE(reason,''), COALESCE(protocol,''),
|
||||
COALESCE(source,''), COALESCE(geoip_country,''), COALESCE(geoip_mode,'') FROM active_bans WHERE reason=?`, reason)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var out []Ban
|
||||
for rows.Next() {
|
||||
var b Ban
|
||||
var perm int
|
||||
if err := rows.Scan(&b.IP, &b.Domain, &b.Count, &b.BanUntil, &b.Duration, &b.OffenseLevel, &perm, &b.Reason, &b.Protocol, &b.Source, &b.GeoIPCountry, &b.GeoIPMode); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
b.Permanent = perm == 1
|
||||
out = append(out, b)
|
||||
}
|
||||
return out, rows.Err()
|
||||
}
|
||||
|
||||
func (s *Store) CountBySource(source string) (int, error) {
|
||||
var count int
|
||||
err := s.DB.QueryRow(`SELECT COUNT(*) FROM active_bans WHERE source=?`, source).Scan(&count)
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (s *Store) ExpiredBans(now int64) ([]string, error) {
|
||||
rows, err := s.DB.Query(`SELECT client_ip FROM active_bans WHERE ban_until_epoch > 0 AND is_permanent = 0 AND ban_until_epoch <= ?`, now)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var ips []string
|
||||
for rows.Next() {
|
||||
var ip string
|
||||
if err := rows.Scan(&ip); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ips = append(ips, ip)
|
||||
}
|
||||
return ips, rows.Err()
|
||||
}
|
||||
|
||||
func (s *Store) History(action, ip, domain, count, duration, protocol, reason string) error {
|
||||
now := time.Now()
|
||||
_, err := s.DB.Exec(`INSERT INTO ban_history (timestamp_epoch, timestamp_text, action, client_ip, domain, count, duration, protocol, reason)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`, now.Unix(), now.Format("2006-01-02 15:04:05"), action, ip, domain, count, duration, protocol, reason)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Store) RecentHistory(limit int) ([]string, error) {
|
||||
rows, err := s.DB.Query(`SELECT timestamp_text, action, client_ip, COALESCE(domain,''), COALESCE(count,''), COALESCE(duration,''), COALESCE(protocol,''), COALESCE(reason,'')
|
||||
FROM ban_history ORDER BY id DESC LIMIT ?`, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var out []string
|
||||
for rows.Next() {
|
||||
var ts, action, ip, domain, count, duration, proto, reason string
|
||||
if err := rows.Scan(&ts, &action, &ip, &domain, &count, &duration, &proto, &reason); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = append(out, fmt.Sprintf("%s | %s | %s | %s | %s | %s | %s | %s", ts, action, ip, domain, count, duration, proto, reason))
|
||||
}
|
||||
return out, rows.Err()
|
||||
}
|
||||
|
||||
func (s *Store) WhitelistContains(ip string) (bool, error) {
|
||||
var one int
|
||||
err := s.DB.QueryRow(`SELECT 1 FROM whitelist_cache WHERE ip_address=? LIMIT 1`, ip).Scan(&one)
|
||||
if err == sql.ErrNoRows {
|
||||
return false, nil
|
||||
}
|
||||
return err == nil, err
|
||||
}
|
||||
|
||||
func (s *Store) ReplaceWhitelist(ips []string, source string) error {
|
||||
tx, err := s.DB.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
if _, err := tx.Exec(`DELETE FROM whitelist_cache WHERE source=? OR source IS NULL`, source); err != nil {
|
||||
return err
|
||||
}
|
||||
stmt, err := tx.Prepare(`INSERT OR IGNORE INTO whitelist_cache (ip_address, source) VALUES (?, ?)`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
for _, ip := range ips {
|
||||
if _, err := stmt.Exec(ip, source); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (s *Store) AllWhitelist() (map[string]bool, error) {
|
||||
rows, err := s.DB.Query(`SELECT ip_address FROM whitelist_cache`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
out := map[string]bool{}
|
||||
for rows.Next() {
|
||||
var ip string
|
||||
if err := rows.Scan(&ip); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out[ip] = true
|
||||
}
|
||||
return out, rows.Err()
|
||||
}
|
||||
|
||||
func (s *Store) IncrementOffense(ip string, resetAfter int64) (int, error) {
|
||||
now := time.Now()
|
||||
var level int
|
||||
var last int64
|
||||
var first string
|
||||
err := s.DB.QueryRow(`SELECT offense_level, COALESCE(last_offense_epoch,0), COALESCE(first_offense,'') FROM offense_tracking WHERE client_ip=?`, ip).Scan(&level, &last, &first)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return 0, err
|
||||
}
|
||||
if err == sql.ErrNoRows || (last > 0 && now.Unix()-last > resetAfter) {
|
||||
level = 0
|
||||
first = now.Format("2006-01-02 15:04:05")
|
||||
}
|
||||
level++
|
||||
_, err = s.DB.Exec(`INSERT OR REPLACE INTO offense_tracking (client_ip, offense_level, last_offense_epoch, last_offense, first_offense, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)`, ip, level, now.Unix(), now.Format("2006-01-02 15:04:05"), first, now.Format("2006-01-02 15:04:05"))
|
||||
return level, err
|
||||
}
|
||||
|
||||
func (s *Store) ResetOffense(ip string) error {
|
||||
if ip == "" {
|
||||
_, err := s.DB.Exec(`DELETE FROM offense_tracking`)
|
||||
return err
|
||||
}
|
||||
_, err := s.DB.Exec(`DELETE FROM offense_tracking WHERE client_ip=?`, ip)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Store) CleanupOffenses(resetAfter int64) (int64, error) {
|
||||
cutoff := time.Now().Unix() - resetAfter
|
||||
res, err := s.DB.Exec(`DELETE FROM offense_tracking WHERE last_offense_epoch <= ?`, cutoff)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return res.RowsAffected()
|
||||
}
|
||||
|
||||
func (s *Store) CountOffenses() (int, error) {
|
||||
var count int
|
||||
err := s.DB.QueryRow(`SELECT COUNT(*) FROM offense_tracking`).Scan(&count)
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (s *Store) CountExpiredOffenses(resetAfter int64) (int, error) {
|
||||
var count int
|
||||
cutoff := time.Now().Unix() - resetAfter
|
||||
err := s.DB.QueryRow(`SELECT COUNT(*) FROM offense_tracking WHERE last_offense_epoch <= ?`, cutoff).Scan(&count)
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (s *Store) LoadGeoIPCache(ttl, dbMtime int64) (map[string]string, error) {
|
||||
rows, err := s.DB.Query(`SELECT ip, country_code FROM geoip_cache WHERE looked_up_at_epoch >= ? AND (db_mtime=? OR db_mtime=0)`, time.Now().Unix()-ttl, dbMtime)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
out := map[string]string{}
|
||||
for rows.Next() {
|
||||
var ip, cc string
|
||||
if err := rows.Scan(&ip, &cc); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out[ip] = cc
|
||||
}
|
||||
return out, rows.Err()
|
||||
}
|
||||
|
||||
func (s *Store) UpsertGeoIP(ip, country string, dbMtime int64) error {
|
||||
_, err := s.DB.Exec(`INSERT OR REPLACE INTO geoip_cache (ip, country_code, looked_up_at_epoch, db_mtime) VALUES (?, ?, ?, ?)`, ip, country, time.Now().Unix(), dbMtime)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Store) ClearGeoIPCache() (int64, error) {
|
||||
res, err := s.DB.Exec(`DELETE FROM geoip_cache`)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return res.RowsAffected()
|
||||
}
|
||||
|
||||
func (s *Store) ReportStats(since, until int64, limit int) (ReportStats, error) {
|
||||
st := ReportStats{Since: since, Until: until}
|
||||
if err := s.DB.QueryRow(`SELECT COUNT(*) FROM ban_history WHERE action='BAN' AND timestamp_epoch BETWEEN ? AND ?`, since, until).Scan(&st.TotalBans); err != nil {
|
||||
return st, err
|
||||
}
|
||||
if err := s.DB.QueryRow(`SELECT COUNT(*) FROM ban_history WHERE action='UNBAN' AND timestamp_epoch BETWEEN ? AND ?`, since, until).Scan(&st.TotalUnbans); err != nil {
|
||||
return st, err
|
||||
}
|
||||
if err := s.DB.QueryRow(`SELECT COUNT(*) FROM active_bans`).Scan(&st.ActiveBans); err != nil {
|
||||
return st, err
|
||||
}
|
||||
var err error
|
||||
st.TopClients, err = s.reportCounts(`SELECT client_ip, COUNT(*) FROM ban_history WHERE action='BAN' AND timestamp_epoch BETWEEN ? AND ? GROUP BY client_ip ORDER BY COUNT(*) DESC, client_ip LIMIT ?`, since, until, limit)
|
||||
if err != nil {
|
||||
return st, err
|
||||
}
|
||||
st.Reasons, err = s.reportCounts(`SELECT COALESCE(NULLIF(reason,''), 'unknown'), COUNT(*) FROM ban_history WHERE action='BAN' AND timestamp_epoch BETWEEN ? AND ? GROUP BY COALESCE(NULLIF(reason,''), 'unknown') ORDER BY COUNT(*) DESC LIMIT ?`, since, until, limit)
|
||||
if err != nil {
|
||||
return st, err
|
||||
}
|
||||
st.Sources, err = s.reportCounts(`SELECT COALESCE(NULLIF(source,''), 'unknown'), COUNT(*) FROM active_bans GROUP BY COALESCE(NULLIF(source,''), 'unknown') ORDER BY COUNT(*) DESC LIMIT ?`, 0, 0, limit)
|
||||
if err != nil {
|
||||
return st, err
|
||||
}
|
||||
st.RecentEvents, err = s.RecentHistory(limit)
|
||||
return st, err
|
||||
}
|
||||
|
||||
func (s *Store) reportCounts(query string, since, until int64, limit int) ([]ReportCount, error) {
|
||||
var rows *sql.Rows
|
||||
var err error
|
||||
if since == 0 && until == 0 {
|
||||
rows, err = s.DB.Query(query, limit)
|
||||
} else {
|
||||
rows, err = s.DB.Query(query, since, until, limit)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var out []ReportCount
|
||||
for rows.Next() {
|
||||
var item ReportCount
|
||||
if err := rows.Scan(&item.Name, &item.Count); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = append(out, item)
|
||||
}
|
||||
return out, rows.Err()
|
||||
}
|
||||
31
internal/db/db_test.go
Normal file
31
internal/db/db_test.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestStoreBanAndGeoIPCache(t *testing.T) {
|
||||
s, err := Open(filepath.Join(t.TempDir(), "test.db"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer s.Close()
|
||||
if err := s.InsertBan(Ban{IP: "1.2.3.4", Domain: "example.com", Permanent: true, Reason: "geoip", Source: "geoip", GeoIPCountry: "CN"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ok, err := s.BanExists("1.2.3.4")
|
||||
if err != nil || !ok {
|
||||
t.Fatalf("ban not found: %v %v", ok, err)
|
||||
}
|
||||
if err := s.UpsertGeoIP("1.2.3.4", "CN", 123); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
cache, err := s.LoadGeoIPCache(86400, 123)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if cache["1.2.3.4"] != "CN" {
|
||||
t.Fatalf("unexpected cache: %#v", cache)
|
||||
}
|
||||
}
|
||||
203
internal/firewall/firewall.go
Normal file
203
internal/firewall/firewall.go
Normal file
@@ -0,0 +1,203 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Executor interface {
|
||||
Run(ctx context.Context, name string, args ...string) error
|
||||
}
|
||||
|
||||
type OSExecutor struct{}
|
||||
|
||||
func (OSExecutor) Run(ctx context.Context, name string, args ...string) error {
|
||||
return exec.CommandContext(ctx, name, args...).Run()
|
||||
}
|
||||
|
||||
type Firewall struct {
|
||||
Exec Executor
|
||||
Chain string
|
||||
Ports []string
|
||||
Mode string
|
||||
DryRun bool
|
||||
Set4 string
|
||||
Set6 string
|
||||
}
|
||||
|
||||
func New(exec Executor, chain string, ports []string, mode string, dry bool) *Firewall {
|
||||
return &Firewall{Exec: exec, Chain: chain, Ports: ports, Mode: normalizeMode(mode), DryRun: dry, Set4: "adguard_shield_v4", Set6: "adguard_shield_v6"}
|
||||
}
|
||||
|
||||
func (f *Firewall) Setup(ctx context.Context) error {
|
||||
if f.DryRun {
|
||||
return nil
|
||||
}
|
||||
if len(f.hooks("iptables")) == 0 {
|
||||
return fmt.Errorf("unsupported firewall mode %q", f.Mode)
|
||||
}
|
||||
_ = f.Exec.Run(ctx, "ipset", "create", f.Set4, "hash:net", "family", "inet", "timeout", "0", "-exist")
|
||||
_ = f.Exec.Run(ctx, "ipset", "create", f.Set6, "hash:net", "family", "inet6", "timeout", "0", "-exist")
|
||||
_ = f.Exec.Run(ctx, "iptables", "-N", f.Chain)
|
||||
_ = f.Exec.Run(ctx, "ip6tables", "-N", f.Chain)
|
||||
if err := ensureSetDrop(ctx, f.Exec, "iptables", f.Chain, f.Set4); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := ensureSetDrop(ctx, f.Exec, "ip6tables", f.Chain, f.Set6); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := f.ensureHooks(ctx, "iptables"); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := f.ensureHooks(ctx, "ip6tables"); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ensureRule(ctx context.Context, ex Executor, bin string, args ...string) bool {
|
||||
return ex.Run(ctx, bin, args...) == nil
|
||||
}
|
||||
|
||||
func ensureSetDrop(ctx context.Context, ex Executor, bin, chain, set string) error {
|
||||
check := []string{"-C", chain, "-m", "set", "--match-set", set, "src", "-j", "DROP"}
|
||||
if ex.Run(ctx, bin, check...) == nil {
|
||||
return nil
|
||||
}
|
||||
return ex.Run(ctx, bin, "-I", chain, "-m", "set", "--match-set", set, "src", "-j", "DROP")
|
||||
}
|
||||
|
||||
type hook struct {
|
||||
Chain string
|
||||
OptionalMissing bool
|
||||
}
|
||||
|
||||
func normalizeMode(mode string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(mode)) {
|
||||
case "", "host", "classic", "native", "docker-host":
|
||||
return "host"
|
||||
case "docker", "docker-bridge", "docker-published", "published":
|
||||
return "docker-bridge"
|
||||
case "hybrid", "both":
|
||||
return "hybrid"
|
||||
default:
|
||||
return strings.ToLower(strings.TrimSpace(mode))
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Firewall) hooks(bin string) []hook {
|
||||
docker := hook{Chain: "DOCKER-USER", OptionalMissing: bin == "ip6tables"}
|
||||
switch f.Mode {
|
||||
case "host":
|
||||
return []hook{{Chain: "INPUT"}}
|
||||
case "docker-bridge":
|
||||
return []hook{docker}
|
||||
case "hybrid":
|
||||
return []hook{{Chain: "INPUT"}, docker}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Firewall) ensureHooks(ctx context.Context, bin string) error {
|
||||
for _, h := range f.hooks(bin) {
|
||||
if !chainExists(ctx, f.Exec, bin, h.Chain) {
|
||||
if h.OptionalMissing {
|
||||
continue
|
||||
}
|
||||
return fmt.Errorf("%s chain %s not found", bin, h.Chain)
|
||||
}
|
||||
for _, p := range f.Ports {
|
||||
for _, proto := range []string{"tcp", "udp"} {
|
||||
check := []string{"-C", h.Chain, "-p", proto, "--dport", p, "-j", f.Chain}
|
||||
if ensureRule(ctx, f.Exec, bin, check...) {
|
||||
continue
|
||||
}
|
||||
_ = f.Exec.Run(ctx, bin, "-I", h.Chain, "-p", proto, "--dport", p, "-j", f.Chain)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func chainExists(ctx context.Context, ex Executor, bin, chain string) bool {
|
||||
return ex.Run(ctx, bin, "-n", "-L", chain) == nil
|
||||
}
|
||||
|
||||
func (f *Firewall) Add(ctx context.Context, ip string, timeout int64) error {
|
||||
if f.DryRun {
|
||||
return nil
|
||||
}
|
||||
set, err := f.setFor(ip)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
args := []string{"add", set, ip, "-exist"}
|
||||
if timeout > 0 {
|
||||
args = append(args, "timeout", strconv.FormatInt(timeout, 10))
|
||||
}
|
||||
return f.Exec.Run(ctx, "ipset", args...)
|
||||
}
|
||||
|
||||
func (f *Firewall) Del(ctx context.Context, ip string) error {
|
||||
if f.DryRun {
|
||||
return nil
|
||||
}
|
||||
set, err := f.setFor(ip)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_ = f.Exec.Run(ctx, "ipset", "del", set, ip)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *Firewall) Flush(ctx context.Context) error {
|
||||
if f.DryRun {
|
||||
return nil
|
||||
}
|
||||
_ = f.Exec.Run(ctx, "ipset", "flush", f.Set4)
|
||||
_ = f.Exec.Run(ctx, "ipset", "flush", f.Set6)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *Firewall) Remove(ctx context.Context) error {
|
||||
if f.DryRun {
|
||||
return nil
|
||||
}
|
||||
for _, p := range f.Ports {
|
||||
for _, proto := range []string{"tcp", "udp"} {
|
||||
for _, parent := range []string{"INPUT", "DOCKER-USER"} {
|
||||
_ = f.Exec.Run(ctx, "iptables", "-D", parent, "-p", proto, "--dport", p, "-j", f.Chain)
|
||||
_ = f.Exec.Run(ctx, "ip6tables", "-D", parent, "-p", proto, "--dport", p, "-j", f.Chain)
|
||||
}
|
||||
}
|
||||
}
|
||||
_ = f.Exec.Run(ctx, "iptables", "-F", f.Chain)
|
||||
_ = f.Exec.Run(ctx, "ip6tables", "-F", f.Chain)
|
||||
_ = f.Exec.Run(ctx, "iptables", "-X", f.Chain)
|
||||
_ = f.Exec.Run(ctx, "ip6tables", "-X", f.Chain)
|
||||
_ = f.Exec.Run(ctx, "ipset", "destroy", f.Set4)
|
||||
_ = f.Exec.Run(ctx, "ipset", "destroy", f.Set6)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *Firewall) setFor(s string) (string, error) {
|
||||
if p, err := netip.ParsePrefix(s); err == nil {
|
||||
if p.Addr().Is4() {
|
||||
return f.Set4, nil
|
||||
}
|
||||
return f.Set6, nil
|
||||
}
|
||||
a, err := netip.ParseAddr(s)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid IP/prefix %q", s)
|
||||
}
|
||||
if a.Is4() {
|
||||
return f.Set4, nil
|
||||
}
|
||||
return f.Set6, nil
|
||||
}
|
||||
142
internal/firewall/firewall_test.go
Normal file
142
internal/firewall/firewall_test.go
Normal file
@@ -0,0 +1,142 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type fakeExec struct {
|
||||
calls []string
|
||||
failChecks bool
|
||||
missing map[string]bool
|
||||
}
|
||||
|
||||
func (f *fakeExec) Run(_ context.Context, name string, args ...string) error {
|
||||
call := name + " " + strings.Join(args, " ")
|
||||
f.calls = append(f.calls, call)
|
||||
if f.missing != nil && f.missing[call] {
|
||||
return errFake
|
||||
}
|
||||
if f.failChecks && len(args) > 0 && args[0] == "-C" {
|
||||
return errFake
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type fakeErr string
|
||||
|
||||
func (e fakeErr) Error() string { return string(e) }
|
||||
|
||||
var errFake = fakeErr("missing")
|
||||
|
||||
func TestFirewallSetupCreatesSetsAndRules(t *testing.T) {
|
||||
ex := &fakeExec{failChecks: true}
|
||||
fw := New(ex, "ADGUARD_SHIELD", []string{"53"}, "host", false)
|
||||
if err := fw.Setup(context.Background()); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
joined := strings.Join(ex.calls, "\n")
|
||||
for _, want := range []string{
|
||||
"ipset create adguard_shield_v4 hash:net family inet timeout 0 -exist",
|
||||
"iptables -I INPUT -p tcp --dport 53 -j ADGUARD_SHIELD",
|
||||
"iptables -I ADGUARD_SHIELD -m set --match-set adguard_shield_v4 src -j DROP",
|
||||
"ip6tables -I ADGUARD_SHIELD -m set --match-set adguard_shield_v6 src -j DROP",
|
||||
} {
|
||||
if !strings.Contains(joined, want) {
|
||||
t.Fatalf("missing call %q in:\n%s", want, joined)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFirewallSetupUsesDockerUserForBridgeMode(t *testing.T) {
|
||||
ex := &fakeExec{failChecks: true}
|
||||
fw := New(ex, "ADGUARD_SHIELD", []string{"53"}, "docker-bridge", false)
|
||||
if err := fw.Setup(context.Background()); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
joined := strings.Join(ex.calls, "\n")
|
||||
if !strings.Contains(joined, "iptables -I DOCKER-USER -p udp --dport 53 -j ADGUARD_SHIELD") {
|
||||
t.Fatalf("missing docker hook in:\n%s", joined)
|
||||
}
|
||||
if strings.Contains(joined, "iptables -I INPUT -p udp --dport 53 -j ADGUARD_SHIELD") {
|
||||
t.Fatalf("unexpected INPUT hook in docker-bridge mode:\n%s", joined)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFirewallSetupHybridUsesInputAndDockerUser(t *testing.T) {
|
||||
ex := &fakeExec{failChecks: true}
|
||||
fw := New(ex, "ADGUARD_SHIELD", []string{"53"}, "hybrid", false)
|
||||
if err := fw.Setup(context.Background()); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
joined := strings.Join(ex.calls, "\n")
|
||||
for _, want := range []string{
|
||||
"iptables -I INPUT -p tcp --dport 53 -j ADGUARD_SHIELD",
|
||||
"iptables -I DOCKER-USER -p tcp --dport 53 -j ADGUARD_SHIELD",
|
||||
} {
|
||||
if !strings.Contains(joined, want) {
|
||||
t.Fatalf("missing call %q in:\n%s", want, joined)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFirewallSetupRequiresDockerUserForIPv4BridgeMode(t *testing.T) {
|
||||
ex := &fakeExec{missing: map[string]bool{"iptables -n -L DOCKER-USER": true}}
|
||||
fw := New(ex, "ADGUARD_SHIELD", []string{"53"}, "docker-bridge", false)
|
||||
if err := fw.Setup(context.Background()); err == nil || !strings.Contains(err.Error(), "DOCKER-USER") {
|
||||
t.Fatalf("expected DOCKER-USER error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFirewallSetupSkipsMissingIPv6DockerUser(t *testing.T) {
|
||||
ex := &fakeExec{
|
||||
failChecks: true,
|
||||
missing: map[string]bool{"ip6tables -n -L DOCKER-USER": true},
|
||||
}
|
||||
fw := New(ex, "ADGUARD_SHIELD", []string{"53"}, "docker-bridge", false)
|
||||
if err := fw.Setup(context.Background()); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
joined := strings.Join(ex.calls, "\n")
|
||||
if strings.Contains(joined, "ip6tables -I DOCKER-USER") {
|
||||
t.Fatalf("unexpected IPv6 docker hook with missing DOCKER-USER:\n%s", joined)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFirewallSetupRejectsUnknownMode(t *testing.T) {
|
||||
fw := New(&fakeExec{}, "ADGUARD_SHIELD", []string{"53"}, "surprise", false)
|
||||
err := fw.Setup(context.Background())
|
||||
if err == nil || !strings.Contains(err.Error(), "unsupported firewall mode") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFirewallAddChoosesFamily(t *testing.T) {
|
||||
ex := &fakeExec{}
|
||||
fw := New(ex, "ADGUARD_SHIELD", []string{"53"}, "host", false)
|
||||
if err := fw.Add(context.Background(), "2001:db8::1", 30); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
got := strings.Join(ex.calls, "\n")
|
||||
if !strings.Contains(got, "ipset add adguard_shield_v6 2001:db8::1 -exist timeout 30") {
|
||||
t.Fatalf("unexpected calls:\n%s", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFirewallRemoveDeletesAllKnownHooks(t *testing.T) {
|
||||
ex := &fakeExec{}
|
||||
fw := New(ex, "ADGUARD_SHIELD", []string{"53"}, "host", false)
|
||||
if err := fw.Remove(context.Background()); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
joined := strings.Join(ex.calls, "\n")
|
||||
for _, want := range []string{
|
||||
"iptables -D INPUT -p tcp --dport 53 -j ADGUARD_SHIELD",
|
||||
"iptables -D DOCKER-USER -p tcp --dport 53 -j ADGUARD_SHIELD",
|
||||
} {
|
||||
if !strings.Contains(joined, want) {
|
||||
t.Fatalf("missing cleanup call %q in:\n%s", want, joined)
|
||||
}
|
||||
}
|
||||
}
|
||||
245
internal/geoip/geoip.go
Normal file
245
internal/geoip/geoip.go
Normal file
@@ -0,0 +1,245 @@
|
||||
package geoip
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/oschwald/maxminddb-golang"
|
||||
)
|
||||
|
||||
type Store interface {
|
||||
LoadGeoIPCache(ttl, dbMtime int64) (map[string]string, error)
|
||||
UpsertGeoIP(ip, country string, dbMtime int64) error
|
||||
}
|
||||
|
||||
type Resolver struct {
|
||||
DBPath string
|
||||
effectivePath string
|
||||
LicenseKey string
|
||||
Dir string
|
||||
TTL int64
|
||||
Store Store
|
||||
reader *maxminddb.Reader
|
||||
cache map[string]string
|
||||
mtime int64
|
||||
}
|
||||
|
||||
func New(dbPath, licenseKey, dir string, ttl int64, store Store) *Resolver {
|
||||
return &Resolver{DBPath: dbPath, LicenseKey: licenseKey, Dir: dir, TTL: ttl, Store: store, cache: map[string]string{}}
|
||||
}
|
||||
|
||||
func (r *Resolver) Open(ctx context.Context) error {
|
||||
path := r.DBPath
|
||||
if path == "" && r.LicenseKey != "" {
|
||||
var err error
|
||||
path, err = r.ensureAutoDB(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if path == "" {
|
||||
return nil
|
||||
}
|
||||
r.effectivePath = path
|
||||
st, err := os.Stat(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
reader, err := maxminddb.Open(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.reader = reader
|
||||
r.mtime = st.ModTime().Unix()
|
||||
if r.Store != nil {
|
||||
if c, err := r.Store.LoadGeoIPCache(r.TTL, r.mtime); err == nil {
|
||||
r.cache = c
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Resolver) Close() error {
|
||||
if r.reader != nil {
|
||||
return r.reader.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Resolver) Lookup(ip string) (string, error) {
|
||||
if v, ok := r.cache[ip]; ok {
|
||||
return v, nil
|
||||
}
|
||||
if r.reader == nil {
|
||||
return r.lookupLegacy(ip)
|
||||
}
|
||||
parsed := net.ParseIP(ip)
|
||||
if parsed == nil {
|
||||
return "", fmt.Errorf("invalid IP %q", ip)
|
||||
}
|
||||
var rec struct {
|
||||
Country struct {
|
||||
ISOCode string `maxminddb:"iso_code"`
|
||||
} `maxminddb:"country"`
|
||||
RegisteredCountry struct {
|
||||
ISOCode string `maxminddb:"iso_code"`
|
||||
} `maxminddb:"registered_country"`
|
||||
}
|
||||
if err := r.reader.Lookup(parsed, &rec); err != nil {
|
||||
return "", err
|
||||
}
|
||||
cc := strings.ToUpper(rec.Country.ISOCode)
|
||||
if cc == "" {
|
||||
cc = strings.ToUpper(rec.RegisteredCountry.ISOCode)
|
||||
}
|
||||
if cc != "" {
|
||||
r.cache[ip] = cc
|
||||
if r.Store != nil {
|
||||
_ = r.Store.UpsertGeoIP(ip, cc, r.mtime)
|
||||
}
|
||||
}
|
||||
return cc, nil
|
||||
}
|
||||
|
||||
func (r *Resolver) lookupLegacy(ip string) (string, error) {
|
||||
if strings.Contains(ip, ":") {
|
||||
if cc, err := runGeoIPCommand("geoiplookup6", ip); err == nil && cc != "" {
|
||||
return cc, nil
|
||||
}
|
||||
} else {
|
||||
if cc, err := runGeoIPCommand("geoiplookup", ip); err == nil && cc != "" {
|
||||
return cc, nil
|
||||
}
|
||||
}
|
||||
if r.effectivePath != "" {
|
||||
if cc, err := runGeoIPCommand("mmdblookup", "--file", r.effectivePath, "--ip", ip, "country", "iso_code"); err == nil && cc != "" {
|
||||
return cc, nil
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("no GeoIP result for %s", ip)
|
||||
}
|
||||
|
||||
func runGeoIPCommand(name string, args ...string) (string, error) {
|
||||
if _, err := exec.LookPath(name); err != nil {
|
||||
return "", err
|
||||
}
|
||||
out, err := exec.Command(name, args...).CombinedOutput()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
re := regexp.MustCompile(`\b[A-Z]{2}\b`)
|
||||
matches := re.FindAllString(string(out), -1)
|
||||
for _, m := range matches {
|
||||
if m != "IP" {
|
||||
return strings.ToUpper(m), nil
|
||||
}
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func ShouldBlock(country, mode string, countries []string) bool {
|
||||
if country == "" || len(countries) == 0 {
|
||||
return false
|
||||
}
|
||||
found := false
|
||||
country = strings.ToUpper(country)
|
||||
for _, c := range countries {
|
||||
if strings.ToUpper(strings.TrimSpace(c)) == country {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if strings.ToLower(mode) == "allowlist" {
|
||||
return !found
|
||||
}
|
||||
return found
|
||||
}
|
||||
|
||||
func IsPrivateIP(s string) bool {
|
||||
if p, err := netip.ParsePrefix(s); err == nil {
|
||||
return isPrivateAddr(p.Addr())
|
||||
}
|
||||
a, err := netip.ParseAddr(s)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return isPrivateAddr(a)
|
||||
}
|
||||
|
||||
func isPrivateAddr(a netip.Addr) bool {
|
||||
return a.IsPrivate() || a.IsLoopback() || a.IsLinkLocalUnicast() || a.IsUnspecified() ||
|
||||
(a.Is4() && strings.HasPrefix(a.String(), "100.") && isCGNAT(a))
|
||||
}
|
||||
|
||||
func isCGNAT(a netip.Addr) bool {
|
||||
p := a.As4()
|
||||
return p[0] == 100 && p[1] >= 64 && p[1] <= 127
|
||||
}
|
||||
|
||||
func (r *Resolver) ensureAutoDB(ctx context.Context) (string, error) {
|
||||
if err := os.MkdirAll(r.Dir, 0755); err != nil {
|
||||
return "", err
|
||||
}
|
||||
dst := filepath.Join(r.Dir, "GeoLite2-Country.mmdb")
|
||||
if st, err := os.Stat(dst); err == nil && time.Since(st.ModTime()) < 24*time.Hour {
|
||||
return dst, nil
|
||||
}
|
||||
url := "https://download.maxmind.com/app/geoip_download?edition_id=GeoLite2-Country&license_key=" + r.LicenseKey + "&suffix=tar.gz"
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("MaxMind download failed: HTTP %d", resp.StatusCode)
|
||||
}
|
||||
gzr, err := gzip.NewReader(resp.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer gzr.Close()
|
||||
tr := tar.NewReader(gzr)
|
||||
tmp := dst + ".tmp"
|
||||
for {
|
||||
h, err := tr.Next()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if h.FileInfo().IsDir() || filepath.Base(h.Name) != "GeoLite2-Country.mmdb" {
|
||||
continue
|
||||
}
|
||||
f, err := os.Create(tmp)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
_, copyErr := io.Copy(f, tr)
|
||||
closeErr := f.Close()
|
||||
if copyErr != nil {
|
||||
return "", copyErr
|
||||
}
|
||||
if closeErr != nil {
|
||||
return "", closeErr
|
||||
}
|
||||
return dst, os.Rename(tmp, dst)
|
||||
}
|
||||
return "", fmt.Errorf("GeoLite2-Country.mmdb not found in archive")
|
||||
}
|
||||
30
internal/geoip/geoip_test.go
Normal file
30
internal/geoip/geoip_test.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package geoip
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestShouldBlockModes(t *testing.T) {
|
||||
countries := []string{"CN", "RU"}
|
||||
if !ShouldBlock("cn", "blocklist", countries) {
|
||||
t.Fatal("blocklist should block listed country")
|
||||
}
|
||||
if ShouldBlock("DE", "blocklist", countries) {
|
||||
t.Fatal("blocklist should allow unlisted country")
|
||||
}
|
||||
if ShouldBlock("CN", "allowlist", countries) {
|
||||
t.Fatal("allowlist should allow listed country")
|
||||
}
|
||||
if !ShouldBlock("DE", "allowlist", countries) {
|
||||
t.Fatal("allowlist should block unlisted country")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsPrivateIP(t *testing.T) {
|
||||
for _, ip := range []string{"127.0.0.1", "192.168.1.10", "10.1.2.3", "100.64.0.1", "::1", "fd00::1"} {
|
||||
if !IsPrivateIP(ip) {
|
||||
t.Fatalf("%s should be private", ip)
|
||||
}
|
||||
}
|
||||
if IsPrivateIP("8.8.8.8") {
|
||||
t.Fatal("8.8.8.8 should be public")
|
||||
}
|
||||
}
|
||||
642
internal/installer/installer.go
Normal file
642
internal/installer/installer.go
Normal file
@@ -0,0 +1,642 @@
|
||||
package installer
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultInstallDir = "/opt/adguard-shield"
|
||||
DefaultStateDir = "/var/lib/adguard-shield"
|
||||
DefaultLogFile = "/var/log/adguard-shield.log"
|
||||
ServiceName = "adguard-shield.service"
|
||||
ServicePath = "/etc/systemd/system/adguard-shield.service"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
InstallDir string
|
||||
ConfigSource string
|
||||
Enable bool
|
||||
SkipDeps bool
|
||||
KeepConfig bool
|
||||
}
|
||||
|
||||
type Status struct {
|
||||
InstallDir string
|
||||
BinaryPath string
|
||||
ConfigPath string
|
||||
BinaryExists bool
|
||||
ConfigExists bool
|
||||
ServiceExists bool
|
||||
ServiceEnabled bool
|
||||
ServiceActive bool
|
||||
Version string
|
||||
LegacyFindings []string
|
||||
}
|
||||
|
||||
type LegacyError struct {
|
||||
Findings []string
|
||||
}
|
||||
|
||||
func (e *LegacyError) Error() string {
|
||||
return "scriptbasierte AdGuard-Shield-Installation gefunden"
|
||||
}
|
||||
|
||||
func DefaultOptions() Options {
|
||||
return Options{InstallDir: DefaultInstallDir, Enable: true}
|
||||
}
|
||||
|
||||
func Install(opts Options) error {
|
||||
opts = normalize(opts)
|
||||
fmt.Println("AdGuard Shield Go-Installation")
|
||||
fmt.Printf("Installationspfad: %s\n", opts.InstallDir)
|
||||
fmt.Println("1/8 Pruefe Betriebssystem und root-Rechte ...")
|
||||
if err := requireLinuxRoot(); err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Println("2/8 Pruefe auf scriptbasierte Altinstallation ...")
|
||||
if findings := DetectLegacy(opts.InstallDir); len(findings) > 0 {
|
||||
return &LegacyError{Findings: findings}
|
||||
}
|
||||
if !opts.SkipDeps {
|
||||
fmt.Println("3/8 Pruefe System-Abhaengigkeiten ...")
|
||||
if err := ensureDependencies(); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
fmt.Println("3/8 System-Abhaengigkeiten uebersprungen (--skip-deps)")
|
||||
}
|
||||
fmt.Println("4/8 Erstelle Verzeichnisse ...")
|
||||
if err := os.MkdirAll(opts.InstallDir, 0755); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.MkdirAll(DefaultStateDir, 0755); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Join(opts.InstallDir, "geoip"), 0755); err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Println("5/8 Installiere Binary ...")
|
||||
if err := copySelf(filepath.Join(opts.InstallDir, "adguard-shield")); err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Println("6/8 Installiere oder migriere Konfiguration ...")
|
||||
if err := ensureConfig(opts); err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Println("7/8 Schreibe systemd-Service ...")
|
||||
if err := writeService(opts.InstallDir); err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Println("8/8 Aktualisiere systemd ...")
|
||||
_ = run("systemctl", "daemon-reload")
|
||||
if opts.Enable {
|
||||
fmt.Println("Aktiviere Autostart ...")
|
||||
if err := run("systemctl", "enable", ServiceName); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if askStartService() {
|
||||
fmt.Println("Starte Service neu ...")
|
||||
if err := run("systemctl", "restart", ServiceName); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
fmt.Println("Installation fertig.")
|
||||
return nil
|
||||
}
|
||||
|
||||
func Update(opts Options) error {
|
||||
opts = normalize(opts)
|
||||
if err := requireLinuxRoot(); err != nil {
|
||||
return err
|
||||
}
|
||||
if findings := DetectLegacy(opts.InstallDir); len(findings) > 0 {
|
||||
return &LegacyError{Findings: findings}
|
||||
}
|
||||
return Install(opts)
|
||||
}
|
||||
|
||||
func Uninstall(opts Options) error {
|
||||
opts = normalize(opts)
|
||||
if err := requireLinuxRoot(); err != nil {
|
||||
return err
|
||||
}
|
||||
_ = run("systemctl", "stop", ServiceName)
|
||||
_ = run("systemctl", "disable", ServiceName)
|
||||
if _, err := os.Stat(filepath.Join(opts.InstallDir, "adguard-shield")); err == nil {
|
||||
_ = run(filepath.Join(opts.InstallDir, "adguard-shield"), "-config", filepath.Join(opts.InstallDir, "adguard-shield.conf"), "firewall-remove")
|
||||
}
|
||||
_ = os.Remove(ServicePath)
|
||||
_ = run("systemctl", "daemon-reload")
|
||||
if opts.KeepConfig {
|
||||
for _, p := range []string{
|
||||
filepath.Join(opts.InstallDir, "adguard-shield"),
|
||||
filepath.Join(opts.InstallDir, "adguard-shield.conf.old"),
|
||||
} {
|
||||
_ = os.Remove(p)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
_ = os.RemoveAll(opts.InstallDir)
|
||||
_ = os.RemoveAll(DefaultStateDir)
|
||||
_ = os.Remove(DefaultLogFile)
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetStatus(installDir string) Status {
|
||||
if installDir == "" {
|
||||
installDir = DefaultInstallDir
|
||||
}
|
||||
bin := filepath.Join(installDir, "adguard-shield")
|
||||
conf := filepath.Join(installDir, "adguard-shield.conf")
|
||||
st := Status{
|
||||
InstallDir: installDir,
|
||||
BinaryPath: bin,
|
||||
ConfigPath: conf,
|
||||
BinaryExists: fileExists(bin),
|
||||
ConfigExists: fileExists(conf),
|
||||
ServiceExists: fileExists(ServicePath),
|
||||
LegacyFindings: DetectLegacy(installDir),
|
||||
}
|
||||
if st.BinaryExists {
|
||||
if out, err := exec.Command(bin, "version").Output(); err == nil {
|
||||
st.Version = strings.TrimSpace(string(out))
|
||||
}
|
||||
}
|
||||
st.ServiceEnabled = commandOK("systemctl", "is-enabled", "adguard-shield")
|
||||
st.ServiceActive = commandOK("systemctl", "is-active", "adguard-shield")
|
||||
return st
|
||||
}
|
||||
|
||||
func DetectLegacy(installDir string) []string {
|
||||
if installDir == "" {
|
||||
installDir = DefaultInstallDir
|
||||
}
|
||||
var findings []string
|
||||
for _, p := range []string{
|
||||
"adguard-shield.sh",
|
||||
"iptables-helper.sh",
|
||||
"db.sh",
|
||||
"external-blocklist-worker.sh",
|
||||
"external-whitelist-worker.sh",
|
||||
"geoip-worker.sh",
|
||||
"offense-cleanup-worker.sh",
|
||||
"report-generator.sh",
|
||||
"unban-expired.sh",
|
||||
"adguard-shield-watchdog.sh",
|
||||
} {
|
||||
full := filepath.Join(installDir, p)
|
||||
if fileExists(full) {
|
||||
findings = append(findings, full)
|
||||
}
|
||||
}
|
||||
for _, p := range []string{
|
||||
"/etc/systemd/system/adguard-shield-watchdog.service",
|
||||
"/etc/systemd/system/adguard-shield-watchdog.timer",
|
||||
} {
|
||||
if fileExists(p) {
|
||||
findings = append(findings, p)
|
||||
}
|
||||
}
|
||||
if b, err := os.ReadFile(ServicePath); err == nil {
|
||||
s := string(b)
|
||||
if strings.Contains(s, ".sh") || strings.Contains(s, "/bin/bash") || strings.Contains(s, "adguard-shield-watchdog") {
|
||||
findings = append(findings, ServicePath+" verweist auf Shell/Watchdog")
|
||||
}
|
||||
}
|
||||
sort.Strings(findings)
|
||||
return findings
|
||||
}
|
||||
|
||||
func FormatLegacyMessage(err *LegacyError, installDir string) string {
|
||||
if installDir == "" {
|
||||
installDir = DefaultInstallDir
|
||||
}
|
||||
var b strings.Builder
|
||||
b.WriteString("Die scriptbasierte Installation ist noch vorhanden und muss zuerst deinstalliert werden.\n\n")
|
||||
b.WriteString("Gefunden:\n")
|
||||
for _, f := range err.Findings {
|
||||
b.WriteString(" - ")
|
||||
b.WriteString(f)
|
||||
b.WriteByte('\n')
|
||||
}
|
||||
b.WriteString("\nKonfiguration uebernehmen:\n")
|
||||
b.WriteString(" 1. Backup behalten: ")
|
||||
b.WriteString(filepath.Join(installDir, "adguard-shield.conf"))
|
||||
b.WriteByte('\n')
|
||||
b.WriteString(" 2. Alte Shell-Version mit deren uninstall.sh entfernen und die Konfiguration behalten.\n")
|
||||
b.WriteString(" 3. Danach dieses Binary erneut ausfuehren: adguard-shield install\n")
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func PrintStatus(st Status) string {
|
||||
var b strings.Builder
|
||||
b.WriteString("AdGuard Shield Installationsstatus\n")
|
||||
b.WriteString(fmt.Sprintf("Installationspfad: %s\n", st.InstallDir))
|
||||
b.WriteString(fmt.Sprintf("Binary: %s\n", yesNo(st.BinaryExists)))
|
||||
if st.Version != "" {
|
||||
b.WriteString(fmt.Sprintf("Version: %s\n", st.Version))
|
||||
}
|
||||
b.WriteString(fmt.Sprintf("Konfiguration: %s\n", yesNo(st.ConfigExists)))
|
||||
b.WriteString(fmt.Sprintf("systemd Service: %s\n", yesNo(st.ServiceExists)))
|
||||
b.WriteString(fmt.Sprintf("Autostart: %s\n", yesNo(st.ServiceEnabled)))
|
||||
b.WriteString(fmt.Sprintf("Service aktiv: %s\n", yesNo(st.ServiceActive)))
|
||||
if len(st.LegacyFindings) > 0 {
|
||||
b.WriteString("\nScriptbasierte Altinstallation/Altartefakte gefunden:\n")
|
||||
for _, f := range st.LegacyFindings {
|
||||
b.WriteString(" - ")
|
||||
b.WriteString(f)
|
||||
b.WriteByte('\n')
|
||||
}
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func normalize(opts Options) Options {
|
||||
if opts.InstallDir == "" {
|
||||
opts.InstallDir = DefaultInstallDir
|
||||
}
|
||||
return opts
|
||||
}
|
||||
|
||||
func askStartService() bool {
|
||||
fmt.Print("AdGuard Shield jetzt (neu) starten? [j/N] ")
|
||||
line, err := bufio.NewReader(os.Stdin).ReadString('\n')
|
||||
if err != nil && len(line) == 0 {
|
||||
fmt.Println("Keine Eingabe gelesen, Service wird nicht gestartet.")
|
||||
return false
|
||||
}
|
||||
switch strings.ToLower(strings.TrimSpace(line)) {
|
||||
case "j", "ja", "y", "yes":
|
||||
return true
|
||||
default:
|
||||
fmt.Println("Service wird nicht gestartet.")
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func requireLinuxRoot() error {
|
||||
if runtime.GOOS != "linux" {
|
||||
return fmt.Errorf("Installation ist nur auf Linux-Servern unterstuetzt")
|
||||
}
|
||||
if os.Geteuid() != 0 {
|
||||
return fmt.Errorf("Installation muss als root ausgefuehrt werden")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ensureDependencies() error {
|
||||
missing := missingCommands("iptables", "ip6tables", "ipset", "systemctl")
|
||||
if len(missing) == 0 {
|
||||
fmt.Println(" Alle benoetigten Befehle sind vorhanden.")
|
||||
return nil
|
||||
}
|
||||
fmt.Printf(" Fehlende Befehle: %s\n", strings.Join(missing, ", "))
|
||||
if _, err := exec.LookPath("apt-get"); err != nil {
|
||||
return fmt.Errorf("fehlende Abhaengigkeiten (%s), apt-get nicht gefunden", strings.Join(missing, ", "))
|
||||
}
|
||||
pkgs := map[string]bool{"iptables": false, "ipset": false, "systemd": false, "ca-certificates": false}
|
||||
for _, m := range missing {
|
||||
switch m {
|
||||
case "iptables", "ip6tables":
|
||||
pkgs["iptables"] = true
|
||||
case "ipset":
|
||||
pkgs["ipset"] = true
|
||||
case "systemctl":
|
||||
pkgs["systemd"] = true
|
||||
}
|
||||
}
|
||||
var install []string
|
||||
for p, needed := range pkgs {
|
||||
if needed || p == "ca-certificates" {
|
||||
install = append(install, p)
|
||||
}
|
||||
}
|
||||
sort.Strings(install)
|
||||
fmt.Printf(" Installiere Pakete via apt-get: %s\n", strings.Join(install, ", "))
|
||||
fmt.Println(" apt-get update ...")
|
||||
if err := runStreaming("apt-get", "update"); err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Println(" apt-get install ...")
|
||||
args := append([]string{"install", "-y", "-qq"}, install...)
|
||||
return runStreaming("apt-get", args...)
|
||||
}
|
||||
|
||||
func missingCommands(names ...string) []string {
|
||||
var missing []string
|
||||
for _, name := range names {
|
||||
if _, err := exec.LookPath(name); err != nil {
|
||||
missing = append(missing, name)
|
||||
}
|
||||
}
|
||||
return missing
|
||||
}
|
||||
|
||||
func copySelf(dst string) error {
|
||||
src, err := os.Executable()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if sameFile(src, dst) {
|
||||
return os.Chmod(dst, 0755)
|
||||
}
|
||||
in, err := os.Open(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer in.Close()
|
||||
tmp := dst + ".tmp"
|
||||
out, err := os.OpenFile(tmp, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0755)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := io.Copy(out, in); err != nil {
|
||||
_ = out.Close()
|
||||
return err
|
||||
}
|
||||
if err := out.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.Chmod(tmp, 0755); err != nil {
|
||||
return err
|
||||
}
|
||||
return os.Rename(tmp, dst)
|
||||
}
|
||||
|
||||
func sameFile(a, b string) bool {
|
||||
aa, errA := filepath.Abs(a)
|
||||
bb, errB := filepath.Abs(b)
|
||||
if errA == nil && errB == nil && aa == bb {
|
||||
return true
|
||||
}
|
||||
ai, errA := os.Stat(a)
|
||||
bi, errB := os.Stat(b)
|
||||
return errA == nil && errB == nil && os.SameFile(ai, bi)
|
||||
}
|
||||
|
||||
func ensureConfig(opts Options) error {
|
||||
target := filepath.Join(opts.InstallDir, "adguard-shield.conf")
|
||||
defaults := []byte(defaultConfig)
|
||||
if opts.ConfigSource != "" {
|
||||
b, err := os.ReadFile(opts.ConfigSource)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defaults = b
|
||||
}
|
||||
if !fileExists(target) {
|
||||
if err := os.WriteFile(target, defaults, 0600); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
current, err := os.ReadFile(target)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
merged, changed := mergeConfig(current, []byte(defaultConfig))
|
||||
if !changed {
|
||||
return os.Chmod(target, 0600)
|
||||
}
|
||||
if err := os.WriteFile(target+".old", current, 0600); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.WriteFile(target, merged, 0600); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func mergeConfig(current, defaults []byte) ([]byte, bool) {
|
||||
existing := configKeys(current)
|
||||
var add [][]byte
|
||||
for _, block := range configBlocks(defaults) {
|
||||
key := blockKey(block)
|
||||
if key == "" || existing[key] {
|
||||
continue
|
||||
}
|
||||
add = append(add, block)
|
||||
}
|
||||
if len(add) == 0 {
|
||||
return current, false
|
||||
}
|
||||
out := bytes.TrimRight(current, "\r\n")
|
||||
out = append(out, '\n', '\n')
|
||||
out = append(out, []byte("# Neue Parameter aus der Go-Version\n")...)
|
||||
for _, block := range add {
|
||||
out = append(out, bytes.Trim(block, "\r\n")...)
|
||||
out = append(out, '\n')
|
||||
}
|
||||
return out, true
|
||||
}
|
||||
|
||||
func configKeys(data []byte) map[string]bool {
|
||||
keys := map[string]bool{}
|
||||
for _, line := range bytes.Split(data, []byte{'\n'}) {
|
||||
line = bytes.TrimSpace(line)
|
||||
if len(line) == 0 || line[0] == '#' {
|
||||
continue
|
||||
}
|
||||
if i := bytes.IndexByte(line, '='); i > 0 {
|
||||
keys[string(bytes.TrimSpace(line[:i]))] = true
|
||||
}
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
func configBlocks(data []byte) [][]byte {
|
||||
lines := bytes.Split(data, []byte{'\n'})
|
||||
var blocks [][]byte
|
||||
var comments [][]byte
|
||||
for _, line := range lines {
|
||||
trim := bytes.TrimSpace(line)
|
||||
if len(trim) == 0 || trim[0] == '#' {
|
||||
comments = append(comments, append([]byte(nil), line...))
|
||||
continue
|
||||
}
|
||||
block := bytes.Join(append(comments, line), []byte{'\n'})
|
||||
blocks = append(blocks, block)
|
||||
comments = nil
|
||||
}
|
||||
return blocks
|
||||
}
|
||||
|
||||
func blockKey(block []byte) string {
|
||||
for _, line := range bytes.Split(block, []byte{'\n'}) {
|
||||
line = bytes.TrimSpace(line)
|
||||
if len(line) == 0 || line[0] == '#' {
|
||||
continue
|
||||
}
|
||||
if i := bytes.IndexByte(line, '='); i > 0 {
|
||||
return string(bytes.TrimSpace(line[:i]))
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func writeService(installDir string) error {
|
||||
service := fmt.Sprintf(`[Unit]
|
||||
Description=AdGuard Shield - Go DNS Rate-Limit Monitor
|
||||
After=network.target AdGuardHome.service
|
||||
Wants=AdGuardHome.service
|
||||
StartLimitBurst=5
|
||||
StartLimitIntervalSec=300
|
||||
|
||||
[Service]
|
||||
Type=simple
|
||||
ExecStart=%s/adguard-shield -config %s/adguard-shield.conf run
|
||||
ExecReload=/bin/kill -HUP $MAINPID
|
||||
Restart=on-failure
|
||||
RestartSec=30
|
||||
ProtectSystem=full
|
||||
ReadWritePaths=/var/log /var/lib/adguard-shield /var/run %s/geoip
|
||||
ProtectHome=true
|
||||
NoNewPrivileges=false
|
||||
PrivateTmp=true
|
||||
AmbientCapabilities=CAP_NET_ADMIN CAP_NET_RAW
|
||||
CapabilityBoundingSet=CAP_NET_ADMIN CAP_NET_RAW CAP_DAC_OVERRIDE CAP_DAC_READ_SEARCH CAP_FOWNER CAP_KILL CAP_SETUID CAP_SETGID CAP_CHOWN
|
||||
StandardOutput=journal
|
||||
StandardError=journal
|
||||
SyslogIdentifier=adguard-shield
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
`, installDir, installDir, installDir)
|
||||
return os.WriteFile(ServicePath, []byte(service), 0644)
|
||||
}
|
||||
|
||||
func run(name string, args ...string) error {
|
||||
cmd := exec.Command(name, args...)
|
||||
out, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s %s: %w\n%s", name, strings.Join(args, " "), err, strings.TrimSpace(string(out)))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func runStreaming(name string, args ...string) error {
|
||||
cmd := exec.Command(name, args...)
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
if err := cmd.Run(); err != nil {
|
||||
return fmt.Errorf("%s %s: %w", name, strings.Join(args, " "), err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func commandOK(name string, args ...string) bool {
|
||||
return exec.Command(name, args...).Run() == nil
|
||||
}
|
||||
|
||||
func fileExists(path string) bool {
|
||||
_, err := os.Stat(path)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func yesNo(ok bool) string {
|
||||
if ok {
|
||||
return "ja"
|
||||
}
|
||||
return "nein"
|
||||
}
|
||||
|
||||
func IsLegacyError(err error) (*LegacyError, bool) {
|
||||
var le *LegacyError
|
||||
if errors.As(err, &le) {
|
||||
return le, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
const defaultConfig = `# AdGuard Shield Konfiguration
|
||||
|
||||
ADGUARD_URL="https://dns1.domain.com"
|
||||
ADGUARD_USER="admin"
|
||||
ADGUARD_PASS='changeme'
|
||||
|
||||
RATE_LIMIT_MAX_REQUESTS=30
|
||||
RATE_LIMIT_WINDOW=60
|
||||
CHECK_INTERVAL=10
|
||||
API_QUERY_LIMIT=500
|
||||
|
||||
SUBDOMAIN_FLOOD_ENABLED=true
|
||||
SUBDOMAIN_FLOOD_MAX_UNIQUE=50
|
||||
SUBDOMAIN_FLOOD_WINDOW=60
|
||||
|
||||
DNS_FLOOD_WATCHLIST_ENABLED=false
|
||||
DNS_FLOOD_WATCHLIST=""
|
||||
|
||||
BAN_DURATION=3600
|
||||
IPTABLES_CHAIN="ADGUARD_SHIELD"
|
||||
BLOCKED_PORTS="53 443 853"
|
||||
FIREWALL_BACKEND="ipset"
|
||||
FIREWALL_MODE="host"
|
||||
DRY_RUN=false
|
||||
|
||||
WHITELIST="127.0.0.1,::1"
|
||||
|
||||
LOG_FILE="/var/log/adguard-shield.log"
|
||||
LOG_LEVEL="INFO"
|
||||
STATE_DIR="/var/lib/adguard-shield"
|
||||
PID_FILE="/var/run/adguard-shield.pid"
|
||||
|
||||
NOTIFY_ENABLED=false
|
||||
NOTIFY_TYPE="ntfy"
|
||||
NOTIFY_WEBHOOK_URL=""
|
||||
NTFY_SERVER_URL="https://ntfy.sh"
|
||||
NTFY_TOPIC=""
|
||||
NTFY_TOKEN=""
|
||||
NTFY_PRIORITY="4"
|
||||
|
||||
REPORT_ENABLED=false
|
||||
REPORT_INTERVAL="weekly"
|
||||
REPORT_TIME="08:00"
|
||||
REPORT_EMAIL_TO="admin@example.com"
|
||||
REPORT_EMAIL_FROM="adguard-shield@example.com"
|
||||
REPORT_FORMAT="html"
|
||||
REPORT_MAIL_CMD="msmtp"
|
||||
REPORT_BUSIEST_DAY_RANGE=30
|
||||
|
||||
EXTERNAL_WHITELIST_ENABLED=false
|
||||
EXTERNAL_WHITELIST_URLS=""
|
||||
EXTERNAL_WHITELIST_INTERVAL=300
|
||||
EXTERNAL_WHITELIST_CACHE_DIR="/var/lib/adguard-shield/external-whitelist"
|
||||
|
||||
EXTERNAL_BLOCKLIST_ENABLED=false
|
||||
EXTERNAL_BLOCKLIST_URLS=""
|
||||
EXTERNAL_BLOCKLIST_INTERVAL=300
|
||||
EXTERNAL_BLOCKLIST_BAN_DURATION=0
|
||||
EXTERNAL_BLOCKLIST_AUTO_UNBAN=true
|
||||
EXTERNAL_BLOCKLIST_NOTIFY=false
|
||||
EXTERNAL_BLOCKLIST_CACHE_DIR="/var/lib/adguard-shield/external-blocklist"
|
||||
|
||||
PROGRESSIVE_BAN_ENABLED=true
|
||||
PROGRESSIVE_BAN_MULTIPLIER=2
|
||||
PROGRESSIVE_BAN_MAX_LEVEL=5
|
||||
PROGRESSIVE_BAN_RESET_AFTER=86400
|
||||
|
||||
ABUSEIPDB_ENABLED=false
|
||||
ABUSEIPDB_API_KEY=""
|
||||
ABUSEIPDB_CATEGORIES="4"
|
||||
|
||||
GEOIP_ENABLED=false
|
||||
GEOIP_MODE="blocklist"
|
||||
GEOIP_COUNTRIES=""
|
||||
GEOIP_CHECK_INTERVAL=0
|
||||
GEOIP_NOTIFY=true
|
||||
GEOIP_SKIP_PRIVATE=true
|
||||
GEOIP_LICENSE_KEY=""
|
||||
GEOIP_MMDB_PATH=""
|
||||
GEOIP_CACHE_TTL=86400
|
||||
`
|
||||
242
internal/report/report.go
Normal file
242
internal/report/report.go
Normal file
@@ -0,0 +1,242 @@
|
||||
package report
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"html"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"adguard-shield/internal/config"
|
||||
"adguard-shield/internal/db"
|
||||
)
|
||||
|
||||
type Store interface {
|
||||
ReportStats(since, until int64, limit int) (db.ReportStats, error)
|
||||
}
|
||||
|
||||
const cronPath = "/etc/cron.d/adguard-shield-report"
|
||||
|
||||
func Status(c *config.Config) string {
|
||||
cron := "nicht installiert"
|
||||
if _, err := os.Stat(cronPath); err == nil {
|
||||
cron = "installiert (" + cronPath + ")"
|
||||
}
|
||||
return fmt.Sprintf(`E-Mail Report
|
||||
Aktiv: %v
|
||||
Intervall: %s
|
||||
Zeit: %s
|
||||
Empfaenger: %s
|
||||
Absender: %s
|
||||
Format: %s
|
||||
Mail-Befehl: %s
|
||||
Cron: %s
|
||||
`, c.ReportEnabled, c.ReportInterval, c.ReportTime, c.ReportEmailTo, c.ReportEmailFrom, c.ReportFormat, c.ReportMailCmd, cron)
|
||||
}
|
||||
|
||||
func Generate(c *config.Config, st Store, format string) (string, error) {
|
||||
if format == "" {
|
||||
format = c.ReportFormat
|
||||
}
|
||||
since, until := window(c.ReportInterval)
|
||||
stats, err := st.ReportStats(since, until, 20)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if strings.EqualFold(format, "html") {
|
||||
return renderHTML(c, stats), nil
|
||||
}
|
||||
return renderText(c, stats), nil
|
||||
}
|
||||
|
||||
func Send(ctx context.Context, c *config.Config, st Store) error {
|
||||
body, err := Generate(c, st, c.ReportFormat)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return sendMail(ctx, c, "AdGuard Shield Report", body)
|
||||
}
|
||||
|
||||
func SendTest(ctx context.Context, c *config.Config) error {
|
||||
body := fmt.Sprintf("AdGuard Shield Test-Mail\n\nHostname: %s\nZeitpunkt: %s\nEmpfaenger: %s\nAbsender: %s\n", hostname(), time.Now().Format("2006-01-02 15:04:05"), c.ReportEmailTo, c.ReportEmailFrom)
|
||||
if strings.EqualFold(c.ReportFormat, "html") {
|
||||
body = "<!doctype html><html><body><h1>AdGuard Shield Test-Mail</h1><p>Hostname: " + html.EscapeString(hostname()) + "</p><p>Zeitpunkt: " + html.EscapeString(time.Now().Format("2006-01-02 15:04:05")) + "</p></body></html>"
|
||||
}
|
||||
return sendMail(ctx, c, "AdGuard Shield Test-Mail", body)
|
||||
}
|
||||
|
||||
func InstallCron(binary, configPath string, c *config.Config) error {
|
||||
minute, hour, err := parseReportTime(c.ReportTime)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
schedule := cronSchedule(c.ReportInterval, minute, hour)
|
||||
if binary == "" {
|
||||
binary = "/opt/adguard-shield/adguard-shield"
|
||||
}
|
||||
if configPath == "" {
|
||||
configPath = "/opt/adguard-shield/adguard-shield.conf"
|
||||
}
|
||||
line := fmt.Sprintf("SHELL=/bin/sh\nPATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin\n%s root %s -config %s report-send\n", schedule, binary, configPath)
|
||||
return os.WriteFile(cronPath, []byte(line), 0644)
|
||||
}
|
||||
|
||||
func RemoveCron() error {
|
||||
if err := os.Remove(cronPath); err != nil && !os.IsNotExist(err) {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func sendMail(ctx context.Context, c *config.Config, subject, body string) error {
|
||||
if c.ReportEmailTo == "" {
|
||||
return fmt.Errorf("REPORT_EMAIL_TO ist leer")
|
||||
}
|
||||
if c.ReportMailCmd == "" {
|
||||
return fmt.Errorf("REPORT_MAIL_CMD ist leer")
|
||||
}
|
||||
contentType := "text/plain; charset=utf-8"
|
||||
if strings.EqualFold(c.ReportFormat, "html") {
|
||||
contentType = "text/html; charset=utf-8"
|
||||
}
|
||||
msg := "From: " + c.ReportEmailFrom + "\n" +
|
||||
"To: " + c.ReportEmailTo + "\n" +
|
||||
"Subject: " + subject + "\n" +
|
||||
"Content-Type: " + contentType + "\n\n" + body
|
||||
parts := strings.Fields(c.ReportMailCmd)
|
||||
if len(parts) == 0 {
|
||||
return fmt.Errorf("REPORT_MAIL_CMD ist leer")
|
||||
}
|
||||
args := append(parts[1:], "-t")
|
||||
cmd := exec.CommandContext(ctx, parts[0], args...)
|
||||
cmd.Stdin = strings.NewReader(msg)
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
func parseReportTime(value string) (string, string, error) {
|
||||
parts := strings.Split(value, ":")
|
||||
if len(parts) != 2 {
|
||||
return "", "", fmt.Errorf("REPORT_TIME muss HH:MM sein")
|
||||
}
|
||||
hour, err := strconv.Atoi(parts[0])
|
||||
if err != nil || hour < 0 || hour > 23 {
|
||||
return "", "", fmt.Errorf("REPORT_TIME hat ungueltige Stunde")
|
||||
}
|
||||
minute, err := strconv.Atoi(parts[1])
|
||||
if err != nil || minute < 0 || minute > 59 {
|
||||
return "", "", fmt.Errorf("REPORT_TIME hat ungueltige Minute")
|
||||
}
|
||||
return strconv.Itoa(minute), strconv.Itoa(hour), nil
|
||||
}
|
||||
|
||||
func cronSchedule(interval, minute, hour string) string {
|
||||
switch strings.ToLower(interval) {
|
||||
case "daily":
|
||||
return fmt.Sprintf("%s %s * * *", minute, hour)
|
||||
case "biweekly":
|
||||
return fmt.Sprintf("%s %s 1,15 * *", minute, hour)
|
||||
case "monthly":
|
||||
return fmt.Sprintf("%s %s 1 * *", minute, hour)
|
||||
default:
|
||||
return fmt.Sprintf("%s %s * * 1", minute, hour)
|
||||
}
|
||||
}
|
||||
|
||||
func window(interval string) (int64, int64) {
|
||||
now := time.Now()
|
||||
days := 7
|
||||
switch strings.ToLower(interval) {
|
||||
case "daily":
|
||||
days = 1
|
||||
case "biweekly":
|
||||
days = 14
|
||||
case "monthly":
|
||||
days = 30
|
||||
}
|
||||
return now.AddDate(0, 0, -days).Unix(), now.Unix()
|
||||
}
|
||||
|
||||
func renderText(c *config.Config, st db.ReportStats) string {
|
||||
var b strings.Builder
|
||||
b.WriteString("AdGuard Shield Report\n")
|
||||
b.WriteString("Zeitraum: " + formatTime(st.Since) + " bis " + formatTime(st.Until) + "\n\n")
|
||||
b.WriteString("Bans: " + strconv.Itoa(st.TotalBans) + "\n")
|
||||
b.WriteString("Unbans: " + strconv.Itoa(st.TotalUnbans) + "\n")
|
||||
b.WriteString("Aktive Sperren: " + strconv.Itoa(st.ActiveBans) + "\n\n")
|
||||
writeCountsText(&b, "Top Clients", st.TopClients)
|
||||
writeCountsText(&b, "Gruende", st.Reasons)
|
||||
writeCountsText(&b, "Aktive Quellen", st.Sources)
|
||||
if len(st.RecentEvents) > 0 {
|
||||
b.WriteString("Letzte Ereignisse\n")
|
||||
for _, e := range st.RecentEvents {
|
||||
b.WriteString("- " + e + "\n")
|
||||
}
|
||||
}
|
||||
_ = c
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func renderHTML(c *config.Config, st db.ReportStats) string {
|
||||
var b bytes.Buffer
|
||||
b.WriteString("<!doctype html><html><head><meta charset=\"utf-8\"><title>AdGuard Shield Report</title>")
|
||||
b.WriteString("<style>body{font-family:Arial,sans-serif;color:#1f2937}table{border-collapse:collapse;margin:12px 0}td,th{border:1px solid #d1d5db;padding:6px 9px;text-align:left}th{background:#f3f4f6}</style>")
|
||||
b.WriteString("</head><body>")
|
||||
b.WriteString("<h1>AdGuard Shield Report</h1>")
|
||||
b.WriteString("<p>Zeitraum: " + html.EscapeString(formatTime(st.Since)) + " bis " + html.EscapeString(formatTime(st.Until)) + "</p>")
|
||||
b.WriteString("<ul><li>Bans: " + strconv.Itoa(st.TotalBans) + "</li><li>Unbans: " + strconv.Itoa(st.TotalUnbans) + "</li><li>Aktive Sperren: " + strconv.Itoa(st.ActiveBans) + "</li></ul>")
|
||||
writeCountsHTML(&b, "Top Clients", st.TopClients)
|
||||
writeCountsHTML(&b, "Gruende", st.Reasons)
|
||||
writeCountsHTML(&b, "Aktive Quellen", st.Sources)
|
||||
if len(st.RecentEvents) > 0 {
|
||||
b.WriteString("<h2>Letzte Ereignisse</h2><table><tr><th>Ereignis</th></tr>")
|
||||
for _, e := range st.RecentEvents {
|
||||
b.WriteString("<tr><td>" + html.EscapeString(e) + "</td></tr>")
|
||||
}
|
||||
b.WriteString("</table>")
|
||||
}
|
||||
b.WriteString("</body></html>")
|
||||
_ = c
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func writeCountsText(b *strings.Builder, title string, rows []db.ReportCount) {
|
||||
b.WriteString(title + "\n")
|
||||
if len(rows) == 0 {
|
||||
b.WriteString("- keine Daten\n\n")
|
||||
return
|
||||
}
|
||||
for _, r := range rows {
|
||||
b.WriteString("- " + r.Name + ": " + strconv.Itoa(r.Count) + "\n")
|
||||
}
|
||||
b.WriteByte('\n')
|
||||
}
|
||||
|
||||
func writeCountsHTML(b *bytes.Buffer, title string, rows []db.ReportCount) {
|
||||
b.WriteString("<h2>" + html.EscapeString(title) + "</h2><table><tr><th>Name</th><th>Anzahl</th></tr>")
|
||||
if len(rows) == 0 {
|
||||
b.WriteString("<tr><td colspan=\"2\">keine Daten</td></tr>")
|
||||
}
|
||||
for _, r := range rows {
|
||||
b.WriteString("<tr><td>" + html.EscapeString(r.Name) + "</td><td>" + strconv.Itoa(r.Count) + "</td></tr>")
|
||||
}
|
||||
b.WriteString("</table>")
|
||||
}
|
||||
|
||||
func formatTime(epoch int64) string {
|
||||
return time.Unix(epoch, 0).Format("2006-01-02 15:04:05")
|
||||
}
|
||||
|
||||
func hostname() string {
|
||||
name, err := os.Hostname()
|
||||
if err != nil || name == "" {
|
||||
return filepath.Base(os.Args[0])
|
||||
}
|
||||
return name
|
||||
}
|
||||
82
internal/syslog/syslog.go
Normal file
82
internal/syslog/syslog.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package syslog
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Level int
|
||||
|
||||
const (
|
||||
Debug Level = iota
|
||||
Info
|
||||
Warn
|
||||
Error
|
||||
)
|
||||
|
||||
type Logger struct {
|
||||
mu sync.Mutex
|
||||
min Level
|
||||
log *log.Logger
|
||||
}
|
||||
|
||||
func New(w io.Writer, min string) *Logger {
|
||||
return &Logger{
|
||||
min: ParseLevel(min, Info),
|
||||
log: log.New(w, "", log.LstdFlags),
|
||||
}
|
||||
}
|
||||
|
||||
func ParseLevel(s string, fallback Level) Level {
|
||||
switch strings.ToUpper(strings.TrimSpace(s)) {
|
||||
case "DEBUG":
|
||||
return Debug
|
||||
case "INFO", "":
|
||||
return Info
|
||||
case "WARN", "WARNING":
|
||||
return Warn
|
||||
case "ERROR", "ERR":
|
||||
return Error
|
||||
default:
|
||||
return fallback
|
||||
}
|
||||
}
|
||||
|
||||
func LevelName(l Level) string {
|
||||
switch l {
|
||||
case Debug:
|
||||
return "DEBUG"
|
||||
case Info:
|
||||
return "INFO"
|
||||
case Warn:
|
||||
return "WARN"
|
||||
case Error:
|
||||
return "ERROR"
|
||||
default:
|
||||
return "INFO"
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) Enabled(level Level) bool {
|
||||
if l == nil {
|
||||
return false
|
||||
}
|
||||
return level >= l.min
|
||||
}
|
||||
|
||||
func (l *Logger) Logf(level Level, format string, args ...any) {
|
||||
if !l.Enabled(level) {
|
||||
return
|
||||
}
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
l.log.Printf("[%s] [ADGUARD-SHIELDD] %s", LevelName(level), fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
func (l *Logger) Debugf(format string, args ...any) { l.Logf(Debug, format, args...) }
|
||||
func (l *Logger) Infof(format string, args ...any) { l.Logf(Info, format, args...) }
|
||||
func (l *Logger) Warnf(format string, args ...any) { l.Logf(Warn, format, args...) }
|
||||
func (l *Logger) Errorf(format string, args ...any) { l.Logf(Error, format, args...) }
|
||||
Reference in New Issue
Block a user