bedrock.rs

   1use std::pin::Pin;
   2use std::str::FromStr;
   3use std::sync::Arc;
   4
   5use crate::ui::InstructionListItem;
   6use anyhow::{Context as _, Result, anyhow};
   7use aws_config::Region;
   8use aws_config::stalled_stream_protection::StalledStreamProtectionConfig;
   9use aws_credential_types::Credentials;
  10use aws_http_client::AwsHttpClient;
  11use bedrock::bedrock_client::types::{
  12    ContentBlockDelta, ContentBlockStart, ContentBlockStartEvent, ConverseStreamOutput,
  13};
  14use bedrock::bedrock_client::{self, Config};
  15use bedrock::{BedrockError, BedrockInnerContent, BedrockMessage, BedrockStreamingResponse, Model};
  16use collections::{BTreeMap, HashMap};
  17use credentials_provider::CredentialsProvider;
  18use editor::{Editor, EditorElement, EditorStyle};
  19use futures::{FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream};
  20use gpui::{
  21    AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace,
  22};
  23use gpui_tokio::Tokio;
  24use http_client::HttpClient;
  25use language_model::{
  26    AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
  27    LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
  28    LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
  29    LanguageModelRequest, LanguageModelToolUse, MessageContent, RateLimiter, Role,
  30};
  31use schemars::JsonSchema;
  32use serde::{Deserialize, Serialize};
  33use serde_json::Value;
  34use settings::{Settings, SettingsStore};
  35use strum::IntoEnumIterator;
  36use theme::ThemeSettings;
  37use tokio::runtime::Handle;
  38use ui::{Icon, IconName, List, Tooltip, prelude::*};
  39use util::{ResultExt, maybe};
  40
  41use crate::AllLanguageModelSettings;
  42
  43const PROVIDER_ID: &str = "amazon-bedrock";
  44const PROVIDER_NAME: &str = "Amazon Bedrock";
  45
  46#[derive(Default, Clone, Deserialize, Serialize, PartialEq, Debug)]
  47pub struct BedrockCredentials {
  48    pub region: String,
  49    pub access_key_id: String,
  50    pub secret_access_key: String,
  51}
  52
  53#[derive(Default, Clone, Debug, PartialEq)]
  54pub struct AmazonBedrockSettings {
  55    pub session_token: Option<String>,
  56    pub available_models: Vec<AvailableModel>,
  57}
  58
  59#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
  60pub struct AvailableModel {
  61    pub name: String,
  62    pub display_name: Option<String>,
  63    pub max_tokens: usize,
  64    pub cache_configuration: Option<LanguageModelCacheConfiguration>,
  65    pub max_output_tokens: Option<u32>,
  66    pub default_temperature: Option<f32>,
  67}
  68
  69/// The URL of the base AWS service.
  70///
  71/// Right now we're just using this as the key to store the AWS credentials
  72/// under in the keychain.
  73const AMAZON_AWS_URL: &str = "https://amazonaws.com";
  74
  75// These environment variables all use a `ZED_` prefix because we don't want to overwrite the user's AWS credentials.
  76const ZED_BEDROCK_ACCESS_KEY_ID_VAR: &str = "ZED_ACCESS_KEY_ID";
  77const ZED_BEDROCK_SECRET_ACCESS_KEY_VAR: &str = "ZED_SECRET_ACCESS_KEY";
  78const ZED_BEDROCK_REGION_VAR: &str = "ZED_AWS_REGION";
  79const ZED_AWS_CREDENTIALS_VAR: &str = "ZED_AWS_CREDENTIALS";
  80
  81pub struct State {
  82    credentials: Option<BedrockCredentials>,
  83    credentials_from_env: bool,
  84    _subscription: Subscription,
  85}
  86
  87impl State {
  88    fn reset_credentials(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
  89        let credentials_provider = <dyn CredentialsProvider>::global(cx);
  90        cx.spawn(async move |this, cx| {
  91            credentials_provider
  92                .delete_credentials(AMAZON_AWS_URL, &cx)
  93                .await
  94                .log_err();
  95            this.update(cx, |this, cx| {
  96                this.credentials = None;
  97                this.credentials_from_env = false;
  98                cx.notify();
  99            })
 100        })
 101    }
 102
 103    fn set_credentials(
 104        &mut self,
 105        credentials: BedrockCredentials,
 106        cx: &mut Context<Self>,
 107    ) -> Task<Result<()>> {
 108        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 109        cx.spawn(async move |this, cx| {
 110            credentials_provider
 111                .write_credentials(
 112                    AMAZON_AWS_URL,
 113                    "Bearer",
 114                    &serde_json::to_vec(&credentials)?,
 115                    &cx,
 116                )
 117                .await?;
 118            this.update(cx, |this, cx| {
 119                this.credentials = Some(credentials);
 120                cx.notify();
 121            })
 122        })
 123    }
 124
 125    fn is_authenticated(&self) -> bool {
 126        self.credentials.is_some()
 127    }
 128
 129    fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
 130        if self.is_authenticated() {
 131            return Task::ready(Ok(()));
 132        }
 133
 134        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 135        cx.spawn(async move |this, cx| {
 136            let (credentials, from_env) =
 137                if let Ok(credentials) = std::env::var(ZED_AWS_CREDENTIALS_VAR) {
 138                    (credentials, true)
 139                } else {
 140                    let (_, credentials) = credentials_provider
 141                        .read_credentials(AMAZON_AWS_URL, &cx)
 142                        .await?
 143                        .ok_or_else(|| AuthenticateError::CredentialsNotFound)?;
 144                    (
 145                        String::from_utf8(credentials)
 146                            .context("invalid {PROVIDER_NAME} credentials")?,
 147                        false,
 148                    )
 149                };
 150
 151            let credentials: BedrockCredentials =
 152                serde_json::from_str(&credentials).context("failed to parse credentials")?;
 153
 154            this.update(cx, |this, cx| {
 155                this.credentials = Some(credentials);
 156                this.credentials_from_env = from_env;
 157                cx.notify();
 158            })?;
 159
 160            Ok(())
 161        })
 162    }
 163}
 164
 165pub struct BedrockLanguageModelProvider {
 166    http_client: AwsHttpClient,
 167    handler: tokio::runtime::Handle,
 168    state: gpui::Entity<State>,
 169}
 170
 171impl BedrockLanguageModelProvider {
 172    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
 173        let state = cx.new(|cx| State {
 174            credentials: None,
 175            credentials_from_env: false,
 176            _subscription: cx.observe_global::<SettingsStore>(|_, cx| {
 177                cx.notify();
 178            }),
 179        });
 180
 181        let tokio_handle = Tokio::handle(cx);
 182
 183        let coerced_client = AwsHttpClient::new(http_client.clone(), tokio_handle.clone());
 184
 185        Self {
 186            http_client: coerced_client,
 187            handler: tokio_handle.clone(),
 188            state,
 189        }
 190    }
 191}
 192
 193impl LanguageModelProvider for BedrockLanguageModelProvider {
 194    fn id(&self) -> LanguageModelProviderId {
 195        LanguageModelProviderId(PROVIDER_ID.into())
 196    }
 197
 198    fn name(&self) -> LanguageModelProviderName {
 199        LanguageModelProviderName(PROVIDER_NAME.into())
 200    }
 201
 202    fn icon(&self) -> IconName {
 203        IconName::AiBedrock
 204    }
 205
 206    fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
 207        let model = bedrock::Model::default();
 208        Some(Arc::new(BedrockModel {
 209            id: LanguageModelId::from(model.id().to_string()),
 210            model,
 211            http_client: self.http_client.clone(),
 212            handler: self.handler.clone(),
 213            state: self.state.clone(),
 214            request_limiter: RateLimiter::new(4),
 215        }))
 216    }
 217
 218    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 219        let mut models = BTreeMap::default();
 220
 221        for model in bedrock::Model::iter() {
 222            if !matches!(model, bedrock::Model::Custom { .. }) {
 223                models.insert(model.id().to_string(), model);
 224            }
 225        }
 226
 227        // Override with available models from settings
 228        for model in AllLanguageModelSettings::get_global(cx)
 229            .bedrock
 230            .available_models
 231            .iter()
 232        {
 233            models.insert(
 234                model.name.clone(),
 235                bedrock::Model::Custom {
 236                    name: model.name.clone(),
 237                    display_name: model.display_name.clone(),
 238                    max_tokens: model.max_tokens,
 239                    max_output_tokens: model.max_output_tokens,
 240                    default_temperature: model.default_temperature,
 241                },
 242            );
 243        }
 244
 245        models
 246            .into_values()
 247            .map(|model| {
 248                Arc::new(BedrockModel {
 249                    id: LanguageModelId::from(model.id().to_string()),
 250                    model,
 251                    http_client: self.http_client.clone(),
 252                    handler: self.handler.clone(),
 253                    state: self.state.clone(),
 254                    request_limiter: RateLimiter::new(4),
 255                }) as Arc<dyn LanguageModel>
 256            })
 257            .collect()
 258    }
 259
 260    fn is_authenticated(&self, cx: &App) -> bool {
 261        self.state.read(cx).is_authenticated()
 262    }
 263
 264    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
 265        self.state.update(cx, |state, cx| state.authenticate(cx))
 266    }
 267
 268    fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView {
 269        cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
 270            .into()
 271    }
 272
 273    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
 274        self.state
 275            .update(cx, |state, cx| state.reset_credentials(cx))
 276    }
 277}
 278
 279impl LanguageModelProviderState for BedrockLanguageModelProvider {
 280    type ObservableEntity = State;
 281
 282    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
 283        Some(self.state.clone())
 284    }
 285}
 286
 287struct BedrockModel {
 288    id: LanguageModelId,
 289    model: Model,
 290    http_client: AwsHttpClient,
 291    handler: tokio::runtime::Handle,
 292    state: gpui::Entity<State>,
 293    request_limiter: RateLimiter,
 294}
 295
 296impl BedrockModel {
 297    fn stream_completion(
 298        &self,
 299        request: bedrock::Request,
 300        cx: &AsyncApp,
 301    ) -> Result<
 302        BoxFuture<'static, BoxStream<'static, Result<BedrockStreamingResponse, BedrockError>>>,
 303    > {
 304        let Ok(Ok((access_key_id, secret_access_key, region))) =
 305            cx.read_entity(&self.state, |state, _cx| {
 306                if let Some(credentials) = &state.credentials {
 307                    Ok((
 308                        credentials.access_key_id.clone(),
 309                        credentials.secret_access_key.clone(),
 310                        credentials.region.clone(),
 311                    ))
 312                } else {
 313                    return Err(anyhow!("Failed to read credentials"));
 314                }
 315            })
 316        else {
 317            return Err(anyhow!("App state dropped"));
 318        };
 319
 320        let runtime_client = bedrock_client::Client::from_conf(
 321            Config::builder()
 322                .stalled_stream_protection(StalledStreamProtectionConfig::disabled())
 323                .credentials_provider(Credentials::new(
 324                    access_key_id,
 325                    secret_access_key,
 326                    None,
 327                    None,
 328                    "Keychain",
 329                ))
 330                .region(Region::new(region))
 331                .http_client(self.http_client.clone())
 332                .build(),
 333        );
 334
 335        let owned_handle = self.handler.clone();
 336
 337        Ok(async move {
 338            let request = bedrock::stream_completion(runtime_client, request, owned_handle);
 339            request.await.unwrap_or_else(|e| {
 340                futures::stream::once(async move { Err(BedrockError::ClientError(e)) }).boxed()
 341            })
 342        }
 343        .boxed())
 344    }
 345}
 346
 347impl LanguageModel for BedrockModel {
 348    fn id(&self) -> LanguageModelId {
 349        self.id.clone()
 350    }
 351
 352    fn name(&self) -> LanguageModelName {
 353        LanguageModelName::from(self.model.display_name().to_string())
 354    }
 355
 356    fn provider_id(&self) -> LanguageModelProviderId {
 357        LanguageModelProviderId(PROVIDER_ID.into())
 358    }
 359
 360    fn provider_name(&self) -> LanguageModelProviderName {
 361        LanguageModelProviderName(PROVIDER_NAME.into())
 362    }
 363
 364    fn supports_tools(&self) -> bool {
 365        true
 366    }
 367
 368    fn telemetry_id(&self) -> String {
 369        format!("bedrock/{}", self.model.id())
 370    }
 371
 372    fn max_token_count(&self) -> usize {
 373        self.model.max_token_count()
 374    }
 375
 376    fn max_output_tokens(&self) -> Option<u32> {
 377        Some(self.model.max_output_tokens())
 378    }
 379
 380    fn count_tokens(
 381        &self,
 382        request: LanguageModelRequest,
 383        cx: &App,
 384    ) -> BoxFuture<'static, Result<usize>> {
 385        get_bedrock_tokens(request, cx)
 386    }
 387
 388    fn stream_completion(
 389        &self,
 390        request: LanguageModelRequest,
 391        cx: &AsyncApp,
 392    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
 393        let request = into_bedrock(
 394            request,
 395            self.model.id().into(),
 396            self.model.default_temperature(),
 397            self.model.max_output_tokens(),
 398        );
 399
 400        let owned_handle = self.handler.clone();
 401
 402        let request = self.stream_completion(request, cx);
 403        let future = self.request_limiter.stream(async move {
 404            let response = request.map_err(|err| anyhow!(err))?.await;
 405            Ok(map_to_language_model_completion_events(
 406                response,
 407                owned_handle,
 408            ))
 409        });
 410        async move { Ok(future.await?.boxed()) }.boxed()
 411    }
 412
 413    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
 414        None
 415    }
 416}
 417
 418pub fn into_bedrock(
 419    request: LanguageModelRequest,
 420    model: String,
 421    default_temperature: f32,
 422    max_output_tokens: u32,
 423) -> bedrock::Request {
 424    let mut new_messages: Vec<BedrockMessage> = Vec::new();
 425    let mut system_message = String::new();
 426
 427    for message in request.messages {
 428        if message.contents_empty() {
 429            continue;
 430        }
 431
 432        match message.role {
 433            Role::User | Role::Assistant => {
 434                let bedrock_message_content: Vec<BedrockInnerContent> = message
 435                    .content
 436                    .into_iter()
 437                    .filter_map(|content| match content {
 438                        MessageContent::Text(text) => {
 439                            if !text.is_empty() {
 440                                Some(BedrockInnerContent::Text(text))
 441                            } else {
 442                                None
 443                            }
 444                        }
 445                        _ => None,
 446                    })
 447                    .collect();
 448                let bedrock_role = match message.role {
 449                    Role::User => bedrock::BedrockRole::User,
 450                    Role::Assistant => bedrock::BedrockRole::Assistant,
 451                    Role::System => unreachable!("System role should never occur here"),
 452                };
 453                if let Some(last_message) = new_messages.last_mut() {
 454                    if last_message.role == bedrock_role {
 455                        last_message.content.extend(bedrock_message_content);
 456                        continue;
 457                    }
 458                }
 459                new_messages.push(
 460                    BedrockMessage::builder()
 461                        .role(bedrock_role)
 462                        .set_content(Some(bedrock_message_content))
 463                        .build()
 464                        .expect("failed to build Bedrock message"),
 465                );
 466            }
 467            Role::System => {
 468                if !system_message.is_empty() {
 469                    system_message.push_str("\n\n");
 470                }
 471                system_message.push_str(&message.string_contents());
 472            }
 473        }
 474    }
 475
 476    bedrock::Request {
 477        model,
 478        messages: new_messages,
 479        max_tokens: max_output_tokens,
 480        system: Some(system_message),
 481        tools: vec![],
 482        tool_choice: None,
 483        metadata: None,
 484        stop_sequences: Vec::new(),
 485        temperature: request.temperature.or(Some(default_temperature)),
 486        top_k: None,
 487        top_p: None,
 488    }
 489}
 490
 491// TODO: just call the ConverseOutput.usage() method:
 492// https://docs.rs/aws-sdk-bedrockruntime/latest/aws_sdk_bedrockruntime/operation/converse/struct.ConverseOutput.html#method.output
 493pub fn get_bedrock_tokens(
 494    request: LanguageModelRequest,
 495    cx: &App,
 496) -> BoxFuture<'static, Result<usize>> {
 497    cx.background_executor()
 498        .spawn(async move {
 499            let messages = request.messages;
 500            let mut tokens_from_images = 0;
 501            let mut string_messages = Vec::with_capacity(messages.len());
 502
 503            for message in messages {
 504                use language_model::MessageContent;
 505
 506                let mut string_contents = String::new();
 507
 508                for content in message.content {
 509                    match content {
 510                        MessageContent::Text(text) => {
 511                            string_contents.push_str(&text);
 512                        }
 513                        MessageContent::Image(image) => {
 514                            tokens_from_images += image.estimate_tokens();
 515                        }
 516                        MessageContent::ToolUse(_tool_use) => {
 517                            // TODO: Estimate token usage from tool uses.
 518                        }
 519                        MessageContent::ToolResult(tool_result) => {
 520                            string_contents.push_str(&tool_result.content);
 521                        }
 522                    }
 523                }
 524
 525                if !string_contents.is_empty() {
 526                    string_messages.push(tiktoken_rs::ChatCompletionRequestMessage {
 527                        role: match message.role {
 528                            Role::User => "user".into(),
 529                            Role::Assistant => "assistant".into(),
 530                            Role::System => "system".into(),
 531                        },
 532                        content: Some(string_contents),
 533                        name: None,
 534                        function_call: None,
 535                    });
 536                }
 537            }
 538
 539            // Tiktoken doesn't yet support these models, so we manually use the
 540            // same tokenizer as GPT-4.
 541            tiktoken_rs::num_tokens_from_messages("gpt-4", &string_messages)
 542                .map(|tokens| tokens + tokens_from_images)
 543        })
 544        .boxed()
 545}
 546
 547pub async fn extract_tool_args_from_events(
 548    name: String,
 549    mut events: Pin<Box<dyn Send + Stream<Item = Result<BedrockStreamingResponse, BedrockError>>>>,
 550    handle: Handle,
 551) -> Result<impl Send + Stream<Item = Result<String>>> {
 552    handle
 553        .spawn(async move {
 554            let mut tool_use_index = None;
 555            while let Some(event) = events.next().await {
 556                if let BedrockStreamingResponse::ContentBlockStart(ContentBlockStartEvent {
 557                    content_block_index,
 558                    start,
 559                    ..
 560                }) = event?
 561                {
 562                    match start {
 563                        None => {
 564                            continue;
 565                        }
 566                        Some(start) => match start.as_tool_use() {
 567                            Ok(tool_use) => {
 568                                if name == tool_use.name {
 569                                    tool_use_index = Some(content_block_index);
 570                                    break;
 571                                }
 572                            }
 573                            Err(err) => {
 574                                return Err(anyhow!("Failed to parse tool use event: {:?}", err));
 575                            }
 576                        },
 577                    }
 578                }
 579            }
 580
 581            let Some(tool_use_index) = tool_use_index else {
 582                return Err(anyhow!("Tool is not used"));
 583            };
 584
 585            Ok(events.filter_map(move |event| {
 586                let result = match event {
 587                    Err(_err) => None,
 588                    Ok(output) => match output.clone() {
 589                        BedrockStreamingResponse::ContentBlockDelta(inner) => {
 590                            match inner.clone().delta {
 591                                Some(ContentBlockDelta::ToolUse(tool_use)) => {
 592                                    if inner.content_block_index == tool_use_index {
 593                                        Some(Ok(tool_use.input))
 594                                    } else {
 595                                        None
 596                                    }
 597                                }
 598                                _ => None,
 599                            }
 600                        }
 601                        _ => None,
 602                    },
 603                };
 604
 605                async move { result }
 606            }))
 607        })
 608        .await?
 609}
 610
 611pub fn map_to_language_model_completion_events(
 612    events: Pin<Box<dyn Send + Stream<Item = Result<BedrockStreamingResponse, BedrockError>>>>,
 613    handle: Handle,
 614) -> impl Stream<Item = Result<LanguageModelCompletionEvent>> {
 615    struct RawToolUse {
 616        id: String,
 617        name: String,
 618        input_json: String,
 619    }
 620
 621    struct State {
 622        events: Pin<Box<dyn Send + Stream<Item = Result<BedrockStreamingResponse, BedrockError>>>>,
 623        tool_uses_by_index: HashMap<i32, RawToolUse>,
 624    }
 625
 626    futures::stream::unfold(
 627        State {
 628            events,
 629            tool_uses_by_index: HashMap::default(),
 630        },
 631        move |mut state: State| {
 632            let inner_handle = handle.clone();
 633            async move {
 634                inner_handle
 635                    .spawn(async {
 636                        while let Some(event) = state.events.next().await {
 637                            match event {
 638                                Ok(event) => match event {
 639                                    ConverseStreamOutput::ContentBlockDelta(cb_delta) => {
 640                                        if let Some(ContentBlockDelta::Text(text_out)) =
 641                                            cb_delta.delta
 642                                        {
 643                                            return Some((
 644                                                Some(Ok(LanguageModelCompletionEvent::Text(
 645                                                    text_out,
 646                                                ))),
 647                                                state,
 648                                            ));
 649                                        } else if let Some(ContentBlockDelta::ToolUse(text_out)) =
 650                                            cb_delta.delta
 651                                        {
 652                                            if let Some(tool_use) = state
 653                                                .tool_uses_by_index
 654                                                .get_mut(&cb_delta.content_block_index)
 655                                            {
 656                                                tool_use.input_json.push_str(text_out.input());
 657                                                return Some((None, state));
 658                                            };
 659
 660                                            return Some((None, state));
 661                                        } else if cb_delta.delta.is_none() {
 662                                            return Some((None, state));
 663                                        }
 664                                    }
 665                                    ConverseStreamOutput::ContentBlockStart(cb_start) => {
 666                                        if let Some(start) = cb_start.start {
 667                                            match start {
 668                                                ContentBlockStart::ToolUse(text_out) => {
 669                                                    let tool_use = RawToolUse {
 670                                                        id: text_out.tool_use_id,
 671                                                        name: text_out.name,
 672                                                        input_json: String::new(),
 673                                                    };
 674
 675                                                    state.tool_uses_by_index.insert(
 676                                                        cb_start.content_block_index,
 677                                                        tool_use,
 678                                                    );
 679                                                }
 680                                                _ => {}
 681                                            }
 682                                        }
 683                                    }
 684                                    ConverseStreamOutput::ContentBlockStop(cb_stop) => {
 685                                        if let Some(tool_use) = state
 686                                            .tool_uses_by_index
 687                                            .remove(&cb_stop.content_block_index)
 688                                        {
 689                                            return Some((
 690                                                Some(maybe!({
 691                                                    Ok(LanguageModelCompletionEvent::ToolUse(
 692                                                        LanguageModelToolUse {
 693                                                            id: tool_use.id.into(),
 694                                                            name: tool_use.name.into(),
 695                                                            input: if tool_use.input_json.is_empty()
 696                                                            {
 697                                                                Value::Null
 698                                                            } else {
 699                                                                serde_json::Value::from_str(
 700                                                                    &tool_use.input_json,
 701                                                                )
 702                                                                .map_err(|err| anyhow!(err))?
 703                                                            },
 704                                                        },
 705                                                    ))
 706                                                })),
 707                                                state,
 708                                            ));
 709                                        }
 710                                    }
 711                                    _ => {}
 712                                },
 713                                Err(err) => return Some((Some(Err(anyhow!(err))), state)),
 714                            }
 715                        }
 716                        None
 717                    })
 718                    .await
 719                    .log_err()
 720                    .flatten()
 721            }
 722        },
 723    )
 724    .filter_map(|event| async move { event })
 725}
 726
 727struct ConfigurationView {
 728    access_key_id_editor: Entity<Editor>,
 729    secret_access_key_editor: Entity<Editor>,
 730    region_editor: Entity<Editor>,
 731    state: gpui::Entity<State>,
 732    load_credentials_task: Option<Task<()>>,
 733}
 734
 735impl ConfigurationView {
 736    const PLACEHOLDER_ACCESS_KEY_ID_TEXT: &'static str = "XXXXXXXXXXXXXXXX";
 737    const PLACEHOLDER_SECRET_ACCESS_KEY_TEXT: &'static str =
 738        "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX";
 739    const PLACEHOLDER_REGION: &'static str = "us-east-1";
 740
 741    fn new(state: gpui::Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
 742        cx.observe(&state, |_, _, cx| {
 743            cx.notify();
 744        })
 745        .detach();
 746
 747        let load_credentials_task = Some(cx.spawn({
 748            let state = state.clone();
 749            async move |this, cx| {
 750                if let Some(task) = state
 751                    .update(cx, |state, cx| state.authenticate(cx))
 752                    .log_err()
 753                {
 754                    // We don't log an error, because "not signed in" is also an error.
 755                    let _ = task.await;
 756                }
 757                this.update(cx, |this, cx| {
 758                    this.load_credentials_task = None;
 759                    cx.notify();
 760                })
 761                .log_err();
 762            }
 763        }));
 764
 765        Self {
 766            access_key_id_editor: cx.new(|cx| {
 767                let mut editor = Editor::single_line(window, cx);
 768                editor.set_placeholder_text(Self::PLACEHOLDER_ACCESS_KEY_ID_TEXT, cx);
 769                editor
 770            }),
 771            secret_access_key_editor: cx.new(|cx| {
 772                let mut editor = Editor::single_line(window, cx);
 773                editor.set_placeholder_text(Self::PLACEHOLDER_SECRET_ACCESS_KEY_TEXT, cx);
 774                editor
 775            }),
 776            region_editor: cx.new(|cx| {
 777                let mut editor = Editor::single_line(window, cx);
 778                editor.set_placeholder_text(Self::PLACEHOLDER_REGION, cx);
 779                editor
 780            }),
 781            state,
 782            load_credentials_task,
 783        }
 784    }
 785
 786    fn save_credentials(
 787        &mut self,
 788        _: &menu::Confirm,
 789        _window: &mut Window,
 790        cx: &mut Context<Self>,
 791    ) {
 792        let access_key_id = self
 793            .access_key_id_editor
 794            .read(cx)
 795            .text(cx)
 796            .to_string()
 797            .trim()
 798            .to_string();
 799        let secret_access_key = self
 800            .secret_access_key_editor
 801            .read(cx)
 802            .text(cx)
 803            .to_string()
 804            .trim()
 805            .to_string();
 806        let region = self
 807            .region_editor
 808            .read(cx)
 809            .text(cx)
 810            .to_string()
 811            .trim()
 812            .to_string();
 813
 814        let state = self.state.clone();
 815        cx.spawn(async move |_, cx| {
 816            state
 817                .update(cx, |state, cx| {
 818                    let credentials: BedrockCredentials = BedrockCredentials {
 819                        access_key_id: access_key_id.clone(),
 820                        secret_access_key: secret_access_key.clone(),
 821                        region: region.clone(),
 822                    };
 823
 824                    state.set_credentials(credentials, cx)
 825                })?
 826                .await
 827        })
 828        .detach_and_log_err(cx);
 829    }
 830
 831    fn reset_credentials(&mut self, window: &mut Window, cx: &mut Context<Self>) {
 832        self.access_key_id_editor
 833            .update(cx, |editor, cx| editor.set_text("", window, cx));
 834        self.secret_access_key_editor
 835            .update(cx, |editor, cx| editor.set_text("", window, cx));
 836        self.region_editor
 837            .update(cx, |editor, cx| editor.set_text("", window, cx));
 838
 839        let state = self.state.clone();
 840        cx.spawn(async move |_, cx| {
 841            state
 842                .update(cx, |state, cx| state.reset_credentials(cx))?
 843                .await
 844        })
 845        .detach_and_log_err(cx);
 846    }
 847
 848    fn make_text_style(&self, cx: &Context<Self>) -> TextStyle {
 849        let settings = ThemeSettings::get_global(cx);
 850        TextStyle {
 851            color: cx.theme().colors().text,
 852            font_family: settings.ui_font.family.clone(),
 853            font_features: settings.ui_font.features.clone(),
 854            font_fallbacks: settings.ui_font.fallbacks.clone(),
 855            font_size: rems(0.875).into(),
 856            font_weight: settings.ui_font.weight,
 857            font_style: FontStyle::Normal,
 858            line_height: relative(1.3),
 859            background_color: None,
 860            underline: None,
 861            strikethrough: None,
 862            white_space: WhiteSpace::Normal,
 863            text_overflow: None,
 864            text_align: Default::default(),
 865            line_clamp: None,
 866        }
 867    }
 868
 869    fn render_aa_id_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
 870        let text_style = self.make_text_style(cx);
 871
 872        EditorElement::new(
 873            &self.access_key_id_editor,
 874            EditorStyle {
 875                background: cx.theme().colors().editor_background,
 876                local_player: cx.theme().players().local(),
 877                text: text_style,
 878                ..Default::default()
 879            },
 880        )
 881    }
 882
 883    fn render_sk_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
 884        let text_style = self.make_text_style(cx);
 885
 886        EditorElement::new(
 887            &self.secret_access_key_editor,
 888            EditorStyle {
 889                background: cx.theme().colors().editor_background,
 890                local_player: cx.theme().players().local(),
 891                text: text_style,
 892                ..Default::default()
 893            },
 894        )
 895    }
 896
 897    fn render_region_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
 898        let text_style = self.make_text_style(cx);
 899
 900        EditorElement::new(
 901            &self.region_editor,
 902            EditorStyle {
 903                background: cx.theme().colors().editor_background,
 904                local_player: cx.theme().players().local(),
 905                text: text_style,
 906                ..Default::default()
 907            },
 908        )
 909    }
 910
 911    fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
 912        !self.state.read(cx).is_authenticated()
 913    }
 914}
 915
 916impl Render for ConfigurationView {
 917    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
 918        let env_var_set = self.state.read(cx).credentials_from_env;
 919        let bg_color = cx.theme().colors().editor_background;
 920        let border_color = cx.theme().colors().border_variant;
 921        let input_base_styles = || {
 922            h_flex()
 923                .w_full()
 924                .px_2()
 925                .py_1()
 926                .bg(bg_color)
 927                .border_1()
 928                .border_color(border_color)
 929                .rounded_sm()
 930        };
 931
 932        if self.load_credentials_task.is_some() {
 933            div().child(Label::new("Loading credentials...")).into_any()
 934        } else if self.should_render_editor(cx) {
 935            v_flex()
 936                .size_full()
 937                .on_action(cx.listener(ConfigurationView::save_credentials))
 938                .child(Label::new("To use Zed's assistant with Bedrock, you need to add the Access Key ID, Secret Access Key and AWS Region. Follow these steps:"))
 939                .child(
 940                    List::new()
 941                        .child(
 942                            InstructionListItem::new(
 943                                "Start by",
 944                                Some("creating a user and security credentials"),
 945                                Some("https://us-east-1.console.aws.amazon.com/iam/home")
 946                            )
 947                        )
 948                        .child(
 949                            InstructionListItem::new(
 950                                "Grant that user permissions according to this documentation:",
 951                                Some("Prerequisites"),
 952                                Some("https://docs.aws.amazon.com/bedrock/latest/userguide/inference-prereq.html")
 953                            )
 954                        )
 955                        .child(
 956                            InstructionListItem::new(
 957                                "Select the models you would like access to:",
 958                                Some("Bedrock Model Catalog"),
 959                                Some("https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/modelaccess")
 960                            )
 961                        )
 962                        .child(
 963                            InstructionListItem::text_only("Fill the fields below and hit enter to start using the assistant")
 964                        )
 965                )
 966                .child(
 967                    v_flex()
 968                        .my_2()
 969                        .gap_1p5()
 970                        .child(
 971                            v_flex()
 972                                .gap_0p5()
 973                                .child(Label::new("Access Key ID").size(LabelSize::Small))
 974                                .child(
 975                                    input_base_styles().child(self.render_aa_id_editor(cx))
 976                                )
 977                        )
 978                        .child(
 979                            v_flex()
 980                                .gap_0p5()
 981                                .child(Label::new("Secret Access Key").size(LabelSize::Small))
 982                                .child(
 983                                    input_base_styles().child(self.render_sk_editor(cx))
 984                                )
 985                        )
 986                        .child(
 987                            v_flex()
 988                                .gap_0p5()
 989                                .child(Label::new("Region").size(LabelSize::Small))
 990                                .child(
 991                                    input_base_styles().child(self.render_region_editor(cx))
 992                                )
 993                            )
 994                )
 995                .child(
 996                    Label::new(
 997                        format!("You can also assign the {ZED_BEDROCK_ACCESS_KEY_ID_VAR}, {ZED_BEDROCK_SECRET_ACCESS_KEY_VAR}, and {ZED_BEDROCK_REGION_VAR} environment variables and restart Zed."),
 998                    )
 999                        .size(LabelSize::Small)
1000                        .color(Color::Muted),
1001                )
1002                .into_any()
1003        } else {
1004            h_flex()
1005                .size_full()
1006                .justify_between()
1007                .child(
1008                    h_flex()
1009                        .gap_1()
1010                        .child(Icon::new(IconName::Check).color(Color::Success))
1011                        .child(Label::new(if env_var_set {
1012                            format!("Access Key ID is set in {ZED_BEDROCK_ACCESS_KEY_ID_VAR}, Secret Key is set in {ZED_BEDROCK_SECRET_ACCESS_KEY_VAR}, Region is set in {ZED_BEDROCK_REGION_VAR} environment variables.")
1013                        } else {
1014                            "Credentials configured.".to_string()
1015                        })),
1016                )
1017                .child(
1018                    Button::new("reset-key", "Reset key")
1019                        .icon(Some(IconName::Trash))
1020                        .icon_size(IconSize::Small)
1021                        .icon_position(IconPosition::Start)
1022                        .disabled(env_var_set)
1023                        .when(env_var_set, |this| {
1024                            this.tooltip(Tooltip::text(format!("To reset your credentials, unset the {ZED_BEDROCK_ACCESS_KEY_ID_VAR}, {ZED_BEDROCK_SECRET_ACCESS_KEY_VAR}, and {ZED_BEDROCK_REGION_VAR} environment variables.")))
1025                        })
1026                        .on_click(cx.listener(|this, _, window, cx| this.reset_credentials(window, cx))),
1027                )
1028                .into_any()
1029        }
1030    }
1031}