2025-11-04 21:22:13 +09:00

184 lines
4.5 KiB
Go

package migrate
import (
"context"
"fmt"
"io/fs"
"log"
"net/url"
"sort"
"strings"
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
mig "watch-party-backend/db/migration"
"watch-party-backend/internal/config"
)
// Run applies all pending migrations.
// Uses cfg.DB.DSN() for the app database and derives an admin DSN (db=postgres) from it.
func Run(ctx context.Context, cfg config.Config) error {
appDSN := cfg.DB.DSN()
adminDSN := dsnWithDB(appDSN, "postgres")
dbName := cfg.DB.Name
// 1) Connect to admin DB to ensure target DB exists
adminCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
adminConn, err := pgx.Connect(adminCtx, adminDSN)
if err != nil {
return fmt.Errorf("connect admin DB: %w", err)
}
defer adminConn.Close(ctx)
if err := createDatabaseIfNotExists(ctx, adminConn, dbName); err != nil {
return fmt.Errorf("create database: %w", err)
}
log.Printf("database %q is present", dbName)
// 2) Connect to target DB and run migrations
appCtx, cancel2 := context.WithTimeout(ctx, 10*time.Second)
defer cancel2()
appConn, err := pgx.Connect(appCtx, appDSN)
if err != nil {
return fmt.Errorf("connect app DB: %w", err)
}
defer appConn.Close(ctx)
if err := ensureSchemaMigrations(ctx, appConn); err != nil {
return fmt.Errorf("ensure schema_migrations: %w", err)
}
applied, err := fetchAppliedVersions(ctx, appConn)
if err != nil {
return fmt.Errorf("read applied versions: %w", err)
}
files, err := fs.Glob(mig.FS, "*.sql")
if err != nil {
return fmt.Errorf("glob migrations: %w", err)
}
sort.Strings(files)
for _, name := range files {
version := versionOf(name)
if applied[version] {
log.Printf("skip %s (already applied)", version)
continue
}
sqlBytes, err := mig.FS.ReadFile(name)
if err != nil {
return fmt.Errorf("read %s: %w", name, err)
}
log.Printf("applying %s ...", version)
if err := applyMigration(ctx, appConn, version, string(sqlBytes)); err != nil {
return fmt.Errorf("apply %s failed: %w", version, err)
}
log.Printf("applied %s", version)
}
return nil
}
// dsnWithDB parses a Postgres URL DSN and replaces the DB path segment.
// Example: postgres://user:pass@host:5432/appdb?sslmode=disable -> .../postgres?sslmode=disable
func dsnWithDB(raw string, dbName string) string {
u, err := url.Parse(raw)
if err != nil {
// if malformed, just return raw; pgx will error at connect time
return raw
}
u.Path = "/" + dbName
return u.String()
}
func createDatabaseIfNotExists(ctx context.Context, admin *pgx.Conn, dbName string) error {
var exists bool
if err := admin.QueryRow(ctx,
"SELECT EXISTS(SELECT 1 FROM pg_database WHERE datname=$1)", dbName).
Scan(&exists); err != nil {
return fmt.Errorf("check db exists: %w", err)
}
if exists {
return nil
}
// Quote identifier safely via quote_ident
_, err := admin.Exec(ctx, `
DO $$
BEGIN
IF NOT EXISTS (SELECT FROM pg_database WHERE datname = $1) THEN
EXECUTE 'CREATE DATABASE ' || quote_ident($1);
END IF;
END$$;`, dbName)
if err != nil {
// 42P04: duplicate_database
if pgErr, ok := err.(*pgconn.PgError); ok && pgErr.Code == "42P04" {
return nil
}
return fmt.Errorf("create database: %w", err)
}
return nil
}
func ensureSchemaMigrations(ctx context.Context, conn *pgx.Conn) error {
_, err := conn.Exec(ctx, `
CREATE TABLE IF NOT EXISTS schema_migrations (
version TEXT PRIMARY KEY,
applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
)`)
return err
}
func fetchAppliedVersions(ctx context.Context, conn *pgx.Conn) (map[string]bool, error) {
rows, err := conn.Query(ctx, `SELECT version FROM schema_migrations`)
if err != nil {
return nil, err
}
defer rows.Close()
out := make(map[string]bool)
for rows.Next() {
var v string
if err := rows.Scan(&v); err != nil {
return nil, err
}
out[v] = true
}
return out, rows.Err()
}
func applyMigration(ctx context.Context, conn *pgx.Conn, version, sqlText string) error {
tx, err := conn.Begin(ctx)
if err != nil {
return err
}
defer func() { _ = tx.Rollback(ctx) }()
if _, err := tx.Exec(ctx, sqlText); err != nil {
return err
}
if _, err := tx.Exec(ctx, `INSERT INTO schema_migrations(version) VALUES ($1)`, version); err != nil {
return err
}
return tx.Commit(ctx)
}
func versionOf(path string) string {
base := path[strings.LastIndex(path, "/")+1:]
// "0001_init.sql" -> "0001"
dot := strings.Index(base, ".")
if underscore := strings.Index(base, "_"); underscore > 0 && underscore < dot {
return base[:underscore]
}
if dot > 0 {
return base[:dot]
}
return base
}