From 6ac23c24ee75c4e312928f48dad338294e7a36e9 Mon Sep 17 00:00:00 2001 From: Nik Afiq Date: Wed, 3 Dec 2025 20:34:08 +0900 Subject: [PATCH] Add tests for episode repository methods and refactor service interface --- backend/internal/http/handlers_test.go | 96 ++++++++- backend/internal/repo/episode_repo.go | 10 +- backend/internal/repo/episode_repo_test.go | 221 +++++++++++++++++++++ 3 files changed, 320 insertions(+), 7 deletions(-) create mode 100644 backend/internal/repo/episode_repo_test.go diff --git a/backend/internal/http/handlers_test.go b/backend/internal/http/handlers_test.go index 6ee5e80..f534532 100644 --- a/backend/internal/http/handlers_test.go +++ b/backend/internal/http/handlers_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "net/http" "net/http/httptest" "testing" @@ -24,9 +25,13 @@ type fakeSvc struct { setErr error moveRes episode.MoveResult moveErr error + listRes []episode.Episode + listErr error + deleteErr error lastSetID int64 lastTime string lastMove []int64 + lastDelID int64 } func (f *fakeSvc) GetCurrent(ctx context.Context) (episode.Episode, error) { @@ -41,15 +46,14 @@ func (f *fakeSvc) MoveToArchive(ctx context.Context, ids []int64) (episode.MoveR return f.moveRes, f.moveErr } func (f *fakeSvc) ListAll(ctx context.Context) ([]episode.Episode, error) { - if f.moveErr != nil { - return nil, f.moveErr + if f.listRes == nil && f.listErr == nil { + return []episode.Episode{{Id: 10, EpTitle: "X"}}, nil } - return []episode.Episode{ - {Id: 10, EpTitle: "X"}, - }, nil + return f.listRes, f.listErr } func (f *fakeSvc) Delete(ctx context.Context, id int64) error { - return nil + f.lastDelID = id + return f.deleteErr } // ---- helpers ---- @@ -199,3 +203,83 @@ func TestPostArchive_SingleAndMultiple_OK(t *testing.T) { t.Fatalf("missing counters in response: %+v", body) } } + +func TestListShows_OK(t *testing.T) { + svc := &fakeSvc{ + listRes: []episode.Episode{ + {Id: 1, EpTitle: "Pilot"}, + {Id: 2, EpTitle: "Next"}, + }, + } + r := newRouterWithSvc(svc) + req := httptest.NewRequest(http.MethodGet, "/api/v1/shows", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + if w.Code != http.StatusOK { + t.Fatalf("got %d: %s", w.Code, w.Body.String()) + } + var got []episode.Episode + if err := json.Unmarshal(w.Body.Bytes(), &got); err != nil { + t.Fatalf("json: %v", err) + } + if len(got) != 2 || got[0].EpTitle != "Pilot" { + t.Fatalf("unexpected list: %+v", got) + } +} + +func TestListShows_Error(t *testing.T) { + svc := &fakeSvc{listErr: errors.New("query failed")} + r := newRouterWithSvc(svc) + req := httptest.NewRequest(http.MethodGet, "/api/v1/shows", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + if w.Code != http.StatusInternalServerError { + t.Fatalf("expected 500, got %d", w.Code) + } +} + +func TestDeleteShows_InvalidID(t *testing.T) { + r := newRouterWithSvc(&fakeSvc{}) + req := httptest.NewRequest(http.MethodDelete, "/api/v1/shows?id=abc", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + if w.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", w.Code) + } +} + +func TestDeleteShows_NotFound(t *testing.T) { + svc := &fakeSvc{deleteErr: episode.ErrNotFound} + r := newRouterWithSvc(svc) + req := httptest.NewRequest(http.MethodDelete, "/api/v1/shows?id=99", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + if w.Code != http.StatusNotFound { + t.Fatalf("expected 404, got %d", w.Code) + } +} + +func TestDeleteShows_OtherError(t *testing.T) { + svc := &fakeSvc{deleteErr: errors.New("db down")} + r := newRouterWithSvc(svc) + req := httptest.NewRequest(http.MethodDelete, "/api/v1/shows?id=77", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + if w.Code != http.StatusInternalServerError { + t.Fatalf("expected 500, got %d", w.Code) + } +} + +func TestDeleteShows_OK(t *testing.T) { + svc := &fakeSvc{} + r := newRouterWithSvc(svc) + req := httptest.NewRequest(http.MethodDelete, "/api/v1/shows?id=11", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + if w.Code != http.StatusNoContent { + t.Fatalf("expected 204, got %d", w.Code) + } + if svc.lastDelID != 11 { + t.Fatalf("expected delete id 11, got %d", svc.lastDelID) + } +} diff --git a/backend/internal/repo/episode_repo.go b/backend/internal/repo/episode_repo.go index 846efe4..f00f45f 100644 --- a/backend/internal/repo/episode_repo.go +++ b/backend/internal/repo/episode_repo.go @@ -7,11 +7,19 @@ import ( "watch-party-backend/internal/core/episode" "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgxpool" ) +type pgxPool interface { + Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) + QueryRow(ctx context.Context, sql string, args ...any) pgx.Row + Exec(ctx context.Context, sql string, args ...any) (pgconn.CommandTag, error) + Begin(ctx context.Context) (pgx.Tx, error) +} + type pgxEpisodeRepo struct { - pool *pgxpool.Pool + pool pgxPool } func NewEpisodeRepo(pool *pgxpool.Pool) episode.Repository { diff --git a/backend/internal/repo/episode_repo_test.go b/backend/internal/repo/episode_repo_test.go new file mode 100644 index 0000000..9fc125a --- /dev/null +++ b/backend/internal/repo/episode_repo_test.go @@ -0,0 +1,221 @@ +package repo + +import ( + "context" + "errors" + "fmt" + "reflect" + "testing" + "time" + + "watch-party-backend/internal/core/episode" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" +) + +func TestPGXEpisodeRepo_Delete(t *testing.T) { + t.Run("not found", func(t *testing.T) { + fp := &fakePool{ + execFn: func(ctx context.Context, sql string, args ...any) (pgconn.CommandTag, error) { + if sql != `DELETE FROM current WHERE id = $1` { + t.Fatalf("unexpected sql: %s", sql) + } + return pgconn.NewCommandTag("DELETE 0"), nil + }, + } + repo := &pgxEpisodeRepo{pool: fp} + if err := repo.Delete(context.Background(), 10); !errors.Is(err, episode.ErrNotFound) { + t.Fatalf("expected ErrNotFound, got %v", err) + } + }) + + t.Run("ok", func(t *testing.T) { + var gotID int64 + fp := &fakePool{ + execFn: func(ctx context.Context, sql string, args ...any) (pgconn.CommandTag, error) { + gotID = args[0].(int64) + return pgconn.NewCommandTag("DELETE 1"), nil + }, + } + repo := &pgxEpisodeRepo{pool: fp} + if err := repo.Delete(context.Background(), 22); err != nil { + t.Fatalf("unexpected err: %v", err) + } + if gotID != 22 { + t.Fatalf("expected id 22, got %d", gotID) + } + }) +} + +func TestPGXEpisodeRepo_GetCurrent(t *testing.T) { + t.Run("not found", func(t *testing.T) { + fp := &fakePool{ + queryRowFn: func(ctx context.Context, sql string, args ...any) pgx.Row { + return fakeRow{err: pgx.ErrNoRows} + }, + } + repo := &pgxEpisodeRepo{pool: fp} + _, err := repo.GetCurrent(context.Background()) + if !errors.Is(err, episode.ErrNotFound) { + t.Fatalf("expected ErrNotFound, got %v", err) + } + }) + + t.Run("ok", func(t *testing.T) { + now := time.Now().UTC() + fp := &fakePool{ + queryRowFn: func(ctx context.Context, sql string, args ...any) pgx.Row { + return fakeRow{values: []any{ + 3, + 5, + "Title", + "S1", + "12:00:00", + "00:24:00", + now, + }} + }, + } + repo := &pgxEpisodeRepo{pool: fp} + row, err := repo.GetCurrent(context.Background()) + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + if row.Id != 3 || row.EpTitle != "Title" { + t.Fatalf("bad row: %+v", row) + } + }) +} + +func TestPGXEpisodeRepo_ListAll(t *testing.T) { + now := time.Now().UTC() + fp := &fakePool{ + queryFn: func(ctx context.Context, sql string, args ...any) (pgx.Rows, error) { + return &fakeRows{ + rows: [][]any{ + {1, 1, "Pilot", "S1", "10:00:00", "00:24:00", now}, + {2, 2, "Next", "S1", "10:30:00", "00:24:00", now}, + }, + }, nil + }, + } + repo := &pgxEpisodeRepo{pool: fp} + items, err := repo.ListAll(context.Background()) + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + if len(items) != 2 || items[1].EpNum != 2 { + t.Fatalf("bad items: %+v", items) + } +} + +// --- fakes --- + +type fakePool struct { + queryFn func(ctx context.Context, sql string, args ...any) (pgx.Rows, error) + queryRowFn func(ctx context.Context, sql string, args ...any) pgx.Row + execFn func(ctx context.Context, sql string, args ...any) (pgconn.CommandTag, error) +} + +func (f *fakePool) Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) { + if f.queryFn == nil { + return nil, fmt.Errorf("unexpected Query call") + } + return f.queryFn(ctx, sql, args...) +} + +func (f *fakePool) QueryRow(ctx context.Context, sql string, args ...any) pgx.Row { + if f.queryRowFn == nil { + return fakeRow{err: fmt.Errorf("unexpected QueryRow call")} + } + return f.queryRowFn(ctx, sql, args...) +} + +func (f *fakePool) Exec(ctx context.Context, sql string, args ...any) (pgconn.CommandTag, error) { + if f.execFn == nil { + return pgconn.CommandTag{}, fmt.Errorf("unexpected Exec call") + } + return f.execFn(ctx, sql, args...) +} + +func (f *fakePool) Begin(ctx context.Context) (pgx.Tx, error) { + return nil, fmt.Errorf("Begin not implemented") +} + +type fakeRow struct { + values []any + err error +} + +func (r fakeRow) Scan(dest ...any) error { + if r.err != nil { + return r.err + } + return assignValues(dest, r.values) +} + +type fakeRows struct { + rows [][]any + idx int + err error +} + +func (r *fakeRows) Close() {} + +func (r *fakeRows) Err() error { return r.err } + +func (r *fakeRows) CommandTag() pgconn.CommandTag { return pgconn.CommandTag{} } + +func (r *fakeRows) FieldDescriptions() []pgconn.FieldDescription { return nil } + +func (r *fakeRows) Next() bool { + if r.err != nil { + return false + } + if r.idx >= len(r.rows) { + return false + } + r.idx++ + return true +} + +func (r *fakeRows) Scan(dest ...any) error { + if r.idx == 0 { + return fmt.Errorf("Scan without Next") + } + return assignValues(dest, r.rows[r.idx-1]) +} + +func (r *fakeRows) Values() ([]any, error) { + if r.idx == 0 || r.idx > len(r.rows) { + return nil, fmt.Errorf("no row") + } + return append([]any(nil), r.rows[r.idx-1]...), nil +} + +func (r *fakeRows) RawValues() [][]byte { return nil } + +func (r *fakeRows) Conn() *pgx.Conn { return nil } + +func assignValues(dest []any, values []any) error { + if len(dest) != len(values) { + return fmt.Errorf("dest len %d != values len %d", len(dest), len(values)) + } + for i := range dest { + if dest[i] == nil { + continue + } + rv := reflect.ValueOf(dest[i]) + if rv.Kind() != reflect.Pointer { + return fmt.Errorf("dest %d not pointer", i) + } + rv = rv.Elem() + val := reflect.ValueOf(values[i]) + if !val.Type().AssignableTo(rv.Type()) { + return fmt.Errorf("value %d type %s not assignable to %s", i, val.Type(), rv.Type()) + } + rv.Set(val) + } + return nil +}