main.go

  1// Package main is the main entry point for the HTTP server that serves
  2// inference providers.
  3package main
  4
  5import (
  6	"encoding/json"
  7	"log"
  8	"net/http"
  9	"time"
 10
 11	"github.com/charmbracelet/catwalk/internal/deprecated"
 12	"github.com/charmbracelet/catwalk/internal/providers"
 13	"github.com/charmbracelet/x/etag"
 14	"github.com/prometheus/client_golang/prometheus"
 15	"github.com/prometheus/client_golang/prometheus/promauto"
 16	"github.com/prometheus/client_golang/prometheus/promhttp"
 17)
 18
 19var counter = promauto.NewCounter(prometheus.CounterOpts{
 20	Namespace: "catwalk",
 21	Subsystem: "providers",
 22	Name:      "requests_total",
 23	Help:      "Total number of requests to the providers endpoint",
 24})
 25
 26var (
 27	providersJSON []byte
 28	providersETag string
 29)
 30
 31func init() {
 32	var err error
 33	providersJSON, err = json.Marshal(providers.GetAll())
 34	if err != nil {
 35		log.Fatal("Failed to marshal providers:", err)
 36	}
 37	providersETag = etag.Of(providersJSON)
 38}
 39
 40func providersHandler(w http.ResponseWriter, r *http.Request) {
 41	w.Header().Set("Content-Type", "application/json")
 42	etag.Response(w, providersETag)
 43
 44	if r.Method == http.MethodHead {
 45		return
 46	}
 47
 48	if r.Method != http.MethodGet {
 49		http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
 50		return
 51	}
 52
 53	counter.Inc()
 54
 55	if etag.Matches(r, providersETag) {
 56		w.WriteHeader(http.StatusNotModified)
 57		return
 58	}
 59
 60	if _, err := w.Write(providersJSON); err != nil {
 61		log.Printf("Error writing response: %v", err)
 62		http.Error(w, err.Error(), http.StatusInternalServerError)
 63	}
 64}
 65
 66func providersHandlerDeprecated(w http.ResponseWriter, r *http.Request) {
 67	w.Header().Set("Content-Type", "application/json")
 68	if r.Method == http.MethodHead {
 69		return
 70	}
 71
 72	if r.Method != http.MethodGet {
 73		http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
 74		return
 75	}
 76
 77	counter.Inc()
 78	allProviders := deprecated.GetAll()
 79	if err := json.NewEncoder(w).Encode(allProviders); err != nil {
 80		http.Error(w, "Internal server error", http.StatusInternalServerError)
 81		return
 82	}
 83}
 84
 85func main() {
 86	mux := http.NewServeMux()
 87	mux.HandleFunc("/v2/providers", providersHandler)
 88	mux.HandleFunc("/providers", providersHandlerDeprecated)
 89	mux.HandleFunc("/healthz", func(w http.ResponseWriter, _ *http.Request) {
 90		w.WriteHeader(http.StatusOK)
 91		_, _ = w.Write([]byte("OK"))
 92	})
 93	mux.Handle("/metrics", promhttp.Handler())
 94
 95	server := &http.Server{
 96		Addr:         ":8080",
 97		Handler:      mux,
 98		ReadTimeout:  15 * time.Second,
 99		WriteTimeout: 15 * time.Second,
100		IdleTimeout:  60 * time.Second,
101	}
102
103	log.Println("Server starting on :8080")
104	if err := server.ListenAndServe(); err != nil {
105		log.Fatal("Server failed to start:", err)
106	}
107}