1// Package main is the main entry point for the HTTP server that serves
2// inference providers.
3package main
4
5import (
6 "encoding/json"
7 "log"
8 "net/http"
9 "time"
10
11 "github.com/charmbracelet/catwalk/internal/deprecated"
12 "github.com/charmbracelet/catwalk/internal/providers"
13 "github.com/charmbracelet/x/etag"
14 "github.com/prometheus/client_golang/prometheus"
15 "github.com/prometheus/client_golang/prometheus/promauto"
16 "github.com/prometheus/client_golang/prometheus/promhttp"
17)
18
19var counter = promauto.NewCounter(prometheus.CounterOpts{
20 Namespace: "catwalk",
21 Subsystem: "providers",
22 Name: "requests_total",
23 Help: "Total number of requests to the providers endpoint",
24})
25
26var (
27 providersJSON []byte
28 providersETag string
29)
30
31func init() {
32 var err error
33 providersJSON, err = json.Marshal(providers.GetAll())
34 if err != nil {
35 log.Fatal("Failed to marshal providers:", err)
36 }
37 providersETag = etag.Of(providersJSON)
38}
39
40func providersHandler(w http.ResponseWriter, r *http.Request) {
41 w.Header().Set("Content-Type", "application/json")
42 etag.Response(w, providersETag)
43
44 if r.Method == http.MethodHead {
45 return
46 }
47
48 if r.Method != http.MethodGet {
49 http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
50 return
51 }
52
53 counter.Inc()
54
55 if etag.Matches(r, providersETag) {
56 w.WriteHeader(http.StatusNotModified)
57 return
58 }
59
60 if _, err := w.Write(providersJSON); err != nil {
61 log.Printf("Error writing response: %v", err)
62 http.Error(w, err.Error(), http.StatusInternalServerError)
63 }
64}
65
66func providersHandlerDeprecated(w http.ResponseWriter, r *http.Request) {
67 w.Header().Set("Content-Type", "application/json")
68 if r.Method == http.MethodHead {
69 return
70 }
71
72 if r.Method != http.MethodGet {
73 http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
74 return
75 }
76
77 counter.Inc()
78 allProviders := deprecated.GetAll()
79 if err := json.NewEncoder(w).Encode(allProviders); err != nil {
80 http.Error(w, "Internal server error", http.StatusInternalServerError)
81 return
82 }
83}
84
85func main() {
86 mux := http.NewServeMux()
87 mux.HandleFunc("/v2/providers", providersHandler)
88 mux.HandleFunc("/providers", providersHandlerDeprecated)
89 mux.HandleFunc("/healthz", func(w http.ResponseWriter, _ *http.Request) {
90 w.WriteHeader(http.StatusOK)
91 _, _ = w.Write([]byte("OK"))
92 })
93 mux.Handle("/metrics", promhttp.Handler())
94
95 server := &http.Server{
96 Addr: ":8080",
97 Handler: mux,
98 ReadTimeout: 15 * time.Second,
99 WriteTimeout: 15 * time.Second,
100 IdleTimeout: 60 * time.Second,
101 }
102
103 log.Println("Server starting on :8080")
104 if err := server.ListenAndServe(); err != nil {
105 log.Fatal("Server failed to start:", err)
106 }
107}