@@ -0,0 +1,13 @@
+// Package etag can create the etag value for the given data.
+package etag
+
+import (
+ "crypto/sha256"
+ "fmt"
+)
+
+// Of returns the etag for the given data.
+func Of(data []byte) string {
+ hash := sha256.Sum256(data)
+ return fmt.Sprintf(`%x`, hash[:16])
+}
@@ -4,11 +4,13 @@ package main
import (
"encoding/json"
+ "fmt"
"log"
"net/http"
"time"
"github.com/charmbracelet/catwalk/internal/deprecated"
+ "github.com/charmbracelet/catwalk/internal/etag"
"github.com/charmbracelet/catwalk/internal/providers"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
@@ -22,8 +24,24 @@ var counter = promauto.NewCounter(prometheus.CounterOpts{
Help: "Total number of requests to the providers endpoint",
})
+var (
+ providersJSON []byte
+ providersETag string
+)
+
+func init() {
+ var err error
+ providersJSON, err = json.Marshal(providers.GetAll())
+ if err != nil {
+ log.Fatal("Failed to marshal providers:", err)
+ }
+ providersETag = fmt.Sprintf(`"%s"`, etag.Of(providersJSON))
+}
+
func providersHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
+ w.Header().Set("ETag", providersETag)
+
if r.Method == http.MethodHead {
return
}
@@ -34,11 +52,16 @@ func providersHandler(w http.ResponseWriter, r *http.Request) {
}
counter.Inc()
- allProviders := providers.GetAll()
- if err := json.NewEncoder(w).Encode(allProviders); err != nil {
- http.Error(w, "Internal server error", http.StatusInternalServerError)
+
+ if match := r.Header.Get("If-None-Match"); match == providersETag {
+ w.WriteHeader(http.StatusNotModified)
return
}
+
+ if _, err := w.Write(providersJSON); err != nil {
+ log.Printf("Error writing response: %v", err)
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ }
}
func providersHandlerDeprecated(w http.ResponseWriter, r *http.Request) {
@@ -1,10 +1,14 @@
package catwalk
import (
+ "cmp"
+ "context"
"encoding/json"
"fmt"
"net/http"
"os"
+
+ "github.com/charmbracelet/catwalk/internal/etag"
)
const defaultURL = "http://localhost:8080"
@@ -18,13 +22,8 @@ type Client struct {
// New creates a new client instance
// Uses CATWALK_URL environment variable or falls back to localhost:8080.
func New() *Client {
- baseURL := os.Getenv("CATWALK_URL")
- if baseURL == "" {
- baseURL = defaultURL
- }
-
return &Client{
- baseURL: baseURL,
+ baseURL: cmp.Or(os.Getenv("CATWALK_URL"), defaultURL),
httpClient: &http.Client{},
}
}
@@ -37,16 +36,40 @@ func NewWithURL(url string) *Client {
}
}
+// ErrNotModified happens when the given ETag matches the server, so no update
+// is needed.
+var ErrNotModified = fmt.Errorf("not modified")
+
+// Etag returns the ETag for the given data.
+func Etag(data []byte) string { return etag.Of(data) }
+
// GetProviders retrieves all available providers from the service.
-func (c *Client) GetProviders() ([]Provider, error) {
- url := fmt.Sprintf("%s/v2/providers", c.baseURL)
+func (c *Client) GetProviders(ctx context.Context, etag string) ([]Provider, error) {
+ req, err := http.NewRequestWithContext(
+ ctx,
+ http.MethodGet,
+ fmt.Sprintf("%s/v2/providers", c.baseURL),
+ nil,
+ )
+ if err != nil {
+ return nil, fmt.Errorf("could not create request: %w", err)
+ }
+
+ if etag != "" {
+ // It needs to be quoted:
+ req.Header.Add("If-None-Match", fmt.Sprintf(`"%s"`, etag))
+ }
- resp, err := c.httpClient.Get(url) //nolint:noctx
+ resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to make request: %w", err)
}
defer resp.Body.Close() //nolint:errcheck
+ if resp.StatusCode == http.StatusNotModified {
+ return nil, ErrNotModified
+ }
+
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}