main.go

  1// Package main is the main entry point for the HTTP server that serves
  2// inference providers.
  3package main
  4
  5import (
  6	"encoding/json"
  7	"fmt"
  8	"log"
  9	"net/http"
 10	"time"
 11
 12	"github.com/charmbracelet/catwalk/internal/deprecated"
 13	"github.com/charmbracelet/catwalk/internal/etag"
 14	"github.com/charmbracelet/catwalk/internal/providers"
 15	"github.com/prometheus/client_golang/prometheus"
 16	"github.com/prometheus/client_golang/prometheus/promauto"
 17	"github.com/prometheus/client_golang/prometheus/promhttp"
 18)
 19
 20var counter = promauto.NewCounter(prometheus.CounterOpts{
 21	Namespace: "catwalk",
 22	Subsystem: "providers",
 23	Name:      "requests_total",
 24	Help:      "Total number of requests to the providers endpoint",
 25})
 26
 27var (
 28	providersJSON []byte
 29	providersETag string
 30)
 31
 32func init() {
 33	var err error
 34	providersJSON, err = json.Marshal(providers.GetAll())
 35	if err != nil {
 36		log.Fatal("Failed to marshal providers:", err)
 37	}
 38	providersETag = fmt.Sprintf(`"%s"`, etag.Of(providersJSON))
 39}
 40
 41func providersHandler(w http.ResponseWriter, r *http.Request) {
 42	w.Header().Set("Content-Type", "application/json")
 43	w.Header().Set("ETag", providersETag)
 44
 45	if r.Method == http.MethodHead {
 46		return
 47	}
 48
 49	if r.Method != http.MethodGet {
 50		http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
 51		return
 52	}
 53
 54	counter.Inc()
 55
 56	if match := r.Header.Get("If-None-Match"); match == providersETag {
 57		w.WriteHeader(http.StatusNotModified)
 58		return
 59	}
 60
 61	if _, err := w.Write(providersJSON); err != nil {
 62		log.Printf("Error writing response: %v", err)
 63		http.Error(w, err.Error(), http.StatusInternalServerError)
 64	}
 65}
 66
 67func providersHandlerDeprecated(w http.ResponseWriter, r *http.Request) {
 68	w.Header().Set("Content-Type", "application/json")
 69	if r.Method == http.MethodHead {
 70		return
 71	}
 72
 73	if r.Method != http.MethodGet {
 74		http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
 75		return
 76	}
 77
 78	counter.Inc()
 79	allProviders := deprecated.GetAll()
 80	if err := json.NewEncoder(w).Encode(allProviders); err != nil {
 81		http.Error(w, "Internal server error", http.StatusInternalServerError)
 82		return
 83	}
 84}
 85
 86func main() {
 87	mux := http.NewServeMux()
 88	mux.HandleFunc("/v2/providers", providersHandler)
 89	mux.HandleFunc("/providers", providersHandlerDeprecated)
 90	mux.HandleFunc("/healthz", func(w http.ResponseWriter, _ *http.Request) {
 91		w.WriteHeader(http.StatusOK)
 92		_, _ = w.Write([]byte("OK"))
 93	})
 94	mux.Handle("/metrics", promhttp.Handler())
 95
 96	server := &http.Server{
 97		Addr:         ":8080",
 98		Handler:      mux,
 99		ReadTimeout:  15 * time.Second,
100		WriteTimeout: 15 * time.Second,
101		IdleTimeout:  60 * time.Second,
102	}
103
104	log.Println("Server starting on :8080")
105	if err := server.ListenAndServe(); err != nil {
106		log.Fatal("Server failed to start:", err)
107	}
108}