1package models
2
3import (
4 "context"
5 "errors"
6 "log"
7
8 "github.com/cloudwego/eino-ext/components/model/claude"
9 "github.com/cloudwego/eino-ext/components/model/openai"
10 "github.com/cloudwego/eino/components/model"
11 "github.com/spf13/viper"
12)
13
14type (
15 ModelID string
16 ModelProvider string
17)
18
19type Model struct {
20 ID ModelID `json:"id"`
21 Name string `json:"name"`
22 Provider ModelProvider `json:"provider"`
23 APIModel string `json:"api_model"`
24 CostPer1MIn float64 `json:"cost_per_1m_in"`
25 CostPer1MOut float64 `json:"cost_per_1m_out"`
26}
27
28const (
29 DefaultBigModel = GPT4oMini
30 DefaultLittleModel = GPT4oMini
31)
32
33// Model IDs
34const (
35 // OpenAI
36 GPT4o ModelID = "gpt-4o"
37 GPT4oMini ModelID = "gpt-4o-mini"
38 GPT45 ModelID = "gpt-4.5"
39 O1 ModelID = "o1"
40 O1Mini ModelID = "o1-mini"
41 // Anthropic
42 Claude35Sonnet ModelID = "claude-3.5-sonnet"
43 Claude3Haiku ModelID = "claude-3-haiku"
44 Claude37Sonnet ModelID = "claude-3.7-sonnet"
45 // Google
46 Gemini20Pro ModelID = "gemini-2.0-pro"
47 Gemini15Flash ModelID = "gemini-1.5-flash"
48 Gemini20Flash ModelID = "gemini-2.0-flash"
49 // xAI
50 Grok3 ModelID = "grok-3"
51 Grok2Mini ModelID = "grok-2-mini"
52 // DeepSeek
53 DeepSeekR1 ModelID = "deepseek-r1"
54 DeepSeekCoder ModelID = "deepseek-coder"
55 // Meta
56 Llama3 ModelID = "llama-3"
57 Llama270B ModelID = "llama-2-70b"
58 // GROQ
59 GroqLlama3SpecDec ModelID = "groq-llama-3-spec-dec"
60 GroqQwen32BCoder ModelID = "qwen-2.5-coder-32b"
61)
62
63const (
64 ProviderOpenAI ModelProvider = "openai"
65 ProviderAnthropic ModelProvider = "anthropic"
66 ProviderGoogle ModelProvider = "google"
67 ProviderXAI ModelProvider = "xai"
68 ProviderDeepSeek ModelProvider = "deepseek"
69 ProviderMeta ModelProvider = "meta"
70 ProviderGroq ModelProvider = "groq"
71)
72
73var SupportedModels = map[ModelID]Model{
74 // OpenAI
75 GPT4o: {
76 ID: GPT4o,
77 Name: "GPT-4o",
78 Provider: ProviderOpenAI,
79 APIModel: "gpt-4o",
80 },
81 GPT4oMini: {
82 ID: GPT4oMini,
83 Name: "GPT-4o Mini",
84 Provider: ProviderOpenAI,
85 APIModel: "gpt-4o-mini",
86 CostPer1MIn: 0.150,
87 CostPer1MOut: 0.600,
88 },
89 GPT45: {
90 ID: GPT45,
91 Name: "GPT-4.5",
92 Provider: ProviderOpenAI,
93 APIModel: "gpt-4.5",
94 },
95 O1: {
96 ID: O1,
97 Name: "o1",
98 Provider: ProviderOpenAI,
99 APIModel: "o1",
100 },
101 O1Mini: {
102 ID: O1Mini,
103 Name: "o1 Mini",
104 Provider: ProviderOpenAI,
105 APIModel: "o1-mini",
106 },
107 // Anthropic
108 Claude35Sonnet: {
109 ID: Claude35Sonnet,
110 Name: "Claude 3.5 Sonnet",
111 Provider: ProviderAnthropic,
112 APIModel: "claude-3.5-sonnet",
113 },
114 Claude3Haiku: {
115 ID: Claude3Haiku,
116 Name: "Claude 3 Haiku",
117 Provider: ProviderAnthropic,
118 APIModel: "claude-3-haiku",
119 },
120 Claude37Sonnet: {
121 ID: Claude37Sonnet,
122 Name: "Claude 3.7 Sonnet",
123 Provider: ProviderAnthropic,
124 APIModel: "claude-3-7-sonnet-20250219",
125 },
126 // Google
127 Gemini20Pro: {
128 ID: Gemini20Pro,
129 Name: "Gemini 2.0 Pro",
130 Provider: ProviderGoogle,
131 APIModel: "gemini-2.0-pro",
132 },
133 Gemini15Flash: {
134 ID: Gemini15Flash,
135 Name: "Gemini 1.5 Flash",
136 Provider: ProviderGoogle,
137 APIModel: "gemini-1.5-flash",
138 },
139 Gemini20Flash: {
140 ID: Gemini20Flash,
141 Name: "Gemini 2.0 Flash",
142 Provider: ProviderGoogle,
143 APIModel: "gemini-2.0-flash",
144 },
145 // xAI
146 Grok3: {
147 ID: Grok3,
148 Name: "Grok 3",
149 Provider: ProviderXAI,
150 APIModel: "grok-3",
151 },
152 Grok2Mini: {
153 ID: Grok2Mini,
154 Name: "Grok 2 Mini",
155 Provider: ProviderXAI,
156 APIModel: "grok-2-mini",
157 },
158 // DeepSeek
159 DeepSeekR1: {
160 ID: DeepSeekR1,
161 Name: "DeepSeek R1",
162 Provider: ProviderDeepSeek,
163 APIModel: "deepseek-r1",
164 },
165 DeepSeekCoder: {
166 ID: DeepSeekCoder,
167 Name: "DeepSeek Coder",
168 Provider: ProviderDeepSeek,
169 APIModel: "deepseek-coder",
170 },
171 // Meta
172 Llama3: {
173 ID: Llama3,
174 Name: "LLaMA 3",
175 Provider: ProviderMeta,
176 APIModel: "llama-3",
177 },
178 Llama270B: {
179 ID: Llama270B,
180 Name: "LLaMA 2 70B",
181 Provider: ProviderMeta,
182 APIModel: "llama-2-70b",
183 },
184
185 // GROQ
186 GroqLlama3SpecDec: {
187 ID: GroqLlama3SpecDec,
188 Name: "GROQ LLaMA 3 SpecDec",
189 Provider: ProviderGroq,
190 APIModel: "llama-3.3-70b-specdec",
191 },
192 GroqQwen32BCoder: {
193 ID: GroqQwen32BCoder,
194 Name: "GROQ Qwen 2.5 Coder 32B",
195 Provider: ProviderGroq,
196 APIModel: "qwen-2.5-coder-32b",
197 },
198}
199
200func GetModel(ctx context.Context, model ModelID) (model.ChatModel, error) {
201 provider := SupportedModels[model].Provider
202 log.Printf("Provider: %s", provider)
203 maxTokens := viper.GetInt("providers.common.max_tokens")
204 switch provider {
205 case ProviderOpenAI:
206 return openai.NewChatModel(ctx, &openai.ChatModelConfig{
207 APIKey: viper.GetString("providers.openai.key"),
208 Model: string(SupportedModels[model].APIModel),
209 MaxTokens: &maxTokens,
210 })
211 case ProviderAnthropic:
212 return claude.NewChatModel(ctx, &claude.Config{
213 APIKey: viper.GetString("providers.anthropic.key"),
214 Model: string(SupportedModels[model].APIModel),
215 MaxTokens: maxTokens,
216 })
217
218 case ProviderGroq:
219 return openai.NewChatModel(ctx, &openai.ChatModelConfig{
220 BaseURL: "https://api.groq.com/openai/v1",
221 APIKey: viper.GetString("providers.groq.key"),
222 Model: string(SupportedModels[model].APIModel),
223 MaxTokens: &maxTokens,
224 })
225
226 }
227 return nil, errors.New("unsupported provider")
228}