edit_prediction_provider_setup.rs

  1use codestral::{CODESTRAL_API_URL, codestral_api_key_state, codestral_api_url};
  2use edit_prediction::{
  3    ApiKeyState,
  4    mercury::{MERCURY_CREDENTIALS_URL, mercury_api_token},
  5    open_ai_compatible::{open_ai_compatible_api_token, open_ai_compatible_api_url},
  6    sweep_ai::{SWEEP_CREDENTIALS_URL, sweep_api_token},
  7};
  8use edit_prediction_ui::{get_available_providers, set_completion_provider};
  9use gpui::{Entity, ScrollHandle, prelude::*};
 10use language::language_settings::AllLanguageSettings;
 11
 12use settings::Settings as _;
 13use ui::{ButtonLink, ConfiguredApiCard, ContextMenu, DropdownMenu, DropdownStyle, prelude::*};
 14use workspace::AppState;
 15
 16const OLLAMA_API_URL_PLACEHOLDER: &str = "http://localhost:11434";
 17const OLLAMA_MODEL_PLACEHOLDER: &str = "qwen2.5-coder:3b-base";
 18
 19use crate::{
 20    SettingField, SettingItem, SettingsFieldMetadata, SettingsPageItem, SettingsWindow, USER,
 21    components::{SettingsInputField, SettingsSectionHeader},
 22};
 23
 24pub(crate) fn render_edit_prediction_setup_page(
 25    settings_window: &SettingsWindow,
 26    scroll_handle: &ScrollHandle,
 27    window: &mut Window,
 28    cx: &mut Context<SettingsWindow>,
 29) -> AnyElement {
 30    let providers = [
 31        Some(render_provider_dropdown(window, cx)),
 32        render_github_copilot_provider(window, cx).map(IntoElement::into_any_element),
 33        Some(
 34            render_api_key_provider(
 35                IconName::Inception,
 36                "Mercury",
 37                ApiKeyDocs::Link {
 38                    dashboard_url: "https://platform.inceptionlabs.ai/dashboard/api-keys".into(),
 39                },
 40                mercury_api_token(cx),
 41                |_cx| MERCURY_CREDENTIALS_URL,
 42                None,
 43                window,
 44                cx,
 45            )
 46            .into_any_element(),
 47        ),
 48        Some(
 49            render_api_key_provider(
 50                IconName::SweepAi,
 51                "Sweep",
 52                ApiKeyDocs::Link {
 53                    dashboard_url: "https://app.sweep.dev/".into(),
 54                },
 55                sweep_api_token(cx),
 56                |_cx| SWEEP_CREDENTIALS_URL,
 57                Some(
 58                    settings_window
 59                        .render_sub_page_items_section(
 60                            sweep_settings().iter().enumerate(),
 61                            true,
 62                            window,
 63                            cx,
 64                        )
 65                        .into_any_element(),
 66                ),
 67                window,
 68                cx,
 69            )
 70            .into_any_element(),
 71        ),
 72        Some(
 73            render_api_key_provider(
 74                IconName::AiMistral,
 75                "Codestral",
 76                ApiKeyDocs::Link {
 77                    dashboard_url: "https://console.mistral.ai/codestral".into(),
 78                },
 79                codestral_api_key_state(cx),
 80                |cx| codestral_api_url(cx),
 81                Some(
 82                    settings_window
 83                        .render_sub_page_items_section(
 84                            codestral_settings().iter().enumerate(),
 85                            true,
 86                            window,
 87                            cx,
 88                        )
 89                        .into_any_element(),
 90                ),
 91                window,
 92                cx,
 93            )
 94            .into_any_element(),
 95        ),
 96        Some(render_ollama_provider(settings_window, window, cx).into_any_element()),
 97        Some(
 98            render_api_key_provider(
 99                IconName::AiOpenAiCompat,
100                "OpenAI Compatible API",
101                ApiKeyDocs::Custom {
102                    message: "Set an API key here. It will be sent as Authorization: Bearer {key}."
103                        .into(),
104                },
105                open_ai_compatible_api_token(cx),
106                |cx| open_ai_compatible_api_url(cx),
107                Some(
108                    settings_window
109                        .render_sub_page_items_section(
110                            open_ai_compatible_settings().iter().enumerate(),
111                            true,
112                            window,
113                            cx,
114                        )
115                        .into_any_element(),
116                ),
117                window,
118                cx,
119            )
120            .into_any_element(),
121        ),
122    ];
123
124    div()
125        .size_full()
126        .child(
127            v_flex()
128                .id("ep-setup-page")
129                .min_w_0()
130                .size_full()
131                .px_8()
132                .pb_16()
133                .overflow_y_scroll()
134                .track_scroll(&scroll_handle)
135                .children(providers.into_iter().flatten()),
136        )
137        .into_any_element()
138}
139
140fn render_provider_dropdown(window: &mut Window, cx: &mut App) -> AnyElement {
141    let current_provider = AllLanguageSettings::get_global(cx)
142        .edit_predictions
143        .provider;
144    let current_provider_name = current_provider.display_name().unwrap_or("No provider set");
145
146    let menu = ContextMenu::build(window, cx, move |mut menu, _, cx| {
147        let available_providers = get_available_providers(cx);
148        let fs = <dyn fs::Fs>::global(cx);
149
150        for provider in available_providers {
151            let Some(name) = provider.display_name() else {
152                continue;
153            };
154            let is_current = provider == current_provider;
155
156            menu = menu.toggleable_entry(name, is_current, IconPosition::Start, None, {
157                let fs = fs.clone();
158                move |_, cx| {
159                    set_completion_provider(fs.clone(), cx, provider);
160                }
161            });
162        }
163        menu
164    });
165
166    v_flex()
167        .id("provider-selector")
168        .min_w_0()
169        .gap_1p5()
170        .child(SettingsSectionHeader::new("Active Provider").no_padding(true))
171        .child(
172            h_flex()
173                .pt_2p5()
174                .w_full()
175                .justify_between()
176                .child(
177                    v_flex()
178                        .w_full()
179                        .max_w_1_2()
180                        .child(Label::new("Provider"))
181                        .child(
182                            Label::new("Select which provider to use for edit predictions.")
183                                .size(LabelSize::Small)
184                                .color(Color::Muted),
185                        ),
186                )
187                .child(
188                    DropdownMenu::new("provider-dropdown", current_provider_name, menu)
189                        .tab_index(0)
190                        .style(DropdownStyle::Outlined),
191                ),
192        )
193        .into_any_element()
194}
195
196enum ApiKeyDocs {
197    Link { dashboard_url: SharedString },
198    Custom { message: SharedString },
199}
200
201fn render_api_key_provider(
202    icon: IconName,
203    title: &'static str,
204    docs: ApiKeyDocs,
205    api_key_state: Entity<ApiKeyState>,
206    current_url: fn(&mut App) -> SharedString,
207    additional_fields: Option<AnyElement>,
208    window: &mut Window,
209    cx: &mut Context<SettingsWindow>,
210) -> impl IntoElement {
211    let weak_page = cx.weak_entity();
212    _ = window.use_keyed_state(current_url(cx), cx, |_, cx| {
213        let task = api_key_state.update(cx, |key_state, cx| {
214            key_state.load_if_needed(current_url(cx), |state| state, cx)
215        });
216        cx.spawn(async move |_, cx| {
217            task.await.ok();
218            weak_page
219                .update(cx, |_, cx| {
220                    cx.notify();
221                })
222                .ok();
223        })
224    });
225
226    let (has_key, env_var_name, is_from_env_var) = api_key_state.read_with(cx, |state, _| {
227        (
228            state.has_key(),
229            Some(state.env_var_name().clone()),
230            state.is_from_env_var(),
231        )
232    });
233
234    let write_key = move |api_key: Option<String>, cx: &mut App| {
235        api_key_state
236            .update(cx, |key_state, cx| {
237                let url = current_url(cx);
238                key_state.store(url, api_key, |key_state| key_state, cx)
239            })
240            .detach_and_log_err(cx);
241    };
242
243    let base_container = v_flex().id(title).min_w_0().pt_8().gap_1p5();
244    let header = SettingsSectionHeader::new(title)
245        .icon(icon)
246        .no_padding(true);
247    let button_link_label = format!("{} dashboard", title);
248    let description = match docs {
249        ApiKeyDocs::Custom { message } => h_flex().min_w_0().gap_0p5().child(
250            Label::new(message)
251                .size(LabelSize::Small)
252                .color(Color::Muted),
253        ),
254        ApiKeyDocs::Link { dashboard_url } => h_flex()
255            .min_w_0()
256            .gap_0p5()
257            .child(
258                Label::new("Visit the")
259                    .size(LabelSize::Small)
260                    .color(Color::Muted),
261            )
262            .child(
263                ButtonLink::new(button_link_label, dashboard_url)
264                    .no_icon(true)
265                    .label_size(LabelSize::Small)
266                    .label_color(Color::Muted),
267            )
268            .child(
269                Label::new("to generate an API key.")
270                    .size(LabelSize::Small)
271                    .color(Color::Muted),
272            ),
273    };
274    let configured_card_label = if is_from_env_var {
275        "API Key Set in Environment Variable"
276    } else {
277        "API Key Configured"
278    };
279
280    let container = if has_key {
281        base_container.child(header).child(
282            ConfiguredApiCard::new(configured_card_label)
283                .button_label("Reset Key")
284                .button_tab_index(0)
285                .disabled(is_from_env_var)
286                .when_some(env_var_name, |this, env_var_name| {
287                    this.when(is_from_env_var, |this| {
288                        this.tooltip_label(format!(
289                            "To reset your API key, unset the {} environment variable.",
290                            env_var_name
291                        ))
292                    })
293                })
294                .on_click(move |_, _, cx| {
295                    write_key(None, cx);
296                }),
297        )
298    } else {
299        base_container.child(header).child(
300            h_flex()
301                .pt_2p5()
302                .w_full()
303                .justify_between()
304                .child(
305                    v_flex()
306                        .w_full()
307                        .max_w_1_2()
308                        .child(Label::new("API Key"))
309                        .child(description)
310                        .when_some(env_var_name, |this, env_var_name| {
311                            this.child({
312                                let label = format!(
313                                    "Or set the {} env var and restart Zed.",
314                                    env_var_name.as_ref()
315                                );
316                                Label::new(label).size(LabelSize::Small).color(Color::Muted)
317                            })
318                        }),
319                )
320                .child(
321                    SettingsInputField::new()
322                        .tab_index(0)
323                        .with_placeholder("xxxxxxxxxxxxxxxxxxxx")
324                        .on_confirm(move |api_key, _window, cx| {
325                            write_key(api_key.filter(|key| !key.is_empty()), cx);
326                        }),
327                ),
328        )
329    };
330
331    container.when_some(additional_fields, |this, additional_fields| {
332        this.child(
333            div()
334                .map(|this| if has_key { this.mt_1() } else { this.mt_4() })
335                .px_neg_8()
336                .border_t_1()
337                .border_color(cx.theme().colors().border_variant)
338                .child(additional_fields),
339        )
340    })
341}
342
343fn sweep_settings() -> Box<[SettingsPageItem]> {
344    Box::new([SettingsPageItem::SettingItem(SettingItem {
345        title: "Privacy Mode",
346        description: "When enabled, Sweep will not store edit prediction inputs or outputs. When disabled, Sweep may collect data including buffer contents, diagnostics, file paths, and generated predictions to improve the service.",
347        field: Box::new(SettingField {
348            pick: |settings| {
349                settings
350                    .project
351                    .all_languages
352                    .edit_predictions
353                    .as_ref()?
354                    .sweep
355                    .as_ref()?
356                    .privacy_mode
357                    .as_ref()
358            },
359            write: |settings, value| {
360                settings
361                    .project
362                    .all_languages
363                    .edit_predictions
364                    .get_or_insert_default()
365                    .sweep
366                    .get_or_insert_default()
367                    .privacy_mode = value;
368            },
369            json_path: Some("edit_predictions.sweep.privacy_mode"),
370        }),
371        metadata: None,
372        files: USER,
373    })])
374}
375
376fn render_ollama_provider(
377    settings_window: &SettingsWindow,
378    window: &mut Window,
379    cx: &mut Context<SettingsWindow>,
380) -> impl IntoElement {
381    let ollama_settings = ollama_settings();
382    let additional_fields = settings_window
383        .render_sub_page_items_section(ollama_settings.iter().enumerate(), true, window, cx)
384        .into_any_element();
385
386    v_flex()
387        .id("ollama")
388        .min_w_0()
389        .pt_8()
390        .gap_1p5()
391        .child(
392            SettingsSectionHeader::new("Ollama")
393                .icon(IconName::AiOllama)
394                .no_padding(true),
395        )
396        .child(div().px_neg_8().child(additional_fields))
397}
398
399fn ollama_settings() -> Box<[SettingsPageItem]> {
400    Box::new([
401        SettingsPageItem::SettingItem(SettingItem {
402            title: "API URL",
403            description: "The base URL of your Ollama server.",
404            field: Box::new(SettingField {
405                pick: |settings| {
406                    settings
407                        .project
408                        .all_languages
409                        .edit_predictions
410                        .as_ref()?
411                        .ollama
412                        .as_ref()?
413                        .api_url
414                        .as_ref()
415                },
416                write: |settings, value| {
417                    settings
418                        .project
419                        .all_languages
420                        .edit_predictions
421                        .get_or_insert_default()
422                        .ollama
423                        .get_or_insert_default()
424                        .api_url = value;
425                },
426                json_path: Some("edit_predictions.ollama.api_url"),
427            }),
428            metadata: Some(Box::new(SettingsFieldMetadata {
429                placeholder: Some(OLLAMA_API_URL_PLACEHOLDER),
430                ..Default::default()
431            })),
432            files: USER,
433        }),
434        SettingsPageItem::SettingItem(SettingItem {
435            title: "Model",
436            description: "The Ollama model to use for edit predictions.",
437            field: Box::new(SettingField {
438                pick: |settings| {
439                    settings
440                        .project
441                        .all_languages
442                        .edit_predictions
443                        .as_ref()?
444                        .ollama
445                        .as_ref()?
446                        .model
447                        .as_ref()
448                },
449                write: |settings, value| {
450                    settings
451                        .project
452                        .all_languages
453                        .edit_predictions
454                        .get_or_insert_default()
455                        .ollama
456                        .get_or_insert_default()
457                        .model = value;
458                },
459                json_path: Some("edit_predictions.ollama.model"),
460            }),
461            metadata: Some(Box::new(SettingsFieldMetadata {
462                placeholder: Some(OLLAMA_MODEL_PLACEHOLDER),
463                ..Default::default()
464            })),
465            files: USER,
466        }),
467        SettingsPageItem::SettingItem(SettingItem {
468            title: "Prompt Format",
469            description: "The prompt format to use when requesting predictions. Set to Infer to have the format inferred based on the model name",
470            field: Box::new(SettingField {
471                pick: |settings| {
472                    settings
473                        .project
474                        .all_languages
475                        .edit_predictions
476                        .as_ref()?
477                        .ollama
478                        .as_ref()?
479                        .prompt_format
480                        .as_ref()
481                },
482                write: |settings, value| {
483                    settings
484                        .project
485                        .all_languages
486                        .edit_predictions
487                        .get_or_insert_default()
488                        .ollama
489                        .get_or_insert_default()
490                        .prompt_format = value;
491                },
492                json_path: Some("edit_predictions.ollama.prompt_format"),
493            }),
494            files: USER,
495            metadata: None,
496        }),
497        SettingsPageItem::SettingItem(SettingItem {
498            title: "Max Output Tokens",
499            description: "The maximum number of tokens to generate.",
500            field: Box::new(SettingField {
501                pick: |settings| {
502                    settings
503                        .project
504                        .all_languages
505                        .edit_predictions
506                        .as_ref()?
507                        .ollama
508                        .as_ref()?
509                        .max_output_tokens
510                        .as_ref()
511                },
512                write: |settings, value| {
513                    settings
514                        .project
515                        .all_languages
516                        .edit_predictions
517                        .get_or_insert_default()
518                        .ollama
519                        .get_or_insert_default()
520                        .max_output_tokens = value;
521                },
522                json_path: Some("edit_predictions.ollama.max_output_tokens"),
523            }),
524            metadata: None,
525            files: USER,
526        }),
527    ])
528}
529
530fn open_ai_compatible_settings() -> Box<[SettingsPageItem]> {
531    Box::new([
532        SettingsPageItem::SettingItem(SettingItem {
533            title: "API URL",
534            description: "The URL of your OpenAI-compatible server's completions API.",
535            field: Box::new(SettingField {
536                pick: |settings| {
537                    settings
538                        .project
539                        .all_languages
540                        .edit_predictions
541                        .as_ref()?
542                        .open_ai_compatible_api
543                        .as_ref()?
544                        .api_url
545                        .as_ref()
546                },
547                write: |settings, value| {
548                    settings
549                        .project
550                        .all_languages
551                        .edit_predictions
552                        .get_or_insert_default()
553                        .open_ai_compatible_api
554                        .get_or_insert_default()
555                        .api_url = value;
556                },
557                json_path: Some("edit_predictions.open_ai_compatible_api.api_url"),
558            }),
559            metadata: Some(Box::new(SettingsFieldMetadata {
560                placeholder: Some(OLLAMA_API_URL_PLACEHOLDER),
561                ..Default::default()
562            })),
563            files: USER,
564        }),
565        SettingsPageItem::SettingItem(SettingItem {
566            title: "Model",
567            description: "The model string to pass to the OpenAI-compatible server.",
568            field: Box::new(SettingField {
569                pick: |settings| {
570                    settings
571                        .project
572                        .all_languages
573                        .edit_predictions
574                        .as_ref()?
575                        .open_ai_compatible_api
576                        .as_ref()?
577                        .model
578                        .as_ref()
579                },
580                write: |settings, value| {
581                    settings
582                        .project
583                        .all_languages
584                        .edit_predictions
585                        .get_or_insert_default()
586                        .open_ai_compatible_api
587                        .get_or_insert_default()
588                        .model = value;
589                },
590                json_path: Some("edit_predictions.open_ai_compatible_api.model"),
591            }),
592            metadata: Some(Box::new(SettingsFieldMetadata {
593                placeholder: Some(OLLAMA_MODEL_PLACEHOLDER),
594                ..Default::default()
595            })),
596            files: USER,
597        }),
598        SettingsPageItem::SettingItem(SettingItem {
599            title: "Prompt Format",
600            description: "The prompt format to use when requesting predictions. Set to Infer to have the format inferred based on the model name",
601            field: Box::new(SettingField {
602                pick: |settings| {
603                    settings
604                        .project
605                        .all_languages
606                        .edit_predictions
607                        .as_ref()?
608                        .open_ai_compatible_api
609                        .as_ref()?
610                        .prompt_format
611                        .as_ref()
612                },
613                write: |settings, value| {
614                    settings
615                        .project
616                        .all_languages
617                        .edit_predictions
618                        .get_or_insert_default()
619                        .open_ai_compatible_api
620                        .get_or_insert_default()
621                        .prompt_format = value;
622                },
623                json_path: Some("edit_predictions.open_ai_compatible_api.prompt_format"),
624            }),
625            files: USER,
626            metadata: None,
627        }),
628        SettingsPageItem::SettingItem(SettingItem {
629            title: "Max Output Tokens",
630            description: "The maximum number of tokens to generate.",
631            field: Box::new(SettingField {
632                pick: |settings| {
633                    settings
634                        .project
635                        .all_languages
636                        .edit_predictions
637                        .as_ref()?
638                        .open_ai_compatible_api
639                        .as_ref()?
640                        .max_output_tokens
641                        .as_ref()
642                },
643                write: |settings, value| {
644                    settings
645                        .project
646                        .all_languages
647                        .edit_predictions
648                        .get_or_insert_default()
649                        .open_ai_compatible_api
650                        .get_or_insert_default()
651                        .max_output_tokens = value;
652                },
653                json_path: Some("edit_predictions.open_ai_compatible_api.max_output_tokens"),
654            }),
655            metadata: None,
656            files: USER,
657        }),
658    ])
659}
660
661fn codestral_settings() -> Box<[SettingsPageItem]> {
662    Box::new([
663        SettingsPageItem::SettingItem(SettingItem {
664            title: "API URL",
665            description: "The API URL to use for Codestral.",
666            field: Box::new(SettingField {
667                pick: |settings| {
668                    settings
669                        .project
670                        .all_languages
671                        .edit_predictions
672                        .as_ref()?
673                        .codestral
674                        .as_ref()?
675                        .api_url
676                        .as_ref()
677                },
678                write: |settings, value| {
679                    settings
680                        .project
681                        .all_languages
682                        .edit_predictions
683                        .get_or_insert_default()
684                        .codestral
685                        .get_or_insert_default()
686                        .api_url = value;
687                },
688                json_path: Some("edit_predictions.codestral.api_url"),
689            }),
690            metadata: Some(Box::new(SettingsFieldMetadata {
691                placeholder: Some(CODESTRAL_API_URL),
692                ..Default::default()
693            })),
694            files: USER,
695        }),
696        SettingsPageItem::SettingItem(SettingItem {
697            title: "Max Tokens",
698            description: "The maximum number of tokens to generate.",
699            field: Box::new(SettingField {
700                pick: |settings| {
701                    settings
702                        .project
703                        .all_languages
704                        .edit_predictions
705                        .as_ref()?
706                        .codestral
707                        .as_ref()?
708                        .max_tokens
709                        .as_ref()
710                },
711                write: |settings, value| {
712                    settings
713                        .project
714                        .all_languages
715                        .edit_predictions
716                        .get_or_insert_default()
717                        .codestral
718                        .get_or_insert_default()
719                        .max_tokens = value;
720                },
721                json_path: Some("edit_predictions.codestral.max_tokens"),
722            }),
723            metadata: None,
724            files: USER,
725        }),
726        SettingsPageItem::SettingItem(SettingItem {
727            title: "Model",
728            description: "The Codestral model id to use.",
729            field: Box::new(SettingField {
730                pick: |settings| {
731                    settings
732                        .project
733                        .all_languages
734                        .edit_predictions
735                        .as_ref()?
736                        .codestral
737                        .as_ref()?
738                        .model
739                        .as_ref()
740                },
741                write: |settings, value| {
742                    settings
743                        .project
744                        .all_languages
745                        .edit_predictions
746                        .get_or_insert_default()
747                        .codestral
748                        .get_or_insert_default()
749                        .model = value;
750                },
751                json_path: Some("edit_predictions.codestral.model"),
752            }),
753            metadata: Some(Box::new(SettingsFieldMetadata {
754                placeholder: Some("codestral-latest"),
755                ..Default::default()
756            })),
757            files: USER,
758        }),
759    ])
760}
761
762fn render_github_copilot_provider(window: &mut Window, cx: &mut App) -> Option<impl IntoElement> {
763    let configuration_view = window.use_state(cx, |_, cx| {
764        copilot_ui::ConfigurationView::new(
765            move |cx| {
766                if let Some(app_state) = AppState::global(cx).upgrade() {
767                    copilot::GlobalCopilotAuth::try_get_or_init(app_state, cx)
768                        .is_some_and(|copilot| copilot.0.read(cx).is_authenticated())
769                } else {
770                    false
771                }
772            },
773            copilot_ui::ConfigurationMode::EditPrediction,
774            cx,
775        )
776    });
777
778    Some(
779        v_flex()
780            .id("github-copilot")
781            .min_w_0()
782            .pt_8()
783            .gap_1p5()
784            .child(
785                SettingsSectionHeader::new("GitHub Copilot")
786                    .icon(IconName::Copilot)
787                    .no_padding(true),
788            )
789            .child(configuration_view),
790    )
791}