main.go

  1// Package main provides a command-line tool to fetch models from xAI
  2// and generate a configuration file for the provider.
  3package main
  4
  5import (
  6	"context"
  7	"encoding/json"
  8	"fmt"
  9	"io"
 10	"log"
 11	"math"
 12	"net/http"
 13	"os"
 14	"slices"
 15	"strings"
 16	"time"
 17
 18	"charm.land/catwalk/pkg/catwalk"
 19)
 20
 21type ModelsResponse struct {
 22	Models []XAIModel `json:"models"`
 23}
 24
 25type XAIModel struct {
 26	ID                       string   `json:"id"`
 27	Aliases                  []string `json:"aliases"`
 28	InputModalities          []string `json:"input_modalities"`
 29	OutputModalities         []string `json:"output_modalities"`
 30	PromptTextTokenPrice     int64    `json:"prompt_text_token_price"`
 31	CompletionTextTokenPrice int64    `json:"completion_text_token_price"`
 32	CachedPromptTextTokenPrc int64    `json:"cached_prompt_text_token_price"`
 33}
 34
 35func shortestAlias(model XAIModel) string {
 36	if len(model.Aliases) == 0 {
 37		return model.ID
 38	}
 39	shortest := model.Aliases[0]
 40	for _, a := range model.Aliases[1:] {
 41		if len(a) < len(shortest) {
 42			shortest = a
 43		}
 44	}
 45	if len(shortest) < len(model.ID) {
 46		return shortest
 47	}
 48	return model.ID
 49}
 50
 51var prettyNames = map[string]string{
 52	"grok-3":                      "Grok 3",
 53	"grok-3-mini":                 "Grok 3 Mini",
 54	"grok-4":                      "Grok 4",
 55	"grok-4-fast":                 "Grok 4 Fast",
 56	"grok-4-fast-non-reasoning":   "Grok 4 Fast Non-Reasoning",
 57	"grok-4-1-fast":               "Grok 4.1 Fast",
 58	"grok-4-1-fast-non-reasoning": "Grok 4.1 Fast Non-Reasoning",
 59	"grok-4.20":                   "Grok 4.20",
 60	"grok-4.20-non-reasoning":     "Grok 4.20 Non-Reasoning",
 61	"grok-4.20-multi-agent":       "Grok 4.20 Multi-Agent",
 62	"grok-code-fast":              "Grok Code Fast",
 63}
 64
 65func prettyName(id string) string {
 66	if name, ok := prettyNames[id]; ok {
 67		return name
 68	}
 69	return id
 70}
 71
 72func contextWindow(modelID string) int64 {
 73	if strings.Contains(modelID, "grok-4") {
 74		return 200_000
 75	}
 76	return 131_072
 77}
 78
 79func roundCost(v float64) float64 {
 80	return math.Round(v*1e5) / 1e5
 81}
 82
 83func priceToDollarsPerMillion(centsPerHundredMillion int64) float64 {
 84	return roundCost(float64(centsPerHundredMillion) / 10_000)
 85}
 86
 87func fetchXAIModels() (*ModelsResponse, error) {
 88	apiKey := os.Getenv("XAI_API_KEY")
 89	if apiKey == "" {
 90		return nil, fmt.Errorf("XAI_API_KEY environment variable is not set")
 91	}
 92
 93	client := &http.Client{Timeout: 30 * time.Second}
 94	req, _ := http.NewRequestWithContext(
 95		context.Background(),
 96		"GET",
 97		"https://api.x.ai/v1/language-models",
 98		nil,
 99	)
100	req.Header.Set("User-Agent", "Crush-Client/1.0")
101	req.Header.Set("Authorization", "Bearer "+apiKey)
102
103	resp, err := client.Do(req)
104	if err != nil {
105		return nil, err //nolint:wrapcheck
106	}
107	defer resp.Body.Close() //nolint:errcheck
108
109	body, err := io.ReadAll(resp.Body)
110	if err != nil {
111		return nil, fmt.Errorf("unable to read response body: %w", err)
112	}
113
114	if resp.StatusCode != http.StatusOK {
115		return nil, fmt.Errorf("status %d: %s", resp.StatusCode, body)
116	}
117
118	_ = os.MkdirAll("tmp", 0o700)
119	_ = os.WriteFile("tmp/xai-response.json", body, 0o600)
120
121	var mr ModelsResponse
122	if err := json.Unmarshal(body, &mr); err != nil {
123		return nil, err //nolint:wrapcheck
124	}
125	return &mr, nil
126}
127
128func main() {
129	modelsResp, err := fetchXAIModels()
130	if err != nil {
131		log.Fatal("Error fetching xAI models:", err)
132	}
133
134	provider := catwalk.Provider{
135		Name:                "xAI",
136		ID:                  catwalk.InferenceProviderXAI,
137		APIKey:              "$XAI_API_KEY",
138		APIEndpoint:         "https://api.x.ai/v1",
139		Type:                catwalk.TypeOpenAICompat,
140		DefaultLargeModelID: "grok-4.20",
141		DefaultSmallModelID: "grok-4-1-fast",
142	}
143
144	for _, model := range modelsResp.Models {
145		if strings.Contains(model.ID, "multi-agent") {
146			continue
147		}
148
149		id := shortestAlias(model)
150		ctxWindow := contextWindow(model.ID)
151		defaultMaxTokens := ctxWindow / 10
152
153		canReason := !strings.Contains(model.ID, "non-reasoning") &&
154			model.ID != "grok-3"
155		supportsImages := slices.Contains(model.InputModalities, "image")
156
157		m := catwalk.Model{
158			ID:                 id,
159			Name:               prettyName(id),
160			CostPer1MIn:        priceToDollarsPerMillion(model.PromptTextTokenPrice),
161			CostPer1MOut:       priceToDollarsPerMillion(model.CompletionTextTokenPrice),
162			CostPer1MInCached:  0,
163			CostPer1MOutCached: priceToDollarsPerMillion(model.CachedPromptTextTokenPrc),
164			ContextWindow:      ctxWindow,
165			DefaultMaxTokens:   defaultMaxTokens,
166			CanReason:          canReason,
167			SupportsImages:     supportsImages,
168		}
169
170		provider.Models = append(provider.Models, m)
171		fmt.Printf("Added model %s (alias: %s)\n", model.ID, id)
172	}
173
174	slices.SortFunc(provider.Models, func(a, b catwalk.Model) int {
175		return strings.Compare(a.ID, b.ID)
176	})
177
178	data, err := json.MarshalIndent(provider, "", "  ")
179	if err != nil {
180		log.Fatal("Error marshaling xAI provider:", err)
181	}
182	data = append(data, '\n')
183
184	if err := os.WriteFile("internal/providers/configs/xai.json", data, 0o600); err != nil {
185		log.Fatal("Error writing xAI provider config:", err)
186	}
187
188	fmt.Printf("Generated xai.json with %d models\n", len(provider.Models))
189}