feat: enhance AI model management in Discord bot
All checks were successful
CI / test (push) Successful in 5s
CI / build-ai-gateway (push) Successful in 43s
CI / build-ha-gateway (push) Successful in 47s
CI / build-discord-bot (push) Successful in 41s

- Updated LLMClient interface to support model-specific generation and model listing.
- Integrated model store and validator into the command application for managing AI models.
- Implemented commands for setting, getting, and listing active AI models in Discord.
- Enhanced AI query handling to utilize the selected model and return model information in responses.
- Added caching mechanism for model validation to improve performance.
- Introduced gRPC methods for listing available AI models in the ai-gateway.
- Updated protobuf definitions to include model-related fields and messages.
- Added tests for model store and validator functionalities.
This commit is contained in:
Nik Afiq 2026-04-21 22:52:00 +09:00
parent 9cc29c2329
commit ad50d641bd
21 changed files with 876 additions and 80 deletions

View File

@ -67,7 +67,7 @@ func main() {
os.Exit(1) os.Exit(1)
} }
ollamaClient := ollama.New(cfg.OllamaURL, cfg.OllamaModel, &http.Client{Timeout: cfg.OllamaTimeout}) ollamaClient := ollama.New(cfg.OllamaURL, &http.Client{Timeout: cfg.OllamaTimeout})
haClient, err := hagateway.New(ctx, cfg.HAGatewayAddr, cfg.TLSDir, cfg.HAGatewayServerName, log) haClient, err := hagateway.New(ctx, cfg.HAGatewayAddr, cfg.TLSDir, cfg.HAGatewayServerName, log)
if err != nil { if err != nil {
log.Error("ha-gateway client setup failed", "err", err) log.Error("ha-gateway client setup failed", "err", err)
@ -80,7 +80,7 @@ func main() {
}() }()
lightCache := domain.NewLightCache(cfg.LightCacheTTL, haClient.ListLights) lightCache := domain.NewLightCache(cfg.LightCacheTTL, haClient.ListLights)
queryApp := app.NewQueryApp(ollamaClient, haClient, lightCache, log) queryApp := app.NewQueryApp(ollamaClient, haClient, lightCache, cfg.OllamaModel, log)
serverOpts := []grpc.ServerOption{ serverOpts := []grpc.ServerOption{
grpc.StatsHandler(otelgrpc.NewServerHandler()), grpc.StatsHandler(otelgrpc.NewServerHandler()),

View File

@ -33,7 +33,7 @@ func (s *Server) Query(ctx context.Context, req *aiv1.QueryRequest) (*aiv1.Query
ctx = logger.WithLogger(ctx, log) ctx = logger.WithLogger(ctx, log)
} }
result, err := s.app.Query(ctx, req.GetText()) result, err := s.app.Query(ctx, req.GetText(), req.GetModel())
if err != nil { if err != nil {
return nil, status.Errorf(codes.Unavailable, "query failed: %v", err) return nil, status.Errorf(codes.Unavailable, "query failed: %v", err)
} }
@ -41,5 +41,15 @@ func (s *Server) Query(ctx context.Context, req *aiv1.QueryRequest) (*aiv1.Query
Reply: result.Reply, Reply: result.Reply,
Intent: result.Intent, Intent: result.Intent,
ActionTaken: result.ActionTaken, ActionTaken: result.ActionTaken,
ModelUsed: result.ModelUsed,
}, nil }, nil
} }
// ListModels returns the installed model names from Ollama.
func (s *Server) ListModels(ctx context.Context, _ *aiv1.ListModelsRequest) (*aiv1.ListModelsResponse, error) {
names, err := s.app.ListModels(ctx)
if err != nil {
return nil, status.Errorf(codes.Unavailable, "list models: %v", err)
}
return &aiv1.ListModelsResponse{Names: names}, nil
}

View File

@ -6,6 +6,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
"strings"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
) )
@ -14,37 +15,49 @@ type generateRequest struct {
Model string `json:"model"` Model string `json:"model"`
Prompt string `json:"prompt"` Prompt string `json:"prompt"`
Stream bool `json:"stream"` Stream bool `json:"stream"`
Think *bool `json:"think,omitempty"`
} }
type generateResponse struct { type generateResponse struct {
Response string `json:"response"` Response string `json:"response"`
} }
type listModelsResponse struct {
Models []struct {
Name string `json:"name"`
} `json:"models"`
}
// Client implements the LLM driven port with the Ollama generate API. // Client implements the LLM driven port with the Ollama generate API.
type Client struct { type Client struct {
baseURL string baseURL string
model string
http *http.Client http *http.Client
} }
// New constructs an Ollama client with OTel-instrumented transport. // New constructs an Ollama client with OTel-instrumented transport.
func New(baseURL, model string, httpClient *http.Client) *Client { func New(baseURL string, httpClient *http.Client) *Client {
if httpClient == nil { if httpClient == nil {
httpClient = &http.Client{Transport: otelhttp.NewTransport(http.DefaultTransport)} httpClient = &http.Client{Transport: otelhttp.NewTransport(http.DefaultTransport)}
} }
if httpClient.Transport == nil { if httpClient.Transport == nil {
httpClient.Transport = otelhttp.NewTransport(http.DefaultTransport) httpClient.Transport = otelhttp.NewTransport(http.DefaultTransport)
} }
return &Client{baseURL: baseURL, model: model, http: httpClient} return &Client{baseURL: baseURL, http: httpClient}
} }
// Generate sends one non-streaming prompt to Ollama. // Generate sends one non-streaming prompt to Ollama.
func (c *Client) Generate(ctx context.Context, prompt string) (string, error) { func (c *Client) Generate(ctx context.Context, model, prompt string) (string, error) {
body, err := json.Marshal(generateRequest{ reqBody := generateRequest{
Model: c.model, Model: model,
Prompt: prompt, Prompt: prompt,
Stream: false, Stream: false,
}) }
if isThinkingModel(model) {
disabled := false
reqBody.Think = &disabled
}
body, err := json.Marshal(reqBody)
if err != nil { if err != nil {
return "", fmt.Errorf("marshal ollama request: %w", err) return "", fmt.Errorf("marshal ollama request: %w", err)
} }
@ -71,3 +84,38 @@ func (c *Client) Generate(ctx context.Context, prompt string) (string, error) {
} }
return out.Response, nil return out.Response, nil
} }
// ListModels returns the installed model names from Ollama.
func (c *Client) ListModels(ctx context.Context) ([]string, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.baseURL+"/api/tags", nil)
if err != nil {
return nil, fmt.Errorf("build ollama list models request: %w", err)
}
resp, err := c.http.Do(req)
if err != nil {
return nil, fmt.Errorf("list ollama models: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("ollama returned status %s", resp.Status)
}
var out listModelsResponse
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
return nil, fmt.Errorf("decode ollama list models response: %w", err)
}
names := make([]string, 0, len(out.Models))
for _, model := range out.Models {
if model.Name != "" {
names = append(names, model.Name)
}
}
return names, nil
}
func isThinkingModel(name string) bool {
return strings.HasPrefix(name, "qwen3")
}

View File

@ -0,0 +1,81 @@
package ollama
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
)
func TestIsThinkingModel(t *testing.T) {
tests := []struct {
name string
want bool
}{
{name: "qwen3", want: true},
{name: "qwen3:4b", want: true},
{name: "qwen3:latest", want: true},
{name: "llama3", want: false},
{name: "mistral", want: false},
}
for _, tt := range tests {
if got := isThinkingModel(tt.name); got != tt.want {
t.Fatalf("isThinkingModel(%q) = %v, want %v", tt.name, got, tt.want)
}
}
}
func TestGenerateSetsThinkForQwen3(t *testing.T) {
var body map[string]any
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
t.Fatalf("Decode() error = %v", err)
}
_, _ = w.Write([]byte(`{"response":"ok"}`))
}))
defer srv.Close()
client := New(srv.URL, srv.Client())
if _, err := client.Generate(context.Background(), "qwen3:latest", "prompt"); err != nil {
t.Fatalf("Generate() error = %v", err)
}
if body["think"] != false {
t.Fatalf("think = %#v, want false", body["think"])
}
}
func TestGenerateOmitsThinkForLlama3(t *testing.T) {
var body map[string]any
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
t.Fatalf("Decode() error = %v", err)
}
_, _ = w.Write([]byte(`{"response":"ok"}`))
}))
defer srv.Close()
client := New(srv.URL, srv.Client())
if _, err := client.Generate(context.Background(), "llama3:latest", "prompt"); err != nil {
t.Fatalf("Generate() error = %v", err)
}
if _, ok := body["think"]; ok {
t.Fatalf("unexpected think key = %#v", body["think"])
}
}
func TestListModelsReturnsNamesOnly(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte(`{"models":[{"name":"llama3:latest"},{"name":"qwen3:latest"}]}`))
}))
defer srv.Close()
client := New(srv.URL, srv.Client())
got, err := client.ListModels(context.Background())
if err != nil {
t.Fatalf("ListModels() error = %v", err)
}
if len(got) != 2 || got[0] != "llama3:latest" || got[1] != "qwen3:latest" {
t.Fatalf("ListModels() = %#v", got)
}
}

