edit_prediction_provider_setup.rs

  1use codestral::{CODESTRAL_API_URL, codestral_api_key_state, codestral_api_url};
  2use edit_prediction::{
  3    ApiKeyState,
  4    mercury::{MERCURY_CREDENTIALS_URL, mercury_api_token},
  5    sweep_ai::{SWEEP_CREDENTIALS_URL, sweep_api_token},
  6};
  7use edit_prediction_ui::{get_available_providers, set_completion_provider};
  8use gpui::{Entity, ScrollHandle, prelude::*};
  9use language::language_settings::AllLanguageSettings;
 10
 11use settings::Settings as _;
 12use ui::{ButtonLink, ConfiguredApiCard, ContextMenu, DropdownMenu, DropdownStyle, prelude::*};
 13use workspace::AppState;
 14
 15const OLLAMA_API_URL_PLACEHOLDER: &str = "http://localhost:11434";
 16const OLLAMA_MODEL_PLACEHOLDER: &str = "qwen2.5-coder:3b-base";
 17
 18use crate::{
 19    SettingField, SettingItem, SettingsFieldMetadata, SettingsPageItem, SettingsWindow, USER,
 20    components::{SettingsInputField, SettingsSectionHeader},
 21};
 22
 23pub(crate) fn render_edit_prediction_setup_page(
 24    settings_window: &SettingsWindow,
 25    scroll_handle: &ScrollHandle,
 26    window: &mut Window,
 27    cx: &mut Context<SettingsWindow>,
 28) -> AnyElement {
 29    let providers = [
 30        Some(render_provider_dropdown(window, cx)),
 31        render_github_copilot_provider(window, cx).map(IntoElement::into_any_element),
 32        Some(
 33            render_api_key_provider(
 34                IconName::Inception,
 35                "Mercury",
 36                "https://platform.inceptionlabs.ai/dashboard/api-keys".into(),
 37                mercury_api_token(cx),
 38                |_cx| MERCURY_CREDENTIALS_URL,
 39                None,
 40                window,
 41                cx,
 42            )
 43            .into_any_element(),
 44        ),
 45        Some(
 46            render_api_key_provider(
 47                IconName::SweepAi,
 48                "Sweep",
 49                "https://app.sweep.dev/".into(),
 50                sweep_api_token(cx),
 51                |_cx| SWEEP_CREDENTIALS_URL,
 52                Some(
 53                    settings_window
 54                        .render_sub_page_items_section(
 55                            sweep_settings().iter().enumerate(),
 56                            true,
 57                            window,
 58                            cx,
 59                        )
 60                        .into_any_element(),
 61                ),
 62                window,
 63                cx,
 64            )
 65            .into_any_element(),
 66        ),
 67        Some(render_ollama_provider(settings_window, window, cx).into_any_element()),
 68        Some(
 69            render_api_key_provider(
 70                IconName::AiMistral,
 71                "Codestral",
 72                "https://console.mistral.ai/codestral".into(),
 73                codestral_api_key_state(cx),
 74                |cx| codestral_api_url(cx),
 75                Some(
 76                    settings_window
 77                        .render_sub_page_items_section(
 78                            codestral_settings().iter().enumerate(),
 79                            true,
 80                            window,
 81                            cx,
 82                        )
 83                        .into_any_element(),
 84                ),
 85                window,
 86                cx,
 87            )
 88            .into_any_element(),
 89        ),
 90    ];
 91
 92    div()
 93        .size_full()
 94        .child(
 95            v_flex()
 96                .id("ep-setup-page")
 97                .min_w_0()
 98                .size_full()
 99                .px_8()
100                .pb_16()
101                .overflow_y_scroll()
102                .track_scroll(&scroll_handle)
103                .children(providers.into_iter().flatten()),
104        )
105        .into_any_element()
106}
107
108fn render_provider_dropdown(window: &mut Window, cx: &mut App) -> AnyElement {
109    let current_provider = AllLanguageSettings::get_global(cx)
110        .edit_predictions
111        .provider;
112    let current_provider_name = current_provider.display_name().unwrap_or("No provider set");
113
114    let menu = ContextMenu::build(window, cx, move |mut menu, _, cx| {
115        let available_providers = get_available_providers(cx);
116        let fs = <dyn fs::Fs>::global(cx);
117
118        for provider in available_providers {
119            let Some(name) = provider.display_name() else {
120                continue;
121            };
122            let is_current = provider == current_provider;
123
124            menu = menu.toggleable_entry(name, is_current, IconPosition::Start, None, {
125                let fs = fs.clone();
126                move |_, cx| {
127                    set_completion_provider(fs.clone(), cx, provider);
128                }
129            });
130        }
131        menu
132    });
133
134    v_flex()
135        .id("provider-selector")
136        .min_w_0()
137        .gap_1p5()
138        .child(SettingsSectionHeader::new("Active Provider").no_padding(true))
139        .child(
140            h_flex()
141                .pt_2p5()
142                .w_full()
143                .justify_between()
144                .child(
145                    v_flex()
146                        .w_full()
147                        .max_w_1_2()
148                        .child(Label::new("Provider"))
149                        .child(
150                            Label::new("Select which provider to use for edit predictions.")
151                                .size(LabelSize::Small)
152                                .color(Color::Muted),
153                        ),
154                )
155                .child(
156                    DropdownMenu::new("provider-dropdown", current_provider_name, menu)
157                        .tab_index(0)
158                        .style(DropdownStyle::Outlined),
159                ),
160        )
161        .into_any_element()
162}
163
164fn render_api_key_provider(
165    icon: IconName,
166    title: &'static str,
167    link: SharedString,
168    api_key_state: Entity<ApiKeyState>,
169    current_url: fn(&mut App) -> SharedString,
170    additional_fields: Option<AnyElement>,
171    window: &mut Window,
172    cx: &mut Context<SettingsWindow>,
173) -> impl IntoElement {
174    let weak_page = cx.weak_entity();
175    _ = window.use_keyed_state(current_url(cx), cx, |_, cx| {
176        let task = api_key_state.update(cx, |key_state, cx| {
177            key_state.load_if_needed(current_url(cx), |state| state, cx)
178        });
179        cx.spawn(async move |_, cx| {
180            task.await.ok();
181            weak_page
182                .update(cx, |_, cx| {
183                    cx.notify();
184                })
185                .ok();
186        })
187    });
188
189    let (has_key, env_var_name, is_from_env_var) = api_key_state.read_with(cx, |state, _| {
190        (
191            state.has_key(),
192            Some(state.env_var_name().clone()),
193            state.is_from_env_var(),
194        )
195    });
196
197    let write_key = move |api_key: Option<String>, cx: &mut App| {
198        api_key_state
199            .update(cx, |key_state, cx| {
200                let url = current_url(cx);
201                key_state.store(url, api_key, |key_state| key_state, cx)
202            })
203            .detach_and_log_err(cx);
204    };
205
206    let base_container = v_flex().id(title).min_w_0().pt_8().gap_1p5();
207    let header = SettingsSectionHeader::new(title)
208        .icon(icon)
209        .no_padding(true);
210    let button_link_label = format!("{} dashboard", title);
211    let description = h_flex()
212        .min_w_0()
213        .gap_0p5()
214        .child(
215            Label::new("Visit the")
216                .size(LabelSize::Small)
217                .color(Color::Muted),
218        )
219        .child(
220            ButtonLink::new(button_link_label, link)
221                .no_icon(true)
222                .label_size(LabelSize::Small)
223                .label_color(Color::Muted),
224        )
225        .child(
226            Label::new("to generate an API key.")
227                .size(LabelSize::Small)
228                .color(Color::Muted),
229        );
230    let configured_card_label = if is_from_env_var {
231        "API Key Set in Environment Variable"
232    } else {
233        "API Key Configured"
234    };
235
236    let container = if has_key {
237        base_container.child(header).child(
238            ConfiguredApiCard::new(configured_card_label)
239                .button_label("Reset Key")
240                .button_tab_index(0)
241                .disabled(is_from_env_var)
242                .when_some(env_var_name, |this, env_var_name| {
243                    this.when(is_from_env_var, |this| {
244                        this.tooltip_label(format!(
245                            "To reset your API key, unset the {} environment variable.",
246                            env_var_name
247                        ))
248                    })
249                })
250                .on_click(move |_, _, cx| {
251                    write_key(None, cx);
252                }),
253        )
254    } else {
255        base_container.child(header).child(
256            h_flex()
257                .pt_2p5()
258                .w_full()
259                .justify_between()
260                .child(
261                    v_flex()
262                        .w_full()
263                        .max_w_1_2()
264                        .child(Label::new("API Key"))
265                        .child(description)
266                        .when_some(env_var_name, |this, env_var_name| {
267                            this.child({
268                                let label = format!(
269                                    "Or set the {} env var and restart Zed.",
270                                    env_var_name.as_ref()
271                                );
272                                Label::new(label).size(LabelSize::Small).color(Color::Muted)
273                            })
274                        }),
275                )
276                .child(
277                    SettingsInputField::new()
278                        .tab_index(0)
279                        .with_placeholder("xxxxxxxxxxxxxxxxxxxx")
280                        .on_confirm(move |api_key, _window, cx| {
281                            write_key(api_key.filter(|key| !key.is_empty()), cx);
282                        }),
283                ),
284        )
285    };
286
287    container.when_some(additional_fields, |this, additional_fields| {
288        this.child(
289            div()
290                .map(|this| if has_key { this.mt_1() } else { this.mt_4() })
291                .px_neg_8()
292                .border_t_1()
293                .border_color(cx.theme().colors().border_variant)
294                .child(additional_fields),
295        )
296    })
297}
298
299fn sweep_settings() -> Box<[SettingsPageItem]> {
300    Box::new([SettingsPageItem::SettingItem(SettingItem {
301        title: "Privacy Mode",
302        description: "When enabled, Sweep will not store edit prediction inputs or outputs. When disabled, Sweep may collect data including buffer contents, diagnostics, file paths, and generated predictions to improve the service.",
303        field: Box::new(SettingField {
304            pick: |settings| {
305                settings
306                    .project
307                    .all_languages
308                    .edit_predictions
309                    .as_ref()?
310                    .sweep
311                    .as_ref()?
312                    .privacy_mode
313                    .as_ref()
314            },
315            write: |settings, value| {
316                settings
317                    .project
318                    .all_languages
319                    .edit_predictions
320                    .get_or_insert_default()
321                    .sweep
322                    .get_or_insert_default()
323                    .privacy_mode = value;
324            },
325            json_path: Some("edit_predictions.sweep.privacy_mode"),
326        }),
327        metadata: None,
328        files: USER,
329    })])
330}
331
332fn render_ollama_provider(
333    settings_window: &SettingsWindow,
334    window: &mut Window,
335    cx: &mut Context<SettingsWindow>,
336) -> impl IntoElement {
337    let ollama_settings = ollama_settings();
338    let additional_fields = settings_window
339        .render_sub_page_items_section(ollama_settings.iter().enumerate(), true, window, cx)
340        .into_any_element();
341
342    v_flex()
343        .id("ollama")
344        .min_w_0()
345        .pt_8()
346        .gap_1p5()
347        .child(
348            SettingsSectionHeader::new("Ollama")
349                .icon(IconName::AiOllama)
350                .no_padding(true),
351        )
352        .child(div().px_neg_8().child(additional_fields))
353}
354
355fn ollama_settings() -> Box<[SettingsPageItem]> {
356    Box::new([
357        SettingsPageItem::SettingItem(SettingItem {
358            title: "API URL",
359            description: "The base URL of your Ollama server.",
360            field: Box::new(SettingField {
361                pick: |settings| {
362                    settings
363                        .project
364                        .all_languages
365                        .edit_predictions
366                        .as_ref()?
367                        .ollama
368                        .as_ref()?
369                        .api_url
370                        .as_ref()
371                },
372                write: |settings, value| {
373                    settings
374                        .project
375                        .all_languages
376                        .edit_predictions
377                        .get_or_insert_default()
378                        .ollama
379                        .get_or_insert_default()
380                        .api_url = value;
381                },
382                json_path: Some("edit_predictions.ollama.api_url"),
383            }),
384            metadata: Some(Box::new(SettingsFieldMetadata {
385                placeholder: Some(OLLAMA_API_URL_PLACEHOLDER),
386                ..Default::default()
387            })),
388            files: USER,
389        }),
390        SettingsPageItem::SettingItem(SettingItem {
391            title: "Model",
392            description: "The Ollama model to use for edit predictions.",
393            field: Box::new(SettingField {
394                pick: |settings| {
395                    settings
396                        .project
397                        .all_languages
398                        .edit_predictions
399                        .as_ref()?
400                        .ollama
401                        .as_ref()?
402                        .model
403                        .as_ref()
404                },
405                write: |settings, value| {
406                    settings
407                        .project
408                        .all_languages
409                        .edit_predictions
410                        .get_or_insert_default()
411                        .ollama
412                        .get_or_insert_default()
413                        .model = value;
414                },
415                json_path: Some("edit_predictions.ollama.model"),
416            }),
417            metadata: Some(Box::new(SettingsFieldMetadata {
418                placeholder: Some(OLLAMA_MODEL_PLACEHOLDER),
419                ..Default::default()
420            })),
421            files: USER,
422        }),
423        SettingsPageItem::SettingItem(SettingItem {
424            title: "Max Output Tokens",
425            description: "The maximum number of tokens to generate.",
426            field: Box::new(SettingField {
427                pick: |settings| {
428                    settings
429                        .project
430                        .all_languages
431                        .edit_predictions
432                        .as_ref()?
433                        .ollama
434                        .as_ref()?
435                        .max_output_tokens
436                        .as_ref()
437                },
438                write: |settings, value| {
439                    settings
440                        .project
441                        .all_languages
442                        .edit_predictions
443                        .get_or_insert_default()
444                        .ollama
445                        .get_or_insert_default()
446                        .max_output_tokens = value;
447                },
448                json_path: Some("edit_predictions.ollama.max_output_tokens"),
449            }),
450            metadata: None,
451            files: USER,
452        }),
453    ])
454}
455
456fn codestral_settings() -> Box<[SettingsPageItem]> {
457    Box::new([
458        SettingsPageItem::SettingItem(SettingItem {
459            title: "API URL",
460            description: "The API URL to use for Codestral.",
461            field: Box::new(SettingField {
462                pick: |settings| {
463                    settings
464                        .project
465                        .all_languages
466                        .edit_predictions
467                        .as_ref()?
468                        .codestral
469                        .as_ref()?
470                        .api_url
471                        .as_ref()
472                },
473                write: |settings, value| {
474                    settings
475                        .project
476                        .all_languages
477                        .edit_predictions
478                        .get_or_insert_default()
479                        .codestral
480                        .get_or_insert_default()
481                        .api_url = value;
482                },
483                json_path: Some("edit_predictions.codestral.api_url"),
484            }),
485            metadata: Some(Box::new(SettingsFieldMetadata {
486                placeholder: Some(CODESTRAL_API_URL),
487                ..Default::default()
488            })),
489            files: USER,
490        }),
491        SettingsPageItem::SettingItem(SettingItem {
492            title: "Max Tokens",
493            description: "The maximum number of tokens to generate.",
494            field: Box::new(SettingField {
495                pick: |settings| {
496                    settings
497                        .project
498                        .all_languages
499                        .edit_predictions
500                        .as_ref()?
501                        .codestral
502                        .as_ref()?
503                        .max_tokens
504                        .as_ref()
505                },
506                write: |settings, value| {
507                    settings
508                        .project
509                        .all_languages
510                        .edit_predictions
511                        .get_or_insert_default()
512                        .codestral
513                        .get_or_insert_default()
514                        .max_tokens = value;
515                },
516                json_path: Some("edit_predictions.codestral.max_tokens"),
517            }),
518            metadata: None,
519            files: USER,
520        }),
521        SettingsPageItem::SettingItem(SettingItem {
522            title: "Model",
523            description: "The Codestral model id to use.",
524            field: Box::new(SettingField {
525                pick: |settings| {
526                    settings
527                        .project
528                        .all_languages
529                        .edit_predictions
530                        .as_ref()?
531                        .codestral
532                        .as_ref()?
533                        .model
534                        .as_ref()
535                },
536                write: |settings, value| {
537                    settings
538                        .project
539                        .all_languages
540                        .edit_predictions
541                        .get_or_insert_default()
542                        .codestral
543                        .get_or_insert_default()
544                        .model = value;
545                },
546                json_path: Some("edit_predictions.codestral.model"),
547            }),
548            metadata: Some(Box::new(SettingsFieldMetadata {
549                placeholder: Some("codestral-latest"),
550                ..Default::default()
551            })),
552            files: USER,
553        }),
554    ])
555}
556
557fn render_github_copilot_provider(window: &mut Window, cx: &mut App) -> Option<impl IntoElement> {
558    let configuration_view = window.use_state(cx, |_, cx| {
559        copilot_ui::ConfigurationView::new(
560            move |cx| {
561                if let Some(app_state) = AppState::global(cx).upgrade() {
562                    copilot::GlobalCopilotAuth::try_get_or_init(app_state, cx)
563                        .is_some_and(|copilot| copilot.0.read(cx).is_authenticated())
564                } else {
565                    false
566                }
567            },
568            copilot_ui::ConfigurationMode::EditPrediction,
569            cx,
570        )
571    });
572
573    Some(
574        v_flex()
575            .id("github-copilot")
576            .min_w_0()
577            .pt_8()
578            .gap_1p5()
579            .child(
580                SettingsSectionHeader::new("GitHub Copilot")
581                    .icon(IconName::Copilot)
582                    .no_padding(true),
583            )
584            .child(configuration_view),
585    )
586}