main.go

  1// Package main provides a command-line tool to fetch models from AIHubMix
  2// and generate a configuration file for the provider.
  3package main
  4
  5import (
  6	"context"
  7	"encoding/json"
  8	"fmt"
  9	"io"
 10	"log"
 11	"net/http"
 12	"os"
 13	"slices"
 14	"strings"
 15	"time"
 16
 17	"charm.land/catwalk/pkg/catwalk"
 18)
 19
 20// APIModel represents a model from the AIHubMix API.
 21type APIModel struct {
 22	ModelID         string  `json:"model_id"`
 23	ModelName       string  `json:"model_name"`
 24	Desc            string  `json:"desc"`
 25	Pricing         Pricing `json:"pricing"`
 26	Types           string  `json:"types"`
 27	Features        string  `json:"features"`
 28	InputModalities string  `json:"input_modalities"`
 29	MaxOutput       int64   `json:"max_output"`
 30	ContextLength   int64   `json:"context_length"`
 31}
 32
 33// Pricing contains the pricing information from the API.
 34type Pricing struct {
 35	Input      *float64 `json:"input"`
 36	Output     *float64 `json:"output"`
 37	CacheRead  *float64 `json:"cache_read"`
 38	CacheWrite *float64 `json:"cache_write"`
 39}
 40
 41const (
 42	minContextWindow  = 20000
 43	defaultLargeModel = "gpt-5"
 44	defaultSmallModel = "gpt-5-nano"
 45	maxTokensFactor   = 10
 46)
 47
 48// ModelsResponse is the response structure for the models API.
 49type ModelsResponse struct {
 50	Data    []APIModel `json:"data"`
 51	Message string     `json:"message"`
 52	Success bool       `json:"success"`
 53}
 54
 55func fetchAIHubMixModels() (*ModelsResponse, error) {
 56	req, err := http.NewRequestWithContext(
 57		context.Background(),
 58		"GET",
 59		"https://aihubmix.com/api/v1/models?type=llm",
 60		nil,
 61	)
 62	if err != nil {
 63		return nil, fmt.Errorf("creating request: %w", err)
 64	}
 65	req.Header.Set("User-Agent", "Crush-Client/1.0")
 66
 67	client := &http.Client{Timeout: 30 * time.Second}
 68	resp, err := client.Do(req)
 69	if err != nil {
 70		return nil, fmt.Errorf("fetching models: %w", err)
 71	}
 72	defer resp.Body.Close() //nolint:errcheck
 73
 74	if resp.StatusCode != http.StatusOK {
 75		body, _ := io.ReadAll(resp.Body)
 76		return nil, fmt.Errorf("unexpected status %d: %s", resp.StatusCode, body)
 77	}
 78
 79	var mr ModelsResponse
 80	if err := json.NewDecoder(resp.Body).Decode(&mr); err != nil {
 81		return nil, fmt.Errorf("parsing response: %w", err)
 82	}
 83	return &mr, nil
 84}
 85
 86func hasField(s, field string) bool {
 87	if s == "" {
 88		return false
 89	}
 90	for item := range strings.SplitSeq(s, ",") {
 91		if strings.TrimSpace(item) == field {
 92			return true
 93		}
 94	}
 95	return false
 96}
 97
 98func parseFloat(p *float64) float64 {
 99	if p == nil {
100		return 0
101	}
102	return *p
103}
104
105func calculateMaxTokens(contextLength, maxOutput, factor int64) int64 {
106	if maxOutput == 0 || maxOutput > contextLength/2 {
107		return contextLength / factor
108	}
109	return maxOutput
110}
111
112func buildReasoningConfig(canReason bool) ([]string, string) {
113	if !canReason {
114		return nil, ""
115	}
116	return []string{"low", "medium", "high"}, "medium"
117}
118
119func main() {
120	modelsResp, err := fetchAIHubMixModels()
121	if err != nil {
122		log.Fatal("Error fetching AIHubMix models:", err)
123	}
124
125	aiHubMixProvider := catwalk.Provider{
126		Name:                "AIHubMix",
127		ID:                  catwalk.InferenceAIHubMix,
128		APIKey:              "$AIHUBMIX_API_KEY",
129		APIEndpoint:         "https://aihubmix.com/v1",
130		Type:                catwalk.TypeOpenAICompat,
131		DefaultLargeModelID: defaultLargeModel,
132		DefaultSmallModelID: defaultSmallModel,
133		DefaultHeaders: map[string]string{
134			"APP-Code": "IUFF7106",
135		},
136	}
137
138	for _, model := range modelsResp.Data {
139		if model.ContextLength < minContextWindow {
140			continue
141		}
142		if !hasField(model.InputModalities, "text") {
143			continue
144		}
145
146		canReason := hasField(model.Features, "thinking")
147		supportsImages := hasField(model.InputModalities, "image")
148
149		reasoningLevels, defaultReasoning := buildReasoningConfig(canReason)
150		maxTokens := calculateMaxTokens(model.ContextLength, model.MaxOutput, maxTokensFactor)
151
152		aiHubMixProvider.Models = append(aiHubMixProvider.Models, catwalk.Model{
153			ID:                     model.ModelID,
154			Name:                   model.ModelName,
155			CostPer1MIn:            parseFloat(model.Pricing.Input),
156			CostPer1MOut:           parseFloat(model.Pricing.Output),
157			CostPer1MInCached:      parseFloat(model.Pricing.CacheWrite),
158			CostPer1MOutCached:     parseFloat(model.Pricing.CacheRead),
159			ContextWindow:          model.ContextLength,
160			DefaultMaxTokens:       maxTokens,
161			CanReason:              canReason,
162			ReasoningLevels:        reasoningLevels,
163			DefaultReasoningEffort: defaultReasoning,
164			SupportsImages:         supportsImages,
165		})
166
167		fmt.Printf("Added model %s with context window %d\n",
168			model.ModelID, model.ContextLength)
169	}
170
171	if len(aiHubMixProvider.Models) == 0 {
172		log.Fatal("No models found or no models met the criteria")
173	}
174
175	slices.SortFunc(aiHubMixProvider.Models, func(a, b catwalk.Model) int {
176		return strings.Compare(a.ID, b.ID)
177	})
178
179	data, err := json.MarshalIndent(aiHubMixProvider, "", "  ")
180	if err != nil {
181		log.Fatal("Error marshaling AIHubMix provider:", err)
182	}
183
184	if err := os.WriteFile("internal/providers/configs/aihubmix.json", data, 0o600); err != nil {
185		log.Fatal("Error writing AIHubMix provider config:", err)
186	}
187
188	fmt.Printf("\nSuccessfully wrote %d models to internal/providers/configs/aihubmix.json\n", len(aiHubMixProvider.Models))
189}