- Updated LLMClient interface to support model-specific generation and model listing. - Integrated model store and validator into the command application for managing AI models. - Implemented commands for setting, getting, and listing active AI models in Discord. - Enhanced AI query handling to utilize the selected model and return model information in responses. - Added caching mechanism for model validation to improve performance. - Introduced gRPC methods for listing available AI models in the ai-gateway. - Updated protobuf definitions to include model-related fields and messages. - Added tests for model store and validator functionalities.
161 lines
4.6 KiB
Go
161 lines
4.6 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"fmt"
|
|
"log/slog"
|
|
"net"
|
|
"net/http"
|
|
"os"
|
|
"os/signal"
|
|
"path/filepath"
|
|
"syscall"
|
|
|
|
"github.com/joho/godotenv"
|
|
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/credentials"
|
|
"google.golang.org/grpc/health"
|
|
grpc_health_v1 "google.golang.org/grpc/health/grpc_health_v1"
|
|
"google.golang.org/grpc/reflection"
|
|
|
|
aigrpc "gitea.nik4nao.com/nik/home-services/ai-gateway/internal/adapters/primary/grpc"
|
|
"gitea.nik4nao.com/nik/home-services/ai-gateway/internal/adapters/secondary/hagateway"
|
|
"gitea.nik4nao.com/nik/home-services/ai-gateway/internal/adapters/secondary/ollama"
|
|
"gitea.nik4nao.com/nik/home-services/ai-gateway/internal/app"
|
|
"gitea.nik4nao.com/nik/home-services/ai-gateway/internal/config"
|
|
"gitea.nik4nao.com/nik/home-services/ai-gateway/internal/core/domain"
|
|
"gitea.nik4nao.com/nik/home-services/ai-gateway/internal/logger"
|
|
"gitea.nik4nao.com/nik/home-services/ai-gateway/internal/telemetry"
|
|
aiv1 "gitea.nik4nao.com/nik/home-services/gen/ai/v1"
|
|
)
|
|
|
|
// version is set at build time via -ldflags "-X main.version=<tag>".
|
|
var version = "dev"
|
|
|
|
func main() {
|
|
_ = godotenv.Load()
|
|
|
|
cfg, err := config.Load()
|
|
if err != nil {
|
|
os.Stderr.WriteString("config error: " + err.Error() + "\n")
|
|
os.Exit(1)
|
|
}
|
|
|
|
log := logger.New(cfg.LogFormat, cfg.LogLevel)
|
|
slog.SetDefault(log)
|
|
log.Info("starting ai-gateway",
|
|
"version", version,
|
|
"grpc_port", cfg.GRPCPort,
|
|
"ollama_url", cfg.OllamaURL,
|
|
"ollama_model", cfg.OllamaModel,
|
|
"ha_gateway_addr", cfg.HAGatewayAddr,
|
|
"tls_dir", cfg.TLSDir,
|
|
"otel_endpoint", cfg.OTELEndpoint,
|
|
"light_cache_ttl", cfg.LightCacheTTL.String(),
|
|
)
|
|
|
|
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT)
|
|
defer stop()
|
|
ctx = logger.WithLogger(ctx, log)
|
|
|
|
shutdown, err := telemetry.Setup(ctx, "ai-gateway", version, cfg)
|
|
if err != nil {
|
|
log.Error("telemetry setup failed", "err", err)
|
|
os.Exit(1)
|
|
}
|
|
|
|
ollamaClient := ollama.New(cfg.OllamaURL, &http.Client{Timeout: cfg.OllamaTimeout})
|
|
haClient, err := hagateway.New(ctx, cfg.HAGatewayAddr, cfg.TLSDir, cfg.HAGatewayServerName, log)
|
|
if err != nil {
|
|
log.Error("ha-gateway client setup failed", "err", err)
|
|
os.Exit(1)
|
|
}
|
|
defer func() {
|
|
if err := haClient.Close(); err != nil {
|
|
log.Error("ha-gateway client close failed", "err", err)
|
|
}
|
|
}()
|
|
|
|
lightCache := domain.NewLightCache(cfg.LightCacheTTL, haClient.ListLights)
|
|
queryApp := app.NewQueryApp(ollamaClient, haClient, lightCache, cfg.OllamaModel, log)
|
|
|
|
serverOpts := []grpc.ServerOption{
|
|
grpc.StatsHandler(otelgrpc.NewServerHandler()),
|
|
grpc.ChainUnaryInterceptor(aigrpc.LoggingUnaryInterceptor(log)),
|
|
}
|
|
if cfg.TLSDir != "" {
|
|
creds, err := loadServerCredentials(cfg.TLSDir)
|
|
if err != nil {
|
|
log.Error("load mTLS credentials failed", "tls_dir", cfg.TLSDir, "err", err)
|
|
os.Exit(1)
|
|
}
|
|
serverOpts = append(serverOpts, grpc.Creds(creds))
|
|
log.Info("mTLS enabled", "tls_dir", cfg.TLSDir)
|
|
} else {
|
|
log.Info("mTLS disabled")
|
|
}
|
|
|
|
srv := grpc.NewServer(serverOpts...)
|
|
healthSrv := health.NewServer()
|
|
healthSrv.SetServingStatus("", grpc_health_v1.HealthCheckResponse_SERVING)
|
|
|
|
aiv1.RegisterAIServiceServer(srv, aigrpc.NewServer(queryApp))
|
|
grpc_health_v1.RegisterHealthServer(srv, healthSrv)
|
|
if cfg.LogLevel == "debug" {
|
|
reflection.Register(srv)
|
|
}
|
|
|
|
lis, err := net.Listen("tcp", ":"+cfg.GRPCPort)
|
|
if err != nil {
|
|
log.Error("listen failed", "err", err)
|
|
os.Exit(1)
|
|
}
|
|
|
|
go func() {
|
|
log.Info("ai-gateway listening", "addr", lis.Addr().String())
|
|
if err := srv.Serve(lis); err != nil {
|
|
log.Error("serve failed", "err", err)
|
|
}
|
|
}()
|
|
|
|
<-ctx.Done()
|
|
log.Info("shutdown signal received, draining")
|
|
healthSrv.SetServingStatus("", grpc_health_v1.HealthCheckResponse_NOT_SERVING)
|
|
srv.GracefulStop()
|
|
log.Info("shutdown complete")
|
|
|
|
if err := shutdown(context.Background()); err != nil {
|
|
log.Error("telemetry shutdown error", "err", err)
|
|
}
|
|
}
|
|
|
|
func loadServerCredentials(tlsDir string) (credentials.TransportCredentials, error) {
|
|
cert, err := tls.LoadX509KeyPair(
|
|
filepath.Join(tlsDir, "tls.crt"),
|
|
filepath.Join(tlsDir, "tls.key"),
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("load server key pair: %w", err)
|
|
}
|
|
|
|
caPEM, err := os.ReadFile(filepath.Join(tlsDir, "ca.crt"))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("read client CA: %w", err)
|
|
}
|
|
|
|
clientCAs := x509.NewCertPool()
|
|
if !clientCAs.AppendCertsFromPEM(caPEM) {
|
|
return nil, fmt.Errorf("append client CA: invalid PEM")
|
|
}
|
|
|
|
return credentials.NewTLS(&tls.Config{
|
|
Certificates: []tls.Certificate{cert},
|
|
ClientCAs: clientCAs,
|
|
ClientAuth: tls.RequireAndVerifyClientCert,
|
|
MinVersion: tls.VersionTLS13,
|
|
}), nil
|
|
}
|