package repo import ( "context" "errors" "fmt" "reflect" "strings" "testing" "time" "watch-party-backend/internal/core/episode" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" ) func TestPGXEpisodeRepo_Delete(t *testing.T) { t.Run("not found", func(t *testing.T) { fp := &fakePool{ execFn: func(ctx context.Context, sql string, args ...any) (pgconn.CommandTag, error) { if sql != `DELETE FROM current_archive WHERE id = $1` { t.Fatalf("unexpected sql: %s", sql) } return pgconn.NewCommandTag("DELETE 0"), nil }, } repo := &pgxEpisodeRepo{pool: fp} if err := repo.Delete(context.Background(), 10); !errors.Is(err, episode.ErrNotFound) { t.Fatalf("expected ErrNotFound, got %v", err) } }) t.Run("ok", func(t *testing.T) { var ( gotSQL string gotID int64 ) fp := &fakePool{ execFn: func(ctx context.Context, sql string, args ...any) (pgconn.CommandTag, error) { gotSQL = sql gotID = args[0].(int64) return pgconn.NewCommandTag("DELETE 1"), nil }, } repo := &pgxEpisodeRepo{pool: fp} if err := repo.Delete(context.Background(), 22); err != nil { t.Fatalf("unexpected err: %v", err) } if gotSQL != `DELETE FROM current_archive WHERE id = $1` { t.Fatalf("expected archive delete sql, got %s", gotSQL) } if gotID != 22 { t.Fatalf("expected id 22, got %d", gotID) } }) } func TestPGXEpisodeRepo_GetCurrent(t *testing.T) { t.Run("not found", func(t *testing.T) { fp := &fakePool{ queryRowFn: func(ctx context.Context, sql string, args ...any) pgx.Row { return fakeRow{err: pgx.ErrNoRows} }, } repo := &pgxEpisodeRepo{pool: fp} _, err := repo.GetCurrent(context.Background()) if !errors.Is(err, episode.ErrNotFound) { t.Fatalf("expected ErrNotFound, got %v", err) } }) t.Run("ok", func(t *testing.T) { now := time.Now().UTC() fp := &fakePool{ queryRowFn: func(ctx context.Context, sql string, args ...any) pgx.Row { return fakeRow{values: []any{ 3, 5, "Title", "S1", "12:00:00", "00:24:00", true, now, }} }, } repo := &pgxEpisodeRepo{pool: fp} row, err := repo.GetCurrent(context.Background()) if err != nil { t.Fatalf("unexpected err: %v", err) } if row.Id != 3 || row.EpTitle != "Title" { t.Fatalf("bad row: %+v", row) } }) } func TestPGXEpisodeRepo_ListAll(t *testing.T) { now := time.Now().UTC() fp := &fakePool{ queryFn: func(ctx context.Context, sql string, args ...any) (pgx.Rows, error) { return &fakeRows{ rows: [][]any{ {1, 1, "Pilot", "S1", "10:00:00", "00:24:00", true, now}, {2, 2, "Next", "S1", "10:30:00", "00:24:00", false, now}, }, }, nil }, } repo := &pgxEpisodeRepo{pool: fp} items, err := repo.ListAll(context.Background()) if err != nil { t.Fatalf("unexpected err: %v", err) } if len(items) != 2 || items[1].EpNum != 2 { t.Fatalf("bad items: %+v", items) } } func TestPGXEpisodeRepo_Create(t *testing.T) { now := time.Now().UTC() var gotSQL string var gotArgs []any fp := &fakePool{ queryRowFn: func(ctx context.Context, sql string, args ...any) pgx.Row { gotSQL = sql gotArgs = append([]any(nil), args...) return fakeRow{values: []any{ 5, 10, "Title", "S1", "12:00:00", "00:24:00", false, now, }} }, } repo := &pgxEpisodeRepo{pool: fp} row, err := repo.Create(context.Background(), episode.NewShowInput{ EpNum: 10, EpTitle: "Title", SeasonName: "S1", StartTime: "12:00:00", PlaybackLength: "00:24:00", }) if err != nil { t.Fatalf("unexpected err: %v", err) } if row.Id != 5 || !strings.Contains(gotSQL, "INSERT INTO current") { t.Fatalf("bad insert: id=%d sql=%s", row.Id, gotSQL) } if len(gotArgs) != 5 || gotArgs[0] != 10 || gotArgs[4] != "00:24:00" { t.Fatalf("bad args: %+v", gotArgs) } if row.CurrentEp { t.Fatalf("expected current_ep false") } } // --- fakes --- type fakePool struct { queryFn func(ctx context.Context, sql string, args ...any) (pgx.Rows, error) queryRowFn func(ctx context.Context, sql string, args ...any) pgx.Row execFn func(ctx context.Context, sql string, args ...any) (pgconn.CommandTag, error) } func (f *fakePool) Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) { if f.queryFn == nil { return nil, fmt.Errorf("unexpected Query call") } return f.queryFn(ctx, sql, args...) } func (f *fakePool) QueryRow(ctx context.Context, sql string, args ...any) pgx.Row { if f.queryRowFn == nil { return fakeRow{err: fmt.Errorf("unexpected QueryRow call")} } return f.queryRowFn(ctx, sql, args...) } func (f *fakePool) Exec(ctx context.Context, sql string, args ...any) (pgconn.CommandTag, error) { if f.execFn == nil { return pgconn.CommandTag{}, fmt.Errorf("unexpected Exec call") } return f.execFn(ctx, sql, args...) } func (f *fakePool) Begin(ctx context.Context) (pgx.Tx, error) { return nil, fmt.Errorf("Begin not implemented") } type fakeRow struct { values []any err error } func (r fakeRow) Scan(dest ...any) error { if r.err != nil { return r.err } return assignValues(dest, r.values) } type fakeRows struct { rows [][]any idx int err error } func (r *fakeRows) Close() {} func (r *fakeRows) Err() error { return r.err } func (r *fakeRows) CommandTag() pgconn.CommandTag { return pgconn.CommandTag{} } func (r *fakeRows) FieldDescriptions() []pgconn.FieldDescription { return nil } func (r *fakeRows) Next() bool { if r.err != nil { return false } if r.idx >= len(r.rows) { return false } r.idx++ return true } func (r *fakeRows) Scan(dest ...any) error { if r.idx == 0 { return fmt.Errorf("Scan without Next") } return assignValues(dest, r.rows[r.idx-1]) } func (r *fakeRows) Values() ([]any, error) { if r.idx == 0 || r.idx > len(r.rows) { return nil, fmt.Errorf("no row") } return append([]any(nil), r.rows[r.idx-1]...), nil } func (r *fakeRows) RawValues() [][]byte { return nil } func (r *fakeRows) Conn() *pgx.Conn { return nil } func assignValues(dest []any, values []any) error { if len(dest) != len(values) { return fmt.Errorf("dest len %d != values len %d", len(dest), len(values)) } for i := range dest { if dest[i] == nil { continue } rv := reflect.ValueOf(dest[i]) if rv.Kind() != reflect.Pointer { return fmt.Errorf("dest %d not pointer", i) } rv = rv.Elem() val := reflect.ValueOf(values[i]) if !val.Type().AssignableTo(rv.Type()) { return fmt.Errorf("value %d type %s not assignable to %s", i, val.Type(), rv.Type()) } rv.Set(val) } return nil }