1// Package main is the main entry point for the HTTP server that serves
 2// inference providers.
 3package main
 4
 5import (
 6	"encoding/json"
 7	"log"
 8	"net/http"
 9	"time"
10
11	"github.com/charmbracelet/catwalk/internal/deprecated"
12	"github.com/charmbracelet/catwalk/internal/providers"
13	"github.com/prometheus/client_golang/prometheus"
14	"github.com/prometheus/client_golang/prometheus/promauto"
15	"github.com/prometheus/client_golang/prometheus/promhttp"
16)
17
18var counter = promauto.NewCounter(prometheus.CounterOpts{
19	Namespace: "catwalk",
20	Subsystem: "providers",
21	Name:      "requests_total",
22	Help:      "Total number of requests to the providers endpoint",
23})
24
25func providersHandler(w http.ResponseWriter, r *http.Request) {
26	w.Header().Set("Content-Type", "application/json")
27	if r.Method == http.MethodHead {
28		return
29	}
30
31	if r.Method != http.MethodGet {
32		http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
33		return
34	}
35
36	counter.Inc()
37	allProviders := providers.GetAll()
38	if err := json.NewEncoder(w).Encode(allProviders); err != nil {
39		http.Error(w, "Internal server error", http.StatusInternalServerError)
40		return
41	}
42}
43
44func providersHandlerDeprecated(w http.ResponseWriter, r *http.Request) {
45	w.Header().Set("Content-Type", "application/json")
46	if r.Method == http.MethodHead {
47		return
48	}
49
50	if r.Method != http.MethodGet {
51		http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
52		return
53	}
54
55	counter.Inc()
56	allProviders := deprecated.GetAll()
57	if err := json.NewEncoder(w).Encode(allProviders); err != nil {
58		http.Error(w, "Internal server error", http.StatusInternalServerError)
59		return
60	}
61}
62
63func main() {
64	mux := http.NewServeMux()
65	mux.HandleFunc("/v2/providers", providersHandler)
66	mux.HandleFunc("/providers", providersHandlerDeprecated)
67	mux.HandleFunc("/healthz", func(w http.ResponseWriter, _ *http.Request) {
68		w.WriteHeader(http.StatusOK)
69		_, _ = w.Write([]byte("OK"))
70	})
71	mux.Handle("/metrics", promhttp.Handler())
72
73	server := &http.Server{
74		Addr:         ":8080",
75		Handler:      mux,
76		ReadTimeout:  15 * time.Second,
77		WriteTimeout: 15 * time.Second,
78		IdleTimeout:  60 * time.Second,
79	}
80
81	log.Println("Server starting on :8080")
82	if err := server.ListenAndServe(); err != nil {
83		log.Fatal("Server failed to start:", err)
84	}
85}