108 lines
2.8 KiB
Go
108 lines
2.8 KiB
Go
package aigateway
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"fmt"
|
|
"log/slog"
|
|
"os"
|
|
"path/filepath"
|
|
"time"
|
|
|
|
"gitea.nik4nao.com/nik/home-services/discord-bot/internal/logger"
|
|
aiv1 "gitea.nik4nao.com/nik/home-services/gen/ai/v1"
|
|
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/credentials"
|
|
"google.golang.org/grpc/credentials/insecure"
|
|
)
|
|
|
|
// Client implements the app's AI driven port over gRPC.
|
|
type Client struct {
|
|
conn *grpc.ClientConn
|
|
client aiv1.AIServiceClient
|
|
log *slog.Logger
|
|
}
|
|
|
|
// New constructs a gRPC client for the internal ai-gateway service.
|
|
func New(ctx context.Context, addr, tlsDir string, log *slog.Logger) (*Client, error) {
|
|
transportCreds := insecure.NewCredentials()
|
|
if tlsDir != "" {
|
|
creds, err := loadTransportCredentials(tlsDir)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("load mTLS credentials: %w", err)
|
|
}
|
|
transportCreds = creds
|
|
}
|
|
|
|
conn, err := grpc.NewClient(
|
|
addr,
|
|
grpc.WithTransportCredentials(transportCreds),
|
|
grpc.WithStatsHandler(otelgrpc.NewClientHandler()),
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("dial ai-gateway: %w", err)
|
|
}
|
|
|
|
return &Client{
|
|
conn: conn,
|
|
client: aiv1.NewAIServiceClient(conn),
|
|
log: log,
|
|
}, nil
|
|
}
|
|
|
|
// Close closes the underlying gRPC connection.
|
|
func (c *Client) Close() error {
|
|
if err := c.conn.Close(); err != nil {
|
|
return fmt.Errorf("close ai-gateway client: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Query forwards one free-form request to ai-gateway.
|
|
func (c *Client) Query(ctx context.Context, text string) (string, error) {
|
|
start := time.Now()
|
|
log := logger.FromContext(ctx).With("grpc.method", "AIService/Query")
|
|
resp, err := c.client.Query(ctx, &aiv1.QueryRequest{
|
|
Text: text,
|
|
Source: "discord-bot",
|
|
})
|
|
if err != nil {
|
|
log.Error("grpc call failed",
|
|
"duration_ms", time.Since(start).Milliseconds(),
|
|
"error", err.Error(),
|
|
)
|
|
return "", fmt.Errorf("query ai-gateway: %w", err)
|
|
}
|
|
log.Debug("grpc call completed", "duration_ms", time.Since(start).Milliseconds())
|
|
return resp.GetReply(), nil
|
|
}
|
|
|
|
func loadTransportCredentials(tlsDir string) (credentials.TransportCredentials, error) {
|
|
cert, err := tls.LoadX509KeyPair(
|
|
filepath.Join(tlsDir, "tls.crt"),
|
|
filepath.Join(tlsDir, "tls.key"),
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("load client key pair: %w", err)
|
|
}
|
|
|
|
caPEM, err := os.ReadFile(filepath.Join(tlsDir, "ca.crt"))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("read server CA: %w", err)
|
|
}
|
|
|
|
rootCAs := x509.NewCertPool()
|
|
if !rootCAs.AppendCertsFromPEM(caPEM) {
|
|
return nil, fmt.Errorf("append server CA: invalid PEM")
|
|
}
|
|
|
|
return credentials.NewTLS(&tls.Config{
|
|
Certificates: []tls.Certificate{cert},
|
|
RootCAs: rootCAs,
|
|
ServerName: "ai-gateway.home-services.svc.cluster.local",
|
|
MinVersion: tls.VersionTLS13,
|
|
}), nil
|
|
}
|