351 lines
10 KiB
Go
351 lines
10 KiB
Go
// Keywarden - Centralized SSH Key Management and Deployment
|
|
// Copyright (C) 2026 Patrick Asmus (scriptos)
|
|
// SPDX-License-Identifier: AGPL-3.0-or-later
|
|
|
|
package security
|
|
|
|
import (
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
)
|
|
|
|
// ---------- CSRF Middleware ----------
|
|
|
|
func TestCSRFMiddleware_SetsTokenCookie(t *testing.T) {
|
|
handler := CSRFMiddleware(false)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
rec := httptest.NewRecorder()
|
|
handler.ServeHTTP(rec, req)
|
|
|
|
cookies := rec.Result().Cookies()
|
|
var csrfCookie *http.Cookie
|
|
for _, c := range cookies {
|
|
if c.Name == "_csrf" {
|
|
csrfCookie = c
|
|
}
|
|
}
|
|
if csrfCookie == nil {
|
|
t.Fatal("expected _csrf cookie to be set on GET request")
|
|
}
|
|
if len(csrfCookie.Value) != 64 {
|
|
t.Fatalf("expected 64-char hex token, got %d chars", len(csrfCookie.Value))
|
|
}
|
|
if csrfCookie.SameSite != http.SameSiteStrictMode {
|
|
t.Fatal("expected SameSite=Strict on CSRF cookie")
|
|
}
|
|
}
|
|
|
|
func TestCSRFMiddleware_BlocksPOSTWithoutToken(t *testing.T) {
|
|
handler := CSRFMiddleware(false)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "/action", nil)
|
|
req.AddCookie(&http.Cookie{Name: "_csrf", Value: strings.Repeat("a", 64)})
|
|
rec := httptest.NewRecorder()
|
|
handler.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusForbidden {
|
|
t.Fatalf("expected 403 Forbidden for POST without matching token, got %d", rec.Code)
|
|
}
|
|
}
|
|
|
|
func TestCSRFMiddleware_AllowsPOSTWithValidToken(t *testing.T) {
|
|
token := strings.Repeat("ab", 32) // 64-char hex
|
|
handler := CSRFMiddleware(false)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
body := strings.NewReader("_csrf=" + token)
|
|
req := httptest.NewRequest(http.MethodPost, "/action", body)
|
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
req.AddCookie(&http.Cookie{Name: "_csrf", Value: token})
|
|
rec := httptest.NewRecorder()
|
|
handler.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusOK {
|
|
t.Fatalf("expected 200 OK for POST with valid CSRF token, got %d", rec.Code)
|
|
}
|
|
}
|
|
|
|
func TestCSRFMiddleware_AllowsGETWithoutToken(t *testing.T) {
|
|
handler := CSRFMiddleware(false)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/page", nil)
|
|
rec := httptest.NewRecorder()
|
|
handler.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusOK {
|
|
t.Fatalf("expected 200 OK for GET without CSRF token, got %d", rec.Code)
|
|
}
|
|
}
|
|
|
|
func TestCSRFMiddleware_AcceptsHeaderToken(t *testing.T) {
|
|
token := strings.Repeat("cd", 32) // 64-char hex
|
|
handler := CSRFMiddleware(false)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "/api/action", nil)
|
|
req.Header.Set("X-CSRF-Token", token)
|
|
req.AddCookie(&http.Cookie{Name: "_csrf", Value: token})
|
|
rec := httptest.NewRecorder()
|
|
handler.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusOK {
|
|
t.Fatalf("expected 200 OK for POST with X-CSRF-Token header, got %d", rec.Code)
|
|
}
|
|
}
|
|
|
|
// ---------- Security Headers Middleware ----------
|
|
|
|
func TestHeadersMiddleware_SetsSecurityHeaders(t *testing.T) {
|
|
handler := HeadersMiddleware()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/dashboard", nil)
|
|
rec := httptest.NewRecorder()
|
|
handler.ServeHTTP(rec, req)
|
|
|
|
expected := map[string]string{
|
|
"X-Frame-Options": "DENY",
|
|
"X-Content-Type-Options": "nosniff",
|
|
"Referrer-Policy": "strict-origin-when-cross-origin",
|
|
}
|
|
for header, want := range expected {
|
|
got := rec.Header().Get(header)
|
|
if got != want {
|
|
t.Errorf("header %s: got %q, want %q", header, got, want)
|
|
}
|
|
}
|
|
|
|
csp := rec.Header().Get("Content-Security-Policy")
|
|
if csp == "" {
|
|
t.Fatal("expected Content-Security-Policy header to be set")
|
|
}
|
|
if !strings.Contains(csp, "frame-ancestors 'none'") {
|
|
t.Error("CSP should contain frame-ancestors 'none'")
|
|
}
|
|
if !strings.Contains(csp, "form-action 'self'") {
|
|
t.Error("CSP should contain form-action 'self'")
|
|
}
|
|
}
|
|
|
|
func TestHeadersMiddleware_SetsCacheControlForNonStatic(t *testing.T) {
|
|
handler := HeadersMiddleware()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/settings", nil)
|
|
rec := httptest.NewRecorder()
|
|
handler.ServeHTTP(rec, req)
|
|
|
|
cc := rec.Header().Get("Cache-Control")
|
|
if !strings.Contains(cc, "no-store") {
|
|
t.Errorf("expected no-store in Cache-Control for non-static page, got %q", cc)
|
|
}
|
|
}
|
|
|
|
// ---------- Rate Limit Middleware ----------
|
|
|
|
func TestRateLimitMiddleware_BlocksAfterLimit(t *testing.T) {
|
|
handler := RateLimitMiddleware(3)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
for i := 0; i < 3; i++ {
|
|
req := httptest.NewRequest(http.MethodPost, "/login", nil)
|
|
req.RemoteAddr = "192.0.2.1:12345"
|
|
rec := httptest.NewRecorder()
|
|
handler.ServeHTTP(rec, req)
|
|
if rec.Code != http.StatusOK {
|
|
t.Fatalf("request %d: expected 200, got %d", i+1, rec.Code)
|
|
}
|
|
}
|
|
|
|
// 4th request should be blocked
|
|
req := httptest.NewRequest(http.MethodPost, "/login", nil)
|
|
req.RemoteAddr = "192.0.2.1:12345"
|
|
rec := httptest.NewRecorder()
|
|
handler.ServeHTTP(rec, req)
|
|
if rec.Code != http.StatusTooManyRequests {
|
|
t.Fatalf("expected 429 Too Many Requests, got %d", rec.Code)
|
|
}
|
|
}
|
|
|
|
func TestRateLimitMiddleware_DisabledWhenZero(t *testing.T) {
|
|
handler := RateLimitMiddleware(0)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
for i := 0; i < 100; i++ {
|
|
req := httptest.NewRequest(http.MethodPost, "/login", nil)
|
|
req.RemoteAddr = "192.0.2.1:12345"
|
|
rec := httptest.NewRecorder()
|
|
handler.ServeHTTP(rec, req)
|
|
if rec.Code != http.StatusOK {
|
|
t.Fatalf("request %d: expected 200 (rate limiting disabled), got %d", i+1, rec.Code)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestRateLimitMiddleware_AllowsGETLogin(t *testing.T) {
|
|
handler := RateLimitMiddleware(1)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
// Exhaust POST limit
|
|
req := httptest.NewRequest(http.MethodPost, "/login", nil)
|
|
req.RemoteAddr = "192.0.2.1:12345"
|
|
rec := httptest.NewRecorder()
|
|
handler.ServeHTTP(rec, req)
|
|
|
|
// GET should still work
|
|
req = httptest.NewRequest(http.MethodGet, "/login", nil)
|
|
req.RemoteAddr = "192.0.2.1:12345"
|
|
rec = httptest.NewRecorder()
|
|
handler.ServeHTTP(rec, req)
|
|
if rec.Code != http.StatusOK {
|
|
t.Fatalf("expected GET /login to pass rate limit, got %d", rec.Code)
|
|
}
|
|
}
|
|
|
|
func TestRateLimitMiddleware_SeparatesIPs(t *testing.T) {
|
|
handler := RateLimitMiddleware(1)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
// Exhaust limit for IP 1
|
|
req := httptest.NewRequest(http.MethodPost, "/login", nil)
|
|
req.RemoteAddr = "192.0.2.1:12345"
|
|
rec := httptest.NewRecorder()
|
|
handler.ServeHTTP(rec, req)
|
|
|
|
// IP 2 should still be allowed
|
|
req = httptest.NewRequest(http.MethodPost, "/login", nil)
|
|
req.RemoteAddr = "192.0.2.2:12345"
|
|
rec = httptest.NewRecorder()
|
|
handler.ServeHTTP(rec, req)
|
|
if rec.Code != http.StatusOK {
|
|
t.Fatalf("expected different IP to be allowed, got %d", rec.Code)
|
|
}
|
|
}
|
|
|
|
// ---------- Size Limit Middleware ----------
|
|
|
|
func TestSizeLimitMiddleware_BlocksOversizedBody(t *testing.T) {
|
|
handler := SizeLimitMiddleware(10)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
buf := make([]byte, 1024)
|
|
_, err := r.Body.Read(buf)
|
|
if err != nil {
|
|
http.Error(w, "body too large", http.StatusRequestEntityTooLarge)
|
|
return
|
|
}
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
body := strings.NewReader(strings.Repeat("x", 100))
|
|
req := httptest.NewRequest(http.MethodPost, "/upload", body)
|
|
req.Header.Set("Content-Type", "application/octet-stream")
|
|
rec := httptest.NewRecorder()
|
|
handler.ServeHTTP(rec, req)
|
|
|
|
if rec.Code == http.StatusOK {
|
|
t.Fatal("expected request with body > 10 bytes to be rejected")
|
|
}
|
|
}
|
|
|
|
func TestSizeLimitMiddleware_DisabledWhenZero(t *testing.T) {
|
|
handler := SizeLimitMiddleware(0)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
body := strings.NewReader(strings.Repeat("x", 1000))
|
|
req := httptest.NewRequest(http.MethodPost, "/upload", body)
|
|
rec := httptest.NewRecorder()
|
|
handler.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusOK {
|
|
t.Fatalf("expected 200 with size limit disabled, got %d", rec.Code)
|
|
}
|
|
}
|
|
|
|
// ---------- Proxy / ClientIP ----------
|
|
|
|
func TestClientIP_RemoteAddrFallback(t *testing.T) {
|
|
Init("") // no trusted proxies
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
req.RemoteAddr = "10.0.0.1:54321"
|
|
|
|
ip := ClientIP(req)
|
|
if ip != "10.0.0.1" {
|
|
t.Fatalf("expected 10.0.0.1, got %s", ip)
|
|
}
|
|
}
|
|
|
|
func TestClientIP_XForwardedFor_Legacy(t *testing.T) {
|
|
Init("") // legacy mode, trusts headers
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
req.RemoteAddr = "10.0.0.1:54321"
|
|
req.Header.Set("X-Forwarded-For", "203.0.113.50, 10.0.0.1")
|
|
|
|
ip := ClientIP(req)
|
|
if ip != "203.0.113.50" {
|
|
t.Fatalf("expected leftmost XFF IP 203.0.113.50, got %s", ip)
|
|
}
|
|
}
|
|
|
|
func TestClientIP_TrustedProxies(t *testing.T) {
|
|
Init("10.0.0.0/8")
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
req.RemoteAddr = "10.0.0.1:54321"
|
|
req.Header.Set("X-Forwarded-For", "203.0.113.50, 10.0.0.2")
|
|
|
|
ip := ClientIP(req)
|
|
if ip != "203.0.113.50" {
|
|
t.Fatalf("expected rightmost untrusted IP 203.0.113.50, got %s", ip)
|
|
}
|
|
}
|
|
|
|
func TestClientIP_UntrustedPeerIgnoresHeaders(t *testing.T) {
|
|
Init("10.0.0.0/8")
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
req.RemoteAddr = "203.0.113.99:54321" // not in trusted range
|
|
req.Header.Set("X-Forwarded-For", "1.2.3.4")
|
|
|
|
ip := ClientIP(req)
|
|
if ip != "203.0.113.99" {
|
|
t.Fatalf("expected direct peer IP when not trusted, got %s", ip)
|
|
}
|
|
}
|
|
|
|
// ---------- isStaticAsset ----------
|
|
|
|
func TestIsStaticAsset(t *testing.T) {
|
|
tests := []struct {
|
|
path string
|
|
want bool
|
|
}{
|
|
{"/static/css/style.css", true},
|
|
{"/avatar/1.png", true},
|
|
{"/dashboard", false},
|
|
{"/login", false},
|
|
{"", false},
|
|
{"/short", false},
|
|
}
|
|
for _, tt := range tests {
|
|
got := isStaticAsset(tt.path)
|
|
if got != tt.want {
|
|
t.Errorf("isStaticAsset(%q) = %v, want %v", tt.path, got, tt.want)
|
|
}
|
|
}
|
|
}
|