1use serde::{Deserialize, Serialize};
2use strum::EnumIter;
3
4#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
5#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
6pub enum BedrockModelMode {
7 #[default]
8 Default,
9 Thinking {
10 budget_tokens: Option<u64>,
11 },
12}
13
14#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
15#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
16pub enum Model {
17 // Anthropic models (already included)
18 #[default]
19 #[serde(rename = "claude-3-5-sonnet-v2", alias = "claude-3-5-sonnet-latest")]
20 Claude3_5SonnetV2,
21 #[serde(rename = "claude-3-7-sonnet", alias = "claude-3-7-sonnet-latest")]
22 Claude3_7Sonnet,
23 #[serde(
24 rename = "claude-3-7-sonnet-thinking",
25 alias = "claude-3-7-sonnet-thinking-latest"
26 )]
27 Claude3_7SonnetThinking,
28 #[serde(rename = "claude-3-opus", alias = "claude-3-opus-latest")]
29 Claude3Opus,
30 #[serde(rename = "claude-3-sonnet", alias = "claude-3-sonnet-latest")]
31 Claude3Sonnet,
32 #[serde(rename = "claude-3-5-haiku", alias = "claude-3-5-haiku-latest")]
33 Claude3_5Haiku,
34 Claude3_5Sonnet,
35 Claude3Haiku,
36 // Amazon Nova Models
37 AmazonNovaLite,
38 AmazonNovaMicro,
39 AmazonNovaPro,
40 AmazonNovaPremier,
41 // AI21 models
42 AI21J2GrandeInstruct,
43 AI21J2JumboInstruct,
44 AI21J2Mid,
45 AI21J2MidV1,
46 AI21J2Ultra,
47 AI21J2UltraV1_8k,
48 AI21J2UltraV1,
49 AI21JambaInstructV1,
50 AI21Jamba15LargeV1,
51 AI21Jamba15MiniV1,
52 // Cohere models
53 CohereCommandTextV14_4k,
54 CohereCommandRV1,
55 CohereCommandRPlusV1,
56 CohereCommandLightTextV14_4k,
57 // DeepSeek
58 DeepSeekR1,
59 // Meta models
60 MetaLlama38BInstructV1,
61 MetaLlama370BInstructV1,
62 MetaLlama318BInstructV1_128k,
63 MetaLlama318BInstructV1,
64 MetaLlama3170BInstructV1_128k,
65 MetaLlama3170BInstructV1,
66 MetaLlama3211BInstructV1,
67 MetaLlama3290BInstructV1,
68 MetaLlama321BInstructV1,
69 MetaLlama323BInstructV1,
70 // Mistral models
71 MistralMistral7BInstructV0,
72 MistralMixtral8x7BInstructV0,
73 MistralMistralLarge2402V1,
74 MistralMistralSmall2402V1,
75 MistralPixtralLarge2502V1,
76 // Writer models
77 PalmyraWriterX5,
78 PalmyraWriterX4,
79 #[serde(rename = "custom")]
80 Custom {
81 name: String,
82 max_tokens: usize,
83 /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
84 display_name: Option<String>,
85 max_output_tokens: Option<u32>,
86 default_temperature: Option<f32>,
87 },
88}
89
90impl Model {
91 pub fn default_fast() -> Self {
92 Self::Claude3_5Haiku
93 }
94
95 pub fn from_id(id: &str) -> anyhow::Result<Self> {
96 if id.starts_with("claude-3-5-sonnet-v2") {
97 Ok(Self::Claude3_5SonnetV2)
98 } else if id.starts_with("claude-3-opus") {
99 Ok(Self::Claude3Opus)
100 } else if id.starts_with("claude-3-sonnet") {
101 Ok(Self::Claude3Sonnet)
102 } else if id.starts_with("claude-3-5-haiku") {
103 Ok(Self::Claude3_5Haiku)
104 } else if id.starts_with("claude-3-7-sonnet") {
105 Ok(Self::Claude3_7Sonnet)
106 } else if id.starts_with("claude-3-7-sonnet-thinking") {
107 Ok(Self::Claude3_7SonnetThinking)
108 } else {
109 anyhow::bail!("invalid model id {id}");
110 }
111 }
112
113 pub fn id(&self) -> &str {
114 match self {
115 Model::Claude3_5SonnetV2 => "anthropic.claude-3-5-sonnet-20241022-v2:0",
116 Model::Claude3_5Sonnet => "anthropic.claude-3-5-sonnet-20240620-v1:0",
117 Model::Claude3Opus => "anthropic.claude-3-opus-20240229-v1:0",
118 Model::Claude3Sonnet => "anthropic.claude-3-sonnet-20240229-v1:0",
119 Model::Claude3Haiku => "anthropic.claude-3-haiku-20240307-v1:0",
120 Model::Claude3_5Haiku => "anthropic.claude-3-5-haiku-20241022-v1:0",
121 Model::Claude3_7Sonnet | Model::Claude3_7SonnetThinking => {
122 "anthropic.claude-3-7-sonnet-20250219-v1:0"
123 }
124 Model::AmazonNovaLite => "amazon.nova-lite-v1:0",
125 Model::AmazonNovaMicro => "amazon.nova-micro-v1:0",
126 Model::AmazonNovaPro => "amazon.nova-pro-v1:0",
127 Model::AmazonNovaPremier => "amazon.nova-premier-v1:0",
128 Model::DeepSeekR1 => "us.deepseek.r1-v1:0",
129 Model::AI21J2GrandeInstruct => "ai21.j2-grande-instruct",
130 Model::AI21J2JumboInstruct => "ai21.j2-jumbo-instruct",
131 Model::AI21J2Mid => "ai21.j2-mid",
132 Model::AI21J2MidV1 => "ai21.j2-mid-v1",
133 Model::AI21J2Ultra => "ai21.j2-ultra",
134 Model::AI21J2UltraV1_8k => "ai21.j2-ultra-v1:0:8k",
135 Model::AI21J2UltraV1 => "ai21.j2-ultra-v1",
136 Model::AI21JambaInstructV1 => "ai21.jamba-instruct-v1:0",
137 Model::AI21Jamba15LargeV1 => "ai21.jamba-1-5-large-v1:0",
138 Model::AI21Jamba15MiniV1 => "ai21.jamba-1-5-mini-v1:0",
139 Model::CohereCommandTextV14_4k => "cohere.command-text-v14:7:4k",
140 Model::CohereCommandRV1 => "cohere.command-r-v1:0",
141 Model::CohereCommandRPlusV1 => "cohere.command-r-plus-v1:0",
142 Model::CohereCommandLightTextV14_4k => "cohere.command-light-text-v14:7:4k",
143 Model::MetaLlama38BInstructV1 => "meta.llama3-8b-instruct-v1:0",
144 Model::MetaLlama370BInstructV1 => "meta.llama3-70b-instruct-v1:0",
145 Model::MetaLlama318BInstructV1_128k => "meta.llama3-1-8b-instruct-v1:0:128k",
146 Model::MetaLlama318BInstructV1 => "meta.llama3-1-8b-instruct-v1:0",
147 Model::MetaLlama3170BInstructV1_128k => "meta.llama3-1-70b-instruct-v1:0:128k",
148 Model::MetaLlama3170BInstructV1 => "meta.llama3-1-70b-instruct-v1:0",
149 Model::MetaLlama3211BInstructV1 => "meta.llama3-2-11b-instruct-v1:0",
150 Model::MetaLlama3290BInstructV1 => "meta.llama3-2-90b-instruct-v1:0",
151 Model::MetaLlama321BInstructV1 => "meta.llama3-2-1b-instruct-v1:0",
152 Model::MetaLlama323BInstructV1 => "meta.llama3-2-3b-instruct-v1:0",
153 Model::MistralMistral7BInstructV0 => "mistral.mistral-7b-instruct-v0:2",
154 Model::MistralMixtral8x7BInstructV0 => "mistral.mixtral-8x7b-instruct-v0:1",
155 Model::MistralMistralLarge2402V1 => "mistral.mistral-large-2402-v1:0",
156 Model::MistralMistralSmall2402V1 => "mistral.mistral-small-2402-v1:0",
157 Model::MistralPixtralLarge2502V1 => "mistral.pixtral-large-2502-v1:0",
158 Model::PalmyraWriterX4 => "writer.palmyra-x4-v1:0",
159 Model::PalmyraWriterX5 => "writer.palmyra-x5-v1:0",
160 Self::Custom { name, .. } => name,
161 }
162 }
163
164 pub fn display_name(&self) -> &str {
165 match self {
166 Self::Claude3_5SonnetV2 => "Claude 3.5 Sonnet v2",
167 Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
168 Self::Claude3Opus => "Claude 3 Opus",
169 Self::Claude3Sonnet => "Claude 3 Sonnet",
170 Self::Claude3Haiku => "Claude 3 Haiku",
171 Self::Claude3_5Haiku => "Claude 3.5 Haiku",
172 Self::Claude3_7Sonnet => "Claude 3.7 Sonnet",
173 Self::Claude3_7SonnetThinking => "Claude 3.7 Sonnet Thinking",
174 Self::AmazonNovaLite => "Amazon Nova Lite",
175 Self::AmazonNovaMicro => "Amazon Nova Micro",
176 Self::AmazonNovaPro => "Amazon Nova Pro",
177 Self::AmazonNovaPremier => "Amazon Nova Premier",
178 Self::DeepSeekR1 => "DeepSeek R1",
179 Self::AI21J2GrandeInstruct => "AI21 Jurassic2 Grande Instruct",
180 Self::AI21J2JumboInstruct => "AI21 Jurassic2 Jumbo Instruct",
181 Self::AI21J2Mid => "AI21 Jurassic2 Mid",
182 Self::AI21J2MidV1 => "AI21 Jurassic2 Mid V1",
183 Self::AI21J2Ultra => "AI21 Jurassic2 Ultra",
184 Self::AI21J2UltraV1_8k => "AI21 Jurassic2 Ultra V1 8K",
185 Self::AI21J2UltraV1 => "AI21 Jurassic2 Ultra V1",
186 Self::AI21JambaInstructV1 => "AI21 Jamba Instruct",
187 Self::AI21Jamba15LargeV1 => "AI21 Jamba 1.5 Large",
188 Self::AI21Jamba15MiniV1 => "AI21 Jamba 1.5 Mini",
189 Self::CohereCommandTextV14_4k => "Cohere Command Text V14 4K",
190 Self::CohereCommandRV1 => "Cohere Command R V1",
191 Self::CohereCommandRPlusV1 => "Cohere Command R Plus V1",
192 Self::CohereCommandLightTextV14_4k => "Cohere Command Light Text V14 4K",
193 Self::MetaLlama38BInstructV1 => "Meta Llama 3 8B Instruct V1",
194 Self::MetaLlama370BInstructV1 => "Meta Llama 3 70B Instruct V1",
195 Self::MetaLlama318BInstructV1_128k => "Meta Llama 3 1.8B Instruct V1 128K",
196 Self::MetaLlama318BInstructV1 => "Meta Llama 3 1.8B Instruct V1",
197 Self::MetaLlama3170BInstructV1_128k => "Meta Llama 3 1 70B Instruct V1 128K",
198 Self::MetaLlama3170BInstructV1 => "Meta Llama 3 1 70B Instruct V1",
199 Self::MetaLlama3211BInstructV1 => "Meta Llama 3 2 11B Instruct V1",
200 Self::MetaLlama3290BInstructV1 => "Meta Llama 3 2 90B Instruct V1",
201 Self::MetaLlama321BInstructV1 => "Meta Llama 3 2 1B Instruct V1",
202 Self::MetaLlama323BInstructV1 => "Meta Llama 3 2 3B Instruct V1",
203 Self::MistralMistral7BInstructV0 => "Mistral 7B Instruct V0",
204 Self::MistralMixtral8x7BInstructV0 => "Mistral Mixtral 8x7B Instruct V0",
205 Self::MistralMistralLarge2402V1 => "Mistral Large 2402 V1",
206 Self::MistralMistralSmall2402V1 => "Mistral Small 2402 V1",
207 Self::MistralPixtralLarge2502V1 => "Pixtral Large 25.02 V1",
208 Self::PalmyraWriterX5 => "Writer Palmyra X5",
209 Self::PalmyraWriterX4 => "Writer Palmyra X4",
210 Self::Custom {
211 display_name, name, ..
212 } => display_name.as_deref().unwrap_or(name),
213 }
214 }
215
216 pub fn max_token_count(&self) -> usize {
217 match self {
218 Self::Claude3_5SonnetV2
219 | Self::Claude3Opus
220 | Self::Claude3Sonnet
221 | Self::Claude3_5Haiku
222 | Self::Claude3_7Sonnet => 200_000,
223 Self::AmazonNovaPremier => 1_000_000,
224 Self::PalmyraWriterX5 => 1_000_000,
225 Self::PalmyraWriterX4 => 128_000,
226 Self::Custom { max_tokens, .. } => *max_tokens,
227 _ => 128_000,
228 }
229 }
230
231 pub fn max_output_tokens(&self) -> u32 {
232 match self {
233 Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3_5Haiku => 4_096,
234 Self::Claude3_7Sonnet | Self::Claude3_7SonnetThinking => 128_000,
235 Self::Claude3_5SonnetV2 | Self::PalmyraWriterX4 | Self::PalmyraWriterX5 => 8_192,
236 Self::Custom {
237 max_output_tokens, ..
238 } => max_output_tokens.unwrap_or(4_096),
239 _ => 4_096,
240 }
241 }
242
243 pub fn default_temperature(&self) -> f32 {
244 match self {
245 Self::Claude3_5SonnetV2
246 | Self::Claude3Opus
247 | Self::Claude3Sonnet
248 | Self::Claude3_5Haiku
249 | Self::Claude3_7Sonnet => 1.0,
250 Self::Custom {
251 default_temperature,
252 ..
253 } => default_temperature.unwrap_or(1.0),
254 _ => 1.0,
255 }
256 }
257
258 pub fn supports_tool_use(&self) -> bool {
259 match self {
260 // Anthropic Claude 3 models (all support tool use)
261 Self::Claude3Opus
262 | Self::Claude3Sonnet
263 | Self::Claude3_5Sonnet
264 | Self::Claude3_5SonnetV2
265 | Self::Claude3_7Sonnet
266 | Self::Claude3_7SonnetThinking
267 | Self::Claude3_5Haiku => true,
268
269 // Amazon Nova models (all support tool use)
270 Self::AmazonNovaPremier
271 | Self::AmazonNovaPro
272 | Self::AmazonNovaLite
273 | Self::AmazonNovaMicro => true,
274
275 // AI21 Jamba 1.5 models support tool use
276 Self::AI21Jamba15LargeV1 | Self::AI21Jamba15MiniV1 => true,
277
278 // Cohere Command R models support tool use
279 Self::CohereCommandRV1 | Self::CohereCommandRPlusV1 => true,
280
281 // All other models don't support tool use
282 // Including Meta Llama 3.2, AI21 Jurassic, and others
283 _ => false,
284 }
285 }
286
287 pub fn mode(&self) -> BedrockModelMode {
288 match self {
289 Model::Claude3_7SonnetThinking => BedrockModelMode::Thinking {
290 budget_tokens: Some(4096),
291 },
292 _ => BedrockModelMode::Default,
293 }
294 }
295
296 pub fn cross_region_inference_id(&self, region: &str) -> anyhow::Result<String> {
297 let region_group = if region.starts_with("us-gov-") {
298 "us-gov"
299 } else if region.starts_with("us-") {
300 "us"
301 } else if region.starts_with("eu-") {
302 "eu"
303 } else if region.starts_with("ap-") || region == "me-central-1" || region == "me-south-1" {
304 "apac"
305 } else if region.starts_with("ca-") || region.starts_with("sa-") {
306 // Canada and South America regions - default to US profiles
307 "us"
308 } else {
309 anyhow::bail!("Unsupported Region {region}");
310 };
311
312 let model_id = self.id();
313
314 match (self, region_group) {
315 // Custom models can't have CRI IDs
316 (Model::Custom { .. }, _) => Ok(self.id().into()),
317
318 // Models with US Gov only
319 (Model::Claude3_5Sonnet, "us-gov") | (Model::Claude3Haiku, "us-gov") => {
320 Ok(format!("{}.{}", region_group, model_id))
321 }
322
323 // Models available only in US
324 (Model::Claude3Opus, "us")
325 | (Model::Claude3_5Haiku, "us")
326 | (Model::Claude3_7Sonnet, "us")
327 | (Model::Claude3_7SonnetThinking, "us")
328 | (Model::AmazonNovaPremier, "us")
329 | (Model::MistralPixtralLarge2502V1, "us") => {
330 Ok(format!("{}.{}", region_group, model_id))
331 }
332
333 // Models available in US, EU, and APAC
334 (Model::Claude3_5SonnetV2, "us")
335 | (Model::Claude3_5SonnetV2, "apac")
336 | (Model::Claude3_5Sonnet, _)
337 | (Model::Claude3Haiku, _)
338 | (Model::Claude3Sonnet, _)
339 | (Model::AmazonNovaLite, _)
340 | (Model::AmazonNovaMicro, _)
341 | (Model::AmazonNovaPro, _) => Ok(format!("{}.{}", region_group, model_id)),
342
343 // Models with limited EU availability
344 (Model::MetaLlama321BInstructV1, "us")
345 | (Model::MetaLlama321BInstructV1, "eu")
346 | (Model::MetaLlama323BInstructV1, "us")
347 | (Model::MetaLlama323BInstructV1, "eu") => {
348 Ok(format!("{}.{}", region_group, model_id))
349 }
350
351 // US-only models (all remaining Meta models)
352 (Model::MetaLlama38BInstructV1, "us")
353 | (Model::MetaLlama370BInstructV1, "us")
354 | (Model::MetaLlama318BInstructV1, "us")
355 | (Model::MetaLlama318BInstructV1_128k, "us")
356 | (Model::MetaLlama3170BInstructV1, "us")
357 | (Model::MetaLlama3170BInstructV1_128k, "us")
358 | (Model::MetaLlama3211BInstructV1, "us")
359 | (Model::MetaLlama3290BInstructV1, "us") => {
360 Ok(format!("{}.{}", region_group, model_id))
361 }
362
363 // Writer models only available in the US
364 (Model::PalmyraWriterX4, "us") | (Model::PalmyraWriterX5, "us") => {
365 // They have some goofiness
366 Ok(format!("{}.{}", region_group, model_id))
367 }
368
369 // Any other combination is not supported
370 _ => Ok(self.id().into()),
371 }
372 }
373}
374
375#[cfg(test)]
376mod tests {
377 use super::*;
378
379 #[test]
380 fn test_us_region_inference_ids() -> anyhow::Result<()> {
381 // Test US regions
382 assert_eq!(
383 Model::Claude3_5SonnetV2.cross_region_inference_id("us-east-1")?,
384 "us.anthropic.claude-3-5-sonnet-20241022-v2:0"
385 );
386 assert_eq!(
387 Model::Claude3_5SonnetV2.cross_region_inference_id("us-west-2")?,
388 "us.anthropic.claude-3-5-sonnet-20241022-v2:0"
389 );
390 assert_eq!(
391 Model::AmazonNovaPro.cross_region_inference_id("us-east-2")?,
392 "us.amazon.nova-pro-v1:0"
393 );
394 Ok(())
395 }
396
397 #[test]
398 fn test_eu_region_inference_ids() -> anyhow::Result<()> {
399 // Test European regions
400 assert_eq!(
401 Model::Claude3Sonnet.cross_region_inference_id("eu-west-1")?,
402 "eu.anthropic.claude-3-sonnet-20240229-v1:0"
403 );
404 assert_eq!(
405 Model::AmazonNovaMicro.cross_region_inference_id("eu-north-1")?,
406 "eu.amazon.nova-micro-v1:0"
407 );
408 Ok(())
409 }
410
411 #[test]
412 fn test_apac_region_inference_ids() -> anyhow::Result<()> {
413 // Test Asia-Pacific regions
414 assert_eq!(
415 Model::Claude3_5SonnetV2.cross_region_inference_id("ap-northeast-1")?,
416 "apac.anthropic.claude-3-5-sonnet-20241022-v2:0"
417 );
418 assert_eq!(
419 Model::AmazonNovaLite.cross_region_inference_id("ap-south-1")?,
420 "apac.amazon.nova-lite-v1:0"
421 );
422 Ok(())
423 }
424
425 #[test]
426 fn test_gov_region_inference_ids() -> anyhow::Result<()> {
427 // Test Government regions
428 assert_eq!(
429 Model::Claude3_5Sonnet.cross_region_inference_id("us-gov-east-1")?,
430 "us-gov.anthropic.claude-3-5-sonnet-20240620-v1:0"
431 );
432 assert_eq!(
433 Model::Claude3Haiku.cross_region_inference_id("us-gov-west-1")?,
434 "us-gov.anthropic.claude-3-haiku-20240307-v1:0"
435 );
436 Ok(())
437 }
438
439 #[test]
440 fn test_meta_models_inference_ids() -> anyhow::Result<()> {
441 // Test Meta models
442 assert_eq!(
443 Model::MetaLlama370BInstructV1.cross_region_inference_id("us-east-1")?,
444 "us.meta.llama3-70b-instruct-v1:0"
445 );
446 assert_eq!(
447 Model::MetaLlama321BInstructV1.cross_region_inference_id("eu-west-1")?,
448 "eu.meta.llama3-2-1b-instruct-v1:0"
449 );
450 Ok(())
451 }
452
453 #[test]
454 fn test_mistral_models_inference_ids() -> anyhow::Result<()> {
455 // Mistral models don't follow the regional prefix pattern,
456 // so they should return their original IDs
457 assert_eq!(
458 Model::MistralMistralLarge2402V1.cross_region_inference_id("us-east-1")?,
459 "mistral.mistral-large-2402-v1:0"
460 );
461 assert_eq!(
462 Model::MistralMixtral8x7BInstructV0.cross_region_inference_id("eu-west-1")?,
463 "mistral.mixtral-8x7b-instruct-v0:1"
464 );
465 Ok(())
466 }
467
468 #[test]
469 fn test_ai21_models_inference_ids() -> anyhow::Result<()> {
470 // AI21 models don't follow the regional prefix pattern,
471 // so they should return their original IDs
472 assert_eq!(
473 Model::AI21J2UltraV1.cross_region_inference_id("us-east-1")?,
474 "ai21.j2-ultra-v1"
475 );
476 assert_eq!(
477 Model::AI21JambaInstructV1.cross_region_inference_id("eu-west-1")?,
478 "ai21.jamba-instruct-v1:0"
479 );
480 Ok(())
481 }
482
483 #[test]
484 fn test_cohere_models_inference_ids() -> anyhow::Result<()> {
485 // Cohere models don't follow the regional prefix pattern,
486 // so they should return their original IDs
487 assert_eq!(
488 Model::CohereCommandRV1.cross_region_inference_id("us-east-1")?,
489 "cohere.command-r-v1:0"
490 );
491 assert_eq!(
492 Model::CohereCommandTextV14_4k.cross_region_inference_id("ap-southeast-1")?,
493 "cohere.command-text-v14:7:4k"
494 );
495 Ok(())
496 }
497
498 #[test]
499 fn test_custom_model_inference_ids() -> anyhow::Result<()> {
500 // Test custom models
501 let custom_model = Model::Custom {
502 name: "custom.my-model-v1:0".to_string(),
503 max_tokens: 100000,
504 display_name: Some("My Custom Model".to_string()),
505 max_output_tokens: Some(8192),
506 default_temperature: Some(0.7),
507 };
508
509 // Custom model should return its name unchanged
510 assert_eq!(
511 custom_model.cross_region_inference_id("us-east-1")?,
512 "custom.my-model-v1:0"
513 );
514
515 Ok(())
516 }
517}