edit_prediction_provider_setup.rs

  1use codestral::{CODESTRAL_API_URL, codestral_api_key_state, codestral_api_url};
  2use edit_prediction::{
  3    ApiKeyState,
  4    mercury::{MERCURY_CREDENTIALS_URL, mercury_api_token},
  5    open_ai_compatible::{open_ai_compatible_api_token, open_ai_compatible_api_url},
  6};
  7use edit_prediction_ui::{get_available_providers, set_completion_provider};
  8use gpui::{Entity, ScrollHandle, prelude::*};
  9use language::language_settings::AllLanguageSettings;
 10
 11use settings::Settings as _;
 12use ui::{ButtonLink, ConfiguredApiCard, ContextMenu, DropdownMenu, DropdownStyle, prelude::*};
 13use workspace::AppState;
 14use zed_credentials_provider::global as global_credentials_provider;
 15
 16const OLLAMA_API_URL_PLACEHOLDER: &str = "http://localhost:11434";
 17const OLLAMA_MODEL_PLACEHOLDER: &str = "qwen2.5-coder:3b-base";
 18
 19use crate::{
 20    SettingField, SettingItem, SettingsFieldMetadata, SettingsPageItem, SettingsWindow, USER,
 21    components::{SettingsInputField, SettingsSectionHeader},
 22};
 23
 24pub(crate) fn render_edit_prediction_setup_page(
 25    settings_window: &SettingsWindow,
 26    scroll_handle: &ScrollHandle,
 27    window: &mut Window,
 28    cx: &mut Context<SettingsWindow>,
 29) -> AnyElement {
 30    let providers = [
 31        Some(render_provider_dropdown(window, cx)),
 32        render_github_copilot_provider(window, cx).map(IntoElement::into_any_element),
 33        Some(
 34            render_api_key_provider(
 35                IconName::Inception,
 36                "Mercury",
 37                ApiKeyDocs::Link {
 38                    dashboard_url: "https://platform.inceptionlabs.ai/dashboard/api-keys".into(),
 39                },
 40                mercury_api_token(cx),
 41                |_cx| MERCURY_CREDENTIALS_URL,
 42                None,
 43                window,
 44                cx,
 45            )
 46            .into_any_element(),
 47        ),
 48        Some(
 49            render_api_key_provider(
 50                IconName::AiMistral,
 51                "Codestral",
 52                ApiKeyDocs::Link {
 53                    dashboard_url: "https://console.mistral.ai/codestral".into(),
 54                },
 55                codestral_api_key_state(cx),
 56                |cx| codestral_api_url(cx),
 57                Some(
 58                    settings_window
 59                        .render_sub_page_items_section(
 60                            codestral_settings().iter().enumerate(),
 61                            true,
 62                            window,
 63                            cx,
 64                        )
 65                        .into_any_element(),
 66                ),
 67                window,
 68                cx,
 69            )
 70            .into_any_element(),
 71        ),
 72        Some(render_ollama_provider(settings_window, window, cx).into_any_element()),
 73        Some(
 74            render_api_key_provider(
 75                IconName::AiOpenAiCompat,
 76                "OpenAI Compatible API",
 77                ApiKeyDocs::Custom {
 78                    message: "The API key sent as Authorization: Bearer {key}.".into(),
 79                },
 80                open_ai_compatible_api_token(cx),
 81                |cx| open_ai_compatible_api_url(cx),
 82                Some(
 83                    settings_window
 84                        .render_sub_page_items_section(
 85                            open_ai_compatible_settings().iter().enumerate(),
 86                            true,
 87                            window,
 88                            cx,
 89                        )
 90                        .into_any_element(),
 91                ),
 92                window,
 93                cx,
 94            )
 95            .into_any_element(),
 96        ),
 97    ];
 98
 99    div()
100        .size_full()
101        .child(
102            v_flex()
103                .id("ep-setup-page")
104                .min_w_0()
105                .size_full()
106                .px_8()
107                .pb_16()
108                .overflow_y_scroll()
109                .track_scroll(&scroll_handle)
110                .children(providers.into_iter().flatten()),
111        )
112        .into_any_element()
113}
114
115fn render_provider_dropdown(window: &mut Window, cx: &mut App) -> AnyElement {
116    let current_provider = AllLanguageSettings::get_global(cx)
117        .edit_predictions
118        .provider;
119    let current_provider_name = current_provider.display_name().unwrap_or("No provider set");
120
121    let menu = ContextMenu::build(window, cx, move |mut menu, _, cx| {
122        let available_providers = get_available_providers(cx);
123        let fs = <dyn fs::Fs>::global(cx);
124
125        for provider in available_providers {
126            let Some(name) = provider.display_name() else {
127                continue;
128            };
129            let is_current = provider == current_provider;
130
131            menu = menu.toggleable_entry(name, is_current, IconPosition::Start, None, {
132                let fs = fs.clone();
133                move |_, cx| {
134                    set_completion_provider(fs.clone(), cx, provider);
135                }
136            });
137        }
138        menu
139    });
140
141    v_flex()
142        .id("provider-selector")
143        .min_w_0()
144        .gap_1p5()
145        .child(SettingsSectionHeader::new("Active Provider").no_padding(true))
146        .child(
147            h_flex()
148                .pt_2p5()
149                .w_full()
150                .min_w_0()
151                .justify_between()
152                .child(
153                    v_flex()
154                        .w_full()
155                        .min_w_0()
156                        .max_w_1_2()
157                        .child(Label::new("Provider"))
158                        .child(
159                            Label::new("Select which provider to use for edit predictions.")
160                                .size(LabelSize::Small)
161                                .color(Color::Muted),
162                        ),
163                )
164                .child(
165                    DropdownMenu::new("provider-dropdown", current_provider_name, menu)
166                        .tab_index(0)
167                        .style(DropdownStyle::Outlined),
168                ),
169        )
170        .into_any_element()
171}
172
173enum ApiKeyDocs {
174    Link { dashboard_url: SharedString },
175    Custom { message: SharedString },
176}
177
178fn render_api_key_provider(
179    icon: IconName,
180    title: &'static str,
181    docs: ApiKeyDocs,
182    api_key_state: Entity<ApiKeyState>,
183    current_url: fn(&mut App) -> SharedString,
184    additional_fields: Option<AnyElement>,
185    window: &mut Window,
186    cx: &mut Context<SettingsWindow>,
187) -> impl IntoElement {
188    let weak_page = cx.weak_entity();
189    let credentials_provider = global_credentials_provider(cx);
190    _ = window.use_keyed_state(current_url(cx), cx, |_, cx| {
191        let task = api_key_state.update(cx, |key_state, cx| {
192            key_state.load_if_needed(
193                current_url(cx),
194                |state| state,
195                credentials_provider.clone(),
196                cx,
197            )
198        });
199        cx.spawn(async move |_, cx| {
200            task.await.ok();
201            weak_page
202                .update(cx, |_, cx| {
203                    cx.notify();
204                })
205                .ok();
206        })
207    });
208
209    let (has_key, env_var_name, is_from_env_var) = api_key_state.read_with(cx, |state, _| {
210        (
211            state.has_key(),
212            Some(state.env_var_name().clone()),
213            state.is_from_env_var(),
214        )
215    });
216
217    let write_key = move |api_key: Option<String>, cx: &mut App| {
218        let credentials_provider = global_credentials_provider(cx);
219        api_key_state
220            .update(cx, |key_state, cx| {
221                let url = current_url(cx);
222                key_state.store(
223                    url,
224                    api_key,
225                    |key_state| key_state,
226                    credentials_provider,
227                    cx,
228                )
229            })
230            .detach_and_log_err(cx);
231    };
232
233    let base_container = v_flex().id(title).min_w_0().pt_8().gap_1p5();
234    let header = SettingsSectionHeader::new(title)
235        .icon(icon)
236        .no_padding(true);
237    let button_link_label = format!("{} dashboard", title);
238    let description = match docs {
239        ApiKeyDocs::Custom { message } => div().min_w_0().w_full().child(
240            Label::new(message)
241                .size(LabelSize::Small)
242                .color(Color::Muted),
243        ),
244        ApiKeyDocs::Link { dashboard_url } => h_flex()
245            .w_full()
246            .min_w_0()
247            .flex_wrap()
248            .gap_0p5()
249            .child(
250                Label::new("Visit the")
251                    .size(LabelSize::Small)
252                    .color(Color::Muted),
253            )
254            .child(
255                ButtonLink::new(button_link_label, dashboard_url)
256                    .no_icon(true)
257                    .label_size(LabelSize::Small)
258                    .label_color(Color::Muted),
259            )
260            .child(
261                Label::new("to generate an API key.")
262                    .size(LabelSize::Small)
263                    .color(Color::Muted),
264            ),
265    };
266    let configured_card_label = if is_from_env_var {
267        "API Key Set in Environment Variable"
268    } else {
269        "API Key Configured"
270    };
271
272    let container = if has_key {
273        base_container.child(header).child(
274            ConfiguredApiCard::new(configured_card_label)
275                .button_label("Reset Key")
276                .button_tab_index(0)
277                .disabled(is_from_env_var)
278                .when_some(env_var_name, |this, env_var_name| {
279                    this.when(is_from_env_var, |this| {
280                        this.tooltip_label(format!(
281                            "To reset your API key, unset the {} environment variable.",
282                            env_var_name
283                        ))
284                    })
285                })
286                .on_click(move |_, _, cx| {
287                    write_key(None, cx);
288                }),
289        )
290    } else {
291        base_container.child(header).child(
292            h_flex()
293                .pt_2p5()
294                .w_full()
295                .min_w_0()
296                .justify_between()
297                .child(
298                    v_flex()
299                        .w_full()
300                        .min_w_0()
301                        .max_w_1_2()
302                        .child(Label::new("API Key"))
303                        .child(description)
304                        .when_some(env_var_name, |this, env_var_name| {
305                            this.child({
306                                let label = format!(
307                                    "Or set the {} env var and restart Zed.",
308                                    env_var_name.as_ref()
309                                );
310                                Label::new(label).size(LabelSize::Small).color(Color::Muted)
311                            })
312                        }),
313                )
314                .child(
315                    SettingsInputField::new()
316                        .tab_index(0)
317                        .with_placeholder("xxxxxxxxxxxxxxxxxxxx")
318                        .on_confirm(move |api_key, _window, cx| {
319                            write_key(api_key.filter(|key| !key.is_empty()), cx);
320                        }),
321                ),
322        )
323    };
324
325    container.when_some(additional_fields, |this, additional_fields| {
326        this.child(
327            div()
328                .map(|this| if has_key { this.mt_1() } else { this.mt_4() })
329                .px_neg_8()
330                .border_t_1()
331                .border_color(cx.theme().colors().border_variant)
332                .child(additional_fields),
333        )
334    })
335}
336
337fn render_ollama_provider(
338    settings_window: &SettingsWindow,
339    window: &mut Window,
340    cx: &mut Context<SettingsWindow>,
341) -> impl IntoElement {
342    let ollama_settings = ollama_settings();
343    let additional_fields = settings_window
344        .render_sub_page_items_section(ollama_settings.iter().enumerate(), true, window, cx)
345        .into_any_element();
346
347    v_flex()
348        .id("ollama")
349        .min_w_0()
350        .pt_8()
351        .gap_1p5()
352        .child(
353            SettingsSectionHeader::new("Ollama")
354                .icon(IconName::AiOllama)
355                .no_padding(true),
356        )
357        .child(div().px_neg_8().child(additional_fields))
358}
359
360fn ollama_settings() -> Box<[SettingsPageItem]> {
361    Box::new([
362        SettingsPageItem::SettingItem(SettingItem {
363            title: "API URL",
364            description: "The base URL of your Ollama server.",
365            field: Box::new(SettingField {
366                pick: |settings| {
367                    settings
368                        .project
369                        .all_languages
370                        .edit_predictions
371                        .as_ref()?
372                        .ollama
373                        .as_ref()?
374                        .api_url
375                        .as_ref()
376                },
377                write: |settings, value| {
378                    settings
379                        .project
380                        .all_languages
381                        .edit_predictions
382                        .get_or_insert_default()
383                        .ollama
384                        .get_or_insert_default()
385                        .api_url = value;
386                },
387                json_path: Some("edit_predictions.ollama.api_url"),
388            }),
389            metadata: Some(Box::new(SettingsFieldMetadata {
390                placeholder: Some(OLLAMA_API_URL_PLACEHOLDER),
391                ..Default::default()
392            })),
393            files: USER,
394        }),
395        SettingsPageItem::SettingItem(SettingItem {
396            title: "Model",
397            description: "The Ollama model to use for edit predictions.",
398            field: Box::new(SettingField {
399                pick: |settings| {
400                    settings
401                        .project
402                        .all_languages
403                        .edit_predictions
404                        .as_ref()?
405                        .ollama
406                        .as_ref()?
407                        .model
408                        .as_ref()
409                },
410                write: |settings, value| {
411                    settings
412                        .project
413                        .all_languages
414                        .edit_predictions
415                        .get_or_insert_default()
416                        .ollama
417                        .get_or_insert_default()
418                        .model = value;
419                },
420                json_path: Some("edit_predictions.ollama.model"),
421            }),
422            metadata: Some(Box::new(SettingsFieldMetadata {
423                placeholder: Some(OLLAMA_MODEL_PLACEHOLDER),
424                ..Default::default()
425            })),
426            files: USER,
427        }),
428        SettingsPageItem::SettingItem(SettingItem {
429            title: "Prompt Format",
430            description: "The prompt format to use when requesting predictions. Set to Infer to have the format inferred based on the model name.",
431            field: Box::new(SettingField {
432                pick: |settings| {
433                    settings
434                        .project
435                        .all_languages
436                        .edit_predictions
437                        .as_ref()?
438                        .ollama
439                        .as_ref()?
440                        .prompt_format
441                        .as_ref()
442                },
443                write: |settings, value| {
444                    settings
445                        .project
446                        .all_languages
447                        .edit_predictions
448                        .get_or_insert_default()
449                        .ollama
450                        .get_or_insert_default()
451                        .prompt_format = value;
452                },
453                json_path: Some("edit_predictions.ollama.prompt_format"),
454            }),
455            files: USER,
456            metadata: None,
457        }),
458        SettingsPageItem::SettingItem(SettingItem {
459            title: "Max Output Tokens",
460            description: "The maximum number of tokens to generate.",
461            field: Box::new(SettingField {
462                pick: |settings| {
463                    settings
464                        .project
465                        .all_languages
466                        .edit_predictions
467                        .as_ref()?
468                        .ollama
469                        .as_ref()?
470                        .max_output_tokens
471                        .as_ref()
472                },
473                write: |settings, value| {
474                    settings
475                        .project
476                        .all_languages
477                        .edit_predictions
478                        .get_or_insert_default()
479                        .ollama
480                        .get_or_insert_default()
481                        .max_output_tokens = value;
482                },
483                json_path: Some("edit_predictions.ollama.max_output_tokens"),
484            }),
485            metadata: None,
486            files: USER,
487        }),
488    ])
489}
490
491fn open_ai_compatible_settings() -> Box<[SettingsPageItem]> {
492    Box::new([
493        SettingsPageItem::SettingItem(SettingItem {
494            title: "API URL",
495            description: "The URL of your OpenAI-compatible server's completions API.",
496            field: Box::new(SettingField {
497                pick: |settings| {
498                    settings
499                        .project
500                        .all_languages
501                        .edit_predictions
502                        .as_ref()?
503                        .open_ai_compatible_api
504                        .as_ref()?
505                        .api_url
506                        .as_ref()
507                },
508                write: |settings, value| {
509                    settings
510                        .project
511                        .all_languages
512                        .edit_predictions
513                        .get_or_insert_default()
514                        .open_ai_compatible_api
515                        .get_or_insert_default()
516                        .api_url = value;
517                },
518                json_path: Some("edit_predictions.open_ai_compatible_api.api_url"),
519            }),
520            metadata: Some(Box::new(SettingsFieldMetadata {
521                placeholder: Some(OLLAMA_API_URL_PLACEHOLDER),
522                ..Default::default()
523            })),
524            files: USER,
525        }),
526        SettingsPageItem::SettingItem(SettingItem {
527            title: "Model",
528            description: "The model string to pass to the OpenAI-compatible server.",
529            field: Box::new(SettingField {
530                pick: |settings| {
531                    settings
532                        .project
533                        .all_languages
534                        .edit_predictions
535                        .as_ref()?
536                        .open_ai_compatible_api
537                        .as_ref()?
538                        .model
539                        .as_ref()
540                },
541                write: |settings, value| {
542                    settings
543                        .project
544                        .all_languages
545                        .edit_predictions
546                        .get_or_insert_default()
547                        .open_ai_compatible_api
548                        .get_or_insert_default()
549                        .model = value;
550                },
551                json_path: Some("edit_predictions.open_ai_compatible_api.model"),
552            }),
553            metadata: Some(Box::new(SettingsFieldMetadata {
554                placeholder: Some(OLLAMA_MODEL_PLACEHOLDER),
555                ..Default::default()
556            })),
557            files: USER,
558        }),
559        SettingsPageItem::SettingItem(SettingItem {
560            title: "Prompt Format",
561            description: "The prompt format to use when requesting predictions. Set to Infer to have the format inferred based on the model name.",
562            field: Box::new(SettingField {
563                pick: |settings| {
564                    settings
565                        .project
566                        .all_languages
567                        .edit_predictions
568                        .as_ref()?
569                        .open_ai_compatible_api
570                        .as_ref()?
571                        .prompt_format
572                        .as_ref()
573                },
574                write: |settings, value| {
575                    settings
576                        .project
577                        .all_languages
578                        .edit_predictions
579                        .get_or_insert_default()
580                        .open_ai_compatible_api
581                        .get_or_insert_default()
582                        .prompt_format = value;
583                },
584                json_path: Some("edit_predictions.open_ai_compatible_api.prompt_format"),
585            }),
586            files: USER,
587            metadata: None,
588        }),
589        SettingsPageItem::SettingItem(SettingItem {
590            title: "Max Output Tokens",
591            description: "The maximum number of tokens to generate.",
592            field: Box::new(SettingField {
593                pick: |settings| {
594                    settings
595                        .project
596                        .all_languages
597                        .edit_predictions
598                        .as_ref()?
599                        .open_ai_compatible_api
600                        .as_ref()?
601                        .max_output_tokens
602                        .as_ref()
603                },
604                write: |settings, value| {
605                    settings
606                        .project
607                        .all_languages
608                        .edit_predictions
609                        .get_or_insert_default()
610                        .open_ai_compatible_api
611                        .get_or_insert_default()
612                        .max_output_tokens = value;
613                },
614                json_path: Some("edit_predictions.open_ai_compatible_api.max_output_tokens"),
615            }),
616            metadata: None,
617            files: USER,
618        }),
619    ])
620}
621
622fn codestral_settings() -> Box<[SettingsPageItem]> {
623    Box::new([
624        SettingsPageItem::SettingItem(SettingItem {
625            title: "API URL",
626            description: "The API URL to use for Codestral.",
627            field: Box::new(SettingField {
628                pick: |settings| {
629                    settings
630                        .project
631                        .all_languages
632                        .edit_predictions
633                        .as_ref()?
634                        .codestral
635                        .as_ref()?
636                        .api_url
637                        .as_ref()
638                },
639                write: |settings, value| {
640                    settings
641                        .project
642                        .all_languages
643                        .edit_predictions
644                        .get_or_insert_default()
645                        .codestral
646                        .get_or_insert_default()
647                        .api_url = value;
648                },
649                json_path: Some("edit_predictions.codestral.api_url"),
650            }),
651            metadata: Some(Box::new(SettingsFieldMetadata {
652                placeholder: Some(CODESTRAL_API_URL),
653                ..Default::default()
654            })),
655            files: USER,
656        }),
657        SettingsPageItem::SettingItem(SettingItem {
658            title: "Max Tokens",
659            description: "The maximum number of tokens to generate.",
660            field: Box::new(SettingField {
661                pick: |settings| {
662                    settings
663                        .project
664                        .all_languages
665                        .edit_predictions
666                        .as_ref()?
667                        .codestral
668                        .as_ref()?
669                        .max_tokens
670                        .as_ref()
671                },
672                write: |settings, value| {
673                    settings
674                        .project
675                        .all_languages
676                        .edit_predictions
677                        .get_or_insert_default()
678                        .codestral
679                        .get_or_insert_default()
680                        .max_tokens = value;
681                },
682                json_path: Some("edit_predictions.codestral.max_tokens"),
683            }),
684            metadata: None,
685            files: USER,
686        }),
687        SettingsPageItem::SettingItem(SettingItem {
688            title: "Model",
689            description: "The Codestral model id to use.",
690            field: Box::new(SettingField {
691                pick: |settings| {
692                    settings
693                        .project
694                        .all_languages
695                        .edit_predictions
696                        .as_ref()?
697                        .codestral
698                        .as_ref()?
699                        .model
700                        .as_ref()
701                },
702                write: |settings, value| {
703                    settings
704                        .project
705                        .all_languages
706                        .edit_predictions
707                        .get_or_insert_default()
708                        .codestral
709                        .get_or_insert_default()
710                        .model = value;
711                },
712                json_path: Some("edit_predictions.codestral.model"),
713            }),
714            metadata: Some(Box::new(SettingsFieldMetadata {
715                placeholder: Some("codestral-latest"),
716                ..Default::default()
717            })),
718            files: USER,
719        }),
720    ])
721}
722
723fn render_github_copilot_provider(window: &mut Window, cx: &mut App) -> Option<impl IntoElement> {
724    let configuration_view = window.use_state(cx, |_, cx| {
725        copilot_ui::ConfigurationView::new(
726            move |cx| {
727                let app_state = AppState::global(cx);
728                copilot::GlobalCopilotAuth::try_get_or_init(app_state, cx)
729                    .is_some_and(|copilot| copilot.0.read(cx).is_authenticated())
730            },
731            copilot_ui::ConfigurationMode::EditPrediction,
732            cx,
733        )
734    });
735
736    Some(
737        v_flex()
738            .id("github-copilot")
739            .min_w_0()
740            .pt_8()
741            .gap_1p5()
742            .child(
743                SettingsSectionHeader::new("GitHub Copilot")
744                    .icon(IconName::Copilot)
745                    .no_padding(true),
746            )
747            .child(configuration_view),
748    )
749}