feat: enhance AI model management in Discord bot
- Updated LLMClient interface to support model-specific generation and model listing. - Integrated model store and validator into the command application for managing AI models. - Implemented commands for setting, getting, and listing active AI models in Discord. - Enhanced AI query handling to utilize the selected model and return model information in responses. - Added caching mechanism for model validation to improve performance. - Introduced gRPC methods for listing available AI models in the ai-gateway. - Updated protobuf definitions to include model-related fields and messages. - Added tests for model store and validator functionalities.
This commit is contained in:
parent
9cc29c2329
commit
ad50d641bd
@ -67,7 +67,7 @@ func main() {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
ollamaClient := ollama.New(cfg.OllamaURL, cfg.OllamaModel, &http.Client{Timeout: cfg.OllamaTimeout})
|
||||
ollamaClient := ollama.New(cfg.OllamaURL, &http.Client{Timeout: cfg.OllamaTimeout})
|
||||
haClient, err := hagateway.New(ctx, cfg.HAGatewayAddr, cfg.TLSDir, cfg.HAGatewayServerName, log)
|
||||
if err != nil {
|
||||
log.Error("ha-gateway client setup failed", "err", err)
|
||||
@ -80,7 +80,7 @@ func main() {
|
||||
}()
|
||||
|
||||
lightCache := domain.NewLightCache(cfg.LightCacheTTL, haClient.ListLights)
|
||||
queryApp := app.NewQueryApp(ollamaClient, haClient, lightCache, log)
|
||||
queryApp := app.NewQueryApp(ollamaClient, haClient, lightCache, cfg.OllamaModel, log)
|
||||
|
||||
serverOpts := []grpc.ServerOption{
|
||||
grpc.StatsHandler(otelgrpc.NewServerHandler()),
|
||||
|
||||
@ -33,7 +33,7 @@ func (s *Server) Query(ctx context.Context, req *aiv1.QueryRequest) (*aiv1.Query
|
||||
ctx = logger.WithLogger(ctx, log)
|
||||
}
|
||||
|
||||
result, err := s.app.Query(ctx, req.GetText())
|
||||
result, err := s.app.Query(ctx, req.GetText(), req.GetModel())
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Unavailable, "query failed: %v", err)
|
||||
}
|
||||
@ -41,5 +41,15 @@ func (s *Server) Query(ctx context.Context, req *aiv1.QueryRequest) (*aiv1.Query
|
||||
Reply: result.Reply,
|
||||
Intent: result.Intent,
|
||||
ActionTaken: result.ActionTaken,
|
||||
ModelUsed: result.ModelUsed,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ListModels returns the installed model names from Ollama.
|
||||
func (s *Server) ListModels(ctx context.Context, _ *aiv1.ListModelsRequest) (*aiv1.ListModelsResponse, error) {
|
||||
names, err := s.app.ListModels(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Unavailable, "list models: %v", err)
|
||||
}
|
||||
return &aiv1.ListModelsResponse{Names: names}, nil
|
||||
}
|
||||
|
||||
@ -6,6 +6,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
|
||||
)
|
||||
@ -14,37 +15,49 @@ type generateRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
Stream bool `json:"stream"`
|
||||
Think *bool `json:"think,omitempty"`
|
||||
}
|
||||
|
||||
type generateResponse struct {
|
||||
Response string `json:"response"`
|
||||
}
|
||||
|
||||
type listModelsResponse struct {
|
||||
Models []struct {
|
||||
Name string `json:"name"`
|
||||
} `json:"models"`
|
||||
}
|
||||
|
||||
// Client implements the LLM driven port with the Ollama generate API.
|
||||
type Client struct {
|
||||
baseURL string
|
||||
model string
|
||||
http *http.Client
|
||||
}
|
||||
|
||||
// New constructs an Ollama client with OTel-instrumented transport.
|
||||
func New(baseURL, model string, httpClient *http.Client) *Client {
|
||||
func New(baseURL string, httpClient *http.Client) *Client {
|
||||
if httpClient == nil {
|
||||
httpClient = &http.Client{Transport: otelhttp.NewTransport(http.DefaultTransport)}
|
||||
}
|
||||
if httpClient.Transport == nil {
|
||||
httpClient.Transport = otelhttp.NewTransport(http.DefaultTransport)
|
||||
}
|
||||
return &Client{baseURL: baseURL, model: model, http: httpClient}
|
||||
return &Client{baseURL: baseURL, http: httpClient}
|
||||
}
|
||||
|
||||
// Generate sends one non-streaming prompt to Ollama.
|
||||
func (c *Client) Generate(ctx context.Context, prompt string) (string, error) {
|
||||
body, err := json.Marshal(generateRequest{
|
||||
Model: c.model,
|
||||
func (c *Client) Generate(ctx context.Context, model, prompt string) (string, error) {
|
||||
reqBody := generateRequest{
|
||||
Model: model,
|
||||
Prompt: prompt,
|
||||
Stream: false,
|
||||
})
|
||||
}
|
||||
if isThinkingModel(model) {
|
||||
disabled := false
|
||||
reqBody.Think = &disabled
|
||||
}
|
||||
|
||||
body, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("marshal ollama request: %w", err)
|
||||
}
|
||||
@ -71,3 +84,38 @@ func (c *Client) Generate(ctx context.Context, prompt string) (string, error) {
|
||||
}
|
||||
return out.Response, nil
|
||||
}
|
||||
|
||||
// ListModels returns the installed model names from Ollama.
|
||||
func (c *Client) ListModels(ctx context.Context) ([]string, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.baseURL+"/api/tags", nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build ollama list models request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := c.http.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list ollama models: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("ollama returned status %s", resp.Status)
|
||||
}
|
||||
|
||||
var out listModelsResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
|
||||
return nil, fmt.Errorf("decode ollama list models response: %w", err)
|
||||
}
|
||||
|
||||
names := make([]string, 0, len(out.Models))
|
||||
for _, model := range out.Models {
|
||||
if model.Name != "" {
|
||||
names = append(names, model.Name)
|
||||
}
|
||||
}
|
||||
return names, nil
|
||||
}
|
||||
|
||||
func isThinkingModel(name string) bool {
|
||||
return strings.HasPrefix(name, "qwen3")
|
||||
}
|
||||
|
||||
81
ai-gateway/internal/adapters/secondary/ollama/client_test.go
Normal file
81
ai-gateway/internal/adapters/secondary/ollama/client_test.go
Normal file
@ -0,0 +1,81 @@
|
||||
package ollama
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestIsThinkingModel(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
want bool
|
||||
}{
|
||||
{name: "qwen3", want: true},
|
||||
{name: "qwen3:4b", want: true},
|
||||
{name: "qwen3:latest", want: true},
|
||||
{name: "llama3", want: false},
|
||||
{name: "mistral", want: false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
if got := isThinkingModel(tt.name); got != tt.want {
|
||||
t.Fatalf("isThinkingModel(%q) = %v, want %v", tt.name, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateSetsThinkForQwen3(t *testing.T) {
|
||||
var body map[string]any
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||
t.Fatalf("Decode() error = %v", err)
|
||||
}
|
||||
_, _ = w.Write([]byte(`{"response":"ok"}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
client := New(srv.URL, srv.Client())
|
||||
if _, err := client.Generate(context.Background(), "qwen3:latest", "prompt"); err != nil {
|
||||
t.Fatalf("Generate() error = %v", err)
|
||||
}
|
||||
if body["think"] != false {
|
||||
t.Fatalf("think = %#v, want false", body["think"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateOmitsThinkForLlama3(t *testing.T) {
|
||||
var body map[string]any
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||
t.Fatalf("Decode() error = %v", err)
|
||||
}
|
||||
_, _ = w.Write([]byte(`{"response":"ok"}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
client := New(srv.URL, srv.Client())
|
||||
if _, err := client.Generate(context.Background(), "llama3:latest", "prompt"); err != nil {
|
||||
t.Fatalf("Generate() error = %v", err)
|
||||
}
|
||||
if _, ok := body["think"]; ok {
|
||||
t.Fatalf("unexpected think key = %#v", body["think"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestListModelsReturnsNamesOnly(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = w.Write([]byte(`{"models":[{"name":"llama3:latest"},{"name":"qwen3:latest"}]}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
client := New(srv.URL, srv.Client())
|
||||
got, err := client.ListModels(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("ListModels() error = %v", err)
|
||||
}
|
||||
if len(got) != 2 || got[0] != "llama3:latest" || got[1] != "qwen3:latest" {
|
||||
t.Fatalf("ListModels() = %#v", got)
|
||||
}
|
||||
}
|
||||
@ -18,34 +18,40 @@ type QueryResult struct {
|
||||
Reply string
|
||||
Intent string
|
||||
ActionTaken bool
|
||||
ModelUsed string
|
||||
}
|
||||
|
||||
// QueryApp orchestrates one AI query request.
|
||||
type QueryApp struct {
|
||||
llm driven.LLMClient
|
||||
ha driven.HAClient
|
||||
cache *domain.LightCache
|
||||
log *slog.Logger
|
||||
llm driven.LLMClient
|
||||
ha driven.HAClient
|
||||
cache *domain.LightCache
|
||||
defaultModel string
|
||||
log *slog.Logger
|
||||
}
|
||||
|
||||
// NewQueryApp constructs the AI query application service.
|
||||
func NewQueryApp(llm driven.LLMClient, ha driven.HAClient, cache *domain.LightCache, log *slog.Logger) *QueryApp {
|
||||
return &QueryApp{llm: llm, ha: ha, cache: cache, log: log}
|
||||
func NewQueryApp(llm driven.LLMClient, ha driven.HAClient, cache *domain.LightCache, defaultModel string, log *slog.Logger) *QueryApp {
|
||||
return &QueryApp{llm: llm, ha: ha, cache: cache, defaultModel: defaultModel, log: log}
|
||||
}
|
||||
|
||||
// Query runs the full intent parsing and dispatch flow for one user request.
|
||||
func (a *QueryApp) Query(ctx context.Context, text string) (QueryResult, error) {
|
||||
func (a *QueryApp) Query(ctx context.Context, text, model string) (QueryResult, error) {
|
||||
if model == "" {
|
||||
model = a.defaultModel
|
||||
}
|
||||
lights, err := a.cache.Get(ctx)
|
||||
if err != nil {
|
||||
a.log.Error("light cache refresh failed", "err", err)
|
||||
return QueryResult{
|
||||
Reply: "I couldn't reach Home Assistant right now.",
|
||||
ActionTaken: false,
|
||||
ModelUsed: model,
|
||||
}, nil
|
||||
}
|
||||
|
||||
prompt := domain.BuildPrompt(text, promptLightLines(lights))
|
||||
raw, err := a.llm.Generate(ctx, prompt)
|
||||
raw, err := a.llm.Generate(ctx, model, prompt)
|
||||
if err != nil {
|
||||
return QueryResult{}, err
|
||||
}
|
||||
@ -57,6 +63,7 @@ func (a *QueryApp) Query(ctx context.Context, text string) (QueryResult, error)
|
||||
Reply: "I didn't understand that.",
|
||||
Intent: domain.IntentNone,
|
||||
ActionTaken: false,
|
||||
ModelUsed: model,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -64,7 +71,7 @@ func (a *QueryApp) Query(ctx context.Context, text string) (QueryResult, error)
|
||||
case domain.IntentTurnOnLight:
|
||||
entityID, ok := resolveLightEntity(intent.Entity, lights)
|
||||
if !ok {
|
||||
return QueryResult{Reply: "I couldn't find that light.", Intent: intent.Name}, nil
|
||||
return QueryResult{Reply: "I couldn't find that light.", Intent: intent.Name, ModelUsed: model}, nil
|
||||
}
|
||||
params, err := ParseLightParams(intent.Params)
|
||||
if err != nil {
|
||||
@ -72,6 +79,7 @@ func (a *QueryApp) Query(ctx context.Context, text string) (QueryResult, error)
|
||||
Reply: "I couldn't understand the light settings.",
|
||||
Intent: intent.Name,
|
||||
ActionTaken: false,
|
||||
ModelUsed: model,
|
||||
}, nil
|
||||
}
|
||||
if err := a.ha.TurnOnLight(ctx, entityID, params); err != nil {
|
||||
@ -80,17 +88,19 @@ func (a *QueryApp) Query(ctx context.Context, text string) (QueryResult, error)
|
||||
Reply: "I couldn't reach Home Assistant right now.",
|
||||
Intent: intent.Name,
|
||||
ActionTaken: false,
|
||||
ModelUsed: model,
|
||||
}, nil
|
||||
}
|
||||
return QueryResult{
|
||||
Reply: fallbackReply(intent.Reply, fmt.Sprintf("Turned on `%s`.", displayLightName(entityID, lights))),
|
||||
Intent: intent.Name,
|
||||
ActionTaken: true,
|
||||
ModelUsed: model,
|
||||
}, nil
|
||||
case domain.IntentTurnOffLight:
|
||||
entityID, ok := resolveLightEntity(intent.Entity, lights)
|
||||
if !ok {
|
||||
return QueryResult{Reply: "I couldn't find that light.", Intent: intent.Name}, nil
|
||||
return QueryResult{Reply: "I couldn't find that light.", Intent: intent.Name, ModelUsed: model}, nil
|
||||
}
|
||||
if err := a.ha.TurnOffLight(ctx, entityID); err != nil {
|
||||
a.log.Error("turn off light failed", "entity_id", entityID, "err", err)
|
||||
@ -98,18 +108,21 @@ func (a *QueryApp) Query(ctx context.Context, text string) (QueryResult, error)
|
||||
Reply: "I couldn't reach Home Assistant right now.",
|
||||
Intent: intent.Name,
|
||||
ActionTaken: false,
|
||||
ModelUsed: model,
|
||||
}, nil
|
||||
}
|
||||
return QueryResult{
|
||||
Reply: fallbackReply(intent.Reply, fmt.Sprintf("Turned off `%s`.", displayLightName(entityID, lights))),
|
||||
Intent: intent.Name,
|
||||
ActionTaken: true,
|
||||
ModelUsed: model,
|
||||
}, nil
|
||||
case domain.IntentListLights:
|
||||
return QueryResult{
|
||||
Reply: formatLightListReply(lights),
|
||||
Intent: intent.Name,
|
||||
ActionTaken: false,
|
||||
ModelUsed: model,
|
||||
}, nil
|
||||
case domain.IntentNone:
|
||||
fallthrough
|
||||
@ -118,10 +131,16 @@ func (a *QueryApp) Query(ctx context.Context, text string) (QueryResult, error)
|
||||
Reply: fallbackReply(intent.Reply, "I didn't understand that."),
|
||||
Intent: intent.Name,
|
||||
ActionTaken: false,
|
||||
ModelUsed: model,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// ListModels returns the currently installed model names.
|
||||
func (a *QueryApp) ListModels(ctx context.Context) ([]string, error) {
|
||||
return a.llm.ListModels(ctx)
|
||||
}
|
||||
|
||||
func promptLightLines(lights []driven.Light) []string {
|
||||
lines := make([]string, 0, len(lights))
|
||||
for _, light := range lights {
|
||||
|
||||
@ -13,11 +13,19 @@ import (
|
||||
)
|
||||
|
||||
type fakeLLM struct {
|
||||
generate func(context.Context, string) (string, error)
|
||||
generate func(context.Context, string, string) (string, error)
|
||||
listModels func(context.Context) ([]string, error)
|
||||
}
|
||||
|
||||
func (f *fakeLLM) Generate(ctx context.Context, prompt string) (string, error) {
|
||||
return f.generate(ctx, prompt)
|
||||
func (f *fakeLLM) Generate(ctx context.Context, model, prompt string) (string, error) {
|
||||
return f.generate(ctx, model, prompt)
|
||||
}
|
||||
|
||||
func (f *fakeLLM) ListModels(ctx context.Context) ([]string, error) {
|
||||
if f.listModels == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return f.listModels(ctx)
|
||||
}
|
||||
|
||||
type fakeHA struct {
|
||||
@ -54,12 +62,15 @@ func TestQueryAppTurnOnLight(t *testing.T) {
|
||||
ha := &fakeHA{lights: []driven.Light{{EntityID: "light.kitchen", FriendlyName: "Kitchen", State: "off"}}}
|
||||
cache := domain.NewLightCache(time.Hour, ha.ListLights)
|
||||
app := NewQueryApp(&fakeLLM{
|
||||
generate: func(ctx context.Context, prompt string) (string, error) {
|
||||
generate: func(ctx context.Context, model, prompt string) (string, error) {
|
||||
if model != "llama3:latest" {
|
||||
t.Fatalf("Generate() model = %q", model)
|
||||
}
|
||||
return `{"intent":"turn_on_light","entity":"Kitchen","params":{"brightness":"80"},"reply":"Turning on Kitchen."}`, nil
|
||||
},
|
||||
}, ha, cache, slog.Default())
|
||||
}, ha, cache, "llama3:latest", slog.Default())
|
||||
|
||||
got, err := app.Query(context.Background(), "turn on kitchen")
|
||||
got, err := app.Query(context.Background(), "turn on kitchen", "")
|
||||
if err != nil {
|
||||
t.Fatalf("Query() error = %v", err)
|
||||
}
|
||||
@ -77,12 +88,12 @@ func TestQueryAppTurnOnLight(t *testing.T) {
|
||||
func TestQueryAppInvalidJSON(t *testing.T) {
|
||||
ha := &fakeHA{lights: []driven.Light{{EntityID: "light.kitchen", FriendlyName: "Kitchen", State: "off"}}}
|
||||
app := NewQueryApp(&fakeLLM{
|
||||
generate: func(ctx context.Context, prompt string) (string, error) {
|
||||
generate: func(ctx context.Context, model, prompt string) (string, error) {
|
||||
return `not-json`, nil
|
||||
},
|
||||
}, ha, domain.NewLightCache(time.Hour, ha.ListLights), slog.Default())
|
||||
}, ha, domain.NewLightCache(time.Hour, ha.ListLights), "llama3:latest", slog.Default())
|
||||
|
||||
got, err := app.Query(context.Background(), "turn on kitchen")
|
||||
got, err := app.Query(context.Background(), "turn on kitchen", "")
|
||||
if err != nil {
|
||||
t.Fatalf("Query() error = %v", err)
|
||||
}
|
||||
@ -97,12 +108,12 @@ func TestQueryAppInvalidJSON(t *testing.T) {
|
||||
func TestQueryAppIntentNone(t *testing.T) {
|
||||
ha := &fakeHA{lights: []driven.Light{{EntityID: "light.kitchen", FriendlyName: "Kitchen", State: "off"}}}
|
||||
app := NewQueryApp(&fakeLLM{
|
||||
generate: func(ctx context.Context, prompt string) (string, error) {
|
||||
generate: func(ctx context.Context, model, prompt string) (string, error) {
|
||||
return `{"intent":"none","entity":"","params":{},"reply":"Hello there."}`, nil
|
||||
},
|
||||
}, ha, domain.NewLightCache(time.Hour, ha.ListLights), slog.Default())
|
||||
}, ha, domain.NewLightCache(time.Hour, ha.ListLights), "llama3:latest", slog.Default())
|
||||
|
||||
got, err := app.Query(context.Background(), "hello")
|
||||
got, err := app.Query(context.Background(), "hello", "")
|
||||
if err != nil {
|
||||
t.Fatalf("Query() error = %v", err)
|
||||
}
|
||||
@ -117,12 +128,12 @@ func TestQueryAppHAFailure(t *testing.T) {
|
||||
turnOnErr: errors.New("boom"),
|
||||
}
|
||||
app := NewQueryApp(&fakeLLM{
|
||||
generate: func(ctx context.Context, prompt string) (string, error) {
|
||||
generate: func(ctx context.Context, model, prompt string) (string, error) {
|
||||
return `{"intent":"turn_on_light","entity":"light.kitchen","params":{},"reply":"Turning on Kitchen."}`, nil
|
||||
},
|
||||
}, ha, domain.NewLightCache(time.Hour, ha.ListLights), slog.Default())
|
||||
}, ha, domain.NewLightCache(time.Hour, ha.ListLights), "llama3:latest", slog.Default())
|
||||
|
||||
got, err := app.Query(context.Background(), "turn on kitchen")
|
||||
got, err := app.Query(context.Background(), "turn on kitchen", "")
|
||||
if err != nil {
|
||||
t.Fatalf("Query() error = %v", err)
|
||||
}
|
||||
@ -134,12 +145,12 @@ func TestQueryAppHAFailure(t *testing.T) {
|
||||
func TestQueryAppListLights(t *testing.T) {
|
||||
ha := &fakeHA{lights: []driven.Light{{EntityID: "light.kitchen", FriendlyName: "Kitchen", State: "on"}}}
|
||||
app := NewQueryApp(&fakeLLM{
|
||||
generate: func(ctx context.Context, prompt string) (string, error) {
|
||||
generate: func(ctx context.Context, model, prompt string) (string, error) {
|
||||
return `{"intent":"list_lights","entity":"","params":{},"reply":""}`, nil
|
||||
},
|
||||
}, ha, domain.NewLightCache(time.Hour, ha.ListLights), slog.Default())
|
||||
}, ha, domain.NewLightCache(time.Hour, ha.ListLights), "llama3:latest", slog.Default())
|
||||
|
||||
got, err := app.Query(context.Background(), "what lights exist")
|
||||
got, err := app.Query(context.Background(), "what lights exist", "")
|
||||
if err != nil {
|
||||
t.Fatalf("Query() error = %v", err)
|
||||
}
|
||||
@ -164,3 +175,42 @@ func TestLightCacheRefreshAfterTTL(t *testing.T) {
|
||||
t.Fatalf("ListLights calls = %d, want at least 2", ha.listCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryAppExplicitModel(t *testing.T) {
|
||||
ha := &fakeHA{lights: []driven.Light{{EntityID: "light.kitchen", FriendlyName: "Kitchen", State: "off"}}}
|
||||
app := NewQueryApp(&fakeLLM{
|
||||
generate: func(ctx context.Context, model, prompt string) (string, error) {
|
||||
if model != "qwen3:latest" {
|
||||
t.Fatalf("Generate() model = %q", model)
|
||||
}
|
||||
return `{"intent":"none","entity":"","params":{},"reply":"Hello there."}`, nil
|
||||
},
|
||||
}, ha, domain.NewLightCache(time.Hour, ha.ListLights), "llama3:latest", slog.Default())
|
||||
|
||||
got, err := app.Query(context.Background(), "hello", "qwen3:latest")
|
||||
if err != nil {
|
||||
t.Fatalf("Query() error = %v", err)
|
||||
}
|
||||
if got.ModelUsed != "qwen3:latest" {
|
||||
t.Fatalf("ModelUsed = %q", got.ModelUsed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryAppListModels(t *testing.T) {
|
||||
app := NewQueryApp(&fakeLLM{
|
||||
generate: func(ctx context.Context, model, prompt string) (string, error) {
|
||||
return "", nil
|
||||
},
|
||||
listModels: func(ctx context.Context) ([]string, error) {
|
||||
return []string{"llama3:latest", "qwen3:latest"}, nil
|
||||
},
|
||||
}, &fakeHA{}, domain.NewLightCache(time.Hour, func(context.Context) ([]driven.Light, error) { return nil, nil }), "llama3:latest", slog.Default())
|
||||
|
||||
got, err := app.ListModels(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("ListModels() error = %v", err)
|
||||
}
|
||||
if !reflect.DeepEqual(got, []string{"llama3:latest", "qwen3:latest"}) {
|
||||
t.Fatalf("ListModels() = %#v", got)
|
||||
}
|
||||
}
|
||||
|
||||
@ -4,5 +4,6 @@ import "context"
|
||||
|
||||
// LLMClient generates one model response for a prompt.
|
||||
type LLMClient interface {
|
||||
Generate(ctx context.Context, prompt string) (string, error)
|
||||
Generate(ctx context.Context, model, prompt string) (string, error)
|
||||
ListModels(ctx context.Context) ([]string, error)
|
||||
}
|
||||
|
||||
@ -17,6 +17,8 @@ import (
|
||||
"gitea.nik4nao.com/nik/home-services/discord-bot/internal/app"
|
||||
"gitea.nik4nao.com/nik/home-services/discord-bot/internal/config"
|
||||
"gitea.nik4nao.com/nik/home-services/discord-bot/internal/logger"
|
||||
"gitea.nik4nao.com/nik/home-services/discord-bot/internal/modelstore"
|
||||
"gitea.nik4nao.com/nik/home-services/discord-bot/internal/modelvalidator"
|
||||
"gitea.nik4nao.com/nik/home-services/discord-bot/internal/telemetry"
|
||||
)
|
||||
|
||||
@ -84,7 +86,9 @@ func main() {
|
||||
}
|
||||
}()
|
||||
|
||||
commandApp := app.NewCommandApp(haClient, aiClient)
|
||||
modelStore := modelstore.New()
|
||||
validator := modelvalidator.New(aiClient, 30*time.Second)
|
||||
commandApp := app.NewCommandApp(haClient, aiClient, modelStore, validator)
|
||||
|
||||
// Discord-specific wiring stays at the edge so the app layer remains transport-agnostic.
|
||||
session, err := discordgo.New("Bot " + cfg.DiscordToken)
|
||||
|
||||
@ -24,8 +24,12 @@ type commandHandler interface {
|
||||
HandleLightToggle(ctx context.Context, entityID string) (string, error)
|
||||
HandleSwitchList(ctx context.Context) (string, error)
|
||||
HandleAIQuery(ctx context.Context, text string) (string, error)
|
||||
HandleAIModelSet(ctx context.Context, name string) (string, error)
|
||||
HandleAIModelGet(ctx context.Context) (string, error)
|
||||
HandleAIModelList(ctx context.Context) (string, error)
|
||||
AutocompleteLights(ctx context.Context) ([]apppkg.Choice, error)
|
||||
AutocompleteSwitches(ctx context.Context) ([]apppkg.Choice, error)
|
||||
AutocompleteAIModels(ctx context.Context) ([]apppkg.Choice, error)
|
||||
}
|
||||
|
||||
// Handler adapts Discord interactions to the command application layer.
|
||||
@ -69,6 +73,9 @@ func (h *Handler) handleApplicationCommand(ctx context.Context, s *discordgo.Ses
|
||||
command := data.Name
|
||||
if len(data.Options) > 0 {
|
||||
command += "." + data.Options[0].Name
|
||||
if len(data.Options[0].Options) > 0 && data.Options[0].Type == discordgo.ApplicationCommandOptionSubCommandGroup {
|
||||
command += "." + data.Options[0].Options[0].Name
|
||||
}
|
||||
}
|
||||
user := ""
|
||||
if i.Member != nil && i.Member.User != nil {
|
||||
@ -90,7 +97,18 @@ func (h *Handler) handleApplicationCommand(ctx context.Context, s *discordgo.Ses
|
||||
}
|
||||
|
||||
sub := data.Options[0]
|
||||
switch data.Name + "." + sub.Name {
|
||||
commandPath := data.Name + "." + sub.Name
|
||||
target := sub
|
||||
if sub.Type == discordgo.ApplicationCommandOptionSubCommandGroup {
|
||||
if len(sub.Options) == 0 {
|
||||
h.respondError(ctx, s, i.Interaction, true, start, fmt.Errorf("missing grouped subcommand"))
|
||||
return
|
||||
}
|
||||
target = sub.Options[0]
|
||||
commandPath += "." + target.Name
|
||||
}
|
||||
|
||||
switch commandPath {
|
||||
case "light.list":
|
||||
msg, err := h.app.HandleLightList(ctx)
|
||||
if err != nil {
|
||||
@ -145,10 +163,38 @@ func (h *Handler) handleApplicationCommand(ctx context.Context, s *discordgo.Ses
|
||||
)
|
||||
return
|
||||
}
|
||||
msg, err := h.app.HandleAIQuery(ctx, requiredStringOption(sub, "text"))
|
||||
msg, err := h.app.HandleAIQuery(ctx, requiredStringOption(target, "text"))
|
||||
h.followup(ctx, s, i.Interaction, msg, true, start, err)
|
||||
case "ai.model.set":
|
||||
if err := h.deferResponse(s, i.Interaction, true); err != nil {
|
||||
log.Error("discord response failed",
|
||||
"duration_ms", time.Since(start).Milliseconds(),
|
||||
"error", err.Error(),
|
||||
)
|
||||
return
|
||||
}
|
||||
msg, err := h.app.HandleAIModelSet(ctx, requiredStringOption(target, "name"))
|
||||
h.followup(ctx, s, i.Interaction, msg, true, start, err)
|
||||
case "ai.model.get":
|
||||
msg, err := h.app.HandleAIModelGet(ctx)
|
||||
if err != nil {
|
||||
h.respondError(ctx, s, i.Interaction, true, start, err)
|
||||
return
|
||||
}
|
||||
h.respondMessage(ctx, s, i.Interaction, msg, true)
|
||||
log.Info("command handled", "duration_ms", time.Since(start).Milliseconds())
|
||||
case "ai.model.list":
|
||||
if err := h.deferResponse(s, i.Interaction, true); err != nil {
|
||||
log.Error("discord response failed",
|
||||
"duration_ms", time.Since(start).Milliseconds(),
|
||||
"error", err.Error(),
|
||||
)
|
||||
return
|
||||
}
|
||||
msg, err := h.app.HandleAIModelList(ctx)
|
||||
h.followup(ctx, s, i.Interaction, msg, true, start, err)
|
||||
default:
|
||||
h.respondError(ctx, s, i.Interaction, true, start, fmt.Errorf("unsupported command: %s.%s", data.Name, sub.Name))
|
||||
h.respondError(ctx, s, i.Interaction, true, start, fmt.Errorf("unsupported command: %s", commandPath))
|
||||
}
|
||||
}
|
||||
|
||||
@ -167,6 +213,10 @@ func (h *Handler) handleAutocomplete(ctx context.Context, s *discordgo.Session,
|
||||
choices, err = h.app.AutocompleteLights(ctx)
|
||||
case "switch":
|
||||
choices, err = h.app.AutocompleteSwitches(ctx)
|
||||
case "ai":
|
||||
if focusedOptionName(data) == "name" {
|
||||
choices, err = h.app.AutocompleteAIModels(ctx)
|
||||
}
|
||||
default:
|
||||
choices = nil
|
||||
}
|
||||
@ -299,6 +349,15 @@ func optionalUint32Option(sub *discordgo.ApplicationCommandInteractionDataOption
|
||||
|
||||
func focusedOptionValue(data discordgo.ApplicationCommandInteractionData) string {
|
||||
for _, sub := range data.Options {
|
||||
if sub.Type == discordgo.ApplicationCommandOptionSubCommandGroup {
|
||||
for _, nested := range sub.Options {
|
||||
for _, opt := range nested.Options {
|
||||
if opt.Focused {
|
||||
return opt.StringValue()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, opt := range sub.Options {
|
||||
if opt.Focused {
|
||||
return opt.StringValue()
|
||||
@ -307,3 +366,23 @@ func focusedOptionValue(data discordgo.ApplicationCommandInteractionData) string
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func focusedOptionName(data discordgo.ApplicationCommandInteractionData) string {
|
||||
for _, sub := range data.Options {
|
||||
if sub.Type == discordgo.ApplicationCommandOptionSubCommandGroup {
|
||||
for _, nested := range sub.Options {
|
||||
for _, opt := range nested.Options {
|
||||
if opt.Focused {
|
||||
return opt.Name
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, opt := range sub.Options {
|
||||
if opt.Focused {
|
||||
return opt.Name
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
@ -98,6 +98,37 @@ func RegisterCommands(s *discordgo.Session, guildID string) error {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: discordgo.ApplicationCommandOptionSubCommandGroup,
|
||||
Name: "model",
|
||||
Description: "Manage the active AI model",
|
||||
Options: []*discordgo.ApplicationCommandOption{
|
||||
{
|
||||
Type: discordgo.ApplicationCommandOptionSubCommand,
|
||||
Name: "set",
|
||||
Description: "Set the active model",
|
||||
Options: []*discordgo.ApplicationCommandOption{
|
||||
{
|
||||
Type: discordgo.ApplicationCommandOptionString,
|
||||
Name: "name",
|
||||
Description: "Model name",
|
||||
Required: true,
|
||||
Autocomplete: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: discordgo.ApplicationCommandOptionSubCommand,
|
||||
Name: "get",
|
||||
Description: "Show the active model",
|
||||
},
|
||||
{
|
||||
Type: discordgo.ApplicationCommandOptionSubCommand,
|
||||
Name: "list",
|
||||
Description: "List available models",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@ -61,22 +61,39 @@ func (c *Client) Close() error {
|
||||
}
|
||||
|
||||
// Query forwards one free-form request to ai-gateway.
|
||||
func (c *Client) Query(ctx context.Context, text string) (string, error) {
|
||||
func (c *Client) Query(ctx context.Context, text, model string) (string, string, error) {
|
||||
start := time.Now()
|
||||
log := logger.FromContext(ctx).With("grpc.method", "AIService/Query")
|
||||
resp, err := c.client.Query(ctx, &aiv1.QueryRequest{
|
||||
Text: text,
|
||||
Source: "discord-bot",
|
||||
Model: model,
|
||||
})
|
||||
if err != nil {
|
||||
log.Error("grpc call failed",
|
||||
"duration_ms", time.Since(start).Milliseconds(),
|
||||
"error", err.Error(),
|
||||
)
|
||||
return "", fmt.Errorf("query ai-gateway: %w", err)
|
||||
return "", "", fmt.Errorf("query ai-gateway: %w", err)
|
||||
}
|
||||
log.Debug("grpc call completed", "duration_ms", time.Since(start).Milliseconds())
|
||||
return resp.GetReply(), nil
|
||||
return resp.GetReply(), resp.GetModelUsed(), nil
|
||||
}
|
||||
|
||||
// ListModels returns the installed model names from ai-gateway.
|
||||
func (c *Client) ListModels(ctx context.Context) ([]string, error) {
|
||||
start := time.Now()
|
||||
log := logger.FromContext(ctx).With("grpc.method", "AIService/ListModels")
|
||||
resp, err := c.client.ListModels(ctx, &aiv1.ListModelsRequest{})
|
||||
if err != nil {
|
||||
log.Error("grpc call failed",
|
||||
"duration_ms", time.Since(start).Milliseconds(),
|
||||
"error", err.Error(),
|
||||
)
|
||||
return nil, fmt.Errorf("list ai-gateway models: %w", err)
|
||||
}
|
||||
log.Debug("grpc call completed", "duration_ms", time.Since(start).Milliseconds())
|
||||
return append([]string(nil), resp.GetNames()...), nil
|
||||
}
|
||||
|
||||
func loadTransportCredentials(tlsDir string) (credentials.TransportCredentials, error) {
|
||||
|
||||
@ -7,6 +7,8 @@ import (
|
||||
"strings"
|
||||
|
||||
"gitea.nik4nao.com/nik/home-services/discord-bot/internal/core/ports/driven"
|
||||
"gitea.nik4nao.com/nik/home-services/discord-bot/internal/modelstore"
|
||||
"gitea.nik4nao.com/nik/home-services/discord-bot/internal/modelvalidator"
|
||||
)
|
||||
|
||||
// Choice is one Discord autocomplete entry.
|
||||
@ -17,13 +19,15 @@ type Choice struct {
|
||||
|
||||
// CommandApp orchestrates Discord command use cases against ha-gateway.
|
||||
type CommandApp struct {
|
||||
ha driven.HAGateway
|
||||
ai driven.AIGateway
|
||||
ha driven.HAGateway
|
||||
ai driven.AIGateway
|
||||
models *modelstore.Store
|
||||
validator *modelvalidator.Validator
|
||||
}
|
||||
|
||||
// NewCommandApp constructs the Discord command application service.
|
||||
func NewCommandApp(ha driven.HAGateway, ai driven.AIGateway) *CommandApp {
|
||||
return &CommandApp{ha: ha, ai: ai}
|
||||
func NewCommandApp(ha driven.HAGateway, ai driven.AIGateway, models *modelstore.Store, validator *modelvalidator.Validator) *CommandApp {
|
||||
return &CommandApp{ha: ha, ai: ai, models: models, validator: validator}
|
||||
}
|
||||
|
||||
// HandleLightList formats discovered lights into a monospace-friendly response.
|
||||
@ -124,11 +128,72 @@ func (a *CommandApp) HandleSwitchList(ctx context.Context) (string, error) {
|
||||
|
||||
// HandleAIQuery forwards a free-form request to ai-gateway.
|
||||
func (a *CommandApp) HandleAIQuery(ctx context.Context, text string) (string, error) {
|
||||
reply, err := a.ai.Query(ctx, text)
|
||||
reply, modelUsed, err := a.ai.Query(ctx, text, a.models.Get())
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("handle ai query: %w", err)
|
||||
}
|
||||
return reply, nil
|
||||
return fmt.Sprintf("%s\n\n_(via %s)_", reply, modelUsed), nil
|
||||
}
|
||||
|
||||
// HandleAIModelSet validates and stores the selected model globally.
|
||||
func (a *CommandApp) HandleAIModelSet(ctx context.Context, name string) (string, error) {
|
||||
canonical, err := a.validator.Normalize(ctx, name)
|
||||
if err != nil {
|
||||
if err.Error() == "ambiguous model name" {
|
||||
return "", fmt.Errorf("unknown model: %s. Be more specific.", name)
|
||||
}
|
||||
if err.Error() == "unknown model" {
|
||||
return "", fmt.Errorf("unknown model: %s. Try /ai model list.", name)
|
||||
}
|
||||
return "", fmt.Errorf("validate model: %w", err)
|
||||
}
|
||||
a.models.Set(canonical)
|
||||
return fmt.Sprintf("Active model set to `%s`.", canonical), nil
|
||||
}
|
||||
|
||||
// HandleAIModelGet reports the current selected model or default state.
|
||||
func (a *CommandApp) HandleAIModelGet(ctx context.Context) (string, error) {
|
||||
cur := a.models.Get()
|
||||
if cur == "" {
|
||||
return "No model override set. Using ai-gateway default.", nil
|
||||
}
|
||||
return fmt.Sprintf("Active model: `%s`", cur), nil
|
||||
}
|
||||
|
||||
// HandleAIModelList shows the installed models and marks the active selection.
|
||||
func (a *CommandApp) HandleAIModelList(ctx context.Context) (string, error) {
|
||||
models, err := a.validator.Known(ctx)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("list models: %w", err)
|
||||
}
|
||||
if len(models) == 0 {
|
||||
return "No models installed on the Ollama host.", nil
|
||||
}
|
||||
|
||||
active := a.models.Get()
|
||||
lines := make([]string, 0, len(models)+1)
|
||||
lines = append(lines, "Available models:")
|
||||
for _, model := range models {
|
||||
marker := ""
|
||||
if model == active {
|
||||
marker = " <- active"
|
||||
}
|
||||
lines = append(lines, fmt.Sprintf("- `%s`%s", model, marker))
|
||||
}
|
||||
return strings.Join(lines, "\n"), nil
|
||||
}
|
||||
|
||||
// AutocompleteAIModels returns model names for the /ai model set command.
|
||||
func (a *CommandApp) AutocompleteAIModels(ctx context.Context) ([]Choice, error) {
|
||||
models, err := a.validator.Known(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("autocomplete ai models: %w", err)
|
||||
}
|
||||
choices := make([]Choice, 0, len(models))
|
||||
for _, model := range models {
|
||||
choices = append(choices, Choice{Label: model, Value: model})
|
||||
}
|
||||
return choices, nil
|
||||
}
|
||||
|
||||
// AutocompleteLights maps discovered lights into Discord autocomplete choices.
|
||||
|
||||
@ -5,8 +5,11 @@ import (
|
||||
"errors"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gitea.nik4nao.com/nik/home-services/discord-bot/internal/core/ports/driven"
|
||||
"gitea.nik4nao.com/nik/home-services/discord-bot/internal/modelstore"
|
||||
"gitea.nik4nao.com/nik/home-services/discord-bot/internal/modelvalidator"
|
||||
)
|
||||
|
||||
type mockHAGateway struct {
|
||||
@ -18,7 +21,8 @@ type mockHAGateway struct {
|
||||
}
|
||||
|
||||
type mockAIGateway struct {
|
||||
queryFunc func(ctx context.Context, text string) (string, error)
|
||||
queryFunc func(ctx context.Context, text, model string) (string, string, error)
|
||||
listModelsFunc func(ctx context.Context) ([]string, error)
|
||||
}
|
||||
|
||||
func (m *mockHAGateway) ListLights(ctx context.Context) ([]driven.Light, error) {
|
||||
@ -56,11 +60,22 @@ func (m *mockHAGateway) ToggleLight(ctx context.Context, entityID string) error
|
||||
return m.toggleLightFunc(ctx, entityID)
|
||||
}
|
||||
|
||||
func (m *mockAIGateway) Query(ctx context.Context, text string) (string, error) {
|
||||
func (m *mockAIGateway) Query(ctx context.Context, text, model string) (string, string, error) {
|
||||
if m.queryFunc == nil {
|
||||
return "", nil
|
||||
return "", "", nil
|
||||
}
|
||||
return m.queryFunc(ctx, text)
|
||||
return m.queryFunc(ctx, text, model)
|
||||
}
|
||||
|
||||
func (m *mockAIGateway) ListModels(ctx context.Context) ([]string, error) {
|
||||
if m.listModelsFunc == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return m.listModelsFunc(ctx)
|
||||
}
|
||||
|
||||
func newTestCommandApp(ha *mockHAGateway, ai *mockAIGateway) *CommandApp {
|
||||
return NewCommandApp(ha, ai, modelstore.New(), modelvalidator.New(ai, time.Minute))
|
||||
}
|
||||
|
||||
func TestCommandAppHandleLightList(t *testing.T) {
|
||||
@ -102,7 +117,7 @@ func TestCommandAppHandleLightList(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
app := NewCommandApp(&mockHAGateway{
|
||||
app := newTestCommandApp(&mockHAGateway{
|
||||
listLightsFunc: func(ctx context.Context) ([]driven.Light, error) {
|
||||
return tt.lights, nil
|
||||
},
|
||||
@ -183,7 +198,7 @@ func TestCommandAppHandleLightOn(t *testing.T) {
|
||||
var gotBrightness *uint32
|
||||
var gotColorTemp *uint32
|
||||
|
||||
app := NewCommandApp(&mockHAGateway{
|
||||
app := newTestCommandApp(&mockHAGateway{
|
||||
listLightsFunc: func(ctx context.Context) ([]driven.Light, error) {
|
||||
if tt.listErr != nil {
|
||||
return nil, tt.listErr
|
||||
@ -258,7 +273,7 @@ func TestCommandAppHandleLightOff(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var gotTransition *uint32
|
||||
app := NewCommandApp(&mockHAGateway{
|
||||
app := newTestCommandApp(&mockHAGateway{
|
||||
listLightsFunc: func(ctx context.Context) ([]driven.Light, error) {
|
||||
return tt.lights, nil
|
||||
},
|
||||
@ -317,7 +332,7 @@ func TestCommandAppHandleLightToggle(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
app := NewCommandApp(&mockHAGateway{
|
||||
app := newTestCommandApp(&mockHAGateway{
|
||||
listLightsFunc: func(ctx context.Context) ([]driven.Light, error) {
|
||||
if tt.listErr != nil {
|
||||
return nil, tt.listErr
|
||||
@ -368,7 +383,7 @@ func TestCommandAppHandleSwitchList(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
app := NewCommandApp(&mockHAGateway{
|
||||
app := newTestCommandApp(&mockHAGateway{
|
||||
listSwitchesFunc: func(ctx context.Context) ([]driven.Switch, error) {
|
||||
return tt.switches, nil
|
||||
},
|
||||
@ -413,7 +428,7 @@ func TestCommandAppAutocompleteLights(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
app := NewCommandApp(&mockHAGateway{
|
||||
app := newTestCommandApp(&mockHAGateway{
|
||||
listLightsFunc: func(ctx context.Context) ([]driven.Light, error) {
|
||||
if tt.listErr != nil {
|
||||
return nil, tt.listErr
|
||||
@ -467,7 +482,7 @@ func TestCommandAppAutocompleteSwitches(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
app := NewCommandApp(&mockHAGateway{
|
||||
app := newTestCommandApp(&mockHAGateway{
|
||||
listSwitchesFunc: func(ctx context.Context) ([]driven.Switch, error) {
|
||||
if tt.listErr != nil {
|
||||
return nil, tt.listErr
|
||||
@ -494,20 +509,41 @@ func TestCommandAppAutocompleteSwitches(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestCommandAppHandleAIQuery(t *testing.T) {
|
||||
store := modelstore.New()
|
||||
store.Set("llama3:latest")
|
||||
app := NewCommandApp(&mockHAGateway{}, &mockAIGateway{
|
||||
queryFunc: func(ctx context.Context, text string) (string, error) {
|
||||
queryFunc: func(ctx context.Context, text, model string) (string, string, error) {
|
||||
if text != "turn on kitchen" {
|
||||
t.Fatalf("Query() text = %q", text)
|
||||
}
|
||||
return "Turning on Kitchen.", nil
|
||||
if model != "llama3:latest" {
|
||||
t.Fatalf("Query() model = %q", model)
|
||||
}
|
||||
return "Turning on Kitchen.", "llama3:latest", nil
|
||||
},
|
||||
})
|
||||
}, store, modelvalidator.New(&mockAIGateway{}, time.Minute))
|
||||
|
||||
got, err := app.HandleAIQuery(context.Background(), "turn on kitchen")
|
||||
if err != nil {
|
||||
t.Fatalf("HandleAIQuery() error = %v", err)
|
||||
}
|
||||
if got != "Turning on Kitchen." {
|
||||
if got != "Turning on Kitchen.\n\n_(via llama3:latest)_" {
|
||||
t.Fatalf("HandleAIQuery() = %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommandAppHandleAIModelSet(t *testing.T) {
|
||||
app := newTestCommandApp(&mockHAGateway{}, &mockAIGateway{
|
||||
listModelsFunc: func(ctx context.Context) ([]string, error) {
|
||||
return []string{"llama3:latest"}, nil
|
||||
},
|
||||
})
|
||||
|
||||
got, err := app.HandleAIModelSet(context.Background(), "llama3")
|
||||
if err != nil {
|
||||
t.Fatalf("HandleAIModelSet() error = %v", err)
|
||||
}
|
||||
if got != "Active model set to `llama3:latest`." {
|
||||
t.Fatalf("HandleAIModelSet() = %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
@ -4,5 +4,6 @@ import "context"
|
||||
|
||||
// AIGateway exposes the free-form AI query API used by the Discord bot.
|
||||
type AIGateway interface {
|
||||
Query(ctx context.Context, text string) (string, error)
|
||||
Query(ctx context.Context, text, model string) (reply, modelUsed string, err error)
|
||||
ListModels(ctx context.Context) ([]string, error)
|
||||
}
|
||||
|
||||
28
discord-bot/internal/modelstore/store.go
Normal file
28
discord-bot/internal/modelstore/store.go
Normal file
@ -0,0 +1,28 @@
|
||||
package modelstore
|
||||
|
||||
import "sync"
|
||||
|
||||
// Store keeps the globally selected AI model in memory.
|
||||
type Store struct {
|
||||
mu sync.RWMutex
|
||||
selected string
|
||||
}
|
||||
|
||||
// New constructs an empty in-memory model store.
|
||||
func New() *Store {
|
||||
return &Store{}
|
||||
}
|
||||
|
||||
// Get returns the currently selected model, or empty for default behavior.
|
||||
func (s *Store) Get() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.selected
|
||||
}
|
||||
|
||||
// Set updates the currently selected model.
|
||||
func (s *Store) Set(model string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.selected = model
|
||||
}
|
||||
15
discord-bot/internal/modelstore/store_test.go
Normal file
15
discord-bot/internal/modelstore/store_test.go
Normal file
@ -0,0 +1,15 @@
|
||||
package modelstore
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestStoreGetSet(t *testing.T) {
|
||||
store := New()
|
||||
if got := store.Get(); got != "" {
|
||||
t.Fatalf("Get() = %q, want empty", got)
|
||||
}
|
||||
|
||||
store.Set("llama3:latest")
|
||||
if got := store.Get(); got != "llama3:latest" {
|
||||
t.Fatalf("Get() = %q, want llama3:latest", got)
|
||||
}
|
||||
}
|
||||
86
discord-bot/internal/modelvalidator/validator.go
Normal file
86
discord-bot/internal/modelvalidator/validator.go
Normal file
@ -0,0 +1,86 @@
|
||||
package modelvalidator
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"gitea.nik4nao.com/nik/home-services/discord-bot/internal/core/ports/driven"
|
||||
)
|
||||
|
||||
// Validator caches the model list briefly and normalizes friendly names.
|
||||
type Validator struct {
|
||||
client driven.AIGateway
|
||||
ttl time.Duration
|
||||
|
||||
mu sync.Mutex
|
||||
cache []string
|
||||
cachedAt time.Time
|
||||
}
|
||||
|
||||
// New constructs a model validator with a TTL cache.
|
||||
func New(client driven.AIGateway, ttl time.Duration) *Validator {
|
||||
return &Validator{client: client, ttl: ttl}
|
||||
}
|
||||
|
||||
// Known returns the cached model list, refreshing when stale.
|
||||
func (v *Validator) Known(ctx context.Context) ([]string, error) {
|
||||
v.mu.Lock()
|
||||
defer v.mu.Unlock()
|
||||
if len(v.cache) > 0 && time.Since(v.cachedAt) < v.ttl {
|
||||
return append([]string(nil), v.cache...), nil
|
||||
}
|
||||
models, err := v.client.ListModels(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
v.cache = append([]string(nil), models...)
|
||||
v.cachedAt = time.Now()
|
||||
return append([]string(nil), v.cache...), nil
|
||||
}
|
||||
|
||||
// Normalize resolves a user-provided name to a canonical installed model name.
|
||||
func (v *Validator) Normalize(ctx context.Context, name string) (string, error) {
|
||||
models, err := v.Known(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
for _, model := range models {
|
||||
if model == name {
|
||||
return model, nil
|
||||
}
|
||||
}
|
||||
latest := name + ":latest"
|
||||
for _, model := range models {
|
||||
if model == latest {
|
||||
return model, nil
|
||||
}
|
||||
}
|
||||
|
||||
lower := strings.ToLower(name)
|
||||
for _, model := range models {
|
||||
if strings.ToLower(model) == lower {
|
||||
return model, nil
|
||||
}
|
||||
}
|
||||
lowerLatest := strings.ToLower(latest)
|
||||
for _, model := range models {
|
||||
if strings.ToLower(model) == lowerLatest {
|
||||
return model, nil
|
||||
}
|
||||
}
|
||||
|
||||
matches := make([]string, 0, 2)
|
||||
prefix := lower + ":"
|
||||
for _, model := range models {
|
||||
if strings.HasPrefix(strings.ToLower(model), prefix) {
|
||||
matches = append(matches, model)
|
||||
}
|
||||
}
|
||||
if len(matches) > 1 {
|
||||
return "", fmt.Errorf("ambiguous model name")
|
||||
}
|
||||
return "", fmt.Errorf("unknown model")
|
||||
}
|
||||
70
discord-bot/internal/modelvalidator/validator_test.go
Normal file
70
discord-bot/internal/modelvalidator/validator_test.go
Normal file
@ -0,0 +1,70 @@
|
||||
package modelvalidator
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type fakeAIGateway struct {
|
||||
listModels func(context.Context) ([]string, error)
|
||||
}
|
||||
|
||||
func (f *fakeAIGateway) Query(ctx context.Context, text, model string) (string, string, error) {
|
||||
return "", "", nil
|
||||
}
|
||||
|
||||
func (f *fakeAIGateway) ListModels(ctx context.Context) ([]string, error) {
|
||||
return f.listModels(ctx)
|
||||
}
|
||||
|
||||
func TestValidatorNormalize(t *testing.T) {
|
||||
v := New(&fakeAIGateway{
|
||||
listModels: func(ctx context.Context) ([]string, error) {
|
||||
return []string{"llama3:latest", "qwen3:latest", "qwen3:4b"}, nil
|
||||
},
|
||||
}, time.Minute)
|
||||
|
||||
got, err := v.Normalize(context.Background(), "llama3")
|
||||
if err != nil || got != "llama3:latest" {
|
||||
t.Fatalf("Normalize(llama3) = %q, %v", got, err)
|
||||
}
|
||||
|
||||
got, err = v.Normalize(context.Background(), "qwen3")
|
||||
if err != nil || got != "qwen3:latest" {
|
||||
t.Fatalf("Normalize(qwen3) = %q, %v", got, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidatorKnownCaches(t *testing.T) {
|
||||
calls := 0
|
||||
v := New(&fakeAIGateway{
|
||||
listModels: func(ctx context.Context) ([]string, error) {
|
||||
calls++
|
||||
return []string{"llama3:latest"}, nil
|
||||
},
|
||||
}, time.Minute)
|
||||
|
||||
if _, err := v.Known(context.Background()); err != nil {
|
||||
t.Fatalf("Known() error = %v", err)
|
||||
}
|
||||
if _, err := v.Known(context.Background()); err != nil {
|
||||
t.Fatalf("Known() error = %v", err)
|
||||
}
|
||||
if calls != 1 {
|
||||
t.Fatalf("calls = %d, want 1", calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidatorKnownPropagatesError(t *testing.T) {
|
||||
v := New(&fakeAIGateway{
|
||||
listModels: func(ctx context.Context) ([]string, error) {
|
||||
return nil, errors.New("boom")
|
||||
},
|
||||
}, time.Minute)
|
||||
|
||||
if _, err := v.Known(context.Background()); err == nil {
|
||||
t.Fatal("Known() error = nil, want error")
|
||||
}
|
||||
}
|
||||
@ -25,6 +25,7 @@ type QueryRequest struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
Text string `protobuf:"bytes,1,opt,name=text,proto3" json:"text,omitempty"`
|
||||
Source string `protobuf:"bytes,2,opt,name=source,proto3" json:"source,omitempty"`
|
||||
Model string `protobuf:"bytes,3,opt,name=model,proto3" json:"model,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
@ -73,11 +74,19 @@ func (x *QueryRequest) GetSource() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *QueryRequest) GetModel() string {
|
||||
if x != nil {
|
||||
return x.Model
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type QueryResponse struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
Reply string `protobuf:"bytes,1,opt,name=reply,proto3" json:"reply,omitempty"`
|
||||
Intent string `protobuf:"bytes,2,opt,name=intent,proto3" json:"intent,omitempty"`
|
||||
ActionTaken bool `protobuf:"varint,3,opt,name=action_taken,json=actionTaken,proto3" json:"action_taken,omitempty"`
|
||||
ModelUsed string `protobuf:"bytes,4,opt,name=model_used,json=modelUsed,proto3" json:"model_used,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
@ -133,20 +142,115 @@ func (x *QueryResponse) GetActionTaken() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (x *QueryResponse) GetModelUsed() string {
|
||||
if x != nil {
|
||||
return x.ModelUsed
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type ListModelsRequest struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *ListModelsRequest) Reset() {
|
||||
*x = ListModelsRequest{}
|
||||
mi := &file_ai_v1_ai_proto_msgTypes[2]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *ListModelsRequest) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*ListModelsRequest) ProtoMessage() {}
|
||||
|
||||
func (x *ListModelsRequest) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_ai_v1_ai_proto_msgTypes[2]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use ListModelsRequest.ProtoReflect.Descriptor instead.
|
||||
func (*ListModelsRequest) Descriptor() ([]byte, []int) {
|
||||
return file_ai_v1_ai_proto_rawDescGZIP(), []int{2}
|
||||
}
|
||||
|
||||
type ListModelsResponse struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
Names []string `protobuf:"bytes,1,rep,name=names,proto3" json:"names,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *ListModelsResponse) Reset() {
|
||||
*x = ListModelsResponse{}
|
||||
mi := &file_ai_v1_ai_proto_msgTypes[3]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *ListModelsResponse) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*ListModelsResponse) ProtoMessage() {}
|
||||
|
||||
func (x *ListModelsResponse) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_ai_v1_ai_proto_msgTypes[3]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use ListModelsResponse.ProtoReflect.Descriptor instead.
|
||||
func (*ListModelsResponse) Descriptor() ([]byte, []int) {
|
||||
return file_ai_v1_ai_proto_rawDescGZIP(), []int{3}
|
||||
}
|
||||
|
||||
func (x *ListModelsResponse) GetNames() []string {
|
||||
if x != nil {
|
||||
return x.Names
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var File_ai_v1_ai_proto protoreflect.FileDescriptor
|
||||
|
||||
const file_ai_v1_ai_proto_rawDesc = "" +
|
||||
"\n" +
|
||||
"\x0eai/v1/ai.proto\x12\x05ai.v1\":\n" +
|
||||
"\x0eai/v1/ai.proto\x12\x05ai.v1\"P\n" +
|
||||
"\fQueryRequest\x12\x12\n" +
|
||||
"\x04text\x18\x01 \x01(\tR\x04text\x12\x16\n" +
|
||||
"\x06source\x18\x02 \x01(\tR\x06source\"`\n" +
|
||||
"\x06source\x18\x02 \x01(\tR\x06source\x12\x14\n" +
|
||||
"\x05model\x18\x03 \x01(\tR\x05model\"\x7f\n" +
|
||||
"\rQueryResponse\x12\x14\n" +
|
||||
"\x05reply\x18\x01 \x01(\tR\x05reply\x12\x16\n" +
|
||||
"\x06intent\x18\x02 \x01(\tR\x06intent\x12!\n" +
|
||||
"\faction_taken\x18\x03 \x01(\bR\vactionTaken2?\n" +
|
||||
"\faction_taken\x18\x03 \x01(\bR\vactionTaken\x12\x1d\n" +
|
||||
"\n" +
|
||||
"model_used\x18\x04 \x01(\tR\tmodelUsed\"\x13\n" +
|
||||
"\x11ListModelsRequest\"*\n" +
|
||||
"\x12ListModelsResponse\x12\x14\n" +
|
||||
"\x05names\x18\x01 \x03(\tR\x05names2\x82\x01\n" +
|
||||
"\tAIService\x122\n" +
|
||||
"\x05Query\x12\x13.ai.v1.QueryRequest\x1a\x14.ai.v1.QueryResponseB4Z2gitea.nik4nao.com/nik/home-services/gen/ai/v1;aiv1b\x06proto3"
|
||||
"\x05Query\x12\x13.ai.v1.QueryRequest\x1a\x14.ai.v1.QueryResponse\x12A\n" +
|
||||
"\n" +
|
||||
"ListModels\x12\x18.ai.v1.ListModelsRequest\x1a\x19.ai.v1.ListModelsResponseB4Z2gitea.nik4nao.com/nik/home-services/gen/ai/v1;aiv1b\x06proto3"
|
||||
|
||||
var (
|
||||
file_ai_v1_ai_proto_rawDescOnce sync.Once
|
||||
@ -160,16 +264,20 @@ func file_ai_v1_ai_proto_rawDescGZIP() []byte {
|
||||
return file_ai_v1_ai_proto_rawDescData
|
||||
}
|
||||
|
||||
var file_ai_v1_ai_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
|
||||
var file_ai_v1_ai_proto_msgTypes = make([]protoimpl.MessageInfo, 4)
|
||||
var file_ai_v1_ai_proto_goTypes = []any{
|
||||
(*QueryRequest)(nil), // 0: ai.v1.QueryRequest
|
||||
(*QueryResponse)(nil), // 1: ai.v1.QueryResponse
|
||||
(*QueryRequest)(nil), // 0: ai.v1.QueryRequest
|
||||
(*QueryResponse)(nil), // 1: ai.v1.QueryResponse
|
||||
(*ListModelsRequest)(nil), // 2: ai.v1.ListModelsRequest
|
||||
(*ListModelsResponse)(nil), // 3: ai.v1.ListModelsResponse
|
||||
}
|
||||
var file_ai_v1_ai_proto_depIdxs = []int32{
|
||||
0, // 0: ai.v1.AIService.Query:input_type -> ai.v1.QueryRequest
|
||||
1, // 1: ai.v1.AIService.Query:output_type -> ai.v1.QueryResponse
|
||||
1, // [1:2] is the sub-list for method output_type
|
||||
0, // [0:1] is the sub-list for method input_type
|
||||
2, // 1: ai.v1.AIService.ListModels:input_type -> ai.v1.ListModelsRequest
|
||||
1, // 2: ai.v1.AIService.Query:output_type -> ai.v1.QueryResponse
|
||||
3, // 3: ai.v1.AIService.ListModels:output_type -> ai.v1.ListModelsResponse
|
||||
2, // [2:4] is the sub-list for method output_type
|
||||
0, // [0:2] is the sub-list for method input_type
|
||||
0, // [0:0] is the sub-list for extension type_name
|
||||
0, // [0:0] is the sub-list for extension extendee
|
||||
0, // [0:0] is the sub-list for field type_name
|
||||
@ -186,7 +294,7 @@ func file_ai_v1_ai_proto_init() {
|
||||
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||
RawDescriptor: unsafe.Slice(unsafe.StringData(file_ai_v1_ai_proto_rawDesc), len(file_ai_v1_ai_proto_rawDesc)),
|
||||
NumEnums: 0,
|
||||
NumMessages: 2,
|
||||
NumMessages: 4,
|
||||
NumExtensions: 0,
|
||||
NumServices: 1,
|
||||
},
|
||||
|
||||
@ -19,7 +19,8 @@ import (
|
||||
const _ = grpc.SupportPackageIsVersion9
|
||||
|
||||
const (
|
||||
AIService_Query_FullMethodName = "/ai.v1.AIService/Query"
|
||||
AIService_Query_FullMethodName = "/ai.v1.AIService/Query"
|
||||
AIService_ListModels_FullMethodName = "/ai.v1.AIService/ListModels"
|
||||
)
|
||||
|
||||
// AIServiceClient is the client API for AIService service.
|
||||
@ -27,6 +28,7 @@ const (
|
||||
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
|
||||
type AIServiceClient interface {
|
||||
Query(ctx context.Context, in *QueryRequest, opts ...grpc.CallOption) (*QueryResponse, error)
|
||||
ListModels(ctx context.Context, in *ListModelsRequest, opts ...grpc.CallOption) (*ListModelsResponse, error)
|
||||
}
|
||||
|
||||
type aIServiceClient struct {
|
||||
@ -47,11 +49,22 @@ func (c *aIServiceClient) Query(ctx context.Context, in *QueryRequest, opts ...g
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *aIServiceClient) ListModels(ctx context.Context, in *ListModelsRequest, opts ...grpc.CallOption) (*ListModelsResponse, error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
out := new(ListModelsResponse)
|
||||
err := c.cc.Invoke(ctx, AIService_ListModels_FullMethodName, in, out, cOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// AIServiceServer is the server API for AIService service.
|
||||
// All implementations must embed UnimplementedAIServiceServer
|
||||
// for forward compatibility.
|
||||
type AIServiceServer interface {
|
||||
Query(context.Context, *QueryRequest) (*QueryResponse, error)
|
||||
ListModels(context.Context, *ListModelsRequest) (*ListModelsResponse, error)
|
||||
mustEmbedUnimplementedAIServiceServer()
|
||||
}
|
||||
|
||||
@ -65,6 +78,9 @@ type UnimplementedAIServiceServer struct{}
|
||||
func (UnimplementedAIServiceServer) Query(context.Context, *QueryRequest) (*QueryResponse, error) {
|
||||
return nil, status.Error(codes.Unimplemented, "method Query not implemented")
|
||||
}
|
||||
func (UnimplementedAIServiceServer) ListModels(context.Context, *ListModelsRequest) (*ListModelsResponse, error) {
|
||||
return nil, status.Error(codes.Unimplemented, "method ListModels not implemented")
|
||||
}
|
||||
func (UnimplementedAIServiceServer) mustEmbedUnimplementedAIServiceServer() {}
|
||||
func (UnimplementedAIServiceServer) testEmbeddedByValue() {}
|
||||
|
||||
@ -104,6 +120,24 @@ func _AIService_Query_Handler(srv interface{}, ctx context.Context, dec func(int
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _AIService_ListModels_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(ListModelsRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(AIServiceServer).ListModels(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: AIService_ListModels_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(AIServiceServer).ListModels(ctx, req.(*ListModelsRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
// AIService_ServiceDesc is the grpc.ServiceDesc for AIService service.
|
||||
// It's only intended for direct use with grpc.RegisterService,
|
||||
// and not to be introspected or modified (even as a copy)
|
||||
@ -115,6 +149,10 @@ var AIService_ServiceDesc = grpc.ServiceDesc{
|
||||
MethodName: "Query",
|
||||
Handler: _AIService_Query_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "ListModels",
|
||||
Handler: _AIService_ListModels_Handler,
|
||||
},
|
||||
},
|
||||
Streams: []grpc.StreamDesc{},
|
||||
Metadata: "ai/v1/ai.proto",
|
||||
|
||||
@ -6,15 +6,24 @@ option go_package = "gitea.nik4nao.com/nik/home-services/gen/ai/v1;aiv1";
|
||||
|
||||
service AIService {
|
||||
rpc Query(QueryRequest) returns (QueryResponse);
|
||||
rpc ListModels(ListModelsRequest) returns (ListModelsResponse);
|
||||
}
|
||||
|
||||
message QueryRequest {
|
||||
string text = 1;
|
||||
string source = 2;
|
||||
string model = 3;
|
||||
}
|
||||
|
||||
message QueryResponse {
|
||||
string reply = 1;
|
||||
string intent = 2;
|
||||
bool action_taken = 3;
|
||||
string model_used = 4;
|
||||
}
|
||||
|
||||
message ListModelsRequest {}
|
||||
|
||||
message ListModelsResponse {
|
||||
repeated string names = 1;
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user