1use std::pin::Pin;
2use std::str::FromStr;
3use std::sync::Arc;
4
5use crate::ui::InstructionListItem;
6use anyhow::{anyhow, Context as _, Result};
7use aws_config::stalled_stream_protection::StalledStreamProtectionConfig;
8use aws_config::Region;
9use aws_credential_types::Credentials;
10use aws_http_client::AwsHttpClient;
11use bedrock::bedrock_client::types::{
12 ContentBlockDelta, ContentBlockStart, ContentBlockStartEvent, ConverseStreamOutput,
13};
14use bedrock::bedrock_client::{self, Config};
15use bedrock::{
16 value_to_aws_document, BedrockError, BedrockInnerContent, BedrockMessage, BedrockSpecificTool,
17 BedrockStreamingResponse, BedrockTool, BedrockToolChoice, BedrockToolInputSchema, Model,
18};
19use collections::{BTreeMap, HashMap};
20use credentials_provider::CredentialsProvider;
21use editor::{Editor, EditorElement, EditorStyle};
22use futures::{future::BoxFuture, stream::BoxStream, FutureExt, Stream, StreamExt};
23use gpui::{
24 AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace,
25};
26use gpui_tokio::Tokio;
27use http_client::HttpClient;
28use language_model::{
29 AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
30 LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
31 LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
32 LanguageModelRequest, LanguageModelToolUse, MessageContent, RateLimiter, Role,
33};
34use schemars::JsonSchema;
35use serde::{Deserialize, Serialize};
36use serde_json::Value;
37use settings::{Settings, SettingsStore};
38use strum::IntoEnumIterator;
39use theme::ThemeSettings;
40use tokio::runtime::Handle;
41use ui::{prelude::*, Icon, IconName, List, Tooltip};
42use util::{maybe, ResultExt};
43
44use crate::AllLanguageModelSettings;
45
46const PROVIDER_ID: &str = "amazon-bedrock";
47const PROVIDER_NAME: &str = "Amazon Bedrock";
48
49#[derive(Default, Clone, Deserialize, Serialize, PartialEq, Debug)]
50pub struct BedrockCredentials {
51 pub region: String,
52 pub access_key_id: String,
53 pub secret_access_key: String,
54}
55
56#[derive(Default, Clone, Debug, PartialEq)]
57pub struct AmazonBedrockSettings {
58 pub session_token: Option<String>,
59 pub available_models: Vec<AvailableModel>,
60}
61
62#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
63pub struct AvailableModel {
64 pub name: String,
65 pub display_name: Option<String>,
66 pub max_tokens: usize,
67 pub cache_configuration: Option<LanguageModelCacheConfiguration>,
68 pub max_output_tokens: Option<u32>,
69 pub default_temperature: Option<f32>,
70}
71
72/// The URL of the base AWS service.
73///
74/// Right now we're just using this as the key to store the AWS credentials
75/// under in the keychain.
76const AMAZON_AWS_URL: &str = "https://amazonaws.com";
77
78// These environment variables all use a `ZED_` prefix because we don't want to overwrite the user's AWS credentials.
79const ZED_BEDROCK_ACCESS_KEY_ID_VAR: &str = "ZED_ACCESS_KEY_ID";
80const ZED_BEDROCK_SECRET_ACCESS_KEY_VAR: &str = "ZED_SECRET_ACCESS_KEY";
81const ZED_BEDROCK_REGION_VAR: &str = "ZED_AWS_REGION";
82const ZED_AWS_CREDENTIALS_VAR: &str = "ZED_AWS_CREDENTIALS";
83
84pub struct State {
85 credentials: Option<BedrockCredentials>,
86 credentials_from_env: bool,
87 _subscription: Subscription,
88}
89
90impl State {
91 fn reset_credentials(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
92 let credentials_provider = <dyn CredentialsProvider>::global(cx);
93 cx.spawn(|this, mut cx| async move {
94 credentials_provider
95 .delete_credentials(AMAZON_AWS_URL, &cx)
96 .await
97 .log_err();
98 this.update(&mut cx, |this, cx| {
99 this.credentials = None;
100 this.credentials_from_env = false;
101 cx.notify();
102 })
103 })
104 }
105
106 fn set_credentials(
107 &mut self,
108 credentials: BedrockCredentials,
109 cx: &mut Context<Self>,
110 ) -> Task<Result<()>> {
111 let credentials_provider = <dyn CredentialsProvider>::global(cx);
112 cx.spawn(|this, mut cx| async move {
113 credentials_provider
114 .write_credentials(
115 AMAZON_AWS_URL,
116 "Bearer",
117 &serde_json::to_vec(&credentials)?,
118 &cx,
119 )
120 .await?;
121 this.update(&mut cx, |this, cx| {
122 this.credentials = Some(credentials);
123 cx.notify();
124 })
125 })
126 }
127
128 fn is_authenticated(&self) -> bool {
129 self.credentials.is_some()
130 }
131
132 fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
133 if self.is_authenticated() {
134 return Task::ready(Ok(()));
135 }
136
137 let credentials_provider = <dyn CredentialsProvider>::global(cx);
138 cx.spawn(|this, mut cx| async move {
139 let (credentials, from_env) =
140 if let Ok(credentials) = std::env::var(ZED_AWS_CREDENTIALS_VAR) {
141 (credentials, true)
142 } else {
143 let (_, credentials) = credentials_provider
144 .read_credentials(AMAZON_AWS_URL, &cx)
145 .await?
146 .ok_or_else(|| AuthenticateError::CredentialsNotFound)?;
147 (
148 String::from_utf8(credentials)
149 .context("invalid {PROVIDER_NAME} credentials")?,
150 false,
151 )
152 };
153
154 let credentials: BedrockCredentials =
155 serde_json::from_str(&credentials).context("failed to parse credentials")?;
156
157 this.update(&mut cx, |this, cx| {
158 this.credentials = Some(credentials);
159 this.credentials_from_env = from_env;
160 cx.notify();
161 })?;
162
163 Ok(())
164 })
165 }
166}
167
168pub struct BedrockLanguageModelProvider {
169 http_client: AwsHttpClient,
170 handler: tokio::runtime::Handle,
171 state: gpui::Entity<State>,
172}
173
174impl BedrockLanguageModelProvider {
175 pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
176 let state = cx.new(|cx| State {
177 credentials: None,
178 credentials_from_env: false,
179 _subscription: cx.observe_global::<SettingsStore>(|_, cx| {
180 cx.notify();
181 }),
182 });
183
184 let tokio_handle = Tokio::handle(cx);
185
186 let coerced_client = AwsHttpClient::new(http_client.clone(), tokio_handle.clone());
187
188 Self {
189 http_client: coerced_client,
190 handler: tokio_handle.clone(),
191 state,
192 }
193 }
194}
195
196impl LanguageModelProvider for BedrockLanguageModelProvider {
197 fn id(&self) -> LanguageModelProviderId {
198 LanguageModelProviderId(PROVIDER_ID.into())
199 }
200
201 fn name(&self) -> LanguageModelProviderName {
202 LanguageModelProviderName(PROVIDER_NAME.into())
203 }
204
205 fn icon(&self) -> IconName {
206 IconName::AiBedrock
207 }
208
209 fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
210 let model = bedrock::Model::default();
211 Some(Arc::new(BedrockModel {
212 id: LanguageModelId::from(model.id().to_string()),
213 model,
214 http_client: self.http_client.clone(),
215 handler: self.handler.clone(),
216 state: self.state.clone(),
217 request_limiter: RateLimiter::new(4),
218 }))
219 }
220
221 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
222 let mut models = BTreeMap::default();
223
224 for model in bedrock::Model::iter() {
225 if !matches!(model, bedrock::Model::Custom { .. }) {
226 models.insert(model.id().to_string(), model);
227 }
228 }
229
230 // Override with available models from settings
231 for model in AllLanguageModelSettings::get_global(cx)
232 .bedrock
233 .available_models
234 .iter()
235 {
236 models.insert(
237 model.name.clone(),
238 bedrock::Model::Custom {
239 name: model.name.clone(),
240 display_name: model.display_name.clone(),
241 max_tokens: model.max_tokens,
242 max_output_tokens: model.max_output_tokens,
243 default_temperature: model.default_temperature,
244 },
245 );
246 }
247
248 models
249 .into_values()
250 .map(|model| {
251 Arc::new(BedrockModel {
252 id: LanguageModelId::from(model.id().to_string()),
253 model,
254 http_client: self.http_client.clone(),
255 handler: self.handler.clone(),
256 state: self.state.clone(),
257 request_limiter: RateLimiter::new(4),
258 }) as Arc<dyn LanguageModel>
259 })
260 .collect()
261 }
262
263 fn is_authenticated(&self, cx: &App) -> bool {
264 self.state.read(cx).is_authenticated()
265 }
266
267 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
268 self.state.update(cx, |state, cx| state.authenticate(cx))
269 }
270
271 fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView {
272 cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
273 .into()
274 }
275
276 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
277 self.state
278 .update(cx, |state, cx| state.reset_credentials(cx))
279 }
280}
281
282impl LanguageModelProviderState for BedrockLanguageModelProvider {
283 type ObservableEntity = State;
284
285 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
286 Some(self.state.clone())
287 }
288}
289
290struct BedrockModel {
291 id: LanguageModelId,
292 model: Model,
293 http_client: AwsHttpClient,
294 handler: tokio::runtime::Handle,
295 state: gpui::Entity<State>,
296 request_limiter: RateLimiter,
297}
298
299impl BedrockModel {
300 fn stream_completion(
301 &self,
302 request: bedrock::Request,
303 cx: &AsyncApp,
304 ) -> Result<
305 BoxFuture<'static, BoxStream<'static, Result<BedrockStreamingResponse, BedrockError>>>,
306 > {
307 let Ok(Ok((access_key_id, secret_access_key, region))) =
308 cx.read_entity(&self.state, |state, _cx| {
309 if let Some(credentials) = &state.credentials {
310 Ok((
311 credentials.access_key_id.clone(),
312 credentials.secret_access_key.clone(),
313 credentials.region.clone(),
314 ))
315 } else {
316 return Err(anyhow!("Failed to read credentials"));
317 }
318 })
319 else {
320 return Err(anyhow!("App state dropped"));
321 };
322
323 let runtime_client = bedrock_client::Client::from_conf(
324 Config::builder()
325 .stalled_stream_protection(StalledStreamProtectionConfig::disabled())
326 .credentials_provider(Credentials::new(
327 access_key_id,
328 secret_access_key,
329 None,
330 None,
331 "Keychain",
332 ))
333 .region(Region::new(region))
334 .http_client(self.http_client.clone())
335 .build(),
336 );
337
338 let owned_handle = self.handler.clone();
339
340 Ok(async move {
341 let request = bedrock::stream_completion(runtime_client, request, owned_handle);
342 request.await.unwrap_or_else(|e| {
343 futures::stream::once(async move { Err(BedrockError::ClientError(e)) }).boxed()
344 })
345 }
346 .boxed())
347 }
348}
349
350impl LanguageModel for BedrockModel {
351 fn id(&self) -> LanguageModelId {
352 self.id.clone()
353 }
354
355 fn name(&self) -> LanguageModelName {
356 LanguageModelName::from(self.model.display_name().to_string())
357 }
358
359 fn provider_id(&self) -> LanguageModelProviderId {
360 LanguageModelProviderId(PROVIDER_ID.into())
361 }
362
363 fn provider_name(&self) -> LanguageModelProviderName {
364 LanguageModelProviderName(PROVIDER_NAME.into())
365 }
366
367 fn telemetry_id(&self) -> String {
368 format!("bedrock/{}", self.model.id())
369 }
370
371 fn max_token_count(&self) -> usize {
372 self.model.max_token_count()
373 }
374
375 fn max_output_tokens(&self) -> Option<u32> {
376 Some(self.model.max_output_tokens())
377 }
378
379 fn count_tokens(
380 &self,
381 request: LanguageModelRequest,
382 cx: &App,
383 ) -> BoxFuture<'static, Result<usize>> {
384 get_bedrock_tokens(request, cx)
385 }
386
387 fn stream_completion(
388 &self,
389 request: LanguageModelRequest,
390 cx: &AsyncApp,
391 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
392 let request = into_bedrock(
393 request,
394 self.model.id().into(),
395 self.model.default_temperature(),
396 self.model.max_output_tokens(),
397 );
398
399 let owned_handle = self.handler.clone();
400
401 let request = self.stream_completion(request, cx);
402 let future = self.request_limiter.stream(async move {
403 let response = request.map_err(|e| anyhow!(e)).unwrap().await;
404 Ok(map_to_language_model_completion_events(
405 response,
406 owned_handle,
407 ))
408 });
409 async move { Ok(future.await?.boxed()) }.boxed()
410 }
411
412 fn use_any_tool(
413 &self,
414 request: LanguageModelRequest,
415 name: String,
416 description: String,
417 schema: Value,
418 _cx: &AsyncApp,
419 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
420 let mut request = into_bedrock(
421 request,
422 self.model.id().into(),
423 self.model.default_temperature(),
424 self.model.max_output_tokens(),
425 );
426
427 request.tool_choice = Some(BedrockToolChoice::Tool(
428 BedrockSpecificTool::builder()
429 .name(name.clone())
430 .build()
431 .unwrap(),
432 ));
433
434 request.tools = vec![BedrockTool::builder()
435 .name(name.clone())
436 .description(description.clone())
437 .input_schema(BedrockToolInputSchema::Json(value_to_aws_document(&schema)))
438 .build()
439 .unwrap()];
440
441 let handle = self.handler.clone();
442
443 let request = self.stream_completion(request, _cx);
444 self.request_limiter
445 .run(async move {
446 let response = request.map_err(|e| anyhow!(e)).unwrap().await;
447 Ok(extract_tool_args_from_events(name, response, handle)
448 .await?
449 .boxed())
450 })
451 .boxed()
452 }
453
454 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
455 None
456 }
457}
458
459pub fn into_bedrock(
460 request: LanguageModelRequest,
461 model: String,
462 default_temperature: f32,
463 max_output_tokens: u32,
464) -> bedrock::Request {
465 let mut new_messages: Vec<BedrockMessage> = Vec::new();
466 let mut system_message = String::new();
467
468 for message in request.messages {
469 if message.contents_empty() {
470 continue;
471 }
472
473 match message.role {
474 Role::User | Role::Assistant => {
475 let bedrock_message_content: Vec<BedrockInnerContent> = message
476 .content
477 .into_iter()
478 .filter_map(|content| match content {
479 MessageContent::Text(text) => {
480 if !text.is_empty() {
481 Some(BedrockInnerContent::Text(text))
482 } else {
483 None
484 }
485 }
486 _ => None,
487 })
488 .collect();
489 let bedrock_role = match message.role {
490 Role::User => bedrock::BedrockRole::User,
491 Role::Assistant => bedrock::BedrockRole::Assistant,
492 Role::System => unreachable!("System role should never occur here"),
493 };
494 if let Some(last_message) = new_messages.last_mut() {
495 if last_message.role == bedrock_role {
496 last_message.content.extend(bedrock_message_content);
497 continue;
498 }
499 }
500 new_messages.push(
501 BedrockMessage::builder()
502 .role(bedrock_role)
503 .set_content(Some(bedrock_message_content))
504 .build()
505 .expect("failed to build Bedrock message"),
506 );
507 }
508 Role::System => {
509 if !system_message.is_empty() {
510 system_message.push_str("\n\n");
511 }
512 system_message.push_str(&message.string_contents());
513 }
514 }
515 }
516
517 bedrock::Request {
518 model,
519 messages: new_messages,
520 max_tokens: max_output_tokens,
521 system: Some(system_message),
522 tools: vec![],
523 tool_choice: None,
524 metadata: None,
525 stop_sequences: Vec::new(),
526 temperature: request.temperature.or(Some(default_temperature)),
527 top_k: None,
528 top_p: None,
529 }
530}
531
532// TODO: just call the ConverseOutput.usage() method:
533// https://docs.rs/aws-sdk-bedrockruntime/latest/aws_sdk_bedrockruntime/operation/converse/struct.ConverseOutput.html#method.output
534pub fn get_bedrock_tokens(
535 request: LanguageModelRequest,
536 cx: &App,
537) -> BoxFuture<'static, Result<usize>> {
538 cx.background_executor()
539 .spawn(async move {
540 let messages = request.messages;
541 let mut tokens_from_images = 0;
542 let mut string_messages = Vec::with_capacity(messages.len());
543
544 for message in messages {
545 use language_model::MessageContent;
546
547 let mut string_contents = String::new();
548
549 for content in message.content {
550 match content {
551 MessageContent::Text(text) => {
552 string_contents.push_str(&text);
553 }
554 MessageContent::Image(image) => {
555 tokens_from_images += image.estimate_tokens();
556 }
557 MessageContent::ToolUse(_tool_use) => {
558 // TODO: Estimate token usage from tool uses.
559 }
560 MessageContent::ToolResult(tool_result) => {
561 string_contents.push_str(&tool_result.content);
562 }
563 }
564 }
565
566 if !string_contents.is_empty() {
567 string_messages.push(tiktoken_rs::ChatCompletionRequestMessage {
568 role: match message.role {
569 Role::User => "user".into(),
570 Role::Assistant => "assistant".into(),
571 Role::System => "system".into(),
572 },
573 content: Some(string_contents),
574 name: None,
575 function_call: None,
576 });
577 }
578 }
579
580 // Tiktoken doesn't yet support these models, so we manually use the
581 // same tokenizer as GPT-4.
582 tiktoken_rs::num_tokens_from_messages("gpt-4", &string_messages)
583 .map(|tokens| tokens + tokens_from_images)
584 })
585 .boxed()
586}
587
588pub async fn extract_tool_args_from_events(
589 name: String,
590 mut events: Pin<Box<dyn Send + Stream<Item = Result<BedrockStreamingResponse, BedrockError>>>>,
591 handle: Handle,
592) -> Result<impl Send + Stream<Item = Result<String>>> {
593 handle
594 .spawn(async move {
595 let mut tool_use_index = None;
596 while let Some(event) = events.next().await {
597 if let BedrockStreamingResponse::ContentBlockStart(ContentBlockStartEvent {
598 content_block_index,
599 start,
600 ..
601 }) = event?
602 {
603 match start {
604 None => {
605 continue;
606 }
607 Some(start) => match start.as_tool_use() {
608 Ok(tool_use) => {
609 if name == tool_use.name {
610 tool_use_index = Some(content_block_index);
611 break;
612 }
613 }
614 Err(err) => {
615 return Err(anyhow!("Failed to parse tool use event: {:?}", err));
616 }
617 },
618 }
619 }
620 }
621
622 let Some(tool_use_index) = tool_use_index else {
623 return Err(anyhow!("Tool is not used"));
624 };
625
626 Ok(events.filter_map(move |event| {
627 let result = match event {
628 Err(_err) => None,
629 Ok(output) => match output.clone() {
630 BedrockStreamingResponse::ContentBlockDelta(inner) => {
631 match inner.clone().delta {
632 Some(ContentBlockDelta::ToolUse(tool_use)) => {
633 if inner.content_block_index == tool_use_index {
634 Some(Ok(tool_use.input))
635 } else {
636 None
637 }
638 }
639 _ => None,
640 }
641 }
642 _ => None,
643 },
644 };
645
646 async move { result }
647 }))
648 })
649 .await?
650}
651
652pub fn map_to_language_model_completion_events(
653 events: Pin<Box<dyn Send + Stream<Item = Result<BedrockStreamingResponse, BedrockError>>>>,
654 handle: Handle,
655) -> impl Stream<Item = Result<LanguageModelCompletionEvent>> {
656 struct RawToolUse {
657 id: String,
658 name: String,
659 input_json: String,
660 }
661
662 struct State {
663 events: Pin<Box<dyn Send + Stream<Item = Result<BedrockStreamingResponse, BedrockError>>>>,
664 tool_uses_by_index: HashMap<i32, RawToolUse>,
665 }
666
667 futures::stream::unfold(
668 State {
669 events,
670 tool_uses_by_index: HashMap::default(),
671 },
672 move |mut state: State| {
673 let inner_handle = handle.clone();
674 async move {
675 inner_handle
676 .spawn(async {
677 while let Some(event) = state.events.next().await {
678 match event {
679 Ok(event) => match event {
680 ConverseStreamOutput::ContentBlockDelta(cb_delta) => {
681 if let Some(ContentBlockDelta::Text(text_out)) =
682 cb_delta.delta
683 {
684 return Some((
685 Some(Ok(LanguageModelCompletionEvent::Text(
686 text_out,
687 ))),
688 state,
689 ));
690 } else if let Some(ContentBlockDelta::ToolUse(text_out)) =
691 cb_delta.delta
692 {
693 if let Some(tool_use) = state
694 .tool_uses_by_index
695 .get_mut(&cb_delta.content_block_index)
696 {
697 tool_use.input_json.push_str(text_out.input());
698 return Some((None, state));
699 };
700
701 return Some((None, state));
702 } else if cb_delta.delta.is_none() {
703 return Some((None, state));
704 }
705 }
706 ConverseStreamOutput::ContentBlockStart(cb_start) => {
707 if let Some(start) = cb_start.start {
708 match start {
709 ContentBlockStart::ToolUse(text_out) => {
710 let tool_use = RawToolUse {
711 id: text_out.tool_use_id,
712 name: text_out.name,
713 input_json: String::new(),
714 };
715
716 state.tool_uses_by_index.insert(
717 cb_start.content_block_index,
718 tool_use,
719 );
720 }
721 _ => {}
722 }
723 }
724 }
725 ConverseStreamOutput::ContentBlockStop(cb_stop) => {
726 if let Some(tool_use) = state
727 .tool_uses_by_index
728 .remove(&cb_stop.content_block_index)
729 {
730 return Some((
731 Some(maybe!({
732 Ok(LanguageModelCompletionEvent::ToolUse(
733 LanguageModelToolUse {
734 id: tool_use.id.into(),
735 name: tool_use.name.into(),
736 input: if tool_use.input_json.is_empty()
737 {
738 Value::Null
739 } else {
740 serde_json::Value::from_str(
741 &tool_use.input_json,
742 )
743 .map_err(|err| anyhow!(err))?
744 },
745 },
746 ))
747 })),
748 state,
749 ));
750 }
751 }
752 _ => {}
753 },
754 Err(err) => return Some((Some(Err(anyhow!(err))), state)),
755 }
756 }
757 None
758 })
759 .await
760 .unwrap()
761 }
762 },
763 )
764 .filter_map(|event| async move { event })
765}
766
767struct ConfigurationView {
768 access_key_id_editor: Entity<Editor>,
769 secret_access_key_editor: Entity<Editor>,
770 region_editor: Entity<Editor>,
771 state: gpui::Entity<State>,
772 load_credentials_task: Option<Task<()>>,
773}
774
775impl ConfigurationView {
776 const PLACEHOLDER_ACCESS_KEY_ID_TEXT: &'static str = "XXXXXXXXXXXXXXXX";
777 const PLACEHOLDER_SECRET_ACCESS_KEY_TEXT: &'static str =
778 "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX";
779 const PLACEHOLDER_REGION: &'static str = "us-east-1";
780
781 fn new(state: gpui::Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
782 cx.observe(&state, |_, _, cx| {
783 cx.notify();
784 })
785 .detach();
786
787 let load_credentials_task = Some(cx.spawn({
788 let state = state.clone();
789 |this, mut cx| async move {
790 if let Some(task) = state
791 .update(&mut cx, |state, cx| state.authenticate(cx))
792 .log_err()
793 {
794 // We don't log an error, because "not signed in" is also an error.
795 let _ = task.await;
796 }
797 this.update(&mut cx, |this, cx| {
798 this.load_credentials_task = None;
799 cx.notify();
800 })
801 .log_err();
802 }
803 }));
804
805 Self {
806 access_key_id_editor: cx.new(|cx| {
807 let mut editor = Editor::single_line(window, cx);
808 editor.set_placeholder_text(Self::PLACEHOLDER_ACCESS_KEY_ID_TEXT, cx);
809 editor
810 }),
811 secret_access_key_editor: cx.new(|cx| {
812 let mut editor = Editor::single_line(window, cx);
813 editor.set_placeholder_text(Self::PLACEHOLDER_SECRET_ACCESS_KEY_TEXT, cx);
814 editor
815 }),
816 region_editor: cx.new(|cx| {
817 let mut editor = Editor::single_line(window, cx);
818 editor.set_placeholder_text(Self::PLACEHOLDER_REGION, cx);
819 editor
820 }),
821 state,
822 load_credentials_task,
823 }
824 }
825
826 fn save_credentials(
827 &mut self,
828 _: &menu::Confirm,
829 _window: &mut Window,
830 cx: &mut Context<Self>,
831 ) {
832 let access_key_id = self
833 .access_key_id_editor
834 .read(cx)
835 .text(cx)
836 .to_string()
837 .trim()
838 .to_string();
839 let secret_access_key = self
840 .secret_access_key_editor
841 .read(cx)
842 .text(cx)
843 .to_string()
844 .trim()
845 .to_string();
846 let region = self
847 .region_editor
848 .read(cx)
849 .text(cx)
850 .to_string()
851 .trim()
852 .to_string();
853
854 let state = self.state.clone();
855 cx.spawn(|_, mut cx| async move {
856 state
857 .update(&mut cx, |state, cx| {
858 let credentials: BedrockCredentials = BedrockCredentials {
859 access_key_id: access_key_id.clone(),
860 secret_access_key: secret_access_key.clone(),
861 region: region.clone(),
862 };
863
864 state.set_credentials(credentials, cx)
865 })?
866 .await
867 })
868 .detach_and_log_err(cx);
869 }
870
871 fn reset_credentials(&mut self, window: &mut Window, cx: &mut Context<Self>) {
872 self.access_key_id_editor
873 .update(cx, |editor, cx| editor.set_text("", window, cx));
874 self.secret_access_key_editor
875 .update(cx, |editor, cx| editor.set_text("", window, cx));
876 self.region_editor
877 .update(cx, |editor, cx| editor.set_text("", window, cx));
878
879 let state = self.state.clone();
880 cx.spawn(|_, mut cx| async move {
881 state
882 .update(&mut cx, |state, cx| state.reset_credentials(cx))?
883 .await
884 })
885 .detach_and_log_err(cx);
886 }
887
888 fn make_text_style(&self, cx: &Context<Self>) -> TextStyle {
889 let settings = ThemeSettings::get_global(cx);
890 TextStyle {
891 color: cx.theme().colors().text,
892 font_family: settings.ui_font.family.clone(),
893 font_features: settings.ui_font.features.clone(),
894 font_fallbacks: settings.ui_font.fallbacks.clone(),
895 font_size: rems(0.875).into(),
896 font_weight: settings.ui_font.weight,
897 font_style: FontStyle::Normal,
898 line_height: relative(1.3),
899 background_color: None,
900 underline: None,
901 strikethrough: None,
902 white_space: WhiteSpace::Normal,
903 text_overflow: None,
904 text_align: Default::default(),
905 line_clamp: None,
906 }
907 }
908
909 fn render_aa_id_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
910 let text_style = self.make_text_style(cx);
911
912 EditorElement::new(
913 &self.access_key_id_editor,
914 EditorStyle {
915 background: cx.theme().colors().editor_background,
916 local_player: cx.theme().players().local(),
917 text: text_style,
918 ..Default::default()
919 },
920 )
921 }
922
923 fn render_sk_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
924 let text_style = self.make_text_style(cx);
925
926 EditorElement::new(
927 &self.secret_access_key_editor,
928 EditorStyle {
929 background: cx.theme().colors().editor_background,
930 local_player: cx.theme().players().local(),
931 text: text_style,
932 ..Default::default()
933 },
934 )
935 }
936
937 fn render_region_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
938 let text_style = self.make_text_style(cx);
939
940 EditorElement::new(
941 &self.region_editor,
942 EditorStyle {
943 background: cx.theme().colors().editor_background,
944 local_player: cx.theme().players().local(),
945 text: text_style,
946 ..Default::default()
947 },
948 )
949 }
950
951 fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
952 !self.state.read(cx).is_authenticated()
953 }
954}
955
956impl Render for ConfigurationView {
957 fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
958 let env_var_set = self.state.read(cx).credentials_from_env;
959 let bg_color = cx.theme().colors().editor_background;
960 let border_color = cx.theme().colors().border_variant;
961 let input_base_styles = || {
962 h_flex()
963 .w_full()
964 .px_2()
965 .py_1()
966 .bg(bg_color)
967 .border_1()
968 .border_color(border_color)
969 .rounded_md()
970 };
971
972 if self.load_credentials_task.is_some() {
973 div().child(Label::new("Loading credentials...")).into_any()
974 } else if self.should_render_editor(cx) {
975 v_flex()
976 .size_full()
977 .on_action(cx.listener(ConfigurationView::save_credentials))
978 .child(Label::new("To use Zed's assistant with Bedrock, you need to add the Access Key ID, Secret Access Key and AWS Region. Follow these steps:"))
979 .child(
980 List::new()
981 .child(
982 InstructionListItem::new(
983 "Start by",
984 Some("creating a user and security credentials"),
985 Some("https://us-east-1.console.aws.amazon.com/iam/home")
986 )
987 )
988 .child(
989 InstructionListItem::new(
990 "Grant that user permissions according to this documentation:",
991 Some("Prerequisites"),
992 Some("https://docs.aws.amazon.com/bedrock/latest/userguide/inference-prereq.html")
993 )
994 )
995 .child(
996 InstructionListItem::new(
997 "Select the models you would like access to:",
998 Some("Bedrock Model Catalog"),
999 Some("https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/modelaccess")
1000 )
1001 )
1002 .child(
1003 InstructionListItem::text_only("Fill the fields below and hit enter to start using the assistant")
1004 )
1005 )
1006 .child(
1007 v_flex()
1008 .my_2()
1009 .gap_1p5()
1010 .child(
1011 v_flex()
1012 .gap_0p5()
1013 .child(Label::new("Access Key ID").size(LabelSize::Small))
1014 .child(
1015 input_base_styles().child(self.render_aa_id_editor(cx))
1016 )
1017 )
1018 .child(
1019 v_flex()
1020 .gap_0p5()
1021 .child(Label::new("Secret Access Key").size(LabelSize::Small))
1022 .child(
1023 input_base_styles().child(self.render_sk_editor(cx))
1024 )
1025 )
1026 .child(
1027 v_flex()
1028 .gap_0p5()
1029 .child(Label::new("Region").size(LabelSize::Small))
1030 .child(
1031 input_base_styles().child(self.render_region_editor(cx))
1032 )
1033 )
1034 )
1035 .child(
1036 Label::new(
1037 format!("You can also assign the {ZED_BEDROCK_ACCESS_KEY_ID_VAR}, {ZED_BEDROCK_SECRET_ACCESS_KEY_VAR}, and {ZED_BEDROCK_REGION_VAR} environment variables and restart Zed."),
1038 )
1039 .size(LabelSize::Small)
1040 .color(Color::Muted),
1041 )
1042 .into_any()
1043 } else {
1044 h_flex()
1045 .size_full()
1046 .justify_between()
1047 .child(
1048 h_flex()
1049 .gap_1()
1050 .child(Icon::new(IconName::Check).color(Color::Success))
1051 .child(Label::new(if env_var_set {
1052 format!("Access Key ID is set in {ZED_BEDROCK_ACCESS_KEY_ID_VAR}, Secret Key is set in {ZED_BEDROCK_SECRET_ACCESS_KEY_VAR}, Region is set in {ZED_BEDROCK_REGION_VAR} environment variables.")
1053 } else {
1054 "Credentials configured.".to_string()
1055 })),
1056 )
1057 .child(
1058 Button::new("reset-key", "Reset key")
1059 .icon(Some(IconName::Trash))
1060 .icon_size(IconSize::Small)
1061 .icon_position(IconPosition::Start)
1062 .disabled(env_var_set)
1063 .when(env_var_set, |this| {
1064 this.tooltip(Tooltip::text(format!("To reset your credentials, unset the {ZED_BEDROCK_ACCESS_KEY_ID_VAR}, {ZED_BEDROCK_SECRET_ACCESS_KEY_VAR}, and {ZED_BEDROCK_REGION_VAR} environment variables.")))
1065 })
1066 .on_click(cx.listener(|this, _, window, cx| this.reset_credentials(window, cx))),
1067 )
1068 .into_any()
1069 }
1070 }
1071}