184 lines
4.7 KiB
Go

package ha
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
"gitea.nik4nao.com/nik/home-services/ha-gateway/internal/config"
"gitea.nik4nao.com/nik/home-services/ha-gateway/internal/core/ports/driven"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
)
var tracer = otel.Tracer("ha-gateway/ha-client")
type Client struct {
baseURL string
token string
httpClient *http.Client
}
func NewClient(cfg *config.Config) *Client {
return &Client{
baseURL: strings.TrimRight(cfg.HABaseURL, "/"),
token: cfg.HAToken,
httpClient: &http.Client{Timeout: 10 * time.Second},
}
}
func (c *Client) GetState(ctx context.Context, entityID string) (*driven.HAState, error) {
ctx, span := tracer.Start(ctx, "ha.GetState")
defer span.End()
span.SetAttributes(attribute.String("entity_id", entityID))
var raw haStateRaw
if err := c.get(ctx, "/api/states/"+entityID, &raw); err != nil {
span.RecordError(err)
span.SetStatus(codes.Error, err.Error())
return nil, err
}
return raw.toDriven()
}
func (c *Client) ListStates(ctx context.Context) ([]*driven.HAState, error) {
ctx, span := tracer.Start(ctx, "ha.ListStates")
defer span.End()
var raw []haStateRaw
if err := c.get(ctx, "/api/states", &raw); err != nil {
span.RecordError(err)
span.SetStatus(codes.Error, err.Error())
return nil, err
}
out := make([]*driven.HAState, 0, len(raw))
for i := range raw {
s, err := raw[i].toDriven()
if err != nil {
return nil, err
}
out = append(out, s)
}
return out, nil
}
func (c *Client) CallService(ctx context.Context, domain, service string, payload map[string]any) ([]*driven.HAState, error) {
ctx, span := tracer.Start(ctx, "ha.CallService")
defer span.End()
span.SetAttributes(
attribute.String("ha.domain", domain),
attribute.String("ha.service", service),
)
body, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("marshal payload: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost,
c.baseURL+"/api/services/"+domain+"/"+service,
strings.NewReader(string(body)))
if err != nil {
return nil, fmt.Errorf("build request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+c.token)
req.Header.Set("Content-Type", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
span.RecordError(err)
span.SetStatus(codes.Error, err.Error())
return nil, fmt.Errorf("call service %s/%s: %w", domain, service, err)
}
defer resp.Body.Close()
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
preview := string(respBody)
if len(preview) > 200 {
preview = preview[:200]
}
err := fmt.Errorf("HA returned %d: %s", resp.StatusCode, preview)
span.RecordError(err)
span.SetStatus(codes.Error, err.Error())
return nil, err
}
var raw []haStateRaw
if err := json.Unmarshal(respBody, &raw); err != nil {
// HA may return an empty body or non-array on some calls; treat as empty.
return nil, nil
}
out := make([]*driven.HAState, 0, len(raw))
for i := range raw {
s, err := raw[i].toDriven()
if err != nil {
return nil, err
}
out = append(out, s)
}
return out, nil
}
func (c *Client) get(ctx context.Context, path string, dst any) error {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.baseURL+path, nil)
if err != nil {
return fmt.Errorf("build request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+c.token)
resp, err := c.httpClient.Do(req)
if err != nil {
return fmt.Errorf("GET %s: %w", path, err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
preview := string(body)
if len(preview) > 200 {
preview = preview[:200]
}
return fmt.Errorf("HA returned %d for GET %s: %s", resp.StatusCode, path, preview)
}
if err := json.Unmarshal(body, dst); err != nil {
return fmt.Errorf("decode response for GET %s: %w", path, err)
}
return nil
}
// haStateRaw is the raw JSON shape returned by the HA REST API.
type haStateRaw struct {
EntityID string `json:"entity_id"`
State string `json:"state"`
Attributes map[string]any `json:"attributes"`
LastChanged string `json:"last_changed"`
LastUpdated string `json:"last_updated"`
}
func (r *haStateRaw) toDriven() (*driven.HAState, error) {
lc, err := time.Parse(time.RFC3339, r.LastChanged)
if err != nil {
lc = time.Time{}
}
lu, err := time.Parse(time.RFC3339, r.LastUpdated)
if err != nil {
lu = time.Time{}
}
return &driven.HAState{
EntityID: r.EntityID,
State: r.State,
Attributes: r.Attributes,
LastChanged: lc,
LastUpdated: lu,
}, nil
}