- Updated LLMClient interface to support model-specific generation and model listing. - Integrated model store and validator into the command application for managing AI models. - Implemented commands for setting, getting, and listing active AI models in Discord. - Enhanced AI query handling to utilize the selected model and return model information in responses. - Added caching mechanism for model validation to improve performance. - Introduced gRPC methods for listing available AI models in the ai-gateway. - Updated protobuf definitions to include model-related fields and messages. - Added tests for model store and validator functionalities.
87 lines
2.0 KiB
Go
87 lines
2.0 KiB
Go
package modelvalidator
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"gitea.nik4nao.com/nik/home-services/discord-bot/internal/core/ports/driven"
|
|
)
|
|
|
|
// Validator caches the model list briefly and normalizes friendly names.
|
|
type Validator struct {
|
|
client driven.AIGateway
|
|
ttl time.Duration
|
|
|
|
mu sync.Mutex
|
|
cache []string
|
|
cachedAt time.Time
|
|
}
|
|
|
|
// New constructs a model validator with a TTL cache.
|
|
func New(client driven.AIGateway, ttl time.Duration) *Validator {
|
|
return &Validator{client: client, ttl: ttl}
|
|
}
|
|
|
|
// Known returns the cached model list, refreshing when stale.
|
|
func (v *Validator) Known(ctx context.Context) ([]string, error) {
|
|
v.mu.Lock()
|
|
defer v.mu.Unlock()
|
|
if len(v.cache) > 0 && time.Since(v.cachedAt) < v.ttl {
|
|
return append([]string(nil), v.cache...), nil
|
|
}
|
|
models, err := v.client.ListModels(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
v.cache = append([]string(nil), models...)
|
|
v.cachedAt = time.Now()
|
|
return append([]string(nil), v.cache...), nil
|
|
}
|
|
|
|
// Normalize resolves a user-provided name to a canonical installed model name.
|
|
func (v *Validator) Normalize(ctx context.Context, name string) (string, error) {
|
|
models, err := v.Known(ctx)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
for _, model := range models {
|
|
if model == name {
|
|
return model, nil
|
|
}
|
|
}
|
|
latest := name + ":latest"
|
|
for _, model := range models {
|
|
if model == latest {
|
|
return model, nil
|
|
}
|
|
}
|
|
|
|
lower := strings.ToLower(name)
|
|
for _, model := range models {
|
|
if strings.ToLower(model) == lower {
|
|
return model, nil
|
|
}
|
|
}
|
|
lowerLatest := strings.ToLower(latest)
|
|
for _, model := range models {
|
|
if strings.ToLower(model) == lowerLatest {
|
|
return model, nil
|
|
}
|
|
}
|
|
|
|
matches := make([]string, 0, 2)
|
|
prefix := lower + ":"
|
|
for _, model := range models {
|
|
if strings.HasPrefix(strings.ToLower(model), prefix) {
|
|
matches = append(matches, model)
|
|
}
|
|
}
|
|
if len(matches) > 1 {
|
|
return "", fmt.Errorf("ambiguous model name")
|
|
}
|
|
return "", fmt.Errorf("unknown model")
|
|
}
|