models.go

  1package models
  2
  3import (
  4	"context"
  5	"errors"
  6	"log"
  7
  8	"github.com/cloudwego/eino-ext/components/model/claude"
  9	"github.com/cloudwego/eino-ext/components/model/openai"
 10	"github.com/cloudwego/eino/components/model"
 11	"github.com/spf13/viper"
 12)
 13
 14type (
 15	ModelID       string
 16	ModelProvider string
 17)
 18
 19type Model struct {
 20	ID           ModelID       `json:"id"`
 21	Name         string        `json:"name"`
 22	Provider     ModelProvider `json:"provider"`
 23	APIModel     string        `json:"api_model"`
 24	CostPer1MIn  float64       `json:"cost_per_1m_in"`
 25	CostPer1MOut float64       `json:"cost_per_1m_out"`
 26}
 27
 28const (
 29	DefaultBigModel    = GPT4oMini
 30	DefaultLittleModel = GPT4oMini
 31)
 32
 33// Model IDs
 34const (
 35	// OpenAI
 36	GPT4o     ModelID = "gpt-4o"
 37	GPT4oMini ModelID = "gpt-4o-mini"
 38	GPT45     ModelID = "gpt-4.5"
 39	O1        ModelID = "o1"
 40	O1Mini    ModelID = "o1-mini"
 41	// Anthropic
 42	Claude35Sonnet ModelID = "claude-3.5-sonnet"
 43	Claude3Haiku   ModelID = "claude-3-haiku"
 44	Claude37Sonnet ModelID = "claude-3.7-sonnet"
 45	// Google
 46	Gemini20Pro   ModelID = "gemini-2.0-pro"
 47	Gemini15Flash ModelID = "gemini-1.5-flash"
 48	Gemini20Flash ModelID = "gemini-2.0-flash"
 49	// xAI
 50	Grok3     ModelID = "grok-3"
 51	Grok2Mini ModelID = "grok-2-mini"
 52	// DeepSeek
 53	DeepSeekR1    ModelID = "deepseek-r1"
 54	DeepSeekCoder ModelID = "deepseek-coder"
 55	// Meta
 56	Llama3    ModelID = "llama-3"
 57	Llama270B ModelID = "llama-2-70b"
 58	// GROQ
 59	GroqLlama3SpecDec ModelID = "groq-llama-3-spec-dec"
 60	GroqQwen32BCoder  ModelID = "qwen-2.5-coder-32b"
 61)
 62
 63const (
 64	ProviderOpenAI    ModelProvider = "openai"
 65	ProviderAnthropic ModelProvider = "anthropic"
 66	ProviderGoogle    ModelProvider = "google"
 67	ProviderXAI       ModelProvider = "xai"
 68	ProviderDeepSeek  ModelProvider = "deepseek"
 69	ProviderMeta      ModelProvider = "meta"
 70	ProviderGroq      ModelProvider = "groq"
 71)
 72
 73var SupportedModels = map[ModelID]Model{
 74	// OpenAI
 75	GPT4o: {
 76		ID:       GPT4o,
 77		Name:     "GPT-4o",
 78		Provider: ProviderOpenAI,
 79		APIModel: "gpt-4o",
 80	},
 81	GPT4oMini: {
 82		ID:           GPT4oMini,
 83		Name:         "GPT-4o Mini",
 84		Provider:     ProviderOpenAI,
 85		APIModel:     "gpt-4o-mini",
 86		CostPer1MIn:  0.150,
 87		CostPer1MOut: 0.600,
 88	},
 89	GPT45: {
 90		ID:       GPT45,
 91		Name:     "GPT-4.5",
 92		Provider: ProviderOpenAI,
 93		APIModel: "gpt-4.5",
 94	},
 95	O1: {
 96		ID:       O1,
 97		Name:     "o1",
 98		Provider: ProviderOpenAI,
 99		APIModel: "o1",
100	},
101	O1Mini: {
102		ID:       O1Mini,
103		Name:     "o1 Mini",
104		Provider: ProviderOpenAI,
105		APIModel: "o1-mini",
106	},
107	// Anthropic
108	Claude35Sonnet: {
109		ID:       Claude35Sonnet,
110		Name:     "Claude 3.5 Sonnet",
111		Provider: ProviderAnthropic,
112		APIModel: "claude-3.5-sonnet",
113	},
114	Claude3Haiku: {
115		ID:       Claude3Haiku,
116		Name:     "Claude 3 Haiku",
117		Provider: ProviderAnthropic,
118		APIModel: "claude-3-haiku",
119	},
120	Claude37Sonnet: {
121		ID:       Claude37Sonnet,
122		Name:     "Claude 3.7 Sonnet",
123		Provider: ProviderAnthropic,
124		APIModel: "claude-3-7-sonnet-20250219",
125	},
126	// Google
127	Gemini20Pro: {
128		ID:       Gemini20Pro,
129		Name:     "Gemini 2.0 Pro",
130		Provider: ProviderGoogle,
131		APIModel: "gemini-2.0-pro",
132	},
133	Gemini15Flash: {
134		ID:       Gemini15Flash,
135		Name:     "Gemini 1.5 Flash",
136		Provider: ProviderGoogle,
137		APIModel: "gemini-1.5-flash",
138	},
139	Gemini20Flash: {
140		ID:       Gemini20Flash,
141		Name:     "Gemini 2.0 Flash",
142		Provider: ProviderGoogle,
143		APIModel: "gemini-2.0-flash",
144	},
145	// xAI
146	Grok3: {
147		ID:       Grok3,
148		Name:     "Grok 3",
149		Provider: ProviderXAI,
150		APIModel: "grok-3",
151	},
152	Grok2Mini: {
153		ID:       Grok2Mini,
154		Name:     "Grok 2 Mini",
155		Provider: ProviderXAI,
156		APIModel: "grok-2-mini",
157	},
158	// DeepSeek
159	DeepSeekR1: {
160		ID:       DeepSeekR1,
161		Name:     "DeepSeek R1",
162		Provider: ProviderDeepSeek,
163		APIModel: "deepseek-r1",
164	},
165	DeepSeekCoder: {
166		ID:       DeepSeekCoder,
167		Name:     "DeepSeek Coder",
168		Provider: ProviderDeepSeek,
169		APIModel: "deepseek-coder",
170	},
171	// Meta
172	Llama3: {
173		ID:       Llama3,
174		Name:     "LLaMA 3",
175		Provider: ProviderMeta,
176		APIModel: "llama-3",
177	},
178	Llama270B: {
179		ID:       Llama270B,
180		Name:     "LLaMA 2 70B",
181		Provider: ProviderMeta,
182		APIModel: "llama-2-70b",
183	},
184
185	// GROQ
186	GroqLlama3SpecDec: {
187		ID:       GroqLlama3SpecDec,
188		Name:     "GROQ LLaMA 3 SpecDec",
189		Provider: ProviderGroq,
190		APIModel: "llama-3.3-70b-specdec",
191	},
192	GroqQwen32BCoder: {
193		ID:       GroqQwen32BCoder,
194		Name:     "GROQ Qwen 2.5 Coder 32B",
195		Provider: ProviderGroq,
196		APIModel: "qwen-2.5-coder-32b",
197	},
198}
199
200func GetModel(ctx context.Context, model ModelID) (model.ChatModel, error) {
201	provider := SupportedModels[model].Provider
202	log.Printf("Provider: %s", provider)
203	maxTokens := viper.GetInt("providers.common.max_tokens")
204	switch provider {
205	case ProviderOpenAI:
206		return openai.NewChatModel(ctx, &openai.ChatModelConfig{
207			APIKey:    viper.GetString("providers.openai.key"),
208			Model:     string(SupportedModels[model].APIModel),
209			MaxTokens: &maxTokens,
210		})
211	case ProviderAnthropic:
212		return claude.NewChatModel(ctx, &claude.Config{
213			APIKey:    viper.GetString("providers.anthropic.key"),
214			Model:     string(SupportedModels[model].APIModel),
215			MaxTokens: maxTokens,
216		})
217
218	case ProviderGroq:
219		return openai.NewChatModel(ctx, &openai.ChatModelConfig{
220			BaseURL:   "https://api.groq.com/openai/v1",
221			APIKey:    viper.GetString("providers.groq.key"),
222			Model:     string(SupportedModels[model].APIModel),
223			MaxTokens: &maxTokens,
224		})
225
226	}
227	return nil, errors.New("unsupported provider")
228}