1// Package main provides a command-line tool to fetch models from Neuralwatt
2// and generate a configuration file for the provider.
3package main
4
5import (
6 "context"
7 "encoding/json"
8 "fmt"
9 "io"
10 "log"
11 "math"
12 "net/http"
13 "os"
14 "slices"
15 "strings"
16 "time"
17
18 "charm.land/catwalk/pkg/catwalk"
19)
20
21type Pricing struct {
22 InputPerMillion *float64 `json:"input_per_million"`
23 OutputPerMillion *float64 `json:"output_per_million"`
24 CachedInputPerMillion *float64 `json:"cached_input_per_million"`
25 CachedOutputPerMillion *float64 `json:"cached_output_per_million"`
26 PricingTBD bool `json:"pricing_tbd"`
27}
28
29type Capabilities struct {
30 Tools bool `json:"tools"`
31 Vision bool `json:"vision"`
32 Reasoning bool `json:"reasoning"`
33 ReasoningEffort bool `json:"reasoning_effort"`
34}
35
36type Limits struct {
37 MaxOutputTokens *int64 `json:"max_output_tokens"`
38}
39
40type Metadata struct {
41 DisplayName string `json:"display_name"`
42 Pricing Pricing `json:"pricing"`
43 Capabilities Capabilities `json:"capabilities"`
44 Limits Limits `json:"limits"`
45 Deprecated bool `json:"deprecated"`
46}
47
48type NeuralwattModel struct {
49 ID string `json:"id"`
50 MaxModelLen int64 `json:"max_model_len"`
51 Metadata Metadata `json:"metadata"`
52}
53
54type ModelsResponse struct {
55 Data []NeuralwattModel `json:"data"`
56}
57
58func roundCost(v float64) float64 {
59 return math.Round(v*1e5) / 1e5
60}
61
62func ptrDeref[T any](v *T, fallback T) T {
63 if v == nil {
64 return fallback
65 }
66 return *v
67}
68
69func fetchNeuralwattModels(apiEndpoint string) (*ModelsResponse, error) {
70 client := &http.Client{Timeout: 30 * time.Second}
71 req, _ := http.NewRequestWithContext(context.Background(), "GET", apiEndpoint+"/models", nil)
72 req.Header.Set("User-Agent", "Crush-Client/1.0")
73
74 resp, err := client.Do(req)
75 if err != nil {
76 return nil, fmt.Errorf("fetching models: %w", err)
77 }
78 defer func() { _ = resp.Body.Close() }()
79
80 body, err := io.ReadAll(resp.Body)
81 if err != nil {
82 return nil, fmt.Errorf("reading models response: %w", err)
83 }
84
85 if resp.StatusCode != 200 {
86 return nil, fmt.Errorf("status %d: %s", resp.StatusCode, body)
87 }
88
89 _ = os.MkdirAll("tmp", 0o700)
90 _ = os.WriteFile("tmp/neuralwatt-response.json", body, 0o600)
91
92 var mr ModelsResponse
93 if err := json.Unmarshal(body, &mr); err != nil {
94 return nil, fmt.Errorf("decoding models response: %w", err)
95 }
96
97 return &mr, nil
98}
99
100func fallbackDisplayName(id string) string {
101 name := id
102 if idx := strings.Index(name, "/"); idx != -1 {
103 name = name[idx+1:]
104 }
105 return strings.ReplaceAll(name, "-", " ")
106}
107
108func main() {
109 neuralwattProvider := catwalk.Provider{
110 Name: "Neuralwatt",
111 ID: "neuralwatt",
112 APIKey: "$NEURALWATT_API_KEY",
113 APIEndpoint: "https://api.neuralwatt.com/v1",
114 Type: catwalk.TypeOpenAICompat,
115 DefaultLargeModelID: "zai-org/GLM-5.1-FP8",
116 DefaultSmallModelID: "mistralai/Devstral-Small-2-24B-Instruct-2512",
117 }
118
119 modelsResp, err := fetchNeuralwattModels(neuralwattProvider.APIEndpoint)
120 if err != nil {
121 log.Fatal("Error fetching Neuralwatt models:", err)
122 }
123
124 for _, model := range modelsResp.Data {
125 meta := model.Metadata
126
127 if meta.Deprecated {
128 fmt.Printf("Skipping deprecated model %s\n", model.ID)
129 continue
130 }
131
132 // Skip models with small context windows
133 if model.MaxModelLen < 20000 {
134 fmt.Printf("Skipping model %s: context %d < 20000\n",
135 model.ID, model.MaxModelLen)
136 continue
137 }
138
139 if !meta.Capabilities.Tools {
140 fmt.Printf("Skipping model %s (no tool support)\n", model.ID)
141 continue
142 }
143
144 costIn := ptrDeref(meta.Pricing.InputPerMillion, 0)
145 costOut := ptrDeref(meta.Pricing.OutputPerMillion, 0)
146 // Null cached pricing means same as non-cached
147 costInCached := ptrDeref(meta.Pricing.CachedInputPerMillion, costIn)
148 costOutCached := ptrDeref(meta.Pricing.CachedOutputPerMillion, costOut)
149
150 var defaultMaxTokens int64
151 if meta.Limits.MaxOutputTokens != nil {
152 defaultMaxTokens = *meta.Limits.MaxOutputTokens
153 } else {
154 defaultMaxTokens = model.MaxModelLen / 10
155 }
156
157 var reasoningLevels []string
158 var defaultReasoning string
159 if meta.Capabilities.ReasoningEffort {
160 reasoningLevels = []string{"low", "medium", "high"}
161 defaultReasoning = "medium"
162 }
163
164 name := meta.DisplayName
165 if name == "" {
166 name = fallbackDisplayName(model.ID)
167 }
168
169 m := catwalk.Model{
170 ID: model.ID,
171 Name: name,
172 CostPer1MIn: roundCost(costIn),
173 CostPer1MOut: roundCost(costOut),
174 CostPer1MInCached: roundCost(costInCached),
175 CostPer1MOutCached: roundCost(costOutCached),
176 ContextWindow: model.MaxModelLen,
177 DefaultMaxTokens: defaultMaxTokens,
178 CanReason: meta.Capabilities.Reasoning,
179 DefaultReasoningEffort: defaultReasoning,
180 ReasoningLevels: reasoningLevels,
181 SupportsImages: meta.Capabilities.Vision,
182 }
183
184 neuralwattProvider.Models = append(neuralwattProvider.Models, m)
185 fmt.Printf("Added model %s with context window %d\n", model.ID, model.MaxModelLen)
186 }
187
188 slices.SortFunc(neuralwattProvider.Models, func(a catwalk.Model, b catwalk.Model) int {
189 return strings.Compare(a.Name, b.Name)
190 })
191
192 data, err := json.MarshalIndent(neuralwattProvider, "", " ")
193 if err != nil {
194 log.Fatal("Error marshaling Neuralwatt provider:", err)
195 }
196 data = append(data, '\n')
197
198 if err := os.WriteFile("internal/providers/configs/neuralwatt.json", data, 0o600); err != nil {
199 log.Fatal("Error writing Neuralwatt provider config:", err)
200 }
201
202 fmt.Printf("Generated neuralwatt.json with %d models\n", len(neuralwattProvider.Models))
203}