1use codestral::{CODESTRAL_API_URL, codestral_api_key_state, codestral_api_url};
2use edit_prediction::{
3 ApiKeyState,
4 mercury::{MERCURY_CREDENTIALS_URL, mercury_api_token},
5 open_ai_compatible::{open_ai_compatible_api_token, open_ai_compatible_api_url},
6};
7use edit_prediction_ui::{get_available_providers, set_completion_provider};
8use gpui::{Entity, ScrollHandle, prelude::*};
9use language::language_settings::AllLanguageSettings;
10
11use settings::Settings as _;
12use ui::{ButtonLink, ConfiguredApiCard, ContextMenu, DropdownMenu, DropdownStyle, prelude::*};
13use workspace::AppState;
14use zed_credentials_provider::global as global_credentials_provider;
15
16const OLLAMA_API_URL_PLACEHOLDER: &str = "http://localhost:11434";
17const OLLAMA_MODEL_PLACEHOLDER: &str = "qwen2.5-coder:3b-base";
18
19use crate::{
20 SettingField, SettingItem, SettingsFieldMetadata, SettingsPageItem, SettingsWindow, USER,
21 components::{SettingsInputField, SettingsSectionHeader},
22};
23
24pub(crate) fn render_edit_prediction_setup_page(
25 settings_window: &SettingsWindow,
26 scroll_handle: &ScrollHandle,
27 window: &mut Window,
28 cx: &mut Context<SettingsWindow>,
29) -> AnyElement {
30 let providers = [
31 Some(render_provider_dropdown(window, cx)),
32 render_github_copilot_provider(window, cx).map(IntoElement::into_any_element),
33 Some(
34 render_api_key_provider(
35 IconName::Inception,
36 "Mercury",
37 ApiKeyDocs::Link {
38 dashboard_url: "https://platform.inceptionlabs.ai/dashboard/api-keys".into(),
39 },
40 mercury_api_token(cx),
41 |_cx| MERCURY_CREDENTIALS_URL,
42 None,
43 window,
44 cx,
45 )
46 .into_any_element(),
47 ),
48 Some(
49 render_api_key_provider(
50 IconName::AiMistral,
51 "Codestral",
52 ApiKeyDocs::Link {
53 dashboard_url: "https://console.mistral.ai/codestral".into(),
54 },
55 codestral_api_key_state(cx),
56 |cx| codestral_api_url(cx),
57 Some(
58 settings_window
59 .render_sub_page_items_section(
60 codestral_settings().iter().enumerate(),
61 true,
62 window,
63 cx,
64 )
65 .into_any_element(),
66 ),
67 window,
68 cx,
69 )
70 .into_any_element(),
71 ),
72 Some(render_ollama_provider(settings_window, window, cx).into_any_element()),
73 Some(
74 render_api_key_provider(
75 IconName::AiOpenAiCompat,
76 "OpenAI Compatible API",
77 ApiKeyDocs::Custom {
78 message: "The API key sent as Authorization: Bearer {key}.".into(),
79 },
80 open_ai_compatible_api_token(cx),
81 |cx| open_ai_compatible_api_url(cx),
82 Some(
83 settings_window
84 .render_sub_page_items_section(
85 open_ai_compatible_settings().iter().enumerate(),
86 true,
87 window,
88 cx,
89 )
90 .into_any_element(),
91 ),
92 window,
93 cx,
94 )
95 .into_any_element(),
96 ),
97 ];
98
99 div()
100 .size_full()
101 .child(
102 v_flex()
103 .id("ep-setup-page")
104 .min_w_0()
105 .size_full()
106 .px_8()
107 .pb_16()
108 .overflow_y_scroll()
109 .track_scroll(&scroll_handle)
110 .children(providers.into_iter().flatten()),
111 )
112 .into_any_element()
113}
114
115fn render_provider_dropdown(window: &mut Window, cx: &mut App) -> AnyElement {
116 let current_provider = AllLanguageSettings::get_global(cx)
117 .edit_predictions
118 .provider;
119 let current_provider_name = current_provider.display_name().unwrap_or("No provider set");
120
121 let menu = ContextMenu::build(window, cx, move |mut menu, _, cx| {
122 let available_providers = get_available_providers(cx);
123 let fs = <dyn fs::Fs>::global(cx);
124
125 for provider in available_providers {
126 let Some(name) = provider.display_name() else {
127 continue;
128 };
129 let is_current = provider == current_provider;
130
131 menu = menu.toggleable_entry(name, is_current, IconPosition::Start, None, {
132 let fs = fs.clone();
133 move |_, cx| {
134 set_completion_provider(fs.clone(), cx, provider);
135 }
136 });
137 }
138 menu
139 });
140
141 v_flex()
142 .id("provider-selector")
143 .min_w_0()
144 .gap_1p5()
145 .child(SettingsSectionHeader::new("Active Provider").no_padding(true))
146 .child(
147 h_flex()
148 .pt_2p5()
149 .w_full()
150 .min_w_0()
151 .justify_between()
152 .child(
153 v_flex()
154 .w_full()
155 .min_w_0()
156 .max_w_1_2()
157 .child(Label::new("Provider"))
158 .child(
159 Label::new("Select which provider to use for edit predictions.")
160 .size(LabelSize::Small)
161 .color(Color::Muted),
162 ),
163 )
164 .child(
165 DropdownMenu::new("provider-dropdown", current_provider_name, menu)
166 .tab_index(0)
167 .style(DropdownStyle::Outlined),
168 ),
169 )
170 .into_any_element()
171}
172
173enum ApiKeyDocs {
174 Link { dashboard_url: SharedString },
175 Custom { message: SharedString },
176}
177
178fn render_api_key_provider(
179 icon: IconName,
180 title: &'static str,
181 docs: ApiKeyDocs,
182 api_key_state: Entity<ApiKeyState>,
183 current_url: fn(&mut App) -> SharedString,
184 additional_fields: Option<AnyElement>,
185 window: &mut Window,
186 cx: &mut Context<SettingsWindow>,
187) -> impl IntoElement {
188 let weak_page = cx.weak_entity();
189 let credentials_provider = global_credentials_provider(cx);
190 _ = window.use_keyed_state(current_url(cx), cx, |_, cx| {
191 let task = api_key_state.update(cx, |key_state, cx| {
192 key_state.load_if_needed(
193 current_url(cx),
194 |state| state,
195 credentials_provider.clone(),
196 cx,
197 )
198 });
199 cx.spawn(async move |_, cx| {
200 task.await.ok();
201 weak_page
202 .update(cx, |_, cx| {
203 cx.notify();
204 })
205 .ok();
206 })
207 });
208
209 let (has_key, env_var_name, is_from_env_var) = api_key_state.read_with(cx, |state, _| {
210 (
211 state.has_key(),
212 Some(state.env_var_name().clone()),
213 state.is_from_env_var(),
214 )
215 });
216
217 let write_key = move |api_key: Option<String>, cx: &mut App| {
218 let credentials_provider = global_credentials_provider(cx);
219 api_key_state
220 .update(cx, |key_state, cx| {
221 let url = current_url(cx);
222 key_state.store(
223 url,
224 api_key,
225 |key_state| key_state,
226 credentials_provider,
227 cx,
228 )
229 })
230 .detach_and_log_err(cx);
231 };
232
233 let base_container = v_flex().id(title).min_w_0().pt_8().gap_1p5();
234 let header = SettingsSectionHeader::new(title)
235 .icon(icon)
236 .no_padding(true);
237 let button_link_label = format!("{} dashboard", title);
238 let description = match docs {
239 ApiKeyDocs::Custom { message } => div().min_w_0().w_full().child(
240 Label::new(message)
241 .size(LabelSize::Small)
242 .color(Color::Muted),
243 ),
244 ApiKeyDocs::Link { dashboard_url } => h_flex()
245 .w_full()
246 .min_w_0()
247 .flex_wrap()
248 .gap_0p5()
249 .child(
250 Label::new("Visit the")
251 .size(LabelSize::Small)
252 .color(Color::Muted),
253 )
254 .child(
255 ButtonLink::new(button_link_label, dashboard_url)
256 .no_icon(true)
257 .label_size(LabelSize::Small)
258 .label_color(Color::Muted),
259 )
260 .child(
261 Label::new("to generate an API key.")
262 .size(LabelSize::Small)
263 .color(Color::Muted),
264 ),
265 };
266 let configured_card_label = if is_from_env_var {
267 "API Key Set in Environment Variable"
268 } else {
269 "API Key Configured"
270 };
271
272 let container = if has_key {
273 base_container.child(header).child(
274 ConfiguredApiCard::new(configured_card_label)
275 .button_label("Reset Key")
276 .button_tab_index(0)
277 .disabled(is_from_env_var)
278 .when_some(env_var_name, |this, env_var_name| {
279 this.when(is_from_env_var, |this| {
280 this.tooltip_label(format!(
281 "To reset your API key, unset the {} environment variable.",
282 env_var_name
283 ))
284 })
285 })
286 .on_click(move |_, _, cx| {
287 write_key(None, cx);
288 }),
289 )
290 } else {
291 base_container.child(header).child(
292 h_flex()
293 .pt_2p5()
294 .w_full()
295 .min_w_0()
296 .justify_between()
297 .child(
298 v_flex()
299 .w_full()
300 .min_w_0()
301 .max_w_1_2()
302 .child(Label::new("API Key"))
303 .child(description)
304 .when_some(env_var_name, |this, env_var_name| {
305 this.child({
306 let label = format!(
307 "Or set the {} env var and restart Zed.",
308 env_var_name.as_ref()
309 );
310 Label::new(label).size(LabelSize::Small).color(Color::Muted)
311 })
312 }),
313 )
314 .child(
315 SettingsInputField::new()
316 .tab_index(0)
317 .with_placeholder("xxxxxxxxxxxxxxxxxxxx")
318 .on_confirm(move |api_key, _window, cx| {
319 write_key(api_key.filter(|key| !key.is_empty()), cx);
320 }),
321 ),
322 )
323 };
324
325 container.when_some(additional_fields, |this, additional_fields| {
326 this.child(
327 div()
328 .map(|this| if has_key { this.mt_1() } else { this.mt_4() })
329 .px_neg_8()
330 .border_t_1()
331 .border_color(cx.theme().colors().border_variant)
332 .child(additional_fields),
333 )
334 })
335}
336
337fn render_ollama_provider(
338 settings_window: &SettingsWindow,
339 window: &mut Window,
340 cx: &mut Context<SettingsWindow>,
341) -> impl IntoElement {
342 let ollama_settings = ollama_settings();
343 let additional_fields = settings_window
344 .render_sub_page_items_section(ollama_settings.iter().enumerate(), true, window, cx)
345 .into_any_element();
346
347 v_flex()
348 .id("ollama")
349 .min_w_0()
350 .pt_8()
351 .gap_1p5()
352 .child(
353 SettingsSectionHeader::new("Ollama")
354 .icon(IconName::AiOllama)
355 .no_padding(true),
356 )
357 .child(div().px_neg_8().child(additional_fields))
358}
359
360fn ollama_settings() -> Box<[SettingsPageItem]> {
361 Box::new([
362 SettingsPageItem::SettingItem(SettingItem {
363 title: "API URL",
364 description: "The base URL of your Ollama server.",
365 field: Box::new(SettingField {
366 pick: |settings| {
367 settings
368 .project
369 .all_languages
370 .edit_predictions
371 .as_ref()?
372 .ollama
373 .as_ref()?
374 .api_url
375 .as_ref()
376 },
377 write: |settings, value| {
378 settings
379 .project
380 .all_languages
381 .edit_predictions
382 .get_or_insert_default()
383 .ollama
384 .get_or_insert_default()
385 .api_url = value;
386 },
387 json_path: Some("edit_predictions.ollama.api_url"),
388 }),
389 metadata: Some(Box::new(SettingsFieldMetadata {
390 placeholder: Some(OLLAMA_API_URL_PLACEHOLDER),
391 ..Default::default()
392 })),
393 files: USER,
394 }),
395 SettingsPageItem::SettingItem(SettingItem {
396 title: "Model",
397 description: "The Ollama model to use for edit predictions.",
398 field: Box::new(SettingField {
399 pick: |settings| {
400 settings
401 .project
402 .all_languages
403 .edit_predictions
404 .as_ref()?
405 .ollama
406 .as_ref()?
407 .model
408 .as_ref()
409 },
410 write: |settings, value| {
411 settings
412 .project
413 .all_languages
414 .edit_predictions
415 .get_or_insert_default()
416 .ollama
417 .get_or_insert_default()
418 .model = value;
419 },
420 json_path: Some("edit_predictions.ollama.model"),
421 }),
422 metadata: Some(Box::new(SettingsFieldMetadata {
423 placeholder: Some(OLLAMA_MODEL_PLACEHOLDER),
424 ..Default::default()
425 })),
426 files: USER,
427 }),
428 SettingsPageItem::SettingItem(SettingItem {
429 title: "Prompt Format",
430 description: "The prompt format to use when requesting predictions. Set to Infer to have the format inferred based on the model name.",
431 field: Box::new(SettingField {
432 pick: |settings| {
433 settings
434 .project
435 .all_languages
436 .edit_predictions
437 .as_ref()?
438 .ollama
439 .as_ref()?
440 .prompt_format
441 .as_ref()
442 },
443 write: |settings, value| {
444 settings
445 .project
446 .all_languages
447 .edit_predictions
448 .get_or_insert_default()
449 .ollama
450 .get_or_insert_default()
451 .prompt_format = value;
452 },
453 json_path: Some("edit_predictions.ollama.prompt_format"),
454 }),
455 files: USER,
456 metadata: None,
457 }),
458 SettingsPageItem::SettingItem(SettingItem {
459 title: "Max Output Tokens",
460 description: "The maximum number of tokens to generate.",
461 field: Box::new(SettingField {
462 pick: |settings| {
463 settings
464 .project
465 .all_languages
466 .edit_predictions
467 .as_ref()?
468 .ollama
469 .as_ref()?
470 .max_output_tokens
471 .as_ref()
472 },
473 write: |settings, value| {
474 settings
475 .project
476 .all_languages
477 .edit_predictions
478 .get_or_insert_default()
479 .ollama
480 .get_or_insert_default()
481 .max_output_tokens = value;
482 },
483 json_path: Some("edit_predictions.ollama.max_output_tokens"),
484 }),
485 metadata: None,
486 files: USER,
487 }),
488 ])
489}
490
491fn open_ai_compatible_settings() -> Box<[SettingsPageItem]> {
492 Box::new([
493 SettingsPageItem::SettingItem(SettingItem {
494 title: "API URL",
495 description: "The URL of your OpenAI-compatible server's completions API.",
496 field: Box::new(SettingField {
497 pick: |settings| {
498 settings
499 .project
500 .all_languages
501 .edit_predictions
502 .as_ref()?
503 .open_ai_compatible_api
504 .as_ref()?
505 .api_url
506 .as_ref()
507 },
508 write: |settings, value| {
509 settings
510 .project
511 .all_languages
512 .edit_predictions
513 .get_or_insert_default()
514 .open_ai_compatible_api
515 .get_or_insert_default()
516 .api_url = value;
517 },
518 json_path: Some("edit_predictions.open_ai_compatible_api.api_url"),
519 }),
520 metadata: Some(Box::new(SettingsFieldMetadata {
521 placeholder: Some(OLLAMA_API_URL_PLACEHOLDER),
522 ..Default::default()
523 })),
524 files: USER,
525 }),
526 SettingsPageItem::SettingItem(SettingItem {
527 title: "Model",
528 description: "The model string to pass to the OpenAI-compatible server.",
529 field: Box::new(SettingField {
530 pick: |settings| {
531 settings
532 .project
533 .all_languages
534 .edit_predictions
535 .as_ref()?
536 .open_ai_compatible_api
537 .as_ref()?
538 .model
539 .as_ref()
540 },
541 write: |settings, value| {
542 settings
543 .project
544 .all_languages
545 .edit_predictions
546 .get_or_insert_default()
547 .open_ai_compatible_api
548 .get_or_insert_default()
549 .model = value;
550 },
551 json_path: Some("edit_predictions.open_ai_compatible_api.model"),
552 }),
553 metadata: Some(Box::new(SettingsFieldMetadata {
554 placeholder: Some(OLLAMA_MODEL_PLACEHOLDER),
555 ..Default::default()
556 })),
557 files: USER,
558 }),
559 SettingsPageItem::SettingItem(SettingItem {
560 title: "Prompt Format",
561 description: "The prompt format to use when requesting predictions. Set to Infer to have the format inferred based on the model name.",
562 field: Box::new(SettingField {
563 pick: |settings| {
564 settings
565 .project
566 .all_languages
567 .edit_predictions
568 .as_ref()?
569 .open_ai_compatible_api
570 .as_ref()?
571 .prompt_format
572 .as_ref()
573 },
574 write: |settings, value| {
575 settings
576 .project
577 .all_languages
578 .edit_predictions
579 .get_or_insert_default()
580 .open_ai_compatible_api
581 .get_or_insert_default()
582 .prompt_format = value;
583 },
584 json_path: Some("edit_predictions.open_ai_compatible_api.prompt_format"),
585 }),
586 files: USER,
587 metadata: None,
588 }),
589 SettingsPageItem::SettingItem(SettingItem {
590 title: "Max Output Tokens",
591 description: "The maximum number of tokens to generate.",
592 field: Box::new(SettingField {
593 pick: |settings| {
594 settings
595 .project
596 .all_languages
597 .edit_predictions
598 .as_ref()?
599 .open_ai_compatible_api
600 .as_ref()?
601 .max_output_tokens
602 .as_ref()
603 },
604 write: |settings, value| {
605 settings
606 .project
607 .all_languages
608 .edit_predictions
609 .get_or_insert_default()
610 .open_ai_compatible_api
611 .get_or_insert_default()
612 .max_output_tokens = value;
613 },
614 json_path: Some("edit_predictions.open_ai_compatible_api.max_output_tokens"),
615 }),
616 metadata: None,
617 files: USER,
618 }),
619 ])
620}
621
622fn codestral_settings() -> Box<[SettingsPageItem]> {
623 Box::new([
624 SettingsPageItem::SettingItem(SettingItem {
625 title: "API URL",
626 description: "The API URL to use for Codestral.",
627 field: Box::new(SettingField {
628 pick: |settings| {
629 settings
630 .project
631 .all_languages
632 .edit_predictions
633 .as_ref()?
634 .codestral
635 .as_ref()?
636 .api_url
637 .as_ref()
638 },
639 write: |settings, value| {
640 settings
641 .project
642 .all_languages
643 .edit_predictions
644 .get_or_insert_default()
645 .codestral
646 .get_or_insert_default()
647 .api_url = value;
648 },
649 json_path: Some("edit_predictions.codestral.api_url"),
650 }),
651 metadata: Some(Box::new(SettingsFieldMetadata {
652 placeholder: Some(CODESTRAL_API_URL),
653 ..Default::default()
654 })),
655 files: USER,
656 }),
657 SettingsPageItem::SettingItem(SettingItem {
658 title: "Max Tokens",
659 description: "The maximum number of tokens to generate.",
660 field: Box::new(SettingField {
661 pick: |settings| {
662 settings
663 .project
664 .all_languages
665 .edit_predictions
666 .as_ref()?
667 .codestral
668 .as_ref()?
669 .max_tokens
670 .as_ref()
671 },
672 write: |settings, value| {
673 settings
674 .project
675 .all_languages
676 .edit_predictions
677 .get_or_insert_default()
678 .codestral
679 .get_or_insert_default()
680 .max_tokens = value;
681 },
682 json_path: Some("edit_predictions.codestral.max_tokens"),
683 }),
684 metadata: None,
685 files: USER,
686 }),
687 SettingsPageItem::SettingItem(SettingItem {
688 title: "Model",
689 description: "The Codestral model id to use.",
690 field: Box::new(SettingField {
691 pick: |settings| {
692 settings
693 .project
694 .all_languages
695 .edit_predictions
696 .as_ref()?
697 .codestral
698 .as_ref()?
699 .model
700 .as_ref()
701 },
702 write: |settings, value| {
703 settings
704 .project
705 .all_languages
706 .edit_predictions
707 .get_or_insert_default()
708 .codestral
709 .get_or_insert_default()
710 .model = value;
711 },
712 json_path: Some("edit_predictions.codestral.model"),
713 }),
714 metadata: Some(Box::new(SettingsFieldMetadata {
715 placeholder: Some("codestral-latest"),
716 ..Default::default()
717 })),
718 files: USER,
719 }),
720 ])
721}
722
723fn render_github_copilot_provider(window: &mut Window, cx: &mut App) -> Option<impl IntoElement> {
724 let configuration_view = window.use_state(cx, |_, cx| {
725 copilot_ui::ConfigurationView::new(
726 move |cx| {
727 let app_state = AppState::global(cx);
728 copilot::GlobalCopilotAuth::try_get_or_init(app_state, cx)
729 .is_some_and(|copilot| copilot.0.read(cx).is_authenticated())
730 },
731 copilot_ui::ConfigurationMode::EditPrediction,
732 cx,
733 )
734 });
735
736 Some(
737 v_flex()
738 .id("github-copilot")
739 .min_w_0()
740 .pt_8()
741 .gap_1p5()
742 .child(
743 SettingsSectionHeader::new("GitHub Copilot")
744 .icon(IconName::Copilot)
745 .no_padding(true),
746 )
747 .child(configuration_view),
748 )
749}