models.rs

  1use anyhow::anyhow;
  2use serde::{Deserialize, Serialize};
  3use strum::EnumIter;
  4
  5#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
  6#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
  7pub enum BedrockModelMode {
  8    #[default]
  9    Default,
 10    Thinking {
 11        budget_tokens: Option<u64>,
 12    },
 13}
 14
 15#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 16#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
 17pub enum Model {
 18    // Anthropic models (already included)
 19    #[default]
 20    #[serde(rename = "claude-3-5-sonnet-v2", alias = "claude-3-5-sonnet-latest")]
 21    Claude3_5SonnetV2,
 22    #[serde(rename = "claude-3-7-sonnet", alias = "claude-3-7-sonnet-latest")]
 23    Claude3_7Sonnet,
 24    #[serde(
 25        rename = "claude-3-7-sonnet-thinking",
 26        alias = "claude-3-7-sonnet-thinking-latest"
 27    )]
 28    Claude3_7SonnetThinking,
 29    #[serde(rename = "claude-3-opus", alias = "claude-3-opus-latest")]
 30    Claude3Opus,
 31    #[serde(rename = "claude-3-sonnet", alias = "claude-3-sonnet-latest")]
 32    Claude3Sonnet,
 33    #[serde(rename = "claude-3-5-haiku", alias = "claude-3-5-haiku-latest")]
 34    Claude3_5Haiku,
 35    Claude3_5Sonnet,
 36    Claude3Haiku,
 37    // Amazon Nova Models
 38    AmazonNovaLite,
 39    AmazonNovaMicro,
 40    AmazonNovaPro,
 41    // AI21 models
 42    AI21J2GrandeInstruct,
 43    AI21J2JumboInstruct,
 44    AI21J2Mid,
 45    AI21J2MidV1,
 46    AI21J2Ultra,
 47    AI21J2UltraV1_8k,
 48    AI21J2UltraV1,
 49    AI21JambaInstructV1,
 50    AI21Jamba15LargeV1,
 51    AI21Jamba15MiniV1,
 52    // Cohere models
 53    CohereCommandTextV14_4k,
 54    CohereCommandRV1,
 55    CohereCommandRPlusV1,
 56    CohereCommandLightTextV14_4k,
 57    // DeepSeek
 58    DeepSeekR1,
 59    // Meta models
 60    MetaLlama38BInstructV1,
 61    MetaLlama370BInstructV1,
 62    MetaLlama318BInstructV1_128k,
 63    MetaLlama318BInstructV1,
 64    MetaLlama3170BInstructV1_128k,
 65    MetaLlama3170BInstructV1,
 66    MetaLlama3211BInstructV1,
 67    MetaLlama3290BInstructV1,
 68    MetaLlama321BInstructV1,
 69    MetaLlama323BInstructV1,
 70    // Mistral models
 71    MistralMistral7BInstructV0,
 72    MistralMixtral8x7BInstructV0,
 73    MistralMistralLarge2402V1,
 74    MistralMistralSmall2402V1,
 75    #[serde(rename = "custom")]
 76    Custom {
 77        name: String,
 78        max_tokens: usize,
 79        /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
 80        display_name: Option<String>,
 81        max_output_tokens: Option<u32>,
 82        default_temperature: Option<f32>,
 83    },
 84}
 85
 86impl Model {
 87    pub fn from_id(id: &str) -> anyhow::Result<Self> {
 88        if id.starts_with("claude-3-5-sonnet-v2") {
 89            Ok(Self::Claude3_5SonnetV2)
 90        } else if id.starts_with("claude-3-opus") {
 91            Ok(Self::Claude3Opus)
 92        } else if id.starts_with("claude-3-sonnet") {
 93            Ok(Self::Claude3Sonnet)
 94        } else if id.starts_with("claude-3-5-haiku") {
 95            Ok(Self::Claude3_5Haiku)
 96        } else if id.starts_with("claude-3-7-sonnet") {
 97            Ok(Self::Claude3_7Sonnet)
 98        } else if id.starts_with("claude-3-7-sonnet-thinking") {
 99            Ok(Self::Claude3_7SonnetThinking)
100        } else {
101            Err(anyhow!("invalid model id"))
102        }
103    }
104
105    pub fn id(&self) -> &str {
106        match self {
107            Model::Claude3_5SonnetV2 => "anthropic.claude-3-5-sonnet-20241022-v2:0",
108            Model::Claude3_5Sonnet => "anthropic.claude-3-5-sonnet-20240620-v1:0",
109            Model::Claude3Opus => "anthropic.claude-3-opus-20240229-v1:0",
110            Model::Claude3Sonnet => "anthropic.claude-3-sonnet-20240229-v1:0",
111            Model::Claude3Haiku => "anthropic.claude-3-haiku-20240307-v1:0",
112            Model::Claude3_5Haiku => "anthropic.claude-3-5-haiku-20241022-v1:0",
113            Model::Claude3_7Sonnet | Model::Claude3_7SonnetThinking => {
114                "anthropic.claude-3-7-sonnet-20250219-v1:0"
115            }
116            Model::AmazonNovaLite => "amazon.nova-lite-v1:0",
117            Model::AmazonNovaMicro => "amazon.nova-micro-v1:0",
118            Model::AmazonNovaPro => "amazon.nova-pro-v1:0",
119            Model::DeepSeekR1 => "us.deepseek.r1-v1:0",
120            Model::AI21J2GrandeInstruct => "ai21.j2-grande-instruct",
121            Model::AI21J2JumboInstruct => "ai21.j2-jumbo-instruct",
122            Model::AI21J2Mid => "ai21.j2-mid",
123            Model::AI21J2MidV1 => "ai21.j2-mid-v1",
124            Model::AI21J2Ultra => "ai21.j2-ultra",
125            Model::AI21J2UltraV1_8k => "ai21.j2-ultra-v1:0:8k",
126            Model::AI21J2UltraV1 => "ai21.j2-ultra-v1",
127            Model::AI21JambaInstructV1 => "ai21.jamba-instruct-v1:0",
128            Model::AI21Jamba15LargeV1 => "ai21.jamba-1-5-large-v1:0",
129            Model::AI21Jamba15MiniV1 => "ai21.jamba-1-5-mini-v1:0",
130            Model::CohereCommandTextV14_4k => "cohere.command-text-v14:7:4k",
131            Model::CohereCommandRV1 => "cohere.command-r-v1:0",
132            Model::CohereCommandRPlusV1 => "cohere.command-r-plus-v1:0",
133            Model::CohereCommandLightTextV14_4k => "cohere.command-light-text-v14:7:4k",
134            Model::MetaLlama38BInstructV1 => "meta.llama3-8b-instruct-v1:0",
135            Model::MetaLlama370BInstructV1 => "meta.llama3-70b-instruct-v1:0",
136            Model::MetaLlama318BInstructV1_128k => "meta.llama3-1-8b-instruct-v1:0:128k",
137            Model::MetaLlama318BInstructV1 => "meta.llama3-1-8b-instruct-v1:0",
138            Model::MetaLlama3170BInstructV1_128k => "meta.llama3-1-70b-instruct-v1:0:128k",
139            Model::MetaLlama3170BInstructV1 => "meta.llama3-1-70b-instruct-v1:0",
140            Model::MetaLlama3211BInstructV1 => "meta.llama3-2-11b-instruct-v1:0",
141            Model::MetaLlama3290BInstructV1 => "meta.llama3-2-90b-instruct-v1:0",
142            Model::MetaLlama321BInstructV1 => "meta.llama3-2-1b-instruct-v1:0",
143            Model::MetaLlama323BInstructV1 => "meta.llama3-2-3b-instruct-v1:0",
144            Model::MistralMistral7BInstructV0 => "mistral.mistral-7b-instruct-v0:2",
145            Model::MistralMixtral8x7BInstructV0 => "mistral.mixtral-8x7b-instruct-v0:1",
146            Model::MistralMistralLarge2402V1 => "mistral.mistral-large-2402-v1:0",
147            Model::MistralMistralSmall2402V1 => "mistral.mistral-small-2402-v1:0",
148            Self::Custom { name, .. } => name,
149        }
150    }
151
152    pub fn display_name(&self) -> &str {
153        match self {
154            Self::Claude3_5SonnetV2 => "Claude 3.5 Sonnet v2",
155            Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
156            Self::Claude3Opus => "Claude 3 Opus",
157            Self::Claude3Sonnet => "Claude 3 Sonnet",
158            Self::Claude3Haiku => "Claude 3 Haiku",
159            Self::Claude3_5Haiku => "Claude 3.5 Haiku",
160            Self::Claude3_7Sonnet => "Claude 3.7 Sonnet",
161            Self::Claude3_7SonnetThinking => "Claude 3.7 Sonnet Thinking",
162            Self::AmazonNovaLite => "Amazon Nova Lite",
163            Self::AmazonNovaMicro => "Amazon Nova Micro",
164            Self::AmazonNovaPro => "Amazon Nova Pro",
165            Self::DeepSeekR1 => "DeepSeek R1",
166            Self::AI21J2GrandeInstruct => "AI21 Jurassic2 Grande Instruct",
167            Self::AI21J2JumboInstruct => "AI21 Jurassic2 Jumbo Instruct",
168            Self::AI21J2Mid => "AI21 Jurassic2 Mid",
169            Self::AI21J2MidV1 => "AI21 Jurassic2 Mid V1",
170            Self::AI21J2Ultra => "AI21 Jurassic2 Ultra",
171            Self::AI21J2UltraV1_8k => "AI21 Jurassic2 Ultra V1 8K",
172            Self::AI21J2UltraV1 => "AI21 Jurassic2 Ultra V1",
173            Self::AI21JambaInstructV1 => "AI21 Jamba Instruct",
174            Self::AI21Jamba15LargeV1 => "AI21 Jamba 1.5 Large",
175            Self::AI21Jamba15MiniV1 => "AI21 Jamba 1.5 Mini",
176            Self::CohereCommandTextV14_4k => "Cohere Command Text V14 4K",
177            Self::CohereCommandRV1 => "Cohere Command R V1",
178            Self::CohereCommandRPlusV1 => "Cohere Command R Plus V1",
179            Self::CohereCommandLightTextV14_4k => "Cohere Command Light Text V14 4K",
180            Self::MetaLlama38BInstructV1 => "Meta Llama 3 8B Instruct V1",
181            Self::MetaLlama370BInstructV1 => "Meta Llama 3 70B Instruct V1",
182            Self::MetaLlama318BInstructV1_128k => "Meta Llama 3 1.8B Instruct V1 128K",
183            Self::MetaLlama318BInstructV1 => "Meta Llama 3 1.8B Instruct V1",
184            Self::MetaLlama3170BInstructV1_128k => "Meta Llama 3 1 70B Instruct V1 128K",
185            Self::MetaLlama3170BInstructV1 => "Meta Llama 3 1 70B Instruct V1",
186            Self::MetaLlama3211BInstructV1 => "Meta Llama 3 2 11B Instruct V1",
187            Self::MetaLlama3290BInstructV1 => "Meta Llama 3 2 90B Instruct V1",
188            Self::MetaLlama321BInstructV1 => "Meta Llama 3 2 1B Instruct V1",
189            Self::MetaLlama323BInstructV1 => "Meta Llama 3 2 3B Instruct V1",
190            Self::MistralMistral7BInstructV0 => "Mistral 7B Instruct V0",
191            Self::MistralMixtral8x7BInstructV0 => "Mistral Mixtral 8x7B Instruct V0",
192            Self::MistralMistralLarge2402V1 => "Mistral Large 2402 V1",
193            Self::MistralMistralSmall2402V1 => "Mistral Small 2402 V1",
194            Self::Custom {
195                display_name, name, ..
196            } => display_name.as_deref().unwrap_or(name),
197        }
198    }
199
200    pub fn max_token_count(&self) -> usize {
201        match self {
202            Self::Claude3_5SonnetV2
203            | Self::Claude3Opus
204            | Self::Claude3Sonnet
205            | Self::Claude3_5Haiku
206            | Self::Claude3_7Sonnet => 200_000,
207            Self::Custom { max_tokens, .. } => *max_tokens,
208            _ => 200_000,
209        }
210    }
211
212    pub fn max_output_tokens(&self) -> u32 {
213        match self {
214            Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3_5Haiku => 4_096,
215            Self::Claude3_7Sonnet | Self::Claude3_7SonnetThinking => 128_000,
216            Self::Claude3_5SonnetV2 => 8_192,
217            Self::Custom {
218                max_output_tokens, ..
219            } => max_output_tokens.unwrap_or(4_096),
220            _ => 4_096,
221        }
222    }
223
224    pub fn default_temperature(&self) -> f32 {
225        match self {
226            Self::Claude3_5SonnetV2
227            | Self::Claude3Opus
228            | Self::Claude3Sonnet
229            | Self::Claude3_5Haiku
230            | Self::Claude3_7Sonnet => 1.0,
231            Self::Custom {
232                default_temperature,
233                ..
234            } => default_temperature.unwrap_or(1.0),
235            _ => 1.0,
236        }
237    }
238
239    pub fn supports_tool_use(&self) -> bool {
240        match self {
241            // Anthropic Claude 3 models (all support tool use)
242            Self::Claude3Opus
243            | Self::Claude3Sonnet
244            | Self::Claude3_5Sonnet
245            | Self::Claude3_5SonnetV2
246            | Self::Claude3_7Sonnet
247            | Self::Claude3_7SonnetThinking
248            | Self::Claude3_5Haiku => true,
249
250            // Amazon Nova models (all support tool use)
251            Self::AmazonNovaPro | Self::AmazonNovaLite | Self::AmazonNovaMicro => true,
252
253            // AI21 Jamba 1.5 models support tool use
254            Self::AI21Jamba15LargeV1 | Self::AI21Jamba15MiniV1 => true,
255
256            // Cohere Command R models support tool use
257            Self::CohereCommandRV1 | Self::CohereCommandRPlusV1 => true,
258
259            // All other models don't support tool use
260            // Including Meta Llama 3.2, AI21 Jurassic, and others
261            _ => false,
262        }
263    }
264
265    pub fn mode(&self) -> BedrockModelMode {
266        match self {
267            Model::Claude3_7SonnetThinking => BedrockModelMode::Thinking {
268                budget_tokens: Some(4096),
269            },
270            _ => BedrockModelMode::Default,
271        }
272    }
273
274    pub fn cross_region_inference_id(&self, region: &str) -> Result<String, anyhow::Error> {
275        let region_group = if region.starts_with("us-gov-") {
276            "us-gov"
277        } else if region.starts_with("us-") {
278            "us"
279        } else if region.starts_with("eu-") {
280            "eu"
281        } else if region.starts_with("ap-") || region == "me-central-1" || region == "me-south-1" {
282            "apac"
283        } else if region.starts_with("ca-") || region.starts_with("sa-") {
284            // Canada and South America regions - default to US profiles
285            "us"
286        } else {
287            // Unknown region
288            return Err(anyhow!("Unsupported Region"));
289        };
290
291        let model_id = self.id();
292
293        match (self, region_group) {
294            // Custom models can't have CRI IDs
295            (Model::Custom { .. }, _) => Ok(self.id().into()),
296
297            // Models with US Gov only
298            (Model::Claude3_5Sonnet, "us-gov") | (Model::Claude3Haiku, "us-gov") => {
299                Ok(format!("{}.{}", region_group, model_id))
300            }
301
302            // Models available only in US
303            (Model::Claude3Opus, "us")
304            | (Model::Claude3_7Sonnet, "us")
305            | (Model::Claude3_7SonnetThinking, "us") => {
306                Ok(format!("{}.{}", region_group, model_id))
307            }
308
309            // Models available in US, EU, and APAC
310            (Model::Claude3_5SonnetV2, "us")
311            | (Model::Claude3_5SonnetV2, "apac")
312            | (Model::Claude3_5Sonnet, _)
313            | (Model::Claude3Haiku, _)
314            | (Model::Claude3Sonnet, _)
315            | (Model::AmazonNovaLite, _)
316            | (Model::AmazonNovaMicro, _)
317            | (Model::AmazonNovaPro, _) => Ok(format!("{}.{}", region_group, model_id)),
318
319            // Models with limited EU availability
320            (Model::MetaLlama321BInstructV1, "us")
321            | (Model::MetaLlama321BInstructV1, "eu")
322            | (Model::MetaLlama323BInstructV1, "us")
323            | (Model::MetaLlama323BInstructV1, "eu") => {
324                Ok(format!("{}.{}", region_group, model_id))
325            }
326
327            // US-only models (all remaining Meta models)
328            (Model::MetaLlama38BInstructV1, "us")
329            | (Model::MetaLlama370BInstructV1, "us")
330            | (Model::MetaLlama318BInstructV1, "us")
331            | (Model::MetaLlama318BInstructV1_128k, "us")
332            | (Model::MetaLlama3170BInstructV1, "us")
333            | (Model::MetaLlama3170BInstructV1_128k, "us")
334            | (Model::MetaLlama3211BInstructV1, "us")
335            | (Model::MetaLlama3290BInstructV1, "us") => {
336                Ok(format!("{}.{}", region_group, model_id))
337            }
338
339            // Any other combination is not supported
340            _ => Ok(self.id().into()),
341        }
342    }
343}
344
345#[cfg(test)]
346mod tests {
347    use super::*;
348
349    #[test]
350    fn test_us_region_inference_ids() -> anyhow::Result<()> {
351        // Test US regions
352        assert_eq!(
353            Model::Claude3_5SonnetV2.cross_region_inference_id("us-east-1")?,
354            "us.anthropic.claude-3-5-sonnet-20241022-v2:0"
355        );
356        assert_eq!(
357            Model::Claude3_5SonnetV2.cross_region_inference_id("us-west-2")?,
358            "us.anthropic.claude-3-5-sonnet-20241022-v2:0"
359        );
360        assert_eq!(
361            Model::AmazonNovaPro.cross_region_inference_id("us-east-2")?,
362            "us.amazon.nova-pro-v1:0"
363        );
364        Ok(())
365    }
366
367    #[test]
368    fn test_eu_region_inference_ids() -> anyhow::Result<()> {
369        // Test European regions
370        assert_eq!(
371            Model::Claude3Sonnet.cross_region_inference_id("eu-west-1")?,
372            "eu.anthropic.claude-3-sonnet-20240229-v1:0"
373        );
374        assert_eq!(
375            Model::AmazonNovaMicro.cross_region_inference_id("eu-north-1")?,
376            "eu.amazon.nova-micro-v1:0"
377        );
378        Ok(())
379    }
380
381    #[test]
382    fn test_apac_region_inference_ids() -> anyhow::Result<()> {
383        // Test Asia-Pacific regions
384        assert_eq!(
385            Model::Claude3_5SonnetV2.cross_region_inference_id("ap-northeast-1")?,
386            "apac.anthropic.claude-3-5-sonnet-20241022-v2:0"
387        );
388        assert_eq!(
389            Model::AmazonNovaLite.cross_region_inference_id("ap-south-1")?,
390            "apac.amazon.nova-lite-v1:0"
391        );
392        Ok(())
393    }
394
395    #[test]
396    fn test_gov_region_inference_ids() -> anyhow::Result<()> {
397        // Test Government regions
398        assert_eq!(
399            Model::Claude3_5Sonnet.cross_region_inference_id("us-gov-east-1")?,
400            "us-gov.anthropic.claude-3-5-sonnet-20240620-v1:0"
401        );
402        assert_eq!(
403            Model::Claude3Haiku.cross_region_inference_id("us-gov-west-1")?,
404            "us-gov.anthropic.claude-3-haiku-20240307-v1:0"
405        );
406        Ok(())
407    }
408
409    #[test]
410    fn test_meta_models_inference_ids() -> anyhow::Result<()> {
411        // Test Meta models
412        assert_eq!(
413            Model::MetaLlama370BInstructV1.cross_region_inference_id("us-east-1")?,
414            "us.meta.llama3-70b-instruct-v1:0"
415        );
416        assert_eq!(
417            Model::MetaLlama321BInstructV1.cross_region_inference_id("eu-west-1")?,
418            "eu.meta.llama3-2-1b-instruct-v1:0"
419        );
420        Ok(())
421    }
422
423    #[test]
424    fn test_mistral_models_inference_ids() -> anyhow::Result<()> {
425        // Mistral models don't follow the regional prefix pattern,
426        // so they should return their original IDs
427        assert_eq!(
428            Model::MistralMistralLarge2402V1.cross_region_inference_id("us-east-1")?,
429            "mistral.mistral-large-2402-v1:0"
430        );
431        assert_eq!(
432            Model::MistralMixtral8x7BInstructV0.cross_region_inference_id("eu-west-1")?,
433            "mistral.mixtral-8x7b-instruct-v0:1"
434        );
435        Ok(())
436    }
437
438    #[test]
439    fn test_ai21_models_inference_ids() -> anyhow::Result<()> {
440        // AI21 models don't follow the regional prefix pattern,
441        // so they should return their original IDs
442        assert_eq!(
443            Model::AI21J2UltraV1.cross_region_inference_id("us-east-1")?,
444            "ai21.j2-ultra-v1"
445        );
446        assert_eq!(
447            Model::AI21JambaInstructV1.cross_region_inference_id("eu-west-1")?,
448            "ai21.jamba-instruct-v1:0"
449        );
450        Ok(())
451    }
452
453    #[test]
454    fn test_cohere_models_inference_ids() -> anyhow::Result<()> {
455        // Cohere models don't follow the regional prefix pattern,
456        // so they should return their original IDs
457        assert_eq!(
458            Model::CohereCommandRV1.cross_region_inference_id("us-east-1")?,
459            "cohere.command-r-v1:0"
460        );
461        assert_eq!(
462            Model::CohereCommandTextV14_4k.cross_region_inference_id("ap-southeast-1")?,
463            "cohere.command-text-v14:7:4k"
464        );
465        Ok(())
466    }
467
468    #[test]
469    fn test_custom_model_inference_ids() -> anyhow::Result<()> {
470        // Test custom models
471        let custom_model = Model::Custom {
472            name: "custom.my-model-v1:0".to_string(),
473            max_tokens: 100000,
474            display_name: Some("My Custom Model".to_string()),
475            max_output_tokens: Some(8192),
476            default_temperature: Some(0.7),
477        };
478
479        // Custom model should return its name unchanged
480        assert_eq!(
481            custom_model.cross_region_inference_id("us-east-1")?,
482            "custom.my-model-v1:0"
483        );
484
485        Ok(())
486    }
487}