View File

@ -18,6 +18,7 @@ type QueryResult struct {
Reply string Reply string
Intent string Intent string
ActionTaken bool ActionTaken bool
ModelUsed string
} }
// QueryApp orchestrates one AI query request. // QueryApp orchestrates one AI query request.
@ -25,27 +26,32 @@ type QueryApp struct {
llm driven.LLMClient llm driven.LLMClient
ha driven.HAClient ha driven.HAClient
cache *domain.LightCache cache *domain.LightCache
defaultModel string
log *slog.Logger log *slog.Logger
} }
// NewQueryApp constructs the AI query application service. // NewQueryApp constructs the AI query application service.
func NewQueryApp(llm driven.LLMClient, ha driven.HAClient, cache *domain.LightCache, log *slog.Logger) *QueryApp { func NewQueryApp(llm driven.LLMClient, ha driven.HAClient, cache *domain.LightCache, defaultModel string, log *slog.Logger) *QueryApp {
return &QueryApp{llm: llm, ha: ha, cache: cache, log: log} return &QueryApp{llm: llm, ha: ha, cache: cache, defaultModel: defaultModel, log: log}
} }
// Query runs the full intent parsing and dispatch flow for one user request. // Query runs the full intent parsing and dispatch flow for one user request.
func (a *QueryApp) Query(ctx context.Context, text string) (QueryResult, error) { func (a *QueryApp) Query(ctx context.Context, text, model string) (QueryResult, error) {
if model == "" {
model = a.defaultModel
}
lights, err := a.cache.Get(ctx) lights, err := a.cache.Get(ctx)
if err != nil { if err != nil {
a.log.Error("light cache refresh failed", "err", err) a.log.Error("light cache refresh failed", "err", err)
return QueryResult{ return QueryResult{
Reply: "I couldn't reach Home Assistant right now.", Reply: "I couldn't reach Home Assistant right now.",
ActionTaken: false, ActionTaken: false,
ModelUsed: model,
}, nil }, nil
} }
prompt := domain.BuildPrompt(text, promptLightLines(lights)) prompt := domain.BuildPrompt(text, promptLightLines(lights))
raw, err := a.llm.Generate(ctx, prompt) raw, err := a.llm.Generate(ctx, model, prompt)
if err != nil { if err != nil {
return QueryResult{}, err return QueryResult{}, err
} }
@ -57,6 +63,7 @@ func (a *QueryApp) Query(ctx context.Context, text string) (QueryResult, error)
Reply: "I didn't understand that.", Reply: "I didn't understand that.",
Intent: domain.IntentNone, Intent: domain.IntentNone,
ActionTaken: false, ActionTaken: false,
ModelUsed: model,
}, nil }, nil
} }
@ -64,7 +71,7 @@ func (a *QueryApp) Query(ctx context.Context, text string) (QueryResult, error)
case domain.IntentTurnOnLight: case domain.IntentTurnOnLight:
entityID, ok := resolveLightEntity(intent.Entity, lights) entityID, ok := resolveLightEntity(intent.Entity, lights)
if !ok { if !ok {
return QueryResult{Reply: "I couldn't find that light.", Intent: intent.Name}, nil return QueryResult{Reply: "I couldn't find that light.", Intent: intent.Name, ModelUsed: model}, nil
} }
params, err := ParseLightParams(intent.Params) params, err := ParseLightParams(intent.Params)
if err != nil { if err != nil {
@ -72,6 +79,7 @@ func (a *QueryApp) Query(ctx context.Context, text string) (QueryResult, error)
Reply: "I couldn't understand the light settings.", Reply: "I couldn't understand the light settings.",
Intent: intent.Name, Intent: intent.Name,
ActionTaken: false, ActionTaken: false,
ModelUsed: model,
}, nil }, nil
} }
if err := a.ha.TurnOnLight(ctx, entityID, params); err != nil { if err := a.ha.TurnOnLight(ctx, entityID, params); err != nil {
@ -80,17 +88,19 @@ func (a *QueryApp) Query(ctx context.Context, text string) (QueryResult, error)
Reply: "I couldn't reach Home Assistant right now.", Reply: "I couldn't reach Home Assistant right now.",
Intent: intent.Name, Intent: intent.Name,
ActionTaken: false, ActionTaken: false,
ModelUsed: model,
}, nil }, nil
} }
return QueryResult{ return QueryResult{
Reply: fallbackReply(intent.Reply, fmt.Sprintf("Turned on `%s`.", displayLightName(entityID, lights))), Reply: fallbackReply(intent.Reply, fmt.Sprintf("Turned on `%s`.", displayLightName(entityID, lights))),
Intent: intent.Name, Intent: intent.Name,
ActionTaken: true, ActionTaken: true,
ModelUsed: model,
}, nil }, nil
case domain.IntentTurnOffLight: case domain.IntentTurnOffLight:
entityID, ok := resolveLightEntity(intent.Entity, lights) entityID, ok := resolveLightEntity(intent.Entity, lights)
if !ok { if !ok {
return QueryResult{Reply: "I couldn't find that light.", Intent: intent.Name}, nil return QueryResult{Reply: "I couldn't find that light.", Intent: intent.Name, ModelUsed: model}, nil
} }
if err := a.ha.TurnOffLight(ctx, entityID); err != nil { if err := a.ha.TurnOffLight(ctx, entityID); err != nil {
a.log.Error("turn off light failed", "entity_id", entityID, "err", err) a.log.Error("turn off light failed", "entity_id", entityID, "err", err)
@ -98,18 +108,21 @@ func (a *QueryApp) Query(ctx context.Context, text string) (QueryResult, error)
Reply: "I couldn't reach Home Assistant right now.", Reply: "I couldn't reach Home Assistant right now.",
Intent: intent.Name, Intent: intent.Name,
ActionTaken: false, ActionTaken: false,
ModelUsed: model,
}, nil }, nil
} }
return QueryResult{ return QueryResult{
Reply: fallbackReply(intent.Reply, fmt.Sprintf("Turned off `%s`.", displayLightName(entityID, lights))), Reply: fallbackReply(intent.Reply, fmt.Sprintf("Turned off `%s`.", displayLightName(entityID, lights))),
Intent: intent.Name, Intent: intent.Name,
ActionTaken: true, ActionTaken: true,
ModelUsed: model,
}, nil }, nil
case domain.IntentListLights: case domain.IntentListLights:
return QueryResult{ return QueryResult{
Reply: formatLightListReply(lights), Reply: formatLightListReply(lights),
Intent: intent.Name, Intent: intent.Name,
ActionTaken: false, ActionTaken: false,
ModelUsed: model,
}, nil }, nil
case domain.IntentNone: case domain.IntentNone:
fallthrough fallthrough
@ -118,10 +131,16 @@ func (a *QueryApp) Query(ctx context.Context, text string) (QueryResult, error)
Reply: fallbackReply(intent.Reply, "I didn't understand that."), Reply: fallbackReply(intent.Reply, "I didn't understand that."),
Intent: intent.Name, Intent: intent.Name,
ActionTaken: false, ActionTaken: false,
ModelUsed: model,
}, nil }, nil
} }
} }
// ListModels returns the currently installed model names.
func (a *QueryApp) ListModels(ctx context.Context) ([]string, error) {
return a.llm.ListModels(ctx)
}
func promptLightLines(lights []driven.Light) []string { func promptLightLines(lights []driven.Light) []string {
lines := make([]string, 0, len(lights)) lines := make([]string, 0, len(lights))
for _, light := range lights { for _, light := range lights {

View File

@ -13,11 +13,19 @@ import (
) )
type fakeLLM struct { type fakeLLM struct {
generate func(context.Context, string) (string, error) generate func(context.Context, string, string) (string, error)
listModels func(context.Context) ([]string, error)
} }
func (f *fakeLLM) Generate(ctx context.Context, prompt string) (string, error) { func (f *fakeLLM) Generate(ctx context.Context, model, prompt string) (string, error) {
return f.generate(ctx, prompt) return f.generate(ctx, model, prompt)
}
func (f *fakeLLM) ListModels(ctx context.Context) ([]string, error) {
if f.listModels == nil {
return nil, nil
}
return f.listModels(ctx)
} }
type fakeHA struct { type fakeHA struct {
@ -54,12 +62,15 @@ func TestQueryAppTurnOnLight(t *testing.T) {
ha := &fakeHA{lights: []driven.Light{{EntityID: "light.kitchen", FriendlyName: "Kitchen", State: "off"}}} ha := &fakeHA{lights: []driven.Light{{EntityID: "light.kitchen", FriendlyName: "Kitchen", State: "off"}}}
cache := domain.NewLightCache(time.Hour, ha.ListLights) cache := domain.NewLightCache(time.Hour, ha.ListLights)
app := NewQueryApp(&fakeLLM{ app := NewQueryApp(&fakeLLM{
generate: func(ctx context.Context, prompt string) (string, error) { generate: func(ctx context.Context, model, prompt string) (string, error) {
if model != "llama3:latest" {
t.Fatalf("Generate() model = %q", model)
}
return `{"intent":"turn_on_light","entity":"Kitchen","params":{"brightness":"80"},"reply":"Turning on Kitchen."}`, nil return `{"intent":"turn_on_light","entity":"Kitchen","params":{"brightness":"80"},"reply":"Turning on Kitchen."}`, nil
}, },
}, ha, cache, slog.Default()) }, ha, cache, "llama3:latest", slog.Default())
got, err := app.Query(context.Background(), "turn on kitchen") got, err := app.Query(context.Background(), "turn on kitchen", "")
if err != nil { if err != nil {
t.Fatalf("Query() error = %v", err) t.Fatalf("Query() error = %v", err)
} }
@ -77,12 +88,12 @@ func TestQueryAppTurnOnLight(t *testing.T) {
func TestQueryAppInvalidJSON(t *testing.T) { func TestQueryAppInvalidJSON(t *testing.T) {
ha := &fakeHA{lights: []driven.Light{{EntityID: "light.kitchen", FriendlyName: "Kitchen", State: "off"}}} ha := &fakeHA{lights: []driven.Light{{EntityID: "light.kitchen", FriendlyName: "Kitchen", State: "off"}}}
app := NewQueryApp(&fakeLLM{ app := NewQueryApp(&fakeLLM{
generate: func(ctx context.Context, prompt string) (string, error) { generate: func(ctx context.Context, model, prompt string) (string, error) {
return `not-json`, nil return `not-json`, nil
}, },
}, ha, domain.NewLightCache(time.Hour, ha.ListLights), slog.Default()) }, ha, domain.NewLightCache(time.Hour, ha.ListLights), "llama3:latest", slog.Default())
got, err := app.Query(context.Background(), "turn on kitchen") got, err := app.Query(context.Background(), "turn on kitchen", "")
if err != nil { if err != nil {
t.Fatalf("Query() error = %v", err) t.Fatalf("Query() error = %v", err)
} }
@ -97,12 +108,12 @@ func TestQueryAppInvalidJSON(t *testing.T) {
func TestQueryAppIntentNone(t *testing.T) { func TestQueryAppIntentNone(t *testing.T) {
ha := &fakeHA{lights: []driven.Light{{EntityID: "light.kitchen", FriendlyName: "Kitchen", State: "off"}}} ha := &fakeHA{lights: []driven.Light{{EntityID: "light.kitchen", FriendlyName: "Kitchen", State: "off"}}}
app := NewQueryApp(&fakeLLM{ app := NewQueryApp(&fakeLLM{
generate: func(ctx context.Context, prompt string) (string, error) { generate: func(ctx context.Context, model, prompt string) (string, error) {
return `{"intent":"none","entity":"","params":{},"reply":"Hello there."}`, nil return `{"intent":"none","entity":"","params":{},"reply":"Hello there."}`, nil
}, },
}, ha, domain.NewLightCache(time.Hour, ha.ListLights), slog.Default()) }, ha, domain.NewLightCache(time.Hour, ha.ListLights), "llama3:latest", slog.Default())
got, err := app.Query(context.Background(), "hello") got, err := app.Query(context.Background(), "hello", "")
if err != nil { if err != nil {
t.Fatalf("Query() error = %v", err) t.Fatalf("Query() error = %v", err)
} }
@ -117,12 +128,12 @@ func TestQueryAppHAFailure(t *testing.T) {
turnOnErr: errors.New("boom"), turnOnErr: errors.New("boom"),
} }
app := NewQueryApp(&fakeLLM{ app := NewQueryApp(&fakeLLM{
generate: func(ctx context.Context, prompt string) (string, error) { generate: func(ctx context.Context, model, prompt string) (string, error) {
return `{"intent":"turn_on_light","entity":"light.kitchen","params":{},"reply":"Turning on Kitchen."}`, nil return `{"intent":"turn_on_light","entity":"light.kitchen","params":{},"reply":"Turning on Kitchen."}`, nil
}, },
}, ha, domain.NewLightCache(time.Hour, ha.ListLights), slog.Default()) }, ha, domain.NewLightCache(time.Hour, ha.ListLights), "llama3:latest", slog.Default())
got, err := app.Query(context.Background(), "turn on kitchen") got, err := app.Query(context.Background(), "turn on kitchen", "")
if err != nil { if err != nil {
t.Fatalf("Query() error = %v", err) t.Fatalf("Query() error = %v", err)
} }
@ -134,12 +145,12 @@ func TestQueryAppHAFailure(t *testing.T) {
func TestQueryAppListLights(t *testing.T) { func TestQueryAppListLights(t *testing.T) {
ha := &fakeHA{lights: []driven.Light{{EntityID: "light.kitchen", FriendlyName: "Kitchen", State: "on"}}} ha := &fakeHA{lights: []driven.Light{{EntityID: "light.kitchen", FriendlyName: "Kitchen", State: "on"}}}
app := NewQueryApp(&fakeLLM{ app := NewQueryApp(&fakeLLM{
generate: func(ctx context.Context, prompt string) (string, error) { generate: func(ctx context.Context, model, prompt string) (string, error) {
return `{"intent":"list_lights","entity":"","params":{},"reply":""}`, nil return `{"intent":"list_lights","entity":"","params":{},"reply":""}`, nil
}, },
}, ha, domain.NewLightCache(time.Hour, ha.ListLights), slog.Default()) }, ha, domain.NewLightCache(time.Hour, ha.ListLights), "llama3:latest", slog.Default())
got, err := app.Query(context.Background(), "what lights exist") got, err := app.Query(context.Background(), "what lights exist", "")
if err != nil { if err != nil {
t.Fatalf("Query() error = %v", err) t.Fatalf("Query() error = %v", err)
} }
@ -164,3 +175,42 @@ func TestLightCacheRefreshAfterTTL(t *testing.T) {
t.Fatalf("ListLights calls = %d, want at least 2", ha.listCalls) t.Fatalf("ListLights calls = %d, want at least 2", ha.listCalls)
} }
} }
func TestQueryAppExplicitModel(t *testing.T) {
ha := &fakeHA{lights: []driven.Light{{EntityID: "light.kitchen", FriendlyName: "Kitchen", State: "off"}}}
app := NewQueryApp(&fakeLLM{
generate: func(ctx context.Context, model, prompt string) (string, error) {
if model != "qwen3:latest" {
t.Fatalf("Generate() model = %q", model)
}
return `{"intent":"none","entity":"","params":{},"reply":"Hello there."}`, nil
},
}, ha, domain.NewLightCache(time.Hour, ha.ListLights), "llama3:latest", slog.Default())
got, err := app.Query(context.Background(), "hello", "qwen3:latest")
if err != nil {
t.Fatalf("Query() error = %v", err)
}
if got.ModelUsed != "qwen3:latest" {
t.Fatalf("ModelUsed = %q", got.ModelUsed)
}
}
func TestQueryAppListModels(t *testing.T) {
app := NewQueryApp(&fakeLLM{
generate: func(ctx context.Context, model, prompt string) (string, error) {
return "", nil
},
listModels: func(ctx context.Context) ([]string, error) {
return []string{"llama3:latest", "qwen3:latest"}, nil
},
}, &fakeHA{}, domain.NewLightCache(time.Hour, func(context.Context) ([]driven.Light, error) { return nil, nil }), "llama3:latest", slog.Default())
got, err := app.ListModels(context.Background())
if err != nil {
t.Fatalf("ListModels() error = %v", err)
}
if !reflect.DeepEqual(got, []string{"llama3:latest", "qwen3:latest"}) {
t.Fatalf("ListModels() = %#v", got)
}
}

View File

@ -4,5 +4,6 @@ import "context"
// LLMClient generates one model response for a prompt. // LLMClient generates one model response for a prompt.
type LLMClient interface { type LLMClient interface {
Generate(ctx context.Context, prompt string) (string, error) Generate(ctx context.Context, model, prompt string) (string, error)
ListModels(ctx context.Context) ([]string, error)
} }

View File

@ -17,6 +17,8 @@ import (
"gitea.nik4nao.com/nik/home-services/discord-bot/internal/app" "gitea.nik4nao.com/nik/home-services/discord-bot/internal/app"
"gitea.nik4nao.com/nik/home-services/discord-bot/internal/config" "gitea.nik4nao.com/nik/home-services/discord-bot/internal/config"
"gitea.nik4nao.com/nik/home-services/discord-bot/internal/logger" "gitea.nik4nao.com/nik/home-services/discord-bot/internal/logger"
"gitea.nik4nao.com/nik/home-services/discord-bot/internal/modelstore"
"gitea.nik4nao.com/nik/home-services/discord-bot/internal/modelvalidator"
"gitea.nik4nao.com/nik/home-services/discord-bot/internal/telemetry" "gitea.nik4nao.com/nik/home-services/discord-bot/internal/telemetry"
) )
@ -84,7 +86,9 @@ func main() {
} }
}() }()
commandApp := app.NewCommandApp(haClient, aiClient) modelStore := modelstore.New()
validator := modelvalidator.New(aiClient, 30*time.Second)
commandApp := app.NewCommandApp(haClient, aiClient, modelStore, validator)
// Discord-specific wiring stays at the edge so the app layer remains transport-agnostic. // Discord-specific wiring stays at the edge so the app layer remains transport-agnostic.
session, err := discordgo.New("Bot " + cfg.DiscordToken) session, err := discordgo.New("Bot " + cfg.DiscordToken)

View File

@ -24,8 +24,12 @@ type commandHandler interface {
HandleLightToggle(ctx context.Context, entityID string) (string, error) HandleLightToggle(ctx context.Context, entityID string) (string, error)
HandleSwitchList(ctx context.Context) (string, error) HandleSwitchList(ctx context.Context) (string, error)
HandleAIQuery(ctx context.Context, text string) (string, error) HandleAIQuery(ctx context.Context, text string) (string, error)
HandleAIModelSet(ctx context.Context, name string) (string, error)
HandleAIModelGet(ctx context.Context) (string, error)
HandleAIModelList(ctx context.Context) (string, error)
AutocompleteLights(ctx context.Context) ([]apppkg.Choice, error) AutocompleteLights(ctx context.Context) ([]apppkg.Choice, error)
AutocompleteSwitches(ctx context.Context) ([]apppkg.Choice, error) AutocompleteSwitches(ctx context.Context) ([]apppkg.Choice, error)
AutocompleteAIModels(ctx context.Context) ([]apppkg.Choice, error)
} }
// Handler adapts Discord interactions to the command application layer. // Handler adapts Discord interactions to the command application layer.
@ -69,6 +73,9 @@ func (h *Handler) handleApplicationCommand(ctx context.Context, s *discordgo.Ses
command := data.Name command := data.Name
if len(data.Options) > 0 { if len(data.Options) > 0 {
command += "." + data.Options[0].Name command += "." + data.Options[0].Name
if len(data.Options[0].Options) > 0 && data.Options[0].Type == discordgo.ApplicationCommandOptionSubCommandGroup {
command += "." + data.Options[0].Options[0].Name
}
} }
user := "" user := ""
if i.Member != nil && i.Member.User != nil { if i.Member != nil && i.Member.User != nil {
@ -90,7 +97,18 @@ func (h *Handler) handleApplicationCommand(ctx context.Context, s *discordgo.Ses
} }
sub := data.Options[0] sub := data.Options[0]
switch data.Name + "." + sub.Name { commandPath := data.Name + "." + sub.Name
target := sub
if sub.Type == discordgo.ApplicationCommandOptionSubCommandGroup {
if len(sub.Options) == 0 {
h.respondError(ctx, s, i.Interaction, true, start, fmt.Errorf("missing grouped subcommand"))
return
}
target = sub.Options[0]
commandPath += "." + target.Name
}
switch commandPath {
case "light.list": case "light.list":
msg, err := h.app.HandleLightList(ctx) msg, err := h.app.HandleLightList(ctx)
if err != nil { if err != nil {
@ -145,10 +163,38 @@ func (h *Handler) handleApplicationCommand(ctx context.Context, s *discordgo.Ses
) )
return return
} }
msg, err := h.app.HandleAIQuery(ctx, requiredStringOption(sub, "text")) msg, err := h.app.HandleAIQuery(ctx, requiredStringOption(target, "text"))
h.followup(ctx, s, i.Interaction, msg, true, start, err)
case "ai.model.set":
if err := h.deferResponse(s, i.Interaction, true); err != nil {
log.Error("discord response failed",
"duration_ms", time.Since(start).Milliseconds(),
"error", err.Error(),
)
return
}
msg, err := h.app.HandleAIModelSet(ctx, requiredStringOption(target, "name"))
h.followup(ctx, s, i.Interaction, msg, true, start, err)
case "ai.model.get":
msg, err := h.app.HandleAIModelGet(ctx)
if err != nil {
h.respondError(ctx, s, i.Interaction, true, start, err)
return
}
h.respondMessage(ctx, s, i.Interaction, msg, true)
log.Info("command handled", "duration_ms", time.Since(start).Milliseconds())
case "ai.model.list":
if err := h.deferResponse(s, i.Interaction, true); err != nil {
log.Error("discord response failed",
"duration_ms", time.Since(start).Milliseconds(),
"error", err.Error(),
)
return
}
msg, err := h.app.HandleAIModelList(ctx)
h.followup(ctx, s, i.Interaction, msg, true, start, err) h.followup(ctx, s, i.Interaction, msg, true, start, err)
default: default:
h.respondError(ctx, s, i.Interaction, true, start, fmt.Errorf("unsupported command: %s.%s", data.Name, sub.Name)) h.respondError(ctx, s, i.Interaction, true, start, fmt.Errorf("unsupported command: %s", commandPath))
} }
} }
@ -167,6 +213,10 @@ func (h *Handler) handleAutocomplete(ctx context.Context, s *discordgo.Session,
choices, err = h.app.AutocompleteLights(ctx) choices, err = h.app.AutocompleteLights(ctx)
case "switch": case "switch":
choices, err = h.app.AutocompleteSwitches(ctx) choices, err = h.app.AutocompleteSwitches(ctx)
case "ai":
if focusedOptionName(data) == "name" {
choices, err = h.app.AutocompleteAIModels(ctx)
}
default: default:
choices = nil choices = nil
} }
@ -299,6 +349,15 @@ func optionalUint32Option(sub *discordgo.ApplicationCommandInteractionDataOption
func focusedOptionValue(data discordgo.ApplicationCommandInteractionData) string { func focusedOptionValue(data discordgo.ApplicationCommandInteractionData) string {
for _, sub := range data.Options { for _, sub := range data.Options {
if sub.Type == discordgo.ApplicationCommandOptionSubCommandGroup {
for _, nested := range sub.Options {
for _, opt := range nested.Options {
if opt.Focused {
return opt.StringValue()
}
}
}
}
for _, opt := range sub.Options { for _, opt := range sub.Options {
if opt.Focused { if opt.Focused {
return opt.StringValue() return opt.StringValue()
@ -307,3 +366,23 @@ func focusedOptionValue(data discordgo.ApplicationCommandInteractionData) string
} }
return "" return ""
} }
func focusedOptionName(data discordgo.ApplicationCommandInteractionData) string {
for _, sub := range data.Options {
if sub.Type == discordgo.ApplicationCommandOptionSubCommandGroup {
for _, nested := range sub.Options {
for _, opt := range nested.Options {
if opt.Focused {
return opt.Name
}
}
}
}
for _, opt := range sub.Options {
if opt.Focused {
return opt.Name
}
}
}
return ""
}

View File

@ -98,6 +98,37 @@ func RegisterCommands(s *discordgo.Session, guildID string) error {
}, },
}, },
}, },
{
Type: discordgo.ApplicationCommandOptionSubCommandGroup,
Name: "model",
Description: "Manage the active AI model",
Options: []*discordgo.ApplicationCommandOption{
{
Type: discordgo.ApplicationCommandOptionSubCommand,
Name: "set",
Description: "Set the active model",
Options: []*discordgo.ApplicationCommandOption{
{
Type: discordgo.ApplicationCommandOptionString,
Name: "name",
Description: "Model name",
Required: true,
Autocomplete: true,
},
},
},
{
Type: discordgo.ApplicationCommandOptionSubCommand,
Name: "get",
Description: "Show the active model",
},
{
Type: discordgo.ApplicationCommandOptionSubCommand,
Name: "list",
Description: "List available models",
},
},
},
}, },
}, },
} }

