1// Package main provides a command-line tool to fetch models from Venice
2// and generate a configuration file for the provider.
3package main
4
5import (
6 "context"
7 "encoding/json"
8 "fmt"
9 "io"
10 "log"
11 "math"
12 "net/http"
13 "os"
14 "slices"
15 "strings"
16 "time"
17
18 "charm.land/catwalk/pkg/catwalk"
19)
20
21type ModelsResponse struct {
22 Data []VeniceModel `json:"data"`
23}
24
25type VeniceModel struct {
26 Created int64 `json:"created"`
27 ID string `json:"id"`
28 ModelSpec VeniceModelSpec `json:"model_spec"`
29 Object string `json:"object"`
30 OwnedBy string `json:"owned_by"`
31 Type string `json:"type"`
32}
33
34type VeniceModelSpec struct {
35 AvailableContextTokens int64 `json:"availableContextTokens"`
36 Capabilities VeniceModelCapabilities `json:"capabilities"`
37 Constraints VeniceModelConstraints `json:"constraints"`
38 Name string `json:"name"`
39 ModelSource string `json:"modelSource"`
40 Offline bool `json:"offline"`
41 Pricing VeniceModelPricing `json:"pricing"`
42 Traits []string `json:"traits"`
43 Beta bool `json:"beta"`
44}
45
46type VeniceModelCapabilities struct {
47 OptimizedForCode bool `json:"optimizedForCode"`
48 Quantization string `json:"quantization"`
49 SupportsFunctionCalling bool `json:"supportsFunctionCalling"`
50 SupportsReasoning bool `json:"supportsReasoning"`
51 SupportsResponseSchema bool `json:"supportsResponseSchema"`
52 SupportsVision bool `json:"supportsVision"`
53 SupportsWebSearch bool `json:"supportsWebSearch"`
54 SupportsLogProbs bool `json:"supportsLogProbs"`
55}
56
57type VeniceModelConstraints struct {
58 Temperature *VeniceDefaultFloat `json:"temperature"`
59 TopP *VeniceDefaultFloat `json:"top_p"`
60}
61
62type VeniceDefaultFloat struct {
63 Default float64 `json:"default"`
64}
65
66type VeniceModelPricing struct {
67 Input VeniceModelPricingValue `json:"input"`
68 Output VeniceModelPricingValue `json:"output"`
69}
70
71type VeniceModelPricingValue struct {
72 USD float64 `json:"usd"`
73 Diem float64 `json:"diem"`
74}
75
76func fetchVeniceModels(apiEndpoint string) (*ModelsResponse, error) {
77 client := &http.Client{Timeout: 30 * time.Second}
78 url := strings.TrimRight(apiEndpoint, "/") + "/models"
79 req, _ := http.NewRequestWithContext(context.Background(), "GET", url, nil)
80 req.Header.Set("User-Agent", "Crush-Client/1.0")
81
82 if apiKey := strings.TrimSpace(os.Getenv("VENICE_API_KEY")); apiKey != "" && !strings.HasPrefix(apiKey, "$") {
83 req.Header.Set("Authorization", "Bearer "+apiKey)
84 }
85
86 resp, err := client.Do(req)
87 if err != nil {
88 return nil, err //nolint:wrapcheck
89 }
90 defer resp.Body.Close() //nolint:errcheck
91 if resp.StatusCode != 200 {
92 body, _ := io.ReadAll(resp.Body)
93 return nil, fmt.Errorf("status %d: %s", resp.StatusCode, body)
94 }
95
96 var mr ModelsResponse
97 if err := json.NewDecoder(resp.Body).Decode(&mr); err != nil {
98 return nil, err //nolint:wrapcheck
99 }
100 return &mr, nil
101}
102
103func minInt64(a, b int64) int64 {
104 if a < b {
105 return a
106 }
107 return b
108}
109
110func maxInt64(a, b int64) int64 {
111 if a > b {
112 return a
113 }
114 return b
115}
116
117func bestLargeModelID(models []catwalk.Model) string {
118 var best *catwalk.Model
119 for i := range models {
120 m := &models[i]
121
122 if best == nil {
123 best = m
124 continue
125 }
126 mCost := m.CostPer1MIn + m.CostPer1MOut
127 bestCost := best.CostPer1MIn + best.CostPer1MOut
128 if mCost > bestCost {
129 best = m
130 continue
131 }
132 if mCost == bestCost && m.ContextWindow > best.ContextWindow {
133 best = m
134 }
135 }
136 if best == nil {
137 return ""
138 }
139 return best.ID
140}
141
142func bestSmallModelID(models []catwalk.Model) string {
143 var best *catwalk.Model
144 for i := range models {
145 m := &models[i]
146 if best == nil {
147 best = m
148 continue
149 }
150 mCost := m.CostPer1MIn + m.CostPer1MOut
151 bestCost := best.CostPer1MIn + best.CostPer1MOut
152 if mCost < bestCost {
153 best = m
154 continue
155 }
156 if mCost == bestCost && m.ContextWindow < best.ContextWindow {
157 best = m
158 }
159 }
160 if best == nil {
161 return ""
162 }
163 return best.ID
164}
165
166func main() {
167 veniceProvider := catwalk.Provider{
168 Name: "Venice AI",
169 ID: catwalk.InferenceProviderVenice,
170 APIKey: "$VENICE_API_KEY",
171 APIEndpoint: "https://api.venice.ai/api/v1",
172 Type: catwalk.TypeOpenAICompat,
173 Models: []catwalk.Model{},
174 }
175
176 codeOptimizedModels := []catwalk.Model{}
177
178 modelsResp, err := fetchVeniceModels(veniceProvider.APIEndpoint)
179 if err != nil {
180 log.Fatal("Error fetching Venice models:", err)
181 }
182
183 for _, model := range modelsResp.Data {
184 if strings.ToLower(model.Type) != "text" {
185 continue
186 }
187 if model.ModelSpec.Offline {
188 continue
189 }
190 if !model.ModelSpec.Capabilities.SupportsFunctionCalling {
191 continue
192 }
193
194 if model.ModelSpec.Beta {
195 continue
196 }
197
198 contextWindow := model.ModelSpec.AvailableContextTokens
199 if contextWindow <= 0 {
200 continue
201 }
202
203 defaultMaxTokens := minInt64(contextWindow/4, 32768)
204 defaultMaxTokens = maxInt64(defaultMaxTokens, 2048)
205
206 canReason := model.ModelSpec.Capabilities.SupportsReasoning
207 var reasoningLevels []string
208 var defaultReasoning string
209 if canReason {
210 reasoningLevels = []string{"low", "medium", "high"}
211 defaultReasoning = "medium"
212 }
213
214 options := catwalk.ModelOptions{}
215 if model.ModelSpec.Constraints.Temperature != nil {
216 v := model.ModelSpec.Constraints.Temperature.Default
217 if !math.IsNaN(v) {
218 options.Temperature = &v
219 }
220 }
221 if model.ModelSpec.Constraints.TopP != nil {
222 v := model.ModelSpec.Constraints.TopP.Default
223 if !math.IsNaN(v) {
224 options.TopP = &v
225 }
226 }
227
228 m := catwalk.Model{
229 ID: model.ID,
230 Name: model.ModelSpec.Name,
231 CostPer1MIn: model.ModelSpec.Pricing.Input.USD,
232 CostPer1MOut: model.ModelSpec.Pricing.Output.USD,
233 CostPer1MInCached: 0,
234 CostPer1MOutCached: 0,
235 ContextWindow: contextWindow,
236 DefaultMaxTokens: defaultMaxTokens,
237 CanReason: canReason,
238 ReasoningLevels: reasoningLevels,
239 DefaultReasoningEffort: defaultReasoning,
240 SupportsImages: model.ModelSpec.Capabilities.SupportsVision,
241 Options: options,
242 }
243
244 veniceProvider.Models = append(veniceProvider.Models, m)
245 if model.ModelSpec.Capabilities.OptimizedForCode {
246 codeOptimizedModels = append(codeOptimizedModels, m)
247 }
248 }
249
250 candidateModels := veniceProvider.Models
251 if len(codeOptimizedModels) > 0 {
252 candidateModels = codeOptimizedModels
253 }
254
255 veniceProvider.DefaultLargeModelID = bestLargeModelID(candidateModels)
256 veniceProvider.DefaultSmallModelID = bestSmallModelID(candidateModels)
257
258 slices.SortFunc(veniceProvider.Models, func(a catwalk.Model, b catwalk.Model) int {
259 return strings.Compare(a.Name, b.Name)
260 })
261
262 data, err := json.MarshalIndent(veniceProvider, "", " ")
263 if err != nil {
264 log.Fatal("Error marshaling Venice provider:", err)
265 }
266
267 if err := os.WriteFile("internal/providers/configs/venice.json", data, 0o600); err != nil {
268 log.Fatal("Error writing Venice provider config:", err)
269 }
270
271 fmt.Printf("Generated venice.json with %d models\n", len(veniceProvider.Models))
272}