4272 lines
135 KiB
Go
4272 lines
135 KiB
Go
// Keywarden - Centralized SSH Key Management and Deployment
|
||
// Copyright (C) 2026 Patrick Asmus (scriptos)
|
||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||
|
||
package handlers
|
||
|
||
import (
|
||
"bytes"
|
||
"crypto/hmac"
|
||
"crypto/rand"
|
||
"crypto/sha1"
|
||
"embed"
|
||
"encoding/base32"
|
||
"encoding/base64"
|
||
"encoding/binary"
|
||
"encoding/json"
|
||
"fmt"
|
||
"html/template"
|
||
"image"
|
||
_ "image/jpeg"
|
||
_ "image/png"
|
||
"io"
|
||
"io/fs"
|
||
"math"
|
||
"net/http"
|
||
"net/url"
|
||
"os"
|
||
"path/filepath"
|
||
"runtime"
|
||
"strconv"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
|
||
_ "golang.org/x/image/webp"
|
||
|
||
"git.techniverse.net/scriptos/keywarden/internal/audit"
|
||
"git.techniverse.net/scriptos/keywarden/internal/auth"
|
||
"git.techniverse.net/scriptos/keywarden/internal/cron"
|
||
"git.techniverse.net/scriptos/keywarden/internal/database"
|
||
"git.techniverse.net/scriptos/keywarden/internal/deploy"
|
||
"git.techniverse.net/scriptos/keywarden/internal/keys"
|
||
"git.techniverse.net/scriptos/keywarden/internal/logging"
|
||
"git.techniverse.net/scriptos/keywarden/internal/mail"
|
||
"git.techniverse.net/scriptos/keywarden/internal/models"
|
||
"git.techniverse.net/scriptos/keywarden/internal/security"
|
||
"git.techniverse.net/scriptos/keywarden/internal/servers"
|
||
"git.techniverse.net/scriptos/keywarden/internal/updater"
|
||
"git.techniverse.net/scriptos/keywarden/internal/worker"
|
||
)
|
||
|
||
// sessionData holds session metadata for timeout tracking
|
||
type sessionData struct {
|
||
UserID int64
|
||
LastActive time.Time
|
||
MFASetupRequired bool // set at login when admin enforces MFA but user hasn't configured it
|
||
}
|
||
|
||
// Handler holds all dependencies for HTTP handlers
|
||
type Handler struct {
|
||
auth *auth.Service
|
||
keys *keys.Service
|
||
servers *servers.Service
|
||
deploy *deploy.Service
|
||
audit *audit.Service
|
||
cron *cron.Service
|
||
worker *worker.Service
|
||
mail *mail.Service
|
||
updater *updater.Service
|
||
db *database.DB // direct database access for backup/restore
|
||
templates map[string]*template.Template
|
||
sessions map[string]*sessionData // cookie -> session data with timeout tracking
|
||
mu sync.RWMutex // protects sessions and pending maps
|
||
pending map[string]int64 // pending MFA sessions: cookie -> userID
|
||
staticFS http.Handler // serves embedded static assets
|
||
dataDir string // persistent data directory for avatars etc.
|
||
secureCookies bool // set Secure flag on cookies (HTTPS mode)
|
||
baseURL string // external base URL for links in emails
|
||
}
|
||
|
||
// Flash represents a flash message
|
||
type Flash struct {
|
||
Type string // "success", "danger", "warning"
|
||
Message string
|
||
}
|
||
|
||
// PageData is passed to every template
|
||
type PageData struct {
|
||
Title string
|
||
Active string
|
||
User interface{}
|
||
Flash *Flash
|
||
Data interface{}
|
||
|
||
// Dashboard specific
|
||
KeyCount int
|
||
ServerCount int
|
||
DeployCount int
|
||
UserCount int
|
||
GroupCount int
|
||
AssignmentCount int
|
||
RecentKeys interface{}
|
||
RecentDeploys interface{}
|
||
RecentAudit []audit.AuditEntry
|
||
UserRole string
|
||
|
||
// Keys
|
||
Keys interface{}
|
||
Key interface{}
|
||
|
||
// Servers
|
||
Servers interface{}
|
||
Server interface{}
|
||
|
||
// Server Groups
|
||
Groups interface{}
|
||
Group interface{}
|
||
GroupServers interface{}
|
||
AllServers interface{}
|
||
|
||
// Deploy
|
||
Deployments interface{}
|
||
|
||
// User management
|
||
Users []models.User
|
||
EditUser *models.User
|
||
|
||
// Settings
|
||
Settings map[string]string
|
||
|
||
// MFA
|
||
MFASecret string
|
||
MFAUri string
|
||
|
||
// Admin Settings: user list
|
||
AdminUsers []AdminUserInfo
|
||
|
||
// Audit Log
|
||
AuditEntries []audit.AuditEntry
|
||
AuditTotal int
|
||
AuditPage int
|
||
AuditTotalPages int
|
||
AuditPrevPage int
|
||
AuditNextPage int
|
||
AuditIsAdmin bool
|
||
AuditFilterUser bool
|
||
|
||
// Cron Jobs
|
||
CronJobs []models.CronJobDisplay
|
||
CronJob *models.CronJob
|
||
CronCount int
|
||
DaysOfMonth []int
|
||
|
||
// Access Assignments
|
||
Assignments []models.AccessAssignmentDisplay
|
||
Assignment *models.AccessAssignment
|
||
AssignAllUsers []models.User
|
||
AssignAllKeys []models.SSHKey
|
||
AssignAllHosts []models.Server
|
||
AssignAllGroups []models.ServerGroupWithCount
|
||
|
||
// Error (login page)
|
||
Error string
|
||
MFAPending bool
|
||
MFAToken string
|
||
|
||
// Email
|
||
EmailEnabled bool
|
||
|
||
// Password Policy
|
||
PasswordPolicy *models.PasswordPolicy
|
||
|
||
// MFA enforcement
|
||
MFARequired bool
|
||
|
||
// System Master Key
|
||
MasterKeyPublic string
|
||
MasterKeyFingerprint string
|
||
|
||
// System Information
|
||
SystemInfo *SystemInfo
|
||
|
||
// Key Enforcement
|
||
EnforcementStatus map[string]string
|
||
|
||
// Initial Owner protection
|
||
IsInitialOwner bool
|
||
InitialOwnerID int64
|
||
}
|
||
|
||
// SystemInfo holds runtime system information for the settings page
|
||
type SystemInfo struct {
|
||
GoVersion string
|
||
OS string
|
||
Arch string
|
||
NumCPU int
|
||
NumGoroutine int
|
||
MemAlloc string
|
||
MemSys string
|
||
Runtime string // e.g. "Docker" or "Native"
|
||
Hostname string
|
||
Uptime string
|
||
Timezone string
|
||
}
|
||
|
||
// AdminUserInfo holds user info for the admin settings page
|
||
type AdminUserInfo struct {
|
||
ID int64
|
||
Username string
|
||
Role string
|
||
}
|
||
|
||
// GroupOption represents a group option for server add/edit forms
|
||
type GroupOption struct {
|
||
ID int64
|
||
Name string
|
||
Description string
|
||
Selected bool
|
||
}
|
||
|
||
// startTime records when the application started
|
||
var startTime = time.Now()
|
||
|
||
// daysOfMonth returns a slice [1..31] for template dropdowns
|
||
func daysOfMonth() []int {
|
||
days := make([]int, 31)
|
||
for i := range days {
|
||
days[i] = i + 1
|
||
}
|
||
return days
|
||
}
|
||
|
||
// formatBytes converts bytes to a human-readable string
|
||
func formatBytes(b uint64) string {
|
||
const unit = 1024
|
||
if b < unit {
|
||
return fmt.Sprintf("%d B", b)
|
||
}
|
||
div, exp := uint64(unit), 0
|
||
for n := b / unit; n >= unit; n /= unit {
|
||
div *= unit
|
||
exp++
|
||
}
|
||
return fmt.Sprintf("%.1f %s", float64(b)/float64(div), []string{"KB", "MB", "GB", "TB"}[exp])
|
||
}
|
||
|
||
// formatUptime returns a human-readable uptime string
|
||
func formatUptime(start time.Time) string {
|
||
d := time.Since(start)
|
||
days := int(d.Hours()) / 24
|
||
hours := int(d.Hours()) % 24
|
||
minutes := int(d.Minutes()) % 60
|
||
if days > 0 {
|
||
return fmt.Sprintf("%dd %dh %dm", days, hours, minutes)
|
||
}
|
||
if hours > 0 {
|
||
return fmt.Sprintf("%dh %dm", hours, minutes)
|
||
}
|
||
return fmt.Sprintf("%dm", minutes)
|
||
}
|
||
|
||
// New creates a new Handler
|
||
func New(authSvc *auth.Service, keysSvc *keys.Service, serversSvc *servers.Service, deploySvc *deploy.Service, auditSvc *audit.Service, cronSvc *cron.Service, workerSvc *worker.Service, mailSvc *mail.Service, db *database.DB, templateFS embed.FS, staticFS embed.FS, dataDir string, secureCookies bool, baseURL string, updaterSvc *updater.Service) *Handler {
|
||
// Create sub-FS so /static/css/... maps to static/css/... in embed
|
||
staticSub, err := fs.Sub(staticFS, "static")
|
||
if err != nil {
|
||
logging.Fatal("Failed to create static sub-FS: %v", err)
|
||
}
|
||
|
||
// Ensure avatars directory exists
|
||
avatarsDir := filepath.Join(dataDir, "avatars")
|
||
if err := os.MkdirAll(avatarsDir, 0700); err != nil {
|
||
logging.Warn("Failed to create avatars directory %s: %v", avatarsDir, err)
|
||
}
|
||
|
||
// Ensure branding directory exists
|
||
brandingDir := filepath.Join(dataDir, "branding")
|
||
if err := os.MkdirAll(brandingDir, 0700); err != nil {
|
||
logging.Warn("Failed to create branding directory %s: %v", brandingDir, err)
|
||
}
|
||
|
||
h := &Handler{
|
||
auth: authSvc,
|
||
keys: keysSvc,
|
||
servers: serversSvc,
|
||
deploy: deploySvc,
|
||
audit: auditSvc,
|
||
cron: cronSvc,
|
||
worker: workerSvc,
|
||
mail: mailSvc,
|
||
updater: updaterSvc,
|
||
db: db,
|
||
sessions: make(map[string]*sessionData),
|
||
pending: make(map[string]int64),
|
||
staticFS: http.StripPrefix("/static/", http.FileServer(http.FS(staticSub))),
|
||
dataDir: dataDir,
|
||
secureCookies: secureCookies,
|
||
baseURL: baseURL,
|
||
}
|
||
|
||
h.loadTemplates(templateFS)
|
||
|
||
// Migrate any legacy base64 avatars to file-based storage
|
||
h.migrateAvatarsToFiles()
|
||
|
||
return h
|
||
}
|
||
|
||
func (h *Handler) loadTemplates(templateFS embed.FS) {
|
||
h.templates = make(map[string]*template.Template)
|
||
|
||
// Template functions available in all templates
|
||
funcMap := template.FuncMap{
|
||
"appName": func() string {
|
||
name, _ := h.auth.GetSetting("app_name")
|
||
if name == "" {
|
||
return "Keywarden"
|
||
}
|
||
return name
|
||
},
|
||
"appVersion": func() string {
|
||
return h.updater.CurrentVersion()
|
||
},
|
||
"updateAvailable": func() bool {
|
||
return h.updater.HasUpdate()
|
||
},
|
||
"latestVersion": func() string {
|
||
return h.updater.LatestVersion()
|
||
},
|
||
"releaseURL": func() string {
|
||
return h.updater.ReleaseURL()
|
||
},
|
||
"releasesPageURL": func() string {
|
||
return updater.ReleasesPageURL
|
||
},
|
||
"loginBgImage": func() string {
|
||
bgPath := filepath.Join(h.dataDir, "branding", "login_bg")
|
||
if _, err := os.Stat(bgPath); err == nil {
|
||
return "/branding/login-bg"
|
||
}
|
||
return ""
|
||
},
|
||
"loginTextColor": func() string {
|
||
c, _ := h.auth.GetSetting("login_text_color")
|
||
if c == "" {
|
||
return "light"
|
||
}
|
||
return c
|
||
},
|
||
// formatTime converts a time.Time to the app timezone and formats as "2006-01-02 15:04"
|
||
"formatTime": func(v interface{}) string {
|
||
switch t := v.(type) {
|
||
case time.Time:
|
||
return t.Local().Format("2006-01-02 15:04")
|
||
case *time.Time:
|
||
if t != nil {
|
||
return t.Local().Format("2006-01-02 15:04")
|
||
}
|
||
}
|
||
return ""
|
||
},
|
||
// formatDateTime converts a time.Time to the app timezone and formats as "2006-01-02 15:04:05"
|
||
"formatDateTime": func(v interface{}) string {
|
||
switch t := v.(type) {
|
||
case time.Time:
|
||
return t.Local().Format("2006-01-02 15:04:05")
|
||
case *time.Time:
|
||
if t != nil {
|
||
return t.Local().Format("2006-01-02 15:04:05")
|
||
}
|
||
}
|
||
return ""
|
||
},
|
||
// formatDateTimeLocal converts a time.Time to the app timezone and formats for HTML datetime-local inputs
|
||
"formatDateTimeLocal": func(v interface{}) string {
|
||
switch t := v.(type) {
|
||
case time.Time:
|
||
return t.Local().Format("2006-01-02T15:04")
|
||
case *time.Time:
|
||
if t != nil {
|
||
return t.Local().Format("2006-01-02T15:04")
|
||
}
|
||
}
|
||
return ""
|
||
},
|
||
}
|
||
|
||
baseLayout, err := fs.ReadFile(templateFS, "templates/layout/base.html")
|
||
if err != nil {
|
||
logging.Fatal("Failed to read base layout: %v", err)
|
||
}
|
||
|
||
pages := []string{
|
||
"dashboard", "keys", "keys_generate", "keys_import", "servers", "servers_add", "servers_edit",
|
||
"server_groups", "server_groups_add", "server_groups_edit",
|
||
"deploy", "audit", "users", "users_add", "users_edit", "settings", "mfa_setup",
|
||
"admin_settings", "system_info",
|
||
"cron", "cron_add", "cron_edit",
|
||
"assignments", "assignments_add", "assignments_edit",
|
||
}
|
||
for _, page := range pages {
|
||
pageContent, err := fs.ReadFile(templateFS, "templates/"+page+".html")
|
||
if err != nil {
|
||
logging.Fatal("Failed to read template %s: %v", page, err)
|
||
}
|
||
tmpl, err := template.New("base").Funcs(funcMap).Parse(string(baseLayout))
|
||
if err != nil {
|
||
logging.Fatal("Failed to parse base for %s: %v", page, err)
|
||
}
|
||
tmpl, err = tmpl.Parse(string(pageContent))
|
||
if err != nil {
|
||
logging.Fatal("Failed to parse page %s: %v", page, err)
|
||
}
|
||
h.templates[page] = tmpl
|
||
}
|
||
|
||
// Login has its own layout
|
||
loginContent, err := fs.ReadFile(templateFS, "templates/login.html")
|
||
if err != nil {
|
||
logging.Fatal("Failed to read login template: %v", err)
|
||
}
|
||
loginTmpl, err := template.New("login").Funcs(funcMap).Parse(string(loginContent))
|
||
if err != nil {
|
||
logging.Fatal("Failed to parse login: %v", err)
|
||
}
|
||
h.templates["login"] = loginTmpl
|
||
|
||
// Force password change page has its own layout (standalone, no sidebar)
|
||
fpcContent, err := fs.ReadFile(templateFS, "templates/force_password_change.html")
|
||
if err != nil {
|
||
logging.Fatal("Failed to read force_password_change template: %v", err)
|
||
}
|
||
fpcTmpl, err := template.New("force_password_change").Funcs(funcMap).Parse(string(fpcContent))
|
||
if err != nil {
|
||
logging.Fatal("Failed to parse force_password_change: %v", err)
|
||
}
|
||
h.templates["force_password_change"] = fpcTmpl
|
||
|
||
// MFA required page has its own layout (standalone, no sidebar)
|
||
mfaReqContent, err := fs.ReadFile(templateFS, "templates/mfa_required.html")
|
||
if err != nil {
|
||
logging.Fatal("Failed to read mfa_required template: %v", err)
|
||
}
|
||
mfaReqTmpl, err := template.New("mfa_required").Funcs(funcMap).Parse(string(mfaReqContent))
|
||
if err != nil {
|
||
logging.Fatal("Failed to parse mfa_required: %v", err)
|
||
}
|
||
h.templates["mfa_required"] = mfaReqTmpl
|
||
|
||
// Invitation acceptance page has its own layout (standalone, no sidebar)
|
||
inviteContent, err := fs.ReadFile(templateFS, "templates/invite_accept.html")
|
||
if err != nil {
|
||
logging.Fatal("Failed to read invite_accept template: %v", err)
|
||
}
|
||
inviteTmpl, err := template.New("invite_accept").Funcs(funcMap).Parse(string(inviteContent))
|
||
if err != nil {
|
||
logging.Fatal("Failed to parse invite_accept: %v", err)
|
||
}
|
||
h.templates["invite_accept"] = inviteTmpl
|
||
}
|
||
|
||
// RegisterRoutes sets up all HTTP routes
|
||
func (h *Handler) RegisterRoutes(mux *http.ServeMux) {
|
||
// Static assets (CSS, JS, fonts) – served with long cache headers
|
||
mux.HandleFunc("/static/", h.handleStatic)
|
||
|
||
// Public routes
|
||
mux.HandleFunc("/branding/login-bg", h.handleLoginBgServe)
|
||
mux.HandleFunc("/login", h.handleLogin)
|
||
mux.HandleFunc("/login/mfa", h.handleLoginMFA)
|
||
mux.HandleFunc("/logout", h.handleLogout)
|
||
mux.HandleFunc("/invite/", h.handleInviteAccept)
|
||
|
||
// MFA enforcement (requires auth but shown without sidebar)
|
||
mux.HandleFunc("/mfa/setup", h.requireAuth(h.handleMFAEnforce))
|
||
|
||
// Protected routes (all authenticated users)
|
||
mux.HandleFunc("/", h.requireAuth(h.handleRoot))
|
||
mux.HandleFunc("/dashboard", h.requireAuth(h.handleDashboard))
|
||
mux.HandleFunc("/password/change", h.requireAuth(h.handleForcePasswordChange))
|
||
mux.HandleFunc("/keys", h.requireAuth(h.handleKeys))
|
||
mux.HandleFunc("/keys/generate", h.requireAuth(h.handleKeysGenerate))
|
||
mux.HandleFunc("/keys/import", h.requireAuth(h.handleKeysImport))
|
||
mux.HandleFunc("/keys/", h.requireAuth(h.handleKeyAction))
|
||
mux.HandleFunc("/settings", h.requireAuth(h.handleSettings))
|
||
mux.HandleFunc("/settings/theme", h.requireAuth(h.handleThemeChange))
|
||
mux.HandleFunc("/settings/mfa/setup", h.requireAuth(h.handleMFASetup))
|
||
mux.HandleFunc("/settings/mfa/disable", h.requireAuth(h.handleMFADisable))
|
||
mux.HandleFunc("/settings/email/notify", h.requireAuth(h.handleEmailNotifyToggle))
|
||
mux.HandleFunc("/settings/avatar", h.requireAuth(h.handleAvatarUpload))
|
||
mux.HandleFunc("/avatar/", h.requireAuth(h.handleAvatarServe))
|
||
mux.HandleFunc("/audit", h.requireAuth(h.handleAudit))
|
||
mux.HandleFunc("/my/access", h.requireAuth(h.handleMyAssignments))
|
||
|
||
// Admin-only routes (admin + owner)
|
||
mux.HandleFunc("/servers", h.requireAdmin(h.handleServers))
|
||
mux.HandleFunc("/servers/add", h.requireAdmin(h.handleServersAdd))
|
||
mux.HandleFunc("/servers/test", h.requireAdmin(h.handleServerTest))
|
||
mux.HandleFunc("/servers/test-auth", h.requireAdmin(h.handleServerTestAuth))
|
||
mux.HandleFunc("/servers/", h.requireAdmin(h.handleServerAction))
|
||
mux.HandleFunc("/groups", h.requireAdmin(h.handleServerGroups))
|
||
mux.HandleFunc("/groups/add", h.requireAdmin(h.handleServerGroupsAdd))
|
||
mux.HandleFunc("/groups/", h.requireAdmin(h.handleServerGroupAction))
|
||
mux.HandleFunc("/deploy", h.requireAdmin(h.handleDeploy))
|
||
mux.HandleFunc("/deploy/group", h.requireAdmin(h.handleDeployGroup))
|
||
mux.HandleFunc("/cron", h.requireAdmin(h.handleCron))
|
||
mux.HandleFunc("/cron/add", h.requireAdmin(h.handleCronAdd))
|
||
mux.HandleFunc("/cron/", h.requireAdmin(h.handleCronAction))
|
||
mux.HandleFunc("/users", h.requireAdmin(h.handleUsers))
|
||
mux.HandleFunc("/users/add", h.requireAdmin(h.handleUsersAdd))
|
||
mux.HandleFunc("/users/", h.requireAdmin(h.handleUserAction))
|
||
mux.HandleFunc("/assignments", h.requireAdmin(h.handleAssignments))
|
||
mux.HandleFunc("/assignments/add", h.requireAdmin(h.handleAssignmentsAdd))
|
||
mux.HandleFunc("/assignments/", h.requireAdmin(h.handleAssignmentAction))
|
||
mux.HandleFunc("/system", h.requireAdmin(h.handleSystemInfo))
|
||
mux.HandleFunc("/admin/settings/email/test", h.requireOwner(h.handleAdminEmailTest))
|
||
|
||
// API endpoints (JSON)
|
||
mux.HandleFunc("/api/health", h.handleAPIHealth)
|
||
mux.HandleFunc("/api/cron/keys", h.requireAdmin(h.handleAPICronKeys))
|
||
|
||
// Owner-only routes
|
||
mux.HandleFunc("/admin/settings", h.requireOwner(h.handleAdminSettings))
|
||
mux.HandleFunc("/admin/branding/upload", h.requireOwner(h.handleLoginBrandingUpload))
|
||
mux.HandleFunc("/admin/branding/remove-bg", h.requireOwner(h.handleLoginBrandingRemoveBg))
|
||
mux.HandleFunc("/admin/masterkey/regenerate", h.requireOwner(h.handleMasterKeyRegenerate))
|
||
mux.HandleFunc("/admin/backup/export", h.requireOwner(h.handleBackupExport))
|
||
mux.HandleFunc("/admin/backup/import", h.requireOwner(h.handleBackupImport))
|
||
mux.HandleFunc("/admin/enforcement/run", h.requireOwner(h.handleEnforcementRunNow))
|
||
}
|
||
|
||
// handleAPIHealth returns a JSON health status (no auth required).
|
||
// Used by Docker HEALTHCHECK and external monitoring.
|
||
func (h *Handler) handleAPIHealth(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method != http.MethodGet {
|
||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||
return
|
||
}
|
||
|
||
// Check database connectivity
|
||
dbOK := true
|
||
if err := h.db.Ping(); err != nil {
|
||
dbOK = false
|
||
}
|
||
|
||
status := "healthy"
|
||
httpCode := http.StatusOK
|
||
if !dbOK {
|
||
status = "unhealthy"
|
||
httpCode = http.StatusServiceUnavailable
|
||
}
|
||
|
||
uptime := time.Since(startTime)
|
||
|
||
result := map[string]interface{}{
|
||
"status": status,
|
||
"uptime": formatUptime(startTime),
|
||
"uptime_seconds": int(uptime.Seconds()),
|
||
"checks": map[string]interface{}{
|
||
"database": map[string]interface{}{
|
||
"status": boolToStatus(dbOK),
|
||
},
|
||
},
|
||
}
|
||
|
||
w.Header().Set("Content-Type", "application/json")
|
||
w.WriteHeader(httpCode)
|
||
json.NewEncoder(w).Encode(result)
|
||
}
|
||
|
||
// boolToStatus converts a boolean to "ok" / "fail".
|
||
func boolToStatus(ok bool) string {
|
||
if ok {
|
||
return "ok"
|
||
}
|
||
return "fail"
|
||
}
|
||
|
||
// getSessionTimeout returns the configured session timeout duration.
|
||
// Falls back to 60 minutes if not configured or invalid.
|
||
func (h *Handler) getSessionTimeout() time.Duration {
|
||
val, err := h.auth.GetSetting("session_timeout")
|
||
if err != nil || val == "" {
|
||
return 60 * time.Minute
|
||
}
|
||
minutes, err := strconv.Atoi(val)
|
||
if err != nil || minutes < 1 {
|
||
return 60 * time.Minute
|
||
}
|
||
return time.Duration(minutes) * time.Minute
|
||
}
|
||
|
||
// StartSessionCleanup starts a background goroutine that periodically
|
||
// removes expired sessions from the in-memory map.
|
||
func (h *Handler) StartSessionCleanup() {
|
||
go func() {
|
||
ticker := time.NewTicker(1 * time.Minute)
|
||
defer ticker.Stop()
|
||
for range ticker.C {
|
||
timeout := h.getSessionTimeout()
|
||
now := time.Now()
|
||
h.mu.Lock()
|
||
for token, sess := range h.sessions {
|
||
if now.Sub(sess.LastActive) > timeout {
|
||
logging.Debug("Session expired for user ID %d (inactive for %v)", sess.UserID, now.Sub(sess.LastActive).Round(time.Second))
|
||
delete(h.sessions, token)
|
||
}
|
||
}
|
||
h.mu.Unlock()
|
||
}
|
||
}()
|
||
}
|
||
|
||
// Middleware: require authentication
|
||
func (h *Handler) requireAuth(next http.HandlerFunc) http.HandlerFunc {
|
||
return func(w http.ResponseWriter, r *http.Request) {
|
||
cookie, err := r.Cookie("keywarden_session")
|
||
if err != nil || cookie.Value == "" {
|
||
http.Redirect(w, r, "/login", http.StatusSeeOther)
|
||
return
|
||
}
|
||
|
||
h.mu.RLock()
|
||
sess, ok := h.sessions[cookie.Value]
|
||
h.mu.RUnlock()
|
||
if !ok {
|
||
http.Redirect(w, r, "/login", http.StatusSeeOther)
|
||
return
|
||
}
|
||
|
||
// Check session timeout (inactivity based)
|
||
timeout := h.getSessionTimeout()
|
||
if time.Since(sess.LastActive) > timeout {
|
||
h.mu.Lock()
|
||
delete(h.sessions, cookie.Value)
|
||
h.mu.Unlock()
|
||
logging.Info("Session expired for user ID %d due to inactivity (%v timeout)", sess.UserID, timeout)
|
||
http.SetCookie(w, &http.Cookie{
|
||
Name: "keywarden_session",
|
||
Value: "",
|
||
Path: "/",
|
||
HttpOnly: true,
|
||
Secure: h.secureCookies,
|
||
SameSite: http.SameSiteStrictMode,
|
||
MaxAge: -1,
|
||
})
|
||
http.Redirect(w, r, "/login", http.StatusSeeOther)
|
||
return
|
||
}
|
||
|
||
// Update last activity (sliding window)
|
||
h.mu.Lock()
|
||
sess.LastActive = time.Now()
|
||
h.mu.Unlock()
|
||
|
||
// Refresh cookie expiry (sliding cookie) so the browser keeps
|
||
// the cookie alive as long as the user is active.
|
||
http.SetCookie(w, &http.Cookie{
|
||
Name: "keywarden_session",
|
||
Value: cookie.Value,
|
||
Path: "/",
|
||
HttpOnly: true,
|
||
Secure: h.secureCookies,
|
||
SameSite: http.SameSiteStrictMode,
|
||
MaxAge: int(timeout.Seconds()),
|
||
})
|
||
|
||
user, err := h.auth.GetUserByID(sess.UserID)
|
||
if err != nil {
|
||
http.Redirect(w, r, "/login", http.StatusSeeOther)
|
||
return
|
||
}
|
||
|
||
// Store user in request context via header (simple approach)
|
||
r.Header.Set("X-User-ID", strconv.FormatInt(user.ID, 10))
|
||
r.Header.Set("X-User-Name", user.Username)
|
||
r.Header.Set("X-User-Role", user.Role)
|
||
|
||
// Force password change: redirect to /password/change unless already there
|
||
if user.MustChangePassword && r.URL.Path != "/password/change" {
|
||
http.Redirect(w, r, "/password/change", http.StatusSeeOther)
|
||
return
|
||
}
|
||
|
||
// MFA enforcement: only redirect if the session was flagged at login time
|
||
// This ensures already-logged-in users are not disrupted; enforcement
|
||
// takes effect on the next login.
|
||
if !user.MustChangePassword && sess.MFASetupRequired {
|
||
// If the user has since enabled MFA, clear the flag
|
||
if user.MFAEnabled {
|
||
h.mu.Lock()
|
||
sess.MFASetupRequired = false
|
||
h.mu.Unlock()
|
||
} else if r.URL.Path != "/mfa/setup" &&
|
||
!(user.Role == "owner" && strings.HasPrefix(r.URL.Path, "/admin/")) {
|
||
http.Redirect(w, r, "/mfa/setup", http.StatusSeeOther)
|
||
return
|
||
}
|
||
}
|
||
|
||
next(w, r)
|
||
}
|
||
}
|
||
|
||
// handleStatic serves embedded static assets with cache headers
|
||
func (h *Handler) handleStatic(w http.ResponseWriter, r *http.Request) {
|
||
// Cache static assets for 1 year (immutable – version pinned in filenames)
|
||
w.Header().Set("Cache-Control", "public, max-age=31536000, immutable")
|
||
h.staticFS.ServeHTTP(w, r)
|
||
}
|
||
|
||
// Middleware: require admin or owner role
|
||
func (h *Handler) requireAdmin(next http.HandlerFunc) http.HandlerFunc {
|
||
return h.requireAuth(func(w http.ResponseWriter, r *http.Request) {
|
||
role := r.Header.Get("X-User-Role")
|
||
if role != "admin" && role != "owner" {
|
||
http.Redirect(w, r, "/dashboard", http.StatusSeeOther)
|
||
return
|
||
}
|
||
next(w, r)
|
||
})
|
||
}
|
||
|
||
// Middleware: require owner role exclusively
|
||
func (h *Handler) requireOwner(next http.HandlerFunc) http.HandlerFunc {
|
||
return h.requireAuth(func(w http.ResponseWriter, r *http.Request) {
|
||
role := r.Header.Get("X-User-Role")
|
||
if role != "owner" {
|
||
http.Redirect(w, r, "/dashboard", http.StatusSeeOther)
|
||
return
|
||
}
|
||
next(w, r)
|
||
})
|
||
}
|
||
|
||
// isAdmin returns true if the role is admin or owner
|
||
func isAdmin(role string) bool {
|
||
return role == "admin" || role == "owner"
|
||
}
|
||
|
||
// isOwner returns true if the role is owner
|
||
func isOwner(role string) bool {
|
||
return role == "owner"
|
||
}
|
||
|
||
// getInitialOwnerID returns the user ID of the initial owner (0 if not set)
|
||
func (h *Handler) getInitialOwnerID() int64 {
|
||
return h.auth.GetInitialOwnerID()
|
||
}
|
||
|
||
func (h *Handler) getUserID(r *http.Request) int64 {
|
||
id, _ := strconv.ParseInt(r.Header.Get("X-User-ID"), 10, 64)
|
||
return id
|
||
}
|
||
|
||
// clientIP delegates to the security package for trusted-proxy-aware IP extraction.
|
||
func clientIP(r *http.Request) string {
|
||
return security.ClientIP(r)
|
||
}
|
||
|
||
// GetUserName returns the username from the request (for request logging middleware).
|
||
// Returns empty string if no user is authenticated.
|
||
func (h *Handler) GetUserName(r *http.Request) string {
|
||
return r.Header.Get("X-User-Name")
|
||
}
|
||
|
||
// --- Route Handlers ---
|
||
|
||
func (h *Handler) handleRoot(w http.ResponseWriter, r *http.Request) {
|
||
if r.URL.Path != "/" {
|
||
http.NotFound(w, r)
|
||
return
|
||
}
|
||
http.Redirect(w, r, "/dashboard", http.StatusSeeOther)
|
||
}
|
||
|
||
func (h *Handler) handleLogin(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method == http.MethodGet {
|
||
h.templates["login"].Execute(w, &PageData{})
|
||
return
|
||
}
|
||
|
||
username := r.FormValue("username")
|
||
password := r.FormValue("password")
|
||
|
||
user, err := h.auth.Login(username, password)
|
||
if err != nil {
|
||
if err == auth.ErrAccountLocked {
|
||
logging.Warn("Login blocked for locked account '%s' from IP %s", username, clientIP(r))
|
||
h.audit.Log(0, audit.ActionLoginFailed, fmt.Sprintf("Login attempt for locked account: %s", username), clientIP(r))
|
||
h.templates["login"].Execute(w, &PageData{Error: "Account is temporarily locked due to too many failed attempts. Please try again later."})
|
||
return
|
||
}
|
||
logging.Warn("Login failed for user '%s' from IP %s: %v", username, clientIP(r), err)
|
||
h.audit.Log(0, audit.ActionLoginFailed, fmt.Sprintf("Failed login attempt for user: %s", username), clientIP(r))
|
||
|
||
// Record failed login and potentially lock the account
|
||
h.auth.RecordFailedLogin(username)
|
||
|
||
h.templates["login"].Execute(w, &PageData{Error: "Invalid username or password"})
|
||
return
|
||
}
|
||
|
||
// Successful login – reset lockout counter and track last login
|
||
h.auth.ResetFailedLogins(user.ID)
|
||
h.auth.UpdateLastLogin(user.ID)
|
||
|
||
// Check if MFA is enabled
|
||
if user.MFAEnabled && user.MFASecret != "" {
|
||
// Create a pending MFA session
|
||
logging.Info("Login for user '%s' from IP %s – MFA verification pending", user.Username, clientIP(r))
|
||
token := generateSessionID()
|
||
h.mu.Lock()
|
||
h.pending[token] = user.ID
|
||
h.mu.Unlock()
|
||
h.templates["login"].Execute(w, &PageData{MFAPending: true, MFAToken: token})
|
||
return
|
||
}
|
||
|
||
// Create session directly
|
||
sessionID := generateSessionID()
|
||
timeout := h.getSessionTimeout()
|
||
|
||
// Check if MFA is required by admin but user hasn't set it up
|
||
mfaRequired, _ := h.auth.GetSetting("mfa_required")
|
||
needsMFASetup := mfaRequired == "true" && !user.MFAEnabled
|
||
|
||
h.mu.Lock()
|
||
h.sessions[sessionID] = &sessionData{UserID: user.ID, LastActive: time.Now(), MFASetupRequired: needsMFASetup}
|
||
h.mu.Unlock()
|
||
|
||
logging.Info("Login successful for user '%s' from IP %s", user.Username, clientIP(r))
|
||
h.audit.Log(user.ID, audit.ActionLoginSuccess, fmt.Sprintf("User %s logged in", user.Username), clientIP(r))
|
||
|
||
// Send login notification email (async)
|
||
h.sendLoginNotification(user, r)
|
||
|
||
http.SetCookie(w, &http.Cookie{
|
||
Name: "keywarden_session",
|
||
Value: sessionID,
|
||
Path: "/",
|
||
HttpOnly: true,
|
||
Secure: h.secureCookies,
|
||
SameSite: http.SameSiteStrictMode,
|
||
MaxAge: int(timeout.Seconds()),
|
||
})
|
||
|
||
http.Redirect(w, r, "/dashboard", http.StatusSeeOther)
|
||
}
|
||
|
||
func (h *Handler) handleLoginMFA(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method != http.MethodPost {
|
||
http.Redirect(w, r, "/login", http.StatusSeeOther)
|
||
return
|
||
}
|
||
|
||
token := r.FormValue("mfa_token")
|
||
code := r.FormValue("mfa_code")
|
||
|
||
h.mu.RLock()
|
||
userID, ok := h.pending[token]
|
||
h.mu.RUnlock()
|
||
if !ok {
|
||
h.templates["login"].Execute(w, &PageData{Error: "MFA session expired. Please login again."})
|
||
return
|
||
}
|
||
|
||
user, err := h.auth.GetUserByID(userID)
|
||
if err != nil {
|
||
h.templates["login"].Execute(w, &PageData{Error: "User not found."})
|
||
return
|
||
}
|
||
|
||
// Validate TOTP code
|
||
if !validateTOTP(user.MFASecret, code) {
|
||
logging.Warn("MFA verification failed for user '%s' from IP %s", user.Username, clientIP(r))
|
||
h.audit.Log(user.ID, audit.ActionMFAFailed, fmt.Sprintf("Failed MFA attempt for user: %s", user.Username), clientIP(r))
|
||
h.templates["login"].Execute(w, &PageData{MFAPending: true, MFAToken: token, Error: "Invalid MFA code. Please try again."})
|
||
return
|
||
}
|
||
|
||
// MFA verified, create session
|
||
h.mu.Lock()
|
||
delete(h.pending, token)
|
||
h.mu.Unlock()
|
||
|
||
// Track last login after MFA verification
|
||
h.auth.UpdateLastLogin(user.ID)
|
||
|
||
sessionID := generateSessionID()
|
||
timeout := h.getSessionTimeout()
|
||
h.mu.Lock()
|
||
h.sessions[sessionID] = &sessionData{UserID: user.ID, LastActive: time.Now()}
|
||
h.mu.Unlock()
|
||
|
||
logging.Info("Login successful for user '%s' from IP %s (MFA verified)", user.Username, clientIP(r))
|
||
h.audit.Log(user.ID, audit.ActionLoginSuccess, fmt.Sprintf("User %s logged in (MFA verified)", user.Username), clientIP(r))
|
||
|
||
// Send login notification email (async)
|
||
h.sendLoginNotification(user, r)
|
||
|
||
http.SetCookie(w, &http.Cookie{
|
||
Name: "keywarden_session",
|
||
Value: sessionID,
|
||
Path: "/",
|
||
HttpOnly: true,
|
||
Secure: h.secureCookies,
|
||
SameSite: http.SameSiteStrictMode,
|
||
MaxAge: int(timeout.Seconds()),
|
||
})
|
||
|
||
http.Redirect(w, r, "/dashboard", http.StatusSeeOther)
|
||
}
|
||
|
||
func (h *Handler) handleLogout(w http.ResponseWriter, r *http.Request) {
|
||
cookie, err := r.Cookie("keywarden_session")
|
||
if err == nil {
|
||
h.mu.Lock()
|
||
if sess, ok := h.sessions[cookie.Value]; ok {
|
||
if u, uErr := h.auth.GetUserByID(sess.UserID); uErr == nil {
|
||
logging.Info("User '%s' logged out from IP %s", u.Username, clientIP(r))
|
||
}
|
||
h.audit.Log(sess.UserID, audit.ActionLogout, "User logged out", clientIP(r))
|
||
}
|
||
delete(h.sessions, cookie.Value)
|
||
h.mu.Unlock()
|
||
}
|
||
|
||
http.SetCookie(w, &http.Cookie{
|
||
Name: "keywarden_session",
|
||
Value: "",
|
||
Path: "/",
|
||
HttpOnly: true,
|
||
Secure: h.secureCookies,
|
||
SameSite: http.SameSiteStrictMode,
|
||
MaxAge: -1,
|
||
})
|
||
|
||
http.Redirect(w, r, "/login", http.StatusSeeOther)
|
||
}
|
||
|
||
func (h *Handler) handleDashboard(w http.ResponseWriter, r *http.Request) {
|
||
userID := h.getUserID(r)
|
||
user, _ := h.auth.GetUserByID(userID)
|
||
role := r.Header.Get("X-User-Role")
|
||
deployments, _ := h.deploy.GetDeployments(userID)
|
||
cronCount := h.cron.CountByUser(userID)
|
||
|
||
var keyCount, serverCount, groupCount, assignmentCount, userCount int
|
||
var recentKeys interface{}
|
||
var recentAudit []audit.AuditEntry
|
||
|
||
if isAdmin(role) {
|
||
allKeys, _ := h.keys.GetAllKeys()
|
||
allServers, _ := h.servers.GetAllServers()
|
||
allGroups, _ := h.servers.GetAllGroups()
|
||
allAssignments, _ := h.servers.GetAllAssignments()
|
||
allUsers, _ := h.auth.GetAllUsers()
|
||
keyCount = len(allKeys)
|
||
serverCount = len(allServers)
|
||
groupCount = len(allGroups)
|
||
assignmentCount = len(allAssignments)
|
||
userCount = len(allUsers)
|
||
recentKeys = allKeys
|
||
entries, _, _ := h.audit.GetAll(1, 5)
|
||
recentAudit = entries
|
||
} else {
|
||
keyList, _ := h.keys.GetKeysByUser(userID)
|
||
userServers, _ := h.servers.GetByUser(userID)
|
||
userGroups, _ := h.servers.GetGroupsByUser(userID)
|
||
userAssignments, _ := h.servers.GetAssignmentsByUser(userID)
|
||
keyCount = len(keyList)
|
||
serverCount = len(userServers)
|
||
groupCount = len(userGroups)
|
||
assignmentCount = len(userAssignments)
|
||
recentKeys = keyList
|
||
entries, _, _ := h.audit.GetByUser(userID, 1, 5)
|
||
recentAudit = entries
|
||
}
|
||
|
||
data := &PageData{
|
||
Title: "Dashboard",
|
||
Active: "dashboard",
|
||
User: user,
|
||
UserRole: role,
|
||
KeyCount: keyCount,
|
||
ServerCount: serverCount,
|
||
DeployCount: len(deployments),
|
||
CronCount: cronCount,
|
||
UserCount: userCount,
|
||
GroupCount: groupCount,
|
||
AssignmentCount: assignmentCount,
|
||
RecentKeys: recentKeys,
|
||
RecentDeploys: deployments,
|
||
RecentAudit: recentAudit,
|
||
}
|
||
|
||
h.templates["dashboard"].ExecuteTemplate(w, "base", data)
|
||
}
|
||
|
||
func (h *Handler) handleKeys(w http.ResponseWriter, r *http.Request) {
|
||
userID := h.getUserID(r)
|
||
user, _ := h.auth.GetUserByID(userID)
|
||
role := r.Header.Get("X-User-Role")
|
||
|
||
data := &PageData{
|
||
Title: "SSH Keys",
|
||
Active: "keys",
|
||
User: user,
|
||
}
|
||
|
||
// Admin/Owner see all keys with owner info; User sees only own keys
|
||
if isAdmin(role) {
|
||
allKeys, _ := h.keys.GetAllKeysWithOwner()
|
||
data.Data = allKeys
|
||
data.Keys = allKeys
|
||
} else {
|
||
keyList, _ := h.keys.GetKeysByUser(userID)
|
||
data.Keys = keyList
|
||
}
|
||
|
||
h.templates["keys"].ExecuteTemplate(w, "base", data)
|
||
}
|
||
|
||
func (h *Handler) handleKeysGenerate(w http.ResponseWriter, r *http.Request) {
|
||
userID := h.getUserID(r)
|
||
user, _ := h.auth.GetUserByID(userID)
|
||
role := r.Header.Get("X-User-Role")
|
||
|
||
if r.Method == http.MethodGet {
|
||
data := &PageData{
|
||
Title: "Generate SSH Key",
|
||
Active: "keys",
|
||
User: user,
|
||
}
|
||
// Admin/Owner can generate keys for other users
|
||
if isAdmin(role) {
|
||
allUsers, _ := h.auth.GetAllUsers()
|
||
data.Users = allUsers
|
||
}
|
||
h.templates["keys_generate"].ExecuteTemplate(w, "base", data)
|
||
return
|
||
}
|
||
|
||
name := r.FormValue("name")
|
||
keyType := r.FormValue("key_type")
|
||
comment := r.FormValue("comment")
|
||
bits := 4096
|
||
if b := r.FormValue("bits"); b != "" {
|
||
bits, _ = strconv.Atoi(b)
|
||
}
|
||
|
||
// Admin/Owner can generate keys for a selected user
|
||
targetUserID := userID
|
||
if isAdmin(role) {
|
||
if tid := r.FormValue("target_user_id"); tid != "" {
|
||
if parsed, err := strconv.ParseInt(tid, 10, 64); err == nil && parsed > 0 {
|
||
targetUserID = parsed
|
||
}
|
||
}
|
||
}
|
||
|
||
_, err := h.keys.GenerateKey(targetUserID, name, keyType, bits, comment)
|
||
if err != nil {
|
||
logging.Warn("Key generation failed for user %d: %v", targetUserID, err)
|
||
data := &PageData{
|
||
Title: "Generate SSH Key",
|
||
Active: "keys",
|
||
User: user,
|
||
Flash: &Flash{Type: "danger", Message: "Failed to generate key: " + err.Error()},
|
||
}
|
||
if isAdmin(role) {
|
||
allUsers, _ := h.auth.GetAllUsers()
|
||
data.Users = allUsers
|
||
}
|
||
h.templates["keys_generate"].ExecuteTemplate(w, "base", data)
|
||
return
|
||
}
|
||
|
||
if targetUserID != userID {
|
||
targetUser, _ := h.auth.GetUserByID(targetUserID)
|
||
targetName := fmt.Sprintf("user_id=%d", targetUserID)
|
||
if targetUser != nil {
|
||
targetName = targetUser.Username
|
||
}
|
||
logging.Info("SSH key generated: type=%s name='%s' bits=%d for user '%s' by admin user_id=%d", keyType, name, bits, targetName, userID)
|
||
h.audit.Log(userID, audit.ActionKeyGenerated, fmt.Sprintf("Generated %s key: %s (%d bits) for user %s", keyType, name, bits, targetName), clientIP(r))
|
||
} else {
|
||
logging.Info("SSH key generated: type=%s name='%s' bits=%d user_id=%d", keyType, name, bits, userID)
|
||
h.audit.Log(userID, audit.ActionKeyGenerated, fmt.Sprintf("Generated %s key: %s (%d bits)", keyType, name, bits), clientIP(r))
|
||
}
|
||
http.Redirect(w, r, "/keys", http.StatusSeeOther)
|
||
}
|
||
|
||
func (h *Handler) handleKeysImport(w http.ResponseWriter, r *http.Request) {
|
||
userID := h.getUserID(r)
|
||
user, _ := h.auth.GetUserByID(userID)
|
||
role := r.Header.Get("X-User-Role")
|
||
|
||
if r.Method == http.MethodGet {
|
||
data := &PageData{
|
||
Title: "Import SSH Key",
|
||
Active: "keys",
|
||
User: user,
|
||
}
|
||
if isAdmin(role) {
|
||
allUsers, _ := h.auth.GetAllUsers()
|
||
data.Users = allUsers
|
||
}
|
||
h.templates["keys_import"].ExecuteTemplate(w, "base", data)
|
||
return
|
||
}
|
||
|
||
name := r.FormValue("name")
|
||
privateKey := r.FormValue("private_key")
|
||
|
||
if name == "" || privateKey == "" {
|
||
data := &PageData{
|
||
Title: "Import SSH Key",
|
||
Active: "keys",
|
||
User: user,
|
||
Flash: &Flash{Type: "danger", Message: "Name and private key are required."},
|
||
}
|
||
if isAdmin(role) {
|
||
allUsers, _ := h.auth.GetAllUsers()
|
||
data.Users = allUsers
|
||
}
|
||
h.templates["keys_import"].ExecuteTemplate(w, "base", data)
|
||
return
|
||
}
|
||
|
||
// Admin/Owner can import keys for a selected user
|
||
targetUserID := userID
|
||
if isAdmin(role) {
|
||
if tid := r.FormValue("target_user_id"); tid != "" {
|
||
if parsed, err := strconv.ParseInt(tid, 10, 64); err == nil && parsed > 0 {
|
||
targetUserID = parsed
|
||
}
|
||
}
|
||
}
|
||
|
||
_, err := h.keys.ImportKey(targetUserID, name, []byte(privateKey))
|
||
if err != nil {
|
||
logging.Warn("Key import failed for user %d: %v", targetUserID, err)
|
||
data := &PageData{
|
||
Title: "Import SSH Key",
|
||
Active: "keys",
|
||
User: user,
|
||
Flash: &Flash{Type: "danger", Message: "Failed to import key: " + err.Error()},
|
||
}
|
||
if isAdmin(role) {
|
||
allUsers, _ := h.auth.GetAllUsers()
|
||
data.Users = allUsers
|
||
}
|
||
h.templates["keys_import"].ExecuteTemplate(w, "base", data)
|
||
return
|
||
}
|
||
|
||
if targetUserID != userID {
|
||
targetUser, _ := h.auth.GetUserByID(targetUserID)
|
||
targetName := fmt.Sprintf("user_id=%d", targetUserID)
|
||
if targetUser != nil {
|
||
targetName = targetUser.Username
|
||
}
|
||
logging.Info("SSH key imported: name='%s' for user '%s' by admin user_id=%d", name, targetName, userID)
|
||
h.audit.Log(userID, audit.ActionKeyImported, fmt.Sprintf("Imported key: %s for user %s", name, targetName), clientIP(r))
|
||
} else {
|
||
logging.Info("SSH key imported: name='%s' user_id=%d", name, userID)
|
||
h.audit.Log(userID, audit.ActionKeyImported, fmt.Sprintf("Imported key: %s", name), clientIP(r))
|
||
}
|
||
http.Redirect(w, r, "/keys", http.StatusSeeOther)
|
||
}
|
||
|
||
func (h *Handler) handleKeyAction(w http.ResponseWriter, r *http.Request) {
|
||
userID := h.getUserID(r)
|
||
role := r.Header.Get("X-User-Role")
|
||
parts := strings.Split(strings.Trim(r.URL.Path, "/"), "/")
|
||
if len(parts) < 3 {
|
||
http.NotFound(w, r)
|
||
return
|
||
}
|
||
|
||
keyID, err := strconv.ParseInt(parts[1], 10, 64)
|
||
if err != nil {
|
||
http.NotFound(w, r)
|
||
return
|
||
}
|
||
|
||
action := parts[2]
|
||
|
||
switch action {
|
||
case "view":
|
||
// Admin/Owner can view any public key; User can only view own keys
|
||
if isAdmin(role) {
|
||
key, err := h.keys.GetKeyByIDGlobal(keyID)
|
||
if err != nil {
|
||
http.NotFound(w, r)
|
||
return
|
||
}
|
||
w.Header().Set("Content-Type", "text/plain")
|
||
w.Write([]byte(key.PublicKey))
|
||
} else {
|
||
key, err := h.keys.GetKeyByID(keyID, userID)
|
||
if err != nil {
|
||
http.NotFound(w, r)
|
||
return
|
||
}
|
||
w.Header().Set("Content-Type", "text/plain")
|
||
w.Write([]byte(key.PublicKey))
|
||
}
|
||
|
||
case "download":
|
||
// Private key download: only the key owner can download their own private key
|
||
key, err := h.keys.GetKeyByID(keyID, userID)
|
||
if err != nil {
|
||
http.NotFound(w, r)
|
||
return
|
||
}
|
||
logging.Info("Private key downloaded: name='%s' id=%d user_id=%d", key.Name, key.ID, userID)
|
||
h.audit.Log(userID, audit.ActionKeyDownload, fmt.Sprintf("Downloaded private key: %s (ID %d)", key.Name, key.ID), clientIP(r))
|
||
w.Header().Set("Content-Disposition", "attachment; filename="+key.Name+"_private.pem")
|
||
w.Header().Set("Content-Type", "application/x-pem-file")
|
||
w.Write([]byte(key.PrivateKeyEnc))
|
||
|
||
case "delete":
|
||
if r.Method == http.MethodPost {
|
||
logging.Info("SSH key deleted: id=%d by_user_id=%d", keyID, userID)
|
||
h.audit.Log(userID, audit.ActionKeyDeleted, fmt.Sprintf("Deleted SSH key ID %d", keyID), clientIP(r))
|
||
// Admin/Owner can delete any key; User can only delete own keys
|
||
if isAdmin(role) {
|
||
h.keys.DeleteKeyGlobal(keyID)
|
||
} else {
|
||
h.keys.DeleteKey(keyID, userID)
|
||
}
|
||
}
|
||
http.Redirect(w, r, "/keys", http.StatusSeeOther)
|
||
|
||
default:
|
||
http.NotFound(w, r)
|
||
}
|
||
}
|
||
|
||
func (h *Handler) handleServers(w http.ResponseWriter, r *http.Request) {
|
||
userID := h.getUserID(r)
|
||
user, _ := h.auth.GetUserByID(userID)
|
||
|
||
// Admin/Owner see all servers globally
|
||
serverList, _ := h.servers.GetAllServers()
|
||
|
||
data := &PageData{
|
||
Title: "Hosts",
|
||
Active: "servers",
|
||
User: user,
|
||
Servers: serverList,
|
||
}
|
||
h.templates["servers"].ExecuteTemplate(w, "base", data)
|
||
}
|
||
|
||
func (h *Handler) handleServersAdd(w http.ResponseWriter, r *http.Request) {
|
||
userID := h.getUserID(r)
|
||
user, _ := h.auth.GetUserByID(userID)
|
||
groups, _ := h.servers.GetAllGroups()
|
||
|
||
// Build group options
|
||
groupOptions := make([]GroupOption, len(groups))
|
||
for i, g := range groups {
|
||
groupOptions[i] = GroupOption{
|
||
ID: g.ID,
|
||
Name: g.Name,
|
||
Description: g.Description,
|
||
}
|
||
}
|
||
|
||
if r.Method == http.MethodGet {
|
||
data := &PageData{
|
||
Title: "Add Host",
|
||
Active: "servers",
|
||
User: user,
|
||
Data: groupOptions,
|
||
}
|
||
h.templates["servers_add"].ExecuteTemplate(w, "base", data)
|
||
return
|
||
}
|
||
|
||
name := r.FormValue("name")
|
||
hostname := r.FormValue("hostname")
|
||
port, _ := strconv.Atoi(r.FormValue("port"))
|
||
username := r.FormValue("username")
|
||
description := r.FormValue("description")
|
||
|
||
srv, err := h.servers.Create(userID, name, hostname, port, username, description)
|
||
if err != nil {
|
||
data := &PageData{
|
||
Title: "Add Host",
|
||
Active: "servers",
|
||
User: user,
|
||
Data: groupOptions,
|
||
Flash: &Flash{Type: "danger", Message: "Failed to add host: " + err.Error()},
|
||
}
|
||
h.templates["servers_add"].ExecuteTemplate(w, "base", data)
|
||
return
|
||
}
|
||
|
||
// Assign server to selected groups
|
||
groupIDStrs := r.Form["group_ids"]
|
||
if len(groupIDStrs) > 0 {
|
||
var groupIDs []int64
|
||
for _, gidStr := range groupIDStrs {
|
||
gid, err := strconv.ParseInt(gidStr, 10, 64)
|
||
if err == nil {
|
||
groupIDs = append(groupIDs, gid)
|
||
}
|
||
}
|
||
h.servers.SetServerGroupsGlobal(srv.ID, groupIDs)
|
||
}
|
||
|
||
h.audit.Log(userID, audit.ActionServerAdded, fmt.Sprintf("Added server: %s (%s:%d)", name, hostname, port), clientIP(r))
|
||
http.Redirect(w, r, "/servers", http.StatusSeeOther)
|
||
}
|
||
|
||
func (h *Handler) handleServerAction(w http.ResponseWriter, r *http.Request) {
|
||
userID := h.getUserID(r)
|
||
parts := strings.Split(strings.Trim(r.URL.Path, "/"), "/")
|
||
if len(parts) < 3 {
|
||
http.NotFound(w, r)
|
||
return
|
||
}
|
||
|
||
serverID, err := strconv.ParseInt(parts[1], 10, 64)
|
||
if err != nil {
|
||
http.NotFound(w, r)
|
||
return
|
||
}
|
||
|
||
action := parts[2]
|
||
|
||
switch action {
|
||
case "delete":
|
||
if r.Method == http.MethodPost {
|
||
h.audit.Log(userID, audit.ActionServerDeleted, fmt.Sprintf("Deleted server ID %d", serverID), clientIP(r))
|
||
h.servers.DeleteGlobal(serverID)
|
||
}
|
||
http.Redirect(w, r, "/servers", http.StatusSeeOther)
|
||
case "edit":
|
||
user, _ := h.auth.GetUserByID(userID)
|
||
srv, err := h.servers.GetByIDGlobal(serverID)
|
||
if err != nil {
|
||
http.Redirect(w, r, "/servers", http.StatusSeeOther)
|
||
return
|
||
}
|
||
|
||
groups, _ := h.servers.GetAllGroups()
|
||
currentGroupIDs, _ := h.servers.GetGroupIDsForServerGlobal(serverID)
|
||
currentGroupMap := make(map[int64]bool)
|
||
for _, gid := range currentGroupIDs {
|
||
currentGroupMap[gid] = true
|
||
}
|
||
|
||
groupOptions := make([]GroupOption, len(groups))
|
||
for i, g := range groups {
|
||
groupOptions[i] = GroupOption{
|
||
ID: g.ID,
|
||
Name: g.Name,
|
||
Description: g.Description,
|
||
Selected: currentGroupMap[g.ID],
|
||
}
|
||
}
|
||
|
||
if r.Method == http.MethodGet {
|
||
data := &PageData{
|
||
Title: "Edit Host",
|
||
Active: "servers",
|
||
User: user,
|
||
Server: srv,
|
||
Data: groupOptions,
|
||
}
|
||
h.templates["servers_edit"].ExecuteTemplate(w, "base", data)
|
||
return
|
||
}
|
||
|
||
// POST: update server
|
||
name := r.FormValue("name")
|
||
hostname := r.FormValue("hostname")
|
||
port, _ := strconv.Atoi(r.FormValue("port"))
|
||
username := r.FormValue("username")
|
||
description := r.FormValue("description")
|
||
|
||
if err := h.servers.UpdateGlobal(serverID, name, hostname, port, username, description); err != nil {
|
||
data := &PageData{
|
||
Title: "Edit Host",
|
||
Active: "servers",
|
||
User: user,
|
||
Server: srv,
|
||
Data: groupOptions,
|
||
Flash: &Flash{Type: "danger", Message: "Failed to update host: " + err.Error()},
|
||
}
|
||
h.templates["servers_edit"].ExecuteTemplate(w, "base", data)
|
||
return
|
||
}
|
||
|
||
// Update group assignments
|
||
groupIDStrs := r.Form["group_ids"]
|
||
var groupIDs []int64
|
||
for _, gidStr := range groupIDStrs {
|
||
gid, err := strconv.ParseInt(gidStr, 10, 64)
|
||
if err == nil {
|
||
groupIDs = append(groupIDs, gid)
|
||
}
|
||
}
|
||
h.servers.SetServerGroupsGlobal(serverID, groupIDs)
|
||
|
||
h.audit.Log(userID, audit.ActionServerUpdated, fmt.Sprintf("Updated server: %s (%s:%d)", name, hostname, port), clientIP(r))
|
||
http.Redirect(w, r, "/servers", http.StatusSeeOther)
|
||
default:
|
||
http.NotFound(w, r)
|
||
}
|
||
}
|
||
|
||
func (h *Handler) handleServerTest(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method != http.MethodPost {
|
||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||
return
|
||
}
|
||
|
||
userID := h.getUserID(r)
|
||
serverID, _ := strconv.ParseInt(r.FormValue("server_id"), 10, 64)
|
||
|
||
server, err := h.servers.GetByIDGlobal(serverID)
|
||
if err != nil {
|
||
w.Header().Set("Content-Type", "application/json")
|
||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||
"success": false,
|
||
"message": "Server not found.",
|
||
})
|
||
return
|
||
}
|
||
|
||
// Test TCP connectivity
|
||
err = h.deploy.TestConnection(server.Hostname, server.Port)
|
||
|
||
w.Header().Set("Content-Type", "application/json")
|
||
if err != nil {
|
||
h.audit.Log(userID, audit.ActionServerTest, fmt.Sprintf("Connection test failed for %s:%d: %v", server.Hostname, server.Port, err), clientIP(r))
|
||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||
"success": false,
|
||
"message": fmt.Sprintf("Connection failed: %v", err),
|
||
})
|
||
} else {
|
||
h.audit.Log(userID, audit.ActionServerTest, fmt.Sprintf("Connection test OK for %s:%d", server.Hostname, server.Port), clientIP(r))
|
||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||
"success": true,
|
||
"message": fmt.Sprintf("Connection to %s:%d successful (SSH port reachable).", server.Hostname, server.Port),
|
||
})
|
||
}
|
||
}
|
||
|
||
func (h *Handler) handleServerTestAuth(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method != http.MethodPost {
|
||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||
return
|
||
}
|
||
|
||
userID := h.getUserID(r)
|
||
serverID, _ := strconv.ParseInt(r.FormValue("server_id"), 10, 64)
|
||
|
||
server, err := h.servers.GetByIDGlobal(serverID)
|
||
if err != nil {
|
||
w.Header().Set("Content-Type", "application/json")
|
||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||
"success": false,
|
||
"message": "Server not found.",
|
||
})
|
||
return
|
||
}
|
||
|
||
// Use system master key for auth test
|
||
masterKeyPEM, err := h.keys.GetSystemMasterKeyPrivate()
|
||
if err != nil {
|
||
w.Header().Set("Content-Type", "application/json")
|
||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||
"success": false,
|
||
"message": "System master key not available. Check Admin Settings.",
|
||
})
|
||
return
|
||
}
|
||
|
||
// Test actual SSH authentication with system master key
|
||
err = h.deploy.TestSSHAuth(server.Hostname, server.Port, server.Username, masterKeyPEM)
|
||
|
||
w.Header().Set("Content-Type", "application/json")
|
||
if err != nil {
|
||
h.audit.Log(userID, audit.ActionServerAuth, fmt.Sprintf("SSH auth test failed for %s@%s:%d: %v", server.Username, server.Hostname, server.Port, err), clientIP(r))
|
||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||
"success": false,
|
||
"message": fmt.Sprintf("SSH authentication failed: %v", err),
|
||
})
|
||
} else {
|
||
h.audit.Log(userID, audit.ActionServerAuth, fmt.Sprintf("SSH auth test OK for %s@%s:%d", server.Username, server.Hostname, server.Port), clientIP(r))
|
||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||
"success": true,
|
||
"message": fmt.Sprintf("SSH login to %s@%s:%d successful!", server.Username, server.Hostname, server.Port),
|
||
})
|
||
}
|
||
}
|
||
|
||
// masterKeyForDeploy returns the system master key as a virtual SSHKey entry for deployment.
|
||
// Returns nil if the master key is not available.
|
||
func (h *Handler) masterKeyForDeploy() *models.SSHKey {
|
||
pub, err := h.keys.GetSystemMasterKeyPublic()
|
||
if err != nil || pub == "" {
|
||
return nil
|
||
}
|
||
fp, _ := h.keys.GetSystemMasterKeyFingerprint()
|
||
return &models.SSHKey{
|
||
ID: -1,
|
||
UserID: 0,
|
||
Name: "[MASTER] System Master Key",
|
||
KeyType: "ed25519",
|
||
PublicKey: pub,
|
||
Fingerprint: fp,
|
||
}
|
||
}
|
||
|
||
// prependMasterKey adds the system master key to the key list if the user is an owner.
|
||
func (h *Handler) prependMasterKey(keyList []models.SSHKey, role string) []models.SSHKey {
|
||
if !isOwner(role) {
|
||
return keyList
|
||
}
|
||
mk := h.masterKeyForDeploy()
|
||
if mk == nil {
|
||
return keyList
|
||
}
|
||
return append([]models.SSHKey{*mk}, keyList...)
|
||
}
|
||
|
||
func (h *Handler) handleDeploy(w http.ResponseWriter, r *http.Request) {
|
||
userID := h.getUserID(r)
|
||
user, _ := h.auth.GetUserByID(userID)
|
||
keyList, _ := h.keys.GetAllKeys()
|
||
serverList, _ := h.servers.GetAllServers()
|
||
groups, _ := h.servers.GetAllGroups()
|
||
deployments, _ := h.deploy.GetDeployments(userID)
|
||
|
||
// Owner can deploy the system master key
|
||
keyList = h.prependMasterKey(keyList, user.Role)
|
||
|
||
if r.Method == http.MethodGet {
|
||
data := &PageData{
|
||
Title: "Deploy Keys",
|
||
Active: "deploy",
|
||
User: user,
|
||
Keys: keyList,
|
||
Servers: serverList,
|
||
Groups: groups,
|
||
Deployments: deployments,
|
||
}
|
||
h.templates["deploy"].ExecuteTemplate(w, "base", data)
|
||
return
|
||
}
|
||
|
||
// Handle POST: deploy key
|
||
keyID, _ := strconv.ParseInt(r.FormValue("key_id"), 10, 64)
|
||
serverID, _ := strconv.ParseInt(r.FormValue("server_id"), 10, 64)
|
||
authMethod := r.FormValue("auth_method")
|
||
|
||
var key *models.SSHKey
|
||
var err error
|
||
|
||
if keyID == -1 && isOwner(user.Role) {
|
||
// Owner deploying the system master key
|
||
key = h.masterKeyForDeploy()
|
||
if key == nil {
|
||
http.Redirect(w, r, "/deploy", http.StatusSeeOther)
|
||
return
|
||
}
|
||
} else {
|
||
key, err = h.keys.GetKeyByID(keyID, userID)
|
||
if err != nil {
|
||
// Try global access for admin/owner deploying other users' keys
|
||
key, err = h.keys.GetKeyByIDGlobal(keyID)
|
||
if err != nil {
|
||
http.Redirect(w, r, "/deploy", http.StatusSeeOther)
|
||
return
|
||
}
|
||
}
|
||
}
|
||
|
||
server, err := h.servers.GetByIDGlobal(serverID)
|
||
if err != nil {
|
||
http.Redirect(w, r, "/deploy", http.StatusSeeOther)
|
||
return
|
||
}
|
||
|
||
switch authMethod {
|
||
case "password":
|
||
password := r.FormValue("password")
|
||
err = h.deploy.DeployKeyWithPassword(key, server, password)
|
||
case "key":
|
||
authKeyID, _ := strconv.ParseInt(r.FormValue("auth_key_id"), 10, 64)
|
||
authKey, kerr := h.keys.GetKeyByID(authKeyID, userID)
|
||
if kerr != nil {
|
||
http.Redirect(w, r, "/deploy", http.StatusSeeOther)
|
||
return
|
||
}
|
||
err = h.deploy.DeployKey(key, server, []byte(authKey.PrivateKeyEnc))
|
||
}
|
||
|
||
if err != nil {
|
||
// Reload with error
|
||
logging.Warn("Deploy failed: key='%s' target=%s@%s:%d error=%v", key.Name, server.Username, server.Hostname, server.Port, err)
|
||
h.audit.Log(userID, audit.ActionDeployFailed, fmt.Sprintf("Deploy key '%s' to %s@%s:%d failed: %v", key.Name, server.Username, server.Hostname, server.Port, err), clientIP(r))
|
||
deployments, _ = h.deploy.GetDeployments(userID)
|
||
data := &PageData{
|
||
Title: "Deploy Keys",
|
||
Active: "deploy",
|
||
User: user,
|
||
Keys: keyList,
|
||
Servers: serverList,
|
||
Groups: groups,
|
||
Deployments: deployments,
|
||
Flash: &Flash{Type: "danger", Message: "Deployment failed: " + err.Error()},
|
||
}
|
||
h.templates["deploy"].ExecuteTemplate(w, "base", data)
|
||
return
|
||
}
|
||
|
||
logging.Info("Deploy successful: key='%s' target=%s@%s:%d", key.Name, server.Username, server.Hostname, server.Port)
|
||
h.audit.Log(userID, audit.ActionDeploySuccess, fmt.Sprintf("Deployed key '%s' to %s@%s:%d", key.Name, server.Username, server.Hostname, server.Port), clientIP(r))
|
||
deployments, _ = h.deploy.GetDeployments(userID)
|
||
data := &PageData{
|
||
Title: "Deploy Keys",
|
||
Active: "deploy",
|
||
User: user,
|
||
Keys: keyList,
|
||
Servers: serverList,
|
||
Groups: groups,
|
||
Deployments: deployments,
|
||
Flash: &Flash{Type: "success", Message: fmt.Sprintf("Key '%s' successfully deployed to %s@%s:%d.", key.Name, server.Username, server.Hostname, server.Port)},
|
||
}
|
||
h.templates["deploy"].ExecuteTemplate(w, "base", data)
|
||
}
|
||
|
||
func (h *Handler) handleDeployGroup(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method != http.MethodPost {
|
||
http.Redirect(w, r, "/deploy", http.StatusSeeOther)
|
||
return
|
||
}
|
||
|
||
userID := h.getUserID(r)
|
||
user, _ := h.auth.GetUserByID(userID)
|
||
keyID, _ := strconv.ParseInt(r.FormValue("key_id"), 10, 64)
|
||
groupID, _ := strconv.ParseInt(r.FormValue("group_id"), 10, 64)
|
||
authMethod := r.FormValue("auth_method")
|
||
|
||
var key *models.SSHKey
|
||
var keyErr error
|
||
|
||
if keyID == -1 && isOwner(user.Role) {
|
||
// Owner deploying the system master key
|
||
key = h.masterKeyForDeploy()
|
||
if key == nil {
|
||
http.Redirect(w, r, "/deploy", http.StatusSeeOther)
|
||
return
|
||
}
|
||
} else {
|
||
key, keyErr = h.keys.GetKeyByID(keyID, userID)
|
||
if keyErr != nil {
|
||
// Try global access for admin/owner deploying other users' keys
|
||
key, keyErr = h.keys.GetKeyByIDGlobal(keyID)
|
||
if keyErr != nil {
|
||
http.Redirect(w, r, "/deploy", http.StatusSeeOther)
|
||
return
|
||
}
|
||
}
|
||
}
|
||
|
||
group, err := h.servers.GetGroupByIDGlobal(groupID)
|
||
if err != nil {
|
||
http.Redirect(w, r, "/deploy", http.StatusSeeOther)
|
||
return
|
||
}
|
||
|
||
members, err := h.servers.GetGroupMembersGlobal(groupID)
|
||
if err != nil || len(members) == 0 {
|
||
keyList, _ := h.keys.GetAllKeys()
|
||
keyList = h.prependMasterKey(keyList, user.Role)
|
||
serverList, _ := h.servers.GetAllServers()
|
||
groups, _ := h.servers.GetAllGroups()
|
||
deployments, _ := h.deploy.GetDeployments(userID)
|
||
data := &PageData{
|
||
Title: "Deploy Keys",
|
||
Active: "deploy",
|
||
User: user,
|
||
Keys: keyList,
|
||
Servers: serverList,
|
||
Groups: groups,
|
||
Deployments: deployments,
|
||
Flash: &Flash{Type: "warning", Message: "Group has no members."},
|
||
}
|
||
h.templates["deploy"].ExecuteTemplate(w, "base", data)
|
||
return
|
||
}
|
||
|
||
var successCount, failCount int
|
||
for _, server := range members {
|
||
srv := server // capture loop var
|
||
var deployErr error
|
||
switch authMethod {
|
||
case "password":
|
||
password := r.FormValue("password")
|
||
deployErr = h.deploy.DeployKeyWithPassword(key, &srv, password)
|
||
case "key":
|
||
authKeyID, _ := strconv.ParseInt(r.FormValue("auth_key_id"), 10, 64)
|
||
authKey, kerr := h.keys.GetKeyByID(authKeyID, userID)
|
||
if kerr != nil {
|
||
deployErr = fmt.Errorf("auth key not found")
|
||
} else {
|
||
deployErr = h.deploy.DeployKey(key, &srv, []byte(authKey.PrivateKeyEnc))
|
||
}
|
||
}
|
||
if deployErr != nil {
|
||
failCount++
|
||
h.audit.Log(userID, audit.ActionDeployFailed, fmt.Sprintf("Group deploy key '%s' to %s@%s:%d failed: %v", key.Name, srv.Username, srv.Hostname, srv.Port, deployErr), clientIP(r))
|
||
} else {
|
||
successCount++
|
||
h.audit.Log(userID, audit.ActionDeploySuccess, fmt.Sprintf("Group deploy key '%s' to %s@%s:%d", key.Name, srv.Username, srv.Hostname, srv.Port), clientIP(r))
|
||
}
|
||
}
|
||
|
||
h.audit.Log(userID, audit.ActionGroupDeploy, fmt.Sprintf("Group deploy '%s' to group '%s': %d success, %d failed", key.Name, group.Name, successCount, failCount), clientIP(r))
|
||
|
||
flashType := "success"
|
||
flashMsg := fmt.Sprintf("Deployed key to group '%s': %d/%d servers successful.", group.Name, successCount, len(members))
|
||
if failCount > 0 && successCount > 0 {
|
||
flashType = "warning"
|
||
} else if failCount > 0 && successCount == 0 {
|
||
flashType = "danger"
|
||
flashMsg = fmt.Sprintf("Deploy to group '%s' failed on all %d servers.", group.Name, failCount)
|
||
}
|
||
|
||
keyList, _ := h.keys.GetAllKeys()
|
||
keyList = h.prependMasterKey(keyList, user.Role)
|
||
serverList, _ := h.servers.GetAllServers()
|
||
groups, _ := h.servers.GetAllGroups()
|
||
deployments, _ := h.deploy.GetDeployments(userID)
|
||
|
||
data := &PageData{
|
||
Title: "Deploy Keys",
|
||
Active: "deploy",
|
||
User: user,
|
||
Keys: keyList,
|
||
Servers: serverList,
|
||
Groups: groups,
|
||
Deployments: deployments,
|
||
Flash: &Flash{Type: flashType, Message: flashMsg},
|
||
}
|
||
h.templates["deploy"].ExecuteTemplate(w, "base", data)
|
||
}
|
||
|
||
// --- Server Group Handlers ---
|
||
|
||
func (h *Handler) handleServerGroups(w http.ResponseWriter, r *http.Request) {
|
||
userID := h.getUserID(r)
|
||
user, _ := h.auth.GetUserByID(userID)
|
||
groups, _ := h.servers.GetAllGroups()
|
||
|
||
data := &PageData{
|
||
Title: "Groups",
|
||
Active: "groups",
|
||
User: user,
|
||
Groups: groups,
|
||
}
|
||
h.templates["server_groups"].ExecuteTemplate(w, "base", data)
|
||
}
|
||
|
||
func (h *Handler) handleServerGroupsAdd(w http.ResponseWriter, r *http.Request) {
|
||
userID := h.getUserID(r)
|
||
user, _ := h.auth.GetUserByID(userID)
|
||
|
||
if r.Method == http.MethodGet {
|
||
data := &PageData{
|
||
Title: "Create Group",
|
||
Active: "groups",
|
||
User: user,
|
||
}
|
||
h.templates["server_groups_add"].ExecuteTemplate(w, "base", data)
|
||
return
|
||
}
|
||
|
||
name := r.FormValue("name")
|
||
description := r.FormValue("description")
|
||
|
||
_, err := h.servers.CreateGroup(userID, name, description)
|
||
if err != nil {
|
||
data := &PageData{
|
||
Title: "Create Group",
|
||
Active: "groups",
|
||
User: user,
|
||
Flash: &Flash{Type: "danger", Message: "Failed to create group: " + err.Error()},
|
||
}
|
||
h.templates["server_groups_add"].ExecuteTemplate(w, "base", data)
|
||
return
|
||
}
|
||
|
||
h.audit.Log(userID, audit.ActionGroupCreated, fmt.Sprintf("Created server group: %s", name), clientIP(r))
|
||
http.Redirect(w, r, "/groups", http.StatusSeeOther)
|
||
}
|
||
|
||
func (h *Handler) handleServerGroupAction(w http.ResponseWriter, r *http.Request) {
|
||
userID := h.getUserID(r)
|
||
user, _ := h.auth.GetUserByID(userID)
|
||
|
||
parts := strings.Split(strings.Trim(r.URL.Path, "/"), "/")
|
||
if len(parts) < 3 {
|
||
http.NotFound(w, r)
|
||
return
|
||
}
|
||
|
||
groupID, err := strconv.ParseInt(parts[1], 10, 64)
|
||
if err != nil {
|
||
http.NotFound(w, r)
|
||
return
|
||
}
|
||
|
||
action := parts[2]
|
||
|
||
switch action {
|
||
case "edit":
|
||
group, err := h.servers.GetGroupByIDGlobal(groupID)
|
||
if err != nil {
|
||
http.Redirect(w, r, "/groups", http.StatusSeeOther)
|
||
return
|
||
}
|
||
|
||
if r.Method == http.MethodGet {
|
||
members, _ := h.servers.GetGroupMembersGlobal(groupID)
|
||
allServers, _ := h.servers.GetAllServers()
|
||
|
||
data := &PageData{
|
||
Title: "Edit Group",
|
||
Active: "groups",
|
||
User: user,
|
||
Group: group,
|
||
GroupServers: members,
|
||
AllServers: allServers,
|
||
}
|
||
h.templates["server_groups_edit"].ExecuteTemplate(w, "base", data)
|
||
return
|
||
}
|
||
|
||
// POST: update group info
|
||
name := r.FormValue("name")
|
||
description := r.FormValue("description")
|
||
if err := h.servers.UpdateGroupGlobal(groupID, name, description); err != nil {
|
||
members, _ := h.servers.GetGroupMembersGlobal(groupID)
|
||
allServers, _ := h.servers.GetAllServers()
|
||
data := &PageData{
|
||
Title: "Edit Group",
|
||
Active: "groups",
|
||
User: user,
|
||
Group: group,
|
||
GroupServers: members,
|
||
AllServers: allServers,
|
||
Flash: &Flash{Type: "danger", Message: "Failed to update group: " + err.Error()},
|
||
}
|
||
h.templates["server_groups_edit"].ExecuteTemplate(w, "base", data)
|
||
return
|
||
}
|
||
h.audit.Log(userID, audit.ActionGroupUpdated, fmt.Sprintf("Updated server group: %s (ID %d)", name, groupID), clientIP(r))
|
||
http.Redirect(w, r, fmt.Sprintf("/groups/%d/edit", groupID), http.StatusSeeOther)
|
||
|
||
case "delete":
|
||
if r.Method == http.MethodPost {
|
||
group, _ := h.servers.GetGroupByIDGlobal(groupID)
|
||
gname := "unknown"
|
||
if group != nil {
|
||
gname = group.Name
|
||
}
|
||
h.audit.Log(userID, audit.ActionGroupDeleted, fmt.Sprintf("Deleted server group: %s (ID %d)", gname, groupID), clientIP(r))
|
||
h.servers.DeleteGroupGlobal(groupID)
|
||
}
|
||
http.Redirect(w, r, "/groups", http.StatusSeeOther)
|
||
|
||
case "add-server":
|
||
if r.Method == http.MethodPost {
|
||
serverID, _ := strconv.ParseInt(r.FormValue("server_id"), 10, 64)
|
||
if err := h.servers.AddServerToGroupGlobal(groupID, serverID); err != nil {
|
||
logging.Error("Failed to add server to group: %v", err)
|
||
} else {
|
||
h.audit.Log(userID, audit.ActionGroupServerAdded, fmt.Sprintf("Added server ID %d to group ID %d", serverID, groupID), clientIP(r))
|
||
}
|
||
}
|
||
http.Redirect(w, r, fmt.Sprintf("/groups/%d/edit", groupID), http.StatusSeeOther)
|
||
|
||
case "remove-server":
|
||
if r.Method == http.MethodPost {
|
||
serverID, _ := strconv.ParseInt(r.FormValue("server_id"), 10, 64)
|
||
if err := h.servers.RemoveServerFromGroupGlobal(groupID, serverID); err != nil {
|
||
logging.Error("Failed to remove server from group: %v", err)
|
||
} else {
|
||
h.audit.Log(userID, audit.ActionGroupServerRemoved, fmt.Sprintf("Removed server ID %d from group ID %d", serverID, groupID), clientIP(r))
|
||
}
|
||
}
|
||
http.Redirect(w, r, fmt.Sprintf("/groups/%d/edit", groupID), http.StatusSeeOther)
|
||
|
||
default:
|
||
http.NotFound(w, r)
|
||
}
|
||
}
|
||
|
||
func (h *Handler) handleAudit(w http.ResponseWriter, r *http.Request) {
|
||
userID := h.getUserID(r)
|
||
user, _ := h.auth.GetUserByID(userID)
|
||
|
||
page, _ := strconv.Atoi(r.URL.Query().Get("page"))
|
||
if page < 1 {
|
||
page = 1
|
||
}
|
||
perPage := 50
|
||
|
||
filterMine := r.URL.Query().Get("filter") == "mine"
|
||
|
||
var entries []audit.AuditEntry
|
||
var total int
|
||
var err error
|
||
|
||
switch {
|
||
case isOwner(user.Role) && !filterMine:
|
||
// Owner sees everything
|
||
entries, total, err = h.audit.GetAll(page, perPage)
|
||
case isAdmin(user.Role) && !filterMine:
|
||
// Admin sees all except owner entries
|
||
entries, total, err = h.audit.GetAllExceptOwners(page, perPage)
|
||
default:
|
||
// User sees only own entries; also used when admin/owner filters "mine"
|
||
entries, total, err = h.audit.GetByUser(userID, page, perPage)
|
||
if !isAdmin(user.Role) {
|
||
filterMine = true
|
||
}
|
||
}
|
||
|
||
if err != nil {
|
||
logging.Error("Failed to load audit log: %v", err)
|
||
}
|
||
|
||
totalPages := (total + perPage - 1) / perPage
|
||
if totalPages < 1 {
|
||
totalPages = 1
|
||
}
|
||
|
||
prevPage := page - 1
|
||
if prevPage < 1 {
|
||
prevPage = 1
|
||
}
|
||
nextPage := page + 1
|
||
if nextPage > totalPages {
|
||
nextPage = totalPages
|
||
}
|
||
|
||
data := &PageData{
|
||
Title: "Audit Log",
|
||
Active: "audit",
|
||
User: user,
|
||
AuditEntries: entries,
|
||
AuditTotal: total,
|
||
AuditPage: page,
|
||
AuditTotalPages: totalPages,
|
||
AuditPrevPage: prevPage,
|
||
AuditNextPage: nextPage,
|
||
AuditIsAdmin: isAdmin(user.Role),
|
||
AuditFilterUser: filterMine,
|
||
}
|
||
h.templates["audit"].ExecuteTemplate(w, "base", data)
|
||
}
|
||
|
||
// --- User Management Handlers (Admin/Owner Only) ---
|
||
|
||
func (h *Handler) handleUsers(w http.ResponseWriter, r *http.Request) {
|
||
userID := h.getUserID(r)
|
||
user, _ := h.auth.GetUserByID(userID)
|
||
allUsers, _ := h.auth.GetAllUsers()
|
||
|
||
// Admins only see users with role "user"; owners see all
|
||
var users []models.User
|
||
if isOwner(user.Role) {
|
||
users = allUsers
|
||
} else {
|
||
for _, u := range allUsers {
|
||
if u.Role == "user" || u.ID == userID {
|
||
users = append(users, u)
|
||
}
|
||
}
|
||
}
|
||
|
||
data := &PageData{
|
||
Title: "User Management",
|
||
Active: "users",
|
||
User: user,
|
||
Users: users,
|
||
InitialOwnerID: h.getInitialOwnerID(),
|
||
}
|
||
h.templates["users"].ExecuteTemplate(w, "base", data)
|
||
}
|
||
|
||
func (h *Handler) handleUsersAdd(w http.ResponseWriter, r *http.Request) {
|
||
userID := h.getUserID(r)
|
||
user, _ := h.auth.GetUserByID(userID)
|
||
|
||
if r.Method == http.MethodGet {
|
||
policy := h.auth.GetPasswordPolicy()
|
||
data := &PageData{
|
||
Title: "Add User",
|
||
Active: "users",
|
||
User: user,
|
||
PasswordPolicy: &policy,
|
||
EmailEnabled: h.mail.IsEnabled(),
|
||
}
|
||
h.templates["users_add"].ExecuteTemplate(w, "base", data)
|
||
return
|
||
}
|
||
|
||
username := r.FormValue("username")
|
||
email := r.FormValue("email")
|
||
password := r.FormValue("password")
|
||
role := r.FormValue("role")
|
||
mustChangePassword := r.FormValue("must_change_password") == "1"
|
||
sendInvitation := r.FormValue("send_invitation") == "1"
|
||
|
||
// Enforce role restrictions: admin can only create "user" role
|
||
if !isOwner(user.Role) && role != "user" {
|
||
role = "user"
|
||
}
|
||
// Only owner can assign "owner" role
|
||
if role == "owner" && !isOwner(user.Role) {
|
||
role = "user"
|
||
}
|
||
|
||
// If sending invitation, generate a random temporary password
|
||
if sendInvitation {
|
||
randBytes := make([]byte, 24)
|
||
if _, err := rand.Read(randBytes); err != nil {
|
||
policy := h.auth.GetPasswordPolicy()
|
||
data := &PageData{
|
||
Title: "Add User",
|
||
Active: "users",
|
||
User: user,
|
||
PasswordPolicy: &policy,
|
||
EmailEnabled: h.mail.IsEnabled(),
|
||
Flash: &Flash{Type: "danger", Message: "Failed to generate temporary password."},
|
||
}
|
||
h.templates["users_add"].ExecuteTemplate(w, "base", data)
|
||
return
|
||
}
|
||
password = base64.URLEncoding.EncodeToString(randBytes)
|
||
mustChangePassword = true
|
||
} else {
|
||
// Validate password against policy (only when manually set)
|
||
if err := h.auth.ValidatePasswordPolicy(password); err != nil {
|
||
policy := h.auth.GetPasswordPolicy()
|
||
data := &PageData{
|
||
Title: "Add User",
|
||
Active: "users",
|
||
User: user,
|
||
PasswordPolicy: &policy,
|
||
EmailEnabled: h.mail.IsEnabled(),
|
||
Flash: &Flash{Type: "danger", Message: err.Error()},
|
||
}
|
||
h.templates["users_add"].ExecuteTemplate(w, "base", data)
|
||
return
|
||
}
|
||
}
|
||
|
||
newUser, err := h.auth.Register(username, email, password, role, mustChangePassword)
|
||
if err != nil {
|
||
policy := h.auth.GetPasswordPolicy()
|
||
data := &PageData{
|
||
Title: "Add User",
|
||
Active: "users",
|
||
User: user,
|
||
PasswordPolicy: &policy,
|
||
EmailEnabled: h.mail.IsEnabled(),
|
||
Flash: &Flash{Type: "danger", Message: "Failed to create user: " + err.Error()},
|
||
}
|
||
h.templates["users_add"].ExecuteTemplate(w, "base", data)
|
||
return
|
||
}
|
||
|
||
logging.Info("User created: username='%s' role='%s' by admin user_id=%d", newUser.Username, role, userID)
|
||
h.audit.Log(userID, audit.ActionUserCreated, fmt.Sprintf("Created user: %s (role: %s)", username, role), clientIP(r))
|
||
|
||
// Send invitation email if requested
|
||
if sendInvitation {
|
||
token, err := h.auth.CreateInvitationToken(newUser.ID, 48*time.Hour)
|
||
if err != nil {
|
||
logging.Error("Failed to create invitation token for user '%s': %v", username, err)
|
||
h.audit.Log(userID, audit.ActionInvitationSendFailed, fmt.Sprintf("Failed to create invitation token for user: %s", username), clientIP(r))
|
||
// User was created but invitation failed – redirect with warning
|
||
http.Redirect(w, r, "/users", http.StatusSeeOther)
|
||
return
|
||
}
|
||
|
||
// Build invite URL – prefer configured BaseURL, fall back to request
|
||
base := h.baseURL
|
||
if base == "" {
|
||
scheme := "http"
|
||
if r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https" {
|
||
scheme = "https"
|
||
}
|
||
base = fmt.Sprintf("%s://%s", scheme, r.Host)
|
||
logging.Warn("KEYWARDEN_BASE_URL is not set – deriving invite link from request: %s (set KEYWARDEN_BASE_URL for reliable email links)", base)
|
||
}
|
||
inviteURL := fmt.Sprintf("%s/invite/%s", base, token)
|
||
go func() {
|
||
mailErr := h.mail.SendInvitation(email, mail.InvitationData{
|
||
Username: username,
|
||
InviteURL: inviteURL,
|
||
ExpiresIn: "48 hours",
|
||
})
|
||
if mailErr != nil {
|
||
logging.Error("Failed to send invitation email to '%s': %v", email, mailErr)
|
||
h.audit.Log(userID, audit.ActionInvitationSendFailed, fmt.Sprintf("Email delivery failed for user: %s (%s)", username, email), clientIP(r))
|
||
} else {
|
||
logging.Info("Invitation email sent to '%s' for user '%s'", email, username)
|
||
h.audit.Log(userID, audit.ActionInvitationSent, fmt.Sprintf("Invitation sent to %s for user: %s", email, username), clientIP(r))
|
||
}
|
||
}()
|
||
}
|
||
|
||
http.Redirect(w, r, "/users", http.StatusSeeOther)
|
||
}
|
||
|
||
func (h *Handler) handleUserAction(w http.ResponseWriter, r *http.Request) {
|
||
userID := h.getUserID(r)
|
||
user, _ := h.auth.GetUserByID(userID)
|
||
|
||
parts := strings.Split(strings.Trim(r.URL.Path, "/"), "/")
|
||
if len(parts) < 3 {
|
||
http.NotFound(w, r)
|
||
return
|
||
}
|
||
|
||
targetID, err := strconv.ParseInt(parts[1], 10, 64)
|
||
if err != nil {
|
||
http.NotFound(w, r)
|
||
return
|
||
}
|
||
|
||
action := parts[2]
|
||
|
||
// Load target user for permission checks
|
||
targetUser, err := h.auth.GetUserByID(targetID)
|
||
if err != nil {
|
||
http.Redirect(w, r, "/users", http.StatusSeeOther)
|
||
return
|
||
}
|
||
|
||
// Admin can only manage users with role "user"; owner/admin management requires owner role
|
||
if !isOwner(user.Role) && targetUser.Role != "user" && targetID != userID {
|
||
http.Redirect(w, r, "/users", http.StatusSeeOther)
|
||
return
|
||
}
|
||
|
||
switch action {
|
||
case "edit":
|
||
if r.Method == http.MethodGet {
|
||
policy := h.auth.GetPasswordPolicy()
|
||
data := &PageData{
|
||
Title: "Edit User",
|
||
Active: "users",
|
||
User: user,
|
||
EditUser: targetUser,
|
||
PasswordPolicy: &policy,
|
||
IsInitialOwner: h.auth.IsInitialOwner(targetID),
|
||
}
|
||
h.templates["users_edit"].ExecuteTemplate(w, "base", data)
|
||
return
|
||
}
|
||
|
||
// POST: update user
|
||
username := r.FormValue("username")
|
||
email := r.FormValue("email")
|
||
role := r.FormValue("role")
|
||
newPassword := r.FormValue("password")
|
||
forceChange := r.FormValue("must_change_password") == "1"
|
||
|
||
// Initial Owner protection: role must remain "owner"
|
||
if h.auth.IsInitialOwner(targetID) && role != "owner" {
|
||
policy := h.auth.GetPasswordPolicy()
|
||
data := &PageData{
|
||
Title: "Edit User",
|
||
Active: "users",
|
||
User: user,
|
||
EditUser: targetUser,
|
||
PasswordPolicy: &policy,
|
||
IsInitialOwner: true,
|
||
Flash: &Flash{Type: "danger", Message: "The initial owner role cannot be changed. This account was created during installation and is permanently protected."},
|
||
}
|
||
h.templates["users_edit"].ExecuteTemplate(w, "base", data)
|
||
return
|
||
}
|
||
|
||
// Enforce role restrictions:
|
||
// - Admin can only assign "user" role
|
||
// - Only owner can assign "admin" or "owner"
|
||
if !isOwner(user.Role) {
|
||
role = "user"
|
||
}
|
||
// Only owner can assign owner role
|
||
if role == "owner" && !isOwner(user.Role) {
|
||
role = "user"
|
||
}
|
||
|
||
// Owner protection: cannot degrade self if last owner
|
||
if targetID == userID && isOwner(user.Role) && role != "owner" {
|
||
ownerCount, _ := h.auth.CountByRole("owner")
|
||
if ownerCount <= 1 {
|
||
policy := h.auth.GetPasswordPolicy()
|
||
data := &PageData{
|
||
Title: "Edit User",
|
||
Active: "users",
|
||
User: user,
|
||
EditUser: targetUser,
|
||
PasswordPolicy: &policy,
|
||
Flash: &Flash{Type: "danger", Message: "Cannot remove the last owner role. At least one owner must exist."},
|
||
}
|
||
h.templates["users_edit"].ExecuteTemplate(w, "base", data)
|
||
return
|
||
}
|
||
}
|
||
|
||
if err := h.auth.UpdateUser(targetID, username, email, role); err != nil {
|
||
policy := h.auth.GetPasswordPolicy()
|
||
data := &PageData{
|
||
Title: "Edit User",
|
||
Active: "users",
|
||
User: user,
|
||
EditUser: targetUser,
|
||
PasswordPolicy: &policy,
|
||
Flash: &Flash{Type: "danger", Message: "Failed to update user: " + err.Error()},
|
||
}
|
||
h.templates["users_edit"].ExecuteTemplate(w, "base", data)
|
||
return
|
||
}
|
||
|
||
// Update password if provided
|
||
if newPassword != "" {
|
||
// Validate against password policy
|
||
if err := h.auth.ValidatePasswordPolicy(newPassword); err != nil {
|
||
policy := h.auth.GetPasswordPolicy()
|
||
data := &PageData{
|
||
Title: "Edit User",
|
||
Active: "users",
|
||
User: user,
|
||
EditUser: targetUser,
|
||
PasswordPolicy: &policy,
|
||
Flash: &Flash{Type: "danger", Message: err.Error()},
|
||
}
|
||
h.templates["users_edit"].ExecuteTemplate(w, "base", data)
|
||
return
|
||
}
|
||
if err := h.auth.UpdatePassword(targetID, newPassword); err != nil {
|
||
policy := h.auth.GetPasswordPolicy()
|
||
data := &PageData{
|
||
Title: "Edit User",
|
||
Active: "users",
|
||
User: user,
|
||
EditUser: targetUser,
|
||
PasswordPolicy: &policy,
|
||
Flash: &Flash{Type: "danger", Message: "User updated but password change failed: " + err.Error()},
|
||
}
|
||
h.templates["users_edit"].ExecuteTemplate(w, "base", data)
|
||
return
|
||
}
|
||
logging.Info("Password changed for user '%s' (ID %d) by admin user_id=%d", username, targetID, userID)
|
||
h.audit.Log(userID, audit.ActionPasswordChanged, fmt.Sprintf("Admin changed password for user: %s (ID %d)", username, targetID), clientIP(r))
|
||
}
|
||
|
||
// Update must_change_password flag
|
||
h.auth.SetMustChangePassword(targetID, forceChange)
|
||
|
||
logging.Info("User updated: username='%s' (ID %d) role='%s' by admin user_id=%d", username, targetID, role, userID)
|
||
h.audit.Log(userID, audit.ActionUserUpdated, fmt.Sprintf("Updated user: %s (ID %d, role: %s)", username, targetID, role), clientIP(r))
|
||
http.Redirect(w, r, "/users", http.StatusSeeOther)
|
||
|
||
case "unlock":
|
||
if r.Method == http.MethodPost {
|
||
h.auth.UnlockAccount(targetID)
|
||
logging.Info("Account unlocked for user '%s' (ID %d) by admin user_id=%d", targetUser.Username, targetID, userID)
|
||
h.audit.Log(userID, audit.ActionAccountUnlocked, fmt.Sprintf("Unlocked account: %s (ID %d)", targetUser.Username, targetID), clientIP(r))
|
||
}
|
||
http.Redirect(w, r, "/users", http.StatusSeeOther)
|
||
|
||
case "delete":
|
||
if r.Method == http.MethodPost {
|
||
// Initial Owner protection: cannot be deleted
|
||
if h.auth.IsInitialOwner(targetID) {
|
||
http.Redirect(w, r, "/users", http.StatusSeeOther)
|
||
return
|
||
}
|
||
// Owner protection: cannot self-delete
|
||
if targetID == userID {
|
||
http.Redirect(w, r, "/users", http.StatusSeeOther)
|
||
return
|
||
}
|
||
// Cannot delete an owner if it would leave zero owners
|
||
if targetUser.Role == "owner" {
|
||
ownerCount, _ := h.auth.CountByRole("owner")
|
||
if ownerCount <= 1 {
|
||
http.Redirect(w, r, "/users", http.StatusSeeOther)
|
||
return
|
||
}
|
||
}
|
||
uname := targetUser.Username
|
||
logging.Info("User deleted: username='%s' (ID %d) by admin user_id=%d", uname, targetID, userID)
|
||
h.audit.Log(userID, audit.ActionUserDeleted, fmt.Sprintf("Deleted user: %s (ID %d)", uname, targetID), clientIP(r))
|
||
h.auth.DeleteUser(targetID)
|
||
}
|
||
http.Redirect(w, r, "/users", http.StatusSeeOther)
|
||
|
||
default:
|
||
http.NotFound(w, r)
|
||
}
|
||
}
|
||
|
||
// --- Settings Handlers ---
|
||
|
||
func (h *Handler) handleSettings(w http.ResponseWriter, r *http.Request) {
|
||
userID := h.getUserID(r)
|
||
user, _ := h.auth.GetUserByID(userID)
|
||
|
||
if r.Method == http.MethodGet {
|
||
policy := h.auth.GetPasswordPolicy()
|
||
mfaRequired, _ := h.auth.GetSetting("mfa_required")
|
||
data := &PageData{
|
||
Title: "Settings",
|
||
Active: "settings",
|
||
User: user,
|
||
EmailEnabled: h.mail.IsEnabled(),
|
||
PasswordPolicy: &policy,
|
||
MFARequired: mfaRequired == "true",
|
||
}
|
||
h.templates["settings"].ExecuteTemplate(w, "base", data)
|
||
return
|
||
}
|
||
|
||
// Handle personal password change
|
||
currentPass := r.FormValue("current_password")
|
||
newPass := r.FormValue("new_password")
|
||
confirmPass := r.FormValue("confirm_password")
|
||
|
||
if currentPass != "" && newPass != "" {
|
||
if newPass != confirmPass {
|
||
policy := h.auth.GetPasswordPolicy()
|
||
data := &PageData{
|
||
Title: "Settings",
|
||
Active: "settings",
|
||
User: user,
|
||
PasswordPolicy: &policy,
|
||
Flash: &Flash{Type: "danger", Message: "New passwords do not match."},
|
||
}
|
||
h.templates["settings"].ExecuteTemplate(w, "base", data)
|
||
return
|
||
}
|
||
|
||
// Validate against password policy
|
||
if err := h.auth.ValidatePasswordPolicy(newPass); err != nil {
|
||
policy := h.auth.GetPasswordPolicy()
|
||
data := &PageData{
|
||
Title: "Settings",
|
||
Active: "settings",
|
||
User: user,
|
||
PasswordPolicy: &policy,
|
||
Flash: &Flash{Type: "danger", Message: err.Error()},
|
||
}
|
||
h.templates["settings"].ExecuteTemplate(w, "base", data)
|
||
return
|
||
}
|
||
|
||
// Verify current password
|
||
_, err := h.auth.Login(user.Username, currentPass)
|
||
if err != nil {
|
||
logging.Warn("Password change failed for user '%s': current password incorrect", user.Username)
|
||
policy := h.auth.GetPasswordPolicy()
|
||
data := &PageData{
|
||
Title: "Settings",
|
||
Active: "settings",
|
||
User: user,
|
||
PasswordPolicy: &policy,
|
||
Flash: &Flash{Type: "danger", Message: "Current password is incorrect."},
|
||
}
|
||
h.templates["settings"].ExecuteTemplate(w, "base", data)
|
||
return
|
||
}
|
||
|
||
if err := h.auth.UpdatePassword(userID, newPass); err != nil {
|
||
policy := h.auth.GetPasswordPolicy()
|
||
data := &PageData{
|
||
Title: "Settings",
|
||
Active: "settings",
|
||
User: user,
|
||
PasswordPolicy: &policy,
|
||
Flash: &Flash{Type: "danger", Message: "Failed to change password: " + err.Error()},
|
||
}
|
||
h.templates["settings"].ExecuteTemplate(w, "base", data)
|
||
return
|
||
}
|
||
logging.Info("Password changed by user '%s' (ID %d)", user.Username, userID)
|
||
h.audit.Log(userID, audit.ActionPasswordChanged, "User changed their password", clientIP(r))
|
||
}
|
||
|
||
http.Redirect(w, r, "/settings", http.StatusSeeOther)
|
||
}
|
||
|
||
// handleForcePasswordChange handles the mandatory password change page.
|
||
// Users with must_change_password flag are redirected here and cannot use
|
||
// the application until they set a new password.
|
||
func (h *Handler) handleForcePasswordChange(w http.ResponseWriter, r *http.Request) {
|
||
userID := h.getUserID(r)
|
||
user, _ := h.auth.GetUserByID(userID)
|
||
|
||
// If password change is not required, redirect to dashboard
|
||
if !user.MustChangePassword {
|
||
http.Redirect(w, r, "/dashboard", http.StatusSeeOther)
|
||
return
|
||
}
|
||
|
||
policy := h.auth.GetPasswordPolicy()
|
||
|
||
if r.Method == http.MethodGet {
|
||
data := &PageData{
|
||
Title: "Change Password",
|
||
User: user,
|
||
PasswordPolicy: &policy,
|
||
}
|
||
h.templates["force_password_change"].Execute(w, data)
|
||
return
|
||
}
|
||
|
||
newPass := r.FormValue("new_password")
|
||
confirmPass := r.FormValue("confirm_password")
|
||
|
||
if newPass != confirmPass {
|
||
data := &PageData{
|
||
Title: "Change Password",
|
||
User: user,
|
||
PasswordPolicy: &policy,
|
||
Flash: &Flash{Type: "danger", Message: "Passwords do not match."},
|
||
}
|
||
h.templates["force_password_change"].Execute(w, data)
|
||
return
|
||
}
|
||
|
||
if err := h.auth.ValidatePasswordPolicy(newPass); err != nil {
|
||
data := &PageData{
|
||
Title: "Change Password",
|
||
User: user,
|
||
PasswordPolicy: &policy,
|
||
Flash: &Flash{Type: "danger", Message: err.Error()},
|
||
}
|
||
h.templates["force_password_change"].Execute(w, data)
|
||
return
|
||
}
|
||
|
||
if err := h.auth.UpdatePassword(userID, newPass); err != nil {
|
||
data := &PageData{
|
||
Title: "Change Password",
|
||
User: user,
|
||
PasswordPolicy: &policy,
|
||
Flash: &Flash{Type: "danger", Message: "Failed to change password: " + err.Error()},
|
||
}
|
||
h.templates["force_password_change"].Execute(w, data)
|
||
return
|
||
}
|
||
|
||
// Clear the flag
|
||
h.auth.SetMustChangePassword(userID, false)
|
||
logging.Info("Forced password change completed for user '%s' (ID %d)", user.Username, userID)
|
||
h.audit.Log(userID, audit.ActionForcePasswordChange, "User changed initial password", clientIP(r))
|
||
http.Redirect(w, r, "/dashboard", http.StatusSeeOther)
|
||
}
|
||
|
||
// handleInviteAccept handles the public invitation acceptance page.
|
||
// Users reach this via the invitation link in their email.
|
||
// GET: shows the registration form (set password)
|
||
// POST: completes the registration
|
||
func (h *Handler) handleInviteAccept(w http.ResponseWriter, r *http.Request) {
|
||
// Extract token from URL: /invite/{token}
|
||
parts := strings.Split(strings.Trim(r.URL.Path, "/"), "/")
|
||
if len(parts) < 2 || parts[1] == "" {
|
||
http.Redirect(w, r, "/login", http.StatusSeeOther)
|
||
return
|
||
}
|
||
token := parts[1]
|
||
|
||
// Look up the invitation
|
||
inv, err := h.auth.GetInvitationByToken(token)
|
||
if err != nil {
|
||
data := &PageData{
|
||
Title: "Invalid Invitation",
|
||
Error: "This invitation link is invalid or has already been used.",
|
||
}
|
||
h.templates["invite_accept"].Execute(w, data)
|
||
return
|
||
}
|
||
if inv.Used {
|
||
data := &PageData{
|
||
Title: "Invitation Used",
|
||
Error: "This invitation has already been used. Please log in with your credentials.",
|
||
}
|
||
h.templates["invite_accept"].Execute(w, data)
|
||
return
|
||
}
|
||
if time.Now().After(inv.ExpiresAt) {
|
||
data := &PageData{
|
||
Title: "Invitation Expired",
|
||
Error: "This invitation has expired. Please contact your administrator for a new invitation.",
|
||
}
|
||
h.templates["invite_accept"].Execute(w, data)
|
||
return
|
||
}
|
||
|
||
// Load the user associated with this invitation
|
||
invitedUser, err := h.auth.GetUserByID(inv.UserID)
|
||
if err != nil {
|
||
data := &PageData{
|
||
Title: "Invalid Invitation",
|
||
Error: "The user associated with this invitation could not be found.",
|
||
}
|
||
h.templates["invite_accept"].Execute(w, data)
|
||
return
|
||
}
|
||
|
||
policy := h.auth.GetPasswordPolicy()
|
||
|
||
if r.Method == http.MethodGet {
|
||
data := &PageData{
|
||
Title: "Complete Registration",
|
||
EditUser: invitedUser,
|
||
PasswordPolicy: &policy,
|
||
Data: token,
|
||
}
|
||
h.templates["invite_accept"].Execute(w, data)
|
||
return
|
||
}
|
||
|
||
// POST: complete registration
|
||
newPass := r.FormValue("new_password")
|
||
confirmPass := r.FormValue("confirm_password")
|
||
|
||
if newPass != confirmPass {
|
||
data := &PageData{
|
||
Title: "Complete Registration",
|
||
EditUser: invitedUser,
|
||
PasswordPolicy: &policy,
|
||
Data: token,
|
||
Flash: &Flash{Type: "danger", Message: "Passwords do not match."},
|
||
}
|
||
h.templates["invite_accept"].Execute(w, data)
|
||
return
|
||
}
|
||
|
||
if err := h.auth.ValidatePasswordPolicy(newPass); err != nil {
|
||
data := &PageData{
|
||
Title: "Complete Registration",
|
||
EditUser: invitedUser,
|
||
PasswordPolicy: &policy,
|
||
Data: token,
|
||
Flash: &Flash{Type: "danger", Message: err.Error()},
|
||
}
|
||
h.templates["invite_accept"].Execute(w, data)
|
||
return
|
||
}
|
||
|
||
completedUser, err := h.auth.CompleteInvitation(token, newPass)
|
||
if err != nil {
|
||
data := &PageData{
|
||
Title: "Complete Registration",
|
||
EditUser: invitedUser,
|
||
PasswordPolicy: &policy,
|
||
Data: token,
|
||
Flash: &Flash{Type: "danger", Message: "Registration failed: " + err.Error()},
|
||
}
|
||
h.templates["invite_accept"].Execute(w, data)
|
||
return
|
||
}
|
||
|
||
logging.Info("Invitation accepted: user '%s' (ID %d) completed registration", completedUser.Username, completedUser.ID)
|
||
h.audit.Log(completedUser.ID, audit.ActionInvitationAccepted, fmt.Sprintf("User %s completed registration via invitation", completedUser.Username), clientIP(r))
|
||
|
||
// Redirect to login page with success indication
|
||
http.Redirect(w, r, "/login", http.StatusSeeOther)
|
||
}
|
||
|
||
// handleThemeChange saves the user's theme preference
|
||
func (h *Handler) handleThemeChange(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method != http.MethodPost {
|
||
http.Redirect(w, r, "/settings", http.StatusSeeOther)
|
||
return
|
||
}
|
||
|
||
userID := h.getUserID(r)
|
||
theme := r.FormValue("theme")
|
||
h.auth.UpdateTheme(userID, theme)
|
||
http.Redirect(w, r, "/settings", http.StatusSeeOther)
|
||
}
|
||
|
||
// handleAvatarUpload saves the user's profile picture as a file on disk
|
||
func (h *Handler) handleAvatarUpload(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method != http.MethodPost {
|
||
http.Redirect(w, r, "/settings", http.StatusSeeOther)
|
||
return
|
||
}
|
||
|
||
userID := h.getUserID(r)
|
||
avatarPath := filepath.Join(h.dataDir, "avatars", fmt.Sprintf("%d", userID))
|
||
|
||
// Check for avatar removal
|
||
if r.FormValue("remove_avatar") == "1" {
|
||
os.Remove(avatarPath)
|
||
h.auth.UpdateAvatar(userID, "")
|
||
h.audit.Log(userID, audit.ActionAvatarChanged, "Removed profile picture", clientIP(r))
|
||
http.Redirect(w, r, "/settings", http.StatusSeeOther)
|
||
return
|
||
}
|
||
|
||
// Limit upload to 5MB
|
||
r.Body = http.MaxBytesReader(w, r.Body, 5<<20)
|
||
if err := r.ParseMultipartForm(5 << 20); err != nil {
|
||
http.Redirect(w, r, "/settings", http.StatusSeeOther)
|
||
return
|
||
}
|
||
|
||
file, header, err := r.FormFile("avatar")
|
||
if err != nil {
|
||
http.Redirect(w, r, "/settings", http.StatusSeeOther)
|
||
return
|
||
}
|
||
defer file.Close()
|
||
|
||
// Validate content type
|
||
ct := header.Header.Get("Content-Type")
|
||
if ct != "image/png" && ct != "image/jpeg" && ct != "image/gif" && ct != "image/webp" {
|
||
http.Redirect(w, r, "/settings", http.StatusSeeOther)
|
||
return
|
||
}
|
||
|
||
data, err := io.ReadAll(file)
|
||
if err != nil {
|
||
http.Redirect(w, r, "/settings", http.StatusSeeOther)
|
||
return
|
||
}
|
||
|
||
// Save avatar as file on disk (persistent in Docker volume)
|
||
if err := os.WriteFile(avatarPath, data, 0600); err != nil {
|
||
logging.Warn("Failed to save avatar file for user %d: %v", userID, err)
|
||
http.Redirect(w, r, "/settings", http.StatusSeeOther)
|
||
return
|
||
}
|
||
|
||
// Store marker in DB (not the actual image data)
|
||
if err := h.auth.UpdateAvatar(userID, "file"); err != nil {
|
||
logging.Warn("Failed to update avatar marker for user %d: %v", userID, err)
|
||
http.Redirect(w, r, "/settings", http.StatusSeeOther)
|
||
return
|
||
}
|
||
|
||
h.audit.Log(userID, audit.ActionAvatarChanged, "Updated profile picture", clientIP(r))
|
||
http.Redirect(w, r, "/settings", http.StatusSeeOther)
|
||
}
|
||
|
||
// handleAvatarServe serves a user's avatar image from disk
|
||
func (h *Handler) handleAvatarServe(w http.ResponseWriter, r *http.Request) {
|
||
// Extract user ID from URL: /avatar/{id}
|
||
parts := strings.Split(strings.TrimSuffix(r.URL.Path, "/"), "/")
|
||
if len(parts) < 3 || parts[2] == "" {
|
||
http.NotFound(w, r)
|
||
return
|
||
}
|
||
targetID, err := strconv.ParseInt(parts[2], 10, 64)
|
||
if err != nil {
|
||
http.NotFound(w, r)
|
||
return
|
||
}
|
||
|
||
avatarPath := filepath.Join(h.dataDir, "avatars", fmt.Sprintf("%d", targetID))
|
||
data, err := os.ReadFile(avatarPath)
|
||
if err != nil {
|
||
http.NotFound(w, r)
|
||
return
|
||
}
|
||
|
||
contentType := http.DetectContentType(data)
|
||
w.Header().Set("Content-Type", contentType)
|
||
w.Header().Set("Cache-Control", "private, max-age=300")
|
||
w.Write(data)
|
||
}
|
||
|
||
// migrateAvatarsToFiles converts legacy base64 data URI avatars to file-based storage.
|
||
// This ensures backwards compatibility when upgrading existing installations.
|
||
func (h *Handler) migrateAvatarsToFiles() {
|
||
rows, err := h.auth.GetUsersWithLegacyAvatars()
|
||
if err != nil {
|
||
logging.Warn("Could not check for legacy avatars: %v", err)
|
||
return
|
||
}
|
||
|
||
avatarsDir := filepath.Join(h.dataDir, "avatars")
|
||
migrated := 0
|
||
for _, entry := range rows {
|
||
// Parse data URI: "data:image/png;base64,iVBOR..."
|
||
if !strings.HasPrefix(entry.AvatarBase64, "data:") {
|
||
continue
|
||
}
|
||
commaIdx := strings.Index(entry.AvatarBase64, ",")
|
||
if commaIdx < 0 {
|
||
continue
|
||
}
|
||
b64Data := entry.AvatarBase64[commaIdx+1:]
|
||
imgData, err := base64.StdEncoding.DecodeString(b64Data)
|
||
if err != nil {
|
||
logging.Warn("Failed to decode legacy avatar for user %d: %v", entry.ID, err)
|
||
continue
|
||
}
|
||
|
||
avatarPath := filepath.Join(avatarsDir, fmt.Sprintf("%d", entry.ID))
|
||
if err := os.WriteFile(avatarPath, imgData, 0600); err != nil {
|
||
logging.Warn("Failed to write avatar file for user %d: %v", entry.ID, err)
|
||
continue
|
||
}
|
||
|
||
if err := h.auth.UpdateAvatar(entry.ID, "file"); err != nil {
|
||
logging.Warn("Failed to update avatar marker for user %d: %v", entry.ID, err)
|
||
continue
|
||
}
|
||
migrated++
|
||
}
|
||
if migrated > 0 {
|
||
logging.Info("Migrated %d avatar(s) from base64 to file storage", migrated)
|
||
}
|
||
}
|
||
|
||
// --- Access Assignments Handlers ---
|
||
|
||
func (h *Handler) handleAssignments(w http.ResponseWriter, r *http.Request) {
|
||
userID := h.getUserID(r)
|
||
user, _ := h.auth.GetUserByID(userID)
|
||
|
||
assignments, _ := h.servers.GetAllAssignments()
|
||
|
||
// Decrypt initial passwords for display
|
||
for i := range assignments {
|
||
if assignments[i].InitialPassword != "" {
|
||
if decrypted, err := h.keys.DecryptValue(assignments[i].InitialPassword); err == nil {
|
||
assignments[i].InitialPassword = decrypted
|
||
}
|
||
}
|
||
}
|
||
|
||
data := &PageData{
|
||
Title: "Access Assignments",
|
||
Active: "assignments",
|
||
User: user,
|
||
Assignments: assignments,
|
||
}
|
||
h.templates["assignments"].ExecuteTemplate(w, "base", data)
|
||
}
|
||
|
||
func (h *Handler) handleAssignmentsAdd(w http.ResponseWriter, r *http.Request) {
|
||
userID := h.getUserID(r)
|
||
user, _ := h.auth.GetUserByID(userID)
|
||
|
||
if r.Method == http.MethodGet {
|
||
allUsers, _ := h.auth.GetAllUsers()
|
||
allKeys, _ := h.keys.GetAllKeys()
|
||
allServers, _ := h.servers.GetAllServers()
|
||
allGroups, _ := h.servers.GetAllGroups()
|
||
|
||
data := &PageData{
|
||
Title: "Create Assignment",
|
||
Active: "assignments",
|
||
User: user,
|
||
AssignAllUsers: allUsers,
|
||
AssignAllKeys: allKeys,
|
||
AssignAllHosts: allServers,
|
||
AssignAllGroups: allGroups,
|
||
}
|
||
h.templates["assignments_add"].ExecuteTemplate(w, "base", data)
|
||
return
|
||
}
|
||
|
||
// POST: create assignment
|
||
targetUserID, _ := strconv.ParseInt(r.FormValue("user_id"), 10, 64)
|
||
sshKeyID, _ := strconv.ParseInt(r.FormValue("ssh_key_id"), 10, 64)
|
||
targetType := r.FormValue("target_type") // "host" or "group"
|
||
var serverID, groupID int64
|
||
if targetType == "host" {
|
||
serverID, _ = strconv.ParseInt(r.FormValue("server_id"), 10, 64)
|
||
} else {
|
||
groupID, _ = strconv.ParseInt(r.FormValue("group_id"), 10, 64)
|
||
}
|
||
systemUser := r.FormValue("system_user")
|
||
desiredState := r.FormValue("desired_state")
|
||
sudo := r.FormValue("sudo") == "on"
|
||
createUser := r.FormValue("create_user") == "on"
|
||
|
||
newAssignment, err := h.servers.CreateAssignment(targetUserID, sshKeyID, serverID, groupID, systemUser, desiredState, sudo, createUser)
|
||
if err != nil {
|
||
allUsers, _ := h.auth.GetAllUsers()
|
||
allKeys, _ := h.keys.GetAllKeys()
|
||
allServers, _ := h.servers.GetAllServers()
|
||
allGroups, _ := h.servers.GetAllGroups()
|
||
data := &PageData{
|
||
Title: "Create Assignment",
|
||
Active: "assignments",
|
||
User: user,
|
||
Flash: &Flash{Type: "danger", Message: "Failed to create assignment: " + err.Error()},
|
||
AssignAllUsers: allUsers,
|
||
AssignAllKeys: allKeys,
|
||
AssignAllHosts: allServers,
|
||
AssignAllGroups: allGroups,
|
||
}
|
||
h.templates["assignments_add"].ExecuteTemplate(w, "base", data)
|
||
return
|
||
}
|
||
|
||
targetUser, _ := h.auth.GetUserByID(targetUserID)
|
||
targetName := "unknown"
|
||
if targetUser != nil {
|
||
targetName = targetUser.Username
|
||
}
|
||
h.audit.Log(userID, audit.ActionAssignmentCreated, fmt.Sprintf("Created access assignment for user %s (key ID %d)", targetName, sshKeyID), clientIP(r))
|
||
|
||
// Auto-sync: deploy the key immediately after creating the assignment
|
||
h.syncAssignment(w, r, newAssignment.ID, userID)
|
||
|
||
http.Redirect(w, r, "/assignments", http.StatusSeeOther)
|
||
}
|
||
|
||
func (h *Handler) handleAssignmentAction(w http.ResponseWriter, r *http.Request) {
|
||
userID := h.getUserID(r)
|
||
user, _ := h.auth.GetUserByID(userID)
|
||
|
||
parts := strings.Split(strings.Trim(r.URL.Path, "/"), "/")
|
||
if len(parts) < 3 {
|
||
http.NotFound(w, r)
|
||
return
|
||
}
|
||
|
||
assignID, err := strconv.ParseInt(parts[1], 10, 64)
|
||
if err != nil {
|
||
http.NotFound(w, r)
|
||
return
|
||
}
|
||
|
||
action := parts[2]
|
||
|
||
switch action {
|
||
case "edit":
|
||
assignment, err := h.servers.GetAssignmentByID(assignID)
|
||
if err != nil {
|
||
http.Redirect(w, r, "/assignments", http.StatusSeeOther)
|
||
return
|
||
}
|
||
|
||
if r.Method == http.MethodGet {
|
||
allUsers, _ := h.auth.GetAllUsers()
|
||
allKeys, _ := h.keys.GetAllKeys()
|
||
allServers, _ := h.servers.GetAllServers()
|
||
allGroups, _ := h.servers.GetAllGroups()
|
||
|
||
data := &PageData{
|
||
Title: "Edit Assignment",
|
||
Active: "assignments",
|
||
User: user,
|
||
Assignment: assignment,
|
||
AssignAllUsers: allUsers,
|
||
AssignAllKeys: allKeys,
|
||
AssignAllHosts: allServers,
|
||
AssignAllGroups: allGroups,
|
||
}
|
||
h.templates["assignments_edit"].ExecuteTemplate(w, "base", data)
|
||
return
|
||
}
|
||
|
||
// POST: update assignment
|
||
targetUserID, _ := strconv.ParseInt(r.FormValue("user_id"), 10, 64)
|
||
sshKeyID, _ := strconv.ParseInt(r.FormValue("ssh_key_id"), 10, 64)
|
||
targetType := r.FormValue("target_type")
|
||
var serverID, groupID int64
|
||
if targetType == "host" {
|
||
serverID, _ = strconv.ParseInt(r.FormValue("server_id"), 10, 64)
|
||
} else {
|
||
groupID, _ = strconv.ParseInt(r.FormValue("group_id"), 10, 64)
|
||
}
|
||
systemUser := r.FormValue("system_user")
|
||
desiredState := r.FormValue("desired_state")
|
||
sudo := r.FormValue("sudo") == "on"
|
||
createUser := r.FormValue("create_user") == "on"
|
||
|
||
if err := h.servers.UpdateAssignment(assignID, targetUserID, sshKeyID, serverID, groupID, systemUser, desiredState, sudo, createUser); err != nil {
|
||
allUsers, _ := h.auth.GetAllUsers()
|
||
allKeys, _ := h.keys.GetAllKeys()
|
||
allServers, _ := h.servers.GetAllServers()
|
||
allGroups, _ := h.servers.GetAllGroups()
|
||
data := &PageData{
|
||
Title: "Edit Assignment",
|
||
Active: "assignments",
|
||
User: user,
|
||
Assignment: assignment,
|
||
Flash: &Flash{Type: "danger", Message: "Failed to update assignment: " + err.Error()},
|
||
AssignAllUsers: allUsers,
|
||
AssignAllKeys: allKeys,
|
||
AssignAllHosts: allServers,
|
||
AssignAllGroups: allGroups,
|
||
}
|
||
h.templates["assignments_edit"].ExecuteTemplate(w, "base", data)
|
||
return
|
||
}
|
||
h.audit.Log(userID, audit.ActionAssignmentUpdated, fmt.Sprintf("Updated access assignment ID %d", assignID), clientIP(r))
|
||
|
||
// Auto-sync: re-deploy after updating the assignment
|
||
h.syncAssignment(w, r, assignID, userID)
|
||
|
||
http.Redirect(w, r, "/assignments", http.StatusSeeOther)
|
||
|
||
case "delete":
|
||
if r.Method == http.MethodPost {
|
||
// Fetch assignment details before deleting
|
||
assignment, err := h.servers.GetAssignmentByID(assignID)
|
||
if err != nil {
|
||
http.Redirect(w, r, "/assignments", http.StatusSeeOther)
|
||
return
|
||
}
|
||
|
||
deleteUser := r.FormValue("delete_user") == "on"
|
||
|
||
// Resolve SSH key and target servers for cleanup
|
||
key, keyErr := h.keys.GetKeyByIDGlobal(assignment.SSHKeyID)
|
||
masterKeyPEM, masterKeyErr := h.keys.GetSystemMasterKeyPrivate()
|
||
|
||
if keyErr == nil && masterKeyErr == nil {
|
||
// Collect target servers
|
||
var targets []models.Server
|
||
if assignment.ServerID > 0 {
|
||
srv, err := h.servers.GetByIDGlobal(assignment.ServerID)
|
||
if err == nil {
|
||
targets = append(targets, *srv)
|
||
}
|
||
} else if assignment.GroupID > 0 {
|
||
members, err := h.servers.GetGroupMembersGlobal(assignment.GroupID)
|
||
if err == nil {
|
||
targets = members
|
||
}
|
||
}
|
||
|
||
for _, server := range targets {
|
||
srv := server
|
||
|
||
if deleteUser && assignment.SystemUser != "" && assignment.SystemUser != "root" {
|
||
// Delete system user (includes key removal, sudo removal)
|
||
if err := h.deploy.RemoveSystemUser(key, &srv, masterKeyPEM, assignment.SystemUser); err != nil {
|
||
logging.Warn("Assignment %d cleanup: failed to delete user '%s' on %s: %v", assignID, assignment.SystemUser, srv.Hostname, err)
|
||
h.audit.Log(userID, audit.ActionAssignmentCleanFailed,
|
||
fmt.Sprintf("Failed to delete system user '%s' on %s: %v", assignment.SystemUser, srv.Hostname, err), clientIP(r))
|
||
} else {
|
||
logging.Info("Assignment %d cleanup: deleted system user '%s' on %s", assignID, assignment.SystemUser, srv.Hostname)
|
||
h.audit.Log(userID, audit.ActionAssignmentUserDeleted,
|
||
fmt.Sprintf("Deleted system user '%s' on server %s (assignment %d)", assignment.SystemUser, srv.Hostname, assignID), clientIP(r))
|
||
}
|
||
} else {
|
||
// Only remove the SSH key from the server
|
||
if err := h.deploy.RemoveKeyFromUser(key, &srv, masterKeyPEM, assignment.SystemUser); err != nil {
|
||
logging.Warn("Assignment %d cleanup: failed to remove key from '%s' on %s: %v", assignID, assignment.SystemUser, srv.Hostname, err)
|
||
h.audit.Log(userID, audit.ActionAssignmentCleanFailed,
|
||
fmt.Sprintf("Failed to remove key from '%s' on %s: %v", assignment.SystemUser, srv.Hostname, err), clientIP(r))
|
||
} else {
|
||
logging.Info("Assignment %d cleanup: removed key from '%s' on %s", assignID, assignment.SystemUser, srv.Hostname)
|
||
h.audit.Log(userID, audit.ActionAssignmentKeyRemoved,
|
||
fmt.Sprintf("Removed SSH key from '%s' on server %s (assignment %d)", assignment.SystemUser, srv.Hostname, assignID), clientIP(r))
|
||
}
|
||
}
|
||
}
|
||
} else {
|
||
logging.Warn("Assignment %d cleanup: could not load key or master key, skipping server cleanup (keyErr=%v, masterKeyErr=%v)", assignID, keyErr, masterKeyErr)
|
||
h.audit.Log(userID, audit.ActionAssignmentCleanFailed,
|
||
fmt.Sprintf("Assignment %d cleanup skipped: key or master key unavailable", assignID), clientIP(r))
|
||
}
|
||
|
||
h.audit.Log(userID, audit.ActionAssignmentDeleted, fmt.Sprintf("Deleted access assignment ID %d (delete_user=%v)", assignID, deleteUser), clientIP(r))
|
||
h.servers.DeleteAssignment(assignID)
|
||
}
|
||
http.Redirect(w, r, "/assignments", http.StatusSeeOther)
|
||
|
||
case "sync":
|
||
// Sync an assignment: deploy or remove the key based on desired_state
|
||
if r.Method == http.MethodPost {
|
||
h.syncAssignment(w, r, assignID, userID)
|
||
}
|
||
http.Redirect(w, r, "/assignments", http.StatusSeeOther)
|
||
|
||
default:
|
||
http.NotFound(w, r)
|
||
}
|
||
}
|
||
|
||
// syncAssignment executes the deployment/removal for a single access assignment
|
||
func (h *Handler) syncAssignment(w http.ResponseWriter, r *http.Request, assignID, actingUserID int64) {
|
||
assignment, err := h.servers.GetAssignmentByID(assignID)
|
||
if err != nil {
|
||
logging.Error("Sync assignment %d: not found: %v", assignID, err)
|
||
return
|
||
}
|
||
|
||
// Resolve the SSH key (need private key of the key owner for auth OR use the key's public key)
|
||
// For deployment we need: the public key to deploy and an auth method to connect
|
||
key, err := h.keys.GetKeyByIDGlobal(assignment.SSHKeyID)
|
||
if err != nil {
|
||
h.servers.UpdateAssignmentStatus(assignID, "failed", "SSH key not found")
|
||
h.audit.Log(actingUserID, audit.ActionAssignmentSyncFailed, fmt.Sprintf("Assignment %d sync failed: SSH key not found", assignID), clientIP(r))
|
||
logging.Error("Sync assignment %d: key %d not found: %v", assignID, assignment.SSHKeyID, err)
|
||
return
|
||
}
|
||
|
||
// Collect target servers
|
||
var targets []models.Server
|
||
if assignment.ServerID > 0 {
|
||
srv, err := h.servers.GetByIDGlobal(assignment.ServerID)
|
||
if err != nil {
|
||
h.servers.UpdateAssignmentStatus(assignID, "failed", "Target host not found")
|
||
h.audit.Log(actingUserID, audit.ActionAssignmentSyncFailed, fmt.Sprintf("Assignment %d sync failed: host not found", assignID), clientIP(r))
|
||
return
|
||
}
|
||
targets = append(targets, *srv)
|
||
} else if assignment.GroupID > 0 {
|
||
members, err := h.servers.GetGroupMembersGlobal(assignment.GroupID)
|
||
if err != nil || len(members) == 0 {
|
||
h.servers.UpdateAssignmentStatus(assignID, "failed", "Group has no members or not found")
|
||
h.audit.Log(actingUserID, audit.ActionAssignmentSyncFailed, fmt.Sprintf("Assignment %d sync failed: group empty or not found", assignID), clientIP(r))
|
||
return
|
||
}
|
||
targets = members
|
||
} else {
|
||
h.servers.UpdateAssignmentStatus(assignID, "failed", "No target defined")
|
||
return
|
||
}
|
||
|
||
// For each target, use the system master key for SSH authentication
|
||
masterKeyPEM, masterKeyErr := h.keys.GetSystemMasterKeyPrivate()
|
||
if masterKeyErr != nil {
|
||
h.servers.UpdateAssignmentStatus(assignID, "failed", "System master key not available")
|
||
h.audit.Log(actingUserID, audit.ActionAssignmentSyncFailed, fmt.Sprintf("Assignment %d sync failed: system master key not available", assignID), clientIP(r))
|
||
logging.Error("Sync assignment %d: system master key not available: %v", assignID, masterKeyErr)
|
||
return
|
||
}
|
||
|
||
var successCount, failCount int
|
||
|
||
// Generate initial password if createUser is enabled and no password is stored yet
|
||
var initialPassword string
|
||
if assignment.CreateUser && assignment.InitialPassword == "" {
|
||
initialPassword = generateInitialPassword(10)
|
||
}
|
||
|
||
for _, server := range targets {
|
||
srv := server
|
||
|
||
var deployErr error
|
||
|
||
if assignment.DesiredState == "present" {
|
||
// Deploy key – connect as server admin user with system master key, deploy to systemUser
|
||
deployErr = h.deploy.DeployKeyToUser(key, &srv, masterKeyPEM, assignment.SystemUser, assignment.CreateUser, assignment.Sudo, initialPassword)
|
||
} else {
|
||
// Remove key (desired_state == "absent")
|
||
deployErr = h.deploy.RemoveKeyFromUser(key, &srv, masterKeyPEM, assignment.SystemUser)
|
||
}
|
||
|
||
if deployErr != nil {
|
||
failCount++
|
||
logging.Warn("Sync assignment %d to %s@%s:%d failed: %v", assignID, assignment.SystemUser, srv.Hostname, srv.Port, deployErr)
|
||
} else {
|
||
successCount++
|
||
logging.Info("Sync assignment %d to %s@%s:%d successful", assignID, assignment.SystemUser, srv.Hostname, srv.Port)
|
||
}
|
||
}
|
||
|
||
// Store initial password (encrypted) if it was generated
|
||
if initialPassword != "" && successCount > 0 {
|
||
if encPW, err := h.keys.EncryptValue(initialPassword); err == nil {
|
||
h.servers.UpdateAssignmentInitialPassword(assignID, encPW)
|
||
} else {
|
||
logging.Warn("Failed to encrypt initial password for assignment %d: %v", assignID, err)
|
||
}
|
||
}
|
||
|
||
if failCount == 0 {
|
||
h.servers.UpdateAssignmentStatus(assignID, "synced", "")
|
||
h.audit.Log(actingUserID, audit.ActionAssignmentSynced, fmt.Sprintf("Assignment %d synced: %d/%d targets", assignID, successCount, len(targets)), clientIP(r))
|
||
} else {
|
||
h.servers.UpdateAssignmentStatus(assignID, "failed", fmt.Sprintf("%d/%d targets failed", failCount, len(targets)))
|
||
h.audit.Log(actingUserID, audit.ActionAssignmentSyncFailed, fmt.Sprintf("Assignment %d sync: %d success, %d failed of %d targets", assignID, successCount, failCount, len(targets)), clientIP(r))
|
||
}
|
||
}
|
||
|
||
// handleMyAssignments shows the current user's own access assignments (for User role)
|
||
func (h *Handler) handleMyAssignments(w http.ResponseWriter, r *http.Request) {
|
||
userID := h.getUserID(r)
|
||
user, _ := h.auth.GetUserByID(userID)
|
||
|
||
assignments, _ := h.servers.GetAssignmentsByUser(userID)
|
||
assignedHosts, _ := h.servers.GetServersByAssignedUser(userID)
|
||
|
||
// Decrypt initial passwords for display
|
||
for i := range assignments {
|
||
if assignments[i].InitialPassword != "" {
|
||
if decrypted, err := h.keys.DecryptValue(assignments[i].InitialPassword); err == nil {
|
||
assignments[i].InitialPassword = decrypted
|
||
}
|
||
}
|
||
}
|
||
|
||
data := &PageData{
|
||
Title: "My Access",
|
||
Active: "my_access",
|
||
User: user,
|
||
Assignments: assignments,
|
||
Servers: assignedHosts,
|
||
}
|
||
h.templates["assignments"].ExecuteTemplate(w, "base", data)
|
||
}
|
||
|
||
// --- Admin Settings Handler ---
|
||
|
||
// --- System Information Handler ---
|
||
|
||
func (h *Handler) handleSystemInfo(w http.ResponseWriter, r *http.Request) {
|
||
userID := h.getUserID(r)
|
||
user, _ := h.auth.GetUserByID(userID)
|
||
|
||
var memStats runtime.MemStats
|
||
runtime.ReadMemStats(&memStats)
|
||
|
||
// Detect runtime environment
|
||
runtimeEnv := "Native"
|
||
if _, err := os.Stat("/.dockerenv"); err == nil {
|
||
runtimeEnv = "Docker"
|
||
}
|
||
|
||
hostname, _ := os.Hostname()
|
||
uptimeStr := formatUptime(startTime)
|
||
|
||
sysInfo := &SystemInfo{
|
||
GoVersion: runtime.Version(),
|
||
OS: runtime.GOOS,
|
||
Arch: runtime.GOARCH,
|
||
NumCPU: runtime.NumCPU(),
|
||
NumGoroutine: runtime.NumGoroutine(),
|
||
MemAlloc: formatBytes(memStats.Alloc),
|
||
MemSys: formatBytes(memStats.Sys),
|
||
Runtime: runtimeEnv,
|
||
Hostname: hostname,
|
||
Uptime: uptimeStr,
|
||
Timezone: time.Local.String(),
|
||
}
|
||
|
||
data := &PageData{
|
||
Title: "System Information",
|
||
Active: "system_info",
|
||
User: user,
|
||
SystemInfo: sysInfo,
|
||
}
|
||
h.templates["system_info"].ExecuteTemplate(w, "base", data)
|
||
}
|
||
|
||
// handleLoginBrandingUpload handles background image upload for the login page
|
||
func (h *Handler) handleLoginBrandingUpload(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method != http.MethodPost {
|
||
http.Redirect(w, r, "/admin/settings", http.StatusSeeOther)
|
||
return
|
||
}
|
||
|
||
userID := h.getUserID(r)
|
||
|
||
// Limit upload to 5MB
|
||
r.Body = http.MaxBytesReader(w, r.Body, 5<<20)
|
||
if err := r.ParseMultipartForm(5 << 20); err != nil {
|
||
http.Redirect(w, r, "/admin/settings?flash_type=danger&flash_msg="+url.QueryEscape("File too large. Maximum size is 5 MB."), http.StatusSeeOther)
|
||
return
|
||
}
|
||
|
||
file, header, err := r.FormFile("login_bg")
|
||
if err != nil {
|
||
http.Redirect(w, r, "/admin/settings?flash_type=danger&flash_msg="+url.QueryEscape("No file selected."), http.StatusSeeOther)
|
||
return
|
||
}
|
||
defer file.Close()
|
||
|
||
// Validate content type
|
||
ct := header.Header.Get("Content-Type")
|
||
if ct != "image/png" && ct != "image/jpeg" && ct != "image/webp" {
|
||
http.Redirect(w, r, "/admin/settings?flash_type=danger&flash_msg="+url.QueryEscape("Invalid file type. Only PNG, JPEG and WebP are allowed."), http.StatusSeeOther)
|
||
return
|
||
}
|
||
|
||
data, err := io.ReadAll(file)
|
||
if err != nil {
|
||
http.Redirect(w, r, "/admin/settings?flash_type=danger&flash_msg="+url.QueryEscape("Failed to read uploaded file."), http.StatusSeeOther)
|
||
return
|
||
}
|
||
|
||
bgPath := filepath.Join(h.dataDir, "branding", "login_bg")
|
||
if err := os.WriteFile(bgPath, data, 0600); err != nil {
|
||
logging.Warn("Failed to save login background image: %v", err)
|
||
http.Redirect(w, r, "/admin/settings?flash_type=danger&flash_msg="+url.QueryEscape("Failed to save background image."), http.StatusSeeOther)
|
||
return
|
||
}
|
||
|
||
// Auto-detect brightness and set text color accordingly
|
||
textColor := analyzeImageBrightness(data)
|
||
if err := h.auth.SetSetting("login_text_color", textColor); err != nil {
|
||
logging.Warn("Failed to save auto-detected text color: %v", err)
|
||
}
|
||
logging.Info("Login background uploaded: auto-detected text color = %s", textColor)
|
||
|
||
h.audit.Log(userID, audit.ActionBrandingChanged, fmt.Sprintf("Login background image uploaded (auto text color: %s)", textColor), clientIP(r))
|
||
http.Redirect(w, r, "/admin/settings?flash_type=success&flash_msg="+url.QueryEscape("Background image uploaded successfully."), http.StatusSeeOther)
|
||
}
|
||
|
||
// handleLoginBrandingRemoveBg removes the login page background image
|
||
func (h *Handler) handleLoginBrandingRemoveBg(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method != http.MethodPost {
|
||
http.Redirect(w, r, "/admin/settings", http.StatusSeeOther)
|
||
return
|
||
}
|
||
|
||
userID := h.getUserID(r)
|
||
bgPath := filepath.Join(h.dataDir, "branding", "login_bg")
|
||
os.Remove(bgPath)
|
||
// Reset auto-detected text color
|
||
_ = h.auth.SetSetting("login_text_color", "light")
|
||
|
||
h.audit.Log(userID, audit.ActionBrandingChanged, "Login background image removed", clientIP(r))
|
||
http.Redirect(w, r, "/admin/settings?flash_type=success&flash_msg="+url.QueryEscape("Background image removed."), http.StatusSeeOther)
|
||
}
|
||
|
||
// handleLoginBgServe serves the login page background image (public, no auth required)
|
||
func (h *Handler) handleLoginBgServe(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method != http.MethodGet {
|
||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||
return
|
||
}
|
||
|
||
bgPath := filepath.Join(h.dataDir, "branding", "login_bg")
|
||
data, err := os.ReadFile(bgPath)
|
||
if err != nil {
|
||
http.NotFound(w, r)
|
||
return
|
||
}
|
||
|
||
contentType := http.DetectContentType(data)
|
||
w.Header().Set("Content-Type", contentType)
|
||
w.Header().Set("Cache-Control", "public, max-age=3600")
|
||
w.Write(data)
|
||
}
|
||
|
||
// analyzeImageBrightness decodes an image and computes the average perceived
|
||
// brightness using the ITU-R BT.709 luminance formula. It samples every Nth
|
||
// pixel for performance. Returns "light" if the image is dark (bright text
|
||
// needed) or "dark" if the image is bright (dark text needed).
|
||
func analyzeImageBrightness(data []byte) string {
|
||
img, _, err := image.Decode(bytes.NewReader(data))
|
||
if err != nil {
|
||
// Cannot decode → assume dark image, use light text
|
||
return "light"
|
||
}
|
||
|
||
bounds := img.Bounds()
|
||
width := bounds.Max.X - bounds.Min.X
|
||
height := bounds.Max.Y - bounds.Min.Y
|
||
totalPixels := width * height
|
||
|
||
// Sample step: aim for ~10 000 pixels max for performance
|
||
step := 1
|
||
if totalPixels > 10000 {
|
||
step = int(math.Sqrt(float64(totalPixels) / 10000))
|
||
if step < 1 {
|
||
step = 1
|
||
}
|
||
}
|
||
|
||
var sum float64
|
||
var count int
|
||
for y := bounds.Min.Y; y < bounds.Max.Y; y += step {
|
||
for x := bounds.Min.X; x < bounds.Max.X; x += step {
|
||
r, g, b, _ := img.At(x, y).RGBA()
|
||
// ITU-R BT.709 perceived luminance (values are 0–65535)
|
||
lum := 0.2126*float64(r) + 0.7152*float64(g) + 0.0722*float64(b)
|
||
sum += lum
|
||
count++
|
||
}
|
||
}
|
||
|
||
if count == 0 {
|
||
return "light"
|
||
}
|
||
|
||
avg := sum / float64(count)
|
||
// 65535 / 2 = 32767.5 → threshold at ~40% brightness
|
||
if avg < 26214 {
|
||
return "light" // dark image → use light/white text
|
||
}
|
||
return "dark" // bright image → use dark text
|
||
}
|
||
|
||
func (h *Handler) handleAdminSettings(w http.ResponseWriter, r *http.Request) {
|
||
userID := h.getUserID(r)
|
||
user, _ := h.auth.GetUserByID(userID)
|
||
settings, _ := h.auth.GetAllSettings()
|
||
|
||
// Get system master key info
|
||
masterPub, _ := h.keys.GetSystemMasterKeyPublic()
|
||
masterFP, _ := h.keys.GetSystemMasterKeyFingerprint()
|
||
|
||
if r.Method == http.MethodGet {
|
||
adminUsers := h.buildAdminUserList()
|
||
data := &PageData{
|
||
Title: "Admin Settings",
|
||
Active: "admin_settings",
|
||
User: user,
|
||
Settings: settings,
|
||
AdminUsers: adminUsers,
|
||
EmailEnabled: h.mail.IsEnabled(),
|
||
MasterKeyPublic: masterPub,
|
||
MasterKeyFingerprint: masterFP,
|
||
EnforcementStatus: h.worker.GetStatus(),
|
||
}
|
||
|
||
// Check for flash message from query parameters (e.g. after backup restore)
|
||
if flashType := r.URL.Query().Get("flash_type"); flashType != "" {
|
||
if flashMsg := r.URL.Query().Get("flash_msg"); flashMsg != "" {
|
||
data.Flash = &Flash{Type: flashType, Message: flashMsg}
|
||
}
|
||
}
|
||
|
||
h.templates["admin_settings"].ExecuteTemplate(w, "base", data)
|
||
return
|
||
}
|
||
|
||
// POST: save application settings
|
||
r.ParseForm()
|
||
formType := r.FormValue("form_type")
|
||
var changed []string
|
||
|
||
logging.Info("Admin settings POST: form_type=%s from user_id=%d", formType, userID)
|
||
|
||
switch formType {
|
||
case "security_settings":
|
||
// Collect all settings to save
|
||
batch := make(map[string]string)
|
||
|
||
// Number settings
|
||
for _, key := range []string{"pw_min_length", "lockout_attempts", "lockout_duration"} {
|
||
val := r.FormValue(key)
|
||
if val != "" {
|
||
batch[key] = val
|
||
changed = append(changed, key+"="+val)
|
||
}
|
||
}
|
||
// Boolean settings (checkbox: present = true, absent = false)
|
||
for _, key := range []string{"pw_require_upper", "pw_require_lower", "pw_require_digit", "pw_require_special", "mfa_required"} {
|
||
if _, ok := r.PostForm[key]; ok {
|
||
batch[key] = "true"
|
||
changed = append(changed, key+"=true")
|
||
} else {
|
||
batch[key] = "false"
|
||
changed = append(changed, key+"=false")
|
||
}
|
||
}
|
||
logging.Info("Saving security settings: %v", batch)
|
||
// Save all settings in a single transaction
|
||
if err := h.auth.SetSettingsBatch(batch); err != nil {
|
||
logging.Error("Failed to save security settings: %v", err)
|
||
http.Redirect(w, r, "/admin/settings?flash_type=danger&flash_msg="+url.QueryEscape("Failed to save security settings: "+err.Error()), http.StatusSeeOther)
|
||
return
|
||
}
|
||
logging.Info("Security settings saved successfully")
|
||
if len(changed) > 0 {
|
||
h.audit.Log(userID, audit.ActionPasswordPolicyChanged, fmt.Sprintf("Security settings updated: %s", strings.Join(changed, ", ")), clientIP(r))
|
||
}
|
||
case "enforcement_settings":
|
||
// Key enforcement settings
|
||
batch := make(map[string]string)
|
||
enforceMode := r.FormValue("enforce_mode")
|
||
if enforceMode == "" {
|
||
enforceMode = "disabled"
|
||
}
|
||
batch["enforce_mode"] = enforceMode
|
||
changed = append(changed, "enforce_mode="+enforceMode)
|
||
|
||
enforceInterval := r.FormValue("enforce_interval")
|
||
if enforceInterval == "" {
|
||
enforceInterval = "15"
|
||
}
|
||
batch["enforce_interval"] = enforceInterval
|
||
changed = append(changed, "enforce_interval="+enforceInterval)
|
||
|
||
if err := h.auth.SetSettingsBatch(batch); err != nil {
|
||
logging.Error("Failed to save enforcement settings: %v", err)
|
||
http.Redirect(w, r, "/admin/settings?flash_type=danger&flash_msg="+url.QueryEscape("Failed to save enforcement settings: "+err.Error()), http.StatusSeeOther)
|
||
return
|
||
}
|
||
if len(changed) > 0 {
|
||
h.audit.Log(userID, audit.ActionEnforcementSettings, fmt.Sprintf("Enforcement settings updated: %s", strings.Join(changed, ", ")), clientIP(r))
|
||
}
|
||
default:
|
||
// Application settings (existing behavior)
|
||
batch := make(map[string]string)
|
||
for _, key := range []string{"app_name", "default_key_type", "default_key_bits", "session_timeout"} {
|
||
val := r.FormValue(key)
|
||
if val != "" || key == "app_name" {
|
||
batch[key] = val
|
||
changed = append(changed, key+"="+val)
|
||
}
|
||
}
|
||
if err := h.auth.SetSettingsBatch(batch); err != nil {
|
||
logging.Error("Failed to save application settings: %v", err)
|
||
http.Redirect(w, r, "/admin/settings?flash_type=danger&flash_msg="+url.QueryEscape("Failed to save application settings: "+err.Error()), http.StatusSeeOther)
|
||
return
|
||
}
|
||
if len(changed) > 0 {
|
||
h.audit.Log(userID, audit.ActionSettingsChanged, fmt.Sprintf("Changed settings: %s", strings.Join(changed, ", ")), clientIP(r))
|
||
}
|
||
}
|
||
|
||
http.Redirect(w, r, "/admin/settings?flash_type=success&flash_msg="+url.QueryEscape("Settings saved successfully."), http.StatusSeeOther)
|
||
}
|
||
|
||
// buildAdminUserList creates the user list for admin settings
|
||
func (h *Handler) buildAdminUserList() []AdminUserInfo {
|
||
users, _ := h.auth.GetAllUsers()
|
||
var result []AdminUserInfo
|
||
for _, u := range users {
|
||
info := AdminUserInfo{
|
||
ID: u.ID,
|
||
Username: u.Username,
|
||
Role: u.Role,
|
||
}
|
||
result = append(result, info)
|
||
}
|
||
return result
|
||
}
|
||
|
||
// handleMasterKeyRegenerate regenerates the system master key (owner only, requires password confirmation)
|
||
func (h *Handler) handleMasterKeyRegenerate(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method != http.MethodPost {
|
||
http.Redirect(w, r, "/admin/settings", http.StatusSeeOther)
|
||
return
|
||
}
|
||
|
||
userID := h.getUserID(r)
|
||
user, _ := h.auth.GetUserByID(userID)
|
||
|
||
// Require password confirmation
|
||
password := r.FormValue("confirm_password")
|
||
if password == "" {
|
||
http.Redirect(w, r, "/admin/settings?flash_type=danger&flash_msg="+url.QueryEscape("Master key regeneration failed: no password provided"), http.StatusSeeOther)
|
||
return
|
||
}
|
||
|
||
// Verify the owner's password
|
||
_, err := h.auth.Login(user.Username, password)
|
||
if err != nil {
|
||
logging.Warn("Master key regeneration failed: invalid password for user %s", user.Username)
|
||
h.audit.Log(userID, audit.ActionMasterKeyRegenFailed, "Master key regeneration failed: wrong password", clientIP(r))
|
||
http.Redirect(w, r, "/admin/settings?flash_type=danger&flash_msg="+url.QueryEscape("Master key regeneration failed: wrong password"), http.StatusSeeOther)
|
||
return
|
||
}
|
||
|
||
// Regenerate the system master key
|
||
newPub, err := h.keys.RegenerateSystemMasterKey()
|
||
if err != nil {
|
||
logging.Error("Failed to regenerate system master key: %v", err)
|
||
http.Redirect(w, r, "/admin/settings?flash_type=danger&flash_msg="+url.QueryEscape("Master key regeneration failed: "+err.Error()), http.StatusSeeOther)
|
||
return
|
||
}
|
||
|
||
logging.Info("System master key regenerated by user %s", user.Username)
|
||
h.audit.Log(userID, audit.ActionMasterKeyRegenerated, fmt.Sprintf("System master key regenerated. New public key: %s", newPub[:40]+"..."), clientIP(r))
|
||
http.Redirect(w, r, "/admin/settings?flash_type=success&flash_msg="+url.QueryEscape("System master key successfully regenerated."), http.StatusSeeOther)
|
||
}
|
||
|
||
// handleEnforcementRunNow triggers an immediate key enforcement run (owner only)
|
||
func (h *Handler) handleEnforcementRunNow(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method != http.MethodPost {
|
||
http.Redirect(w, r, "/admin/settings", http.StatusSeeOther)
|
||
return
|
||
}
|
||
|
||
userID := h.getUserID(r)
|
||
logging.Info("Key enforcement manual run triggered by user_id=%d", userID)
|
||
h.audit.Log(userID, audit.ActionEnforcementRun, "Manual key enforcement run triggered", clientIP(r))
|
||
|
||
h.worker.RunNow()
|
||
|
||
http.Redirect(w, r, "/admin/settings?flash_type=success&flash_msg="+url.QueryEscape("Key enforcement run started. Check the audit log for results."), http.StatusSeeOther)
|
||
}
|
||
|
||
// --- Cron Job Handlers ---
|
||
|
||
// handleAPICronAssignments returns assignments for a given user as JSON (for AJAX).
|
||
// GET /api/cron/keys?user_id=X — returns SSH keys for a user
|
||
func (h *Handler) handleAPICronKeys(w http.ResponseWriter, r *http.Request) {
|
||
userIDStr := r.URL.Query().Get("user_id")
|
||
targetUserID, err := strconv.ParseInt(userIDStr, 10, 64)
|
||
if err != nil || targetUserID <= 0 {
|
||
w.Header().Set("Content-Type", "application/json")
|
||
w.Write([]byte("[]"))
|
||
return
|
||
}
|
||
|
||
userKeys, err := h.keys.GetKeysByUser(targetUserID)
|
||
if err != nil {
|
||
w.Header().Set("Content-Type", "application/json")
|
||
w.Write([]byte("[]"))
|
||
return
|
||
}
|
||
|
||
type keyJSON struct {
|
||
ID int64 `json:"id"`
|
||
Name string `json:"name"`
|
||
KeyType string `json:"key_type"`
|
||
Fingerprint string `json:"fingerprint"`
|
||
}
|
||
|
||
var result []keyJSON
|
||
for _, k := range userKeys {
|
||
result = append(result, keyJSON{
|
||
ID: k.ID,
|
||
Name: k.Name,
|
||
KeyType: k.KeyType,
|
||
Fingerprint: k.Fingerprint,
|
||
})
|
||
}
|
||
|
||
w.Header().Set("Content-Type", "application/json")
|
||
json.NewEncoder(w).Encode(result)
|
||
}
|
||
|
||
func (h *Handler) handleCron(w http.ResponseWriter, r *http.Request) {
|
||
userID := h.getUserID(r)
|
||
user, _ := h.auth.GetUserByID(userID)
|
||
jobs, _ := h.cron.GetByUser(userID)
|
||
|
||
// Decrypt initial passwords for display
|
||
for i := range jobs {
|
||
if jobs[i].InitialPassword != "" {
|
||
if decrypted, err := h.keys.DecryptValue(jobs[i].InitialPassword); err == nil {
|
||
jobs[i].InitialPassword = decrypted
|
||
}
|
||
}
|
||
}
|
||
|
||
data := &PageData{
|
||
Title: "Temporary Access",
|
||
Active: "cron",
|
||
User: user,
|
||
CronJobs: jobs,
|
||
}
|
||
h.templates["cron"].ExecuteTemplate(w, "base", data)
|
||
}
|
||
|
||
func (h *Handler) handleCronAdd(w http.ResponseWriter, r *http.Request) {
|
||
userID := h.getUserID(r)
|
||
user, _ := h.auth.GetUserByID(userID)
|
||
serverList, _ := h.servers.GetAllServers()
|
||
groups, _ := h.servers.GetAllGroups()
|
||
allUsers, _ := h.auth.GetAllUsers()
|
||
|
||
if r.Method == http.MethodGet {
|
||
data := &PageData{
|
||
Title: "New Temporary Access",
|
||
Active: "cron",
|
||
User: user,
|
||
Servers: serverList,
|
||
Groups: groups,
|
||
DaysOfMonth: daysOfMonth(),
|
||
AssignAllUsers: allUsers,
|
||
}
|
||
h.templates["cron_add"].ExecuteTemplate(w, "base", data)
|
||
return
|
||
}
|
||
|
||
// POST: create temporary access job
|
||
name := r.FormValue("name")
|
||
targetUserID, _ := strconv.ParseInt(r.FormValue("target_user_id"), 10, 64)
|
||
keyID, _ := strconv.ParseInt(r.FormValue("key_id"), 10, 64)
|
||
|
||
targetType := r.FormValue("target_type")
|
||
serverID, _ := strconv.ParseInt(r.FormValue("server_id"), 10, 64)
|
||
groupID, _ := strconv.ParseInt(r.FormValue("group_id"), 10, 64)
|
||
if targetType == "group" {
|
||
serverID = 0
|
||
} else {
|
||
groupID = 0
|
||
}
|
||
|
||
systemUser := r.FormValue("system_user")
|
||
sudo := r.FormValue("sudo") == "on"
|
||
createUser := r.FormValue("create_user") == "on"
|
||
initialPassword := r.FormValue("initial_password")
|
||
expiryAction := r.FormValue("expiry_action")
|
||
|
||
schedule := r.FormValue("schedule")
|
||
scheduledAtStr := r.FormValue("scheduled_at")
|
||
removeAfterMin, _ := strconv.Atoi(r.FormValue("remove_after_min"))
|
||
tz := r.FormValue("timezone")
|
||
timeOfDay := r.FormValue("time_of_day")
|
||
dayOfWeek, _ := strconv.Atoi(r.FormValue("day_of_week"))
|
||
dayOfMonth, _ := strconv.Atoi(r.FormValue("day_of_month"))
|
||
minuteOfHour, _ := strconv.Atoi(r.FormValue("minute_of_hour"))
|
||
|
||
if tz == "" {
|
||
tz = "UTC"
|
||
}
|
||
if timeOfDay == "" {
|
||
timeOfDay = "00:00"
|
||
}
|
||
|
||
// Parse scheduled_at in the user's timezone for "once" schedule
|
||
var scheduledAt time.Time
|
||
if schedule == "once" && scheduledAtStr != "" {
|
||
loc, err := time.LoadLocation(tz)
|
||
if err != nil {
|
||
loc = time.UTC
|
||
}
|
||
scheduledAt, err = time.ParseInLocation("2006-01-02T15:04", scheduledAtStr, loc)
|
||
if err != nil {
|
||
data := &PageData{
|
||
Title: "New Temporary Access",
|
||
Active: "cron",
|
||
User: user,
|
||
Servers: serverList,
|
||
Groups: groups,
|
||
DaysOfMonth: daysOfMonth(),
|
||
AssignAllUsers: allUsers,
|
||
Flash: &Flash{Type: "danger", Message: "Invalid date format."},
|
||
}
|
||
h.templates["cron_add"].ExecuteTemplate(w, "base", data)
|
||
return
|
||
}
|
||
} else {
|
||
scheduledAt = time.Now().UTC()
|
||
}
|
||
|
||
job, err := h.cron.Create(userID, name, keyID, serverID, groupID, schedule, scheduledAt, removeAfterMin, tz, timeOfDay, dayOfWeek, dayOfMonth, minuteOfHour, targetUserID, systemUser, sudo, createUser, initialPassword, expiryAction)
|
||
if err != nil {
|
||
data := &PageData{
|
||
Title: "New Temporary Access",
|
||
Active: "cron",
|
||
User: user,
|
||
Servers: serverList,
|
||
Groups: groups,
|
||
DaysOfMonth: daysOfMonth(),
|
||
AssignAllUsers: allUsers,
|
||
Flash: &Flash{Type: "danger", Message: "Failed to create job: " + err.Error()},
|
||
}
|
||
h.templates["cron_add"].ExecuteTemplate(w, "base", data)
|
||
return
|
||
}
|
||
|
||
h.audit.Log(userID, audit.ActionCronJobCreated, fmt.Sprintf("Created temporary access: %s (ID %d, schedule: %s)", name, job.ID, schedule), clientIP(r))
|
||
http.Redirect(w, r, "/cron", http.StatusSeeOther)
|
||
}
|
||
|
||
func (h *Handler) handleCronAction(w http.ResponseWriter, r *http.Request) {
|
||
userID := h.getUserID(r)
|
||
user, _ := h.auth.GetUserByID(userID)
|
||
|
||
parts := strings.Split(strings.Trim(r.URL.Path, "/"), "/")
|
||
if len(parts) < 3 {
|
||
http.NotFound(w, r)
|
||
return
|
||
}
|
||
|
||
jobID, err := strconv.ParseInt(parts[1], 10, 64)
|
||
if err != nil {
|
||
http.NotFound(w, r)
|
||
return
|
||
}
|
||
|
||
action := parts[2]
|
||
|
||
switch action {
|
||
case "edit":
|
||
job, err := h.cron.GetByID(jobID, userID)
|
||
if err != nil {
|
||
http.Redirect(w, r, "/cron", http.StatusSeeOther)
|
||
return
|
||
}
|
||
|
||
serverList, _ := h.servers.GetAllServers()
|
||
groups, _ := h.servers.GetAllGroups()
|
||
allUsers, _ := h.auth.GetAllUsers()
|
||
|
||
if r.Method == http.MethodGet {
|
||
data := &PageData{
|
||
Title: "Edit Temporary Access",
|
||
Active: "cron",
|
||
User: user,
|
||
CronJob: job,
|
||
Servers: serverList,
|
||
Groups: groups,
|
||
DaysOfMonth: daysOfMonth(),
|
||
AssignAllUsers: allUsers,
|
||
}
|
||
h.templates["cron_edit"].ExecuteTemplate(w, "base", data)
|
||
return
|
||
}
|
||
|
||
// POST: update
|
||
name := r.FormValue("name")
|
||
targetUserID, _ := strconv.ParseInt(r.FormValue("target_user_id"), 10, 64)
|
||
keyID, _ := strconv.ParseInt(r.FormValue("key_id"), 10, 64)
|
||
|
||
targetType := r.FormValue("target_type")
|
||
serverID, _ := strconv.ParseInt(r.FormValue("server_id"), 10, 64)
|
||
groupID, _ := strconv.ParseInt(r.FormValue("group_id"), 10, 64)
|
||
if targetType == "group" {
|
||
serverID = 0
|
||
} else {
|
||
groupID = 0
|
||
}
|
||
|
||
systemUser := r.FormValue("system_user")
|
||
sudo := r.FormValue("sudo") == "on"
|
||
createUser := r.FormValue("create_user") == "on"
|
||
initialPassword := r.FormValue("initial_password")
|
||
expiryAction := r.FormValue("expiry_action")
|
||
|
||
schedule := r.FormValue("schedule")
|
||
scheduledAtStr := r.FormValue("scheduled_at")
|
||
removeAfterMin, _ := strconv.Atoi(r.FormValue("remove_after_min"))
|
||
tz := r.FormValue("timezone")
|
||
timeOfDay := r.FormValue("time_of_day")
|
||
dayOfWeek, _ := strconv.Atoi(r.FormValue("day_of_week"))
|
||
dayOfMonth, _ := strconv.Atoi(r.FormValue("day_of_month"))
|
||
minuteOfHour, _ := strconv.Atoi(r.FormValue("minute_of_hour"))
|
||
|
||
if tz == "" {
|
||
tz = "UTC"
|
||
}
|
||
if timeOfDay == "" {
|
||
timeOfDay = "00:00"
|
||
}
|
||
|
||
// Parse scheduled_at in the user's timezone for "once" schedule
|
||
var scheduledAt time.Time
|
||
if schedule == "once" && scheduledAtStr != "" {
|
||
loc, locErr := time.LoadLocation(tz)
|
||
if locErr != nil {
|
||
loc = time.UTC
|
||
}
|
||
var parseErr error
|
||
scheduledAt, parseErr = time.ParseInLocation("2006-01-02T15:04", scheduledAtStr, loc)
|
||
if parseErr != nil {
|
||
data := &PageData{
|
||
Title: "Edit Temporary Access",
|
||
Active: "cron",
|
||
User: user,
|
||
CronJob: job,
|
||
Servers: serverList,
|
||
Groups: groups,
|
||
DaysOfMonth: daysOfMonth(),
|
||
AssignAllUsers: allUsers,
|
||
Flash: &Flash{Type: "danger", Message: "Invalid date format."},
|
||
}
|
||
h.templates["cron_edit"].ExecuteTemplate(w, "base", data)
|
||
return
|
||
}
|
||
} else {
|
||
scheduledAt = time.Now().UTC()
|
||
}
|
||
|
||
if err := h.cron.Update(jobID, userID, name, keyID, serverID, groupID, schedule, scheduledAt, removeAfterMin, tz, timeOfDay, dayOfWeek, dayOfMonth, minuteOfHour, targetUserID, systemUser, sudo, createUser, initialPassword, expiryAction); err != nil {
|
||
data := &PageData{
|
||
Title: "Edit Temporary Access",
|
||
Active: "cron",
|
||
User: user,
|
||
CronJob: job,
|
||
Servers: serverList,
|
||
Groups: groups,
|
||
DaysOfMonth: daysOfMonth(),
|
||
AssignAllUsers: allUsers,
|
||
Flash: &Flash{Type: "danger", Message: "Failed to update job: " + err.Error()},
|
||
}
|
||
h.templates["cron_edit"].ExecuteTemplate(w, "base", data)
|
||
return
|
||
}
|
||
|
||
h.audit.Log(userID, audit.ActionCronJobUpdated, fmt.Sprintf("Updated temporary access: %s (ID %d)", name, jobID), clientIP(r))
|
||
http.Redirect(w, r, "/cron", http.StatusSeeOther)
|
||
|
||
case "delete":
|
||
if r.Method == http.MethodPost {
|
||
job, _ := h.cron.GetByID(jobID, userID)
|
||
jname := "unknown"
|
||
if job != nil {
|
||
jname = job.Name
|
||
}
|
||
h.audit.Log(userID, audit.ActionCronJobDeleted, fmt.Sprintf("Deleted cron job: %s (ID %d)", jname, jobID), clientIP(r))
|
||
h.cron.Delete(jobID, userID)
|
||
}
|
||
http.Redirect(w, r, "/cron", http.StatusSeeOther)
|
||
|
||
case "toggle":
|
||
if r.Method == http.MethodPost {
|
||
job, _ := h.cron.GetByID(jobID, userID)
|
||
if err := h.cron.TogglePause(jobID, userID); err != nil {
|
||
logging.Error("Failed to toggle cron job: %v", err)
|
||
} else if job != nil {
|
||
if job.Status == "paused" {
|
||
h.audit.Log(userID, audit.ActionCronJobResumed, fmt.Sprintf("Resumed cron job: %s (ID %d)", job.Name, jobID), clientIP(r))
|
||
} else {
|
||
h.audit.Log(userID, audit.ActionCronJobPaused, fmt.Sprintf("Paused cron job: %s (ID %d)", job.Name, jobID), clientIP(r))
|
||
}
|
||
}
|
||
}
|
||
http.Redirect(w, r, "/cron", http.StatusSeeOther)
|
||
|
||
default:
|
||
http.NotFound(w, r)
|
||
}
|
||
}
|
||
|
||
// --- MFA Handlers ---
|
||
|
||
// getSession returns the session data for the current request (requires auth middleware)
|
||
func (h *Handler) getSession(r *http.Request) *sessionData {
|
||
cookie, err := r.Cookie("keywarden_session")
|
||
if err != nil {
|
||
return nil
|
||
}
|
||
h.mu.RLock()
|
||
sess := h.sessions[cookie.Value]
|
||
h.mu.RUnlock()
|
||
return sess
|
||
}
|
||
|
||
// handleMFAEnforce shows the standalone MFA setup page (no sidebar) for
|
||
// users who are required to enable MFA. This page is shown after login.
|
||
func (h *Handler) handleMFAEnforce(w http.ResponseWriter, r *http.Request) {
|
||
userID := h.getUserID(r)
|
||
user, _ := h.auth.GetUserByID(userID)
|
||
|
||
// If user already has MFA enabled, go to dashboard
|
||
if user.MFAEnabled {
|
||
http.Redirect(w, r, "/dashboard", http.StatusSeeOther)
|
||
return
|
||
}
|
||
|
||
if r.Method == http.MethodGet {
|
||
secret := h.auth.GenerateMFASecret()
|
||
uri := fmt.Sprintf("otpauth://totp/Keywarden:%s?secret=%s&issuer=Keywarden&algorithm=SHA1&digits=6&period=30",
|
||
user.Username, secret)
|
||
|
||
data := &PageData{
|
||
Title: "MFA Setup Required",
|
||
MFASecret: secret,
|
||
MFAUri: uri,
|
||
}
|
||
h.templates["mfa_required"].Execute(w, data)
|
||
return
|
||
}
|
||
|
||
// POST: verify & enable MFA
|
||
secret := r.FormValue("mfa_secret")
|
||
code := r.FormValue("mfa_code")
|
||
|
||
if !validateTOTP(secret, code) {
|
||
uri := fmt.Sprintf("otpauth://totp/Keywarden:%s?secret=%s&issuer=Keywarden&algorithm=SHA1&digits=6&period=30",
|
||
user.Username, secret)
|
||
data := &PageData{
|
||
Title: "MFA Setup Required",
|
||
MFASecret: secret,
|
||
MFAUri: uri,
|
||
Flash: &Flash{Type: "danger", Message: "Invalid verification code. Please try again."},
|
||
}
|
||
h.templates["mfa_required"].Execute(w, data)
|
||
return
|
||
}
|
||
|
||
if err := h.auth.EnableMFA(userID, secret); err != nil {
|
||
data := &PageData{
|
||
Title: "MFA Setup Required",
|
||
Flash: &Flash{Type: "danger", Message: "Failed to enable MFA: " + err.Error()},
|
||
}
|
||
h.templates["mfa_required"].Execute(w, data)
|
||
return
|
||
}
|
||
|
||
// Clear the session flag
|
||
if sess := h.getSession(r); sess != nil {
|
||
h.mu.Lock()
|
||
sess.MFASetupRequired = false
|
||
h.mu.Unlock()
|
||
}
|
||
|
||
logging.Info("MFA enabled for user_id=%d (enforcement)", userID)
|
||
h.audit.Log(userID, audit.ActionMFAEnabled, "MFA enabled (enforcement)", clientIP(r))
|
||
http.Redirect(w, r, "/dashboard", http.StatusSeeOther)
|
||
}
|
||
|
||
func (h *Handler) handleMFASetup(w http.ResponseWriter, r *http.Request) {
|
||
userID := h.getUserID(r)
|
||
user, _ := h.auth.GetUserByID(userID)
|
||
|
||
// Check if MFA is enforced by admin
|
||
mfaRequired, _ := h.auth.GetSetting("mfa_required")
|
||
isMFARequired := mfaRequired == "true"
|
||
|
||
if r.Method == http.MethodGet {
|
||
secret := h.auth.GenerateMFASecret()
|
||
uri := fmt.Sprintf("otpauth://totp/Keywarden:%s?secret=%s&issuer=Keywarden&algorithm=SHA1&digits=6&period=30",
|
||
user.Username, secret)
|
||
|
||
data := &PageData{
|
||
Title: "MFA Setup",
|
||
Active: "settings",
|
||
User: user,
|
||
MFASecret: secret,
|
||
MFAUri: uri,
|
||
MFARequired: isMFARequired,
|
||
}
|
||
h.templates["mfa_setup"].ExecuteTemplate(w, "base", data)
|
||
return
|
||
}
|
||
|
||
// POST: verify & enable MFA
|
||
secret := r.FormValue("mfa_secret")
|
||
code := r.FormValue("mfa_code")
|
||
|
||
if !validateTOTP(secret, code) {
|
||
uri := fmt.Sprintf("otpauth://totp/Keywarden:%s?secret=%s&issuer=Keywarden&algorithm=SHA1&digits=6&period=30",
|
||
user.Username, secret)
|
||
data := &PageData{
|
||
Title: "MFA Setup",
|
||
Active: "settings",
|
||
User: user,
|
||
MFASecret: secret,
|
||
MFAUri: uri,
|
||
Flash: &Flash{Type: "danger", Message: "Invalid verification code. Please try again."},
|
||
}
|
||
h.templates["mfa_setup"].ExecuteTemplate(w, "base", data)
|
||
return
|
||
}
|
||
|
||
if err := h.auth.EnableMFA(userID, secret); err != nil {
|
||
data := &PageData{
|
||
Title: "MFA Setup",
|
||
Active: "settings",
|
||
User: user,
|
||
Flash: &Flash{Type: "danger", Message: "Failed to enable MFA: " + err.Error()},
|
||
}
|
||
h.templates["mfa_setup"].ExecuteTemplate(w, "base", data)
|
||
return
|
||
}
|
||
|
||
logging.Info("MFA enabled for user_id=%d", userID)
|
||
h.audit.Log(userID, audit.ActionMFAEnabled, "MFA enabled", clientIP(r))
|
||
http.Redirect(w, r, "/settings", http.StatusSeeOther)
|
||
}
|
||
|
||
func (h *Handler) handleMFADisable(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method != http.MethodPost {
|
||
http.Redirect(w, r, "/settings", http.StatusSeeOther)
|
||
return
|
||
}
|
||
|
||
// Prevent disabling MFA when enforcement is active
|
||
mfaRequired, _ := h.auth.GetSetting("mfa_required")
|
||
if mfaRequired == "true" {
|
||
http.Redirect(w, r, "/settings", http.StatusSeeOther)
|
||
return
|
||
}
|
||
|
||
userID := h.getUserID(r)
|
||
h.auth.DisableMFA(userID)
|
||
logging.Info("MFA disabled for user_id=%d", userID)
|
||
h.audit.Log(userID, audit.ActionMFADisabled, "MFA disabled", clientIP(r))
|
||
http.Redirect(w, r, "/settings", http.StatusSeeOther)
|
||
}
|
||
|
||
// --- TOTP Helpers (RFC 6238, no external dependency) ---
|
||
|
||
// sendLoginNotification sends login notification email asynchronously
|
||
func (h *Handler) sendLoginNotification(user *models.User, r *http.Request) {
|
||
if !h.mail.IsEnabled() {
|
||
logging.Debug("Login notification skipped: email service not enabled")
|
||
return
|
||
}
|
||
if user.Email == "" {
|
||
logging.Debug("Login notification skipped for %s: no email address configured", user.Username)
|
||
return
|
||
}
|
||
if !user.EmailNotifyLogin {
|
||
logging.Debug("Login notification skipped for %s: notifications disabled", user.Username)
|
||
return
|
||
}
|
||
data := mail.LoginNotificationData{
|
||
Username: user.Username,
|
||
IPAddress: clientIP(r),
|
||
Timestamp: time.Now().Format("2006-01-02 15:04:05 MST"),
|
||
UserAgent: r.UserAgent(),
|
||
}
|
||
go func() {
|
||
if err := h.mail.SendLoginNotification(user.Email, data); err != nil {
|
||
logging.Warn("Failed to send login notification to %s: %v", user.Email, err)
|
||
h.audit.Log(user.ID, audit.ActionEmailLoginFailed, fmt.Sprintf("Login notification to %s failed: %v", user.Email, err), data.IPAddress)
|
||
} else {
|
||
logging.Info("Login notification email sent to %s for user '%s'", user.Email, user.Username)
|
||
h.audit.Log(user.ID, audit.ActionEmailLoginSent, fmt.Sprintf("Login notification sent to %s", user.Email), data.IPAddress)
|
||
}
|
||
}()
|
||
}
|
||
|
||
// handleEmailNotifyToggle toggles login email notifications for the current user
|
||
func (h *Handler) handleEmailNotifyToggle(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method != http.MethodPost {
|
||
http.Redirect(w, r, "/settings", http.StatusSeeOther)
|
||
return
|
||
}
|
||
|
||
userID := h.getUserID(r)
|
||
// The form uses a hidden field (value="0") + checkbox (value="1").
|
||
// When the checkbox is checked, both values are posted.
|
||
// r.FormValue() only returns the first value (always "0"),
|
||
// so we must check all posted values for the field.
|
||
_ = r.ParseForm()
|
||
enabled := false
|
||
for _, v := range r.PostForm["email_notify_login"] {
|
||
if v == "1" {
|
||
enabled = true
|
||
break
|
||
}
|
||
}
|
||
h.auth.UpdateEmailNotifyLogin(userID, enabled)
|
||
|
||
action := "disabled"
|
||
if enabled {
|
||
action = "enabled"
|
||
}
|
||
h.audit.Log(userID, audit.ActionEmailNotifyChanged, fmt.Sprintf("Login email notifications %s", action), clientIP(r))
|
||
http.Redirect(w, r, "/settings", http.StatusSeeOther)
|
||
}
|
||
|
||
// handleAdminEmailTest sends a test email (admin only)
|
||
func (h *Handler) handleAdminEmailTest(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method != http.MethodPost {
|
||
http.Redirect(w, r, "/admin/settings", http.StatusSeeOther)
|
||
return
|
||
}
|
||
|
||
userID := h.getUserID(r)
|
||
user, _ := h.auth.GetUserByID(userID)
|
||
|
||
toEmail := r.FormValue("test_email")
|
||
if toEmail == "" {
|
||
toEmail = user.Email
|
||
}
|
||
|
||
if err := h.mail.SendTestEmail(toEmail); err != nil {
|
||
logging.Warn("SMTP test email to %s failed: %v", toEmail, err)
|
||
h.audit.Log(userID, audit.ActionEmailTestFailed, fmt.Sprintf("SMTP test to %s failed: %v", toEmail, err), clientIP(r))
|
||
} else {
|
||
logging.Info("SMTP test email sent successfully to %s", toEmail)
|
||
h.audit.Log(userID, audit.ActionEmailTestSent, fmt.Sprintf("SMTP test email sent to %s", toEmail), clientIP(r))
|
||
}
|
||
|
||
http.Redirect(w, r, "/admin/settings", http.StatusSeeOther)
|
||
}
|
||
|
||
// handleBackupExport creates an encrypted backup of the entire database
|
||
func (h *Handler) handleBackupExport(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method != http.MethodPost {
|
||
http.Redirect(w, r, "/admin/settings", http.StatusSeeOther)
|
||
return
|
||
}
|
||
|
||
userID := h.getUserID(r)
|
||
password := r.FormValue("backup_password")
|
||
passwordConfirm := r.FormValue("backup_password_confirm")
|
||
|
||
if password == "" {
|
||
h.audit.Log(userID, audit.ActionBackupExportFailed, "Empty backup password", clientIP(r))
|
||
http.Redirect(w, r, "/admin/settings", http.StatusSeeOther)
|
||
return
|
||
}
|
||
|
||
if password != passwordConfirm {
|
||
h.audit.Log(userID, audit.ActionBackupExportFailed, "Passwords do not match", clientIP(r))
|
||
http.Redirect(w, r, "/admin/settings", http.StatusSeeOther)
|
||
return
|
||
}
|
||
|
||
// Validate password against the configured password policy
|
||
if err := h.auth.ValidatePasswordPolicy(password); err != nil {
|
||
h.audit.Log(userID, audit.ActionBackupExportFailed, fmt.Sprintf("Backup password does not meet policy: %v", err), clientIP(r))
|
||
http.Redirect(w, r, "/admin/settings", http.StatusSeeOther)
|
||
return
|
||
}
|
||
|
||
// Export all data
|
||
backup, err := h.db.ExportAll()
|
||
if err != nil {
|
||
logging.Error("Backup export failed: %v", err)
|
||
h.audit.Log(userID, audit.ActionBackupExportFailed, fmt.Sprintf("Export failed: %v", err), clientIP(r))
|
||
http.Redirect(w, r, "/admin/settings", http.StatusSeeOther)
|
||
return
|
||
}
|
||
|
||
// Serialize to JSON
|
||
jsonData, err := json.MarshalIndent(backup, "", " ")
|
||
if err != nil {
|
||
logging.Error("Backup JSON marshal failed: %v", err)
|
||
h.audit.Log(userID, audit.ActionBackupExportFailed, fmt.Sprintf("JSON marshal failed: %v", err), clientIP(r))
|
||
http.Redirect(w, r, "/admin/settings", http.StatusSeeOther)
|
||
return
|
||
}
|
||
|
||
// Encrypt with user-provided password
|
||
encrypted, err := database.EncryptBackup(jsonData, password)
|
||
if err != nil {
|
||
logging.Error("Backup encryption failed: %v", err)
|
||
h.audit.Log(userID, audit.ActionBackupExportFailed, fmt.Sprintf("Encryption failed: %v", err), clientIP(r))
|
||
http.Redirect(w, r, "/admin/settings", http.StatusSeeOther)
|
||
return
|
||
}
|
||
|
||
h.audit.Log(userID, audit.ActionBackupExported, "Full system backup exported", clientIP(r))
|
||
|
||
// Send as download
|
||
filename := fmt.Sprintf("keywarden-backup-%s.kwbak", time.Now().Format("2006-01-02_150405"))
|
||
w.Header().Set("Content-Type", "application/octet-stream")
|
||
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", filename))
|
||
w.Header().Set("Content-Length", strconv.Itoa(len(encrypted)))
|
||
w.Write(encrypted)
|
||
}
|
||
|
||
// handleBackupImport restores an encrypted backup
|
||
func (h *Handler) handleBackupImport(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method != http.MethodPost {
|
||
http.Redirect(w, r, "/admin/settings", http.StatusSeeOther)
|
||
return
|
||
}
|
||
|
||
userID := h.getUserID(r)
|
||
password := r.FormValue("restore_password")
|
||
|
||
if password == "" {
|
||
h.audit.Log(userID, audit.ActionBackupImportFailed, "Empty restore password", clientIP(r))
|
||
http.Redirect(w, r, "/admin/settings?flash_type=danger&flash_msg=Restore+failed:+no+password+provided", http.StatusSeeOther)
|
||
return
|
||
}
|
||
|
||
// Parse multipart form (max 100MB)
|
||
if err := r.ParseMultipartForm(100 << 20); err != nil {
|
||
h.audit.Log(userID, audit.ActionBackupImportFailed, fmt.Sprintf("Failed to parse form: %v", err), clientIP(r))
|
||
http.Redirect(w, r, "/admin/settings?flash_type=danger&flash_msg=Restore+failed:+could+not+parse+upload", http.StatusSeeOther)
|
||
return
|
||
}
|
||
|
||
file, header, err := r.FormFile("backup_file")
|
||
if err != nil {
|
||
h.audit.Log(userID, audit.ActionBackupImportFailed, "No backup file provided", clientIP(r))
|
||
http.Redirect(w, r, "/admin/settings?flash_type=danger&flash_msg=Restore+failed:+no+backup+file+provided", http.StatusSeeOther)
|
||
return
|
||
}
|
||
defer file.Close()
|
||
|
||
// Read file content
|
||
encrypted, err := io.ReadAll(file)
|
||
if err != nil {
|
||
h.audit.Log(userID, audit.ActionBackupImportFailed, fmt.Sprintf("Failed to read file: %v", err), clientIP(r))
|
||
http.Redirect(w, r, "/admin/settings?flash_type=danger&flash_msg=Restore+failed:+could+not+read+backup+file", http.StatusSeeOther)
|
||
return
|
||
}
|
||
|
||
// Decrypt
|
||
jsonData, err := database.DecryptBackup(encrypted, password)
|
||
if err != nil {
|
||
logging.Warn("Backup import decryption failed: %v", err)
|
||
h.audit.Log(userID, audit.ActionBackupImportFailed, fmt.Sprintf("Decryption failed (wrong password or corrupt file): %s", header.Filename), clientIP(r))
|
||
http.Redirect(w, r, "/admin/settings?flash_type=danger&flash_msg=Restore+failed:+wrong+password+or+corrupt+backup+file", http.StatusSeeOther)
|
||
return
|
||
}
|
||
|
||
// Parse JSON
|
||
backup, err := database.ParseBackupJSON(jsonData)
|
||
if err != nil {
|
||
logging.Warn("Backup import parse failed: %v", err)
|
||
h.audit.Log(userID, audit.ActionBackupImportFailed, fmt.Sprintf("Invalid backup format: %v", err), clientIP(r))
|
||
http.Redirect(w, r, "/admin/settings?flash_type=danger&flash_msg=Restore+failed:+invalid+backup+format", http.StatusSeeOther)
|
||
return
|
||
}
|
||
|
||
// Import all data
|
||
if err := h.db.ImportAll(backup); err != nil {
|
||
logging.Error("Backup import failed: %v", err)
|
||
h.audit.Log(userID, audit.ActionBackupImportFailed, fmt.Sprintf("Import failed: %v", err), clientIP(r))
|
||
http.Redirect(w, r, "/admin/settings?flash_type=danger&flash_msg=Restore+failed:+database+import+error", http.StatusSeeOther)
|
||
return
|
||
}
|
||
|
||
logging.Info("Backup successfully imported from %s", header.Filename)
|
||
h.audit.Log(userID, audit.ActionBackupImported, fmt.Sprintf("Full system backup restored from %s (created: %s)", header.Filename, backup.CreatedAt), clientIP(r))
|
||
|
||
http.Redirect(w, r, "/admin/settings?flash_type=success&flash_msg=Backup+successfully+restored+from+"+url.QueryEscape(header.Filename), http.StatusSeeOther)
|
||
}
|
||
|
||
func validateTOTP(secret, code string) bool {
|
||
if secret == "" || code == "" {
|
||
return false
|
||
}
|
||
// Check current time step and +/- 1 for clock skew tolerance
|
||
now := time.Now().Unix()
|
||
for _, offset := range []int64{-1, 0, 1} {
|
||
t := (now / 30) + offset
|
||
if generateTOTP(secret, t) == code {
|
||
return true
|
||
}
|
||
}
|
||
return false
|
||
}
|
||
|
||
func generateTOTP(secret string, counter int64) string {
|
||
key, err := base32.StdEncoding.WithPadding(base32.NoPadding).DecodeString(strings.ToUpper(secret))
|
||
if err != nil {
|
||
return ""
|
||
}
|
||
|
||
buf := make([]byte, 8)
|
||
binary.BigEndian.PutUint64(buf, uint64(counter))
|
||
|
||
mac := hmac.New(sha1.New, key)
|
||
mac.Write(buf)
|
||
sum := mac.Sum(nil)
|
||
|
||
offset := sum[len(sum)-1] & 0x0f
|
||
code := binary.BigEndian.Uint32(sum[offset:offset+4]) & 0x7fffffff
|
||
otp := code % uint32(math.Pow10(6))
|
||
|
||
return fmt.Sprintf("%06d", otp)
|
||
}
|
||
|
||
// generateSessionID creates a random session identifier
|
||
func generateSessionID() string {
|
||
b := make([]byte, 32)
|
||
rand.Read(b)
|
||
return fmt.Sprintf("%x", b)
|
||
}
|
||
|
||
// generateInitialPassword creates a random password of the given length
|
||
// using uppercase letters, lowercase letters, and digits (no special characters)
|
||
func generateInitialPassword(length int) string {
|
||
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||
b := make([]byte, length)
|
||
rand.Read(b)
|
||
for i := range b {
|
||
b[i] = charset[int(b[i])%len(charset)]
|
||
}
|
||
return string(b)
|
||
}
|