1use std::pin::Pin;
2use std::str::FromStr;
3use std::sync::Arc;
4
5use crate::ui::InstructionListItem;
6use anyhow::{Context as _, Result, anyhow};
7use aws_config::stalled_stream_protection::StalledStreamProtectionConfig;
8use aws_config::{BehaviorVersion, Region};
9use aws_credential_types::Credentials;
10use aws_http_client::AwsHttpClient;
11use bedrock::bedrock_client::Client as BedrockClient;
12use bedrock::bedrock_client::config::timeout::TimeoutConfig;
13use bedrock::bedrock_client::types::{
14 ContentBlockDelta, ContentBlockStart, ConverseStreamOutput, ReasoningContentBlockDelta,
15 StopReason,
16};
17use bedrock::{
18 BedrockAutoToolChoice, BedrockError, BedrockInnerContent, BedrockMessage, BedrockModelMode,
19 BedrockStreamingResponse, BedrockTool, BedrockToolChoice, BedrockToolConfig,
20 BedrockToolInputSchema, BedrockToolResultBlock, BedrockToolResultContentBlock,
21 BedrockToolResultStatus, BedrockToolSpec, BedrockToolUseBlock, Model, value_to_aws_document,
22};
23use collections::{BTreeMap, HashMap};
24use credentials_provider::CredentialsProvider;
25use editor::{Editor, EditorElement, EditorStyle};
26use futures::{FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream};
27use gpui::{
28 AnyView, App, AsyncApp, Context, Entity, FontStyle, FontWeight, Subscription, Task, TextStyle,
29 WhiteSpace,
30};
31use gpui_tokio::Tokio;
32use http_client::HttpClient;
33use language_model::{
34 AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
35 LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName,
36 LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
37 LanguageModelProviderState, LanguageModelRequest, LanguageModelToolUse, MessageContent,
38 RateLimiter, Role, TokenUsage,
39};
40use schemars::JsonSchema;
41use serde::{Deserialize, Serialize};
42use serde_json::Value;
43use settings::{Settings, SettingsStore};
44use smol::lock::OnceCell;
45use strum::{EnumIter, IntoEnumIterator, IntoStaticStr};
46use theme::ThemeSettings;
47use tokio::runtime::Handle;
48use ui::{Icon, IconName, List, Tooltip, prelude::*};
49use util::{ResultExt, default};
50
51use crate::AllLanguageModelSettings;
52
53const PROVIDER_ID: &str = "amazon-bedrock";
54const PROVIDER_NAME: &str = "Amazon Bedrock";
55
56#[derive(Default, Clone, Deserialize, Serialize, PartialEq, Debug)]
57pub struct BedrockCredentials {
58 pub access_key_id: String,
59 pub secret_access_key: String,
60 pub session_token: Option<String>,
61 pub region: String,
62}
63
64#[derive(Default, Clone, Debug, PartialEq)]
65pub struct AmazonBedrockSettings {
66 pub available_models: Vec<AvailableModel>,
67 pub region: Option<String>,
68 pub endpoint: Option<String>,
69 pub profile_name: Option<String>,
70 pub role_arn: Option<String>,
71 pub authentication_method: Option<BedrockAuthMethod>,
72}
73
74#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, EnumIter, IntoStaticStr, JsonSchema)]
75pub enum BedrockAuthMethod {
76 #[serde(rename = "named_profile")]
77 NamedProfile,
78 #[serde(rename = "static_credentials")]
79 StaticCredentials,
80 #[serde(rename = "sso")]
81 SingleSignOn,
82 /// IMDSv2, PodIdentity, env vars, etc.
83 #[serde(rename = "default")]
84 Automatic,
85}
86
87#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
88pub struct AvailableModel {
89 pub name: String,
90 pub display_name: Option<String>,
91 pub max_tokens: usize,
92 pub cache_configuration: Option<LanguageModelCacheConfiguration>,
93 pub max_output_tokens: Option<u32>,
94 pub default_temperature: Option<f32>,
95 pub mode: Option<ModelMode>,
96}
97
98#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize, JsonSchema)]
99#[serde(tag = "type", rename_all = "lowercase")]
100pub enum ModelMode {
101 #[default]
102 Default,
103 Thinking {
104 /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
105 budget_tokens: Option<u64>,
106 },
107}
108
109impl From<ModelMode> for BedrockModelMode {
110 fn from(value: ModelMode) -> Self {
111 match value {
112 ModelMode::Default => BedrockModelMode::Default,
113 ModelMode::Thinking { budget_tokens } => BedrockModelMode::Thinking { budget_tokens },
114 }
115 }
116}
117
118impl From<BedrockModelMode> for ModelMode {
119 fn from(value: BedrockModelMode) -> Self {
120 match value {
121 BedrockModelMode::Default => ModelMode::Default,
122 BedrockModelMode::Thinking { budget_tokens } => ModelMode::Thinking { budget_tokens },
123 }
124 }
125}
126
127/// The URL of the base AWS service.
128///
129/// Right now we're just using this as the key to store the AWS credentials
130/// under in the keychain.
131const AMAZON_AWS_URL: &str = "https://amazonaws.com";
132
133// These environment variables all use a `ZED_` prefix because we don't want to overwrite the user's AWS credentials.
134const ZED_BEDROCK_ACCESS_KEY_ID_VAR: &str = "ZED_ACCESS_KEY_ID";
135const ZED_BEDROCK_SECRET_ACCESS_KEY_VAR: &str = "ZED_SECRET_ACCESS_KEY";
136const ZED_BEDROCK_SESSION_TOKEN_VAR: &str = "ZED_SESSION_TOKEN";
137const ZED_AWS_PROFILE_VAR: &str = "ZED_AWS_PROFILE";
138const ZED_BEDROCK_REGION_VAR: &str = "ZED_AWS_REGION";
139const ZED_AWS_CREDENTIALS_VAR: &str = "ZED_AWS_CREDENTIALS";
140const ZED_AWS_ENDPOINT_VAR: &str = "ZED_AWS_ENDPOINT";
141
142pub struct State {
143 credentials: Option<BedrockCredentials>,
144 settings: Option<AmazonBedrockSettings>,
145 credentials_from_env: bool,
146 _subscription: Subscription,
147}
148
149impl State {
150 fn reset_credentials(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
151 let credentials_provider = <dyn CredentialsProvider>::global(cx);
152 cx.spawn(async move |this, cx| {
153 credentials_provider
154 .delete_credentials(AMAZON_AWS_URL, &cx)
155 .await
156 .log_err();
157 this.update(cx, |this, cx| {
158 this.credentials = None;
159 this.credentials_from_env = false;
160 this.settings = None;
161 cx.notify();
162 })
163 })
164 }
165
166 fn set_credentials(
167 &mut self,
168 credentials: BedrockCredentials,
169 cx: &mut Context<Self>,
170 ) -> Task<Result<()>> {
171 let credentials_provider = <dyn CredentialsProvider>::global(cx);
172 cx.spawn(async move |this, cx| {
173 credentials_provider
174 .write_credentials(
175 AMAZON_AWS_URL,
176 "Bearer",
177 &serde_json::to_vec(&credentials)?,
178 &cx,
179 )
180 .await?;
181 this.update(cx, |this, cx| {
182 this.credentials = Some(credentials);
183 cx.notify();
184 })
185 })
186 }
187
188 fn is_authenticated(&self) -> Option<String> {
189 match self
190 .settings
191 .as_ref()
192 .and_then(|s| s.authentication_method.as_ref())
193 {
194 Some(BedrockAuthMethod::StaticCredentials) => Some(String::from(
195 "You are authenticated using Static Credentials.",
196 )),
197 Some(BedrockAuthMethod::NamedProfile) | Some(BedrockAuthMethod::SingleSignOn) => {
198 match self.settings.as_ref() {
199 None => Some(String::from(
200 "You are authenticated using a Named Profile, but no profile is set.",
201 )),
202 Some(settings) => match settings.clone().profile_name {
203 None => Some(String::from(
204 "You are authenticated using a Named Profile, but no profile is set.",
205 )),
206 Some(profile_name) => Some(format!(
207 "You are authenticated using a Named Profile: {profile_name}",
208 )),
209 },
210 }
211 }
212 Some(BedrockAuthMethod::Automatic) => Some(String::from(
213 "You are authenticated using Automatic Credentials.",
214 )),
215 None => {
216 if self.credentials.is_some() {
217 Some(String::from(
218 "You are authenticated using Static Credentials.",
219 ))
220 } else {
221 None
222 }
223 }
224 }
225 }
226
227 fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
228 if self.is_authenticated().is_some() {
229 return Task::ready(Ok(()));
230 }
231
232 let credentials_provider = <dyn CredentialsProvider>::global(cx);
233 cx.spawn(async move |this, cx| {
234 let (credentials, from_env) =
235 if let Ok(credentials) = std::env::var(ZED_AWS_CREDENTIALS_VAR) {
236 (credentials, true)
237 } else {
238 let (_, credentials) = credentials_provider
239 .read_credentials(AMAZON_AWS_URL, &cx)
240 .await?
241 .ok_or_else(|| AuthenticateError::CredentialsNotFound)?;
242 (
243 String::from_utf8(credentials)
244 .context("invalid {PROVIDER_NAME} credentials")?,
245 false,
246 )
247 };
248
249 let credentials: BedrockCredentials =
250 serde_json::from_str(&credentials).context("failed to parse credentials")?;
251
252 this.update(cx, |this, cx| {
253 this.credentials = Some(credentials);
254 this.credentials_from_env = from_env;
255 cx.notify();
256 })?;
257
258 Ok(())
259 })
260 }
261}
262
263pub struct BedrockLanguageModelProvider {
264 http_client: AwsHttpClient,
265 handler: tokio::runtime::Handle,
266 state: gpui::Entity<State>,
267}
268
269impl BedrockLanguageModelProvider {
270 pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
271 let state = cx.new(|cx| State {
272 credentials: None,
273 settings: Some(AllLanguageModelSettings::get_global(cx).bedrock.clone()),
274 credentials_from_env: false,
275 _subscription: cx.observe_global::<SettingsStore>(|_, cx| {
276 cx.notify();
277 }),
278 });
279
280 let tokio_handle = Tokio::handle(cx);
281
282 let coerced_client = AwsHttpClient::new(http_client.clone(), tokio_handle.clone());
283
284 Self {
285 http_client: coerced_client,
286 handler: tokio_handle.clone(),
287 state,
288 }
289 }
290
291 fn create_language_model(&self, model: bedrock::Model) -> Arc<dyn LanguageModel> {
292 Arc::new(BedrockModel {
293 id: LanguageModelId::from(model.id().to_string()),
294 model,
295 http_client: self.http_client.clone(),
296 handler: self.handler.clone(),
297 state: self.state.clone(),
298 client: OnceCell::new(),
299 request_limiter: RateLimiter::new(4),
300 })
301 }
302}
303
304impl LanguageModelProvider for BedrockLanguageModelProvider {
305 fn id(&self) -> LanguageModelProviderId {
306 LanguageModelProviderId(PROVIDER_ID.into())
307 }
308
309 fn name(&self) -> LanguageModelProviderName {
310 LanguageModelProviderName(PROVIDER_NAME.into())
311 }
312
313 fn icon(&self) -> IconName {
314 IconName::AiBedrock
315 }
316
317 fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
318 Some(self.create_language_model(bedrock::Model::default()))
319 }
320
321 fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
322 Some(self.create_language_model(bedrock::Model::default_fast()))
323 }
324
325 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
326 let mut models = BTreeMap::default();
327
328 for model in bedrock::Model::iter() {
329 if !matches!(model, bedrock::Model::Custom { .. }) {
330 models.insert(model.id().to_string(), model);
331 }
332 }
333
334 // Override with available models from settings
335 for model in AllLanguageModelSettings::get_global(cx)
336 .bedrock
337 .available_models
338 .iter()
339 {
340 models.insert(
341 model.name.clone(),
342 bedrock::Model::Custom {
343 name: model.name.clone(),
344 display_name: model.display_name.clone(),
345 max_tokens: model.max_tokens,
346 max_output_tokens: model.max_output_tokens,
347 default_temperature: model.default_temperature,
348 },
349 );
350 }
351
352 models
353 .into_values()
354 .map(|model| self.create_language_model(model))
355 .collect()
356 }
357
358 fn is_authenticated(&self, cx: &App) -> bool {
359 self.state.read(cx).is_authenticated().is_some()
360 }
361
362 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
363 self.state.update(cx, |state, cx| state.authenticate(cx))
364 }
365
366 fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView {
367 cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
368 .into()
369 }
370
371 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
372 self.state
373 .update(cx, |state, cx| state.reset_credentials(cx))
374 }
375}
376
377impl LanguageModelProviderState for BedrockLanguageModelProvider {
378 type ObservableEntity = State;
379
380 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
381 Some(self.state.clone())
382 }
383}
384
385struct BedrockModel {
386 id: LanguageModelId,
387 model: Model,
388 http_client: AwsHttpClient,
389 handler: tokio::runtime::Handle,
390 client: OnceCell<BedrockClient>,
391 state: gpui::Entity<State>,
392 request_limiter: RateLimiter,
393}
394
395impl BedrockModel {
396 fn get_or_init_client(&self, cx: &AsyncApp) -> Result<&BedrockClient, anyhow::Error> {
397 self.client
398 .get_or_try_init_blocking(|| {
399 let Ok((auth_method, credentials, endpoint, region, settings)) =
400 cx.read_entity(&self.state, |state, _cx| {
401 let auth_method = state
402 .settings
403 .as_ref()
404 .and_then(|s| s.authentication_method.clone())
405 .unwrap_or(BedrockAuthMethod::Automatic);
406
407 let endpoint = state.settings.as_ref().and_then(|s| s.endpoint.clone());
408
409 let region = state
410 .settings
411 .as_ref()
412 .and_then(|s| s.region.clone())
413 .unwrap_or(String::from("us-east-1"));
414
415 (
416 auth_method,
417 state.credentials.clone(),
418 endpoint,
419 region,
420 state.settings.clone(),
421 )
422 })
423 else {
424 return Err(anyhow!("App state dropped"));
425 };
426
427 let mut config_builder = aws_config::defaults(BehaviorVersion::latest())
428 .stalled_stream_protection(StalledStreamProtectionConfig::disabled())
429 .http_client(self.http_client.clone())
430 .region(Region::new(region))
431 .timeout_config(TimeoutConfig::disabled());
432
433 if let Some(endpoint_url) = endpoint {
434 if !endpoint_url.is_empty() {
435 config_builder = config_builder.endpoint_url(endpoint_url);
436 }
437 }
438
439 match auth_method {
440 BedrockAuthMethod::StaticCredentials => {
441 if let Some(creds) = credentials {
442 let aws_creds = Credentials::new(
443 creds.access_key_id,
444 creds.secret_access_key,
445 creds.session_token,
446 None,
447 "zed-bedrock-provider",
448 );
449 config_builder = config_builder.credentials_provider(aws_creds);
450 }
451 }
452 BedrockAuthMethod::NamedProfile | BedrockAuthMethod::SingleSignOn => {
453 // Currently NamedProfile and SSO behave the same way but only the instructions change
454 // Until we support BearerAuth through SSO, this will not change.
455 let profile_name = settings
456 .and_then(|s| s.profile_name)
457 .unwrap_or_else(|| "default".to_string());
458
459 if !profile_name.is_empty() {
460 config_builder = config_builder.profile_name(profile_name);
461 }
462 }
463 BedrockAuthMethod::Automatic => {
464 // Use default credential provider chain
465 }
466 }
467
468 let config = self.handler.block_on(config_builder.load());
469 Ok(BedrockClient::new(&config))
470 })
471 .map_err(|err| anyhow!("Failed to initialize Bedrock client: {err}"))?;
472
473 self.client
474 .get()
475 .ok_or_else(|| anyhow!("Bedrock client not initialized"))
476 }
477
478 fn stream_completion(
479 &self,
480 request: bedrock::Request,
481 cx: &AsyncApp,
482 ) -> Result<
483 BoxFuture<'static, BoxStream<'static, Result<BedrockStreamingResponse, BedrockError>>>,
484 > {
485 let runtime_client = self
486 .get_or_init_client(cx)
487 .cloned()
488 .context("Bedrock client not initialized")?;
489 let owned_handle = self.handler.clone();
490
491 Ok(async move {
492 let request = bedrock::stream_completion(runtime_client, request, owned_handle);
493 request.await.unwrap_or_else(|e| {
494 futures::stream::once(async move { Err(BedrockError::ClientError(e)) }).boxed()
495 })
496 }
497 .boxed())
498 }
499}
500
501impl LanguageModel for BedrockModel {
502 fn id(&self) -> LanguageModelId {
503 self.id.clone()
504 }
505
506 fn name(&self) -> LanguageModelName {
507 LanguageModelName::from(self.model.display_name().to_string())
508 }
509
510 fn provider_id(&self) -> LanguageModelProviderId {
511 LanguageModelProviderId(PROVIDER_ID.into())
512 }
513
514 fn provider_name(&self) -> LanguageModelProviderName {
515 LanguageModelProviderName(PROVIDER_NAME.into())
516 }
517
518 fn supports_tools(&self) -> bool {
519 self.model.supports_tool_use()
520 }
521
522 fn telemetry_id(&self) -> String {
523 format!("bedrock/{}", self.model.id())
524 }
525
526 fn max_token_count(&self) -> usize {
527 self.model.max_token_count()
528 }
529
530 fn max_output_tokens(&self) -> Option<u32> {
531 Some(self.model.max_output_tokens())
532 }
533
534 fn count_tokens(
535 &self,
536 request: LanguageModelRequest,
537 cx: &App,
538 ) -> BoxFuture<'static, Result<usize>> {
539 get_bedrock_tokens(request, cx)
540 }
541
542 fn stream_completion(
543 &self,
544 request: LanguageModelRequest,
545 cx: &AsyncApp,
546 ) -> BoxFuture<
547 'static,
548 Result<
549 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
550 >,
551 > {
552 let Ok(region) = cx.read_entity(&self.state, |state, _cx| {
553 // Get region - from credentials or directly from settings
554 let region = state
555 .credentials
556 .as_ref()
557 .map(|s| s.region.clone())
558 .unwrap_or(String::from("us-east-1"));
559
560 region
561 }) else {
562 return async move { Err(anyhow!("App State Dropped")) }.boxed();
563 };
564
565 let model_id = match self.model.cross_region_inference_id(®ion) {
566 Ok(s) => s,
567 Err(e) => {
568 return async move { Err(e) }.boxed();
569 }
570 };
571
572 let request = match into_bedrock(
573 request,
574 model_id,
575 self.model.default_temperature(),
576 self.model.max_output_tokens(),
577 self.model.mode(),
578 ) {
579 Ok(request) => request,
580 Err(err) => return futures::future::ready(Err(err)).boxed(),
581 };
582
583 let owned_handle = self.handler.clone();
584
585 let request = self.stream_completion(request, cx);
586 let future = self.request_limiter.stream(async move {
587 let response = request.map_err(|err| anyhow!(err))?.await;
588 Ok(map_to_language_model_completion_events(
589 response,
590 owned_handle,
591 ))
592 });
593 async move { Ok(future.await?.boxed()) }.boxed()
594 }
595
596 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
597 None
598 }
599}
600
601pub fn into_bedrock(
602 request: LanguageModelRequest,
603 model: String,
604 default_temperature: f32,
605 max_output_tokens: u32,
606 mode: BedrockModelMode,
607) -> Result<bedrock::Request> {
608 let mut new_messages: Vec<BedrockMessage> = Vec::new();
609 let mut system_message = String::new();
610
611 for message in request.messages {
612 if message.contents_empty() {
613 continue;
614 }
615
616 match message.role {
617 Role::User | Role::Assistant => {
618 let bedrock_message_content: Vec<BedrockInnerContent> = message
619 .content
620 .into_iter()
621 .filter_map(|content| match content {
622 MessageContent::Text(text) => {
623 if !text.is_empty() {
624 Some(BedrockInnerContent::Text(text))
625 } else {
626 None
627 }
628 }
629 MessageContent::ToolUse(tool_use) => BedrockToolUseBlock::builder()
630 .name(tool_use.name.to_string())
631 .tool_use_id(tool_use.id.to_string())
632 .input(value_to_aws_document(&tool_use.input))
633 .build()
634 .context("failed to build Bedrock tool use block")
635 .log_err()
636 .map(BedrockInnerContent::ToolUse),
637 MessageContent::ToolResult(tool_result) => {
638 BedrockToolResultBlock::builder()
639 .tool_use_id(tool_result.tool_use_id.to_string())
640 .content(BedrockToolResultContentBlock::Text(
641 tool_result.content.to_string(),
642 ))
643 .status({
644 if tool_result.is_error {
645 BedrockToolResultStatus::Error
646 } else {
647 BedrockToolResultStatus::Success
648 }
649 })
650 .build()
651 .context("failed to build Bedrock tool result block")
652 .log_err()
653 .map(BedrockInnerContent::ToolResult)
654 }
655 _ => None,
656 })
657 .collect();
658 let bedrock_role = match message.role {
659 Role::User => bedrock::BedrockRole::User,
660 Role::Assistant => bedrock::BedrockRole::Assistant,
661 Role::System => unreachable!("System role should never occur here"),
662 };
663 if let Some(last_message) = new_messages.last_mut() {
664 if last_message.role == bedrock_role {
665 last_message.content.extend(bedrock_message_content);
666 continue;
667 }
668 }
669 new_messages.push(
670 BedrockMessage::builder()
671 .role(bedrock_role)
672 .set_content(Some(bedrock_message_content))
673 .build()
674 .context("failed to build Bedrock message")?,
675 );
676 }
677 Role::System => {
678 if !system_message.is_empty() {
679 system_message.push_str("\n\n");
680 }
681 system_message.push_str(&message.string_contents());
682 }
683 }
684 }
685
686 let tool_spec: Vec<BedrockTool> = request
687 .tools
688 .iter()
689 .filter_map(|tool| {
690 Some(BedrockTool::ToolSpec(
691 BedrockToolSpec::builder()
692 .name(tool.name.clone())
693 .description(tool.description.clone())
694 .input_schema(BedrockToolInputSchema::Json(value_to_aws_document(
695 &tool.input_schema,
696 )))
697 .build()
698 .log_err()?,
699 ))
700 })
701 .collect();
702
703 let tool_config: BedrockToolConfig = BedrockToolConfig::builder()
704 .set_tools(Some(tool_spec))
705 .tool_choice(BedrockToolChoice::Auto(
706 BedrockAutoToolChoice::builder().build(),
707 ))
708 .build()?;
709
710 Ok(bedrock::Request {
711 model,
712 messages: new_messages,
713 max_tokens: max_output_tokens,
714 system: Some(system_message),
715 tools: Some(tool_config),
716 thinking: if let BedrockModelMode::Thinking { budget_tokens } = mode {
717 Some(bedrock::Thinking::Enabled { budget_tokens })
718 } else {
719 None
720 },
721 metadata: None,
722 stop_sequences: Vec::new(),
723 temperature: request.temperature.or(Some(default_temperature)),
724 top_k: None,
725 top_p: None,
726 })
727}
728
729// TODO: just call the ConverseOutput.usage() method:
730// https://docs.rs/aws-sdk-bedrockruntime/latest/aws_sdk_bedrockruntime/operation/converse/struct.ConverseOutput.html#method.output
731pub fn get_bedrock_tokens(
732 request: LanguageModelRequest,
733 cx: &App,
734) -> BoxFuture<'static, Result<usize>> {
735 cx.background_executor()
736 .spawn(async move {
737 let messages = request.messages;
738 let mut tokens_from_images = 0;
739 let mut string_messages = Vec::with_capacity(messages.len());
740
741 for message in messages {
742 use language_model::MessageContent;
743
744 let mut string_contents = String::new();
745
746 for content in message.content {
747 match content {
748 MessageContent::Text(text) | MessageContent::Thinking { text, .. } => {
749 string_contents.push_str(&text);
750 }
751 MessageContent::RedactedThinking(_) => {}
752 MessageContent::Image(image) => {
753 tokens_from_images += image.estimate_tokens();
754 }
755 MessageContent::ToolUse(_tool_use) => {
756 // TODO: Estimate token usage from tool uses.
757 }
758 MessageContent::ToolResult(tool_result) => {
759 string_contents.push_str(&tool_result.content);
760 }
761 }
762 }
763
764 if !string_contents.is_empty() {
765 string_messages.push(tiktoken_rs::ChatCompletionRequestMessage {
766 role: match message.role {
767 Role::User => "user".into(),
768 Role::Assistant => "assistant".into(),
769 Role::System => "system".into(),
770 },
771 content: Some(string_contents),
772 name: None,
773 function_call: None,
774 });
775 }
776 }
777
778 // Tiktoken doesn't yet support these models, so we manually use the
779 // same tokenizer as GPT-4.
780 tiktoken_rs::num_tokens_from_messages("gpt-4", &string_messages)
781 .map(|tokens| tokens + tokens_from_images)
782 })
783 .boxed()
784}
785
786pub fn map_to_language_model_completion_events(
787 events: Pin<Box<dyn Send + Stream<Item = Result<BedrockStreamingResponse, BedrockError>>>>,
788 handle: Handle,
789) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
790 struct RawToolUse {
791 id: String,
792 name: String,
793 input_json: String,
794 }
795
796 struct State {
797 events: Pin<Box<dyn Send + Stream<Item = Result<BedrockStreamingResponse, BedrockError>>>>,
798 tool_uses_by_index: HashMap<i32, RawToolUse>,
799 }
800
801 futures::stream::unfold(
802 State {
803 events,
804 tool_uses_by_index: HashMap::default(),
805 },
806 move |mut state: State| {
807 let inner_handle = handle.clone();
808 async move {
809 inner_handle
810 .spawn(async {
811 while let Some(event) = state.events.next().await {
812 match event {
813 Ok(event) => match event {
814 ConverseStreamOutput::ContentBlockDelta(cb_delta) => {
815 match cb_delta.delta {
816 Some(ContentBlockDelta::Text(text_out)) => {
817 let completion_event =
818 LanguageModelCompletionEvent::Text(text_out);
819 return Some((Some(Ok(completion_event)), state));
820 }
821
822 Some(ContentBlockDelta::ToolUse(text_out)) => {
823 if let Some(tool_use) = state
824 .tool_uses_by_index
825 .get_mut(&cb_delta.content_block_index)
826 {
827 tool_use.input_json.push_str(text_out.input());
828 }
829 }
830
831 Some(ContentBlockDelta::ReasoningContent(thinking)) => {
832 match thinking {
833 ReasoningContentBlockDelta::RedactedContent(
834 redacted,
835 ) => {
836 let thinking_event =
837 LanguageModelCompletionEvent::Thinking {
838 text: String::from_utf8(
839 redacted.into_inner(),
840 )
841 .unwrap_or("REDACTED".to_string()),
842 signature: None,
843 };
844
845 return Some((
846 Some(Ok(thinking_event)),
847 state,
848 ));
849 }
850 ReasoningContentBlockDelta::Signature(
851 signature,
852 ) => {
853 return Some((
854 Some(Ok(LanguageModelCompletionEvent::Thinking {
855 text: "".to_string(),
856 signature: Some(signature)
857 })),
858 state,
859 ));
860 }
861 ReasoningContentBlockDelta::Text(thoughts) => {
862 let thinking_event =
863 LanguageModelCompletionEvent::Thinking {
864 text: thoughts.to_string(),
865 signature: None
866 };
867
868 return Some((
869 Some(Ok(thinking_event)),
870 state,
871 ));
872 }
873 _ => {}
874 }
875 }
876 _ => {}
877 }
878 }
879 ConverseStreamOutput::ContentBlockStart(cb_start) => {
880 if let Some(ContentBlockStart::ToolUse(text_out)) =
881 cb_start.start
882 {
883 let tool_use = RawToolUse {
884 id: text_out.tool_use_id,
885 name: text_out.name,
886 input_json: String::new(),
887 };
888
889 state
890 .tool_uses_by_index
891 .insert(cb_start.content_block_index, tool_use);
892 }
893 }
894 ConverseStreamOutput::ContentBlockStop(cb_stop) => {
895 if let Some(tool_use) = state
896 .tool_uses_by_index
897 .remove(&cb_stop.content_block_index)
898 {
899 let tool_use_event = LanguageModelToolUse {
900 id: tool_use.id.into(),
901 name: tool_use.name.into(),
902 is_input_complete: true,
903 raw_input: tool_use.input_json.clone(),
904 input: if tool_use.input_json.is_empty() {
905 Value::Null
906 } else {
907 serde_json::Value::from_str(
908 &tool_use.input_json,
909 )
910 .map_err(|err| anyhow!(err))
911 .unwrap()
912 },
913 };
914
915 return Some((
916 Some(Ok(LanguageModelCompletionEvent::ToolUse(
917 tool_use_event,
918 ))),
919 state,
920 ));
921 }
922 }
923
924 ConverseStreamOutput::Metadata(cb_meta) => {
925 if let Some(metadata) = cb_meta.usage {
926 let completion_event =
927 LanguageModelCompletionEvent::UsageUpdate(
928 TokenUsage {
929 input_tokens: metadata.input_tokens as u32,
930 output_tokens: metadata.output_tokens
931 as u32,
932 cache_creation_input_tokens: default(),
933 cache_read_input_tokens: default(),
934 },
935 );
936 return Some((Some(Ok(completion_event)), state));
937 }
938 }
939 ConverseStreamOutput::MessageStop(message_stop) => {
940 let reason = match message_stop.stop_reason {
941 StopReason::ContentFiltered => {
942 LanguageModelCompletionEvent::Stop(
943 language_model::StopReason::EndTurn,
944 )
945 }
946 StopReason::EndTurn => {
947 LanguageModelCompletionEvent::Stop(
948 language_model::StopReason::EndTurn,
949 )
950 }
951 StopReason::GuardrailIntervened => {
952 LanguageModelCompletionEvent::Stop(
953 language_model::StopReason::EndTurn,
954 )
955 }
956 StopReason::MaxTokens => {
957 LanguageModelCompletionEvent::Stop(
958 language_model::StopReason::EndTurn,
959 )
960 }
961 StopReason::StopSequence => {
962 LanguageModelCompletionEvent::Stop(
963 language_model::StopReason::EndTurn,
964 )
965 }
966 StopReason::ToolUse => {
967 LanguageModelCompletionEvent::Stop(
968 language_model::StopReason::ToolUse,
969 )
970 }
971 _ => LanguageModelCompletionEvent::Stop(
972 language_model::StopReason::EndTurn,
973 ),
974 };
975 return Some((Some(Ok(reason)), state));
976 }
977 _ => {}
978 },
979
980 Err(err) => return Some((Some(Err(anyhow!(err).into())), state)),
981 }
982 }
983 None
984 })
985 .await
986 .log_err()
987 .flatten()
988 }
989 },
990 )
991 .filter_map(|event| async move { event })
992}
993
994struct ConfigurationView {
995 access_key_id_editor: Entity<Editor>,
996 secret_access_key_editor: Entity<Editor>,
997 session_token_editor: Entity<Editor>,
998 region_editor: Entity<Editor>,
999 state: gpui::Entity<State>,
1000 load_credentials_task: Option<Task<()>>,
1001}
1002
1003impl ConfigurationView {
1004 const PLACEHOLDER_ACCESS_KEY_ID_TEXT: &'static str = "XXXXXXXXXXXXXXXX";
1005 const PLACEHOLDER_SECRET_ACCESS_KEY_TEXT: &'static str =
1006 "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX";
1007 const PLACEHOLDER_SESSION_TOKEN_TEXT: &'static str = "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX";
1008 const PLACEHOLDER_REGION: &'static str = "us-east-1";
1009
1010 fn new(state: gpui::Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
1011 cx.observe(&state, |_, _, cx| {
1012 cx.notify();
1013 })
1014 .detach();
1015
1016 let load_credentials_task = Some(cx.spawn({
1017 let state = state.clone();
1018 async move |this, cx| {
1019 if let Some(task) = state
1020 .update(cx, |state, cx| state.authenticate(cx))
1021 .log_err()
1022 {
1023 // We don't log an error, because "not signed in" is also an error.
1024 let _ = task.await;
1025 }
1026 this.update(cx, |this, cx| {
1027 this.load_credentials_task = None;
1028 cx.notify();
1029 })
1030 .log_err();
1031 }
1032 }));
1033
1034 Self {
1035 access_key_id_editor: cx.new(|cx| {
1036 let mut editor = Editor::single_line(window, cx);
1037 editor.set_placeholder_text(Self::PLACEHOLDER_ACCESS_KEY_ID_TEXT, cx);
1038 editor
1039 }),
1040 secret_access_key_editor: cx.new(|cx| {
1041 let mut editor = Editor::single_line(window, cx);
1042 editor.set_placeholder_text(Self::PLACEHOLDER_SECRET_ACCESS_KEY_TEXT, cx);
1043 editor
1044 }),
1045 session_token_editor: cx.new(|cx| {
1046 let mut editor = Editor::single_line(window, cx);
1047 editor.set_placeholder_text(Self::PLACEHOLDER_SESSION_TOKEN_TEXT, cx);
1048 editor
1049 }),
1050 region_editor: cx.new(|cx| {
1051 let mut editor = Editor::single_line(window, cx);
1052 editor.set_placeholder_text(Self::PLACEHOLDER_REGION, cx);
1053 editor
1054 }),
1055 state,
1056 load_credentials_task,
1057 }
1058 }
1059
1060 fn save_credentials(
1061 &mut self,
1062 _: &menu::Confirm,
1063 _window: &mut Window,
1064 cx: &mut Context<Self>,
1065 ) {
1066 let access_key_id = self
1067 .access_key_id_editor
1068 .read(cx)
1069 .text(cx)
1070 .to_string()
1071 .trim()
1072 .to_string();
1073 let secret_access_key = self
1074 .secret_access_key_editor
1075 .read(cx)
1076 .text(cx)
1077 .to_string()
1078 .trim()
1079 .to_string();
1080 let session_token = self
1081 .session_token_editor
1082 .read(cx)
1083 .text(cx)
1084 .to_string()
1085 .trim()
1086 .to_string();
1087 let session_token = if session_token.is_empty() {
1088 None
1089 } else {
1090 Some(session_token)
1091 };
1092 let region = self
1093 .region_editor
1094 .read(cx)
1095 .text(cx)
1096 .to_string()
1097 .trim()
1098 .to_string();
1099 let region = if region.is_empty() {
1100 "us-east-1".to_string()
1101 } else {
1102 region
1103 };
1104
1105 let state = self.state.clone();
1106 cx.spawn(async move |_, cx| {
1107 state
1108 .update(cx, |state, cx| {
1109 let credentials: BedrockCredentials = BedrockCredentials {
1110 region: region.clone(),
1111 access_key_id: access_key_id.clone(),
1112 secret_access_key: secret_access_key.clone(),
1113 session_token: session_token.clone(),
1114 };
1115
1116 state.set_credentials(credentials, cx)
1117 })?
1118 .await
1119 })
1120 .detach_and_log_err(cx);
1121 }
1122
1123 fn reset_credentials(&mut self, window: &mut Window, cx: &mut Context<Self>) {
1124 self.access_key_id_editor
1125 .update(cx, |editor, cx| editor.set_text("", window, cx));
1126 self.secret_access_key_editor
1127 .update(cx, |editor, cx| editor.set_text("", window, cx));
1128 self.session_token_editor
1129 .update(cx, |editor, cx| editor.set_text("", window, cx));
1130 self.region_editor
1131 .update(cx, |editor, cx| editor.set_text("", window, cx));
1132
1133 let state = self.state.clone();
1134 cx.spawn(async move |_, cx| {
1135 state
1136 .update(cx, |state, cx| state.reset_credentials(cx))?
1137 .await
1138 })
1139 .detach_and_log_err(cx);
1140 }
1141
1142 fn make_text_style(&self, cx: &Context<Self>) -> TextStyle {
1143 let settings = ThemeSettings::get_global(cx);
1144 TextStyle {
1145 color: cx.theme().colors().text,
1146 font_family: settings.ui_font.family.clone(),
1147 font_features: settings.ui_font.features.clone(),
1148 font_fallbacks: settings.ui_font.fallbacks.clone(),
1149 font_size: rems(0.875).into(),
1150 font_weight: settings.ui_font.weight,
1151 font_style: FontStyle::Normal,
1152 line_height: relative(1.3),
1153 background_color: None,
1154 underline: None,
1155 strikethrough: None,
1156 white_space: WhiteSpace::Normal,
1157 text_overflow: None,
1158 text_align: Default::default(),
1159 line_clamp: None,
1160 }
1161 }
1162
1163 fn make_input_styles(&self, cx: &Context<Self>) -> Div {
1164 let bg_color = cx.theme().colors().editor_background;
1165 let border_color = cx.theme().colors().border;
1166
1167 h_flex()
1168 .w_full()
1169 .px_2()
1170 .py_1()
1171 .bg(bg_color)
1172 .border_1()
1173 .border_color(border_color)
1174 .rounded_sm()
1175 }
1176
1177 fn should_render_editor(&self, cx: &mut Context<Self>) -> Option<String> {
1178 self.state.read(cx).is_authenticated()
1179 }
1180}
1181
1182impl Render for ConfigurationView {
1183 fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
1184 let env_var_set = self.state.read(cx).credentials_from_env;
1185 let creds_type = self.should_render_editor(cx).is_some();
1186
1187 if self.load_credentials_task.is_some() {
1188 return div().child(Label::new("Loading credentials...")).into_any();
1189 }
1190
1191 if let Some(auth) = self.should_render_editor(cx) {
1192 return h_flex()
1193 .mt_1()
1194 .p_1()
1195 .justify_between()
1196 .rounded_md()
1197 .border_1()
1198 .border_color(cx.theme().colors().border)
1199 .bg(cx.theme().colors().background)
1200 .child(
1201 h_flex()
1202 .gap_1()
1203 .child(Icon::new(IconName::Check).color(Color::Success))
1204 .child(Label::new(if env_var_set {
1205 format!("Access Key ID is set in {ZED_BEDROCK_ACCESS_KEY_ID_VAR}, Secret Key is set in {ZED_BEDROCK_SECRET_ACCESS_KEY_VAR}, Region is set in {ZED_BEDROCK_REGION_VAR} environment variables.")
1206 } else {
1207 auth.clone()
1208 })),
1209 )
1210 .child(
1211 Button::new("reset-key", "Reset Key")
1212 .icon(Some(IconName::Trash))
1213 .icon_size(IconSize::Small)
1214 .icon_position(IconPosition::Start)
1215 // .disabled(env_var_set || creds_type)
1216 .when(env_var_set, |this| {
1217 this.tooltip(Tooltip::text(format!("To reset your credentials, unset the {ZED_BEDROCK_ACCESS_KEY_ID_VAR}, {ZED_BEDROCK_SECRET_ACCESS_KEY_VAR}, and {ZED_BEDROCK_REGION_VAR} environment variables.")))
1218 })
1219 .when(creds_type, |this| {
1220 this.tooltip(Tooltip::text("You cannot reset credentials as they're being derived, check Zed settings to understand how."))
1221 })
1222 .on_click(cx.listener(|this, _, window, cx| this.reset_credentials(window, cx))),
1223 )
1224 .into_any();
1225 }
1226
1227 v_flex()
1228 .size_full()
1229 .on_action(cx.listener(ConfigurationView::save_credentials))
1230 .child(Label::new("To use Zed's assistant with Bedrock, you can set a custom authentication strategy through the settings.json, or use static credentials."))
1231 .child(Label::new("But, to access models on AWS, you need to:").mt_1())
1232 .child(
1233 List::new()
1234 .child(
1235 InstructionListItem::new(
1236 "Grant permissions to the strategy you'll use according to the:",
1237 Some("Prerequisites"),
1238 Some("https://docs.aws.amazon.com/bedrock/latest/userguide/inference-prereq.html"),
1239 )
1240 )
1241 .child(
1242 InstructionListItem::new(
1243 "Select the models you would like access to:",
1244 Some("Bedrock Model Catalog"),
1245 Some("https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/modelaccess"),
1246 )
1247 )
1248 )
1249 .child(self.render_static_credentials_ui(cx))
1250 .child(self.render_common_fields(cx))
1251 .child(
1252 Label::new(
1253 format!("You can also assign the {ZED_BEDROCK_ACCESS_KEY_ID_VAR}, {ZED_BEDROCK_SECRET_ACCESS_KEY_VAR} AND {ZED_BEDROCK_REGION_VAR} environment variables and restart Zed."),
1254 )
1255 .size(LabelSize::Small)
1256 .color(Color::Muted)
1257 .my_1(),
1258 )
1259 .child(
1260 Label::new(
1261 format!("Optionally, if your environment uses AWS CLI profiles, you can set {ZED_AWS_PROFILE_VAR}; if it requires a custom endpoint, you can set {ZED_AWS_ENDPOINT_VAR}; and if it requires a Session Token, you can set {ZED_BEDROCK_SESSION_TOKEN_VAR}."),
1262 )
1263 .size(LabelSize::Small)
1264 .color(Color::Muted),
1265 )
1266 .into_any()
1267 }
1268}
1269
1270impl ConfigurationView {
1271 fn render_access_key_id_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
1272 let text_style = self.make_text_style(cx);
1273
1274 EditorElement::new(
1275 &self.access_key_id_editor,
1276 EditorStyle {
1277 background: cx.theme().colors().editor_background,
1278 local_player: cx.theme().players().local(),
1279 text: text_style,
1280 ..Default::default()
1281 },
1282 )
1283 }
1284
1285 fn render_secret_key_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
1286 let text_style = self.make_text_style(cx);
1287
1288 EditorElement::new(
1289 &self.secret_access_key_editor,
1290 EditorStyle {
1291 background: cx.theme().colors().editor_background,
1292 local_player: cx.theme().players().local(),
1293 text: text_style,
1294 ..Default::default()
1295 },
1296 )
1297 }
1298
1299 fn render_session_token_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
1300 let text_style = self.make_text_style(cx);
1301
1302 EditorElement::new(
1303 &self.session_token_editor,
1304 EditorStyle {
1305 background: cx.theme().colors().editor_background,
1306 local_player: cx.theme().players().local(),
1307 text: text_style,
1308 ..Default::default()
1309 },
1310 )
1311 }
1312
1313 fn render_region_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
1314 let text_style = self.make_text_style(cx);
1315
1316 EditorElement::new(
1317 &self.region_editor,
1318 EditorStyle {
1319 background: cx.theme().colors().editor_background,
1320 local_player: cx.theme().players().local(),
1321 text: text_style,
1322 ..Default::default()
1323 },
1324 )
1325 }
1326
1327 fn render_static_credentials_ui(&self, cx: &mut Context<Self>) -> AnyElement {
1328 v_flex()
1329 .my_2()
1330 .gap_1p5()
1331 .child(
1332 Label::new("Static Keys")
1333 .size(LabelSize::Default)
1334 .weight(FontWeight::BOLD),
1335 )
1336 .child(
1337 Label::new(
1338 "This method uses your AWS access key ID and secret access key directly.",
1339 )
1340 )
1341 .child(
1342 List::new()
1343 .child(InstructionListItem::new(
1344 "Create an IAM user in the AWS console with programmatic access",
1345 Some("IAM Console"),
1346 Some("https://us-east-1.console.aws.amazon.com/iam/home?region=us-east-1#/users"),
1347 ))
1348 .child(InstructionListItem::new(
1349 "Attach the necessary Bedrock permissions to this ",
1350 Some("user"),
1351 Some("https://docs.aws.amazon.com/bedrock/latest/userguide/inference-prereq.html"),
1352 ))
1353 .child(InstructionListItem::text_only(
1354 "Copy the access key ID and secret access key when provided",
1355 ))
1356 .child(InstructionListItem::text_only(
1357 "Enter these credentials below",
1358 )),
1359 )
1360 .child(
1361 v_flex()
1362 .gap_0p5()
1363 .child(Label::new("Access Key ID").size(LabelSize::Small))
1364 .child(
1365 self.make_input_styles(cx)
1366 .child(self.render_access_key_id_editor(cx)),
1367 ),
1368 )
1369 .child(
1370 v_flex()
1371 .gap_0p5()
1372 .child(Label::new("Secret Access Key").size(LabelSize::Small))
1373 .child(self.make_input_styles(cx).child(self.render_secret_key_editor(cx))),
1374 )
1375 .child(
1376 v_flex()
1377 .gap_0p5()
1378 .child(Label::new("Session Token (Optional)").size(LabelSize::Small))
1379 .child(
1380 self.make_input_styles(cx)
1381 .child(self.render_session_token_editor(cx)),
1382 ),
1383 )
1384 .into_any_element()
1385 }
1386
1387 fn render_common_fields(&self, cx: &mut Context<Self>) -> AnyElement {
1388 v_flex()
1389 .gap_0p5()
1390 .child(Label::new("Region").size(LabelSize::Small))
1391 .child(
1392 self.make_input_styles(cx)
1393 .child(self.render_region_editor(cx)),
1394 )
1395 .into_any_element()
1396 }
1397}