models.go

  1package models
  2
  3import (
  4	"context"
  5	"errors"
  6	"log"
  7
  8	"github.com/cloudwego/eino-ext/components/model/claude"
  9	"github.com/cloudwego/eino-ext/components/model/openai"
 10	"github.com/cloudwego/eino/components/model"
 11	"github.com/spf13/viper"
 12)
 13
 14type (
 15	ModelID       string
 16	ModelProvider string
 17)
 18
 19type Model struct {
 20	ID           ModelID       `json:"id"`
 21	Name         string        `json:"name"`
 22	Provider     ModelProvider `json:"provider"`
 23	APIModel     string        `json:"api_model"`
 24	CostPer1MIn  float64       `json:"cost_per_1m_in"`
 25	CostPer1MOut float64       `json:"cost_per_1m_out"`
 26}
 27
 28const (
 29	DefaultBigModel    = Claude37Sonnet
 30	DefaultLittleModel = Claude37Sonnet
 31)
 32
 33// Model IDs
 34const (
 35	// OpenAI
 36	GPT4o     ModelID = "gpt-4o"
 37	GPT4oMini ModelID = "gpt-4o-mini"
 38	GPT45     ModelID = "gpt-4.5"
 39	O1        ModelID = "o1"
 40	O1Mini    ModelID = "o1-mini"
 41	// Anthropic
 42	Claude35Sonnet ModelID = "claude-3.5-sonnet"
 43	Claude3Haiku   ModelID = "claude-3-haiku"
 44	Claude37Sonnet ModelID = "claude-3.7-sonnet"
 45	// Google
 46	Gemini20Pro   ModelID = "gemini-2.0-pro"
 47	Gemini15Flash ModelID = "gemini-1.5-flash"
 48	Gemini20Flash ModelID = "gemini-2.0-flash"
 49	// xAI
 50	Grok3     ModelID = "grok-3"
 51	Grok2Mini ModelID = "grok-2-mini"
 52	// DeepSeek
 53	DeepSeekR1    ModelID = "deepseek-r1"
 54	DeepSeekCoder ModelID = "deepseek-coder"
 55	// Meta
 56	Llama3    ModelID = "llama-3"
 57	Llama270B ModelID = "llama-2-70b"
 58	// GROQ
 59	GroqLlama3SpecDec ModelID = "groq-llama-3-spec-dec"
 60	GroqQwen32BCoder  ModelID = "qwen-2.5-coder-32b"
 61)
 62
 63const (
 64	ProviderOpenAI    ModelProvider = "openai"
 65	ProviderAnthropic ModelProvider = "anthropic"
 66	ProviderGoogle    ModelProvider = "google"
 67	ProviderXAI       ModelProvider = "xai"
 68	ProviderDeepSeek  ModelProvider = "deepseek"
 69	ProviderMeta      ModelProvider = "meta"
 70	ProviderGroq      ModelProvider = "groq"
 71)
 72
 73var SupportedModels = map[ModelID]Model{
 74	// OpenAI
 75	GPT4o: {
 76		ID:       GPT4o,
 77		Name:     "GPT-4o",
 78		Provider: ProviderOpenAI,
 79		APIModel: "gpt-4o",
 80	},
 81	GPT4oMini: {
 82		ID:           GPT4oMini,
 83		Name:         "GPT-4o Mini",
 84		Provider:     ProviderOpenAI,
 85		APIModel:     "gpt-4o-mini",
 86		CostPer1MIn:  0.150,
 87		CostPer1MOut: 0.600,
 88	},
 89	GPT45: {
 90		ID:       GPT45,
 91		Name:     "GPT-4.5",
 92		Provider: ProviderOpenAI,
 93		APIModel: "gpt-4.5",
 94	},
 95	O1: {
 96		ID:       O1,
 97		Name:     "o1",
 98		Provider: ProviderOpenAI,
 99		APIModel: "o1",
100	},
101	O1Mini: {
102		ID:       O1Mini,
103		Name:     "o1 Mini",
104		Provider: ProviderOpenAI,
105		APIModel: "o1-mini",
106	},
107	// Anthropic
108	Claude35Sonnet: {
109		ID:       Claude35Sonnet,
110		Name:     "Claude 3.5 Sonnet",
111		Provider: ProviderAnthropic,
112		APIModel: "claude-3.5-sonnet",
113	},
114	Claude3Haiku: {
115		ID:       Claude3Haiku,
116		Name:     "Claude 3 Haiku",
117		Provider: ProviderAnthropic,
118		APIModel: "claude-3-haiku",
119	},
120	Claude37Sonnet: {
121		ID:           Claude37Sonnet,
122		Name:         "Claude 3.7 Sonnet",
123		Provider:     ProviderAnthropic,
124		APIModel:     "claude-3-7-sonnet-20250219",
125		CostPer1MIn:  3.0,
126		CostPer1MOut: 15.0,
127	},
128	// Google
129	Gemini20Pro: {
130		ID:       Gemini20Pro,
131		Name:     "Gemini 2.0 Pro",
132		Provider: ProviderGoogle,
133		APIModel: "gemini-2.0-pro",
134	},
135	Gemini15Flash: {
136		ID:       Gemini15Flash,
137		Name:     "Gemini 1.5 Flash",
138		Provider: ProviderGoogle,
139		APIModel: "gemini-1.5-flash",
140	},
141	Gemini20Flash: {
142		ID:       Gemini20Flash,
143		Name:     "Gemini 2.0 Flash",
144		Provider: ProviderGoogle,
145		APIModel: "gemini-2.0-flash",
146	},
147	// xAI
148	Grok3: {
149		ID:       Grok3,
150		Name:     "Grok 3",
151		Provider: ProviderXAI,
152		APIModel: "grok-3",
153	},
154	Grok2Mini: {
155		ID:       Grok2Mini,
156		Name:     "Grok 2 Mini",
157		Provider: ProviderXAI,
158		APIModel: "grok-2-mini",
159	},
160	// DeepSeek
161	DeepSeekR1: {
162		ID:       DeepSeekR1,
163		Name:     "DeepSeek R1",
164		Provider: ProviderDeepSeek,
165		APIModel: "deepseek-r1",
166	},
167	DeepSeekCoder: {
168		ID:       DeepSeekCoder,
169		Name:     "DeepSeek Coder",
170		Provider: ProviderDeepSeek,
171		APIModel: "deepseek-coder",
172	},
173	// Meta
174	Llama3: {
175		ID:       Llama3,
176		Name:     "LLaMA 3",
177		Provider: ProviderMeta,
178		APIModel: "llama-3",
179	},
180	Llama270B: {
181		ID:       Llama270B,
182		Name:     "LLaMA 2 70B",
183		Provider: ProviderMeta,
184		APIModel: "llama-2-70b",
185	},
186
187	// GROQ
188	GroqLlama3SpecDec: {
189		ID:       GroqLlama3SpecDec,
190		Name:     "GROQ LLaMA 3 SpecDec",
191		Provider: ProviderGroq,
192		APIModel: "llama-3.3-70b-specdec",
193	},
194	GroqQwen32BCoder: {
195		ID:       GroqQwen32BCoder,
196		Name:     "GROQ Qwen 2.5 Coder 32B",
197		Provider: ProviderGroq,
198		APIModel: "qwen-2.5-coder-32b",
199	},
200}
201
202func GetModel(ctx context.Context, model ModelID) (model.ChatModel, error) {
203	provider := SupportedModels[model].Provider
204	log.Printf("Provider: %s", provider)
205	maxTokens := viper.GetInt("providers.common.max_tokens")
206	switch provider {
207	case ProviderOpenAI:
208		return openai.NewChatModel(ctx, &openai.ChatModelConfig{
209			APIKey:    viper.GetString("providers.openai.key"),
210			Model:     string(SupportedModels[model].APIModel),
211			MaxTokens: &maxTokens,
212		})
213	case ProviderAnthropic:
214		return claude.NewChatModel(ctx, &claude.Config{
215			APIKey:    viper.GetString("providers.anthropic.key"),
216			Model:     string(SupportedModels[model].APIModel),
217			MaxTokens: maxTokens,
218		})
219
220	case ProviderGroq:
221		return openai.NewChatModel(ctx, &openai.ChatModelConfig{
222			BaseURL:   "https://api.groq.com/openai/v1",
223			APIKey:    viper.GetString("providers.groq.key"),
224			Model:     string(SupportedModels[model].APIModel),
225			MaxTokens: &maxTokens,
226		})
227
228	}
229	return nil, errors.New("unsupported provider")
230}