feat: enhance AI model management in Discord bot
All checks were successful
CI / test (push) Successful in 5s
CI / build-ai-gateway (push) Successful in 43s
CI / build-ha-gateway (push) Successful in 47s
CI / build-discord-bot (push) Successful in 41s

- Updated LLMClient interface to support model-specific generation and model listing.
- Integrated model store and validator into the command application for managing AI models.
- Implemented commands for setting, getting, and listing active AI models in Discord.
- Enhanced AI query handling to utilize the selected model and return model information in responses.
- Added caching mechanism for model validation to improve performance.
- Introduced gRPC methods for listing available AI models in the ai-gateway.
- Updated protobuf definitions to include model-related fields and messages.
- Added tests for model store and validator functionalities.
This commit is contained in:
Nik Afiq 2026-04-21 22:52:00 +09:00
parent 9cc29c2329
commit ad50d641bd
21 changed files with 876 additions and 80 deletions

View File

@ -67,7 +67,7 @@ func main() {
os.Exit(1)
}
ollamaClient := ollama.New(cfg.OllamaURL, cfg.OllamaModel, &http.Client{Timeout: cfg.OllamaTimeout})
ollamaClient := ollama.New(cfg.OllamaURL, &http.Client{Timeout: cfg.OllamaTimeout})
haClient, err := hagateway.New(ctx, cfg.HAGatewayAddr, cfg.TLSDir, cfg.HAGatewayServerName, log)
if err != nil {
log.Error("ha-gateway client setup failed", "err", err)
@ -80,7 +80,7 @@ func main() {
}()
lightCache := domain.NewLightCache(cfg.LightCacheTTL, haClient.ListLights)
queryApp := app.NewQueryApp(ollamaClient, haClient, lightCache, log)
queryApp := app.NewQueryApp(ollamaClient, haClient, lightCache, cfg.OllamaModel, log)
serverOpts := []grpc.ServerOption{
grpc.StatsHandler(otelgrpc.NewServerHandler()),

View File

@ -33,7 +33,7 @@ func (s *Server) Query(ctx context.Context, req *aiv1.QueryRequest) (*aiv1.Query
ctx = logger.WithLogger(ctx, log)
}
result, err := s.app.Query(ctx, req.GetText())
result, err := s.app.Query(ctx, req.GetText(), req.GetModel())
if err != nil {
return nil, status.Errorf(codes.Unavailable, "query failed: %v", err)
}
@ -41,5 +41,15 @@ func (s *Server) Query(ctx context.Context, req *aiv1.QueryRequest) (*aiv1.Query
Reply: result.Reply,
Intent: result.Intent,
ActionTaken: result.ActionTaken,
ModelUsed: result.ModelUsed,
}, nil
}
// ListModels returns the installed model names from Ollama.
func (s *Server) ListModels(ctx context.Context, _ *aiv1.ListModelsRequest) (*aiv1.ListModelsResponse, error) {
names, err := s.app.ListModels(ctx)
if err != nil {
return nil, status.Errorf(codes.Unavailable, "list models: %v", err)
}
return &aiv1.ListModelsResponse{Names: names}, nil
}

View File

@ -6,6 +6,7 @@ import (
"encoding/json"
"fmt"
"net/http"
"strings"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
)
@ -14,37 +15,49 @@ type generateRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
Stream bool `json:"stream"`
Think *bool `json:"think,omitempty"`
}
type generateResponse struct {
Response string `json:"response"`
}
type listModelsResponse struct {
Models []struct {
Name string `json:"name"`
} `json:"models"`
}
// Client implements the LLM driven port with the Ollama generate API.
type Client struct {
baseURL string
model string
http *http.Client
}
// New constructs an Ollama client with OTel-instrumented transport.
func New(baseURL, model string, httpClient *http.Client) *Client {
func New(baseURL string, httpClient *http.Client) *Client {
if httpClient == nil {
httpClient = &http.Client{Transport: otelhttp.NewTransport(http.DefaultTransport)}
}
if httpClient.Transport == nil {
httpClient.Transport = otelhttp.NewTransport(http.DefaultTransport)
}
return &Client{baseURL: baseURL, model: model, http: httpClient}
return &Client{baseURL: baseURL, http: httpClient}
}
// Generate sends one non-streaming prompt to Ollama.
func (c *Client) Generate(ctx context.Context, prompt string) (string, error) {
body, err := json.Marshal(generateRequest{
Model: c.model,
func (c *Client) Generate(ctx context.Context, model, prompt string) (string, error) {
reqBody := generateRequest{
Model: model,
Prompt: prompt,
Stream: false,
})
}
if isThinkingModel(model) {
disabled := false
reqBody.Think = &disabled
}
body, err := json.Marshal(reqBody)
if err != nil {
return "", fmt.Errorf("marshal ollama request: %w", err)
}
@ -71,3 +84,38 @@ func (c *Client) Generate(ctx context.Context, prompt string) (string, error) {
}
return out.Response, nil
}
// ListModels returns the installed model names from Ollama.
func (c *Client) ListModels(ctx context.Context) ([]string, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.baseURL+"/api/tags", nil)
if err != nil {
return nil, fmt.Errorf("build ollama list models request: %w", err)
}
resp, err := c.http.Do(req)
if err != nil {
return nil, fmt.Errorf("list ollama models: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("ollama returned status %s", resp.Status)
}
var out listModelsResponse
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
return nil, fmt.Errorf("decode ollama list models response: %w", err)
}
names := make([]string, 0, len(out.Models))
for _, model := range out.Models {
if model.Name != "" {
names = append(names, model.Name)
}
}
return names, nil
}
func isThinkingModel(name string) bool {
return strings.HasPrefix(name, "qwen3")
}

View File

@ -0,0 +1,81 @@
package ollama
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
)
func TestIsThinkingModel(t *testing.T) {
tests := []struct {
name string
want bool
}{
{name: "qwen3", want: true},
{name: "qwen3:4b", want: true},
{name: "qwen3:latest", want: true},
{name: "llama3", want: false},
{name: "mistral", want: false},
}
for _, tt := range tests {
if got := isThinkingModel(tt.name); got != tt.want {
t.Fatalf("isThinkingModel(%q) = %v, want %v", tt.name, got, tt.want)
}
}
}
func TestGenerateSetsThinkForQwen3(t *testing.T) {
var body map[string]any
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
t.Fatalf("Decode() error = %v", err)
}
_, _ = w.Write([]byte(`{"response":"ok"}`))
}))
defer srv.Close()
client := New(srv.URL, srv.Client())
if _, err := client.Generate(context.Background(), "qwen3:latest", "prompt"); err != nil {
t.Fatalf("Generate() error = %v", err)
}
if body["think"] != false {
t.Fatalf("think = %#v, want false", body["think"])
}
}
func TestGenerateOmitsThinkForLlama3(t *testing.T) {
var body map[string]any
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
t.Fatalf("Decode() error = %v", err)
}
_, _ = w.Write([]byte(`{"response":"ok"}`))
}))
defer srv.Close()
client := New(srv.URL, srv.Client())
if _, err := client.Generate(context.Background(), "llama3:latest", "prompt"); err != nil {
t.Fatalf("Generate() error = %v", err)
}
if _, ok := body["think"]; ok {
t.Fatalf("unexpected think key = %#v", body["think"])
}
}
func TestListModelsReturnsNamesOnly(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte(`{"models":[{"name":"llama3:latest"},{"name":"qwen3:latest"}]}`))
}))
defer srv.Close()
client := New(srv.URL, srv.Client())
got, err := client.ListModels(context.Background())
if err != nil {
t.Fatalf("ListModels() error = %v", err)
}
if len(got) != 2 || got[0] != "llama3:latest" || got[1] != "qwen3:latest" {
t.Fatalf("ListModels() = %#v", got)
}
}

