main.go

  1// Package main provides a command-line tool to fetch models from Venice
  2// and generate a configuration file for the provider.
  3package main
  4
  5import (
  6	"context"
  7	"encoding/json"
  8	"fmt"
  9	"io"
 10	"log"
 11	"math"
 12	"net/http"
 13	"os"
 14	"slices"
 15	"strings"
 16	"time"
 17
 18	"charm.land/catwalk/pkg/catwalk"
 19)
 20
 21type ModelsResponse struct {
 22	Data []VeniceModel `json:"data"`
 23}
 24
 25type VeniceModel struct {
 26	ID        string          `json:"id"`
 27	ModelSpec VeniceModelSpec `json:"model_spec"`
 28	Type      string          `json:"type"`
 29}
 30
 31type VeniceModelSpec struct {
 32	AvailableContextTokens int64                   `json:"availableContextTokens"`
 33	MaxCompletionTokens    int64                   `json:"maxCompletionTokens"`
 34	Capabilities           VeniceModelCapabilities `json:"capabilities"`
 35	Constraints            VeniceModelConstraints  `json:"constraints"`
 36	Name                   string                  `json:"name"`
 37	Offline                bool                    `json:"offline"`
 38	Pricing                VeniceModelPricing      `json:"pricing"`
 39	Beta                   bool                    `json:"beta"`
 40}
 41
 42type VeniceModelCapabilities struct {
 43	OptimizedForCode        bool `json:"optimizedForCode"`
 44	SupportsFunctionCalling bool `json:"supportsFunctionCalling"`
 45	SupportsReasoning       bool `json:"supportsReasoning"`
 46	SupportsReasoningEffort bool `json:"supportsReasoningEffort"`
 47	SupportsVision          bool `json:"supportsVision"`
 48}
 49
 50type VeniceModelConstraints struct {
 51	Temperature *VeniceDefaultFloat `json:"temperature"`
 52	TopP        *VeniceDefaultFloat `json:"top_p"`
 53}
 54
 55type VeniceDefaultFloat struct {
 56	Default float64 `json:"default"`
 57}
 58
 59type VeniceModelPricing struct {
 60	Input  VeniceModelPricingValue `json:"input"`
 61	Output VeniceModelPricingValue `json:"output"`
 62}
 63
 64type VeniceModelPricingValue struct {
 65	USD float64 `json:"usd"`
 66}
 67
 68func fetchVeniceModels(apiEndpoint string) (*ModelsResponse, error) {
 69	client := &http.Client{Timeout: 30 * time.Second}
 70	url := strings.TrimRight(apiEndpoint, "/") + "/models"
 71	req, _ := http.NewRequestWithContext(context.Background(), "GET", url, nil)
 72	req.Header.Set("User-Agent", "Crush-Client/1.0")
 73
 74	if apiKey := strings.TrimSpace(os.Getenv("VENICE_API_KEY")); apiKey != "" && !strings.HasPrefix(apiKey, "$") {
 75		req.Header.Set("Authorization", "Bearer "+apiKey)
 76	}
 77
 78	resp, err := client.Do(req)
 79	if err != nil {
 80		return nil, err //nolint:wrapcheck
 81	}
 82	defer resp.Body.Close() //nolint:errcheck
 83	if resp.StatusCode != 200 {
 84		body, _ := io.ReadAll(resp.Body)
 85		return nil, fmt.Errorf("status %d: %s", resp.StatusCode, body)
 86	}
 87
 88	var mr ModelsResponse
 89	if err := json.NewDecoder(resp.Body).Decode(&mr); err != nil {
 90		return nil, err //nolint:wrapcheck
 91	}
 92	return &mr, nil
 93}
 94
 95func bestLargeModelID(models []catwalk.Model) string {
 96	var best *catwalk.Model
 97	for i := range models {
 98		m := &models[i]
 99
100		if best == nil {
101			best = m
102			continue
103		}
104		mCost := m.CostPer1MIn + m.CostPer1MOut
105		bestCost := best.CostPer1MIn + best.CostPer1MOut
106		if mCost > bestCost {
107			best = m
108			continue
109		}
110		if mCost == bestCost && m.ContextWindow > best.ContextWindow {
111			best = m
112		}
113	}
114	if best == nil {
115		return ""
116	}
117	return best.ID
118}
119
120func bestSmallModelID(models []catwalk.Model) string {
121	var best *catwalk.Model
122	for i := range models {
123		m := &models[i]
124		if best == nil {
125			best = m
126			continue
127		}
128		mCost := m.CostPer1MIn + m.CostPer1MOut
129		bestCost := best.CostPer1MIn + best.CostPer1MOut
130		if mCost < bestCost {
131			best = m
132			continue
133		}
134		if mCost == bestCost && m.ContextWindow < best.ContextWindow {
135			best = m
136		}
137	}
138	if best == nil {
139		return ""
140	}
141	return best.ID
142}
143
144func main() {
145	veniceProvider := catwalk.Provider{
146		Name:        "Venice AI",
147		ID:          catwalk.InferenceProviderVenice,
148		APIKey:      "$VENICE_API_KEY",
149		APIEndpoint: "https://api.venice.ai/api/v1",
150		Type:        catwalk.TypeOpenAICompat,
151		Models:      []catwalk.Model{},
152	}
153
154	var codeOptimizedModels []catwalk.Model
155
156	modelsResp, err := fetchVeniceModels(veniceProvider.APIEndpoint)
157	if err != nil {
158		log.Fatal("Error fetching Venice models:", err)
159	}
160
161	for _, model := range modelsResp.Data {
162		if strings.ToLower(model.Type) != "text" {
163			continue
164		}
165		if model.ModelSpec.Offline {
166			continue
167		}
168		if !model.ModelSpec.Capabilities.SupportsFunctionCalling {
169			continue
170		}
171
172		if model.ModelSpec.Beta {
173			continue
174		}
175
176		contextWindow := model.ModelSpec.AvailableContextTokens
177		if contextWindow <= 0 {
178			continue
179		}
180
181		var (
182			canReason            = model.ModelSpec.Capabilities.SupportsReasoning
183			supportsReasonEffort = model.ModelSpec.Capabilities.SupportsReasoningEffort
184		)
185		var reasoningLevels []string
186		var defaultReasoning string
187		if canReason && supportsReasonEffort {
188			reasoningLevels = []string{"low", "medium", "high"}
189			defaultReasoning = "medium"
190		}
191
192		options := catwalk.ModelOptions{}
193		if model.ModelSpec.Constraints.Temperature != nil {
194			v := model.ModelSpec.Constraints.Temperature.Default
195			if !math.IsNaN(v) {
196				options.Temperature = &v
197			}
198		}
199		if model.ModelSpec.Constraints.TopP != nil {
200			v := model.ModelSpec.Constraints.TopP.Default
201			if !math.IsNaN(v) {
202				options.TopP = &v
203			}
204		}
205
206		roundCost := func(v float64) float64 { return math.Round(v*1e5) / 1e5 }
207		m := catwalk.Model{
208			ID:                     model.ID,
209			Name:                   model.ModelSpec.Name,
210			CostPer1MIn:            roundCost(model.ModelSpec.Pricing.Input.USD),
211			CostPer1MOut:           roundCost(model.ModelSpec.Pricing.Output.USD),
212			CostPer1MInCached:      0,
213			CostPer1MOutCached:     0,
214			ContextWindow:          contextWindow,
215			DefaultMaxTokens:       model.ModelSpec.MaxCompletionTokens,
216			CanReason:              canReason,
217			ReasoningLevels:        reasoningLevels,
218			DefaultReasoningEffort: defaultReasoning,
219			SupportsImages:         model.ModelSpec.Capabilities.SupportsVision,
220			Options:                options,
221		}
222
223		veniceProvider.Models = append(veniceProvider.Models, m)
224		if model.ModelSpec.Capabilities.OptimizedForCode {
225			codeOptimizedModels = append(codeOptimizedModels, m)
226		}
227	}
228
229	candidateModels := veniceProvider.Models
230	if len(codeOptimizedModels) > 0 {
231		candidateModels = codeOptimizedModels
232	}
233
234	veniceProvider.DefaultLargeModelID = bestLargeModelID(candidateModels)
235	veniceProvider.DefaultSmallModelID = bestSmallModelID(candidateModels)
236
237	slices.SortFunc(veniceProvider.Models, func(a catwalk.Model, b catwalk.Model) int {
238		return strings.Compare(a.Name, b.Name)
239	})
240
241	data, err := json.MarshalIndent(veniceProvider, "", "  ")
242	if err != nil {
243		log.Fatal("Error marshaling Venice provider:", err)
244	}
245	data = append(data, '\n')
246
247	if err := os.WriteFile("internal/providers/configs/venice.json", data, 0o600); err != nil {
248		log.Fatal("Error writing Venice provider config:", err)
249	}
250
251	fmt.Printf("Generated venice.json with %d models\n", len(veniceProvider.Models))
252}