edit_prediction_provider_setup.rs

  1use codestral::{CODESTRAL_API_URL, codestral_api_key_state, codestral_api_url};
  2use edit_prediction::{
  3    ApiKeyState,
  4    mercury::{MERCURY_CREDENTIALS_URL, mercury_api_token},
  5    sweep_ai::{SWEEP_CREDENTIALS_URL, sweep_api_token},
  6};
  7use edit_prediction_ui::{get_available_providers, set_completion_provider};
  8use gpui::{Entity, ScrollHandle, prelude::*};
  9use language::language_settings::AllLanguageSettings;
 10
 11use settings::Settings as _;
 12use ui::{ButtonLink, ConfiguredApiCard, ContextMenu, DropdownMenu, DropdownStyle, prelude::*};
 13use workspace::AppState;
 14
 15const OLLAMA_API_URL_PLACEHOLDER: &str = "http://localhost:11434";
 16const OLLAMA_MODEL_PLACEHOLDER: &str = "qwen2.5-coder:3b-base";
 17
 18use crate::{
 19    SettingField, SettingItem, SettingsFieldMetadata, SettingsPageItem, SettingsWindow, USER,
 20    components::{SettingsInputField, SettingsSectionHeader},
 21};
 22
 23pub(crate) fn render_edit_prediction_setup_page(
 24    settings_window: &SettingsWindow,
 25    scroll_handle: &ScrollHandle,
 26    window: &mut Window,
 27    cx: &mut Context<SettingsWindow>,
 28) -> AnyElement {
 29    let providers = [
 30        Some(render_provider_dropdown(window, cx)),
 31        render_github_copilot_provider(window, cx).map(IntoElement::into_any_element),
 32        Some(
 33            render_api_key_provider(
 34                IconName::Inception,
 35                "Mercury",
 36                "https://platform.inceptionlabs.ai/dashboard/api-keys".into(),
 37                mercury_api_token(cx),
 38                |_cx| MERCURY_CREDENTIALS_URL,
 39                None,
 40                window,
 41                cx,
 42            )
 43            .into_any_element(),
 44        ),
 45        Some(
 46            render_api_key_provider(
 47                IconName::SweepAi,
 48                "Sweep",
 49                "https://app.sweep.dev/".into(),
 50                sweep_api_token(cx),
 51                |_cx| SWEEP_CREDENTIALS_URL,
 52                Some(
 53                    settings_window
 54                        .render_sub_page_items_section(
 55                            sweep_settings().iter().enumerate(),
 56                            true,
 57                            window,
 58                            cx,
 59                        )
 60                        .into_any_element(),
 61                ),
 62                window,
 63                cx,
 64            )
 65            .into_any_element(),
 66        ),
 67        Some(render_ollama_provider(settings_window, window, cx).into_any_element()),
 68        Some(
 69            render_api_key_provider(
 70                IconName::AiMistral,
 71                "Codestral",
 72                "https://console.mistral.ai/codestral".into(),
 73                codestral_api_key_state(cx),
 74                |cx| codestral_api_url(cx),
 75                Some(
 76                    settings_window
 77                        .render_sub_page_items_section(
 78                            codestral_settings().iter().enumerate(),
 79                            true,
 80                            window,
 81                            cx,
 82                        )
 83                        .into_any_element(),
 84                ),
 85                window,
 86                cx,
 87            )
 88            .into_any_element(),
 89        ),
 90    ];
 91
 92    div()
 93        .size_full()
 94        .child(
 95            v_flex()
 96                .id("ep-setup-page")
 97                .min_w_0()
 98                .size_full()
 99                .px_8()
100                .pb_16()
101                .overflow_y_scroll()
102                .track_scroll(&scroll_handle)
103                .children(providers.into_iter().flatten()),
104        )
105        .into_any_element()
106}
107
108fn render_provider_dropdown(window: &mut Window, cx: &mut App) -> AnyElement {
109    let current_provider = AllLanguageSettings::get_global(cx)
110        .edit_predictions
111        .provider;
112    let current_provider_name = current_provider.display_name().unwrap_or("No provider set");
113
114    let menu = ContextMenu::build(window, cx, move |mut menu, _, cx| {
115        let available_providers = get_available_providers(cx);
116        let fs = <dyn fs::Fs>::global(cx);
117
118        for provider in available_providers {
119            let Some(name) = provider.display_name() else {
120                continue;
121            };
122            let is_current = provider == current_provider;
123
124            menu = menu.toggleable_entry(name, is_current, IconPosition::Start, None, {
125                let fs = fs.clone();
126                move |_, cx| {
127                    set_completion_provider(fs.clone(), cx, provider);
128                }
129            });
130        }
131        menu
132    });
133
134    v_flex()
135        .id("provider-selector")
136        .min_w_0()
137        .gap_1p5()
138        .child(
139            SettingsSectionHeader::new("Active Provider")
140                .icon(IconName::Sparkle)
141                .no_padding(true),
142        )
143        .child(
144            h_flex()
145                .pt_2p5()
146                .w_full()
147                .justify_between()
148                .child(
149                    v_flex()
150                        .w_full()
151                        .max_w_1_2()
152                        .child(Label::new("Provider"))
153                        .child(
154                            Label::new("Select which provider to use for edit predictions.")
155                                .size(LabelSize::Small)
156                                .color(Color::Muted),
157                        ),
158                )
159                .child(
160                    DropdownMenu::new("provider-dropdown", current_provider_name, menu)
161                        .style(DropdownStyle::Outlined),
162                ),
163        )
164        .into_any_element()
165}
166
167fn render_api_key_provider(
168    icon: IconName,
169    title: &'static str,
170    link: SharedString,
171    api_key_state: Entity<ApiKeyState>,
172    current_url: fn(&mut App) -> SharedString,
173    additional_fields: Option<AnyElement>,
174    window: &mut Window,
175    cx: &mut Context<SettingsWindow>,
176) -> impl IntoElement {
177    let weak_page = cx.weak_entity();
178    _ = window.use_keyed_state(current_url(cx), cx, |_, cx| {
179        let task = api_key_state.update(cx, |key_state, cx| {
180            key_state.load_if_needed(current_url(cx), |state| state, cx)
181        });
182        cx.spawn(async move |_, cx| {
183            task.await.ok();
184            weak_page
185                .update(cx, |_, cx| {
186                    cx.notify();
187                })
188                .ok();
189        })
190    });
191
192    let (has_key, env_var_name, is_from_env_var) = api_key_state.read_with(cx, |state, _| {
193        (
194            state.has_key(),
195            Some(state.env_var_name().clone()),
196            state.is_from_env_var(),
197        )
198    });
199
200    let write_key = move |api_key: Option<String>, cx: &mut App| {
201        api_key_state
202            .update(cx, |key_state, cx| {
203                let url = current_url(cx);
204                key_state.store(url, api_key, |key_state| key_state, cx)
205            })
206            .detach_and_log_err(cx);
207    };
208
209    let base_container = v_flex().id(title).min_w_0().pt_8().gap_1p5();
210    let header = SettingsSectionHeader::new(title)
211        .icon(icon)
212        .no_padding(true);
213    let button_link_label = format!("{} dashboard", title);
214    let description = h_flex()
215        .min_w_0()
216        .gap_0p5()
217        .child(
218            Label::new("Visit the")
219                .size(LabelSize::Small)
220                .color(Color::Muted),
221        )
222        .child(
223            ButtonLink::new(button_link_label, link)
224                .no_icon(true)
225                .label_size(LabelSize::Small)
226                .label_color(Color::Muted),
227        )
228        .child(
229            Label::new("to generate an API key.")
230                .size(LabelSize::Small)
231                .color(Color::Muted),
232        );
233    let configured_card_label = if is_from_env_var {
234        "API Key Set in Environment Variable"
235    } else {
236        "API Key Configured"
237    };
238
239    let container = if has_key {
240        base_container.child(header).child(
241            ConfiguredApiCard::new(configured_card_label)
242                .button_label("Reset Key")
243                .button_tab_index(0)
244                .disabled(is_from_env_var)
245                .when_some(env_var_name, |this, env_var_name| {
246                    this.when(is_from_env_var, |this| {
247                        this.tooltip_label(format!(
248                            "To reset your API key, unset the {} environment variable.",
249                            env_var_name
250                        ))
251                    })
252                })
253                .on_click(move |_, _, cx| {
254                    write_key(None, cx);
255                }),
256        )
257    } else {
258        base_container.child(header).child(
259            h_flex()
260                .pt_2p5()
261                .w_full()
262                .justify_between()
263                .child(
264                    v_flex()
265                        .w_full()
266                        .max_w_1_2()
267                        .child(Label::new("API Key"))
268                        .child(description)
269                        .when_some(env_var_name, |this, env_var_name| {
270                            this.child({
271                                let label = format!(
272                                    "Or set the {} env var and restart Zed.",
273                                    env_var_name.as_ref()
274                                );
275                                Label::new(label).size(LabelSize::Small).color(Color::Muted)
276                            })
277                        }),
278                )
279                .child(
280                    SettingsInputField::new()
281                        .tab_index(0)
282                        .with_placeholder("xxxxxxxxxxxxxxxxxxxx")
283                        .on_confirm(move |api_key, _window, cx| {
284                            write_key(api_key.filter(|key| !key.is_empty()), cx);
285                        }),
286                ),
287        )
288    };
289
290    container.when_some(additional_fields, |this, additional_fields| {
291        this.child(
292            div()
293                .map(|this| if has_key { this.mt_1() } else { this.mt_4() })
294                .px_neg_8()
295                .border_t_1()
296                .border_color(cx.theme().colors().border_variant)
297                .child(additional_fields),
298        )
299    })
300}
301
302fn sweep_settings() -> Box<[SettingsPageItem]> {
303    Box::new([SettingsPageItem::SettingItem(SettingItem {
304        title: "Privacy Mode",
305        description: "When enabled, Sweep will not store edit prediction inputs or outputs. When disabled, Sweep may collect data including buffer contents, diagnostics, file paths, and generated predictions to improve the service.",
306        field: Box::new(SettingField {
307            pick: |settings| {
308                settings
309                    .project
310                    .all_languages
311                    .edit_predictions
312                    .as_ref()?
313                    .sweep
314                    .as_ref()?
315                    .privacy_mode
316                    .as_ref()
317            },
318            write: |settings, value| {
319                settings
320                    .project
321                    .all_languages
322                    .edit_predictions
323                    .get_or_insert_default()
324                    .sweep
325                    .get_or_insert_default()
326                    .privacy_mode = value;
327            },
328            json_path: Some("edit_predictions.sweep.privacy_mode"),
329        }),
330        metadata: None,
331        files: USER,
332    })])
333}
334
335fn render_ollama_provider(
336    settings_window: &SettingsWindow,
337    window: &mut Window,
338    cx: &mut Context<SettingsWindow>,
339) -> impl IntoElement {
340    let ollama_settings = ollama_settings();
341    let additional_fields = settings_window
342        .render_sub_page_items_section(ollama_settings.iter().enumerate(), true, window, cx)
343        .into_any_element();
344
345    v_flex()
346        .id("ollama")
347        .min_w_0()
348        .pt_8()
349        .gap_1p5()
350        .child(
351            SettingsSectionHeader::new("Ollama")
352                .icon(IconName::ZedPredict)
353                .no_padding(true),
354        )
355        .child(
356            Label::new("Configure the local Ollama server and model used for edit predictions.")
357                .size(LabelSize::Small)
358                .color(Color::Muted),
359        )
360        .child(div().px_neg_8().child(additional_fields))
361}
362
363fn ollama_settings() -> Box<[SettingsPageItem]> {
364    Box::new([
365        SettingsPageItem::SettingItem(SettingItem {
366            title: "API URL",
367            description: "The base URL of your Ollama server.",
368            field: Box::new(SettingField {
369                pick: |settings| {
370                    settings
371                        .project
372                        .all_languages
373                        .edit_predictions
374                        .as_ref()?
375                        .ollama
376                        .as_ref()?
377                        .api_url
378                        .as_ref()
379                },
380                write: |settings, value| {
381                    settings
382                        .project
383                        .all_languages
384                        .edit_predictions
385                        .get_or_insert_default()
386                        .ollama
387                        .get_or_insert_default()
388                        .api_url = value;
389                },
390                json_path: Some("edit_predictions.ollama.api_url"),
391            }),
392            metadata: Some(Box::new(SettingsFieldMetadata {
393                placeholder: Some(OLLAMA_API_URL_PLACEHOLDER),
394                ..Default::default()
395            })),
396            files: USER,
397        }),
398        SettingsPageItem::SettingItem(SettingItem {
399            title: "Model",
400            description: "The Ollama model to use for edit predictions.",
401            field: Box::new(SettingField {
402                pick: |settings| {
403                    settings
404                        .project
405                        .all_languages
406                        .edit_predictions
407                        .as_ref()?
408                        .ollama
409                        .as_ref()?
410                        .model
411                        .as_ref()
412                },
413                write: |settings, value| {
414                    settings
415                        .project
416                        .all_languages
417                        .edit_predictions
418                        .get_or_insert_default()
419                        .ollama
420                        .get_or_insert_default()
421                        .model = value;
422                },
423                json_path: Some("edit_predictions.ollama.model"),
424            }),
425            metadata: Some(Box::new(SettingsFieldMetadata {
426                placeholder: Some(OLLAMA_MODEL_PLACEHOLDER),
427                ..Default::default()
428            })),
429            files: USER,
430        }),
431        SettingsPageItem::SettingItem(SettingItem {
432            title: "Max Output Tokens",
433            description: "The maximum number of tokens to generate.",
434            field: Box::new(SettingField {
435                pick: |settings| {
436                    settings
437                        .project
438                        .all_languages
439                        .edit_predictions
440                        .as_ref()?
441                        .ollama
442                        .as_ref()?
443                        .max_output_tokens
444                        .as_ref()
445                },
446                write: |settings, value| {
447                    settings
448                        .project
449                        .all_languages
450                        .edit_predictions
451                        .get_or_insert_default()
452                        .ollama
453                        .get_or_insert_default()
454                        .max_output_tokens = value;
455                },
456                json_path: Some("edit_predictions.ollama.max_output_tokens"),
457            }),
458            metadata: None,
459            files: USER,
460        }),
461    ])
462}
463
464fn codestral_settings() -> Box<[SettingsPageItem]> {
465    Box::new([
466        SettingsPageItem::SettingItem(SettingItem {
467            title: "API URL",
468            description: "The API URL to use for Codestral.",
469            field: Box::new(SettingField {
470                pick: |settings| {
471                    settings
472                        .project
473                        .all_languages
474                        .edit_predictions
475                        .as_ref()?
476                        .codestral
477                        .as_ref()?
478                        .api_url
479                        .as_ref()
480                },
481                write: |settings, value| {
482                    settings
483                        .project
484                        .all_languages
485                        .edit_predictions
486                        .get_or_insert_default()
487                        .codestral
488                        .get_or_insert_default()
489                        .api_url = value;
490                },
491                json_path: Some("edit_predictions.codestral.api_url"),
492            }),
493            metadata: Some(Box::new(SettingsFieldMetadata {
494                placeholder: Some(CODESTRAL_API_URL),
495                ..Default::default()
496            })),
497            files: USER,
498        }),
499        SettingsPageItem::SettingItem(SettingItem {
500            title: "Max Tokens",
501            description: "The maximum number of tokens to generate.",
502            field: Box::new(SettingField {
503                pick: |settings| {
504                    settings
505                        .project
506                        .all_languages
507                        .edit_predictions
508                        .as_ref()?
509                        .codestral
510                        .as_ref()?
511                        .max_tokens
512                        .as_ref()
513                },
514                write: |settings, value| {
515                    settings
516                        .project
517                        .all_languages
518                        .edit_predictions
519                        .get_or_insert_default()
520                        .codestral
521                        .get_or_insert_default()
522                        .max_tokens = value;
523                },
524                json_path: Some("edit_predictions.codestral.max_tokens"),
525            }),
526            metadata: None,
527            files: USER,
528        }),
529        SettingsPageItem::SettingItem(SettingItem {
530            title: "Model",
531            description: "The Codestral model id to use.",
532            field: Box::new(SettingField {
533                pick: |settings| {
534                    settings
535                        .project
536                        .all_languages
537                        .edit_predictions
538                        .as_ref()?
539                        .codestral
540                        .as_ref()?
541                        .model
542                        .as_ref()
543                },
544                write: |settings, value| {
545                    settings
546                        .project
547                        .all_languages
548                        .edit_predictions
549                        .get_or_insert_default()
550                        .codestral
551                        .get_or_insert_default()
552                        .model = value;
553                },
554                json_path: Some("edit_predictions.codestral.model"),
555            }),
556            metadata: Some(Box::new(SettingsFieldMetadata {
557                placeholder: Some("codestral-latest"),
558                ..Default::default()
559            })),
560            files: USER,
561        }),
562    ])
563}
564
565fn render_github_copilot_provider(window: &mut Window, cx: &mut App) -> Option<impl IntoElement> {
566    let configuration_view = window.use_state(cx, |_, cx| {
567        copilot_ui::ConfigurationView::new(
568            move |cx| {
569                if let Some(app_state) = AppState::global(cx).upgrade() {
570                    copilot::GlobalCopilotAuth::try_get_or_init(app_state, cx)
571                        .is_some_and(|copilot| copilot.0.read(cx).is_authenticated())
572                } else {
573                    false
574                }
575            },
576            copilot_ui::ConfigurationMode::EditPrediction,
577            cx,
578        )
579    });
580
581    Some(
582        v_flex()
583            .id("github-copilot")
584            .min_w_0()
585            .pt_8()
586            .gap_1p5()
587            .child(
588                SettingsSectionHeader::new("GitHub Copilot")
589                    .icon(IconName::Copilot)
590                    .no_padding(true),
591            )
592            .child(configuration_view),
593    )
594}