main.go

  1// Package main provides a command-line tool to fetch models from Venice
  2// and generate a configuration file for the provider.
  3package main
  4
  5import (
  6	"context"
  7	"encoding/json"
  8	"fmt"
  9	"io"
 10	"log"
 11	"math"
 12	"net/http"
 13	"os"
 14	"slices"
 15	"strings"
 16	"time"
 17
 18	"charm.land/catwalk/pkg/catwalk"
 19)
 20
 21type ModelsResponse struct {
 22	Data []VeniceModel `json:"data"`
 23}
 24
 25type VeniceModel struct {
 26	Created   int64           `json:"created"`
 27	ID        string          `json:"id"`
 28	ModelSpec VeniceModelSpec `json:"model_spec"`
 29	Object    string          `json:"object"`
 30	OwnedBy   string          `json:"owned_by"`
 31	Type      string          `json:"type"`
 32}
 33
 34type VeniceModelSpec struct {
 35	AvailableContextTokens int64                   `json:"availableContextTokens"`
 36	Capabilities           VeniceModelCapabilities `json:"capabilities"`
 37	Constraints            VeniceModelConstraints  `json:"constraints"`
 38	Name                   string                  `json:"name"`
 39	ModelSource            string                  `json:"modelSource"`
 40	Offline                bool                    `json:"offline"`
 41	Pricing                VeniceModelPricing      `json:"pricing"`
 42	Traits                 []string                `json:"traits"`
 43	Beta                   bool                    `json:"beta"`
 44}
 45
 46type VeniceModelCapabilities struct {
 47	OptimizedForCode        bool   `json:"optimizedForCode"`
 48	Quantization            string `json:"quantization"`
 49	SupportsFunctionCalling bool   `json:"supportsFunctionCalling"`
 50	SupportsReasoning       bool   `json:"supportsReasoning"`
 51	SupportsResponseSchema  bool   `json:"supportsResponseSchema"`
 52	SupportsVision          bool   `json:"supportsVision"`
 53	SupportsWebSearch       bool   `json:"supportsWebSearch"`
 54	SupportsLogProbs        bool   `json:"supportsLogProbs"`
 55}
 56
 57type VeniceModelConstraints struct {
 58	Temperature *VeniceDefaultFloat `json:"temperature"`
 59	TopP        *VeniceDefaultFloat `json:"top_p"`
 60}
 61
 62type VeniceDefaultFloat struct {
 63	Default float64 `json:"default"`
 64}
 65
 66type VeniceModelPricing struct {
 67	Input  VeniceModelPricingValue `json:"input"`
 68	Output VeniceModelPricingValue `json:"output"`
 69}
 70
 71type VeniceModelPricingValue struct {
 72	USD  float64 `json:"usd"`
 73	Diem float64 `json:"diem"`
 74}
 75
 76func fetchVeniceModels(apiEndpoint string) (*ModelsResponse, error) {
 77	client := &http.Client{Timeout: 30 * time.Second}
 78	url := strings.TrimRight(apiEndpoint, "/") + "/models"
 79	req, _ := http.NewRequestWithContext(context.Background(), "GET", url, nil)
 80	req.Header.Set("User-Agent", "Crush-Client/1.0")
 81
 82	if apiKey := strings.TrimSpace(os.Getenv("VENICE_API_KEY")); apiKey != "" && !strings.HasPrefix(apiKey, "$") {
 83		req.Header.Set("Authorization", "Bearer "+apiKey)
 84	}
 85
 86	resp, err := client.Do(req)
 87	if err != nil {
 88		return nil, err //nolint:wrapcheck
 89	}
 90	defer resp.Body.Close() //nolint:errcheck
 91	if resp.StatusCode != 200 {
 92		body, _ := io.ReadAll(resp.Body)
 93		return nil, fmt.Errorf("status %d: %s", resp.StatusCode, body)
 94	}
 95
 96	var mr ModelsResponse
 97	if err := json.NewDecoder(resp.Body).Decode(&mr); err != nil {
 98		return nil, err //nolint:wrapcheck
 99	}
100	return &mr, nil
101}
102
103func minInt64(a, b int64) int64 {
104	if a < b {
105		return a
106	}
107	return b
108}
109
110func maxInt64(a, b int64) int64 {
111	if a > b {
112		return a
113	}
114	return b
115}
116
117func bestLargeModelID(models []catwalk.Model) string {
118	var best *catwalk.Model
119	for i := range models {
120		m := &models[i]
121
122		if best == nil {
123			best = m
124			continue
125		}
126		mCost := m.CostPer1MIn + m.CostPer1MOut
127		bestCost := best.CostPer1MIn + best.CostPer1MOut
128		if mCost > bestCost {
129			best = m
130			continue
131		}
132		if mCost == bestCost && m.ContextWindow > best.ContextWindow {
133			best = m
134		}
135	}
136	if best == nil {
137		return ""
138	}
139	return best.ID
140}
141
142func bestSmallModelID(models []catwalk.Model) string {
143	var best *catwalk.Model
144	for i := range models {
145		m := &models[i]
146		if best == nil {
147			best = m
148			continue
149		}
150		mCost := m.CostPer1MIn + m.CostPer1MOut
151		bestCost := best.CostPer1MIn + best.CostPer1MOut
152		if mCost < bestCost {
153			best = m
154			continue
155		}
156		if mCost == bestCost && m.ContextWindow < best.ContextWindow {
157			best = m
158		}
159	}
160	if best == nil {
161		return ""
162	}
163	return best.ID
164}
165
166func main() {
167	veniceProvider := catwalk.Provider{
168		Name:        "Venice AI",
169		ID:          catwalk.InferenceProviderVenice,
170		APIKey:      "$VENICE_API_KEY",
171		APIEndpoint: "https://api.venice.ai/api/v1",
172		Type:        catwalk.TypeOpenAICompat,
173		Models:      []catwalk.Model{},
174	}
175
176	codeOptimizedModels := []catwalk.Model{}
177
178	modelsResp, err := fetchVeniceModels(veniceProvider.APIEndpoint)
179	if err != nil {
180		log.Fatal("Error fetching Venice models:", err)
181	}
182
183	for _, model := range modelsResp.Data {
184		if strings.ToLower(model.Type) != "text" {
185			continue
186		}
187		if model.ModelSpec.Offline {
188			continue
189		}
190		if !model.ModelSpec.Capabilities.SupportsFunctionCalling {
191			continue
192		}
193
194		if model.ModelSpec.Beta {
195			continue
196		}
197
198		contextWindow := model.ModelSpec.AvailableContextTokens
199		if contextWindow <= 0 {
200			continue
201		}
202
203		defaultMaxTokens := minInt64(contextWindow/4, 32768)
204		defaultMaxTokens = maxInt64(defaultMaxTokens, 2048)
205
206		canReason := model.ModelSpec.Capabilities.SupportsReasoning
207		var reasoningLevels []string
208		var defaultReasoning string
209		if canReason {
210			reasoningLevels = []string{"low", "medium", "high"}
211			defaultReasoning = "medium"
212		}
213
214		options := catwalk.ModelOptions{}
215		if model.ModelSpec.Constraints.Temperature != nil {
216			v := model.ModelSpec.Constraints.Temperature.Default
217			if !math.IsNaN(v) {
218				options.Temperature = &v
219			}
220		}
221		if model.ModelSpec.Constraints.TopP != nil {
222			v := model.ModelSpec.Constraints.TopP.Default
223			if !math.IsNaN(v) {
224				options.TopP = &v
225			}
226		}
227
228		roundCost := func(v float64) float64 { return math.Round(v*1e5) / 1e5 }
229		m := catwalk.Model{
230			ID:                     model.ID,
231			Name:                   model.ModelSpec.Name,
232			CostPer1MIn:            roundCost(model.ModelSpec.Pricing.Input.USD),
233			CostPer1MOut:           roundCost(model.ModelSpec.Pricing.Output.USD),
234			CostPer1MInCached:      0,
235			CostPer1MOutCached:     0,
236			ContextWindow:          contextWindow,
237			DefaultMaxTokens:       defaultMaxTokens,
238			CanReason:              canReason,
239			ReasoningLevels:        reasoningLevels,
240			DefaultReasoningEffort: defaultReasoning,
241			SupportsImages:         model.ModelSpec.Capabilities.SupportsVision,
242			Options:                options,
243		}
244
245		veniceProvider.Models = append(veniceProvider.Models, m)
246		if model.ModelSpec.Capabilities.OptimizedForCode {
247			codeOptimizedModels = append(codeOptimizedModels, m)
248		}
249	}
250
251	candidateModels := veniceProvider.Models
252	if len(codeOptimizedModels) > 0 {
253		candidateModels = codeOptimizedModels
254	}
255
256	veniceProvider.DefaultLargeModelID = bestLargeModelID(candidateModels)
257	veniceProvider.DefaultSmallModelID = bestSmallModelID(candidateModels)
258
259	slices.SortFunc(veniceProvider.Models, func(a catwalk.Model, b catwalk.Model) int {
260		return strings.Compare(a.Name, b.Name)
261	})
262
263	data, err := json.MarshalIndent(veniceProvider, "", "  ")
264	if err != nil {
265		log.Fatal("Error marshaling Venice provider:", err)
266	}
267
268	if err := os.WriteFile("internal/providers/configs/venice.json", data, 0o600); err != nil {
269		log.Fatal("Error writing Venice provider config:", err)
270	}
271
272	fmt.Printf("Generated venice.json with %d models\n", len(veniceProvider.Models))
273}