1use codestral::{CODESTRAL_API_URL, codestral_api_key_state, codestral_api_url};
2use edit_prediction::{
3 ApiKeyState,
4 mercury::{MERCURY_CREDENTIALS_URL, mercury_api_token},
5 open_ai_compatible::{open_ai_compatible_api_token, open_ai_compatible_api_url},
6};
7use edit_prediction_ui::{get_available_providers, set_completion_provider};
8use gpui::{Entity, ScrollHandle, prelude::*};
9use language::language_settings::AllLanguageSettings;
10
11use settings::Settings as _;
12use ui::{ButtonLink, ConfiguredApiCard, ContextMenu, DropdownMenu, DropdownStyle, prelude::*};
13use workspace::AppState;
14
15const OLLAMA_API_URL_PLACEHOLDER: &str = "http://localhost:11434";
16const OLLAMA_MODEL_PLACEHOLDER: &str = "qwen2.5-coder:3b-base";
17
18use crate::{
19 SettingField, SettingItem, SettingsFieldMetadata, SettingsPageItem, SettingsWindow, USER,
20 components::{SettingsInputField, SettingsSectionHeader},
21};
22
23pub(crate) fn render_edit_prediction_setup_page(
24 settings_window: &SettingsWindow,
25 scroll_handle: &ScrollHandle,
26 window: &mut Window,
27 cx: &mut Context<SettingsWindow>,
28) -> AnyElement {
29 let providers = [
30 Some(render_provider_dropdown(window, cx)),
31 render_github_copilot_provider(window, cx).map(IntoElement::into_any_element),
32 Some(
33 render_api_key_provider(
34 IconName::Inception,
35 "Mercury",
36 ApiKeyDocs::Link {
37 dashboard_url: "https://platform.inceptionlabs.ai/dashboard/api-keys".into(),
38 },
39 mercury_api_token(cx),
40 |_cx| MERCURY_CREDENTIALS_URL,
41 None,
42 window,
43 cx,
44 )
45 .into_any_element(),
46 ),
47 Some(
48 render_api_key_provider(
49 IconName::AiMistral,
50 "Codestral",
51 ApiKeyDocs::Link {
52 dashboard_url: "https://console.mistral.ai/codestral".into(),
53 },
54 codestral_api_key_state(cx),
55 |cx| codestral_api_url(cx),
56 Some(
57 settings_window
58 .render_sub_page_items_section(
59 codestral_settings().iter().enumerate(),
60 true,
61 window,
62 cx,
63 )
64 .into_any_element(),
65 ),
66 window,
67 cx,
68 )
69 .into_any_element(),
70 ),
71 Some(render_ollama_provider(settings_window, window, cx).into_any_element()),
72 Some(
73 render_api_key_provider(
74 IconName::AiOpenAiCompat,
75 "OpenAI Compatible API",
76 ApiKeyDocs::Custom {
77 message: "The API key sent as Authorization: Bearer {key}.".into(),
78 },
79 open_ai_compatible_api_token(cx),
80 |cx| open_ai_compatible_api_url(cx),
81 Some(
82 settings_window
83 .render_sub_page_items_section(
84 open_ai_compatible_settings().iter().enumerate(),
85 true,
86 window,
87 cx,
88 )
89 .into_any_element(),
90 ),
91 window,
92 cx,
93 )
94 .into_any_element(),
95 ),
96 ];
97
98 div()
99 .size_full()
100 .child(
101 v_flex()
102 .id("ep-setup-page")
103 .min_w_0()
104 .size_full()
105 .px_8()
106 .pb_16()
107 .overflow_y_scroll()
108 .track_scroll(&scroll_handle)
109 .children(providers.into_iter().flatten()),
110 )
111 .into_any_element()
112}
113
114fn render_provider_dropdown(window: &mut Window, cx: &mut App) -> AnyElement {
115 let current_provider = AllLanguageSettings::get_global(cx)
116 .edit_predictions
117 .provider;
118 let current_provider_name = current_provider.display_name().unwrap_or("No provider set");
119
120 let menu = ContextMenu::build(window, cx, move |mut menu, _, cx| {
121 let available_providers = get_available_providers(cx);
122 let fs = <dyn fs::Fs>::global(cx);
123
124 for provider in available_providers {
125 let Some(name) = provider.display_name() else {
126 continue;
127 };
128 let is_current = provider == current_provider;
129
130 menu = menu.toggleable_entry(name, is_current, IconPosition::Start, None, {
131 let fs = fs.clone();
132 move |_, cx| {
133 set_completion_provider(fs.clone(), cx, provider);
134 }
135 });
136 }
137 menu
138 });
139
140 v_flex()
141 .id("provider-selector")
142 .min_w_0()
143 .gap_1p5()
144 .child(SettingsSectionHeader::new("Active Provider").no_padding(true))
145 .child(
146 h_flex()
147 .pt_2p5()
148 .w_full()
149 .min_w_0()
150 .justify_between()
151 .child(
152 v_flex()
153 .w_full()
154 .min_w_0()
155 .max_w_1_2()
156 .child(Label::new("Provider"))
157 .child(
158 Label::new("Select which provider to use for edit predictions.")
159 .size(LabelSize::Small)
160 .color(Color::Muted),
161 ),
162 )
163 .child(
164 DropdownMenu::new("provider-dropdown", current_provider_name, menu)
165 .tab_index(0)
166 .style(DropdownStyle::Outlined),
167 ),
168 )
169 .into_any_element()
170}
171
172enum ApiKeyDocs {
173 Link { dashboard_url: SharedString },
174 Custom { message: SharedString },
175}
176
177fn render_api_key_provider(
178 icon: IconName,
179 title: &'static str,
180 docs: ApiKeyDocs,
181 api_key_state: Entity<ApiKeyState>,
182 current_url: fn(&mut App) -> SharedString,
183 additional_fields: Option<AnyElement>,
184 window: &mut Window,
185 cx: &mut Context<SettingsWindow>,
186) -> impl IntoElement {
187 let weak_page = cx.weak_entity();
188 _ = window.use_keyed_state(current_url(cx), cx, |_, cx| {
189 let task = api_key_state.update(cx, |key_state, cx| {
190 key_state.load_if_needed(current_url(cx), |state| state, cx)
191 });
192 cx.spawn(async move |_, cx| {
193 task.await.ok();
194 weak_page
195 .update(cx, |_, cx| {
196 cx.notify();
197 })
198 .ok();
199 })
200 });
201
202 let (has_key, env_var_name, is_from_env_var) = api_key_state.read_with(cx, |state, _| {
203 (
204 state.has_key(),
205 Some(state.env_var_name().clone()),
206 state.is_from_env_var(),
207 )
208 });
209
210 let write_key = move |api_key: Option<String>, cx: &mut App| {
211 api_key_state
212 .update(cx, |key_state, cx| {
213 let url = current_url(cx);
214 key_state.store(url, api_key, |key_state| key_state, cx)
215 })
216 .detach_and_log_err(cx);
217 };
218
219 let base_container = v_flex().id(title).min_w_0().pt_8().gap_1p5();
220 let header = SettingsSectionHeader::new(title)
221 .icon(icon)
222 .no_padding(true);
223 let button_link_label = format!("{} dashboard", title);
224 let description = match docs {
225 ApiKeyDocs::Custom { message } => div().min_w_0().w_full().child(
226 Label::new(message)
227 .size(LabelSize::Small)
228 .color(Color::Muted),
229 ),
230 ApiKeyDocs::Link { dashboard_url } => h_flex()
231 .w_full()
232 .min_w_0()
233 .flex_wrap()
234 .gap_0p5()
235 .child(
236 Label::new("Visit the")
237 .size(LabelSize::Small)
238 .color(Color::Muted),
239 )
240 .child(
241 ButtonLink::new(button_link_label, dashboard_url)
242 .no_icon(true)
243 .label_size(LabelSize::Small)
244 .label_color(Color::Muted),
245 )
246 .child(
247 Label::new("to generate an API key.")
248 .size(LabelSize::Small)
249 .color(Color::Muted),
250 ),
251 };
252 let configured_card_label = if is_from_env_var {
253 "API Key Set in Environment Variable"
254 } else {
255 "API Key Configured"
256 };
257
258 let container = if has_key {
259 base_container.child(header).child(
260 ConfiguredApiCard::new(configured_card_label)
261 .button_label("Reset Key")
262 .button_tab_index(0)
263 .disabled(is_from_env_var)
264 .when_some(env_var_name, |this, env_var_name| {
265 this.when(is_from_env_var, |this| {
266 this.tooltip_label(format!(
267 "To reset your API key, unset the {} environment variable.",
268 env_var_name
269 ))
270 })
271 })
272 .on_click(move |_, _, cx| {
273 write_key(None, cx);
274 }),
275 )
276 } else {
277 base_container.child(header).child(
278 h_flex()
279 .pt_2p5()
280 .w_full()
281 .min_w_0()
282 .justify_between()
283 .child(
284 v_flex()
285 .w_full()
286 .min_w_0()
287 .max_w_1_2()
288 .child(Label::new("API Key"))
289 .child(description)
290 .when_some(env_var_name, |this, env_var_name| {
291 this.child({
292 let label = format!(
293 "Or set the {} env var and restart Zed.",
294 env_var_name.as_ref()
295 );
296 Label::new(label).size(LabelSize::Small).color(Color::Muted)
297 })
298 }),
299 )
300 .child(
301 SettingsInputField::new()
302 .tab_index(0)
303 .with_placeholder("xxxxxxxxxxxxxxxxxxxx")
304 .on_confirm(move |api_key, _window, cx| {
305 write_key(api_key.filter(|key| !key.is_empty()), cx);
306 }),
307 ),
308 )
309 };
310
311 container.when_some(additional_fields, |this, additional_fields| {
312 this.child(
313 div()
314 .map(|this| if has_key { this.mt_1() } else { this.mt_4() })
315 .px_neg_8()
316 .border_t_1()
317 .border_color(cx.theme().colors().border_variant)
318 .child(additional_fields),
319 )
320 })
321}
322
323fn render_ollama_provider(
324 settings_window: &SettingsWindow,
325 window: &mut Window,
326 cx: &mut Context<SettingsWindow>,
327) -> impl IntoElement {
328 let ollama_settings = ollama_settings();
329 let additional_fields = settings_window
330 .render_sub_page_items_section(ollama_settings.iter().enumerate(), true, window, cx)
331 .into_any_element();
332
333 v_flex()
334 .id("ollama")
335 .min_w_0()
336 .pt_8()
337 .gap_1p5()
338 .child(
339 SettingsSectionHeader::new("Ollama")
340 .icon(IconName::AiOllama)
341 .no_padding(true),
342 )
343 .child(div().px_neg_8().child(additional_fields))
344}
345
346fn ollama_settings() -> Box<[SettingsPageItem]> {
347 Box::new([
348 SettingsPageItem::SettingItem(SettingItem {
349 title: "API URL",
350 description: "The base URL of your Ollama server.",
351 field: Box::new(SettingField {
352 pick: |settings| {
353 settings
354 .project
355 .all_languages
356 .edit_predictions
357 .as_ref()?
358 .ollama
359 .as_ref()?
360 .api_url
361 .as_ref()
362 },
363 write: |settings, value| {
364 settings
365 .project
366 .all_languages
367 .edit_predictions
368 .get_or_insert_default()
369 .ollama
370 .get_or_insert_default()
371 .api_url = value;
372 },
373 json_path: Some("edit_predictions.ollama.api_url"),
374 }),
375 metadata: Some(Box::new(SettingsFieldMetadata {
376 placeholder: Some(OLLAMA_API_URL_PLACEHOLDER),
377 ..Default::default()
378 })),
379 files: USER,
380 }),
381 SettingsPageItem::SettingItem(SettingItem {
382 title: "Model",
383 description: "The Ollama model to use for edit predictions.",
384 field: Box::new(SettingField {
385 pick: |settings| {
386 settings
387 .project
388 .all_languages
389 .edit_predictions
390 .as_ref()?
391 .ollama
392 .as_ref()?
393 .model
394 .as_ref()
395 },
396 write: |settings, value| {
397 settings
398 .project
399 .all_languages
400 .edit_predictions
401 .get_or_insert_default()
402 .ollama
403 .get_or_insert_default()
404 .model = value;
405 },
406 json_path: Some("edit_predictions.ollama.model"),
407 }),
408 metadata: Some(Box::new(SettingsFieldMetadata {
409 placeholder: Some(OLLAMA_MODEL_PLACEHOLDER),
410 ..Default::default()
411 })),
412 files: USER,
413 }),
414 SettingsPageItem::SettingItem(SettingItem {
415 title: "Prompt Format",
416 description: "The prompt format to use when requesting predictions. Set to Infer to have the format inferred based on the model name.",
417 field: Box::new(SettingField {
418 pick: |settings| {
419 settings
420 .project
421 .all_languages
422 .edit_predictions
423 .as_ref()?
424 .ollama
425 .as_ref()?
426 .prompt_format
427 .as_ref()
428 },
429 write: |settings, value| {
430 settings
431 .project
432 .all_languages
433 .edit_predictions
434 .get_or_insert_default()
435 .ollama
436 .get_or_insert_default()
437 .prompt_format = value;
438 },
439 json_path: Some("edit_predictions.ollama.prompt_format"),
440 }),
441 files: USER,
442 metadata: None,
443 }),
444 SettingsPageItem::SettingItem(SettingItem {
445 title: "Max Output Tokens",
446 description: "The maximum number of tokens to generate.",
447 field: Box::new(SettingField {
448 pick: |settings| {
449 settings
450 .project
451 .all_languages
452 .edit_predictions
453 .as_ref()?
454 .ollama
455 .as_ref()?
456 .max_output_tokens
457 .as_ref()
458 },
459 write: |settings, value| {
460 settings
461 .project
462 .all_languages
463 .edit_predictions
464 .get_or_insert_default()
465 .ollama
466 .get_or_insert_default()
467 .max_output_tokens = value;
468 },
469 json_path: Some("edit_predictions.ollama.max_output_tokens"),
470 }),
471 metadata: None,
472 files: USER,
473 }),
474 ])
475}
476
477fn open_ai_compatible_settings() -> Box<[SettingsPageItem]> {
478 Box::new([
479 SettingsPageItem::SettingItem(SettingItem {
480 title: "API URL",
481 description: "The URL of your OpenAI-compatible server's completions API.",
482 field: Box::new(SettingField {
483 pick: |settings| {
484 settings
485 .project
486 .all_languages
487 .edit_predictions
488 .as_ref()?
489 .open_ai_compatible_api
490 .as_ref()?
491 .api_url
492 .as_ref()
493 },
494 write: |settings, value| {
495 settings
496 .project
497 .all_languages
498 .edit_predictions
499 .get_or_insert_default()
500 .open_ai_compatible_api
501 .get_or_insert_default()
502 .api_url = value;
503 },
504 json_path: Some("edit_predictions.open_ai_compatible_api.api_url"),
505 }),
506 metadata: Some(Box::new(SettingsFieldMetadata {
507 placeholder: Some(OLLAMA_API_URL_PLACEHOLDER),
508 ..Default::default()
509 })),
510 files: USER,
511 }),
512 SettingsPageItem::SettingItem(SettingItem {
513 title: "Model",
514 description: "The model string to pass to the OpenAI-compatible server.",
515 field: Box::new(SettingField {
516 pick: |settings| {
517 settings
518 .project
519 .all_languages
520 .edit_predictions
521 .as_ref()?
522 .open_ai_compatible_api
523 .as_ref()?
524 .model
525 .as_ref()
526 },
527 write: |settings, value| {
528 settings
529 .project
530 .all_languages
531 .edit_predictions
532 .get_or_insert_default()
533 .open_ai_compatible_api
534 .get_or_insert_default()
535 .model = value;
536 },
537 json_path: Some("edit_predictions.open_ai_compatible_api.model"),
538 }),
539 metadata: Some(Box::new(SettingsFieldMetadata {
540 placeholder: Some(OLLAMA_MODEL_PLACEHOLDER),
541 ..Default::default()
542 })),
543 files: USER,
544 }),
545 SettingsPageItem::SettingItem(SettingItem {
546 title: "Prompt Format",
547 description: "The prompt format to use when requesting predictions. Set to Infer to have the format inferred based on the model name.",
548 field: Box::new(SettingField {
549 pick: |settings| {
550 settings
551 .project
552 .all_languages
553 .edit_predictions
554 .as_ref()?
555 .open_ai_compatible_api
556 .as_ref()?
557 .prompt_format
558 .as_ref()
559 },
560 write: |settings, value| {
561 settings
562 .project
563 .all_languages
564 .edit_predictions
565 .get_or_insert_default()
566 .open_ai_compatible_api
567 .get_or_insert_default()
568 .prompt_format = value;
569 },
570 json_path: Some("edit_predictions.open_ai_compatible_api.prompt_format"),
571 }),
572 files: USER,
573 metadata: None,
574 }),
575 SettingsPageItem::SettingItem(SettingItem {
576 title: "Max Output Tokens",
577 description: "The maximum number of tokens to generate.",
578 field: Box::new(SettingField {
579 pick: |settings| {
580 settings
581 .project
582 .all_languages
583 .edit_predictions
584 .as_ref()?
585 .open_ai_compatible_api
586 .as_ref()?
587 .max_output_tokens
588 .as_ref()
589 },
590 write: |settings, value| {
591 settings
592 .project
593 .all_languages
594 .edit_predictions
595 .get_or_insert_default()
596 .open_ai_compatible_api
597 .get_or_insert_default()
598 .max_output_tokens = value;
599 },
600 json_path: Some("edit_predictions.open_ai_compatible_api.max_output_tokens"),
601 }),
602 metadata: None,
603 files: USER,
604 }),
605 ])
606}
607
608fn codestral_settings() -> Box<[SettingsPageItem]> {
609 Box::new([
610 SettingsPageItem::SettingItem(SettingItem {
611 title: "API URL",
612 description: "The API URL to use for Codestral.",
613 field: Box::new(SettingField {
614 pick: |settings| {
615 settings
616 .project
617 .all_languages
618 .edit_predictions
619 .as_ref()?
620 .codestral
621 .as_ref()?
622 .api_url
623 .as_ref()
624 },
625 write: |settings, value| {
626 settings
627 .project
628 .all_languages
629 .edit_predictions
630 .get_or_insert_default()
631 .codestral
632 .get_or_insert_default()
633 .api_url = value;
634 },
635 json_path: Some("edit_predictions.codestral.api_url"),
636 }),
637 metadata: Some(Box::new(SettingsFieldMetadata {
638 placeholder: Some(CODESTRAL_API_URL),
639 ..Default::default()
640 })),
641 files: USER,
642 }),
643 SettingsPageItem::SettingItem(SettingItem {
644 title: "Max Tokens",
645 description: "The maximum number of tokens to generate.",
646 field: Box::new(SettingField {
647 pick: |settings| {
648 settings
649 .project
650 .all_languages
651 .edit_predictions
652 .as_ref()?
653 .codestral
654 .as_ref()?
655 .max_tokens
656 .as_ref()
657 },
658 write: |settings, value| {
659 settings
660 .project
661 .all_languages
662 .edit_predictions
663 .get_or_insert_default()
664 .codestral
665 .get_or_insert_default()
666 .max_tokens = value;
667 },
668 json_path: Some("edit_predictions.codestral.max_tokens"),
669 }),
670 metadata: None,
671 files: USER,
672 }),
673 SettingsPageItem::SettingItem(SettingItem {
674 title: "Model",
675 description: "The Codestral model id to use.",
676 field: Box::new(SettingField {
677 pick: |settings| {
678 settings
679 .project
680 .all_languages
681 .edit_predictions
682 .as_ref()?
683 .codestral
684 .as_ref()?
685 .model
686 .as_ref()
687 },
688 write: |settings, value| {
689 settings
690 .project
691 .all_languages
692 .edit_predictions
693 .get_or_insert_default()
694 .codestral
695 .get_or_insert_default()
696 .model = value;
697 },
698 json_path: Some("edit_predictions.codestral.model"),
699 }),
700 metadata: Some(Box::new(SettingsFieldMetadata {
701 placeholder: Some("codestral-latest"),
702 ..Default::default()
703 })),
704 files: USER,
705 }),
706 ])
707}
708
709fn render_github_copilot_provider(window: &mut Window, cx: &mut App) -> Option<impl IntoElement> {
710 let configuration_view = window.use_state(cx, |_, cx| {
711 copilot_ui::ConfigurationView::new(
712 move |cx| {
713 if let Some(app_state) = AppState::global(cx).upgrade() {
714 copilot::GlobalCopilotAuth::try_get_or_init(app_state, cx)
715 .is_some_and(|copilot| copilot.0.read(cx).is_authenticated())
716 } else {
717 false
718 }
719 },
720 copilot_ui::ConfigurationMode::EditPrediction,
721 cx,
722 )
723 });
724
725 Some(
726 v_flex()
727 .id("github-copilot")
728 .min_w_0()
729 .pt_8()
730 .gap_1p5()
731 .child(
732 SettingsSectionHeader::new("GitHub Copilot")
733 .icon(IconName::Copilot)
734 .no_padding(true),
735 )
736 .child(configuration_view),
737 )
738}