1package models
2
3import (
4 "context"
5 "errors"
6 "log"
7
8 "github.com/cloudwego/eino-ext/components/model/claude"
9 "github.com/cloudwego/eino-ext/components/model/openai"
10 "github.com/cloudwego/eino/components/model"
11 "github.com/spf13/viper"
12)
13
14type (
15 ModelID string
16 ModelProvider string
17)
18
19type Model struct {
20 ID ModelID `json:"id"`
21 Name string `json:"name"`
22 Provider ModelProvider `json:"provider"`
23 APIModel string `json:"api_model"`
24 CostPer1MIn float64 `json:"cost_per_1m_in"`
25 CostPer1MOut float64 `json:"cost_per_1m_out"`
26}
27
28const (
29 DefaultBigModel = Claude37Sonnet
30 DefaultLittleModel = Claude37Sonnet
31)
32
33// Model IDs
34const (
35 // OpenAI
36 GPT4o ModelID = "gpt-4o"
37 GPT4oMini ModelID = "gpt-4o-mini"
38 GPT45 ModelID = "gpt-4.5"
39 O1 ModelID = "o1"
40 O1Mini ModelID = "o1-mini"
41 // Anthropic
42 Claude35Sonnet ModelID = "claude-3.5-sonnet"
43 Claude3Haiku ModelID = "claude-3-haiku"
44 Claude37Sonnet ModelID = "claude-3.7-sonnet"
45 // Google
46 Gemini20Pro ModelID = "gemini-2.0-pro"
47 Gemini15Flash ModelID = "gemini-1.5-flash"
48 Gemini20Flash ModelID = "gemini-2.0-flash"
49 // xAI
50 Grok3 ModelID = "grok-3"
51 Grok2Mini ModelID = "grok-2-mini"
52 // DeepSeek
53 DeepSeekR1 ModelID = "deepseek-r1"
54 DeepSeekCoder ModelID = "deepseek-coder"
55 // Meta
56 Llama3 ModelID = "llama-3"
57 Llama270B ModelID = "llama-2-70b"
58 // GROQ
59 GroqLlama3SpecDec ModelID = "groq-llama-3-spec-dec"
60 GroqQwen32BCoder ModelID = "qwen-2.5-coder-32b"
61)
62
63const (
64 ProviderOpenAI ModelProvider = "openai"
65 ProviderAnthropic ModelProvider = "anthropic"
66 ProviderGoogle ModelProvider = "google"
67 ProviderXAI ModelProvider = "xai"
68 ProviderDeepSeek ModelProvider = "deepseek"
69 ProviderMeta ModelProvider = "meta"
70 ProviderGroq ModelProvider = "groq"
71)
72
73var SupportedModels = map[ModelID]Model{
74 // OpenAI
75 GPT4o: {
76 ID: GPT4o,
77 Name: "GPT-4o",
78 Provider: ProviderOpenAI,
79 APIModel: "gpt-4o",
80 },
81 GPT4oMini: {
82 ID: GPT4oMini,
83 Name: "GPT-4o Mini",
84 Provider: ProviderOpenAI,
85 APIModel: "gpt-4o-mini",
86 CostPer1MIn: 0.150,
87 CostPer1MOut: 0.600,
88 },
89 GPT45: {
90 ID: GPT45,
91 Name: "GPT-4.5",
92 Provider: ProviderOpenAI,
93 APIModel: "gpt-4.5",
94 },
95 O1: {
96 ID: O1,
97 Name: "o1",
98 Provider: ProviderOpenAI,
99 APIModel: "o1",
100 },
101 O1Mini: {
102 ID: O1Mini,
103 Name: "o1 Mini",
104 Provider: ProviderOpenAI,
105 APIModel: "o1-mini",
106 },
107 // Anthropic
108 Claude35Sonnet: {
109 ID: Claude35Sonnet,
110 Name: "Claude 3.5 Sonnet",
111 Provider: ProviderAnthropic,
112 APIModel: "claude-3.5-sonnet",
113 },
114 Claude3Haiku: {
115 ID: Claude3Haiku,
116 Name: "Claude 3 Haiku",
117 Provider: ProviderAnthropic,
118 APIModel: "claude-3-haiku",
119 },
120 Claude37Sonnet: {
121 ID: Claude37Sonnet,
122 Name: "Claude 3.7 Sonnet",
123 Provider: ProviderAnthropic,
124 APIModel: "claude-3-7-sonnet-20250219",
125 CostPer1MIn: 3.0,
126 CostPer1MOut: 15.0,
127 },
128 // Google
129 Gemini20Pro: {
130 ID: Gemini20Pro,
131 Name: "Gemini 2.0 Pro",
132 Provider: ProviderGoogle,
133 APIModel: "gemini-2.0-pro",
134 },
135 Gemini15Flash: {
136 ID: Gemini15Flash,
137 Name: "Gemini 1.5 Flash",
138 Provider: ProviderGoogle,
139 APIModel: "gemini-1.5-flash",
140 },
141 Gemini20Flash: {
142 ID: Gemini20Flash,
143 Name: "Gemini 2.0 Flash",
144 Provider: ProviderGoogle,
145 APIModel: "gemini-2.0-flash",
146 },
147 // xAI
148 Grok3: {
149 ID: Grok3,
150 Name: "Grok 3",
151 Provider: ProviderXAI,
152 APIModel: "grok-3",
153 },
154 Grok2Mini: {
155 ID: Grok2Mini,
156 Name: "Grok 2 Mini",
157 Provider: ProviderXAI,
158 APIModel: "grok-2-mini",
159 },
160 // DeepSeek
161 DeepSeekR1: {
162 ID: DeepSeekR1,
163 Name: "DeepSeek R1",
164 Provider: ProviderDeepSeek,
165 APIModel: "deepseek-r1",
166 },
167 DeepSeekCoder: {
168 ID: DeepSeekCoder,
169 Name: "DeepSeek Coder",
170 Provider: ProviderDeepSeek,
171 APIModel: "deepseek-coder",
172 },
173 // Meta
174 Llama3: {
175 ID: Llama3,
176 Name: "LLaMA 3",
177 Provider: ProviderMeta,
178 APIModel: "llama-3",
179 },
180 Llama270B: {
181 ID: Llama270B,
182 Name: "LLaMA 2 70B",
183 Provider: ProviderMeta,
184 APIModel: "llama-2-70b",
185 },
186
187 // GROQ
188 GroqLlama3SpecDec: {
189 ID: GroqLlama3SpecDec,
190 Name: "GROQ LLaMA 3 SpecDec",
191 Provider: ProviderGroq,
192 APIModel: "llama-3.3-70b-specdec",
193 },
194 GroqQwen32BCoder: {
195 ID: GroqQwen32BCoder,
196 Name: "GROQ Qwen 2.5 Coder 32B",
197 Provider: ProviderGroq,
198 APIModel: "qwen-2.5-coder-32b",
199 },
200}
201
202func GetModel(ctx context.Context, model ModelID) (model.ChatModel, error) {
203 provider := SupportedModels[model].Provider
204 log.Printf("Provider: %s", provider)
205 maxTokens := viper.GetInt("providers.common.max_tokens")
206 switch provider {
207 case ProviderOpenAI:
208 return openai.NewChatModel(ctx, &openai.ChatModelConfig{
209 APIKey: viper.GetString("providers.openai.key"),
210 Model: string(SupportedModels[model].APIModel),
211 MaxTokens: &maxTokens,
212 })
213 case ProviderAnthropic:
214 return claude.NewChatModel(ctx, &claude.Config{
215 APIKey: viper.GetString("providers.anthropic.key"),
216 Model: string(SupportedModels[model].APIModel),
217 MaxTokens: maxTokens,
218 })
219
220 case ProviderGroq:
221 return openai.NewChatModel(ctx, &openai.ChatModelConfig{
222 BaseURL: "https://api.groq.com/openai/v1",
223 APIKey: viper.GetString("providers.groq.key"),
224 Model: string(SupportedModels[model].APIModel),
225 MaxTokens: &maxTokens,
226 })
227
228 }
229 return nil, errors.New("unsupported provider")
230}