edit_prediction_provider_setup.rs

  1use edit_prediction::{
  2    ApiKeyState, MercuryFeatureFlag, SweepFeatureFlag,
  3    mercury::{MERCURY_CREDENTIALS_URL, mercury_api_token},
  4    sweep_ai::{SWEEP_CREDENTIALS_URL, sweep_api_token},
  5};
  6use edit_prediction_ui::{get_available_providers, set_completion_provider};
  7use feature_flags::FeatureFlagAppExt as _;
  8use gpui::{Entity, ScrollHandle, prelude::*};
  9use language::language_settings::AllLanguageSettings;
 10use language_models::provider::mistral::{CODESTRAL_API_URL, codestral_api_key};
 11use settings::Settings as _;
 12use ui::{ButtonLink, ConfiguredApiCard, ContextMenu, DropdownMenu, DropdownStyle, prelude::*};
 13use workspace::AppState;
 14
 15use crate::{
 16    SettingField, SettingItem, SettingsFieldMetadata, SettingsPageItem, SettingsWindow, USER,
 17    components::{SettingsInputField, SettingsSectionHeader},
 18};
 19
 20pub(crate) fn render_edit_prediction_setup_page(
 21    settings_window: &SettingsWindow,
 22    scroll_handle: &ScrollHandle,
 23    window: &mut Window,
 24    cx: &mut Context<SettingsWindow>,
 25) -> AnyElement {
 26    let providers = [
 27        Some(render_provider_dropdown(window, cx)),
 28        render_github_copilot_provider(window, cx).map(IntoElement::into_any_element),
 29        cx.has_flag::<MercuryFeatureFlag>().then(|| {
 30            render_api_key_provider(
 31                IconName::Inception,
 32                "Mercury",
 33                "https://platform.inceptionlabs.ai/dashboard/api-keys".into(),
 34                mercury_api_token(cx),
 35                |_cx| MERCURY_CREDENTIALS_URL,
 36                None,
 37                window,
 38                cx,
 39            )
 40            .into_any_element()
 41        }),
 42        cx.has_flag::<SweepFeatureFlag>().then(|| {
 43            render_api_key_provider(
 44                IconName::SweepAi,
 45                "Sweep",
 46                "https://app.sweep.dev/".into(),
 47                sweep_api_token(cx),
 48                |_cx| SWEEP_CREDENTIALS_URL,
 49                None,
 50                window,
 51                cx,
 52            )
 53            .into_any_element()
 54        }),
 55        Some(
 56            render_api_key_provider(
 57                IconName::AiMistral,
 58                "Codestral",
 59                "https://console.mistral.ai/codestral".into(),
 60                codestral_api_key(cx),
 61                |cx| language_models::MistralLanguageModelProvider::api_url(cx),
 62                Some(
 63                    settings_window
 64                        .render_sub_page_items_section(
 65                            codestral_settings().iter().enumerate(),
 66                            window,
 67                            cx,
 68                        )
 69                        .into_any_element(),
 70                ),
 71                window,
 72                cx,
 73            )
 74            .into_any_element(),
 75        ),
 76    ];
 77
 78    div()
 79        .size_full()
 80        .child(
 81            v_flex()
 82                .id("ep-setup-page")
 83                .min_w_0()
 84                .size_full()
 85                .px_8()
 86                .pb_16()
 87                .overflow_y_scroll()
 88                .track_scroll(&scroll_handle)
 89                .children(providers.into_iter().flatten()),
 90        )
 91        .into_any_element()
 92}
 93
 94fn render_provider_dropdown(window: &mut Window, cx: &mut App) -> AnyElement {
 95    let current_provider = AllLanguageSettings::get_global(cx)
 96        .edit_predictions
 97        .provider;
 98    let current_provider_name = current_provider.display_name().unwrap_or("No provider set");
 99
100    let menu = ContextMenu::build(window, cx, move |mut menu, _, cx| {
101        let available_providers = get_available_providers(cx);
102        let fs = <dyn fs::Fs>::global(cx);
103
104        for provider in available_providers {
105            let Some(name) = provider.display_name() else {
106                continue;
107            };
108            let is_current = provider == current_provider;
109
110            menu = menu.toggleable_entry(name, is_current, IconPosition::Start, None, {
111                let fs = fs.clone();
112                move |_, cx| {
113                    set_completion_provider(fs.clone(), cx, provider);
114                }
115            });
116        }
117        menu
118    });
119
120    v_flex()
121        .id("provider-selector")
122        .min_w_0()
123        .gap_1p5()
124        .child(
125            SettingsSectionHeader::new("Active Provider")
126                .icon(IconName::Sparkle)
127                .no_padding(true),
128        )
129        .child(
130            h_flex()
131                .pt_2p5()
132                .w_full()
133                .justify_between()
134                .child(
135                    v_flex()
136                        .w_full()
137                        .max_w_1_2()
138                        .child(Label::new("Provider"))
139                        .child(
140                            Label::new("Select which provider to use for edit predictions.")
141                                .size(LabelSize::Small)
142                                .color(Color::Muted),
143                        ),
144                )
145                .child(
146                    DropdownMenu::new("provider-dropdown", current_provider_name, menu)
147                        .style(DropdownStyle::Outlined),
148                ),
149        )
150        .into_any_element()
151}
152
153fn render_api_key_provider(
154    icon: IconName,
155    title: &'static str,
156    link: SharedString,
157    api_key_state: Entity<ApiKeyState>,
158    current_url: fn(&mut App) -> SharedString,
159    additional_fields: Option<AnyElement>,
160    window: &mut Window,
161    cx: &mut Context<SettingsWindow>,
162) -> impl IntoElement {
163    let weak_page = cx.weak_entity();
164    _ = window.use_keyed_state(title, cx, |_, cx| {
165        let task = api_key_state.update(cx, |key_state, cx| {
166            key_state.load_if_needed(current_url(cx), |state| state, cx)
167        });
168        cx.spawn(async move |_, cx| {
169            task.await.ok();
170            weak_page
171                .update(cx, |_, cx| {
172                    cx.notify();
173                })
174                .ok();
175        })
176    });
177
178    let (has_key, env_var_name, is_from_env_var) = api_key_state.read_with(cx, |state, _| {
179        (
180            state.has_key(),
181            Some(state.env_var_name().clone()),
182            state.is_from_env_var(),
183        )
184    });
185
186    let write_key = move |api_key: Option<String>, cx: &mut App| {
187        api_key_state
188            .update(cx, |key_state, cx| {
189                let url = current_url(cx);
190                key_state.store(url, api_key, |key_state| key_state, cx)
191            })
192            .detach_and_log_err(cx);
193    };
194
195    let base_container = v_flex().id(title).min_w_0().pt_8().gap_1p5();
196    let header = SettingsSectionHeader::new(title)
197        .icon(icon)
198        .no_padding(true);
199    let button_link_label = format!("{} dashboard", title);
200    let description = h_flex()
201        .min_w_0()
202        .gap_0p5()
203        .child(
204            Label::new("Visit the")
205                .size(LabelSize::Small)
206                .color(Color::Muted),
207        )
208        .child(
209            ButtonLink::new(button_link_label, link)
210                .no_icon(true)
211                .label_size(LabelSize::Small)
212                .label_color(Color::Muted),
213        )
214        .child(
215            Label::new("to generate an API key.")
216                .size(LabelSize::Small)
217                .color(Color::Muted),
218        );
219    let configured_card_label = if is_from_env_var {
220        "API Key Set in Environment Variable"
221    } else {
222        "API Key Configured"
223    };
224
225    let container = if has_key {
226        base_container.child(header).child(
227            ConfiguredApiCard::new(configured_card_label)
228                .button_label("Reset Key")
229                .button_tab_index(0)
230                .disabled(is_from_env_var)
231                .when_some(env_var_name, |this, env_var_name| {
232                    this.when(is_from_env_var, |this| {
233                        this.tooltip_label(format!(
234                            "To reset your API key, unset the {} environment variable.",
235                            env_var_name
236                        ))
237                    })
238                })
239                .on_click(move |_, _, cx| {
240                    write_key(None, cx);
241                }),
242        )
243    } else {
244        base_container.child(header).child(
245            h_flex()
246                .pt_2p5()
247                .w_full()
248                .justify_between()
249                .child(
250                    v_flex()
251                        .w_full()
252                        .max_w_1_2()
253                        .child(Label::new("API Key"))
254                        .child(description)
255                        .when_some(env_var_name, |this, env_var_name| {
256                            this.child({
257                                let label = format!(
258                                    "Or set the {} env var and restart Zed.",
259                                    env_var_name.as_ref()
260                                );
261                                Label::new(label).size(LabelSize::Small).color(Color::Muted)
262                            })
263                        }),
264                )
265                .child(
266                    SettingsInputField::new()
267                        .tab_index(0)
268                        .with_placeholder("xxxxxxxxxxxxxxxxxxxx")
269                        .on_confirm(move |api_key, _window, cx| {
270                            write_key(api_key.filter(|key| !key.is_empty()), cx);
271                        }),
272                ),
273        )
274    };
275
276    container.when_some(additional_fields, |this, additional_fields| {
277        this.child(
278            div()
279                .map(|this| if has_key { this.mt_1() } else { this.mt_4() })
280                .px_neg_8()
281                .border_t_1()
282                .border_color(cx.theme().colors().border_variant)
283                .child(additional_fields),
284        )
285    })
286}
287
288fn codestral_settings() -> Box<[SettingsPageItem]> {
289    Box::new([
290        SettingsPageItem::SettingItem(SettingItem {
291            title: "API URL",
292            description: "The API URL to use for Codestral.",
293            field: Box::new(SettingField {
294                pick: |settings| {
295                    settings
296                        .project
297                        .all_languages
298                        .edit_predictions
299                        .as_ref()?
300                        .codestral
301                        .as_ref()?
302                        .api_url
303                        .as_ref()
304                },
305                write: |settings, value| {
306                    settings
307                        .project
308                        .all_languages
309                        .edit_predictions
310                        .get_or_insert_default()
311                        .codestral
312                        .get_or_insert_default()
313                        .api_url = value;
314                },
315                json_path: Some("edit_predictions.codestral.api_url"),
316            }),
317            metadata: Some(Box::new(SettingsFieldMetadata {
318                placeholder: Some(CODESTRAL_API_URL),
319                ..Default::default()
320            })),
321            files: USER,
322        }),
323        SettingsPageItem::SettingItem(SettingItem {
324            title: "Max Tokens",
325            description: "The maximum number of tokens to generate.",
326            field: Box::new(SettingField {
327                pick: |settings| {
328                    settings
329                        .project
330                        .all_languages
331                        .edit_predictions
332                        .as_ref()?
333                        .codestral
334                        .as_ref()?
335                        .max_tokens
336                        .as_ref()
337                },
338                write: |settings, value| {
339                    settings
340                        .project
341                        .all_languages
342                        .edit_predictions
343                        .get_or_insert_default()
344                        .codestral
345                        .get_or_insert_default()
346                        .max_tokens = value;
347                },
348                json_path: Some("edit_predictions.codestral.max_tokens"),
349            }),
350            metadata: None,
351            files: USER,
352        }),
353        SettingsPageItem::SettingItem(SettingItem {
354            title: "Model",
355            description: "The Codestral model id to use.",
356            field: Box::new(SettingField {
357                pick: |settings| {
358                    settings
359                        .project
360                        .all_languages
361                        .edit_predictions
362                        .as_ref()?
363                        .codestral
364                        .as_ref()?
365                        .model
366                        .as_ref()
367                },
368                write: |settings, value| {
369                    settings
370                        .project
371                        .all_languages
372                        .edit_predictions
373                        .get_or_insert_default()
374                        .codestral
375                        .get_or_insert_default()
376                        .model = value;
377                },
378                json_path: Some("edit_predictions.codestral.model"),
379            }),
380            metadata: Some(Box::new(SettingsFieldMetadata {
381                placeholder: Some("codestral-latest"),
382                ..Default::default()
383            })),
384            files: USER,
385        }),
386    ])
387}
388
389fn render_github_copilot_provider(window: &mut Window, cx: &mut App) -> Option<impl IntoElement> {
390    let configuration_view = window.use_state(cx, |_, cx| {
391        copilot_ui::ConfigurationView::new(
392            move |cx| {
393                if let Some(app_state) = AppState::global(cx).upgrade() {
394                    let copilot = copilot::GlobalCopilotAuth::get_or_init(app_state, cx);
395                    copilot.0.read(cx).is_authenticated()
396                } else {
397                    false
398                }
399            },
400            copilot_ui::ConfigurationMode::EditPrediction,
401            cx,
402        )
403    });
404
405    Some(
406        v_flex()
407            .id("github-copilot")
408            .min_w_0()
409            .pt_8()
410            .gap_1p5()
411            .child(
412                SettingsSectionHeader::new("GitHub Copilot")
413                    .icon(IconName::Copilot)
414                    .no_padding(true),
415            )
416            .child(configuration_view),
417    )
418}