main.go

  1// Package main provides a command-line tool to fetch models from Nebius Token Factory
  2// and generate a configuration file for the provider.
  3package main
  4
  5import (
  6	"context"
  7	"encoding/json"
  8	"fmt"
  9	"io"
 10	"log"
 11	"math"
 12	"net/http"
 13	"os"
 14	"slices"
 15	"strconv"
 16	"strings"
 17	"time"
 18
 19	"charm.land/catwalk/pkg/catwalk"
 20)
 21
 22// Model represents a model from the Nebius Token Factory API.
 23type Model struct {
 24	ID                string   `json:"id"`
 25	DisplayName       string   `json:"name"`
 26	ContextLength     int64    `json:"context_length"`
 27	SupportedFeatures []string `json:"supported_features,omitempty"`
 28	Pricing           Pricing  `json:"pricing"`
 29	Architecture      struct {
 30		Modality string `json:"modality"`
 31	} `json:"architecture,omitempty"`
 32}
 33
 34type Pricing struct {
 35	Prompt     string `json:"prompt"`
 36	Completion string `json:"completion"`
 37}
 38
 39type ModelsResponse struct {
 40	Data []Model `json:"data"`
 41}
 42
 43func (m Model) hasFeature(featureValue string) bool {
 44	if m.SupportedFeatures != nil {
 45		for _, feature := range m.SupportedFeatures {
 46			if strings.EqualFold(feature, featureValue) {
 47				return true
 48			}
 49		}
 50	}
 51	return false
 52}
 53
 54func fetchNebiusModels() (*ModelsResponse, error) {
 55	client := &http.Client{Timeout: 30 * time.Second}
 56	req, _ := http.NewRequestWithContext(
 57		context.Background(),
 58		"GET",
 59		"https://api.tokenfactory.nebius.com/v1/models?verbose=true",
 60		nil,
 61	)
 62	req.Header.Set("User-Agent", "Crush-Client/1.0")
 63
 64	// Read API key from environment variable
 65	apiKey := os.Getenv("NEBIUS_API_KEY")
 66	if apiKey == "" {
 67		return nil, fmt.Errorf("$NEBIUS_API_KEY is required")
 68	}
 69	req.Header.Set("Authorization", "Bearer "+apiKey)
 70
 71	resp, err := client.Do(req)
 72	if err != nil {
 73		return nil, err //nolint:wrapcheck
 74	}
 75	defer resp.Body.Close() //nolint:errcheck
 76	if resp.StatusCode != 200 {
 77		body, _ := io.ReadAll(resp.Body)
 78		return nil, fmt.Errorf("status %d: %s", resp.StatusCode, body)
 79	}
 80	var mr ModelsResponse
 81	if err := json.NewDecoder(resp.Body).Decode(&mr); err != nil {
 82		return nil, err //nolint:wrapcheck
 83	}
 84	return &mr, nil
 85}
 86
 87func main() {
 88	modelsResp, err := fetchNebiusModels()
 89	if err != nil {
 90		log.Fatal("Error fetching Nebius models:", err)
 91	}
 92
 93	nebiusProvider := catwalk.Provider{
 94		Name:                "Nebius Token Factory",
 95		ID:                  catwalk.InferenceProviderNebius,
 96		APIKey:              "$NEBIUS_API_KEY",
 97		APIEndpoint:         "https://api.tokenfactory.nebius.com/v1", // this is their default region, eu-north1
 98		Type:                catwalk.TypeOpenAICompat,
 99		DefaultLargeModelID: "moonshotai/Kimi-K2.5",
100		DefaultSmallModelID: "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B",
101	}
102
103	for _, model := range modelsResp.Data {
104		// we skip models that don't support tool calling
105		if !model.hasFeature("tools") {
106			continue
107		}
108
109		// Convert pricing from string to float64
110		var costPer1MIn, costPer1MOut float64
111
112		// Handle prompt price conversion
113		promptPrice, err := strconv.ParseFloat(model.Pricing.Prompt, 64)
114		if err != nil {
115			promptPrice = 0.0
116		}
117		costPer1MIn = math.Round(promptPrice*1_000_000*100) / 100 // Round to 2 decimal places
118
119		// Handle completion price conversion
120		completionPrice, err := strconv.ParseFloat(model.Pricing.Completion, 64)
121		if err != nil {
122			completionPrice = 0.0
123		}
124		costPer1MOut = math.Round(completionPrice*1_000_000*100) / 100 // Round to 2 decimal places
125
126		var (
127			supportsImages   = strings.Contains(strings.ToLower(model.Architecture.Modality), "image")
128			canReason        = model.hasFeature("reasoning")
129			reasoningLevels  []string
130			defaultReasoning string
131		)
132		if canReason {
133			reasoningLevels = []string{"low", "medium", "high"}
134			defaultReasoning = "medium"
135		}
136
137		m := catwalk.Model{
138			ID:                     model.ID,
139			Name:                   model.DisplayName,
140			CostPer1MIn:            costPer1MIn,
141			CostPer1MOut:           costPer1MOut,
142			CostPer1MInCached:      0,
143			CostPer1MOutCached:     0,
144			ContextWindow:          model.ContextLength,
145			DefaultMaxTokens:       model.ContextLength / 10, // there is no MaxTokens exposed, so play safe
146			CanReason:              canReason,
147			ReasoningLevels:        reasoningLevels,
148			DefaultReasoningEffort: defaultReasoning,
149			SupportsImages:         supportsImages,
150		}
151
152		nebiusProvider.Models = append(nebiusProvider.Models, m)
153		fmt.Printf("Added model %s with context window %d\n", model.ID, model.ContextLength)
154	}
155
156	slices.SortFunc(nebiusProvider.Models, func(a catwalk.Model, b catwalk.Model) int {
157		return strings.Compare(a.Name, b.Name)
158	})
159
160	// Save the JSON in internal/providers/configs/nebius.json
161	data, err := json.MarshalIndent(nebiusProvider, "", "  ")
162	if err != nil {
163		log.Fatal("Error marshaling Nebius provider:", err)
164	}
165	data = append(data, '\n')
166
167	if err := os.WriteFile("internal/providers/configs/nebius.json", data, 0o600); err != nil {
168		log.Fatal("Error writing Nebius provider config:", err)
169	}
170
171	fmt.Printf("Generated nebius.json with %d models\n", len(nebiusProvider.Models))
172}