Files
keywarden/internal/security/security_test.go

351 lines
10 KiB
Go

// Keywarden - Centralized SSH Key Management and Deployment
// Copyright (C) 2026 Patrick Asmus (scriptos)
// SPDX-License-Identifier: AGPL-3.0-or-later
package security
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
)
// ---------- CSRF Middleware ----------
func TestCSRFMiddleware_SetsTokenCookie(t *testing.T) {
handler := CSRFMiddleware(false)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
cookies := rec.Result().Cookies()
var csrfCookie *http.Cookie
for _, c := range cookies {
if c.Name == "_csrf" {
csrfCookie = c
}
}
if csrfCookie == nil {
t.Fatal("expected _csrf cookie to be set on GET request")
}
if len(csrfCookie.Value) != 64 {
t.Fatalf("expected 64-char hex token, got %d chars", len(csrfCookie.Value))
}
if csrfCookie.SameSite != http.SameSiteStrictMode {
t.Fatal("expected SameSite=Strict on CSRF cookie")
}
}
func TestCSRFMiddleware_BlocksPOSTWithoutToken(t *testing.T) {
handler := CSRFMiddleware(false)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodPost, "/action", nil)
req.AddCookie(&http.Cookie{Name: "_csrf", Value: strings.Repeat("a", 64)})
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusForbidden {
t.Fatalf("expected 403 Forbidden for POST without matching token, got %d", rec.Code)
}
}
func TestCSRFMiddleware_AllowsPOSTWithValidToken(t *testing.T) {
token := strings.Repeat("ab", 32) // 64-char hex
handler := CSRFMiddleware(false)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
body := strings.NewReader("_csrf=" + token)
req := httptest.NewRequest(http.MethodPost, "/action", body)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(&http.Cookie{Name: "_csrf", Value: token})
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("expected 200 OK for POST with valid CSRF token, got %d", rec.Code)
}
}
func TestCSRFMiddleware_AllowsGETWithoutToken(t *testing.T) {
handler := CSRFMiddleware(false)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/page", nil)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("expected 200 OK for GET without CSRF token, got %d", rec.Code)
}
}
func TestCSRFMiddleware_AcceptsHeaderToken(t *testing.T) {
token := strings.Repeat("cd", 32) // 64-char hex
handler := CSRFMiddleware(false)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodPost, "/api/action", nil)
req.Header.Set("X-CSRF-Token", token)
req.AddCookie(&http.Cookie{Name: "_csrf", Value: token})
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("expected 200 OK for POST with X-CSRF-Token header, got %d", rec.Code)
}
}
// ---------- Security Headers Middleware ----------
func TestHeadersMiddleware_SetsSecurityHeaders(t *testing.T) {
handler := HeadersMiddleware()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/dashboard", nil)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
expected := map[string]string{
"X-Frame-Options": "DENY",
"X-Content-Type-Options": "nosniff",
"Referrer-Policy": "strict-origin-when-cross-origin",
}
for header, want := range expected {
got := rec.Header().Get(header)
if got != want {
t.Errorf("header %s: got %q, want %q", header, got, want)
}
}
csp := rec.Header().Get("Content-Security-Policy")
if csp == "" {
t.Fatal("expected Content-Security-Policy header to be set")
}
if !strings.Contains(csp, "frame-ancestors 'none'") {
t.Error("CSP should contain frame-ancestors 'none'")
}
if !strings.Contains(csp, "form-action 'self'") {
t.Error("CSP should contain form-action 'self'")
}
}
func TestHeadersMiddleware_SetsCacheControlForNonStatic(t *testing.T) {
handler := HeadersMiddleware()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/settings", nil)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
cc := rec.Header().Get("Cache-Control")
if !strings.Contains(cc, "no-store") {
t.Errorf("expected no-store in Cache-Control for non-static page, got %q", cc)
}
}
// ---------- Rate Limit Middleware ----------
func TestRateLimitMiddleware_BlocksAfterLimit(t *testing.T) {
handler := RateLimitMiddleware(3)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
for i := 0; i < 3; i++ {
req := httptest.NewRequest(http.MethodPost, "/login", nil)
req.RemoteAddr = "192.0.2.1:12345"
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("request %d: expected 200, got %d", i+1, rec.Code)
}
}
// 4th request should be blocked
req := httptest.NewRequest(http.MethodPost, "/login", nil)
req.RemoteAddr = "192.0.2.1:12345"
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusTooManyRequests {
t.Fatalf("expected 429 Too Many Requests, got %d", rec.Code)
}
}
func TestRateLimitMiddleware_DisabledWhenZero(t *testing.T) {
handler := RateLimitMiddleware(0)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
for i := 0; i < 100; i++ {
req := httptest.NewRequest(http.MethodPost, "/login", nil)
req.RemoteAddr = "192.0.2.1:12345"
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("request %d: expected 200 (rate limiting disabled), got %d", i+1, rec.Code)
}
}
}
func TestRateLimitMiddleware_AllowsGETLogin(t *testing.T) {
handler := RateLimitMiddleware(1)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// Exhaust POST limit
req := httptest.NewRequest(http.MethodPost, "/login", nil)
req.RemoteAddr = "192.0.2.1:12345"
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
// GET should still work
req = httptest.NewRequest(http.MethodGet, "/login", nil)
req.RemoteAddr = "192.0.2.1:12345"
rec = httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("expected GET /login to pass rate limit, got %d", rec.Code)
}
}
func TestRateLimitMiddleware_SeparatesIPs(t *testing.T) {
handler := RateLimitMiddleware(1)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// Exhaust limit for IP 1
req := httptest.NewRequest(http.MethodPost, "/login", nil)
req.RemoteAddr = "192.0.2.1:12345"
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
// IP 2 should still be allowed
req = httptest.NewRequest(http.MethodPost, "/login", nil)
req.RemoteAddr = "192.0.2.2:12345"
rec = httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("expected different IP to be allowed, got %d", rec.Code)
}
}
// ---------- Size Limit Middleware ----------
func TestSizeLimitMiddleware_BlocksOversizedBody(t *testing.T) {
handler := SizeLimitMiddleware(10)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
buf := make([]byte, 1024)
_, err := r.Body.Read(buf)
if err != nil {
http.Error(w, "body too large", http.StatusRequestEntityTooLarge)
return
}
w.WriteHeader(http.StatusOK)
}))
body := strings.NewReader(strings.Repeat("x", 100))
req := httptest.NewRequest(http.MethodPost, "/upload", body)
req.Header.Set("Content-Type", "application/octet-stream")
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Code == http.StatusOK {
t.Fatal("expected request with body > 10 bytes to be rejected")
}
}
func TestSizeLimitMiddleware_DisabledWhenZero(t *testing.T) {
handler := SizeLimitMiddleware(0)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
body := strings.NewReader(strings.Repeat("x", 1000))
req := httptest.NewRequest(http.MethodPost, "/upload", body)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("expected 200 with size limit disabled, got %d", rec.Code)
}
}
// ---------- Proxy / ClientIP ----------
func TestClientIP_RemoteAddrFallback(t *testing.T) {
Init("") // no trusted proxies
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.RemoteAddr = "10.0.0.1:54321"
ip := ClientIP(req)
if ip != "10.0.0.1" {
t.Fatalf("expected 10.0.0.1, got %s", ip)
}
}
func TestClientIP_XForwardedFor_Legacy(t *testing.T) {
Init("") // legacy mode, trusts headers
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.RemoteAddr = "10.0.0.1:54321"
req.Header.Set("X-Forwarded-For", "203.0.113.50, 10.0.0.1")
ip := ClientIP(req)
if ip != "203.0.113.50" {
t.Fatalf("expected leftmost XFF IP 203.0.113.50, got %s", ip)
}
}
func TestClientIP_TrustedProxies(t *testing.T) {
Init("10.0.0.0/8")
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.RemoteAddr = "10.0.0.1:54321"
req.Header.Set("X-Forwarded-For", "203.0.113.50, 10.0.0.2")
ip := ClientIP(req)
if ip != "203.0.113.50" {
t.Fatalf("expected rightmost untrusted IP 203.0.113.50, got %s", ip)
}
}
func TestClientIP_UntrustedPeerIgnoresHeaders(t *testing.T) {
Init("10.0.0.0/8")
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.RemoteAddr = "203.0.113.99:54321" // not in trusted range
req.Header.Set("X-Forwarded-For", "1.2.3.4")
ip := ClientIP(req)
if ip != "203.0.113.99" {
t.Fatalf("expected direct peer IP when not trusted, got %s", ip)
}
}
// ---------- isStaticAsset ----------
func TestIsStaticAsset(t *testing.T) {
tests := []struct {
path string
want bool
}{
{"/static/css/style.css", true},
{"/avatar/1.png", true},
{"/dashboard", false},
{"/login", false},
{"", false},
{"/short", false},
}
for _, tt := range tests {
got := isStaticAsset(tt.path)
if got != tt.want {
t.Errorf("isStaticAsset(%q) = %v, want %v", tt.path, got, tt.want)
}
}
}