1// Package main provides a command-line tool to fetch models from xAI
2// and generate a configuration file for the provider.
3package main
4
5import (
6 "context"
7 "encoding/json"
8 "fmt"
9 "io"
10 "log"
11 "math"
12 "net/http"
13 "os"
14 "slices"
15 "strings"
16 "time"
17
18 "charm.land/catwalk/pkg/catwalk"
19)
20
21type ModelsResponse struct {
22 Models []XAIModel `json:"models"`
23}
24
25type XAIModel struct {
26 ID string `json:"id"`
27 Aliases []string `json:"aliases"`
28 InputModalities []string `json:"input_modalities"`
29 OutputModalities []string `json:"output_modalities"`
30 PromptTextTokenPrice int64 `json:"prompt_text_token_price"`
31 CompletionTextTokenPrice int64 `json:"completion_text_token_price"`
32 CachedPromptTextTokenPrc int64 `json:"cached_prompt_text_token_price"`
33}
34
35func shortestAlias(model XAIModel) string {
36 if len(model.Aliases) == 0 {
37 return model.ID
38 }
39 shortest := model.Aliases[0]
40 for _, a := range model.Aliases[1:] {
41 if len(a) < len(shortest) {
42 shortest = a
43 }
44 }
45 if len(shortest) < len(model.ID) {
46 return shortest
47 }
48 return model.ID
49}
50
51var prettyNames = map[string]string{
52 "grok-3": "Grok 3",
53 "grok-3-mini": "Grok 3 Mini",
54 "grok-4": "Grok 4",
55 "grok-4-fast": "Grok 4 Fast",
56 "grok-4-fast-non-reasoning": "Grok 4 Fast Non-Reasoning",
57 "grok-4-1-fast": "Grok 4.1 Fast",
58 "grok-4-1-fast-non-reasoning": "Grok 4.1 Fast Non-Reasoning",
59 "grok-4.20": "Grok 4.20",
60 "grok-4.20-non-reasoning": "Grok 4.20 Non-Reasoning",
61 "grok-4.20-multi-agent": "Grok 4.20 Multi-Agent",
62 "grok-code-fast": "Grok Code Fast",
63}
64
65func prettyName(id string) string {
66 if name, ok := prettyNames[id]; ok {
67 return name
68 }
69 return id
70}
71
72func contextWindow(modelID string) int64 {
73 if strings.Contains(modelID, "grok-4") {
74 return 200_000
75 }
76 return 131_072
77}
78
79func roundCost(v float64) float64 {
80 return math.Round(v*1e5) / 1e5
81}
82
83func priceToDollarsPerMillion(centsPerHundredMillion int64) float64 {
84 return roundCost(float64(centsPerHundredMillion) / 10_000)
85}
86
87func fetchXAIModels() (*ModelsResponse, error) {
88 apiKey := os.Getenv("XAI_API_KEY")
89 if apiKey == "" {
90 return nil, fmt.Errorf("XAI_API_KEY environment variable is not set")
91 }
92
93 client := &http.Client{Timeout: 30 * time.Second}
94 req, _ := http.NewRequestWithContext(
95 context.Background(),
96 "GET",
97 "https://api.x.ai/v1/language-models",
98 nil,
99 )
100 req.Header.Set("User-Agent", "Crush-Client/1.0")
101 req.Header.Set("Authorization", "Bearer "+apiKey)
102
103 resp, err := client.Do(req)
104 if err != nil {
105 return nil, err //nolint:wrapcheck
106 }
107 defer resp.Body.Close() //nolint:errcheck
108
109 body, err := io.ReadAll(resp.Body)
110 if err != nil {
111 return nil, fmt.Errorf("unable to read response body: %w", err)
112 }
113
114 if resp.StatusCode != http.StatusOK {
115 return nil, fmt.Errorf("status %d: %s", resp.StatusCode, body)
116 }
117
118 _ = os.MkdirAll("tmp", 0o700)
119 _ = os.WriteFile("tmp/xai-response.json", body, 0o600)
120
121 var mr ModelsResponse
122 if err := json.Unmarshal(body, &mr); err != nil {
123 return nil, err //nolint:wrapcheck
124 }
125 return &mr, nil
126}
127
128func main() {
129 modelsResp, err := fetchXAIModels()
130 if err != nil {
131 log.Fatal("Error fetching xAI models:", err)
132 }
133
134 provider := catwalk.Provider{
135 Name: "xAI",
136 ID: catwalk.InferenceProviderXAI,
137 APIKey: "$XAI_API_KEY",
138 APIEndpoint: "https://api.x.ai/v1",
139 Type: catwalk.TypeOpenAICompat,
140 DefaultLargeModelID: "grok-4.20",
141 DefaultSmallModelID: "grok-4-1-fast",
142 }
143
144 for _, model := range modelsResp.Models {
145 if strings.Contains(model.ID, "multi-agent") {
146 continue
147 }
148
149 id := shortestAlias(model)
150 ctxWindow := contextWindow(model.ID)
151 defaultMaxTokens := ctxWindow / 10
152
153 canReason := !strings.Contains(model.ID, "non-reasoning") &&
154 model.ID != "grok-3"
155 supportsImages := slices.Contains(model.InputModalities, "image")
156
157 m := catwalk.Model{
158 ID: id,
159 Name: prettyName(id),
160 CostPer1MIn: priceToDollarsPerMillion(model.PromptTextTokenPrice),
161 CostPer1MOut: priceToDollarsPerMillion(model.CompletionTextTokenPrice),
162 CostPer1MInCached: 0,
163 CostPer1MOutCached: priceToDollarsPerMillion(model.CachedPromptTextTokenPrc),
164 ContextWindow: ctxWindow,
165 DefaultMaxTokens: defaultMaxTokens,
166 CanReason: canReason,
167 SupportsImages: supportsImages,
168 }
169
170 provider.Models = append(provider.Models, m)
171 fmt.Printf("Added model %s (alias: %s)\n", model.ID, id)
172 }
173
174 slices.SortFunc(provider.Models, func(a, b catwalk.Model) int {
175 return strings.Compare(a.ID, b.ID)
176 })
177
178 data, err := json.MarshalIndent(provider, "", " ")
179 if err != nil {
180 log.Fatal("Error marshaling xAI provider:", err)
181 }
182 data = append(data, '\n')
183
184 if err := os.WriteFile("internal/providers/configs/xai.json", data, 0o600); err != nil {
185 log.Fatal("Error writing xAI provider config:", err)
186 }
187
188 fmt.Printf("Generated xai.json with %d models\n", len(provider.Models))
189}