models.rs

  1use serde::{Deserialize, Serialize};
  2use strum::EnumIter;
  3
  4#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
  5#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
  6pub enum BedrockModelMode {
  7    #[default]
  8    Default,
  9    Thinking {
 10        budget_tokens: Option<u64>,
 11    },
 12}
 13
 14#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 15#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
 16pub enum Model {
 17    // Anthropic models (already included)
 18    #[default]
 19    #[serde(rename = "claude-3-5-sonnet-v2", alias = "claude-3-5-sonnet-latest")]
 20    Claude3_5SonnetV2,
 21    #[serde(rename = "claude-3-7-sonnet", alias = "claude-3-7-sonnet-latest")]
 22    Claude3_7Sonnet,
 23    #[serde(
 24        rename = "claude-3-7-sonnet-thinking",
 25        alias = "claude-3-7-sonnet-thinking-latest"
 26    )]
 27    Claude3_7SonnetThinking,
 28    #[serde(rename = "claude-3-opus", alias = "claude-3-opus-latest")]
 29    Claude3Opus,
 30    #[serde(rename = "claude-3-sonnet", alias = "claude-3-sonnet-latest")]
 31    Claude3Sonnet,
 32    #[serde(rename = "claude-3-5-haiku", alias = "claude-3-5-haiku-latest")]
 33    Claude3_5Haiku,
 34    Claude3_5Sonnet,
 35    Claude3Haiku,
 36    // Amazon Nova Models
 37    AmazonNovaLite,
 38    AmazonNovaMicro,
 39    AmazonNovaPro,
 40    AmazonNovaPremier,
 41    // AI21 models
 42    AI21J2GrandeInstruct,
 43    AI21J2JumboInstruct,
 44    AI21J2Mid,
 45    AI21J2MidV1,
 46    AI21J2Ultra,
 47    AI21J2UltraV1_8k,
 48    AI21J2UltraV1,
 49    AI21JambaInstructV1,
 50    AI21Jamba15LargeV1,
 51    AI21Jamba15MiniV1,
 52    // Cohere models
 53    CohereCommandTextV14_4k,
 54    CohereCommandRV1,
 55    CohereCommandRPlusV1,
 56    CohereCommandLightTextV14_4k,
 57    // DeepSeek
 58    DeepSeekR1,
 59    // Meta models
 60    MetaLlama38BInstructV1,
 61    MetaLlama370BInstructV1,
 62    MetaLlama318BInstructV1_128k,
 63    MetaLlama318BInstructV1,
 64    MetaLlama3170BInstructV1_128k,
 65    MetaLlama3170BInstructV1,
 66    MetaLlama3211BInstructV1,
 67    MetaLlama3290BInstructV1,
 68    MetaLlama321BInstructV1,
 69    MetaLlama323BInstructV1,
 70    // Mistral models
 71    MistralMistral7BInstructV0,
 72    MistralMixtral8x7BInstructV0,
 73    MistralMistralLarge2402V1,
 74    MistralMistralSmall2402V1,
 75    MistralPixtralLarge2502V1,
 76    // Writer models
 77    PalmyraWriterX5,
 78    PalmyraWriterX4,
 79    #[serde(rename = "custom")]
 80    Custom {
 81        name: String,
 82        max_tokens: usize,
 83        /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
 84        display_name: Option<String>,
 85        max_output_tokens: Option<u32>,
 86        default_temperature: Option<f32>,
 87    },
 88}
 89
 90impl Model {
 91    pub fn default_fast() -> Self {
 92        Self::Claude3_5Haiku
 93    }
 94
 95    pub fn from_id(id: &str) -> anyhow::Result<Self> {
 96        if id.starts_with("claude-3-5-sonnet-v2") {
 97            Ok(Self::Claude3_5SonnetV2)
 98        } else if id.starts_with("claude-3-opus") {
 99            Ok(Self::Claude3Opus)
100        } else if id.starts_with("claude-3-sonnet") {
101            Ok(Self::Claude3Sonnet)
102        } else if id.starts_with("claude-3-5-haiku") {
103            Ok(Self::Claude3_5Haiku)
104        } else if id.starts_with("claude-3-7-sonnet") {
105            Ok(Self::Claude3_7Sonnet)
106        } else if id.starts_with("claude-3-7-sonnet-thinking") {
107            Ok(Self::Claude3_7SonnetThinking)
108        } else {
109            anyhow::bail!("invalid model id {id}");
110        }
111    }
112
113    pub fn id(&self) -> &str {
114        match self {
115            Model::Claude3_5SonnetV2 => "anthropic.claude-3-5-sonnet-20241022-v2:0",
116            Model::Claude3_5Sonnet => "anthropic.claude-3-5-sonnet-20240620-v1:0",
117            Model::Claude3Opus => "anthropic.claude-3-opus-20240229-v1:0",
118            Model::Claude3Sonnet => "anthropic.claude-3-sonnet-20240229-v1:0",
119            Model::Claude3Haiku => "anthropic.claude-3-haiku-20240307-v1:0",
120            Model::Claude3_5Haiku => "anthropic.claude-3-5-haiku-20241022-v1:0",
121            Model::Claude3_7Sonnet | Model::Claude3_7SonnetThinking => {
122                "anthropic.claude-3-7-sonnet-20250219-v1:0"
123            }
124            Model::AmazonNovaLite => "amazon.nova-lite-v1:0",
125            Model::AmazonNovaMicro => "amazon.nova-micro-v1:0",
126            Model::AmazonNovaPro => "amazon.nova-pro-v1:0",
127            Model::AmazonNovaPremier => "amazon.nova-premier-v1:0",
128            Model::DeepSeekR1 => "us.deepseek.r1-v1:0",
129            Model::AI21J2GrandeInstruct => "ai21.j2-grande-instruct",
130            Model::AI21J2JumboInstruct => "ai21.j2-jumbo-instruct",
131            Model::AI21J2Mid => "ai21.j2-mid",
132            Model::AI21J2MidV1 => "ai21.j2-mid-v1",
133            Model::AI21J2Ultra => "ai21.j2-ultra",
134            Model::AI21J2UltraV1_8k => "ai21.j2-ultra-v1:0:8k",
135            Model::AI21J2UltraV1 => "ai21.j2-ultra-v1",
136            Model::AI21JambaInstructV1 => "ai21.jamba-instruct-v1:0",
137            Model::AI21Jamba15LargeV1 => "ai21.jamba-1-5-large-v1:0",
138            Model::AI21Jamba15MiniV1 => "ai21.jamba-1-5-mini-v1:0",
139            Model::CohereCommandTextV14_4k => "cohere.command-text-v14:7:4k",
140            Model::CohereCommandRV1 => "cohere.command-r-v1:0",
141            Model::CohereCommandRPlusV1 => "cohere.command-r-plus-v1:0",
142            Model::CohereCommandLightTextV14_4k => "cohere.command-light-text-v14:7:4k",
143            Model::MetaLlama38BInstructV1 => "meta.llama3-8b-instruct-v1:0",
144            Model::MetaLlama370BInstructV1 => "meta.llama3-70b-instruct-v1:0",
145            Model::MetaLlama318BInstructV1_128k => "meta.llama3-1-8b-instruct-v1:0:128k",
146            Model::MetaLlama318BInstructV1 => "meta.llama3-1-8b-instruct-v1:0",
147            Model::MetaLlama3170BInstructV1_128k => "meta.llama3-1-70b-instruct-v1:0:128k",
148            Model::MetaLlama3170BInstructV1 => "meta.llama3-1-70b-instruct-v1:0",
149            Model::MetaLlama3211BInstructV1 => "meta.llama3-2-11b-instruct-v1:0",
150            Model::MetaLlama3290BInstructV1 => "meta.llama3-2-90b-instruct-v1:0",
151            Model::MetaLlama321BInstructV1 => "meta.llama3-2-1b-instruct-v1:0",
152            Model::MetaLlama323BInstructV1 => "meta.llama3-2-3b-instruct-v1:0",
153            Model::MistralMistral7BInstructV0 => "mistral.mistral-7b-instruct-v0:2",
154            Model::MistralMixtral8x7BInstructV0 => "mistral.mixtral-8x7b-instruct-v0:1",
155            Model::MistralMistralLarge2402V1 => "mistral.mistral-large-2402-v1:0",
156            Model::MistralMistralSmall2402V1 => "mistral.mistral-small-2402-v1:0",
157            Model::MistralPixtralLarge2502V1 => "mistral.pixtral-large-2502-v1:0",
158            Model::PalmyraWriterX4 => "writer.palmyra-x4-v1:0",
159            Model::PalmyraWriterX5 => "writer.palmyra-x5-v1:0",
160            Self::Custom { name, .. } => name,
161        }
162    }
163
164    pub fn display_name(&self) -> &str {
165        match self {
166            Self::Claude3_5SonnetV2 => "Claude 3.5 Sonnet v2",
167            Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
168            Self::Claude3Opus => "Claude 3 Opus",
169            Self::Claude3Sonnet => "Claude 3 Sonnet",
170            Self::Claude3Haiku => "Claude 3 Haiku",
171            Self::Claude3_5Haiku => "Claude 3.5 Haiku",
172            Self::Claude3_7Sonnet => "Claude 3.7 Sonnet",
173            Self::Claude3_7SonnetThinking => "Claude 3.7 Sonnet Thinking",
174            Self::AmazonNovaLite => "Amazon Nova Lite",
175            Self::AmazonNovaMicro => "Amazon Nova Micro",
176            Self::AmazonNovaPro => "Amazon Nova Pro",
177            Self::AmazonNovaPremier => "Amazon Nova Premier",
178            Self::DeepSeekR1 => "DeepSeek R1",
179            Self::AI21J2GrandeInstruct => "AI21 Jurassic2 Grande Instruct",
180            Self::AI21J2JumboInstruct => "AI21 Jurassic2 Jumbo Instruct",
181            Self::AI21J2Mid => "AI21 Jurassic2 Mid",
182            Self::AI21J2MidV1 => "AI21 Jurassic2 Mid V1",
183            Self::AI21J2Ultra => "AI21 Jurassic2 Ultra",
184            Self::AI21J2UltraV1_8k => "AI21 Jurassic2 Ultra V1 8K",
185            Self::AI21J2UltraV1 => "AI21 Jurassic2 Ultra V1",
186            Self::AI21JambaInstructV1 => "AI21 Jamba Instruct",
187            Self::AI21Jamba15LargeV1 => "AI21 Jamba 1.5 Large",
188            Self::AI21Jamba15MiniV1 => "AI21 Jamba 1.5 Mini",
189            Self::CohereCommandTextV14_4k => "Cohere Command Text V14 4K",
190            Self::CohereCommandRV1 => "Cohere Command R V1",
191            Self::CohereCommandRPlusV1 => "Cohere Command R Plus V1",
192            Self::CohereCommandLightTextV14_4k => "Cohere Command Light Text V14 4K",
193            Self::MetaLlama38BInstructV1 => "Meta Llama 3 8B Instruct V1",
194            Self::MetaLlama370BInstructV1 => "Meta Llama 3 70B Instruct V1",
195            Self::MetaLlama318BInstructV1_128k => "Meta Llama 3 1.8B Instruct V1 128K",
196            Self::MetaLlama318BInstructV1 => "Meta Llama 3 1.8B Instruct V1",
197            Self::MetaLlama3170BInstructV1_128k => "Meta Llama 3 1 70B Instruct V1 128K",
198            Self::MetaLlama3170BInstructV1 => "Meta Llama 3 1 70B Instruct V1",
199            Self::MetaLlama3211BInstructV1 => "Meta Llama 3 2 11B Instruct V1",
200            Self::MetaLlama3290BInstructV1 => "Meta Llama 3 2 90B Instruct V1",
201            Self::MetaLlama321BInstructV1 => "Meta Llama 3 2 1B Instruct V1",
202            Self::MetaLlama323BInstructV1 => "Meta Llama 3 2 3B Instruct V1",
203            Self::MistralMistral7BInstructV0 => "Mistral 7B Instruct V0",
204            Self::MistralMixtral8x7BInstructV0 => "Mistral Mixtral 8x7B Instruct V0",
205            Self::MistralMistralLarge2402V1 => "Mistral Large 2402 V1",
206            Self::MistralMistralSmall2402V1 => "Mistral Small 2402 V1",
207            Self::MistralPixtralLarge2502V1 => "Pixtral Large 25.02 V1",
208            Self::PalmyraWriterX5 => "Writer Palmyra X5",
209            Self::PalmyraWriterX4 => "Writer Palmyra X4",
210            Self::Custom {
211                display_name, name, ..
212            } => display_name.as_deref().unwrap_or(name),
213        }
214    }
215
216    pub fn max_token_count(&self) -> usize {
217        match self {
218            Self::Claude3_5SonnetV2
219            | Self::Claude3Opus
220            | Self::Claude3Sonnet
221            | Self::Claude3_5Haiku
222            | Self::Claude3_7Sonnet => 200_000,
223            Self::AmazonNovaPremier => 1_000_000,
224            Self::PalmyraWriterX5 => 1_000_000,
225            Self::PalmyraWriterX4 => 128_000,
226            Self::Custom { max_tokens, .. } => *max_tokens,
227            _ => 128_000,
228        }
229    }
230
231    pub fn max_output_tokens(&self) -> u32 {
232        match self {
233            Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3_5Haiku => 4_096,
234            Self::Claude3_7Sonnet | Self::Claude3_7SonnetThinking => 128_000,
235            Self::Claude3_5SonnetV2 | Self::PalmyraWriterX4 | Self::PalmyraWriterX5 => 8_192,
236            Self::Custom {
237                max_output_tokens, ..
238            } => max_output_tokens.unwrap_or(4_096),
239            _ => 4_096,
240        }
241    }
242
243    pub fn default_temperature(&self) -> f32 {
244        match self {
245            Self::Claude3_5SonnetV2
246            | Self::Claude3Opus
247            | Self::Claude3Sonnet
248            | Self::Claude3_5Haiku
249            | Self::Claude3_7Sonnet => 1.0,
250            Self::Custom {
251                default_temperature,
252                ..
253            } => default_temperature.unwrap_or(1.0),
254            _ => 1.0,
255        }
256    }
257
258    pub fn supports_tool_use(&self) -> bool {
259        match self {
260            // Anthropic Claude 3 models (all support tool use)
261            Self::Claude3Opus
262            | Self::Claude3Sonnet
263            | Self::Claude3_5Sonnet
264            | Self::Claude3_5SonnetV2
265            | Self::Claude3_7Sonnet
266            | Self::Claude3_7SonnetThinking
267            | Self::Claude3_5Haiku => true,
268
269            // Amazon Nova models (all support tool use)
270            Self::AmazonNovaPremier
271            | Self::AmazonNovaPro
272            | Self::AmazonNovaLite
273            | Self::AmazonNovaMicro => true,
274
275            // AI21 Jamba 1.5 models support tool use
276            Self::AI21Jamba15LargeV1 | Self::AI21Jamba15MiniV1 => true,
277
278            // Cohere Command R models support tool use
279            Self::CohereCommandRV1 | Self::CohereCommandRPlusV1 => true,
280
281            // All other models don't support tool use
282            // Including Meta Llama 3.2, AI21 Jurassic, and others
283            _ => false,
284        }
285    }
286
287    pub fn mode(&self) -> BedrockModelMode {
288        match self {
289            Model::Claude3_7SonnetThinking => BedrockModelMode::Thinking {
290                budget_tokens: Some(4096),
291            },
292            _ => BedrockModelMode::Default,
293        }
294    }
295
296    pub fn cross_region_inference_id(&self, region: &str) -> anyhow::Result<String> {
297        let region_group = if region.starts_with("us-gov-") {
298            "us-gov"
299        } else if region.starts_with("us-") {
300            "us"
301        } else if region.starts_with("eu-") {
302            "eu"
303        } else if region.starts_with("ap-") || region == "me-central-1" || region == "me-south-1" {
304            "apac"
305        } else if region.starts_with("ca-") || region.starts_with("sa-") {
306            // Canada and South America regions - default to US profiles
307            "us"
308        } else {
309            anyhow::bail!("Unsupported Region {region}");
310        };
311
312        let model_id = self.id();
313
314        match (self, region_group) {
315            // Custom models can't have CRI IDs
316            (Model::Custom { .. }, _) => Ok(self.id().into()),
317
318            // Models with US Gov only
319            (Model::Claude3_5Sonnet, "us-gov") | (Model::Claude3Haiku, "us-gov") => {
320                Ok(format!("{}.{}", region_group, model_id))
321            }
322
323            // Models available only in US
324            (Model::Claude3Opus, "us")
325            | (Model::Claude3_5Haiku, "us")
326            | (Model::Claude3_7Sonnet, "us")
327            | (Model::Claude3_7SonnetThinking, "us")
328            | (Model::AmazonNovaPremier, "us")
329            | (Model::MistralPixtralLarge2502V1, "us") => {
330                Ok(format!("{}.{}", region_group, model_id))
331            }
332
333            // Models available in US, EU, and APAC
334            (Model::Claude3_5SonnetV2, "us")
335            | (Model::Claude3_5SonnetV2, "apac")
336            | (Model::Claude3_5Sonnet, _)
337            | (Model::Claude3Haiku, _)
338            | (Model::Claude3Sonnet, _)
339            | (Model::AmazonNovaLite, _)
340            | (Model::AmazonNovaMicro, _)
341            | (Model::AmazonNovaPro, _) => Ok(format!("{}.{}", region_group, model_id)),
342
343            // Models with limited EU availability
344            (Model::MetaLlama321BInstructV1, "us")
345            | (Model::MetaLlama321BInstructV1, "eu")
346            | (Model::MetaLlama323BInstructV1, "us")
347            | (Model::MetaLlama323BInstructV1, "eu") => {
348                Ok(format!("{}.{}", region_group, model_id))
349            }
350
351            // US-only models (all remaining Meta models)
352            (Model::MetaLlama38BInstructV1, "us")
353            | (Model::MetaLlama370BInstructV1, "us")
354            | (Model::MetaLlama318BInstructV1, "us")
355            | (Model::MetaLlama318BInstructV1_128k, "us")
356            | (Model::MetaLlama3170BInstructV1, "us")
357            | (Model::MetaLlama3170BInstructV1_128k, "us")
358            | (Model::MetaLlama3211BInstructV1, "us")
359            | (Model::MetaLlama3290BInstructV1, "us") => {
360                Ok(format!("{}.{}", region_group, model_id))
361            }
362
363            // Writer models only available in the US
364            (Model::PalmyraWriterX4, "us") | (Model::PalmyraWriterX5, "us") => {
365                // They have some goofiness
366                Ok(format!("{}.{}", region_group, model_id))
367            }
368
369            // Any other combination is not supported
370            _ => Ok(self.id().into()),
371        }
372    }
373}
374
375#[cfg(test)]
376mod tests {
377    use super::*;
378
379    #[test]
380    fn test_us_region_inference_ids() -> anyhow::Result<()> {
381        // Test US regions
382        assert_eq!(
383            Model::Claude3_5SonnetV2.cross_region_inference_id("us-east-1")?,
384            "us.anthropic.claude-3-5-sonnet-20241022-v2:0"
385        );
386        assert_eq!(
387            Model::Claude3_5SonnetV2.cross_region_inference_id("us-west-2")?,
388            "us.anthropic.claude-3-5-sonnet-20241022-v2:0"
389        );
390        assert_eq!(
391            Model::AmazonNovaPro.cross_region_inference_id("us-east-2")?,
392            "us.amazon.nova-pro-v1:0"
393        );
394        Ok(())
395    }
396
397    #[test]
398    fn test_eu_region_inference_ids() -> anyhow::Result<()> {
399        // Test European regions
400        assert_eq!(
401            Model::Claude3Sonnet.cross_region_inference_id("eu-west-1")?,
402            "eu.anthropic.claude-3-sonnet-20240229-v1:0"
403        );
404        assert_eq!(
405            Model::AmazonNovaMicro.cross_region_inference_id("eu-north-1")?,
406            "eu.amazon.nova-micro-v1:0"
407        );
408        Ok(())
409    }
410
411    #[test]
412    fn test_apac_region_inference_ids() -> anyhow::Result<()> {
413        // Test Asia-Pacific regions
414        assert_eq!(
415            Model::Claude3_5SonnetV2.cross_region_inference_id("ap-northeast-1")?,
416            "apac.anthropic.claude-3-5-sonnet-20241022-v2:0"
417        );
418        assert_eq!(
419            Model::AmazonNovaLite.cross_region_inference_id("ap-south-1")?,
420            "apac.amazon.nova-lite-v1:0"
421        );
422        Ok(())
423    }
424
425    #[test]
426    fn test_gov_region_inference_ids() -> anyhow::Result<()> {
427        // Test Government regions
428        assert_eq!(
429            Model::Claude3_5Sonnet.cross_region_inference_id("us-gov-east-1")?,
430            "us-gov.anthropic.claude-3-5-sonnet-20240620-v1:0"
431        );
432        assert_eq!(
433            Model::Claude3Haiku.cross_region_inference_id("us-gov-west-1")?,
434            "us-gov.anthropic.claude-3-haiku-20240307-v1:0"
435        );
436        Ok(())
437    }
438
439    #[test]
440    fn test_meta_models_inference_ids() -> anyhow::Result<()> {
441        // Test Meta models
442        assert_eq!(
443            Model::MetaLlama370BInstructV1.cross_region_inference_id("us-east-1")?,
444            "us.meta.llama3-70b-instruct-v1:0"
445        );
446        assert_eq!(
447            Model::MetaLlama321BInstructV1.cross_region_inference_id("eu-west-1")?,
448            "eu.meta.llama3-2-1b-instruct-v1:0"
449        );
450        Ok(())
451    }
452
453    #[test]
454    fn test_mistral_models_inference_ids() -> anyhow::Result<()> {
455        // Mistral models don't follow the regional prefix pattern,
456        // so they should return their original IDs
457        assert_eq!(
458            Model::MistralMistralLarge2402V1.cross_region_inference_id("us-east-1")?,
459            "mistral.mistral-large-2402-v1:0"
460        );
461        assert_eq!(
462            Model::MistralMixtral8x7BInstructV0.cross_region_inference_id("eu-west-1")?,
463            "mistral.mixtral-8x7b-instruct-v0:1"
464        );
465        Ok(())
466    }
467
468    #[test]
469    fn test_ai21_models_inference_ids() -> anyhow::Result<()> {
470        // AI21 models don't follow the regional prefix pattern,
471        // so they should return their original IDs
472        assert_eq!(
473            Model::AI21J2UltraV1.cross_region_inference_id("us-east-1")?,
474            "ai21.j2-ultra-v1"
475        );
476        assert_eq!(
477            Model::AI21JambaInstructV1.cross_region_inference_id("eu-west-1")?,
478            "ai21.jamba-instruct-v1:0"
479        );
480        Ok(())
481    }
482
483    #[test]
484    fn test_cohere_models_inference_ids() -> anyhow::Result<()> {
485        // Cohere models don't follow the regional prefix pattern,
486        // so they should return their original IDs
487        assert_eq!(
488            Model::CohereCommandRV1.cross_region_inference_id("us-east-1")?,
489            "cohere.command-r-v1:0"
490        );
491        assert_eq!(
492            Model::CohereCommandTextV14_4k.cross_region_inference_id("ap-southeast-1")?,
493            "cohere.command-text-v14:7:4k"
494        );
495        Ok(())
496    }
497
498    #[test]
499    fn test_custom_model_inference_ids() -> anyhow::Result<()> {
500        // Test custom models
501        let custom_model = Model::Custom {
502            name: "custom.my-model-v1:0".to_string(),
503            max_tokens: 100000,
504            display_name: Some("My Custom Model".to_string()),
505            max_output_tokens: Some(8192),
506            default_temperature: Some(0.7),
507        };
508
509        // Custom model should return its name unchanged
510        assert_eq!(
511            custom_model.cross_region_inference_id("us-east-1")?,
512            "custom.my-model-v1:0"
513        );
514
515        Ok(())
516    }
517}