1use std::pin::Pin;
2use std::str::FromStr;
3use std::sync::Arc;
4
5use crate::ui::InstructionListItem;
6use anyhow::{Context as _, Result, anyhow};
7use aws_config::Region;
8use aws_config::stalled_stream_protection::StalledStreamProtectionConfig;
9use aws_credential_types::Credentials;
10use aws_http_client::AwsHttpClient;
11use bedrock::bedrock_client::types::{ContentBlockDelta, ContentBlockStart, ConverseStreamOutput};
12use bedrock::bedrock_client::{self, Config};
13use bedrock::{BedrockError, BedrockInnerContent, BedrockMessage, BedrockStreamingResponse, Model};
14use collections::{BTreeMap, HashMap};
15use credentials_provider::CredentialsProvider;
16use editor::{Editor, EditorElement, EditorStyle};
17use futures::{FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream};
18use gpui::{
19 AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace,
20};
21use gpui_tokio::Tokio;
22use http_client::HttpClient;
23use language_model::{
24 AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
25 LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
26 LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
27 LanguageModelRequest, LanguageModelToolUse, MessageContent, RateLimiter, Role,
28};
29use schemars::JsonSchema;
30use serde::{Deserialize, Serialize};
31use serde_json::Value;
32use settings::{Settings, SettingsStore};
33use strum::IntoEnumIterator;
34use theme::ThemeSettings;
35use tokio::runtime::Handle;
36use ui::{Icon, IconName, List, Tooltip, prelude::*};
37use util::{ResultExt, maybe};
38
39use crate::AllLanguageModelSettings;
40
41const PROVIDER_ID: &str = "amazon-bedrock";
42const PROVIDER_NAME: &str = "Amazon Bedrock";
43
44#[derive(Default, Clone, Deserialize, Serialize, PartialEq, Debug)]
45pub struct BedrockCredentials {
46 pub region: String,
47 pub access_key_id: String,
48 pub secret_access_key: String,
49}
50
51#[derive(Default, Clone, Debug, PartialEq)]
52pub struct AmazonBedrockSettings {
53 pub session_token: Option<String>,
54 pub available_models: Vec<AvailableModel>,
55}
56
57#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
58pub struct AvailableModel {
59 pub name: String,
60 pub display_name: Option<String>,
61 pub max_tokens: usize,
62 pub cache_configuration: Option<LanguageModelCacheConfiguration>,
63 pub max_output_tokens: Option<u32>,
64 pub default_temperature: Option<f32>,
65}
66
67/// The URL of the base AWS service.
68///
69/// Right now we're just using this as the key to store the AWS credentials
70/// under in the keychain.
71const AMAZON_AWS_URL: &str = "https://amazonaws.com";
72
73// These environment variables all use a `ZED_` prefix because we don't want to overwrite the user's AWS credentials.
74const ZED_BEDROCK_ACCESS_KEY_ID_VAR: &str = "ZED_ACCESS_KEY_ID";
75const ZED_BEDROCK_SECRET_ACCESS_KEY_VAR: &str = "ZED_SECRET_ACCESS_KEY";
76const ZED_BEDROCK_REGION_VAR: &str = "ZED_AWS_REGION";
77const ZED_AWS_CREDENTIALS_VAR: &str = "ZED_AWS_CREDENTIALS";
78
79pub struct State {
80 credentials: Option<BedrockCredentials>,
81 credentials_from_env: bool,
82 _subscription: Subscription,
83}
84
85impl State {
86 fn reset_credentials(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
87 let credentials_provider = <dyn CredentialsProvider>::global(cx);
88 cx.spawn(async move |this, cx| {
89 credentials_provider
90 .delete_credentials(AMAZON_AWS_URL, &cx)
91 .await
92 .log_err();
93 this.update(cx, |this, cx| {
94 this.credentials = None;
95 this.credentials_from_env = false;
96 cx.notify();
97 })
98 })
99 }
100
101 fn set_credentials(
102 &mut self,
103 credentials: BedrockCredentials,
104 cx: &mut Context<Self>,
105 ) -> Task<Result<()>> {
106 let credentials_provider = <dyn CredentialsProvider>::global(cx);
107 cx.spawn(async move |this, cx| {
108 credentials_provider
109 .write_credentials(
110 AMAZON_AWS_URL,
111 "Bearer",
112 &serde_json::to_vec(&credentials)?,
113 &cx,
114 )
115 .await?;
116 this.update(cx, |this, cx| {
117 this.credentials = Some(credentials);
118 cx.notify();
119 })
120 })
121 }
122
123 fn is_authenticated(&self) -> bool {
124 self.credentials.is_some()
125 }
126
127 fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
128 if self.is_authenticated() {
129 return Task::ready(Ok(()));
130 }
131
132 let credentials_provider = <dyn CredentialsProvider>::global(cx);
133 cx.spawn(async move |this, cx| {
134 let (credentials, from_env) =
135 if let Ok(credentials) = std::env::var(ZED_AWS_CREDENTIALS_VAR) {
136 (credentials, true)
137 } else {
138 let (_, credentials) = credentials_provider
139 .read_credentials(AMAZON_AWS_URL, &cx)
140 .await?
141 .ok_or_else(|| AuthenticateError::CredentialsNotFound)?;
142 (
143 String::from_utf8(credentials)
144 .context("invalid {PROVIDER_NAME} credentials")?,
145 false,
146 )
147 };
148
149 let credentials: BedrockCredentials =
150 serde_json::from_str(&credentials).context("failed to parse credentials")?;
151
152 this.update(cx, |this, cx| {
153 this.credentials = Some(credentials);
154 this.credentials_from_env = from_env;
155 cx.notify();
156 })?;
157
158 Ok(())
159 })
160 }
161}
162
163pub struct BedrockLanguageModelProvider {
164 http_client: AwsHttpClient,
165 handler: tokio::runtime::Handle,
166 state: gpui::Entity<State>,
167}
168
169impl BedrockLanguageModelProvider {
170 pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
171 let state = cx.new(|cx| State {
172 credentials: None,
173 credentials_from_env: false,
174 _subscription: cx.observe_global::<SettingsStore>(|_, cx| {
175 cx.notify();
176 }),
177 });
178
179 let tokio_handle = Tokio::handle(cx);
180
181 let coerced_client = AwsHttpClient::new(http_client.clone(), tokio_handle.clone());
182
183 Self {
184 http_client: coerced_client,
185 handler: tokio_handle.clone(),
186 state,
187 }
188 }
189}
190
191impl LanguageModelProvider for BedrockLanguageModelProvider {
192 fn id(&self) -> LanguageModelProviderId {
193 LanguageModelProviderId(PROVIDER_ID.into())
194 }
195
196 fn name(&self) -> LanguageModelProviderName {
197 LanguageModelProviderName(PROVIDER_NAME.into())
198 }
199
200 fn icon(&self) -> IconName {
201 IconName::AiBedrock
202 }
203
204 fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
205 let model = bedrock::Model::default();
206 Some(Arc::new(BedrockModel {
207 id: LanguageModelId::from(model.id().to_string()),
208 model,
209 http_client: self.http_client.clone(),
210 handler: self.handler.clone(),
211 state: self.state.clone(),
212 request_limiter: RateLimiter::new(4),
213 }))
214 }
215
216 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
217 let mut models = BTreeMap::default();
218
219 for model in bedrock::Model::iter() {
220 if !matches!(model, bedrock::Model::Custom { .. }) {
221 models.insert(model.id().to_string(), model);
222 }
223 }
224
225 // Override with available models from settings
226 for model in AllLanguageModelSettings::get_global(cx)
227 .bedrock
228 .available_models
229 .iter()
230 {
231 models.insert(
232 model.name.clone(),
233 bedrock::Model::Custom {
234 name: model.name.clone(),
235 display_name: model.display_name.clone(),
236 max_tokens: model.max_tokens,
237 max_output_tokens: model.max_output_tokens,
238 default_temperature: model.default_temperature,
239 },
240 );
241 }
242
243 models
244 .into_values()
245 .map(|model| {
246 Arc::new(BedrockModel {
247 id: LanguageModelId::from(model.id().to_string()),
248 model,
249 http_client: self.http_client.clone(),
250 handler: self.handler.clone(),
251 state: self.state.clone(),
252 request_limiter: RateLimiter::new(4),
253 }) as Arc<dyn LanguageModel>
254 })
255 .collect()
256 }
257
258 fn is_authenticated(&self, cx: &App) -> bool {
259 self.state.read(cx).is_authenticated()
260 }
261
262 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
263 self.state.update(cx, |state, cx| state.authenticate(cx))
264 }
265
266 fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView {
267 cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
268 .into()
269 }
270
271 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
272 self.state
273 .update(cx, |state, cx| state.reset_credentials(cx))
274 }
275}
276
277impl LanguageModelProviderState for BedrockLanguageModelProvider {
278 type ObservableEntity = State;
279
280 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
281 Some(self.state.clone())
282 }
283}
284
285struct BedrockModel {
286 id: LanguageModelId,
287 model: Model,
288 http_client: AwsHttpClient,
289 handler: tokio::runtime::Handle,
290 state: gpui::Entity<State>,
291 request_limiter: RateLimiter,
292}
293
294impl BedrockModel {
295 fn stream_completion(
296 &self,
297 request: bedrock::Request,
298 cx: &AsyncApp,
299 ) -> Result<
300 BoxFuture<'static, BoxStream<'static, Result<BedrockStreamingResponse, BedrockError>>>,
301 > {
302 let Ok(Ok((access_key_id, secret_access_key, region))) =
303 cx.read_entity(&self.state, |state, _cx| {
304 if let Some(credentials) = &state.credentials {
305 Ok((
306 credentials.access_key_id.clone(),
307 credentials.secret_access_key.clone(),
308 credentials.region.clone(),
309 ))
310 } else {
311 return Err(anyhow!("Failed to read credentials"));
312 }
313 })
314 else {
315 return Err(anyhow!("App state dropped"));
316 };
317
318 let runtime_client = bedrock_client::Client::from_conf(
319 Config::builder()
320 .stalled_stream_protection(StalledStreamProtectionConfig::disabled())
321 .credentials_provider(Credentials::new(
322 access_key_id,
323 secret_access_key,
324 None,
325 None,
326 "Keychain",
327 ))
328 .region(Region::new(region))
329 .http_client(self.http_client.clone())
330 .build(),
331 );
332
333 let owned_handle = self.handler.clone();
334
335 Ok(async move {
336 let request = bedrock::stream_completion(runtime_client, request, owned_handle);
337 request.await.unwrap_or_else(|e| {
338 futures::stream::once(async move { Err(BedrockError::ClientError(e)) }).boxed()
339 })
340 }
341 .boxed())
342 }
343}
344
345impl LanguageModel for BedrockModel {
346 fn id(&self) -> LanguageModelId {
347 self.id.clone()
348 }
349
350 fn name(&self) -> LanguageModelName {
351 LanguageModelName::from(self.model.display_name().to_string())
352 }
353
354 fn provider_id(&self) -> LanguageModelProviderId {
355 LanguageModelProviderId(PROVIDER_ID.into())
356 }
357
358 fn provider_name(&self) -> LanguageModelProviderName {
359 LanguageModelProviderName(PROVIDER_NAME.into())
360 }
361
362 fn supports_tools(&self) -> bool {
363 true
364 }
365
366 fn telemetry_id(&self) -> String {
367 format!("bedrock/{}", self.model.id())
368 }
369
370 fn max_token_count(&self) -> usize {
371 self.model.max_token_count()
372 }
373
374 fn max_output_tokens(&self) -> Option<u32> {
375 Some(self.model.max_output_tokens())
376 }
377
378 fn count_tokens(
379 &self,
380 request: LanguageModelRequest,
381 cx: &App,
382 ) -> BoxFuture<'static, Result<usize>> {
383 get_bedrock_tokens(request, cx)
384 }
385
386 fn stream_completion(
387 &self,
388 request: LanguageModelRequest,
389 cx: &AsyncApp,
390 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
391 let request = into_bedrock(
392 request,
393 self.model.id().into(),
394 self.model.default_temperature(),
395 self.model.max_output_tokens(),
396 );
397
398 let owned_handle = self.handler.clone();
399
400 let request = self.stream_completion(request, cx);
401 let future = self.request_limiter.stream(async move {
402 let response = request.map_err(|err| anyhow!(err))?.await;
403 Ok(map_to_language_model_completion_events(
404 response,
405 owned_handle,
406 ))
407 });
408 async move { Ok(future.await?.boxed()) }.boxed()
409 }
410
411 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
412 None
413 }
414}
415
416pub fn into_bedrock(
417 request: LanguageModelRequest,
418 model: String,
419 default_temperature: f32,
420 max_output_tokens: u32,
421) -> bedrock::Request {
422 let mut new_messages: Vec<BedrockMessage> = Vec::new();
423 let mut system_message = String::new();
424
425 for message in request.messages {
426 if message.contents_empty() {
427 continue;
428 }
429
430 match message.role {
431 Role::User | Role::Assistant => {
432 let bedrock_message_content: Vec<BedrockInnerContent> = message
433 .content
434 .into_iter()
435 .filter_map(|content| match content {
436 MessageContent::Text(text) => {
437 if !text.is_empty() {
438 Some(BedrockInnerContent::Text(text))
439 } else {
440 None
441 }
442 }
443 _ => None,
444 })
445 .collect();
446 let bedrock_role = match message.role {
447 Role::User => bedrock::BedrockRole::User,
448 Role::Assistant => bedrock::BedrockRole::Assistant,
449 Role::System => unreachable!("System role should never occur here"),
450 };
451 if let Some(last_message) = new_messages.last_mut() {
452 if last_message.role == bedrock_role {
453 last_message.content.extend(bedrock_message_content);
454 continue;
455 }
456 }
457 new_messages.push(
458 BedrockMessage::builder()
459 .role(bedrock_role)
460 .set_content(Some(bedrock_message_content))
461 .build()
462 .expect("failed to build Bedrock message"),
463 );
464 }
465 Role::System => {
466 if !system_message.is_empty() {
467 system_message.push_str("\n\n");
468 }
469 system_message.push_str(&message.string_contents());
470 }
471 }
472 }
473
474 bedrock::Request {
475 model,
476 messages: new_messages,
477 max_tokens: max_output_tokens,
478 system: Some(system_message),
479 tools: vec![],
480 tool_choice: None,
481 metadata: None,
482 stop_sequences: Vec::new(),
483 temperature: request.temperature.or(Some(default_temperature)),
484 top_k: None,
485 top_p: None,
486 }
487}
488
489// TODO: just call the ConverseOutput.usage() method:
490// https://docs.rs/aws-sdk-bedrockruntime/latest/aws_sdk_bedrockruntime/operation/converse/struct.ConverseOutput.html#method.output
491pub fn get_bedrock_tokens(
492 request: LanguageModelRequest,
493 cx: &App,
494) -> BoxFuture<'static, Result<usize>> {
495 cx.background_executor()
496 .spawn(async move {
497 let messages = request.messages;
498 let mut tokens_from_images = 0;
499 let mut string_messages = Vec::with_capacity(messages.len());
500
501 for message in messages {
502 use language_model::MessageContent;
503
504 let mut string_contents = String::new();
505
506 for content in message.content {
507 match content {
508 MessageContent::Text(text) => {
509 string_contents.push_str(&text);
510 }
511 MessageContent::Image(image) => {
512 tokens_from_images += image.estimate_tokens();
513 }
514 MessageContent::ToolUse(_tool_use) => {
515 // TODO: Estimate token usage from tool uses.
516 }
517 MessageContent::ToolResult(tool_result) => {
518 string_contents.push_str(&tool_result.content);
519 }
520 }
521 }
522
523 if !string_contents.is_empty() {
524 string_messages.push(tiktoken_rs::ChatCompletionRequestMessage {
525 role: match message.role {
526 Role::User => "user".into(),
527 Role::Assistant => "assistant".into(),
528 Role::System => "system".into(),
529 },
530 content: Some(string_contents),
531 name: None,
532 function_call: None,
533 });
534 }
535 }
536
537 // Tiktoken doesn't yet support these models, so we manually use the
538 // same tokenizer as GPT-4.
539 tiktoken_rs::num_tokens_from_messages("gpt-4", &string_messages)
540 .map(|tokens| tokens + tokens_from_images)
541 })
542 .boxed()
543}
544
545pub fn map_to_language_model_completion_events(
546 events: Pin<Box<dyn Send + Stream<Item = Result<BedrockStreamingResponse, BedrockError>>>>,
547 handle: Handle,
548) -> impl Stream<Item = Result<LanguageModelCompletionEvent>> {
549 struct RawToolUse {
550 id: String,
551 name: String,
552 input_json: String,
553 }
554
555 struct State {
556 events: Pin<Box<dyn Send + Stream<Item = Result<BedrockStreamingResponse, BedrockError>>>>,
557 tool_uses_by_index: HashMap<i32, RawToolUse>,
558 }
559
560 futures::stream::unfold(
561 State {
562 events,
563 tool_uses_by_index: HashMap::default(),
564 },
565 move |mut state: State| {
566 let inner_handle = handle.clone();
567 async move {
568 inner_handle
569 .spawn(async {
570 while let Some(event) = state.events.next().await {
571 match event {
572 Ok(event) => match event {
573 ConverseStreamOutput::ContentBlockDelta(cb_delta) => {
574 if let Some(ContentBlockDelta::Text(text_out)) =
575 cb_delta.delta
576 {
577 return Some((
578 Some(Ok(LanguageModelCompletionEvent::Text(
579 text_out,
580 ))),
581 state,
582 ));
583 } else if let Some(ContentBlockDelta::ToolUse(text_out)) =
584 cb_delta.delta
585 {
586 if let Some(tool_use) = state
587 .tool_uses_by_index
588 .get_mut(&cb_delta.content_block_index)
589 {
590 tool_use.input_json.push_str(text_out.input());
591 return Some((None, state));
592 };
593
594 return Some((None, state));
595 } else if cb_delta.delta.is_none() {
596 return Some((None, state));
597 }
598 }
599 ConverseStreamOutput::ContentBlockStart(cb_start) => {
600 if let Some(start) = cb_start.start {
601 match start {
602 ContentBlockStart::ToolUse(text_out) => {
603 let tool_use = RawToolUse {
604 id: text_out.tool_use_id,
605 name: text_out.name,
606 input_json: String::new(),
607 };
608
609 state.tool_uses_by_index.insert(
610 cb_start.content_block_index,
611 tool_use,
612 );
613 }
614 _ => {}
615 }
616 }
617 }
618 ConverseStreamOutput::ContentBlockStop(cb_stop) => {
619 if let Some(tool_use) = state
620 .tool_uses_by_index
621 .remove(&cb_stop.content_block_index)
622 {
623 return Some((
624 Some(maybe!({
625 Ok(LanguageModelCompletionEvent::ToolUse(
626 LanguageModelToolUse {
627 id: tool_use.id.into(),
628 name: tool_use.name.into(),
629 input: if tool_use.input_json.is_empty()
630 {
631 Value::Null
632 } else {
633 serde_json::Value::from_str(
634 &tool_use.input_json,
635 )
636 .map_err(|err| anyhow!(err))?
637 },
638 },
639 ))
640 })),
641 state,
642 ));
643 }
644 }
645 _ => {}
646 },
647 Err(err) => return Some((Some(Err(anyhow!(err))), state)),
648 }
649 }
650 None
651 })
652 .await
653 .log_err()
654 .flatten()
655 }
656 },
657 )
658 .filter_map(|event| async move { event })
659}
660
661struct ConfigurationView {
662 access_key_id_editor: Entity<Editor>,
663 secret_access_key_editor: Entity<Editor>,
664 region_editor: Entity<Editor>,
665 state: gpui::Entity<State>,
666 load_credentials_task: Option<Task<()>>,
667}
668
669impl ConfigurationView {
670 const PLACEHOLDER_ACCESS_KEY_ID_TEXT: &'static str = "XXXXXXXXXXXXXXXX";
671 const PLACEHOLDER_SECRET_ACCESS_KEY_TEXT: &'static str =
672 "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX";
673 const PLACEHOLDER_REGION: &'static str = "us-east-1";
674
675 fn new(state: gpui::Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
676 cx.observe(&state, |_, _, cx| {
677 cx.notify();
678 })
679 .detach();
680
681 let load_credentials_task = Some(cx.spawn({
682 let state = state.clone();
683 async move |this, cx| {
684 if let Some(task) = state
685 .update(cx, |state, cx| state.authenticate(cx))
686 .log_err()
687 {
688 // We don't log an error, because "not signed in" is also an error.
689 let _ = task.await;
690 }
691 this.update(cx, |this, cx| {
692 this.load_credentials_task = None;
693 cx.notify();
694 })
695 .log_err();
696 }
697 }));
698
699 Self {
700 access_key_id_editor: cx.new(|cx| {
701 let mut editor = Editor::single_line(window, cx);
702 editor.set_placeholder_text(Self::PLACEHOLDER_ACCESS_KEY_ID_TEXT, cx);
703 editor
704 }),
705 secret_access_key_editor: cx.new(|cx| {
706 let mut editor = Editor::single_line(window, cx);
707 editor.set_placeholder_text(Self::PLACEHOLDER_SECRET_ACCESS_KEY_TEXT, cx);
708 editor
709 }),
710 region_editor: cx.new(|cx| {
711 let mut editor = Editor::single_line(window, cx);
712 editor.set_placeholder_text(Self::PLACEHOLDER_REGION, cx);
713 editor
714 }),
715 state,
716 load_credentials_task,
717 }
718 }
719
720 fn save_credentials(
721 &mut self,
722 _: &menu::Confirm,
723 _window: &mut Window,
724 cx: &mut Context<Self>,
725 ) {
726 let access_key_id = self
727 .access_key_id_editor
728 .read(cx)
729 .text(cx)
730 .to_string()
731 .trim()
732 .to_string();
733 let secret_access_key = self
734 .secret_access_key_editor
735 .read(cx)
736 .text(cx)
737 .to_string()
738 .trim()
739 .to_string();
740 let region = self
741 .region_editor
742 .read(cx)
743 .text(cx)
744 .to_string()
745 .trim()
746 .to_string();
747
748 let state = self.state.clone();
749 cx.spawn(async move |_, cx| {
750 state
751 .update(cx, |state, cx| {
752 let credentials: BedrockCredentials = BedrockCredentials {
753 access_key_id: access_key_id.clone(),
754 secret_access_key: secret_access_key.clone(),
755 region: region.clone(),
756 };
757
758 state.set_credentials(credentials, cx)
759 })?
760 .await
761 })
762 .detach_and_log_err(cx);
763 }
764
765 fn reset_credentials(&mut self, window: &mut Window, cx: &mut Context<Self>) {
766 self.access_key_id_editor
767 .update(cx, |editor, cx| editor.set_text("", window, cx));
768 self.secret_access_key_editor
769 .update(cx, |editor, cx| editor.set_text("", window, cx));
770 self.region_editor
771 .update(cx, |editor, cx| editor.set_text("", window, cx));
772
773 let state = self.state.clone();
774 cx.spawn(async move |_, cx| {
775 state
776 .update(cx, |state, cx| state.reset_credentials(cx))?
777 .await
778 })
779 .detach_and_log_err(cx);
780 }
781
782 fn make_text_style(&self, cx: &Context<Self>) -> TextStyle {
783 let settings = ThemeSettings::get_global(cx);
784 TextStyle {
785 color: cx.theme().colors().text,
786 font_family: settings.ui_font.family.clone(),
787 font_features: settings.ui_font.features.clone(),
788 font_fallbacks: settings.ui_font.fallbacks.clone(),
789 font_size: rems(0.875).into(),
790 font_weight: settings.ui_font.weight,
791 font_style: FontStyle::Normal,
792 line_height: relative(1.3),
793 background_color: None,
794 underline: None,
795 strikethrough: None,
796 white_space: WhiteSpace::Normal,
797 text_overflow: None,
798 text_align: Default::default(),
799 line_clamp: None,
800 }
801 }
802
803 fn render_aa_id_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
804 let text_style = self.make_text_style(cx);
805
806 EditorElement::new(
807 &self.access_key_id_editor,
808 EditorStyle {
809 background: cx.theme().colors().editor_background,
810 local_player: cx.theme().players().local(),
811 text: text_style,
812 ..Default::default()
813 },
814 )
815 }
816
817 fn render_sk_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
818 let text_style = self.make_text_style(cx);
819
820 EditorElement::new(
821 &self.secret_access_key_editor,
822 EditorStyle {
823 background: cx.theme().colors().editor_background,
824 local_player: cx.theme().players().local(),
825 text: text_style,
826 ..Default::default()
827 },
828 )
829 }
830
831 fn render_region_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
832 let text_style = self.make_text_style(cx);
833
834 EditorElement::new(
835 &self.region_editor,
836 EditorStyle {
837 background: cx.theme().colors().editor_background,
838 local_player: cx.theme().players().local(),
839 text: text_style,
840 ..Default::default()
841 },
842 )
843 }
844
845 fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
846 !self.state.read(cx).is_authenticated()
847 }
848}
849
850impl Render for ConfigurationView {
851 fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
852 let env_var_set = self.state.read(cx).credentials_from_env;
853 let bg_color = cx.theme().colors().editor_background;
854 let border_color = cx.theme().colors().border_variant;
855 let input_base_styles = || {
856 h_flex()
857 .w_full()
858 .px_2()
859 .py_1()
860 .bg(bg_color)
861 .border_1()
862 .border_color(border_color)
863 .rounded_sm()
864 };
865
866 if self.load_credentials_task.is_some() {
867 div().child(Label::new("Loading credentials...")).into_any()
868 } else if self.should_render_editor(cx) {
869 v_flex()
870 .size_full()
871 .on_action(cx.listener(ConfigurationView::save_credentials))
872 .child(Label::new("To use Zed's assistant with Bedrock, you need to add the Access Key ID, Secret Access Key and AWS Region. Follow these steps:"))
873 .child(
874 List::new()
875 .child(
876 InstructionListItem::new(
877 "Start by",
878 Some("creating a user and security credentials"),
879 Some("https://us-east-1.console.aws.amazon.com/iam/home")
880 )
881 )
882 .child(
883 InstructionListItem::new(
884 "Grant that user permissions according to this documentation:",
885 Some("Prerequisites"),
886 Some("https://docs.aws.amazon.com/bedrock/latest/userguide/inference-prereq.html")
887 )
888 )
889 .child(
890 InstructionListItem::new(
891 "Select the models you would like access to:",
892 Some("Bedrock Model Catalog"),
893 Some("https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/modelaccess")
894 )
895 )
896 .child(
897 InstructionListItem::text_only("Fill the fields below and hit enter to start using the assistant")
898 )
899 )
900 .child(
901 v_flex()
902 .my_2()
903 .gap_1p5()
904 .child(
905 v_flex()
906 .gap_0p5()
907 .child(Label::new("Access Key ID").size(LabelSize::Small))
908 .child(
909 input_base_styles().child(self.render_aa_id_editor(cx))
910 )
911 )
912 .child(
913 v_flex()
914 .gap_0p5()
915 .child(Label::new("Secret Access Key").size(LabelSize::Small))
916 .child(
917 input_base_styles().child(self.render_sk_editor(cx))
918 )
919 )
920 .child(
921 v_flex()
922 .gap_0p5()
923 .child(Label::new("Region").size(LabelSize::Small))
924 .child(
925 input_base_styles().child(self.render_region_editor(cx))
926 )
927 )
928 )
929 .child(
930 Label::new(
931 format!("You can also assign the {ZED_BEDROCK_ACCESS_KEY_ID_VAR}, {ZED_BEDROCK_SECRET_ACCESS_KEY_VAR}, and {ZED_BEDROCK_REGION_VAR} environment variables and restart Zed."),
932 )
933 .size(LabelSize::Small)
934 .color(Color::Muted),
935 )
936 .into_any()
937 } else {
938 h_flex()
939 .size_full()
940 .justify_between()
941 .child(
942 h_flex()
943 .gap_1()
944 .child(Icon::new(IconName::Check).color(Color::Success))
945 .child(Label::new(if env_var_set {
946 format!("Access Key ID is set in {ZED_BEDROCK_ACCESS_KEY_ID_VAR}, Secret Key is set in {ZED_BEDROCK_SECRET_ACCESS_KEY_VAR}, Region is set in {ZED_BEDROCK_REGION_VAR} environment variables.")
947 } else {
948 "Credentials configured.".to_string()
949 })),
950 )
951 .child(
952 Button::new("reset-key", "Reset key")
953 .icon(Some(IconName::Trash))
954 .icon_size(IconSize::Small)
955 .icon_position(IconPosition::Start)
956 .disabled(env_var_set)
957 .when(env_var_set, |this| {
958 this.tooltip(Tooltip::text(format!("To reset your credentials, unset the {ZED_BEDROCK_ACCESS_KEY_ID_VAR}, {ZED_BEDROCK_SECRET_ACCESS_KEY_VAR}, and {ZED_BEDROCK_REGION_VAR} environment variables.")))
959 })
960 .on_click(cx.listener(|this, _, window, cx| this.reset_credentials(window, cx))),
961 )
962 .into_any()
963 }
964 }
965}