1package models
2
3import (
4 "context"
5 "errors"
6
7 "github.com/cloudwego/eino-ext/components/model/claude"
8 "github.com/cloudwego/eino-ext/components/model/openai"
9 "github.com/cloudwego/eino/components/model"
10 "github.com/spf13/viper"
11)
12
13type (
14 ModelID string
15 ModelProvider string
16)
17
18type Model struct {
19 ID ModelID `json:"id"`
20 Name string `json:"name"`
21 Provider ModelProvider `json:"provider"`
22 APIModel string `json:"api_model"` // Actual value used when calling the API
23}
24
25const (
26 DefaultBigModel = GPT4oMini
27 DefaultLittleModel = GPT4oMini
28)
29
30// Model IDs
31const (
32 // OpenAI
33 GPT4o ModelID = "gpt-4o"
34 GPT4oMini ModelID = "gpt-4o-mini"
35 GPT45 ModelID = "gpt-4.5"
36 O1 ModelID = "o1"
37 O1Mini ModelID = "o1-mini"
38 // Anthropic
39 Claude35Sonnet ModelID = "claude-3.5-sonnet"
40 Claude3Haiku ModelID = "claude-3-haiku"
41 Claude37Sonnet ModelID = "claude-3.7-sonnet"
42 // Google
43 Gemini20Pro ModelID = "gemini-2.0-pro"
44 Gemini15Flash ModelID = "gemini-1.5-flash"
45 Gemini20Flash ModelID = "gemini-2.0-flash"
46 // xAI
47 Grok3 ModelID = "grok-3"
48 Grok2Mini ModelID = "grok-2-mini"
49 // DeepSeek
50 DeepSeekR1 ModelID = "deepseek-r1"
51 DeepSeekCoder ModelID = "deepseek-coder"
52 // Meta
53 Llama3 ModelID = "llama-3"
54 Llama270B ModelID = "llama-2-70b"
55)
56
57const (
58 ProviderOpenAI ModelProvider = "openai"
59 ProviderAnthropic ModelProvider = "anthropic"
60 ProviderGoogle ModelProvider = "google"
61 ProviderXAI ModelProvider = "xai"
62 ProviderDeepSeek ModelProvider = "deepseek"
63 ProviderMeta ModelProvider = "meta"
64)
65
66var SupportedModels = map[ModelID]Model{
67 // OpenAI
68 GPT4o: {
69 ID: GPT4o,
70 Name: "GPT-4o",
71 Provider: ProviderOpenAI,
72 APIModel: "gpt-4o",
73 },
74 GPT4oMini: {
75 ID: GPT4oMini,
76 Name: "GPT-4o Mini",
77 Provider: ProviderOpenAI,
78 APIModel: "gpt-4o-mini",
79 },
80 GPT45: {
81 ID: GPT45,
82 Name: "GPT-4.5",
83 Provider: ProviderOpenAI,
84 APIModel: "gpt-4.5",
85 },
86 O1: {
87 ID: O1,
88 Name: "o1",
89 Provider: ProviderOpenAI,
90 APIModel: "o1",
91 },
92 O1Mini: {
93 ID: O1Mini,
94 Name: "o1 Mini",
95 Provider: ProviderOpenAI,
96 APIModel: "o1-mini",
97 },
98 // Anthropic
99 Claude35Sonnet: {
100 ID: Claude35Sonnet,
101 Name: "Claude 3.5 Sonnet",
102 Provider: ProviderAnthropic,
103 APIModel: "claude-3.5-sonnet",
104 },
105 Claude3Haiku: {
106 ID: Claude3Haiku,
107 Name: "Claude 3 Haiku",
108 Provider: ProviderAnthropic,
109 APIModel: "claude-3-haiku",
110 },
111 Claude37Sonnet: {
112 ID: Claude37Sonnet,
113 Name: "Claude 3.7 Sonnet",
114 Provider: ProviderAnthropic,
115 APIModel: "claude-3-7-sonnet-20250219",
116 },
117 // Google
118 Gemini20Pro: {
119 ID: Gemini20Pro,
120 Name: "Gemini 2.0 Pro",
121 Provider: ProviderGoogle,
122 APIModel: "gemini-2.0-pro",
123 },
124 Gemini15Flash: {
125 ID: Gemini15Flash,
126 Name: "Gemini 1.5 Flash",
127 Provider: ProviderGoogle,
128 APIModel: "gemini-1.5-flash",
129 },
130 Gemini20Flash: {
131 ID: Gemini20Flash,
132 Name: "Gemini 2.0 Flash",
133 Provider: ProviderGoogle,
134 APIModel: "gemini-2.0-flash",
135 },
136 // xAI
137 Grok3: {
138 ID: Grok3,
139 Name: "Grok 3",
140 Provider: ProviderXAI,
141 APIModel: "grok-3",
142 },
143 Grok2Mini: {
144 ID: Grok2Mini,
145 Name: "Grok 2 Mini",
146 Provider: ProviderXAI,
147 APIModel: "grok-2-mini",
148 },
149 // DeepSeek
150 DeepSeekR1: {
151 ID: DeepSeekR1,
152 Name: "DeepSeek R1",
153 Provider: ProviderDeepSeek,
154 APIModel: "deepseek-r1",
155 },
156 DeepSeekCoder: {
157 ID: DeepSeekCoder,
158 Name: "DeepSeek Coder",
159 Provider: ProviderDeepSeek,
160 APIModel: "deepseek-coder",
161 },
162 // Meta
163 Llama3: {
164 ID: Llama3,
165 Name: "LLaMA 3",
166 Provider: ProviderMeta,
167 APIModel: "llama-3",
168 },
169 Llama270B: {
170 ID: Llama270B,
171 Name: "LLaMA 2 70B",
172 Provider: ProviderMeta,
173 APIModel: "llama-2-70b",
174 },
175}
176
177func GetModel(ctx context.Context, model ModelID) (model.ChatModel, error) {
178 provider := SupportedModels[model].Provider
179 maxTokens := viper.GetInt("providers.common.max_tokens")
180 switch provider {
181 case ProviderOpenAI:
182 return openai.NewChatModel(ctx, &openai.ChatModelConfig{
183 APIKey: viper.GetString("providers.openai.key"),
184 Model: string(SupportedModels[model].APIModel),
185 MaxTokens: &maxTokens,
186 })
187 case ProviderAnthropic:
188 return claude.NewChatModel(ctx, &claude.Config{
189 APIKey: viper.GetString("providers.anthropic.key"),
190 Model: string(SupportedModels[model].APIModel),
191 MaxTokens: maxTokens,
192 })
193
194 }
195 return nil, errors.New("unsupported provider")
196}