main.go

  1// Package main provides a command-line tool to fetch models from OpenRouter
  2// and generate a configuration file for the provider.
  3package main
  4
  5import (
  6	"context"
  7	"encoding/json"
  8	"fmt"
  9	"io"
 10	"log"
 11	"net/http"
 12	"os"
 13	"slices"
 14	"strconv"
 15	"strings"
 16	"time"
 17
 18	"charm.land/catwalk/pkg/catwalk"
 19)
 20
 21// Model represents the complete model configuration.
 22type Model struct {
 23	ID              string       `json:"id"`
 24	CanonicalSlug   string       `json:"canonical_slug"`
 25	HuggingFaceID   string       `json:"hugging_face_id"`
 26	Name            string       `json:"name"`
 27	Created         int64        `json:"created"`
 28	Description     string       `json:"description"`
 29	ContextLength   int64        `json:"context_length"`
 30	Architecture    Architecture `json:"architecture"`
 31	Pricing         Pricing      `json:"pricing"`
 32	TopProvider     TopProvider  `json:"top_provider"`
 33	SupportedParams []string     `json:"supported_parameters"`
 34}
 35
 36// Architecture defines the model's architecture details.
 37type Architecture struct {
 38	Modality         string   `json:"modality"`
 39	InputModalities  []string `json:"input_modalities"`
 40	OutputModalities []string `json:"output_modalities"`
 41	Tokenizer        string   `json:"tokenizer"`
 42	InstructType     *string  `json:"instruct_type"`
 43}
 44
 45// Pricing contains the pricing information for different operations.
 46type Pricing struct {
 47	Prompt            string `json:"prompt"`
 48	Completion        string `json:"completion"`
 49	Request           string `json:"request"`
 50	Image             string `json:"image"`
 51	WebSearch         string `json:"web_search"`
 52	InternalReasoning string `json:"internal_reasoning"`
 53	InputCacheRead    string `json:"input_cache_read"`
 54	InputCacheWrite   string `json:"input_cache_write"`
 55}
 56
 57// TopProvider describes the top provider's capabilities.
 58type TopProvider struct {
 59	ContextLength       int64  `json:"context_length"`
 60	MaxCompletionTokens *int64 `json:"max_completion_tokens"`
 61	IsModerated         bool   `json:"is_moderated"`
 62}
 63
 64// Endpoint represents a single endpoint configuration for a model.
 65type Endpoint struct {
 66	Name                string   `json:"name"`
 67	ContextLength       int64    `json:"context_length"`
 68	Pricing             Pricing  `json:"pricing"`
 69	ProviderName        string   `json:"provider_name"`
 70	Tag                 string   `json:"tag"`
 71	Quantization        *string  `json:"quantization"`
 72	MaxCompletionTokens *int64   `json:"max_completion_tokens"`
 73	MaxPromptTokens     *int64   `json:"max_prompt_tokens"`
 74	SupportedParams     []string `json:"supported_parameters"`
 75	Status              int      `json:"status"`
 76	UptimeLast30m       float64  `json:"uptime_last_30m"`
 77}
 78
 79// EndpointsResponse is the response structure for the endpoints API.
 80type EndpointsResponse struct {
 81	Data struct {
 82		ID          string     `json:"id"`
 83		Name        string     `json:"name"`
 84		Created     int64      `json:"created"`
 85		Description string     `json:"description"`
 86		Endpoints   []Endpoint `json:"endpoints"`
 87	} `json:"data"`
 88}
 89
 90// ModelsResponse is the response structure for the models API.
 91type ModelsResponse struct {
 92	Data []Model `json:"data"`
 93}
 94
 95// ModelPricing is the pricing structure for a model, detailing costs per
 96// million tokens for input and output, both cached and uncached.
 97type ModelPricing struct {
 98	CostPer1MIn        float64 `json:"cost_per_1m_in"`
 99	CostPer1MOut       float64 `json:"cost_per_1m_out"`
100	CostPer1MInCached  float64 `json:"cost_per_1m_in_cached"`
101	CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"`
102}
103
104func getPricing(model Model) ModelPricing {
105	pricing := ModelPricing{}
106	costPrompt, err := strconv.ParseFloat(model.Pricing.Prompt, 64)
107	if err != nil {
108		costPrompt = 0.0
109	}
110	pricing.CostPer1MIn = costPrompt * 1_000_000
111	costCompletion, err := strconv.ParseFloat(model.Pricing.Completion, 64)
112	if err != nil {
113		costCompletion = 0.0
114	}
115	pricing.CostPer1MOut = costCompletion * 1_000_000
116
117	costPromptCached, err := strconv.ParseFloat(model.Pricing.InputCacheWrite, 64)
118	if err != nil {
119		costPromptCached = 0.0
120	}
121	pricing.CostPer1MInCached = costPromptCached * 1_000_000
122	costCompletionCached, err := strconv.ParseFloat(model.Pricing.InputCacheRead, 64)
123	if err != nil {
124		costCompletionCached = 0.0
125	}
126	pricing.CostPer1MOutCached = costCompletionCached * 1_000_000
127	return pricing
128}
129
130func fetchOpenRouterModels() (*ModelsResponse, error) {
131	client := &http.Client{Timeout: 30 * time.Second}
132	req, _ := http.NewRequestWithContext(
133		context.Background(),
134		"GET",
135		"https://openrouter.ai/api/v1/models",
136		nil,
137	)
138	req.Header.Set("User-Agent", "Crush-Client/1.0")
139
140	resp, err := client.Do(req)
141	if err != nil {
142		return nil, err //nolint:wrapcheck
143	}
144	defer resp.Body.Close() //nolint:errcheck
145
146	body, err := io.ReadAll(resp.Body)
147	if err != nil {
148		return nil, fmt.Errorf("unable to read models response body: %w", err)
149	}
150
151	if resp.StatusCode != http.StatusOK {
152		return nil, fmt.Errorf("status %d: %s", resp.StatusCode, body)
153	}
154
155	// for debugging
156	_ = os.MkdirAll("tmp", 0o700)
157	_ = os.WriteFile("tmp/openrouter-response.json", body, 0o600)
158
159	var mr ModelsResponse
160	if err := json.Unmarshal(body, &mr); err != nil {
161		return nil, err //nolint:wrapcheck
162	}
163	return &mr, nil
164}
165
166func fetchModelEndpoints(modelID string) (*EndpointsResponse, error) {
167	client := &http.Client{Timeout: 30 * time.Second}
168	url := fmt.Sprintf("https://openrouter.ai/api/v1/models/%s/endpoints", modelID)
169	req, _ := http.NewRequestWithContext(
170		context.Background(),
171		"GET",
172		url,
173		nil,
174	)
175	req.Header.Set("User-Agent", "Crush-Client/1.0")
176	resp, err := client.Do(req)
177	if err != nil {
178		return nil, err //nolint:wrapcheck
179	}
180	defer resp.Body.Close() //nolint:errcheck
181	if resp.StatusCode != 200 {
182		body, _ := io.ReadAll(resp.Body)
183		return nil, fmt.Errorf("status %d: %s", resp.StatusCode, body)
184	}
185	var er EndpointsResponse
186	if err := json.NewDecoder(resp.Body).Decode(&er); err != nil {
187		return nil, err //nolint:wrapcheck
188	}
189	return &er, nil
190}
191
192func selectBestEndpoint(endpoints []Endpoint) *Endpoint {
193	if len(endpoints) == 0 {
194		return nil
195	}
196
197	var best *Endpoint
198	for i := range endpoints {
199		endpoint := &endpoints[i]
200		// Skip endpoints with poor status or uptime
201		if endpoint.Status < 0 || endpoint.UptimeLast30m < 90.0 {
202			continue
203		}
204
205		if best == nil {
206			best = endpoint
207			continue
208		}
209
210		if isBetterEndpoint(endpoint, best) {
211			best = endpoint
212		}
213	}
214
215	// If no good endpoint found, return the first one as fallback
216	if best == nil {
217		best = &endpoints[0]
218	}
219
220	return best
221}
222
223func isBetterEndpoint(candidate, current *Endpoint) bool {
224	candidateHasTools := slices.Contains(candidate.SupportedParams, "tools")
225	currentHasTools := slices.Contains(current.SupportedParams, "tools")
226
227	// Prefer endpoints with tool support over those without
228	if candidateHasTools && !currentHasTools {
229		return true
230	}
231	if !candidateHasTools && currentHasTools {
232		return false
233	}
234
235	// Both have same tool support status, compare other factors
236	if candidate.ContextLength > current.ContextLength {
237		return true
238	}
239	if candidate.ContextLength == current.ContextLength {
240		return candidate.UptimeLast30m > current.UptimeLast30m
241	}
242
243	return false
244}
245
246// This is used to generate the openrouter.json config file.
247func main() {
248	modelsResp, err := fetchOpenRouterModels()
249	if err != nil {
250		log.Fatal("Error fetching OpenRouter models:", err)
251	}
252
253	openRouterProvider := catwalk.Provider{
254		Name:                "OpenRouter",
255		ID:                  "openrouter",
256		APIKey:              "$OPENROUTER_API_KEY",
257		APIEndpoint:         "https://openrouter.ai/api/v1",
258		Type:                catwalk.TypeOpenRouter,
259		DefaultLargeModelID: "anthropic/claude-sonnet-4",
260		DefaultSmallModelID: "anthropic/claude-3.5-haiku",
261		Models:              []catwalk.Model{},
262		DefaultHeaders: map[string]string{
263			"HTTP-Referer": "https://charm.land",
264			"X-Title":      "Crush",
265		},
266	}
267
268	for _, model := range modelsResp.Data {
269		if model.ContextLength < 20000 {
270			continue
271		}
272		// skip non‐text models or those without tools
273		if !slices.Contains(model.SupportedParams, "tools") ||
274			!slices.Contains(model.Architecture.InputModalities, "text") ||
275			!slices.Contains(model.Architecture.OutputModalities, "text") {
276			continue
277		}
278
279		// Fetch endpoints for this model to get the best configuration
280		endpointsResp, err := fetchModelEndpoints(model.ID)
281		if err != nil {
282			fmt.Printf("Warning: Failed to fetch endpoints for %s: %v\n", model.ID, err)
283			// Fall back to using the original model data
284			pricing := getPricing(model)
285			canReason := slices.Contains(model.SupportedParams, "reasoning")
286			supportsImages := slices.Contains(model.Architecture.InputModalities, "image")
287
288			var reasoningLevels []string
289			var defaultReasoning string
290			if canReason {
291				reasoningLevels = []string{"low", "medium", "high"}
292				defaultReasoning = "medium"
293			}
294			m := catwalk.Model{
295				ID:                     model.ID,
296				Name:                   model.Name,
297				CostPer1MIn:            pricing.CostPer1MIn,
298				CostPer1MOut:           pricing.CostPer1MOut,
299				CostPer1MInCached:      pricing.CostPer1MInCached,
300				CostPer1MOutCached:     pricing.CostPer1MOutCached,
301				ContextWindow:          model.ContextLength,
302				CanReason:              canReason,
303				DefaultReasoningEffort: defaultReasoning,
304				ReasoningLevels:        reasoningLevels,
305				SupportsImages:         supportsImages,
306			}
307			if model.TopProvider.MaxCompletionTokens != nil {
308				m.DefaultMaxTokens = *model.TopProvider.MaxCompletionTokens / 2
309			} else {
310				m.DefaultMaxTokens = model.ContextLength / 10
311			}
312			if model.TopProvider.ContextLength > 0 {
313				m.ContextWindow = model.TopProvider.ContextLength
314			}
315			openRouterProvider.Models = append(openRouterProvider.Models, m)
316			continue
317		}
318
319		// Select the best endpoint
320		bestEndpoint := selectBestEndpoint(endpointsResp.Data.Endpoints)
321		if bestEndpoint == nil {
322			fmt.Printf("Warning: No suitable endpoint found for %s\n", model.ID)
323			continue
324		}
325
326		// Check if the best endpoint supports tools
327		if !slices.Contains(bestEndpoint.SupportedParams, "tools") {
328			continue
329		}
330
331		// Use the best endpoint's configuration
332		pricing := ModelPricing{}
333		costPrompt, err := strconv.ParseFloat(bestEndpoint.Pricing.Prompt, 64)
334		if err != nil {
335			costPrompt = 0.0
336		}
337		pricing.CostPer1MIn = costPrompt * 1_000_000
338		costCompletion, err := strconv.ParseFloat(bestEndpoint.Pricing.Completion, 64)
339		if err != nil {
340			costCompletion = 0.0
341		}
342		pricing.CostPer1MOut = costCompletion * 1_000_000
343
344		costPromptCached, err := strconv.ParseFloat(bestEndpoint.Pricing.InputCacheWrite, 64)
345		if err != nil {
346			costPromptCached = 0.0
347		}
348		pricing.CostPer1MInCached = costPromptCached * 1_000_000
349		costCompletionCached, err := strconv.ParseFloat(bestEndpoint.Pricing.InputCacheRead, 64)
350		if err != nil {
351			costCompletionCached = 0.0
352		}
353		pricing.CostPer1MOutCached = costCompletionCached * 1_000_000
354
355		canReason := slices.Contains(bestEndpoint.SupportedParams, "reasoning")
356		supportsImages := slices.Contains(model.Architecture.InputModalities, "image")
357
358		var reasoningLevels []string
359		var defaultReasoning string
360		if canReason {
361			reasoningLevels = []string{"low", "medium", "high"}
362			defaultReasoning = "medium"
363		}
364		m := catwalk.Model{
365			ID:                     model.ID,
366			Name:                   model.Name,
367			CostPer1MIn:            pricing.CostPer1MIn,
368			CostPer1MOut:           pricing.CostPer1MOut,
369			CostPer1MInCached:      pricing.CostPer1MInCached,
370			CostPer1MOutCached:     pricing.CostPer1MOutCached,
371			ContextWindow:          bestEndpoint.ContextLength,
372			CanReason:              canReason,
373			DefaultReasoningEffort: defaultReasoning,
374			ReasoningLevels:        reasoningLevels,
375			SupportsImages:         supportsImages,
376		}
377
378		// Set max tokens based on the best endpoint
379		if bestEndpoint.MaxCompletionTokens != nil {
380			m.DefaultMaxTokens = *bestEndpoint.MaxCompletionTokens / 2
381		} else {
382			m.DefaultMaxTokens = bestEndpoint.ContextLength / 10
383		}
384
385		openRouterProvider.Models = append(openRouterProvider.Models, m)
386		fmt.Printf("Added model %s with context window %d from provider %s\n",
387			model.ID, bestEndpoint.ContextLength, bestEndpoint.ProviderName)
388	}
389
390	slices.SortFunc(openRouterProvider.Models, func(a catwalk.Model, b catwalk.Model) int {
391		return strings.Compare(a.Name, b.Name)
392	})
393
394	// save the json in internal/providers/config/openrouter.json
395	data, err := json.MarshalIndent(openRouterProvider, "", "  ")
396	if err != nil {
397		log.Fatal("Error marshaling OpenRouter provider:", err)
398	}
399	// write to file
400	if err := os.WriteFile("internal/providers/configs/openrouter.json", data, 0o600); err != nil {
401		log.Fatal("Error writing OpenRouter provider config:", err)
402	}
403}