bedrock: Add support for tool use, cross-region inference, and Claude 3.7 Thinking (#28137)

Shardul Vaidya and Marshall Bowers created

Closes #27223
Merges: #27996, #26734, #27949 

Release Notes:

- AWS Bedrock: Added advanced authentication strategies with:
  - Short lived credentials with Session Tokens 
  - AWS Named Profile
  - EC2 Identity, Pod Identity, Web Identity
- AWS Bedrock: Added Claude 3.7 Thinking support.
- AWS Bedrock: Adding Cross Region Inference for all combinations of
regions and model availability.
- Agent Beta: Added support for AWS Bedrock.

---------

Co-authored-by: Marshall Bowers <git@maxdeviant.com>

Change summary

Cargo.lock                                     |  74 
Cargo.toml                                     |  10 
crates/bedrock/src/bedrock.rs                  |  73 
crates/bedrock/src/models.rs                   | 304 ++++++
crates/language_models/src/provider/bedrock.rs | 868 ++++++++++++++-----
crates/language_models/src/settings.rs         |  29 
6 files changed, 1,041 insertions(+), 317 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -2141,9 +2141,9 @@ dependencies = [
 
 [[package]]
 name = "blake3"
-version = "1.8.0"
+version = "1.8.1"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "34a796731680be7931955498a16a10b2270c7762963d5d570fdbfe02dcbf314f"
+checksum = "389a099b34312839e16420d499a9cad9650541715937ffbdd40d36f49e77eeb3"
 dependencies = [
  "arrayref",
  "arrayvec",
@@ -2455,9 +2455,9 @@ dependencies = [
 
 [[package]]
 name = "cap-fs-ext"
-version = "3.4.2"
+version = "3.4.3"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "7f78efdd7378980d79c0f36b519e51191742d2c9f91ffa5e228fba9f3806d2e1"
+checksum = "f6323b9baffb4d6d9c65bfef3129db62b1391f7fb953fe67c0d7cb0804feb77b"
 dependencies = [
  "cap-primitives",
  "cap-std",
@@ -2467,9 +2467,9 @@ dependencies = [
 
 [[package]]
 name = "cap-net-ext"
-version = "3.4.2"
+version = "3.4.3"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "4ac68674a6042af2bcee1adad9f6abd432642cf03444ce3a5b36c3f39f23baf8"
+checksum = "66022e5e076ea27041a05ebd4349022e2133e6f4977974dffd54ceb7b982e871"
 dependencies = [
  "cap-primitives",
  "cap-std",
@@ -2479,9 +2479,9 @@ dependencies = [
 
 [[package]]
 name = "cap-primitives"
-version = "3.4.2"
+version = "3.4.3"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "8fc15faeed2223d8b8e8cc1857f5861935a06d06713c4ac106b722ae9ce3c369"
+checksum = "50ad0183a9659850877cefe8f5b87d564b2dd1fe78b18945813687f29c0a6878"
 dependencies = [
  "ambient-authority",
  "fs-set-times",
@@ -2496,9 +2496,9 @@ dependencies = [
 
 [[package]]
 name = "cap-rand"
-version = "3.4.2"
+version = "3.4.3"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "dea13372b49df066d1ae654e5c6e41799c1efd9f6b36794b921e877ea4037977"
+checksum = "ab78a9f6301e70c0fe5df7328adbcb9228277fdb7bab36f312fc072f505e38c2"
 dependencies = [
  "ambient-authority",
  "rand 0.8.5",
@@ -2506,9 +2506,9 @@ dependencies = [
 
 [[package]]
 name = "cap-std"
-version = "3.4.2"
+version = "3.4.3"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "c3dbd3e8e8d093d6ccb4b512264869e1281cdb032f7940bd50b2894f96f25609"
+checksum = "1c41814365b796ed12688026cb90a1e03236a84ccf009628f9c43c8aa3af250a"
 dependencies = [
  "cap-primitives",
  "io-extras",
@@ -2518,9 +2518,9 @@ dependencies = [
 
 [[package]]
 name = "cap-time-ext"
-version = "3.4.2"
+version = "3.4.3"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "bd736b20fc033f564a1995fb82fc349146de43aabba19c7368b4cb17d8f9ea53"
+checksum = "eb57b71bb69b97c638ec38b477e9b33fa1c1cff0e437dd55d15c117cf17ed5dc"
 dependencies = [
  "ambient-authority",
  "cap-primitives",
@@ -2598,9 +2598,9 @@ dependencies = [
 
 [[package]]
 name = "cc"
-version = "1.2.17"
+version = "1.2.18"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "1fcb57c740ae1daf453ae85f16e37396f672b039e00d9d866e07ddb24e328e3a"
+checksum = "525046617d8376e3db1deffb079e91cef90a89fc3ca5c185bbf8c9ecdd15cd5c"
 dependencies = [
  "jobserver",
  "libc",
@@ -3926,9 +3926,9 @@ checksum = "4f211af61d8efdd104f96e57adf5e426ba1bc3ed7a4ead616e15e5881fd79c4d"
 
 [[package]]
 name = "ctrlc"
-version = "3.4.5"
+version = "3.4.6"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "90eeab0aa92f3f9b4e87f258c72b139c207d251f9cbc1080a0086b86a8870dd3"
+checksum = "697b5419f348fd5ae2478e8018cb016c00a5881c7f46c717de98ffd135a5651c"
 dependencies = [
  "nix",
  "windows-sys 0.59.0",
@@ -4805,9 +4805,9 @@ dependencies = [
 
 [[package]]
 name = "errno"
-version = "0.3.10"
+version = "0.3.11"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "33d852cb9b869c2a9b3df2f71a3074817f01e1844f839a144f5fcef059a4eb5d"
+checksum = "976dd42dc7e85965fe702eb8164f21f450704bdde31faefd6471dba214cb594e"
 dependencies = [
  "libc",
  "windows-sys 0.59.0",
@@ -5252,9 +5252,9 @@ checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80"
 
 [[package]]
 name = "flate2"
-version = "1.1.0"
+version = "1.1.1"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "11faaf5a5236997af9848be0bef4db95824b1d534ebc64d0f0c6cf3e67bd38dc"
+checksum = "7ced92e76e966ca2fd84c8f7aa01a4aea65b0eb6648d72f7c8f3e2764a67fece"
 dependencies = [
  "crc32fast",
  "miniz_oxide",
@@ -7891,9 +7891,9 @@ checksum = "8355be11b20d696c8f18f6cc018c4e372165b1fa8126cef092399c9951984ffa"
 
 [[package]]
 name = "libmimalloc-sys"
-version = "0.1.40"
+version = "0.1.41"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "07d0e07885d6a754b9c7993f2625187ad694ee985d60f23355ff0e7077261502"
+checksum = "6b20daca3a4ac14dbdc753c5e90fc7b490a48a9131daed3c9a9ced7b2defd37b"
 dependencies = [
  "cc",
  "libc",
@@ -8562,9 +8562,9 @@ dependencies = [
 
 [[package]]
 name = "mimalloc"
-version = "0.1.44"
+version = "0.1.45"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "99585191385958383e13f6b822e6b6d8d9cf928e7d286ceb092da92b43c87bc1"
+checksum = "03cb1f88093fe50061ca1195d336ffec131347c7b833db31f9ab62a2d1b7925f"
 dependencies = [
  "libmimalloc-sys",
 ]
@@ -8593,9 +8593,9 @@ checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a"
 
 [[package]]
 name = "miniz_oxide"
-version = "0.8.5"
+version = "0.8.7"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "8e3e04debbb59698c15bacbb6d93584a8c0ca9cc3213cb423d31f760d8843ce5"
+checksum = "ff70ce3e48ae43fa075863cef62e8b43b71a4f2382229920e0df362592919430"
 dependencies = [
  "adler2",
  "simd-adler32",
@@ -9474,9 +9474,9 @@ dependencies = [
 
 [[package]]
 name = "openssl"
-version = "0.10.71"
+version = "0.10.72"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "5e14130c6a98cd258fdcb0fb6d744152343ff729cbfcb28c656a9d12b999fbcd"
+checksum = "fedfea7d58a1f73118430a55da6a286e7b044961736ce96a16a17068ea25e5da"
 dependencies = [
  "bitflags 2.9.0",
  "cfg-if",
@@ -9515,9 +9515,9 @@ dependencies = [
 
 [[package]]
 name = "openssl-sys"
-version = "0.9.106"
+version = "0.9.107"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "8bb61ea9811cc39e3c2069f40b8b8e2e70d8569b361f879786cc7ed48b777cdd"
+checksum = "8288979acd84749c744a9014b4382d42b8f7b2592847b5afb2ed29e5d16ede07"
 dependencies = [
  "cc",
  "libc",
@@ -12149,7 +12149,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154"
 dependencies = [
  "bitflags 2.9.0",
- "errno 0.3.10",
+ "errno 0.3.11",
  "itoa",
  "libc",
  "linux-raw-sys 0.4.15",
@@ -12164,7 +12164,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "d97817398dd4bb2e6da002002db259209759911da105da92bec29ccb12cf58bf"
 dependencies = [
  "bitflags 2.9.0",
- "errno 0.3.10",
+ "errno 0.3.11",
  "libc",
  "linux-raw-sys 0.9.3",
  "windows-sys 0.59.0",
@@ -12176,7 +12176,7 @@ version = "0.1.1"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "a25c3aad9fc1424eb82c88087789a7d938e1829724f3e4043163baf0d13cfc12"
 dependencies = [
- "errno 0.3.10",
+ "errno 0.3.11",
  "libc",
  "rustix 0.38.44",
 ]
@@ -13798,9 +13798,9 @@ dependencies = [
 
 [[package]]
 name = "swash"
-version = "0.2.1"
+version = "0.2.2"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "13d5bbc2aa266907ed8ee977c9c9e16363cc2b001266104e13397b57f1d15f71"
+checksum = "fae9a562c7b46107d9c78cd78b75bbe1e991c16734c0aee8ff0ee711fb8b620a"
 dependencies = [
  "skrifa",
  "yazi",

Cargo.toml 🔗

@@ -399,11 +399,11 @@ async-trait = "0.1"
 async-tungstenite = "0.28"
 async-watch = "0.3.1"
 async_zip = { version = "0.0.17", features = ["deflate", "deflate64"] }
-aws-config = { version = "1.5.16", features = ["behavior-version-latest"] }
-aws-credential-types = { version = "1.2.1", features = ["hardcoded-credentials"] }
-aws-sdk-bedrockruntime = { version = "1.73.0", features = ["behavior-version-latest"] }
-aws-smithy-runtime-api = { version = "1.7.3", features = ["http-1x", "client"] }
-aws-smithy-types = { version = "1.2.13", features = ["http-body-1-x"] }
+aws-config = { version = "1.6.1", features = ["behavior-version-latest"] }
+aws-credential-types = { version = "1.2.2", features = ["hardcoded-credentials"] }
+aws-sdk-bedrockruntime = { version = "1.80.0", features = ["behavior-version-latest"] }
+aws-smithy-runtime-api = { version = "1.7.4", features = ["http-1x", "client"] }
+aws-smithy-types = { version = "1.3.0", features = ["http-body-1-x"] }
 base64 = "0.22"
 bitflags = "2.6.0"
 blade-graphics = { git = "https://github.com/kvark/blade", rev = "b16f5c7bd873c7126f48c82c39e7ae64602ae74f" }

crates/bedrock/src/bedrock.rs 🔗

@@ -1,21 +1,25 @@
 mod models;
 
+use std::collections::HashMap;
 use std::pin::Pin;
 
-use anyhow::{Context, Error, Result, anyhow};
+use anyhow::{Error, Result, anyhow};
 use aws_sdk_bedrockruntime as bedrock;
 pub use aws_sdk_bedrockruntime as bedrock_client;
 pub use aws_sdk_bedrockruntime::types::{
-    ContentBlock as BedrockInnerContent, SpecificToolChoice as BedrockSpecificTool,
-    ToolChoice as BedrockToolChoice, ToolInputSchema as BedrockToolInputSchema,
-    ToolSpecification as BedrockTool,
+    AutoToolChoice as BedrockAutoToolChoice, ContentBlock as BedrockInnerContent,
+    Tool as BedrockTool, ToolChoice as BedrockToolChoice, ToolConfiguration as BedrockToolConfig,
+    ToolInputSchema as BedrockToolInputSchema, ToolSpecification as BedrockToolSpec,
 };
 use aws_smithy_types::{Document, Number as AwsNumber};
 pub use bedrock::operation::converse_stream::ConverseStreamInput as BedrockStreamingRequest;
 pub use bedrock::types::{
     ContentBlock as BedrockRequestContent, ConversationRole as BedrockRole,
     ConverseOutput as BedrockResponse, ConverseStreamOutput as BedrockStreamingResponse,
-    Message as BedrockMessage, ResponseStream as BedrockResponseStream,
+    ImageBlock as BedrockImageBlock, Message as BedrockMessage,
+    ResponseStream as BedrockResponseStream, ToolResultBlock as BedrockToolResultBlock,
+    ToolResultContentBlock as BedrockToolResultContentBlock,
+    ToolResultStatus as BedrockToolResultStatus, ToolUseBlock as BedrockToolUseBlock,
 };
 use futures::stream::{self, BoxStream, Stream};
 use serde::{Deserialize, Serialize};
@@ -24,25 +28,6 @@ use thiserror::Error;
 
 pub use crate::models::*;
 
-pub async fn complete(
-    client: &bedrock::Client,
-    request: Request,
-) -> Result<BedrockResponse, BedrockError> {
-    let response = bedrock::Client::converse(client)
-        .model_id(request.model.clone())
-        .set_messages(request.messages.into())
-        .send()
-        .await
-        .context("failed to send request to Bedrock");
-
-    match response {
-        Ok(output) => output
-            .output
-            .ok_or_else(|| BedrockError::Other(anyhow!("no output"))),
-        Err(err) => Err(BedrockError::Other(err)),
-    }
-}
-
 pub async fn stream_completion(
     client: bedrock::Client,
     request: Request,
@@ -50,11 +35,32 @@ pub async fn stream_completion(
 ) -> Result<BoxStream<'static, Result<BedrockStreamingResponse, BedrockError>>, Error> {
     handle
         .spawn(async move {
-            let response = bedrock::Client::converse_stream(&client)
+            let mut response = bedrock::Client::converse_stream(&client)
                 .model_id(request.model.clone())
-                .set_messages(request.messages.into())
-                .send()
-                .await;
+                .set_messages(request.messages.into());
+
+            if let Some(Thinking::Enabled {
+                budget_tokens: Some(budget_tokens),
+            }) = request.thinking
+            {
+                response =
+                    response.additional_model_request_fields(Document::Object(HashMap::from([(
+                        "thinking".to_string(),
+                        Document::from(HashMap::from([
+                            ("type".to_string(), Document::String("enabled".to_string())),
+                            (
+                                "budget_tokens".to_string(),
+                                Document::Number(AwsNumber::PosInt(budget_tokens)),
+                            ),
+                        ])),
+                    )])));
+            }
+
+            if request.tools.is_some() && !request.tools.as_ref().unwrap().tools.is_empty() {
+                response = response.set_tool_config(request.tools);
+            }
+
+            let response = response.send().await;
 
             match response {
                 Ok(output) => {
@@ -65,7 +71,7 @@ pub async fn stream_completion(
                         >,
                     > = Box::pin(stream::unfold(output.stream, |mut stream| async move {
                         match stream.recv().await {
-                            Ok(Some(output)) => Some((Ok(output), stream)),
+                            Ok(Some(output)) => Some(({ Ok(output) }, stream)),
                             Ok(None) => None,
                             Err(err) => {
                                 Some((
@@ -135,13 +141,18 @@ pub fn value_to_aws_document(value: &Value) -> Document {
     }
 }
 
+#[derive(Debug, Serialize, Deserialize)]
+pub enum Thinking {
+    Enabled { budget_tokens: Option<u64> },
+}
+
 #[derive(Debug)]
 pub struct Request {
     pub model: String,
     pub max_tokens: u32,
     pub messages: Vec<BedrockMessage>,
-    pub tools: Vec<BedrockTool>,
-    pub tool_choice: Option<BedrockToolChoice>,
+    pub tools: Option<BedrockToolConfig>,
+    pub thinking: Option<Thinking>,
     pub system: Option<String>,
     pub metadata: Option<Metadata>,
     pub stop_sequences: Vec<String>,

crates/bedrock/src/models.rs 🔗

@@ -2,21 +2,38 @@ use anyhow::anyhow;
 use serde::{Deserialize, Serialize};
 use strum::EnumIter;
 
+#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
+#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
+pub enum BedrockModelMode {
+    #[default]
+    Default,
+    Thinking {
+        budget_tokens: Option<u64>,
+    },
+}
+
 #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 #[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
 pub enum Model {
     // Anthropic models (already included)
     #[default]
     #[serde(rename = "claude-3-5-sonnet-v2", alias = "claude-3-5-sonnet-latest")]
-    Claude3_5Sonnet,
+    Claude3_5SonnetV2,
     #[serde(rename = "claude-3-7-sonnet", alias = "claude-3-7-sonnet-latest")]
     Claude3_7Sonnet,
+    #[serde(
+        rename = "claude-3-7-sonnet-thinking",
+        alias = "claude-3-7-sonnet-thinking-latest"
+    )]
+    Claude3_7SonnetThinking,
     #[serde(rename = "claude-3-opus", alias = "claude-3-opus-latest")]
     Claude3Opus,
     #[serde(rename = "claude-3-sonnet", alias = "claude-3-sonnet-latest")]
     Claude3Sonnet,
     #[serde(rename = "claude-3-5-haiku", alias = "claude-3-5-haiku-latest")]
     Claude3_5Haiku,
+    Claude3_5Sonnet,
+    Claude3Haiku,
     // Amazon Nova Models
     AmazonNovaLite,
     AmazonNovaMicro,
@@ -69,7 +86,7 @@ pub enum Model {
 impl Model {
     pub fn from_id(id: &str) -> anyhow::Result<Self> {
         if id.starts_with("claude-3-5-sonnet-v2") {
-            Ok(Self::Claude3_5Sonnet)
+            Ok(Self::Claude3_5SonnetV2)
         } else if id.starts_with("claude-3-opus") {
             Ok(Self::Claude3Opus)
         } else if id.starts_with("claude-3-sonnet") {
@@ -78,6 +95,8 @@ impl Model {
             Ok(Self::Claude3_5Haiku)
         } else if id.starts_with("claude-3-7-sonnet") {
             Ok(Self::Claude3_7Sonnet)
+        } else if id.starts_with("claude-3-7-sonnet-thinking") {
+            Ok(Self::Claude3_7SonnetThinking)
         } else {
             Err(anyhow!("invalid model id"))
         }
@@ -85,14 +104,18 @@ impl Model {
 
     pub fn id(&self) -> &str {
         match self {
-            Model::Claude3_5Sonnet => "us.anthropic.claude-3-5-sonnet-20241022-v2:0",
-            Model::Claude3Opus => "us.anthropic.claude-3-opus-20240229-v1:0",
-            Model::Claude3Sonnet => "us.anthropic.claude-3-sonnet-20240229-v1:0",
-            Model::Claude3_5Haiku => "us.anthropic.claude-3-5-haiku-20241022-v1:0",
-            Model::Claude3_7Sonnet => "us.anthropic.claude-3-7-sonnet-20250219-v1:0",
-            Model::AmazonNovaLite => "us.amazon.nova-lite-v1:0",
-            Model::AmazonNovaMicro => "us.amazon.nova-micro-v1:0",
-            Model::AmazonNovaPro => "us.amazon.nova-pro-v1:0",
+            Model::Claude3_5SonnetV2 => "anthropic.claude-3-5-sonnet-20241022-v2:0",
+            Model::Claude3_5Sonnet => "anthropic.claude-3-5-sonnet-20240620-v1:0",
+            Model::Claude3Opus => "anthropic.claude-3-opus-20240229-v1:0",
+            Model::Claude3Sonnet => "anthropic.claude-3-sonnet-20240229-v1:0",
+            Model::Claude3Haiku => "anthropic.claude-3-haiku-20240307-v1:0",
+            Model::Claude3_5Haiku => "anthropic.claude-3-5-haiku-20241022-v1:0",
+            Model::Claude3_7Sonnet | Model::Claude3_7SonnetThinking => {
+                "anthropic.claude-3-7-sonnet-20250219-v1:0"
+            }
+            Model::AmazonNovaLite => "amazon.nova-lite-v1:0",
+            Model::AmazonNovaMicro => "amazon.nova-micro-v1:0",
+            Model::AmazonNovaPro => "amazon.nova-pro-v1:0",
             Model::DeepSeekR1 => "us.deepseek.r1-v1:0",
             Model::AI21J2GrandeInstruct => "ai21.j2-grande-instruct",
             Model::AI21J2JumboInstruct => "ai21.j2-jumbo-instruct",
@@ -128,11 +151,14 @@ impl Model {
 
     pub fn display_name(&self) -> &str {
         match self {
-            Self::Claude3_5Sonnet => "Claude 3.5 Sonnet v2",
+            Self::Claude3_5SonnetV2 => "Claude 3.5 Sonnet v2",
+            Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
             Self::Claude3Opus => "Claude 3 Opus",
             Self::Claude3Sonnet => "Claude 3 Sonnet",
+            Self::Claude3Haiku => "Claude 3 Haiku",
             Self::Claude3_5Haiku => "Claude 3.5 Haiku",
             Self::Claude3_7Sonnet => "Claude 3.7 Sonnet",
+            Self::Claude3_7SonnetThinking => "Claude 3.7 Sonnet Thinking",
             Self::AmazonNovaLite => "Amazon Nova Lite",
             Self::AmazonNovaMicro => "Amazon Nova Micro",
             Self::AmazonNovaPro => "Amazon Nova Pro",
@@ -173,7 +199,7 @@ impl Model {
 
     pub fn max_token_count(&self) -> usize {
         match self {
-            Self::Claude3_5Sonnet
+            Self::Claude3_5SonnetV2
             | Self::Claude3Opus
             | Self::Claude3Sonnet
             | Self::Claude3_5Haiku
@@ -186,7 +212,8 @@ impl Model {
     pub fn max_output_tokens(&self) -> u32 {
         match self {
             Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3_5Haiku => 4_096,
-            Self::Claude3_5Sonnet => 8_192,
+            Self::Claude3_7Sonnet | Self::Claude3_7SonnetThinking => 128_000,
+            Self::Claude3_5SonnetV2 => 8_192,
             Self::Custom {
                 max_output_tokens, ..
             } => max_output_tokens.unwrap_or(4_096),
@@ -196,7 +223,7 @@ impl Model {
 
     pub fn default_temperature(&self) -> f32 {
         match self {
-            Self::Claude3_5Sonnet
+            Self::Claude3_5SonnetV2
             | Self::Claude3Opus
             | Self::Claude3Sonnet
             | Self::Claude3_5Haiku
@@ -208,4 +235,253 @@ impl Model {
             _ => 1.0,
         }
     }
+
+    pub fn supports_tool_use(&self) -> bool {
+        match self {
+            // Anthropic Claude 3 models (all support tool use)
+            Self::Claude3Opus
+            | Self::Claude3Sonnet
+            | Self::Claude3_5Sonnet
+            | Self::Claude3_5SonnetV2
+            | Self::Claude3_7Sonnet
+            | Self::Claude3_7SonnetThinking
+            | Self::Claude3_5Haiku => true,
+
+            // Amazon Nova models (all support tool use)
+            Self::AmazonNovaPro | Self::AmazonNovaLite | Self::AmazonNovaMicro => true,
+
+            // AI21 Jamba 1.5 models support tool use
+            Self::AI21Jamba15LargeV1 | Self::AI21Jamba15MiniV1 => true,
+
+            // Cohere Command R models support tool use
+            Self::CohereCommandRV1 | Self::CohereCommandRPlusV1 => true,
+
+            // All other models don't support tool use
+            // Including Meta Llama 3.2, AI21 Jurassic, and others
+            _ => false,
+        }
+    }
+
+    pub fn mode(&self) -> BedrockModelMode {
+        match self {
+            Model::Claude3_7SonnetThinking => BedrockModelMode::Thinking {
+                budget_tokens: Some(4096),
+            },
+            _ => BedrockModelMode::Default,
+        }
+    }
+
+    pub fn cross_region_inference_id(&self, region: &str) -> Result<String, anyhow::Error> {
+        let region_group = if region.starts_with("us-gov-") {
+            "us-gov"
+        } else if region.starts_with("us-") {
+            "us"
+        } else if region.starts_with("eu-") {
+            "eu"
+        } else if region.starts_with("ap-") || region == "me-central-1" || region == "me-south-1" {
+            "apac"
+        } else if region.starts_with("ca-") || region.starts_with("sa-") {
+            // Canada and South America regions - default to US profiles
+            "us"
+        } else {
+            // Unknown region
+            return Err(anyhow!("Unsupported Region"));
+        };
+
+        let model_id = self.id();
+
+        match (self, region_group) {
+            // Custom models can't have CRI IDs
+            (Model::Custom { .. }, _) => Ok(self.id().into()),
+
+            // Models with US Gov only
+            (Model::Claude3_5Sonnet, "us-gov") | (Model::Claude3Haiku, "us-gov") => {
+                Ok(format!("{}.{}", region_group, model_id))
+            }
+
+            // Models available only in US
+            (Model::Claude3Opus, "us")
+            | (Model::Claude3_7Sonnet, "us")
+            | (Model::Claude3_7SonnetThinking, "us") => {
+                Ok(format!("{}.{}", region_group, model_id))
+            }
+
+            // Models available in US, EU, and APAC
+            (Model::Claude3_5SonnetV2, "us")
+            | (Model::Claude3_5SonnetV2, "apac")
+            | (Model::Claude3_5Sonnet, _)
+            | (Model::Claude3Haiku, _)
+            | (Model::Claude3Sonnet, _)
+            | (Model::AmazonNovaLite, _)
+            | (Model::AmazonNovaMicro, _)
+            | (Model::AmazonNovaPro, _) => Ok(format!("{}.{}", region_group, model_id)),
+
+            // Models with limited EU availability
+            (Model::MetaLlama321BInstructV1, "us")
+            | (Model::MetaLlama321BInstructV1, "eu")
+            | (Model::MetaLlama323BInstructV1, "us")
+            | (Model::MetaLlama323BInstructV1, "eu") => {
+                Ok(format!("{}.{}", region_group, model_id))
+            }
+
+            // US-only models (all remaining Meta models)
+            (Model::MetaLlama38BInstructV1, "us")
+            | (Model::MetaLlama370BInstructV1, "us")
+            | (Model::MetaLlama318BInstructV1, "us")
+            | (Model::MetaLlama318BInstructV1_128k, "us")
+            | (Model::MetaLlama3170BInstructV1, "us")
+            | (Model::MetaLlama3170BInstructV1_128k, "us")
+            | (Model::MetaLlama3211BInstructV1, "us")
+            | (Model::MetaLlama3290BInstructV1, "us") => {
+                Ok(format!("{}.{}", region_group, model_id))
+            }
+
+            // Any other combination is not supported
+            _ => Ok(self.id().into()),
+        }
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    #[test]
+    fn test_us_region_inference_ids() -> anyhow::Result<()> {
+        // Test US regions
+        assert_eq!(
+            Model::Claude3_5SonnetV2.cross_region_inference_id("us-east-1")?,
+            "us.anthropic.claude-3-5-sonnet-20241022-v2:0"
+        );
+        assert_eq!(
+            Model::Claude3_5SonnetV2.cross_region_inference_id("us-west-2")?,
+            "us.anthropic.claude-3-5-sonnet-20241022-v2:0"
+        );
+        assert_eq!(
+            Model::AmazonNovaPro.cross_region_inference_id("us-east-2")?,
+            "us.amazon.nova-pro-v1:0"
+        );
+        Ok(())
+    }
+
+    #[test]
+    fn test_eu_region_inference_ids() -> anyhow::Result<()> {
+        // Test European regions
+        assert_eq!(
+            Model::Claude3Sonnet.cross_region_inference_id("eu-west-1")?,
+            "eu.anthropic.claude-3-sonnet-20240229-v1:0"
+        );
+        assert_eq!(
+            Model::AmazonNovaMicro.cross_region_inference_id("eu-north-1")?,
+            "eu.amazon.nova-micro-v1:0"
+        );
+        Ok(())
+    }
+
+    #[test]
+    fn test_apac_region_inference_ids() -> anyhow::Result<()> {
+        // Test Asia-Pacific regions
+        assert_eq!(
+            Model::Claude3_5SonnetV2.cross_region_inference_id("ap-northeast-1")?,
+            "apac.anthropic.claude-3-5-sonnet-20241022-v2:0"
+        );
+        assert_eq!(
+            Model::AmazonNovaLite.cross_region_inference_id("ap-south-1")?,
+            "apac.amazon.nova-lite-v1:0"
+        );
+        Ok(())
+    }
+
+    #[test]
+    fn test_gov_region_inference_ids() -> anyhow::Result<()> {
+        // Test Government regions
+        assert_eq!(
+            Model::Claude3_5Sonnet.cross_region_inference_id("us-gov-east-1")?,
+            "us-gov.anthropic.claude-3-5-sonnet-20240620-v1:0"
+        );
+        assert_eq!(
+            Model::Claude3Haiku.cross_region_inference_id("us-gov-west-1")?,
+            "us-gov.anthropic.claude-3-haiku-20240307-v1:0"
+        );
+        Ok(())
+    }
+
+    #[test]
+    fn test_meta_models_inference_ids() -> anyhow::Result<()> {
+        // Test Meta models
+        assert_eq!(
+            Model::MetaLlama370BInstructV1.cross_region_inference_id("us-east-1")?,
+            "us.meta.llama3-70b-instruct-v1:0"
+        );
+        assert_eq!(
+            Model::MetaLlama321BInstructV1.cross_region_inference_id("eu-west-1")?,
+            "eu.meta.llama3-2-1b-instruct-v1:0"
+        );
+        Ok(())
+    }
+
+    #[test]
+    fn test_mistral_models_inference_ids() -> anyhow::Result<()> {
+        // Mistral models don't follow the regional prefix pattern,
+        // so they should return their original IDs
+        assert_eq!(
+            Model::MistralMistralLarge2402V1.cross_region_inference_id("us-east-1")?,
+            "mistral.mistral-large-2402-v1:0"
+        );
+        assert_eq!(
+            Model::MistralMixtral8x7BInstructV0.cross_region_inference_id("eu-west-1")?,
+            "mistral.mixtral-8x7b-instruct-v0:1"
+        );
+        Ok(())
+    }
+
+    #[test]
+    fn test_ai21_models_inference_ids() -> anyhow::Result<()> {
+        // AI21 models don't follow the regional prefix pattern,
+        // so they should return their original IDs
+        assert_eq!(
+            Model::AI21J2UltraV1.cross_region_inference_id("us-east-1")?,
+            "ai21.j2-ultra-v1"
+        );
+        assert_eq!(
+            Model::AI21JambaInstructV1.cross_region_inference_id("eu-west-1")?,
+            "ai21.jamba-instruct-v1:0"
+        );
+        Ok(())
+    }
+
+    #[test]
+    fn test_cohere_models_inference_ids() -> anyhow::Result<()> {
+        // Cohere models don't follow the regional prefix pattern,
+        // so they should return their original IDs
+        assert_eq!(
+            Model::CohereCommandRV1.cross_region_inference_id("us-east-1")?,
+            "cohere.command-r-v1:0"
+        );
+        assert_eq!(
+            Model::CohereCommandTextV14_4k.cross_region_inference_id("ap-southeast-1")?,
+            "cohere.command-text-v14:7:4k"
+        );
+        Ok(())
+    }
+
+    #[test]
+    fn test_custom_model_inference_ids() -> anyhow::Result<()> {
+        // Test custom models
+        let custom_model = Model::Custom {
+            name: "custom.my-model-v1:0".to_string(),
+            max_tokens: 100000,
+            display_name: Some("My Custom Model".to_string()),
+            max_output_tokens: Some(8192),
+            default_temperature: Some(0.7),
+        };
+
+        // Custom model should return its name unchanged
+        assert_eq!(
+            custom_model.cross_region_inference_id("us-east-1")?,
+            "custom.my-model-v1:0"
+        );
+
+        Ok(())
+    }
 }

crates/language_models/src/provider/bedrock.rs 🔗

@@ -4,19 +4,29 @@ use std::sync::Arc;
 
 use crate::ui::InstructionListItem;
 use anyhow::{Context as _, Result, anyhow};
-use aws_config::Region;
 use aws_config::stalled_stream_protection::StalledStreamProtectionConfig;
+use aws_config::{BehaviorVersion, Region};
 use aws_credential_types::Credentials;
 use aws_http_client::AwsHttpClient;
-use bedrock::bedrock_client::types::{ContentBlockDelta, ContentBlockStart, ConverseStreamOutput};
-use bedrock::bedrock_client::{self, Config};
-use bedrock::{BedrockError, BedrockInnerContent, BedrockMessage, BedrockStreamingResponse, Model};
+use bedrock::bedrock_client::Client as BedrockClient;
+use bedrock::bedrock_client::config::timeout::TimeoutConfig;
+use bedrock::bedrock_client::types::{
+    ContentBlockDelta, ContentBlockStart, ConverseStreamOutput, ReasoningContentBlockDelta,
+    StopReason,
+};
+use bedrock::{
+    BedrockAutoToolChoice, BedrockError, BedrockInnerContent, BedrockMessage, BedrockModelMode,
+    BedrockStreamingResponse, BedrockTool, BedrockToolChoice, BedrockToolConfig,
+    BedrockToolInputSchema, BedrockToolResultBlock, BedrockToolResultContentBlock,
+    BedrockToolResultStatus, BedrockToolSpec, BedrockToolUseBlock, Model, value_to_aws_document,
+};
 use collections::{BTreeMap, HashMap};
 use credentials_provider::CredentialsProvider;
 use editor::{Editor, EditorElement, EditorStyle};
 use futures::{FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream};
 use gpui::{
-    AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace,
+    AnyView, App, AsyncApp, Context, Entity, FontStyle, FontWeight, Subscription, Task, TextStyle,
+    WhiteSpace,
 };
 use gpui_tokio::Tokio;
 use http_client::HttpClient;
@@ -24,17 +34,18 @@ use language_model::{
     AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
     LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
     LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
-    LanguageModelRequest, LanguageModelToolUse, MessageContent, RateLimiter, Role,
+    LanguageModelRequest, LanguageModelToolUse, MessageContent, RateLimiter, Role, TokenUsage,
 };
 use schemars::JsonSchema;
 use serde::{Deserialize, Serialize};
 use serde_json::Value;
 use settings::{Settings, SettingsStore};
-use strum::IntoEnumIterator;
+use smol::lock::OnceCell;
+use strum::{EnumIter, IntoEnumIterator, IntoStaticStr};
 use theme::ThemeSettings;
 use tokio::runtime::Handle;
 use ui::{Icon, IconName, List, Tooltip, prelude::*};
-use util::{ResultExt, maybe};
+use util::{ResultExt, default};
 
 use crate::AllLanguageModelSettings;
 
@@ -43,15 +54,33 @@ const PROVIDER_NAME: &str = "Amazon Bedrock";
 
 #[derive(Default, Clone, Deserialize, Serialize, PartialEq, Debug)]
 pub struct BedrockCredentials {
-    pub region: String,
     pub access_key_id: String,
     pub secret_access_key: String,
+    pub session_token: Option<String>,
+    pub region: String,
 }
 
 #[derive(Default, Clone, Debug, PartialEq)]
 pub struct AmazonBedrockSettings {
-    pub session_token: Option<String>,
     pub available_models: Vec<AvailableModel>,
+    pub region: Option<String>,
+    pub endpoint: Option<String>,
+    pub profile_name: Option<String>,
+    pub role_arn: Option<String>,
+    pub authentication_method: Option<BedrockAuthMethod>,
+}
+
+#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, EnumIter, IntoStaticStr, JsonSchema)]
+pub enum BedrockAuthMethod {
+    #[serde(rename = "named_profile")]
+    NamedProfile,
+    #[serde(rename = "static_credentials")]
+    StaticCredentials,
+    #[serde(rename = "sso")]
+    SingleSignOn,
+    /// IMDSv2, PodIdentity, env vars, etc.
+    #[serde(rename = "default")]
+    Automatic,
 }
 
 #[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
@@ -62,6 +91,36 @@ pub struct AvailableModel {
     pub cache_configuration: Option<LanguageModelCacheConfiguration>,
     pub max_output_tokens: Option<u32>,
     pub default_temperature: Option<f32>,
+    pub mode: Option<ModelMode>,
+}
+
+#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize, JsonSchema)]
+#[serde(tag = "type", rename_all = "lowercase")]
+pub enum ModelMode {
+    #[default]
+    Default,
+    Thinking {
+        /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
+        budget_tokens: Option<u64>,
+    },
+}
+
+impl From<ModelMode> for BedrockModelMode {
+    fn from(value: ModelMode) -> Self {
+        match value {
+            ModelMode::Default => BedrockModelMode::Default,
+            ModelMode::Thinking { budget_tokens } => BedrockModelMode::Thinking { budget_tokens },
+        }
+    }
+}
+
+impl From<BedrockModelMode> for ModelMode {
+    fn from(value: BedrockModelMode) -> Self {
+        match value {
+            BedrockModelMode::Default => ModelMode::Default,
+            BedrockModelMode::Thinking { budget_tokens } => ModelMode::Thinking { budget_tokens },
+        }
+    }
 }
 
 /// The URL of the base AWS service.
@@ -73,11 +132,15 @@ const AMAZON_AWS_URL: &str = "https://amazonaws.com";
 // These environment variables all use a `ZED_` prefix because we don't want to overwrite the user's AWS credentials.
 const ZED_BEDROCK_ACCESS_KEY_ID_VAR: &str = "ZED_ACCESS_KEY_ID";
 const ZED_BEDROCK_SECRET_ACCESS_KEY_VAR: &str = "ZED_SECRET_ACCESS_KEY";
+const ZED_BEDROCK_SESSION_TOKEN_VAR: &str = "ZED_SESSION_TOKEN";
+const ZED_AWS_PROFILE_VAR: &str = "ZED_AWS_PROFILE";
 const ZED_BEDROCK_REGION_VAR: &str = "ZED_AWS_REGION";
 const ZED_AWS_CREDENTIALS_VAR: &str = "ZED_AWS_CREDENTIALS";
+const ZED_AWS_ENDPOINT_VAR: &str = "ZED_AWS_ENDPOINT";
 
 pub struct State {
     credentials: Option<BedrockCredentials>,
+    settings: Option<AmazonBedrockSettings>,
     credentials_from_env: bool,
     _subscription: Subscription,
 }
@@ -93,6 +156,7 @@ impl State {
             this.update(cx, |this, cx| {
                 this.credentials = None;
                 this.credentials_from_env = false;
+                this.settings = None;
                 cx.notify();
             })
         })
@@ -120,12 +184,47 @@ impl State {
         })
     }
 
-    fn is_authenticated(&self) -> bool {
-        self.credentials.is_some()
+    fn is_authenticated(&self) -> Option<String> {
+        match self
+            .settings
+            .as_ref()
+            .and_then(|s| s.authentication_method.as_ref())
+        {
+            Some(BedrockAuthMethod::StaticCredentials) => Some(String::from(
+                "You are authenticated using Static Credentials.",
+            )),
+            Some(BedrockAuthMethod::NamedProfile) | Some(BedrockAuthMethod::SingleSignOn) => {
+                match self.settings.as_ref() {
+                    None => Some(String::from(
+                        "You are authenticated using a Named Profile, but no profile is set.",
+                    )),
+                    Some(settings) => match settings.clone().profile_name {
+                        None => Some(String::from(
+                            "You are authenticated using a Named Profile, but no profile is set.",
+                        )),
+                        Some(profile_name) => Some(format!(
+                            "You are authenticated using a Named Profile: {profile_name}",
+                        )),
+                    },
+                }
+            }
+            Some(BedrockAuthMethod::Automatic) => Some(String::from(
+                "You are authenticated using Automatic Credentials.",
+            )),
+            None => {
+                if self.credentials.is_some() {
+                    Some(String::from(
+                        "You are authenticated using Static Credentials.",
+                    ))
+                } else {
+                    None
+                }
+            }
+        }
     }
 
     fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
-        if self.is_authenticated() {
+        if self.is_authenticated().is_some() {
             return Task::ready(Ok(()));
         }
 
@@ -170,6 +269,7 @@ impl BedrockLanguageModelProvider {
     pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
         let state = cx.new(|cx| State {
             credentials: None,
+            settings: Some(AllLanguageModelSettings::get_global(cx).bedrock.clone()),
             credentials_from_env: false,
             _subscription: cx.observe_global::<SettingsStore>(|_, cx| {
                 cx.notify();
@@ -209,6 +309,7 @@ impl LanguageModelProvider for BedrockLanguageModelProvider {
             http_client: self.http_client.clone(),
             handler: self.handler.clone(),
             state: self.state.clone(),
+            client: OnceCell::new(),
             request_limiter: RateLimiter::new(4),
         }))
     }
@@ -249,6 +350,7 @@ impl LanguageModelProvider for BedrockLanguageModelProvider {
                     http_client: self.http_client.clone(),
                     handler: self.handler.clone(),
                     state: self.state.clone(),
+                    client: OnceCell::new(),
                     request_limiter: RateLimiter::new(4),
                 }) as Arc<dyn LanguageModel>
             })
@@ -256,7 +358,7 @@ impl LanguageModelProvider for BedrockLanguageModelProvider {
     }
 
     fn is_authenticated(&self, cx: &App) -> bool {
-        self.state.read(cx).is_authenticated()
+        self.state.read(cx).is_authenticated().is_some()
     }
 
     fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
@@ -287,11 +389,94 @@ struct BedrockModel {
     model: Model,
     http_client: AwsHttpClient,
     handler: tokio::runtime::Handle,
+    client: OnceCell<BedrockClient>,
     state: gpui::Entity<State>,
     request_limiter: RateLimiter,
 }
 
 impl BedrockModel {
+    fn get_or_init_client(&self, cx: &AsyncApp) -> Result<&BedrockClient, anyhow::Error> {
+        self.client
+            .get_or_try_init_blocking(|| {
+                let Ok((auth_method, credentials, endpoint, region, settings)) =
+                    cx.read_entity(&self.state, |state, _cx| {
+                        let auth_method = state
+                            .settings
+                            .as_ref()
+                            .and_then(|s| s.authentication_method.clone())
+                            .unwrap_or(BedrockAuthMethod::Automatic);
+
+                        let endpoint = state.settings.as_ref().and_then(|s| s.endpoint.clone());
+
+                        let region = state
+                            .settings
+                            .as_ref()
+                            .and_then(|s| s.region.clone())
+                            .unwrap_or(String::from("us-east-1"));
+
+                        (
+                            auth_method,
+                            state.credentials.clone(),
+                            endpoint,
+                            region,
+                            state.settings.clone(),
+                        )
+                    })
+                else {
+                    return Err(anyhow!("App state dropped"));
+                };
+
+                let mut config_builder = aws_config::defaults(BehaviorVersion::latest())
+                    .stalled_stream_protection(StalledStreamProtectionConfig::disabled())
+                    .http_client(self.http_client.clone())
+                    .region(Region::new(region))
+                    .timeout_config(TimeoutConfig::disabled());
+
+                if let Some(endpoint_url) = endpoint {
+                    if !endpoint_url.is_empty() {
+                        config_builder = config_builder.endpoint_url(endpoint_url);
+                    }
+                }
+
+                match auth_method {
+                    BedrockAuthMethod::StaticCredentials => {
+                        if let Some(creds) = credentials {
+                            let aws_creds = Credentials::new(
+                                creds.access_key_id,
+                                creds.secret_access_key,
+                                creds.session_token,
+                                None,
+                                "zed-bedrock-provider",
+                            );
+                            config_builder = config_builder.credentials_provider(aws_creds);
+                        }
+                    }
+                    BedrockAuthMethod::NamedProfile | BedrockAuthMethod::SingleSignOn => {
+                        // Currently NamedProfile and SSO behave the same way but only the instructions change
+                        // Until we support BearerAuth through SSO, this will not change.
+                        let profile_name = settings
+                            .and_then(|s| s.profile_name)
+                            .unwrap_or_else(|| "default".to_string());
+
+                        if !profile_name.is_empty() {
+                            config_builder = config_builder.profile_name(profile_name);
+                        }
+                    }
+                    BedrockAuthMethod::Automatic => {
+                        // Use default credential provider chain
+                    }
+                }
+
+                let config = self.handler.block_on(config_builder.load());
+                Ok(BedrockClient::new(&config))
+            })
+            .map_err(|err| anyhow!("Failed to initialize Bedrock client: {err}"))?;
+
+        self.client
+            .get()
+            .ok_or_else(|| anyhow!("Bedrock client not initialized"))
+    }
+
     fn stream_completion(
         &self,
         request: bedrock::Request,
@@ -299,37 +484,10 @@ impl BedrockModel {
     ) -> Result<
         BoxFuture<'static, BoxStream<'static, Result<BedrockStreamingResponse, BedrockError>>>,
     > {
-        let Ok(Ok((access_key_id, secret_access_key, region))) =
-            cx.read_entity(&self.state, |state, _cx| {
-                if let Some(credentials) = &state.credentials {
-                    Ok((
-                        credentials.access_key_id.clone(),
-                        credentials.secret_access_key.clone(),
-                        credentials.region.clone(),
-                    ))
-                } else {
-                    return Err(anyhow!("Failed to read credentials"));
-                }
-            })
-        else {
-            return Err(anyhow!("App state dropped"));
-        };
-
-        let runtime_client = bedrock_client::Client::from_conf(
-            Config::builder()
-                .stalled_stream_protection(StalledStreamProtectionConfig::disabled())
-                .credentials_provider(Credentials::new(
-                    access_key_id,
-                    secret_access_key,
-                    None,
-                    None,
-                    "Keychain",
-                ))
-                .region(Region::new(region))
-                .http_client(self.http_client.clone())
-                .build(),
-        );
-
+        let runtime_client = self
+            .get_or_init_client(cx)
+            .cloned()
+            .context("Bedrock client not initialized")?;
         let owned_handle = self.handler.clone();
 
         Ok(async move {
@@ -360,7 +518,7 @@ impl LanguageModel for BedrockModel {
     }
 
     fn supports_tools(&self) -> bool {
-        true
+        self.model.supports_tool_use()
     }
 
     fn telemetry_id(&self) -> String {
@@ -388,12 +546,36 @@ impl LanguageModel for BedrockModel {
         request: LanguageModelRequest,
         cx: &AsyncApp,
     ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
-        let request = into_bedrock(
+        let Ok(region) = cx.read_entity(&self.state, |state, _cx| {
+            // Get region - from credentials or directly from settings
+            let region = state
+                .credentials
+                .as_ref()
+                .map(|s| s.region.clone())
+                .unwrap_or(String::from("us-east-1"));
+
+            region
+        }) else {
+            return async move { Err(anyhow!("App State Dropped")) }.boxed();
+        };
+
+        let model_id = match self.model.cross_region_inference_id(&region) {
+            Ok(s) => s,
+            Err(e) => {
+                return async move { Err(e) }.boxed();
+            }
+        };
+
+        let request = match into_bedrock(
             request,
-            self.model.id().into(),
+            model_id,
             self.model.default_temperature(),
             self.model.max_output_tokens(),
-        );
+            self.model.mode(),
+        ) {
+            Ok(request) => request,
+            Err(err) => return futures::future::ready(Err(err)).boxed(),
+        };
 
         let owned_handle = self.handler.clone();
 
@@ -418,7 +600,8 @@ pub fn into_bedrock(
     model: String,
     default_temperature: f32,
     max_output_tokens: u32,
-) -> bedrock::Request {
+    mode: BedrockModelMode,
+) -> Result<bedrock::Request> {
     let mut new_messages: Vec<BedrockMessage> = Vec::new();
     let mut system_message = String::new();
 
@@ -440,6 +623,32 @@ pub fn into_bedrock(
                                 None
                             }
                         }
+                        MessageContent::ToolUse(tool_use) => BedrockToolUseBlock::builder()
+                            .name(tool_use.name.to_string())
+                            .tool_use_id(tool_use.id.to_string())
+                            .input(value_to_aws_document(&tool_use.input))
+                            .build()
+                            .context("failed to build Bedrock tool use block")
+                            .log_err()
+                            .map(BedrockInnerContent::ToolUse),
+                        MessageContent::ToolResult(tool_result) => {
+                            BedrockToolResultBlock::builder()
+                                .tool_use_id(tool_result.tool_use_id.to_string())
+                                .content(BedrockToolResultContentBlock::Text(
+                                    tool_result.content.to_string(),
+                                ))
+                                .status({
+                                    if tool_result.is_error {
+                                        BedrockToolResultStatus::Error
+                                    } else {
+                                        BedrockToolResultStatus::Success
+                                    }
+                                })
+                                .build()
+                                .context("failed to build Bedrock tool result block")
+                                .log_err()
+                                .map(BedrockInnerContent::ToolResult)
+                        }
                         _ => None,
                     })
                     .collect();
@@ -459,7 +668,7 @@ pub fn into_bedrock(
                         .role(bedrock_role)
                         .set_content(Some(bedrock_message_content))
                         .build()
-                        .expect("failed to build Bedrock message"),
+                        .context("failed to build Bedrock message")?,
                 );
             }
             Role::System => {
@@ -471,19 +680,47 @@ pub fn into_bedrock(
         }
     }
 
-    bedrock::Request {
+    let tool_spec: Vec<BedrockTool> = request
+        .tools
+        .iter()
+        .filter_map(|tool| {
+            Some(BedrockTool::ToolSpec(
+                BedrockToolSpec::builder()
+                    .name(tool.name.clone())
+                    .description(tool.description.clone())
+                    .input_schema(BedrockToolInputSchema::Json(value_to_aws_document(
+                        &tool.input_schema,
+                    )))
+                    .build()
+                    .log_err()?,
+            ))
+        })
+        .collect();
+
+    let tool_config: BedrockToolConfig = BedrockToolConfig::builder()
+        .set_tools(Some(tool_spec))
+        .tool_choice(BedrockToolChoice::Auto(
+            BedrockAutoToolChoice::builder().build(),
+        ))
+        .build()?;
+
+    Ok(bedrock::Request {
         model,
         messages: new_messages,
         max_tokens: max_output_tokens,
         system: Some(system_message),
-        tools: vec![],
-        tool_choice: None,
+        tools: Some(tool_config),
+        thinking: if let BedrockModelMode::Thinking { budget_tokens } = mode {
+            Some(bedrock::Thinking::Enabled { budget_tokens })
+        } else {
+            None
+        },
         metadata: None,
         stop_sequences: Vec::new(),
         temperature: request.temperature.or(Some(default_temperature)),
         top_k: None,
         top_p: None,
-    }
+    })
 }
 
 // TODO: just call the ConverseOutput.usage() method:
@@ -571,48 +808,72 @@ pub fn map_to_language_model_completion_events(
                             match event {
                                 Ok(event) => match event {
                                     ConverseStreamOutput::ContentBlockDelta(cb_delta) => {
-                                        if let Some(ContentBlockDelta::Text(text_out)) =
-                                            cb_delta.delta
-                                        {
-                                            return Some((
-                                                Some(Ok(LanguageModelCompletionEvent::Text(
-                                                    text_out,
-                                                ))),
-                                                state,
-                                            ));
-                                        } else if let Some(ContentBlockDelta::ToolUse(text_out)) =
-                                            cb_delta.delta
-                                        {
-                                            if let Some(tool_use) = state
-                                                .tool_uses_by_index
-                                                .get_mut(&cb_delta.content_block_index)
-                                            {
-                                                tool_use.input_json.push_str(text_out.input());
-                                                return Some((None, state));
-                                            };
+                                        match cb_delta.delta {
+                                            Some(ContentBlockDelta::Text(text_out)) => {
+                                                let completion_event =
+                                                    LanguageModelCompletionEvent::Text(text_out);
+                                                return Some((Some(Ok(completion_event)), state));
+                                            }
+
+                                            Some(ContentBlockDelta::ToolUse(text_out)) => {
+                                                if let Some(tool_use) = state
+                                                    .tool_uses_by_index
+                                                    .get_mut(&cb_delta.content_block_index)
+                                                {
+                                                    tool_use.input_json.push_str(text_out.input());
+                                                }
+                                            }
 
-                                            return Some((None, state));
-                                        } else if cb_delta.delta.is_none() {
-                                            return Some((None, state));
+                                            Some(ContentBlockDelta::ReasoningContent(thinking)) => {
+                                                match thinking {
+                                                    ReasoningContentBlockDelta::RedactedContent(
+                                                        redacted,
+                                                    ) => {
+                                                        let thinking_event =
+                                                            LanguageModelCompletionEvent::Thinking(
+                                                                String::from_utf8(
+                                                                    redacted.into_inner(),
+                                                                )
+                                                                .unwrap_or("REDACTED".to_string()),
+                                                            );
+
+                                                        return Some((
+                                                            Some(Ok(thinking_event)),
+                                                            state,
+                                                        ));
+                                                    }
+                                                    ReasoningContentBlockDelta::Signature(_sig) => {
+                                                    }
+                                                    ReasoningContentBlockDelta::Text(thoughts) => {
+                                                        let thinking_event =
+                                                            LanguageModelCompletionEvent::Thinking(
+                                                                thoughts.to_string(),
+                                                            );
+
+                                                        return Some((
+                                                            Some(Ok(thinking_event)),
+                                                            state,
+                                                        ));
+                                                    }
+                                                    _ => {}
+                                                }
+                                            }
+                                            _ => {}
                                         }
                                     }
                                     ConverseStreamOutput::ContentBlockStart(cb_start) => {
-                                        if let Some(start) = cb_start.start {
-                                            match start {
-                                                ContentBlockStart::ToolUse(text_out) => {
-                                                    let tool_use = RawToolUse {
-                                                        id: text_out.tool_use_id,
-                                                        name: text_out.name,
-                                                        input_json: String::new(),
-                                                    };
-
-                                                    state.tool_uses_by_index.insert(
-                                                        cb_start.content_block_index,
-                                                        tool_use,
-                                                    );
-                                                }
-                                                _ => {}
-                                            }
+                                        if let Some(ContentBlockStart::ToolUse(text_out)) =
+                                            cb_start.start
+                                        {
+                                            let tool_use = RawToolUse {
+                                                id: text_out.tool_use_id,
+                                                name: text_out.name,
+                                                input_json: String::new(),
+                                            };
+
+                                            state
+                                                .tool_uses_by_index
+                                                .insert(cb_start.content_block_index, tool_use);
                                         }
                                     }
                                     ConverseStreamOutput::ContentBlockStop(cb_stop) => {
@@ -620,30 +881,85 @@ pub fn map_to_language_model_completion_events(
                                             .tool_uses_by_index
                                             .remove(&cb_stop.content_block_index)
                                         {
+                                            let tool_use_event = LanguageModelToolUse {
+                                                id: tool_use.id.into(),
+                                                name: tool_use.name.into(),
+                                                input: if tool_use.input_json.is_empty() {
+                                                    Value::Null
+                                                } else {
+                                                    serde_json::Value::from_str(
+                                                        &tool_use.input_json,
+                                                    )
+                                                    .map_err(|err| anyhow!(err))
+                                                    .unwrap()
+                                                },
+                                            };
+
                                             return Some((
-                                                Some(maybe!({
-                                                    Ok(LanguageModelCompletionEvent::ToolUse(
-                                                        LanguageModelToolUse {
-                                                            id: tool_use.id.into(),
-                                                            name: tool_use.name.into(),
-                                                            input: if tool_use.input_json.is_empty()
-                                                            {
-                                                                Value::Null
-                                                            } else {
-                                                                serde_json::Value::from_str(
-                                                                    &tool_use.input_json,
-                                                                )
-                                                                .map_err(|err| anyhow!(err))?
-                                                            },
-                                                        },
-                                                    ))
-                                                })),
+                                                Some(Ok(LanguageModelCompletionEvent::ToolUse(
+                                                    tool_use_event,
+                                                ))),
                                                 state,
                                             ));
                                         }
                                     }
+
+                                    ConverseStreamOutput::Metadata(cb_meta) => {
+                                        if let Some(metadata) = cb_meta.usage {
+                                            let completion_event =
+                                                LanguageModelCompletionEvent::UsageUpdate(
+                                                    TokenUsage {
+                                                        input_tokens: metadata.input_tokens as u32,
+                                                        output_tokens: metadata.output_tokens
+                                                            as u32,
+                                                        cache_creation_input_tokens: default(),
+                                                        cache_read_input_tokens: default(),
+                                                    },
+                                                );
+                                            return Some((Some(Ok(completion_event)), state));
+                                        }
+                                    }
+                                    ConverseStreamOutput::MessageStop(message_stop) => {
+                                        let reason = match message_stop.stop_reason {
+                                            StopReason::ContentFiltered => {
+                                                LanguageModelCompletionEvent::Stop(
+                                                    language_model::StopReason::EndTurn,
+                                                )
+                                            }
+                                            StopReason::EndTurn => {
+                                                LanguageModelCompletionEvent::Stop(
+                                                    language_model::StopReason::EndTurn,
+                                                )
+                                            }
+                                            StopReason::GuardrailIntervened => {
+                                                LanguageModelCompletionEvent::Stop(
+                                                    language_model::StopReason::EndTurn,
+                                                )
+                                            }
+                                            StopReason::MaxTokens => {
+                                                LanguageModelCompletionEvent::Stop(
+                                                    language_model::StopReason::EndTurn,
+                                                )
+                                            }
+                                            StopReason::StopSequence => {
+                                                LanguageModelCompletionEvent::Stop(
+                                                    language_model::StopReason::EndTurn,
+                                                )
+                                            }
+                                            StopReason::ToolUse => {
+                                                LanguageModelCompletionEvent::Stop(
+                                                    language_model::StopReason::ToolUse,
+                                                )
+                                            }
+                                            _ => LanguageModelCompletionEvent::Stop(
+                                                language_model::StopReason::EndTurn,
+                                            ),
+                                        };
+                                        return Some((Some(Ok(reason)), state));
+                                    }
                                     _ => {}
                                 },
+
                                 Err(err) => return Some((Some(Err(anyhow!(err))), state)),
                             }
                         }
@@ -661,6 +977,7 @@ pub fn map_to_language_model_completion_events(
 struct ConfigurationView {
     access_key_id_editor: Entity<Editor>,
     secret_access_key_editor: Entity<Editor>,
+    session_token_editor: Entity<Editor>,
     region_editor: Entity<Editor>,
     state: gpui::Entity<State>,
     load_credentials_task: Option<Task<()>>,
@@ -670,6 +987,7 @@ impl ConfigurationView {
     const PLACEHOLDER_ACCESS_KEY_ID_TEXT: &'static str = "XXXXXXXXXXXXXXXX";
     const PLACEHOLDER_SECRET_ACCESS_KEY_TEXT: &'static str =
         "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX";
+    const PLACEHOLDER_SESSION_TOKEN_TEXT: &'static str = "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX";
     const PLACEHOLDER_REGION: &'static str = "us-east-1";
 
     fn new(state: gpui::Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
@@ -707,6 +1025,11 @@ impl ConfigurationView {
                 editor.set_placeholder_text(Self::PLACEHOLDER_SECRET_ACCESS_KEY_TEXT, cx);
                 editor
             }),
+            session_token_editor: cx.new(|cx| {
+                let mut editor = Editor::single_line(window, cx);
+                editor.set_placeholder_text(Self::PLACEHOLDER_SESSION_TOKEN_TEXT, cx);
+                editor
+            }),
             region_editor: cx.new(|cx| {
                 let mut editor = Editor::single_line(window, cx);
                 editor.set_placeholder_text(Self::PLACEHOLDER_REGION, cx);
@@ -737,6 +1060,18 @@ impl ConfigurationView {
             .to_string()
             .trim()
             .to_string();
+        let session_token = self
+            .session_token_editor
+            .read(cx)
+            .text(cx)
+            .to_string()
+            .trim()
+            .to_string();
+        let session_token = if session_token.is_empty() {
+            None
+        } else {
+            Some(session_token)
+        };
         let region = self
             .region_editor
             .read(cx)
@@ -744,15 +1079,21 @@ impl ConfigurationView {
             .to_string()
             .trim()
             .to_string();
+        let region = if region.is_empty() {
+            "us-east-1".to_string()
+        } else {
+            region
+        };
 
         let state = self.state.clone();
         cx.spawn(async move |_, cx| {
             state
                 .update(cx, |state, cx| {
                     let credentials: BedrockCredentials = BedrockCredentials {
+                        region: region.clone(),
                         access_key_id: access_key_id.clone(),
                         secret_access_key: secret_access_key.clone(),
-                        region: region.clone(),
+                        session_token: session_token.clone(),
                     };
 
                     state.set_credentials(credentials, cx)
@@ -767,6 +1108,8 @@ impl ConfigurationView {
             .update(cx, |editor, cx| editor.set_text("", window, cx));
         self.secret_access_key_editor
             .update(cx, |editor, cx| editor.set_text("", window, cx));
+        self.session_token_editor
+            .update(cx, |editor, cx| editor.set_text("", window, cx));
         self.region_editor
             .update(cx, |editor, cx| editor.set_text("", window, cx));
 
@@ -800,7 +1143,102 @@ impl ConfigurationView {
         }
     }
 
-    fn render_aa_id_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
+    fn make_input_styles(&self, cx: &Context<Self>) -> Div {
+        let bg_color = cx.theme().colors().editor_background;
+        let border_color = cx.theme().colors().border_variant;
+
+        h_flex()
+            .w_full()
+            .px_2()
+            .py_1()
+            .bg(bg_color)
+            .border_1()
+            .border_color(border_color)
+            .rounded_sm()
+    }
+
+    fn should_render_editor(&self, cx: &mut Context<Self>) -> Option<String> {
+        self.state.read(cx).is_authenticated()
+    }
+}
+
+impl Render for ConfigurationView {
+    fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
+        let env_var_set = self.state.read(cx).credentials_from_env;
+        let creds_type = self.should_render_editor(cx).is_some();
+
+        if self.load_credentials_task.is_some() {
+            return div().child(Label::new("Loading credentials...")).into_any();
+        }
+
+        if let Some(auth) = self.should_render_editor(cx) {
+            return h_flex()
+                .size_full()
+                .justify_between()
+                .child(
+                    h_flex()
+                        .gap_1()
+                        .child(Icon::new(IconName::Check).color(Color::Success))
+                        .child(Label::new(if env_var_set {
+                            format!("Access Key ID is set in {ZED_BEDROCK_ACCESS_KEY_ID_VAR}, Secret Key is set in {ZED_BEDROCK_SECRET_ACCESS_KEY_VAR}, Region is set in {ZED_BEDROCK_REGION_VAR} environment variables.")
+                        } else {
+                            auth.clone()
+                        })),
+                )
+                .child(
+                    Button::new("reset-key", "Reset key")
+                        .icon(Some(IconName::Trash))
+                        .icon_size(IconSize::Small)
+                        .icon_position(IconPosition::Start)
+                        .disabled(env_var_set || creds_type)
+                        .when(env_var_set, |this| {
+                            this.tooltip(Tooltip::text(format!("To reset your credentials, unset the {ZED_BEDROCK_ACCESS_KEY_ID_VAR}, {ZED_BEDROCK_SECRET_ACCESS_KEY_VAR}, and {ZED_BEDROCK_REGION_VAR} environment variables.")))
+                        })
+                        .when(creds_type, |this| {
+                            this.tooltip(Tooltip::text("You cannot reset credentials as they're being derived, check Zed settings to understand how"))
+                        })
+                        .on_click(cx.listener(|this, _, window, cx| this.reset_credentials(window, cx))),
+                )
+                .into_any();
+        }
+
+        v_flex()
+            .size_full()
+            .on_action(cx.listener(ConfigurationView::save_credentials))
+            .child(Label::new("To use Zed's assistant with Bedrock, you can set a custom authentication strategy through the settings.json, or use static credentials."))
+            .child(Label::new("Though to access models on AWS first, you will have to: "))
+            .child(
+                List::new()
+                    .child(
+                        InstructionListItem::new(
+                            "Grant permissions to the strategy you plan to use according to this documentation: ",
+                            Some("Prerequisites"),
+                            Some("https://docs.aws.amazon.com/bedrock/latest/userguide/inference-prereq.html"),
+                        )
+                    )
+                    .child(
+                        InstructionListItem::new(
+                            "Select the models you would like access to: ",
+                            Some("Bedrock Model Catalog"),
+                            Some("https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/modelaccess"),
+                        )
+                    )
+            )
+            .child(self.render_static_credentials_ui(cx))
+            .child(self.render_common_fields(cx))
+            .child(
+                Label::new(
+                    format!("You can also assign the {ZED_BEDROCK_ACCESS_KEY_ID_VAR}, {ZED_BEDROCK_SECRET_ACCESS_KEY_VAR} AND {ZED_BEDROCK_REGION_VAR} environment variables and restart Zed.\n Optionally, if your environment uses AWS CLI profiles, you can set {ZED_AWS_PROFILE_VAR}; if it requires a custom endpoint, you can set {ZED_AWS_ENDPOINT_VAR}; and if it requires a Session Token, you can set {ZED_BEDROCK_SESSION_TOKEN_VAR}."),
+                )
+                    .size(LabelSize::Small)
+                    .color(Color::Muted),
+            )
+            .into_any()
+    }
+}
+
+impl ConfigurationView {
+    fn render_access_key_id_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
         let text_style = self.make_text_style(cx);
 
         EditorElement::new(
@@ -814,7 +1252,7 @@ impl ConfigurationView {
         )
     }
 
-    fn render_sk_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
+    fn render_secret_key_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
         let text_style = self.make_text_style(cx);
 
         EditorElement::new(
@@ -828,6 +1266,20 @@ impl ConfigurationView {
         )
     }
 
+    fn render_session_token_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
+        let text_style = self.make_text_style(cx);
+
+        EditorElement::new(
+            &self.session_token_editor,
+            EditorStyle {
+                background: cx.theme().colors().editor_background,
+                local_player: cx.theme().players().local(),
+                text: text_style,
+                ..Default::default()
+            },
+        )
+    }
+
     fn render_region_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
         let text_style = self.make_text_style(cx);
 
@@ -842,124 +1294,80 @@ impl ConfigurationView {
         )
     }
 
-    fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
-        !self.state.read(cx).is_authenticated()
+    fn render_static_credentials_ui(&self, cx: &mut Context<Self>) -> AnyElement {
+        v_flex()
+            .my_2()
+            .gap_1p5()
+            .child(
+                Label::new("Static Keys")
+                    .size(LabelSize::Default)
+                    .weight(FontWeight::BOLD),
+            )
+            .child(
+                Label::new(
+                    "This method uses your AWS access key ID and secret access key directly.",
+                )
+                    .size(LabelSize::Small),
+            )
+            .child(
+                List::new()
+                    .child(InstructionListItem::new(
+                        "Create an IAM user in the AWS console with programmatic access",
+                        Some("IAM Console"),
+                        Some("https://us-east-1.console.aws.amazon.com/iam/home?region=us-east-1#/users"),
+                    ))
+                    .child(InstructionListItem::new(
+                        "Attach the necessary Bedrock permissions to this ",
+                        Some("user"),
+                        Some("https://docs.aws.amazon.com/bedrock/latest/userguide/inference-prereq.html"),
+                    ))
+                    .child(InstructionListItem::text_only(
+                        "Copy the access key ID and secret access key when provided",
+                    ))
+                    .child(InstructionListItem::text_only(
+                        "Enter these credentials below",
+                    )),
+            )
+            .child(
+                v_flex()
+                    .gap_0p5()
+                    .child(Label::new("Access Key ID").size(LabelSize::Small))
+                    .child(
+                        self.make_input_styles(cx)
+                            .child(self.render_access_key_id_editor(cx)),
+                    ),
+            )
+            .child(
+                v_flex()
+                    .gap_0p5()
+                    .child(Label::new("Secret Access Key").size(LabelSize::Small))
+                    .child(self.make_input_styles(cx).child(self.render_secret_key_editor(cx))),
+            )
+            .child(
+                v_flex()
+                    .gap_0p5()
+                    .child(Label::new("Session Token (Optional)").size(LabelSize::Small))
+                    .child(
+                        self.make_input_styles(cx)
+                            .child(self.render_session_token_editor(cx)),
+                    ),
+            )
+            .into_any_element()
     }
-}
 
-impl Render for ConfigurationView {
-    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
-        let env_var_set = self.state.read(cx).credentials_from_env;
-        let bg_color = cx.theme().colors().editor_background;
-        let border_color = cx.theme().colors().border_variant;
-        let input_base_styles = || {
-            h_flex()
-                .w_full()
-                .px_2()
-                .py_1()
-                .bg(bg_color)
-                .border_1()
-                .border_color(border_color)
-                .rounded_sm()
-        };
-
-        if self.load_credentials_task.is_some() {
-            div().child(Label::new("Loading credentials...")).into_any()
-        } else if self.should_render_editor(cx) {
-            v_flex()
-                .size_full()
-                .on_action(cx.listener(ConfigurationView::save_credentials))
-                .child(Label::new("To use Zed's assistant with Bedrock, you need to add the Access Key ID, Secret Access Key and AWS Region. Follow these steps:"))
-                .child(
-                    List::new()
-                        .child(
-                            InstructionListItem::new(
-                                "Start by",
-                                Some("creating a user and security credentials"),
-                                Some("https://us-east-1.console.aws.amazon.com/iam/home")
-                            )
-                        )
-                        .child(
-                            InstructionListItem::new(
-                                "Grant that user permissions according to this documentation:",
-                                Some("Prerequisites"),
-                                Some("https://docs.aws.amazon.com/bedrock/latest/userguide/inference-prereq.html")
-                            )
-                        )
-                        .child(
-                            InstructionListItem::new(
-                                "Select the models you would like access to:",
-                                Some("Bedrock Model Catalog"),
-                                Some("https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/modelaccess")
-                            )
-                        )
-                        .child(
-                            InstructionListItem::text_only("Fill the fields below and hit enter to start using the assistant")
-                        )
-                )
-                .child(
-                    v_flex()
-                        .my_2()
-                        .gap_1p5()
-                        .child(
-                            v_flex()
-                                .gap_0p5()
-                                .child(Label::new("Access Key ID").size(LabelSize::Small))
-                                .child(
-                                    input_base_styles().child(self.render_aa_id_editor(cx))
-                                )
-                        )
-                        .child(
-                            v_flex()
-                                .gap_0p5()
-                                .child(Label::new("Secret Access Key").size(LabelSize::Small))
-                                .child(
-                                    input_base_styles().child(self.render_sk_editor(cx))
-                                )
-                        )
-                        .child(
-                            v_flex()
-                                .gap_0p5()
-                                .child(Label::new("Region").size(LabelSize::Small))
-                                .child(
-                                    input_base_styles().child(self.render_region_editor(cx))
-                                )
-                            )
-                )
-                .child(
-                    Label::new(
-                        format!("You can also assign the {ZED_BEDROCK_ACCESS_KEY_ID_VAR}, {ZED_BEDROCK_SECRET_ACCESS_KEY_VAR}, and {ZED_BEDROCK_REGION_VAR} environment variables and restart Zed."),
-                    )
-                        .size(LabelSize::Small)
-                        .color(Color::Muted),
-                )
-                .into_any()
-        } else {
-            h_flex()
-                .size_full()
-                .justify_between()
-                .child(
-                    h_flex()
-                        .gap_1()
-                        .child(Icon::new(IconName::Check).color(Color::Success))
-                        .child(Label::new(if env_var_set {
-                            format!("Access Key ID is set in {ZED_BEDROCK_ACCESS_KEY_ID_VAR}, Secret Key is set in {ZED_BEDROCK_SECRET_ACCESS_KEY_VAR}, Region is set in {ZED_BEDROCK_REGION_VAR} environment variables.")
-                        } else {
-                            "Credentials configured.".to_string()
-                        })),
-                )
-                .child(
-                    Button::new("reset-key", "Reset key")
-                        .icon(Some(IconName::Trash))
-                        .icon_size(IconSize::Small)
-                        .icon_position(IconPosition::Start)
-                        .disabled(env_var_set)
-                        .when(env_var_set, |this| {
-                            this.tooltip(Tooltip::text(format!("To reset your credentials, unset the {ZED_BEDROCK_ACCESS_KEY_ID_VAR}, {ZED_BEDROCK_SECRET_ACCESS_KEY_VAR}, and {ZED_BEDROCK_REGION_VAR} environment variables.")))
-                        })
-                        .on_click(cx.listener(|this, _, window, cx| this.reset_credentials(window, cx))),
-                )
-                .into_any()
-        }
+    fn render_common_fields(&self, cx: &mut Context<Self>) -> AnyElement {
+        v_flex()
+            .my_2()
+            .gap_1p5()
+            .child(
+                v_flex()
+                    .gap_0p5()
+                    .child(Label::new("Region").size(LabelSize::Small))
+                    .child(
+                        self.make_input_styles(cx)
+                            .child(self.render_region_editor(cx)),
+                    ),
+            )
+            .into_any_element()
     }
 }

crates/language_models/src/settings.rs 🔗

@@ -72,6 +72,7 @@ pub struct AllLanguageModelSettings {
 #[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
 pub struct AllLanguageModelSettingsContent {
     pub anthropic: Option<AnthropicSettingsContent>,
+    pub bedrock: Option<AmazonBedrockSettingsContent>,
     pub ollama: Option<OllamaSettingsContent>,
     pub lmstudio: Option<LmStudioSettingsContent>,
     pub openai: Option<OpenAiSettingsContent>,
@@ -160,6 +161,15 @@ pub struct AnthropicSettingsContentV1 {
     pub available_models: Option<Vec<provider::anthropic::AvailableModel>>,
 }
 
+#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
+pub struct AmazonBedrockSettingsContent {
+    available_models: Option<Vec<provider::bedrock::AvailableModel>>,
+    endpoint_url: Option<String>,
+    region: Option<String>,
+    profile: Option<String>,
+    authentication_method: Option<provider::bedrock::BedrockAuthMethod>,
+}
+
 #[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
 pub struct OllamaSettingsContent {
     pub api_url: Option<String>,
@@ -297,6 +307,25 @@ impl settings::Settings for AllLanguageModelSettings {
                 anthropic.as_ref().and_then(|s| s.available_models.clone()),
             );
 
+            // Bedrock
+            let bedrock = value.bedrock.clone();
+            merge(
+                &mut settings.bedrock.profile_name,
+                bedrock.as_ref().map(|s| s.profile.clone()),
+            );
+            merge(
+                &mut settings.bedrock.authentication_method,
+                bedrock.as_ref().map(|s| s.authentication_method.clone()),
+            );
+            merge(
+                &mut settings.bedrock.region,
+                bedrock.as_ref().map(|s| s.region.clone()),
+            );
+            merge(
+                &mut settings.bedrock.endpoint,
+                bedrock.as_ref().map(|s| s.endpoint_url.clone()),
+            );
+
             // Ollama
             let ollama = value.ollama.clone();