- Updated LLMClient interface to support model-specific generation and model listing. - Integrated model store and validator into the command application for managing AI models. - Implemented commands for setting, getting, and listing active AI models in Discord. - Enhanced AI query handling to utilize the selected model and return model information in responses. - Added caching mechanism for model validation to improve performance. - Introduced gRPC methods for listing available AI models in the ai-gateway. - Updated protobuf definitions to include model-related fields and messages. - Added tests for model store and validator functionalities.
224 lines
6.4 KiB
Go
224 lines
6.4 KiB
Go
package app
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"log/slog"
|
|
"slices"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"gitea.nik4nao.com/nik/home-services/ai-gateway/internal/core/domain"
|
|
"gitea.nik4nao.com/nik/home-services/ai-gateway/internal/core/ports/driven"
|
|
)
|
|
|
|
// QueryResult is the app-layer response mapped onto the gRPC API.
|
|
type QueryResult struct {
|
|
Reply string
|
|
Intent string
|
|
ActionTaken bool
|
|
ModelUsed string
|
|
}
|
|
|
|
// QueryApp orchestrates one AI query request.
|
|
type QueryApp struct {
|
|
llm driven.LLMClient
|
|
ha driven.HAClient
|
|
cache *domain.LightCache
|
|
defaultModel string
|
|
log *slog.Logger
|
|
}
|
|
|
|
// NewQueryApp constructs the AI query application service.
|
|
func NewQueryApp(llm driven.LLMClient, ha driven.HAClient, cache *domain.LightCache, defaultModel string, log *slog.Logger) *QueryApp {
|
|
return &QueryApp{llm: llm, ha: ha, cache: cache, defaultModel: defaultModel, log: log}
|
|
}
|
|
|
|
// Query runs the full intent parsing and dispatch flow for one user request.
|
|
func (a *QueryApp) Query(ctx context.Context, text, model string) (QueryResult, error) {
|
|
if model == "" {
|
|
model = a.defaultModel
|
|
}
|
|
lights, err := a.cache.Get(ctx)
|
|
if err != nil {
|
|
a.log.Error("light cache refresh failed", "err", err)
|
|
return QueryResult{
|
|
Reply: "I couldn't reach Home Assistant right now.",
|
|
ActionTaken: false,
|
|
ModelUsed: model,
|
|
}, nil
|
|
}
|
|
|
|
prompt := domain.BuildPrompt(text, promptLightLines(lights))
|
|
raw, err := a.llm.Generate(ctx, model, prompt)
|
|
if err != nil {
|
|
return QueryResult{}, err
|
|
}
|
|
|
|
var intent domain.Intent
|
|
if err := json.Unmarshal([]byte(raw), &intent); err != nil {
|
|
a.log.Warn("llm returned invalid json", "text", text, "raw_output", raw)
|
|
return QueryResult{
|
|
Reply: "I didn't understand that.",
|
|
Intent: domain.IntentNone,
|
|
ActionTaken: false,
|
|
ModelUsed: model,
|
|
}, nil
|
|
}
|
|
|
|
switch intent.Name {
|
|
case domain.IntentTurnOnLight:
|
|
entityID, ok := resolveLightEntity(intent.Entity, lights)
|
|
if !ok {
|
|
return QueryResult{Reply: "I couldn't find that light.", Intent: intent.Name, ModelUsed: model}, nil
|
|
}
|
|
params, err := ParseLightParams(intent.Params)
|
|
if err != nil {
|
|
return QueryResult{
|
|
Reply: "I couldn't understand the light settings.",
|
|
Intent: intent.Name,
|
|
ActionTaken: false,
|
|
ModelUsed: model,
|
|
}, nil
|
|
}
|
|
if err := a.ha.TurnOnLight(ctx, entityID, params); err != nil {
|
|
a.log.Error("turn on light failed", "entity_id", entityID, "err", err)
|
|
return QueryResult{
|
|
Reply: "I couldn't reach Home Assistant right now.",
|
|
Intent: intent.Name,
|
|
ActionTaken: false,
|
|
ModelUsed: model,
|
|
}, nil
|
|
}
|
|
return QueryResult{
|
|
Reply: fallbackReply(intent.Reply, fmt.Sprintf("Turned on `%s`.", displayLightName(entityID, lights))),
|
|
Intent: intent.Name,
|
|
ActionTaken: true,
|
|
ModelUsed: model,
|
|
}, nil
|
|
case domain.IntentTurnOffLight:
|
|
entityID, ok := resolveLightEntity(intent.Entity, lights)
|
|
if !ok {
|
|
return QueryResult{Reply: "I couldn't find that light.", Intent: intent.Name, ModelUsed: model}, nil
|
|
}
|
|
if err := a.ha.TurnOffLight(ctx, entityID); err != nil {
|
|
a.log.Error("turn off light failed", "entity_id", entityID, "err", err)
|
|
return QueryResult{
|
|
Reply: "I couldn't reach Home Assistant right now.",
|
|
Intent: intent.Name,
|
|
ActionTaken: false,
|
|
ModelUsed: model,
|
|
}, nil
|
|
}
|
|
return QueryResult{
|
|
Reply: fallbackReply(intent.Reply, fmt.Sprintf("Turned off `%s`.", displayLightName(entityID, lights))),
|
|
Intent: intent.Name,
|
|
ActionTaken: true,
|
|
ModelUsed: model,
|
|
}, nil
|
|
case domain.IntentListLights:
|
|
return QueryResult{
|
|
Reply: formatLightListReply(lights),
|
|
Intent: intent.Name,
|
|
ActionTaken: false,
|
|
ModelUsed: model,
|
|
}, nil
|
|
case domain.IntentNone:
|
|
fallthrough
|
|
default:
|
|
return QueryResult{
|
|
Reply: fallbackReply(intent.Reply, "I didn't understand that."),
|
|
Intent: intent.Name,
|
|
ActionTaken: false,
|
|
ModelUsed: model,
|
|
}, nil
|
|
}
|
|
}
|
|
|
|
// ListModels returns the currently installed model names.
|
|
func (a *QueryApp) ListModels(ctx context.Context) ([]string, error) {
|
|
return a.llm.ListModels(ctx)
|
|
}
|
|
|
|
func promptLightLines(lights []driven.Light) []string {
|
|
lines := make([]string, 0, len(lights))
|
|
for _, light := range lights {
|
|
label := light.FriendlyName
|
|
if label == "" {
|
|
label = light.EntityID
|
|
}
|
|
lines = append(lines, fmt.Sprintf("- %s (%s) state=%s", label, light.EntityID, light.State))
|
|
}
|
|
return lines
|
|
}
|
|
|
|
func resolveLightEntity(value string, lights []driven.Light) (string, bool) {
|
|
needle := strings.TrimSpace(strings.ToLower(value))
|
|
if needle == "" {
|
|
return "", false
|
|
}
|
|
idx := slices.IndexFunc(lights, func(light driven.Light) bool {
|
|
return strings.ToLower(light.EntityID) == needle || strings.ToLower(light.FriendlyName) == needle
|
|
})
|
|
if idx != -1 {
|
|
return lights[idx].EntityID, true
|
|
}
|
|
idx = slices.IndexFunc(lights, func(light driven.Light) bool {
|
|
return strings.Contains(strings.ToLower(light.FriendlyName), needle)
|
|
})
|
|
if idx != -1 {
|
|
return lights[idx].EntityID, true
|
|
}
|
|
return "", false
|
|
}
|
|
|
|
func displayLightName(entityID string, lights []driven.Light) string {
|
|
idx := slices.IndexFunc(lights, func(light driven.Light) bool { return light.EntityID == entityID })
|
|
if idx == -1 || lights[idx].FriendlyName == "" {
|
|
return entityID
|
|
}
|
|
return lights[idx].FriendlyName
|
|
}
|
|
|
|
func fallbackReply(reply, fallback string) string {
|
|
if strings.TrimSpace(reply) == "" {
|
|
return fallback
|
|
}
|
|
return reply
|
|
}
|
|
|
|
func formatLightListReply(lights []driven.Light) string {
|
|
if len(lights) == 0 {
|
|
return "No lights found."
|
|
}
|
|
lines := make([]string, 0, len(lights)+1)
|
|
lines = append(lines, "Known lights:")
|
|
for _, light := range lights {
|
|
label := light.FriendlyName
|
|
if label == "" {
|
|
label = light.EntityID
|
|
}
|
|
lines = append(lines, fmt.Sprintf("- %s (%s) [%s]", label, light.EntityID, light.State))
|
|
}
|
|
return strings.Join(lines, "\n")
|
|
}
|
|
|
|
// ParseLightParams converts string params into the protobuf-compatible values the HA adapter expects.
|
|
func ParseLightParams(params map[string]string) (map[string]string, error) {
|
|
if len(params) == 0 {
|
|
return map[string]string{}, nil
|
|
}
|
|
normalized := make(map[string]string, len(params))
|
|
for key, value := range params {
|
|
switch key {
|
|
case "brightness", "brightness_pct", "color_temp", "color_temp_kelvin":
|
|
if _, err := strconv.ParseUint(value, 10, 32); err != nil {
|
|
return nil, fmt.Errorf("invalid %s value %q: %w", key, value, err)
|
|
}
|
|
}
|
|
normalized[key] = value
|
|
}
|
|
return normalized, nil
|
|
}
|