main.go

  1// Package main provides a command-line tool to fetch models from Neuralwatt
  2// and generate a configuration file for the provider.
  3package main
  4
  5import (
  6	"context"
  7	"encoding/json"
  8	"fmt"
  9	"io"
 10	"log"
 11	"math"
 12	"net/http"
 13	"os"
 14	"slices"
 15	"strings"
 16	"time"
 17
 18	"charm.land/catwalk/pkg/catwalk"
 19)
 20
 21type NeuralwattModel struct {
 22	ID          string `json:"id"`
 23	MaxModelLen int64  `json:"max_model_len"`
 24}
 25
 26type ModelsResponse struct {
 27	Data []NeuralwattModel `json:"data"`
 28}
 29
 30// ModelMeta contains the hardcoded metadata for a Neuralwatt model.
 31// The API only returns id and max_model_len, so pricing and capabilities
 32// are sourced from the pricing page at https://portal.neuralwatt.com/pricing.
 33type ModelMeta struct {
 34	Tools        bool
 35	Reasoning    bool
 36	Vision       bool
 37	CostPer1MIn  float64
 38	CostPer1MOut float64
 39}
 40
 41var modelMetadata = map[string]ModelMeta{
 42	"mistralai/Devstral-Small-2-24B-Instruct-2512": {
 43		Tools:        true,
 44		Reasoning:    false,
 45		Vision:       true,
 46		CostPer1MIn:  0.1,
 47		CostPer1MOut: 0.3,
 48	},
 49	"zai-org/GLM-5.1-FP8": {
 50		Tools:        true,
 51		Reasoning:    true,
 52		Vision:       false,
 53		CostPer1MIn:  1.1,
 54		CostPer1MOut: 3.6,
 55	},
 56	"glm-5.1-fast": {
 57		Tools:        true,
 58		Reasoning:    false,
 59		Vision:       false,
 60		CostPer1MIn:  1.1,
 61		CostPer1MOut: 3.6,
 62	},
 63	"openai/gpt-oss-20b": {
 64		Tools:        true,
 65		Reasoning:    false,
 66		Vision:       false,
 67		CostPer1MIn:  0.0,
 68		CostPer1MOut: 0.2,
 69	},
 70	"moonshotai/Kimi-K2.5": {
 71		Tools:        true,
 72		Reasoning:    false,
 73		Vision:       true,
 74		CostPer1MIn:  0.5,
 75		CostPer1MOut: 2.6,
 76	},
 77	"kimi-k2.5-fast": {
 78		Tools:        true,
 79		Reasoning:    false,
 80		Vision:       true,
 81		CostPer1MIn:  0.5,
 82		CostPer1MOut: 2.6,
 83	},
 84	"MiniMaxAI/MiniMax-M2.5": {
 85		Tools:        true,
 86		Reasoning:    true,
 87		Vision:       false,
 88		CostPer1MIn:  0.3,
 89		CostPer1MOut: 1.4,
 90	},
 91	"Qwen/Qwen3.5-35B-A3B": {
 92		Tools:        true,
 93		Reasoning:    true,
 94		Vision:       false,
 95		CostPer1MIn:  0.3,
 96		CostPer1MOut: 1.1,
 97	},
 98	"Qwen/Qwen3.5-397B-A17B-FP8": {
 99		Tools:        true,
100		Reasoning:    true,
101		Vision:       false,
102		CostPer1MIn:  0.7,
103		CostPer1MOut: 4.1,
104	},
105	"qwen3.5-397b-fast": {
106		Tools:        true,
107		Reasoning:    false,
108		Vision:       false,
109		CostPer1MIn:  0.7,
110		CostPer1MOut: 4.1,
111	},
112}
113
114// modelNames provides display names for Neuralwatt-owned models that lack an
115// org prefix and use lowercase IDs.
116var modelNames = map[string]string{
117	"glm-5.1-fast":      "GLM 5.1 Fast",
118	"kimi-k2.5-fast":    "Kimi K2.5 Fast",
119	"qwen3.5-397b-fast": "Qwen3.5 397B Fast",
120}
121
122func roundCost(v float64) float64 {
123	return math.Round(v*1e5) / 1e5
124}
125
126// modelDisplayName converts a model ID to a human-readable display name. For
127// models with an org prefix (e.g. "zai-org/GLM-5-FP8"), the prefix is stripped.
128// Neuralwatt-owned models without a prefix are looked up in modelNames for
129// proper casing.
130func modelDisplayName(id string) string {
131	if name, ok := modelNames[id]; ok {
132		return name
133	}
134
135	name := id
136	if idx := strings.Index(name, "/"); idx != -1 {
137		name = name[idx+1:]
138	}
139	name = strings.ReplaceAll(name, "-", " ")
140	return name
141}
142
143func fetchNeuralwattModels(apiEndpoint string) (*ModelsResponse, error) {
144	client := &http.Client{Timeout: 30 * time.Second}
145	req, _ := http.NewRequestWithContext(context.Background(), "GET", apiEndpoint+"/models", nil)
146	req.Header.Set("User-Agent", "Crush-Client/1.0")
147
148	resp, err := client.Do(req)
149	if err != nil {
150		return nil, fmt.Errorf("fetching models: %w", err)
151	}
152	defer func() { _ = resp.Body.Close() }()
153
154	body, err := io.ReadAll(resp.Body)
155	if err != nil {
156		return nil, fmt.Errorf("reading models response: %w", err)
157	}
158
159	if resp.StatusCode != 200 {
160		return nil, fmt.Errorf("status %d: %s", resp.StatusCode, body)
161	}
162
163	_ = os.MkdirAll("tmp", 0o700)
164	_ = os.WriteFile("tmp/neuralwatt-response.json", body, 0o600)
165
166	var mr ModelsResponse
167	if err := json.Unmarshal(body, &mr); err != nil {
168		return nil, fmt.Errorf("decoding models response: %w", err)
169	}
170
171	return &mr, nil
172}
173
174func main() {
175	neuralwattProvider := catwalk.Provider{
176		Name:                "Neuralwatt",
177		ID:                  "neuralwatt",
178		APIKey:              "$NEURALWATT_API_KEY",
179		APIEndpoint:         "https://api.neuralwatt.com/v1",
180		Type:                catwalk.TypeOpenAICompat,
181		DefaultLargeModelID: "zai-org/GLM-5.1-FP8",
182		DefaultSmallModelID: "mistralai/Devstral-Small-2-24B-Instruct-2512",
183	}
184
185	modelsResp, err := fetchNeuralwattModels(neuralwattProvider.APIEndpoint)
186	if err != nil {
187		log.Fatal("Error fetching Neuralwatt models:", err)
188	}
189
190	for _, model := range modelsResp.Data {
191		// Skip models with small context windows
192		if model.MaxModelLen < 20000 {
193			fmt.Printf("Skipping model %s: context %d < 20000\n",
194				model.ID, model.MaxModelLen)
195			continue
196		}
197
198		meta, ok := modelMetadata[model.ID]
199		if !ok {
200			fmt.Printf("Skipping unknown model %s (no metadata)\n", model.ID)
201			continue
202		}
203
204		// Only include models that support tools
205		if !meta.Tools {
206			continue
207		}
208
209		var reasoningLevels []string
210		var defaultReasoning string
211		if meta.Reasoning {
212			reasoningLevels = []string{"low", "medium", "high"}
213			defaultReasoning = "medium"
214		}
215
216		m := catwalk.Model{
217			ID:                     model.ID,
218			Name:                   modelDisplayName(model.ID),
219			CostPer1MIn:            roundCost(meta.CostPer1MIn),
220			CostPer1MOut:           roundCost(meta.CostPer1MOut),
221			CostPer1MInCached:      0, // Not available
222			CostPer1MOutCached:     0, // Not available
223			ContextWindow:          model.MaxModelLen,
224			DefaultMaxTokens:       model.MaxModelLen / 10,
225			CanReason:              meta.Reasoning,
226			DefaultReasoningEffort: defaultReasoning,
227			ReasoningLevels:        reasoningLevels,
228			SupportsImages:         meta.Vision,
229		}
230
231		neuralwattProvider.Models = append(neuralwattProvider.Models, m)
232		fmt.Printf("Added model %s with context window %d\n", model.ID, model.MaxModelLen)
233	}
234
235	slices.SortFunc(neuralwattProvider.Models, func(a catwalk.Model, b catwalk.Model) int {
236		return strings.Compare(a.Name, b.Name)
237	})
238
239	data, err := json.MarshalIndent(neuralwattProvider, "", "  ")
240	if err != nil {
241		log.Fatal("Error marshaling Neuralwatt provider:", err)
242	}
243	data = append(data, '\n')
244
245	if err := os.WriteFile("internal/providers/configs/neuralwatt.json", data, 0o600); err != nil {
246		log.Fatal("Error writing Neuralwatt provider config:", err)
247	}
248
249	fmt.Printf("Generated neuralwatt.json with %d models\n", len(neuralwattProvider.Models))
250}