1// Package main is the main entry point for the HTTP server that serves
2// inference providers.
3package main
4
5import (
6 "encoding/json"
7 "fmt"
8 "log"
9 "net/http"
10 "time"
11
12 "github.com/charmbracelet/catwalk/internal/deprecated"
13 "github.com/charmbracelet/catwalk/internal/etag"
14 "github.com/charmbracelet/catwalk/internal/providers"
15 "github.com/prometheus/client_golang/prometheus"
16 "github.com/prometheus/client_golang/prometheus/promauto"
17 "github.com/prometheus/client_golang/prometheus/promhttp"
18)
19
20var counter = promauto.NewCounter(prometheus.CounterOpts{
21 Namespace: "catwalk",
22 Subsystem: "providers",
23 Name: "requests_total",
24 Help: "Total number of requests to the providers endpoint",
25})
26
27var (
28 providersJSON []byte
29 providersETag string
30)
31
32func init() {
33 var err error
34 providersJSON, err = json.Marshal(providers.GetAll())
35 if err != nil {
36 log.Fatal("Failed to marshal providers:", err)
37 }
38 providersETag = fmt.Sprintf(`"%s"`, etag.Of(providersJSON))
39}
40
41func providersHandler(w http.ResponseWriter, r *http.Request) {
42 w.Header().Set("Content-Type", "application/json")
43 w.Header().Set("ETag", providersETag)
44
45 if r.Method == http.MethodHead {
46 return
47 }
48
49 if r.Method != http.MethodGet {
50 http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
51 return
52 }
53
54 counter.Inc()
55
56 if match := r.Header.Get("If-None-Match"); match == providersETag {
57 w.WriteHeader(http.StatusNotModified)
58 return
59 }
60
61 if _, err := w.Write(providersJSON); err != nil {
62 log.Printf("Error writing response: %v", err)
63 http.Error(w, err.Error(), http.StatusInternalServerError)
64 }
65}
66
67func providersHandlerDeprecated(w http.ResponseWriter, r *http.Request) {
68 w.Header().Set("Content-Type", "application/json")
69 if r.Method == http.MethodHead {
70 return
71 }
72
73 if r.Method != http.MethodGet {
74 http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
75 return
76 }
77
78 counter.Inc()
79 allProviders := deprecated.GetAll()
80 if err := json.NewEncoder(w).Encode(allProviders); err != nil {
81 http.Error(w, "Internal server error", http.StatusInternalServerError)
82 return
83 }
84}
85
86func main() {
87 mux := http.NewServeMux()
88 mux.HandleFunc("/v2/providers", providersHandler)
89 mux.HandleFunc("/providers", providersHandlerDeprecated)
90 mux.HandleFunc("/healthz", func(w http.ResponseWriter, _ *http.Request) {
91 w.WriteHeader(http.StatusOK)
92 _, _ = w.Write([]byte("OK"))
93 })
94 mux.Handle("/metrics", promhttp.Handler())
95
96 server := &http.Server{
97 Addr: ":8080",
98 Handler: mux,
99 ReadTimeout: 15 * time.Second,
100 WriteTimeout: 15 * time.Second,
101 IdleTimeout: 60 * time.Second,
102 }
103
104 log.Println("Server starting on :8080")
105 if err := server.ListenAndServe(); err != nil {
106 log.Fatal("Server failed to start:", err)
107 }
108}