299 lines
11 KiB
Go
299 lines
11 KiB
Go
// Keywarden - Centralized SSH Key Management and Deployment
|
|
// Copyright (C) 2026 Patrick Asmus (scriptos)
|
|
// SPDX-License-Identifier: AGPL-3.0-or-later
|
|
|
|
package database
|
|
|
|
import (
|
|
"crypto/aes"
|
|
"crypto/cipher"
|
|
"crypto/rand"
|
|
"crypto/sha256"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"time"
|
|
)
|
|
|
|
// BackupData contains all exportable data from the database
|
|
type BackupData struct {
|
|
Version string `json:"version"`
|
|
CreatedAt string `json:"created_at"`
|
|
Users []map[string]interface{} `json:"users"`
|
|
SSHKeys []map[string]interface{} `json:"ssh_keys"`
|
|
Servers []map[string]interface{} `json:"servers"`
|
|
ServerGroups []map[string]interface{} `json:"server_groups"`
|
|
GroupMembers []map[string]interface{} `json:"server_group_members"`
|
|
KeyDeployments []map[string]interface{} `json:"key_deployments"`
|
|
AuditLog []map[string]interface{} `json:"audit_log"`
|
|
Settings []map[string]interface{} `json:"settings"`
|
|
AccessAssign []map[string]interface{} `json:"access_assignments"`
|
|
CronJobs []map[string]interface{} `json:"cron_jobs"`
|
|
}
|
|
|
|
// ExportAll exports all database tables to a BackupData struct
|
|
func (d *DB) ExportAll() (*BackupData, error) {
|
|
backup := &BackupData{
|
|
Version: "1",
|
|
CreatedAt: time.Now().UTC().Format(time.RFC3339),
|
|
}
|
|
|
|
tables := []struct {
|
|
query string
|
|
dest *[]map[string]interface{}
|
|
}{
|
|
{`SELECT id, username, email, password_hash, role, mfa_enabled, mfa_secret, theme, email_notify_login, avatar_base64, must_change_password, failed_login_attempts, locked_until, last_login_at, created_at, updated_at FROM users ORDER BY id`, &backup.Users},
|
|
{`SELECT id, user_id, name, key_type, bits, fingerprint, public_key, private_key_enc, passphrase_enc, created_at FROM ssh_keys ORDER BY id`, &backup.SSHKeys},
|
|
{`SELECT id, user_id, name, hostname, port, username, description, created_at, updated_at FROM servers ORDER BY id`, &backup.Servers},
|
|
{`SELECT id, user_id, name, description, created_at, updated_at FROM server_groups ORDER BY id`, &backup.ServerGroups},
|
|
{`SELECT id, group_id, server_id FROM server_group_members ORDER BY id`, &backup.GroupMembers},
|
|
{`SELECT id, ssh_key_id, server_id, deployed_at, status, message, key_name FROM key_deployments ORDER BY id`, &backup.KeyDeployments},
|
|
{`SELECT id, user_id, action, details, ip_address, created_at FROM audit_log ORDER BY id`, &backup.AuditLog},
|
|
{`SELECT key, value, updated_at FROM settings ORDER BY key`, &backup.Settings},
|
|
{`SELECT id, user_id, ssh_key_id, server_id, group_id, system_user, desired_state, sudo, create_user, initial_password, status, last_sync_at, created_at, updated_at FROM access_assignments ORDER BY id`, &backup.AccessAssign},
|
|
{`SELECT id, user_id, name, ssh_key_id, server_id, group_id, schedule, scheduled_at, next_run, last_run, remove_after_min, status, message, timezone, time_of_day, day_of_week, day_of_month, minute_of_hour, target_user_id, system_user, sudo, create_user, initial_password, expiry_action, created_at FROM cron_jobs ORDER BY id`, &backup.CronJobs},
|
|
}
|
|
|
|
for _, t := range tables {
|
|
rows, err := d.Query(t.query)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to query table: %w", err)
|
|
}
|
|
data, err := rowsToMaps(rows)
|
|
rows.Close()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to read rows: %w", err)
|
|
}
|
|
*t.dest = data
|
|
}
|
|
|
|
return backup, nil
|
|
}
|
|
|
|
// ImportAll restores all database tables from a BackupData struct.
|
|
// It clears existing data and replaces it with the backup data.
|
|
func (d *DB) ImportAll(backup *BackupData) error {
|
|
tx, err := d.Begin()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to begin transaction: %w", err)
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
// Disable foreign key checks during import
|
|
if _, err := tx.Exec(`PRAGMA foreign_keys = OFF`); err != nil {
|
|
return fmt.Errorf("failed to disable foreign keys: %w", err)
|
|
}
|
|
|
|
// Clear all tables in dependency order
|
|
clearOrder := []string{
|
|
"cron_jobs",
|
|
"access_assignments",
|
|
"key_deployments",
|
|
"server_group_members",
|
|
"server_groups",
|
|
"ssh_keys",
|
|
"servers",
|
|
"audit_log",
|
|
"settings",
|
|
"users",
|
|
}
|
|
for _, table := range clearOrder {
|
|
if _, err := tx.Exec(fmt.Sprintf("DELETE FROM %s", table)); err != nil {
|
|
return fmt.Errorf("failed to clear table %s: %w", table, err)
|
|
}
|
|
}
|
|
|
|
// Import tables in reverse dependency order (parents first)
|
|
importOrder := []struct {
|
|
table string
|
|
columns []string
|
|
data []map[string]interface{}
|
|
}{
|
|
{"users", []string{"id", "username", "email", "password_hash", "role", "mfa_enabled", "mfa_secret", "theme", "email_notify_login", "avatar_base64", "must_change_password", "failed_login_attempts", "locked_until", "last_login_at", "created_at", "updated_at"}, backup.Users},
|
|
{"ssh_keys", []string{"id", "user_id", "name", "key_type", "bits", "fingerprint", "public_key", "private_key_enc", "passphrase_enc", "created_at"}, backup.SSHKeys},
|
|
{"servers", []string{"id", "user_id", "name", "hostname", "port", "username", "description", "created_at", "updated_at"}, backup.Servers},
|
|
{"server_groups", []string{"id", "user_id", "name", "description", "created_at", "updated_at"}, backup.ServerGroups},
|
|
{"server_group_members", []string{"id", "group_id", "server_id"}, backup.GroupMembers},
|
|
{"key_deployments", []string{"id", "ssh_key_id", "server_id", "deployed_at", "status", "message", "key_name"}, backup.KeyDeployments},
|
|
{"audit_log", []string{"id", "user_id", "action", "details", "ip_address", "created_at"}, backup.AuditLog},
|
|
{"settings", []string{"key", "value", "updated_at"}, backup.Settings},
|
|
{"access_assignments", []string{"id", "user_id", "ssh_key_id", "server_id", "group_id", "system_user", "desired_state", "sudo", "create_user", "initial_password", "status", "last_sync_at", "created_at", "updated_at"}, backup.AccessAssign},
|
|
{"cron_jobs", []string{"id", "user_id", "name", "ssh_key_id", "server_id", "group_id", "schedule", "scheduled_at", "next_run", "last_run", "remove_after_min", "status", "message", "timezone", "time_of_day", "day_of_week", "day_of_month", "minute_of_hour", "target_user_id", "system_user", "sudo", "create_user", "initial_password", "expiry_action", "created_at"}, backup.CronJobs},
|
|
}
|
|
|
|
for _, imp := range importOrder {
|
|
if len(imp.data) == 0 {
|
|
continue
|
|
}
|
|
|
|
// Build INSERT statement
|
|
placeholders := ""
|
|
for i := range imp.columns {
|
|
if i > 0 {
|
|
placeholders += ", "
|
|
}
|
|
placeholders += "?"
|
|
}
|
|
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", imp.table, joinColumns(imp.columns), placeholders)
|
|
|
|
stmt, err := tx.Prepare(query)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to prepare insert for %s: %w", imp.table, err)
|
|
}
|
|
|
|
for _, row := range imp.data {
|
|
args := make([]interface{}, len(imp.columns))
|
|
for i, col := range imp.columns {
|
|
v := row[col]
|
|
// For columns with NOT NULL DEFAULT '', treat nil as empty string
|
|
if v == nil && col == "key_name" {
|
|
v = ""
|
|
}
|
|
args[i] = v
|
|
}
|
|
if _, err := stmt.Exec(args...); err != nil {
|
|
stmt.Close()
|
|
return fmt.Errorf("failed to insert into %s: %w", imp.table, err)
|
|
}
|
|
}
|
|
stmt.Close()
|
|
}
|
|
|
|
// Re-enable foreign key checks
|
|
if _, err := tx.Exec(`PRAGMA foreign_keys = ON`); err != nil {
|
|
return fmt.Errorf("failed to re-enable foreign keys: %w", err)
|
|
}
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
return fmt.Errorf("failed to commit transaction: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// EncryptBackup encrypts JSON backup data with AES-256-GCM using the given password
|
|
func EncryptBackup(data []byte, password string) ([]byte, error) {
|
|
key := sha256.Sum256([]byte(password))
|
|
block, err := aes.NewCipher(key[:])
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create cipher: %w", err)
|
|
}
|
|
|
|
gcm, err := cipher.NewGCM(block)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create GCM: %w", err)
|
|
}
|
|
|
|
nonce := make([]byte, gcm.NonceSize())
|
|
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
|
return nil, fmt.Errorf("failed to generate nonce: %w", err)
|
|
}
|
|
|
|
// Prepend a magic header so we can identify backup files
|
|
magic := []byte("KWBAK1") // Keywarden Backup v1
|
|
ciphertext := gcm.Seal(nonce, nonce, data, nil)
|
|
result := append(magic, ciphertext...)
|
|
return result, nil
|
|
}
|
|
|
|
// DecryptBackup decrypts an AES-256-GCM encrypted backup with the given password
|
|
func DecryptBackup(encrypted []byte, password string) ([]byte, error) {
|
|
// Check magic header
|
|
magic := []byte("KWBAK1")
|
|
if len(encrypted) < len(magic) {
|
|
return nil, fmt.Errorf("invalid backup file: too short")
|
|
}
|
|
if string(encrypted[:len(magic)]) != string(magic) {
|
|
return nil, fmt.Errorf("invalid backup file: wrong format")
|
|
}
|
|
encrypted = encrypted[len(magic):]
|
|
|
|
key := sha256.Sum256([]byte(password))
|
|
block, err := aes.NewCipher(key[:])
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create cipher: %w", err)
|
|
}
|
|
|
|
gcm, err := cipher.NewGCM(block)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create GCM: %w", err)
|
|
}
|
|
|
|
nonceSize := gcm.NonceSize()
|
|
if len(encrypted) < nonceSize {
|
|
return nil, fmt.Errorf("invalid backup file: ciphertext too short")
|
|
}
|
|
|
|
nonce, ciphertext := encrypted[:nonceSize], encrypted[nonceSize:]
|
|
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("decryption failed: wrong password or corrupted file")
|
|
}
|
|
|
|
return plaintext, nil
|
|
}
|
|
|
|
// rowsToMaps converts sql.Rows to a slice of maps
|
|
func rowsToMaps(rows *sql.Rows) ([]map[string]interface{}, error) {
|
|
columns, err := rows.Columns()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var result []map[string]interface{}
|
|
for rows.Next() {
|
|
values := make([]interface{}, len(columns))
|
|
valuePtrs := make([]interface{}, len(columns))
|
|
for i := range values {
|
|
valuePtrs[i] = &values[i]
|
|
}
|
|
|
|
if err := rows.Scan(valuePtrs...); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
row := make(map[string]interface{})
|
|
for i, col := range columns {
|
|
val := values[i]
|
|
// Convert byte slices to strings for JSON serialization
|
|
if b, ok := val.([]byte); ok {
|
|
row[col] = string(b)
|
|
} else {
|
|
row[col] = val
|
|
}
|
|
}
|
|
result = append(result, row)
|
|
}
|
|
|
|
if result == nil {
|
|
result = []map[string]interface{}{}
|
|
}
|
|
|
|
return result, rows.Err()
|
|
}
|
|
|
|
// joinColumns joins column names with commas
|
|
func joinColumns(cols []string) string {
|
|
result := ""
|
|
for i, col := range cols {
|
|
if i > 0 {
|
|
result += ", "
|
|
}
|
|
result += col
|
|
}
|
|
return result
|
|
}
|
|
|
|
// ParseBackupJSON parses decrypted JSON data into a BackupData struct
|
|
func ParseBackupJSON(data []byte) (*BackupData, error) {
|
|
var backup BackupData
|
|
if err := json.Unmarshal(data, &backup); err != nil {
|
|
return nil, fmt.Errorf("failed to parse backup data: %w", err)
|
|
}
|
|
if backup.Version == "" {
|
|
return nil, fmt.Errorf("invalid backup: missing version")
|
|
}
|
|
return &backup, nil
|
|
}
|