package grpc import ( "context" "errors" "net" "reflect" "testing" "time" hav1 "gitea.nik4nao.com/nik/home-services/gen/ha/v1" "gitea.nik4nao.com/nik/home-services/ha-gateway/internal/core/domain" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/status" "google.golang.org/grpc/test/bufconn" ) type mockEntityService struct { getStateFunc func(ctx context.Context, id domain.EntityID) (*domain.EntityState, error) listStatesFunc func(ctx context.Context, ids []domain.EntityID, domainFilter string) ([]*domain.EntityState, error) } func (m *mockEntityService) GetState(ctx context.Context, id domain.EntityID) (*domain.EntityState, error) { if m.getStateFunc == nil { return nil, nil } return m.getStateFunc(ctx, id) } func (m *mockEntityService) ListStates(ctx context.Context, ids []domain.EntityID, domainFilter string) ([]*domain.EntityState, error) { if m.listStatesFunc == nil { return nil, nil } return m.listStatesFunc(ctx, ids, domainFilter) } func TestEntityGRPCGetState(t *testing.T) { now := time.Date(2026, 4, 9, 10, 0, 0, 0, time.UTC) tests := []struct { name string err error wantCode codes.Code }{ {name: "happy path", wantCode: codes.OK}, {name: "not found maps to codes.NotFound", err: ErrNotFound, wantCode: codes.NotFound}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var gotID domain.EntityID conn := newEntityTestClientConn(t, &mockEntityService{ getStateFunc: func(ctx context.Context, id domain.EntityID) (*domain.EntityState, error) { gotID = id if tt.err != nil { return nil, tt.err } return &domain.EntityState{ EntityID: "light.kitchen", State: "on", Attributes: map[string]string{"friendly_name": "Kitchen"}, LastChanged: now, LastUpdated: now, }, nil }, }) client := hav1.NewEntityServiceClient(conn) resp, err := client.GetState(context.Background(), &hav1.GetStateRequest{EntityId: "light.kitchen"}) if status.Code(err) != tt.wantCode { t.Fatalf("status code = %v, want %v", status.Code(err), tt.wantCode) } if tt.wantCode != codes.OK { return } if gotID != "light.kitchen" { t.Fatalf("GetState id = %q, want %q", gotID, "light.kitchen") } if resp.GetState().GetEntityId() != "light.kitchen" { t.Fatalf("response state = %#v", resp.GetState()) } }) } } func TestEntityGRPCListStates(t *testing.T) { now := time.Date(2026, 4, 9, 10, 0, 0, 0, time.UTC) tests := []struct { name string err error wantCode codes.Code }{ {name: "happy path", wantCode: codes.OK}, {name: "error maps to codes.Internal", err: errors.New("boom"), wantCode: codes.Internal}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var gotIDs []domain.EntityID var gotDomain string conn := newEntityTestClientConn(t, &mockEntityService{ listStatesFunc: func(ctx context.Context, ids []domain.EntityID, domainFilter string) ([]*domain.EntityState, error) { gotIDs = append([]domain.EntityID(nil), ids...) gotDomain = domainFilter if tt.err != nil { return nil, tt.err } return []*domain.EntityState{ { EntityID: "light.kitchen", State: "on", Attributes: map[string]string{"friendly_name": "Kitchen"}, LastChanged: now, LastUpdated: now, }, { EntityID: "light.desk", State: "off", Attributes: map[string]string{"friendly_name": "Desk"}, LastChanged: now, LastUpdated: now, }, }, nil }, }) client := hav1.NewEntityServiceClient(conn) resp, err := client.ListStates(context.Background(), &hav1.ListStatesRequest{ EntityIds: []string{"light.kitchen", "light.desk"}, Domain: "light", }) if status.Code(err) != tt.wantCode { t.Fatalf("status code = %v, want %v", status.Code(err), tt.wantCode) } if tt.wantCode != codes.OK { return } wantIDs := []domain.EntityID{"light.kitchen", "light.desk"} if !reflect.DeepEqual(gotIDs, wantIDs) { t.Fatalf("ListStates ids = %#v, want %#v", gotIDs, wantIDs) } if gotDomain != "light" { t.Fatalf("ListStates domain = %q, want %q", gotDomain, "light") } if len(resp.GetStates()) != 2 { t.Fatalf("len(states) = %d, want 2", len(resp.GetStates())) } }) } } func newEntityTestClientConn(t *testing.T, svc *mockEntityService) *grpc.ClientConn { t.Helper() lis := bufconn.Listen(testBufSize) server := grpc.NewServer() hav1.RegisterEntityServiceServer(server, NewEntityGRPC(svc)) go func() { _ = server.Serve(lis) }() t.Cleanup(func() { server.Stop() _ = lis.Close() }) conn, err := grpc.DialContext( context.Background(), "bufnet", grpc.WithContextDialer(func(ctx context.Context, s string) (net.Conn, error) { return lis.Dial() }), grpc.WithTransportCredentials(insecure.NewCredentials()), ) if err != nil { t.Fatalf("grpc.DialContext() error = %v", err) } t.Cleanup(func() { _ = conn.Close() }) return conn }