1use std::pin::Pin;
2use std::str::FromStr;
3use std::sync::Arc;
4
5use crate::ui::InstructionListItem;
6use anyhow::{Context as _, Result, anyhow};
7use aws_config::Region;
8use aws_config::stalled_stream_protection::StalledStreamProtectionConfig;
9use aws_credential_types::Credentials;
10use aws_http_client::AwsHttpClient;
11use bedrock::bedrock_client::types::{
12 ContentBlockDelta, ContentBlockStart, ContentBlockStartEvent, ConverseStreamOutput,
13};
14use bedrock::bedrock_client::{self, Config};
15use bedrock::{BedrockError, BedrockInnerContent, BedrockMessage, BedrockStreamingResponse, Model};
16use collections::{BTreeMap, HashMap};
17use credentials_provider::CredentialsProvider;
18use editor::{Editor, EditorElement, EditorStyle};
19use futures::{FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream};
20use gpui::{
21 AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace,
22};
23use gpui_tokio::Tokio;
24use http_client::HttpClient;
25use language_model::{
26 AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
27 LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
28 LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
29 LanguageModelRequest, LanguageModelToolUse, MessageContent, RateLimiter, Role,
30};
31use schemars::JsonSchema;
32use serde::{Deserialize, Serialize};
33use serde_json::Value;
34use settings::{Settings, SettingsStore};
35use strum::IntoEnumIterator;
36use theme::ThemeSettings;
37use tokio::runtime::Handle;
38use ui::{Icon, IconName, List, Tooltip, prelude::*};
39use util::{ResultExt, maybe};
40
41use crate::AllLanguageModelSettings;
42
43const PROVIDER_ID: &str = "amazon-bedrock";
44const PROVIDER_NAME: &str = "Amazon Bedrock";
45
46#[derive(Default, Clone, Deserialize, Serialize, PartialEq, Debug)]
47pub struct BedrockCredentials {
48 pub region: String,
49 pub access_key_id: String,
50 pub secret_access_key: String,
51}
52
53#[derive(Default, Clone, Debug, PartialEq)]
54pub struct AmazonBedrockSettings {
55 pub session_token: Option<String>,
56 pub available_models: Vec<AvailableModel>,
57}
58
59#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
60pub struct AvailableModel {
61 pub name: String,
62 pub display_name: Option<String>,
63 pub max_tokens: usize,
64 pub cache_configuration: Option<LanguageModelCacheConfiguration>,
65 pub max_output_tokens: Option<u32>,
66 pub default_temperature: Option<f32>,
67}
68
69/// The URL of the base AWS service.
70///
71/// Right now we're just using this as the key to store the AWS credentials
72/// under in the keychain.
73const AMAZON_AWS_URL: &str = "https://amazonaws.com";
74
75// These environment variables all use a `ZED_` prefix because we don't want to overwrite the user's AWS credentials.
76const ZED_BEDROCK_ACCESS_KEY_ID_VAR: &str = "ZED_ACCESS_KEY_ID";
77const ZED_BEDROCK_SECRET_ACCESS_KEY_VAR: &str = "ZED_SECRET_ACCESS_KEY";
78const ZED_BEDROCK_REGION_VAR: &str = "ZED_AWS_REGION";
79const ZED_AWS_CREDENTIALS_VAR: &str = "ZED_AWS_CREDENTIALS";
80
81pub struct State {
82 credentials: Option<BedrockCredentials>,
83 credentials_from_env: bool,
84 _subscription: Subscription,
85}
86
87impl State {
88 fn reset_credentials(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
89 let credentials_provider = <dyn CredentialsProvider>::global(cx);
90 cx.spawn(async move |this, cx| {
91 credentials_provider
92 .delete_credentials(AMAZON_AWS_URL, &cx)
93 .await
94 .log_err();
95 this.update(cx, |this, cx| {
96 this.credentials = None;
97 this.credentials_from_env = false;
98 cx.notify();
99 })
100 })
101 }
102
103 fn set_credentials(
104 &mut self,
105 credentials: BedrockCredentials,
106 cx: &mut Context<Self>,
107 ) -> Task<Result<()>> {
108 let credentials_provider = <dyn CredentialsProvider>::global(cx);
109 cx.spawn(async move |this, cx| {
110 credentials_provider
111 .write_credentials(
112 AMAZON_AWS_URL,
113 "Bearer",
114 &serde_json::to_vec(&credentials)?,
115 &cx,
116 )
117 .await?;
118 this.update(cx, |this, cx| {
119 this.credentials = Some(credentials);
120 cx.notify();
121 })
122 })
123 }
124
125 fn is_authenticated(&self) -> bool {
126 self.credentials.is_some()
127 }
128
129 fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
130 if self.is_authenticated() {
131 return Task::ready(Ok(()));
132 }
133
134 let credentials_provider = <dyn CredentialsProvider>::global(cx);
135 cx.spawn(async move |this, cx| {
136 let (credentials, from_env) =
137 if let Ok(credentials) = std::env::var(ZED_AWS_CREDENTIALS_VAR) {
138 (credentials, true)
139 } else {
140 let (_, credentials) = credentials_provider
141 .read_credentials(AMAZON_AWS_URL, &cx)
142 .await?
143 .ok_or_else(|| AuthenticateError::CredentialsNotFound)?;
144 (
145 String::from_utf8(credentials)
146 .context("invalid {PROVIDER_NAME} credentials")?,
147 false,
148 )
149 };
150
151 let credentials: BedrockCredentials =
152 serde_json::from_str(&credentials).context("failed to parse credentials")?;
153
154 this.update(cx, |this, cx| {
155 this.credentials = Some(credentials);
156 this.credentials_from_env = from_env;
157 cx.notify();
158 })?;
159
160 Ok(())
161 })
162 }
163}
164
165pub struct BedrockLanguageModelProvider {
166 http_client: AwsHttpClient,
167 handler: tokio::runtime::Handle,
168 state: gpui::Entity<State>,
169}
170
171impl BedrockLanguageModelProvider {
172 pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
173 let state = cx.new(|cx| State {
174 credentials: None,
175 credentials_from_env: false,
176 _subscription: cx.observe_global::<SettingsStore>(|_, cx| {
177 cx.notify();
178 }),
179 });
180
181 let tokio_handle = Tokio::handle(cx);
182
183 let coerced_client = AwsHttpClient::new(http_client.clone(), tokio_handle.clone());
184
185 Self {
186 http_client: coerced_client,
187 handler: tokio_handle.clone(),
188 state,
189 }
190 }
191}
192
193impl LanguageModelProvider for BedrockLanguageModelProvider {
194 fn id(&self) -> LanguageModelProviderId {
195 LanguageModelProviderId(PROVIDER_ID.into())
196 }
197
198 fn name(&self) -> LanguageModelProviderName {
199 LanguageModelProviderName(PROVIDER_NAME.into())
200 }
201
202 fn icon(&self) -> IconName {
203 IconName::AiBedrock
204 }
205
206 fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
207 let model = bedrock::Model::default();
208 Some(Arc::new(BedrockModel {
209 id: LanguageModelId::from(model.id().to_string()),
210 model,
211 http_client: self.http_client.clone(),
212 handler: self.handler.clone(),
213 state: self.state.clone(),
214 request_limiter: RateLimiter::new(4),
215 }))
216 }
217
218 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
219 let mut models = BTreeMap::default();
220
221 for model in bedrock::Model::iter() {
222 if !matches!(model, bedrock::Model::Custom { .. }) {
223 models.insert(model.id().to_string(), model);
224 }
225 }
226
227 // Override with available models from settings
228 for model in AllLanguageModelSettings::get_global(cx)
229 .bedrock
230 .available_models
231 .iter()
232 {
233 models.insert(
234 model.name.clone(),
235 bedrock::Model::Custom {
236 name: model.name.clone(),
237 display_name: model.display_name.clone(),
238 max_tokens: model.max_tokens,
239 max_output_tokens: model.max_output_tokens,
240 default_temperature: model.default_temperature,
241 },
242 );
243 }
244
245 models
246 .into_values()
247 .map(|model| {
248 Arc::new(BedrockModel {
249 id: LanguageModelId::from(model.id().to_string()),
250 model,
251 http_client: self.http_client.clone(),
252 handler: self.handler.clone(),
253 state: self.state.clone(),
254 request_limiter: RateLimiter::new(4),
255 }) as Arc<dyn LanguageModel>
256 })
257 .collect()
258 }
259
260 fn is_authenticated(&self, cx: &App) -> bool {
261 self.state.read(cx).is_authenticated()
262 }
263
264 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
265 self.state.update(cx, |state, cx| state.authenticate(cx))
266 }
267
268 fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView {
269 cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
270 .into()
271 }
272
273 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
274 self.state
275 .update(cx, |state, cx| state.reset_credentials(cx))
276 }
277}
278
279impl LanguageModelProviderState for BedrockLanguageModelProvider {
280 type ObservableEntity = State;
281
282 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
283 Some(self.state.clone())
284 }
285}
286
287struct BedrockModel {
288 id: LanguageModelId,
289 model: Model,
290 http_client: AwsHttpClient,
291 handler: tokio::runtime::Handle,
292 state: gpui::Entity<State>,
293 request_limiter: RateLimiter,
294}
295
296impl BedrockModel {
297 fn stream_completion(
298 &self,
299 request: bedrock::Request,
300 cx: &AsyncApp,
301 ) -> Result<
302 BoxFuture<'static, BoxStream<'static, Result<BedrockStreamingResponse, BedrockError>>>,
303 > {
304 let Ok(Ok((access_key_id, secret_access_key, region))) =
305 cx.read_entity(&self.state, |state, _cx| {
306 if let Some(credentials) = &state.credentials {
307 Ok((
308 credentials.access_key_id.clone(),
309 credentials.secret_access_key.clone(),
310 credentials.region.clone(),
311 ))
312 } else {
313 return Err(anyhow!("Failed to read credentials"));
314 }
315 })
316 else {
317 return Err(anyhow!("App state dropped"));
318 };
319
320 let runtime_client = bedrock_client::Client::from_conf(
321 Config::builder()
322 .stalled_stream_protection(StalledStreamProtectionConfig::disabled())
323 .credentials_provider(Credentials::new(
324 access_key_id,
325 secret_access_key,
326 None,
327 None,
328 "Keychain",
329 ))
330 .region(Region::new(region))
331 .http_client(self.http_client.clone())
332 .build(),
333 );
334
335 let owned_handle = self.handler.clone();
336
337 Ok(async move {
338 let request = bedrock::stream_completion(runtime_client, request, owned_handle);
339 request.await.unwrap_or_else(|e| {
340 futures::stream::once(async move { Err(BedrockError::ClientError(e)) }).boxed()
341 })
342 }
343 .boxed())
344 }
345}
346
347impl LanguageModel for BedrockModel {
348 fn id(&self) -> LanguageModelId {
349 self.id.clone()
350 }
351
352 fn name(&self) -> LanguageModelName {
353 LanguageModelName::from(self.model.display_name().to_string())
354 }
355
356 fn provider_id(&self) -> LanguageModelProviderId {
357 LanguageModelProviderId(PROVIDER_ID.into())
358 }
359
360 fn provider_name(&self) -> LanguageModelProviderName {
361 LanguageModelProviderName(PROVIDER_NAME.into())
362 }
363
364 fn supports_tools(&self) -> bool {
365 true
366 }
367
368 fn telemetry_id(&self) -> String {
369 format!("bedrock/{}", self.model.id())
370 }
371
372 fn max_token_count(&self) -> usize {
373 self.model.max_token_count()
374 }
375
376 fn max_output_tokens(&self) -> Option<u32> {
377 Some(self.model.max_output_tokens())
378 }
379
380 fn count_tokens(
381 &self,
382 request: LanguageModelRequest,
383 cx: &App,
384 ) -> BoxFuture<'static, Result<usize>> {
385 get_bedrock_tokens(request, cx)
386 }
387
388 fn stream_completion(
389 &self,
390 request: LanguageModelRequest,
391 cx: &AsyncApp,
392 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
393 let request = into_bedrock(
394 request,
395 self.model.id().into(),
396 self.model.default_temperature(),
397 self.model.max_output_tokens(),
398 );
399
400 let owned_handle = self.handler.clone();
401
402 let request = self.stream_completion(request, cx);
403 let future = self.request_limiter.stream(async move {
404 let response = request.map_err(|err| anyhow!(err))?.await;
405 Ok(map_to_language_model_completion_events(
406 response,
407 owned_handle,
408 ))
409 });
410 async move { Ok(future.await?.boxed()) }.boxed()
411 }
412
413 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
414 None
415 }
416}
417
418pub fn into_bedrock(
419 request: LanguageModelRequest,
420 model: String,
421 default_temperature: f32,
422 max_output_tokens: u32,
423) -> bedrock::Request {
424 let mut new_messages: Vec<BedrockMessage> = Vec::new();
425 let mut system_message = String::new();
426
427 for message in request.messages {
428 if message.contents_empty() {
429 continue;
430 }
431
432 match message.role {
433 Role::User | Role::Assistant => {
434 let bedrock_message_content: Vec<BedrockInnerContent> = message
435 .content
436 .into_iter()
437 .filter_map(|content| match content {
438 MessageContent::Text(text) => {
439 if !text.is_empty() {
440 Some(BedrockInnerContent::Text(text))
441 } else {
442 None
443 }
444 }
445 _ => None,
446 })
447 .collect();
448 let bedrock_role = match message.role {
449 Role::User => bedrock::BedrockRole::User,
450 Role::Assistant => bedrock::BedrockRole::Assistant,
451 Role::System => unreachable!("System role should never occur here"),
452 };
453 if let Some(last_message) = new_messages.last_mut() {
454 if last_message.role == bedrock_role {
455 last_message.content.extend(bedrock_message_content);
456 continue;
457 }
458 }
459 new_messages.push(
460 BedrockMessage::builder()
461 .role(bedrock_role)
462 .set_content(Some(bedrock_message_content))
463 .build()
464 .expect("failed to build Bedrock message"),
465 );
466 }
467 Role::System => {
468 if !system_message.is_empty() {
469 system_message.push_str("\n\n");
470 }
471 system_message.push_str(&message.string_contents());
472 }
473 }
474 }
475
476 bedrock::Request {
477 model,
478 messages: new_messages,
479 max_tokens: max_output_tokens,
480 system: Some(system_message),
481 tools: vec![],
482 tool_choice: None,
483 metadata: None,
484 stop_sequences: Vec::new(),
485 temperature: request.temperature.or(Some(default_temperature)),
486 top_k: None,
487 top_p: None,
488 }
489}
490
491// TODO: just call the ConverseOutput.usage() method:
492// https://docs.rs/aws-sdk-bedrockruntime/latest/aws_sdk_bedrockruntime/operation/converse/struct.ConverseOutput.html#method.output
493pub fn get_bedrock_tokens(
494 request: LanguageModelRequest,
495 cx: &App,
496) -> BoxFuture<'static, Result<usize>> {
497 cx.background_executor()
498 .spawn(async move {
499 let messages = request.messages;
500 let mut tokens_from_images = 0;
501 let mut string_messages = Vec::with_capacity(messages.len());
502
503 for message in messages {
504 use language_model::MessageContent;
505
506 let mut string_contents = String::new();
507
508 for content in message.content {
509 match content {
510 MessageContent::Text(text) => {
511 string_contents.push_str(&text);
512 }
513 MessageContent::Image(image) => {
514 tokens_from_images += image.estimate_tokens();
515 }
516 MessageContent::ToolUse(_tool_use) => {
517 // TODO: Estimate token usage from tool uses.
518 }
519 MessageContent::ToolResult(tool_result) => {
520 string_contents.push_str(&tool_result.content);
521 }
522 }
523 }
524
525 if !string_contents.is_empty() {
526 string_messages.push(tiktoken_rs::ChatCompletionRequestMessage {
527 role: match message.role {
528 Role::User => "user".into(),
529 Role::Assistant => "assistant".into(),
530 Role::System => "system".into(),
531 },
532 content: Some(string_contents),
533 name: None,
534 function_call: None,
535 });
536 }
537 }
538
539 // Tiktoken doesn't yet support these models, so we manually use the
540 // same tokenizer as GPT-4.
541 tiktoken_rs::num_tokens_from_messages("gpt-4", &string_messages)
542 .map(|tokens| tokens + tokens_from_images)
543 })
544 .boxed()
545}
546
547pub async fn extract_tool_args_from_events(
548 name: String,
549 mut events: Pin<Box<dyn Send + Stream<Item = Result<BedrockStreamingResponse, BedrockError>>>>,
550 handle: Handle,
551) -> Result<impl Send + Stream<Item = Result<String>>> {
552 handle
553 .spawn(async move {
554 let mut tool_use_index = None;
555 while let Some(event) = events.next().await {
556 if let BedrockStreamingResponse::ContentBlockStart(ContentBlockStartEvent {
557 content_block_index,
558 start,
559 ..
560 }) = event?
561 {
562 match start {
563 None => {
564 continue;
565 }
566 Some(start) => match start.as_tool_use() {
567 Ok(tool_use) => {
568 if name == tool_use.name {
569 tool_use_index = Some(content_block_index);
570 break;
571 }
572 }
573 Err(err) => {
574 return Err(anyhow!("Failed to parse tool use event: {:?}", err));
575 }
576 },
577 }
578 }
579 }
580
581 let Some(tool_use_index) = tool_use_index else {
582 return Err(anyhow!("Tool is not used"));
583 };
584
585 Ok(events.filter_map(move |event| {
586 let result = match event {
587 Err(_err) => None,
588 Ok(output) => match output.clone() {
589 BedrockStreamingResponse::ContentBlockDelta(inner) => {
590 match inner.clone().delta {
591 Some(ContentBlockDelta::ToolUse(tool_use)) => {
592 if inner.content_block_index == tool_use_index {
593 Some(Ok(tool_use.input))
594 } else {
595 None
596 }
597 }
598 _ => None,
599 }
600 }
601 _ => None,
602 },
603 };
604
605 async move { result }
606 }))
607 })
608 .await?
609}
610
611pub fn map_to_language_model_completion_events(
612 events: Pin<Box<dyn Send + Stream<Item = Result<BedrockStreamingResponse, BedrockError>>>>,
613 handle: Handle,
614) -> impl Stream<Item = Result<LanguageModelCompletionEvent>> {
615 struct RawToolUse {
616 id: String,
617 name: String,
618 input_json: String,
619 }
620
621 struct State {
622 events: Pin<Box<dyn Send + Stream<Item = Result<BedrockStreamingResponse, BedrockError>>>>,
623 tool_uses_by_index: HashMap<i32, RawToolUse>,
624 }
625
626 futures::stream::unfold(
627 State {
628 events,
629 tool_uses_by_index: HashMap::default(),
630 },
631 move |mut state: State| {
632 let inner_handle = handle.clone();
633 async move {
634 inner_handle
635 .spawn(async {
636 while let Some(event) = state.events.next().await {
637 match event {
638 Ok(event) => match event {
639 ConverseStreamOutput::ContentBlockDelta(cb_delta) => {
640 if let Some(ContentBlockDelta::Text(text_out)) =
641 cb_delta.delta
642 {
643 return Some((
644 Some(Ok(LanguageModelCompletionEvent::Text(
645 text_out,
646 ))),
647 state,
648 ));
649 } else if let Some(ContentBlockDelta::ToolUse(text_out)) =
650 cb_delta.delta
651 {
652 if let Some(tool_use) = state
653 .tool_uses_by_index
654 .get_mut(&cb_delta.content_block_index)
655 {
656 tool_use.input_json.push_str(text_out.input());
657 return Some((None, state));
658 };
659
660 return Some((None, state));
661 } else if cb_delta.delta.is_none() {
662 return Some((None, state));
663 }
664 }
665 ConverseStreamOutput::ContentBlockStart(cb_start) => {
666 if let Some(start) = cb_start.start {
667 match start {
668 ContentBlockStart::ToolUse(text_out) => {
669 let tool_use = RawToolUse {
670 id: text_out.tool_use_id,
671 name: text_out.name,
672 input_json: String::new(),
673 };
674
675 state.tool_uses_by_index.insert(
676 cb_start.content_block_index,
677 tool_use,
678 );
679 }
680 _ => {}
681 }
682 }
683 }
684 ConverseStreamOutput::ContentBlockStop(cb_stop) => {
685 if let Some(tool_use) = state
686 .tool_uses_by_index
687 .remove(&cb_stop.content_block_index)
688 {
689 return Some((
690 Some(maybe!({
691 Ok(LanguageModelCompletionEvent::ToolUse(
692 LanguageModelToolUse {
693 id: tool_use.id.into(),
694 name: tool_use.name.into(),
695 input: if tool_use.input_json.is_empty()
696 {
697 Value::Null
698 } else {
699 serde_json::Value::from_str(
700 &tool_use.input_json,
701 )
702 .map_err(|err| anyhow!(err))?
703 },
704 },
705 ))
706 })),
707 state,
708 ));
709 }
710 }
711 _ => {}
712 },
713 Err(err) => return Some((Some(Err(anyhow!(err))), state)),
714 }
715 }
716 None
717 })
718 .await
719 .log_err()
720 .flatten()
721 }
722 },
723 )
724 .filter_map(|event| async move { event })
725}
726
727struct ConfigurationView {
728 access_key_id_editor: Entity<Editor>,
729 secret_access_key_editor: Entity<Editor>,
730 region_editor: Entity<Editor>,
731 state: gpui::Entity<State>,
732 load_credentials_task: Option<Task<()>>,
733}
734
735impl ConfigurationView {
736 const PLACEHOLDER_ACCESS_KEY_ID_TEXT: &'static str = "XXXXXXXXXXXXXXXX";
737 const PLACEHOLDER_SECRET_ACCESS_KEY_TEXT: &'static str =
738 "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX";
739 const PLACEHOLDER_REGION: &'static str = "us-east-1";
740
741 fn new(state: gpui::Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
742 cx.observe(&state, |_, _, cx| {
743 cx.notify();
744 })
745 .detach();
746
747 let load_credentials_task = Some(cx.spawn({
748 let state = state.clone();
749 async move |this, cx| {
750 if let Some(task) = state
751 .update(cx, |state, cx| state.authenticate(cx))
752 .log_err()
753 {
754 // We don't log an error, because "not signed in" is also an error.
755 let _ = task.await;
756 }
757 this.update(cx, |this, cx| {
758 this.load_credentials_task = None;
759 cx.notify();
760 })
761 .log_err();
762 }
763 }));
764
765 Self {
766 access_key_id_editor: cx.new(|cx| {
767 let mut editor = Editor::single_line(window, cx);
768 editor.set_placeholder_text(Self::PLACEHOLDER_ACCESS_KEY_ID_TEXT, cx);
769 editor
770 }),
771 secret_access_key_editor: cx.new(|cx| {
772 let mut editor = Editor::single_line(window, cx);
773 editor.set_placeholder_text(Self::PLACEHOLDER_SECRET_ACCESS_KEY_TEXT, cx);
774 editor
775 }),
776 region_editor: cx.new(|cx| {
777 let mut editor = Editor::single_line(window, cx);
778 editor.set_placeholder_text(Self::PLACEHOLDER_REGION, cx);
779 editor
780 }),
781 state,
782 load_credentials_task,
783 }
784 }
785
786 fn save_credentials(
787 &mut self,
788 _: &menu::Confirm,
789 _window: &mut Window,
790 cx: &mut Context<Self>,
791 ) {
792 let access_key_id = self
793 .access_key_id_editor
794 .read(cx)
795 .text(cx)
796 .to_string()
797 .trim()
798 .to_string();
799 let secret_access_key = self
800 .secret_access_key_editor
801 .read(cx)
802 .text(cx)
803 .to_string()
804 .trim()
805 .to_string();
806 let region = self
807 .region_editor
808 .read(cx)
809 .text(cx)
810 .to_string()
811 .trim()
812 .to_string();
813
814 let state = self.state.clone();
815 cx.spawn(async move |_, cx| {
816 state
817 .update(cx, |state, cx| {
818 let credentials: BedrockCredentials = BedrockCredentials {
819 access_key_id: access_key_id.clone(),
820 secret_access_key: secret_access_key.clone(),
821 region: region.clone(),
822 };
823
824 state.set_credentials(credentials, cx)
825 })?
826 .await
827 })
828 .detach_and_log_err(cx);
829 }
830
831 fn reset_credentials(&mut self, window: &mut Window, cx: &mut Context<Self>) {
832 self.access_key_id_editor
833 .update(cx, |editor, cx| editor.set_text("", window, cx));
834 self.secret_access_key_editor
835 .update(cx, |editor, cx| editor.set_text("", window, cx));
836 self.region_editor
837 .update(cx, |editor, cx| editor.set_text("", window, cx));
838
839 let state = self.state.clone();
840 cx.spawn(async move |_, cx| {
841 state
842 .update(cx, |state, cx| state.reset_credentials(cx))?
843 .await
844 })
845 .detach_and_log_err(cx);
846 }
847
848 fn make_text_style(&self, cx: &Context<Self>) -> TextStyle {
849 let settings = ThemeSettings::get_global(cx);
850 TextStyle {
851 color: cx.theme().colors().text,
852 font_family: settings.ui_font.family.clone(),
853 font_features: settings.ui_font.features.clone(),
854 font_fallbacks: settings.ui_font.fallbacks.clone(),
855 font_size: rems(0.875).into(),
856 font_weight: settings.ui_font.weight,
857 font_style: FontStyle::Normal,
858 line_height: relative(1.3),
859 background_color: None,
860 underline: None,
861 strikethrough: None,
862 white_space: WhiteSpace::Normal,
863 text_overflow: None,
864 text_align: Default::default(),
865 line_clamp: None,
866 }
867 }
868
869 fn render_aa_id_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
870 let text_style = self.make_text_style(cx);
871
872 EditorElement::new(
873 &self.access_key_id_editor,
874 EditorStyle {
875 background: cx.theme().colors().editor_background,
876 local_player: cx.theme().players().local(),
877 text: text_style,
878 ..Default::default()
879 },
880 )
881 }
882
883 fn render_sk_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
884 let text_style = self.make_text_style(cx);
885
886 EditorElement::new(
887 &self.secret_access_key_editor,
888 EditorStyle {
889 background: cx.theme().colors().editor_background,
890 local_player: cx.theme().players().local(),
891 text: text_style,
892 ..Default::default()
893 },
894 )
895 }
896
897 fn render_region_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
898 let text_style = self.make_text_style(cx);
899
900 EditorElement::new(
901 &self.region_editor,
902 EditorStyle {
903 background: cx.theme().colors().editor_background,
904 local_player: cx.theme().players().local(),
905 text: text_style,
906 ..Default::default()
907 },
908 )
909 }
910
911 fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
912 !self.state.read(cx).is_authenticated()
913 }
914}
915
916impl Render for ConfigurationView {
917 fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
918 let env_var_set = self.state.read(cx).credentials_from_env;
919 let bg_color = cx.theme().colors().editor_background;
920 let border_color = cx.theme().colors().border_variant;
921 let input_base_styles = || {
922 h_flex()
923 .w_full()
924 .px_2()
925 .py_1()
926 .bg(bg_color)
927 .border_1()
928 .border_color(border_color)
929 .rounded_sm()
930 };
931
932 if self.load_credentials_task.is_some() {
933 div().child(Label::new("Loading credentials...")).into_any()
934 } else if self.should_render_editor(cx) {
935 v_flex()
936 .size_full()
937 .on_action(cx.listener(ConfigurationView::save_credentials))
938 .child(Label::new("To use Zed's assistant with Bedrock, you need to add the Access Key ID, Secret Access Key and AWS Region. Follow these steps:"))
939 .child(
940 List::new()
941 .child(
942 InstructionListItem::new(
943 "Start by",
944 Some("creating a user and security credentials"),
945 Some("https://us-east-1.console.aws.amazon.com/iam/home")
946 )
947 )
948 .child(
949 InstructionListItem::new(
950 "Grant that user permissions according to this documentation:",
951 Some("Prerequisites"),
952 Some("https://docs.aws.amazon.com/bedrock/latest/userguide/inference-prereq.html")
953 )
954 )
955 .child(
956 InstructionListItem::new(
957 "Select the models you would like access to:",
958 Some("Bedrock Model Catalog"),
959 Some("https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/modelaccess")
960 )
961 )
962 .child(
963 InstructionListItem::text_only("Fill the fields below and hit enter to start using the assistant")
964 )
965 )
966 .child(
967 v_flex()
968 .my_2()
969 .gap_1p5()
970 .child(
971 v_flex()
972 .gap_0p5()
973 .child(Label::new("Access Key ID").size(LabelSize::Small))
974 .child(
975 input_base_styles().child(self.render_aa_id_editor(cx))
976 )
977 )
978 .child(
979 v_flex()
980 .gap_0p5()
981 .child(Label::new("Secret Access Key").size(LabelSize::Small))
982 .child(
983 input_base_styles().child(self.render_sk_editor(cx))
984 )
985 )
986 .child(
987 v_flex()
988 .gap_0p5()
989 .child(Label::new("Region").size(LabelSize::Small))
990 .child(
991 input_base_styles().child(self.render_region_editor(cx))
992 )
993 )
994 )
995 .child(
996 Label::new(
997 format!("You can also assign the {ZED_BEDROCK_ACCESS_KEY_ID_VAR}, {ZED_BEDROCK_SECRET_ACCESS_KEY_VAR}, and {ZED_BEDROCK_REGION_VAR} environment variables and restart Zed."),
998 )
999 .size(LabelSize::Small)
1000 .color(Color::Muted),
1001 )
1002 .into_any()
1003 } else {
1004 h_flex()
1005 .size_full()
1006 .justify_between()
1007 .child(
1008 h_flex()
1009 .gap_1()
1010 .child(Icon::new(IconName::Check).color(Color::Success))
1011 .child(Label::new(if env_var_set {
1012 format!("Access Key ID is set in {ZED_BEDROCK_ACCESS_KEY_ID_VAR}, Secret Key is set in {ZED_BEDROCK_SECRET_ACCESS_KEY_VAR}, Region is set in {ZED_BEDROCK_REGION_VAR} environment variables.")
1013 } else {
1014 "Credentials configured.".to_string()
1015 })),
1016 )
1017 .child(
1018 Button::new("reset-key", "Reset key")
1019 .icon(Some(IconName::Trash))
1020 .icon_size(IconSize::Small)
1021 .icon_position(IconPosition::Start)
1022 .disabled(env_var_set)
1023 .when(env_var_set, |this| {
1024 this.tooltip(Tooltip::text(format!("To reset your credentials, unset the {ZED_BEDROCK_ACCESS_KEY_ID_VAR}, {ZED_BEDROCK_SECRET_ACCESS_KEY_VAR}, and {ZED_BEDROCK_REGION_VAR} environment variables.")))
1025 })
1026 .on_click(cx.listener(|this, _, window, cx| this.reset_credentials(window, cx))),
1027 )
1028 .into_any()
1029 }
1030 }
1031}