1use std::pin::Pin;
2use std::str::FromStr;
3use std::sync::Arc;
4
5use crate::ui::InstructionListItem;
6use anyhow::{anyhow, Context as _, Result};
7use aws_config::stalled_stream_protection::StalledStreamProtectionConfig;
8use aws_config::Region;
9use aws_credential_types::Credentials;
10use aws_http_client::AwsHttpClient;
11use bedrock::bedrock_client::types::{
12 ContentBlockDelta, ContentBlockStart, ContentBlockStartEvent, ConverseStreamOutput,
13};
14use bedrock::bedrock_client::{self, Config};
15use bedrock::{
16 value_to_aws_document, BedrockError, BedrockInnerContent, BedrockMessage, BedrockSpecificTool,
17 BedrockStreamingResponse, BedrockTool, BedrockToolChoice, BedrockToolInputSchema, Model,
18};
19use collections::{BTreeMap, HashMap};
20use credentials_provider::CredentialsProvider;
21use editor::{Editor, EditorElement, EditorStyle};
22use futures::{future::BoxFuture, stream::BoxStream, FutureExt, Stream, StreamExt};
23use gpui::{
24 AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace,
25};
26use gpui_tokio::Tokio;
27use http_client::HttpClient;
28use language_model::{
29 AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
30 LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
31 LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
32 LanguageModelRequest, LanguageModelToolUse, MessageContent, RateLimiter, Role,
33};
34use schemars::JsonSchema;
35use serde::{Deserialize, Serialize};
36use serde_json::Value;
37use settings::{Settings, SettingsStore};
38use strum::IntoEnumIterator;
39use theme::ThemeSettings;
40use tokio::runtime::Handle;
41use ui::{prelude::*, Icon, IconName, List, Tooltip};
42use util::{maybe, ResultExt};
43
44use crate::AllLanguageModelSettings;
45
46const PROVIDER_ID: &str = "amazon-bedrock";
47const PROVIDER_NAME: &str = "Amazon Bedrock";
48
49#[derive(Default, Clone, Deserialize, Serialize, PartialEq, Debug)]
50pub struct BedrockCredentials {
51 pub region: String,
52 pub access_key_id: String,
53 pub secret_access_key: String,
54}
55
56#[derive(Default, Clone, Debug, PartialEq)]
57pub struct AmazonBedrockSettings {
58 pub session_token: Option<String>,
59 pub available_models: Vec<AvailableModel>,
60}
61
62#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
63pub struct AvailableModel {
64 pub name: String,
65 pub display_name: Option<String>,
66 pub max_tokens: usize,
67 pub cache_configuration: Option<LanguageModelCacheConfiguration>,
68 pub max_output_tokens: Option<u32>,
69 pub default_temperature: Option<f32>,
70}
71
72/// The URL of the base AWS service.
73///
74/// Right now we're just using this as the key to store the AWS credentials
75/// under in the keychain.
76const AMAZON_AWS_URL: &str = "https://amazonaws.com";
77
78// These environment variables all use a `ZED_` prefix because we don't want to overwrite the user's AWS credentials.
79const ZED_BEDROCK_ACCESS_KEY_ID_VAR: &str = "ZED_ACCESS_KEY_ID";
80const ZED_BEDROCK_SECRET_ACCESS_KEY_VAR: &str = "ZED_SECRET_ACCESS_KEY";
81const ZED_BEDROCK_REGION_VAR: &str = "ZED_AWS_REGION";
82const ZED_AWS_CREDENTIALS_VAR: &str = "ZED_AWS_CREDENTIALS";
83
84pub struct State {
85 credentials: Option<BedrockCredentials>,
86 credentials_from_env: bool,
87 _subscription: Subscription,
88}
89
90impl State {
91 fn reset_credentials(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
92 let credentials_provider = <dyn CredentialsProvider>::global(cx);
93 cx.spawn(async move |this, cx| {
94 credentials_provider
95 .delete_credentials(AMAZON_AWS_URL, &cx)
96 .await
97 .log_err();
98 this.update(cx, |this, cx| {
99 this.credentials = None;
100 this.credentials_from_env = false;
101 cx.notify();
102 })
103 })
104 }
105
106 fn set_credentials(
107 &mut self,
108 credentials: BedrockCredentials,
109 cx: &mut Context<Self>,
110 ) -> Task<Result<()>> {
111 let credentials_provider = <dyn CredentialsProvider>::global(cx);
112 cx.spawn(async move |this, cx| {
113 credentials_provider
114 .write_credentials(
115 AMAZON_AWS_URL,
116 "Bearer",
117 &serde_json::to_vec(&credentials)?,
118 &cx,
119 )
120 .await?;
121 this.update(cx, |this, cx| {
122 this.credentials = Some(credentials);
123 cx.notify();
124 })
125 })
126 }
127
128 fn is_authenticated(&self) -> bool {
129 self.credentials.is_some()
130 }
131
132 fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
133 if self.is_authenticated() {
134 return Task::ready(Ok(()));
135 }
136
137 let credentials_provider = <dyn CredentialsProvider>::global(cx);
138 cx.spawn(async move |this, cx| {
139 let (credentials, from_env) =
140 if let Ok(credentials) = std::env::var(ZED_AWS_CREDENTIALS_VAR) {
141 (credentials, true)
142 } else {
143 let (_, credentials) = credentials_provider
144 .read_credentials(AMAZON_AWS_URL, &cx)
145 .await?
146 .ok_or_else(|| AuthenticateError::CredentialsNotFound)?;
147 (
148 String::from_utf8(credentials)
149 .context("invalid {PROVIDER_NAME} credentials")?,
150 false,
151 )
152 };
153
154 let credentials: BedrockCredentials =
155 serde_json::from_str(&credentials).context("failed to parse credentials")?;
156
157 this.update(cx, |this, cx| {
158 this.credentials = Some(credentials);
159 this.credentials_from_env = from_env;
160 cx.notify();
161 })?;
162
163 Ok(())
164 })
165 }
166}
167
168pub struct BedrockLanguageModelProvider {
169 http_client: AwsHttpClient,
170 handler: tokio::runtime::Handle,
171 state: gpui::Entity<State>,
172}
173
174impl BedrockLanguageModelProvider {
175 pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
176 let state = cx.new(|cx| State {
177 credentials: None,
178 credentials_from_env: false,
179 _subscription: cx.observe_global::<SettingsStore>(|_, cx| {
180 cx.notify();
181 }),
182 });
183
184 let tokio_handle = Tokio::handle(cx);
185
186 let coerced_client = AwsHttpClient::new(http_client.clone(), tokio_handle.clone());
187
188 Self {
189 http_client: coerced_client,
190 handler: tokio_handle.clone(),
191 state,
192 }
193 }
194}
195
196impl LanguageModelProvider for BedrockLanguageModelProvider {
197 fn id(&self) -> LanguageModelProviderId {
198 LanguageModelProviderId(PROVIDER_ID.into())
199 }
200
201 fn name(&self) -> LanguageModelProviderName {
202 LanguageModelProviderName(PROVIDER_NAME.into())
203 }
204
205 fn icon(&self) -> IconName {
206 IconName::AiBedrock
207 }
208
209 fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
210 let model = bedrock::Model::default();
211 Some(Arc::new(BedrockModel {
212 id: LanguageModelId::from(model.id().to_string()),
213 model,
214 http_client: self.http_client.clone(),
215 handler: self.handler.clone(),
216 state: self.state.clone(),
217 request_limiter: RateLimiter::new(4),
218 }))
219 }
220
221 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
222 let mut models = BTreeMap::default();
223
224 for model in bedrock::Model::iter() {
225 if !matches!(model, bedrock::Model::Custom { .. }) {
226 models.insert(model.id().to_string(), model);
227 }
228 }
229
230 // Override with available models from settings
231 for model in AllLanguageModelSettings::get_global(cx)
232 .bedrock
233 .available_models
234 .iter()
235 {
236 models.insert(
237 model.name.clone(),
238 bedrock::Model::Custom {
239 name: model.name.clone(),
240 display_name: model.display_name.clone(),
241 max_tokens: model.max_tokens,
242 max_output_tokens: model.max_output_tokens,
243 default_temperature: model.default_temperature,
244 },
245 );
246 }
247
248 models
249 .into_values()
250 .map(|model| {
251 Arc::new(BedrockModel {
252 id: LanguageModelId::from(model.id().to_string()),
253 model,
254 http_client: self.http_client.clone(),
255 handler: self.handler.clone(),
256 state: self.state.clone(),
257 request_limiter: RateLimiter::new(4),
258 }) as Arc<dyn LanguageModel>
259 })
260 .collect()
261 }
262
263 fn is_authenticated(&self, cx: &App) -> bool {
264 self.state.read(cx).is_authenticated()
265 }
266
267 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
268 self.state.update(cx, |state, cx| state.authenticate(cx))
269 }
270
271 fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView {
272 cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
273 .into()
274 }
275
276 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
277 self.state
278 .update(cx, |state, cx| state.reset_credentials(cx))
279 }
280}
281
282impl LanguageModelProviderState for BedrockLanguageModelProvider {
283 type ObservableEntity = State;
284
285 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
286 Some(self.state.clone())
287 }
288}
289
290struct BedrockModel {
291 id: LanguageModelId,
292 model: Model,
293 http_client: AwsHttpClient,
294 handler: tokio::runtime::Handle,
295 state: gpui::Entity<State>,
296 request_limiter: RateLimiter,
297}
298
299impl BedrockModel {
300 fn stream_completion(
301 &self,
302 request: bedrock::Request,
303 cx: &AsyncApp,
304 ) -> Result<
305 BoxFuture<'static, BoxStream<'static, Result<BedrockStreamingResponse, BedrockError>>>,
306 > {
307 let Ok(Ok((access_key_id, secret_access_key, region))) =
308 cx.read_entity(&self.state, |state, _cx| {
309 if let Some(credentials) = &state.credentials {
310 Ok((
311 credentials.access_key_id.clone(),
312 credentials.secret_access_key.clone(),
313 credentials.region.clone(),
314 ))
315 } else {
316 return Err(anyhow!("Failed to read credentials"));
317 }
318 })
319 else {
320 return Err(anyhow!("App state dropped"));
321 };
322
323 let runtime_client = bedrock_client::Client::from_conf(
324 Config::builder()
325 .stalled_stream_protection(StalledStreamProtectionConfig::disabled())
326 .credentials_provider(Credentials::new(
327 access_key_id,
328 secret_access_key,
329 None,
330 None,
331 "Keychain",
332 ))
333 .region(Region::new(region))
334 .http_client(self.http_client.clone())
335 .build(),
336 );
337
338 let owned_handle = self.handler.clone();
339
340 Ok(async move {
341 let request = bedrock::stream_completion(runtime_client, request, owned_handle);
342 request.await.unwrap_or_else(|e| {
343 futures::stream::once(async move { Err(BedrockError::ClientError(e)) }).boxed()
344 })
345 }
346 .boxed())
347 }
348}
349
350impl LanguageModel for BedrockModel {
351 fn id(&self) -> LanguageModelId {
352 self.id.clone()
353 }
354
355 fn name(&self) -> LanguageModelName {
356 LanguageModelName::from(self.model.display_name().to_string())
357 }
358
359 fn provider_id(&self) -> LanguageModelProviderId {
360 LanguageModelProviderId(PROVIDER_ID.into())
361 }
362
363 fn provider_name(&self) -> LanguageModelProviderName {
364 LanguageModelProviderName(PROVIDER_NAME.into())
365 }
366
367 fn telemetry_id(&self) -> String {
368 format!("bedrock/{}", self.model.id())
369 }
370
371 fn max_token_count(&self) -> usize {
372 self.model.max_token_count()
373 }
374
375 fn max_output_tokens(&self) -> Option<u32> {
376 Some(self.model.max_output_tokens())
377 }
378
379 fn count_tokens(
380 &self,
381 request: LanguageModelRequest,
382 cx: &App,
383 ) -> BoxFuture<'static, Result<usize>> {
384 get_bedrock_tokens(request, cx)
385 }
386
387 fn stream_completion(
388 &self,
389 request: LanguageModelRequest,
390 cx: &AsyncApp,
391 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
392 let request = into_bedrock(
393 request,
394 self.model.id().into(),
395 self.model.default_temperature(),
396 self.model.max_output_tokens(),
397 );
398
399 let owned_handle = self.handler.clone();
400
401 let request = self.stream_completion(request, cx);
402 let future = self.request_limiter.stream(async move {
403 let response = request.map_err(|err| anyhow!(err))?.await;
404 Ok(map_to_language_model_completion_events(
405 response,
406 owned_handle,
407 ))
408 });
409 async move { Ok(future.await?.boxed()) }.boxed()
410 }
411
412 fn use_any_tool(
413 &self,
414 request: LanguageModelRequest,
415 name: String,
416 description: String,
417 schema: Value,
418 _cx: &AsyncApp,
419 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
420 let mut request = into_bedrock(
421 request,
422 self.model.id().into(),
423 self.model.default_temperature(),
424 self.model.max_output_tokens(),
425 );
426
427 request.tool_choice = BedrockSpecificTool::builder()
428 .name(name.clone())
429 .build()
430 .log_err()
431 .map(BedrockToolChoice::Tool);
432
433 if let Some(tool) = BedrockTool::builder()
434 .name(name.clone())
435 .description(description.clone())
436 .input_schema(BedrockToolInputSchema::Json(value_to_aws_document(&schema)))
437 .build()
438 .log_err()
439 {
440 request.tools.push(tool);
441 }
442
443 let handle = self.handler.clone();
444
445 let request = self.stream_completion(request, _cx);
446 self.request_limiter
447 .run(async move {
448 let response = request.map_err(|err| anyhow!(err))?.await;
449 Ok(extract_tool_args_from_events(name, response, handle)
450 .await?
451 .boxed())
452 })
453 .boxed()
454 }
455
456 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
457 None
458 }
459}
460
461pub fn into_bedrock(
462 request: LanguageModelRequest,
463 model: String,
464 default_temperature: f32,
465 max_output_tokens: u32,
466) -> bedrock::Request {
467 let mut new_messages: Vec<BedrockMessage> = Vec::new();
468 let mut system_message = String::new();
469
470 for message in request.messages {
471 if message.contents_empty() {
472 continue;
473 }
474
475 match message.role {
476 Role::User | Role::Assistant => {
477 let bedrock_message_content: Vec<BedrockInnerContent> = message
478 .content
479 .into_iter()
480 .filter_map(|content| match content {
481 MessageContent::Text(text) => {
482 if !text.is_empty() {
483 Some(BedrockInnerContent::Text(text))
484 } else {
485 None
486 }
487 }
488 _ => None,
489 })
490 .collect();
491 let bedrock_role = match message.role {
492 Role::User => bedrock::BedrockRole::User,
493 Role::Assistant => bedrock::BedrockRole::Assistant,
494 Role::System => unreachable!("System role should never occur here"),
495 };
496 if let Some(last_message) = new_messages.last_mut() {
497 if last_message.role == bedrock_role {
498 last_message.content.extend(bedrock_message_content);
499 continue;
500 }
501 }
502 new_messages.push(
503 BedrockMessage::builder()
504 .role(bedrock_role)
505 .set_content(Some(bedrock_message_content))
506 .build()
507 .expect("failed to build Bedrock message"),
508 );
509 }
510 Role::System => {
511 if !system_message.is_empty() {
512 system_message.push_str("\n\n");
513 }
514 system_message.push_str(&message.string_contents());
515 }
516 }
517 }
518
519 bedrock::Request {
520 model,
521 messages: new_messages,
522 max_tokens: max_output_tokens,
523 system: Some(system_message),
524 tools: vec![],
525 tool_choice: None,
526 metadata: None,
527 stop_sequences: Vec::new(),
528 temperature: request.temperature.or(Some(default_temperature)),
529 top_k: None,
530 top_p: None,
531 }
532}
533
534// TODO: just call the ConverseOutput.usage() method:
535// https://docs.rs/aws-sdk-bedrockruntime/latest/aws_sdk_bedrockruntime/operation/converse/struct.ConverseOutput.html#method.output
536pub fn get_bedrock_tokens(
537 request: LanguageModelRequest,
538 cx: &App,
539) -> BoxFuture<'static, Result<usize>> {
540 cx.background_executor()
541 .spawn(async move {
542 let messages = request.messages;
543 let mut tokens_from_images = 0;
544 let mut string_messages = Vec::with_capacity(messages.len());
545
546 for message in messages {
547 use language_model::MessageContent;
548
549 let mut string_contents = String::new();
550
551 for content in message.content {
552 match content {
553 MessageContent::Text(text) => {
554 string_contents.push_str(&text);
555 }
556 MessageContent::Image(image) => {
557 tokens_from_images += image.estimate_tokens();
558 }
559 MessageContent::ToolUse(_tool_use) => {
560 // TODO: Estimate token usage from tool uses.
561 }
562 MessageContent::ToolResult(tool_result) => {
563 string_contents.push_str(&tool_result.content);
564 }
565 }
566 }
567
568 if !string_contents.is_empty() {
569 string_messages.push(tiktoken_rs::ChatCompletionRequestMessage {
570 role: match message.role {
571 Role::User => "user".into(),
572 Role::Assistant => "assistant".into(),
573 Role::System => "system".into(),
574 },
575 content: Some(string_contents),
576 name: None,
577 function_call: None,
578 });
579 }
580 }
581
582 // Tiktoken doesn't yet support these models, so we manually use the
583 // same tokenizer as GPT-4.
584 tiktoken_rs::num_tokens_from_messages("gpt-4", &string_messages)
585 .map(|tokens| tokens + tokens_from_images)
586 })
587 .boxed()
588}
589
590pub async fn extract_tool_args_from_events(
591 name: String,
592 mut events: Pin<Box<dyn Send + Stream<Item = Result<BedrockStreamingResponse, BedrockError>>>>,
593 handle: Handle,
594) -> Result<impl Send + Stream<Item = Result<String>>> {
595 handle
596 .spawn(async move {
597 let mut tool_use_index = None;
598 while let Some(event) = events.next().await {
599 if let BedrockStreamingResponse::ContentBlockStart(ContentBlockStartEvent {
600 content_block_index,
601 start,
602 ..
603 }) = event?
604 {
605 match start {
606 None => {
607 continue;
608 }
609 Some(start) => match start.as_tool_use() {
610 Ok(tool_use) => {
611 if name == tool_use.name {
612 tool_use_index = Some(content_block_index);
613 break;
614 }
615 }
616 Err(err) => {
617 return Err(anyhow!("Failed to parse tool use event: {:?}", err));
618 }
619 },
620 }
621 }
622 }
623
624 let Some(tool_use_index) = tool_use_index else {
625 return Err(anyhow!("Tool is not used"));
626 };
627
628 Ok(events.filter_map(move |event| {
629 let result = match event {
630 Err(_err) => None,
631 Ok(output) => match output.clone() {
632 BedrockStreamingResponse::ContentBlockDelta(inner) => {
633 match inner.clone().delta {
634 Some(ContentBlockDelta::ToolUse(tool_use)) => {
635 if inner.content_block_index == tool_use_index {
636 Some(Ok(tool_use.input))
637 } else {
638 None
639 }
640 }
641 _ => None,
642 }
643 }
644 _ => None,
645 },
646 };
647
648 async move { result }
649 }))
650 })
651 .await?
652}
653
654pub fn map_to_language_model_completion_events(
655 events: Pin<Box<dyn Send + Stream<Item = Result<BedrockStreamingResponse, BedrockError>>>>,
656 handle: Handle,
657) -> impl Stream<Item = Result<LanguageModelCompletionEvent>> {
658 struct RawToolUse {
659 id: String,
660 name: String,
661 input_json: String,
662 }
663
664 struct State {
665 events: Pin<Box<dyn Send + Stream<Item = Result<BedrockStreamingResponse, BedrockError>>>>,
666 tool_uses_by_index: HashMap<i32, RawToolUse>,
667 }
668
669 futures::stream::unfold(
670 State {
671 events,
672 tool_uses_by_index: HashMap::default(),
673 },
674 move |mut state: State| {
675 let inner_handle = handle.clone();
676 async move {
677 inner_handle
678 .spawn(async {
679 while let Some(event) = state.events.next().await {
680 match event {
681 Ok(event) => match event {
682 ConverseStreamOutput::ContentBlockDelta(cb_delta) => {
683 if let Some(ContentBlockDelta::Text(text_out)) =
684 cb_delta.delta
685 {
686 return Some((
687 Some(Ok(LanguageModelCompletionEvent::Text(
688 text_out,
689 ))),
690 state,
691 ));
692 } else if let Some(ContentBlockDelta::ToolUse(text_out)) =
693 cb_delta.delta
694 {
695 if let Some(tool_use) = state
696 .tool_uses_by_index
697 .get_mut(&cb_delta.content_block_index)
698 {
699 tool_use.input_json.push_str(text_out.input());
700 return Some((None, state));
701 };
702
703 return Some((None, state));
704 } else if cb_delta.delta.is_none() {
705 return Some((None, state));
706 }
707 }
708 ConverseStreamOutput::ContentBlockStart(cb_start) => {
709 if let Some(start) = cb_start.start {
710 match start {
711 ContentBlockStart::ToolUse(text_out) => {
712 let tool_use = RawToolUse {
713 id: text_out.tool_use_id,
714 name: text_out.name,
715 input_json: String::new(),
716 };
717
718 state.tool_uses_by_index.insert(
719 cb_start.content_block_index,
720 tool_use,
721 );
722 }
723 _ => {}
724 }
725 }
726 }
727 ConverseStreamOutput::ContentBlockStop(cb_stop) => {
728 if let Some(tool_use) = state
729 .tool_uses_by_index
730 .remove(&cb_stop.content_block_index)
731 {
732 return Some((
733 Some(maybe!({
734 Ok(LanguageModelCompletionEvent::ToolUse(
735 LanguageModelToolUse {
736 id: tool_use.id.into(),
737 name: tool_use.name.into(),
738 input: if tool_use.input_json.is_empty()
739 {
740 Value::Null
741 } else {
742 serde_json::Value::from_str(
743 &tool_use.input_json,
744 )
745 .map_err(|err| anyhow!(err))?
746 },
747 },
748 ))
749 })),
750 state,
751 ));
752 }
753 }
754 _ => {}
755 },
756 Err(err) => return Some((Some(Err(anyhow!(err))), state)),
757 }
758 }
759 None
760 })
761 .await
762 .log_err()
763 .flatten()
764 }
765 },
766 )
767 .filter_map(|event| async move { event })
768}
769
770struct ConfigurationView {
771 access_key_id_editor: Entity<Editor>,
772 secret_access_key_editor: Entity<Editor>,
773 region_editor: Entity<Editor>,
774 state: gpui::Entity<State>,
775 load_credentials_task: Option<Task<()>>,
776}
777
778impl ConfigurationView {
779 const PLACEHOLDER_ACCESS_KEY_ID_TEXT: &'static str = "XXXXXXXXXXXXXXXX";
780 const PLACEHOLDER_SECRET_ACCESS_KEY_TEXT: &'static str =
781 "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX";
782 const PLACEHOLDER_REGION: &'static str = "us-east-1";
783
784 fn new(state: gpui::Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
785 cx.observe(&state, |_, _, cx| {
786 cx.notify();
787 })
788 .detach();
789
790 let load_credentials_task = Some(cx.spawn({
791 let state = state.clone();
792 async move |this, cx| {
793 if let Some(task) = state
794 .update(cx, |state, cx| state.authenticate(cx))
795 .log_err()
796 {
797 // We don't log an error, because "not signed in" is also an error.
798 let _ = task.await;
799 }
800 this.update(cx, |this, cx| {
801 this.load_credentials_task = None;
802 cx.notify();
803 })
804 .log_err();
805 }
806 }));
807
808 Self {
809 access_key_id_editor: cx.new(|cx| {
810 let mut editor = Editor::single_line(window, cx);
811 editor.set_placeholder_text(Self::PLACEHOLDER_ACCESS_KEY_ID_TEXT, cx);
812 editor
813 }),
814 secret_access_key_editor: cx.new(|cx| {
815 let mut editor = Editor::single_line(window, cx);
816 editor.set_placeholder_text(Self::PLACEHOLDER_SECRET_ACCESS_KEY_TEXT, cx);
817 editor
818 }),
819 region_editor: cx.new(|cx| {
820 let mut editor = Editor::single_line(window, cx);
821 editor.set_placeholder_text(Self::PLACEHOLDER_REGION, cx);
822 editor
823 }),
824 state,
825 load_credentials_task,
826 }
827 }
828
829 fn save_credentials(
830 &mut self,
831 _: &menu::Confirm,
832 _window: &mut Window,
833 cx: &mut Context<Self>,
834 ) {
835 let access_key_id = self
836 .access_key_id_editor
837 .read(cx)
838 .text(cx)
839 .to_string()
840 .trim()
841 .to_string();
842 let secret_access_key = self
843 .secret_access_key_editor
844 .read(cx)
845 .text(cx)
846 .to_string()
847 .trim()
848 .to_string();
849 let region = self
850 .region_editor
851 .read(cx)
852 .text(cx)
853 .to_string()
854 .trim()
855 .to_string();
856
857 let state = self.state.clone();
858 cx.spawn(async move |_, cx| {
859 state
860 .update(cx, |state, cx| {
861 let credentials: BedrockCredentials = BedrockCredentials {
862 access_key_id: access_key_id.clone(),
863 secret_access_key: secret_access_key.clone(),
864 region: region.clone(),
865 };
866
867 state.set_credentials(credentials, cx)
868 })?
869 .await
870 })
871 .detach_and_log_err(cx);
872 }
873
874 fn reset_credentials(&mut self, window: &mut Window, cx: &mut Context<Self>) {
875 self.access_key_id_editor
876 .update(cx, |editor, cx| editor.set_text("", window, cx));
877 self.secret_access_key_editor
878 .update(cx, |editor, cx| editor.set_text("", window, cx));
879 self.region_editor
880 .update(cx, |editor, cx| editor.set_text("", window, cx));
881
882 let state = self.state.clone();
883 cx.spawn(async move |_, cx| {
884 state
885 .update(cx, |state, cx| state.reset_credentials(cx))?
886 .await
887 })
888 .detach_and_log_err(cx);
889 }
890
891 fn make_text_style(&self, cx: &Context<Self>) -> TextStyle {
892 let settings = ThemeSettings::get_global(cx);
893 TextStyle {
894 color: cx.theme().colors().text,
895 font_family: settings.ui_font.family.clone(),
896 font_features: settings.ui_font.features.clone(),
897 font_fallbacks: settings.ui_font.fallbacks.clone(),
898 font_size: rems(0.875).into(),
899 font_weight: settings.ui_font.weight,
900 font_style: FontStyle::Normal,
901 line_height: relative(1.3),
902 background_color: None,
903 underline: None,
904 strikethrough: None,
905 white_space: WhiteSpace::Normal,
906 text_overflow: None,
907 text_align: Default::default(),
908 line_clamp: None,
909 }
910 }
911
912 fn render_aa_id_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
913 let text_style = self.make_text_style(cx);
914
915 EditorElement::new(
916 &self.access_key_id_editor,
917 EditorStyle {
918 background: cx.theme().colors().editor_background,
919 local_player: cx.theme().players().local(),
920 text: text_style,
921 ..Default::default()
922 },
923 )
924 }
925
926 fn render_sk_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
927 let text_style = self.make_text_style(cx);
928
929 EditorElement::new(
930 &self.secret_access_key_editor,
931 EditorStyle {
932 background: cx.theme().colors().editor_background,
933 local_player: cx.theme().players().local(),
934 text: text_style,
935 ..Default::default()
936 },
937 )
938 }
939
940 fn render_region_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
941 let text_style = self.make_text_style(cx);
942
943 EditorElement::new(
944 &self.region_editor,
945 EditorStyle {
946 background: cx.theme().colors().editor_background,
947 local_player: cx.theme().players().local(),
948 text: text_style,
949 ..Default::default()
950 },
951 )
952 }
953
954 fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
955 !self.state.read(cx).is_authenticated()
956 }
957}
958
959impl Render for ConfigurationView {
960 fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
961 let env_var_set = self.state.read(cx).credentials_from_env;
962 let bg_color = cx.theme().colors().editor_background;
963 let border_color = cx.theme().colors().border_variant;
964 let input_base_styles = || {
965 h_flex()
966 .w_full()
967 .px_2()
968 .py_1()
969 .bg(bg_color)
970 .border_1()
971 .border_color(border_color)
972 .rounded_sm()
973 };
974
975 if self.load_credentials_task.is_some() {
976 div().child(Label::new("Loading credentials...")).into_any()
977 } else if self.should_render_editor(cx) {
978 v_flex()
979 .size_full()
980 .on_action(cx.listener(ConfigurationView::save_credentials))
981 .child(Label::new("To use Zed's assistant with Bedrock, you need to add the Access Key ID, Secret Access Key and AWS Region. Follow these steps:"))
982 .child(
983 List::new()
984 .child(
985 InstructionListItem::new(
986 "Start by",
987 Some("creating a user and security credentials"),
988 Some("https://us-east-1.console.aws.amazon.com/iam/home")
989 )
990 )
991 .child(
992 InstructionListItem::new(
993 "Grant that user permissions according to this documentation:",
994 Some("Prerequisites"),
995 Some("https://docs.aws.amazon.com/bedrock/latest/userguide/inference-prereq.html")
996 )
997 )
998 .child(
999 InstructionListItem::new(
1000 "Select the models you would like access to:",
1001 Some("Bedrock Model Catalog"),
1002 Some("https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/modelaccess")
1003 )
1004 )
1005 .child(
1006 InstructionListItem::text_only("Fill the fields below and hit enter to start using the assistant")
1007 )
1008 )
1009 .child(
1010 v_flex()
1011 .my_2()
1012 .gap_1p5()
1013 .child(
1014 v_flex()
1015 .gap_0p5()
1016 .child(Label::new("Access Key ID").size(LabelSize::Small))
1017 .child(
1018 input_base_styles().child(self.render_aa_id_editor(cx))
1019 )
1020 )
1021 .child(
1022 v_flex()
1023 .gap_0p5()
1024 .child(Label::new("Secret Access Key").size(LabelSize::Small))
1025 .child(
1026 input_base_styles().child(self.render_sk_editor(cx))
1027 )
1028 )
1029 .child(
1030 v_flex()
1031 .gap_0p5()
1032 .child(Label::new("Region").size(LabelSize::Small))
1033 .child(
1034 input_base_styles().child(self.render_region_editor(cx))
1035 )
1036 )
1037 )
1038 .child(
1039 Label::new(
1040 format!("You can also assign the {ZED_BEDROCK_ACCESS_KEY_ID_VAR}, {ZED_BEDROCK_SECRET_ACCESS_KEY_VAR}, and {ZED_BEDROCK_REGION_VAR} environment variables and restart Zed."),
1041 )
1042 .size(LabelSize::Small)
1043 .color(Color::Muted),
1044 )
1045 .into_any()
1046 } else {
1047 h_flex()
1048 .size_full()
1049 .justify_between()
1050 .child(
1051 h_flex()
1052 .gap_1()
1053 .child(Icon::new(IconName::Check).color(Color::Success))
1054 .child(Label::new(if env_var_set {
1055 format!("Access Key ID is set in {ZED_BEDROCK_ACCESS_KEY_ID_VAR}, Secret Key is set in {ZED_BEDROCK_SECRET_ACCESS_KEY_VAR}, Region is set in {ZED_BEDROCK_REGION_VAR} environment variables.")
1056 } else {
1057 "Credentials configured.".to_string()
1058 })),
1059 )
1060 .child(
1061 Button::new("reset-key", "Reset key")
1062 .icon(Some(IconName::Trash))
1063 .icon_size(IconSize::Small)
1064 .icon_position(IconPosition::Start)
1065 .disabled(env_var_set)
1066 .when(env_var_set, |this| {
1067 this.tooltip(Tooltip::text(format!("To reset your credentials, unset the {ZED_BEDROCK_ACCESS_KEY_ID_VAR}, {ZED_BEDROCK_SECRET_ACCESS_KEY_VAR}, and {ZED_BEDROCK_REGION_VAR} environment variables.")))
1068 })
1069 .on_click(cx.listener(|this, _, window, cx| this.reset_credentials(window, cx))),
1070 )
1071 .into_any()
1072 }
1073 }
1074}