1use std::pin::Pin;
2use std::str::FromStr;
3use std::sync::Arc;
4
5use crate::ui::{ConfiguredApiCard, InstructionListItem};
6use anyhow::{Context as _, Result, anyhow};
7use aws_config::stalled_stream_protection::StalledStreamProtectionConfig;
8use aws_config::{BehaviorVersion, Region};
9use aws_credential_types::Credentials;
10use aws_http_client::AwsHttpClient;
11use bedrock::bedrock_client::Client as BedrockClient;
12use bedrock::bedrock_client::config::timeout::TimeoutConfig;
13use bedrock::bedrock_client::types::{
14 CachePointBlock, CachePointType, ContentBlockDelta, ContentBlockStart, ConverseStreamOutput,
15 ReasoningContentBlockDelta, StopReason,
16};
17use bedrock::{
18 BedrockAnyToolChoice, BedrockAutoToolChoice, BedrockBlob, BedrockError, BedrockInnerContent,
19 BedrockMessage, BedrockModelMode, BedrockStreamingResponse, BedrockThinkingBlock,
20 BedrockThinkingTextBlock, BedrockTool, BedrockToolChoice, BedrockToolConfig,
21 BedrockToolInputSchema, BedrockToolResultBlock, BedrockToolResultContentBlock,
22 BedrockToolResultStatus, BedrockToolSpec, BedrockToolUseBlock, Model, value_to_aws_document,
23};
24use collections::{BTreeMap, HashMap};
25use credentials_provider::CredentialsProvider;
26use futures::{FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream};
27use gpui::{
28 AnyView, App, AsyncApp, Context, Entity, FocusHandle, FontWeight, Subscription, Task, Window,
29 actions,
30};
31use gpui_tokio::Tokio;
32use http_client::HttpClient;
33use language_model::{
34 AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
35 LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName,
36 LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
37 LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice,
38 LanguageModelToolResultContent, LanguageModelToolUse, MessageContent, RateLimiter, Role,
39 TokenUsage,
40};
41use schemars::JsonSchema;
42use serde::{Deserialize, Serialize};
43use serde_json::Value;
44use settings::{BedrockAvailableModel as AvailableModel, Settings, SettingsStore};
45use smol::lock::OnceCell;
46use strum::{EnumIter, IntoEnumIterator, IntoStaticStr};
47use ui::{List, prelude::*};
48use ui_input::InputField;
49use util::ResultExt;
50
51use crate::AllLanguageModelSettings;
52
53actions!(bedrock, [Tab, TabPrev]);
54
55const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("amazon-bedrock");
56const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Amazon Bedrock");
57
58#[derive(Default, Clone, Deserialize, Serialize, PartialEq, Debug)]
59pub struct BedrockCredentials {
60 pub access_key_id: String,
61 pub secret_access_key: String,
62 pub session_token: Option<String>,
63 pub region: String,
64}
65
66#[derive(Default, Clone, Debug, PartialEq)]
67pub struct AmazonBedrockSettings {
68 pub available_models: Vec<AvailableModel>,
69 pub region: Option<String>,
70 pub endpoint: Option<String>,
71 pub profile_name: Option<String>,
72 pub role_arn: Option<String>,
73 pub authentication_method: Option<BedrockAuthMethod>,
74 pub allow_global: Option<bool>,
75}
76
77#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, EnumIter, IntoStaticStr, JsonSchema)]
78pub enum BedrockAuthMethod {
79 #[serde(rename = "named_profile")]
80 NamedProfile,
81 #[serde(rename = "sso")]
82 SingleSignOn,
83 /// IMDSv2, PodIdentity, env vars, etc.
84 #[serde(rename = "default")]
85 Automatic,
86}
87
88impl From<settings::BedrockAuthMethodContent> for BedrockAuthMethod {
89 fn from(value: settings::BedrockAuthMethodContent) -> Self {
90 match value {
91 settings::BedrockAuthMethodContent::SingleSignOn => BedrockAuthMethod::SingleSignOn,
92 settings::BedrockAuthMethodContent::Automatic => BedrockAuthMethod::Automatic,
93 settings::BedrockAuthMethodContent::NamedProfile => BedrockAuthMethod::NamedProfile,
94 }
95 }
96}
97
98#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize, JsonSchema)]
99#[serde(tag = "type", rename_all = "lowercase")]
100pub enum ModelMode {
101 #[default]
102 Default,
103 Thinking {
104 /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
105 budget_tokens: Option<u64>,
106 },
107}
108
109impl From<ModelMode> for BedrockModelMode {
110 fn from(value: ModelMode) -> Self {
111 match value {
112 ModelMode::Default => BedrockModelMode::Default,
113 ModelMode::Thinking { budget_tokens } => BedrockModelMode::Thinking { budget_tokens },
114 }
115 }
116}
117
118impl From<BedrockModelMode> for ModelMode {
119 fn from(value: BedrockModelMode) -> Self {
120 match value {
121 BedrockModelMode::Default => ModelMode::Default,
122 BedrockModelMode::Thinking { budget_tokens } => ModelMode::Thinking { budget_tokens },
123 }
124 }
125}
126
127/// The URL of the base AWS service.
128///
129/// Right now we're just using this as the key to store the AWS credentials
130/// under in the keychain.
131const AMAZON_AWS_URL: &str = "https://amazonaws.com";
132
133// These environment variables all use a `ZED_` prefix because we don't want to overwrite the user's AWS credentials.
134const ZED_BEDROCK_ACCESS_KEY_ID_VAR: &str = "ZED_ACCESS_KEY_ID";
135const ZED_BEDROCK_SECRET_ACCESS_KEY_VAR: &str = "ZED_SECRET_ACCESS_KEY";
136const ZED_BEDROCK_SESSION_TOKEN_VAR: &str = "ZED_SESSION_TOKEN";
137const ZED_AWS_PROFILE_VAR: &str = "ZED_AWS_PROFILE";
138const ZED_BEDROCK_REGION_VAR: &str = "ZED_AWS_REGION";
139const ZED_AWS_CREDENTIALS_VAR: &str = "ZED_AWS_CREDENTIALS";
140const ZED_AWS_ENDPOINT_VAR: &str = "ZED_AWS_ENDPOINT";
141
142pub struct State {
143 credentials: Option<BedrockCredentials>,
144 settings: Option<AmazonBedrockSettings>,
145 credentials_from_env: bool,
146 _subscription: Subscription,
147}
148
149impl State {
150 fn reset_credentials(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
151 let credentials_provider = <dyn CredentialsProvider>::global(cx);
152 cx.spawn(async move |this, cx| {
153 credentials_provider
154 .delete_credentials(AMAZON_AWS_URL, cx)
155 .await
156 .log_err();
157 this.update(cx, |this, cx| {
158 this.credentials = None;
159 this.credentials_from_env = false;
160 this.settings = None;
161 cx.notify();
162 })
163 })
164 }
165
166 fn set_credentials(
167 &mut self,
168 credentials: BedrockCredentials,
169 cx: &mut Context<Self>,
170 ) -> Task<Result<()>> {
171 let credentials_provider = <dyn CredentialsProvider>::global(cx);
172 cx.spawn(async move |this, cx| {
173 credentials_provider
174 .write_credentials(
175 AMAZON_AWS_URL,
176 "Bearer",
177 &serde_json::to_vec(&credentials)?,
178 cx,
179 )
180 .await?;
181 this.update(cx, |this, cx| {
182 this.credentials = Some(credentials);
183 cx.notify();
184 })
185 })
186 }
187
188 fn is_authenticated(&self) -> bool {
189 let derived = self
190 .settings
191 .as_ref()
192 .and_then(|s| s.authentication_method.as_ref());
193 let creds = self.credentials.as_ref();
194
195 derived.is_some() || creds.is_some()
196 }
197
198 fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
199 if self.is_authenticated() {
200 return Task::ready(Ok(()));
201 }
202
203 let credentials_provider = <dyn CredentialsProvider>::global(cx);
204 cx.spawn(async move |this, cx| {
205 let (credentials, from_env) =
206 if let Ok(credentials) = std::env::var(ZED_AWS_CREDENTIALS_VAR) {
207 (credentials, true)
208 } else {
209 let (_, credentials) = credentials_provider
210 .read_credentials(AMAZON_AWS_URL, cx)
211 .await?
212 .ok_or_else(|| AuthenticateError::CredentialsNotFound)?;
213 (
214 String::from_utf8(credentials)
215 .context("invalid {PROVIDER_NAME} credentials")?,
216 false,
217 )
218 };
219
220 let credentials: BedrockCredentials =
221 serde_json::from_str(&credentials).context("failed to parse credentials")?;
222
223 this.update(cx, |this, cx| {
224 this.credentials = Some(credentials);
225 this.credentials_from_env = from_env;
226 cx.notify();
227 })?;
228
229 Ok(())
230 })
231 }
232
233 fn get_region(&self) -> String {
234 // Get region - from credentials or directly from settings
235 let credentials_region = self.credentials.as_ref().map(|s| s.region.clone());
236 let settings_region = self.settings.as_ref().and_then(|s| s.region.clone());
237
238 // Use credentials region if available, otherwise use settings region, finally fall back to default
239 credentials_region
240 .or(settings_region)
241 .unwrap_or(String::from("us-east-1"))
242 }
243
244 fn get_allow_global(&self) -> bool {
245 self.settings
246 .as_ref()
247 .and_then(|s| s.allow_global)
248 .unwrap_or(false)
249 }
250}
251
252pub struct BedrockLanguageModelProvider {
253 http_client: AwsHttpClient,
254 handle: tokio::runtime::Handle,
255 state: Entity<State>,
256}
257
258impl BedrockLanguageModelProvider {
259 pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
260 let state = cx.new(|cx| State {
261 credentials: None,
262 settings: Some(AllLanguageModelSettings::get_global(cx).bedrock.clone()),
263 credentials_from_env: false,
264 _subscription: cx.observe_global::<SettingsStore>(|_, cx| {
265 cx.notify();
266 }),
267 });
268
269 Self {
270 http_client: AwsHttpClient::new(http_client.clone()),
271 handle: Tokio::handle(cx),
272 state,
273 }
274 }
275
276 fn create_language_model(&self, model: bedrock::Model) -> Arc<dyn LanguageModel> {
277 Arc::new(BedrockModel {
278 id: LanguageModelId::from(model.id().to_string()),
279 model,
280 http_client: self.http_client.clone(),
281 handle: self.handle.clone(),
282 state: self.state.clone(),
283 client: OnceCell::new(),
284 request_limiter: RateLimiter::new(4),
285 })
286 }
287}
288
289impl LanguageModelProvider for BedrockLanguageModelProvider {
290 fn id(&self) -> LanguageModelProviderId {
291 PROVIDER_ID
292 }
293
294 fn name(&self) -> LanguageModelProviderName {
295 PROVIDER_NAME
296 }
297
298 fn icon(&self) -> IconName {
299 IconName::AiBedrock
300 }
301
302 fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
303 Some(self.create_language_model(bedrock::Model::default()))
304 }
305
306 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
307 let region = self.state.read(cx).get_region();
308 Some(self.create_language_model(bedrock::Model::default_fast(region.as_str())))
309 }
310
311 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
312 let mut models = BTreeMap::default();
313
314 for model in bedrock::Model::iter() {
315 if !matches!(model, bedrock::Model::Custom { .. }) {
316 // TODO: Sonnet 3.7 vs. 3.7 Thinking bug is here.
317 models.insert(model.id().to_string(), model);
318 }
319 }
320
321 // Override with available models from settings
322 for model in AllLanguageModelSettings::get_global(cx)
323 .bedrock
324 .available_models
325 .iter()
326 {
327 models.insert(
328 model.name.clone(),
329 bedrock::Model::Custom {
330 name: model.name.clone(),
331 display_name: model.display_name.clone(),
332 max_tokens: model.max_tokens,
333 max_output_tokens: model.max_output_tokens,
334 default_temperature: model.default_temperature,
335 cache_configuration: model.cache_configuration.as_ref().map(|config| {
336 bedrock::BedrockModelCacheConfiguration {
337 max_cache_anchors: config.max_cache_anchors,
338 min_total_token: config.min_total_token,
339 }
340 }),
341 },
342 );
343 }
344
345 models
346 .into_values()
347 .map(|model| self.create_language_model(model))
348 .collect()
349 }
350
351 fn is_authenticated(&self, cx: &App) -> bool {
352 self.state.read(cx).is_authenticated()
353 }
354
355 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
356 self.state.update(cx, |state, cx| state.authenticate(cx))
357 }
358
359 fn configuration_view(
360 &self,
361 _target_agent: language_model::ConfigurationViewTargetAgent,
362 window: &mut Window,
363 cx: &mut App,
364 ) -> AnyView {
365 cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
366 .into()
367 }
368
369 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
370 self.state
371 .update(cx, |state, cx| state.reset_credentials(cx))
372 }
373}
374
375impl LanguageModelProviderState for BedrockLanguageModelProvider {
376 type ObservableEntity = State;
377
378 fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
379 Some(self.state.clone())
380 }
381}
382
383struct BedrockModel {
384 id: LanguageModelId,
385 model: Model,
386 http_client: AwsHttpClient,
387 handle: tokio::runtime::Handle,
388 client: OnceCell<BedrockClient>,
389 state: Entity<State>,
390 request_limiter: RateLimiter,
391}
392
393impl BedrockModel {
394 fn get_or_init_client(&self, cx: &AsyncApp) -> anyhow::Result<&BedrockClient> {
395 self.client
396 .get_or_try_init_blocking(|| {
397 let (auth_method, credentials, endpoint, region, settings) =
398 cx.read_entity(&self.state, |state, _cx| {
399 let auth_method = state
400 .settings
401 .as_ref()
402 .and_then(|s| s.authentication_method.clone());
403
404 let endpoint = state.settings.as_ref().and_then(|s| s.endpoint.clone());
405
406 let region = state.get_region();
407
408 (
409 auth_method,
410 state.credentials.clone(),
411 endpoint,
412 region,
413 state.settings.clone(),
414 )
415 })?;
416
417 let mut config_builder = aws_config::defaults(BehaviorVersion::latest())
418 .stalled_stream_protection(StalledStreamProtectionConfig::disabled())
419 .http_client(self.http_client.clone())
420 .region(Region::new(region))
421 .timeout_config(TimeoutConfig::disabled());
422
423 if let Some(endpoint_url) = endpoint
424 && !endpoint_url.is_empty()
425 {
426 config_builder = config_builder.endpoint_url(endpoint_url);
427 }
428
429 match auth_method {
430 None => {
431 if let Some(creds) = credentials {
432 let aws_creds = Credentials::new(
433 creds.access_key_id,
434 creds.secret_access_key,
435 creds.session_token,
436 None,
437 "zed-bedrock-provider",
438 );
439 config_builder = config_builder.credentials_provider(aws_creds);
440 }
441 }
442 Some(BedrockAuthMethod::NamedProfile)
443 | Some(BedrockAuthMethod::SingleSignOn) => {
444 // Currently NamedProfile and SSO behave the same way but only the instructions change
445 // Until we support BearerAuth through SSO, this will not change.
446 let profile_name = settings
447 .and_then(|s| s.profile_name)
448 .unwrap_or_else(|| "default".to_string());
449
450 if !profile_name.is_empty() {
451 config_builder = config_builder.profile_name(profile_name);
452 }
453 }
454 Some(BedrockAuthMethod::Automatic) => {
455 // Use default credential provider chain
456 }
457 }
458
459 let config = self.handle.block_on(config_builder.load());
460 anyhow::Ok(BedrockClient::new(&config))
461 })
462 .context("initializing Bedrock client")?;
463
464 self.client.get().context("Bedrock client not initialized")
465 }
466
467 fn stream_completion(
468 &self,
469 request: bedrock::Request,
470 cx: &AsyncApp,
471 ) -> BoxFuture<
472 'static,
473 Result<BoxStream<'static, Result<BedrockStreamingResponse, BedrockError>>>,
474 > {
475 let Ok(runtime_client) = self
476 .get_or_init_client(cx)
477 .cloned()
478 .context("Bedrock client not initialized")
479 else {
480 return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
481 };
482
483 match Tokio::spawn(cx, bedrock::stream_completion(runtime_client, request)) {
484 Ok(res) => async { res.await.map_err(|err| anyhow!(err))? }.boxed(),
485 Err(err) => futures::future::ready(Err(anyhow!(err))).boxed(),
486 }
487 }
488}
489
490impl LanguageModel for BedrockModel {
491 fn id(&self) -> LanguageModelId {
492 self.id.clone()
493 }
494
495 fn name(&self) -> LanguageModelName {
496 LanguageModelName::from(self.model.display_name().to_string())
497 }
498
499 fn provider_id(&self) -> LanguageModelProviderId {
500 PROVIDER_ID
501 }
502
503 fn provider_name(&self) -> LanguageModelProviderName {
504 PROVIDER_NAME
505 }
506
507 fn supports_tools(&self) -> bool {
508 self.model.supports_tool_use()
509 }
510
511 fn supports_images(&self) -> bool {
512 false
513 }
514
515 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
516 match choice {
517 LanguageModelToolChoice::Auto | LanguageModelToolChoice::Any => {
518 self.model.supports_tool_use()
519 }
520 // Add support for None - we'll filter tool calls at response
521 LanguageModelToolChoice::None => self.model.supports_tool_use(),
522 }
523 }
524
525 fn telemetry_id(&self) -> String {
526 format!("bedrock/{}", self.model.id())
527 }
528
529 fn max_token_count(&self) -> u64 {
530 self.model.max_token_count()
531 }
532
533 fn max_output_tokens(&self) -> Option<u64> {
534 Some(self.model.max_output_tokens())
535 }
536
537 fn count_tokens(
538 &self,
539 request: LanguageModelRequest,
540 cx: &App,
541 ) -> BoxFuture<'static, Result<u64>> {
542 get_bedrock_tokens(request, cx)
543 }
544
545 fn stream_completion(
546 &self,
547 request: LanguageModelRequest,
548 cx: &AsyncApp,
549 ) -> BoxFuture<
550 'static,
551 Result<
552 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
553 LanguageModelCompletionError,
554 >,
555 > {
556 let Ok((region, allow_global)) = cx.read_entity(&self.state, |state, _cx| {
557 (state.get_region(), state.get_allow_global())
558 }) else {
559 return async move { Err(anyhow::anyhow!("App State Dropped").into()) }.boxed();
560 };
561
562 let model_id = match self.model.cross_region_inference_id(®ion, allow_global) {
563 Ok(s) => s,
564 Err(e) => {
565 return async move { Err(e.into()) }.boxed();
566 }
567 };
568
569 let deny_tool_calls = request.tool_choice == Some(LanguageModelToolChoice::None);
570
571 let request = match into_bedrock(
572 request,
573 model_id,
574 self.model.default_temperature(),
575 self.model.max_output_tokens(),
576 self.model.mode(),
577 self.model.supports_caching(),
578 ) {
579 Ok(request) => request,
580 Err(err) => return futures::future::ready(Err(err.into())).boxed(),
581 };
582
583 let request = self.stream_completion(request, cx);
584 let future = self.request_limiter.stream(async move {
585 let response = request.await.map_err(|err| anyhow!(err))?;
586 let events = map_to_language_model_completion_events(response);
587
588 if deny_tool_calls {
589 Ok(deny_tool_use_events(events).boxed())
590 } else {
591 Ok(events.boxed())
592 }
593 });
594
595 async move { Ok(future.await?.boxed()) }.boxed()
596 }
597
598 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
599 self.model
600 .cache_configuration()
601 .map(|config| LanguageModelCacheConfiguration {
602 max_cache_anchors: config.max_cache_anchors,
603 should_speculate: false,
604 min_total_token: config.min_total_token,
605 })
606 }
607}
608
609fn deny_tool_use_events(
610 events: impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
611) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
612 events.map(|event| {
613 match event {
614 Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) => {
615 // Convert tool use to an error message if model decided to call it
616 Ok(LanguageModelCompletionEvent::Text(format!(
617 "\n\n[Error: Tool calls are disabled in this context. Attempted to call '{}']",
618 tool_use.name
619 )))
620 }
621 other => other,
622 }
623 })
624}
625
626pub fn into_bedrock(
627 request: LanguageModelRequest,
628 model: String,
629 default_temperature: f32,
630 max_output_tokens: u64,
631 mode: BedrockModelMode,
632 supports_caching: bool,
633) -> Result<bedrock::Request> {
634 let mut new_messages: Vec<BedrockMessage> = Vec::new();
635 let mut system_message = String::new();
636
637 for message in request.messages {
638 if message.contents_empty() {
639 continue;
640 }
641
642 match message.role {
643 Role::User | Role::Assistant => {
644 let mut bedrock_message_content: Vec<BedrockInnerContent> = message
645 .content
646 .into_iter()
647 .filter_map(|content| match content {
648 MessageContent::Text(text) => {
649 if !text.is_empty() {
650 Some(BedrockInnerContent::Text(text))
651 } else {
652 None
653 }
654 }
655 MessageContent::Thinking { text, signature } => {
656 if model.contains(Model::DeepSeekR1.request_id()) {
657 // DeepSeekR1 doesn't support thinking blocks
658 // And the AWS API demands that you strip them
659 return None;
660 }
661 let thinking = BedrockThinkingTextBlock::builder()
662 .text(text)
663 .set_signature(signature)
664 .build()
665 .context("failed to build reasoning block")
666 .log_err()?;
667
668 Some(BedrockInnerContent::ReasoningContent(
669 BedrockThinkingBlock::ReasoningText(thinking),
670 ))
671 }
672 MessageContent::RedactedThinking(blob) => {
673 if model.contains(Model::DeepSeekR1.request_id()) {
674 // DeepSeekR1 doesn't support thinking blocks
675 // And the AWS API demands that you strip them
676 return None;
677 }
678 let redacted =
679 BedrockThinkingBlock::RedactedContent(BedrockBlob::new(blob));
680
681 Some(BedrockInnerContent::ReasoningContent(redacted))
682 }
683 MessageContent::ToolUse(tool_use) => {
684 let input = if tool_use.input.is_null() {
685 // Bedrock API requires valid JsonValue, not null, for tool use input
686 value_to_aws_document(&serde_json::json!({}))
687 } else {
688 value_to_aws_document(&tool_use.input)
689 };
690 BedrockToolUseBlock::builder()
691 .name(tool_use.name.to_string())
692 .tool_use_id(tool_use.id.to_string())
693 .input(input)
694 .build()
695 .context("failed to build Bedrock tool use block")
696 .log_err()
697 .map(BedrockInnerContent::ToolUse)
698 },
699 MessageContent::ToolResult(tool_result) => {
700 BedrockToolResultBlock::builder()
701 .tool_use_id(tool_result.tool_use_id.to_string())
702 .content(match tool_result.content {
703 LanguageModelToolResultContent::Text(text) => {
704 BedrockToolResultContentBlock::Text(text.to_string())
705 }
706 LanguageModelToolResultContent::Image(_) => {
707 BedrockToolResultContentBlock::Text(
708 // TODO: Bedrock image support
709 "[Tool responded with an image, but Zed doesn't support these in Bedrock models yet]".to_string()
710 )
711 }
712 })
713 .status({
714 if tool_result.is_error {
715 BedrockToolResultStatus::Error
716 } else {
717 BedrockToolResultStatus::Success
718 }
719 })
720 .build()
721 .context("failed to build Bedrock tool result block")
722 .log_err()
723 .map(BedrockInnerContent::ToolResult)
724 }
725 _ => None,
726 })
727 .collect();
728 if message.cache && supports_caching {
729 bedrock_message_content.push(BedrockInnerContent::CachePoint(
730 CachePointBlock::builder()
731 .r#type(CachePointType::Default)
732 .build()
733 .context("failed to build cache point block")?,
734 ));
735 }
736 let bedrock_role = match message.role {
737 Role::User => bedrock::BedrockRole::User,
738 Role::Assistant => bedrock::BedrockRole::Assistant,
739 Role::System => unreachable!("System role should never occur here"),
740 };
741 if let Some(last_message) = new_messages.last_mut()
742 && last_message.role == bedrock_role
743 {
744 last_message.content.extend(bedrock_message_content);
745 continue;
746 }
747 new_messages.push(
748 BedrockMessage::builder()
749 .role(bedrock_role)
750 .set_content(Some(bedrock_message_content))
751 .build()
752 .context("failed to build Bedrock message")?,
753 );
754 }
755 Role::System => {
756 if !system_message.is_empty() {
757 system_message.push_str("\n\n");
758 }
759 system_message.push_str(&message.string_contents());
760 }
761 }
762 }
763
764 let mut tool_spec: Vec<BedrockTool> = request
765 .tools
766 .iter()
767 .filter_map(|tool| {
768 Some(BedrockTool::ToolSpec(
769 BedrockToolSpec::builder()
770 .name(tool.name.clone())
771 .description(tool.description.clone())
772 .input_schema(BedrockToolInputSchema::Json(value_to_aws_document(
773 &tool.input_schema,
774 )))
775 .build()
776 .log_err()?,
777 ))
778 })
779 .collect();
780
781 if !tool_spec.is_empty() && supports_caching {
782 tool_spec.push(BedrockTool::CachePoint(
783 CachePointBlock::builder()
784 .r#type(CachePointType::Default)
785 .build()
786 .context("failed to build cache point block")?,
787 ));
788 }
789
790 let tool_choice = match request.tool_choice {
791 Some(LanguageModelToolChoice::Auto) | None => {
792 BedrockToolChoice::Auto(BedrockAutoToolChoice::builder().build())
793 }
794 Some(LanguageModelToolChoice::Any) => {
795 BedrockToolChoice::Any(BedrockAnyToolChoice::builder().build())
796 }
797 Some(LanguageModelToolChoice::None) => {
798 // For None, we still use Auto but will filter out tool calls in the response
799 BedrockToolChoice::Auto(BedrockAutoToolChoice::builder().build())
800 }
801 };
802 let tool_config: BedrockToolConfig = BedrockToolConfig::builder()
803 .set_tools(Some(tool_spec))
804 .tool_choice(tool_choice)
805 .build()?;
806
807 Ok(bedrock::Request {
808 model,
809 messages: new_messages,
810 max_tokens: max_output_tokens,
811 system: Some(system_message),
812 tools: Some(tool_config),
813 thinking: if request.thinking_allowed
814 && let BedrockModelMode::Thinking { budget_tokens } = mode
815 {
816 Some(bedrock::Thinking::Enabled { budget_tokens })
817 } else {
818 None
819 },
820 metadata: None,
821 stop_sequences: Vec::new(),
822 temperature: request.temperature.or(Some(default_temperature)),
823 top_k: None,
824 top_p: None,
825 })
826}
827
828// TODO: just call the ConverseOutput.usage() method:
829// https://docs.rs/aws-sdk-bedrockruntime/latest/aws_sdk_bedrockruntime/operation/converse/struct.ConverseOutput.html#method.output
830pub fn get_bedrock_tokens(
831 request: LanguageModelRequest,
832 cx: &App,
833) -> BoxFuture<'static, Result<u64>> {
834 cx.background_executor()
835 .spawn(async move {
836 let messages = request.messages;
837 let mut tokens_from_images = 0;
838 let mut string_messages = Vec::with_capacity(messages.len());
839
840 for message in messages {
841 use language_model::MessageContent;
842
843 let mut string_contents = String::new();
844
845 for content in message.content {
846 match content {
847 MessageContent::Text(text) | MessageContent::Thinking { text, .. } => {
848 string_contents.push_str(&text);
849 }
850 MessageContent::RedactedThinking(_) => {}
851 MessageContent::Image(image) => {
852 tokens_from_images += image.estimate_tokens();
853 }
854 MessageContent::ToolUse(_tool_use) => {
855 // TODO: Estimate token usage from tool uses.
856 }
857 MessageContent::ToolResult(tool_result) => match tool_result.content {
858 LanguageModelToolResultContent::Text(text) => {
859 string_contents.push_str(&text);
860 }
861 LanguageModelToolResultContent::Image(image) => {
862 tokens_from_images += image.estimate_tokens();
863 }
864 },
865 }
866 }
867
868 if !string_contents.is_empty() {
869 string_messages.push(tiktoken_rs::ChatCompletionRequestMessage {
870 role: match message.role {
871 Role::User => "user".into(),
872 Role::Assistant => "assistant".into(),
873 Role::System => "system".into(),
874 },
875 content: Some(string_contents),
876 name: None,
877 function_call: None,
878 });
879 }
880 }
881
882 // Tiktoken doesn't yet support these models, so we manually use the
883 // same tokenizer as GPT-4.
884 tiktoken_rs::num_tokens_from_messages("gpt-4", &string_messages)
885 .map(|tokens| (tokens + tokens_from_images) as u64)
886 })
887 .boxed()
888}
889
890pub fn map_to_language_model_completion_events(
891 events: Pin<Box<dyn Send + Stream<Item = Result<BedrockStreamingResponse, BedrockError>>>>,
892) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
893 struct RawToolUse {
894 id: String,
895 name: String,
896 input_json: String,
897 }
898
899 struct State {
900 events: Pin<Box<dyn Send + Stream<Item = Result<BedrockStreamingResponse, BedrockError>>>>,
901 tool_uses_by_index: HashMap<i32, RawToolUse>,
902 }
903
904 let initial_state = State {
905 events,
906 tool_uses_by_index: HashMap::default(),
907 };
908
909 futures::stream::unfold(initial_state, |mut state| async move {
910 match state.events.next().await {
911 Some(event_result) => match event_result {
912 Ok(event) => {
913 let result = match event {
914 ConverseStreamOutput::ContentBlockDelta(cb_delta) => match cb_delta.delta {
915 Some(ContentBlockDelta::Text(text)) => {
916 Some(Ok(LanguageModelCompletionEvent::Text(text)))
917 }
918 Some(ContentBlockDelta::ToolUse(tool_output)) => {
919 if let Some(tool_use) = state
920 .tool_uses_by_index
921 .get_mut(&cb_delta.content_block_index)
922 {
923 tool_use.input_json.push_str(tool_output.input());
924 }
925 None
926 }
927 Some(ContentBlockDelta::ReasoningContent(thinking)) => match thinking {
928 ReasoningContentBlockDelta::Text(thoughts) => {
929 Some(Ok(LanguageModelCompletionEvent::Thinking {
930 text: thoughts,
931 signature: None,
932 }))
933 }
934 ReasoningContentBlockDelta::Signature(sig) => {
935 Some(Ok(LanguageModelCompletionEvent::Thinking {
936 text: "".into(),
937 signature: Some(sig),
938 }))
939 }
940 ReasoningContentBlockDelta::RedactedContent(redacted) => {
941 let content = String::from_utf8(redacted.into_inner())
942 .unwrap_or("REDACTED".to_string());
943 Some(Ok(LanguageModelCompletionEvent::Thinking {
944 text: content,
945 signature: None,
946 }))
947 }
948 _ => None,
949 },
950 _ => None,
951 },
952 ConverseStreamOutput::ContentBlockStart(cb_start) => {
953 if let Some(ContentBlockStart::ToolUse(tool_start)) = cb_start.start {
954 state.tool_uses_by_index.insert(
955 cb_start.content_block_index,
956 RawToolUse {
957 id: tool_start.tool_use_id,
958 name: tool_start.name,
959 input_json: String::new(),
960 },
961 );
962 }
963 None
964 }
965 ConverseStreamOutput::ContentBlockStop(cb_stop) => state
966 .tool_uses_by_index
967 .remove(&cb_stop.content_block_index)
968 .map(|tool_use| {
969 let input = if tool_use.input_json.is_empty() {
970 Value::Null
971 } else {
972 serde_json::Value::from_str(&tool_use.input_json)
973 .unwrap_or(Value::Null)
974 };
975
976 Ok(LanguageModelCompletionEvent::ToolUse(
977 LanguageModelToolUse {
978 id: tool_use.id.into(),
979 name: tool_use.name.into(),
980 is_input_complete: true,
981 raw_input: tool_use.input_json,
982 input,
983 thought_signature: None,
984 },
985 ))
986 }),
987 ConverseStreamOutput::Metadata(cb_meta) => cb_meta.usage.map(|metadata| {
988 Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
989 input_tokens: metadata.input_tokens as u64,
990 output_tokens: metadata.output_tokens as u64,
991 cache_creation_input_tokens: metadata
992 .cache_write_input_tokens
993 .unwrap_or_default()
994 as u64,
995 cache_read_input_tokens: metadata
996 .cache_read_input_tokens
997 .unwrap_or_default()
998 as u64,
999 }))
1000 }),
1001 ConverseStreamOutput::MessageStop(message_stop) => {
1002 let stop_reason = match message_stop.stop_reason {
1003 StopReason::ToolUse => language_model::StopReason::ToolUse,
1004 _ => language_model::StopReason::EndTurn,
1005 };
1006 Some(Ok(LanguageModelCompletionEvent::Stop(stop_reason)))
1007 }
1008 _ => None,
1009 };
1010
1011 Some((result, state))
1012 }
1013 Err(err) => Some((
1014 Some(Err(LanguageModelCompletionError::Other(anyhow!(err)))),
1015 state,
1016 )),
1017 },
1018 None => None,
1019 }
1020 })
1021 .filter_map(|result| async move { result })
1022}
1023
1024struct ConfigurationView {
1025 access_key_id_editor: Entity<InputField>,
1026 secret_access_key_editor: Entity<InputField>,
1027 session_token_editor: Entity<InputField>,
1028 region_editor: Entity<InputField>,
1029 state: Entity<State>,
1030 load_credentials_task: Option<Task<()>>,
1031 focus_handle: FocusHandle,
1032}
1033
1034impl ConfigurationView {
1035 const PLACEHOLDER_ACCESS_KEY_ID_TEXT: &'static str = "XXXXXXXXXXXXXXXX";
1036 const PLACEHOLDER_SECRET_ACCESS_KEY_TEXT: &'static str =
1037 "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX";
1038 const PLACEHOLDER_SESSION_TOKEN_TEXT: &'static str = "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX";
1039 const PLACEHOLDER_REGION: &'static str = "us-east-1";
1040
1041 fn new(state: Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
1042 let focus_handle = cx.focus_handle();
1043
1044 cx.observe(&state, |_, _, cx| {
1045 cx.notify();
1046 })
1047 .detach();
1048
1049 let access_key_id_editor = cx.new(|cx| {
1050 InputField::new(window, cx, Self::PLACEHOLDER_ACCESS_KEY_ID_TEXT)
1051 .label("Access Key ID")
1052 .tab_index(0)
1053 .tab_stop(true)
1054 });
1055
1056 let secret_access_key_editor = cx.new(|cx| {
1057 InputField::new(window, cx, Self::PLACEHOLDER_SECRET_ACCESS_KEY_TEXT)
1058 .label("Secret Access Key")
1059 .tab_index(1)
1060 .tab_stop(true)
1061 });
1062
1063 let session_token_editor = cx.new(|cx| {
1064 InputField::new(window, cx, Self::PLACEHOLDER_SESSION_TOKEN_TEXT)
1065 .label("Session Token (Optional)")
1066 .tab_index(2)
1067 .tab_stop(true)
1068 });
1069
1070 let region_editor = cx.new(|cx| {
1071 InputField::new(window, cx, Self::PLACEHOLDER_REGION)
1072 .label("Region")
1073 .tab_index(3)
1074 .tab_stop(true)
1075 });
1076
1077 let load_credentials_task = Some(cx.spawn({
1078 let state = state.clone();
1079 async move |this, cx| {
1080 if let Some(task) = state
1081 .update(cx, |state, cx| state.authenticate(cx))
1082 .log_err()
1083 {
1084 // We don't log an error, because "not signed in" is also an error.
1085 let _ = task.await;
1086 }
1087 this.update(cx, |this, cx| {
1088 this.load_credentials_task = None;
1089 cx.notify();
1090 })
1091 .log_err();
1092 }
1093 }));
1094
1095 Self {
1096 access_key_id_editor,
1097 secret_access_key_editor,
1098 session_token_editor,
1099 region_editor,
1100 state,
1101 load_credentials_task,
1102 focus_handle,
1103 }
1104 }
1105
1106 fn save_credentials(
1107 &mut self,
1108 _: &menu::Confirm,
1109 _window: &mut Window,
1110 cx: &mut Context<Self>,
1111 ) {
1112 let access_key_id = self
1113 .access_key_id_editor
1114 .read(cx)
1115 .text(cx)
1116 .trim()
1117 .to_string();
1118 let secret_access_key = self
1119 .secret_access_key_editor
1120 .read(cx)
1121 .text(cx)
1122 .trim()
1123 .to_string();
1124 let session_token = self
1125 .session_token_editor
1126 .read(cx)
1127 .text(cx)
1128 .trim()
1129 .to_string();
1130 let session_token = if session_token.is_empty() {
1131 None
1132 } else {
1133 Some(session_token)
1134 };
1135 let region = self.region_editor.read(cx).text(cx).trim().to_string();
1136 let region = if region.is_empty() {
1137 "us-east-1".to_string()
1138 } else {
1139 region
1140 };
1141
1142 let state = self.state.clone();
1143 cx.spawn(async move |_, cx| {
1144 state
1145 .update(cx, |state, cx| {
1146 let credentials: BedrockCredentials = BedrockCredentials {
1147 region: region.clone(),
1148 access_key_id: access_key_id.clone(),
1149 secret_access_key: secret_access_key.clone(),
1150 session_token: session_token.clone(),
1151 };
1152
1153 state.set_credentials(credentials, cx)
1154 })?
1155 .await
1156 })
1157 .detach_and_log_err(cx);
1158 }
1159
1160 fn reset_credentials(&mut self, window: &mut Window, cx: &mut Context<Self>) {
1161 self.access_key_id_editor
1162 .update(cx, |editor, cx| editor.set_text("", window, cx));
1163 self.secret_access_key_editor
1164 .update(cx, |editor, cx| editor.set_text("", window, cx));
1165 self.session_token_editor
1166 .update(cx, |editor, cx| editor.set_text("", window, cx));
1167 self.region_editor
1168 .update(cx, |editor, cx| editor.set_text("", window, cx));
1169
1170 let state = self.state.clone();
1171 cx.spawn(async move |_, cx| {
1172 state
1173 .update(cx, |state, cx| state.reset_credentials(cx))?
1174 .await
1175 })
1176 .detach_and_log_err(cx);
1177 }
1178
1179 fn should_render_editor(&self, cx: &Context<Self>) -> bool {
1180 self.state.read(cx).is_authenticated()
1181 }
1182
1183 fn on_tab(&mut self, _: &menu::SelectNext, window: &mut Window, _: &mut Context<Self>) {
1184 window.focus_next();
1185 }
1186
1187 fn on_tab_prev(
1188 &mut self,
1189 _: &menu::SelectPrevious,
1190 window: &mut Window,
1191 _: &mut Context<Self>,
1192 ) {
1193 window.focus_prev();
1194 }
1195}
1196
1197impl Render for ConfigurationView {
1198 fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
1199 let env_var_set = self.state.read(cx).credentials_from_env;
1200 let bedrock_settings = self.state.read(cx).settings.as_ref();
1201 let bedrock_method = bedrock_settings
1202 .as_ref()
1203 .and_then(|s| s.authentication_method.clone());
1204
1205 if self.load_credentials_task.is_some() {
1206 return div().child(Label::new("Loading credentials...")).into_any();
1207 }
1208
1209 let configured_label = if env_var_set {
1210 format!(
1211 "Access Key ID is set in {ZED_BEDROCK_ACCESS_KEY_ID_VAR}, Secret Key is set in {ZED_BEDROCK_SECRET_ACCESS_KEY_VAR}, Region is set in {ZED_BEDROCK_REGION_VAR} environment variables."
1212 )
1213 } else {
1214 match bedrock_method {
1215 Some(BedrockAuthMethod::Automatic) => "You are using automatic credentials.".into(),
1216 Some(BedrockAuthMethod::NamedProfile) => "You are using named profile.".into(),
1217 Some(BedrockAuthMethod::SingleSignOn) => {
1218 "You are using a single sign on profile.".into()
1219 }
1220 None => "You are using static credentials.".into(),
1221 }
1222 };
1223
1224 let tooltip_label = if env_var_set {
1225 Some(format!(
1226 "To reset your credentials, unset the {ZED_BEDROCK_ACCESS_KEY_ID_VAR}, {ZED_BEDROCK_SECRET_ACCESS_KEY_VAR}, and {ZED_BEDROCK_REGION_VAR} environment variables."
1227 ))
1228 } else if bedrock_method.is_some() {
1229 Some("You cannot reset credentials as they're being derived, check Zed settings to understand how.".to_string())
1230 } else {
1231 None
1232 };
1233
1234 if self.should_render_editor(cx) {
1235 return ConfiguredApiCard::new(configured_label)
1236 .disabled(env_var_set || bedrock_method.is_some())
1237 .on_click(cx.listener(|this, _, window, cx| this.reset_credentials(window, cx)))
1238 .when_some(tooltip_label, |this, label| this.tooltip_label(label))
1239 .into_any_element();
1240 }
1241
1242 v_flex()
1243 .size_full()
1244 .track_focus(&self.focus_handle)
1245 .on_action(cx.listener(Self::on_tab))
1246 .on_action(cx.listener(Self::on_tab_prev))
1247 .on_action(cx.listener(ConfigurationView::save_credentials))
1248 .child(Label::new("To use Zed's agent with Bedrock, you can set a custom authentication strategy through the settings.json, or use static credentials."))
1249 .child(Label::new("But, to access models on AWS, you need to:").mt_1())
1250 .child(
1251 List::new()
1252 .child(
1253 InstructionListItem::new(
1254 "Grant permissions to the strategy you'll use according to the:",
1255 Some("Prerequisites"),
1256 Some("https://docs.aws.amazon.com/bedrock/latest/userguide/inference-prereq.html"),
1257 )
1258 )
1259 .child(
1260 InstructionListItem::new(
1261 "Select the models you would like access to:",
1262 Some("Bedrock Model Catalog"),
1263 Some("https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/modelaccess"),
1264 )
1265 )
1266 )
1267 .child(self.render_static_credentials_ui())
1268 .child(
1269 Label::new(
1270 format!("You can also assign the {ZED_BEDROCK_ACCESS_KEY_ID_VAR}, {ZED_BEDROCK_SECRET_ACCESS_KEY_VAR} AND {ZED_BEDROCK_REGION_VAR} environment variables and restart Zed."),
1271 )
1272 .size(LabelSize::Small)
1273 .color(Color::Muted)
1274 .my_1(),
1275 )
1276 .child(
1277 Label::new(
1278 format!("Optionally, if your environment uses AWS CLI profiles, you can set {ZED_AWS_PROFILE_VAR}; if it requires a custom endpoint, you can set {ZED_AWS_ENDPOINT_VAR}; and if it requires a Session Token, you can set {ZED_BEDROCK_SESSION_TOKEN_VAR}."),
1279 )
1280 .size(LabelSize::Small)
1281 .color(Color::Muted),
1282 )
1283 .into_any()
1284 }
1285}
1286
1287impl ConfigurationView {
1288 fn render_static_credentials_ui(&self) -> impl IntoElement {
1289 v_flex()
1290 .my_2()
1291 .tab_group()
1292 .gap_1p5()
1293 .child(
1294 Label::new("Static Keys")
1295 .size(LabelSize::Default)
1296 .weight(FontWeight::BOLD),
1297 )
1298 .child(
1299 Label::new(
1300 "This method uses your AWS access key ID and secret access key directly.",
1301 )
1302 )
1303 .child(
1304 List::new()
1305 .child(InstructionListItem::new(
1306 "Create an IAM user in the AWS console with programmatic access",
1307 Some("IAM Console"),
1308 Some("https://us-east-1.console.aws.amazon.com/iam/home?region=us-east-1#/users"),
1309 ))
1310 .child(InstructionListItem::new(
1311 "Attach the necessary Bedrock permissions to this ",
1312 Some("user"),
1313 Some("https://docs.aws.amazon.com/bedrock/latest/userguide/inference-prereq.html"),
1314 ))
1315 .child(InstructionListItem::text_only(
1316 "Copy the access key ID and secret access key when provided",
1317 ))
1318 .child(InstructionListItem::text_only(
1319 "Enter these credentials below",
1320 )),
1321 )
1322 .child(self.access_key_id_editor.clone())
1323 .child(self.secret_access_key_editor.clone())
1324 .child(self.session_token_editor.clone())
1325 .child(self.region_editor.clone())
1326 }
1327}