bedrock.rs

   1use std::pin::Pin;
   2use std::sync::Arc;
   3
   4use anyhow::{Context as _, Result, anyhow};
   5use aws_config::stalled_stream_protection::StalledStreamProtectionConfig;
   6use aws_config::{BehaviorVersion, Region};
   7use aws_credential_types::{Credentials, Token};
   8use aws_http_client::AwsHttpClient;
   9use bedrock::bedrock_client::Client as BedrockClient;
  10use bedrock::bedrock_client::config::timeout::TimeoutConfig;
  11use bedrock::bedrock_client::types::{
  12    CachePointBlock, CachePointType, ContentBlockDelta, ContentBlockStart, ConverseStreamOutput,
  13    ReasoningContentBlockDelta, StopReason,
  14};
  15use bedrock::{
  16    BedrockAnyToolChoice, BedrockAutoToolChoice, BedrockBlob, BedrockError, BedrockImageBlock,
  17    BedrockImageFormat, BedrockImageSource, BedrockInnerContent, BedrockMessage, BedrockModelMode,
  18    BedrockStreamingResponse, BedrockThinkingBlock, BedrockThinkingTextBlock, BedrockTool,
  19    BedrockToolChoice, BedrockToolConfig, BedrockToolInputSchema, BedrockToolResultBlock,
  20    BedrockToolResultContentBlock, BedrockToolResultStatus, BedrockToolSpec, BedrockToolUseBlock,
  21    Model, value_to_aws_document,
  22};
  23use collections::{BTreeMap, HashMap};
  24use credentials_provider::CredentialsProvider;
  25use futures::{FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream};
  26use gpui::{
  27    AnyView, App, AsyncApp, Context, Entity, FocusHandle, Subscription, Task, TaskExt, Window,
  28    actions,
  29};
  30use gpui_tokio::Tokio;
  31use http_client::HttpClient;
  32use language_model::{
  33    AuthenticateError, EnvVar, IconOrSvg, LanguageModel, LanguageModelCacheConfiguration,
  34    LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName,
  35    LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
  36    LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice,
  37    LanguageModelToolResultContent, LanguageModelToolUse, MessageContent, RateLimiter, Role,
  38    TokenUsage, env_var,
  39};
  40use schemars::JsonSchema;
  41use serde::{Deserialize, Serialize};
  42use serde_json::Value;
  43use settings::{BedrockAvailableModel as AvailableModel, Settings, SettingsStore};
  44use smol::lock::OnceCell;
  45use std::sync::LazyLock;
  46use strum::{EnumIter, IntoEnumIterator, IntoStaticStr};
  47use ui::{ButtonLink, ConfiguredApiCard, Divider, List, ListBulletItem, prelude::*};
  48use ui_input::InputField;
  49use util::ResultExt;
  50
  51use crate::AllLanguageModelSettings;
  52use language_model::util::{fix_streamed_json, parse_tool_arguments};
  53
  54actions!(bedrock, [Tab, TabPrev]);
  55
  56const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("amazon-bedrock");
  57const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Amazon Bedrock");
  58
  59/// Credentials stored in the keychain for static authentication.
  60/// Region is handled separately since it's orthogonal to auth method.
  61#[derive(Default, Clone, Deserialize, Serialize, PartialEq, Debug)]
  62pub struct BedrockCredentials {
  63    pub access_key_id: String,
  64    pub secret_access_key: String,
  65    pub session_token: Option<String>,
  66    pub bearer_token: Option<String>,
  67}
  68
  69/// Resolved authentication configuration for Bedrock.
  70/// Settings take priority over UX-provided credentials.
  71#[derive(Clone, Debug, PartialEq)]
  72pub enum BedrockAuth {
  73    /// Use default AWS credential provider chain (IMDSv2, PodIdentity, env vars, etc.)
  74    Automatic,
  75    /// Use AWS named profile from ~/.aws/credentials or ~/.aws/config
  76    NamedProfile { profile_name: String },
  77    /// Use AWS SSO profile
  78    SingleSignOn { profile_name: String },
  79    /// Use IAM credentials (access key + secret + optional session token)
  80    IamCredentials {
  81        access_key_id: String,
  82        secret_access_key: String,
  83        session_token: Option<String>,
  84    },
  85    /// Use Bedrock API Key (bearer token authentication)
  86    ApiKey { api_key: String },
  87}
  88
  89impl BedrockCredentials {
  90    /// Convert stored credentials to the appropriate auth variant.
  91    /// Prefers API key if present, otherwise uses IAM credentials.
  92    fn into_auth(self) -> Option<BedrockAuth> {
  93        if let Some(api_key) = self.bearer_token.filter(|t| !t.is_empty()) {
  94            Some(BedrockAuth::ApiKey { api_key })
  95        } else if !self.access_key_id.is_empty() && !self.secret_access_key.is_empty() {
  96            Some(BedrockAuth::IamCredentials {
  97                access_key_id: self.access_key_id,
  98                secret_access_key: self.secret_access_key,
  99                session_token: self.session_token.filter(|t| !t.is_empty()),
 100            })
 101        } else {
 102            None
 103        }
 104    }
 105}
 106
 107#[derive(Default, Clone, Debug, PartialEq)]
 108pub struct AmazonBedrockSettings {
 109    pub available_models: Vec<AvailableModel>,
 110    pub region: Option<String>,
 111    pub endpoint: Option<String>,
 112    pub profile_name: Option<String>,
 113    pub role_arn: Option<String>,
 114    pub authentication_method: Option<BedrockAuthMethod>,
 115    pub allow_global: Option<bool>,
 116    pub allow_extended_context: Option<bool>,
 117}
 118
 119#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, EnumIter, IntoStaticStr, JsonSchema)]
 120pub enum BedrockAuthMethod {
 121    #[serde(rename = "named_profile")]
 122    NamedProfile,
 123    #[serde(rename = "sso")]
 124    SingleSignOn,
 125    #[serde(rename = "api_key")]
 126    ApiKey,
 127    /// IMDSv2, PodIdentity, env vars, etc.
 128    #[serde(rename = "default")]
 129    Automatic,
 130}
 131
 132impl From<settings::BedrockAuthMethodContent> for BedrockAuthMethod {
 133    fn from(value: settings::BedrockAuthMethodContent) -> Self {
 134        match value {
 135            settings::BedrockAuthMethodContent::SingleSignOn => BedrockAuthMethod::SingleSignOn,
 136            settings::BedrockAuthMethodContent::Automatic => BedrockAuthMethod::Automatic,
 137            settings::BedrockAuthMethodContent::NamedProfile => BedrockAuthMethod::NamedProfile,
 138            settings::BedrockAuthMethodContent::ApiKey => BedrockAuthMethod::ApiKey,
 139        }
 140    }
 141}
 142
 143#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize, JsonSchema)]
 144#[serde(tag = "type", rename_all = "lowercase")]
 145pub enum ModelMode {
 146    #[default]
 147    Default,
 148    Thinking {
 149        /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
 150        budget_tokens: Option<u64>,
 151    },
 152    AdaptiveThinking {
 153        effort: bedrock::BedrockAdaptiveThinkingEffort,
 154    },
 155}
 156
 157impl From<ModelMode> for BedrockModelMode {
 158    fn from(value: ModelMode) -> Self {
 159        match value {
 160            ModelMode::Default => BedrockModelMode::Default,
 161            ModelMode::Thinking { budget_tokens } => BedrockModelMode::Thinking { budget_tokens },
 162            ModelMode::AdaptiveThinking { effort } => BedrockModelMode::AdaptiveThinking { effort },
 163        }
 164    }
 165}
 166
 167impl From<BedrockModelMode> for ModelMode {
 168    fn from(value: BedrockModelMode) -> Self {
 169        match value {
 170            BedrockModelMode::Default => ModelMode::Default,
 171            BedrockModelMode::Thinking { budget_tokens } => ModelMode::Thinking { budget_tokens },
 172            BedrockModelMode::AdaptiveThinking { effort } => ModelMode::AdaptiveThinking { effort },
 173        }
 174    }
 175}
 176
 177/// The URL of the base AWS service.
 178///
 179/// Right now we're just using this as the key to store the AWS credentials
 180/// under in the keychain.
 181const AMAZON_AWS_URL: &str = "https://amazonaws.com";
 182
 183// These environment variables all use a `ZED_` prefix because we don't want to overwrite the user's AWS credentials.
 184static ZED_BEDROCK_ACCESS_KEY_ID_VAR: LazyLock<EnvVar> = env_var!("ZED_ACCESS_KEY_ID");
 185static ZED_BEDROCK_SECRET_ACCESS_KEY_VAR: LazyLock<EnvVar> = env_var!("ZED_SECRET_ACCESS_KEY");
 186static ZED_BEDROCK_SESSION_TOKEN_VAR: LazyLock<EnvVar> = env_var!("ZED_SESSION_TOKEN");
 187static ZED_AWS_PROFILE_VAR: LazyLock<EnvVar> = env_var!("ZED_AWS_PROFILE");
 188static ZED_BEDROCK_REGION_VAR: LazyLock<EnvVar> = env_var!("ZED_AWS_REGION");
 189static ZED_AWS_ENDPOINT_VAR: LazyLock<EnvVar> = env_var!("ZED_AWS_ENDPOINT");
 190static ZED_BEDROCK_BEARER_TOKEN_VAR: LazyLock<EnvVar> = env_var!("ZED_BEDROCK_BEARER_TOKEN");
 191
 192pub struct State {
 193    /// The resolved authentication method. Settings take priority over UX credentials.
 194    auth: Option<BedrockAuth>,
 195    /// Raw settings from settings.json
 196    settings: Option<AmazonBedrockSettings>,
 197    /// Whether credentials came from environment variables (only relevant for static credentials)
 198    credentials_from_env: bool,
 199    credentials_provider: Arc<dyn CredentialsProvider>,
 200    _subscription: Subscription,
 201}
 202
 203impl State {
 204    fn reset_auth(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
 205        let credentials_provider = self.credentials_provider.clone();
 206        cx.spawn(async move |this, cx| {
 207            credentials_provider
 208                .delete_credentials(AMAZON_AWS_URL, cx)
 209                .await
 210                .log_err();
 211            this.update(cx, |this, cx| {
 212                this.auth = None;
 213                this.credentials_from_env = false;
 214                cx.notify();
 215            })
 216        })
 217    }
 218
 219    fn set_static_credentials(
 220        &mut self,
 221        credentials: BedrockCredentials,
 222        cx: &mut Context<Self>,
 223    ) -> Task<Result<()>> {
 224        let auth = credentials.clone().into_auth();
 225        let credentials_provider = self.credentials_provider.clone();
 226        cx.spawn(async move |this, cx| {
 227            credentials_provider
 228                .write_credentials(
 229                    AMAZON_AWS_URL,
 230                    "Bearer",
 231                    &serde_json::to_vec(&credentials)?,
 232                    cx,
 233                )
 234                .await?;
 235            this.update(cx, |this, cx| {
 236                this.auth = auth;
 237                this.credentials_from_env = false;
 238                cx.notify();
 239            })
 240        })
 241    }
 242
 243    fn is_authenticated(&self) -> bool {
 244        self.auth.is_some()
 245    }
 246
 247    /// Resolve authentication. Settings take priority over UX-provided credentials.
 248    fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
 249        if self.is_authenticated() {
 250            return Task::ready(Ok(()));
 251        }
 252
 253        // Step 1: Check if settings specify an auth method (enterprise control)
 254        if let Some(settings) = &self.settings {
 255            if let Some(method) = &settings.authentication_method {
 256                let profile_name = settings
 257                    .profile_name
 258                    .clone()
 259                    .unwrap_or_else(|| "default".to_string());
 260
 261                let auth = match method {
 262                    BedrockAuthMethod::Automatic => BedrockAuth::Automatic,
 263                    BedrockAuthMethod::NamedProfile => BedrockAuth::NamedProfile { profile_name },
 264                    BedrockAuthMethod::SingleSignOn => BedrockAuth::SingleSignOn { profile_name },
 265                    BedrockAuthMethod::ApiKey => {
 266                        // ApiKey method means "use static credentials from keychain/env"
 267                        // Fall through to load them below
 268                        return self.load_static_credentials(cx);
 269                    }
 270                };
 271
 272                return cx.spawn(async move |this, cx| {
 273                    this.update(cx, |this, cx| {
 274                        this.auth = Some(auth);
 275                        this.credentials_from_env = false;
 276                        cx.notify();
 277                    })?;
 278                    Ok(())
 279                });
 280            }
 281        }
 282
 283        // Step 2: No settings auth method - try to load static credentials
 284        self.load_static_credentials(cx)
 285    }
 286
 287    /// Load static credentials from environment variables or keychain.
 288    fn load_static_credentials(
 289        &self,
 290        cx: &mut Context<Self>,
 291    ) -> Task<Result<(), AuthenticateError>> {
 292        let credentials_provider = self.credentials_provider.clone();
 293        cx.spawn(async move |this, cx| {
 294            // Try environment variables first
 295            let (auth, from_env) = if let Some(bearer_token) = &ZED_BEDROCK_BEARER_TOKEN_VAR.value {
 296                if !bearer_token.is_empty() {
 297                    (
 298                        Some(BedrockAuth::ApiKey {
 299                            api_key: bearer_token.to_string(),
 300                        }),
 301                        true,
 302                    )
 303                } else {
 304                    (None, false)
 305                }
 306            } else if let Some(access_key_id) = &ZED_BEDROCK_ACCESS_KEY_ID_VAR.value {
 307                if let Some(secret_access_key) = &ZED_BEDROCK_SECRET_ACCESS_KEY_VAR.value {
 308                    if !access_key_id.is_empty() && !secret_access_key.is_empty() {
 309                        let session_token = ZED_BEDROCK_SESSION_TOKEN_VAR
 310                            .value
 311                            .as_deref()
 312                            .filter(|s| !s.is_empty())
 313                            .map(|s| s.to_string());
 314                        (
 315                            Some(BedrockAuth::IamCredentials {
 316                                access_key_id: access_key_id.to_string(),
 317                                secret_access_key: secret_access_key.to_string(),
 318                                session_token,
 319                            }),
 320                            true,
 321                        )
 322                    } else {
 323                        (None, false)
 324                    }
 325                } else {
 326                    (None, false)
 327                }
 328            } else {
 329                (None, false)
 330            };
 331
 332            // If we got auth from env vars, use it
 333            if let Some(auth) = auth {
 334                this.update(cx, |this, cx| {
 335                    this.auth = Some(auth);
 336                    this.credentials_from_env = from_env;
 337                    cx.notify();
 338                })?;
 339                return Ok(());
 340            }
 341
 342            // Try keychain
 343            let (_, credentials_bytes) = credentials_provider
 344                .read_credentials(AMAZON_AWS_URL, cx)
 345                .await?
 346                .ok_or(AuthenticateError::CredentialsNotFound)?;
 347
 348            let credentials_str = String::from_utf8(credentials_bytes)
 349                .with_context(|| format!("invalid {PROVIDER_NAME} credentials"))?;
 350
 351            let credentials: BedrockCredentials =
 352                serde_json::from_str(&credentials_str).context("failed to parse credentials")?;
 353
 354            let auth = credentials
 355                .into_auth()
 356                .ok_or(AuthenticateError::CredentialsNotFound)?;
 357
 358            this.update(cx, |this, cx| {
 359                this.auth = Some(auth);
 360                this.credentials_from_env = false;
 361                cx.notify();
 362            })?;
 363
 364            Ok(())
 365        })
 366    }
 367
 368    /// Get the resolved region. Checks env var, then settings, then defaults to us-east-1.
 369    fn get_region(&self) -> String {
 370        // Priority: env var > settings > default
 371        if let Some(region) = ZED_BEDROCK_REGION_VAR.value.as_deref() {
 372            if !region.is_empty() {
 373                return region.to_string();
 374            }
 375        }
 376
 377        self.settings
 378            .as_ref()
 379            .and_then(|s| s.region.clone())
 380            .unwrap_or_else(|| "us-east-1".to_string())
 381    }
 382
 383    fn get_allow_global(&self) -> bool {
 384        self.settings
 385            .as_ref()
 386            .and_then(|s| s.allow_global)
 387            .unwrap_or(false)
 388    }
 389
 390    fn get_allow_extended_context(&self) -> bool {
 391        self.settings
 392            .as_ref()
 393            .and_then(|s| s.allow_extended_context)
 394            .unwrap_or(false)
 395    }
 396}
 397
 398pub struct BedrockLanguageModelProvider {
 399    http_client: AwsHttpClient,
 400    handle: tokio::runtime::Handle,
 401    state: Entity<State>,
 402}
 403
 404impl BedrockLanguageModelProvider {
 405    pub fn new(
 406        http_client: Arc<dyn HttpClient>,
 407        credentials_provider: Arc<dyn CredentialsProvider>,
 408        cx: &mut App,
 409    ) -> Self {
 410        let state = cx.new(|cx| State {
 411            auth: None,
 412            settings: Some(AllLanguageModelSettings::get_global(cx).bedrock.clone()),
 413            credentials_from_env: false,
 414            credentials_provider,
 415            _subscription: cx.observe_global::<SettingsStore>(|_, cx| {
 416                cx.notify();
 417            }),
 418        });
 419
 420        Self {
 421            http_client: AwsHttpClient::new(http_client),
 422            handle: Tokio::handle(cx),
 423            state,
 424        }
 425    }
 426
 427    fn create_language_model(&self, model: bedrock::Model) -> Arc<dyn LanguageModel> {
 428        Arc::new(BedrockModel {
 429            id: LanguageModelId::from(model.id().to_string()),
 430            model,
 431            http_client: self.http_client.clone(),
 432            handle: self.handle.clone(),
 433            state: self.state.clone(),
 434            client: OnceCell::new(),
 435            request_limiter: RateLimiter::new(4),
 436        })
 437    }
 438}
 439
 440impl LanguageModelProvider for BedrockLanguageModelProvider {
 441    fn id(&self) -> LanguageModelProviderId {
 442        PROVIDER_ID
 443    }
 444
 445    fn name(&self) -> LanguageModelProviderName {
 446        PROVIDER_NAME
 447    }
 448
 449    fn icon(&self) -> IconOrSvg {
 450        IconOrSvg::Icon(IconName::AiBedrock)
 451    }
 452
 453    fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
 454        Some(self.create_language_model(bedrock::Model::default()))
 455    }
 456
 457    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 458        let region = self.state.read(cx).get_region();
 459        Some(self.create_language_model(bedrock::Model::default_fast(region.as_str())))
 460    }
 461
 462    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 463        let mut models = BTreeMap::default();
 464
 465        for model in bedrock::Model::iter() {
 466            if !matches!(model, bedrock::Model::Custom { .. }) {
 467                models.insert(model.id().to_string(), model);
 468            }
 469        }
 470
 471        // Override with available models from settings
 472        for model in AllLanguageModelSettings::get_global(cx)
 473            .bedrock
 474            .available_models
 475            .iter()
 476        {
 477            models.insert(
 478                model.name.clone(),
 479                bedrock::Model::Custom {
 480                    name: model.name.clone(),
 481                    display_name: model.display_name.clone(),
 482                    max_tokens: model.max_tokens,
 483                    max_output_tokens: model.max_output_tokens,
 484                    default_temperature: model.default_temperature,
 485                    cache_configuration: model.cache_configuration.as_ref().map(|config| {
 486                        bedrock::BedrockModelCacheConfiguration {
 487                            max_cache_anchors: config.max_cache_anchors,
 488                            min_total_token: config.min_total_token,
 489                        }
 490                    }),
 491                },
 492            );
 493        }
 494
 495        models
 496            .into_values()
 497            .map(|model| self.create_language_model(model))
 498            .collect()
 499    }
 500
 501    fn is_authenticated(&self, cx: &App) -> bool {
 502        self.state.read(cx).is_authenticated()
 503    }
 504
 505    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
 506        self.state.update(cx, |state, cx| state.authenticate(cx))
 507    }
 508
 509    fn configuration_view(
 510        &self,
 511        _target_agent: language_model::ConfigurationViewTargetAgent,
 512        window: &mut Window,
 513        cx: &mut App,
 514    ) -> AnyView {
 515        cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
 516            .into()
 517    }
 518
 519    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
 520        self.state.update(cx, |state, cx| state.reset_auth(cx))
 521    }
 522}
 523
 524impl LanguageModelProviderState for BedrockLanguageModelProvider {
 525    type ObservableEntity = State;
 526
 527    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
 528        Some(self.state.clone())
 529    }
 530}
 531
 532struct BedrockModel {
 533    id: LanguageModelId,
 534    model: Model,
 535    http_client: AwsHttpClient,
 536    handle: tokio::runtime::Handle,
 537    client: OnceCell<BedrockClient>,
 538    state: Entity<State>,
 539    request_limiter: RateLimiter,
 540}
 541
 542impl BedrockModel {
 543    fn get_or_init_client(&self, cx: &AsyncApp) -> anyhow::Result<&BedrockClient> {
 544        self.client
 545            .get_or_try_init_blocking(|| {
 546                let (auth, endpoint, region) = cx.read_entity(&self.state, |state, _cx| {
 547                    let endpoint = state.settings.as_ref().and_then(|s| s.endpoint.clone());
 548                    let region = state.get_region();
 549                    (state.auth.clone(), endpoint, region)
 550                });
 551
 552                let mut config_builder = aws_config::defaults(BehaviorVersion::latest())
 553                    .stalled_stream_protection(StalledStreamProtectionConfig::disabled())
 554                    .http_client(self.http_client.clone())
 555                    .region(Region::new(region))
 556                    .timeout_config(TimeoutConfig::disabled());
 557
 558                if let Some(endpoint_url) = endpoint
 559                    && !endpoint_url.is_empty()
 560                {
 561                    config_builder = config_builder.endpoint_url(endpoint_url);
 562                }
 563
 564                match auth {
 565                    Some(BedrockAuth::Automatic) | None => {
 566                        // Use default AWS credential provider chain
 567                    }
 568                    Some(BedrockAuth::NamedProfile { profile_name })
 569                    | Some(BedrockAuth::SingleSignOn { profile_name }) => {
 570                        if !profile_name.is_empty() {
 571                            config_builder = config_builder.profile_name(profile_name);
 572                        }
 573                    }
 574                    Some(BedrockAuth::IamCredentials {
 575                        access_key_id,
 576                        secret_access_key,
 577                        session_token,
 578                    }) => {
 579                        let aws_creds = Credentials::new(
 580                            access_key_id,
 581                            secret_access_key,
 582                            session_token,
 583                            None,
 584                            "zed-bedrock-provider",
 585                        );
 586                        config_builder = config_builder.credentials_provider(aws_creds);
 587                    }
 588                    Some(BedrockAuth::ApiKey { api_key }) => {
 589                        config_builder = config_builder
 590                            .auth_scheme_preference(["httpBearerAuth".into()]) // https://github.com/smithy-lang/smithy-rs/pull/4241
 591                            .token_provider(Token::new(api_key, None));
 592                    }
 593                }
 594
 595                let config = self.handle.block_on(config_builder.load());
 596
 597                anyhow::Ok(BedrockClient::new(&config))
 598            })
 599            .context("initializing Bedrock client")?;
 600
 601        self.client.get().context("Bedrock client not initialized")
 602    }
 603
 604    fn stream_completion(
 605        &self,
 606        request: bedrock::Request,
 607        cx: &AsyncApp,
 608    ) -> BoxFuture<
 609        'static,
 610        Result<BoxStream<'static, Result<BedrockStreamingResponse, anyhow::Error>>, BedrockError>,
 611    > {
 612        let Ok(runtime_client) = self
 613            .get_or_init_client(cx)
 614            .cloned()
 615            .context("Bedrock client not initialized")
 616        else {
 617            return futures::future::ready(Err(BedrockError::Other(anyhow!("App state dropped"))))
 618                .boxed();
 619        };
 620
 621        let task = Tokio::spawn(cx, bedrock::stream_completion(runtime_client, request));
 622        async move { task.await.map_err(|e| BedrockError::Other(e.into()))? }.boxed()
 623    }
 624}
 625
 626impl LanguageModel for BedrockModel {
 627    fn id(&self) -> LanguageModelId {
 628        self.id.clone()
 629    }
 630
 631    fn name(&self) -> LanguageModelName {
 632        LanguageModelName::from(self.model.display_name().to_string())
 633    }
 634
 635    fn provider_id(&self) -> LanguageModelProviderId {
 636        PROVIDER_ID
 637    }
 638
 639    fn provider_name(&self) -> LanguageModelProviderName {
 640        PROVIDER_NAME
 641    }
 642
 643    fn supports_tools(&self) -> bool {
 644        self.model.supports_tool_use()
 645    }
 646
 647    fn supports_images(&self) -> bool {
 648        self.model.supports_images()
 649    }
 650
 651    fn supports_thinking(&self) -> bool {
 652        self.model.supports_thinking()
 653    }
 654
 655    fn supported_effort_levels(&self) -> Vec<language_model::LanguageModelEffortLevel> {
 656        if self.model.supports_adaptive_thinking() {
 657            vec![
 658                language_model::LanguageModelEffortLevel {
 659                    name: "Low".into(),
 660                    value: "low".into(),
 661                    is_default: false,
 662                },
 663                language_model::LanguageModelEffortLevel {
 664                    name: "Medium".into(),
 665                    value: "medium".into(),
 666                    is_default: false,
 667                },
 668                language_model::LanguageModelEffortLevel {
 669                    name: "High".into(),
 670                    value: "high".into(),
 671                    is_default: true,
 672                },
 673                language_model::LanguageModelEffortLevel {
 674                    name: "Max".into(),
 675                    value: "max".into(),
 676                    is_default: false,
 677                },
 678            ]
 679        } else {
 680            Vec::new()
 681        }
 682    }
 683
 684    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
 685        match choice {
 686            LanguageModelToolChoice::Auto | LanguageModelToolChoice::Any => {
 687                self.model.supports_tool_use()
 688            }
 689            // Add support for None - we'll filter tool calls at response
 690            LanguageModelToolChoice::None => self.model.supports_tool_use(),
 691        }
 692    }
 693
 694    fn supports_streaming_tools(&self) -> bool {
 695        true
 696    }
 697
 698    fn telemetry_id(&self) -> String {
 699        format!("bedrock/{}", self.model.id())
 700    }
 701
 702    fn max_token_count(&self) -> u64 {
 703        self.model.max_token_count()
 704    }
 705
 706    fn max_output_tokens(&self) -> Option<u64> {
 707        Some(self.model.max_output_tokens())
 708    }
 709
 710    fn count_tokens(
 711        &self,
 712        request: LanguageModelRequest,
 713        cx: &App,
 714    ) -> BoxFuture<'static, Result<u64>> {
 715        get_bedrock_tokens(request, cx)
 716    }
 717
 718    fn stream_completion(
 719        &self,
 720        request: LanguageModelRequest,
 721        cx: &AsyncApp,
 722    ) -> BoxFuture<
 723        'static,
 724        Result<
 725            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
 726            LanguageModelCompletionError,
 727        >,
 728    > {
 729        let (region, allow_global, allow_extended_context) =
 730            cx.read_entity(&self.state, |state, _cx| {
 731                (
 732                    state.get_region(),
 733                    state.get_allow_global(),
 734                    state.get_allow_extended_context(),
 735                )
 736            });
 737
 738        let model_id = match self.model.cross_region_inference_id(&region, allow_global) {
 739            Ok(s) => s,
 740            Err(e) => {
 741                return async move { Err(e.into()) }.boxed();
 742            }
 743        };
 744
 745        let deny_tool_calls = request.tool_choice == Some(LanguageModelToolChoice::None);
 746
 747        let use_extended_context = allow_extended_context && self.model.supports_extended_context();
 748
 749        let request = match into_bedrock(
 750            request,
 751            model_id,
 752            self.model.default_temperature(),
 753            self.model.max_output_tokens(),
 754            self.model.thinking_mode(),
 755            self.model.supports_caching(),
 756            self.model.supports_tool_use(),
 757            use_extended_context,
 758        ) {
 759            Ok(request) => request,
 760            Err(err) => return futures::future::ready(Err(err.into())).boxed(),
 761        };
 762
 763        let request = self.stream_completion(request, cx);
 764        let display_name = self.model.display_name().to_string();
 765        let future = self.request_limiter.stream(async move {
 766            let response = request.await.map_err(|err| match err {
 767                BedrockError::Validation(ref msg) => {
 768                    if msg.contains("model identifier is invalid") {
 769                        LanguageModelCompletionError::Other(anyhow!(
 770                            "{display_name} is not available in {region}. \
 771                                 Try switching to a region where this model is supported."
 772                        ))
 773                    } else {
 774                        LanguageModelCompletionError::BadRequestFormat {
 775                            provider: PROVIDER_NAME,
 776                            message: msg.clone(),
 777                        }
 778                    }
 779                }
 780                BedrockError::RateLimited => LanguageModelCompletionError::RateLimitExceeded {
 781                    provider: PROVIDER_NAME,
 782                    retry_after: None,
 783                },
 784                BedrockError::ServiceUnavailable => {
 785                    LanguageModelCompletionError::ServerOverloaded {
 786                        provider: PROVIDER_NAME,
 787                        retry_after: None,
 788                    }
 789                }
 790                BedrockError::AccessDenied(msg) => LanguageModelCompletionError::PermissionError {
 791                    provider: PROVIDER_NAME,
 792                    message: msg,
 793                },
 794                BedrockError::InternalServer(msg) => {
 795                    LanguageModelCompletionError::ApiInternalServerError {
 796                        provider: PROVIDER_NAME,
 797                        message: msg,
 798                    }
 799                }
 800                other => LanguageModelCompletionError::Other(anyhow!(other)),
 801            })?;
 802            let events = map_to_language_model_completion_events(response);
 803
 804            if deny_tool_calls {
 805                Ok(deny_tool_use_events(events).boxed())
 806            } else {
 807                Ok(events.boxed())
 808            }
 809        });
 810
 811        async move { Ok(future.await?.boxed()) }.boxed()
 812    }
 813
 814    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
 815        self.model
 816            .cache_configuration()
 817            .map(|config| LanguageModelCacheConfiguration {
 818                max_cache_anchors: config.max_cache_anchors,
 819                should_speculate: false,
 820                min_total_token: config.min_total_token,
 821            })
 822    }
 823}
 824
 825fn deny_tool_use_events(
 826    events: impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
 827) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
 828    events.map(|event| {
 829        match event {
 830            Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) => {
 831                // Convert tool use to an error message if model decided to call it
 832                Ok(LanguageModelCompletionEvent::Text(format!(
 833                    "\n\n[Error: Tool calls are disabled in this context. Attempted to call '{}']",
 834                    tool_use.name
 835                )))
 836            }
 837            other => other,
 838        }
 839    })
 840}
 841
 842pub fn into_bedrock(
 843    request: LanguageModelRequest,
 844    model: String,
 845    default_temperature: f32,
 846    max_output_tokens: u64,
 847    thinking_mode: BedrockModelMode,
 848    supports_caching: bool,
 849    supports_tool_use: bool,
 850    allow_extended_context: bool,
 851) -> Result<bedrock::Request> {
 852    let mut new_messages: Vec<BedrockMessage> = Vec::new();
 853    let mut system_message = String::new();
 854
 855    // Track whether messages contain tool content - Bedrock requires toolConfig
 856    // when tool blocks are present, so we may need to add a dummy tool
 857    let mut messages_contain_tool_content = false;
 858
 859    for message in request.messages {
 860        if message.contents_empty() {
 861            continue;
 862        }
 863
 864        match message.role {
 865            Role::User | Role::Assistant => {
 866                let mut bedrock_message_content: Vec<BedrockInnerContent> = message
 867                    .content
 868                    .into_iter()
 869                    .filter_map(|content| match content {
 870                        MessageContent::Text(text) => {
 871                            if !text.is_empty() {
 872                                Some(BedrockInnerContent::Text(text))
 873                            } else {
 874                                None
 875                            }
 876                        }
 877                        MessageContent::Thinking { text, signature } => {
 878                            if model.contains(Model::DeepSeekR1.request_id()) {
 879                                // DeepSeekR1 doesn't support thinking blocks
 880                                // And the AWS API demands that you strip them
 881                                return None;
 882                            }
 883                            if signature.is_none() {
 884                                // Thinking blocks without a signature are invalid
 885                                // (e.g. from cancellation mid-think) and must be
 886                                // stripped to avoid API errors.
 887                                return None;
 888                            }
 889                            let thinking = BedrockThinkingTextBlock::builder()
 890                                .text(text)
 891                                .set_signature(signature)
 892                                .build()
 893                                .context("failed to build reasoning block")
 894                                .log_err()?;
 895
 896                            Some(BedrockInnerContent::ReasoningContent(
 897                                BedrockThinkingBlock::ReasoningText(thinking),
 898                            ))
 899                        }
 900                        MessageContent::RedactedThinking(blob) => {
 901                            if model.contains(Model::DeepSeekR1.request_id()) {
 902                                // DeepSeekR1 doesn't support thinking blocks
 903                                // And the AWS API demands that you strip them
 904                                return None;
 905                            }
 906                            let redacted =
 907                                BedrockThinkingBlock::RedactedContent(BedrockBlob::new(blob));
 908
 909                            Some(BedrockInnerContent::ReasoningContent(redacted))
 910                        }
 911                        MessageContent::ToolUse(tool_use) => {
 912                            messages_contain_tool_content = true;
 913                            let input = if tool_use.input.is_null() {
 914                                // Bedrock API requires valid JsonValue, not null, for tool use input
 915                                value_to_aws_document(&serde_json::json!({}))
 916                            } else {
 917                                value_to_aws_document(&tool_use.input)
 918                            };
 919                            BedrockToolUseBlock::builder()
 920                                .name(tool_use.name.to_string())
 921                                .tool_use_id(tool_use.id.to_string())
 922                                .input(input)
 923                                .build()
 924                                .context("failed to build Bedrock tool use block")
 925                                .log_err()
 926                                .map(BedrockInnerContent::ToolUse)
 927                        }
 928                        MessageContent::ToolResult(tool_result) => {
 929                            messages_contain_tool_content = true;
 930                            BedrockToolResultBlock::builder()
 931                                .tool_use_id(tool_result.tool_use_id.to_string())
 932                                .content(match tool_result.content {
 933                                    LanguageModelToolResultContent::Text(text) => {
 934                                        BedrockToolResultContentBlock::Text(text.to_string())
 935                                    }
 936                                    LanguageModelToolResultContent::Image(image) => {
 937                                        use base64::Engine;
 938
 939                                        match base64::engine::general_purpose::STANDARD
 940                                            .decode(image.source.as_bytes())
 941                                        {
 942                                            Ok(image_bytes) => {
 943                                                match BedrockImageBlock::builder()
 944                                                    .format(BedrockImageFormat::Png)
 945                                                    .source(BedrockImageSource::Bytes(
 946                                                        BedrockBlob::new(image_bytes),
 947                                                    ))
 948                                                    .build()
 949                                                {
 950                                                    Ok(image_block) => {
 951                                                        BedrockToolResultContentBlock::Image(
 952                                                            image_block,
 953                                                        )
 954                                                    }
 955                                                    Err(err) => {
 956                                                        BedrockToolResultContentBlock::Text(
 957                                                            format!(
 958                                                                "[Failed to build image block: {}]",
 959                                                                err
 960                                                            ),
 961                                                        )
 962                                                    }
 963                                                }
 964                                            }
 965                                            Err(err) => {
 966                                                BedrockToolResultContentBlock::Text(format!(
 967                                                    "[Failed to decode tool result image: {}]",
 968                                                    err
 969                                                ))
 970                                            }
 971                                        }
 972                                    }
 973                                })
 974                                .status({
 975                                    if tool_result.is_error {
 976                                        BedrockToolResultStatus::Error
 977                                    } else {
 978                                        BedrockToolResultStatus::Success
 979                                    }
 980                                })
 981                                .build()
 982                                .context("failed to build Bedrock tool result block")
 983                                .log_err()
 984                                .map(BedrockInnerContent::ToolResult)
 985                        }
 986                        MessageContent::Image(image) => {
 987                            use base64::Engine;
 988
 989                            let image_bytes = base64::engine::general_purpose::STANDARD
 990                                .decode(image.source.as_bytes())
 991                                .context("failed to decode base64 image data")
 992                                .log_err()?;
 993
 994                            BedrockImageBlock::builder()
 995                                .format(BedrockImageFormat::Png)
 996                                .source(BedrockImageSource::Bytes(BedrockBlob::new(image_bytes)))
 997                                .build()
 998                                .context("failed to build Bedrock image block")
 999                                .log_err()
1000                                .map(BedrockInnerContent::Image)
1001                        }
1002                    })
1003                    .collect();
1004                if message.cache && supports_caching {
1005                    bedrock_message_content.push(BedrockInnerContent::CachePoint(
1006                        CachePointBlock::builder()
1007                            .r#type(CachePointType::Default)
1008                            .build()
1009                            .context("failed to build cache point block")?,
1010                    ));
1011                }
1012                let bedrock_role = match message.role {
1013                    Role::User => bedrock::BedrockRole::User,
1014                    Role::Assistant => bedrock::BedrockRole::Assistant,
1015                    Role::System => unreachable!("System role should never occur here"),
1016                };
1017                if bedrock_message_content.is_empty() {
1018                    continue;
1019                }
1020
1021                if let Some(last_message) = new_messages.last_mut()
1022                    && last_message.role == bedrock_role
1023                {
1024                    last_message.content.extend(bedrock_message_content);
1025                    continue;
1026                }
1027                new_messages.push(
1028                    BedrockMessage::builder()
1029                        .role(bedrock_role)
1030                        .set_content(Some(bedrock_message_content))
1031                        .build()
1032                        .context("failed to build Bedrock message")?,
1033                );
1034            }
1035            Role::System => {
1036                if !system_message.is_empty() {
1037                    system_message.push_str("\n\n");
1038                }
1039                system_message.push_str(&message.string_contents());
1040            }
1041        }
1042    }
1043
1044    let mut tool_spec: Vec<BedrockTool> = if supports_tool_use {
1045        request
1046            .tools
1047            .iter()
1048            .filter_map(|tool| {
1049                Some(BedrockTool::ToolSpec(
1050                    BedrockToolSpec::builder()
1051                        .name(tool.name.clone())
1052                        .description(tool.description.clone())
1053                        .input_schema(BedrockToolInputSchema::Json(value_to_aws_document(
1054                            &tool.input_schema,
1055                        )))
1056                        .build()
1057                        .log_err()?,
1058                ))
1059            })
1060            .collect()
1061    } else {
1062        Vec::new()
1063    };
1064
1065    // Bedrock requires toolConfig when messages contain tool use/result blocks.
1066    // If no tools are defined but messages contain tool content (e.g., when
1067    // summarising a conversation that used tools), add a dummy tool to satisfy
1068    // the API requirement.
1069    if supports_tool_use && tool_spec.is_empty() && messages_contain_tool_content {
1070        tool_spec.push(BedrockTool::ToolSpec(
1071            BedrockToolSpec::builder()
1072                .name("_placeholder")
1073                .description("Placeholder tool to satisfy Bedrock API requirements when conversation history contains tool usage")
1074                .input_schema(BedrockToolInputSchema::Json(value_to_aws_document(
1075                    &serde_json::json!({"type": "object", "properties": {}}),
1076                )))
1077                .build()
1078                .context("failed to build placeholder tool spec")?,
1079        ));
1080    }
1081
1082    if !tool_spec.is_empty() && supports_caching {
1083        tool_spec.push(BedrockTool::CachePoint(
1084            CachePointBlock::builder()
1085                .r#type(CachePointType::Default)
1086                .build()
1087                .context("failed to build cache point block")?,
1088        ));
1089    }
1090
1091    let tool_choice = match request.tool_choice {
1092        Some(LanguageModelToolChoice::Auto) | None => {
1093            BedrockToolChoice::Auto(BedrockAutoToolChoice::builder().build())
1094        }
1095        Some(LanguageModelToolChoice::Any) => {
1096            BedrockToolChoice::Any(BedrockAnyToolChoice::builder().build())
1097        }
1098        Some(LanguageModelToolChoice::None) => {
1099            // For None, we still use Auto but will filter out tool calls in the response
1100            BedrockToolChoice::Auto(BedrockAutoToolChoice::builder().build())
1101        }
1102    };
1103    let tool_config = if tool_spec.is_empty() {
1104        None
1105    } else {
1106        Some(
1107            BedrockToolConfig::builder()
1108                .set_tools(Some(tool_spec))
1109                .tool_choice(tool_choice)
1110                .build()?,
1111        )
1112    };
1113
1114    Ok(bedrock::Request {
1115        model,
1116        messages: new_messages,
1117        max_tokens: max_output_tokens,
1118        system: Some(system_message),
1119        tools: tool_config,
1120        thinking: if request.thinking_allowed {
1121            match thinking_mode {
1122                BedrockModelMode::Thinking { budget_tokens } => {
1123                    Some(bedrock::Thinking::Enabled { budget_tokens })
1124                }
1125                BedrockModelMode::AdaptiveThinking {
1126                    effort: default_effort,
1127                } => {
1128                    let effort = request
1129                        .thinking_effort
1130                        .as_deref()
1131                        .and_then(|e| match e {
1132                            "low" => Some(bedrock::BedrockAdaptiveThinkingEffort::Low),
1133                            "medium" => Some(bedrock::BedrockAdaptiveThinkingEffort::Medium),
1134                            "high" => Some(bedrock::BedrockAdaptiveThinkingEffort::High),
1135                            "max" => Some(bedrock::BedrockAdaptiveThinkingEffort::Max),
1136                            _ => None,
1137                        })
1138                        .unwrap_or(default_effort);
1139                    Some(bedrock::Thinking::Adaptive { effort })
1140                }
1141                BedrockModelMode::Default => None,
1142            }
1143        } else {
1144            None
1145        },
1146        metadata: None,
1147        stop_sequences: Vec::new(),
1148        temperature: request.temperature.or(Some(default_temperature)),
1149        top_k: None,
1150        top_p: None,
1151        allow_extended_context,
1152    })
1153}
1154
1155// TODO: just call the ConverseOutput.usage() method:
1156// https://docs.rs/aws-sdk-bedrockruntime/latest/aws_sdk_bedrockruntime/operation/converse/struct.ConverseOutput.html#method.output
1157pub fn get_bedrock_tokens(
1158    request: LanguageModelRequest,
1159    cx: &App,
1160) -> BoxFuture<'static, Result<u64>> {
1161    cx.background_executor()
1162        .spawn(async move {
1163            let messages = request.messages;
1164            let mut tokens_from_images = 0;
1165            let mut string_messages = Vec::with_capacity(messages.len());
1166
1167            for message in messages {
1168                use language_model::MessageContent;
1169
1170                let mut string_contents = String::new();
1171
1172                for content in message.content {
1173                    match content {
1174                        MessageContent::Text(text) | MessageContent::Thinking { text, .. } => {
1175                            string_contents.push_str(&text);
1176                        }
1177                        MessageContent::RedactedThinking(_) => {}
1178                        MessageContent::Image(image) => {
1179                            tokens_from_images += image.estimate_tokens();
1180                        }
1181                        MessageContent::ToolUse(_tool_use) => {
1182                            // TODO: Estimate token usage from tool uses.
1183                        }
1184                        MessageContent::ToolResult(tool_result) => match tool_result.content {
1185                            LanguageModelToolResultContent::Text(text) => {
1186                                string_contents.push_str(&text);
1187                            }
1188                            LanguageModelToolResultContent::Image(image) => {
1189                                tokens_from_images += image.estimate_tokens();
1190                            }
1191                        },
1192                    }
1193                }
1194
1195                if !string_contents.is_empty() {
1196                    string_messages.push(tiktoken_rs::ChatCompletionRequestMessage {
1197                        role: match message.role {
1198                            Role::User => "user".into(),
1199                            Role::Assistant => "assistant".into(),
1200                            Role::System => "system".into(),
1201                        },
1202                        content: Some(string_contents),
1203                        name: None,
1204                        function_call: None,
1205                    });
1206                }
1207            }
1208
1209            // Tiktoken doesn't yet support these models, so we manually use the
1210            // same tokenizer as GPT-4.
1211            tiktoken_rs::num_tokens_from_messages("gpt-4", &string_messages)
1212                .map(|tokens| (tokens + tokens_from_images) as u64)
1213        })
1214        .boxed()
1215}
1216
1217pub fn map_to_language_model_completion_events(
1218    events: Pin<Box<dyn Send + Stream<Item = Result<BedrockStreamingResponse, anyhow::Error>>>>,
1219) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
1220    struct RawToolUse {
1221        id: String,
1222        name: String,
1223        input_json: String,
1224    }
1225
1226    struct State {
1227        events: Pin<Box<dyn Send + Stream<Item = Result<BedrockStreamingResponse, anyhow::Error>>>>,
1228        tool_uses_by_index: HashMap<i32, RawToolUse>,
1229        emitted_tool_use: bool,
1230    }
1231
1232    let initial_state = State {
1233        events,
1234        tool_uses_by_index: HashMap::default(),
1235        emitted_tool_use: false,
1236    };
1237
1238    futures::stream::unfold(initial_state, |mut state| async move {
1239        match state.events.next().await {
1240            Some(event_result) => match event_result {
1241                Ok(event) => {
1242                    let result = match event {
1243                        ConverseStreamOutput::ContentBlockDelta(cb_delta) => match cb_delta.delta {
1244                            Some(ContentBlockDelta::Text(text)) => {
1245                                Some(Ok(LanguageModelCompletionEvent::Text(text)))
1246                            }
1247                            Some(ContentBlockDelta::ToolUse(tool_output)) => {
1248                                if let Some(tool_use) = state
1249                                    .tool_uses_by_index
1250                                    .get_mut(&cb_delta.content_block_index)
1251                                {
1252                                    tool_use.input_json.push_str(tool_output.input());
1253                                    if let Ok(input) = serde_json::from_str::<serde_json::Value>(
1254                                        &fix_streamed_json(&tool_use.input_json),
1255                                    ) {
1256                                        Some(Ok(LanguageModelCompletionEvent::ToolUse(
1257                                            LanguageModelToolUse {
1258                                                id: tool_use.id.clone().into(),
1259                                                name: tool_use.name.clone().into(),
1260                                                is_input_complete: false,
1261                                                raw_input: tool_use.input_json.clone(),
1262                                                input,
1263                                                thought_signature: None,
1264                                            },
1265                                        )))
1266                                    } else {
1267                                        None
1268                                    }
1269                                } else {
1270                                    None
1271                                }
1272                            }
1273                            Some(ContentBlockDelta::ReasoningContent(thinking)) => match thinking {
1274                                ReasoningContentBlockDelta::Text(thoughts) => {
1275                                    Some(Ok(LanguageModelCompletionEvent::Thinking {
1276                                        text: thoughts,
1277                                        signature: None,
1278                                    }))
1279                                }
1280                                ReasoningContentBlockDelta::Signature(sig) => {
1281                                    Some(Ok(LanguageModelCompletionEvent::Thinking {
1282                                        text: "".into(),
1283                                        signature: Some(sig),
1284                                    }))
1285                                }
1286                                ReasoningContentBlockDelta::RedactedContent(redacted) => {
1287                                    let content = String::from_utf8(redacted.into_inner())
1288                                        .unwrap_or("REDACTED".to_string());
1289                                    Some(Ok(LanguageModelCompletionEvent::Thinking {
1290                                        text: content,
1291                                        signature: None,
1292                                    }))
1293                                }
1294                                _ => None,
1295                            },
1296                            _ => None,
1297                        },
1298                        ConverseStreamOutput::ContentBlockStart(cb_start) => {
1299                            if let Some(ContentBlockStart::ToolUse(tool_start)) = cb_start.start {
1300                                state.tool_uses_by_index.insert(
1301                                    cb_start.content_block_index,
1302                                    RawToolUse {
1303                                        id: tool_start.tool_use_id,
1304                                        name: tool_start.name,
1305                                        input_json: String::new(),
1306                                    },
1307                                );
1308                            }
1309                            None
1310                        }
1311                        ConverseStreamOutput::MessageStart(_) => None,
1312                        ConverseStreamOutput::ContentBlockStop(cb_stop) => state
1313                            .tool_uses_by_index
1314                            .remove(&cb_stop.content_block_index)
1315                            .map(|tool_use| {
1316                                state.emitted_tool_use = true;
1317
1318                                let input = parse_tool_arguments(&tool_use.input_json)
1319                                    .unwrap_or_else(|_| Value::Object(Default::default()));
1320
1321                                Ok(LanguageModelCompletionEvent::ToolUse(
1322                                    LanguageModelToolUse {
1323                                        id: tool_use.id.into(),
1324                                        name: tool_use.name.into(),
1325                                        is_input_complete: true,
1326                                        raw_input: tool_use.input_json,
1327                                        input,
1328                                        thought_signature: None,
1329                                    },
1330                                ))
1331                            }),
1332                        ConverseStreamOutput::Metadata(cb_meta) => cb_meta.usage.map(|metadata| {
1333                            Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
1334                                input_tokens: metadata.input_tokens as u64,
1335                                output_tokens: metadata.output_tokens as u64,
1336                                cache_creation_input_tokens: metadata
1337                                    .cache_write_input_tokens
1338                                    .unwrap_or_default()
1339                                    as u64,
1340                                cache_read_input_tokens: metadata
1341                                    .cache_read_input_tokens
1342                                    .unwrap_or_default()
1343                                    as u64,
1344                            }))
1345                        }),
1346                        ConverseStreamOutput::MessageStop(message_stop) => {
1347                            let stop_reason = if state.emitted_tool_use {
1348                                // Some models (e.g. Kimi) send EndTurn even when
1349                                // they've made tool calls. Trust the content over
1350                                // the stop reason.
1351                                language_model::StopReason::ToolUse
1352                            } else {
1353                                match message_stop.stop_reason {
1354                                    StopReason::ToolUse => language_model::StopReason::ToolUse,
1355                                    _ => language_model::StopReason::EndTurn,
1356                                }
1357                            };
1358                            Some(Ok(LanguageModelCompletionEvent::Stop(stop_reason)))
1359                        }
1360                        _ => None,
1361                    };
1362
1363                    Some((result, state))
1364                }
1365                Err(err) => Some((
1366                    Some(Err(LanguageModelCompletionError::Other(anyhow!(err)))),
1367                    state,
1368                )),
1369            },
1370            None => None,
1371        }
1372    })
1373    .filter_map(|result| async move { result })
1374}
1375
1376struct ConfigurationView {
1377    access_key_id_editor: Entity<InputField>,
1378    secret_access_key_editor: Entity<InputField>,
1379    session_token_editor: Entity<InputField>,
1380    bearer_token_editor: Entity<InputField>,
1381    state: Entity<State>,
1382    load_credentials_task: Option<Task<()>>,
1383    focus_handle: FocusHandle,
1384}
1385
1386impl ConfigurationView {
1387    const PLACEHOLDER_ACCESS_KEY_ID_TEXT: &'static str = "XXXXXXXXXXXXXXXX";
1388    const PLACEHOLDER_SECRET_ACCESS_KEY_TEXT: &'static str =
1389        "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX";
1390    const PLACEHOLDER_SESSION_TOKEN_TEXT: &'static str = "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX";
1391    const PLACEHOLDER_BEARER_TOKEN_TEXT: &'static str = "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX";
1392
1393    fn new(state: Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
1394        let focus_handle = cx.focus_handle();
1395
1396        cx.observe(&state, |_, _, cx| {
1397            cx.notify();
1398        })
1399        .detach();
1400
1401        let access_key_id_editor = cx.new(|cx| {
1402            InputField::new(window, cx, Self::PLACEHOLDER_ACCESS_KEY_ID_TEXT)
1403                .label("Access Key ID")
1404                .tab_index(0)
1405                .tab_stop(true)
1406        });
1407
1408        let secret_access_key_editor = cx.new(|cx| {
1409            InputField::new(window, cx, Self::PLACEHOLDER_SECRET_ACCESS_KEY_TEXT)
1410                .label("Secret Access Key")
1411                .tab_index(1)
1412                .tab_stop(true)
1413        });
1414
1415        let session_token_editor = cx.new(|cx| {
1416            InputField::new(window, cx, Self::PLACEHOLDER_SESSION_TOKEN_TEXT)
1417                .label("Session Token (Optional)")
1418                .tab_index(2)
1419                .tab_stop(true)
1420        });
1421
1422        let bearer_token_editor = cx.new(|cx| {
1423            InputField::new(window, cx, Self::PLACEHOLDER_BEARER_TOKEN_TEXT)
1424                .label("Bedrock API Key")
1425                .tab_index(3)
1426                .tab_stop(true)
1427        });
1428
1429        let load_credentials_task = Some(cx.spawn({
1430            let state = state.clone();
1431            async move |this, cx| {
1432                if let Some(task) = Some(state.update(cx, |state, cx| state.authenticate(cx))) {
1433                    // We don't log an error, because "not signed in" is also an error.
1434                    let _ = task.await;
1435                }
1436                this.update(cx, |this, cx| {
1437                    this.load_credentials_task = None;
1438                    cx.notify();
1439                })
1440                .log_err();
1441            }
1442        }));
1443
1444        Self {
1445            access_key_id_editor,
1446            secret_access_key_editor,
1447            session_token_editor,
1448            bearer_token_editor,
1449            state,
1450            load_credentials_task,
1451            focus_handle,
1452        }
1453    }
1454
1455    fn save_credentials(
1456        &mut self,
1457        _: &menu::Confirm,
1458        _window: &mut Window,
1459        cx: &mut Context<Self>,
1460    ) {
1461        let access_key_id = self
1462            .access_key_id_editor
1463            .read(cx)
1464            .text(cx)
1465            .trim()
1466            .to_string();
1467        let secret_access_key = self
1468            .secret_access_key_editor
1469            .read(cx)
1470            .text(cx)
1471            .trim()
1472            .to_string();
1473        let session_token = self
1474            .session_token_editor
1475            .read(cx)
1476            .text(cx)
1477            .trim()
1478            .to_string();
1479        let session_token = if session_token.is_empty() {
1480            None
1481        } else {
1482            Some(session_token)
1483        };
1484        let bearer_token = self
1485            .bearer_token_editor
1486            .read(cx)
1487            .text(cx)
1488            .trim()
1489            .to_string();
1490        let bearer_token = if bearer_token.is_empty() {
1491            None
1492        } else {
1493            Some(bearer_token)
1494        };
1495
1496        let state = self.state.clone();
1497        cx.spawn(async move |_, cx| {
1498            state
1499                .update(cx, |state, cx| {
1500                    let credentials = BedrockCredentials {
1501                        access_key_id,
1502                        secret_access_key,
1503                        session_token,
1504                        bearer_token,
1505                    };
1506
1507                    state.set_static_credentials(credentials, cx)
1508                })
1509                .await
1510        })
1511        .detach_and_log_err(cx);
1512    }
1513
1514    fn reset_credentials(&mut self, window: &mut Window, cx: &mut Context<Self>) {
1515        self.access_key_id_editor
1516            .update(cx, |editor, cx| editor.set_text("", window, cx));
1517        self.secret_access_key_editor
1518            .update(cx, |editor, cx| editor.set_text("", window, cx));
1519        self.session_token_editor
1520            .update(cx, |editor, cx| editor.set_text("", window, cx));
1521        self.bearer_token_editor
1522            .update(cx, |editor, cx| editor.set_text("", window, cx));
1523
1524        let state = self.state.clone();
1525        cx.spawn(async move |_, cx| state.update(cx, |state, cx| state.reset_auth(cx)).await)
1526            .detach_and_log_err(cx);
1527    }
1528
1529    fn should_render_editor(&self, cx: &Context<Self>) -> bool {
1530        self.state.read(cx).is_authenticated()
1531    }
1532
1533    fn on_tab(&mut self, _: &menu::SelectNext, window: &mut Window, cx: &mut Context<Self>) {
1534        window.focus_next(cx);
1535    }
1536
1537    fn on_tab_prev(
1538        &mut self,
1539        _: &menu::SelectPrevious,
1540        window: &mut Window,
1541        cx: &mut Context<Self>,
1542    ) {
1543        window.focus_prev(cx);
1544    }
1545}
1546
1547impl Render for ConfigurationView {
1548    fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
1549        let state = self.state.read(cx);
1550        let env_var_set = state.credentials_from_env;
1551        let auth = state.auth.clone();
1552        let settings_auth_method = state
1553            .settings
1554            .as_ref()
1555            .and_then(|s| s.authentication_method.clone());
1556
1557        if self.load_credentials_task.is_some() {
1558            return div().child(Label::new("Loading credentials...")).into_any();
1559        }
1560
1561        let configured_label = match &auth {
1562            Some(BedrockAuth::Automatic) => {
1563                "Using automatic credentials (AWS default chain)".into()
1564            }
1565            Some(BedrockAuth::NamedProfile { profile_name }) => {
1566                format!("Using AWS profile: {profile_name}")
1567            }
1568            Some(BedrockAuth::SingleSignOn { profile_name }) => {
1569                format!("Using AWS SSO profile: {profile_name}")
1570            }
1571            Some(BedrockAuth::IamCredentials { .. }) if env_var_set => {
1572                format!(
1573                    "Using IAM credentials from {} and {} environment variables",
1574                    ZED_BEDROCK_ACCESS_KEY_ID_VAR.name, ZED_BEDROCK_SECRET_ACCESS_KEY_VAR.name
1575                )
1576            }
1577            Some(BedrockAuth::IamCredentials { .. }) => "Using IAM credentials".into(),
1578            Some(BedrockAuth::ApiKey { .. }) if env_var_set => {
1579                format!(
1580                    "Using Bedrock API Key from {} environment variable",
1581                    ZED_BEDROCK_BEARER_TOKEN_VAR.name
1582                )
1583            }
1584            Some(BedrockAuth::ApiKey { .. }) => "Using Bedrock API Key".into(),
1585            None => "Not authenticated".into(),
1586        };
1587
1588        // Determine if credentials can be reset
1589        // Settings-derived auth (non-ApiKey) cannot be reset from UI
1590        let is_settings_derived = matches!(
1591            settings_auth_method,
1592            Some(BedrockAuthMethod::Automatic)
1593                | Some(BedrockAuthMethod::NamedProfile)
1594                | Some(BedrockAuthMethod::SingleSignOn)
1595        );
1596
1597        let tooltip_label = if env_var_set {
1598            Some(format!(
1599                "To reset your credentials, unset the {}, {}, and {} or {} environment variables.",
1600                ZED_BEDROCK_ACCESS_KEY_ID_VAR.name,
1601                ZED_BEDROCK_SECRET_ACCESS_KEY_VAR.name,
1602                ZED_BEDROCK_SESSION_TOKEN_VAR.name,
1603                ZED_BEDROCK_BEARER_TOKEN_VAR.name
1604            ))
1605        } else if is_settings_derived {
1606            Some(
1607                "Authentication method is configured in settings. Edit settings.json to change."
1608                    .to_string(),
1609            )
1610        } else {
1611            None
1612        };
1613
1614        if self.should_render_editor(cx) {
1615            return ConfiguredApiCard::new(configured_label)
1616                .disabled(env_var_set || is_settings_derived)
1617                .on_click(cx.listener(|this, _, window, cx| this.reset_credentials(window, cx)))
1618                .when_some(tooltip_label, |this, label| this.tooltip_label(label))
1619                .into_any_element();
1620        }
1621
1622        v_flex()
1623            .min_w_0()
1624            .w_full()
1625            .track_focus(&self.focus_handle)
1626            .on_action(cx.listener(Self::on_tab))
1627            .on_action(cx.listener(Self::on_tab_prev))
1628            .on_action(cx.listener(ConfigurationView::save_credentials))
1629            .child(Label::new("To use Zed's agent with Bedrock, you can set a custom authentication strategy through your settings file or use static credentials."))
1630            .child(Label::new("But first, to access models on AWS, you need to:").mt_1())
1631            .child(
1632                List::new()
1633                    .child(
1634                        ListBulletItem::new("")
1635                            .child(Label::new(
1636                                "Grant permissions to the strategy you'll use according to the:",
1637                            ))
1638                            .child(ButtonLink::new(
1639                                "Prerequisites",
1640                                "https://docs.aws.amazon.com/bedrock/latest/userguide/inference-prereq.html",
1641                            )),
1642                    )
1643                    .child(
1644                        ListBulletItem::new("")
1645                            .child(Label::new("Select the models you would like access to:"))
1646                            .child(ButtonLink::new(
1647                                "Bedrock Model Catalog",
1648                                "https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/model-catalog",
1649                            )),
1650                    ),
1651            )
1652            .child(self.render_static_credentials_ui())
1653            .into_any()
1654    }
1655}
1656
1657impl ConfigurationView {
1658    fn render_static_credentials_ui(&self) -> impl IntoElement {
1659        let section_header = |title: SharedString| {
1660            h_flex()
1661                .gap_2()
1662                .child(Label::new(title).size(LabelSize::Default))
1663                .child(Divider::horizontal())
1664        };
1665
1666        let list_item = List::new()
1667            .child(
1668                ListBulletItem::new("")
1669                    .child(Label::new(
1670                        "For access keys: Create an IAM user in the AWS console with programmatic access",
1671                    ))
1672                    .child(ButtonLink::new(
1673                        "IAM Console",
1674                        "https://us-east-1.console.aws.amazon.com/iam/home?region=us-east-1#/users",
1675                    )),
1676            )
1677            .child(
1678                ListBulletItem::new("")
1679                    .child(Label::new("For Bedrock API Keys: Generate an API key from the"))
1680                    .child(ButtonLink::new(
1681                        "Bedrock Console",
1682                        "https://docs.aws.amazon.com/bedrock/latest/userguide/api-keys-use.html",
1683                    )),
1684            )
1685            .child(
1686                ListBulletItem::new("")
1687                    .child(Label::new("Attach the necessary Bedrock permissions to"))
1688                    .child(ButtonLink::new(
1689                        "this user",
1690                        "https://docs.aws.amazon.com/bedrock/latest/userguide/inference-prereq.html",
1691                    )),
1692            )
1693            .child(ListBulletItem::new(
1694                "Enter either access keys OR a Bedrock API Key below (not both)",
1695            ));
1696
1697        v_flex()
1698            .my_2()
1699            .tab_group()
1700            .gap_1p5()
1701            .child(section_header("Static Credentials".into()))
1702            .child(Label::new(
1703                "This method uses your AWS access key ID and secret access key, or a Bedrock API Key.",
1704            ))
1705            .child(list_item)
1706            .child(self.access_key_id_editor.clone())
1707            .child(self.secret_access_key_editor.clone())
1708            .child(self.session_token_editor.clone())
1709            .child(
1710                Label::new(format!(
1711                    "You can also set the {}, {} and {} environment variables (or {} for Bedrock API Key authentication) and restart Zed.",
1712                    ZED_BEDROCK_ACCESS_KEY_ID_VAR.name,
1713                    ZED_BEDROCK_SECRET_ACCESS_KEY_VAR.name,
1714                    ZED_BEDROCK_REGION_VAR.name,
1715                    ZED_BEDROCK_BEARER_TOKEN_VAR.name
1716                ))
1717                .size(LabelSize::Small)
1718                .color(Color::Muted),
1719            )
1720            .child(
1721                Label::new(format!(
1722                    "Optionally, if your environment uses AWS CLI profiles, you can set {}; if it requires a custom endpoint, you can set {}; and if it requires a Session Token, you can set {}.",
1723                    ZED_AWS_PROFILE_VAR.name,
1724                    ZED_AWS_ENDPOINT_VAR.name,
1725                    ZED_BEDROCK_SESSION_TOKEN_VAR.name
1726                ))
1727                .size(LabelSize::Small)
1728                .color(Color::Muted)
1729                .mt_1()
1730                .mb_2p5(),
1731            )
1732            .child(section_header("Using the an API key".into()))
1733            .child(self.bearer_token_editor.clone())
1734            .child(
1735                Label::new(format!(
1736                    "Region is configured via {} environment variable or settings.json (defaults to us-east-1).",
1737                    ZED_BEDROCK_REGION_VAR.name
1738                ))
1739                .size(LabelSize::Small)
1740                .color(Color::Muted)
1741            )
1742    }
1743}