184 lines
4.5 KiB
Go
184 lines
4.5 KiB
Go
package migrate
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"io/fs"
|
|
"log"
|
|
"net/url"
|
|
"sort"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/jackc/pgx/v5"
|
|
"github.com/jackc/pgx/v5/pgconn"
|
|
|
|
mig "watch-party-backend/db/migration"
|
|
"watch-party-backend/internal/config"
|
|
)
|
|
|
|
// Run applies all pending migrations.
|
|
// Uses cfg.DB.DSN() for the app database and derives an admin DSN (db=postgres) from it.
|
|
func Run(ctx context.Context, cfg config.Config) error {
|
|
appDSN := cfg.DB.DSN()
|
|
adminDSN := dsnWithDB(appDSN, "postgres")
|
|
dbName := cfg.DB.Name
|
|
|
|
// 1) Connect to admin DB to ensure target DB exists
|
|
adminCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
|
defer cancel()
|
|
|
|
adminConn, err := pgx.Connect(adminCtx, adminDSN)
|
|
if err != nil {
|
|
return fmt.Errorf("connect admin DB: %w", err)
|
|
}
|
|
defer adminConn.Close(ctx)
|
|
|
|
if err := createDatabaseIfNotExists(ctx, adminConn, dbName); err != nil {
|
|
return fmt.Errorf("create database: %w", err)
|
|
}
|
|
log.Printf("database %q is present", dbName)
|
|
|
|
// 2) Connect to target DB and run migrations
|
|
appCtx, cancel2 := context.WithTimeout(ctx, 10*time.Second)
|
|
defer cancel2()
|
|
|
|
appConn, err := pgx.Connect(appCtx, appDSN)
|
|
if err != nil {
|
|
return fmt.Errorf("connect app DB: %w", err)
|
|
}
|
|
defer appConn.Close(ctx)
|
|
|
|
if err := ensureSchemaMigrations(ctx, appConn); err != nil {
|
|
return fmt.Errorf("ensure schema_migrations: %w", err)
|
|
}
|
|
|
|
applied, err := fetchAppliedVersions(ctx, appConn)
|
|
if err != nil {
|
|
return fmt.Errorf("read applied versions: %w", err)
|
|
}
|
|
|
|
files, err := fs.Glob(mig.FS, "*.sql")
|
|
if err != nil {
|
|
return fmt.Errorf("glob migrations: %w", err)
|
|
}
|
|
sort.Strings(files)
|
|
|
|
for _, name := range files {
|
|
version := versionOf(name)
|
|
if applied[version] {
|
|
log.Printf("skip %s (already applied)", version)
|
|
continue
|
|
}
|
|
|
|
sqlBytes, err := mig.FS.ReadFile(name)
|
|
if err != nil {
|
|
return fmt.Errorf("read %s: %w", name, err)
|
|
}
|
|
|
|
log.Printf("applying %s ...", version)
|
|
if err := applyMigration(ctx, appConn, version, string(sqlBytes)); err != nil {
|
|
return fmt.Errorf("apply %s failed: %w", version, err)
|
|
}
|
|
log.Printf("applied %s", version)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// dsnWithDB parses a Postgres URL DSN and replaces the DB path segment.
|
|
// Example: postgres://user:pass@host:5432/appdb?sslmode=disable -> .../postgres?sslmode=disable
|
|
func dsnWithDB(raw string, dbName string) string {
|
|
u, err := url.Parse(raw)
|
|
if err != nil {
|
|
// if malformed, just return raw; pgx will error at connect time
|
|
return raw
|
|
}
|
|
u.Path = "/" + dbName
|
|
return u.String()
|
|
}
|
|
|
|
func createDatabaseIfNotExists(ctx context.Context, admin *pgx.Conn, dbName string) error {
|
|
var exists bool
|
|
if err := admin.QueryRow(ctx,
|
|
"SELECT EXISTS(SELECT 1 FROM pg_database WHERE datname=$1)", dbName).
|
|
Scan(&exists); err != nil {
|
|
return fmt.Errorf("check db exists: %w", err)
|
|
}
|
|
if exists {
|
|
return nil
|
|
}
|
|
// Quote identifier safely via quote_ident
|
|
_, err := admin.Exec(ctx, `
|
|
DO $$
|
|
BEGIN
|
|
IF NOT EXISTS (SELECT FROM pg_database WHERE datname = $1) THEN
|
|
EXECUTE 'CREATE DATABASE ' || quote_ident($1);
|
|
END IF;
|
|
END$$;`, dbName)
|
|
if err != nil {
|
|
// 42P04: duplicate_database
|
|
if pgErr, ok := err.(*pgconn.PgError); ok && pgErr.Code == "42P04" {
|
|
return nil
|
|
}
|
|
return fmt.Errorf("create database: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func ensureSchemaMigrations(ctx context.Context, conn *pgx.Conn) error {
|
|
_, err := conn.Exec(ctx, `
|
|
CREATE TABLE IF NOT EXISTS schema_migrations (
|
|
version TEXT PRIMARY KEY,
|
|
applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
|
)`)
|
|
return err
|
|
}
|
|
|
|
func fetchAppliedVersions(ctx context.Context, conn *pgx.Conn) (map[string]bool, error) {
|
|
rows, err := conn.Query(ctx, `SELECT version FROM schema_migrations`)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
out := make(map[string]bool)
|
|
for rows.Next() {
|
|
var v string
|
|
if err := rows.Scan(&v); err != nil {
|
|
return nil, err
|
|
}
|
|
out[v] = true
|
|
}
|
|
return out, rows.Err()
|
|
}
|
|
|
|
func applyMigration(ctx context.Context, conn *pgx.Conn, version, sqlText string) error {
|
|
tx, err := conn.Begin(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer func() { _ = tx.Rollback(ctx) }()
|
|
|
|
if _, err := tx.Exec(ctx, sqlText); err != nil {
|
|
return err
|
|
}
|
|
if _, err := tx.Exec(ctx, `INSERT INTO schema_migrations(version) VALUES ($1)`, version); err != nil {
|
|
return err
|
|
}
|
|
return tx.Commit(ctx)
|
|
}
|
|
|
|
func versionOf(path string) string {
|
|
base := path[strings.LastIndex(path, "/")+1:]
|
|
// "0001_init.sql" -> "0001"
|
|
dot := strings.Index(base, ".")
|
|
if underscore := strings.Index(base, "_"); underscore > 0 && underscore < dot {
|
|
return base[:underscore]
|
|
}
|
|
if dot > 0 {
|
|
return base[:dot]
|
|
}
|
|
return base
|
|
}
|