package migrate import ( "context" "fmt" "io/fs" "log" "net/url" "sort" "strings" "time" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" mig "watch-party-backend/db/migration" "watch-party-backend/internal/config" ) // Run applies all pending migrations. // Uses cfg.DB.DSN() for the app database and derives an admin DSN (db=postgres) from it. func Run(ctx context.Context, cfg config.Config) error { appDSN := cfg.DB.DSN() adminDSN := dsnWithDB(appDSN, "postgres") dbName := cfg.DB.Name // 1) Connect to admin DB to ensure target DB exists adminCtx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() adminConn, err := pgx.Connect(adminCtx, adminDSN) if err != nil { return fmt.Errorf("connect admin DB: %w", err) } defer adminConn.Close(ctx) if err := createDatabaseIfNotExists(ctx, adminConn, dbName); err != nil { return fmt.Errorf("create database: %w", err) } log.Printf("database %q is present", dbName) // 2) Connect to target DB and run migrations appCtx, cancel2 := context.WithTimeout(ctx, 10*time.Second) defer cancel2() appConn, err := pgx.Connect(appCtx, appDSN) if err != nil { return fmt.Errorf("connect app DB: %w", err) } defer appConn.Close(ctx) if err := ensureSchemaMigrations(ctx, appConn); err != nil { return fmt.Errorf("ensure schema_migrations: %w", err) } applied, err := fetchAppliedVersions(ctx, appConn) if err != nil { return fmt.Errorf("read applied versions: %w", err) } files, err := fs.Glob(mig.FS, "*.sql") if err != nil { return fmt.Errorf("glob migrations: %w", err) } sort.Strings(files) for _, name := range files { version := versionOf(name) if applied[version] { log.Printf("skip %s (already applied)", version) continue } sqlBytes, err := mig.FS.ReadFile(name) if err != nil { return fmt.Errorf("read %s: %w", name, err) } log.Printf("applying %s ...", version) if err := applyMigration(ctx, appConn, version, string(sqlBytes)); err != nil { return fmt.Errorf("apply %s failed: %w", version, err) } log.Printf("applied %s", version) } return nil } // dsnWithDB parses a Postgres URL DSN and replaces the DB path segment. // Example: postgres://user:pass@host:5432/appdb?sslmode=disable -> .../postgres?sslmode=disable func dsnWithDB(raw string, dbName string) string { u, err := url.Parse(raw) if err != nil { // if malformed, just return raw; pgx will error at connect time return raw } u.Path = "/" + dbName return u.String() } func createDatabaseIfNotExists(ctx context.Context, admin *pgx.Conn, dbName string) error { var exists bool if err := admin.QueryRow(ctx, "SELECT EXISTS(SELECT 1 FROM pg_database WHERE datname=$1)", dbName). Scan(&exists); err != nil { return fmt.Errorf("check db exists: %w", err) } if exists { return nil } // Quote identifier safely via quote_ident _, err := admin.Exec(ctx, ` DO $$ BEGIN IF NOT EXISTS (SELECT FROM pg_database WHERE datname = $1) THEN EXECUTE 'CREATE DATABASE ' || quote_ident($1); END IF; END$$;`, dbName) if err != nil { // 42P04: duplicate_database if pgErr, ok := err.(*pgconn.PgError); ok && pgErr.Code == "42P04" { return nil } return fmt.Errorf("create database: %w", err) } return nil } func ensureSchemaMigrations(ctx context.Context, conn *pgx.Conn) error { _, err := conn.Exec(ctx, ` CREATE TABLE IF NOT EXISTS schema_migrations ( version TEXT PRIMARY KEY, applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW() )`) return err } func fetchAppliedVersions(ctx context.Context, conn *pgx.Conn) (map[string]bool, error) { rows, err := conn.Query(ctx, `SELECT version FROM schema_migrations`) if err != nil { return nil, err } defer rows.Close() out := make(map[string]bool) for rows.Next() { var v string if err := rows.Scan(&v); err != nil { return nil, err } out[v] = true } return out, rows.Err() } func applyMigration(ctx context.Context, conn *pgx.Conn, version, sqlText string) error { tx, err := conn.Begin(ctx) if err != nil { return err } defer func() { _ = tx.Rollback(ctx) }() if _, err := tx.Exec(ctx, sqlText); err != nil { return err } if _, err := tx.Exec(ctx, `INSERT INTO schema_migrations(version) VALUES ($1)`, version); err != nil { return err } return tx.Commit(ctx) } func versionOf(path string) string { base := path[strings.LastIndex(path, "/")+1:] // "0001_init.sql" -> "0001" dot := strings.Index(base, ".") if underscore := strings.Index(base, "_"); underscore > 0 && underscore < dot { return base[:underscore] } if dot > 0 { return base[:dot] } return base }