1use anyhow::anyhow;
2use serde::{Deserialize, Serialize};
3use strum::EnumIter;
4
5#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
6#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
7pub enum BedrockModelMode {
8 #[default]
9 Default,
10 Thinking {
11 budget_tokens: Option<u64>,
12 },
13}
14
15#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
16#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
17pub enum Model {
18 // Anthropic models (already included)
19 #[default]
20 #[serde(rename = "claude-3-5-sonnet-v2", alias = "claude-3-5-sonnet-latest")]
21 Claude3_5SonnetV2,
22 #[serde(rename = "claude-3-7-sonnet", alias = "claude-3-7-sonnet-latest")]
23 Claude3_7Sonnet,
24 #[serde(
25 rename = "claude-3-7-sonnet-thinking",
26 alias = "claude-3-7-sonnet-thinking-latest"
27 )]
28 Claude3_7SonnetThinking,
29 #[serde(rename = "claude-3-opus", alias = "claude-3-opus-latest")]
30 Claude3Opus,
31 #[serde(rename = "claude-3-sonnet", alias = "claude-3-sonnet-latest")]
32 Claude3Sonnet,
33 #[serde(rename = "claude-3-5-haiku", alias = "claude-3-5-haiku-latest")]
34 Claude3_5Haiku,
35 Claude3_5Sonnet,
36 Claude3Haiku,
37 // Amazon Nova Models
38 AmazonNovaLite,
39 AmazonNovaMicro,
40 AmazonNovaPro,
41 AmazonNovaPremier,
42 // AI21 models
43 AI21J2GrandeInstruct,
44 AI21J2JumboInstruct,
45 AI21J2Mid,
46 AI21J2MidV1,
47 AI21J2Ultra,
48 AI21J2UltraV1_8k,
49 AI21J2UltraV1,
50 AI21JambaInstructV1,
51 AI21Jamba15LargeV1,
52 AI21Jamba15MiniV1,
53 // Cohere models
54 CohereCommandTextV14_4k,
55 CohereCommandRV1,
56 CohereCommandRPlusV1,
57 CohereCommandLightTextV14_4k,
58 // DeepSeek
59 DeepSeekR1,
60 // Meta models
61 MetaLlama38BInstructV1,
62 MetaLlama370BInstructV1,
63 MetaLlama318BInstructV1_128k,
64 MetaLlama318BInstructV1,
65 MetaLlama3170BInstructV1_128k,
66 MetaLlama3170BInstructV1,
67 MetaLlama3211BInstructV1,
68 MetaLlama3290BInstructV1,
69 MetaLlama321BInstructV1,
70 MetaLlama323BInstructV1,
71 // Mistral models
72 MistralMistral7BInstructV0,
73 MistralMixtral8x7BInstructV0,
74 MistralMistralLarge2402V1,
75 MistralMistralSmall2402V1,
76 MistralPixtralLarge2502V1,
77 // Writer models
78 PalmyraWriterX5,
79 PalmyraWriterX4,
80 #[serde(rename = "custom")]
81 Custom {
82 name: String,
83 max_tokens: usize,
84 /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
85 display_name: Option<String>,
86 max_output_tokens: Option<u32>,
87 default_temperature: Option<f32>,
88 },
89}
90
91impl Model {
92 pub fn default_fast() -> Self {
93 Self::Claude3_5Haiku
94 }
95
96 pub fn from_id(id: &str) -> anyhow::Result<Self> {
97 if id.starts_with("claude-3-5-sonnet-v2") {
98 Ok(Self::Claude3_5SonnetV2)
99 } else if id.starts_with("claude-3-opus") {
100 Ok(Self::Claude3Opus)
101 } else if id.starts_with("claude-3-sonnet") {
102 Ok(Self::Claude3Sonnet)
103 } else if id.starts_with("claude-3-5-haiku") {
104 Ok(Self::Claude3_5Haiku)
105 } else if id.starts_with("claude-3-7-sonnet") {
106 Ok(Self::Claude3_7Sonnet)
107 } else if id.starts_with("claude-3-7-sonnet-thinking") {
108 Ok(Self::Claude3_7SonnetThinking)
109 } else {
110 Err(anyhow!("invalid model id"))
111 }
112 }
113
114 pub fn id(&self) -> &str {
115 match self {
116 Model::Claude3_5SonnetV2 => "anthropic.claude-3-5-sonnet-20241022-v2:0",
117 Model::Claude3_5Sonnet => "anthropic.claude-3-5-sonnet-20240620-v1:0",
118 Model::Claude3Opus => "anthropic.claude-3-opus-20240229-v1:0",
119 Model::Claude3Sonnet => "anthropic.claude-3-sonnet-20240229-v1:0",
120 Model::Claude3Haiku => "anthropic.claude-3-haiku-20240307-v1:0",
121 Model::Claude3_5Haiku => "anthropic.claude-3-5-haiku-20241022-v1:0",
122 Model::Claude3_7Sonnet | Model::Claude3_7SonnetThinking => {
123 "anthropic.claude-3-7-sonnet-20250219-v1:0"
124 }
125 Model::AmazonNovaLite => "amazon.nova-lite-v1:0",
126 Model::AmazonNovaMicro => "amazon.nova-micro-v1:0",
127 Model::AmazonNovaPro => "amazon.nova-pro-v1:0",
128 Model::AmazonNovaPremier => "amazon.nova-premier-v1:0",
129 Model::DeepSeekR1 => "us.deepseek.r1-v1:0",
130 Model::AI21J2GrandeInstruct => "ai21.j2-grande-instruct",
131 Model::AI21J2JumboInstruct => "ai21.j2-jumbo-instruct",
132 Model::AI21J2Mid => "ai21.j2-mid",
133 Model::AI21J2MidV1 => "ai21.j2-mid-v1",
134 Model::AI21J2Ultra => "ai21.j2-ultra",
135 Model::AI21J2UltraV1_8k => "ai21.j2-ultra-v1:0:8k",
136 Model::AI21J2UltraV1 => "ai21.j2-ultra-v1",
137 Model::AI21JambaInstructV1 => "ai21.jamba-instruct-v1:0",
138 Model::AI21Jamba15LargeV1 => "ai21.jamba-1-5-large-v1:0",
139 Model::AI21Jamba15MiniV1 => "ai21.jamba-1-5-mini-v1:0",
140 Model::CohereCommandTextV14_4k => "cohere.command-text-v14:7:4k",
141 Model::CohereCommandRV1 => "cohere.command-r-v1:0",
142 Model::CohereCommandRPlusV1 => "cohere.command-r-plus-v1:0",
143 Model::CohereCommandLightTextV14_4k => "cohere.command-light-text-v14:7:4k",
144 Model::MetaLlama38BInstructV1 => "meta.llama3-8b-instruct-v1:0",
145 Model::MetaLlama370BInstructV1 => "meta.llama3-70b-instruct-v1:0",
146 Model::MetaLlama318BInstructV1_128k => "meta.llama3-1-8b-instruct-v1:0:128k",
147 Model::MetaLlama318BInstructV1 => "meta.llama3-1-8b-instruct-v1:0",
148 Model::MetaLlama3170BInstructV1_128k => "meta.llama3-1-70b-instruct-v1:0:128k",
149 Model::MetaLlama3170BInstructV1 => "meta.llama3-1-70b-instruct-v1:0",
150 Model::MetaLlama3211BInstructV1 => "meta.llama3-2-11b-instruct-v1:0",
151 Model::MetaLlama3290BInstructV1 => "meta.llama3-2-90b-instruct-v1:0",
152 Model::MetaLlama321BInstructV1 => "meta.llama3-2-1b-instruct-v1:0",
153 Model::MetaLlama323BInstructV1 => "meta.llama3-2-3b-instruct-v1:0",
154 Model::MistralMistral7BInstructV0 => "mistral.mistral-7b-instruct-v0:2",
155 Model::MistralMixtral8x7BInstructV0 => "mistral.mixtral-8x7b-instruct-v0:1",
156 Model::MistralMistralLarge2402V1 => "mistral.mistral-large-2402-v1:0",
157 Model::MistralMistralSmall2402V1 => "mistral.mistral-small-2402-v1:0",
158 Model::MistralPixtralLarge2502V1 => "mistral.pixtral-large-2502-v1:0",
159 Model::PalmyraWriterX4 => "writer.palmyra-x4-v1:0",
160 Model::PalmyraWriterX5 => "writer.palmyra-x5-v1:0",
161 Self::Custom { name, .. } => name,
162 }
163 }
164
165 pub fn display_name(&self) -> &str {
166 match self {
167 Self::Claude3_5SonnetV2 => "Claude 3.5 Sonnet v2",
168 Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
169 Self::Claude3Opus => "Claude 3 Opus",
170 Self::Claude3Sonnet => "Claude 3 Sonnet",
171 Self::Claude3Haiku => "Claude 3 Haiku",
172 Self::Claude3_5Haiku => "Claude 3.5 Haiku",
173 Self::Claude3_7Sonnet => "Claude 3.7 Sonnet",
174 Self::Claude3_7SonnetThinking => "Claude 3.7 Sonnet Thinking",
175 Self::AmazonNovaLite => "Amazon Nova Lite",
176 Self::AmazonNovaMicro => "Amazon Nova Micro",
177 Self::AmazonNovaPro => "Amazon Nova Pro",
178 Self::AmazonNovaPremier => "Amazon Nova Premier",
179 Self::DeepSeekR1 => "DeepSeek R1",
180 Self::AI21J2GrandeInstruct => "AI21 Jurassic2 Grande Instruct",
181 Self::AI21J2JumboInstruct => "AI21 Jurassic2 Jumbo Instruct",
182 Self::AI21J2Mid => "AI21 Jurassic2 Mid",
183 Self::AI21J2MidV1 => "AI21 Jurassic2 Mid V1",
184 Self::AI21J2Ultra => "AI21 Jurassic2 Ultra",
185 Self::AI21J2UltraV1_8k => "AI21 Jurassic2 Ultra V1 8K",
186 Self::AI21J2UltraV1 => "AI21 Jurassic2 Ultra V1",
187 Self::AI21JambaInstructV1 => "AI21 Jamba Instruct",
188 Self::AI21Jamba15LargeV1 => "AI21 Jamba 1.5 Large",
189 Self::AI21Jamba15MiniV1 => "AI21 Jamba 1.5 Mini",
190 Self::CohereCommandTextV14_4k => "Cohere Command Text V14 4K",
191 Self::CohereCommandRV1 => "Cohere Command R V1",
192 Self::CohereCommandRPlusV1 => "Cohere Command R Plus V1",
193 Self::CohereCommandLightTextV14_4k => "Cohere Command Light Text V14 4K",
194 Self::MetaLlama38BInstructV1 => "Meta Llama 3 8B Instruct V1",
195 Self::MetaLlama370BInstructV1 => "Meta Llama 3 70B Instruct V1",
196 Self::MetaLlama318BInstructV1_128k => "Meta Llama 3 1.8B Instruct V1 128K",
197 Self::MetaLlama318BInstructV1 => "Meta Llama 3 1.8B Instruct V1",
198 Self::MetaLlama3170BInstructV1_128k => "Meta Llama 3 1 70B Instruct V1 128K",
199 Self::MetaLlama3170BInstructV1 => "Meta Llama 3 1 70B Instruct V1",
200 Self::MetaLlama3211BInstructV1 => "Meta Llama 3 2 11B Instruct V1",
201 Self::MetaLlama3290BInstructV1 => "Meta Llama 3 2 90B Instruct V1",
202 Self::MetaLlama321BInstructV1 => "Meta Llama 3 2 1B Instruct V1",
203 Self::MetaLlama323BInstructV1 => "Meta Llama 3 2 3B Instruct V1",
204 Self::MistralMistral7BInstructV0 => "Mistral 7B Instruct V0",
205 Self::MistralMixtral8x7BInstructV0 => "Mistral Mixtral 8x7B Instruct V0",
206 Self::MistralMistralLarge2402V1 => "Mistral Large 2402 V1",
207 Self::MistralMistralSmall2402V1 => "Mistral Small 2402 V1",
208 Self::MistralPixtralLarge2502V1 => "Pixtral Large 25.02 V1",
209 Self::PalmyraWriterX5 => "Writer Palmyra X5",
210 Self::PalmyraWriterX4 => "Writer Palmyra X4",
211 Self::Custom {
212 display_name, name, ..
213 } => display_name.as_deref().unwrap_or(name),
214 }
215 }
216
217 pub fn max_token_count(&self) -> usize {
218 match self {
219 Self::Claude3_5SonnetV2
220 | Self::Claude3Opus
221 | Self::Claude3Sonnet
222 | Self::Claude3_5Haiku
223 | Self::Claude3_7Sonnet => 200_000,
224 Self::AmazonNovaPremier => 1_000_000,
225 Self::PalmyraWriterX5 => 1_000_000,
226 Self::PalmyraWriterX4 => 128_000,
227 Self::Custom { max_tokens, .. } => *max_tokens,
228 _ => 128_000,
229 }
230 }
231
232 pub fn max_output_tokens(&self) -> u32 {
233 match self {
234 Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3_5Haiku => 4_096,
235 Self::Claude3_7Sonnet | Self::Claude3_7SonnetThinking => 128_000,
236 Self::Claude3_5SonnetV2 | Self::PalmyraWriterX4 | Self::PalmyraWriterX5 => 8_192,
237 Self::Custom {
238 max_output_tokens, ..
239 } => max_output_tokens.unwrap_or(4_096),
240 _ => 4_096,
241 }
242 }
243
244 pub fn default_temperature(&self) -> f32 {
245 match self {
246 Self::Claude3_5SonnetV2
247 | Self::Claude3Opus
248 | Self::Claude3Sonnet
249 | Self::Claude3_5Haiku
250 | Self::Claude3_7Sonnet => 1.0,
251 Self::Custom {
252 default_temperature,
253 ..
254 } => default_temperature.unwrap_or(1.0),
255 _ => 1.0,
256 }
257 }
258
259 pub fn supports_tool_use(&self) -> bool {
260 match self {
261 // Anthropic Claude 3 models (all support tool use)
262 Self::Claude3Opus
263 | Self::Claude3Sonnet
264 | Self::Claude3_5Sonnet
265 | Self::Claude3_5SonnetV2
266 | Self::Claude3_7Sonnet
267 | Self::Claude3_7SonnetThinking
268 | Self::Claude3_5Haiku => true,
269
270 // Amazon Nova models (all support tool use)
271 Self::AmazonNovaPremier
272 | Self::AmazonNovaPro
273 | Self::AmazonNovaLite
274 | Self::AmazonNovaMicro => true,
275
276 // AI21 Jamba 1.5 models support tool use
277 Self::AI21Jamba15LargeV1 | Self::AI21Jamba15MiniV1 => true,
278
279 // Cohere Command R models support tool use
280 Self::CohereCommandRV1 | Self::CohereCommandRPlusV1 => true,
281
282 // All other models don't support tool use
283 // Including Meta Llama 3.2, AI21 Jurassic, and others
284 _ => false,
285 }
286 }
287
288 pub fn mode(&self) -> BedrockModelMode {
289 match self {
290 Model::Claude3_7SonnetThinking => BedrockModelMode::Thinking {
291 budget_tokens: Some(4096),
292 },
293 _ => BedrockModelMode::Default,
294 }
295 }
296
297 pub fn cross_region_inference_id(&self, region: &str) -> Result<String, anyhow::Error> {
298 let region_group = if region.starts_with("us-gov-") {
299 "us-gov"
300 } else if region.starts_with("us-") {
301 "us"
302 } else if region.starts_with("eu-") {
303 "eu"
304 } else if region.starts_with("ap-") || region == "me-central-1" || region == "me-south-1" {
305 "apac"
306 } else if region.starts_with("ca-") || region.starts_with("sa-") {
307 // Canada and South America regions - default to US profiles
308 "us"
309 } else {
310 // Unknown region
311 return Err(anyhow!("Unsupported Region"));
312 };
313
314 let model_id = self.id();
315
316 match (self, region_group) {
317 // Custom models can't have CRI IDs
318 (Model::Custom { .. }, _) => Ok(self.id().into()),
319
320 // Models with US Gov only
321 (Model::Claude3_5Sonnet, "us-gov") | (Model::Claude3Haiku, "us-gov") => {
322 Ok(format!("{}.{}", region_group, model_id))
323 }
324
325 // Models available only in US
326 (Model::Claude3Opus, "us")
327 | (Model::Claude3_7Sonnet, "us")
328 | (Model::Claude3_7SonnetThinking, "us")
329 | (Model::AmazonNovaPremier, "us")
330 | (Model::MistralPixtralLarge2502V1, "us") => {
331 Ok(format!("{}.{}", region_group, model_id))
332 }
333
334 // Models available in US, EU, and APAC
335 (Model::Claude3_5SonnetV2, "us")
336 | (Model::Claude3_5SonnetV2, "apac")
337 | (Model::Claude3_5Sonnet, _)
338 | (Model::Claude3Haiku, _)
339 | (Model::Claude3Sonnet, _)
340 | (Model::AmazonNovaLite, _)
341 | (Model::AmazonNovaMicro, _)
342 | (Model::AmazonNovaPro, _) => Ok(format!("{}.{}", region_group, model_id)),
343
344 // Models with limited EU availability
345 (Model::MetaLlama321BInstructV1, "us")
346 | (Model::MetaLlama321BInstructV1, "eu")
347 | (Model::MetaLlama323BInstructV1, "us")
348 | (Model::MetaLlama323BInstructV1, "eu") => {
349 Ok(format!("{}.{}", region_group, model_id))
350 }
351
352 // US-only models (all remaining Meta models)
353 (Model::MetaLlama38BInstructV1, "us")
354 | (Model::MetaLlama370BInstructV1, "us")
355 | (Model::MetaLlama318BInstructV1, "us")
356 | (Model::MetaLlama318BInstructV1_128k, "us")
357 | (Model::MetaLlama3170BInstructV1, "us")
358 | (Model::MetaLlama3170BInstructV1_128k, "us")
359 | (Model::MetaLlama3211BInstructV1, "us")
360 | (Model::MetaLlama3290BInstructV1, "us") => {
361 Ok(format!("{}.{}", region_group, model_id))
362 }
363
364 // Writer models only available in the US
365 (Model::PalmyraWriterX4, "us") | (Model::PalmyraWriterX5, "us") => {
366 // They have some goofiness
367 Ok(format!("{}.{}", region_group, model_id))
368 }
369
370 // Any other combination is not supported
371 _ => Ok(self.id().into()),
372 }
373 }
374}
375
376#[cfg(test)]
377mod tests {
378 use super::*;
379
380 #[test]
381 fn test_us_region_inference_ids() -> anyhow::Result<()> {
382 // Test US regions
383 assert_eq!(
384 Model::Claude3_5SonnetV2.cross_region_inference_id("us-east-1")?,
385 "us.anthropic.claude-3-5-sonnet-20241022-v2:0"
386 );
387 assert_eq!(
388 Model::Claude3_5SonnetV2.cross_region_inference_id("us-west-2")?,
389 "us.anthropic.claude-3-5-sonnet-20241022-v2:0"
390 );
391 assert_eq!(
392 Model::AmazonNovaPro.cross_region_inference_id("us-east-2")?,
393 "us.amazon.nova-pro-v1:0"
394 );
395 Ok(())
396 }
397
398 #[test]
399 fn test_eu_region_inference_ids() -> anyhow::Result<()> {
400 // Test European regions
401 assert_eq!(
402 Model::Claude3Sonnet.cross_region_inference_id("eu-west-1")?,
403 "eu.anthropic.claude-3-sonnet-20240229-v1:0"
404 );
405 assert_eq!(
406 Model::AmazonNovaMicro.cross_region_inference_id("eu-north-1")?,
407 "eu.amazon.nova-micro-v1:0"
408 );
409 Ok(())
410 }
411
412 #[test]
413 fn test_apac_region_inference_ids() -> anyhow::Result<()> {
414 // Test Asia-Pacific regions
415 assert_eq!(
416 Model::Claude3_5SonnetV2.cross_region_inference_id("ap-northeast-1")?,
417 "apac.anthropic.claude-3-5-sonnet-20241022-v2:0"
418 );
419 assert_eq!(
420 Model::AmazonNovaLite.cross_region_inference_id("ap-south-1")?,
421 "apac.amazon.nova-lite-v1:0"
422 );
423 Ok(())
424 }
425
426 #[test]
427 fn test_gov_region_inference_ids() -> anyhow::Result<()> {
428 // Test Government regions
429 assert_eq!(
430 Model::Claude3_5Sonnet.cross_region_inference_id("us-gov-east-1")?,
431 "us-gov.anthropic.claude-3-5-sonnet-20240620-v1:0"
432 );
433 assert_eq!(
434 Model::Claude3Haiku.cross_region_inference_id("us-gov-west-1")?,
435 "us-gov.anthropic.claude-3-haiku-20240307-v1:0"
436 );
437 Ok(())
438 }
439
440 #[test]
441 fn test_meta_models_inference_ids() -> anyhow::Result<()> {
442 // Test Meta models
443 assert_eq!(
444 Model::MetaLlama370BInstructV1.cross_region_inference_id("us-east-1")?,
445 "us.meta.llama3-70b-instruct-v1:0"
446 );
447 assert_eq!(
448 Model::MetaLlama321BInstructV1.cross_region_inference_id("eu-west-1")?,
449 "eu.meta.llama3-2-1b-instruct-v1:0"
450 );
451 Ok(())
452 }
453
454 #[test]
455 fn test_mistral_models_inference_ids() -> anyhow::Result<()> {
456 // Mistral models don't follow the regional prefix pattern,
457 // so they should return their original IDs
458 assert_eq!(
459 Model::MistralMistralLarge2402V1.cross_region_inference_id("us-east-1")?,
460 "mistral.mistral-large-2402-v1:0"
461 );
462 assert_eq!(
463 Model::MistralMixtral8x7BInstructV0.cross_region_inference_id("eu-west-1")?,
464 "mistral.mixtral-8x7b-instruct-v0:1"
465 );
466 Ok(())
467 }
468
469 #[test]
470 fn test_ai21_models_inference_ids() -> anyhow::Result<()> {
471 // AI21 models don't follow the regional prefix pattern,
472 // so they should return their original IDs
473 assert_eq!(
474 Model::AI21J2UltraV1.cross_region_inference_id("us-east-1")?,
475 "ai21.j2-ultra-v1"
476 );
477 assert_eq!(
478 Model::AI21JambaInstructV1.cross_region_inference_id("eu-west-1")?,
479 "ai21.jamba-instruct-v1:0"
480 );
481 Ok(())
482 }
483
484 #[test]
485 fn test_cohere_models_inference_ids() -> anyhow::Result<()> {
486 // Cohere models don't follow the regional prefix pattern,
487 // so they should return their original IDs
488 assert_eq!(
489 Model::CohereCommandRV1.cross_region_inference_id("us-east-1")?,
490 "cohere.command-r-v1:0"
491 );
492 assert_eq!(
493 Model::CohereCommandTextV14_4k.cross_region_inference_id("ap-southeast-1")?,
494 "cohere.command-text-v14:7:4k"
495 );
496 Ok(())
497 }
498
499 #[test]
500 fn test_custom_model_inference_ids() -> anyhow::Result<()> {
501 // Test custom models
502 let custom_model = Model::Custom {
503 name: "custom.my-model-v1:0".to_string(),
504 max_tokens: 100000,
505 display_name: Some("My Custom Model".to_string()),
506 max_output_tokens: Some(8192),
507 default_temperature: Some(0.7),
508 };
509
510 // Custom model should return its name unchanged
511 assert_eq!(
512 custom_model.cross_region_inference_id("us-east-1")?,
513 "custom.my-model-v1:0"
514 );
515
516 Ok(())
517 }
518}