1use anyhow::anyhow;
2use serde::{Deserialize, Serialize};
3use strum::EnumIter;
4
5#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
6#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
7pub enum BedrockModelMode {
8 #[default]
9 Default,
10 Thinking {
11 budget_tokens: Option<u64>,
12 },
13}
14
15#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
16#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
17pub enum Model {
18 // Anthropic models (already included)
19 #[default]
20 #[serde(rename = "claude-3-5-sonnet-v2", alias = "claude-3-5-sonnet-latest")]
21 Claude3_5SonnetV2,
22 #[serde(rename = "claude-3-7-sonnet", alias = "claude-3-7-sonnet-latest")]
23 Claude3_7Sonnet,
24 #[serde(
25 rename = "claude-3-7-sonnet-thinking",
26 alias = "claude-3-7-sonnet-thinking-latest"
27 )]
28 Claude3_7SonnetThinking,
29 #[serde(rename = "claude-3-opus", alias = "claude-3-opus-latest")]
30 Claude3Opus,
31 #[serde(rename = "claude-3-sonnet", alias = "claude-3-sonnet-latest")]
32 Claude3Sonnet,
33 #[serde(rename = "claude-3-5-haiku", alias = "claude-3-5-haiku-latest")]
34 Claude3_5Haiku,
35 Claude3_5Sonnet,
36 Claude3Haiku,
37 // Amazon Nova Models
38 AmazonNovaLite,
39 AmazonNovaMicro,
40 AmazonNovaPro,
41 AmazonNovaPremier,
42 // AI21 models
43 AI21J2GrandeInstruct,
44 AI21J2JumboInstruct,
45 AI21J2Mid,
46 AI21J2MidV1,
47 AI21J2Ultra,
48 AI21J2UltraV1_8k,
49 AI21J2UltraV1,
50 AI21JambaInstructV1,
51 AI21Jamba15LargeV1,
52 AI21Jamba15MiniV1,
53 // Cohere models
54 CohereCommandTextV14_4k,
55 CohereCommandRV1,
56 CohereCommandRPlusV1,
57 CohereCommandLightTextV14_4k,
58 // DeepSeek
59 DeepSeekR1,
60 // Meta models
61 MetaLlama38BInstructV1,
62 MetaLlama370BInstructV1,
63 MetaLlama318BInstructV1_128k,
64 MetaLlama318BInstructV1,
65 MetaLlama3170BInstructV1_128k,
66 MetaLlama3170BInstructV1,
67 MetaLlama3211BInstructV1,
68 MetaLlama3290BInstructV1,
69 MetaLlama321BInstructV1,
70 MetaLlama323BInstructV1,
71 // Mistral models
72 MistralMistral7BInstructV0,
73 MistralMixtral8x7BInstructV0,
74 MistralMistralLarge2402V1,
75 MistralMistralSmall2402V1,
76 MistralPixtralLarge2502V1,
77 // Writer models
78 PalmyraWriterX5,
79 PalmyraWriterX4,
80 #[serde(rename = "custom")]
81 Custom {
82 name: String,
83 max_tokens: usize,
84 /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
85 display_name: Option<String>,
86 max_output_tokens: Option<u32>,
87 default_temperature: Option<f32>,
88 },
89}
90
91impl Model {
92 pub fn default_fast() -> Self {
93 Self::Claude3_5Haiku
94 }
95
96 pub fn from_id(id: &str) -> anyhow::Result<Self> {
97 if id.starts_with("claude-3-5-sonnet-v2") {
98 Ok(Self::Claude3_5SonnetV2)
99 } else if id.starts_with("claude-3-opus") {
100 Ok(Self::Claude3Opus)
101 } else if id.starts_with("claude-3-sonnet") {
102 Ok(Self::Claude3Sonnet)
103 } else if id.starts_with("claude-3-5-haiku") {
104 Ok(Self::Claude3_5Haiku)
105 } else if id.starts_with("claude-3-7-sonnet") {
106 Ok(Self::Claude3_7Sonnet)
107 } else if id.starts_with("claude-3-7-sonnet-thinking") {
108 Ok(Self::Claude3_7SonnetThinking)
109 } else {
110 Err(anyhow!("invalid model id"))
111 }
112 }
113
114 pub fn id(&self) -> &str {
115 match self {
116 Model::Claude3_5SonnetV2 => "anthropic.claude-3-5-sonnet-20241022-v2:0",
117 Model::Claude3_5Sonnet => "anthropic.claude-3-5-sonnet-20240620-v1:0",
118 Model::Claude3Opus => "anthropic.claude-3-opus-20240229-v1:0",
119 Model::Claude3Sonnet => "anthropic.claude-3-sonnet-20240229-v1:0",
120 Model::Claude3Haiku => "anthropic.claude-3-haiku-20240307-v1:0",
121 Model::Claude3_5Haiku => "anthropic.claude-3-5-haiku-20241022-v1:0",
122 Model::Claude3_7Sonnet | Model::Claude3_7SonnetThinking => {
123 "anthropic.claude-3-7-sonnet-20250219-v1:0"
124 }
125 Model::AmazonNovaLite => "amazon.nova-lite-v1:0",
126 Model::AmazonNovaMicro => "amazon.nova-micro-v1:0",
127 Model::AmazonNovaPro => "amazon.nova-pro-v1:0",
128 Model::AmazonNovaPremier => "amazon.nova-premier-v1:0",
129 Model::DeepSeekR1 => "us.deepseek.r1-v1:0",
130 Model::AI21J2GrandeInstruct => "ai21.j2-grande-instruct",
131 Model::AI21J2JumboInstruct => "ai21.j2-jumbo-instruct",
132 Model::AI21J2Mid => "ai21.j2-mid",
133 Model::AI21J2MidV1 => "ai21.j2-mid-v1",
134 Model::AI21J2Ultra => "ai21.j2-ultra",
135 Model::AI21J2UltraV1_8k => "ai21.j2-ultra-v1:0:8k",
136 Model::AI21J2UltraV1 => "ai21.j2-ultra-v1",
137 Model::AI21JambaInstructV1 => "ai21.jamba-instruct-v1:0",
138 Model::AI21Jamba15LargeV1 => "ai21.jamba-1-5-large-v1:0",
139 Model::AI21Jamba15MiniV1 => "ai21.jamba-1-5-mini-v1:0",
140 Model::CohereCommandTextV14_4k => "cohere.command-text-v14:7:4k",
141 Model::CohereCommandRV1 => "cohere.command-r-v1:0",
142 Model::CohereCommandRPlusV1 => "cohere.command-r-plus-v1:0",
143 Model::CohereCommandLightTextV14_4k => "cohere.command-light-text-v14:7:4k",
144 Model::MetaLlama38BInstructV1 => "meta.llama3-8b-instruct-v1:0",
145 Model::MetaLlama370BInstructV1 => "meta.llama3-70b-instruct-v1:0",
146 Model::MetaLlama318BInstructV1_128k => "meta.llama3-1-8b-instruct-v1:0:128k",
147 Model::MetaLlama318BInstructV1 => "meta.llama3-1-8b-instruct-v1:0",
148 Model::MetaLlama3170BInstructV1_128k => "meta.llama3-1-70b-instruct-v1:0:128k",
149 Model::MetaLlama3170BInstructV1 => "meta.llama3-1-70b-instruct-v1:0",
150 Model::MetaLlama3211BInstructV1 => "meta.llama3-2-11b-instruct-v1:0",
151 Model::MetaLlama3290BInstructV1 => "meta.llama3-2-90b-instruct-v1:0",
152 Model::MetaLlama321BInstructV1 => "meta.llama3-2-1b-instruct-v1:0",
153 Model::MetaLlama323BInstructV1 => "meta.llama3-2-3b-instruct-v1:0",
154 Model::MistralMistral7BInstructV0 => "mistral.mistral-7b-instruct-v0:2",
155 Model::MistralMixtral8x7BInstructV0 => "mistral.mixtral-8x7b-instruct-v0:1",
156 Model::MistralMistralLarge2402V1 => "mistral.mistral-large-2402-v1:0",
157 Model::MistralMistralSmall2402V1 => "mistral.mistral-small-2402-v1:0",
158 Model::MistralPixtralLarge2502V1 => "mistral.pixtral-large-2502-v1:0",
159 Model::PalmyraWriterX4 => "writer.palmyra-x4-v1:0",
160 Model::PalmyraWriterX5 => "writer.palmyra-x5-v1:0",
161 Self::Custom { name, .. } => name,
162 }
163 }
164
165 pub fn display_name(&self) -> &str {
166 match self {
167 Self::Claude3_5SonnetV2 => "Claude 3.5 Sonnet v2",
168 Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
169 Self::Claude3Opus => "Claude 3 Opus",
170 Self::Claude3Sonnet => "Claude 3 Sonnet",
171 Self::Claude3Haiku => "Claude 3 Haiku",
172 Self::Claude3_5Haiku => "Claude 3.5 Haiku",
173 Self::Claude3_7Sonnet => "Claude 3.7 Sonnet",
174 Self::Claude3_7SonnetThinking => "Claude 3.7 Sonnet Thinking",
175 Self::AmazonNovaLite => "Amazon Nova Lite",
176 Self::AmazonNovaMicro => "Amazon Nova Micro",
177 Self::AmazonNovaPro => "Amazon Nova Pro",
178 Self::AmazonNovaPremier => "Amazon Nova Premier",
179 Self::DeepSeekR1 => "DeepSeek R1",
180 Self::AI21J2GrandeInstruct => "AI21 Jurassic2 Grande Instruct",
181 Self::AI21J2JumboInstruct => "AI21 Jurassic2 Jumbo Instruct",
182 Self::AI21J2Mid => "AI21 Jurassic2 Mid",
183 Self::AI21J2MidV1 => "AI21 Jurassic2 Mid V1",
184 Self::AI21J2Ultra => "AI21 Jurassic2 Ultra",
185 Self::AI21J2UltraV1_8k => "AI21 Jurassic2 Ultra V1 8K",
186 Self::AI21J2UltraV1 => "AI21 Jurassic2 Ultra V1",
187 Self::AI21JambaInstructV1 => "AI21 Jamba Instruct",
188 Self::AI21Jamba15LargeV1 => "AI21 Jamba 1.5 Large",
189 Self::AI21Jamba15MiniV1 => "AI21 Jamba 1.5 Mini",
190 Self::CohereCommandTextV14_4k => "Cohere Command Text V14 4K",
191 Self::CohereCommandRV1 => "Cohere Command R V1",
192 Self::CohereCommandRPlusV1 => "Cohere Command R Plus V1",
193 Self::CohereCommandLightTextV14_4k => "Cohere Command Light Text V14 4K",
194 Self::MetaLlama38BInstructV1 => "Meta Llama 3 8B Instruct V1",
195 Self::MetaLlama370BInstructV1 => "Meta Llama 3 70B Instruct V1",
196 Self::MetaLlama318BInstructV1_128k => "Meta Llama 3 1.8B Instruct V1 128K",
197 Self::MetaLlama318BInstructV1 => "Meta Llama 3 1.8B Instruct V1",
198 Self::MetaLlama3170BInstructV1_128k => "Meta Llama 3 1 70B Instruct V1 128K",
199 Self::MetaLlama3170BInstructV1 => "Meta Llama 3 1 70B Instruct V1",
200 Self::MetaLlama3211BInstructV1 => "Meta Llama 3 2 11B Instruct V1",
201 Self::MetaLlama3290BInstructV1 => "Meta Llama 3 2 90B Instruct V1",
202 Self::MetaLlama321BInstructV1 => "Meta Llama 3 2 1B Instruct V1",
203 Self::MetaLlama323BInstructV1 => "Meta Llama 3 2 3B Instruct V1",
204 Self::MistralMistral7BInstructV0 => "Mistral 7B Instruct V0",
205 Self::MistralMixtral8x7BInstructV0 => "Mistral Mixtral 8x7B Instruct V0",
206 Self::MistralMistralLarge2402V1 => "Mistral Large 2402 V1",
207 Self::MistralMistralSmall2402V1 => "Mistral Small 2402 V1",
208 Self::MistralPixtralLarge2502V1 => "Pixtral Large 25.02 V1",
209 Self::PalmyraWriterX5 => "Writer Palmyra X5",
210 Self::PalmyraWriterX4 => "Writer Palmyra X4",
211 Self::Custom {
212 display_name, name, ..
213 } => display_name.as_deref().unwrap_or(name),
214 }
215 }
216
217 pub fn max_token_count(&self) -> usize {
218 match self {
219 Self::Claude3_5SonnetV2
220 | Self::Claude3Opus
221 | Self::Claude3Sonnet
222 | Self::Claude3_5Haiku
223 | Self::Claude3_7Sonnet => 200_000,
224 Self::AmazonNovaPremier => 1_000_000,
225 Self::PalmyraWriterX5 => 1_000_000,
226 Self::PalmyraWriterX4 => 128_000,
227 Self::Custom { max_tokens, .. } => *max_tokens,
228 _ => 128_000,
229 }
230 }
231
232 pub fn max_output_tokens(&self) -> u32 {
233 match self {
234 Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3_5Haiku => 4_096,
235 Self::Claude3_7Sonnet | Self::Claude3_7SonnetThinking => 128_000,
236 Self::Claude3_5SonnetV2 | Self::PalmyraWriterX4 | Self::PalmyraWriterX5 => 8_192,
237 Self::Custom {
238 max_output_tokens, ..
239 } => max_output_tokens.unwrap_or(4_096),
240 _ => 4_096,
241 }
242 }
243
244 pub fn default_temperature(&self) -> f32 {
245 match self {
246 Self::Claude3_5SonnetV2
247 | Self::Claude3Opus
248 | Self::Claude3Sonnet
249 | Self::Claude3_5Haiku
250 | Self::Claude3_7Sonnet => 1.0,
251 Self::Custom {
252 default_temperature,
253 ..
254 } => default_temperature.unwrap_or(1.0),
255 _ => 1.0,
256 }
257 }
258
259 pub fn supports_tool_use(&self) -> bool {
260 match self {
261 // Anthropic Claude 3 models (all support tool use)
262 Self::Claude3Opus
263 | Self::Claude3Sonnet
264 | Self::Claude3_5Sonnet
265 | Self::Claude3_5SonnetV2
266 | Self::Claude3_7Sonnet
267 | Self::Claude3_7SonnetThinking
268 | Self::Claude3_5Haiku => true,
269
270 // Amazon Nova models (all support tool use)
271 Self::AmazonNovaPremier
272 | Self::AmazonNovaPro
273 | Self::AmazonNovaLite
274 | Self::AmazonNovaMicro => true,
275
276 // AI21 Jamba 1.5 models support tool use
277 Self::AI21Jamba15LargeV1 | Self::AI21Jamba15MiniV1 => true,
278
279 // Cohere Command R models support tool use
280 Self::CohereCommandRV1 | Self::CohereCommandRPlusV1 => true,
281
282 // All other models don't support tool use
283 // Including Meta Llama 3.2, AI21 Jurassic, and others
284 _ => false,
285 }
286 }
287
288 pub fn mode(&self) -> BedrockModelMode {
289 match self {
290 Model::Claude3_7SonnetThinking => BedrockModelMode::Thinking {
291 budget_tokens: Some(4096),
292 },
293 _ => BedrockModelMode::Default,
294 }
295 }
296
297 pub fn cross_region_inference_id(&self, region: &str) -> Result<String, anyhow::Error> {
298 let region_group = if region.starts_with("us-gov-") {
299 "us-gov"
300 } else if region.starts_with("us-") {
301 "us"
302 } else if region.starts_with("eu-") {
303 "eu"
304 } else if region.starts_with("ap-") || region == "me-central-1" || region == "me-south-1" {
305 "apac"
306 } else if region.starts_with("ca-") || region.starts_with("sa-") {
307 // Canada and South America regions - default to US profiles
308 "us"
309 } else {
310 // Unknown region
311 return Err(anyhow!("Unsupported Region"));
312 };
313
314 let model_id = self.id();
315
316 match (self, region_group) {
317 // Custom models can't have CRI IDs
318 (Model::Custom { .. }, _) => Ok(self.id().into()),
319
320 // Models with US Gov only
321 (Model::Claude3_5Sonnet, "us-gov") | (Model::Claude3Haiku, "us-gov") => {
322 Ok(format!("{}.{}", region_group, model_id))
323 }
324
325 // Models available only in US
326 (Model::Claude3Opus, "us")
327 | (Model::Claude3_5Haiku, "us")
328 | (Model::Claude3_7Sonnet, "us")
329 | (Model::Claude3_7SonnetThinking, "us")
330 | (Model::AmazonNovaPremier, "us")
331 | (Model::MistralPixtralLarge2502V1, "us") => {
332 Ok(format!("{}.{}", region_group, model_id))
333 }
334
335 // Models available in US, EU, and APAC
336 (Model::Claude3_5SonnetV2, "us")
337 | (Model::Claude3_5SonnetV2, "apac")
338 | (Model::Claude3_5Sonnet, _)
339 | (Model::Claude3Haiku, _)
340 | (Model::Claude3Sonnet, _)
341 | (Model::AmazonNovaLite, _)
342 | (Model::AmazonNovaMicro, _)
343 | (Model::AmazonNovaPro, _) => Ok(format!("{}.{}", region_group, model_id)),
344
345 // Models with limited EU availability
346 (Model::MetaLlama321BInstructV1, "us")
347 | (Model::MetaLlama321BInstructV1, "eu")
348 | (Model::MetaLlama323BInstructV1, "us")
349 | (Model::MetaLlama323BInstructV1, "eu") => {
350 Ok(format!("{}.{}", region_group, model_id))
351 }
352
353 // US-only models (all remaining Meta models)
354 (Model::MetaLlama38BInstructV1, "us")
355 | (Model::MetaLlama370BInstructV1, "us")
356 | (Model::MetaLlama318BInstructV1, "us")
357 | (Model::MetaLlama318BInstructV1_128k, "us")
358 | (Model::MetaLlama3170BInstructV1, "us")
359 | (Model::MetaLlama3170BInstructV1_128k, "us")
360 | (Model::MetaLlama3211BInstructV1, "us")
361 | (Model::MetaLlama3290BInstructV1, "us") => {
362 Ok(format!("{}.{}", region_group, model_id))
363 }
364
365 // Writer models only available in the US
366 (Model::PalmyraWriterX4, "us") | (Model::PalmyraWriterX5, "us") => {
367 // They have some goofiness
368 Ok(format!("{}.{}", region_group, model_id))
369 }
370
371 // Any other combination is not supported
372 _ => Ok(self.id().into()),
373 }
374 }
375}
376
377#[cfg(test)]
378mod tests {
379 use super::*;
380
381 #[test]
382 fn test_us_region_inference_ids() -> anyhow::Result<()> {
383 // Test US regions
384 assert_eq!(
385 Model::Claude3_5SonnetV2.cross_region_inference_id("us-east-1")?,
386 "us.anthropic.claude-3-5-sonnet-20241022-v2:0"
387 );
388 assert_eq!(
389 Model::Claude3_5SonnetV2.cross_region_inference_id("us-west-2")?,
390 "us.anthropic.claude-3-5-sonnet-20241022-v2:0"
391 );
392 assert_eq!(
393 Model::AmazonNovaPro.cross_region_inference_id("us-east-2")?,
394 "us.amazon.nova-pro-v1:0"
395 );
396 Ok(())
397 }
398
399 #[test]
400 fn test_eu_region_inference_ids() -> anyhow::Result<()> {
401 // Test European regions
402 assert_eq!(
403 Model::Claude3Sonnet.cross_region_inference_id("eu-west-1")?,
404 "eu.anthropic.claude-3-sonnet-20240229-v1:0"
405 );
406 assert_eq!(
407 Model::AmazonNovaMicro.cross_region_inference_id("eu-north-1")?,
408 "eu.amazon.nova-micro-v1:0"
409 );
410 Ok(())
411 }
412
413 #[test]
414 fn test_apac_region_inference_ids() -> anyhow::Result<()> {
415 // Test Asia-Pacific regions
416 assert_eq!(
417 Model::Claude3_5SonnetV2.cross_region_inference_id("ap-northeast-1")?,
418 "apac.anthropic.claude-3-5-sonnet-20241022-v2:0"
419 );
420 assert_eq!(
421 Model::AmazonNovaLite.cross_region_inference_id("ap-south-1")?,
422 "apac.amazon.nova-lite-v1:0"
423 );
424 Ok(())
425 }
426
427 #[test]
428 fn test_gov_region_inference_ids() -> anyhow::Result<()> {
429 // Test Government regions
430 assert_eq!(
431 Model::Claude3_5Sonnet.cross_region_inference_id("us-gov-east-1")?,
432 "us-gov.anthropic.claude-3-5-sonnet-20240620-v1:0"
433 );
434 assert_eq!(
435 Model::Claude3Haiku.cross_region_inference_id("us-gov-west-1")?,
436 "us-gov.anthropic.claude-3-haiku-20240307-v1:0"
437 );
438 Ok(())
439 }
440
441 #[test]
442 fn test_meta_models_inference_ids() -> anyhow::Result<()> {
443 // Test Meta models
444 assert_eq!(
445 Model::MetaLlama370BInstructV1.cross_region_inference_id("us-east-1")?,
446 "us.meta.llama3-70b-instruct-v1:0"
447 );
448 assert_eq!(
449 Model::MetaLlama321BInstructV1.cross_region_inference_id("eu-west-1")?,
450 "eu.meta.llama3-2-1b-instruct-v1:0"
451 );
452 Ok(())
453 }
454
455 #[test]
456 fn test_mistral_models_inference_ids() -> anyhow::Result<()> {
457 // Mistral models don't follow the regional prefix pattern,
458 // so they should return their original IDs
459 assert_eq!(
460 Model::MistralMistralLarge2402V1.cross_region_inference_id("us-east-1")?,
461 "mistral.mistral-large-2402-v1:0"
462 );
463 assert_eq!(
464 Model::MistralMixtral8x7BInstructV0.cross_region_inference_id("eu-west-1")?,
465 "mistral.mixtral-8x7b-instruct-v0:1"
466 );
467 Ok(())
468 }
469
470 #[test]
471 fn test_ai21_models_inference_ids() -> anyhow::Result<()> {
472 // AI21 models don't follow the regional prefix pattern,
473 // so they should return their original IDs
474 assert_eq!(
475 Model::AI21J2UltraV1.cross_region_inference_id("us-east-1")?,
476 "ai21.j2-ultra-v1"
477 );
478 assert_eq!(
479 Model::AI21JambaInstructV1.cross_region_inference_id("eu-west-1")?,
480 "ai21.jamba-instruct-v1:0"
481 );
482 Ok(())
483 }
484
485 #[test]
486 fn test_cohere_models_inference_ids() -> anyhow::Result<()> {
487 // Cohere models don't follow the regional prefix pattern,
488 // so they should return their original IDs
489 assert_eq!(
490 Model::CohereCommandRV1.cross_region_inference_id("us-east-1")?,
491 "cohere.command-r-v1:0"
492 );
493 assert_eq!(
494 Model::CohereCommandTextV14_4k.cross_region_inference_id("ap-southeast-1")?,
495 "cohere.command-text-v14:7:4k"
496 );
497 Ok(())
498 }
499
500 #[test]
501 fn test_custom_model_inference_ids() -> anyhow::Result<()> {
502 // Test custom models
503 let custom_model = Model::Custom {
504 name: "custom.my-model-v1:0".to_string(),
505 max_tokens: 100000,
506 display_name: Some("My Custom Model".to_string()),
507 max_output_tokens: Some(8192),
508 default_temperature: Some(0.7),
509 };
510
511 // Custom model should return its name unchanged
512 assert_eq!(
513 custom_model.cross_region_inference_id("us-east-1")?,
514 "custom.my-model-v1:0"
515 );
516
517 Ok(())
518 }
519}