perf: set etag, read if-none-match, marshal once (#116)

Carlos Alexandro Becker created

* perf: set etag, read if-none-match, marshal once

Signed-off-by: Carlos Alexandro Becker <caarlos0@users.noreply.github.com>

* fix: lint

Signed-off-by: Carlos Alexandro Becker <caarlos0@users.noreply.github.com>

* fix: etag.Of

Signed-off-by: Carlos Alexandro Becker <caarlos0@users.noreply.github.com>

* fixup! fix: etag.Of

Signed-off-by: Carlos Alexandro Becker <caarlos0@users.noreply.github.com>

* fixup! fixup! fix: etag.Of

Signed-off-by: Carlos Alexandro Becker <caarlos0@users.noreply.github.com>

---------

Signed-off-by: Carlos Alexandro Becker <caarlos0@users.noreply.github.com>

Change summary

internal/etag/etag.go | 13 +++++++++++++
main.go               | 29 ++++++++++++++++++++++++++---
pkg/catwalk/client.go | 41 ++++++++++++++++++++++++++++++++---------
3 files changed, 71 insertions(+), 12 deletions(-)

Detailed changes

internal/etag/etag.go 🔗

@@ -0,0 +1,13 @@
+// Package etag can create the etag value for the given data.
+package etag
+
+import (
+	"crypto/sha256"
+	"fmt"
+)
+
+// Of returns the etag for the given data.
+func Of(data []byte) string {
+	hash := sha256.Sum256(data)
+	return fmt.Sprintf(`%x`, hash[:16])
+}

main.go 🔗

@@ -4,11 +4,13 @@ package main
 
 import (
 	"encoding/json"
+	"fmt"
 	"log"
 	"net/http"
 	"time"
 
 	"github.com/charmbracelet/catwalk/internal/deprecated"
+	"github.com/charmbracelet/catwalk/internal/etag"
 	"github.com/charmbracelet/catwalk/internal/providers"
 	"github.com/prometheus/client_golang/prometheus"
 	"github.com/prometheus/client_golang/prometheus/promauto"
@@ -22,8 +24,24 @@ var counter = promauto.NewCounter(prometheus.CounterOpts{
 	Help:      "Total number of requests to the providers endpoint",
 })
 
+var (
+	providersJSON []byte
+	providersETag string
+)
+
+func init() {
+	var err error
+	providersJSON, err = json.Marshal(providers.GetAll())
+	if err != nil {
+		log.Fatal("Failed to marshal providers:", err)
+	}
+	providersETag = fmt.Sprintf(`"%s"`, etag.Of(providersJSON))
+}
+
 func providersHandler(w http.ResponseWriter, r *http.Request) {
 	w.Header().Set("Content-Type", "application/json")
+	w.Header().Set("ETag", providersETag)
+
 	if r.Method == http.MethodHead {
 		return
 	}
@@ -34,11 +52,16 @@ func providersHandler(w http.ResponseWriter, r *http.Request) {
 	}
 
 	counter.Inc()
-	allProviders := providers.GetAll()
-	if err := json.NewEncoder(w).Encode(allProviders); err != nil {
-		http.Error(w, "Internal server error", http.StatusInternalServerError)
+
+	if match := r.Header.Get("If-None-Match"); match == providersETag {
+		w.WriteHeader(http.StatusNotModified)
 		return
 	}
+
+	if _, err := w.Write(providersJSON); err != nil {
+		log.Printf("Error writing response: %v", err)
+		http.Error(w, err.Error(), http.StatusInternalServerError)
+	}
 }
 
 func providersHandlerDeprecated(w http.ResponseWriter, r *http.Request) {

pkg/catwalk/client.go 🔗

@@ -1,10 +1,14 @@
 package catwalk
 
 import (
+	"cmp"
+	"context"
 	"encoding/json"
 	"fmt"
 	"net/http"
 	"os"
+
+	"github.com/charmbracelet/catwalk/internal/etag"
 )
 
 const defaultURL = "http://localhost:8080"
@@ -18,13 +22,8 @@ type Client struct {
 // New creates a new client instance
 // Uses CATWALK_URL environment variable or falls back to localhost:8080.
 func New() *Client {
-	baseURL := os.Getenv("CATWALK_URL")
-	if baseURL == "" {
-		baseURL = defaultURL
-	}
-
 	return &Client{
-		baseURL:    baseURL,
+		baseURL:    cmp.Or(os.Getenv("CATWALK_URL"), defaultURL),
 		httpClient: &http.Client{},
 	}
 }
@@ -37,16 +36,40 @@ func NewWithURL(url string) *Client {
 	}
 }
 
+// ErrNotModified happens when the given ETag matches the server, so no update
+// is needed.
+var ErrNotModified = fmt.Errorf("not modified")
+
+// Etag returns the ETag for the given data.
+func Etag(data []byte) string { return etag.Of(data) }
+
 // GetProviders retrieves all available providers from the service.
-func (c *Client) GetProviders() ([]Provider, error) {
-	url := fmt.Sprintf("%s/v2/providers", c.baseURL)
+func (c *Client) GetProviders(ctx context.Context, etag string) ([]Provider, error) {
+	req, err := http.NewRequestWithContext(
+		ctx,
+		http.MethodGet,
+		fmt.Sprintf("%s/v2/providers", c.baseURL),
+		nil,
+	)
+	if err != nil {
+		return nil, fmt.Errorf("could not create request: %w", err)
+	}
+
+	if etag != "" {
+		// It needs to be quoted:
+		req.Header.Add("If-None-Match", fmt.Sprintf(`"%s"`, etag))
+	}
 
-	resp, err := c.httpClient.Get(url) //nolint:noctx
+	resp, err := c.httpClient.Do(req)
 	if err != nil {
 		return nil, fmt.Errorf("failed to make request: %w", err)
 	}
 	defer resp.Body.Close() //nolint:errcheck
 
+	if resp.StatusCode == http.StatusNotModified {
+		return nil, ErrNotModified
+	}
+
 	if resp.StatusCode != http.StatusOK {
 		return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
 	}