1// Package main is the main entry point for the HTTP server that serves
2// inference providers.
3package main
4
5import (
6 "encoding/json"
7 "log"
8 "net/http"
9 "time"
10
11 "charm.land/catwalk/internal/providers"
12 "github.com/charmbracelet/x/etag"
13 "github.com/prometheus/client_golang/prometheus"
14 "github.com/prometheus/client_golang/prometheus/promauto"
15 "github.com/prometheus/client_golang/prometheus/promhttp"
16)
17
18var counter = promauto.NewCounter(prometheus.CounterOpts{
19 Namespace: "catwalk",
20 Subsystem: "providers",
21 Name: "requests_total",
22 Help: "Total number of requests to the providers endpoint",
23})
24
25var (
26 providersJSON []byte
27 providersETag string
28
29 deprecatedJSON []byte
30)
31
32func init() {
33 var err error
34 providersJSON, err = json.Marshal(providers.GetAll())
35 if err != nil {
36 log.Fatal("Failed to marshal providers:", err)
37 }
38 providersETag = etag.Of(providersJSON)
39
40 deprecatedJSON, err = json.Marshal(map[string]any{"error": "This endpoint was removed. Please use /v2/providers instead."})
41 if err != nil {
42 log.Fatal("Failed to marshal deprecated response:", err)
43 }
44}
45
46func providersHandler(w http.ResponseWriter, r *http.Request) {
47 w.Header().Set("Content-Type", "application/json")
48 etag.Response(w, providersETag)
49
50 if r.Method == http.MethodHead {
51 return
52 }
53
54 if r.Method != http.MethodGet {
55 http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
56 return
57 }
58
59 counter.Inc()
60
61 if etag.Matches(r, providersETag) {
62 w.WriteHeader(http.StatusNotModified)
63 return
64 }
65
66 if _, err := w.Write(providersJSON); err != nil {
67 log.Printf("Error writing response: %v", err)
68 http.Error(w, err.Error(), http.StatusInternalServerError)
69 }
70}
71
72func providersHandlerDeprecated(w http.ResponseWriter, _ *http.Request) {
73 w.Header().Set("Content-Type", "application/json")
74
75 if _, err := w.Write(deprecatedJSON); err != nil {
76 log.Printf("Error writing response: %v", err)
77 http.Error(w, err.Error(), http.StatusInternalServerError)
78 }
79}
80
81func main() {
82 mux := http.NewServeMux()
83 mux.HandleFunc("/v2/providers", providersHandler)
84 mux.HandleFunc("/providers", providersHandlerDeprecated)
85 mux.HandleFunc("/healthz", func(w http.ResponseWriter, _ *http.Request) {
86 w.WriteHeader(http.StatusOK)
87 _, _ = w.Write([]byte("OK"))
88 })
89 mux.Handle("/metrics", promhttp.Handler())
90
91 server := &http.Server{
92 Addr: ":8080",
93 Handler: mux,
94 ReadTimeout: 15 * time.Second,
95 WriteTimeout: 15 * time.Second,
96 IdleTimeout: 60 * time.Second,
97 }
98
99 log.Println("Server starting on :8080")
100 if err := server.ListenAndServe(); err != nil {
101 log.Fatal("Server failed to start:", err)
102 }
103}