edit_prediction_provider_setup.rs

  1use codestral::{CODESTRAL_API_URL, codestral_api_key_state, codestral_api_url};
  2use edit_prediction::{
  3    ApiKeyState,
  4    mercury::{MERCURY_CREDENTIALS_URL, mercury_api_token},
  5    open_ai_compatible::{open_ai_compatible_api_token, open_ai_compatible_api_url},
  6};
  7use edit_prediction_ui::{get_available_providers, set_completion_provider};
  8use gpui::{Entity, ScrollHandle, prelude::*};
  9use language::language_settings::AllLanguageSettings;
 10
 11use settings::Settings as _;
 12use ui::{ButtonLink, ConfiguredApiCard, ContextMenu, DropdownMenu, DropdownStyle, prelude::*};
 13use workspace::AppState;
 14
 15const OLLAMA_API_URL_PLACEHOLDER: &str = "http://localhost:11434";
 16const OLLAMA_MODEL_PLACEHOLDER: &str = "qwen2.5-coder:3b-base";
 17
 18use crate::{
 19    SettingField, SettingItem, SettingsFieldMetadata, SettingsPageItem, SettingsWindow, USER,
 20    components::{SettingsInputField, SettingsSectionHeader},
 21};
 22
 23pub(crate) fn render_edit_prediction_setup_page(
 24    settings_window: &SettingsWindow,
 25    scroll_handle: &ScrollHandle,
 26    window: &mut Window,
 27    cx: &mut Context<SettingsWindow>,
 28) -> AnyElement {
 29    let providers = [
 30        Some(render_provider_dropdown(window, cx)),
 31        render_github_copilot_provider(window, cx).map(IntoElement::into_any_element),
 32        Some(
 33            render_api_key_provider(
 34                IconName::Inception,
 35                "Mercury",
 36                ApiKeyDocs::Link {
 37                    dashboard_url: "https://platform.inceptionlabs.ai/dashboard/api-keys".into(),
 38                },
 39                mercury_api_token(cx),
 40                |_cx| MERCURY_CREDENTIALS_URL,
 41                None,
 42                window,
 43                cx,
 44            )
 45            .into_any_element(),
 46        ),
 47        Some(
 48            render_api_key_provider(
 49                IconName::AiMistral,
 50                "Codestral",
 51                ApiKeyDocs::Link {
 52                    dashboard_url: "https://console.mistral.ai/codestral".into(),
 53                },
 54                codestral_api_key_state(cx),
 55                |cx| codestral_api_url(cx),
 56                Some(
 57                    settings_window
 58                        .render_sub_page_items_section(
 59                            codestral_settings().iter().enumerate(),
 60                            true,
 61                            window,
 62                            cx,
 63                        )
 64                        .into_any_element(),
 65                ),
 66                window,
 67                cx,
 68            )
 69            .into_any_element(),
 70        ),
 71        Some(render_ollama_provider(settings_window, window, cx).into_any_element()),
 72        Some(
 73            render_api_key_provider(
 74                IconName::AiOpenAiCompat,
 75                "OpenAI Compatible API",
 76                ApiKeyDocs::Custom {
 77                    message: "The API key sent as Authorization: Bearer {key}.".into(),
 78                },
 79                open_ai_compatible_api_token(cx),
 80                |cx| open_ai_compatible_api_url(cx),
 81                Some(
 82                    settings_window
 83                        .render_sub_page_items_section(
 84                            open_ai_compatible_settings().iter().enumerate(),
 85                            true,
 86                            window,
 87                            cx,
 88                        )
 89                        .into_any_element(),
 90                ),
 91                window,
 92                cx,
 93            )
 94            .into_any_element(),
 95        ),
 96    ];
 97
 98    div()
 99        .size_full()
100        .child(
101            v_flex()
102                .id("ep-setup-page")
103                .min_w_0()
104                .size_full()
105                .px_8()
106                .pb_16()
107                .overflow_y_scroll()
108                .track_scroll(&scroll_handle)
109                .children(providers.into_iter().flatten()),
110        )
111        .into_any_element()
112}
113
114fn render_provider_dropdown(window: &mut Window, cx: &mut App) -> AnyElement {
115    let current_provider = AllLanguageSettings::get_global(cx)
116        .edit_predictions
117        .provider;
118    let current_provider_name = current_provider.display_name().unwrap_or("No provider set");
119
120    let menu = ContextMenu::build(window, cx, move |mut menu, _, cx| {
121        let available_providers = get_available_providers(cx);
122        let fs = <dyn fs::Fs>::global(cx);
123
124        for provider in available_providers {
125            let Some(name) = provider.display_name() else {
126                continue;
127            };
128            let is_current = provider == current_provider;
129
130            menu = menu.toggleable_entry(name, is_current, IconPosition::Start, None, {
131                let fs = fs.clone();
132                move |_, cx| {
133                    set_completion_provider(fs.clone(), cx, provider);
134                }
135            });
136        }
137        menu
138    });
139
140    v_flex()
141        .id("provider-selector")
142        .min_w_0()
143        .gap_1p5()
144        .child(SettingsSectionHeader::new("Active Provider").no_padding(true))
145        .child(
146            h_flex()
147                .pt_2p5()
148                .w_full()
149                .min_w_0()
150                .justify_between()
151                .child(
152                    v_flex()
153                        .w_full()
154                        .min_w_0()
155                        .max_w_1_2()
156                        .child(Label::new("Provider"))
157                        .child(
158                            Label::new("Select which provider to use for edit predictions.")
159                                .size(LabelSize::Small)
160                                .color(Color::Muted),
161                        ),
162                )
163                .child(
164                    DropdownMenu::new("provider-dropdown", current_provider_name, menu)
165                        .tab_index(0)
166                        .style(DropdownStyle::Outlined),
167                ),
168        )
169        .into_any_element()
170}
171
172enum ApiKeyDocs {
173    Link { dashboard_url: SharedString },
174    Custom { message: SharedString },
175}
176
177fn render_api_key_provider(
178    icon: IconName,
179    title: &'static str,
180    docs: ApiKeyDocs,
181    api_key_state: Entity<ApiKeyState>,
182    current_url: fn(&mut App) -> SharedString,
183    additional_fields: Option<AnyElement>,
184    window: &mut Window,
185    cx: &mut Context<SettingsWindow>,
186) -> impl IntoElement {
187    let weak_page = cx.weak_entity();
188    let credentials_provider = zed_credentials_provider::global(cx);
189    _ = window.use_keyed_state(current_url(cx), cx, |_, cx| {
190        let task = api_key_state.update(cx, |key_state, cx| {
191            key_state.load_if_needed(
192                current_url(cx),
193                |state| state,
194                credentials_provider.clone(),
195                cx,
196            )
197        });
198        cx.spawn(async move |_, cx| {
199            task.await.ok();
200            weak_page
201                .update(cx, |_, cx| {
202                    cx.notify();
203                })
204                .ok();
205        })
206    });
207
208    let (has_key, env_var_name, is_from_env_var) = api_key_state.read_with(cx, |state, _| {
209        (
210            state.has_key(),
211            Some(state.env_var_name().clone()),
212            state.is_from_env_var(),
213        )
214    });
215
216    let write_key = move |api_key: Option<String>, cx: &mut App| {
217        let credentials_provider = zed_credentials_provider::global(cx);
218        api_key_state
219            .update(cx, |key_state, cx| {
220                let url = current_url(cx);
221                key_state.store(
222                    url,
223                    api_key,
224                    |key_state| key_state,
225                    credentials_provider,
226                    cx,
227                )
228            })
229            .detach_and_log_err(cx);
230    };
231
232    let base_container = v_flex().id(title).min_w_0().pt_8().gap_1p5();
233    let header = SettingsSectionHeader::new(title)
234        .icon(icon)
235        .no_padding(true);
236    let button_link_label = format!("{} dashboard", title);
237    let description = match docs {
238        ApiKeyDocs::Custom { message } => div().min_w_0().w_full().child(
239            Label::new(message)
240                .size(LabelSize::Small)
241                .color(Color::Muted),
242        ),
243        ApiKeyDocs::Link { dashboard_url } => h_flex()
244            .w_full()
245            .min_w_0()
246            .flex_wrap()
247            .gap_0p5()
248            .child(
249                Label::new("Visit the")
250                    .size(LabelSize::Small)
251                    .color(Color::Muted),
252            )
253            .child(
254                ButtonLink::new(button_link_label, dashboard_url)
255                    .no_icon(true)
256                    .label_size(LabelSize::Small)
257                    .label_color(Color::Muted),
258            )
259            .child(
260                Label::new("to generate an API key.")
261                    .size(LabelSize::Small)
262                    .color(Color::Muted),
263            ),
264    };
265    let configured_card_label = if is_from_env_var {
266        "API Key Set in Environment Variable"
267    } else {
268        "API Key Configured"
269    };
270
271    let container = if has_key {
272        base_container.child(header).child(
273            ConfiguredApiCard::new(configured_card_label)
274                .button_label("Reset Key")
275                .button_tab_index(0)
276                .disabled(is_from_env_var)
277                .when_some(env_var_name, |this, env_var_name| {
278                    this.when(is_from_env_var, |this| {
279                        this.tooltip_label(format!(
280                            "To reset your API key, unset the {} environment variable.",
281                            env_var_name
282                        ))
283                    })
284                })
285                .on_click(move |_, _, cx| {
286                    write_key(None, cx);
287                }),
288        )
289    } else {
290        base_container.child(header).child(
291            h_flex()
292                .pt_2p5()
293                .w_full()
294                .min_w_0()
295                .justify_between()
296                .child(
297                    v_flex()
298                        .w_full()
299                        .min_w_0()
300                        .max_w_1_2()
301                        .child(Label::new("API Key"))
302                        .child(description)
303                        .when_some(env_var_name, |this, env_var_name| {
304                            this.child({
305                                let label = format!(
306                                    "Or set the {} env var and restart Zed.",
307                                    env_var_name.as_ref()
308                                );
309                                Label::new(label).size(LabelSize::Small).color(Color::Muted)
310                            })
311                        }),
312                )
313                .child(
314                    SettingsInputField::new()
315                        .tab_index(0)
316                        .with_placeholder("xxxxxxxxxxxxxxxxxxxx")
317                        .on_confirm(move |api_key, _window, cx| {
318                            write_key(api_key.filter(|key| !key.is_empty()), cx);
319                        }),
320                ),
321        )
322    };
323
324    container.when_some(additional_fields, |this, additional_fields| {
325        this.child(
326            div()
327                .map(|this| if has_key { this.mt_1() } else { this.mt_4() })
328                .px_neg_8()
329                .border_t_1()
330                .border_color(cx.theme().colors().border_variant)
331                .child(additional_fields),
332        )
333    })
334}
335
336fn render_ollama_provider(
337    settings_window: &SettingsWindow,
338    window: &mut Window,
339    cx: &mut Context<SettingsWindow>,
340) -> impl IntoElement {
341    let ollama_settings = ollama_settings();
342    let additional_fields = settings_window
343        .render_sub_page_items_section(ollama_settings.iter().enumerate(), true, window, cx)
344        .into_any_element();
345
346    v_flex()
347        .id("ollama")
348        .min_w_0()
349        .pt_8()
350        .gap_1p5()
351        .child(
352            SettingsSectionHeader::new("Ollama")
353                .icon(IconName::AiOllama)
354                .no_padding(true),
355        )
356        .child(div().px_neg_8().child(additional_fields))
357}
358
359fn ollama_settings() -> Box<[SettingsPageItem]> {
360    Box::new([
361        SettingsPageItem::SettingItem(SettingItem {
362            title: "API URL",
363            description: "The base URL of your Ollama server.",
364            field: Box::new(SettingField {
365                pick: |settings| {
366                    settings
367                        .project
368                        .all_languages
369                        .edit_predictions
370                        .as_ref()?
371                        .ollama
372                        .as_ref()?
373                        .api_url
374                        .as_ref()
375                },
376                write: |settings, value| {
377                    settings
378                        .project
379                        .all_languages
380                        .edit_predictions
381                        .get_or_insert_default()
382                        .ollama
383                        .get_or_insert_default()
384                        .api_url = value;
385                },
386                json_path: Some("edit_predictions.ollama.api_url"),
387            }),
388            metadata: Some(Box::new(SettingsFieldMetadata {
389                placeholder: Some(OLLAMA_API_URL_PLACEHOLDER),
390                ..Default::default()
391            })),
392            files: USER,
393        }),
394        SettingsPageItem::SettingItem(SettingItem {
395            title: "Model",
396            description: "The Ollama model to use for edit predictions.",
397            field: Box::new(SettingField {
398                pick: |settings| {
399                    settings
400                        .project
401                        .all_languages
402                        .edit_predictions
403                        .as_ref()?
404                        .ollama
405                        .as_ref()?
406                        .model
407                        .as_ref()
408                },
409                write: |settings, value| {
410                    settings
411                        .project
412                        .all_languages
413                        .edit_predictions
414                        .get_or_insert_default()
415                        .ollama
416                        .get_or_insert_default()
417                        .model = value;
418                },
419                json_path: Some("edit_predictions.ollama.model"),
420            }),
421            metadata: Some(Box::new(SettingsFieldMetadata {
422                placeholder: Some(OLLAMA_MODEL_PLACEHOLDER),
423                ..Default::default()
424            })),
425            files: USER,
426        }),
427        SettingsPageItem::SettingItem(SettingItem {
428            title: "Prompt Format",
429            description: "The prompt format to use when requesting predictions. Set to Infer to have the format inferred based on the model name.",
430            field: Box::new(SettingField {
431                pick: |settings| {
432                    settings
433                        .project
434                        .all_languages
435                        .edit_predictions
436                        .as_ref()?
437                        .ollama
438                        .as_ref()?
439                        .prompt_format
440                        .as_ref()
441                },
442                write: |settings, value| {
443                    settings
444                        .project
445                        .all_languages
446                        .edit_predictions
447                        .get_or_insert_default()
448                        .ollama
449                        .get_or_insert_default()
450                        .prompt_format = value;
451                },
452                json_path: Some("edit_predictions.ollama.prompt_format"),
453            }),
454            files: USER,
455            metadata: None,
456        }),
457        SettingsPageItem::SettingItem(SettingItem {
458            title: "Max Output Tokens",
459            description: "The maximum number of tokens to generate.",
460            field: Box::new(SettingField {
461                pick: |settings| {
462                    settings
463                        .project
464                        .all_languages
465                        .edit_predictions
466                        .as_ref()?
467                        .ollama
468                        .as_ref()?
469                        .max_output_tokens
470                        .as_ref()
471                },
472                write: |settings, value| {
473                    settings
474                        .project
475                        .all_languages
476                        .edit_predictions
477                        .get_or_insert_default()
478                        .ollama
479                        .get_or_insert_default()
480                        .max_output_tokens = value;
481                },
482                json_path: Some("edit_predictions.ollama.max_output_tokens"),
483            }),
484            metadata: None,
485            files: USER,
486        }),
487    ])
488}
489
490fn open_ai_compatible_settings() -> Box<[SettingsPageItem]> {
491    Box::new([
492        SettingsPageItem::SettingItem(SettingItem {
493            title: "API URL",
494            description: "The URL of your OpenAI-compatible server's completions API.",
495            field: Box::new(SettingField {
496                pick: |settings| {
497                    settings
498                        .project
499                        .all_languages
500                        .edit_predictions
501                        .as_ref()?
502                        .open_ai_compatible_api
503                        .as_ref()?
504                        .api_url
505                        .as_ref()
506                },
507                write: |settings, value| {
508                    settings
509                        .project
510                        .all_languages
511                        .edit_predictions
512                        .get_or_insert_default()
513                        .open_ai_compatible_api
514                        .get_or_insert_default()
515                        .api_url = value;
516                },
517                json_path: Some("edit_predictions.open_ai_compatible_api.api_url"),
518            }),
519            metadata: Some(Box::new(SettingsFieldMetadata {
520                placeholder: Some(OLLAMA_API_URL_PLACEHOLDER),
521                ..Default::default()
522            })),
523            files: USER,
524        }),
525        SettingsPageItem::SettingItem(SettingItem {
526            title: "Model",
527            description: "The model string to pass to the OpenAI-compatible server.",
528            field: Box::new(SettingField {
529                pick: |settings| {
530                    settings
531                        .project
532                        .all_languages
533                        .edit_predictions
534                        .as_ref()?
535                        .open_ai_compatible_api
536                        .as_ref()?
537                        .model
538                        .as_ref()
539                },
540                write: |settings, value| {
541                    settings
542                        .project
543                        .all_languages
544                        .edit_predictions
545                        .get_or_insert_default()
546                        .open_ai_compatible_api
547                        .get_or_insert_default()
548                        .model = value;
549                },
550                json_path: Some("edit_predictions.open_ai_compatible_api.model"),
551            }),
552            metadata: Some(Box::new(SettingsFieldMetadata {
553                placeholder: Some(OLLAMA_MODEL_PLACEHOLDER),
554                ..Default::default()
555            })),
556            files: USER,
557        }),
558        SettingsPageItem::SettingItem(SettingItem {
559            title: "Prompt Format",
560            description: "The prompt format to use when requesting predictions. Set to Infer to have the format inferred based on the model name.",
561            field: Box::new(SettingField {
562                pick: |settings| {
563                    settings
564                        .project
565                        .all_languages
566                        .edit_predictions
567                        .as_ref()?
568                        .open_ai_compatible_api
569                        .as_ref()?
570                        .prompt_format
571                        .as_ref()
572                },
573                write: |settings, value| {
574                    settings
575                        .project
576                        .all_languages
577                        .edit_predictions
578                        .get_or_insert_default()
579                        .open_ai_compatible_api
580                        .get_or_insert_default()
581                        .prompt_format = value;
582                },
583                json_path: Some("edit_predictions.open_ai_compatible_api.prompt_format"),
584            }),
585            files: USER,
586            metadata: None,
587        }),
588        SettingsPageItem::SettingItem(SettingItem {
589            title: "Max Output Tokens",
590            description: "The maximum number of tokens to generate.",
591            field: Box::new(SettingField {
592                pick: |settings| {
593                    settings
594                        .project
595                        .all_languages
596                        .edit_predictions
597                        .as_ref()?
598                        .open_ai_compatible_api
599                        .as_ref()?
600                        .max_output_tokens
601                        .as_ref()
602                },
603                write: |settings, value| {
604                    settings
605                        .project
606                        .all_languages
607                        .edit_predictions
608                        .get_or_insert_default()
609                        .open_ai_compatible_api
610                        .get_or_insert_default()
611                        .max_output_tokens = value;
612                },
613                json_path: Some("edit_predictions.open_ai_compatible_api.max_output_tokens"),
614            }),
615            metadata: None,
616            files: USER,
617        }),
618    ])
619}
620
621fn codestral_settings() -> Box<[SettingsPageItem]> {
622    Box::new([
623        SettingsPageItem::SettingItem(SettingItem {
624            title: "API URL",
625            description: "The API URL to use for Codestral.",
626            field: Box::new(SettingField {
627                pick: |settings| {
628                    settings
629                        .project
630                        .all_languages
631                        .edit_predictions
632                        .as_ref()?
633                        .codestral
634                        .as_ref()?
635                        .api_url
636                        .as_ref()
637                },
638                write: |settings, value| {
639                    settings
640                        .project
641                        .all_languages
642                        .edit_predictions
643                        .get_or_insert_default()
644                        .codestral
645                        .get_or_insert_default()
646                        .api_url = value;
647                },
648                json_path: Some("edit_predictions.codestral.api_url"),
649            }),
650            metadata: Some(Box::new(SettingsFieldMetadata {
651                placeholder: Some(CODESTRAL_API_URL),
652                ..Default::default()
653            })),
654            files: USER,
655        }),
656        SettingsPageItem::SettingItem(SettingItem {
657            title: "Max Tokens",
658            description: "The maximum number of tokens to generate.",
659            field: Box::new(SettingField {
660                pick: |settings| {
661                    settings
662                        .project
663                        .all_languages
664                        .edit_predictions
665                        .as_ref()?
666                        .codestral
667                        .as_ref()?
668                        .max_tokens
669                        .as_ref()
670                },
671                write: |settings, value| {
672                    settings
673                        .project
674                        .all_languages
675                        .edit_predictions
676                        .get_or_insert_default()
677                        .codestral
678                        .get_or_insert_default()
679                        .max_tokens = value;
680                },
681                json_path: Some("edit_predictions.codestral.max_tokens"),
682            }),
683            metadata: None,
684            files: USER,
685        }),
686        SettingsPageItem::SettingItem(SettingItem {
687            title: "Model",
688            description: "The Codestral model id to use.",
689            field: Box::new(SettingField {
690                pick: |settings| {
691                    settings
692                        .project
693                        .all_languages
694                        .edit_predictions
695                        .as_ref()?
696                        .codestral
697                        .as_ref()?
698                        .model
699                        .as_ref()
700                },
701                write: |settings, value| {
702                    settings
703                        .project
704                        .all_languages
705                        .edit_predictions
706                        .get_or_insert_default()
707                        .codestral
708                        .get_or_insert_default()
709                        .model = value;
710                },
711                json_path: Some("edit_predictions.codestral.model"),
712            }),
713            metadata: Some(Box::new(SettingsFieldMetadata {
714                placeholder: Some("codestral-latest"),
715                ..Default::default()
716            })),
717            files: USER,
718        }),
719    ])
720}
721
722fn render_github_copilot_provider(window: &mut Window, cx: &mut App) -> Option<impl IntoElement> {
723    let configuration_view = window.use_state(cx, |_, cx| {
724        copilot_ui::ConfigurationView::new(
725            move |cx| {
726                let app_state = AppState::global(cx);
727                copilot::GlobalCopilotAuth::try_get_or_init(app_state, cx)
728                    .is_some_and(|copilot| copilot.0.read(cx).is_authenticated())
729            },
730            copilot_ui::ConfigurationMode::EditPrediction,
731            cx,
732        )
733    });
734
735    Some(
736        v_flex()
737            .id("github-copilot")
738            .min_w_0()
739            .pt_8()
740            .gap_1p5()
741            .child(
742                SettingsSectionHeader::new("GitHub Copilot")
743                    .icon(IconName::Copilot)
744                    .no_padding(true),
745            )
746            .child(configuration_view),
747    )
748}