@@ -2,21 +2,38 @@ use anyhow::anyhow;
use serde::{Deserialize, Serialize};
use strum::EnumIter;
+#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
+#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
+pub enum BedrockModelMode {
+ #[default]
+ Default,
+ Thinking {
+ budget_tokens: Option<u64>,
+ },
+}
+
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
pub enum Model {
// Anthropic models (already included)
#[default]
#[serde(rename = "claude-3-5-sonnet-v2", alias = "claude-3-5-sonnet-latest")]
- Claude3_5Sonnet,
+ Claude3_5SonnetV2,
#[serde(rename = "claude-3-7-sonnet", alias = "claude-3-7-sonnet-latest")]
Claude3_7Sonnet,
+ #[serde(
+ rename = "claude-3-7-sonnet-thinking",
+ alias = "claude-3-7-sonnet-thinking-latest"
+ )]
+ Claude3_7SonnetThinking,
#[serde(rename = "claude-3-opus", alias = "claude-3-opus-latest")]
Claude3Opus,
#[serde(rename = "claude-3-sonnet", alias = "claude-3-sonnet-latest")]
Claude3Sonnet,
#[serde(rename = "claude-3-5-haiku", alias = "claude-3-5-haiku-latest")]
Claude3_5Haiku,
+ Claude3_5Sonnet,
+ Claude3Haiku,
// Amazon Nova Models
AmazonNovaLite,
AmazonNovaMicro,
@@ -69,7 +86,7 @@ pub enum Model {
impl Model {
pub fn from_id(id: &str) -> anyhow::Result<Self> {
if id.starts_with("claude-3-5-sonnet-v2") {
- Ok(Self::Claude3_5Sonnet)
+ Ok(Self::Claude3_5SonnetV2)
} else if id.starts_with("claude-3-opus") {
Ok(Self::Claude3Opus)
} else if id.starts_with("claude-3-sonnet") {
@@ -78,6 +95,8 @@ impl Model {
Ok(Self::Claude3_5Haiku)
} else if id.starts_with("claude-3-7-sonnet") {
Ok(Self::Claude3_7Sonnet)
+ } else if id.starts_with("claude-3-7-sonnet-thinking") {
+ Ok(Self::Claude3_7SonnetThinking)
} else {
Err(anyhow!("invalid model id"))
}
@@ -85,14 +104,18 @@ impl Model {
pub fn id(&self) -> &str {
match self {
- Model::Claude3_5Sonnet => "us.anthropic.claude-3-5-sonnet-20241022-v2:0",
- Model::Claude3Opus => "us.anthropic.claude-3-opus-20240229-v1:0",
- Model::Claude3Sonnet => "us.anthropic.claude-3-sonnet-20240229-v1:0",
- Model::Claude3_5Haiku => "us.anthropic.claude-3-5-haiku-20241022-v1:0",
- Model::Claude3_7Sonnet => "us.anthropic.claude-3-7-sonnet-20250219-v1:0",
- Model::AmazonNovaLite => "us.amazon.nova-lite-v1:0",
- Model::AmazonNovaMicro => "us.amazon.nova-micro-v1:0",
- Model::AmazonNovaPro => "us.amazon.nova-pro-v1:0",
+ Model::Claude3_5SonnetV2 => "anthropic.claude-3-5-sonnet-20241022-v2:0",
+ Model::Claude3_5Sonnet => "anthropic.claude-3-5-sonnet-20240620-v1:0",
+ Model::Claude3Opus => "anthropic.claude-3-opus-20240229-v1:0",
+ Model::Claude3Sonnet => "anthropic.claude-3-sonnet-20240229-v1:0",
+ Model::Claude3Haiku => "anthropic.claude-3-haiku-20240307-v1:0",
+ Model::Claude3_5Haiku => "anthropic.claude-3-5-haiku-20241022-v1:0",
+ Model::Claude3_7Sonnet | Model::Claude3_7SonnetThinking => {
+ "anthropic.claude-3-7-sonnet-20250219-v1:0"
+ }
+ Model::AmazonNovaLite => "amazon.nova-lite-v1:0",
+ Model::AmazonNovaMicro => "amazon.nova-micro-v1:0",
+ Model::AmazonNovaPro => "amazon.nova-pro-v1:0",
Model::DeepSeekR1 => "us.deepseek.r1-v1:0",
Model::AI21J2GrandeInstruct => "ai21.j2-grande-instruct",
Model::AI21J2JumboInstruct => "ai21.j2-jumbo-instruct",
@@ -128,11 +151,14 @@ impl Model {
pub fn display_name(&self) -> &str {
match self {
- Self::Claude3_5Sonnet => "Claude 3.5 Sonnet v2",
+ Self::Claude3_5SonnetV2 => "Claude 3.5 Sonnet v2",
+ Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
Self::Claude3Opus => "Claude 3 Opus",
Self::Claude3Sonnet => "Claude 3 Sonnet",
+ Self::Claude3Haiku => "Claude 3 Haiku",
Self::Claude3_5Haiku => "Claude 3.5 Haiku",
Self::Claude3_7Sonnet => "Claude 3.7 Sonnet",
+ Self::Claude3_7SonnetThinking => "Claude 3.7 Sonnet Thinking",
Self::AmazonNovaLite => "Amazon Nova Lite",
Self::AmazonNovaMicro => "Amazon Nova Micro",
Self::AmazonNovaPro => "Amazon Nova Pro",
@@ -173,7 +199,7 @@ impl Model {
pub fn max_token_count(&self) -> usize {
match self {
- Self::Claude3_5Sonnet
+ Self::Claude3_5SonnetV2
| Self::Claude3Opus
| Self::Claude3Sonnet
| Self::Claude3_5Haiku
@@ -186,7 +212,8 @@ impl Model {
pub fn max_output_tokens(&self) -> u32 {
match self {
Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3_5Haiku => 4_096,
- Self::Claude3_5Sonnet => 8_192,
+ Self::Claude3_7Sonnet | Self::Claude3_7SonnetThinking => 128_000,
+ Self::Claude3_5SonnetV2 => 8_192,
Self::Custom {
max_output_tokens, ..
} => max_output_tokens.unwrap_or(4_096),
@@ -196,7 +223,7 @@ impl Model {
pub fn default_temperature(&self) -> f32 {
match self {
- Self::Claude3_5Sonnet
+ Self::Claude3_5SonnetV2
| Self::Claude3Opus
| Self::Claude3Sonnet
| Self::Claude3_5Haiku
@@ -208,4 +235,253 @@ impl Model {
_ => 1.0,
}
}
+
+ pub fn supports_tool_use(&self) -> bool {
+ match self {
+ // Anthropic Claude 3 models (all support tool use)
+ Self::Claude3Opus
+ | Self::Claude3Sonnet
+ | Self::Claude3_5Sonnet
+ | Self::Claude3_5SonnetV2
+ | Self::Claude3_7Sonnet
+ | Self::Claude3_7SonnetThinking
+ | Self::Claude3_5Haiku => true,
+
+ // Amazon Nova models (all support tool use)
+ Self::AmazonNovaPro | Self::AmazonNovaLite | Self::AmazonNovaMicro => true,
+
+ // AI21 Jamba 1.5 models support tool use
+ Self::AI21Jamba15LargeV1 | Self::AI21Jamba15MiniV1 => true,
+
+ // Cohere Command R models support tool use
+ Self::CohereCommandRV1 | Self::CohereCommandRPlusV1 => true,
+
+ // All other models don't support tool use
+ // Including Meta Llama 3.2, AI21 Jurassic, and others
+ _ => false,
+ }
+ }
+
+ pub fn mode(&self) -> BedrockModelMode {
+ match self {
+ Model::Claude3_7SonnetThinking => BedrockModelMode::Thinking {
+ budget_tokens: Some(4096),
+ },
+ _ => BedrockModelMode::Default,
+ }
+ }
+
+ pub fn cross_region_inference_id(&self, region: &str) -> Result<String, anyhow::Error> {
+ let region_group = if region.starts_with("us-gov-") {
+ "us-gov"
+ } else if region.starts_with("us-") {
+ "us"
+ } else if region.starts_with("eu-") {
+ "eu"
+ } else if region.starts_with("ap-") || region == "me-central-1" || region == "me-south-1" {
+ "apac"
+ } else if region.starts_with("ca-") || region.starts_with("sa-") {
+ // Canada and South America regions - default to US profiles
+ "us"
+ } else {
+ // Unknown region
+ return Err(anyhow!("Unsupported Region"));
+ };
+
+ let model_id = self.id();
+
+ match (self, region_group) {
+ // Custom models can't have CRI IDs
+ (Model::Custom { .. }, _) => Ok(self.id().into()),
+
+ // Models with US Gov only
+ (Model::Claude3_5Sonnet, "us-gov") | (Model::Claude3Haiku, "us-gov") => {
+ Ok(format!("{}.{}", region_group, model_id))
+ }
+
+ // Models available only in US
+ (Model::Claude3Opus, "us")
+ | (Model::Claude3_7Sonnet, "us")
+ | (Model::Claude3_7SonnetThinking, "us") => {
+ Ok(format!("{}.{}", region_group, model_id))
+ }
+
+ // Models available in US, EU, and APAC
+ (Model::Claude3_5SonnetV2, "us")
+ | (Model::Claude3_5SonnetV2, "apac")
+ | (Model::Claude3_5Sonnet, _)
+ | (Model::Claude3Haiku, _)
+ | (Model::Claude3Sonnet, _)
+ | (Model::AmazonNovaLite, _)
+ | (Model::AmazonNovaMicro, _)
+ | (Model::AmazonNovaPro, _) => Ok(format!("{}.{}", region_group, model_id)),
+
+ // Models with limited EU availability
+ (Model::MetaLlama321BInstructV1, "us")
+ | (Model::MetaLlama321BInstructV1, "eu")
+ | (Model::MetaLlama323BInstructV1, "us")
+ | (Model::MetaLlama323BInstructV1, "eu") => {
+ Ok(format!("{}.{}", region_group, model_id))
+ }
+
+ // US-only models (all remaining Meta models)
+ (Model::MetaLlama38BInstructV1, "us")
+ | (Model::MetaLlama370BInstructV1, "us")
+ | (Model::MetaLlama318BInstructV1, "us")
+ | (Model::MetaLlama318BInstructV1_128k, "us")
+ | (Model::MetaLlama3170BInstructV1, "us")
+ | (Model::MetaLlama3170BInstructV1_128k, "us")
+ | (Model::MetaLlama3211BInstructV1, "us")
+ | (Model::MetaLlama3290BInstructV1, "us") => {
+ Ok(format!("{}.{}", region_group, model_id))
+ }
+
+ // Any other combination is not supported
+ _ => Ok(self.id().into()),
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_us_region_inference_ids() -> anyhow::Result<()> {
+ // Test US regions
+ assert_eq!(
+ Model::Claude3_5SonnetV2.cross_region_inference_id("us-east-1")?,
+ "us.anthropic.claude-3-5-sonnet-20241022-v2:0"
+ );
+ assert_eq!(
+ Model::Claude3_5SonnetV2.cross_region_inference_id("us-west-2")?,
+ "us.anthropic.claude-3-5-sonnet-20241022-v2:0"
+ );
+ assert_eq!(
+ Model::AmazonNovaPro.cross_region_inference_id("us-east-2")?,
+ "us.amazon.nova-pro-v1:0"
+ );
+ Ok(())
+ }
+
+ #[test]
+ fn test_eu_region_inference_ids() -> anyhow::Result<()> {
+ // Test European regions
+ assert_eq!(
+ Model::Claude3Sonnet.cross_region_inference_id("eu-west-1")?,
+ "eu.anthropic.claude-3-sonnet-20240229-v1:0"
+ );
+ assert_eq!(
+ Model::AmazonNovaMicro.cross_region_inference_id("eu-north-1")?,
+ "eu.amazon.nova-micro-v1:0"
+ );
+ Ok(())
+ }
+
+ #[test]
+ fn test_apac_region_inference_ids() -> anyhow::Result<()> {
+ // Test Asia-Pacific regions
+ assert_eq!(
+ Model::Claude3_5SonnetV2.cross_region_inference_id("ap-northeast-1")?,
+ "apac.anthropic.claude-3-5-sonnet-20241022-v2:0"
+ );
+ assert_eq!(
+ Model::AmazonNovaLite.cross_region_inference_id("ap-south-1")?,
+ "apac.amazon.nova-lite-v1:0"
+ );
+ Ok(())
+ }
+
+ #[test]
+ fn test_gov_region_inference_ids() -> anyhow::Result<()> {
+ // Test Government regions
+ assert_eq!(
+ Model::Claude3_5Sonnet.cross_region_inference_id("us-gov-east-1")?,
+ "us-gov.anthropic.claude-3-5-sonnet-20240620-v1:0"
+ );
+ assert_eq!(
+ Model::Claude3Haiku.cross_region_inference_id("us-gov-west-1")?,
+ "us-gov.anthropic.claude-3-haiku-20240307-v1:0"
+ );
+ Ok(())
+ }
+
+ #[test]
+ fn test_meta_models_inference_ids() -> anyhow::Result<()> {
+ // Test Meta models
+ assert_eq!(
+ Model::MetaLlama370BInstructV1.cross_region_inference_id("us-east-1")?,
+ "us.meta.llama3-70b-instruct-v1:0"
+ );
+ assert_eq!(
+ Model::MetaLlama321BInstructV1.cross_region_inference_id("eu-west-1")?,
+ "eu.meta.llama3-2-1b-instruct-v1:0"
+ );
+ Ok(())
+ }
+
+ #[test]
+ fn test_mistral_models_inference_ids() -> anyhow::Result<()> {
+ // Mistral models don't follow the regional prefix pattern,
+ // so they should return their original IDs
+ assert_eq!(
+ Model::MistralMistralLarge2402V1.cross_region_inference_id("us-east-1")?,
+ "mistral.mistral-large-2402-v1:0"
+ );
+ assert_eq!(
+ Model::MistralMixtral8x7BInstructV0.cross_region_inference_id("eu-west-1")?,
+ "mistral.mixtral-8x7b-instruct-v0:1"
+ );
+ Ok(())
+ }
+
+ #[test]
+ fn test_ai21_models_inference_ids() -> anyhow::Result<()> {
+ // AI21 models don't follow the regional prefix pattern,
+ // so they should return their original IDs
+ assert_eq!(
+ Model::AI21J2UltraV1.cross_region_inference_id("us-east-1")?,
+ "ai21.j2-ultra-v1"
+ );
+ assert_eq!(
+ Model::AI21JambaInstructV1.cross_region_inference_id("eu-west-1")?,
+ "ai21.jamba-instruct-v1:0"
+ );
+ Ok(())
+ }
+
+ #[test]
+ fn test_cohere_models_inference_ids() -> anyhow::Result<()> {
+ // Cohere models don't follow the regional prefix pattern,
+ // so they should return their original IDs
+ assert_eq!(
+ Model::CohereCommandRV1.cross_region_inference_id("us-east-1")?,
+ "cohere.command-r-v1:0"
+ );
+ assert_eq!(
+ Model::CohereCommandTextV14_4k.cross_region_inference_id("ap-southeast-1")?,
+ "cohere.command-text-v14:7:4k"
+ );
+ Ok(())
+ }
+
+ #[test]
+ fn test_custom_model_inference_ids() -> anyhow::Result<()> {
+ // Test custom models
+ let custom_model = Model::Custom {
+ name: "custom.my-model-v1:0".to_string(),
+ max_tokens: 100000,
+ display_name: Some("My Custom Model".to_string()),
+ max_output_tokens: Some(8192),
+ default_temperature: Some(0.7),
+ };
+
+ // Custom model should return its name unchanged
+ assert_eq!(
+ custom_model.cross_region_inference_id("us-east-1")?,
+ "custom.my-model-v1:0"
+ );
+
+ Ok(())
+ }
}
@@ -4,19 +4,29 @@ use std::sync::Arc;
use crate::ui::InstructionListItem;
use anyhow::{Context as _, Result, anyhow};
-use aws_config::Region;
use aws_config::stalled_stream_protection::StalledStreamProtectionConfig;
+use aws_config::{BehaviorVersion, Region};
use aws_credential_types::Credentials;
use aws_http_client::AwsHttpClient;
-use bedrock::bedrock_client::types::{ContentBlockDelta, ContentBlockStart, ConverseStreamOutput};
-use bedrock::bedrock_client::{self, Config};
-use bedrock::{BedrockError, BedrockInnerContent, BedrockMessage, BedrockStreamingResponse, Model};
+use bedrock::bedrock_client::Client as BedrockClient;
+use bedrock::bedrock_client::config::timeout::TimeoutConfig;
+use bedrock::bedrock_client::types::{
+ ContentBlockDelta, ContentBlockStart, ConverseStreamOutput, ReasoningContentBlockDelta,
+ StopReason,
+};
+use bedrock::{
+ BedrockAutoToolChoice, BedrockError, BedrockInnerContent, BedrockMessage, BedrockModelMode,
+ BedrockStreamingResponse, BedrockTool, BedrockToolChoice, BedrockToolConfig,
+ BedrockToolInputSchema, BedrockToolResultBlock, BedrockToolResultContentBlock,
+ BedrockToolResultStatus, BedrockToolSpec, BedrockToolUseBlock, Model, value_to_aws_document,
+};
use collections::{BTreeMap, HashMap};
use credentials_provider::CredentialsProvider;
use editor::{Editor, EditorElement, EditorStyle};
use futures::{FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream};
use gpui::{
- AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace,
+ AnyView, App, AsyncApp, Context, Entity, FontStyle, FontWeight, Subscription, Task, TextStyle,
+ WhiteSpace,
};
use gpui_tokio::Tokio;
use http_client::HttpClient;
@@ -24,17 +34,18 @@ use language_model::{
AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
- LanguageModelRequest, LanguageModelToolUse, MessageContent, RateLimiter, Role,
+ LanguageModelRequest, LanguageModelToolUse, MessageContent, RateLimiter, Role, TokenUsage,
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use settings::{Settings, SettingsStore};
-use strum::IntoEnumIterator;
+use smol::lock::OnceCell;
+use strum::{EnumIter, IntoEnumIterator, IntoStaticStr};
use theme::ThemeSettings;
use tokio::runtime::Handle;
use ui::{Icon, IconName, List, Tooltip, prelude::*};
-use util::{ResultExt, maybe};
+use util::{ResultExt, default};
use crate::AllLanguageModelSettings;
@@ -43,15 +54,33 @@ const PROVIDER_NAME: &str = "Amazon Bedrock";
#[derive(Default, Clone, Deserialize, Serialize, PartialEq, Debug)]
pub struct BedrockCredentials {
- pub region: String,
pub access_key_id: String,
pub secret_access_key: String,
+ pub session_token: Option<String>,
+ pub region: String,
}
#[derive(Default, Clone, Debug, PartialEq)]
pub struct AmazonBedrockSettings {
- pub session_token: Option<String>,
pub available_models: Vec<AvailableModel>,
+ pub region: Option<String>,
+ pub endpoint: Option<String>,
+ pub profile_name: Option<String>,
+ pub role_arn: Option<String>,
+ pub authentication_method: Option<BedrockAuthMethod>,
+}
+
+#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, EnumIter, IntoStaticStr, JsonSchema)]
+pub enum BedrockAuthMethod {
+ #[serde(rename = "named_profile")]
+ NamedProfile,
+ #[serde(rename = "static_credentials")]
+ StaticCredentials,
+ #[serde(rename = "sso")]
+ SingleSignOn,
+ /// IMDSv2, PodIdentity, env vars, etc.
+ #[serde(rename = "default")]
+ Automatic,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
@@ -62,6 +91,36 @@ pub struct AvailableModel {
pub cache_configuration: Option<LanguageModelCacheConfiguration>,
pub max_output_tokens: Option<u32>,
pub default_temperature: Option<f32>,
+ pub mode: Option<ModelMode>,
+}
+
+#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize, JsonSchema)]
+#[serde(tag = "type", rename_all = "lowercase")]
+pub enum ModelMode {
+ #[default]
+ Default,
+ Thinking {
+ /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
+ budget_tokens: Option<u64>,
+ },
+}
+
+impl From<ModelMode> for BedrockModelMode {
+ fn from(value: ModelMode) -> Self {
+ match value {
+ ModelMode::Default => BedrockModelMode::Default,
+ ModelMode::Thinking { budget_tokens } => BedrockModelMode::Thinking { budget_tokens },
+ }
+ }
+}
+
+impl From<BedrockModelMode> for ModelMode {
+ fn from(value: BedrockModelMode) -> Self {
+ match value {
+ BedrockModelMode::Default => ModelMode::Default,
+ BedrockModelMode::Thinking { budget_tokens } => ModelMode::Thinking { budget_tokens },
+ }
+ }
}
/// The URL of the base AWS service.
@@ -73,11 +132,15 @@ const AMAZON_AWS_URL: &str = "https://amazonaws.com";
// These environment variables all use a `ZED_` prefix because we don't want to overwrite the user's AWS credentials.
const ZED_BEDROCK_ACCESS_KEY_ID_VAR: &str = "ZED_ACCESS_KEY_ID";
const ZED_BEDROCK_SECRET_ACCESS_KEY_VAR: &str = "ZED_SECRET_ACCESS_KEY";
+const ZED_BEDROCK_SESSION_TOKEN_VAR: &str = "ZED_SESSION_TOKEN";
+const ZED_AWS_PROFILE_VAR: &str = "ZED_AWS_PROFILE";
const ZED_BEDROCK_REGION_VAR: &str = "ZED_AWS_REGION";
const ZED_AWS_CREDENTIALS_VAR: &str = "ZED_AWS_CREDENTIALS";
+const ZED_AWS_ENDPOINT_VAR: &str = "ZED_AWS_ENDPOINT";
pub struct State {
credentials: Option<BedrockCredentials>,
+ settings: Option<AmazonBedrockSettings>,
credentials_from_env: bool,
_subscription: Subscription,
}
@@ -93,6 +156,7 @@ impl State {
this.update(cx, |this, cx| {
this.credentials = None;
this.credentials_from_env = false;
+ this.settings = None;
cx.notify();
})
})
@@ -120,12 +184,47 @@ impl State {
})
}
- fn is_authenticated(&self) -> bool {
- self.credentials.is_some()
+ fn is_authenticated(&self) -> Option<String> {
+ match self
+ .settings
+ .as_ref()
+ .and_then(|s| s.authentication_method.as_ref())
+ {
+ Some(BedrockAuthMethod::StaticCredentials) => Some(String::from(
+ "You are authenticated using Static Credentials.",
+ )),
+ Some(BedrockAuthMethod::NamedProfile) | Some(BedrockAuthMethod::SingleSignOn) => {
+ match self.settings.as_ref() {
+ None => Some(String::from(
+ "You are authenticated using a Named Profile, but no profile is set.",
+ )),
+ Some(settings) => match settings.clone().profile_name {
+ None => Some(String::from(
+ "You are authenticated using a Named Profile, but no profile is set.",
+ )),
+ Some(profile_name) => Some(format!(
+ "You are authenticated using a Named Profile: {profile_name}",
+ )),
+ },
+ }
+ }
+ Some(BedrockAuthMethod::Automatic) => Some(String::from(
+ "You are authenticated using Automatic Credentials.",
+ )),
+ None => {
+ if self.credentials.is_some() {
+ Some(String::from(
+ "You are authenticated using Static Credentials.",
+ ))
+ } else {
+ None
+ }
+ }
+ }
}
fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
- if self.is_authenticated() {
+ if self.is_authenticated().is_some() {
return Task::ready(Ok(()));
}
@@ -170,6 +269,7 @@ impl BedrockLanguageModelProvider {
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
let state = cx.new(|cx| State {
credentials: None,
+ settings: Some(AllLanguageModelSettings::get_global(cx).bedrock.clone()),
credentials_from_env: false,
_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
cx.notify();
@@ -209,6 +309,7 @@ impl LanguageModelProvider for BedrockLanguageModelProvider {
http_client: self.http_client.clone(),
handler: self.handler.clone(),
state: self.state.clone(),
+ client: OnceCell::new(),
request_limiter: RateLimiter::new(4),
}))
}
@@ -249,6 +350,7 @@ impl LanguageModelProvider for BedrockLanguageModelProvider {
http_client: self.http_client.clone(),
handler: self.handler.clone(),
state: self.state.clone(),
+ client: OnceCell::new(),
request_limiter: RateLimiter::new(4),
}) as Arc<dyn LanguageModel>
})
@@ -256,7 +358,7 @@ impl LanguageModelProvider for BedrockLanguageModelProvider {
}
fn is_authenticated(&self, cx: &App) -> bool {
- self.state.read(cx).is_authenticated()
+ self.state.read(cx).is_authenticated().is_some()
}
fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
@@ -287,11 +389,94 @@ struct BedrockModel {
model: Model,
http_client: AwsHttpClient,
handler: tokio::runtime::Handle,
+ client: OnceCell<BedrockClient>,
state: gpui::Entity<State>,
request_limiter: RateLimiter,
}
impl BedrockModel {
+ fn get_or_init_client(&self, cx: &AsyncApp) -> Result<&BedrockClient, anyhow::Error> {
+ self.client
+ .get_or_try_init_blocking(|| {
+ let Ok((auth_method, credentials, endpoint, region, settings)) =
+ cx.read_entity(&self.state, |state, _cx| {
+ let auth_method = state
+ .settings
+ .as_ref()
+ .and_then(|s| s.authentication_method.clone())
+ .unwrap_or(BedrockAuthMethod::Automatic);
+
+ let endpoint = state.settings.as_ref().and_then(|s| s.endpoint.clone());
+
+ let region = state
+ .settings
+ .as_ref()
+ .and_then(|s| s.region.clone())
+ .unwrap_or(String::from("us-east-1"));
+
+ (
+ auth_method,
+ state.credentials.clone(),
+ endpoint,
+ region,
+ state.settings.clone(),
+ )
+ })
+ else {
+ return Err(anyhow!("App state dropped"));
+ };
+
+ let mut config_builder = aws_config::defaults(BehaviorVersion::latest())
+ .stalled_stream_protection(StalledStreamProtectionConfig::disabled())
+ .http_client(self.http_client.clone())
+ .region(Region::new(region))
+ .timeout_config(TimeoutConfig::disabled());
+
+ if let Some(endpoint_url) = endpoint {
+ if !endpoint_url.is_empty() {
+ config_builder = config_builder.endpoint_url(endpoint_url);
+ }
+ }
+
+ match auth_method {
+ BedrockAuthMethod::StaticCredentials => {
+ if let Some(creds) = credentials {
+ let aws_creds = Credentials::new(
+ creds.access_key_id,
+ creds.secret_access_key,
+ creds.session_token,
+ None,
+ "zed-bedrock-provider",
+ );
+ config_builder = config_builder.credentials_provider(aws_creds);
+ }
+ }
+ BedrockAuthMethod::NamedProfile | BedrockAuthMethod::SingleSignOn => {
+ // Currently NamedProfile and SSO behave the same way but only the instructions change
+ // Until we support BearerAuth through SSO, this will not change.
+ let profile_name = settings
+ .and_then(|s| s.profile_name)
+ .unwrap_or_else(|| "default".to_string());
+
+ if !profile_name.is_empty() {
+ config_builder = config_builder.profile_name(profile_name);
+ }
+ }
+ BedrockAuthMethod::Automatic => {
+ // Use default credential provider chain
+ }
+ }
+
+ let config = self.handler.block_on(config_builder.load());
+ Ok(BedrockClient::new(&config))
+ })
+ .map_err(|err| anyhow!("Failed to initialize Bedrock client: {err}"))?;
+
+ self.client
+ .get()
+ .ok_or_else(|| anyhow!("Bedrock client not initialized"))
+ }
+
fn stream_completion(
&self,
request: bedrock::Request,
@@ -299,37 +484,10 @@ impl BedrockModel {
) -> Result<
BoxFuture<'static, BoxStream<'static, Result<BedrockStreamingResponse, BedrockError>>>,
> {
- let Ok(Ok((access_key_id, secret_access_key, region))) =
- cx.read_entity(&self.state, |state, _cx| {
- if let Some(credentials) = &state.credentials {
- Ok((
- credentials.access_key_id.clone(),
- credentials.secret_access_key.clone(),
- credentials.region.clone(),
- ))
- } else {
- return Err(anyhow!("Failed to read credentials"));
- }
- })
- else {
- return Err(anyhow!("App state dropped"));
- };
-
- let runtime_client = bedrock_client::Client::from_conf(
- Config::builder()
- .stalled_stream_protection(StalledStreamProtectionConfig::disabled())
- .credentials_provider(Credentials::new(
- access_key_id,
- secret_access_key,
- None,
- None,
- "Keychain",
- ))
- .region(Region::new(region))
- .http_client(self.http_client.clone())
- .build(),
- );
-
+ let runtime_client = self
+ .get_or_init_client(cx)
+ .cloned()
+ .context("Bedrock client not initialized")?;
let owned_handle = self.handler.clone();
Ok(async move {
@@ -360,7 +518,7 @@ impl LanguageModel for BedrockModel {
}
fn supports_tools(&self) -> bool {
- true
+ self.model.supports_tool_use()
}
fn telemetry_id(&self) -> String {
@@ -388,12 +546,36 @@ impl LanguageModel for BedrockModel {
request: LanguageModelRequest,
cx: &AsyncApp,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
- let request = into_bedrock(
+ let Ok(region) = cx.read_entity(&self.state, |state, _cx| {
+ // Get region - from credentials or directly from settings
+ let region = state
+ .credentials
+ .as_ref()
+ .map(|s| s.region.clone())
+ .unwrap_or(String::from("us-east-1"));
+
+ region
+ }) else {
+ return async move { Err(anyhow!("App State Dropped")) }.boxed();
+ };
+
+ let model_id = match self.model.cross_region_inference_id(®ion) {
+ Ok(s) => s,
+ Err(e) => {
+ return async move { Err(e) }.boxed();
+ }
+ };
+
+ let request = match into_bedrock(
request,
- self.model.id().into(),
+ model_id,
self.model.default_temperature(),
self.model.max_output_tokens(),
- );
+ self.model.mode(),
+ ) {
+ Ok(request) => request,
+ Err(err) => return futures::future::ready(Err(err)).boxed(),
+ };
let owned_handle = self.handler.clone();
@@ -418,7 +600,8 @@ pub fn into_bedrock(
model: String,
default_temperature: f32,
max_output_tokens: u32,
-) -> bedrock::Request {
+ mode: BedrockModelMode,
+) -> Result<bedrock::Request> {
let mut new_messages: Vec<BedrockMessage> = Vec::new();
let mut system_message = String::new();
@@ -440,6 +623,32 @@ pub fn into_bedrock(
None
}
}
+ MessageContent::ToolUse(tool_use) => BedrockToolUseBlock::builder()
+ .name(tool_use.name.to_string())
+ .tool_use_id(tool_use.id.to_string())
+ .input(value_to_aws_document(&tool_use.input))
+ .build()
+ .context("failed to build Bedrock tool use block")
+ .log_err()
+ .map(BedrockInnerContent::ToolUse),
+ MessageContent::ToolResult(tool_result) => {
+ BedrockToolResultBlock::builder()
+ .tool_use_id(tool_result.tool_use_id.to_string())
+ .content(BedrockToolResultContentBlock::Text(
+ tool_result.content.to_string(),
+ ))
+ .status({
+ if tool_result.is_error {
+ BedrockToolResultStatus::Error
+ } else {
+ BedrockToolResultStatus::Success
+ }
+ })
+ .build()
+ .context("failed to build Bedrock tool result block")
+ .log_err()
+ .map(BedrockInnerContent::ToolResult)
+ }
_ => None,
})
.collect();
@@ -459,7 +668,7 @@ pub fn into_bedrock(
.role(bedrock_role)
.set_content(Some(bedrock_message_content))
.build()
- .expect("failed to build Bedrock message"),
+ .context("failed to build Bedrock message")?,
);
}
Role::System => {
@@ -471,19 +680,47 @@ pub fn into_bedrock(
}
}
- bedrock::Request {
+ let tool_spec: Vec<BedrockTool> = request
+ .tools
+ .iter()
+ .filter_map(|tool| {
+ Some(BedrockTool::ToolSpec(
+ BedrockToolSpec::builder()
+ .name(tool.name.clone())
+ .description(tool.description.clone())
+ .input_schema(BedrockToolInputSchema::Json(value_to_aws_document(
+ &tool.input_schema,
+ )))
+ .build()
+ .log_err()?,
+ ))
+ })
+ .collect();
+
+ let tool_config: BedrockToolConfig = BedrockToolConfig::builder()
+ .set_tools(Some(tool_spec))
+ .tool_choice(BedrockToolChoice::Auto(
+ BedrockAutoToolChoice::builder().build(),
+ ))
+ .build()?;
+
+ Ok(bedrock::Request {
model,
messages: new_messages,
max_tokens: max_output_tokens,
system: Some(system_message),
- tools: vec![],
- tool_choice: None,
+ tools: Some(tool_config),
+ thinking: if let BedrockModelMode::Thinking { budget_tokens } = mode {
+ Some(bedrock::Thinking::Enabled { budget_tokens })
+ } else {
+ None
+ },
metadata: None,
stop_sequences: Vec::new(),
temperature: request.temperature.or(Some(default_temperature)),
top_k: None,
top_p: None,
- }
+ })
}
// TODO: just call the ConverseOutput.usage() method:
@@ -571,48 +808,72 @@ pub fn map_to_language_model_completion_events(
match event {
Ok(event) => match event {
ConverseStreamOutput::ContentBlockDelta(cb_delta) => {
- if let Some(ContentBlockDelta::Text(text_out)) =
- cb_delta.delta
- {
- return Some((
- Some(Ok(LanguageModelCompletionEvent::Text(
- text_out,
- ))),
- state,
- ));
- } else if let Some(ContentBlockDelta::ToolUse(text_out)) =
- cb_delta.delta
- {
- if let Some(tool_use) = state
- .tool_uses_by_index
- .get_mut(&cb_delta.content_block_index)
- {
- tool_use.input_json.push_str(text_out.input());
- return Some((None, state));
- };
+ match cb_delta.delta {
+ Some(ContentBlockDelta::Text(text_out)) => {
+ let completion_event =
+ LanguageModelCompletionEvent::Text(text_out);
+ return Some((Some(Ok(completion_event)), state));
+ }
+
+ Some(ContentBlockDelta::ToolUse(text_out)) => {
+ if let Some(tool_use) = state
+ .tool_uses_by_index
+ .get_mut(&cb_delta.content_block_index)
+ {
+ tool_use.input_json.push_str(text_out.input());
+ }
+ }
- return Some((None, state));
- } else if cb_delta.delta.is_none() {
- return Some((None, state));
+ Some(ContentBlockDelta::ReasoningContent(thinking)) => {
+ match thinking {
+ ReasoningContentBlockDelta::RedactedContent(
+ redacted,
+ ) => {
+ let thinking_event =
+ LanguageModelCompletionEvent::Thinking(
+ String::from_utf8(
+ redacted.into_inner(),
+ )
+ .unwrap_or("REDACTED".to_string()),
+ );
+
+ return Some((
+ Some(Ok(thinking_event)),
+ state,
+ ));
+ }
+ ReasoningContentBlockDelta::Signature(_sig) => {
+ }
+ ReasoningContentBlockDelta::Text(thoughts) => {
+ let thinking_event =
+ LanguageModelCompletionEvent::Thinking(
+ thoughts.to_string(),
+ );
+
+ return Some((
+ Some(Ok(thinking_event)),
+ state,
+ ));
+ }
+ _ => {}
+ }
+ }
+ _ => {}
}
}
ConverseStreamOutput::ContentBlockStart(cb_start) => {
- if let Some(start) = cb_start.start {
- match start {
- ContentBlockStart::ToolUse(text_out) => {
- let tool_use = RawToolUse {
- id: text_out.tool_use_id,
- name: text_out.name,
- input_json: String::new(),
- };
-
- state.tool_uses_by_index.insert(
- cb_start.content_block_index,
- tool_use,
- );
- }
- _ => {}
- }
+ if let Some(ContentBlockStart::ToolUse(text_out)) =
+ cb_start.start
+ {
+ let tool_use = RawToolUse {
+ id: text_out.tool_use_id,
+ name: text_out.name,
+ input_json: String::new(),
+ };
+
+ state
+ .tool_uses_by_index
+ .insert(cb_start.content_block_index, tool_use);
}
}
ConverseStreamOutput::ContentBlockStop(cb_stop) => {
@@ -620,30 +881,85 @@ pub fn map_to_language_model_completion_events(
.tool_uses_by_index
.remove(&cb_stop.content_block_index)
{
+ let tool_use_event = LanguageModelToolUse {
+ id: tool_use.id.into(),
+ name: tool_use.name.into(),
+ input: if tool_use.input_json.is_empty() {
+ Value::Null
+ } else {
+ serde_json::Value::from_str(
+ &tool_use.input_json,
+ )
+ .map_err(|err| anyhow!(err))
+ .unwrap()
+ },
+ };
+
return Some((
- Some(maybe!({
- Ok(LanguageModelCompletionEvent::ToolUse(
- LanguageModelToolUse {
- id: tool_use.id.into(),
- name: tool_use.name.into(),
- input: if tool_use.input_json.is_empty()
- {
- Value::Null
- } else {
- serde_json::Value::from_str(
- &tool_use.input_json,
- )
- .map_err(|err| anyhow!(err))?
- },
- },
- ))
- })),
+ Some(Ok(LanguageModelCompletionEvent::ToolUse(
+ tool_use_event,
+ ))),
state,
));
}
}
+
+ ConverseStreamOutput::Metadata(cb_meta) => {
+ if let Some(metadata) = cb_meta.usage {
+ let completion_event =
+ LanguageModelCompletionEvent::UsageUpdate(
+ TokenUsage {
+ input_tokens: metadata.input_tokens as u32,
+ output_tokens: metadata.output_tokens
+ as u32,
+ cache_creation_input_tokens: default(),
+ cache_read_input_tokens: default(),
+ },
+ );
+ return Some((Some(Ok(completion_event)), state));
+ }
+ }
+ ConverseStreamOutput::MessageStop(message_stop) => {
+ let reason = match message_stop.stop_reason {
+ StopReason::ContentFiltered => {
+ LanguageModelCompletionEvent::Stop(
+ language_model::StopReason::EndTurn,
+ )
+ }
+ StopReason::EndTurn => {
+ LanguageModelCompletionEvent::Stop(
+ language_model::StopReason::EndTurn,
+ )
+ }
+ StopReason::GuardrailIntervened => {
+ LanguageModelCompletionEvent::Stop(
+ language_model::StopReason::EndTurn,
+ )
+ }
+ StopReason::MaxTokens => {
+ LanguageModelCompletionEvent::Stop(
+ language_model::StopReason::EndTurn,
+ )
+ }
+ StopReason::StopSequence => {
+ LanguageModelCompletionEvent::Stop(
+ language_model::StopReason::EndTurn,
+ )
+ }
+ StopReason::ToolUse => {
+ LanguageModelCompletionEvent::Stop(
+ language_model::StopReason::ToolUse,
+ )
+ }
+ _ => LanguageModelCompletionEvent::Stop(
+ language_model::StopReason::EndTurn,
+ ),
+ };
+ return Some((Some(Ok(reason)), state));
+ }
_ => {}
},
+
Err(err) => return Some((Some(Err(anyhow!(err))), state)),
}
}
@@ -661,6 +977,7 @@ pub fn map_to_language_model_completion_events(
struct ConfigurationView {
access_key_id_editor: Entity<Editor>,
secret_access_key_editor: Entity<Editor>,
+ session_token_editor: Entity<Editor>,
region_editor: Entity<Editor>,
state: gpui::Entity<State>,
load_credentials_task: Option<Task<()>>,
@@ -670,6 +987,7 @@ impl ConfigurationView {
const PLACEHOLDER_ACCESS_KEY_ID_TEXT: &'static str = "XXXXXXXXXXXXXXXX";
const PLACEHOLDER_SECRET_ACCESS_KEY_TEXT: &'static str =
"XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX";
+ const PLACEHOLDER_SESSION_TOKEN_TEXT: &'static str = "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX";
const PLACEHOLDER_REGION: &'static str = "us-east-1";
fn new(state: gpui::Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
@@ -707,6 +1025,11 @@ impl ConfigurationView {
editor.set_placeholder_text(Self::PLACEHOLDER_SECRET_ACCESS_KEY_TEXT, cx);
editor
}),
+ session_token_editor: cx.new(|cx| {
+ let mut editor = Editor::single_line(window, cx);
+ editor.set_placeholder_text(Self::PLACEHOLDER_SESSION_TOKEN_TEXT, cx);
+ editor
+ }),
region_editor: cx.new(|cx| {
let mut editor = Editor::single_line(window, cx);
editor.set_placeholder_text(Self::PLACEHOLDER_REGION, cx);
@@ -737,6 +1060,18 @@ impl ConfigurationView {
.to_string()
.trim()
.to_string();
+ let session_token = self
+ .session_token_editor
+ .read(cx)
+ .text(cx)
+ .to_string()
+ .trim()
+ .to_string();
+ let session_token = if session_token.is_empty() {
+ None
+ } else {
+ Some(session_token)
+ };
let region = self
.region_editor
.read(cx)
@@ -744,15 +1079,21 @@ impl ConfigurationView {
.to_string()
.trim()
.to_string();
+ let region = if region.is_empty() {
+ "us-east-1".to_string()
+ } else {
+ region
+ };
let state = self.state.clone();
cx.spawn(async move |_, cx| {
state
.update(cx, |state, cx| {
let credentials: BedrockCredentials = BedrockCredentials {
+ region: region.clone(),
access_key_id: access_key_id.clone(),
secret_access_key: secret_access_key.clone(),
- region: region.clone(),
+ session_token: session_token.clone(),
};
state.set_credentials(credentials, cx)
@@ -767,6 +1108,8 @@ impl ConfigurationView {
.update(cx, |editor, cx| editor.set_text("", window, cx));
self.secret_access_key_editor
.update(cx, |editor, cx| editor.set_text("", window, cx));
+ self.session_token_editor
+ .update(cx, |editor, cx| editor.set_text("", window, cx));
self.region_editor
.update(cx, |editor, cx| editor.set_text("", window, cx));
@@ -800,7 +1143,102 @@ impl ConfigurationView {
}
}
- fn render_aa_id_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
+ fn make_input_styles(&self, cx: &Context<Self>) -> Div {
+ let bg_color = cx.theme().colors().editor_background;
+ let border_color = cx.theme().colors().border_variant;
+
+ h_flex()
+ .w_full()
+ .px_2()
+ .py_1()
+ .bg(bg_color)
+ .border_1()
+ .border_color(border_color)
+ .rounded_sm()
+ }
+
+ fn should_render_editor(&self, cx: &mut Context<Self>) -> Option<String> {
+ self.state.read(cx).is_authenticated()
+ }
+}
+
+impl Render for ConfigurationView {
+ fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
+ let env_var_set = self.state.read(cx).credentials_from_env;
+ let creds_type = self.should_render_editor(cx).is_some();
+
+ if self.load_credentials_task.is_some() {
+ return div().child(Label::new("Loading credentials...")).into_any();
+ }
+
+ if let Some(auth) = self.should_render_editor(cx) {
+ return h_flex()
+ .size_full()
+ .justify_between()
+ .child(
+ h_flex()
+ .gap_1()
+ .child(Icon::new(IconName::Check).color(Color::Success))
+ .child(Label::new(if env_var_set {
+ format!("Access Key ID is set in {ZED_BEDROCK_ACCESS_KEY_ID_VAR}, Secret Key is set in {ZED_BEDROCK_SECRET_ACCESS_KEY_VAR}, Region is set in {ZED_BEDROCK_REGION_VAR} environment variables.")
+ } else {
+ auth.clone()
+ })),
+ )
+ .child(
+ Button::new("reset-key", "Reset key")
+ .icon(Some(IconName::Trash))
+ .icon_size(IconSize::Small)
+ .icon_position(IconPosition::Start)
+ .disabled(env_var_set || creds_type)
+ .when(env_var_set, |this| {
+ this.tooltip(Tooltip::text(format!("To reset your credentials, unset the {ZED_BEDROCK_ACCESS_KEY_ID_VAR}, {ZED_BEDROCK_SECRET_ACCESS_KEY_VAR}, and {ZED_BEDROCK_REGION_VAR} environment variables.")))
+ })
+ .when(creds_type, |this| {
+ this.tooltip(Tooltip::text("You cannot reset credentials as they're being derived, check Zed settings to understand how"))
+ })
+ .on_click(cx.listener(|this, _, window, cx| this.reset_credentials(window, cx))),
+ )
+ .into_any();
+ }
+
+ v_flex()
+ .size_full()
+ .on_action(cx.listener(ConfigurationView::save_credentials))
+ .child(Label::new("To use Zed's assistant with Bedrock, you can set a custom authentication strategy through the settings.json, or use static credentials."))
+ .child(Label::new("Though to access models on AWS first, you will have to: "))
+ .child(
+ List::new()
+ .child(
+ InstructionListItem::new(
+ "Grant permissions to the strategy you plan to use according to this documentation: ",
+ Some("Prerequisites"),
+ Some("https://docs.aws.amazon.com/bedrock/latest/userguide/inference-prereq.html"),
+ )
+ )
+ .child(
+ InstructionListItem::new(
+ "Select the models you would like access to: ",
+ Some("Bedrock Model Catalog"),
+ Some("https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/modelaccess"),
+ )
+ )
+ )
+ .child(self.render_static_credentials_ui(cx))
+ .child(self.render_common_fields(cx))
+ .child(
+ Label::new(
+ format!("You can also assign the {ZED_BEDROCK_ACCESS_KEY_ID_VAR}, {ZED_BEDROCK_SECRET_ACCESS_KEY_VAR} AND {ZED_BEDROCK_REGION_VAR} environment variables and restart Zed.\n Optionally, if your environment uses AWS CLI profiles, you can set {ZED_AWS_PROFILE_VAR}; if it requires a custom endpoint, you can set {ZED_AWS_ENDPOINT_VAR}; and if it requires a Session Token, you can set {ZED_BEDROCK_SESSION_TOKEN_VAR}."),
+ )
+ .size(LabelSize::Small)
+ .color(Color::Muted),
+ )
+ .into_any()
+ }
+}
+
+impl ConfigurationView {
+ fn render_access_key_id_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
let text_style = self.make_text_style(cx);
EditorElement::new(
@@ -814,7 +1252,7 @@ impl ConfigurationView {
)
}
- fn render_sk_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
+ fn render_secret_key_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
let text_style = self.make_text_style(cx);
EditorElement::new(
@@ -828,6 +1266,20 @@ impl ConfigurationView {
)
}
+ fn render_session_token_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
+ let text_style = self.make_text_style(cx);
+
+ EditorElement::new(
+ &self.session_token_editor,
+ EditorStyle {
+ background: cx.theme().colors().editor_background,
+ local_player: cx.theme().players().local(),
+ text: text_style,
+ ..Default::default()
+ },
+ )
+ }
+
fn render_region_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
let text_style = self.make_text_style(cx);
@@ -842,124 +1294,80 @@ impl ConfigurationView {
)
}
- fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
- !self.state.read(cx).is_authenticated()
+ fn render_static_credentials_ui(&self, cx: &mut Context<Self>) -> AnyElement {
+ v_flex()
+ .my_2()
+ .gap_1p5()
+ .child(
+ Label::new("Static Keys")
+ .size(LabelSize::Default)
+ .weight(FontWeight::BOLD),
+ )
+ .child(
+ Label::new(
+ "This method uses your AWS access key ID and secret access key directly.",
+ )
+ .size(LabelSize::Small),
+ )
+ .child(
+ List::new()
+ .child(InstructionListItem::new(
+ "Create an IAM user in the AWS console with programmatic access",
+ Some("IAM Console"),
+ Some("https://us-east-1.console.aws.amazon.com/iam/home?region=us-east-1#/users"),
+ ))
+ .child(InstructionListItem::new(
+ "Attach the necessary Bedrock permissions to this ",
+ Some("user"),
+ Some("https://docs.aws.amazon.com/bedrock/latest/userguide/inference-prereq.html"),
+ ))
+ .child(InstructionListItem::text_only(
+ "Copy the access key ID and secret access key when provided",
+ ))
+ .child(InstructionListItem::text_only(
+ "Enter these credentials below",
+ )),
+ )
+ .child(
+ v_flex()
+ .gap_0p5()
+ .child(Label::new("Access Key ID").size(LabelSize::Small))
+ .child(
+ self.make_input_styles(cx)
+ .child(self.render_access_key_id_editor(cx)),
+ ),
+ )
+ .child(
+ v_flex()
+ .gap_0p5()
+ .child(Label::new("Secret Access Key").size(LabelSize::Small))
+ .child(self.make_input_styles(cx).child(self.render_secret_key_editor(cx))),
+ )
+ .child(
+ v_flex()
+ .gap_0p5()
+ .child(Label::new("Session Token (Optional)").size(LabelSize::Small))
+ .child(
+ self.make_input_styles(cx)
+ .child(self.render_session_token_editor(cx)),
+ ),
+ )
+ .into_any_element()
}
-}
-impl Render for ConfigurationView {
- fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
- let env_var_set = self.state.read(cx).credentials_from_env;
- let bg_color = cx.theme().colors().editor_background;
- let border_color = cx.theme().colors().border_variant;
- let input_base_styles = || {
- h_flex()
- .w_full()
- .px_2()
- .py_1()
- .bg(bg_color)
- .border_1()
- .border_color(border_color)
- .rounded_sm()
- };
-
- if self.load_credentials_task.is_some() {
- div().child(Label::new("Loading credentials...")).into_any()
- } else if self.should_render_editor(cx) {
- v_flex()
- .size_full()
- .on_action(cx.listener(ConfigurationView::save_credentials))
- .child(Label::new("To use Zed's assistant with Bedrock, you need to add the Access Key ID, Secret Access Key and AWS Region. Follow these steps:"))
- .child(
- List::new()
- .child(
- InstructionListItem::new(
- "Start by",
- Some("creating a user and security credentials"),
- Some("https://us-east-1.console.aws.amazon.com/iam/home")
- )
- )
- .child(
- InstructionListItem::new(
- "Grant that user permissions according to this documentation:",
- Some("Prerequisites"),
- Some("https://docs.aws.amazon.com/bedrock/latest/userguide/inference-prereq.html")
- )
- )
- .child(
- InstructionListItem::new(
- "Select the models you would like access to:",
- Some("Bedrock Model Catalog"),
- Some("https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/modelaccess")
- )
- )
- .child(
- InstructionListItem::text_only("Fill the fields below and hit enter to start using the assistant")
- )
- )
- .child(
- v_flex()
- .my_2()
- .gap_1p5()
- .child(
- v_flex()
- .gap_0p5()
- .child(Label::new("Access Key ID").size(LabelSize::Small))
- .child(
- input_base_styles().child(self.render_aa_id_editor(cx))
- )
- )
- .child(
- v_flex()
- .gap_0p5()
- .child(Label::new("Secret Access Key").size(LabelSize::Small))
- .child(
- input_base_styles().child(self.render_sk_editor(cx))
- )
- )
- .child(
- v_flex()
- .gap_0p5()
- .child(Label::new("Region").size(LabelSize::Small))
- .child(
- input_base_styles().child(self.render_region_editor(cx))
- )
- )
- )
- .child(
- Label::new(
- format!("You can also assign the {ZED_BEDROCK_ACCESS_KEY_ID_VAR}, {ZED_BEDROCK_SECRET_ACCESS_KEY_VAR}, and {ZED_BEDROCK_REGION_VAR} environment variables and restart Zed."),
- )
- .size(LabelSize::Small)
- .color(Color::Muted),
- )
- .into_any()
- } else {
- h_flex()
- .size_full()
- .justify_between()
- .child(
- h_flex()
- .gap_1()
- .child(Icon::new(IconName::Check).color(Color::Success))
- .child(Label::new(if env_var_set {
- format!("Access Key ID is set in {ZED_BEDROCK_ACCESS_KEY_ID_VAR}, Secret Key is set in {ZED_BEDROCK_SECRET_ACCESS_KEY_VAR}, Region is set in {ZED_BEDROCK_REGION_VAR} environment variables.")
- } else {
- "Credentials configured.".to_string()
- })),
- )
- .child(
- Button::new("reset-key", "Reset key")
- .icon(Some(IconName::Trash))
- .icon_size(IconSize::Small)
- .icon_position(IconPosition::Start)
- .disabled(env_var_set)
- .when(env_var_set, |this| {
- this.tooltip(Tooltip::text(format!("To reset your credentials, unset the {ZED_BEDROCK_ACCESS_KEY_ID_VAR}, {ZED_BEDROCK_SECRET_ACCESS_KEY_VAR}, and {ZED_BEDROCK_REGION_VAR} environment variables.")))
- })
- .on_click(cx.listener(|this, _, window, cx| this.reset_credentials(window, cx))),
- )
- .into_any()
- }
+ fn render_common_fields(&self, cx: &mut Context<Self>) -> AnyElement {
+ v_flex()
+ .my_2()
+ .gap_1p5()
+ .child(
+ v_flex()
+ .gap_0p5()
+ .child(Label::new("Region").size(LabelSize::Small))
+ .child(
+ self.make_input_styles(cx)
+ .child(self.render_region_editor(cx)),
+ ),
+ )
+ .into_any_element()
}
}