edit_prediction_provider_setup.rs

  1use codestral::{CODESTRAL_API_URL, codestral_api_key_state, codestral_api_url};
  2use edit_prediction::{
  3    ApiKeyState,
  4    mercury::{MERCURY_CREDENTIALS_URL, mercury_api_token},
  5    open_ai_compatible::{open_ai_compatible_api_token, open_ai_compatible_api_url},
  6    sweep_ai::{SWEEP_CREDENTIALS_URL, sweep_api_token},
  7};
  8use edit_prediction_ui::{get_available_providers, set_completion_provider};
  9use gpui::{Entity, ScrollHandle, prelude::*};
 10use language::language_settings::AllLanguageSettings;
 11
 12use settings::Settings as _;
 13use ui::{ButtonLink, ConfiguredApiCard, ContextMenu, DropdownMenu, DropdownStyle, prelude::*};
 14use workspace::AppState;
 15
 16const OLLAMA_API_URL_PLACEHOLDER: &str = "http://localhost:11434";
 17const OLLAMA_MODEL_PLACEHOLDER: &str = "qwen2.5-coder:3b-base";
 18
 19use crate::{
 20    SettingField, SettingItem, SettingsFieldMetadata, SettingsPageItem, SettingsWindow, USER,
 21    components::{SettingsInputField, SettingsSectionHeader},
 22};
 23
 24pub(crate) fn render_edit_prediction_setup_page(
 25    settings_window: &SettingsWindow,
 26    scroll_handle: &ScrollHandle,
 27    window: &mut Window,
 28    cx: &mut Context<SettingsWindow>,
 29) -> AnyElement {
 30    let providers = [
 31        Some(render_provider_dropdown(window, cx)),
 32        render_github_copilot_provider(window, cx).map(IntoElement::into_any_element),
 33        Some(
 34            render_api_key_provider(
 35                IconName::Inception,
 36                "Mercury",
 37                ApiKeyDocs::Link {
 38                    dashboard_url: "https://platform.inceptionlabs.ai/dashboard/api-keys".into(),
 39                },
 40                mercury_api_token(cx),
 41                |_cx| MERCURY_CREDENTIALS_URL,
 42                None,
 43                window,
 44                cx,
 45            )
 46            .into_any_element(),
 47        ),
 48        Some(
 49            render_api_key_provider(
 50                IconName::SweepAi,
 51                "Sweep",
 52                ApiKeyDocs::Link {
 53                    dashboard_url: "https://app.sweep.dev/".into(),
 54                },
 55                sweep_api_token(cx),
 56                |_cx| SWEEP_CREDENTIALS_URL,
 57                Some(
 58                    settings_window
 59                        .render_sub_page_items_section(
 60                            sweep_settings().iter().enumerate(),
 61                            true,
 62                            window,
 63                            cx,
 64                        )
 65                        .into_any_element(),
 66                ),
 67                window,
 68                cx,
 69            )
 70            .into_any_element(),
 71        ),
 72        Some(
 73            render_api_key_provider(
 74                IconName::AiMistral,
 75                "Codestral",
 76                ApiKeyDocs::Link {
 77                    dashboard_url: "https://console.mistral.ai/codestral".into(),
 78                },
 79                codestral_api_key_state(cx),
 80                |cx| codestral_api_url(cx),
 81                Some(
 82                    settings_window
 83                        .render_sub_page_items_section(
 84                            codestral_settings().iter().enumerate(),
 85                            true,
 86                            window,
 87                            cx,
 88                        )
 89                        .into_any_element(),
 90                ),
 91                window,
 92                cx,
 93            )
 94            .into_any_element(),
 95        ),
 96        Some(render_ollama_provider(settings_window, window, cx).into_any_element()),
 97        Some(
 98            render_api_key_provider(
 99                IconName::AiOpenAiCompat,
100                "OpenAI Compatible API",
101                ApiKeyDocs::Custom {
102                    message: "The API key sent as Authorization: Bearer {key}.".into(),
103                },
104                open_ai_compatible_api_token(cx),
105                |cx| open_ai_compatible_api_url(cx),
106                Some(
107                    settings_window
108                        .render_sub_page_items_section(
109                            open_ai_compatible_settings().iter().enumerate(),
110                            true,
111                            window,
112                            cx,
113                        )
114                        .into_any_element(),
115                ),
116                window,
117                cx,
118            )
119            .into_any_element(),
120        ),
121    ];
122
123    div()
124        .size_full()
125        .child(
126            v_flex()
127                .id("ep-setup-page")
128                .min_w_0()
129                .size_full()
130                .px_8()
131                .pb_16()
132                .overflow_y_scroll()
133                .track_scroll(&scroll_handle)
134                .children(providers.into_iter().flatten()),
135        )
136        .into_any_element()
137}
138
139fn render_provider_dropdown(window: &mut Window, cx: &mut App) -> AnyElement {
140    let current_provider = AllLanguageSettings::get_global(cx)
141        .edit_predictions
142        .provider;
143    let current_provider_name = current_provider.display_name().unwrap_or("No provider set");
144
145    let menu = ContextMenu::build(window, cx, move |mut menu, _, cx| {
146        let available_providers = get_available_providers(cx);
147        let fs = <dyn fs::Fs>::global(cx);
148
149        for provider in available_providers {
150            let Some(name) = provider.display_name() else {
151                continue;
152            };
153            let is_current = provider == current_provider;
154
155            menu = menu.toggleable_entry(name, is_current, IconPosition::Start, None, {
156                let fs = fs.clone();
157                move |_, cx| {
158                    set_completion_provider(fs.clone(), cx, provider);
159                }
160            });
161        }
162        menu
163    });
164
165    v_flex()
166        .id("provider-selector")
167        .min_w_0()
168        .gap_1p5()
169        .child(SettingsSectionHeader::new("Active Provider").no_padding(true))
170        .child(
171            h_flex()
172                .pt_2p5()
173                .w_full()
174                .min_w_0()
175                .justify_between()
176                .child(
177                    v_flex()
178                        .w_full()
179                        .min_w_0()
180                        .max_w_1_2()
181                        .child(Label::new("Provider"))
182                        .child(
183                            Label::new("Select which provider to use for edit predictions.")
184                                .size(LabelSize::Small)
185                                .color(Color::Muted),
186                        ),
187                )
188                .child(
189                    DropdownMenu::new("provider-dropdown", current_provider_name, menu)
190                        .tab_index(0)
191                        .style(DropdownStyle::Outlined),
192                ),
193        )
194        .into_any_element()
195}
196
197enum ApiKeyDocs {
198    Link { dashboard_url: SharedString },
199    Custom { message: SharedString },
200}
201
202fn render_api_key_provider(
203    icon: IconName,
204    title: &'static str,
205    docs: ApiKeyDocs,
206    api_key_state: Entity<ApiKeyState>,
207    current_url: fn(&mut App) -> SharedString,
208    additional_fields: Option<AnyElement>,
209    window: &mut Window,
210    cx: &mut Context<SettingsWindow>,
211) -> impl IntoElement {
212    let weak_page = cx.weak_entity();
213    _ = window.use_keyed_state(current_url(cx), cx, |_, cx| {
214        let task = api_key_state.update(cx, |key_state, cx| {
215            key_state.load_if_needed(current_url(cx), |state| state, cx)
216        });
217        cx.spawn(async move |_, cx| {
218            task.await.ok();
219            weak_page
220                .update(cx, |_, cx| {
221                    cx.notify();
222                })
223                .ok();
224        })
225    });
226
227    let (has_key, env_var_name, is_from_env_var) = api_key_state.read_with(cx, |state, _| {
228        (
229            state.has_key(),
230            Some(state.env_var_name().clone()),
231            state.is_from_env_var(),
232        )
233    });
234
235    let write_key = move |api_key: Option<String>, cx: &mut App| {
236        api_key_state
237            .update(cx, |key_state, cx| {
238                let url = current_url(cx);
239                key_state.store(url, api_key, |key_state| key_state, cx)
240            })
241            .detach_and_log_err(cx);
242    };
243
244    let base_container = v_flex().id(title).min_w_0().pt_8().gap_1p5();
245    let header = SettingsSectionHeader::new(title)
246        .icon(icon)
247        .no_padding(true);
248    let button_link_label = format!("{} dashboard", title);
249    let description = match docs {
250        ApiKeyDocs::Custom { message } => div().min_w_0().w_full().child(
251            Label::new(message)
252                .size(LabelSize::Small)
253                .color(Color::Muted),
254        ),
255        ApiKeyDocs::Link { dashboard_url } => h_flex()
256            .w_full()
257            .min_w_0()
258            .flex_wrap()
259            .gap_0p5()
260            .child(
261                Label::new("Visit the")
262                    .size(LabelSize::Small)
263                    .color(Color::Muted),
264            )
265            .child(
266                ButtonLink::new(button_link_label, dashboard_url)
267                    .no_icon(true)
268                    .label_size(LabelSize::Small)
269                    .label_color(Color::Muted),
270            )
271            .child(
272                Label::new("to generate an API key.")
273                    .size(LabelSize::Small)
274                    .color(Color::Muted),
275            ),
276    };
277    let configured_card_label = if is_from_env_var {
278        "API Key Set in Environment Variable"
279    } else {
280        "API Key Configured"
281    };
282
283    let container = if has_key {
284        base_container.child(header).child(
285            ConfiguredApiCard::new(configured_card_label)
286                .button_label("Reset Key")
287                .button_tab_index(0)
288                .disabled(is_from_env_var)
289                .when_some(env_var_name, |this, env_var_name| {
290                    this.when(is_from_env_var, |this| {
291                        this.tooltip_label(format!(
292                            "To reset your API key, unset the {} environment variable.",
293                            env_var_name
294                        ))
295                    })
296                })
297                .on_click(move |_, _, cx| {
298                    write_key(None, cx);
299                }),
300        )
301    } else {
302        base_container.child(header).child(
303            h_flex()
304                .pt_2p5()
305                .w_full()
306                .min_w_0()
307                .justify_between()
308                .child(
309                    v_flex()
310                        .w_full()
311                        .min_w_0()
312                        .max_w_1_2()
313                        .child(Label::new("API Key"))
314                        .child(description)
315                        .when_some(env_var_name, |this, env_var_name| {
316                            this.child({
317                                let label = format!(
318                                    "Or set the {} env var and restart Zed.",
319                                    env_var_name.as_ref()
320                                );
321                                Label::new(label).size(LabelSize::Small).color(Color::Muted)
322                            })
323                        }),
324                )
325                .child(
326                    SettingsInputField::new()
327                        .tab_index(0)
328                        .with_placeholder("xxxxxxxxxxxxxxxxxxxx")
329                        .on_confirm(move |api_key, _window, cx| {
330                            write_key(api_key.filter(|key| !key.is_empty()), cx);
331                        }),
332                ),
333        )
334    };
335
336    container.when_some(additional_fields, |this, additional_fields| {
337        this.child(
338            div()
339                .map(|this| if has_key { this.mt_1() } else { this.mt_4() })
340                .px_neg_8()
341                .border_t_1()
342                .border_color(cx.theme().colors().border_variant)
343                .child(additional_fields),
344        )
345    })
346}
347
348fn sweep_settings() -> Box<[SettingsPageItem]> {
349    Box::new([SettingsPageItem::SettingItem(SettingItem {
350        title: "Privacy Mode",
351        description: "When enabled, Sweep will not store edit prediction inputs or outputs. When disabled, Sweep may collect data including buffer contents, diagnostics, file paths, and generated predictions to improve the service.",
352        field: Box::new(SettingField {
353            pick: |settings| {
354                settings
355                    .project
356                    .all_languages
357                    .edit_predictions
358                    .as_ref()?
359                    .sweep
360                    .as_ref()?
361                    .privacy_mode
362                    .as_ref()
363            },
364            write: |settings, value| {
365                settings
366                    .project
367                    .all_languages
368                    .edit_predictions
369                    .get_or_insert_default()
370                    .sweep
371                    .get_or_insert_default()
372                    .privacy_mode = value;
373            },
374            json_path: Some("edit_predictions.sweep.privacy_mode"),
375        }),
376        metadata: None,
377        files: USER,
378    })])
379}
380
381fn render_ollama_provider(
382    settings_window: &SettingsWindow,
383    window: &mut Window,
384    cx: &mut Context<SettingsWindow>,
385) -> impl IntoElement {
386    let ollama_settings = ollama_settings();
387    let additional_fields = settings_window
388        .render_sub_page_items_section(ollama_settings.iter().enumerate(), true, window, cx)
389        .into_any_element();
390
391    v_flex()
392        .id("ollama")
393        .min_w_0()
394        .pt_8()
395        .gap_1p5()
396        .child(
397            SettingsSectionHeader::new("Ollama")
398                .icon(IconName::AiOllama)
399                .no_padding(true),
400        )
401        .child(div().px_neg_8().child(additional_fields))
402}
403
404fn ollama_settings() -> Box<[SettingsPageItem]> {
405    Box::new([
406        SettingsPageItem::SettingItem(SettingItem {
407            title: "API URL",
408            description: "The base URL of your Ollama server.",
409            field: Box::new(SettingField {
410                pick: |settings| {
411                    settings
412                        .project
413                        .all_languages
414                        .edit_predictions
415                        .as_ref()?
416                        .ollama
417                        .as_ref()?
418                        .api_url
419                        .as_ref()
420                },
421                write: |settings, value| {
422                    settings
423                        .project
424                        .all_languages
425                        .edit_predictions
426                        .get_or_insert_default()
427                        .ollama
428                        .get_or_insert_default()
429                        .api_url = value;
430                },
431                json_path: Some("edit_predictions.ollama.api_url"),
432            }),
433            metadata: Some(Box::new(SettingsFieldMetadata {
434                placeholder: Some(OLLAMA_API_URL_PLACEHOLDER),
435                ..Default::default()
436            })),
437            files: USER,
438        }),
439        SettingsPageItem::SettingItem(SettingItem {
440            title: "Model",
441            description: "The Ollama model to use for edit predictions.",
442            field: Box::new(SettingField {
443                pick: |settings| {
444                    settings
445                        .project
446                        .all_languages
447                        .edit_predictions
448                        .as_ref()?
449                        .ollama
450                        .as_ref()?
451                        .model
452                        .as_ref()
453                },
454                write: |settings, value| {
455                    settings
456                        .project
457                        .all_languages
458                        .edit_predictions
459                        .get_or_insert_default()
460                        .ollama
461                        .get_or_insert_default()
462                        .model = value;
463                },
464                json_path: Some("edit_predictions.ollama.model"),
465            }),
466            metadata: Some(Box::new(SettingsFieldMetadata {
467                placeholder: Some(OLLAMA_MODEL_PLACEHOLDER),
468                ..Default::default()
469            })),
470            files: USER,
471        }),
472        SettingsPageItem::SettingItem(SettingItem {
473            title: "Prompt Format",
474            description: "The prompt format to use when requesting predictions. Set to Infer to have the format inferred based on the model name.",
475            field: Box::new(SettingField {
476                pick: |settings| {
477                    settings
478                        .project
479                        .all_languages
480                        .edit_predictions
481                        .as_ref()?
482                        .ollama
483                        .as_ref()?
484                        .prompt_format
485                        .as_ref()
486                },
487                write: |settings, value| {
488                    settings
489                        .project
490                        .all_languages
491                        .edit_predictions
492                        .get_or_insert_default()
493                        .ollama
494                        .get_or_insert_default()
495                        .prompt_format = value;
496                },
497                json_path: Some("edit_predictions.ollama.prompt_format"),
498            }),
499            files: USER,
500            metadata: None,
501        }),
502        SettingsPageItem::SettingItem(SettingItem {
503            title: "Max Output Tokens",
504            description: "The maximum number of tokens to generate.",
505            field: Box::new(SettingField {
506                pick: |settings| {
507                    settings
508                        .project
509                        .all_languages
510                        .edit_predictions
511                        .as_ref()?
512                        .ollama
513                        .as_ref()?
514                        .max_output_tokens
515                        .as_ref()
516                },
517                write: |settings, value| {
518                    settings
519                        .project
520                        .all_languages
521                        .edit_predictions
522                        .get_or_insert_default()
523                        .ollama
524                        .get_or_insert_default()
525                        .max_output_tokens = value;
526                },
527                json_path: Some("edit_predictions.ollama.max_output_tokens"),
528            }),
529            metadata: None,
530            files: USER,
531        }),
532    ])
533}
534
535fn open_ai_compatible_settings() -> Box<[SettingsPageItem]> {
536    Box::new([
537        SettingsPageItem::SettingItem(SettingItem {
538            title: "API URL",
539            description: "The URL of your OpenAI-compatible server's completions API.",
540            field: Box::new(SettingField {
541                pick: |settings| {
542                    settings
543                        .project
544                        .all_languages
545                        .edit_predictions
546                        .as_ref()?
547                        .open_ai_compatible_api
548                        .as_ref()?
549                        .api_url
550                        .as_ref()
551                },
552                write: |settings, value| {
553                    settings
554                        .project
555                        .all_languages
556                        .edit_predictions
557                        .get_or_insert_default()
558                        .open_ai_compatible_api
559                        .get_or_insert_default()
560                        .api_url = value;
561                },
562                json_path: Some("edit_predictions.open_ai_compatible_api.api_url"),
563            }),
564            metadata: Some(Box::new(SettingsFieldMetadata {
565                placeholder: Some(OLLAMA_API_URL_PLACEHOLDER),
566                ..Default::default()
567            })),
568            files: USER,
569        }),
570        SettingsPageItem::SettingItem(SettingItem {
571            title: "Model",
572            description: "The model string to pass to the OpenAI-compatible server.",
573            field: Box::new(SettingField {
574                pick: |settings| {
575                    settings
576                        .project
577                        .all_languages
578                        .edit_predictions
579                        .as_ref()?
580                        .open_ai_compatible_api
581                        .as_ref()?
582                        .model
583                        .as_ref()
584                },
585                write: |settings, value| {
586                    settings
587                        .project
588                        .all_languages
589                        .edit_predictions
590                        .get_or_insert_default()
591                        .open_ai_compatible_api
592                        .get_or_insert_default()
593                        .model = value;
594                },
595                json_path: Some("edit_predictions.open_ai_compatible_api.model"),
596            }),
597            metadata: Some(Box::new(SettingsFieldMetadata {
598                placeholder: Some(OLLAMA_MODEL_PLACEHOLDER),
599                ..Default::default()
600            })),
601            files: USER,
602        }),
603        SettingsPageItem::SettingItem(SettingItem {
604            title: "Prompt Format",
605            description: "The prompt format to use when requesting predictions. Set to Infer to have the format inferred based on the model name.",
606            field: Box::new(SettingField {
607                pick: |settings| {
608                    settings
609                        .project
610                        .all_languages
611                        .edit_predictions
612                        .as_ref()?
613                        .open_ai_compatible_api
614                        .as_ref()?
615                        .prompt_format
616                        .as_ref()
617                },
618                write: |settings, value| {
619                    settings
620                        .project
621                        .all_languages
622                        .edit_predictions
623                        .get_or_insert_default()
624                        .open_ai_compatible_api
625                        .get_or_insert_default()
626                        .prompt_format = value;
627                },
628                json_path: Some("edit_predictions.open_ai_compatible_api.prompt_format"),
629            }),
630            files: USER,
631            metadata: None,
632        }),
633        SettingsPageItem::SettingItem(SettingItem {
634            title: "Max Output Tokens",
635            description: "The maximum number of tokens to generate.",
636            field: Box::new(SettingField {
637                pick: |settings| {
638                    settings
639                        .project
640                        .all_languages
641                        .edit_predictions
642                        .as_ref()?
643                        .open_ai_compatible_api
644                        .as_ref()?
645                        .max_output_tokens
646                        .as_ref()
647                },
648                write: |settings, value| {
649                    settings
650                        .project
651                        .all_languages
652                        .edit_predictions
653                        .get_or_insert_default()
654                        .open_ai_compatible_api
655                        .get_or_insert_default()
656                        .max_output_tokens = value;
657                },
658                json_path: Some("edit_predictions.open_ai_compatible_api.max_output_tokens"),
659            }),
660            metadata: None,
661            files: USER,
662        }),
663    ])
664}
665
666fn codestral_settings() -> Box<[SettingsPageItem]> {
667    Box::new([
668        SettingsPageItem::SettingItem(SettingItem {
669            title: "API URL",
670            description: "The API URL to use for Codestral.",
671            field: Box::new(SettingField {
672                pick: |settings| {
673                    settings
674                        .project
675                        .all_languages
676                        .edit_predictions
677                        .as_ref()?
678                        .codestral
679                        .as_ref()?
680                        .api_url
681                        .as_ref()
682                },
683                write: |settings, value| {
684                    settings
685                        .project
686                        .all_languages
687                        .edit_predictions
688                        .get_or_insert_default()
689                        .codestral
690                        .get_or_insert_default()
691                        .api_url = value;
692                },
693                json_path: Some("edit_predictions.codestral.api_url"),
694            }),
695            metadata: Some(Box::new(SettingsFieldMetadata {
696                placeholder: Some(CODESTRAL_API_URL),
697                ..Default::default()
698            })),
699            files: USER,
700        }),
701        SettingsPageItem::SettingItem(SettingItem {
702            title: "Max Tokens",
703            description: "The maximum number of tokens to generate.",
704            field: Box::new(SettingField {
705                pick: |settings| {
706                    settings
707                        .project
708                        .all_languages
709                        .edit_predictions
710                        .as_ref()?
711                        .codestral
712                        .as_ref()?
713                        .max_tokens
714                        .as_ref()
715                },
716                write: |settings, value| {
717                    settings
718                        .project
719                        .all_languages
720                        .edit_predictions
721                        .get_or_insert_default()
722                        .codestral
723                        .get_or_insert_default()
724                        .max_tokens = value;
725                },
726                json_path: Some("edit_predictions.codestral.max_tokens"),
727            }),
728            metadata: None,
729            files: USER,
730        }),
731        SettingsPageItem::SettingItem(SettingItem {
732            title: "Model",
733            description: "The Codestral model id to use.",
734            field: Box::new(SettingField {
735                pick: |settings| {
736                    settings
737                        .project
738                        .all_languages
739                        .edit_predictions
740                        .as_ref()?
741                        .codestral
742                        .as_ref()?
743                        .model
744                        .as_ref()
745                },
746                write: |settings, value| {
747                    settings
748                        .project
749                        .all_languages
750                        .edit_predictions
751                        .get_or_insert_default()
752                        .codestral
753                        .get_or_insert_default()
754                        .model = value;
755                },
756                json_path: Some("edit_predictions.codestral.model"),
757            }),
758            metadata: Some(Box::new(SettingsFieldMetadata {
759                placeholder: Some("codestral-latest"),
760                ..Default::default()
761            })),
762            files: USER,
763        }),
764    ])
765}
766
767fn render_github_copilot_provider(window: &mut Window, cx: &mut App) -> Option<impl IntoElement> {
768    let configuration_view = window.use_state(cx, |_, cx| {
769        copilot_ui::ConfigurationView::new(
770            move |cx| {
771                if let Some(app_state) = AppState::global(cx).upgrade() {
772                    copilot::GlobalCopilotAuth::try_get_or_init(app_state, cx)
773                        .is_some_and(|copilot| copilot.0.read(cx).is_authenticated())
774                } else {
775                    false
776                }
777            },
778            copilot_ui::ConfigurationMode::EditPrediction,
779            cx,
780        )
781    });
782
783    Some(
784        v_flex()
785            .id("github-copilot")
786            .min_w_0()
787            .pt_8()
788            .gap_1p5()
789            .child(
790                SettingsSectionHeader::new("GitHub Copilot")
791                    .icon(IconName::Copilot)
792                    .no_padding(true),
793            )
794            .child(configuration_view),
795    )
796}