package storage import ( "database/sql" "fmt" "math" "os" "path/filepath" "strings" "time" _ "modernc.org/sqlite" ) type Store struct { db *sql.DB path string } type Message struct { ID string Topic string Payload string QOS byte Retained bool Timestamp time.Time Size int } type Stats struct { Count int64 `json:"count"` Size string `json:"size"` Bytes int64 `json:"bytes"` } func Open(path string) (*Store, error) { dir := filepath.Dir(path) if err := os.MkdirAll(dir, 0o755); err != nil { return nil, fmt.Errorf("creation dossier sqlite: %w", err) } db, err := sql.Open("sqlite", path) if err != nil { return nil, fmt.Errorf("ouverture sqlite: %w", err) } if _, err := db.Exec("PRAGMA journal_mode=WAL;"); err != nil { return nil, fmt.Errorf("pragma WAL: %w", err) } if _, err := db.Exec("PRAGMA synchronous=NORMAL;"); err != nil { return nil, fmt.Errorf("pragma synchronous: %w", err) } store := &Store{db: db, path: path} if err := store.initSchema(); err != nil { if isCorruptErr(err) { _ = db.Close() return recoverCorruptDB(path) } return nil, err } if err := store.integrityCheck(); err != nil { if isCorruptErr(err) { _ = db.Close() return recoverCorruptDB(path) } return nil, err } return store, nil } func (s *Store) initSchema() error { schema := ` CREATE TABLE IF NOT EXISTS messages ( id TEXT PRIMARY KEY, topic TEXT NOT NULL, payload TEXT NOT NULL, qos INTEGER NOT NULL, retained INTEGER NOT NULL, ts TEXT NOT NULL, size INTEGER NOT NULL ); CREATE INDEX IF NOT EXISTS idx_messages_topic_ts ON messages(topic, ts); CREATE INDEX IF NOT EXISTS idx_messages_ts ON messages(ts); ` _, err := s.db.Exec(schema) if err != nil { return fmt.Errorf("schema sqlite: %w", err) } return nil } func (s *Store) integrityCheck() error { var result string if err := s.db.QueryRow("PRAGMA integrity_check;").Scan(&result); err != nil { return fmt.Errorf("integrity check: %w", err) } if strings.ToLower(result) != "ok" { return fmt.Errorf("integrity check: %s", result) } return nil } func recoverCorruptDB(path string) (*Store, error) { suffix := time.Now().UTC().Format("20060102-150405") backupWithSuffix(path, suffix) backupWithSuffix(path+"-wal", suffix) backupWithSuffix(path+"-shm", suffix) db, err := sql.Open("sqlite", path) if err != nil { return nil, fmt.Errorf("ouverture sqlite apres recovery: %w", err) } if _, err := db.Exec("PRAGMA journal_mode=WAL;"); err != nil { _ = db.Close() return nil, fmt.Errorf("pragma WAL: %w", err) } if _, err := db.Exec("PRAGMA synchronous=NORMAL;"); err != nil { _ = db.Close() return nil, fmt.Errorf("pragma synchronous: %w", err) } store := &Store{db: db, path: path} if err := store.initSchema(); err != nil { _ = db.Close() return nil, err } return store, nil } func backupWithSuffix(path, suffix string) { if _, err := os.Stat(path); err != nil { return } _ = os.Rename(path, fmt.Sprintf("%s.corrupt-%s", path, suffix)) } func isCorruptErr(err error) bool { if err == nil { return false } msg := strings.ToLower(err.Error()) return strings.Contains(msg, "database disk image is malformed") || strings.Contains(msg, "malformed") || strings.Contains(msg, "btreeinitpage") || strings.Contains(msg, "error code 11") } func (s *Store) InsertMessage(msg Message) error { _, err := s.db.Exec( `INSERT INTO messages(id, topic, payload, qos, retained, ts, size) VALUES(?,?,?,?,?,?,?)`, msg.ID, msg.Topic, msg.Payload, int(msg.QOS), boolToInt(msg.Retained), msg.Timestamp.UTC().Format(time.RFC3339Nano), msg.Size, ) if err != nil { return fmt.Errorf("insert message: %w", err) } return nil } func (s *Store) GetHistory(topic string, limit int, from, to string) ([]Message, error) { args := []any{topic} query := "SELECT id, topic, payload, qos, retained, ts, size FROM messages WHERE topic = ?" if from != "" { query += " AND ts >= ?" args = append(args, from) } if to != "" { query += " AND ts <= ?" args = append(args, to) } query += " ORDER BY ts DESC" if limit > 0 { query += " LIMIT ?" args = append(args, limit) } rows, err := s.db.Query(query, args...) if err != nil { return nil, fmt.Errorf("lecture historique: %w", err) } defer rows.Close() var out []Message for rows.Next() { var msg Message var retained int var ts string if err := rows.Scan(&msg.ID, &msg.Topic, &msg.Payload, &msg.QOS, &retained, &ts, &msg.Size); err != nil { return nil, fmt.Errorf("scan historique: %w", err) } msg.Retained = retained == 1 parsed, _ := time.Parse(time.RFC3339Nano, ts) msg.Timestamp = parsed out = append(out, msg) } return out, nil } func (s *Store) ClearTopicHistory(topic string) (int64, error) { res, err := s.db.Exec("DELETE FROM messages WHERE topic = ?", topic) if err != nil { return 0, fmt.Errorf("suppression topic: %w", err) } affected, _ := res.RowsAffected() return affected, nil } func (s *Store) ClearAllHistory() (int64, error) { res, err := s.db.Exec("DELETE FROM messages") if err != nil { return 0, fmt.Errorf("suppression db: %w", err) } affected, _ := res.RowsAffected() if err := s.Compact(); err != nil { return affected, err } return affected, nil } func (s *Store) PurgeOversize(maxBytes int) (int64, error) { res, err := s.db.Exec("DELETE FROM messages WHERE size >= ?", maxBytes) if err != nil { return 0, fmt.Errorf("purge oversize: %w", err) } affected, _ := res.RowsAffected() return affected, nil } func (s *Store) DeleteOldestFraction(fraction float64) (int64, error) { if fraction <= 0 { return 0, nil } var count int64 if err := s.db.QueryRow("SELECT COUNT(*) FROM messages").Scan(&count); err != nil { return 0, fmt.Errorf("count messages: %w", err) } limit := int64(math.Round(float64(count) * fraction)) if limit <= 0 { return 0, nil } res, err := s.db.Exec("DELETE FROM messages WHERE id IN (SELECT id FROM messages ORDER BY ts ASC LIMIT ?)", limit) if err != nil { return 0, fmt.Errorf("purge oldest fraction: %w", err) } affected, _ := res.RowsAffected() return affected, nil } func (s *Store) PurgeBefore(cutoff time.Time) (int64, error) { res, err := s.db.Exec("DELETE FROM messages WHERE ts < ?", cutoff.UTC().Format(time.RFC3339Nano)) if err != nil { return 0, fmt.Errorf("purge ttl: %w", err) } affected, _ := res.RowsAffected() return affected, nil } func (s *Store) Stats() (Stats, error) { var count int64 if err := s.db.QueryRow("SELECT COUNT(*) FROM messages").Scan(&count); err != nil { return Stats{}, fmt.Errorf("stats count: %w", err) } sizeBytes := s.totalSize() return Stats{ Count: count, Size: formatBytes(sizeBytes), Bytes: sizeBytes, }, nil } func (s *Store) totalSize() int64 { if s.path == "" { return 0 } total := fileSize(s.path) total += fileSize(s.path + "-wal") total += fileSize(s.path + "-shm") return total } func (s *Store) Compact() error { if _, err := s.db.Exec("PRAGMA wal_checkpoint(TRUNCATE);"); err != nil { return fmt.Errorf("checkpoint wal: %w", err) } if _, err := s.db.Exec("VACUUM;"); err != nil { return fmt.Errorf("vacuum: %w", err) } return nil } func fileSize(path string) int64 { info, err := os.Stat(path) if err != nil { return 0 } return info.Size() } func formatBytes(size int64) string { if size < 1024 { return fmt.Sprintf("%d B", size) } kb := float64(size) / 1024 if kb < 1024 { return fmt.Sprintf("%.1f KB", kb) } mb := kb / 1024 if mb < 1024 { return fmt.Sprintf("%.1f MB", mb) } gb := mb / 1024 return fmt.Sprintf("%.2f GB", gb) } func boolToInt(val bool) int { if val { return 1 } return 0 }