@@ -108,8 +108,12 @@ pub enum Model {
}
impl Model {
- pub fn default_fast() -> Self {
- Self::Claude3_5Haiku
+ pub fn default_fast(region: &str) -> Self {
+ if region.starts_with("us-") {
+ Self::Claude3_5Haiku
+ } else {
+ Self::Claude3Haiku
+ }
}
pub fn from_id(id: &str) -> anyhow::Result<Self> {
@@ -229,6 +229,17 @@ impl State {
Ok(())
})
}
+
+ fn get_region(&self) -> String {
+ // Get region - from credentials or directly from settings
+ let credentials_region = self.credentials.as_ref().map(|s| s.region.clone());
+ let settings_region = self.settings.as_ref().and_then(|s| s.region.clone());
+
+ // Use credentials region if available, otherwise use settings region, finally fall back to default
+ credentials_region
+ .or(settings_region)
+ .unwrap_or(String::from("us-east-1"))
+ }
}
pub struct BedrockLanguageModelProvider {
@@ -289,8 +300,9 @@ impl LanguageModelProvider for BedrockLanguageModelProvider {
Some(self.create_language_model(bedrock::Model::default()))
}
- fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
- Some(self.create_language_model(bedrock::Model::default_fast()))
+ fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
+ let region = self.state.read(cx).get_region();
+ Some(self.create_language_model(bedrock::Model::default_fast(region.as_str())))
}
fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
@@ -377,11 +389,7 @@ impl BedrockModel {
let endpoint = state.settings.as_ref().and_then(|s| s.endpoint.clone());
- let region = state
- .settings
- .as_ref()
- .and_then(|s| s.region.clone())
- .unwrap_or(String::from("us-east-1"));
+ let region = state.get_region();
(
auth_method,
@@ -530,16 +538,7 @@ impl LanguageModel for BedrockModel {
LanguageModelCompletionError,
>,
> {
- let Ok(region) = cx.read_entity(&self.state, |state, _cx| {
- // Get region - from credentials or directly from settings
- let credentials_region = state.credentials.as_ref().map(|s| s.region.clone());
- let settings_region = state.settings.as_ref().and_then(|s| s.region.clone());
-
- // Use credentials region if available, otherwise use settings region, finally fall back to default
- credentials_region
- .or(settings_region)
- .unwrap_or(String::from("us-east-1"))
- }) else {
+ let Ok(region) = cx.read_entity(&self.state, |state, _cx| state.get_region()) else {
return async move { Err(anyhow::anyhow!("App State Dropped").into()) }.boxed();
};