- Updated LLMClient interface to support model-specific generation and model listing. - Integrated model store and validator into the command application for managing AI models. - Implemented commands for setting, getting, and listing active AI models in Discord. - Enhanced AI query handling to utilize the selected model and return model information in responses. - Added caching mechanism for model validation to improve performance. - Introduced gRPC methods for listing available AI models in the ai-gateway. - Updated protobuf definitions to include model-related fields and messages. - Added tests for model store and validator functionalities.
217 lines
7.1 KiB
Go
217 lines
7.1 KiB
Go
package app
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"log/slog"
|
|
"reflect"
|
|
"testing"
|
|
"time"
|
|
|
|
"gitea.nik4nao.com/nik/home-services/ai-gateway/internal/core/domain"
|
|
"gitea.nik4nao.com/nik/home-services/ai-gateway/internal/core/ports/driven"
|
|
)
|
|
|
|
type fakeLLM struct {
|
|
generate func(context.Context, string, string) (string, error)
|
|
listModels func(context.Context) ([]string, error)
|
|
}
|
|
|
|
func (f *fakeLLM) Generate(ctx context.Context, model, prompt string) (string, error) {
|
|
return f.generate(ctx, model, prompt)
|
|
}
|
|
|
|
func (f *fakeLLM) ListModels(ctx context.Context) ([]string, error) {
|
|
if f.listModels == nil {
|
|
return nil, nil
|
|
}
|
|
return f.listModels(ctx)
|
|
}
|
|
|
|
type fakeHA struct {
|
|
lights []driven.Light
|
|
listErr error
|
|
turnOnErr error
|
|
turnOffErr error
|
|
lastTurnOnID string
|
|
lastTurnOffID string
|
|
lastTurnParams map[string]string
|
|
listCalls int
|
|
}
|
|
|
|
func (f *fakeHA) TurnOnLight(ctx context.Context, entity string, params map[string]string) error {
|
|
f.lastTurnOnID = entity
|
|
f.lastTurnParams = params
|
|
return f.turnOnErr
|
|
}
|
|
|
|
func (f *fakeHA) TurnOffLight(ctx context.Context, entity string) error {
|
|
f.lastTurnOffID = entity
|
|
return f.turnOffErr
|
|
}
|
|
|
|
func (f *fakeHA) ListLights(ctx context.Context) ([]driven.Light, error) {
|
|
f.listCalls++
|
|
if f.listErr != nil {
|
|
return nil, f.listErr
|
|
}
|
|
return append([]driven.Light(nil), f.lights...), nil
|
|
}
|
|
|
|
func TestQueryAppTurnOnLight(t *testing.T) {
|
|
ha := &fakeHA{lights: []driven.Light{{EntityID: "light.kitchen", FriendlyName: "Kitchen", State: "off"}}}
|
|
cache := domain.NewLightCache(time.Hour, ha.ListLights)
|
|
app := NewQueryApp(&fakeLLM{
|
|
generate: func(ctx context.Context, model, prompt string) (string, error) {
|
|
if model != "llama3:latest" {
|
|
t.Fatalf("Generate() model = %q", model)
|
|
}
|
|
return `{"intent":"turn_on_light","entity":"Kitchen","params":{"brightness":"80"},"reply":"Turning on Kitchen."}`, nil
|
|
},
|
|
}, ha, cache, "llama3:latest", slog.Default())
|
|
|
|
got, err := app.Query(context.Background(), "turn on kitchen", "")
|
|
if err != nil {
|
|
t.Fatalf("Query() error = %v", err)
|
|
}
|
|
if got.Intent != domain.IntentTurnOnLight || !got.ActionTaken || got.Reply != "Turning on Kitchen." {
|
|
t.Fatalf("Query() = %+v", got)
|
|
}
|
|
if ha.lastTurnOnID != "light.kitchen" {
|
|
t.Fatalf("TurnOnLight entity = %q", ha.lastTurnOnID)
|
|
}
|
|
if !reflect.DeepEqual(ha.lastTurnParams, map[string]string{"brightness": "80"}) {
|
|
t.Fatalf("TurnOnLight params = %#v", ha.lastTurnParams)
|
|
}
|
|
}
|
|
|
|
func TestQueryAppInvalidJSON(t *testing.T) {
|
|
ha := &fakeHA{lights: []driven.Light{{EntityID: "light.kitchen", FriendlyName: "Kitchen", State: "off"}}}
|
|
app := NewQueryApp(&fakeLLM{
|
|
generate: func(ctx context.Context, model, prompt string) (string, error) {
|
|
return `not-json`, nil
|
|
},
|
|
}, ha, domain.NewLightCache(time.Hour, ha.ListLights), "llama3:latest", slog.Default())
|
|
|
|
got, err := app.Query(context.Background(), "turn on kitchen", "")
|
|
if err != nil {
|
|
t.Fatalf("Query() error = %v", err)
|
|
}
|
|
if got.Reply != "I didn't understand that." || got.ActionTaken {
|
|
t.Fatalf("Query() = %+v", got)
|
|
}
|
|
if ha.lastTurnOnID != "" {
|
|
t.Fatalf("expected no HA call, got %q", ha.lastTurnOnID)
|
|
}
|
|
}
|
|
|
|
func TestQueryAppIntentNone(t *testing.T) {
|
|
ha := &fakeHA{lights: []driven.Light{{EntityID: "light.kitchen", FriendlyName: "Kitchen", State: "off"}}}
|
|
app := NewQueryApp(&fakeLLM{
|
|
generate: func(ctx context.Context, model, prompt string) (string, error) {
|
|
return `{"intent":"none","entity":"","params":{},"reply":"Hello there."}`, nil
|
|
},
|
|
}, ha, domain.NewLightCache(time.Hour, ha.ListLights), "llama3:latest", slog.Default())
|
|
|
|
got, err := app.Query(context.Background(), "hello", "")
|
|
if err != nil {
|
|
t.Fatalf("Query() error = %v", err)
|
|
}
|
|
if got.Reply != "Hello there." || got.ActionTaken {
|
|
t.Fatalf("Query() = %+v", got)
|
|
}
|
|
}
|
|
|
|
func TestQueryAppHAFailure(t *testing.T) {
|
|
ha := &fakeHA{
|
|
lights: []driven.Light{{EntityID: "light.kitchen", FriendlyName: "Kitchen", State: "off"}},
|
|
turnOnErr: errors.New("boom"),
|
|
}
|
|
app := NewQueryApp(&fakeLLM{
|
|
generate: func(ctx context.Context, model, prompt string) (string, error) {
|
|
return `{"intent":"turn_on_light","entity":"light.kitchen","params":{},"reply":"Turning on Kitchen."}`, nil
|
|
},
|
|
}, ha, domain.NewLightCache(time.Hour, ha.ListLights), "llama3:latest", slog.Default())
|
|
|
|
got, err := app.Query(context.Background(), "turn on kitchen", "")
|
|
if err != nil {
|
|
t.Fatalf("Query() error = %v", err)
|
|
}
|
|
if got.Reply != "I couldn't reach Home Assistant right now." || got.ActionTaken {
|
|
t.Fatalf("Query() = %+v", got)
|
|
}
|
|
}
|
|
|
|
func TestQueryAppListLights(t *testing.T) {
|
|
ha := &fakeHA{lights: []driven.Light{{EntityID: "light.kitchen", FriendlyName: "Kitchen", State: "on"}}}
|
|
app := NewQueryApp(&fakeLLM{
|
|
generate: func(ctx context.Context, model, prompt string) (string, error) {
|
|
return `{"intent":"list_lights","entity":"","params":{},"reply":""}`, nil
|
|
},
|
|
}, ha, domain.NewLightCache(time.Hour, ha.ListLights), "llama3:latest", slog.Default())
|
|
|
|
got, err := app.Query(context.Background(), "what lights exist", "")
|
|
if err != nil {
|
|
t.Fatalf("Query() error = %v", err)
|
|
}
|
|
want := "Known lights:\n- Kitchen (light.kitchen) [on]"
|
|
if got.Reply != want || got.ActionTaken {
|
|
t.Fatalf("Query() = %+v", got)
|
|
}
|
|
}
|
|
|
|
func TestLightCacheRefreshAfterTTL(t *testing.T) {
|
|
ha := &fakeHA{lights: []driven.Light{{EntityID: "light.kitchen", FriendlyName: "Kitchen", State: "off"}}}
|
|
cache := domain.NewLightCache(10*time.Millisecond, ha.ListLights)
|
|
|
|
if _, err := cache.Get(context.Background()); err != nil {
|
|
t.Fatalf("Get() error = %v", err)
|
|
}
|
|
time.Sleep(20 * time.Millisecond)
|
|
if _, err := cache.Get(context.Background()); err != nil {
|
|
t.Fatalf("Get() error = %v", err)
|
|
}
|
|
if ha.listCalls < 2 {
|
|
t.Fatalf("ListLights calls = %d, want at least 2", ha.listCalls)
|
|
}
|
|
}
|
|
|
|
func TestQueryAppExplicitModel(t *testing.T) {
|
|
ha := &fakeHA{lights: []driven.Light{{EntityID: "light.kitchen", FriendlyName: "Kitchen", State: "off"}}}
|
|
app := NewQueryApp(&fakeLLM{
|
|
generate: func(ctx context.Context, model, prompt string) (string, error) {
|
|
if model != "qwen3:latest" {
|
|
t.Fatalf("Generate() model = %q", model)
|
|
}
|
|
return `{"intent":"none","entity":"","params":{},"reply":"Hello there."}`, nil
|
|
},
|
|
}, ha, domain.NewLightCache(time.Hour, ha.ListLights), "llama3:latest", slog.Default())
|
|
|
|
got, err := app.Query(context.Background(), "hello", "qwen3:latest")
|
|
if err != nil {
|
|
t.Fatalf("Query() error = %v", err)
|
|
}
|
|
if got.ModelUsed != "qwen3:latest" {
|
|
t.Fatalf("ModelUsed = %q", got.ModelUsed)
|
|
}
|
|
}
|
|
|
|
func TestQueryAppListModels(t *testing.T) {
|
|
app := NewQueryApp(&fakeLLM{
|
|
generate: func(ctx context.Context, model, prompt string) (string, error) {
|
|
return "", nil
|
|
},
|
|
listModels: func(ctx context.Context) ([]string, error) {
|
|
return []string{"llama3:latest", "qwen3:latest"}, nil
|
|
},
|
|
}, &fakeHA{}, domain.NewLightCache(time.Hour, func(context.Context) ([]driven.Light, error) { return nil, nil }), "llama3:latest", slog.Default())
|
|
|
|
got, err := app.ListModels(context.Background())
|
|
if err != nil {
|
|
t.Fatalf("ListModels() error = %v", err)
|
|
}
|
|
if !reflect.DeepEqual(got, []string{"llama3:latest", "qwen3:latest"}) {
|
|
t.Fatalf("ListModels() = %#v", got)
|
|
}
|
|
}
|