main.go

  1// Package main provides a command-line tool to fetch models from Neuralwatt
  2// and generate a configuration file for the provider.
  3package main
  4
  5import (
  6	"context"
  7	"encoding/json"
  8	"fmt"
  9	"io"
 10	"log"
 11	"math"
 12	"net/http"
 13	"os"
 14	"slices"
 15	"strings"
 16	"time"
 17
 18	"charm.land/catwalk/pkg/catwalk"
 19)
 20
 21type Pricing struct {
 22	InputPerMillion        *float64 `json:"input_per_million"`
 23	OutputPerMillion       *float64 `json:"output_per_million"`
 24	CachedInputPerMillion  *float64 `json:"cached_input_per_million"`
 25	CachedOutputPerMillion *float64 `json:"cached_output_per_million"`
 26	PricingTBD             bool     `json:"pricing_tbd"`
 27}
 28
 29type Capabilities struct {
 30	Tools           bool `json:"tools"`
 31	Vision          bool `json:"vision"`
 32	Reasoning       bool `json:"reasoning"`
 33	ReasoningEffort bool `json:"reasoning_effort"`
 34}
 35
 36type Limits struct {
 37	MaxOutputTokens *int64 `json:"max_output_tokens"`
 38}
 39
 40type Metadata struct {
 41	DisplayName  string       `json:"display_name"`
 42	Pricing      Pricing      `json:"pricing"`
 43	Capabilities Capabilities `json:"capabilities"`
 44	Limits       Limits       `json:"limits"`
 45	Deprecated   bool         `json:"deprecated"`
 46}
 47
 48type NeuralwattModel struct {
 49	ID          string   `json:"id"`
 50	MaxModelLen int64    `json:"max_model_len"`
 51	Metadata    Metadata `json:"metadata"`
 52}
 53
 54type ModelsResponse struct {
 55	Data []NeuralwattModel `json:"data"`
 56}
 57
 58func roundCost(v float64) float64 {
 59	return math.Round(v*1e5) / 1e5
 60}
 61
 62func ptrDeref[T any](v *T, fallback T) T {
 63	if v == nil {
 64		return fallback
 65	}
 66	return *v
 67}
 68
 69func fetchNeuralwattModels(apiEndpoint string) (*ModelsResponse, error) {
 70	client := &http.Client{Timeout: 30 * time.Second}
 71	req, _ := http.NewRequestWithContext(context.Background(), "GET", apiEndpoint+"/models", nil)
 72	req.Header.Set("User-Agent", "Crush-Client/1.0")
 73
 74	resp, err := client.Do(req)
 75	if err != nil {
 76		return nil, fmt.Errorf("fetching models: %w", err)
 77	}
 78	defer func() { _ = resp.Body.Close() }()
 79
 80	body, err := io.ReadAll(resp.Body)
 81	if err != nil {
 82		return nil, fmt.Errorf("reading models response: %w", err)
 83	}
 84
 85	if resp.StatusCode != 200 {
 86		return nil, fmt.Errorf("status %d: %s", resp.StatusCode, body)
 87	}
 88
 89	_ = os.MkdirAll("tmp", 0o700)
 90	_ = os.WriteFile("tmp/neuralwatt-response.json", body, 0o600)
 91
 92	var mr ModelsResponse
 93	if err := json.Unmarshal(body, &mr); err != nil {
 94		return nil, fmt.Errorf("decoding models response: %w", err)
 95	}
 96
 97	return &mr, nil
 98}
 99
100func fallbackDisplayName(id string) string {
101	name := id
102	if idx := strings.Index(name, "/"); idx != -1 {
103		name = name[idx+1:]
104	}
105	return strings.ReplaceAll(name, "-", " ")
106}
107
108func main() {
109	neuralwattProvider := catwalk.Provider{
110		Name:                "Neuralwatt",
111		ID:                  "neuralwatt",
112		APIKey:              "$NEURALWATT_API_KEY",
113		APIEndpoint:         "https://api.neuralwatt.com/v1",
114		Type:                catwalk.TypeOpenAICompat,
115		DefaultLargeModelID: "zai-org/GLM-5.1-FP8",
116		DefaultSmallModelID: "mistralai/Devstral-Small-2-24B-Instruct-2512",
117	}
118
119	modelsResp, err := fetchNeuralwattModels(neuralwattProvider.APIEndpoint)
120	if err != nil {
121		log.Fatal("Error fetching Neuralwatt models:", err)
122	}
123
124	for _, model := range modelsResp.Data {
125		meta := model.Metadata
126
127		if meta.Deprecated {
128			fmt.Printf("Skipping deprecated model %s\n", model.ID)
129			continue
130		}
131
132		// Skip models with small context windows
133		if model.MaxModelLen < 20000 {
134			fmt.Printf("Skipping model %s: context %d < 20000\n",
135				model.ID, model.MaxModelLen)
136			continue
137		}
138
139		if !meta.Capabilities.Tools {
140			fmt.Printf("Skipping model %s (no tool support)\n", model.ID)
141			continue
142		}
143
144		costIn := ptrDeref(meta.Pricing.InputPerMillion, 0)
145		costOut := ptrDeref(meta.Pricing.OutputPerMillion, 0)
146		// Null cached pricing means same as non-cached
147		costInCached := ptrDeref(meta.Pricing.CachedInputPerMillion, costIn)
148		costOutCached := ptrDeref(meta.Pricing.CachedOutputPerMillion, costOut)
149
150		var defaultMaxTokens int64
151		if meta.Limits.MaxOutputTokens != nil {
152			defaultMaxTokens = *meta.Limits.MaxOutputTokens
153		} else {
154			defaultMaxTokens = model.MaxModelLen / 10
155		}
156
157		var reasoningLevels []string
158		var defaultReasoning string
159		if meta.Capabilities.ReasoningEffort {
160			reasoningLevels = []string{"low", "medium", "high"}
161			defaultReasoning = "medium"
162		}
163
164		name := meta.DisplayName
165		if name == "" {
166			name = fallbackDisplayName(model.ID)
167		}
168
169		m := catwalk.Model{
170			ID:                     model.ID,
171			Name:                   name,
172			CostPer1MIn:            roundCost(costIn),
173			CostPer1MOut:           roundCost(costOut),
174			CostPer1MInCached:      roundCost(costInCached),
175			CostPer1MOutCached:     roundCost(costOutCached),
176			ContextWindow:          model.MaxModelLen,
177			DefaultMaxTokens:       defaultMaxTokens,
178			CanReason:              meta.Capabilities.Reasoning,
179			DefaultReasoningEffort: defaultReasoning,
180			ReasoningLevels:        reasoningLevels,
181			SupportsImages:         meta.Capabilities.Vision,
182		}
183
184		neuralwattProvider.Models = append(neuralwattProvider.Models, m)
185		fmt.Printf("Added model %s with context window %d\n", model.ID, model.MaxModelLen)
186	}
187
188	slices.SortFunc(neuralwattProvider.Models, func(a catwalk.Model, b catwalk.Model) int {
189		return strings.Compare(a.Name, b.Name)
190	})
191
192	data, err := json.MarshalIndent(neuralwattProvider, "", "  ")
193	if err != nil {
194		log.Fatal("Error marshaling Neuralwatt provider:", err)
195	}
196	data = append(data, '\n')
197
198	if err := os.WriteFile("internal/providers/configs/neuralwatt.json", data, 0o600); err != nil {
199		log.Fatal("Error writing Neuralwatt provider config:", err)
200	}
201
202	fmt.Printf("Generated neuralwatt.json with %d models\n", len(neuralwattProvider.Models))
203}