View File

@ -61,22 +61,39 @@ func (c *Client) Close() error {
} }
// Query forwards one free-form request to ai-gateway. // Query forwards one free-form request to ai-gateway.
func (c *Client) Query(ctx context.Context, text string) (string, error) { func (c *Client) Query(ctx context.Context, text, model string) (string, string, error) {
start := time.Now() start := time.Now()
log := logger.FromContext(ctx).With("grpc.method", "AIService/Query") log := logger.FromContext(ctx).With("grpc.method", "AIService/Query")
resp, err := c.client.Query(ctx, &aiv1.QueryRequest{ resp, err := c.client.Query(ctx, &aiv1.QueryRequest{
Text: text, Text: text,
Source: "discord-bot", Source: "discord-bot",
Model: model,
}) })
if err != nil { if err != nil {
log.Error("grpc call failed", log.Error("grpc call failed",
"duration_ms", time.Since(start).Milliseconds(), "duration_ms", time.Since(start).Milliseconds(),
"error", err.Error(), "error", err.Error(),
) )
return "", fmt.Errorf("query ai-gateway: %w", err) return "", "", fmt.Errorf("query ai-gateway: %w", err)
} }
log.Debug("grpc call completed", "duration_ms", time.Since(start).Milliseconds()) log.Debug("grpc call completed", "duration_ms", time.Since(start).Milliseconds())
return resp.GetReply(), nil return resp.GetReply(), resp.GetModelUsed(), nil
}
// ListModels returns the installed model names from ai-gateway.
func (c *Client) ListModels(ctx context.Context) ([]string, error) {
start := time.Now()
log := logger.FromContext(ctx).With("grpc.method", "AIService/ListModels")
resp, err := c.client.ListModels(ctx, &aiv1.ListModelsRequest{})
if err != nil {
log.Error("grpc call failed",
"duration_ms", time.Since(start).Milliseconds(),
"error", err.Error(),
)
return nil, fmt.Errorf("list ai-gateway models: %w", err)
}
log.Debug("grpc call completed", "duration_ms", time.Since(start).Milliseconds())
return append([]string(nil), resp.GetNames()...), nil
} }
func loadTransportCredentials(tlsDir string) (credentials.TransportCredentials, error) { func loadTransportCredentials(tlsDir string) (credentials.TransportCredentials, error) {

View File

@ -7,6 +7,8 @@ import (
"strings" "strings"
"gitea.nik4nao.com/nik/home-services/discord-bot/internal/core/ports/driven" "gitea.nik4nao.com/nik/home-services/discord-bot/internal/core/ports/driven"
"gitea.nik4nao.com/nik/home-services/discord-bot/internal/modelstore"
"gitea.nik4nao.com/nik/home-services/discord-bot/internal/modelvalidator"
) )
// Choice is one Discord autocomplete entry. // Choice is one Discord autocomplete entry.
@ -19,11 +21,13 @@ type Choice struct {
type CommandApp struct { type CommandApp struct {
ha driven.HAGateway ha driven.HAGateway
ai driven.AIGateway ai driven.AIGateway
models *modelstore.Store
validator *modelvalidator.Validator
} }
// NewCommandApp constructs the Discord command application service. // NewCommandApp constructs the Discord command application service.
func NewCommandApp(ha driven.HAGateway, ai driven.AIGateway) *CommandApp { func NewCommandApp(ha driven.HAGateway, ai driven.AIGateway, models *modelstore.Store, validator *modelvalidator.Validator) *CommandApp {
return &CommandApp{ha: ha, ai: ai} return &CommandApp{ha: ha, ai: ai, models: models, validator: validator}
} }
// HandleLightList formats discovered lights into a monospace-friendly response. // HandleLightList formats discovered lights into a monospace-friendly response.
@ -124,11 +128,72 @@ func (a *CommandApp) HandleSwitchList(ctx context.Context) (string, error) {
// HandleAIQuery forwards a free-form request to ai-gateway. // HandleAIQuery forwards a free-form request to ai-gateway.
func (a *CommandApp) HandleAIQuery(ctx context.Context, text string) (string, error) { func (a *CommandApp) HandleAIQuery(ctx context.Context, text string) (string, error) {
reply, err := a.ai.Query(ctx, text) reply, modelUsed, err := a.ai.Query(ctx, text, a.models.Get())
if err != nil { if err != nil {
return "", fmt.Errorf("handle ai query: %w", err) return "", fmt.Errorf("handle ai query: %w", err)
} }
return reply, nil return fmt.Sprintf("%s\n\n_(via %s)_", reply, modelUsed), nil
}
// HandleAIModelSet validates and stores the selected model globally.
func (a *CommandApp) HandleAIModelSet(ctx context.Context, name string) (string, error) {
canonical, err := a.validator.Normalize(ctx, name)
if err != nil {
if err.Error() == "ambiguous model name" {
return "", fmt.Errorf("unknown model: %s. Be more specific.", name)
}
if err.Error() == "unknown model" {
return "", fmt.Errorf("unknown model: %s. Try /ai model list.", name)
}
return "", fmt.Errorf("validate model: %w", err)
}
a.models.Set(canonical)
return fmt.Sprintf("Active model set to `%s`.", canonical), nil
}
// HandleAIModelGet reports the current selected model or default state.
func (a *CommandApp) HandleAIModelGet(ctx context.Context) (string, error) {
cur := a.models.Get()
if cur == "" {
return "No model override set. Using ai-gateway default.", nil
}
return fmt.Sprintf("Active model: `%s`", cur), nil
}
// HandleAIModelList shows the installed models and marks the active selection.
func (a *CommandApp) HandleAIModelList(ctx context.Context) (string, error) {
models, err := a.validator.Known(ctx)
if err != nil {
return "", fmt.Errorf("list models: %w", err)
}
if len(models) == 0 {
return "No models installed on the Ollama host.", nil
}
active := a.models.Get()
lines := make([]string, 0, len(models)+1)
lines = append(lines, "Available models:")
for _, model := range models {
marker := ""
if model == active {
marker = " <- active"
}
lines = append(lines, fmt.Sprintf("- `%s`%s", model, marker))
}
return strings.Join(lines, "\n"), nil
}
// AutocompleteAIModels returns model names for the /ai model set command.
func (a *CommandApp) AutocompleteAIModels(ctx context.Context) ([]Choice, error) {
models, err := a.validator.Known(ctx)
if err != nil {
return nil, fmt.Errorf("autocomplete ai models: %w", err)
}
choices := make([]Choice, 0, len(models))
for _, model := range models {
choices = append(choices, Choice{Label: model, Value: model})
}
return choices, nil
} }
// AutocompleteLights maps discovered lights into Discord autocomplete choices. // AutocompleteLights maps discovered lights into Discord autocomplete choices.

View File

@ -5,8 +5,11 @@ import (
"errors" "errors"
"reflect" "reflect"
"testing" "testing"
"time"
"gitea.nik4nao.com/nik/home-services/discord-bot/internal/core/ports/driven" "gitea.nik4nao.com/nik/home-services/discord-bot/internal/core/ports/driven"
"gitea.nik4nao.com/nik/home-services/discord-bot/internal/modelstore"
"gitea.nik4nao.com/nik/home-services/discord-bot/internal/modelvalidator"
) )
type mockHAGateway struct { type mockHAGateway struct {
@ -18,7 +21,8 @@ type mockHAGateway struct {
} }
type mockAIGateway struct { type mockAIGateway struct {
queryFunc func(ctx context.Context, text string) (string, error) queryFunc func(ctx context.Context, text, model string) (string, string, error)
listModelsFunc func(ctx context.Context) ([]string, error)
} }
func (m *mockHAGateway) ListLights(ctx context.Context) ([]driven.Light, error) { func (m *mockHAGateway) ListLights(ctx context.Context) ([]driven.Light, error) {
@ -56,11 +60,22 @@ func (m *mockHAGateway) ToggleLight(ctx context.Context, entityID string) error
return m.toggleLightFunc(ctx, entityID) return m.toggleLightFunc(ctx, entityID)
} }
func (m *mockAIGateway) Query(ctx context.Context, text string) (string, error) { func (m *mockAIGateway) Query(ctx context.Context, text, model string) (string, string, error) {
if m.queryFunc == nil { if m.queryFunc == nil {
return "", nil return "", "", nil
} }
return m.queryFunc(ctx, text) return m.queryFunc(ctx, text, model)
}
func (m *mockAIGateway) ListModels(ctx context.Context) ([]string, error) {
if m.listModelsFunc == nil {
return nil, nil
}
return m.listModelsFunc(ctx)
}
func newTestCommandApp(ha *mockHAGateway, ai *mockAIGateway) *CommandApp {
return NewCommandApp(ha, ai, modelstore.New(), modelvalidator.New(ai, time.Minute))
} }
func TestCommandAppHandleLightList(t *testing.T) { func TestCommandAppHandleLightList(t *testing.T) {
@ -102,7 +117,7 @@ func TestCommandAppHandleLightList(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
app := NewCommandApp(&mockHAGateway{ app := newTestCommandApp(&mockHAGateway{
listLightsFunc: func(ctx context.Context) ([]driven.Light, error) { listLightsFunc: func(ctx context.Context) ([]driven.Light, error) {
return tt.lights, nil return tt.lights, nil
}, },
@ -183,7 +198,7 @@ func TestCommandAppHandleLightOn(t *testing.T) {
var gotBrightness *uint32 var gotBrightness *uint32
var gotColorTemp *uint32 var gotColorTemp *uint32
app := NewCommandApp(&mockHAGateway{ app := newTestCommandApp(&mockHAGateway{
listLightsFunc: func(ctx context.Context) ([]driven.Light, error) { listLightsFunc: func(ctx context.Context) ([]driven.Light, error) {
if tt.listErr != nil { if tt.listErr != nil {
return nil, tt.listErr return nil, tt.listErr
@ -258,7 +273,7 @@ func TestCommandAppHandleLightOff(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
var gotTransition *uint32 var gotTransition *uint32
app := NewCommandApp(&mockHAGateway{ app := newTestCommandApp(&mockHAGateway{
listLightsFunc: func(ctx context.Context) ([]driven.Light, error) { listLightsFunc: func(ctx context.Context) ([]driven.Light, error) {
return tt.lights, nil return tt.lights, nil
}, },
@ -317,7 +332,7 @@ func TestCommandAppHandleLightToggle(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
app := NewCommandApp(&mockHAGateway{ app := newTestCommandApp(&mockHAGateway{
listLightsFunc: func(ctx context.Context) ([]driven.Light, error) { listLightsFunc: func(ctx context.Context) ([]driven.Light, error) {
if tt.listErr != nil { if tt.listErr != nil {
return nil, tt.listErr return nil, tt.listErr
@ -368,7 +383,7 @@ func TestCommandAppHandleSwitchList(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
app := NewCommandApp(&mockHAGateway{ app := newTestCommandApp(&mockHAGateway{
listSwitchesFunc: func(ctx context.Context) ([]driven.Switch, error) { listSwitchesFunc: func(ctx context.Context) ([]driven.Switch, error) {
return tt.switches, nil return tt.switches, nil
}, },
@ -413,7 +428,7 @@ func TestCommandAppAutocompleteLights(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
app := NewCommandApp(&mockHAGateway{ app := newTestCommandApp(&mockHAGateway{
listLightsFunc: func(ctx context.Context) ([]driven.Light, error) { listLightsFunc: func(ctx context.Context) ([]driven.Light, error) {
if tt.listErr != nil { if tt.listErr != nil {
return nil, tt.listErr return nil, tt.listErr
@ -467,7 +482,7 @@ func TestCommandAppAutocompleteSwitches(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
app := NewCommandApp(&mockHAGateway{ app := newTestCommandApp(&mockHAGateway{
listSwitchesFunc: func(ctx context.Context) ([]driven.Switch, error) { listSwitchesFunc: func(ctx context.Context) ([]driven.Switch, error) {
if tt.listErr != nil { if tt.listErr != nil {
return nil, tt.listErr return nil, tt.listErr
@ -494,20 +509,41 @@ func TestCommandAppAutocompleteSwitches(t *testing.T) {
} }
func TestCommandAppHandleAIQuery(t *testing.T) { func TestCommandAppHandleAIQuery(t *testing.T) {
store := modelstore.New()
store.Set("llama3:latest")
app := NewCommandApp(&mockHAGateway{}, &mockAIGateway{ app := NewCommandApp(&mockHAGateway{}, &mockAIGateway{
queryFunc: func(ctx context.Context, text string) (string, error) { queryFunc: func(ctx context.Context, text, model string) (string, string, error) {
if text != "turn on kitchen" { if text != "turn on kitchen" {
t.Fatalf("Query() text = %q", text) t.Fatalf("Query() text = %q", text)
} }
return "Turning on Kitchen.", nil if model != "llama3:latest" {
t.Fatalf("Query() model = %q", model)
}
return "Turning on Kitchen.", "llama3:latest", nil
}, },
}) }, store, modelvalidator.New(&mockAIGateway{}, time.Minute))
got, err := app.HandleAIQuery(context.Background(), "turn on kitchen") got, err := app.HandleAIQuery(context.Background(), "turn on kitchen")
if err != nil { if err != nil {
t.Fatalf("HandleAIQuery() error = %v", err) t.Fatalf("HandleAIQuery() error = %v", err)
} }
if got != "Turning on Kitchen." { if got != "Turning on Kitchen.\n\n_(via llama3:latest)_" {
t.Fatalf("HandleAIQuery() = %q", got) t.Fatalf("HandleAIQuery() = %q", got)
} }
} }
func TestCommandAppHandleAIModelSet(t *testing.T) {
app := newTestCommandApp(&mockHAGateway{}, &mockAIGateway{
listModelsFunc: func(ctx context.Context) ([]string, error) {
return []string{"llama3:latest"}, nil
},
})
got, err := app.HandleAIModelSet(context.Background(), "llama3")
if err != nil {
t.Fatalf("HandleAIModelSet() error = %v", err)
}
if got != "Active model set to `llama3:latest`." {
t.Fatalf("HandleAIModelSet() = %q", got)
}
}

View File

@ -4,5 +4,6 @@ import "context"
// AIGateway exposes the free-form AI query API used by the Discord bot. // AIGateway exposes the free-form AI query API used by the Discord bot.
type AIGateway interface { type AIGateway interface {
Query(ctx context.Context, text string) (string, error) Query(ctx context.Context, text, model string) (reply, modelUsed string, err error)
ListModels(ctx context.Context) ([]string, error)
} }

View File

@ -0,0 +1,28 @@
package modelstore
import "sync"
// Store keeps the globally selected AI model in memory.
type Store struct {
mu sync.RWMutex
selected string
}
// New constructs an empty in-memory model store.
func New() *Store {
return &Store{}
}
// Get returns the currently selected model, or empty for default behavior.
func (s *Store) Get() string {
s.mu.RLock()
defer s.mu.RUnlock()
return s.selected
}
// Set updates the currently selected model.
func (s *Store) Set(model string) {
s.mu.Lock()
defer s.mu.Unlock()
s.selected = model
}

View File

@ -0,0 +1,15 @@
package modelstore
import "testing"
func TestStoreGetSet(t *testing.T) {
store := New()
if got := store.Get(); got != "" {
t.Fatalf("Get() = %q, want empty", got)
}
store.Set("llama3:latest")
if got := store.Get(); got != "llama3:latest" {
t.Fatalf("Get() = %q, want llama3:latest", got)
}
}

View File

@ -0,0 +1,86 @@
package modelvalidator
import (
"context"
"fmt"
"strings"
"sync"
"time"
"gitea.nik4nao.com/nik/home-services/discord-bot/internal/core/ports/driven"
)
// Validator caches the model list briefly and normalizes friendly names.
type Validator struct {
client driven.AIGateway
ttl time.Duration
mu sync.Mutex
cache []string
cachedAt time.Time
}
// New constructs a model validator with a TTL cache.
func New(client driven.AIGateway, ttl time.Duration) *Validator {
return &Validator{client: client, ttl: ttl}
}
// Known returns the cached model list, refreshing when stale.
func (v *Validator) Known(ctx context.Context) ([]string, error) {
v.mu.Lock()
defer v.mu.Unlock()
if len(v.cache) > 0 && time.Since(v.cachedAt) < v.ttl {
return append([]string(nil), v.cache...), nil
}
models, err := v.client.ListModels(ctx)
if err != nil {
return nil, err
}
v.cache = append([]string(nil), models...)
v.cachedAt = time.Now()
return append([]string(nil), v.cache...), nil
}
// Normalize resolves a user-provided name to a canonical installed model name.
func (v *Validator) Normalize(ctx context.Context, name string) (string, error) {
models, err := v.Known(ctx)
if err != nil {
return "", err
}
for _, model := range models {
if model == name {
return model, nil
}
}
latest := name + ":latest"
for _, model := range models {
if model == latest {
return model, nil
}
}
lower := strings.ToLower(name)
for _, model := range models {
if strings.ToLower(model) == lower {
return model, nil
}
}
lowerLatest := strings.ToLower(latest)
for _, model := range models {
if strings.ToLower(model) == lowerLatest {
return model, nil
}
}
matches := make([]string, 0, 2)
prefix := lower + ":"
for _, model := range models {
if strings.HasPrefix(strings.ToLower(model), prefix) {
matches = append(matches, model)
}
}
if len(matches) > 1 {
return "", fmt.Errorf("ambiguous model name")
}
return "", fmt.Errorf("unknown model")
}

View File

@ -0,0 +1,70 @@
package modelvalidator
import (
"context"
"errors"
"testing"
"time"
)
type fakeAIGateway struct {
listModels func(context.Context) ([]string, error)
}
func (f *fakeAIGateway) Query(ctx context.Context, text, model string) (string, string, error) {
return "", "", nil
}
func (f *fakeAIGateway) ListModels(ctx context.Context) ([]string, error) {
return f.listModels(ctx)
}
func TestValidatorNormalize(t *testing.T) {
v := New(&fakeAIGateway{
listModels: func(ctx context.Context) ([]string, error) {
return []string{"llama3:latest", "qwen3:latest", "qwen3:4b"}, nil
},
}, time.Minute)
got, err := v.Normalize(context.Background(), "llama3")
if err != nil || got != "llama3:latest" {
t.Fatalf("Normalize(llama3) = %q, %v", got, err)
}
got, err = v.Normalize(context.Background(), "qwen3")
if err != nil || got != "qwen3:latest" {
t.Fatalf("Normalize(qwen3) = %q, %v", got, err)
}
}
func TestValidatorKnownCaches(t *testing.T) {
calls := 0
v := New(&fakeAIGateway{
listModels: func(ctx context.Context) ([]string, error) {
calls++
return []string{"llama3:latest"}, nil
},
}, time.Minute)
if _, err := v.Known(context.Background()); err != nil {
t.Fatalf("Known() error = %v", err)
}
if _, err := v.Known(context.Background()); err != nil {
t.Fatalf("Known() error = %v", err)
}
if calls != 1 {
t.Fatalf("calls = %d, want 1", calls)
}
}
func TestValidatorKnownPropagatesError(t *testing.T) {
v := New(&fakeAIGateway{
listModels: func(ctx context.Context) ([]string, error) {
return nil, errors.New("boom")
},
}, time.Minute)
if _, err := v.Known(context.Background()); err == nil {
t.Fatal("Known() error = nil, want error")
}
}

View File

@ -25,6 +25,7 @@ type QueryRequest struct {
state protoimpl.MessageState `protogen:"open.v1"` state protoimpl.MessageState `protogen:"open.v1"`
Text string `protobuf:"bytes,1,opt,name=text,proto3" json:"text,omitempty"` Text string `protobuf:"bytes,1,opt,name=text,proto3" json:"text,omitempty"`
Source string `protobuf:"bytes,2,opt,name=source,proto3" json:"source,omitempty"` Source string `protobuf:"bytes,2,opt,name=source,proto3" json:"source,omitempty"`
Model string `protobuf:"bytes,3,opt,name=model,proto3" json:"model,omitempty"`
unknownFields protoimpl.UnknownFields unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache sizeCache protoimpl.SizeCache
} }
@ -73,11 +74,19 @@ func (x *QueryRequest) GetSource() string {
return "" return ""
} }
func (x *QueryRequest) GetModel() string {
if x != nil {
return x.Model
}
return ""
}
type QueryResponse struct { type QueryResponse struct {
state protoimpl.MessageState `protogen:"open.v1"` state protoimpl.MessageState `protogen:"open.v1"`
Reply string `protobuf:"bytes,1,opt,name=reply,proto3" json:"reply,omitempty"` Reply string `protobuf:"bytes,1,opt,name=reply,proto3" json:"reply,omitempty"`
Intent string `protobuf:"bytes,2,opt,name=intent,proto3" json:"intent,omitempty"` Intent string `protobuf:"bytes,2,opt,name=intent,proto3" json:"intent,omitempty"`
ActionTaken bool `protobuf:"varint,3,opt,name=action_taken,json=actionTaken,proto3" json:"action_taken,omitempty"` ActionTaken bool `protobuf:"varint,3,opt,name=action_taken,json=actionTaken,proto3" json:"action_taken,omitempty"`
ModelUsed string `protobuf:"bytes,4,opt,name=model_used,json=modelUsed,proto3" json:"model_used,omitempty"`
unknownFields protoimpl.UnknownFields unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache sizeCache protoimpl.SizeCache
} }
@ -133,20 +142,115 @@ func (x *QueryResponse) GetActionTaken() bool {
return false return false
} }
func (x *QueryResponse) GetModelUsed() string {
if x != nil {
return x.ModelUsed
}
return ""
}
type ListModelsRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *ListModelsRequest) Reset() {
*x = ListModelsRequest{}
mi := &file_ai_v1_ai_proto_msgTypes[2]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *ListModelsRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*ListModelsRequest) ProtoMessage() {}
func (x *ListModelsRequest) ProtoReflect() protoreflect.Message {
mi := &file_ai_v1_ai_proto_msgTypes[2]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use ListModelsRequest.ProtoReflect.Descriptor instead.
func (*ListModelsRequest) Descriptor() ([]byte, []int) {
return file_ai_v1_ai_proto_rawDescGZIP(), []int{2}
}
type ListModelsResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
Names []string `protobuf:"bytes,1,rep,name=names,proto3" json:"names,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *ListModelsResponse) Reset() {
*x = ListModelsResponse{}
mi := &file_ai_v1_ai_proto_msgTypes[3]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *ListModelsResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*ListModelsResponse) ProtoMessage() {}
func (x *ListModelsResponse) ProtoReflect() protoreflect.Message {
mi := &file_ai_v1_ai_proto_msgTypes[3]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use ListModelsResponse.ProtoReflect.Descriptor instead.
func (*ListModelsResponse) Descriptor() ([]byte, []int) {
return file_ai_v1_ai_proto_rawDescGZIP(), []int{3}
}
func (x *ListModelsResponse) GetNames() []string {
if x != nil {
return x.Names
}
return nil
}
var File_ai_v1_ai_proto protoreflect.FileDescriptor var File_ai_v1_ai_proto protoreflect.FileDescriptor
const file_ai_v1_ai_proto_rawDesc = "" + const file_ai_v1_ai_proto_rawDesc = "" +
"\n" + "\n" +
"\x0eai/v1/ai.proto\x12\x05ai.v1\":\n" + "\x0eai/v1/ai.proto\x12\x05ai.v1\"P\n" +
"\fQueryRequest\x12\x12\n" + "\fQueryRequest\x12\x12\n" +
"\x04text\x18\x01 \x01(\tR\x04text\x12\x16\n" + "\x04text\x18\x01 \x01(\tR\x04text\x12\x16\n" +
"\x06source\x18\x02 \x01(\tR\x06source\"`\n" + "\x06source\x18\x02 \x01(\tR\x06source\x12\x14\n" +
"\x05model\x18\x03 \x01(\tR\x05model\"\x7f\n" +
"\rQueryResponse\x12\x14\n" + "\rQueryResponse\x12\x14\n" +
"\x05reply\x18\x01 \x01(\tR\x05reply\x12\x16\n" + "\x05reply\x18\x01 \x01(\tR\x05reply\x12\x16\n" +
"\x06intent\x18\x02 \x01(\tR\x06intent\x12!\n" + "\x06intent\x18\x02 \x01(\tR\x06intent\x12!\n" +
"\faction_taken\x18\x03 \x01(\bR\vactionTaken2?\n" + "\faction_taken\x18\x03 \x01(\bR\vactionTaken\x12\x1d\n" +
"\n" +
"model_used\x18\x04 \x01(\tR\tmodelUsed\"\x13\n" +
"\x11ListModelsRequest\"*\n" +
"\x12ListModelsResponse\x12\x14\n" +
"\x05names\x18\x01 \x03(\tR\x05names2\x82\x01\n" +
"\tAIService\x122\n" + "\tAIService\x122\n" +
"\x05Query\x12\x13.ai.v1.QueryRequest\x1a\x14.ai.v1.QueryResponseB4Z2gitea.nik4nao.com/nik/home-services/gen/ai/v1;aiv1b\x06proto3" "\x05Query\x12\x13.ai.v1.QueryRequest\x1a\x14.ai.v1.QueryResponse\x12A\n" +
"\n" +
"ListModels\x12\x18.ai.v1.ListModelsRequest\x1a\x19.ai.v1.ListModelsResponseB4Z2gitea.nik4nao.com/nik/home-services/gen/ai/v1;aiv1b\x06proto3"
var ( var (
file_ai_v1_ai_proto_rawDescOnce sync.Once file_ai_v1_ai_proto_rawDescOnce sync.Once
@ -160,16 +264,20 @@ func file_ai_v1_ai_proto_rawDescGZIP() []byte {
return file_ai_v1_ai_proto_rawDescData return file_ai_v1_ai_proto_rawDescData
} }
var file_ai_v1_ai_proto_msgTypes = make([]protoimpl.MessageInfo, 2) var file_ai_v1_ai_proto_msgTypes = make([]protoimpl.MessageInfo, 4)
var file_ai_v1_ai_proto_goTypes = []any{ var file_ai_v1_ai_proto_goTypes = []any{
(*QueryRequest)(nil), // 0: ai.v1.QueryRequest (*QueryRequest)(nil), // 0: ai.v1.QueryRequest
(*QueryResponse)(nil), // 1: ai.v1.QueryResponse (*QueryResponse)(nil), // 1: ai.v1.QueryResponse
(*ListModelsRequest)(nil), // 2: ai.v1.ListModelsRequest
(*ListModelsResponse)(nil), // 3: ai.v1.ListModelsResponse
} }
var file_ai_v1_ai_proto_depIdxs = []int32{ var file_ai_v1_ai_proto_depIdxs = []int32{
0, // 0: ai.v1.AIService.Query:input_type -> ai.v1.QueryRequest 0, // 0: ai.v1.AIService.Query:input_type -> ai.v1.QueryRequest
1, // 1: ai.v1.AIService.Query:output_type -> ai.v1.QueryResponse 2, // 1: ai.v1.AIService.ListModels:input_type -> ai.v1.ListModelsRequest
1, // [1:2] is the sub-list for method output_type 1, // 2: ai.v1.AIService.Query:output_type -> ai.v1.QueryResponse
0, // [0:1] is the sub-list for method input_type 3, // 3: ai.v1.AIService.ListModels:output_type -> ai.v1.ListModelsResponse
2, // [2:4] is the sub-list for method output_type
0, // [0:2] is the sub-list for method input_type
0, // [0:0] is the sub-list for extension type_name 0, // [0:0] is the sub-list for extension type_name
0, // [0:0] is the sub-list for extension extendee 0, // [0:0] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name 0, // [0:0] is the sub-list for field type_name
@ -186,7 +294,7 @@ func file_ai_v1_ai_proto_init() {
GoPackagePath: reflect.TypeOf(x{}).PkgPath(), GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: unsafe.Slice(unsafe.StringData(file_ai_v1_ai_proto_rawDesc), len(file_ai_v1_ai_proto_rawDesc)), RawDescriptor: unsafe.Slice(unsafe.StringData(file_ai_v1_ai_proto_rawDesc), len(file_ai_v1_ai_proto_rawDesc)),
NumEnums: 0, NumEnums: 0,
NumMessages: 2, NumMessages: 4,
NumExtensions: 0, NumExtensions: 0,
NumServices: 1, NumServices: 1,
}, },

View File

@ -20,6 +20,7 @@ const _ = grpc.SupportPackageIsVersion9
const ( const (
AIService_Query_FullMethodName = "/ai.v1.AIService/Query" AIService_Query_FullMethodName = "/ai.v1.AIService/Query"
AIService_ListModels_FullMethodName = "/ai.v1.AIService/ListModels"
) )
// AIServiceClient is the client API for AIService service. // AIServiceClient is the client API for AIService service.
@ -27,6 +28,7 @@ const (
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. // For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
type AIServiceClient interface { type AIServiceClient interface {
Query(ctx context.Context, in *QueryRequest, opts ...grpc.CallOption) (*QueryResponse, error) Query(ctx context.Context, in *QueryRequest, opts ...grpc.CallOption) (*QueryResponse, error)
ListModels(ctx context.Context, in *ListModelsRequest, opts ...grpc.CallOption) (*ListModelsResponse, error)
} }
type aIServiceClient struct { type aIServiceClient struct {
@ -47,11 +49,22 @@ func (c *aIServiceClient) Query(ctx context.Context, in *QueryRequest, opts ...g
return out, nil return out, nil
} }
func (c *aIServiceClient) ListModels(ctx context.Context, in *ListModelsRequest, opts ...grpc.CallOption) (*ListModelsResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(ListModelsResponse)
err := c.cc.Invoke(ctx, AIService_ListModels_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
return out, nil
}
// AIServiceServer is the server API for AIService service. // AIServiceServer is the server API for AIService service.
// All implementations must embed UnimplementedAIServiceServer // All implementations must embed UnimplementedAIServiceServer
// for forward compatibility. // for forward compatibility.
type AIServiceServer interface { type AIServiceServer interface {
Query(context.Context, *QueryRequest) (*QueryResponse, error) Query(context.Context, *QueryRequest) (*QueryResponse, error)
ListModels(context.Context, *ListModelsRequest) (*ListModelsResponse, error)
mustEmbedUnimplementedAIServiceServer() mustEmbedUnimplementedAIServiceServer()
} }
@ -65,6 +78,9 @@ type UnimplementedAIServiceServer struct{}
func (UnimplementedAIServiceServer) Query(context.Context, *QueryRequest) (*QueryResponse, error) { func (UnimplementedAIServiceServer) Query(context.Context, *QueryRequest) (*QueryResponse, error) {
return nil, status.Error(codes.Unimplemented, "method Query not implemented") return nil, status.Error(codes.Unimplemented, "method Query not implemented")
} }
func (UnimplementedAIServiceServer) ListModels(context.Context, *ListModelsRequest) (*ListModelsResponse, error) {
return nil, status.Error(codes.Unimplemented, "method ListModels not implemented")
}
func (UnimplementedAIServiceServer) mustEmbedUnimplementedAIServiceServer() {} func (UnimplementedAIServiceServer) mustEmbedUnimplementedAIServiceServer() {}
func (UnimplementedAIServiceServer) testEmbeddedByValue() {} func (UnimplementedAIServiceServer) testEmbeddedByValue() {}
@ -104,6 +120,24 @@ func _AIService_Query_Handler(srv interface{}, ctx context.Context, dec func(int
return interceptor(ctx, in, info, handler) return interceptor(ctx, in, info, handler)
} }
func _AIService_ListModels_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(ListModelsRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(AIServiceServer).ListModels(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: AIService_ListModels_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(AIServiceServer).ListModels(ctx, req.(*ListModelsRequest))
}
return interceptor(ctx, in, info, handler)
}
// AIService_ServiceDesc is the grpc.ServiceDesc for AIService service. // AIService_ServiceDesc is the grpc.ServiceDesc for AIService service.
// It's only intended for direct use with grpc.RegisterService, // It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy) // and not to be introspected or modified (even as a copy)
@ -115,6 +149,10 @@ var AIService_ServiceDesc = grpc.ServiceDesc{
MethodName: "Query", MethodName: "Query",
Handler: _AIService_Query_Handler, Handler: _AIService_Query_Handler,
}, },
{
MethodName: "ListModels",
Handler: _AIService_ListModels_Handler,
},
}, },
Streams: []grpc.StreamDesc{}, Streams: []grpc.StreamDesc{},
Metadata: "ai/v1/ai.proto", Metadata: "ai/v1/ai.proto",

View File

@ -6,15 +6,24 @@ option go_package = "gitea.nik4nao.com/nik/home-services/gen/ai/v1;aiv1";
service AIService { service AIService {
rpc Query(QueryRequest) returns (QueryResponse); rpc Query(QueryRequest) returns (QueryResponse);
rpc ListModels(ListModelsRequest) returns (ListModelsResponse);
} }
message QueryRequest { message QueryRequest {
string text = 1; string text = 1;
string source = 2; string source = 2;
string model = 3;
} }
message QueryResponse { message QueryResponse {
string reply = 1; string reply = 1;
string intent = 2; string intent = 2;
bool action_taken = 3; bool action_taken = 3;
string model_used = 4;
}
message ListModelsRequest {}
message ListModelsResponse {
repeated string names = 1;
} }