edit_prediction_provider_setup.rs

  1use codestral::{CODESTRAL_API_URL, codestral_api_key_state, codestral_api_url};
  2use edit_prediction::{
  3    ApiKeyState,
  4    mercury::{MERCURY_CREDENTIALS_URL, mercury_api_token},
  5    sweep_ai::{SWEEP_CREDENTIALS_URL, sweep_api_token},
  6};
  7use edit_prediction_ui::{get_available_providers, set_completion_provider};
  8use gpui::{Entity, ScrollHandle, prelude::*};
  9use language::language_settings::AllLanguageSettings;
 10
 11use settings::Settings as _;
 12use ui::{ButtonLink, ConfiguredApiCard, ContextMenu, DropdownMenu, DropdownStyle, prelude::*};
 13use workspace::AppState;
 14
 15const OLLAMA_API_URL_PLACEHOLDER: &str = "http://localhost:11434";
 16const OLLAMA_MODEL_PLACEHOLDER: &str = "qwen2.5-coder:3b-base";
 17
 18use crate::{
 19    SettingField, SettingItem, SettingsFieldMetadata, SettingsPageItem, SettingsWindow, USER,
 20    components::{SettingsInputField, SettingsSectionHeader},
 21};
 22
 23pub(crate) fn render_edit_prediction_setup_page(
 24    settings_window: &SettingsWindow,
 25    scroll_handle: &ScrollHandle,
 26    window: &mut Window,
 27    cx: &mut Context<SettingsWindow>,
 28) -> AnyElement {
 29    let providers = [
 30        Some(render_provider_dropdown(window, cx)),
 31        render_github_copilot_provider(window, cx).map(IntoElement::into_any_element),
 32        Some(
 33            render_api_key_provider(
 34                IconName::Inception,
 35                "Mercury",
 36                "https://platform.inceptionlabs.ai/dashboard/api-keys".into(),
 37                mercury_api_token(cx),
 38                |_cx| MERCURY_CREDENTIALS_URL,
 39                None,
 40                window,
 41                cx,
 42            )
 43            .into_any_element(),
 44        ),
 45        Some(
 46            render_api_key_provider(
 47                IconName::SweepAi,
 48                "Sweep",
 49                "https://app.sweep.dev/".into(),
 50                sweep_api_token(cx),
 51                |_cx| SWEEP_CREDENTIALS_URL,
 52                Some(
 53                    settings_window
 54                        .render_sub_page_items_section(
 55                            sweep_settings().iter().enumerate(),
 56                            true,
 57                            window,
 58                            cx,
 59                        )
 60                        .into_any_element(),
 61                ),
 62                window,
 63                cx,
 64            )
 65            .into_any_element(),
 66        ),
 67        Some(
 68            render_api_key_provider(
 69                IconName::AiMistral,
 70                "Codestral",
 71                "https://console.mistral.ai/codestral".into(),
 72                codestral_api_key_state(cx),
 73                |cx| codestral_api_url(cx),
 74                Some(
 75                    settings_window
 76                        .render_sub_page_items_section(
 77                            codestral_settings().iter().enumerate(),
 78                            true,
 79                            window,
 80                            cx,
 81                        )
 82                        .into_any_element(),
 83                ),
 84                window,
 85                cx,
 86            )
 87            .into_any_element(),
 88        ),
 89        Some(render_ollama_provider(settings_window, window, cx).into_any_element()),
 90        Some(render_open_ai_compatible_provider(settings_window, window, cx).into_any_element()),
 91    ];
 92
 93    div()
 94        .size_full()
 95        .child(
 96            v_flex()
 97                .id("ep-setup-page")
 98                .min_w_0()
 99                .size_full()
100                .px_8()
101                .pb_16()
102                .overflow_y_scroll()
103                .track_scroll(&scroll_handle)
104                .children(providers.into_iter().flatten()),
105        )
106        .into_any_element()
107}
108
109fn render_provider_dropdown(window: &mut Window, cx: &mut App) -> AnyElement {
110    let current_provider = AllLanguageSettings::get_global(cx)
111        .edit_predictions
112        .provider;
113    let current_provider_name = current_provider.display_name().unwrap_or("No provider set");
114
115    let menu = ContextMenu::build(window, cx, move |mut menu, _, cx| {
116        let available_providers = get_available_providers(cx);
117        let fs = <dyn fs::Fs>::global(cx);
118
119        for provider in available_providers {
120            let Some(name) = provider.display_name() else {
121                continue;
122            };
123            let is_current = provider == current_provider;
124
125            menu = menu.toggleable_entry(name, is_current, IconPosition::Start, None, {
126                let fs = fs.clone();
127                move |_, cx| {
128                    set_completion_provider(fs.clone(), cx, provider);
129                }
130            });
131        }
132        menu
133    });
134
135    v_flex()
136        .id("provider-selector")
137        .min_w_0()
138        .gap_1p5()
139        .child(SettingsSectionHeader::new("Active Provider").no_padding(true))
140        .child(
141            h_flex()
142                .pt_2p5()
143                .w_full()
144                .justify_between()
145                .child(
146                    v_flex()
147                        .w_full()
148                        .max_w_1_2()
149                        .child(Label::new("Provider"))
150                        .child(
151                            Label::new("Select which provider to use for edit predictions.")
152                                .size(LabelSize::Small)
153                                .color(Color::Muted),
154                        ),
155                )
156                .child(
157                    DropdownMenu::new("provider-dropdown", current_provider_name, menu)
158                        .tab_index(0)
159                        .style(DropdownStyle::Outlined),
160                ),
161        )
162        .into_any_element()
163}
164
165fn render_api_key_provider(
166    icon: IconName,
167    title: &'static str,
168    link: SharedString,
169    api_key_state: Entity<ApiKeyState>,
170    current_url: fn(&mut App) -> SharedString,
171    additional_fields: Option<AnyElement>,
172    window: &mut Window,
173    cx: &mut Context<SettingsWindow>,
174) -> impl IntoElement {
175    let weak_page = cx.weak_entity();
176    _ = window.use_keyed_state(current_url(cx), cx, |_, cx| {
177        let task = api_key_state.update(cx, |key_state, cx| {
178            key_state.load_if_needed(current_url(cx), |state| state, cx)
179        });
180        cx.spawn(async move |_, cx| {
181            task.await.ok();
182            weak_page
183                .update(cx, |_, cx| {
184                    cx.notify();
185                })
186                .ok();
187        })
188    });
189
190    let (has_key, env_var_name, is_from_env_var) = api_key_state.read_with(cx, |state, _| {
191        (
192            state.has_key(),
193            Some(state.env_var_name().clone()),
194            state.is_from_env_var(),
195        )
196    });
197
198    let write_key = move |api_key: Option<String>, cx: &mut App| {
199        api_key_state
200            .update(cx, |key_state, cx| {
201                let url = current_url(cx);
202                key_state.store(url, api_key, |key_state| key_state, cx)
203            })
204            .detach_and_log_err(cx);
205    };
206
207    let base_container = v_flex().id(title).min_w_0().pt_8().gap_1p5();
208    let header = SettingsSectionHeader::new(title)
209        .icon(icon)
210        .no_padding(true);
211    let button_link_label = format!("{} dashboard", title);
212    let description = h_flex()
213        .min_w_0()
214        .gap_0p5()
215        .child(
216            Label::new("Visit the")
217                .size(LabelSize::Small)
218                .color(Color::Muted),
219        )
220        .child(
221            ButtonLink::new(button_link_label, link)
222                .no_icon(true)
223                .label_size(LabelSize::Small)
224                .label_color(Color::Muted),
225        )
226        .child(
227            Label::new("to generate an API key.")
228                .size(LabelSize::Small)
229                .color(Color::Muted),
230        );
231    let configured_card_label = if is_from_env_var {
232        "API Key Set in Environment Variable"
233    } else {
234        "API Key Configured"
235    };
236
237    let container = if has_key {
238        base_container.child(header).child(
239            ConfiguredApiCard::new(configured_card_label)
240                .button_label("Reset Key")
241                .button_tab_index(0)
242                .disabled(is_from_env_var)
243                .when_some(env_var_name, |this, env_var_name| {
244                    this.when(is_from_env_var, |this| {
245                        this.tooltip_label(format!(
246                            "To reset your API key, unset the {} environment variable.",
247                            env_var_name
248                        ))
249                    })
250                })
251                .on_click(move |_, _, cx| {
252                    write_key(None, cx);
253                }),
254        )
255    } else {
256        base_container.child(header).child(
257            h_flex()
258                .pt_2p5()
259                .w_full()
260                .justify_between()
261                .child(
262                    v_flex()
263                        .w_full()
264                        .max_w_1_2()
265                        .child(Label::new("API Key"))
266                        .child(description)
267                        .when_some(env_var_name, |this, env_var_name| {
268                            this.child({
269                                let label = format!(
270                                    "Or set the {} env var and restart Zed.",
271                                    env_var_name.as_ref()
272                                );
273                                Label::new(label).size(LabelSize::Small).color(Color::Muted)
274                            })
275                        }),
276                )
277                .child(
278                    SettingsInputField::new()
279                        .tab_index(0)
280                        .with_placeholder("xxxxxxxxxxxxxxxxxxxx")
281                        .on_confirm(move |api_key, _window, cx| {
282                            write_key(api_key.filter(|key| !key.is_empty()), cx);
283                        }),
284                ),
285        )
286    };
287
288    container.when_some(additional_fields, |this, additional_fields| {
289        this.child(
290            div()
291                .map(|this| if has_key { this.mt_1() } else { this.mt_4() })
292                .px_neg_8()
293                .border_t_1()
294                .border_color(cx.theme().colors().border_variant)
295                .child(additional_fields),
296        )
297    })
298}
299
300fn sweep_settings() -> Box<[SettingsPageItem]> {
301    Box::new([SettingsPageItem::SettingItem(SettingItem {
302        title: "Privacy Mode",
303        description: "When enabled, Sweep will not store edit prediction inputs or outputs. When disabled, Sweep may collect data including buffer contents, diagnostics, file paths, and generated predictions to improve the service.",
304        field: Box::new(SettingField {
305            pick: |settings| {
306                settings
307                    .project
308                    .all_languages
309                    .edit_predictions
310                    .as_ref()?
311                    .sweep
312                    .as_ref()?
313                    .privacy_mode
314                    .as_ref()
315            },
316            write: |settings, value| {
317                settings
318                    .project
319                    .all_languages
320                    .edit_predictions
321                    .get_or_insert_default()
322                    .sweep
323                    .get_or_insert_default()
324                    .privacy_mode = value;
325            },
326            json_path: Some("edit_predictions.sweep.privacy_mode"),
327        }),
328        metadata: None,
329        files: USER,
330    })])
331}
332
333fn render_ollama_provider(
334    settings_window: &SettingsWindow,
335    window: &mut Window,
336    cx: &mut Context<SettingsWindow>,
337) -> impl IntoElement {
338    let ollama_settings = ollama_settings();
339    let additional_fields = settings_window
340        .render_sub_page_items_section(ollama_settings.iter().enumerate(), true, window, cx)
341        .into_any_element();
342
343    v_flex()
344        .id("ollama")
345        .min_w_0()
346        .pt_8()
347        .gap_1p5()
348        .child(
349            SettingsSectionHeader::new("Ollama")
350                .icon(IconName::AiOllama)
351                .no_padding(true),
352        )
353        .child(div().px_neg_8().child(additional_fields))
354}
355
356fn ollama_settings() -> Box<[SettingsPageItem]> {
357    Box::new([
358        SettingsPageItem::SettingItem(SettingItem {
359            title: "API URL",
360            description: "The base URL of your Ollama server.",
361            field: Box::new(SettingField {
362                pick: |settings| {
363                    settings
364                        .project
365                        .all_languages
366                        .edit_predictions
367                        .as_ref()?
368                        .ollama
369                        .as_ref()?
370                        .api_url
371                        .as_ref()
372                },
373                write: |settings, value| {
374                    settings
375                        .project
376                        .all_languages
377                        .edit_predictions
378                        .get_or_insert_default()
379                        .ollama
380                        .get_or_insert_default()
381                        .api_url = value;
382                },
383                json_path: Some("edit_predictions.ollama.api_url"),
384            }),
385            metadata: Some(Box::new(SettingsFieldMetadata {
386                placeholder: Some(OLLAMA_API_URL_PLACEHOLDER),
387                ..Default::default()
388            })),
389            files: USER,
390        }),
391        SettingsPageItem::SettingItem(SettingItem {
392            title: "Model",
393            description: "The Ollama model to use for edit predictions.",
394            field: Box::new(SettingField {
395                pick: |settings| {
396                    settings
397                        .project
398                        .all_languages
399                        .edit_predictions
400                        .as_ref()?
401                        .ollama
402                        .as_ref()?
403                        .model
404                        .as_ref()
405                },
406                write: |settings, value| {
407                    settings
408                        .project
409                        .all_languages
410                        .edit_predictions
411                        .get_or_insert_default()
412                        .ollama
413                        .get_or_insert_default()
414                        .model = value;
415                },
416                json_path: Some("edit_predictions.ollama.model"),
417            }),
418            metadata: Some(Box::new(SettingsFieldMetadata {
419                placeholder: Some(OLLAMA_MODEL_PLACEHOLDER),
420                ..Default::default()
421            })),
422            files: USER,
423        }),
424        SettingsPageItem::SettingItem(SettingItem {
425            title: "Prompt Format",
426            description: "The prompt format to use when requesting predictions. Set to Infer to have the format inferred based on the model name",
427            field: Box::new(SettingField {
428                pick: |settings| {
429                    settings
430                        .project
431                        .all_languages
432                        .edit_predictions
433                        .as_ref()?
434                        .ollama
435                        .as_ref()?
436                        .prompt_format
437                        .as_ref()
438                },
439                write: |settings, value| {
440                    settings
441                        .project
442                        .all_languages
443                        .edit_predictions
444                        .get_or_insert_default()
445                        .ollama
446                        .get_or_insert_default()
447                        .prompt_format = value;
448                },
449                json_path: Some("edit_predictions.ollama.prompt_format"),
450            }),
451            files: USER,
452            metadata: None,
453        }),
454        SettingsPageItem::SettingItem(SettingItem {
455            title: "Max Output Tokens",
456            description: "The maximum number of tokens to generate.",
457            field: Box::new(SettingField {
458                pick: |settings| {
459                    settings
460                        .project
461                        .all_languages
462                        .edit_predictions
463                        .as_ref()?
464                        .ollama
465                        .as_ref()?
466                        .max_output_tokens
467                        .as_ref()
468                },
469                write: |settings, value| {
470                    settings
471                        .project
472                        .all_languages
473                        .edit_predictions
474                        .get_or_insert_default()
475                        .ollama
476                        .get_or_insert_default()
477                        .max_output_tokens = value;
478                },
479                json_path: Some("edit_predictions.ollama.max_output_tokens"),
480            }),
481            metadata: None,
482            files: USER,
483        }),
484    ])
485}
486
487fn render_open_ai_compatible_provider(
488    settings_window: &SettingsWindow,
489    window: &mut Window,
490    cx: &mut Context<SettingsWindow>,
491) -> impl IntoElement {
492    let open_ai_compatible_settings = open_ai_compatible_settings();
493    let additional_fields = settings_window
494        .render_sub_page_items_section(
495            open_ai_compatible_settings.iter().enumerate(),
496            true,
497            window,
498            cx,
499        )
500        .into_any_element();
501
502    v_flex()
503        .id("open-ai-compatible")
504        .min_w_0()
505        .pt_8()
506        .gap_1p5()
507        .child(
508            SettingsSectionHeader::new("OpenAI Compatible API")
509                .icon(IconName::AiOpenAiCompat)
510                .no_padding(true),
511        )
512        .child(div().px_neg_8().child(additional_fields))
513}
514
515fn open_ai_compatible_settings() -> Box<[SettingsPageItem]> {
516    Box::new([
517        SettingsPageItem::SettingItem(SettingItem {
518            title: "API URL",
519            description: "The base URL of your OpenAI-compatible server.",
520            field: Box::new(SettingField {
521                pick: |settings| {
522                    settings
523                        .project
524                        .all_languages
525                        .edit_predictions
526                        .as_ref()?
527                        .open_ai_compatible_api
528                        .as_ref()?
529                        .api_url
530                        .as_ref()
531                },
532                write: |settings, value| {
533                    settings
534                        .project
535                        .all_languages
536                        .edit_predictions
537                        .get_or_insert_default()
538                        .open_ai_compatible_api
539                        .get_or_insert_default()
540                        .api_url = value;
541                },
542                json_path: Some("edit_predictions.open_ai_compatible_api.api_url"),
543            }),
544            metadata: Some(Box::new(SettingsFieldMetadata {
545                placeholder: Some(OLLAMA_API_URL_PLACEHOLDER),
546                ..Default::default()
547            })),
548            files: USER,
549        }),
550        SettingsPageItem::SettingItem(SettingItem {
551            title: "Model",
552            description: "The model string to pass to the OpenAI-compatible server.",
553            field: Box::new(SettingField {
554                pick: |settings| {
555                    settings
556                        .project
557                        .all_languages
558                        .edit_predictions
559                        .as_ref()?
560                        .open_ai_compatible_api
561                        .as_ref()?
562                        .model
563                        .as_ref()
564                },
565                write: |settings, value| {
566                    settings
567                        .project
568                        .all_languages
569                        .edit_predictions
570                        .get_or_insert_default()
571                        .open_ai_compatible_api
572                        .get_or_insert_default()
573                        .model = value;
574                },
575                json_path: Some("edit_predictions.open_ai_compatible_api.model"),
576            }),
577            metadata: Some(Box::new(SettingsFieldMetadata {
578                placeholder: Some(OLLAMA_MODEL_PLACEHOLDER),
579                ..Default::default()
580            })),
581            files: USER,
582        }),
583        SettingsPageItem::SettingItem(SettingItem {
584            title: "Prompt Format",
585            description: "The prompt format to use when requesting predictions. Set to Infer to have the format inferred based on the model name",
586            field: Box::new(SettingField {
587                pick: |settings| {
588                    settings
589                        .project
590                        .all_languages
591                        .edit_predictions
592                        .as_ref()?
593                        .open_ai_compatible_api
594                        .as_ref()?
595                        .prompt_format
596                        .as_ref()
597                },
598                write: |settings, value| {
599                    settings
600                        .project
601                        .all_languages
602                        .edit_predictions
603                        .get_or_insert_default()
604                        .open_ai_compatible_api
605                        .get_or_insert_default()
606                        .prompt_format = value;
607                },
608                json_path: Some("edit_predictions.open_ai_compatible_api.prompt_format"),
609            }),
610            files: USER,
611            metadata: None,
612        }),
613        SettingsPageItem::SettingItem(SettingItem {
614            title: "Max Output Tokens",
615            description: "The maximum number of tokens to generate.",
616            field: Box::new(SettingField {
617                pick: |settings| {
618                    settings
619                        .project
620                        .all_languages
621                        .edit_predictions
622                        .as_ref()?
623                        .open_ai_compatible_api
624                        .as_ref()?
625                        .max_output_tokens
626                        .as_ref()
627                },
628                write: |settings, value| {
629                    settings
630                        .project
631                        .all_languages
632                        .edit_predictions
633                        .get_or_insert_default()
634                        .open_ai_compatible_api
635                        .get_or_insert_default()
636                        .max_output_tokens = value;
637                },
638                json_path: Some("edit_predictions.open_ai_compatible_api.max_output_tokens"),
639            }),
640            metadata: None,
641            files: USER,
642        }),
643    ])
644}
645
646fn codestral_settings() -> Box<[SettingsPageItem]> {
647    Box::new([
648        SettingsPageItem::SettingItem(SettingItem {
649            title: "API URL",
650            description: "The API URL to use for Codestral.",
651            field: Box::new(SettingField {
652                pick: |settings| {
653                    settings
654                        .project
655                        .all_languages
656                        .edit_predictions
657                        .as_ref()?
658                        .codestral
659                        .as_ref()?
660                        .api_url
661                        .as_ref()
662                },
663                write: |settings, value| {
664                    settings
665                        .project
666                        .all_languages
667                        .edit_predictions
668                        .get_or_insert_default()
669                        .codestral
670                        .get_or_insert_default()
671                        .api_url = value;
672                },
673                json_path: Some("edit_predictions.codestral.api_url"),
674            }),
675            metadata: Some(Box::new(SettingsFieldMetadata {
676                placeholder: Some(CODESTRAL_API_URL),
677                ..Default::default()
678            })),
679            files: USER,
680        }),
681        SettingsPageItem::SettingItem(SettingItem {
682            title: "Max Tokens",
683            description: "The maximum number of tokens to generate.",
684            field: Box::new(SettingField {
685                pick: |settings| {
686                    settings
687                        .project
688                        .all_languages
689                        .edit_predictions
690                        .as_ref()?
691                        .codestral
692                        .as_ref()?
693                        .max_tokens
694                        .as_ref()
695                },
696                write: |settings, value| {
697                    settings
698                        .project
699                        .all_languages
700                        .edit_predictions
701                        .get_or_insert_default()
702                        .codestral
703                        .get_or_insert_default()
704                        .max_tokens = value;
705                },
706                json_path: Some("edit_predictions.codestral.max_tokens"),
707            }),
708            metadata: None,
709            files: USER,
710        }),
711        SettingsPageItem::SettingItem(SettingItem {
712            title: "Model",
713            description: "The Codestral model id to use.",
714            field: Box::new(SettingField {
715                pick: |settings| {
716                    settings
717                        .project
718                        .all_languages
719                        .edit_predictions
720                        .as_ref()?
721                        .codestral
722                        .as_ref()?
723                        .model
724                        .as_ref()
725                },
726                write: |settings, value| {
727                    settings
728                        .project
729                        .all_languages
730                        .edit_predictions
731                        .get_or_insert_default()
732                        .codestral
733                        .get_or_insert_default()
734                        .model = value;
735                },
736                json_path: Some("edit_predictions.codestral.model"),
737            }),
738            metadata: Some(Box::new(SettingsFieldMetadata {
739                placeholder: Some("codestral-latest"),
740                ..Default::default()
741            })),
742            files: USER,
743        }),
744    ])
745}
746
747fn render_github_copilot_provider(window: &mut Window, cx: &mut App) -> Option<impl IntoElement> {
748    let configuration_view = window.use_state(cx, |_, cx| {
749        copilot_ui::ConfigurationView::new(
750            move |cx| {
751                if let Some(app_state) = AppState::global(cx).upgrade() {
752                    copilot::GlobalCopilotAuth::try_get_or_init(app_state, cx)
753                        .is_some_and(|copilot| copilot.0.read(cx).is_authenticated())
754                } else {
755                    false
756                }
757            },
758            copilot_ui::ConfigurationMode::EditPrediction,
759            cx,
760        )
761    });
762
763    Some(
764        v_flex()
765            .id("github-copilot")
766            .min_w_0()
767            .pt_8()
768            .gap_1p5()
769            .child(
770                SettingsSectionHeader::new("GitHub Copilot")
771                    .icon(IconName::Copilot)
772                    .no_padding(true),
773            )
774            .child(configuration_view),
775    )
776}