edit_prediction_provider_setup.rs

  1use codestral::{CODESTRAL_API_URL, codestral_api_key_state, codestral_api_url};
  2use edit_prediction::{
  3    ApiKeyState,
  4    mercury::{MERCURY_CREDENTIALS_URL, mercury_api_token},
  5    open_ai_compatible::{open_ai_compatible_api_token, open_ai_compatible_api_url},
  6};
  7use edit_prediction_ui::{get_available_providers, set_completion_provider};
  8use gpui::{Entity, ScrollHandle, prelude::*};
  9use language::language_settings::AllLanguageSettings;
 10
 11use settings::Settings as _;
 12use ui::{ButtonLink, ConfiguredApiCard, ContextMenu, DropdownMenu, DropdownStyle, prelude::*};
 13use workspace::AppState;
 14
 15const OLLAMA_API_URL_PLACEHOLDER: &str = "http://localhost:11434";
 16const OLLAMA_MODEL_PLACEHOLDER: &str = "qwen2.5-coder:3b-base";
 17
 18use crate::{
 19    SettingField, SettingItem, SettingsFieldMetadata, SettingsPageItem, SettingsWindow, USER,
 20    components::{SettingsInputField, SettingsSectionHeader},
 21};
 22
 23pub(crate) fn render_edit_prediction_setup_page(
 24    settings_window: &SettingsWindow,
 25    scroll_handle: &ScrollHandle,
 26    window: &mut Window,
 27    cx: &mut Context<SettingsWindow>,
 28) -> AnyElement {
 29    let providers = [
 30        Some(render_provider_dropdown(window, cx)),
 31        render_github_copilot_provider(window, cx).map(IntoElement::into_any_element),
 32        Some(
 33            render_api_key_provider(
 34                IconName::Inception,
 35                "Mercury",
 36                ApiKeyDocs::Link {
 37                    dashboard_url: "https://platform.inceptionlabs.ai/dashboard/api-keys".into(),
 38                },
 39                mercury_api_token(cx),
 40                |_cx| MERCURY_CREDENTIALS_URL,
 41                None,
 42                window,
 43                cx,
 44            )
 45            .into_any_element(),
 46        ),
 47        Some(
 48            render_api_key_provider(
 49                IconName::AiMistral,
 50                "Codestral",
 51                ApiKeyDocs::Link {
 52                    dashboard_url: "https://console.mistral.ai/codestral".into(),
 53                },
 54                codestral_api_key_state(cx),
 55                |cx| codestral_api_url(cx),
 56                Some(
 57                    settings_window
 58                        .render_sub_page_items_section(
 59                            codestral_settings().iter().enumerate(),
 60                            true,
 61                            window,
 62                            cx,
 63                        )
 64                        .into_any_element(),
 65                ),
 66                window,
 67                cx,
 68            )
 69            .into_any_element(),
 70        ),
 71        Some(render_ollama_provider(settings_window, window, cx).into_any_element()),
 72        Some(
 73            render_api_key_provider(
 74                IconName::AiOpenAiCompat,
 75                "OpenAI Compatible API",
 76                ApiKeyDocs::Custom {
 77                    message: "The API key sent as Authorization: Bearer {key}.".into(),
 78                },
 79                open_ai_compatible_api_token(cx),
 80                |cx| open_ai_compatible_api_url(cx),
 81                Some(
 82                    settings_window
 83                        .render_sub_page_items_section(
 84                            open_ai_compatible_settings().iter().enumerate(),
 85                            true,
 86                            window,
 87                            cx,
 88                        )
 89                        .into_any_element(),
 90                ),
 91                window,
 92                cx,
 93            )
 94            .into_any_element(),
 95        ),
 96    ];
 97
 98    div()
 99        .size_full()
100        .child(
101            v_flex()
102                .id("ep-setup-page")
103                .min_w_0()
104                .size_full()
105                .px_8()
106                .pb_16()
107                .overflow_y_scroll()
108                .track_scroll(&scroll_handle)
109                .children(providers.into_iter().flatten()),
110        )
111        .into_any_element()
112}
113
114fn render_provider_dropdown(window: &mut Window, cx: &mut App) -> AnyElement {
115    let current_provider = AllLanguageSettings::get_global(cx)
116        .edit_predictions
117        .provider;
118    let current_provider_name = current_provider.display_name().unwrap_or("No provider set");
119
120    let menu = ContextMenu::build(window, cx, move |mut menu, _, cx| {
121        let available_providers = get_available_providers(cx);
122        let fs = <dyn fs::Fs>::global(cx);
123
124        for provider in available_providers {
125            let Some(name) = provider.display_name() else {
126                continue;
127            };
128            let is_current = provider == current_provider;
129
130            menu = menu.toggleable_entry(name, is_current, IconPosition::Start, None, {
131                let fs = fs.clone();
132                move |_, cx| {
133                    set_completion_provider(fs.clone(), cx, provider);
134                }
135            });
136        }
137        menu
138    });
139
140    v_flex()
141        .id("provider-selector")
142        .min_w_0()
143        .gap_1p5()
144        .child(SettingsSectionHeader::new("Active Provider").no_padding(true))
145        .child(
146            h_flex()
147                .pt_2p5()
148                .w_full()
149                .min_w_0()
150                .justify_between()
151                .child(
152                    v_flex()
153                        .w_full()
154                        .min_w_0()
155                        .max_w_1_2()
156                        .child(Label::new("Provider"))
157                        .child(
158                            Label::new("Select which provider to use for edit predictions.")
159                                .size(LabelSize::Small)
160                                .color(Color::Muted),
161                        ),
162                )
163                .child(
164                    DropdownMenu::new("provider-dropdown", current_provider_name, menu)
165                        .tab_index(0)
166                        .style(DropdownStyle::Outlined),
167                ),
168        )
169        .into_any_element()
170}
171
172enum ApiKeyDocs {
173    Link { dashboard_url: SharedString },
174    Custom { message: SharedString },
175}
176
177fn render_api_key_provider(
178    icon: IconName,
179    title: &'static str,
180    docs: ApiKeyDocs,
181    api_key_state: Entity<ApiKeyState>,
182    current_url: fn(&mut App) -> SharedString,
183    additional_fields: Option<AnyElement>,
184    window: &mut Window,
185    cx: &mut Context<SettingsWindow>,
186) -> impl IntoElement {
187    let weak_page = cx.weak_entity();
188    _ = window.use_keyed_state(current_url(cx), cx, |_, cx| {
189        let task = api_key_state.update(cx, |key_state, cx| {
190            key_state.load_if_needed(current_url(cx), |state| state, cx)
191        });
192        cx.spawn(async move |_, cx| {
193            task.await.ok();
194            weak_page
195                .update(cx, |_, cx| {
196                    cx.notify();
197                })
198                .ok();
199        })
200    });
201
202    let (has_key, env_var_name, is_from_env_var) = api_key_state.read_with(cx, |state, _| {
203        (
204            state.has_key(),
205            Some(state.env_var_name().clone()),
206            state.is_from_env_var(),
207        )
208    });
209
210    let write_key = move |api_key: Option<String>, cx: &mut App| {
211        api_key_state
212            .update(cx, |key_state, cx| {
213                let url = current_url(cx);
214                key_state.store(url, api_key, |key_state| key_state, cx)
215            })
216            .detach_and_log_err(cx);
217    };
218
219    let base_container = v_flex().id(title).min_w_0().pt_8().gap_1p5();
220    let header = SettingsSectionHeader::new(title)
221        .icon(icon)
222        .no_padding(true);
223    let button_link_label = format!("{} dashboard", title);
224    let description = match docs {
225        ApiKeyDocs::Custom { message } => div().min_w_0().w_full().child(
226            Label::new(message)
227                .size(LabelSize::Small)
228                .color(Color::Muted),
229        ),
230        ApiKeyDocs::Link { dashboard_url } => h_flex()
231            .w_full()
232            .min_w_0()
233            .flex_wrap()
234            .gap_0p5()
235            .child(
236                Label::new("Visit the")
237                    .size(LabelSize::Small)
238                    .color(Color::Muted),
239            )
240            .child(
241                ButtonLink::new(button_link_label, dashboard_url)
242                    .no_icon(true)
243                    .label_size(LabelSize::Small)
244                    .label_color(Color::Muted),
245            )
246            .child(
247                Label::new("to generate an API key.")
248                    .size(LabelSize::Small)
249                    .color(Color::Muted),
250            ),
251    };
252    let configured_card_label = if is_from_env_var {
253        "API Key Set in Environment Variable"
254    } else {
255        "API Key Configured"
256    };
257
258    let container = if has_key {
259        base_container.child(header).child(
260            ConfiguredApiCard::new(configured_card_label)
261                .button_label("Reset Key")
262                .button_tab_index(0)
263                .disabled(is_from_env_var)
264                .when_some(env_var_name, |this, env_var_name| {
265                    this.when(is_from_env_var, |this| {
266                        this.tooltip_label(format!(
267                            "To reset your API key, unset the {} environment variable.",
268                            env_var_name
269                        ))
270                    })
271                })
272                .on_click(move |_, _, cx| {
273                    write_key(None, cx);
274                }),
275        )
276    } else {
277        base_container.child(header).child(
278            h_flex()
279                .pt_2p5()
280                .w_full()
281                .min_w_0()
282                .justify_between()
283                .child(
284                    v_flex()
285                        .w_full()
286                        .min_w_0()
287                        .max_w_1_2()
288                        .child(Label::new("API Key"))
289                        .child(description)
290                        .when_some(env_var_name, |this, env_var_name| {
291                            this.child({
292                                let label = format!(
293                                    "Or set the {} env var and restart Zed.",
294                                    env_var_name.as_ref()
295                                );
296                                Label::new(label).size(LabelSize::Small).color(Color::Muted)
297                            })
298                        }),
299                )
300                .child(
301                    SettingsInputField::new()
302                        .tab_index(0)
303                        .with_placeholder("xxxxxxxxxxxxxxxxxxxx")
304                        .on_confirm(move |api_key, _window, cx| {
305                            write_key(api_key.filter(|key| !key.is_empty()), cx);
306                        }),
307                ),
308        )
309    };
310
311    container.when_some(additional_fields, |this, additional_fields| {
312        this.child(
313            div()
314                .map(|this| if has_key { this.mt_1() } else { this.mt_4() })
315                .px_neg_8()
316                .border_t_1()
317                .border_color(cx.theme().colors().border_variant)
318                .child(additional_fields),
319        )
320    })
321}
322
323fn render_ollama_provider(
324    settings_window: &SettingsWindow,
325    window: &mut Window,
326    cx: &mut Context<SettingsWindow>,
327) -> impl IntoElement {
328    let ollama_settings = ollama_settings();
329    let additional_fields = settings_window
330        .render_sub_page_items_section(ollama_settings.iter().enumerate(), true, window, cx)
331        .into_any_element();
332
333    v_flex()
334        .id("ollama")
335        .min_w_0()
336        .pt_8()
337        .gap_1p5()
338        .child(
339            SettingsSectionHeader::new("Ollama")
340                .icon(IconName::AiOllama)
341                .no_padding(true),
342        )
343        .child(div().px_neg_8().child(additional_fields))
344}
345
346fn ollama_settings() -> Box<[SettingsPageItem]> {
347    Box::new([
348        SettingsPageItem::SettingItem(SettingItem {
349            title: "API URL",
350            description: "The base URL of your Ollama server.",
351            field: Box::new(SettingField {
352                pick: |settings| {
353                    settings
354                        .project
355                        .all_languages
356                        .edit_predictions
357                        .as_ref()?
358                        .ollama
359                        .as_ref()?
360                        .api_url
361                        .as_ref()
362                },
363                write: |settings, value| {
364                    settings
365                        .project
366                        .all_languages
367                        .edit_predictions
368                        .get_or_insert_default()
369                        .ollama
370                        .get_or_insert_default()
371                        .api_url = value;
372                },
373                json_path: Some("edit_predictions.ollama.api_url"),
374            }),
375            metadata: Some(Box::new(SettingsFieldMetadata {
376                placeholder: Some(OLLAMA_API_URL_PLACEHOLDER),
377                ..Default::default()
378            })),
379            files: USER,
380        }),
381        SettingsPageItem::SettingItem(SettingItem {
382            title: "Model",
383            description: "The Ollama model to use for edit predictions.",
384            field: Box::new(SettingField {
385                pick: |settings| {
386                    settings
387                        .project
388                        .all_languages
389                        .edit_predictions
390                        .as_ref()?
391                        .ollama
392                        .as_ref()?
393                        .model
394                        .as_ref()
395                },
396                write: |settings, value| {
397                    settings
398                        .project
399                        .all_languages
400                        .edit_predictions
401                        .get_or_insert_default()
402                        .ollama
403                        .get_or_insert_default()
404                        .model = value;
405                },
406                json_path: Some("edit_predictions.ollama.model"),
407            }),
408            metadata: Some(Box::new(SettingsFieldMetadata {
409                placeholder: Some(OLLAMA_MODEL_PLACEHOLDER),
410                ..Default::default()
411            })),
412            files: USER,
413        }),
414        SettingsPageItem::SettingItem(SettingItem {
415            title: "Prompt Format",
416            description: "The prompt format to use when requesting predictions. Set to Infer to have the format inferred based on the model name.",
417            field: Box::new(SettingField {
418                pick: |settings| {
419                    settings
420                        .project
421                        .all_languages
422                        .edit_predictions
423                        .as_ref()?
424                        .ollama
425                        .as_ref()?
426                        .prompt_format
427                        .as_ref()
428                },
429                write: |settings, value| {
430                    settings
431                        .project
432                        .all_languages
433                        .edit_predictions
434                        .get_or_insert_default()
435                        .ollama
436                        .get_or_insert_default()
437                        .prompt_format = value;
438                },
439                json_path: Some("edit_predictions.ollama.prompt_format"),
440            }),
441            files: USER,
442            metadata: None,
443        }),
444        SettingsPageItem::SettingItem(SettingItem {
445            title: "Max Output Tokens",
446            description: "The maximum number of tokens to generate.",
447            field: Box::new(SettingField {
448                pick: |settings| {
449                    settings
450                        .project
451                        .all_languages
452                        .edit_predictions
453                        .as_ref()?
454                        .ollama
455                        .as_ref()?
456                        .max_output_tokens
457                        .as_ref()
458                },
459                write: |settings, value| {
460                    settings
461                        .project
462                        .all_languages
463                        .edit_predictions
464                        .get_or_insert_default()
465                        .ollama
466                        .get_or_insert_default()
467                        .max_output_tokens = value;
468                },
469                json_path: Some("edit_predictions.ollama.max_output_tokens"),
470            }),
471            metadata: None,
472            files: USER,
473        }),
474    ])
475}
476
477fn open_ai_compatible_settings() -> Box<[SettingsPageItem]> {
478    Box::new([
479        SettingsPageItem::SettingItem(SettingItem {
480            title: "API URL",
481            description: "The URL of your OpenAI-compatible server's completions API.",
482            field: Box::new(SettingField {
483                pick: |settings| {
484                    settings
485                        .project
486                        .all_languages
487                        .edit_predictions
488                        .as_ref()?
489                        .open_ai_compatible_api
490                        .as_ref()?
491                        .api_url
492                        .as_ref()
493                },
494                write: |settings, value| {
495                    settings
496                        .project
497                        .all_languages
498                        .edit_predictions
499                        .get_or_insert_default()
500                        .open_ai_compatible_api
501                        .get_or_insert_default()
502                        .api_url = value;
503                },
504                json_path: Some("edit_predictions.open_ai_compatible_api.api_url"),
505            }),
506            metadata: Some(Box::new(SettingsFieldMetadata {
507                placeholder: Some(OLLAMA_API_URL_PLACEHOLDER),
508                ..Default::default()
509            })),
510            files: USER,
511        }),
512        SettingsPageItem::SettingItem(SettingItem {
513            title: "Model",
514            description: "The model string to pass to the OpenAI-compatible server.",
515            field: Box::new(SettingField {
516                pick: |settings| {
517                    settings
518                        .project
519                        .all_languages
520                        .edit_predictions
521                        .as_ref()?
522                        .open_ai_compatible_api
523                        .as_ref()?
524                        .model
525                        .as_ref()
526                },
527                write: |settings, value| {
528                    settings
529                        .project
530                        .all_languages
531                        .edit_predictions
532                        .get_or_insert_default()
533                        .open_ai_compatible_api
534                        .get_or_insert_default()
535                        .model = value;
536                },
537                json_path: Some("edit_predictions.open_ai_compatible_api.model"),
538            }),
539            metadata: Some(Box::new(SettingsFieldMetadata {
540                placeholder: Some(OLLAMA_MODEL_PLACEHOLDER),
541                ..Default::default()
542            })),
543            files: USER,
544        }),
545        SettingsPageItem::SettingItem(SettingItem {
546            title: "Prompt Format",
547            description: "The prompt format to use when requesting predictions. Set to Infer to have the format inferred based on the model name.",
548            field: Box::new(SettingField {
549                pick: |settings| {
550                    settings
551                        .project
552                        .all_languages
553                        .edit_predictions
554                        .as_ref()?
555                        .open_ai_compatible_api
556                        .as_ref()?
557                        .prompt_format
558                        .as_ref()
559                },
560                write: |settings, value| {
561                    settings
562                        .project
563                        .all_languages
564                        .edit_predictions
565                        .get_or_insert_default()
566                        .open_ai_compatible_api
567                        .get_or_insert_default()
568                        .prompt_format = value;
569                },
570                json_path: Some("edit_predictions.open_ai_compatible_api.prompt_format"),
571            }),
572            files: USER,
573            metadata: None,
574        }),
575        SettingsPageItem::SettingItem(SettingItem {
576            title: "Max Output Tokens",
577            description: "The maximum number of tokens to generate.",
578            field: Box::new(SettingField {
579                pick: |settings| {
580                    settings
581                        .project
582                        .all_languages
583                        .edit_predictions
584                        .as_ref()?
585                        .open_ai_compatible_api
586                        .as_ref()?
587                        .max_output_tokens
588                        .as_ref()
589                },
590                write: |settings, value| {
591                    settings
592                        .project
593                        .all_languages
594                        .edit_predictions
595                        .get_or_insert_default()
596                        .open_ai_compatible_api
597                        .get_or_insert_default()
598                        .max_output_tokens = value;
599                },
600                json_path: Some("edit_predictions.open_ai_compatible_api.max_output_tokens"),
601            }),
602            metadata: None,
603            files: USER,
604        }),
605    ])
606}
607
608fn codestral_settings() -> Box<[SettingsPageItem]> {
609    Box::new([
610        SettingsPageItem::SettingItem(SettingItem {
611            title: "API URL",
612            description: "The API URL to use for Codestral.",
613            field: Box::new(SettingField {
614                pick: |settings| {
615                    settings
616                        .project
617                        .all_languages
618                        .edit_predictions
619                        .as_ref()?
620                        .codestral
621                        .as_ref()?
622                        .api_url
623                        .as_ref()
624                },
625                write: |settings, value| {
626                    settings
627                        .project
628                        .all_languages
629                        .edit_predictions
630                        .get_or_insert_default()
631                        .codestral
632                        .get_or_insert_default()
633                        .api_url = value;
634                },
635                json_path: Some("edit_predictions.codestral.api_url"),
636            }),
637            metadata: Some(Box::new(SettingsFieldMetadata {
638                placeholder: Some(CODESTRAL_API_URL),
639                ..Default::default()
640            })),
641            files: USER,
642        }),
643        SettingsPageItem::SettingItem(SettingItem {
644            title: "Max Tokens",
645            description: "The maximum number of tokens to generate.",
646            field: Box::new(SettingField {
647                pick: |settings| {
648                    settings
649                        .project
650                        .all_languages
651                        .edit_predictions
652                        .as_ref()?
653                        .codestral
654                        .as_ref()?
655                        .max_tokens
656                        .as_ref()
657                },
658                write: |settings, value| {
659                    settings
660                        .project
661                        .all_languages
662                        .edit_predictions
663                        .get_or_insert_default()
664                        .codestral
665                        .get_or_insert_default()
666                        .max_tokens = value;
667                },
668                json_path: Some("edit_predictions.codestral.max_tokens"),
669            }),
670            metadata: None,
671            files: USER,
672        }),
673        SettingsPageItem::SettingItem(SettingItem {
674            title: "Model",
675            description: "The Codestral model id to use.",
676            field: Box::new(SettingField {
677                pick: |settings| {
678                    settings
679                        .project
680                        .all_languages
681                        .edit_predictions
682                        .as_ref()?
683                        .codestral
684                        .as_ref()?
685                        .model
686                        .as_ref()
687                },
688                write: |settings, value| {
689                    settings
690                        .project
691                        .all_languages
692                        .edit_predictions
693                        .get_or_insert_default()
694                        .codestral
695                        .get_or_insert_default()
696                        .model = value;
697                },
698                json_path: Some("edit_predictions.codestral.model"),
699            }),
700            metadata: Some(Box::new(SettingsFieldMetadata {
701                placeholder: Some("codestral-latest"),
702                ..Default::default()
703            })),
704            files: USER,
705        }),
706    ])
707}
708
709fn render_github_copilot_provider(window: &mut Window, cx: &mut App) -> Option<impl IntoElement> {
710    let configuration_view = window.use_state(cx, |_, cx| {
711        copilot_ui::ConfigurationView::new(
712            move |cx| {
713                if let Some(app_state) = AppState::global(cx).upgrade() {
714                    copilot::GlobalCopilotAuth::try_get_or_init(app_state, cx)
715                        .is_some_and(|copilot| copilot.0.read(cx).is_authenticated())
716                } else {
717                    false
718                }
719            },
720            copilot_ui::ConfigurationMode::EditPrediction,
721            cx,
722        )
723    });
724
725    Some(
726        v_flex()
727            .id("github-copilot")
728            .min_w_0()
729            .pt_8()
730            .gap_1p5()
731            .child(
732                SettingsSectionHeader::new("GitHub Copilot")
733                    .icon(IconName::Copilot)
734                    .no_padding(true),
735            )
736            .child(configuration_view),
737    )
738}