edit_prediction_provider_setup.rs

  1use edit_prediction::{
  2    ApiKeyState,
  3    mercury::{MERCURY_CREDENTIALS_URL, mercury_api_token},
  4    sweep_ai::{SWEEP_CREDENTIALS_URL, sweep_api_token},
  5};
  6use edit_prediction_ui::{get_available_providers, set_completion_provider};
  7use gpui::{Entity, ScrollHandle, prelude::*};
  8use language::language_settings::AllLanguageSettings;
  9use language_models::provider::mistral::{CODESTRAL_API_URL, codestral_api_key};
 10use settings::Settings as _;
 11use ui::{ButtonLink, ConfiguredApiCard, ContextMenu, DropdownMenu, DropdownStyle, prelude::*};
 12use workspace::AppState;
 13
 14const OLLAMA_API_URL_PLACEHOLDER: &str = "http://localhost:11434";
 15const OLLAMA_MODEL_PLACEHOLDER: &str = "qwen2.5-coder:3b-base";
 16
 17use crate::{
 18    SettingField, SettingItem, SettingsFieldMetadata, SettingsPageItem, SettingsWindow, USER,
 19    components::{SettingsInputField, SettingsSectionHeader},
 20};
 21
 22pub(crate) fn render_edit_prediction_setup_page(
 23    settings_window: &SettingsWindow,
 24    scroll_handle: &ScrollHandle,
 25    window: &mut Window,
 26    cx: &mut Context<SettingsWindow>,
 27) -> AnyElement {
 28    let providers = [
 29        Some(render_provider_dropdown(window, cx)),
 30        render_github_copilot_provider(window, cx).map(IntoElement::into_any_element),
 31        Some(
 32            render_api_key_provider(
 33                IconName::Inception,
 34                "Mercury",
 35                "https://platform.inceptionlabs.ai/dashboard/api-keys".into(),
 36                mercury_api_token(cx),
 37                |_cx| MERCURY_CREDENTIALS_URL,
 38                None,
 39                window,
 40                cx,
 41            )
 42            .into_any_element(),
 43        ),
 44        Some(
 45            render_api_key_provider(
 46                IconName::SweepAi,
 47                "Sweep",
 48                "https://app.sweep.dev/".into(),
 49                sweep_api_token(cx),
 50                |_cx| SWEEP_CREDENTIALS_URL,
 51                Some(
 52                    settings_window
 53                        .render_sub_page_items_section(
 54                            sweep_settings().iter().enumerate(),
 55                            true,
 56                            window,
 57                            cx,
 58                        )
 59                        .into_any_element(),
 60                ),
 61                window,
 62                cx,
 63            )
 64            .into_any_element(),
 65        ),
 66        Some(render_ollama_provider(settings_window, window, cx).into_any_element()),
 67        Some(
 68            render_api_key_provider(
 69                IconName::AiMistral,
 70                "Codestral",
 71                "https://console.mistral.ai/codestral".into(),
 72                codestral_api_key(cx),
 73                |cx| language_models::MistralLanguageModelProvider::api_url(cx),
 74                Some(
 75                    settings_window
 76                        .render_sub_page_items_section(
 77                            codestral_settings().iter().enumerate(),
 78                            true,
 79                            window,
 80                            cx,
 81                        )
 82                        .into_any_element(),
 83                ),
 84                window,
 85                cx,
 86            )
 87            .into_any_element(),
 88        ),
 89    ];
 90
 91    div()
 92        .size_full()
 93        .child(
 94            v_flex()
 95                .id("ep-setup-page")
 96                .min_w_0()
 97                .size_full()
 98                .px_8()
 99                .pb_16()
100                .overflow_y_scroll()
101                .track_scroll(&scroll_handle)
102                .children(providers.into_iter().flatten()),
103        )
104        .into_any_element()
105}
106
107fn render_provider_dropdown(window: &mut Window, cx: &mut App) -> AnyElement {
108    let current_provider = AllLanguageSettings::get_global(cx)
109        .edit_predictions
110        .provider;
111    let current_provider_name = current_provider.display_name().unwrap_or("No provider set");
112
113    let menu = ContextMenu::build(window, cx, move |mut menu, _, cx| {
114        let available_providers = get_available_providers(cx);
115        let fs = <dyn fs::Fs>::global(cx);
116
117        for provider in available_providers {
118            let Some(name) = provider.display_name() else {
119                continue;
120            };
121            let is_current = provider == current_provider;
122
123            menu = menu.toggleable_entry(name, is_current, IconPosition::Start, None, {
124                let fs = fs.clone();
125                move |_, cx| {
126                    set_completion_provider(fs.clone(), cx, provider);
127                }
128            });
129        }
130        menu
131    });
132
133    v_flex()
134        .id("provider-selector")
135        .min_w_0()
136        .gap_1p5()
137        .child(
138            SettingsSectionHeader::new("Active Provider")
139                .icon(IconName::Sparkle)
140                .no_padding(true),
141        )
142        .child(
143            h_flex()
144                .pt_2p5()
145                .w_full()
146                .justify_between()
147                .child(
148                    v_flex()
149                        .w_full()
150                        .max_w_1_2()
151                        .child(Label::new("Provider"))
152                        .child(
153                            Label::new("Select which provider to use for edit predictions.")
154                                .size(LabelSize::Small)
155                                .color(Color::Muted),
156                        ),
157                )
158                .child(
159                    DropdownMenu::new("provider-dropdown", current_provider_name, menu)
160                        .style(DropdownStyle::Outlined),
161                ),
162        )
163        .into_any_element()
164}
165
166fn render_api_key_provider(
167    icon: IconName,
168    title: &'static str,
169    link: SharedString,
170    api_key_state: Entity<ApiKeyState>,
171    current_url: fn(&mut App) -> SharedString,
172    additional_fields: Option<AnyElement>,
173    window: &mut Window,
174    cx: &mut Context<SettingsWindow>,
175) -> impl IntoElement {
176    let weak_page = cx.weak_entity();
177    _ = window.use_keyed_state(title, cx, |_, cx| {
178        let task = api_key_state.update(cx, |key_state, cx| {
179            key_state.load_if_needed(current_url(cx), |state| state, cx)
180        });
181        cx.spawn(async move |_, cx| {
182            task.await.ok();
183            weak_page
184                .update(cx, |_, cx| {
185                    cx.notify();
186                })
187                .ok();
188        })
189    });
190
191    let (has_key, env_var_name, is_from_env_var) = api_key_state.read_with(cx, |state, _| {
192        (
193            state.has_key(),
194            Some(state.env_var_name().clone()),
195            state.is_from_env_var(),
196        )
197    });
198
199    let write_key = move |api_key: Option<String>, cx: &mut App| {
200        api_key_state
201            .update(cx, |key_state, cx| {
202                let url = current_url(cx);
203                key_state.store(url, api_key, |key_state| key_state, cx)
204            })
205            .detach_and_log_err(cx);
206    };
207
208    let base_container = v_flex().id(title).min_w_0().pt_8().gap_1p5();
209    let header = SettingsSectionHeader::new(title)
210        .icon(icon)
211        .no_padding(true);
212    let button_link_label = format!("{} dashboard", title);
213    let description = h_flex()
214        .min_w_0()
215        .gap_0p5()
216        .child(
217            Label::new("Visit the")
218                .size(LabelSize::Small)
219                .color(Color::Muted),
220        )
221        .child(
222            ButtonLink::new(button_link_label, link)
223                .no_icon(true)
224                .label_size(LabelSize::Small)
225                .label_color(Color::Muted),
226        )
227        .child(
228            Label::new("to generate an API key.")
229                .size(LabelSize::Small)
230                .color(Color::Muted),
231        );
232    let configured_card_label = if is_from_env_var {
233        "API Key Set in Environment Variable"
234    } else {
235        "API Key Configured"
236    };
237
238    let container = if has_key {
239        base_container.child(header).child(
240            ConfiguredApiCard::new(configured_card_label)
241                .button_label("Reset Key")
242                .button_tab_index(0)
243                .disabled(is_from_env_var)
244                .when_some(env_var_name, |this, env_var_name| {
245                    this.when(is_from_env_var, |this| {
246                        this.tooltip_label(format!(
247                            "To reset your API key, unset the {} environment variable.",
248                            env_var_name
249                        ))
250                    })
251                })
252                .on_click(move |_, _, cx| {
253                    write_key(None, cx);
254                }),
255        )
256    } else {
257        base_container.child(header).child(
258            h_flex()
259                .pt_2p5()
260                .w_full()
261                .justify_between()
262                .child(
263                    v_flex()
264                        .w_full()
265                        .max_w_1_2()
266                        .child(Label::new("API Key"))
267                        .child(description)
268                        .when_some(env_var_name, |this, env_var_name| {
269                            this.child({
270                                let label = format!(
271                                    "Or set the {} env var and restart Zed.",
272                                    env_var_name.as_ref()
273                                );
274                                Label::new(label).size(LabelSize::Small).color(Color::Muted)
275                            })
276                        }),
277                )
278                .child(
279                    SettingsInputField::new()
280                        .tab_index(0)
281                        .with_placeholder("xxxxxxxxxxxxxxxxxxxx")
282                        .on_confirm(move |api_key, _window, cx| {
283                            write_key(api_key.filter(|key| !key.is_empty()), cx);
284                        }),
285                ),
286        )
287    };
288
289    container.when_some(additional_fields, |this, additional_fields| {
290        this.child(
291            div()
292                .map(|this| if has_key { this.mt_1() } else { this.mt_4() })
293                .px_neg_8()
294                .border_t_1()
295                .border_color(cx.theme().colors().border_variant)
296                .child(additional_fields),
297        )
298    })
299}
300
301fn sweep_settings() -> Box<[SettingsPageItem]> {
302    Box::new([SettingsPageItem::SettingItem(SettingItem {
303        title: "Privacy Mode",
304        description: "When enabled, Sweep will not store edit prediction inputs or outputs. When disabled, Sweep may collect data including buffer contents, diagnostics, file paths, and generated predictions to improve the service.",
305        field: Box::new(SettingField {
306            pick: |settings| {
307                settings
308                    .project
309                    .all_languages
310                    .edit_predictions
311                    .as_ref()?
312                    .sweep
313                    .as_ref()?
314                    .privacy_mode
315                    .as_ref()
316            },
317            write: |settings, value| {
318                settings
319                    .project
320                    .all_languages
321                    .edit_predictions
322                    .get_or_insert_default()
323                    .sweep
324                    .get_or_insert_default()
325                    .privacy_mode = value;
326            },
327            json_path: Some("edit_predictions.sweep.privacy_mode"),
328        }),
329        metadata: None,
330        files: USER,
331    })])
332}
333
334fn render_ollama_provider(
335    settings_window: &SettingsWindow,
336    window: &mut Window,
337    cx: &mut Context<SettingsWindow>,
338) -> impl IntoElement {
339    let ollama_settings = ollama_settings();
340    let additional_fields = settings_window
341        .render_sub_page_items_section(ollama_settings.iter().enumerate(), true, window, cx)
342        .into_any_element();
343
344    v_flex()
345        .id("ollama")
346        .min_w_0()
347        .pt_8()
348        .gap_1p5()
349        .child(
350            SettingsSectionHeader::new("Ollama")
351                .icon(IconName::ZedPredict)
352                .no_padding(true),
353        )
354        .child(
355            Label::new("Configure the local Ollama server and model used for edit predictions.")
356                .size(LabelSize::Small)
357                .color(Color::Muted),
358        )
359        .child(div().px_neg_8().child(additional_fields))
360}
361
362fn ollama_settings() -> Box<[SettingsPageItem]> {
363    Box::new([
364        SettingsPageItem::SettingItem(SettingItem {
365            title: "API URL",
366            description: "The base URL of your Ollama server.",
367            field: Box::new(SettingField {
368                pick: |settings| {
369                    settings
370                        .project
371                        .all_languages
372                        .edit_predictions
373                        .as_ref()?
374                        .ollama
375                        .as_ref()?
376                        .api_url
377                        .as_ref()
378                },
379                write: |settings, value| {
380                    settings
381                        .project
382                        .all_languages
383                        .edit_predictions
384                        .get_or_insert_default()
385                        .ollama
386                        .get_or_insert_default()
387                        .api_url = value;
388                },
389                json_path: Some("edit_predictions.ollama.api_url"),
390            }),
391            metadata: Some(Box::new(SettingsFieldMetadata {
392                placeholder: Some(OLLAMA_API_URL_PLACEHOLDER),
393                ..Default::default()
394            })),
395            files: USER,
396        }),
397        SettingsPageItem::SettingItem(SettingItem {
398            title: "Model",
399            description: "The Ollama model to use for edit predictions.",
400            field: Box::new(SettingField {
401                pick: |settings| {
402                    settings
403                        .project
404                        .all_languages
405                        .edit_predictions
406                        .as_ref()?
407                        .ollama
408                        .as_ref()?
409                        .model
410                        .as_ref()
411                },
412                write: |settings, value| {
413                    settings
414                        .project
415                        .all_languages
416                        .edit_predictions
417                        .get_or_insert_default()
418                        .ollama
419                        .get_or_insert_default()
420                        .model = value;
421                },
422                json_path: Some("edit_predictions.ollama.model"),
423            }),
424            metadata: Some(Box::new(SettingsFieldMetadata {
425                placeholder: Some(OLLAMA_MODEL_PLACEHOLDER),
426                ..Default::default()
427            })),
428            files: USER,
429        }),
430        SettingsPageItem::SettingItem(SettingItem {
431            title: "Max Output Tokens",
432            description: "The maximum number of tokens to generate.",
433            field: Box::new(SettingField {
434                pick: |settings| {
435                    settings
436                        .project
437                        .all_languages
438                        .edit_predictions
439                        .as_ref()?
440                        .ollama
441                        .as_ref()?
442                        .max_output_tokens
443                        .as_ref()
444                },
445                write: |settings, value| {
446                    settings
447                        .project
448                        .all_languages
449                        .edit_predictions
450                        .get_or_insert_default()
451                        .ollama
452                        .get_or_insert_default()
453                        .max_output_tokens = value;
454                },
455                json_path: Some("edit_predictions.ollama.max_output_tokens"),
456            }),
457            metadata: None,
458            files: USER,
459        }),
460    ])
461}
462
463fn codestral_settings() -> Box<[SettingsPageItem]> {
464    Box::new([
465        SettingsPageItem::SettingItem(SettingItem {
466            title: "API URL",
467            description: "The API URL to use for Codestral.",
468            field: Box::new(SettingField {
469                pick: |settings| {
470                    settings
471                        .project
472                        .all_languages
473                        .edit_predictions
474                        .as_ref()?
475                        .codestral
476                        .as_ref()?
477                        .api_url
478                        .as_ref()
479                },
480                write: |settings, value| {
481                    settings
482                        .project
483                        .all_languages
484                        .edit_predictions
485                        .get_or_insert_default()
486                        .codestral
487                        .get_or_insert_default()
488                        .api_url = value;
489                },
490                json_path: Some("edit_predictions.codestral.api_url"),
491            }),
492            metadata: Some(Box::new(SettingsFieldMetadata {
493                placeholder: Some(CODESTRAL_API_URL),
494                ..Default::default()
495            })),
496            files: USER,
497        }),
498        SettingsPageItem::SettingItem(SettingItem {
499            title: "Max Tokens",
500            description: "The maximum number of tokens to generate.",
501            field: Box::new(SettingField {
502                pick: |settings| {
503                    settings
504                        .project
505                        .all_languages
506                        .edit_predictions
507                        .as_ref()?
508                        .codestral
509                        .as_ref()?
510                        .max_tokens
511                        .as_ref()
512                },
513                write: |settings, value| {
514                    settings
515                        .project
516                        .all_languages
517                        .edit_predictions
518                        .get_or_insert_default()
519                        .codestral
520                        .get_or_insert_default()
521                        .max_tokens = value;
522                },
523                json_path: Some("edit_predictions.codestral.max_tokens"),
524            }),
525            metadata: None,
526            files: USER,
527        }),
528        SettingsPageItem::SettingItem(SettingItem {
529            title: "Model",
530            description: "The Codestral model id to use.",
531            field: Box::new(SettingField {
532                pick: |settings| {
533                    settings
534                        .project
535                        .all_languages
536                        .edit_predictions
537                        .as_ref()?
538                        .codestral
539                        .as_ref()?
540                        .model
541                        .as_ref()
542                },
543                write: |settings, value| {
544                    settings
545                        .project
546                        .all_languages
547                        .edit_predictions
548                        .get_or_insert_default()
549                        .codestral
550                        .get_or_insert_default()
551                        .model = value;
552                },
553                json_path: Some("edit_predictions.codestral.model"),
554            }),
555            metadata: Some(Box::new(SettingsFieldMetadata {
556                placeholder: Some("codestral-latest"),
557                ..Default::default()
558            })),
559            files: USER,
560        }),
561    ])
562}
563
564fn render_github_copilot_provider(window: &mut Window, cx: &mut App) -> Option<impl IntoElement> {
565    let configuration_view = window.use_state(cx, |_, cx| {
566        copilot_ui::ConfigurationView::new(
567            move |cx| {
568                if let Some(app_state) = AppState::global(cx).upgrade() {
569                    copilot::GlobalCopilotAuth::try_get_or_init(app_state, cx)
570                        .is_some_and(|copilot| copilot.0.read(cx).is_authenticated())
571                } else {
572                    false
573                }
574            },
575            copilot_ui::ConfigurationMode::EditPrediction,
576            cx,
577        )
578    });
579
580    Some(
581        v_flex()
582            .id("github-copilot")
583            .min_w_0()
584            .pt_8()
585            .gap_1p5()
586            .child(
587                SettingsSectionHeader::new("GitHub Copilot")
588                    .icon(IconName::Copilot)
589                    .no_padding(true),
590            )
591            .child(configuration_view),
592    )
593}