View File

@ -18,34 +18,40 @@ type QueryResult struct {
Reply string
Intent string
ActionTaken bool
ModelUsed string
}
// QueryApp orchestrates one AI query request.
type QueryApp struct {
llm driven.LLMClient
ha driven.HAClient
cache *domain.LightCache
log *slog.Logger
llm driven.LLMClient
ha driven.HAClient
cache *domain.LightCache
defaultModel string
log *slog.Logger
}
// NewQueryApp constructs the AI query application service.
func NewQueryApp(llm driven.LLMClient, ha driven.HAClient, cache *domain.LightCache, log *slog.Logger) *QueryApp {
return &QueryApp{llm: llm, ha: ha, cache: cache, log: log}
func NewQueryApp(llm driven.LLMClient, ha driven.HAClient, cache *domain.LightCache, defaultModel string, log *slog.Logger) *QueryApp {
return &QueryApp{llm: llm, ha: ha, cache: cache, defaultModel: defaultModel, log: log}
}
// Query runs the full intent parsing and dispatch flow for one user request.
func (a *QueryApp) Query(ctx context.Context, text string) (QueryResult, error) {
func (a *QueryApp) Query(ctx context.Context, text, model string) (QueryResult, error) {
if model == "" {
model = a.defaultModel
}
lights, err := a.cache.Get(ctx)
if err != nil {
a.log.Error("light cache refresh failed", "err", err)
return QueryResult{
Reply: "I couldn't reach Home Assistant right now.",
ActionTaken: false,
ModelUsed: model,
}, nil
}
prompt := domain.BuildPrompt(text, promptLightLines(lights))
raw, err := a.llm.Generate(ctx, prompt)
raw, err := a.llm.Generate(ctx, model, prompt)
if err != nil {
return QueryResult{}, err
}
@ -57,6 +63,7 @@ func (a *QueryApp) Query(ctx context.Context, text string) (QueryResult, error)
Reply: "I didn't understand that.",
Intent: domain.IntentNone,
ActionTaken: false,
ModelUsed: model,
}, nil
}
@ -64,7 +71,7 @@ func (a *QueryApp) Query(ctx context.Context, text string) (QueryResult, error)
case domain.IntentTurnOnLight:
entityID, ok := resolveLightEntity(intent.Entity, lights)
if !ok {
return QueryResult{Reply: "I couldn't find that light.", Intent: intent.Name}, nil
return QueryResult{Reply: "I couldn't find that light.", Intent: intent.Name, ModelUsed: model}, nil
}
params, err := ParseLightParams(intent.Params)
if err != nil {
@ -72,6 +79,7 @@ func (a *QueryApp) Query(ctx context.Context, text string) (QueryResult, error)
Reply: "I couldn't understand the light settings.",
Intent: intent.Name,
ActionTaken: false,
ModelUsed: model,
}, nil
}
if err := a.ha.TurnOnLight(ctx, entityID, params); err != nil {
@ -80,17 +88,19 @@ func (a *QueryApp) Query(ctx context.Context, text string) (QueryResult, error)
Reply: "I couldn't reach Home Assistant right now.",
Intent: intent.Name,
ActionTaken: false,
ModelUsed: model,
}, nil
}
return QueryResult{
Reply: fallbackReply(intent.Reply, fmt.Sprintf("Turned on `%s`.", displayLightName(entityID, lights))),
Intent: intent.Name,
ActionTaken: true,
ModelUsed: model,
}, nil
case domain.IntentTurnOffLight:
entityID, ok := resolveLightEntity(intent.Entity, lights)
if !ok {
return QueryResult{Reply: "I couldn't find that light.", Intent: intent.Name}, nil
return QueryResult{Reply: "I couldn't find that light.", Intent: intent.Name, ModelUsed: model}, nil
}
if err := a.ha.TurnOffLight(ctx, entityID); err != nil {
a.log.Error("turn off light failed", "entity_id", entityID, "err", err)
@ -98,18 +108,21 @@ func (a *QueryApp) Query(ctx context.Context, text string) (QueryResult, error)
Reply: "I couldn't reach Home Assistant right now.",
Intent: intent.Name,
ActionTaken: false,
ModelUsed: model,
}, nil
}
return QueryResult{
Reply: fallbackReply(intent.Reply, fmt.Sprintf("Turned off `%s`.", displayLightName(entityID, lights))),
Intent: intent.Name,
ActionTaken: true,
ModelUsed: model,
}, nil
case domain.IntentListLights:
return QueryResult{
Reply: formatLightListReply(lights),
Intent: intent.Name,
ActionTaken: false,
ModelUsed: model,
}, nil
case domain.IntentNone:
fallthrough
@ -118,10 +131,16 @@ func (a *QueryApp) Query(ctx context.Context, text string) (QueryResult, error)
Reply: fallbackReply(intent.Reply, "I didn't understand that."),
Intent: intent.Name,
ActionTaken: false,
ModelUsed: model,
}, nil
}
}
// ListModels returns the currently installed model names.
func (a *QueryApp) ListModels(ctx context.Context) ([]string, error) {
return a.llm.ListModels(ctx)
}
func promptLightLines(lights []driven.Light) []string {
lines := make([]string, 0, len(lights))
for _, light := range lights {

View File

@ -13,11 +13,19 @@ import (
)
type fakeLLM struct {
generate func(context.Context, string) (string, error)
generate func(context.Context, string, string) (string, error)
listModels func(context.Context) ([]string, error)
}
func (f *fakeLLM) Generate(ctx context.Context, prompt string) (string, error) {
return f.generate(ctx, prompt)
func (f *fakeLLM) Generate(ctx context.Context, model, prompt string) (string, error) {
return f.generate(ctx, model, prompt)
}
func (f *fakeLLM) ListModels(ctx context.Context) ([]string, error) {
if f.listModels == nil {
return nil, nil
}
return f.listModels(ctx)
}
type fakeHA struct {
@ -54,12 +62,15 @@ func TestQueryAppTurnOnLight(t *testing.T) {
ha := &fakeHA{lights: []driven.Light{{EntityID: "light.kitchen", FriendlyName: "Kitchen", State: "off"}}}
cache := domain.NewLightCache(time.Hour, ha.ListLights)
app := NewQueryApp(&fakeLLM{
generate: func(ctx context.Context, prompt string) (string, error) {
generate: func(ctx context.Context, model, prompt string) (string, error) {
if model != "llama3:latest" {
t.Fatalf("Generate() model = %q", model)
}
return `{"intent":"turn_on_light","entity":"Kitchen","params":{"brightness":"80"},"reply":"Turning on Kitchen."}`, nil
},
}, ha, cache, slog.Default())
}, ha, cache, "llama3:latest", slog.Default())
got, err := app.Query(context.Background(), "turn on kitchen")
got, err := app.Query(context.Background(), "turn on kitchen", "")
if err != nil {
t.Fatalf("Query() error = %v", err)
}
@ -77,12 +88,12 @@ func TestQueryAppTurnOnLight(t *testing.T) {
func TestQueryAppInvalidJSON(t *testing.T) {
ha := &fakeHA{lights: []driven.Light{{EntityID: "light.kitchen", FriendlyName: "Kitchen", State: "off"}}}
app := NewQueryApp(&fakeLLM{
generate: func(ctx context.Context, prompt string) (string, error) {
generate: func(ctx context.Context, model, prompt string) (string, error) {
return `not-json`, nil
},
}, ha, domain.NewLightCache(time.Hour, ha.ListLights), slog.Default())
}, ha, domain.NewLightCache(time.Hour, ha.ListLights), "llama3:latest", slog.Default())
got, err := app.Query(context.Background(), "turn on kitchen")
got, err := app.Query(context.Background(), "turn on kitchen", "")
if err != nil {
t.Fatalf("Query() error = %v", err)
}
@ -97,12 +108,12 @@ func TestQueryAppInvalidJSON(t *testing.T) {
func TestQueryAppIntentNone(t *testing.T) {
ha := &fakeHA{lights: []driven.Light{{EntityID: "light.kitchen", FriendlyName: "Kitchen", State: "off"}}}
app := NewQueryApp(&fakeLLM{
generate: func(ctx context.Context, prompt string) (string, error) {
generate: func(ctx context.Context, model, prompt string) (string, error) {
return `{"intent":"none","entity":"","params":{},"reply":"Hello there."}`, nil
},
}, ha, domain.NewLightCache(time.Hour, ha.ListLights), slog.Default())
}, ha, domain.NewLightCache(time.Hour, ha.ListLights), "llama3:latest", slog.Default())
got, err := app.Query(context.Background(), "hello")
got, err := app.Query(context.Background(), "hello", "")
if err != nil {
t.Fatalf("Query() error = %v", err)
}
@ -117,12 +128,12 @@ func TestQueryAppHAFailure(t *testing.T) {
turnOnErr: errors.New("boom"),
}
app := NewQueryApp(&fakeLLM{
generate: func(ctx context.Context, prompt string) (string, error) {
generate: func(ctx context.Context, model, prompt string) (string, error) {
return `{"intent":"turn_on_light","entity":"light.kitchen","params":{},"reply":"Turning on Kitchen."}`, nil
},
}, ha, domain.NewLightCache(time.Hour, ha.ListLights), slog.Default())
}, ha, domain.NewLightCache(time.Hour, ha.ListLights), "llama3:latest", slog.Default())
got, err := app.Query(context.Background(), "turn on kitchen")
got, err := app.Query(context.Background(), "turn on kitchen", "")
if err != nil {
t.Fatalf("Query() error = %v", err)
}
@ -134,12 +145,12 @@ func TestQueryAppHAFailure(t *testing.T) {
func TestQueryAppListLights(t *testing.T) {
ha := &fakeHA{lights: []driven.Light{{EntityID: "light.kitchen", FriendlyName: "Kitchen", State: "on"}}}
app := NewQueryApp(&fakeLLM{
generate: func(ctx context.Context, prompt string) (string, error) {
generate: func(ctx context.Context, model, prompt string) (string, error) {
return `{"intent":"list_lights","entity":"","params":{},"reply":""}`, nil
},
}, ha, domain.NewLightCache(time.Hour, ha.ListLights), slog.Default())
}, ha, domain.NewLightCache(time.Hour, ha.ListLights), "llama3:latest", slog.Default())
got, err := app.Query(context.Background(), "what lights exist")
got, err := app.Query(context.Background(), "what lights exist", "")
if err != nil {
t.Fatalf("Query() error = %v", err)
}
@ -164,3 +175,42 @@ func TestLightCacheRefreshAfterTTL(t *testing.T) {
t.Fatalf("ListLights calls = %d, want at least 2", ha.listCalls)
}
}
func TestQueryAppExplicitModel(t *testing.T) {
ha := &fakeHA{lights: []driven.Light{{EntityID: "light.kitchen", FriendlyName: "Kitchen", State: "off"}}}
app := NewQueryApp(&fakeLLM{
generate: func(ctx context.Context, model, prompt string) (string, error) {
if model != "qwen3:latest" {
t.Fatalf("Generate() model = %q", model)
}
return `{"intent":"none","entity":"","params":{},"reply":"Hello there."}`, nil
},
}, ha, domain.NewLightCache(time.Hour, ha.ListLights), "llama3:latest", slog.Default())
got, err := app.Query(context.Background(), "hello", "qwen3:latest")
if err != nil {
t.Fatalf("Query() error = %v", err)
}
if got.ModelUsed != "qwen3:latest" {
t.Fatalf("ModelUsed = %q", got.ModelUsed)
}
}
func TestQueryAppListModels(t *testing.T) {
app := NewQueryApp(&fakeLLM{
generate: func(ctx context.Context, model, prompt string) (string, error) {
return "", nil
},
listModels: func(ctx context.Context) ([]string, error) {
return []string{"llama3:latest", "qwen3:latest"}, nil
},
}, &fakeHA{}, domain.NewLightCache(time.Hour, func(context.Context) ([]driven.Light, error) { return nil, nil }), "llama3:latest", slog.Default())
got, err := app.ListModels(context.Background())
if err != nil {
t.Fatalf("ListModels() error = %v", err)
}
if !reflect.DeepEqual(got, []string{"llama3:latest", "qwen3:latest"}) {
t.Fatalf("ListModels() = %#v", got)
}
}

View File

@ -4,5 +4,6 @@ import "context"
// LLMClient generates one model response for a prompt.
type LLMClient interface {
Generate(ctx context.Context, prompt string) (string, error)
Generate(ctx context.Context, model, prompt string) (string, error)
ListModels(ctx context.Context) ([]string, error)
}

View File

@ -17,6 +17,8 @@ import (
"gitea.nik4nao.com/nik/home-services/discord-bot/internal/app"
"gitea.nik4nao.com/nik/home-services/discord-bot/internal/config"
"gitea.nik4nao.com/nik/home-services/discord-bot/internal/logger"
"gitea.nik4nao.com/nik/home-services/discord-bot/internal/modelstore"
"gitea.nik4nao.com/nik/home-services/discord-bot/internal/modelvalidator"
"gitea.nik4nao.com/nik/home-services/discord-bot/internal/telemetry"
)
@ -84,7 +86,9 @@ func main() {
}
}()
commandApp := app.NewCommandApp(haClient, aiClient)
modelStore := modelstore.New()
validator := modelvalidator.New(aiClient, 30*time.Second)
commandApp := app.NewCommandApp(haClient, aiClient, modelStore, validator)
// Discord-specific wiring stays at the edge so the app layer remains transport-agnostic.
session, err := discordgo.New("Bot " + cfg.DiscordToken)

View File

@ -24,8 +24,12 @@ type commandHandler interface {
HandleLightToggle(ctx context.Context, entityID string) (string, error)
HandleSwitchList(ctx context.Context) (string, error)
HandleAIQuery(ctx context.Context, text string) (string, error)
HandleAIModelSet(ctx context.Context, name string) (string, error)
HandleAIModelGet(ctx context.Context) (string, error)
HandleAIModelList(ctx context.Context) (string, error)
AutocompleteLights(ctx context.Context) ([]apppkg.Choice, error)
AutocompleteSwitches(ctx context.Context) ([]apppkg.Choice, error)
AutocompleteAIModels(ctx context.Context) ([]apppkg.Choice, error)
}
// Handler adapts Discord interactions to the command application layer.
@ -69,6 +73,9 @@ func (h *Handler) handleApplicationCommand(ctx context.Context, s *discordgo.Ses
command := data.Name
if len(data.Options) > 0 {
command += "." + data.Options[0].Name
if len(data.Options[0].Options) > 0 && data.Options[0].Type == discordgo.ApplicationCommandOptionSubCommandGroup {
command += "." + data.Options[0].Options[0].Name
}
}
user := ""
if i.Member != nil && i.Member.User != nil {
@ -90,7 +97,18 @@ func (h *Handler) handleApplicationCommand(ctx context.Context, s *discordgo.Ses
}
sub := data.Options[0]
switch data.Name + "." + sub.Name {
commandPath := data.Name + "." + sub.Name
target := sub
if sub.Type == discordgo.ApplicationCommandOptionSubCommandGroup {
if len(sub.Options) == 0 {
h.respondError(ctx, s, i.Interaction, true, start, fmt.Errorf("missing grouped subcommand"))
return
}
target = sub.Options[0]
commandPath += "." + target.Name
}
switch commandPath {
case "light.list":
msg, err := h.app.HandleLightList(ctx)
if err != nil {
@ -145,10 +163,38 @@ func (h *Handler) handleApplicationCommand(ctx context.Context, s *discordgo.Ses
)
return
}
msg, err := h.app.HandleAIQuery(ctx, requiredStringOption(sub, "text"))
msg, err := h.app.HandleAIQuery(ctx, requiredStringOption(target, "text"))
h.followup(ctx, s, i.Interaction, msg, true, start, err)
case "ai.model.set":
if err := h.deferResponse(s, i.Interaction, true); err != nil {
log.Error("discord response failed",
"duration_ms", time.Since(start).Milliseconds(),
"error", err.Error(),
)
return
}
msg, err := h.app.HandleAIModelSet(ctx, requiredStringOption(target, "name"))
h.followup(ctx, s, i.Interaction, msg, true, start, err)
case "ai.model.get":
msg, err := h.app.HandleAIModelGet(ctx)
if err != nil {
h.respondError(ctx, s, i.Interaction, true, start, err)
return
}
h.respondMessage(ctx, s, i.Interaction, msg, true)
log.Info("command handled", "duration_ms", time.Since(start).Milliseconds())
case "ai.model.list":
if err := h.deferResponse(s, i.Interaction, true); err != nil {
log.Error("discord response failed",
"duration_ms", time.Since(start).Milliseconds(),
"error", err.Error(),
)
return
}
msg, err := h.app.HandleAIModelList(ctx)
h.followup(ctx, s, i.Interaction, msg, true, start, err)
default:
h.respondError(ctx, s, i.Interaction, true, start, fmt.Errorf("unsupported command: %s.%s", data.Name, sub.Name))
h.respondError(ctx, s, i.Interaction, true, start, fmt.Errorf("unsupported command: %s", commandPath))
}
}
@ -167,6 +213,10 @@ func (h *Handler) handleAutocomplete(ctx context.Context, s *discordgo.Session,
choices, err = h.app.AutocompleteLights(ctx)
case "switch":
choices, err = h.app.AutocompleteSwitches(ctx)
case "ai":
if focusedOptionName(data) == "name" {
choices, err = h.app.AutocompleteAIModels(ctx)
}
default:
choices = nil
}
@ -299,6 +349,15 @@ func optionalUint32Option(sub *discordgo.ApplicationCommandInteractionDataOption
func focusedOptionValue(data discordgo.ApplicationCommandInteractionData) string {
for _, sub := range data.Options {
if sub.Type == discordgo.ApplicationCommandOptionSubCommandGroup {
for _, nested := range sub.Options {
for _, opt := range nested.Options {
if opt.Focused {
return opt.StringValue()
}
}
}
}
for _, opt := range sub.Options {
if opt.Focused {
return opt.StringValue()
@ -307,3 +366,23 @@ func focusedOptionValue(data discordgo.ApplicationCommandInteractionData) string
}
return ""
}
func focusedOptionName(data discordgo.ApplicationCommandInteractionData) string {
for _, sub := range data.Options {
if sub.Type == discordgo.ApplicationCommandOptionSubCommandGroup {
for _, nested := range sub.Options {
for _, opt := range nested.Options {
if opt.Focused {
return opt.Name
}
}
}
}
for _, opt := range sub.Options {
if opt.Focused {
return opt.Name
}
}
}
return ""
}

View File

@ -98,6 +98,37 @@ func RegisterCommands(s *discordgo.Session, guildID string) error {
},
},
},
{
Type: discordgo.ApplicationCommandOptionSubCommandGroup,
Name: "model",
Description: "Manage the active AI model",
Options: []*discordgo.ApplicationCommandOption{
{
Type: discordgo.ApplicationCommandOptionSubCommand,
Name: "set",
Description: "Set the active model",
Options: []*discordgo.ApplicationCommandOption{
{
Type: discordgo.ApplicationCommandOptionString,
Name: "name",
Description: "Model name",
Required: true,
Autocomplete: true,
},
},
},
{
Type: discordgo.ApplicationCommandOptionSubCommand,
Name: "get",
Description: "Show the active model",
},
{
Type: discordgo.ApplicationCommandOptionSubCommand,
Name: "list",
Description: "List available models",
},
},
},
},
},
}

View File

@ -61,22 +61,39 @@ func (c *Client) Close() error {
}
// Query forwards one free-form request to ai-gateway.
func (c *Client) Query(ctx context.Context, text string) (string, error) {
func (c *Client) Query(ctx context.Context, text, model string) (string, string, error) {
start := time.Now()
log := logger.FromContext(ctx).With("grpc.method", "AIService/Query")
resp, err := c.client.Query(ctx, &aiv1.QueryRequest{
Text: text,
Source: "discord-bot",
Model: model,
})
if err != nil {
log.Error("grpc call failed",
"duration_ms", time.Since(start).Milliseconds(),
"error", err.Error(),
)
return "", fmt.Errorf("query ai-gateway: %w", err)
return "", "", fmt.Errorf("query ai-gateway: %w", err)
}
log.Debug("grpc call completed", "duration_ms", time.Since(start).Milliseconds())
return resp.GetReply(), nil
return resp.GetReply(), resp.GetModelUsed(), nil
}
// ListModels returns the installed model names from ai-gateway.
func (c *Client) ListModels(ctx context.Context) ([]string, error) {
start := time.Now()
log := logger.FromContext(ctx).With("grpc.method", "AIService/ListModels")
resp, err := c.client.ListModels(ctx, &aiv1.ListModelsRequest{})
if err != nil {
log.Error("grpc call failed",
"duration_ms", time.Since(start).Milliseconds(),
"error", err.Error(),
)
return nil, fmt.Errorf("list ai-gateway models: %w", err)
}
log.Debug("grpc call completed", "duration_ms", time.Since(start).Milliseconds())
return append([]string(nil), resp.GetNames()...), nil
}
func loadTransportCredentials(tlsDir string) (credentials.TransportCredentials, error) {

View File

@ -7,6 +7,8 @@ import (
"strings"
"gitea.nik4nao.com/nik/home-services/discord-bot/internal/core/ports/driven"
"gitea.nik4nao.com/nik/home-services/discord-bot/internal/modelstore"
"gitea.nik4nao.com/nik/home-services/discord-bot/internal/modelvalidator"
)
// Choice is one Discord autocomplete entry.
@ -17,13 +19,15 @@ type Choice struct {
// CommandApp orchestrates Discord command use cases against ha-gateway.
type CommandApp struct {
ha driven.HAGateway
ai driven.AIGateway
ha driven.HAGateway
ai driven.AIGateway
models *modelstore.Store
validator *modelvalidator.Validator
}
// NewCommandApp constructs the Discord command application service.
func NewCommandApp(ha driven.HAGateway, ai driven.AIGateway) *CommandApp {
return &CommandApp{ha: ha, ai: ai}
func NewCommandApp(ha driven.HAGateway, ai driven.AIGateway, models *modelstore.Store, validator *modelvalidator.Validator) *CommandApp {
return &CommandApp{ha: ha, ai: ai, models: models, validator: validator}
}
// HandleLightList formats discovered lights into a monospace-friendly response.
@ -124,11 +128,72 @@ func (a *CommandApp) HandleSwitchList(ctx context.Context) (string, error) {
// HandleAIQuery forwards a free-form request to ai-gateway.
func (a *CommandApp) HandleAIQuery(ctx context.Context, text string) (string, error) {
reply, err := a.ai.Query(ctx, text)
reply, modelUsed, err := a.ai.Query(ctx, text, a.models.Get())
if err != nil {
return "", fmt.Errorf("handle ai query: %w", err)
}
return reply, nil
return fmt.Sprintf("%s\n\n_(via %s)_", reply, modelUsed), nil
}
// HandleAIModelSet validates and stores the selected model globally.
func (a *CommandApp) HandleAIModelSet(ctx context.Context, name string) (string, error) {
canonical, err := a.validator.Normalize(ctx, name)
if err != nil {
if err.Error() == "ambiguous model name" {
return "", fmt.Errorf("unknown model: %s. Be more specific.", name)
}
if err.Error() == "unknown model" {
return "", fmt.Errorf("unknown model: %s. Try /ai model list.", name)
}
return "", fmt.Errorf("validate model: %w", err)
}
a.models.Set(canonical)
return fmt.Sprintf("Active model set to `%s`.", canonical), nil
}
// HandleAIModelGet reports the current selected model or default state.
func (a *CommandApp) HandleAIModelGet(ctx context.Context) (string, error) {
cur := a.models.Get()
if cur == "" {
return "No model override set. Using ai-gateway default.", nil
}
return fmt.Sprintf("Active model: `%s`", cur), nil
}
// HandleAIModelList shows the installed models and marks the active selection.
func (a *CommandApp) HandleAIModelList(ctx context.Context) (string, error) {
models, err := a.validator.Known(ctx)
if err != nil {
return "", fmt.Errorf("list models: %w", err)
}
if len(models) == 0 {
return "No models installed on the Ollama host.", nil
}
active := a.models.Get()
lines := make([]string, 0, len(models)+1)
lines = append(lines, "Available models:")
for _, model := range models {
marker := ""
if model == active {
marker = " <- active"
}
lines = append(lines, fmt.Sprintf("- `%s`%s", model, marker))
}
return strings.Join(lines, "\n"), nil
}
// AutocompleteAIModels returns model names for the /ai model set command.
func (a *CommandApp) AutocompleteAIModels(ctx context.Context) ([]Choice, error) {
models, err := a.validator.Known(ctx)
if err != nil {
return nil, fmt.Errorf("autocomplete ai models: %w", err)
}
choices := make([]Choice, 0, len(models))
for _, model := range models {
choices = append(choices, Choice{Label: model, Value: model})
}
return choices, nil
}
// AutocompleteLights maps discovered lights into Discord autocomplete choices.

View File

@ -5,8 +5,11 @@ import (
"errors"
"reflect"
"testing"
"time"
"gitea.nik4nao.com/nik/home-services/discord-bot/internal/core/ports/driven"
"gitea.nik4nao.com/nik/home-services/discord-bot/internal/modelstore"
"gitea.nik4nao.com/nik/home-services/discord-bot/internal/modelvalidator"
)
type mockHAGateway struct {
@ -18,7 +21,8 @@ type mockHAGateway struct {
}
type mockAIGateway struct {
queryFunc func(ctx context.Context, text string) (string, error)
queryFunc func(ctx context.Context, text, model string) (string, string, error)
listModelsFunc func(ctx context.Context) ([]string, error)
}
func (m *mockHAGateway) ListLights(ctx context.Context) ([]driven.Light, error) {
@ -56,11 +60,22 @@ func (m *mockHAGateway) ToggleLight(ctx context.Context, entityID string) error
return m.toggleLightFunc(ctx, entityID)
}
func (m *mockAIGateway) Query(ctx context.Context, text string) (string, error) {
func (m *mockAIGateway) Query(ctx context.Context, text, model string) (string, string, error) {
if m.queryFunc == nil {
return "", nil
return "", "", nil
}
return m.queryFunc(ctx, text)
return m.queryFunc(ctx, text, model)
}
func (m *mockAIGateway) ListModels(ctx context.Context) ([]string, error) {
if m.listModelsFunc == nil {
return nil, nil
}
return m.listModelsFunc(ctx)
}
func newTestCommandApp(ha *mockHAGateway, ai *mockAIGateway) *CommandApp {
return NewCommandApp(ha, ai, modelstore.New(), modelvalidator.New(ai, time.Minute))
}
func TestCommandAppHandleLightList(t *testing.T) {
@ -102,7 +117,7 @@ func TestCommandAppHandleLightList(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
app := NewCommandApp(&mockHAGateway{
app := newTestCommandApp(&mockHAGateway{
listLightsFunc: func(ctx context.Context) ([]driven.Light, error) {
return tt.lights, nil
},
@ -183,7 +198,7 @@ func TestCommandAppHandleLightOn(t *testing.T) {
var gotBrightness *uint32
var gotColorTemp *uint32
app := NewCommandApp(&mockHAGateway{
app := newTestCommandApp(&mockHAGateway{
listLightsFunc: func(ctx context.Context) ([]driven.Light, error) {
if tt.listErr != nil {
return nil, tt.listErr
@ -258,7 +273,7 @@ func TestCommandAppHandleLightOff(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var gotTransition *uint32
app := NewCommandApp(&mockHAGateway{
app := newTestCommandApp(&mockHAGateway{
listLightsFunc: func(ctx context.Context) ([]driven.Light, error) {
return tt.lights, nil
},
@ -317,7 +332,7 @@ func TestCommandAppHandleLightToggle(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
app := NewCommandApp(&mockHAGateway{
app := newTestCommandApp(&mockHAGateway{
listLightsFunc: func(ctx context.Context) ([]driven.Light, error) {
if tt.listErr != nil {
return nil, tt.listErr
@ -368,7 +383,7 @@ func TestCommandAppHandleSwitchList(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
app := NewCommandApp(&mockHAGateway{
app := newTestCommandApp(&mockHAGateway{
listSwitchesFunc: func(ctx context.Context) ([]driven.Switch, error) {
return tt.switches, nil
},
@ -413,7 +428,7 @@ func TestCommandAppAutocompleteLights(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
app := NewCommandApp(&mockHAGateway{
app := newTestCommandApp(&mockHAGateway{
listLightsFunc: func(ctx context.Context) ([]driven.Light, error) {
if tt.listErr != nil {
return nil, tt.listErr
@ -467,7 +482,7 @@ func TestCommandAppAutocompleteSwitches(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
app := NewCommandApp(&mockHAGateway{
app := newTestCommandApp(&mockHAGateway{
listSwitchesFunc: func(ctx context.Context) ([]driven.Switch, error) {
if tt.listErr != nil {
return nil, tt.listErr
@ -494,20 +509,41 @@ func TestCommandAppAutocompleteSwitches(t *testing.T) {
}
func TestCommandAppHandleAIQuery(t *testing.T) {
store := modelstore.New()
store.Set("llama3:latest")
app := NewCommandApp(&mockHAGateway{}, &mockAIGateway{
queryFunc: func(ctx context.Context, text string) (string, error) {
queryFunc: func(ctx context.Context, text, model string) (string, string, error) {
if text != "turn on kitchen" {
t.Fatalf("Query() text = %q", text)
}
return "Turning on Kitchen.", nil
if model != "llama3:latest" {
t.Fatalf("Query() model = %q", model)
}
return "Turning on Kitchen.", "llama3:latest", nil
},
})
}, store, modelvalidator.New(&mockAIGateway{}, time.Minute))
got, err := app.HandleAIQuery(context.Background(), "turn on kitchen")
if err != nil {
t.Fatalf("HandleAIQuery() error = %v", err)
}
if got != "Turning on Kitchen." {
if got != "Turning on Kitchen.\n\n_(via llama3:latest)_" {
t.Fatalf("HandleAIQuery() = %q", got)
}
}
func TestCommandAppHandleAIModelSet(t *testing.T) {
app := newTestCommandApp(&mockHAGateway{}, &mockAIGateway{
listModelsFunc: func(ctx context.Context) ([]string, error) {
return []string{"llama3:latest"}, nil
},
})
got, err := app.HandleAIModelSet(context.Background(), "llama3")
if err != nil {
t.Fatalf("HandleAIModelSet() error = %v", err)
}
if got != "Active model set to `llama3:latest`." {
t.Fatalf("HandleAIModelSet() = %q", got)
}
}

View File

@ -4,5 +4,6 @@ import "context"
// AIGateway exposes the free-form AI query API used by the Discord bot.
type AIGateway interface {
Query(ctx context.Context, text string) (string, error)
Query(ctx context.Context, text, model string) (reply, modelUsed string, err error)
ListModels(ctx context.Context) ([]string, error)
}

View File

@ -0,0 +1,28 @@
package modelstore
import "sync"
// Store keeps the globally selected AI model in memory.
type Store struct {
mu sync.RWMutex
selected string
}
// New constructs an empty in-memory model store.
func New() *Store {
return &Store{}
}
// Get returns the currently selected model, or empty for default behavior.
func (s *Store) Get() string {
s.mu.RLock()
defer s.mu.RUnlock()
return s.selected
}
// Set updates the currently selected model.
func (s *Store) Set(model string) {
s.mu.Lock()
defer s.mu.Unlock()
s.selected = model
}

View File

@ -0,0 +1,15 @@
package modelstore
import "testing"
func TestStoreGetSet(t *testing.T) {
store := New()
if got := store.Get(); got != "" {
t.Fatalf("Get() = %q, want empty", got)
}
store.Set("llama3:latest")
if got := store.Get(); got != "llama3:latest" {
t.Fatalf("Get() = %q, want llama3:latest", got)
}
}

View File

@ -0,0 +1,86 @@
package modelvalidator
import (
"context"
"fmt"
"strings"
"sync"
"time"
"gitea.nik4nao.com/nik/home-services/discord-bot/internal/core/ports/driven"
)
// Validator caches the model list briefly and normalizes friendly names.
type Validator struct {
client driven.AIGateway
ttl time.Duration
mu sync.Mutex
cache []string
cachedAt time.Time
}
// New constructs a model validator with a TTL cache.
func New(client driven.AIGateway, ttl time.Duration) *Validator {
return &Validator{client: client, ttl: ttl}
}
// Known returns the cached model list, refreshing when stale.
func (v *Validator) Known(ctx context.Context) ([]string, error) {
v.mu.Lock()
defer v.mu.Unlock()
if len(v.cache) > 0 && time.Since(v.cachedAt) < v.ttl {
return append([]string(nil), v.cache...), nil
}
models, err := v.client.ListModels(ctx)
if err != nil {
return nil, err
}
v.cache = append([]string(nil), models...)
v.cachedAt = time.Now()
return append([]string(nil), v.cache...), nil
}
// Normalize resolves a user-provided name to a canonical installed model name.
func (v *Validator) Normalize(ctx context.Context, name string) (string, error) {
models, err := v.Known(ctx)
if err != nil {
return "", err
}
for _, model := range models {
if model == name {
return model, nil
}
}
latest := name + ":latest"
for _, model := range models {
if model == latest {
return model, nil
}
}
lower := strings.ToLower(name)
for _, model := range models {
if strings.ToLower(model) == lower {
return model, nil
}
}
lowerLatest := strings.ToLower(latest)
for _, model := range models {
if strings.ToLower(model) == lowerLatest {
return model, nil
}
}
matches := make([]string, 0, 2)
prefix := lower + ":"
for _, model := range models {
if strings.HasPrefix(strings.ToLower(model), prefix) {
matches = append(matches, model)
}
}
if len(matches) > 1 {
return "", fmt.Errorf("ambiguous model name")
}
return "", fmt.Errorf("unknown model")
}

View File

@ -0,0 +1,70 @@
package modelvalidator
import (
"context"
"errors"
"testing"
"time"
)
type fakeAIGateway struct {
listModels func(context.Context) ([]string, error)
}
func (f *fakeAIGateway) Query(ctx context.Context, text, model string) (string, string, error) {
return "", "", nil
}
func (f *fakeAIGateway) ListModels(ctx context.Context) ([]string, error) {
return f.listModels(ctx)
}
func TestValidatorNormalize(t *testing.T) {
v := New(&fakeAIGateway{
listModels: func(ctx context.Context) ([]string, error) {
return []string{"llama3:latest", "qwen3:latest", "qwen3:4b"}, nil
},
}, time.Minute)
got, err := v.Normalize(context.Background(), "llama3")
if err != nil || got != "llama3:latest" {
t.Fatalf("Normalize(llama3) = %q, %v", got, err)
}
got, err = v.Normalize(context.Background(), "qwen3")
if err != nil || got != "qwen3:latest" {
t.Fatalf("Normalize(qwen3) = %q, %v", got, err)
}
}
func TestValidatorKnownCaches(t *testing.T) {
calls := 0
v := New(&fakeAIGateway{
listModels: func(ctx context.Context) ([]string, error) {
calls++
return []string{"llama3:latest"}, nil
},
}, time.Minute)
if _, err := v.Known(context.Background()); err != nil {
t.Fatalf("Known() error = %v", err)
}
if _, err := v.Known(context.Background()); err != nil {
t.Fatalf("Known() error = %v", err)
}
if calls != 1 {
t.Fatalf("calls = %d, want 1", calls)
}
}
func TestValidatorKnownPropagatesError(t *testing.T) {
v := New(&fakeAIGateway{
listModels: func(ctx context.Context) ([]string, error) {
return nil, errors.New("boom")
},
}, time.Minute)
if _, err := v.Known(context.Background()); err == nil {
t.Fatal("Known() error = nil, want error")
}
}

View File

@ -25,6 +25,7 @@ type QueryRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
Text string `protobuf:"bytes,1,opt,name=text,proto3" json:"text,omitempty"`
Source string `protobuf:"bytes,2,opt,name=source,proto3" json:"source,omitempty"`
Model string `protobuf:"bytes,3,opt,name=model,proto3" json:"model,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@ -73,11 +74,19 @@ func (x *QueryRequest) GetSource() string {
return ""
}
func (x *QueryRequest) GetModel() string {
if x != nil {
return x.Model
}
return ""
}
type QueryResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
Reply string `protobuf:"bytes,1,opt,name=reply,proto3" json:"reply,omitempty"`
Intent string `protobuf:"bytes,2,opt,name=intent,proto3" json:"intent,omitempty"`
ActionTaken bool `protobuf:"varint,3,opt,name=action_taken,json=actionTaken,proto3" json:"action_taken,omitempty"`
ModelUsed string `protobuf:"bytes,4,opt,name=model_used,json=modelUsed,proto3" json:"model_used,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@ -133,20 +142,115 @@ func (x *QueryResponse) GetActionTaken() bool {
return false
}
func (x *QueryResponse) GetModelUsed() string {
if x != nil {
return x.ModelUsed
}
return ""
}
type ListModelsRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *ListModelsRequest) Reset() {
*x = ListModelsRequest{}
mi := &file_ai_v1_ai_proto_msgTypes[2]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *ListModelsRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*ListModelsRequest) ProtoMessage() {}
func (x *ListModelsRequest) ProtoReflect() protoreflect.Message {
mi := &file_ai_v1_ai_proto_msgTypes[2]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use ListModelsRequest.ProtoReflect.Descriptor instead.
func (*ListModelsRequest) Descriptor() ([]byte, []int) {
return file_ai_v1_ai_proto_rawDescGZIP(), []int{2}
}
type ListModelsResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
Names []string `protobuf:"bytes,1,rep,name=names,proto3" json:"names,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *ListModelsResponse) Reset() {
*x = ListModelsResponse{}
mi := &file_ai_v1_ai_proto_msgTypes[3]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *ListModelsResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*ListModelsResponse) ProtoMessage() {}
func (x *ListModelsResponse) ProtoReflect() protoreflect.Message {
mi := &file_ai_v1_ai_proto_msgTypes[3]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use ListModelsResponse.ProtoReflect.Descriptor instead.
func (*ListModelsResponse) Descriptor() ([]byte, []int) {
return file_ai_v1_ai_proto_rawDescGZIP(), []int{3}
}
func (x *ListModelsResponse) GetNames() []string {
if x != nil {
return x.Names
}
return nil
}
var File_ai_v1_ai_proto protoreflect.FileDescriptor
const file_ai_v1_ai_proto_rawDesc = "" +
"\n" +
"\x0eai/v1/ai.proto\x12\x05ai.v1\":\n" +
"\x0eai/v1/ai.proto\x12\x05ai.v1\"P\n" +
"\fQueryRequest\x12\x12\n" +
"\x04text\x18\x01 \x01(\tR\x04text\x12\x16\n" +
"\x06source\x18\x02 \x01(\tR\x06source\"`\n" +
"\x06source\x18\x02 \x01(\tR\x06source\x12\x14\n" +
"\x05model\x18\x03 \x01(\tR\x05model\"\x7f\n" +
"\rQueryResponse\x12\x14\n" +
"\x05reply\x18\x01 \x01(\tR\x05reply\x12\x16\n" +
"\x06intent\x18\x02 \x01(\tR\x06intent\x12!\n" +
"\faction_taken\x18\x03 \x01(\bR\vactionTaken2?\n" +
"\faction_taken\x18\x03 \x01(\bR\vactionTaken\x12\x1d\n" +
"\n" +
"model_used\x18\x04 \x01(\tR\tmodelUsed\"\x13\n" +
"\x11ListModelsRequest\"*\n" +
"\x12ListModelsResponse\x12\x14\n" +
"\x05names\x18\x01 \x03(\tR\x05names2\x82\x01\n" +
"\tAIService\x122\n" +
"\x05Query\x12\x13.ai.v1.QueryRequest\x1a\x14.ai.v1.QueryResponseB4Z2gitea.nik4nao.com/nik/home-services/gen/ai/v1;aiv1b\x06proto3"
"\x05Query\x12\x13.ai.v1.QueryRequest\x1a\x14.ai.v1.QueryResponse\x12A\n" +
"\n" +
"ListModels\x12\x18.ai.v1.ListModelsRequest\x1a\x19.ai.v1.ListModelsResponseB4Z2gitea.nik4nao.com/nik/home-services/gen/ai/v1;aiv1b\x06proto3"
var (
file_ai_v1_ai_proto_rawDescOnce sync.Once
@ -160,16 +264,20 @@ func file_ai_v1_ai_proto_rawDescGZIP() []byte {
return file_ai_v1_ai_proto_rawDescData
}
var file_ai_v1_ai_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
var file_ai_v1_ai_proto_msgTypes = make([]protoimpl.MessageInfo, 4)
var file_ai_v1_ai_proto_goTypes = []any{
(*QueryRequest)(nil), // 0: ai.v1.QueryRequest
(*QueryResponse)(nil), // 1: ai.v1.QueryResponse
(*QueryRequest)(nil), // 0: ai.v1.QueryRequest
(*QueryResponse)(nil), // 1: ai.v1.QueryResponse
(*ListModelsRequest)(nil), // 2: ai.v1.ListModelsRequest
(*ListModelsResponse)(nil), // 3: ai.v1.ListModelsResponse
}
var file_ai_v1_ai_proto_depIdxs = []int32{
0, // 0: ai.v1.AIService.Query:input_type -> ai.v1.QueryRequest
1, // 1: ai.v1.AIService.Query:output_type -> ai.v1.QueryResponse
1, // [1:2] is the sub-list for method output_type
0, // [0:1] is the sub-list for method input_type
2, // 1: ai.v1.AIService.ListModels:input_type -> ai.v1.ListModelsRequest
1, // 2: ai.v1.AIService.Query:output_type -> ai.v1.QueryResponse
3, // 3: ai.v1.AIService.ListModels:output_type -> ai.v1.ListModelsResponse
2, // [2:4] is the sub-list for method output_type
0, // [0:2] is the sub-list for method input_type
0, // [0:0] is the sub-list for extension type_name
0, // [0:0] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name
@ -186,7 +294,7 @@ func file_ai_v1_ai_proto_init() {
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: unsafe.Slice(unsafe.StringData(file_ai_v1_ai_proto_rawDesc), len(file_ai_v1_ai_proto_rawDesc)),
NumEnums: 0,
NumMessages: 2,
NumMessages: 4,
NumExtensions: 0,
NumServices: 1,
},

View File

@ -19,7 +19,8 @@ import (
const _ = grpc.SupportPackageIsVersion9
const (
AIService_Query_FullMethodName = "/ai.v1.AIService/Query"
AIService_Query_FullMethodName = "/ai.v1.AIService/Query"
AIService_ListModels_FullMethodName = "/ai.v1.AIService/ListModels"
)
// AIServiceClient is the client API for AIService service.
@ -27,6 +28,7 @@ const (
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
type AIServiceClient interface {
Query(ctx context.Context, in *QueryRequest, opts ...grpc.CallOption) (*QueryResponse, error)
ListModels(ctx context.Context, in *ListModelsRequest, opts ...grpc.CallOption) (*ListModelsResponse, error)
}
type aIServiceClient struct {
@ -47,11 +49,22 @@ func (c *aIServiceClient) Query(ctx context.Context, in *QueryRequest, opts ...g
return out, nil
}
func (c *aIServiceClient) ListModels(ctx context.Context, in *ListModelsRequest, opts ...grpc.CallOption) (*ListModelsResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(ListModelsResponse)
err := c.cc.Invoke(ctx, AIService_ListModels_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
return out, nil
}
// AIServiceServer is the server API for AIService service.
// All implementations must embed UnimplementedAIServiceServer
// for forward compatibility.
type AIServiceServer interface {
Query(context.Context, *QueryRequest) (*QueryResponse, error)
ListModels(context.Context, *ListModelsRequest) (*ListModelsResponse, error)
mustEmbedUnimplementedAIServiceServer()
}
@ -65,6 +78,9 @@ type UnimplementedAIServiceServer struct{}
func (UnimplementedAIServiceServer) Query(context.Context, *QueryRequest) (*QueryResponse, error) {
return nil, status.Error(codes.Unimplemented, "method Query not implemented")
}
func (UnimplementedAIServiceServer) ListModels(context.Context, *ListModelsRequest) (*ListModelsResponse, error) {
return nil, status.Error(codes.Unimplemented, "method ListModels not implemented")
}
func (UnimplementedAIServiceServer) mustEmbedUnimplementedAIServiceServer() {}
func (UnimplementedAIServiceServer) testEmbeddedByValue() {}
@ -104,6 +120,24 @@ func _AIService_Query_Handler(srv interface{}, ctx context.Context, dec func(int
return interceptor(ctx, in, info, handler)
}
func _AIService_ListModels_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(ListModelsRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(AIServiceServer).ListModels(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: AIService_ListModels_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(AIServiceServer).ListModels(ctx, req.(*ListModelsRequest))
}
return interceptor(ctx, in, info, handler)
}
// AIService_ServiceDesc is the grpc.ServiceDesc for AIService service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
@ -115,6 +149,10 @@ var AIService_ServiceDesc = grpc.ServiceDesc{
MethodName: "Query",
Handler: _AIService_Query_Handler,
},
{
MethodName: "ListModels",
Handler: _AIService_ListModels_Handler,
},
},
Streams: []grpc.StreamDesc{},
Metadata: "ai/v1/ai.proto",

View File

@ -6,15 +6,24 @@ option go_package = "gitea.nik4nao.com/nik/home-services/gen/ai/v1;aiv1";
service AIService {
rpc Query(QueryRequest) returns (QueryResponse);
rpc ListModels(ListModelsRequest) returns (ListModelsResponse);
}
message QueryRequest {
string text = 1;
string source = 2;
string model = 3;
}
message QueryResponse {
string reply = 1;
string intent = 2;
bool action_taken = 3;
string model_used = 4;
}
message ListModelsRequest {}
message ListModelsResponse {
repeated string names = 1;
}