Nik Afiq ad50d641bd
All checks were successful
CI / test (push) Successful in 5s
CI / build-ai-gateway (push) Successful in 43s
CI / build-ha-gateway (push) Successful in 47s
CI / build-discord-bot (push) Successful in 41s
feat: enhance AI model management in Discord bot
- Updated LLMClient interface to support model-specific generation and model listing.
- Integrated model store and validator into the command application for managing AI models.
- Implemented commands for setting, getting, and listing active AI models in Discord.
- Enhanced AI query handling to utilize the selected model and return model information in responses.
- Added caching mechanism for model validation to improve performance.
- Introduced gRPC methods for listing available AI models in the ai-gateway.
- Updated protobuf definitions to include model-related fields and messages.
- Added tests for model store and validator functionalities.
2026-04-21 22:52:00 +09:00

87 lines
2.0 KiB
Go

package modelvalidator
import (
"context"
"fmt"
"strings"
"sync"
"time"
"gitea.nik4nao.com/nik/home-services/discord-bot/internal/core/ports/driven"
)
// Validator caches the model list briefly and normalizes friendly names.
type Validator struct {
client driven.AIGateway
ttl time.Duration
mu sync.Mutex
cache []string
cachedAt time.Time
}
// New constructs a model validator with a TTL cache.
func New(client driven.AIGateway, ttl time.Duration) *Validator {
return &Validator{client: client, ttl: ttl}
}
// Known returns the cached model list, refreshing when stale.
func (v *Validator) Known(ctx context.Context) ([]string, error) {
v.mu.Lock()
defer v.mu.Unlock()
if len(v.cache) > 0 && time.Since(v.cachedAt) < v.ttl {
return append([]string(nil), v.cache...), nil
}
models, err := v.client.ListModels(ctx)
if err != nil {
return nil, err
}
v.cache = append([]string(nil), models...)
v.cachedAt = time.Now()
return append([]string(nil), v.cache...), nil
}
// Normalize resolves a user-provided name to a canonical installed model name.
func (v *Validator) Normalize(ctx context.Context, name string) (string, error) {
models, err := v.Known(ctx)
if err != nil {
return "", err
}
for _, model := range models {
if model == name {
return model, nil
}
}
latest := name + ":latest"
for _, model := range models {
if model == latest {
return model, nil
}
}
lower := strings.ToLower(name)
for _, model := range models {
if strings.ToLower(model) == lower {
return model, nil
}
}
lowerLatest := strings.ToLower(latest)
for _, model := range models {
if strings.ToLower(model) == lowerLatest {
return model, nil
}
}
matches := make([]string, 0, 2)
prefix := lower + ":"
for _, model := range models {
if strings.HasPrefix(strings.ToLower(model), prefix) {
matches = append(matches, model)
}
}
if len(matches) > 1 {
return "", fmt.Errorf("ambiguous model name")
}
return "", fmt.Errorf("unknown model")
}