From ad50d641bd0f852d7c324a36a90dfd88fe7ab06a Mon Sep 17 00:00:00 2001 From: Nik Afiq Date: Tue, 21 Apr 2026 22:52:00 +0900 Subject: [PATCH] feat: enhance AI model management in Discord bot - Updated LLMClient interface to support model-specific generation and model listing. - Integrated model store and validator into the command application for managing AI models. - Implemented commands for setting, getting, and listing active AI models in Discord. - Enhanced AI query handling to utilize the selected model and return model information in responses. - Added caching mechanism for model validation to improve performance. - Introduced gRPC methods for listing available AI models in the ai-gateway. - Updated protobuf definitions to include model-related fields and messages. - Added tests for model store and validator functionalities. --- ai-gateway/cmd/gateway/main.go | 4 +- .../internal/adapters/primary/grpc/server.go | 12 +- .../adapters/secondary/ollama/client.go | 62 ++++++++- .../adapters/secondary/ollama/client_test.go | 81 +++++++++++ ai-gateway/internal/app/query.go | 39 ++++-- ai-gateway/internal/app/query_test.go | 86 +++++++++--- ai-gateway/internal/core/ports/driven/llm.go | 3 +- discord-bot/cmd/bot/main.go | 6 +- .../adapters/primary/discord/handler.go | 85 +++++++++++- .../adapters/primary/discord/register.go | 31 +++++ .../adapters/secondary/aigateway/client.go | 23 +++- discord-bot/internal/app/command.go | 77 ++++++++++- discord-bot/internal/app/command_test.go | 66 +++++++-- discord-bot/internal/core/ports/driven/ai.go | 3 +- discord-bot/internal/modelstore/store.go | 28 ++++ discord-bot/internal/modelstore/store_test.go | 15 ++ .../internal/modelvalidator/validator.go | 86 ++++++++++++ .../internal/modelvalidator/validator_test.go | 70 ++++++++++ gen/ai/v1/ai.pb.go | 130 ++++++++++++++++-- gen/ai/v1/ai_grpc.pb.go | 40 +++++- proto/ai/v1/ai.proto | 9 ++ 21 files changed, 876 insertions(+), 80 deletions(-) create mode 100644 ai-gateway/internal/adapters/secondary/ollama/client_test.go create mode 100644 discord-bot/internal/modelstore/store.go create mode 100644 discord-bot/internal/modelstore/store_test.go create mode 100644 discord-bot/internal/modelvalidator/validator.go create mode 100644 discord-bot/internal/modelvalidator/validator_test.go diff --git a/ai-gateway/cmd/gateway/main.go b/ai-gateway/cmd/gateway/main.go index 357c964..1568370 100644 --- a/ai-gateway/cmd/gateway/main.go +++ b/ai-gateway/cmd/gateway/main.go @@ -67,7 +67,7 @@ func main() { os.Exit(1) } - ollamaClient := ollama.New(cfg.OllamaURL, cfg.OllamaModel, &http.Client{Timeout: cfg.OllamaTimeout}) + ollamaClient := ollama.New(cfg.OllamaURL, &http.Client{Timeout: cfg.OllamaTimeout}) haClient, err := hagateway.New(ctx, cfg.HAGatewayAddr, cfg.TLSDir, cfg.HAGatewayServerName, log) if err != nil { log.Error("ha-gateway client setup failed", "err", err) @@ -80,7 +80,7 @@ func main() { }() lightCache := domain.NewLightCache(cfg.LightCacheTTL, haClient.ListLights) - queryApp := app.NewQueryApp(ollamaClient, haClient, lightCache, log) + queryApp := app.NewQueryApp(ollamaClient, haClient, lightCache, cfg.OllamaModel, log) serverOpts := []grpc.ServerOption{ grpc.StatsHandler(otelgrpc.NewServerHandler()), diff --git a/ai-gateway/internal/adapters/primary/grpc/server.go b/ai-gateway/internal/adapters/primary/grpc/server.go index 147ce36..0417f2e 100644 --- a/ai-gateway/internal/adapters/primary/grpc/server.go +++ b/ai-gateway/internal/adapters/primary/grpc/server.go @@ -33,7 +33,7 @@ func (s *Server) Query(ctx context.Context, req *aiv1.QueryRequest) (*aiv1.Query ctx = logger.WithLogger(ctx, log) } - result, err := s.app.Query(ctx, req.GetText()) + result, err := s.app.Query(ctx, req.GetText(), req.GetModel()) if err != nil { return nil, status.Errorf(codes.Unavailable, "query failed: %v", err) } @@ -41,5 +41,15 @@ func (s *Server) Query(ctx context.Context, req *aiv1.QueryRequest) (*aiv1.Query Reply: result.Reply, Intent: result.Intent, ActionTaken: result.ActionTaken, + ModelUsed: result.ModelUsed, }, nil } + +// ListModels returns the installed model names from Ollama. +func (s *Server) ListModels(ctx context.Context, _ *aiv1.ListModelsRequest) (*aiv1.ListModelsResponse, error) { + names, err := s.app.ListModels(ctx) + if err != nil { + return nil, status.Errorf(codes.Unavailable, "list models: %v", err) + } + return &aiv1.ListModelsResponse{Names: names}, nil +} diff --git a/ai-gateway/internal/adapters/secondary/ollama/client.go b/ai-gateway/internal/adapters/secondary/ollama/client.go index 35a9b08..25e7ae9 100644 --- a/ai-gateway/internal/adapters/secondary/ollama/client.go +++ b/ai-gateway/internal/adapters/secondary/ollama/client.go @@ -6,6 +6,7 @@ import ( "encoding/json" "fmt" "net/http" + "strings" "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" ) @@ -14,37 +15,49 @@ type generateRequest struct { Model string `json:"model"` Prompt string `json:"prompt"` Stream bool `json:"stream"` + Think *bool `json:"think,omitempty"` } type generateResponse struct { Response string `json:"response"` } +type listModelsResponse struct { + Models []struct { + Name string `json:"name"` + } `json:"models"` +} + // Client implements the LLM driven port with the Ollama generate API. type Client struct { baseURL string - model string http *http.Client } // New constructs an Ollama client with OTel-instrumented transport. -func New(baseURL, model string, httpClient *http.Client) *Client { +func New(baseURL string, httpClient *http.Client) *Client { if httpClient == nil { httpClient = &http.Client{Transport: otelhttp.NewTransport(http.DefaultTransport)} } if httpClient.Transport == nil { httpClient.Transport = otelhttp.NewTransport(http.DefaultTransport) } - return &Client{baseURL: baseURL, model: model, http: httpClient} + return &Client{baseURL: baseURL, http: httpClient} } // Generate sends one non-streaming prompt to Ollama. -func (c *Client) Generate(ctx context.Context, prompt string) (string, error) { - body, err := json.Marshal(generateRequest{ - Model: c.model, +func (c *Client) Generate(ctx context.Context, model, prompt string) (string, error) { + reqBody := generateRequest{ + Model: model, Prompt: prompt, Stream: false, - }) + } + if isThinkingModel(model) { + disabled := false + reqBody.Think = &disabled + } + + body, err := json.Marshal(reqBody) if err != nil { return "", fmt.Errorf("marshal ollama request: %w", err) } @@ -71,3 +84,38 @@ func (c *Client) Generate(ctx context.Context, prompt string) (string, error) { } return out.Response, nil } + +// ListModels returns the installed model names from Ollama. +func (c *Client) ListModels(ctx context.Context) ([]string, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.baseURL+"/api/tags", nil) + if err != nil { + return nil, fmt.Errorf("build ollama list models request: %w", err) + } + + resp, err := c.http.Do(req) + if err != nil { + return nil, fmt.Errorf("list ollama models: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, fmt.Errorf("ollama returned status %s", resp.Status) + } + + var out listModelsResponse + if err := json.NewDecoder(resp.Body).Decode(&out); err != nil { + return nil, fmt.Errorf("decode ollama list models response: %w", err) + } + + names := make([]string, 0, len(out.Models)) + for _, model := range out.Models { + if model.Name != "" { + names = append(names, model.Name) + } + } + return names, nil +} + +func isThinkingModel(name string) bool { + return strings.HasPrefix(name, "qwen3") +} diff --git a/ai-gateway/internal/adapters/secondary/ollama/client_test.go b/ai-gateway/internal/adapters/secondary/ollama/client_test.go new file mode 100644 index 0000000..1813bea --- /dev/null +++ b/ai-gateway/internal/adapters/secondary/ollama/client_test.go @@ -0,0 +1,81 @@ +package ollama + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" +) + +func TestIsThinkingModel(t *testing.T) { + tests := []struct { + name string + want bool + }{ + {name: "qwen3", want: true}, + {name: "qwen3:4b", want: true}, + {name: "qwen3:latest", want: true}, + {name: "llama3", want: false}, + {name: "mistral", want: false}, + } + for _, tt := range tests { + if got := isThinkingModel(tt.name); got != tt.want { + t.Fatalf("isThinkingModel(%q) = %v, want %v", tt.name, got, tt.want) + } + } +} + +func TestGenerateSetsThinkForQwen3(t *testing.T) { + var body map[string]any + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + t.Fatalf("Decode() error = %v", err) + } + _, _ = w.Write([]byte(`{"response":"ok"}`)) + })) + defer srv.Close() + + client := New(srv.URL, srv.Client()) + if _, err := client.Generate(context.Background(), "qwen3:latest", "prompt"); err != nil { + t.Fatalf("Generate() error = %v", err) + } + if body["think"] != false { + t.Fatalf("think = %#v, want false", body["think"]) + } +} + +func TestGenerateOmitsThinkForLlama3(t *testing.T) { + var body map[string]any + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + t.Fatalf("Decode() error = %v", err) + } + _, _ = w.Write([]byte(`{"response":"ok"}`)) + })) + defer srv.Close() + + client := New(srv.URL, srv.Client()) + if _, err := client.Generate(context.Background(), "llama3:latest", "prompt"); err != nil { + t.Fatalf("Generate() error = %v", err) + } + if _, ok := body["think"]; ok { + t.Fatalf("unexpected think key = %#v", body["think"]) + } +} + +func TestListModelsReturnsNamesOnly(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte(`{"models":[{"name":"llama3:latest"},{"name":"qwen3:latest"}]}`)) + })) + defer srv.Close() + + client := New(srv.URL, srv.Client()) + got, err := client.ListModels(context.Background()) + if err != nil { + t.Fatalf("ListModels() error = %v", err) + } + if len(got) != 2 || got[0] != "llama3:latest" || got[1] != "qwen3:latest" { + t.Fatalf("ListModels() = %#v", got) + } +} diff --git a/ai-gateway/internal/app/query.go b/ai-gateway/internal/app/query.go index c5ae649..976a6fa 100644 --- a/ai-gateway/internal/app/query.go +++ b/ai-gateway/internal/app/query.go @@ -18,34 +18,40 @@ type QueryResult struct { Reply string Intent string ActionTaken bool + ModelUsed string } // QueryApp orchestrates one AI query request. type QueryApp struct { - llm driven.LLMClient - ha driven.HAClient - cache *domain.LightCache - log *slog.Logger + llm driven.LLMClient + ha driven.HAClient + cache *domain.LightCache + defaultModel string + log *slog.Logger } // NewQueryApp constructs the AI query application service. -func NewQueryApp(llm driven.LLMClient, ha driven.HAClient, cache *domain.LightCache, log *slog.Logger) *QueryApp { - return &QueryApp{llm: llm, ha: ha, cache: cache, log: log} +func NewQueryApp(llm driven.LLMClient, ha driven.HAClient, cache *domain.LightCache, defaultModel string, log *slog.Logger) *QueryApp { + return &QueryApp{llm: llm, ha: ha, cache: cache, defaultModel: defaultModel, log: log} } // Query runs the full intent parsing and dispatch flow for one user request. -func (a *QueryApp) Query(ctx context.Context, text string) (QueryResult, error) { +func (a *QueryApp) Query(ctx context.Context, text, model string) (QueryResult, error) { + if model == "" { + model = a.defaultModel + } lights, err := a.cache.Get(ctx) if err != nil { a.log.Error("light cache refresh failed", "err", err) return QueryResult{ Reply: "I couldn't reach Home Assistant right now.", ActionTaken: false, + ModelUsed: model, }, nil } prompt := domain.BuildPrompt(text, promptLightLines(lights)) - raw, err := a.llm.Generate(ctx, prompt) + raw, err := a.llm.Generate(ctx, model, prompt) if err != nil { return QueryResult{}, err } @@ -57,6 +63,7 @@ func (a *QueryApp) Query(ctx context.Context, text string) (QueryResult, error) Reply: "I didn't understand that.", Intent: domain.IntentNone, ActionTaken: false, + ModelUsed: model, }, nil } @@ -64,7 +71,7 @@ func (a *QueryApp) Query(ctx context.Context, text string) (QueryResult, error) case domain.IntentTurnOnLight: entityID, ok := resolveLightEntity(intent.Entity, lights) if !ok { - return QueryResult{Reply: "I couldn't find that light.", Intent: intent.Name}, nil + return QueryResult{Reply: "I couldn't find that light.", Intent: intent.Name, ModelUsed: model}, nil } params, err := ParseLightParams(intent.Params) if err != nil { @@ -72,6 +79,7 @@ func (a *QueryApp) Query(ctx context.Context, text string) (QueryResult, error) Reply: "I couldn't understand the light settings.", Intent: intent.Name, ActionTaken: false, + ModelUsed: model, }, nil } if err := a.ha.TurnOnLight(ctx, entityID, params); err != nil { @@ -80,17 +88,19 @@ func (a *QueryApp) Query(ctx context.Context, text string) (QueryResult, error) Reply: "I couldn't reach Home Assistant right now.", Intent: intent.Name, ActionTaken: false, + ModelUsed: model, }, nil } return QueryResult{ Reply: fallbackReply(intent.Reply, fmt.Sprintf("Turned on `%s`.", displayLightName(entityID, lights))), Intent: intent.Name, ActionTaken: true, + ModelUsed: model, }, nil case domain.IntentTurnOffLight: entityID, ok := resolveLightEntity(intent.Entity, lights) if !ok { - return QueryResult{Reply: "I couldn't find that light.", Intent: intent.Name}, nil + return QueryResult{Reply: "I couldn't find that light.", Intent: intent.Name, ModelUsed: model}, nil } if err := a.ha.TurnOffLight(ctx, entityID); err != nil { a.log.Error("turn off light failed", "entity_id", entityID, "err", err) @@ -98,18 +108,21 @@ func (a *QueryApp) Query(ctx context.Context, text string) (QueryResult, error) Reply: "I couldn't reach Home Assistant right now.", Intent: intent.Name, ActionTaken: false, + ModelUsed: model, }, nil } return QueryResult{ Reply: fallbackReply(intent.Reply, fmt.Sprintf("Turned off `%s`.", displayLightName(entityID, lights))), Intent: intent.Name, ActionTaken: true, + ModelUsed: model, }, nil case domain.IntentListLights: return QueryResult{ Reply: formatLightListReply(lights), Intent: intent.Name, ActionTaken: false, + ModelUsed: model, }, nil case domain.IntentNone: fallthrough @@ -118,10 +131,16 @@ func (a *QueryApp) Query(ctx context.Context, text string) (QueryResult, error) Reply: fallbackReply(intent.Reply, "I didn't understand that."), Intent: intent.Name, ActionTaken: false, + ModelUsed: model, }, nil } } +// ListModels returns the currently installed model names. +func (a *QueryApp) ListModels(ctx context.Context) ([]string, error) { + return a.llm.ListModels(ctx) +} + func promptLightLines(lights []driven.Light) []string { lines := make([]string, 0, len(lights)) for _, light := range lights { diff --git a/ai-gateway/internal/app/query_test.go b/ai-gateway/internal/app/query_test.go index f4312dc..21ff6a6 100644 --- a/ai-gateway/internal/app/query_test.go +++ b/ai-gateway/internal/app/query_test.go @@ -13,11 +13,19 @@ import ( ) type fakeLLM struct { - generate func(context.Context, string) (string, error) + generate func(context.Context, string, string) (string, error) + listModels func(context.Context) ([]string, error) } -func (f *fakeLLM) Generate(ctx context.Context, prompt string) (string, error) { - return f.generate(ctx, prompt) +func (f *fakeLLM) Generate(ctx context.Context, model, prompt string) (string, error) { + return f.generate(ctx, model, prompt) +} + +func (f *fakeLLM) ListModels(ctx context.Context) ([]string, error) { + if f.listModels == nil { + return nil, nil + } + return f.listModels(ctx) } type fakeHA struct { @@ -54,12 +62,15 @@ func TestQueryAppTurnOnLight(t *testing.T) { ha := &fakeHA{lights: []driven.Light{{EntityID: "light.kitchen", FriendlyName: "Kitchen", State: "off"}}} cache := domain.NewLightCache(time.Hour, ha.ListLights) app := NewQueryApp(&fakeLLM{ - generate: func(ctx context.Context, prompt string) (string, error) { + generate: func(ctx context.Context, model, prompt string) (string, error) { + if model != "llama3:latest" { + t.Fatalf("Generate() model = %q", model) + } return `{"intent":"turn_on_light","entity":"Kitchen","params":{"brightness":"80"},"reply":"Turning on Kitchen."}`, nil }, - }, ha, cache, slog.Default()) + }, ha, cache, "llama3:latest", slog.Default()) - got, err := app.Query(context.Background(), "turn on kitchen") + got, err := app.Query(context.Background(), "turn on kitchen", "") if err != nil { t.Fatalf("Query() error = %v", err) } @@ -77,12 +88,12 @@ func TestQueryAppTurnOnLight(t *testing.T) { func TestQueryAppInvalidJSON(t *testing.T) { ha := &fakeHA{lights: []driven.Light{{EntityID: "light.kitchen", FriendlyName: "Kitchen", State: "off"}}} app := NewQueryApp(&fakeLLM{ - generate: func(ctx context.Context, prompt string) (string, error) { + generate: func(ctx context.Context, model, prompt string) (string, error) { return `not-json`, nil }, - }, ha, domain.NewLightCache(time.Hour, ha.ListLights), slog.Default()) + }, ha, domain.NewLightCache(time.Hour, ha.ListLights), "llama3:latest", slog.Default()) - got, err := app.Query(context.Background(), "turn on kitchen") + got, err := app.Query(context.Background(), "turn on kitchen", "") if err != nil { t.Fatalf("Query() error = %v", err) } @@ -97,12 +108,12 @@ func TestQueryAppInvalidJSON(t *testing.T) { func TestQueryAppIntentNone(t *testing.T) { ha := &fakeHA{lights: []driven.Light{{EntityID: "light.kitchen", FriendlyName: "Kitchen", State: "off"}}} app := NewQueryApp(&fakeLLM{ - generate: func(ctx context.Context, prompt string) (string, error) { + generate: func(ctx context.Context, model, prompt string) (string, error) { return `{"intent":"none","entity":"","params":{},"reply":"Hello there."}`, nil }, - }, ha, domain.NewLightCache(time.Hour, ha.ListLights), slog.Default()) + }, ha, domain.NewLightCache(time.Hour, ha.ListLights), "llama3:latest", slog.Default()) - got, err := app.Query(context.Background(), "hello") + got, err := app.Query(context.Background(), "hello", "") if err != nil { t.Fatalf("Query() error = %v", err) } @@ -117,12 +128,12 @@ func TestQueryAppHAFailure(t *testing.T) { turnOnErr: errors.New("boom"), } app := NewQueryApp(&fakeLLM{ - generate: func(ctx context.Context, prompt string) (string, error) { + generate: func(ctx context.Context, model, prompt string) (string, error) { return `{"intent":"turn_on_light","entity":"light.kitchen","params":{},"reply":"Turning on Kitchen."}`, nil }, - }, ha, domain.NewLightCache(time.Hour, ha.ListLights), slog.Default()) + }, ha, domain.NewLightCache(time.Hour, ha.ListLights), "llama3:latest", slog.Default()) - got, err := app.Query(context.Background(), "turn on kitchen") + got, err := app.Query(context.Background(), "turn on kitchen", "") if err != nil { t.Fatalf("Query() error = %v", err) } @@ -134,12 +145,12 @@ func TestQueryAppHAFailure(t *testing.T) { func TestQueryAppListLights(t *testing.T) { ha := &fakeHA{lights: []driven.Light{{EntityID: "light.kitchen", FriendlyName: "Kitchen", State: "on"}}} app := NewQueryApp(&fakeLLM{ - generate: func(ctx context.Context, prompt string) (string, error) { + generate: func(ctx context.Context, model, prompt string) (string, error) { return `{"intent":"list_lights","entity":"","params":{},"reply":""}`, nil }, - }, ha, domain.NewLightCache(time.Hour, ha.ListLights), slog.Default()) + }, ha, domain.NewLightCache(time.Hour, ha.ListLights), "llama3:latest", slog.Default()) - got, err := app.Query(context.Background(), "what lights exist") + got, err := app.Query(context.Background(), "what lights exist", "") if err != nil { t.Fatalf("Query() error = %v", err) } @@ -164,3 +175,42 @@ func TestLightCacheRefreshAfterTTL(t *testing.T) { t.Fatalf("ListLights calls = %d, want at least 2", ha.listCalls) } } + +func TestQueryAppExplicitModel(t *testing.T) { + ha := &fakeHA{lights: []driven.Light{{EntityID: "light.kitchen", FriendlyName: "Kitchen", State: "off"}}} + app := NewQueryApp(&fakeLLM{ + generate: func(ctx context.Context, model, prompt string) (string, error) { + if model != "qwen3:latest" { + t.Fatalf("Generate() model = %q", model) + } + return `{"intent":"none","entity":"","params":{},"reply":"Hello there."}`, nil + }, + }, ha, domain.NewLightCache(time.Hour, ha.ListLights), "llama3:latest", slog.Default()) + + got, err := app.Query(context.Background(), "hello", "qwen3:latest") + if err != nil { + t.Fatalf("Query() error = %v", err) + } + if got.ModelUsed != "qwen3:latest" { + t.Fatalf("ModelUsed = %q", got.ModelUsed) + } +} + +func TestQueryAppListModels(t *testing.T) { + app := NewQueryApp(&fakeLLM{ + generate: func(ctx context.Context, model, prompt string) (string, error) { + return "", nil + }, + listModels: func(ctx context.Context) ([]string, error) { + return []string{"llama3:latest", "qwen3:latest"}, nil + }, + }, &fakeHA{}, domain.NewLightCache(time.Hour, func(context.Context) ([]driven.Light, error) { return nil, nil }), "llama3:latest", slog.Default()) + + got, err := app.ListModels(context.Background()) + if err != nil { + t.Fatalf("ListModels() error = %v", err) + } + if !reflect.DeepEqual(got, []string{"llama3:latest", "qwen3:latest"}) { + t.Fatalf("ListModels() = %#v", got) + } +} diff --git a/ai-gateway/internal/core/ports/driven/llm.go b/ai-gateway/internal/core/ports/driven/llm.go index 4a5ad06..1e53ffc 100644 --- a/ai-gateway/internal/core/ports/driven/llm.go +++ b/ai-gateway/internal/core/ports/driven/llm.go @@ -4,5 +4,6 @@ import "context" // LLMClient generates one model response for a prompt. type LLMClient interface { - Generate(ctx context.Context, prompt string) (string, error) + Generate(ctx context.Context, model, prompt string) (string, error) + ListModels(ctx context.Context) ([]string, error) } diff --git a/discord-bot/cmd/bot/main.go b/discord-bot/cmd/bot/main.go index b548d6e..55b959f 100644 --- a/discord-bot/cmd/bot/main.go +++ b/discord-bot/cmd/bot/main.go @@ -17,6 +17,8 @@ import ( "gitea.nik4nao.com/nik/home-services/discord-bot/internal/app" "gitea.nik4nao.com/nik/home-services/discord-bot/internal/config" "gitea.nik4nao.com/nik/home-services/discord-bot/internal/logger" + "gitea.nik4nao.com/nik/home-services/discord-bot/internal/modelstore" + "gitea.nik4nao.com/nik/home-services/discord-bot/internal/modelvalidator" "gitea.nik4nao.com/nik/home-services/discord-bot/internal/telemetry" ) @@ -84,7 +86,9 @@ func main() { } }() - commandApp := app.NewCommandApp(haClient, aiClient) + modelStore := modelstore.New() + validator := modelvalidator.New(aiClient, 30*time.Second) + commandApp := app.NewCommandApp(haClient, aiClient, modelStore, validator) // Discord-specific wiring stays at the edge so the app layer remains transport-agnostic. session, err := discordgo.New("Bot " + cfg.DiscordToken) diff --git a/discord-bot/internal/adapters/primary/discord/handler.go b/discord-bot/internal/adapters/primary/discord/handler.go index f6180ce..5d6bf9e 100644 --- a/discord-bot/internal/adapters/primary/discord/handler.go +++ b/discord-bot/internal/adapters/primary/discord/handler.go @@ -24,8 +24,12 @@ type commandHandler interface { HandleLightToggle(ctx context.Context, entityID string) (string, error) HandleSwitchList(ctx context.Context) (string, error) HandleAIQuery(ctx context.Context, text string) (string, error) + HandleAIModelSet(ctx context.Context, name string) (string, error) + HandleAIModelGet(ctx context.Context) (string, error) + HandleAIModelList(ctx context.Context) (string, error) AutocompleteLights(ctx context.Context) ([]apppkg.Choice, error) AutocompleteSwitches(ctx context.Context) ([]apppkg.Choice, error) + AutocompleteAIModels(ctx context.Context) ([]apppkg.Choice, error) } // Handler adapts Discord interactions to the command application layer. @@ -69,6 +73,9 @@ func (h *Handler) handleApplicationCommand(ctx context.Context, s *discordgo.Ses command := data.Name if len(data.Options) > 0 { command += "." + data.Options[0].Name + if len(data.Options[0].Options) > 0 && data.Options[0].Type == discordgo.ApplicationCommandOptionSubCommandGroup { + command += "." + data.Options[0].Options[0].Name + } } user := "" if i.Member != nil && i.Member.User != nil { @@ -90,7 +97,18 @@ func (h *Handler) handleApplicationCommand(ctx context.Context, s *discordgo.Ses } sub := data.Options[0] - switch data.Name + "." + sub.Name { + commandPath := data.Name + "." + sub.Name + target := sub + if sub.Type == discordgo.ApplicationCommandOptionSubCommandGroup { + if len(sub.Options) == 0 { + h.respondError(ctx, s, i.Interaction, true, start, fmt.Errorf("missing grouped subcommand")) + return + } + target = sub.Options[0] + commandPath += "." + target.Name + } + + switch commandPath { case "light.list": msg, err := h.app.HandleLightList(ctx) if err != nil { @@ -145,10 +163,38 @@ func (h *Handler) handleApplicationCommand(ctx context.Context, s *discordgo.Ses ) return } - msg, err := h.app.HandleAIQuery(ctx, requiredStringOption(sub, "text")) + msg, err := h.app.HandleAIQuery(ctx, requiredStringOption(target, "text")) + h.followup(ctx, s, i.Interaction, msg, true, start, err) + case "ai.model.set": + if err := h.deferResponse(s, i.Interaction, true); err != nil { + log.Error("discord response failed", + "duration_ms", time.Since(start).Milliseconds(), + "error", err.Error(), + ) + return + } + msg, err := h.app.HandleAIModelSet(ctx, requiredStringOption(target, "name")) + h.followup(ctx, s, i.Interaction, msg, true, start, err) + case "ai.model.get": + msg, err := h.app.HandleAIModelGet(ctx) + if err != nil { + h.respondError(ctx, s, i.Interaction, true, start, err) + return + } + h.respondMessage(ctx, s, i.Interaction, msg, true) + log.Info("command handled", "duration_ms", time.Since(start).Milliseconds()) + case "ai.model.list": + if err := h.deferResponse(s, i.Interaction, true); err != nil { + log.Error("discord response failed", + "duration_ms", time.Since(start).Milliseconds(), + "error", err.Error(), + ) + return + } + msg, err := h.app.HandleAIModelList(ctx) h.followup(ctx, s, i.Interaction, msg, true, start, err) default: - h.respondError(ctx, s, i.Interaction, true, start, fmt.Errorf("unsupported command: %s.%s", data.Name, sub.Name)) + h.respondError(ctx, s, i.Interaction, true, start, fmt.Errorf("unsupported command: %s", commandPath)) } } @@ -167,6 +213,10 @@ func (h *Handler) handleAutocomplete(ctx context.Context, s *discordgo.Session, choices, err = h.app.AutocompleteLights(ctx) case "switch": choices, err = h.app.AutocompleteSwitches(ctx) + case "ai": + if focusedOptionName(data) == "name" { + choices, err = h.app.AutocompleteAIModels(ctx) + } default: choices = nil } @@ -299,6 +349,15 @@ func optionalUint32Option(sub *discordgo.ApplicationCommandInteractionDataOption func focusedOptionValue(data discordgo.ApplicationCommandInteractionData) string { for _, sub := range data.Options { + if sub.Type == discordgo.ApplicationCommandOptionSubCommandGroup { + for _, nested := range sub.Options { + for _, opt := range nested.Options { + if opt.Focused { + return opt.StringValue() + } + } + } + } for _, opt := range sub.Options { if opt.Focused { return opt.StringValue() @@ -307,3 +366,23 @@ func focusedOptionValue(data discordgo.ApplicationCommandInteractionData) string } return "" } + +func focusedOptionName(data discordgo.ApplicationCommandInteractionData) string { + for _, sub := range data.Options { + if sub.Type == discordgo.ApplicationCommandOptionSubCommandGroup { + for _, nested := range sub.Options { + for _, opt := range nested.Options { + if opt.Focused { + return opt.Name + } + } + } + } + for _, opt := range sub.Options { + if opt.Focused { + return opt.Name + } + } + } + return "" +} diff --git a/discord-bot/internal/adapters/primary/discord/register.go b/discord-bot/internal/adapters/primary/discord/register.go index 3cbe804..dc9f6a1 100644 --- a/discord-bot/internal/adapters/primary/discord/register.go +++ b/discord-bot/internal/adapters/primary/discord/register.go @@ -98,6 +98,37 @@ func RegisterCommands(s *discordgo.Session, guildID string) error { }, }, }, + { + Type: discordgo.ApplicationCommandOptionSubCommandGroup, + Name: "model", + Description: "Manage the active AI model", + Options: []*discordgo.ApplicationCommandOption{ + { + Type: discordgo.ApplicationCommandOptionSubCommand, + Name: "set", + Description: "Set the active model", + Options: []*discordgo.ApplicationCommandOption{ + { + Type: discordgo.ApplicationCommandOptionString, + Name: "name", + Description: "Model name", + Required: true, + Autocomplete: true, + }, + }, + }, + { + Type: discordgo.ApplicationCommandOptionSubCommand, + Name: "get", + Description: "Show the active model", + }, + { + Type: discordgo.ApplicationCommandOptionSubCommand, + Name: "list", + Description: "List available models", + }, + }, + }, }, }, } diff --git a/discord-bot/internal/adapters/secondary/aigateway/client.go b/discord-bot/internal/adapters/secondary/aigateway/client.go index 4e19516..1999fa9 100644 --- a/discord-bot/internal/adapters/secondary/aigateway/client.go +++ b/discord-bot/internal/adapters/secondary/aigateway/client.go @@ -61,22 +61,39 @@ func (c *Client) Close() error { } // Query forwards one free-form request to ai-gateway. -func (c *Client) Query(ctx context.Context, text string) (string, error) { +func (c *Client) Query(ctx context.Context, text, model string) (string, string, error) { start := time.Now() log := logger.FromContext(ctx).With("grpc.method", "AIService/Query") resp, err := c.client.Query(ctx, &aiv1.QueryRequest{ Text: text, Source: "discord-bot", + Model: model, }) if err != nil { log.Error("grpc call failed", "duration_ms", time.Since(start).Milliseconds(), "error", err.Error(), ) - return "", fmt.Errorf("query ai-gateway: %w", err) + return "", "", fmt.Errorf("query ai-gateway: %w", err) } log.Debug("grpc call completed", "duration_ms", time.Since(start).Milliseconds()) - return resp.GetReply(), nil + return resp.GetReply(), resp.GetModelUsed(), nil +} + +// ListModels returns the installed model names from ai-gateway. +func (c *Client) ListModels(ctx context.Context) ([]string, error) { + start := time.Now() + log := logger.FromContext(ctx).With("grpc.method", "AIService/ListModels") + resp, err := c.client.ListModels(ctx, &aiv1.ListModelsRequest{}) + if err != nil { + log.Error("grpc call failed", + "duration_ms", time.Since(start).Milliseconds(), + "error", err.Error(), + ) + return nil, fmt.Errorf("list ai-gateway models: %w", err) + } + log.Debug("grpc call completed", "duration_ms", time.Since(start).Milliseconds()) + return append([]string(nil), resp.GetNames()...), nil } func loadTransportCredentials(tlsDir string) (credentials.TransportCredentials, error) { diff --git a/discord-bot/internal/app/command.go b/discord-bot/internal/app/command.go index 454f35b..32f9bfd 100644 --- a/discord-bot/internal/app/command.go +++ b/discord-bot/internal/app/command.go @@ -7,6 +7,8 @@ import ( "strings" "gitea.nik4nao.com/nik/home-services/discord-bot/internal/core/ports/driven" + "gitea.nik4nao.com/nik/home-services/discord-bot/internal/modelstore" + "gitea.nik4nao.com/nik/home-services/discord-bot/internal/modelvalidator" ) // Choice is one Discord autocomplete entry. @@ -17,13 +19,15 @@ type Choice struct { // CommandApp orchestrates Discord command use cases against ha-gateway. type CommandApp struct { - ha driven.HAGateway - ai driven.AIGateway + ha driven.HAGateway + ai driven.AIGateway + models *modelstore.Store + validator *modelvalidator.Validator } // NewCommandApp constructs the Discord command application service. -func NewCommandApp(ha driven.HAGateway, ai driven.AIGateway) *CommandApp { - return &CommandApp{ha: ha, ai: ai} +func NewCommandApp(ha driven.HAGateway, ai driven.AIGateway, models *modelstore.Store, validator *modelvalidator.Validator) *CommandApp { + return &CommandApp{ha: ha, ai: ai, models: models, validator: validator} } // HandleLightList formats discovered lights into a monospace-friendly response. @@ -124,11 +128,72 @@ func (a *CommandApp) HandleSwitchList(ctx context.Context) (string, error) { // HandleAIQuery forwards a free-form request to ai-gateway. func (a *CommandApp) HandleAIQuery(ctx context.Context, text string) (string, error) { - reply, err := a.ai.Query(ctx, text) + reply, modelUsed, err := a.ai.Query(ctx, text, a.models.Get()) if err != nil { return "", fmt.Errorf("handle ai query: %w", err) } - return reply, nil + return fmt.Sprintf("%s\n\n_(via %s)_", reply, modelUsed), nil +} + +// HandleAIModelSet validates and stores the selected model globally. +func (a *CommandApp) HandleAIModelSet(ctx context.Context, name string) (string, error) { + canonical, err := a.validator.Normalize(ctx, name) + if err != nil { + if err.Error() == "ambiguous model name" { + return "", fmt.Errorf("unknown model: %s. Be more specific.", name) + } + if err.Error() == "unknown model" { + return "", fmt.Errorf("unknown model: %s. Try /ai model list.", name) + } + return "", fmt.Errorf("validate model: %w", err) + } + a.models.Set(canonical) + return fmt.Sprintf("Active model set to `%s`.", canonical), nil +} + +// HandleAIModelGet reports the current selected model or default state. +func (a *CommandApp) HandleAIModelGet(ctx context.Context) (string, error) { + cur := a.models.Get() + if cur == "" { + return "No model override set. Using ai-gateway default.", nil + } + return fmt.Sprintf("Active model: `%s`", cur), nil +} + +// HandleAIModelList shows the installed models and marks the active selection. +func (a *CommandApp) HandleAIModelList(ctx context.Context) (string, error) { + models, err := a.validator.Known(ctx) + if err != nil { + return "", fmt.Errorf("list models: %w", err) + } + if len(models) == 0 { + return "No models installed on the Ollama host.", nil + } + + active := a.models.Get() + lines := make([]string, 0, len(models)+1) + lines = append(lines, "Available models:") + for _, model := range models { + marker := "" + if model == active { + marker = " <- active" + } + lines = append(lines, fmt.Sprintf("- `%s`%s", model, marker)) + } + return strings.Join(lines, "\n"), nil +} + +// AutocompleteAIModels returns model names for the /ai model set command. +func (a *CommandApp) AutocompleteAIModels(ctx context.Context) ([]Choice, error) { + models, err := a.validator.Known(ctx) + if err != nil { + return nil, fmt.Errorf("autocomplete ai models: %w", err) + } + choices := make([]Choice, 0, len(models)) + for _, model := range models { + choices = append(choices, Choice{Label: model, Value: model}) + } + return choices, nil } // AutocompleteLights maps discovered lights into Discord autocomplete choices. diff --git a/discord-bot/internal/app/command_test.go b/discord-bot/internal/app/command_test.go index e839385..bff3d67 100644 --- a/discord-bot/internal/app/command_test.go +++ b/discord-bot/internal/app/command_test.go @@ -5,8 +5,11 @@ import ( "errors" "reflect" "testing" + "time" "gitea.nik4nao.com/nik/home-services/discord-bot/internal/core/ports/driven" + "gitea.nik4nao.com/nik/home-services/discord-bot/internal/modelstore" + "gitea.nik4nao.com/nik/home-services/discord-bot/internal/modelvalidator" ) type mockHAGateway struct { @@ -18,7 +21,8 @@ type mockHAGateway struct { } type mockAIGateway struct { - queryFunc func(ctx context.Context, text string) (string, error) + queryFunc func(ctx context.Context, text, model string) (string, string, error) + listModelsFunc func(ctx context.Context) ([]string, error) } func (m *mockHAGateway) ListLights(ctx context.Context) ([]driven.Light, error) { @@ -56,11 +60,22 @@ func (m *mockHAGateway) ToggleLight(ctx context.Context, entityID string) error return m.toggleLightFunc(ctx, entityID) } -func (m *mockAIGateway) Query(ctx context.Context, text string) (string, error) { +func (m *mockAIGateway) Query(ctx context.Context, text, model string) (string, string, error) { if m.queryFunc == nil { - return "", nil + return "", "", nil } - return m.queryFunc(ctx, text) + return m.queryFunc(ctx, text, model) +} + +func (m *mockAIGateway) ListModels(ctx context.Context) ([]string, error) { + if m.listModelsFunc == nil { + return nil, nil + } + return m.listModelsFunc(ctx) +} + +func newTestCommandApp(ha *mockHAGateway, ai *mockAIGateway) *CommandApp { + return NewCommandApp(ha, ai, modelstore.New(), modelvalidator.New(ai, time.Minute)) } func TestCommandAppHandleLightList(t *testing.T) { @@ -102,7 +117,7 @@ func TestCommandAppHandleLightList(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - app := NewCommandApp(&mockHAGateway{ + app := newTestCommandApp(&mockHAGateway{ listLightsFunc: func(ctx context.Context) ([]driven.Light, error) { return tt.lights, nil }, @@ -183,7 +198,7 @@ func TestCommandAppHandleLightOn(t *testing.T) { var gotBrightness *uint32 var gotColorTemp *uint32 - app := NewCommandApp(&mockHAGateway{ + app := newTestCommandApp(&mockHAGateway{ listLightsFunc: func(ctx context.Context) ([]driven.Light, error) { if tt.listErr != nil { return nil, tt.listErr @@ -258,7 +273,7 @@ func TestCommandAppHandleLightOff(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var gotTransition *uint32 - app := NewCommandApp(&mockHAGateway{ + app := newTestCommandApp(&mockHAGateway{ listLightsFunc: func(ctx context.Context) ([]driven.Light, error) { return tt.lights, nil }, @@ -317,7 +332,7 @@ func TestCommandAppHandleLightToggle(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - app := NewCommandApp(&mockHAGateway{ + app := newTestCommandApp(&mockHAGateway{ listLightsFunc: func(ctx context.Context) ([]driven.Light, error) { if tt.listErr != nil { return nil, tt.listErr @@ -368,7 +383,7 @@ func TestCommandAppHandleSwitchList(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - app := NewCommandApp(&mockHAGateway{ + app := newTestCommandApp(&mockHAGateway{ listSwitchesFunc: func(ctx context.Context) ([]driven.Switch, error) { return tt.switches, nil }, @@ -413,7 +428,7 @@ func TestCommandAppAutocompleteLights(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - app := NewCommandApp(&mockHAGateway{ + app := newTestCommandApp(&mockHAGateway{ listLightsFunc: func(ctx context.Context) ([]driven.Light, error) { if tt.listErr != nil { return nil, tt.listErr @@ -467,7 +482,7 @@ func TestCommandAppAutocompleteSwitches(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - app := NewCommandApp(&mockHAGateway{ + app := newTestCommandApp(&mockHAGateway{ listSwitchesFunc: func(ctx context.Context) ([]driven.Switch, error) { if tt.listErr != nil { return nil, tt.listErr @@ -494,20 +509,41 @@ func TestCommandAppAutocompleteSwitches(t *testing.T) { } func TestCommandAppHandleAIQuery(t *testing.T) { + store := modelstore.New() + store.Set("llama3:latest") app := NewCommandApp(&mockHAGateway{}, &mockAIGateway{ - queryFunc: func(ctx context.Context, text string) (string, error) { + queryFunc: func(ctx context.Context, text, model string) (string, string, error) { if text != "turn on kitchen" { t.Fatalf("Query() text = %q", text) } - return "Turning on Kitchen.", nil + if model != "llama3:latest" { + t.Fatalf("Query() model = %q", model) + } + return "Turning on Kitchen.", "llama3:latest", nil }, - }) + }, store, modelvalidator.New(&mockAIGateway{}, time.Minute)) got, err := app.HandleAIQuery(context.Background(), "turn on kitchen") if err != nil { t.Fatalf("HandleAIQuery() error = %v", err) } - if got != "Turning on Kitchen." { + if got != "Turning on Kitchen.\n\n_(via llama3:latest)_" { t.Fatalf("HandleAIQuery() = %q", got) } } + +func TestCommandAppHandleAIModelSet(t *testing.T) { + app := newTestCommandApp(&mockHAGateway{}, &mockAIGateway{ + listModelsFunc: func(ctx context.Context) ([]string, error) { + return []string{"llama3:latest"}, nil + }, + }) + + got, err := app.HandleAIModelSet(context.Background(), "llama3") + if err != nil { + t.Fatalf("HandleAIModelSet() error = %v", err) + } + if got != "Active model set to `llama3:latest`." { + t.Fatalf("HandleAIModelSet() = %q", got) + } +} diff --git a/discord-bot/internal/core/ports/driven/ai.go b/discord-bot/internal/core/ports/driven/ai.go index ce52270..4c06154 100644 --- a/discord-bot/internal/core/ports/driven/ai.go +++ b/discord-bot/internal/core/ports/driven/ai.go @@ -4,5 +4,6 @@ import "context" // AIGateway exposes the free-form AI query API used by the Discord bot. type AIGateway interface { - Query(ctx context.Context, text string) (string, error) + Query(ctx context.Context, text, model string) (reply, modelUsed string, err error) + ListModels(ctx context.Context) ([]string, error) } diff --git a/discord-bot/internal/modelstore/store.go b/discord-bot/internal/modelstore/store.go new file mode 100644 index 0000000..3bdd34c --- /dev/null +++ b/discord-bot/internal/modelstore/store.go @@ -0,0 +1,28 @@ +package modelstore + +import "sync" + +// Store keeps the globally selected AI model in memory. +type Store struct { + mu sync.RWMutex + selected string +} + +// New constructs an empty in-memory model store. +func New() *Store { + return &Store{} +} + +// Get returns the currently selected model, or empty for default behavior. +func (s *Store) Get() string { + s.mu.RLock() + defer s.mu.RUnlock() + return s.selected +} + +// Set updates the currently selected model. +func (s *Store) Set(model string) { + s.mu.Lock() + defer s.mu.Unlock() + s.selected = model +} diff --git a/discord-bot/internal/modelstore/store_test.go b/discord-bot/internal/modelstore/store_test.go new file mode 100644 index 0000000..7db6c2f --- /dev/null +++ b/discord-bot/internal/modelstore/store_test.go @@ -0,0 +1,15 @@ +package modelstore + +import "testing" + +func TestStoreGetSet(t *testing.T) { + store := New() + if got := store.Get(); got != "" { + t.Fatalf("Get() = %q, want empty", got) + } + + store.Set("llama3:latest") + if got := store.Get(); got != "llama3:latest" { + t.Fatalf("Get() = %q, want llama3:latest", got) + } +} diff --git a/discord-bot/internal/modelvalidator/validator.go b/discord-bot/internal/modelvalidator/validator.go new file mode 100644 index 0000000..d2031d2 --- /dev/null +++ b/discord-bot/internal/modelvalidator/validator.go @@ -0,0 +1,86 @@ +package modelvalidator + +import ( + "context" + "fmt" + "strings" + "sync" + "time" + + "gitea.nik4nao.com/nik/home-services/discord-bot/internal/core/ports/driven" +) + +// Validator caches the model list briefly and normalizes friendly names. +type Validator struct { + client driven.AIGateway + ttl time.Duration + + mu sync.Mutex + cache []string + cachedAt time.Time +} + +// New constructs a model validator with a TTL cache. +func New(client driven.AIGateway, ttl time.Duration) *Validator { + return &Validator{client: client, ttl: ttl} +} + +// Known returns the cached model list, refreshing when stale. +func (v *Validator) Known(ctx context.Context) ([]string, error) { + v.mu.Lock() + defer v.mu.Unlock() + if len(v.cache) > 0 && time.Since(v.cachedAt) < v.ttl { + return append([]string(nil), v.cache...), nil + } + models, err := v.client.ListModels(ctx) + if err != nil { + return nil, err + } + v.cache = append([]string(nil), models...) + v.cachedAt = time.Now() + return append([]string(nil), v.cache...), nil +} + +// Normalize resolves a user-provided name to a canonical installed model name. +func (v *Validator) Normalize(ctx context.Context, name string) (string, error) { + models, err := v.Known(ctx) + if err != nil { + return "", err + } + for _, model := range models { + if model == name { + return model, nil + } + } + latest := name + ":latest" + for _, model := range models { + if model == latest { + return model, nil + } + } + + lower := strings.ToLower(name) + for _, model := range models { + if strings.ToLower(model) == lower { + return model, nil + } + } + lowerLatest := strings.ToLower(latest) + for _, model := range models { + if strings.ToLower(model) == lowerLatest { + return model, nil + } + } + + matches := make([]string, 0, 2) + prefix := lower + ":" + for _, model := range models { + if strings.HasPrefix(strings.ToLower(model), prefix) { + matches = append(matches, model) + } + } + if len(matches) > 1 { + return "", fmt.Errorf("ambiguous model name") + } + return "", fmt.Errorf("unknown model") +} diff --git a/discord-bot/internal/modelvalidator/validator_test.go b/discord-bot/internal/modelvalidator/validator_test.go new file mode 100644 index 0000000..6bbd7d2 --- /dev/null +++ b/discord-bot/internal/modelvalidator/validator_test.go @@ -0,0 +1,70 @@ +package modelvalidator + +import ( + "context" + "errors" + "testing" + "time" +) + +type fakeAIGateway struct { + listModels func(context.Context) ([]string, error) +} + +func (f *fakeAIGateway) Query(ctx context.Context, text, model string) (string, string, error) { + return "", "", nil +} + +func (f *fakeAIGateway) ListModels(ctx context.Context) ([]string, error) { + return f.listModels(ctx) +} + +func TestValidatorNormalize(t *testing.T) { + v := New(&fakeAIGateway{ + listModels: func(ctx context.Context) ([]string, error) { + return []string{"llama3:latest", "qwen3:latest", "qwen3:4b"}, nil + }, + }, time.Minute) + + got, err := v.Normalize(context.Background(), "llama3") + if err != nil || got != "llama3:latest" { + t.Fatalf("Normalize(llama3) = %q, %v", got, err) + } + + got, err = v.Normalize(context.Background(), "qwen3") + if err != nil || got != "qwen3:latest" { + t.Fatalf("Normalize(qwen3) = %q, %v", got, err) + } +} + +func TestValidatorKnownCaches(t *testing.T) { + calls := 0 + v := New(&fakeAIGateway{ + listModels: func(ctx context.Context) ([]string, error) { + calls++ + return []string{"llama3:latest"}, nil + }, + }, time.Minute) + + if _, err := v.Known(context.Background()); err != nil { + t.Fatalf("Known() error = %v", err) + } + if _, err := v.Known(context.Background()); err != nil { + t.Fatalf("Known() error = %v", err) + } + if calls != 1 { + t.Fatalf("calls = %d, want 1", calls) + } +} + +func TestValidatorKnownPropagatesError(t *testing.T) { + v := New(&fakeAIGateway{ + listModels: func(ctx context.Context) ([]string, error) { + return nil, errors.New("boom") + }, + }, time.Minute) + + if _, err := v.Known(context.Background()); err == nil { + t.Fatal("Known() error = nil, want error") + } +} diff --git a/gen/ai/v1/ai.pb.go b/gen/ai/v1/ai.pb.go index 7fdadf4..a7c6dfe 100644 --- a/gen/ai/v1/ai.pb.go +++ b/gen/ai/v1/ai.pb.go @@ -25,6 +25,7 @@ type QueryRequest struct { state protoimpl.MessageState `protogen:"open.v1"` Text string `protobuf:"bytes,1,opt,name=text,proto3" json:"text,omitempty"` Source string `protobuf:"bytes,2,opt,name=source,proto3" json:"source,omitempty"` + Model string `protobuf:"bytes,3,opt,name=model,proto3" json:"model,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -73,11 +74,19 @@ func (x *QueryRequest) GetSource() string { return "" } +func (x *QueryRequest) GetModel() string { + if x != nil { + return x.Model + } + return "" +} + type QueryResponse struct { state protoimpl.MessageState `protogen:"open.v1"` Reply string `protobuf:"bytes,1,opt,name=reply,proto3" json:"reply,omitempty"` Intent string `protobuf:"bytes,2,opt,name=intent,proto3" json:"intent,omitempty"` ActionTaken bool `protobuf:"varint,3,opt,name=action_taken,json=actionTaken,proto3" json:"action_taken,omitempty"` + ModelUsed string `protobuf:"bytes,4,opt,name=model_used,json=modelUsed,proto3" json:"model_used,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -133,20 +142,115 @@ func (x *QueryResponse) GetActionTaken() bool { return false } +func (x *QueryResponse) GetModelUsed() string { + if x != nil { + return x.ModelUsed + } + return "" +} + +type ListModelsRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ListModelsRequest) Reset() { + *x = ListModelsRequest{} + mi := &file_ai_v1_ai_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ListModelsRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ListModelsRequest) ProtoMessage() {} + +func (x *ListModelsRequest) ProtoReflect() protoreflect.Message { + mi := &file_ai_v1_ai_proto_msgTypes[2] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ListModelsRequest.ProtoReflect.Descriptor instead. +func (*ListModelsRequest) Descriptor() ([]byte, []int) { + return file_ai_v1_ai_proto_rawDescGZIP(), []int{2} +} + +type ListModelsResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Names []string `protobuf:"bytes,1,rep,name=names,proto3" json:"names,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ListModelsResponse) Reset() { + *x = ListModelsResponse{} + mi := &file_ai_v1_ai_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ListModelsResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ListModelsResponse) ProtoMessage() {} + +func (x *ListModelsResponse) ProtoReflect() protoreflect.Message { + mi := &file_ai_v1_ai_proto_msgTypes[3] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ListModelsResponse.ProtoReflect.Descriptor instead. +func (*ListModelsResponse) Descriptor() ([]byte, []int) { + return file_ai_v1_ai_proto_rawDescGZIP(), []int{3} +} + +func (x *ListModelsResponse) GetNames() []string { + if x != nil { + return x.Names + } + return nil +} + var File_ai_v1_ai_proto protoreflect.FileDescriptor const file_ai_v1_ai_proto_rawDesc = "" + "\n" + - "\x0eai/v1/ai.proto\x12\x05ai.v1\":\n" + + "\x0eai/v1/ai.proto\x12\x05ai.v1\"P\n" + "\fQueryRequest\x12\x12\n" + "\x04text\x18\x01 \x01(\tR\x04text\x12\x16\n" + - "\x06source\x18\x02 \x01(\tR\x06source\"`\n" + + "\x06source\x18\x02 \x01(\tR\x06source\x12\x14\n" + + "\x05model\x18\x03 \x01(\tR\x05model\"\x7f\n" + "\rQueryResponse\x12\x14\n" + "\x05reply\x18\x01 \x01(\tR\x05reply\x12\x16\n" + "\x06intent\x18\x02 \x01(\tR\x06intent\x12!\n" + - "\faction_taken\x18\x03 \x01(\bR\vactionTaken2?\n" + + "\faction_taken\x18\x03 \x01(\bR\vactionTaken\x12\x1d\n" + + "\n" + + "model_used\x18\x04 \x01(\tR\tmodelUsed\"\x13\n" + + "\x11ListModelsRequest\"*\n" + + "\x12ListModelsResponse\x12\x14\n" + + "\x05names\x18\x01 \x03(\tR\x05names2\x82\x01\n" + "\tAIService\x122\n" + - "\x05Query\x12\x13.ai.v1.QueryRequest\x1a\x14.ai.v1.QueryResponseB4Z2gitea.nik4nao.com/nik/home-services/gen/ai/v1;aiv1b\x06proto3" + "\x05Query\x12\x13.ai.v1.QueryRequest\x1a\x14.ai.v1.QueryResponse\x12A\n" + + "\n" + + "ListModels\x12\x18.ai.v1.ListModelsRequest\x1a\x19.ai.v1.ListModelsResponseB4Z2gitea.nik4nao.com/nik/home-services/gen/ai/v1;aiv1b\x06proto3" var ( file_ai_v1_ai_proto_rawDescOnce sync.Once @@ -160,16 +264,20 @@ func file_ai_v1_ai_proto_rawDescGZIP() []byte { return file_ai_v1_ai_proto_rawDescData } -var file_ai_v1_ai_proto_msgTypes = make([]protoimpl.MessageInfo, 2) +var file_ai_v1_ai_proto_msgTypes = make([]protoimpl.MessageInfo, 4) var file_ai_v1_ai_proto_goTypes = []any{ - (*QueryRequest)(nil), // 0: ai.v1.QueryRequest - (*QueryResponse)(nil), // 1: ai.v1.QueryResponse + (*QueryRequest)(nil), // 0: ai.v1.QueryRequest + (*QueryResponse)(nil), // 1: ai.v1.QueryResponse + (*ListModelsRequest)(nil), // 2: ai.v1.ListModelsRequest + (*ListModelsResponse)(nil), // 3: ai.v1.ListModelsResponse } var file_ai_v1_ai_proto_depIdxs = []int32{ 0, // 0: ai.v1.AIService.Query:input_type -> ai.v1.QueryRequest - 1, // 1: ai.v1.AIService.Query:output_type -> ai.v1.QueryResponse - 1, // [1:2] is the sub-list for method output_type - 0, // [0:1] is the sub-list for method input_type + 2, // 1: ai.v1.AIService.ListModels:input_type -> ai.v1.ListModelsRequest + 1, // 2: ai.v1.AIService.Query:output_type -> ai.v1.QueryResponse + 3, // 3: ai.v1.AIService.ListModels:output_type -> ai.v1.ListModelsResponse + 2, // [2:4] is the sub-list for method output_type + 0, // [0:2] is the sub-list for method input_type 0, // [0:0] is the sub-list for extension type_name 0, // [0:0] is the sub-list for extension extendee 0, // [0:0] is the sub-list for field type_name @@ -186,7 +294,7 @@ func file_ai_v1_ai_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_ai_v1_ai_proto_rawDesc), len(file_ai_v1_ai_proto_rawDesc)), NumEnums: 0, - NumMessages: 2, + NumMessages: 4, NumExtensions: 0, NumServices: 1, }, diff --git a/gen/ai/v1/ai_grpc.pb.go b/gen/ai/v1/ai_grpc.pb.go index 0ca5f82..4d25667 100644 --- a/gen/ai/v1/ai_grpc.pb.go +++ b/gen/ai/v1/ai_grpc.pb.go @@ -19,7 +19,8 @@ import ( const _ = grpc.SupportPackageIsVersion9 const ( - AIService_Query_FullMethodName = "/ai.v1.AIService/Query" + AIService_Query_FullMethodName = "/ai.v1.AIService/Query" + AIService_ListModels_FullMethodName = "/ai.v1.AIService/ListModels" ) // AIServiceClient is the client API for AIService service. @@ -27,6 +28,7 @@ const ( // For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. type AIServiceClient interface { Query(ctx context.Context, in *QueryRequest, opts ...grpc.CallOption) (*QueryResponse, error) + ListModels(ctx context.Context, in *ListModelsRequest, opts ...grpc.CallOption) (*ListModelsResponse, error) } type aIServiceClient struct { @@ -47,11 +49,22 @@ func (c *aIServiceClient) Query(ctx context.Context, in *QueryRequest, opts ...g return out, nil } +func (c *aIServiceClient) ListModels(ctx context.Context, in *ListModelsRequest, opts ...grpc.CallOption) (*ListModelsResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(ListModelsResponse) + err := c.cc.Invoke(ctx, AIService_ListModels_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + // AIServiceServer is the server API for AIService service. // All implementations must embed UnimplementedAIServiceServer // for forward compatibility. type AIServiceServer interface { Query(context.Context, *QueryRequest) (*QueryResponse, error) + ListModels(context.Context, *ListModelsRequest) (*ListModelsResponse, error) mustEmbedUnimplementedAIServiceServer() } @@ -65,6 +78,9 @@ type UnimplementedAIServiceServer struct{} func (UnimplementedAIServiceServer) Query(context.Context, *QueryRequest) (*QueryResponse, error) { return nil, status.Error(codes.Unimplemented, "method Query not implemented") } +func (UnimplementedAIServiceServer) ListModels(context.Context, *ListModelsRequest) (*ListModelsResponse, error) { + return nil, status.Error(codes.Unimplemented, "method ListModels not implemented") +} func (UnimplementedAIServiceServer) mustEmbedUnimplementedAIServiceServer() {} func (UnimplementedAIServiceServer) testEmbeddedByValue() {} @@ -104,6 +120,24 @@ func _AIService_Query_Handler(srv interface{}, ctx context.Context, dec func(int return interceptor(ctx, in, info, handler) } +func _AIService_ListModels_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(ListModelsRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(AIServiceServer).ListModels(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: AIService_ListModels_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(AIServiceServer).ListModels(ctx, req.(*ListModelsRequest)) + } + return interceptor(ctx, in, info, handler) +} + // AIService_ServiceDesc is the grpc.ServiceDesc for AIService service. // It's only intended for direct use with grpc.RegisterService, // and not to be introspected or modified (even as a copy) @@ -115,6 +149,10 @@ var AIService_ServiceDesc = grpc.ServiceDesc{ MethodName: "Query", Handler: _AIService_Query_Handler, }, + { + MethodName: "ListModels", + Handler: _AIService_ListModels_Handler, + }, }, Streams: []grpc.StreamDesc{}, Metadata: "ai/v1/ai.proto", diff --git a/proto/ai/v1/ai.proto b/proto/ai/v1/ai.proto index 1cddf31..456a84a 100644 --- a/proto/ai/v1/ai.proto +++ b/proto/ai/v1/ai.proto @@ -6,15 +6,24 @@ option go_package = "gitea.nik4nao.com/nik/home-services/gen/ai/v1;aiv1"; service AIService { rpc Query(QueryRequest) returns (QueryResponse); + rpc ListModels(ListModelsRequest) returns (ListModelsResponse); } message QueryRequest { string text = 1; string source = 2; + string model = 3; } message QueryResponse { string reply = 1; string intent = 2; bool action_taken = 3; + string model_used = 4; +} + +message ListModelsRequest {} + +message ListModelsResponse { + repeated string names = 1; }