327 lines
7.5 KiB
Go
327 lines
7.5 KiB
Go
package storage
|
|
|
|
import (
|
|
"database/sql"
|
|
"fmt"
|
|
"math"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"time"
|
|
|
|
_ "modernc.org/sqlite"
|
|
)
|
|
|
|
type Store struct {
|
|
db *sql.DB
|
|
path string
|
|
}
|
|
|
|
type Message struct {
|
|
ID string
|
|
Topic string
|
|
Payload string
|
|
QOS byte
|
|
Retained bool
|
|
Timestamp time.Time
|
|
Size int
|
|
}
|
|
|
|
type Stats struct {
|
|
Count int64 `json:"count"`
|
|
Size string `json:"size"`
|
|
Bytes int64 `json:"bytes"`
|
|
}
|
|
|
|
func Open(path string) (*Store, error) {
|
|
dir := filepath.Dir(path)
|
|
if err := os.MkdirAll(dir, 0o755); err != nil {
|
|
return nil, fmt.Errorf("creation dossier sqlite: %w", err)
|
|
}
|
|
|
|
db, err := sql.Open("sqlite", path)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("ouverture sqlite: %w", err)
|
|
}
|
|
|
|
if _, err := db.Exec("PRAGMA journal_mode=WAL;"); err != nil {
|
|
return nil, fmt.Errorf("pragma WAL: %w", err)
|
|
}
|
|
if _, err := db.Exec("PRAGMA synchronous=NORMAL;"); err != nil {
|
|
return nil, fmt.Errorf("pragma synchronous: %w", err)
|
|
}
|
|
|
|
store := &Store{db: db, path: path}
|
|
if err := store.initSchema(); err != nil {
|
|
if isCorruptErr(err) {
|
|
_ = db.Close()
|
|
return recoverCorruptDB(path)
|
|
}
|
|
return nil, err
|
|
}
|
|
if err := store.integrityCheck(); err != nil {
|
|
if isCorruptErr(err) {
|
|
_ = db.Close()
|
|
return recoverCorruptDB(path)
|
|
}
|
|
return nil, err
|
|
}
|
|
return store, nil
|
|
}
|
|
|
|
func (s *Store) initSchema() error {
|
|
schema := `
|
|
CREATE TABLE IF NOT EXISTS messages (
|
|
id TEXT PRIMARY KEY,
|
|
topic TEXT NOT NULL,
|
|
payload TEXT NOT NULL,
|
|
qos INTEGER NOT NULL,
|
|
retained INTEGER NOT NULL,
|
|
ts TEXT NOT NULL,
|
|
size INTEGER NOT NULL
|
|
);
|
|
CREATE INDEX IF NOT EXISTS idx_messages_topic_ts ON messages(topic, ts);
|
|
CREATE INDEX IF NOT EXISTS idx_messages_ts ON messages(ts);
|
|
`
|
|
_, err := s.db.Exec(schema)
|
|
if err != nil {
|
|
return fmt.Errorf("schema sqlite: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *Store) integrityCheck() error {
|
|
var result string
|
|
if err := s.db.QueryRow("PRAGMA integrity_check;").Scan(&result); err != nil {
|
|
return fmt.Errorf("integrity check: %w", err)
|
|
}
|
|
if strings.ToLower(result) != "ok" {
|
|
return fmt.Errorf("integrity check: %s", result)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func recoverCorruptDB(path string) (*Store, error) {
|
|
suffix := time.Now().UTC().Format("20060102-150405")
|
|
backupWithSuffix(path, suffix)
|
|
backupWithSuffix(path+"-wal", suffix)
|
|
backupWithSuffix(path+"-shm", suffix)
|
|
|
|
db, err := sql.Open("sqlite", path)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("ouverture sqlite apres recovery: %w", err)
|
|
}
|
|
if _, err := db.Exec("PRAGMA journal_mode=WAL;"); err != nil {
|
|
_ = db.Close()
|
|
return nil, fmt.Errorf("pragma WAL: %w", err)
|
|
}
|
|
if _, err := db.Exec("PRAGMA synchronous=NORMAL;"); err != nil {
|
|
_ = db.Close()
|
|
return nil, fmt.Errorf("pragma synchronous: %w", err)
|
|
}
|
|
store := &Store{db: db, path: path}
|
|
if err := store.initSchema(); err != nil {
|
|
_ = db.Close()
|
|
return nil, err
|
|
}
|
|
return store, nil
|
|
}
|
|
|
|
func backupWithSuffix(path, suffix string) {
|
|
if _, err := os.Stat(path); err != nil {
|
|
return
|
|
}
|
|
_ = os.Rename(path, fmt.Sprintf("%s.corrupt-%s", path, suffix))
|
|
}
|
|
|
|
func isCorruptErr(err error) bool {
|
|
if err == nil {
|
|
return false
|
|
}
|
|
msg := strings.ToLower(err.Error())
|
|
return strings.Contains(msg, "database disk image is malformed") ||
|
|
strings.Contains(msg, "malformed") ||
|
|
strings.Contains(msg, "btreeinitpage") ||
|
|
strings.Contains(msg, "error code 11")
|
|
}
|
|
|
|
func (s *Store) InsertMessage(msg Message) error {
|
|
_, err := s.db.Exec(
|
|
`INSERT INTO messages(id, topic, payload, qos, retained, ts, size) VALUES(?,?,?,?,?,?,?)`,
|
|
msg.ID,
|
|
msg.Topic,
|
|
msg.Payload,
|
|
int(msg.QOS),
|
|
boolToInt(msg.Retained),
|
|
msg.Timestamp.UTC().Format(time.RFC3339Nano),
|
|
msg.Size,
|
|
)
|
|
if err != nil {
|
|
return fmt.Errorf("insert message: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *Store) GetHistory(topic string, limit int, from, to string) ([]Message, error) {
|
|
args := []any{topic}
|
|
query := "SELECT id, topic, payload, qos, retained, ts, size FROM messages WHERE topic = ?"
|
|
if from != "" {
|
|
query += " AND ts >= ?"
|
|
args = append(args, from)
|
|
}
|
|
if to != "" {
|
|
query += " AND ts <= ?"
|
|
args = append(args, to)
|
|
}
|
|
query += " ORDER BY ts DESC"
|
|
if limit > 0 {
|
|
query += " LIMIT ?"
|
|
args = append(args, limit)
|
|
}
|
|
|
|
rows, err := s.db.Query(query, args...)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("lecture historique: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var out []Message
|
|
for rows.Next() {
|
|
var msg Message
|
|
var retained int
|
|
var ts string
|
|
if err := rows.Scan(&msg.ID, &msg.Topic, &msg.Payload, &msg.QOS, &retained, &ts, &msg.Size); err != nil {
|
|
return nil, fmt.Errorf("scan historique: %w", err)
|
|
}
|
|
msg.Retained = retained == 1
|
|
parsed, _ := time.Parse(time.RFC3339Nano, ts)
|
|
msg.Timestamp = parsed
|
|
out = append(out, msg)
|
|
}
|
|
|
|
return out, nil
|
|
}
|
|
|
|
func (s *Store) ClearTopicHistory(topic string) (int64, error) {
|
|
res, err := s.db.Exec("DELETE FROM messages WHERE topic = ?", topic)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("suppression topic: %w", err)
|
|
}
|
|
affected, _ := res.RowsAffected()
|
|
return affected, nil
|
|
}
|
|
|
|
func (s *Store) ClearAllHistory() (int64, error) {
|
|
res, err := s.db.Exec("DELETE FROM messages")
|
|
if err != nil {
|
|
return 0, fmt.Errorf("suppression db: %w", err)
|
|
}
|
|
affected, _ := res.RowsAffected()
|
|
if err := s.Compact(); err != nil {
|
|
return affected, err
|
|
}
|
|
return affected, nil
|
|
}
|
|
|
|
func (s *Store) PurgeOversize(maxBytes int) (int64, error) {
|
|
res, err := s.db.Exec("DELETE FROM messages WHERE size >= ?", maxBytes)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("purge oversize: %w", err)
|
|
}
|
|
affected, _ := res.RowsAffected()
|
|
return affected, nil
|
|
}
|
|
|
|
func (s *Store) DeleteOldestFraction(fraction float64) (int64, error) {
|
|
if fraction <= 0 {
|
|
return 0, nil
|
|
}
|
|
var count int64
|
|
if err := s.db.QueryRow("SELECT COUNT(*) FROM messages").Scan(&count); err != nil {
|
|
return 0, fmt.Errorf("count messages: %w", err)
|
|
}
|
|
limit := int64(math.Round(float64(count) * fraction))
|
|
if limit <= 0 {
|
|
return 0, nil
|
|
}
|
|
res, err := s.db.Exec("DELETE FROM messages WHERE id IN (SELECT id FROM messages ORDER BY ts ASC LIMIT ?)", limit)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("purge oldest fraction: %w", err)
|
|
}
|
|
affected, _ := res.RowsAffected()
|
|
return affected, nil
|
|
}
|
|
|
|
func (s *Store) PurgeBefore(cutoff time.Time) (int64, error) {
|
|
res, err := s.db.Exec("DELETE FROM messages WHERE ts < ?", cutoff.UTC().Format(time.RFC3339Nano))
|
|
if err != nil {
|
|
return 0, fmt.Errorf("purge ttl: %w", err)
|
|
}
|
|
affected, _ := res.RowsAffected()
|
|
return affected, nil
|
|
}
|
|
|
|
func (s *Store) Stats() (Stats, error) {
|
|
var count int64
|
|
if err := s.db.QueryRow("SELECT COUNT(*) FROM messages").Scan(&count); err != nil {
|
|
return Stats{}, fmt.Errorf("stats count: %w", err)
|
|
}
|
|
sizeBytes := s.totalSize()
|
|
return Stats{
|
|
Count: count,
|
|
Size: formatBytes(sizeBytes),
|
|
Bytes: sizeBytes,
|
|
}, nil
|
|
}
|
|
|
|
func (s *Store) totalSize() int64 {
|
|
if s.path == "" {
|
|
return 0
|
|
}
|
|
total := fileSize(s.path)
|
|
total += fileSize(s.path + "-wal")
|
|
total += fileSize(s.path + "-shm")
|
|
return total
|
|
}
|
|
|
|
func (s *Store) Compact() error {
|
|
if _, err := s.db.Exec("PRAGMA wal_checkpoint(TRUNCATE);"); err != nil {
|
|
return fmt.Errorf("checkpoint wal: %w", err)
|
|
}
|
|
if _, err := s.db.Exec("VACUUM;"); err != nil {
|
|
return fmt.Errorf("vacuum: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func fileSize(path string) int64 {
|
|
info, err := os.Stat(path)
|
|
if err != nil {
|
|
return 0
|
|
}
|
|
return info.Size()
|
|
}
|
|
|
|
func formatBytes(size int64) string {
|
|
if size < 1024 {
|
|
return fmt.Sprintf("%d B", size)
|
|
}
|
|
kb := float64(size) / 1024
|
|
if kb < 1024 {
|
|
return fmt.Sprintf("%.1f KB", kb)
|
|
}
|
|
mb := kb / 1024
|
|
if mb < 1024 {
|
|
return fmt.Sprintf("%.1f MB", mb)
|
|
}
|
|
gb := mb / 1024
|
|
return fmt.Sprintf("%.2f GB", gb)
|
|
}
|
|
|
|
func boolToInt(val bool) int {
|
|
if val {
|
|
return 1
|
|
}
|
|
return 0
|
|
}
|