models.go

  1package models
  2
  3import (
  4	"context"
  5	"errors"
  6
  7	"github.com/cloudwego/eino-ext/components/model/claude"
  8	"github.com/cloudwego/eino-ext/components/model/openai"
  9	"github.com/cloudwego/eino/components/model"
 10	"github.com/spf13/viper"
 11)
 12
 13type (
 14	ModelID       string
 15	ModelProvider string
 16)
 17
 18type Model struct {
 19	ID       ModelID       `json:"id"`
 20	Name     string        `json:"name"`
 21	Provider ModelProvider `json:"provider"`
 22	APIModel string        `json:"api_model"` // Actual value used when calling the API
 23}
 24
 25const (
 26	DefaultBigModel    = GPT4oMini
 27	DefaultLittleModel = GPT4oMini
 28)
 29
 30// Model IDs
 31const (
 32	// OpenAI
 33	GPT4o     ModelID = "gpt-4o"
 34	GPT4oMini ModelID = "gpt-4o-mini"
 35	GPT45     ModelID = "gpt-4.5"
 36	O1        ModelID = "o1"
 37	O1Mini    ModelID = "o1-mini"
 38	// Anthropic
 39	Claude35Sonnet ModelID = "claude-3.5-sonnet"
 40	Claude3Haiku   ModelID = "claude-3-haiku"
 41	Claude37Sonnet ModelID = "claude-3.7-sonnet"
 42	// Google
 43	Gemini20Pro   ModelID = "gemini-2.0-pro"
 44	Gemini15Flash ModelID = "gemini-1.5-flash"
 45	Gemini20Flash ModelID = "gemini-2.0-flash"
 46	// xAI
 47	Grok3     ModelID = "grok-3"
 48	Grok2Mini ModelID = "grok-2-mini"
 49	// DeepSeek
 50	DeepSeekR1    ModelID = "deepseek-r1"
 51	DeepSeekCoder ModelID = "deepseek-coder"
 52	// Meta
 53	Llama3    ModelID = "llama-3"
 54	Llama270B ModelID = "llama-2-70b"
 55)
 56
 57const (
 58	ProviderOpenAI    ModelProvider = "openai"
 59	ProviderAnthropic ModelProvider = "anthropic"
 60	ProviderGoogle    ModelProvider = "google"
 61	ProviderXAI       ModelProvider = "xai"
 62	ProviderDeepSeek  ModelProvider = "deepseek"
 63	ProviderMeta      ModelProvider = "meta"
 64)
 65
 66var SupportedModels = map[ModelID]Model{
 67	// OpenAI
 68	GPT4o: {
 69		ID:       GPT4o,
 70		Name:     "GPT-4o",
 71		Provider: ProviderOpenAI,
 72		APIModel: "gpt-4o",
 73	},
 74	GPT4oMini: {
 75		ID:       GPT4oMini,
 76		Name:     "GPT-4o Mini",
 77		Provider: ProviderOpenAI,
 78		APIModel: "gpt-4o-mini",
 79	},
 80	GPT45: {
 81		ID:       GPT45,
 82		Name:     "GPT-4.5",
 83		Provider: ProviderOpenAI,
 84		APIModel: "gpt-4.5",
 85	},
 86	O1: {
 87		ID:       O1,
 88		Name:     "o1",
 89		Provider: ProviderOpenAI,
 90		APIModel: "o1",
 91	},
 92	O1Mini: {
 93		ID:       O1Mini,
 94		Name:     "o1 Mini",
 95		Provider: ProviderOpenAI,
 96		APIModel: "o1-mini",
 97	},
 98	// Anthropic
 99	Claude35Sonnet: {
100		ID:       Claude35Sonnet,
101		Name:     "Claude 3.5 Sonnet",
102		Provider: ProviderAnthropic,
103		APIModel: "claude-3.5-sonnet",
104	},
105	Claude3Haiku: {
106		ID:       Claude3Haiku,
107		Name:     "Claude 3 Haiku",
108		Provider: ProviderAnthropic,
109		APIModel: "claude-3-haiku",
110	},
111	Claude37Sonnet: {
112		ID:       Claude37Sonnet,
113		Name:     "Claude 3.7 Sonnet",
114		Provider: ProviderAnthropic,
115		APIModel: "claude-3-7-sonnet-20250219",
116	},
117	// Google
118	Gemini20Pro: {
119		ID:       Gemini20Pro,
120		Name:     "Gemini 2.0 Pro",
121		Provider: ProviderGoogle,
122		APIModel: "gemini-2.0-pro",
123	},
124	Gemini15Flash: {
125		ID:       Gemini15Flash,
126		Name:     "Gemini 1.5 Flash",
127		Provider: ProviderGoogle,
128		APIModel: "gemini-1.5-flash",
129	},
130	Gemini20Flash: {
131		ID:       Gemini20Flash,
132		Name:     "Gemini 2.0 Flash",
133		Provider: ProviderGoogle,
134		APIModel: "gemini-2.0-flash",
135	},
136	// xAI
137	Grok3: {
138		ID:       Grok3,
139		Name:     "Grok 3",
140		Provider: ProviderXAI,
141		APIModel: "grok-3",
142	},
143	Grok2Mini: {
144		ID:       Grok2Mini,
145		Name:     "Grok 2 Mini",
146		Provider: ProviderXAI,
147		APIModel: "grok-2-mini",
148	},
149	// DeepSeek
150	DeepSeekR1: {
151		ID:       DeepSeekR1,
152		Name:     "DeepSeek R1",
153		Provider: ProviderDeepSeek,
154		APIModel: "deepseek-r1",
155	},
156	DeepSeekCoder: {
157		ID:       DeepSeekCoder,
158		Name:     "DeepSeek Coder",
159		Provider: ProviderDeepSeek,
160		APIModel: "deepseek-coder",
161	},
162	// Meta
163	Llama3: {
164		ID:       Llama3,
165		Name:     "LLaMA 3",
166		Provider: ProviderMeta,
167		APIModel: "llama-3",
168	},
169	Llama270B: {
170		ID:       Llama270B,
171		Name:     "LLaMA 2 70B",
172		Provider: ProviderMeta,
173		APIModel: "llama-2-70b",
174	},
175}
176
177func GetModel(ctx context.Context, model ModelID) (model.ChatModel, error) {
178	provider := SupportedModels[model].Provider
179	maxTokens := viper.GetInt("providers.common.max_tokens")
180	switch provider {
181	case ProviderOpenAI:
182		return openai.NewChatModel(ctx, &openai.ChatModelConfig{
183			APIKey:    viper.GetString("providers.openai.key"),
184			Model:     string(SupportedModels[model].APIModel),
185			MaxTokens: &maxTokens,
186		})
187	case ProviderAnthropic:
188		return claude.NewChatModel(ctx, &claude.Config{
189			APIKey:    viper.GetString("providers.anthropic.key"),
190			Model:     string(SupportedModels[model].APIModel),
191			MaxTokens: maxTokens,
192		})
193
194	}
195	return nil, errors.New("unsupported provider")
196}