feat: enhance AI model management in Discord bot
- Updated LLMClient interface to support model-specific generation and model listing. - Integrated model store and validator into the command application for managing AI models. - Implemented commands for setting, getting, and listing active AI models in Discord. - Enhanced AI query handling to utilize the selected model and return model information in responses. - Added caching mechanism for model validation to improve performance. - Introduced gRPC methods for listing available AI models in the ai-gateway. - Updated protobuf definitions to include model-related fields and messages. - Added tests for model store and validator functionalities.
This commit is contained in:
parent
9cc29c2329
commit
ad50d641bd
@ -67,7 +67,7 @@ func main() {
|
|||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
ollamaClient := ollama.New(cfg.OllamaURL, cfg.OllamaModel, &http.Client{Timeout: cfg.OllamaTimeout})
|
ollamaClient := ollama.New(cfg.OllamaURL, &http.Client{Timeout: cfg.OllamaTimeout})
|
||||||
haClient, err := hagateway.New(ctx, cfg.HAGatewayAddr, cfg.TLSDir, cfg.HAGatewayServerName, log)
|
haClient, err := hagateway.New(ctx, cfg.HAGatewayAddr, cfg.TLSDir, cfg.HAGatewayServerName, log)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("ha-gateway client setup failed", "err", err)
|
log.Error("ha-gateway client setup failed", "err", err)
|
||||||
@ -80,7 +80,7 @@ func main() {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
lightCache := domain.NewLightCache(cfg.LightCacheTTL, haClient.ListLights)
|
lightCache := domain.NewLightCache(cfg.LightCacheTTL, haClient.ListLights)
|
||||||
queryApp := app.NewQueryApp(ollamaClient, haClient, lightCache, log)
|
queryApp := app.NewQueryApp(ollamaClient, haClient, lightCache, cfg.OllamaModel, log)
|
||||||
|
|
||||||
serverOpts := []grpc.ServerOption{
|
serverOpts := []grpc.ServerOption{
|
||||||
grpc.StatsHandler(otelgrpc.NewServerHandler()),
|
grpc.StatsHandler(otelgrpc.NewServerHandler()),
|
||||||
|
|||||||
@ -33,7 +33,7 @@ func (s *Server) Query(ctx context.Context, req *aiv1.QueryRequest) (*aiv1.Query
|
|||||||
ctx = logger.WithLogger(ctx, log)
|
ctx = logger.WithLogger(ctx, log)
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := s.app.Query(ctx, req.GetText())
|
result, err := s.app.Query(ctx, req.GetText(), req.GetModel())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(codes.Unavailable, "query failed: %v", err)
|
return nil, status.Errorf(codes.Unavailable, "query failed: %v", err)
|
||||||
}
|
}
|
||||||
@ -41,5 +41,15 @@ func (s *Server) Query(ctx context.Context, req *aiv1.QueryRequest) (*aiv1.Query
|
|||||||
Reply: result.Reply,
|
Reply: result.Reply,
|
||||||
Intent: result.Intent,
|
Intent: result.Intent,
|
||||||
ActionTaken: result.ActionTaken,
|
ActionTaken: result.ActionTaken,
|
||||||
|
ModelUsed: result.ModelUsed,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ListModels returns the installed model names from Ollama.
|
||||||
|
func (s *Server) ListModels(ctx context.Context, _ *aiv1.ListModelsRequest) (*aiv1.ListModelsResponse, error) {
|
||||||
|
names, err := s.app.ListModels(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.Errorf(codes.Unavailable, "list models: %v", err)
|
||||||
|
}
|
||||||
|
return &aiv1.ListModelsResponse{Names: names}, nil
|
||||||
|
}
|
||||||
|
|||||||
@ -6,6 +6,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
|
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
|
||||||
)
|
)
|
||||||
@ -14,37 +15,49 @@ type generateRequest struct {
|
|||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Prompt string `json:"prompt"`
|
Prompt string `json:"prompt"`
|
||||||
Stream bool `json:"stream"`
|
Stream bool `json:"stream"`
|
||||||
|
Think *bool `json:"think,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type generateResponse struct {
|
type generateResponse struct {
|
||||||
Response string `json:"response"`
|
Response string `json:"response"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type listModelsResponse struct {
|
||||||
|
Models []struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
} `json:"models"`
|
||||||
|
}
|
||||||
|
|
||||||
// Client implements the LLM driven port with the Ollama generate API.
|
// Client implements the LLM driven port with the Ollama generate API.
|
||||||
type Client struct {
|
type Client struct {
|
||||||
baseURL string
|
baseURL string
|
||||||
model string
|
|
||||||
http *http.Client
|
http *http.Client
|
||||||
}
|
}
|
||||||
|
|
||||||
// New constructs an Ollama client with OTel-instrumented transport.
|
// New constructs an Ollama client with OTel-instrumented transport.
|
||||||
func New(baseURL, model string, httpClient *http.Client) *Client {
|
func New(baseURL string, httpClient *http.Client) *Client {
|
||||||
if httpClient == nil {
|
if httpClient == nil {
|
||||||
httpClient = &http.Client{Transport: otelhttp.NewTransport(http.DefaultTransport)}
|
httpClient = &http.Client{Transport: otelhttp.NewTransport(http.DefaultTransport)}
|
||||||
}
|
}
|
||||||
if httpClient.Transport == nil {
|
if httpClient.Transport == nil {
|
||||||
httpClient.Transport = otelhttp.NewTransport(http.DefaultTransport)
|
httpClient.Transport = otelhttp.NewTransport(http.DefaultTransport)
|
||||||
}
|
}
|
||||||
return &Client{baseURL: baseURL, model: model, http: httpClient}
|
return &Client{baseURL: baseURL, http: httpClient}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate sends one non-streaming prompt to Ollama.
|
// Generate sends one non-streaming prompt to Ollama.
|
||||||
func (c *Client) Generate(ctx context.Context, prompt string) (string, error) {
|
func (c *Client) Generate(ctx context.Context, model, prompt string) (string, error) {
|
||||||
body, err := json.Marshal(generateRequest{
|
reqBody := generateRequest{
|
||||||
Model: c.model,
|
Model: model,
|
||||||
Prompt: prompt,
|
Prompt: prompt,
|
||||||
Stream: false,
|
Stream: false,
|
||||||
})
|
}
|
||||||
|
if isThinkingModel(model) {
|
||||||
|
disabled := false
|
||||||
|
reqBody.Think = &disabled
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := json.Marshal(reqBody)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("marshal ollama request: %w", err)
|
return "", fmt.Errorf("marshal ollama request: %w", err)
|
||||||
}
|
}
|
||||||
@ -71,3 +84,38 @@ func (c *Client) Generate(ctx context.Context, prompt string) (string, error) {
|
|||||||
}
|
}
|
||||||
return out.Response, nil
|
return out.Response, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ListModels returns the installed model names from Ollama.
|
||||||
|
func (c *Client) ListModels(ctx context.Context) ([]string, error) {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.baseURL+"/api/tags", nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("build ollama list models request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := c.http.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("list ollama models: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
|
return nil, fmt.Errorf("ollama returned status %s", resp.Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
var out listModelsResponse
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
|
||||||
|
return nil, fmt.Errorf("decode ollama list models response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
names := make([]string, 0, len(out.Models))
|
||||||
|
for _, model := range out.Models {
|
||||||
|
if model.Name != "" {
|
||||||
|
names = append(names, model.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return names, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isThinkingModel(name string) bool {
|
||||||
|
return strings.HasPrefix(name, "qwen3")
|
||||||
|
}
|
||||||
|
|||||||
81
ai-gateway/internal/adapters/secondary/ollama/client_test.go
Normal file
81
ai-gateway/internal/adapters/secondary/ollama/client_test.go
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
package ollama
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestIsThinkingModel(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{name: "qwen3", want: true},
|
||||||
|
{name: "qwen3:4b", want: true},
|
||||||
|
{name: "qwen3:latest", want: true},
|
||||||
|
{name: "llama3", want: false},
|
||||||
|
{name: "mistral", want: false},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
if got := isThinkingModel(tt.name); got != tt.want {
|
||||||
|
t.Fatalf("isThinkingModel(%q) = %v, want %v", tt.name, got, tt.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateSetsThinkForQwen3(t *testing.T) {
|
||||||
|
var body map[string]any
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||||
|
t.Fatalf("Decode() error = %v", err)
|
||||||
|
}
|
||||||
|
_, _ = w.Write([]byte(`{"response":"ok"}`))
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
client := New(srv.URL, srv.Client())
|
||||||
|
if _, err := client.Generate(context.Background(), "qwen3:latest", "prompt"); err != nil {
|
||||||
|
t.Fatalf("Generate() error = %v", err)
|
||||||
|
}
|
||||||
|
if body["think"] != false {
|
||||||
|
t.Fatalf("think = %#v, want false", body["think"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateOmitsThinkForLlama3(t *testing.T) {
|
||||||
|
var body map[string]any
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||||
|
t.Fatalf("Decode() error = %v", err)
|
||||||
|
}
|
||||||
|
_, _ = w.Write([]byte(`{"response":"ok"}`))
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
client := New(srv.URL, srv.Client())
|
||||||
|
if _, err := client.Generate(context.Background(), "llama3:latest", "prompt"); err != nil {
|
||||||
|
t.Fatalf("Generate() error = %v", err)
|
||||||
|
}
|
||||||
|
if _, ok := body["think"]; ok {
|
||||||
|
t.Fatalf("unexpected think key = %#v", body["think"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListModelsReturnsNamesOnly(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
_, _ = w.Write([]byte(`{"models":[{"name":"llama3:latest"},{"name":"qwen3:latest"}]}`))
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
client := New(srv.URL, srv.Client())
|
||||||
|
got, err := client.ListModels(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ListModels() error = %v", err)
|
||||||
|
}
|
||||||
|
if len(got) != 2 || got[0] != "llama3:latest" || got[1] != "qwen3:latest" {
|
||||||
|
t.Fatalf("ListModels() = %#v", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -18,6 +18,7 @@ type QueryResult struct {
|
|||||||
Reply string
|
Reply string
|
||||||
Intent string
|
Intent string
|
||||||
ActionTaken bool
|
ActionTaken bool
|
||||||
|
ModelUsed string
|
||||||
}
|
}
|
||||||
|
|
||||||
// QueryApp orchestrates one AI query request.
|
// QueryApp orchestrates one AI query request.
|
||||||
@ -25,27 +26,32 @@ type QueryApp struct {
|
|||||||
llm driven.LLMClient
|
llm driven.LLMClient
|
||||||
ha driven.HAClient
|
ha driven.HAClient
|
||||||
cache *domain.LightCache
|
cache *domain.LightCache
|
||||||
|
defaultModel string
|
||||||
log *slog.Logger
|
log *slog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewQueryApp constructs the AI query application service.
|
// NewQueryApp constructs the AI query application service.
|
||||||
func NewQueryApp(llm driven.LLMClient, ha driven.HAClient, cache *domain.LightCache, log *slog.Logger) *QueryApp {
|
func NewQueryApp(llm driven.LLMClient, ha driven.HAClient, cache *domain.LightCache, defaultModel string, log *slog.Logger) *QueryApp {
|
||||||
return &QueryApp{llm: llm, ha: ha, cache: cache, log: log}
|
return &QueryApp{llm: llm, ha: ha, cache: cache, defaultModel: defaultModel, log: log}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Query runs the full intent parsing and dispatch flow for one user request.
|
// Query runs the full intent parsing and dispatch flow for one user request.
|
||||||
func (a *QueryApp) Query(ctx context.Context, text string) (QueryResult, error) {
|
func (a *QueryApp) Query(ctx context.Context, text, model string) (QueryResult, error) {
|
||||||
|
if model == "" {
|
||||||
|
model = a.defaultModel
|
||||||
|
}
|
||||||
lights, err := a.cache.Get(ctx)
|
lights, err := a.cache.Get(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
a.log.Error("light cache refresh failed", "err", err)
|
a.log.Error("light cache refresh failed", "err", err)
|
||||||
return QueryResult{
|
return QueryResult{
|
||||||
Reply: "I couldn't reach Home Assistant right now.",
|
Reply: "I couldn't reach Home Assistant right now.",
|
||||||
ActionTaken: false,
|
ActionTaken: false,
|
||||||
|
ModelUsed: model,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
prompt := domain.BuildPrompt(text, promptLightLines(lights))
|
prompt := domain.BuildPrompt(text, promptLightLines(lights))
|
||||||
raw, err := a.llm.Generate(ctx, prompt)
|
raw, err := a.llm.Generate(ctx, model, prompt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return QueryResult{}, err
|
return QueryResult{}, err
|
||||||
}
|
}
|
||||||
@ -57,6 +63,7 @@ func (a *QueryApp) Query(ctx context.Context, text string) (QueryResult, error)
|
|||||||
Reply: "I didn't understand that.",
|
Reply: "I didn't understand that.",
|
||||||
Intent: domain.IntentNone,
|
Intent: domain.IntentNone,
|
||||||
ActionTaken: false,
|
ActionTaken: false,
|
||||||
|
ModelUsed: model,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -64,7 +71,7 @@ func (a *QueryApp) Query(ctx context.Context, text string) (QueryResult, error)
|
|||||||
case domain.IntentTurnOnLight:
|
case domain.IntentTurnOnLight:
|
||||||
entityID, ok := resolveLightEntity(intent.Entity, lights)
|
entityID, ok := resolveLightEntity(intent.Entity, lights)
|
||||||
if !ok {
|
if !ok {
|
||||||
return QueryResult{Reply: "I couldn't find that light.", Intent: intent.Name}, nil
|
return QueryResult{Reply: "I couldn't find that light.", Intent: intent.Name, ModelUsed: model}, nil
|
||||||
}
|
}
|
||||||
params, err := ParseLightParams(intent.Params)
|
params, err := ParseLightParams(intent.Params)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -72,6 +79,7 @@ func (a *QueryApp) Query(ctx context.Context, text string) (QueryResult, error)
|
|||||||
Reply: "I couldn't understand the light settings.",
|
Reply: "I couldn't understand the light settings.",
|
||||||
Intent: intent.Name,
|
Intent: intent.Name,
|
||||||
ActionTaken: false,
|
ActionTaken: false,
|
||||||
|
ModelUsed: model,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
if err := a.ha.TurnOnLight(ctx, entityID, params); err != nil {
|
if err := a.ha.TurnOnLight(ctx, entityID, params); err != nil {
|
||||||
@ -80,17 +88,19 @@ func (a *QueryApp) Query(ctx context.Context, text string) (QueryResult, error)
|
|||||||
Reply: "I couldn't reach Home Assistant right now.",
|
Reply: "I couldn't reach Home Assistant right now.",
|
||||||
Intent: intent.Name,
|
Intent: intent.Name,
|
||||||
ActionTaken: false,
|
ActionTaken: false,
|
||||||
|
ModelUsed: model,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
return QueryResult{
|
return QueryResult{
|
||||||
Reply: fallbackReply(intent.Reply, fmt.Sprintf("Turned on `%s`.", displayLightName(entityID, lights))),
|
Reply: fallbackReply(intent.Reply, fmt.Sprintf("Turned on `%s`.", displayLightName(entityID, lights))),
|
||||||
Intent: intent.Name,
|
Intent: intent.Name,
|
||||||
ActionTaken: true,
|
ActionTaken: true,
|
||||||
|
ModelUsed: model,
|
||||||
}, nil
|
}, nil
|
||||||
case domain.IntentTurnOffLight:
|
case domain.IntentTurnOffLight:
|
||||||
entityID, ok := resolveLightEntity(intent.Entity, lights)
|
entityID, ok := resolveLightEntity(intent.Entity, lights)
|
||||||
if !ok {
|
if !ok {
|
||||||
return QueryResult{Reply: "I couldn't find that light.", Intent: intent.Name}, nil
|
return QueryResult{Reply: "I couldn't find that light.", Intent: intent.Name, ModelUsed: model}, nil
|
||||||
}
|
}
|
||||||
if err := a.ha.TurnOffLight(ctx, entityID); err != nil {
|
if err := a.ha.TurnOffLight(ctx, entityID); err != nil {
|
||||||
a.log.Error("turn off light failed", "entity_id", entityID, "err", err)
|
a.log.Error("turn off light failed", "entity_id", entityID, "err", err)
|
||||||
@ -98,18 +108,21 @@ func (a *QueryApp) Query(ctx context.Context, text string) (QueryResult, error)
|
|||||||
Reply: "I couldn't reach Home Assistant right now.",
|
Reply: "I couldn't reach Home Assistant right now.",
|
||||||
Intent: intent.Name,
|
Intent: intent.Name,
|
||||||
ActionTaken: false,
|
ActionTaken: false,
|
||||||
|
ModelUsed: model,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
return QueryResult{
|
return QueryResult{
|
||||||
Reply: fallbackReply(intent.Reply, fmt.Sprintf("Turned off `%s`.", displayLightName(entityID, lights))),
|
Reply: fallbackReply(intent.Reply, fmt.Sprintf("Turned off `%s`.", displayLightName(entityID, lights))),
|
||||||
Intent: intent.Name,
|
Intent: intent.Name,
|
||||||
ActionTaken: true,
|
ActionTaken: true,
|
||||||
|
ModelUsed: model,
|
||||||
}, nil
|
}, nil
|
||||||
case domain.IntentListLights:
|
case domain.IntentListLights:
|
||||||
return QueryResult{
|
return QueryResult{
|
||||||
Reply: formatLightListReply(lights),
|
Reply: formatLightListReply(lights),
|
||||||
Intent: intent.Name,
|
Intent: intent.Name,
|
||||||
ActionTaken: false,
|
ActionTaken: false,
|
||||||
|
ModelUsed: model,
|
||||||
}, nil
|
}, nil
|
||||||
case domain.IntentNone:
|
case domain.IntentNone:
|
||||||
fallthrough
|
fallthrough
|
||||||
@ -118,10 +131,16 @@ func (a *QueryApp) Query(ctx context.Context, text string) (QueryResult, error)
|
|||||||
Reply: fallbackReply(intent.Reply, "I didn't understand that."),
|
Reply: fallbackReply(intent.Reply, "I didn't understand that."),
|
||||||
Intent: intent.Name,
|
Intent: intent.Name,
|
||||||
ActionTaken: false,
|
ActionTaken: false,
|
||||||
|
ModelUsed: model,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ListModels returns the currently installed model names.
|
||||||
|
func (a *QueryApp) ListModels(ctx context.Context) ([]string, error) {
|
||||||
|
return a.llm.ListModels(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
func promptLightLines(lights []driven.Light) []string {
|
func promptLightLines(lights []driven.Light) []string {
|
||||||
lines := make([]string, 0, len(lights))
|
lines := make([]string, 0, len(lights))
|
||||||
for _, light := range lights {
|
for _, light := range lights {
|
||||||
|
|||||||
@ -13,11 +13,19 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type fakeLLM struct {
|
type fakeLLM struct {
|
||||||
generate func(context.Context, string) (string, error)
|
generate func(context.Context, string, string) (string, error)
|
||||||
|
listModels func(context.Context) ([]string, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *fakeLLM) Generate(ctx context.Context, prompt string) (string, error) {
|
func (f *fakeLLM) Generate(ctx context.Context, model, prompt string) (string, error) {
|
||||||
return f.generate(ctx, prompt)
|
return f.generate(ctx, model, prompt)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeLLM) ListModels(ctx context.Context) ([]string, error) {
|
||||||
|
if f.listModels == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return f.listModels(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
type fakeHA struct {
|
type fakeHA struct {
|
||||||
@ -54,12 +62,15 @@ func TestQueryAppTurnOnLight(t *testing.T) {
|
|||||||
ha := &fakeHA{lights: []driven.Light{{EntityID: "light.kitchen", FriendlyName: "Kitchen", State: "off"}}}
|
ha := &fakeHA{lights: []driven.Light{{EntityID: "light.kitchen", FriendlyName: "Kitchen", State: "off"}}}
|
||||||
cache := domain.NewLightCache(time.Hour, ha.ListLights)
|
cache := domain.NewLightCache(time.Hour, ha.ListLights)
|
||||||
app := NewQueryApp(&fakeLLM{
|
app := NewQueryApp(&fakeLLM{
|
||||||
generate: func(ctx context.Context, prompt string) (string, error) {
|
generate: func(ctx context.Context, model, prompt string) (string, error) {
|
||||||
|
if model != "llama3:latest" {
|
||||||
|
t.Fatalf("Generate() model = %q", model)
|
||||||
|
}
|
||||||
return `{"intent":"turn_on_light","entity":"Kitchen","params":{"brightness":"80"},"reply":"Turning on Kitchen."}`, nil
|
return `{"intent":"turn_on_light","entity":"Kitchen","params":{"brightness":"80"},"reply":"Turning on Kitchen."}`, nil
|
||||||
},
|
},
|
||||||
}, ha, cache, slog.Default())
|
}, ha, cache, "llama3:latest", slog.Default())
|
||||||
|
|
||||||
got, err := app.Query(context.Background(), "turn on kitchen")
|
got, err := app.Query(context.Background(), "turn on kitchen", "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Query() error = %v", err)
|
t.Fatalf("Query() error = %v", err)
|
||||||
}
|
}
|
||||||
@ -77,12 +88,12 @@ func TestQueryAppTurnOnLight(t *testing.T) {
|
|||||||
func TestQueryAppInvalidJSON(t *testing.T) {
|
func TestQueryAppInvalidJSON(t *testing.T) {
|
||||||
ha := &fakeHA{lights: []driven.Light{{EntityID: "light.kitchen", FriendlyName: "Kitchen", State: "off"}}}
|
ha := &fakeHA{lights: []driven.Light{{EntityID: "light.kitchen", FriendlyName: "Kitchen", State: "off"}}}
|
||||||
app := NewQueryApp(&fakeLLM{
|
app := NewQueryApp(&fakeLLM{
|
||||||
generate: func(ctx context.Context, prompt string) (string, error) {
|
generate: func(ctx context.Context, model, prompt string) (string, error) {
|
||||||
return `not-json`, nil
|
return `not-json`, nil
|
||||||
},
|
},
|
||||||
}, ha, domain.NewLightCache(time.Hour, ha.ListLights), slog.Default())
|
}, ha, domain.NewLightCache(time.Hour, ha.ListLights), "llama3:latest", slog.Default())
|
||||||
|
|
||||||
got, err := app.Query(context.Background(), "turn on kitchen")
|
got, err := app.Query(context.Background(), "turn on kitchen", "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Query() error = %v", err)
|
t.Fatalf("Query() error = %v", err)
|
||||||
}
|
}
|
||||||
@ -97,12 +108,12 @@ func TestQueryAppInvalidJSON(t *testing.T) {
|
|||||||
func TestQueryAppIntentNone(t *testing.T) {
|
func TestQueryAppIntentNone(t *testing.T) {
|
||||||
ha := &fakeHA{lights: []driven.Light{{EntityID: "light.kitchen", FriendlyName: "Kitchen", State: "off"}}}
|
ha := &fakeHA{lights: []driven.Light{{EntityID: "light.kitchen", FriendlyName: "Kitchen", State: "off"}}}
|
||||||
app := NewQueryApp(&fakeLLM{
|
app := NewQueryApp(&fakeLLM{
|
||||||
generate: func(ctx context.Context, prompt string) (string, error) {
|
generate: func(ctx context.Context, model, prompt string) (string, error) {
|
||||||
return `{"intent":"none","entity":"","params":{},"reply":"Hello there."}`, nil
|
return `{"intent":"none","entity":"","params":{},"reply":"Hello there."}`, nil
|
||||||
},
|
},
|
||||||
}, ha, domain.NewLightCache(time.Hour, ha.ListLights), slog.Default())
|
}, ha, domain.NewLightCache(time.Hour, ha.ListLights), "llama3:latest", slog.Default())
|
||||||
|
|
||||||
got, err := app.Query(context.Background(), "hello")
|
got, err := app.Query(context.Background(), "hello", "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Query() error = %v", err)
|
t.Fatalf("Query() error = %v", err)
|
||||||
}
|
}
|
||||||
@ -117,12 +128,12 @@ func TestQueryAppHAFailure(t *testing.T) {
|
|||||||
turnOnErr: errors.New("boom"),
|
turnOnErr: errors.New("boom"),
|
||||||
}
|
}
|
||||||
app := NewQueryApp(&fakeLLM{
|
app := NewQueryApp(&fakeLLM{
|
||||||
generate: func(ctx context.Context, prompt string) (string, error) {
|
generate: func(ctx context.Context, model, prompt string) (string, error) {
|
||||||
return `{"intent":"turn_on_light","entity":"light.kitchen","params":{},"reply":"Turning on Kitchen."}`, nil
|
return `{"intent":"turn_on_light","entity":"light.kitchen","params":{},"reply":"Turning on Kitchen."}`, nil
|
||||||
},
|
},
|
||||||
}, ha, domain.NewLightCache(time.Hour, ha.ListLights), slog.Default())
|
}, ha, domain.NewLightCache(time.Hour, ha.ListLights), "llama3:latest", slog.Default())
|
||||||
|
|
||||||
got, err := app.Query(context.Background(), "turn on kitchen")
|
got, err := app.Query(context.Background(), "turn on kitchen", "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Query() error = %v", err)
|
t.Fatalf("Query() error = %v", err)
|
||||||
}
|
}
|
||||||
@ -134,12 +145,12 @@ func TestQueryAppHAFailure(t *testing.T) {
|
|||||||
func TestQueryAppListLights(t *testing.T) {
|
func TestQueryAppListLights(t *testing.T) {
|
||||||
ha := &fakeHA{lights: []driven.Light{{EntityID: "light.kitchen", FriendlyName: "Kitchen", State: "on"}}}
|
ha := &fakeHA{lights: []driven.Light{{EntityID: "light.kitchen", FriendlyName: "Kitchen", State: "on"}}}
|
||||||
app := NewQueryApp(&fakeLLM{
|
app := NewQueryApp(&fakeLLM{
|
||||||
generate: func(ctx context.Context, prompt string) (string, error) {
|
generate: func(ctx context.Context, model, prompt string) (string, error) {
|
||||||
return `{"intent":"list_lights","entity":"","params":{},"reply":""}`, nil
|
return `{"intent":"list_lights","entity":"","params":{},"reply":""}`, nil
|
||||||
},
|
},
|
||||||
}, ha, domain.NewLightCache(time.Hour, ha.ListLights), slog.Default())
|
}, ha, domain.NewLightCache(time.Hour, ha.ListLights), "llama3:latest", slog.Default())
|
||||||
|
|
||||||
got, err := app.Query(context.Background(), "what lights exist")
|
got, err := app.Query(context.Background(), "what lights exist", "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Query() error = %v", err)
|
t.Fatalf("Query() error = %v", err)
|
||||||
}
|
}
|
||||||
@ -164,3 +175,42 @@ func TestLightCacheRefreshAfterTTL(t *testing.T) {
|
|||||||
t.Fatalf("ListLights calls = %d, want at least 2", ha.listCalls)
|
t.Fatalf("ListLights calls = %d, want at least 2", ha.listCalls)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestQueryAppExplicitModel(t *testing.T) {
|
||||||
|
ha := &fakeHA{lights: []driven.Light{{EntityID: "light.kitchen", FriendlyName: "Kitchen", State: "off"}}}
|
||||||
|
app := NewQueryApp(&fakeLLM{
|
||||||
|
generate: func(ctx context.Context, model, prompt string) (string, error) {
|
||||||
|
if model != "qwen3:latest" {
|
||||||
|
t.Fatalf("Generate() model = %q", model)
|
||||||
|
}
|
||||||
|
return `{"intent":"none","entity":"","params":{},"reply":"Hello there."}`, nil
|
||||||
|
},
|
||||||
|
}, ha, domain.NewLightCache(time.Hour, ha.ListLights), "llama3:latest", slog.Default())
|
||||||
|
|
||||||
|
got, err := app.Query(context.Background(), "hello", "qwen3:latest")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Query() error = %v", err)
|
||||||
|
}
|
||||||
|
if got.ModelUsed != "qwen3:latest" {
|
||||||
|
t.Fatalf("ModelUsed = %q", got.ModelUsed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueryAppListModels(t *testing.T) {
|
||||||
|
app := NewQueryApp(&fakeLLM{
|
||||||
|
generate: func(ctx context.Context, model, prompt string) (string, error) {
|
||||||
|
return "", nil
|
||||||
|
},
|
||||||
|
listModels: func(ctx context.Context) ([]string, error) {
|
||||||
|
return []string{"llama3:latest", "qwen3:latest"}, nil
|
||||||
|
},
|
||||||
|
}, &fakeHA{}, domain.NewLightCache(time.Hour, func(context.Context) ([]driven.Light, error) { return nil, nil }), "llama3:latest", slog.Default())
|
||||||
|
|
||||||
|
got, err := app.ListModels(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ListModels() error = %v", err)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, []string{"llama3:latest", "qwen3:latest"}) {
|
||||||
|
t.Fatalf("ListModels() = %#v", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@ -4,5 +4,6 @@ import "context"
|
|||||||
|
|
||||||
// LLMClient generates one model response for a prompt.
|
// LLMClient generates one model response for a prompt.
|
||||||
type LLMClient interface {
|
type LLMClient interface {
|
||||||
Generate(ctx context.Context, prompt string) (string, error)
|
Generate(ctx context.Context, model, prompt string) (string, error)
|
||||||
|
ListModels(ctx context.Context) ([]string, error)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -17,6 +17,8 @@ import (
|
|||||||
"gitea.nik4nao.com/nik/home-services/discord-bot/internal/app"
|
"gitea.nik4nao.com/nik/home-services/discord-bot/internal/app"
|
||||||
"gitea.nik4nao.com/nik/home-services/discord-bot/internal/config"
|
"gitea.nik4nao.com/nik/home-services/discord-bot/internal/config"
|
||||||
"gitea.nik4nao.com/nik/home-services/discord-bot/internal/logger"
|
"gitea.nik4nao.com/nik/home-services/discord-bot/internal/logger"
|
||||||
|
"gitea.nik4nao.com/nik/home-services/discord-bot/internal/modelstore"
|
||||||
|
"gitea.nik4nao.com/nik/home-services/discord-bot/internal/modelvalidator"
|
||||||
"gitea.nik4nao.com/nik/home-services/discord-bot/internal/telemetry"
|
"gitea.nik4nao.com/nik/home-services/discord-bot/internal/telemetry"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -84,7 +86,9 @@ func main() {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
commandApp := app.NewCommandApp(haClient, aiClient)
|
modelStore := modelstore.New()
|
||||||
|
validator := modelvalidator.New(aiClient, 30*time.Second)
|
||||||
|
commandApp := app.NewCommandApp(haClient, aiClient, modelStore, validator)
|
||||||
|
|
||||||
// Discord-specific wiring stays at the edge so the app layer remains transport-agnostic.
|
// Discord-specific wiring stays at the edge so the app layer remains transport-agnostic.
|
||||||
session, err := discordgo.New("Bot " + cfg.DiscordToken)
|
session, err := discordgo.New("Bot " + cfg.DiscordToken)
|
||||||
|
|||||||
@ -24,8 +24,12 @@ type commandHandler interface {
|
|||||||
HandleLightToggle(ctx context.Context, entityID string) (string, error)
|
HandleLightToggle(ctx context.Context, entityID string) (string, error)
|
||||||
HandleSwitchList(ctx context.Context) (string, error)
|
HandleSwitchList(ctx context.Context) (string, error)
|
||||||
HandleAIQuery(ctx context.Context, text string) (string, error)
|
HandleAIQuery(ctx context.Context, text string) (string, error)
|
||||||
|
HandleAIModelSet(ctx context.Context, name string) (string, error)
|
||||||
|
HandleAIModelGet(ctx context.Context) (string, error)
|
||||||
|
HandleAIModelList(ctx context.Context) (string, error)
|
||||||
AutocompleteLights(ctx context.Context) ([]apppkg.Choice, error)
|
AutocompleteLights(ctx context.Context) ([]apppkg.Choice, error)
|
||||||
AutocompleteSwitches(ctx context.Context) ([]apppkg.Choice, error)
|
AutocompleteSwitches(ctx context.Context) ([]apppkg.Choice, error)
|
||||||
|
AutocompleteAIModels(ctx context.Context) ([]apppkg.Choice, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handler adapts Discord interactions to the command application layer.
|
// Handler adapts Discord interactions to the command application layer.
|
||||||
@ -69,6 +73,9 @@ func (h *Handler) handleApplicationCommand(ctx context.Context, s *discordgo.Ses
|
|||||||
command := data.Name
|
command := data.Name
|
||||||
if len(data.Options) > 0 {
|
if len(data.Options) > 0 {
|
||||||
command += "." + data.Options[0].Name
|
command += "." + data.Options[0].Name
|
||||||
|
if len(data.Options[0].Options) > 0 && data.Options[0].Type == discordgo.ApplicationCommandOptionSubCommandGroup {
|
||||||
|
command += "." + data.Options[0].Options[0].Name
|
||||||
|
}
|
||||||
}
|
}
|
||||||
user := ""
|
user := ""
|
||||||
if i.Member != nil && i.Member.User != nil {
|
if i.Member != nil && i.Member.User != nil {
|
||||||
@ -90,7 +97,18 @@ func (h *Handler) handleApplicationCommand(ctx context.Context, s *discordgo.Ses
|
|||||||
}
|
}
|
||||||
|
|
||||||
sub := data.Options[0]
|
sub := data.Options[0]
|
||||||
switch data.Name + "." + sub.Name {
|
commandPath := data.Name + "." + sub.Name
|
||||||
|
target := sub
|
||||||
|
if sub.Type == discordgo.ApplicationCommandOptionSubCommandGroup {
|
||||||
|
if len(sub.Options) == 0 {
|
||||||
|
h.respondError(ctx, s, i.Interaction, true, start, fmt.Errorf("missing grouped subcommand"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
target = sub.Options[0]
|
||||||
|
commandPath += "." + target.Name
|
||||||
|
}
|
||||||
|
|
||||||
|
switch commandPath {
|
||||||
case "light.list":
|
case "light.list":
|
||||||
msg, err := h.app.HandleLightList(ctx)
|
msg, err := h.app.HandleLightList(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -145,10 +163,38 @@ func (h *Handler) handleApplicationCommand(ctx context.Context, s *discordgo.Ses
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
msg, err := h.app.HandleAIQuery(ctx, requiredStringOption(sub, "text"))
|
msg, err := h.app.HandleAIQuery(ctx, requiredStringOption(target, "text"))
|
||||||
|
h.followup(ctx, s, i.Interaction, msg, true, start, err)
|
||||||
|
case "ai.model.set":
|
||||||
|
if err := h.deferResponse(s, i.Interaction, true); err != nil {
|
||||||
|
log.Error("discord response failed",
|
||||||
|
"duration_ms", time.Since(start).Milliseconds(),
|
||||||
|
"error", err.Error(),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
msg, err := h.app.HandleAIModelSet(ctx, requiredStringOption(target, "name"))
|
||||||
|
h.followup(ctx, s, i.Interaction, msg, true, start, err)
|
||||||
|
case "ai.model.get":
|
||||||
|
msg, err := h.app.HandleAIModelGet(ctx)
|
||||||
|
if err != nil {
|
||||||
|
h.respondError(ctx, s, i.Interaction, true, start, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.respondMessage(ctx, s, i.Interaction, msg, true)
|
||||||
|
log.Info("command handled", "duration_ms", time.Since(start).Milliseconds())
|
||||||
|
case "ai.model.list":
|
||||||
|
if err := h.deferResponse(s, i.Interaction, true); err != nil {
|
||||||
|
log.Error("discord response failed",
|
||||||
|
"duration_ms", time.Since(start).Milliseconds(),
|
||||||
|
"error", err.Error(),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
msg, err := h.app.HandleAIModelList(ctx)
|
||||||
h.followup(ctx, s, i.Interaction, msg, true, start, err)
|
h.followup(ctx, s, i.Interaction, msg, true, start, err)
|
||||||
default:
|
default:
|
||||||
h.respondError(ctx, s, i.Interaction, true, start, fmt.Errorf("unsupported command: %s.%s", data.Name, sub.Name))
|
h.respondError(ctx, s, i.Interaction, true, start, fmt.Errorf("unsupported command: %s", commandPath))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -167,6 +213,10 @@ func (h *Handler) handleAutocomplete(ctx context.Context, s *discordgo.Session,
|
|||||||
choices, err = h.app.AutocompleteLights(ctx)
|
choices, err = h.app.AutocompleteLights(ctx)
|
||||||
case "switch":
|
case "switch":
|
||||||
choices, err = h.app.AutocompleteSwitches(ctx)
|
choices, err = h.app.AutocompleteSwitches(ctx)
|
||||||
|
case "ai":
|
||||||
|
if focusedOptionName(data) == "name" {
|
||||||
|
choices, err = h.app.AutocompleteAIModels(ctx)
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
choices = nil
|
choices = nil
|
||||||
}
|
}
|
||||||
@ -299,6 +349,15 @@ func optionalUint32Option(sub *discordgo.ApplicationCommandInteractionDataOption
|
|||||||
|
|
||||||
func focusedOptionValue(data discordgo.ApplicationCommandInteractionData) string {
|
func focusedOptionValue(data discordgo.ApplicationCommandInteractionData) string {
|
||||||
for _, sub := range data.Options {
|
for _, sub := range data.Options {
|
||||||
|
if sub.Type == discordgo.ApplicationCommandOptionSubCommandGroup {
|
||||||
|
for _, nested := range sub.Options {
|
||||||
|
for _, opt := range nested.Options {
|
||||||
|
if opt.Focused {
|
||||||
|
return opt.StringValue()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
for _, opt := range sub.Options {
|
for _, opt := range sub.Options {
|
||||||
if opt.Focused {
|
if opt.Focused {
|
||||||
return opt.StringValue()
|
return opt.StringValue()
|
||||||
@ -307,3 +366,23 @@ func focusedOptionValue(data discordgo.ApplicationCommandInteractionData) string
|
|||||||
}
|
}
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func focusedOptionName(data discordgo.ApplicationCommandInteractionData) string {
|
||||||
|
for _, sub := range data.Options {
|
||||||
|
if sub.Type == discordgo.ApplicationCommandOptionSubCommandGroup {
|
||||||
|
for _, nested := range sub.Options {
|
||||||
|
for _, opt := range nested.Options {
|
||||||
|
if opt.Focused {
|
||||||
|
return opt.Name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, opt := range sub.Options {
|
||||||
|
if opt.Focused {
|
||||||
|
return opt.Name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|||||||
@ -98,6 +98,37 @@ func RegisterCommands(s *discordgo.Session, guildID string) error {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Type: discordgo.ApplicationCommandOptionSubCommandGroup,
|
||||||
|
Name: "model",
|
||||||
|
Description: "Manage the active AI model",
|
||||||
|
Options: []*discordgo.ApplicationCommandOption{
|
||||||
|
{
|
||||||
|
Type: discordgo.ApplicationCommandOptionSubCommand,
|
||||||
|
Name: "set",
|
||||||
|
Description: "Set the active model",
|
||||||
|
Options: []*discordgo.ApplicationCommandOption{
|
||||||
|
{
|
||||||
|
Type: discordgo.ApplicationCommandOptionString,
|
||||||
|
Name: "name",
|
||||||
|
Description: "Model name",
|
||||||
|
Required: true,
|
||||||
|
Autocomplete: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Type: discordgo.ApplicationCommandOptionSubCommand,
|
||||||
|
Name: "get",
|
||||||
|
Description: "Show the active model",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Type: discordgo.ApplicationCommandOptionSubCommand,
|
||||||
|
Name: "list",
|
||||||
|
Description: "List available models",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@ -61,22 +61,39 @@ func (c *Client) Close() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Query forwards one free-form request to ai-gateway.
|
// Query forwards one free-form request to ai-gateway.
|
||||||
func (c *Client) Query(ctx context.Context, text string) (string, error) {
|
func (c *Client) Query(ctx context.Context, text, model string) (string, string, error) {
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
log := logger.FromContext(ctx).With("grpc.method", "AIService/Query")
|
log := logger.FromContext(ctx).With("grpc.method", "AIService/Query")
|
||||||
resp, err := c.client.Query(ctx, &aiv1.QueryRequest{
|
resp, err := c.client.Query(ctx, &aiv1.QueryRequest{
|
||||||
Text: text,
|
Text: text,
|
||||||
Source: "discord-bot",
|
Source: "discord-bot",
|
||||||
|
Model: model,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("grpc call failed",
|
log.Error("grpc call failed",
|
||||||
"duration_ms", time.Since(start).Milliseconds(),
|
"duration_ms", time.Since(start).Milliseconds(),
|
||||||
"error", err.Error(),
|
"error", err.Error(),
|
||||||
)
|
)
|
||||||
return "", fmt.Errorf("query ai-gateway: %w", err)
|
return "", "", fmt.Errorf("query ai-gateway: %w", err)
|
||||||
}
|
}
|
||||||
log.Debug("grpc call completed", "duration_ms", time.Since(start).Milliseconds())
|
log.Debug("grpc call completed", "duration_ms", time.Since(start).Milliseconds())
|
||||||
return resp.GetReply(), nil
|
return resp.GetReply(), resp.GetModelUsed(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListModels returns the installed model names from ai-gateway.
|
||||||
|
func (c *Client) ListModels(ctx context.Context) ([]string, error) {
|
||||||
|
start := time.Now()
|
||||||
|
log := logger.FromContext(ctx).With("grpc.method", "AIService/ListModels")
|
||||||
|
resp, err := c.client.ListModels(ctx, &aiv1.ListModelsRequest{})
|
||||||
|
if err != nil {
|
||||||
|
log.Error("grpc call failed",
|
||||||
|
"duration_ms", time.Since(start).Milliseconds(),
|
||||||
|
"error", err.Error(),
|
||||||
|
)
|
||||||
|
return nil, fmt.Errorf("list ai-gateway models: %w", err)
|
||||||
|
}
|
||||||
|
log.Debug("grpc call completed", "duration_ms", time.Since(start).Milliseconds())
|
||||||
|
return append([]string(nil), resp.GetNames()...), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func loadTransportCredentials(tlsDir string) (credentials.TransportCredentials, error) {
|
func loadTransportCredentials(tlsDir string) (credentials.TransportCredentials, error) {
|
||||||
|
|||||||
@ -7,6 +7,8 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"gitea.nik4nao.com/nik/home-services/discord-bot/internal/core/ports/driven"
|
"gitea.nik4nao.com/nik/home-services/discord-bot/internal/core/ports/driven"
|
||||||
|
"gitea.nik4nao.com/nik/home-services/discord-bot/internal/modelstore"
|
||||||
|
"gitea.nik4nao.com/nik/home-services/discord-bot/internal/modelvalidator"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Choice is one Discord autocomplete entry.
|
// Choice is one Discord autocomplete entry.
|
||||||
@ -19,11 +21,13 @@ type Choice struct {
|
|||||||
type CommandApp struct {
|
type CommandApp struct {
|
||||||
ha driven.HAGateway
|
ha driven.HAGateway
|
||||||
ai driven.AIGateway
|
ai driven.AIGateway
|
||||||
|
models *modelstore.Store
|
||||||
|
validator *modelvalidator.Validator
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewCommandApp constructs the Discord command application service.
|
// NewCommandApp constructs the Discord command application service.
|
||||||
func NewCommandApp(ha driven.HAGateway, ai driven.AIGateway) *CommandApp {
|
func NewCommandApp(ha driven.HAGateway, ai driven.AIGateway, models *modelstore.Store, validator *modelvalidator.Validator) *CommandApp {
|
||||||
return &CommandApp{ha: ha, ai: ai}
|
return &CommandApp{ha: ha, ai: ai, models: models, validator: validator}
|
||||||
}
|
}
|
||||||
|
|
||||||
// HandleLightList formats discovered lights into a monospace-friendly response.
|
// HandleLightList formats discovered lights into a monospace-friendly response.
|
||||||
@ -124,11 +128,72 @@ func (a *CommandApp) HandleSwitchList(ctx context.Context) (string, error) {
|
|||||||
|
|
||||||
// HandleAIQuery forwards a free-form request to ai-gateway.
|
// HandleAIQuery forwards a free-form request to ai-gateway.
|
||||||
func (a *CommandApp) HandleAIQuery(ctx context.Context, text string) (string, error) {
|
func (a *CommandApp) HandleAIQuery(ctx context.Context, text string) (string, error) {
|
||||||
reply, err := a.ai.Query(ctx, text)
|
reply, modelUsed, err := a.ai.Query(ctx, text, a.models.Get())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("handle ai query: %w", err)
|
return "", fmt.Errorf("handle ai query: %w", err)
|
||||||
}
|
}
|
||||||
return reply, nil
|
return fmt.Sprintf("%s\n\n_(via %s)_", reply, modelUsed), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleAIModelSet validates and stores the selected model globally.
|
||||||
|
func (a *CommandApp) HandleAIModelSet(ctx context.Context, name string) (string, error) {
|
||||||
|
canonical, err := a.validator.Normalize(ctx, name)
|
||||||
|
if err != nil {
|
||||||
|
if err.Error() == "ambiguous model name" {
|
||||||
|
return "", fmt.Errorf("unknown model: %s. Be more specific.", name)
|
||||||
|
}
|
||||||
|
if err.Error() == "unknown model" {
|
||||||
|
return "", fmt.Errorf("unknown model: %s. Try /ai model list.", name)
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("validate model: %w", err)
|
||||||
|
}
|
||||||
|
a.models.Set(canonical)
|
||||||
|
return fmt.Sprintf("Active model set to `%s`.", canonical), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleAIModelGet reports the current selected model or default state.
|
||||||
|
func (a *CommandApp) HandleAIModelGet(ctx context.Context) (string, error) {
|
||||||
|
cur := a.models.Get()
|
||||||
|
if cur == "" {
|
||||||
|
return "No model override set. Using ai-gateway default.", nil
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("Active model: `%s`", cur), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleAIModelList shows the installed models and marks the active selection.
|
||||||
|
func (a *CommandApp) HandleAIModelList(ctx context.Context) (string, error) {
|
||||||
|
models, err := a.validator.Known(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("list models: %w", err)
|
||||||
|
}
|
||||||
|
if len(models) == 0 {
|
||||||
|
return "No models installed on the Ollama host.", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
active := a.models.Get()
|
||||||
|
lines := make([]string, 0, len(models)+1)
|
||||||
|
lines = append(lines, "Available models:")
|
||||||
|
for _, model := range models {
|
||||||
|
marker := ""
|
||||||
|
if model == active {
|
||||||
|
marker = " <- active"
|
||||||
|
}
|
||||||
|
lines = append(lines, fmt.Sprintf("- `%s`%s", model, marker))
|
||||||
|
}
|
||||||
|
return strings.Join(lines, "\n"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AutocompleteAIModels returns model names for the /ai model set command.
|
||||||
|
func (a *CommandApp) AutocompleteAIModels(ctx context.Context) ([]Choice, error) {
|
||||||
|
models, err := a.validator.Known(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("autocomplete ai models: %w", err)
|
||||||
|
}
|
||||||
|
choices := make([]Choice, 0, len(models))
|
||||||
|
for _, model := range models {
|
||||||
|
choices = append(choices, Choice{Label: model, Value: model})
|
||||||
|
}
|
||||||
|
return choices, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// AutocompleteLights maps discovered lights into Discord autocomplete choices.
|
// AutocompleteLights maps discovered lights into Discord autocomplete choices.
|
||||||
|
|||||||
@ -5,8 +5,11 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"gitea.nik4nao.com/nik/home-services/discord-bot/internal/core/ports/driven"
|
"gitea.nik4nao.com/nik/home-services/discord-bot/internal/core/ports/driven"
|
||||||
|
"gitea.nik4nao.com/nik/home-services/discord-bot/internal/modelstore"
|
||||||
|
"gitea.nik4nao.com/nik/home-services/discord-bot/internal/modelvalidator"
|
||||||
)
|
)
|
||||||
|
|
||||||
type mockHAGateway struct {
|
type mockHAGateway struct {
|
||||||
@ -18,7 +21,8 @@ type mockHAGateway struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type mockAIGateway struct {
|
type mockAIGateway struct {
|
||||||
queryFunc func(ctx context.Context, text string) (string, error)
|
queryFunc func(ctx context.Context, text, model string) (string, string, error)
|
||||||
|
listModelsFunc func(ctx context.Context) ([]string, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockHAGateway) ListLights(ctx context.Context) ([]driven.Light, error) {
|
func (m *mockHAGateway) ListLights(ctx context.Context) ([]driven.Light, error) {
|
||||||
@ -56,11 +60,22 @@ func (m *mockHAGateway) ToggleLight(ctx context.Context, entityID string) error
|
|||||||
return m.toggleLightFunc(ctx, entityID)
|
return m.toggleLightFunc(ctx, entityID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockAIGateway) Query(ctx context.Context, text string) (string, error) {
|
func (m *mockAIGateway) Query(ctx context.Context, text, model string) (string, string, error) {
|
||||||
if m.queryFunc == nil {
|
if m.queryFunc == nil {
|
||||||
return "", nil
|
return "", "", nil
|
||||||
}
|
}
|
||||||
return m.queryFunc(ctx, text)
|
return m.queryFunc(ctx, text, model)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockAIGateway) ListModels(ctx context.Context) ([]string, error) {
|
||||||
|
if m.listModelsFunc == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return m.listModelsFunc(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTestCommandApp(ha *mockHAGateway, ai *mockAIGateway) *CommandApp {
|
||||||
|
return NewCommandApp(ha, ai, modelstore.New(), modelvalidator.New(ai, time.Minute))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCommandAppHandleLightList(t *testing.T) {
|
func TestCommandAppHandleLightList(t *testing.T) {
|
||||||
@ -102,7 +117,7 @@ func TestCommandAppHandleLightList(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
app := NewCommandApp(&mockHAGateway{
|
app := newTestCommandApp(&mockHAGateway{
|
||||||
listLightsFunc: func(ctx context.Context) ([]driven.Light, error) {
|
listLightsFunc: func(ctx context.Context) ([]driven.Light, error) {
|
||||||
return tt.lights, nil
|
return tt.lights, nil
|
||||||
},
|
},
|
||||||
@ -183,7 +198,7 @@ func TestCommandAppHandleLightOn(t *testing.T) {
|
|||||||
var gotBrightness *uint32
|
var gotBrightness *uint32
|
||||||
var gotColorTemp *uint32
|
var gotColorTemp *uint32
|
||||||
|
|
||||||
app := NewCommandApp(&mockHAGateway{
|
app := newTestCommandApp(&mockHAGateway{
|
||||||
listLightsFunc: func(ctx context.Context) ([]driven.Light, error) {
|
listLightsFunc: func(ctx context.Context) ([]driven.Light, error) {
|
||||||
if tt.listErr != nil {
|
if tt.listErr != nil {
|
||||||
return nil, tt.listErr
|
return nil, tt.listErr
|
||||||
@ -258,7 +273,7 @@ func TestCommandAppHandleLightOff(t *testing.T) {
|
|||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
var gotTransition *uint32
|
var gotTransition *uint32
|
||||||
app := NewCommandApp(&mockHAGateway{
|
app := newTestCommandApp(&mockHAGateway{
|
||||||
listLightsFunc: func(ctx context.Context) ([]driven.Light, error) {
|
listLightsFunc: func(ctx context.Context) ([]driven.Light, error) {
|
||||||
return tt.lights, nil
|
return tt.lights, nil
|
||||||
},
|
},
|
||||||
@ -317,7 +332,7 @@ func TestCommandAppHandleLightToggle(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
app := NewCommandApp(&mockHAGateway{
|
app := newTestCommandApp(&mockHAGateway{
|
||||||
listLightsFunc: func(ctx context.Context) ([]driven.Light, error) {
|
listLightsFunc: func(ctx context.Context) ([]driven.Light, error) {
|
||||||
if tt.listErr != nil {
|
if tt.listErr != nil {
|
||||||
return nil, tt.listErr
|
return nil, tt.listErr
|
||||||
@ -368,7 +383,7 @@ func TestCommandAppHandleSwitchList(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
app := NewCommandApp(&mockHAGateway{
|
app := newTestCommandApp(&mockHAGateway{
|
||||||
listSwitchesFunc: func(ctx context.Context) ([]driven.Switch, error) {
|
listSwitchesFunc: func(ctx context.Context) ([]driven.Switch, error) {
|
||||||
return tt.switches, nil
|
return tt.switches, nil
|
||||||
},
|
},
|
||||||
@ -413,7 +428,7 @@ func TestCommandAppAutocompleteLights(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
app := NewCommandApp(&mockHAGateway{
|
app := newTestCommandApp(&mockHAGateway{
|
||||||
listLightsFunc: func(ctx context.Context) ([]driven.Light, error) {
|
listLightsFunc: func(ctx context.Context) ([]driven.Light, error) {
|
||||||
if tt.listErr != nil {
|
if tt.listErr != nil {
|
||||||
return nil, tt.listErr
|
return nil, tt.listErr
|
||||||
@ -467,7 +482,7 @@ func TestCommandAppAutocompleteSwitches(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
app := NewCommandApp(&mockHAGateway{
|
app := newTestCommandApp(&mockHAGateway{
|
||||||
listSwitchesFunc: func(ctx context.Context) ([]driven.Switch, error) {
|
listSwitchesFunc: func(ctx context.Context) ([]driven.Switch, error) {
|
||||||
if tt.listErr != nil {
|
if tt.listErr != nil {
|
||||||
return nil, tt.listErr
|
return nil, tt.listErr
|
||||||
@ -494,20 +509,41 @@ func TestCommandAppAutocompleteSwitches(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestCommandAppHandleAIQuery(t *testing.T) {
|
func TestCommandAppHandleAIQuery(t *testing.T) {
|
||||||
|
store := modelstore.New()
|
||||||
|
store.Set("llama3:latest")
|
||||||
app := NewCommandApp(&mockHAGateway{}, &mockAIGateway{
|
app := NewCommandApp(&mockHAGateway{}, &mockAIGateway{
|
||||||
queryFunc: func(ctx context.Context, text string) (string, error) {
|
queryFunc: func(ctx context.Context, text, model string) (string, string, error) {
|
||||||
if text != "turn on kitchen" {
|
if text != "turn on kitchen" {
|
||||||
t.Fatalf("Query() text = %q", text)
|
t.Fatalf("Query() text = %q", text)
|
||||||
}
|
}
|
||||||
return "Turning on Kitchen.", nil
|
if model != "llama3:latest" {
|
||||||
|
t.Fatalf("Query() model = %q", model)
|
||||||
|
}
|
||||||
|
return "Turning on Kitchen.", "llama3:latest", nil
|
||||||
},
|
},
|
||||||
})
|
}, store, modelvalidator.New(&mockAIGateway{}, time.Minute))
|
||||||
|
|
||||||
got, err := app.HandleAIQuery(context.Background(), "turn on kitchen")
|
got, err := app.HandleAIQuery(context.Background(), "turn on kitchen")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("HandleAIQuery() error = %v", err)
|
t.Fatalf("HandleAIQuery() error = %v", err)
|
||||||
}
|
}
|
||||||
if got != "Turning on Kitchen." {
|
if got != "Turning on Kitchen.\n\n_(via llama3:latest)_" {
|
||||||
t.Fatalf("HandleAIQuery() = %q", got)
|
t.Fatalf("HandleAIQuery() = %q", got)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCommandAppHandleAIModelSet(t *testing.T) {
|
||||||
|
app := newTestCommandApp(&mockHAGateway{}, &mockAIGateway{
|
||||||
|
listModelsFunc: func(ctx context.Context) ([]string, error) {
|
||||||
|
return []string{"llama3:latest"}, nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
got, err := app.HandleAIModelSet(context.Background(), "llama3")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("HandleAIModelSet() error = %v", err)
|
||||||
|
}
|
||||||
|
if got != "Active model set to `llama3:latest`." {
|
||||||
|
t.Fatalf("HandleAIModelSet() = %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@ -4,5 +4,6 @@ import "context"
|
|||||||
|
|
||||||
// AIGateway exposes the free-form AI query API used by the Discord bot.
|
// AIGateway exposes the free-form AI query API used by the Discord bot.
|
||||||
type AIGateway interface {
|
type AIGateway interface {
|
||||||
Query(ctx context.Context, text string) (string, error)
|
Query(ctx context.Context, text, model string) (reply, modelUsed string, err error)
|
||||||
|
ListModels(ctx context.Context) ([]string, error)
|
||||||
}
|
}
|
||||||
|
|||||||
28
discord-bot/internal/modelstore/store.go
Normal file
28
discord-bot/internal/modelstore/store.go
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
package modelstore
|
||||||
|
|
||||||
|
import "sync"
|
||||||
|
|
||||||
|
// Store keeps the globally selected AI model in memory.
|
||||||
|
type Store struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
selected string
|
||||||
|
}
|
||||||
|
|
||||||
|
// New constructs an empty in-memory model store.
|
||||||
|
func New() *Store {
|
||||||
|
return &Store{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get returns the currently selected model, or empty for default behavior.
|
||||||
|
func (s *Store) Get() string {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
return s.selected
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set updates the currently selected model.
|
||||||
|
func (s *Store) Set(model string) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
s.selected = model
|
||||||
|
}
|
||||||
15
discord-bot/internal/modelstore/store_test.go
Normal file
15
discord-bot/internal/modelstore/store_test.go
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
package modelstore
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestStoreGetSet(t *testing.T) {
|
||||||
|
store := New()
|
||||||
|
if got := store.Get(); got != "" {
|
||||||
|
t.Fatalf("Get() = %q, want empty", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
store.Set("llama3:latest")
|
||||||
|
if got := store.Get(); got != "llama3:latest" {
|
||||||
|
t.Fatalf("Get() = %q, want llama3:latest", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
86
discord-bot/internal/modelvalidator/validator.go
Normal file
86
discord-bot/internal/modelvalidator/validator.go
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
package modelvalidator
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gitea.nik4nao.com/nik/home-services/discord-bot/internal/core/ports/driven"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Validator caches the model list briefly and normalizes friendly names.
|
||||||
|
type Validator struct {
|
||||||
|
client driven.AIGateway
|
||||||
|
ttl time.Duration
|
||||||
|
|
||||||
|
mu sync.Mutex
|
||||||
|
cache []string
|
||||||
|
cachedAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// New constructs a model validator with a TTL cache.
|
||||||
|
func New(client driven.AIGateway, ttl time.Duration) *Validator {
|
||||||
|
return &Validator{client: client, ttl: ttl}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Known returns the cached model list, refreshing when stale.
|
||||||
|
func (v *Validator) Known(ctx context.Context) ([]string, error) {
|
||||||
|
v.mu.Lock()
|
||||||
|
defer v.mu.Unlock()
|
||||||
|
if len(v.cache) > 0 && time.Since(v.cachedAt) < v.ttl {
|
||||||
|
return append([]string(nil), v.cache...), nil
|
||||||
|
}
|
||||||
|
models, err := v.client.ListModels(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
v.cache = append([]string(nil), models...)
|
||||||
|
v.cachedAt = time.Now()
|
||||||
|
return append([]string(nil), v.cache...), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Normalize resolves a user-provided name to a canonical installed model name.
|
||||||
|
func (v *Validator) Normalize(ctx context.Context, name string) (string, error) {
|
||||||
|
models, err := v.Known(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
for _, model := range models {
|
||||||
|
if model == name {
|
||||||
|
return model, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
latest := name + ":latest"
|
||||||
|
for _, model := range models {
|
||||||
|
if model == latest {
|
||||||
|
return model, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
lower := strings.ToLower(name)
|
||||||
|
for _, model := range models {
|
||||||
|
if strings.ToLower(model) == lower {
|
||||||
|
return model, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
lowerLatest := strings.ToLower(latest)
|
||||||
|
for _, model := range models {
|
||||||
|
if strings.ToLower(model) == lowerLatest {
|
||||||
|
return model, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
matches := make([]string, 0, 2)
|
||||||
|
prefix := lower + ":"
|
||||||
|
for _, model := range models {
|
||||||
|
if strings.HasPrefix(strings.ToLower(model), prefix) {
|
||||||
|
matches = append(matches, model)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(matches) > 1 {
|
||||||
|
return "", fmt.Errorf("ambiguous model name")
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("unknown model")
|
||||||
|
}
|
||||||
70
discord-bot/internal/modelvalidator/validator_test.go
Normal file
70
discord-bot/internal/modelvalidator/validator_test.go
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
package modelvalidator
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type fakeAIGateway struct {
|
||||||
|
listModels func(context.Context) ([]string, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeAIGateway) Query(ctx context.Context, text, model string) (string, string, error) {
|
||||||
|
return "", "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeAIGateway) ListModels(ctx context.Context) ([]string, error) {
|
||||||
|
return f.listModels(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidatorNormalize(t *testing.T) {
|
||||||
|
v := New(&fakeAIGateway{
|
||||||
|
listModels: func(ctx context.Context) ([]string, error) {
|
||||||
|
return []string{"llama3:latest", "qwen3:latest", "qwen3:4b"}, nil
|
||||||
|
},
|
||||||
|
}, time.Minute)
|
||||||
|
|
||||||
|
got, err := v.Normalize(context.Background(), "llama3")
|
||||||
|
if err != nil || got != "llama3:latest" {
|
||||||
|
t.Fatalf("Normalize(llama3) = %q, %v", got, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err = v.Normalize(context.Background(), "qwen3")
|
||||||
|
if err != nil || got != "qwen3:latest" {
|
||||||
|
t.Fatalf("Normalize(qwen3) = %q, %v", got, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidatorKnownCaches(t *testing.T) {
|
||||||
|
calls := 0
|
||||||
|
v := New(&fakeAIGateway{
|
||||||
|
listModels: func(ctx context.Context) ([]string, error) {
|
||||||
|
calls++
|
||||||
|
return []string{"llama3:latest"}, nil
|
||||||
|
},
|
||||||
|
}, time.Minute)
|
||||||
|
|
||||||
|
if _, err := v.Known(context.Background()); err != nil {
|
||||||
|
t.Fatalf("Known() error = %v", err)
|
||||||
|
}
|
||||||
|
if _, err := v.Known(context.Background()); err != nil {
|
||||||
|
t.Fatalf("Known() error = %v", err)
|
||||||
|
}
|
||||||
|
if calls != 1 {
|
||||||
|
t.Fatalf("calls = %d, want 1", calls)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidatorKnownPropagatesError(t *testing.T) {
|
||||||
|
v := New(&fakeAIGateway{
|
||||||
|
listModels: func(ctx context.Context) ([]string, error) {
|
||||||
|
return nil, errors.New("boom")
|
||||||
|
},
|
||||||
|
}, time.Minute)
|
||||||
|
|
||||||
|
if _, err := v.Known(context.Background()); err == nil {
|
||||||
|
t.Fatal("Known() error = nil, want error")
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -25,6 +25,7 @@ type QueryRequest struct {
|
|||||||
state protoimpl.MessageState `protogen:"open.v1"`
|
state protoimpl.MessageState `protogen:"open.v1"`
|
||||||
Text string `protobuf:"bytes,1,opt,name=text,proto3" json:"text,omitempty"`
|
Text string `protobuf:"bytes,1,opt,name=text,proto3" json:"text,omitempty"`
|
||||||
Source string `protobuf:"bytes,2,opt,name=source,proto3" json:"source,omitempty"`
|
Source string `protobuf:"bytes,2,opt,name=source,proto3" json:"source,omitempty"`
|
||||||
|
Model string `protobuf:"bytes,3,opt,name=model,proto3" json:"model,omitempty"`
|
||||||
unknownFields protoimpl.UnknownFields
|
unknownFields protoimpl.UnknownFields
|
||||||
sizeCache protoimpl.SizeCache
|
sizeCache protoimpl.SizeCache
|
||||||
}
|
}
|
||||||
@ -73,11 +74,19 @@ func (x *QueryRequest) GetSource() string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (x *QueryRequest) GetModel() string {
|
||||||
|
if x != nil {
|
||||||
|
return x.Model
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
type QueryResponse struct {
|
type QueryResponse struct {
|
||||||
state protoimpl.MessageState `protogen:"open.v1"`
|
state protoimpl.MessageState `protogen:"open.v1"`
|
||||||
Reply string `protobuf:"bytes,1,opt,name=reply,proto3" json:"reply,omitempty"`
|
Reply string `protobuf:"bytes,1,opt,name=reply,proto3" json:"reply,omitempty"`
|
||||||
Intent string `protobuf:"bytes,2,opt,name=intent,proto3" json:"intent,omitempty"`
|
Intent string `protobuf:"bytes,2,opt,name=intent,proto3" json:"intent,omitempty"`
|
||||||
ActionTaken bool `protobuf:"varint,3,opt,name=action_taken,json=actionTaken,proto3" json:"action_taken,omitempty"`
|
ActionTaken bool `protobuf:"varint,3,opt,name=action_taken,json=actionTaken,proto3" json:"action_taken,omitempty"`
|
||||||
|
ModelUsed string `protobuf:"bytes,4,opt,name=model_used,json=modelUsed,proto3" json:"model_used,omitempty"`
|
||||||
unknownFields protoimpl.UnknownFields
|
unknownFields protoimpl.UnknownFields
|
||||||
sizeCache protoimpl.SizeCache
|
sizeCache protoimpl.SizeCache
|
||||||
}
|
}
|
||||||
@ -133,20 +142,115 @@ func (x *QueryResponse) GetActionTaken() bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (x *QueryResponse) GetModelUsed() string {
|
||||||
|
if x != nil {
|
||||||
|
return x.ModelUsed
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
type ListModelsRequest struct {
|
||||||
|
state protoimpl.MessageState `protogen:"open.v1"`
|
||||||
|
unknownFields protoimpl.UnknownFields
|
||||||
|
sizeCache protoimpl.SizeCache
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *ListModelsRequest) Reset() {
|
||||||
|
*x = ListModelsRequest{}
|
||||||
|
mi := &file_ai_v1_ai_proto_msgTypes[2]
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *ListModelsRequest) String() string {
|
||||||
|
return protoimpl.X.MessageStringOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*ListModelsRequest) ProtoMessage() {}
|
||||||
|
|
||||||
|
func (x *ListModelsRequest) ProtoReflect() protoreflect.Message {
|
||||||
|
mi := &file_ai_v1_ai_proto_msgTypes[2]
|
||||||
|
if x != nil {
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
if ms.LoadMessageInfo() == nil {
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
return ms
|
||||||
|
}
|
||||||
|
return mi.MessageOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deprecated: Use ListModelsRequest.ProtoReflect.Descriptor instead.
|
||||||
|
func (*ListModelsRequest) Descriptor() ([]byte, []int) {
|
||||||
|
return file_ai_v1_ai_proto_rawDescGZIP(), []int{2}
|
||||||
|
}
|
||||||
|
|
||||||
|
type ListModelsResponse struct {
|
||||||
|
state protoimpl.MessageState `protogen:"open.v1"`
|
||||||
|
Names []string `protobuf:"bytes,1,rep,name=names,proto3" json:"names,omitempty"`
|
||||||
|
unknownFields protoimpl.UnknownFields
|
||||||
|
sizeCache protoimpl.SizeCache
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *ListModelsResponse) Reset() {
|
||||||
|
*x = ListModelsResponse{}
|
||||||
|
mi := &file_ai_v1_ai_proto_msgTypes[3]
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *ListModelsResponse) String() string {
|
||||||
|
return protoimpl.X.MessageStringOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*ListModelsResponse) ProtoMessage() {}
|
||||||
|
|
||||||
|
func (x *ListModelsResponse) ProtoReflect() protoreflect.Message {
|
||||||
|
mi := &file_ai_v1_ai_proto_msgTypes[3]
|
||||||
|
if x != nil {
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
if ms.LoadMessageInfo() == nil {
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
return ms
|
||||||
|
}
|
||||||
|
return mi.MessageOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deprecated: Use ListModelsResponse.ProtoReflect.Descriptor instead.
|
||||||
|
func (*ListModelsResponse) Descriptor() ([]byte, []int) {
|
||||||
|
return file_ai_v1_ai_proto_rawDescGZIP(), []int{3}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *ListModelsResponse) GetNames() []string {
|
||||||
|
if x != nil {
|
||||||
|
return x.Names
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
var File_ai_v1_ai_proto protoreflect.FileDescriptor
|
var File_ai_v1_ai_proto protoreflect.FileDescriptor
|
||||||
|
|
||||||
const file_ai_v1_ai_proto_rawDesc = "" +
|
const file_ai_v1_ai_proto_rawDesc = "" +
|
||||||
"\n" +
|
"\n" +
|
||||||
"\x0eai/v1/ai.proto\x12\x05ai.v1\":\n" +
|
"\x0eai/v1/ai.proto\x12\x05ai.v1\"P\n" +
|
||||||
"\fQueryRequest\x12\x12\n" +
|
"\fQueryRequest\x12\x12\n" +
|
||||||
"\x04text\x18\x01 \x01(\tR\x04text\x12\x16\n" +
|
"\x04text\x18\x01 \x01(\tR\x04text\x12\x16\n" +
|
||||||
"\x06source\x18\x02 \x01(\tR\x06source\"`\n" +
|
"\x06source\x18\x02 \x01(\tR\x06source\x12\x14\n" +
|
||||||
|
"\x05model\x18\x03 \x01(\tR\x05model\"\x7f\n" +
|
||||||
"\rQueryResponse\x12\x14\n" +
|
"\rQueryResponse\x12\x14\n" +
|
||||||
"\x05reply\x18\x01 \x01(\tR\x05reply\x12\x16\n" +
|
"\x05reply\x18\x01 \x01(\tR\x05reply\x12\x16\n" +
|
||||||
"\x06intent\x18\x02 \x01(\tR\x06intent\x12!\n" +
|
"\x06intent\x18\x02 \x01(\tR\x06intent\x12!\n" +
|
||||||
"\faction_taken\x18\x03 \x01(\bR\vactionTaken2?\n" +
|
"\faction_taken\x18\x03 \x01(\bR\vactionTaken\x12\x1d\n" +
|
||||||
|
"\n" +
|
||||||
|
"model_used\x18\x04 \x01(\tR\tmodelUsed\"\x13\n" +
|
||||||
|
"\x11ListModelsRequest\"*\n" +
|
||||||
|
"\x12ListModelsResponse\x12\x14\n" +
|
||||||
|
"\x05names\x18\x01 \x03(\tR\x05names2\x82\x01\n" +
|
||||||
"\tAIService\x122\n" +
|
"\tAIService\x122\n" +
|
||||||
"\x05Query\x12\x13.ai.v1.QueryRequest\x1a\x14.ai.v1.QueryResponseB4Z2gitea.nik4nao.com/nik/home-services/gen/ai/v1;aiv1b\x06proto3"
|
"\x05Query\x12\x13.ai.v1.QueryRequest\x1a\x14.ai.v1.QueryResponse\x12A\n" +
|
||||||
|
"\n" +
|
||||||
|
"ListModels\x12\x18.ai.v1.ListModelsRequest\x1a\x19.ai.v1.ListModelsResponseB4Z2gitea.nik4nao.com/nik/home-services/gen/ai/v1;aiv1b\x06proto3"
|
||||||
|
|
||||||
var (
|
var (
|
||||||
file_ai_v1_ai_proto_rawDescOnce sync.Once
|
file_ai_v1_ai_proto_rawDescOnce sync.Once
|
||||||
@ -160,16 +264,20 @@ func file_ai_v1_ai_proto_rawDescGZIP() []byte {
|
|||||||
return file_ai_v1_ai_proto_rawDescData
|
return file_ai_v1_ai_proto_rawDescData
|
||||||
}
|
}
|
||||||
|
|
||||||
var file_ai_v1_ai_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
|
var file_ai_v1_ai_proto_msgTypes = make([]protoimpl.MessageInfo, 4)
|
||||||
var file_ai_v1_ai_proto_goTypes = []any{
|
var file_ai_v1_ai_proto_goTypes = []any{
|
||||||
(*QueryRequest)(nil), // 0: ai.v1.QueryRequest
|
(*QueryRequest)(nil), // 0: ai.v1.QueryRequest
|
||||||
(*QueryResponse)(nil), // 1: ai.v1.QueryResponse
|
(*QueryResponse)(nil), // 1: ai.v1.QueryResponse
|
||||||
|
(*ListModelsRequest)(nil), // 2: ai.v1.ListModelsRequest
|
||||||
|
(*ListModelsResponse)(nil), // 3: ai.v1.ListModelsResponse
|
||||||
}
|
}
|
||||||
var file_ai_v1_ai_proto_depIdxs = []int32{
|
var file_ai_v1_ai_proto_depIdxs = []int32{
|
||||||
0, // 0: ai.v1.AIService.Query:input_type -> ai.v1.QueryRequest
|
0, // 0: ai.v1.AIService.Query:input_type -> ai.v1.QueryRequest
|
||||||
1, // 1: ai.v1.AIService.Query:output_type -> ai.v1.QueryResponse
|
2, // 1: ai.v1.AIService.ListModels:input_type -> ai.v1.ListModelsRequest
|
||||||
1, // [1:2] is the sub-list for method output_type
|
1, // 2: ai.v1.AIService.Query:output_type -> ai.v1.QueryResponse
|
||||||
0, // [0:1] is the sub-list for method input_type
|
3, // 3: ai.v1.AIService.ListModels:output_type -> ai.v1.ListModelsResponse
|
||||||
|
2, // [2:4] is the sub-list for method output_type
|
||||||
|
0, // [0:2] is the sub-list for method input_type
|
||||||
0, // [0:0] is the sub-list for extension type_name
|
0, // [0:0] is the sub-list for extension type_name
|
||||||
0, // [0:0] is the sub-list for extension extendee
|
0, // [0:0] is the sub-list for extension extendee
|
||||||
0, // [0:0] is the sub-list for field type_name
|
0, // [0:0] is the sub-list for field type_name
|
||||||
@ -186,7 +294,7 @@ func file_ai_v1_ai_proto_init() {
|
|||||||
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||||
RawDescriptor: unsafe.Slice(unsafe.StringData(file_ai_v1_ai_proto_rawDesc), len(file_ai_v1_ai_proto_rawDesc)),
|
RawDescriptor: unsafe.Slice(unsafe.StringData(file_ai_v1_ai_proto_rawDesc), len(file_ai_v1_ai_proto_rawDesc)),
|
||||||
NumEnums: 0,
|
NumEnums: 0,
|
||||||
NumMessages: 2,
|
NumMessages: 4,
|
||||||
NumExtensions: 0,
|
NumExtensions: 0,
|
||||||
NumServices: 1,
|
NumServices: 1,
|
||||||
},
|
},
|
||||||
|
|||||||
@ -20,6 +20,7 @@ const _ = grpc.SupportPackageIsVersion9
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
AIService_Query_FullMethodName = "/ai.v1.AIService/Query"
|
AIService_Query_FullMethodName = "/ai.v1.AIService/Query"
|
||||||
|
AIService_ListModels_FullMethodName = "/ai.v1.AIService/ListModels"
|
||||||
)
|
)
|
||||||
|
|
||||||
// AIServiceClient is the client API for AIService service.
|
// AIServiceClient is the client API for AIService service.
|
||||||
@ -27,6 +28,7 @@ const (
|
|||||||
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
|
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
|
||||||
type AIServiceClient interface {
|
type AIServiceClient interface {
|
||||||
Query(ctx context.Context, in *QueryRequest, opts ...grpc.CallOption) (*QueryResponse, error)
|
Query(ctx context.Context, in *QueryRequest, opts ...grpc.CallOption) (*QueryResponse, error)
|
||||||
|
ListModels(ctx context.Context, in *ListModelsRequest, opts ...grpc.CallOption) (*ListModelsResponse, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type aIServiceClient struct {
|
type aIServiceClient struct {
|
||||||
@ -47,11 +49,22 @@ func (c *aIServiceClient) Query(ctx context.Context, in *QueryRequest, opts ...g
|
|||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *aIServiceClient) ListModels(ctx context.Context, in *ListModelsRequest, opts ...grpc.CallOption) (*ListModelsResponse, error) {
|
||||||
|
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||||
|
out := new(ListModelsResponse)
|
||||||
|
err := c.cc.Invoke(ctx, AIService_ListModels_FullMethodName, in, out, cOpts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
// AIServiceServer is the server API for AIService service.
|
// AIServiceServer is the server API for AIService service.
|
||||||
// All implementations must embed UnimplementedAIServiceServer
|
// All implementations must embed UnimplementedAIServiceServer
|
||||||
// for forward compatibility.
|
// for forward compatibility.
|
||||||
type AIServiceServer interface {
|
type AIServiceServer interface {
|
||||||
Query(context.Context, *QueryRequest) (*QueryResponse, error)
|
Query(context.Context, *QueryRequest) (*QueryResponse, error)
|
||||||
|
ListModels(context.Context, *ListModelsRequest) (*ListModelsResponse, error)
|
||||||
mustEmbedUnimplementedAIServiceServer()
|
mustEmbedUnimplementedAIServiceServer()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -65,6 +78,9 @@ type UnimplementedAIServiceServer struct{}
|
|||||||
func (UnimplementedAIServiceServer) Query(context.Context, *QueryRequest) (*QueryResponse, error) {
|
func (UnimplementedAIServiceServer) Query(context.Context, *QueryRequest) (*QueryResponse, error) {
|
||||||
return nil, status.Error(codes.Unimplemented, "method Query not implemented")
|
return nil, status.Error(codes.Unimplemented, "method Query not implemented")
|
||||||
}
|
}
|
||||||
|
func (UnimplementedAIServiceServer) ListModels(context.Context, *ListModelsRequest) (*ListModelsResponse, error) {
|
||||||
|
return nil, status.Error(codes.Unimplemented, "method ListModels not implemented")
|
||||||
|
}
|
||||||
func (UnimplementedAIServiceServer) mustEmbedUnimplementedAIServiceServer() {}
|
func (UnimplementedAIServiceServer) mustEmbedUnimplementedAIServiceServer() {}
|
||||||
func (UnimplementedAIServiceServer) testEmbeddedByValue() {}
|
func (UnimplementedAIServiceServer) testEmbeddedByValue() {}
|
||||||
|
|
||||||
@ -104,6 +120,24 @@ func _AIService_Query_Handler(srv interface{}, ctx context.Context, dec func(int
|
|||||||
return interceptor(ctx, in, info, handler)
|
return interceptor(ctx, in, info, handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func _AIService_ListModels_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||||
|
in := new(ListModelsRequest)
|
||||||
|
if err := dec(in); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if interceptor == nil {
|
||||||
|
return srv.(AIServiceServer).ListModels(ctx, in)
|
||||||
|
}
|
||||||
|
info := &grpc.UnaryServerInfo{
|
||||||
|
Server: srv,
|
||||||
|
FullMethod: AIService_ListModels_FullMethodName,
|
||||||
|
}
|
||||||
|
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||||
|
return srv.(AIServiceServer).ListModels(ctx, req.(*ListModelsRequest))
|
||||||
|
}
|
||||||
|
return interceptor(ctx, in, info, handler)
|
||||||
|
}
|
||||||
|
|
||||||
// AIService_ServiceDesc is the grpc.ServiceDesc for AIService service.
|
// AIService_ServiceDesc is the grpc.ServiceDesc for AIService service.
|
||||||
// It's only intended for direct use with grpc.RegisterService,
|
// It's only intended for direct use with grpc.RegisterService,
|
||||||
// and not to be introspected or modified (even as a copy)
|
// and not to be introspected or modified (even as a copy)
|
||||||
@ -115,6 +149,10 @@ var AIService_ServiceDesc = grpc.ServiceDesc{
|
|||||||
MethodName: "Query",
|
MethodName: "Query",
|
||||||
Handler: _AIService_Query_Handler,
|
Handler: _AIService_Query_Handler,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
MethodName: "ListModels",
|
||||||
|
Handler: _AIService_ListModels_Handler,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
Streams: []grpc.StreamDesc{},
|
Streams: []grpc.StreamDesc{},
|
||||||
Metadata: "ai/v1/ai.proto",
|
Metadata: "ai/v1/ai.proto",
|
||||||
|
|||||||
@ -6,15 +6,24 @@ option go_package = "gitea.nik4nao.com/nik/home-services/gen/ai/v1;aiv1";
|
|||||||
|
|
||||||
service AIService {
|
service AIService {
|
||||||
rpc Query(QueryRequest) returns (QueryResponse);
|
rpc Query(QueryRequest) returns (QueryResponse);
|
||||||
|
rpc ListModels(ListModelsRequest) returns (ListModelsResponse);
|
||||||
}
|
}
|
||||||
|
|
||||||
message QueryRequest {
|
message QueryRequest {
|
||||||
string text = 1;
|
string text = 1;
|
||||||
string source = 2;
|
string source = 2;
|
||||||
|
string model = 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
message QueryResponse {
|
message QueryResponse {
|
||||||
string reply = 1;
|
string reply = 1;
|
||||||
string intent = 2;
|
string intent = 2;
|
||||||
bool action_taken = 3;
|
bool action_taken = 3;
|
||||||
|
string model_used = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
message ListModelsRequest {}
|
||||||
|
|
||||||
|
message ListModelsResponse {
|
||||||
|
repeated string names = 1;
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user