- Updated LLMClient interface to support model-specific generation and model listing. - Integrated model store and validator into the command application for managing AI models. - Implemented commands for setting, getting, and listing active AI models in Discord. - Enhanced AI query handling to utilize the selected model and return model information in responses. - Added caching mechanism for model validation to improve performance. - Introduced gRPC methods for listing available AI models in the ai-gateway. - Updated protobuf definitions to include model-related fields and messages. - Added tests for model store and validator functionalities.
82 lines
2.2 KiB
Go
82 lines
2.2 KiB
Go
package ollama
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
)
|
|
|
|
func TestIsThinkingModel(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
want bool
|
|
}{
|
|
{name: "qwen3", want: true},
|
|
{name: "qwen3:4b", want: true},
|
|
{name: "qwen3:latest", want: true},
|
|
{name: "llama3", want: false},
|
|
{name: "mistral", want: false},
|
|
}
|
|
for _, tt := range tests {
|
|
if got := isThinkingModel(tt.name); got != tt.want {
|
|
t.Fatalf("isThinkingModel(%q) = %v, want %v", tt.name, got, tt.want)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestGenerateSetsThinkForQwen3(t *testing.T) {
|
|
var body map[string]any
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
|
t.Fatalf("Decode() error = %v", err)
|
|
}
|
|
_, _ = w.Write([]byte(`{"response":"ok"}`))
|
|
}))
|
|
defer srv.Close()
|
|
|
|
client := New(srv.URL, srv.Client())
|
|
if _, err := client.Generate(context.Background(), "qwen3:latest", "prompt"); err != nil {
|
|
t.Fatalf("Generate() error = %v", err)
|
|
}
|
|
if body["think"] != false {
|
|
t.Fatalf("think = %#v, want false", body["think"])
|
|
}
|
|
}
|
|
|
|
func TestGenerateOmitsThinkForLlama3(t *testing.T) {
|
|
var body map[string]any
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
|
t.Fatalf("Decode() error = %v", err)
|
|
}
|
|
_, _ = w.Write([]byte(`{"response":"ok"}`))
|
|
}))
|
|
defer srv.Close()
|
|
|
|
client := New(srv.URL, srv.Client())
|
|
if _, err := client.Generate(context.Background(), "llama3:latest", "prompt"); err != nil {
|
|
t.Fatalf("Generate() error = %v", err)
|
|
}
|
|
if _, ok := body["think"]; ok {
|
|
t.Fatalf("unexpected think key = %#v", body["think"])
|
|
}
|
|
}
|
|
|
|
func TestListModelsReturnsNamesOnly(t *testing.T) {
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
_, _ = w.Write([]byte(`{"models":[{"name":"llama3:latest"},{"name":"qwen3:latest"}]}`))
|
|
}))
|
|
defer srv.Close()
|
|
|
|
client := New(srv.URL, srv.Client())
|
|
got, err := client.ListModels(context.Background())
|
|
if err != nil {
|
|
t.Fatalf("ListModels() error = %v", err)
|
|
}
|
|
if len(got) != 2 || got[0] != "llama3:latest" || got[1] != "qwen3:latest" {
|
|
t.Fatalf("ListModels() = %#v", got)
|
|
}
|
|
}
|