edit_prediction_provider_setup.rs

  1use edit_prediction::{
  2    ApiKeyState, EditPredictionStore, MercuryFeatureFlag, SweepFeatureFlag,
  3    mercury::{MERCURY_CREDENTIALS_URL, mercury_api_token},
  4    sweep_ai::{SWEEP_CREDENTIALS_URL, sweep_api_token},
  5};
  6use feature_flags::FeatureFlagAppExt as _;
  7use gpui::{Entity, ScrollHandle, prelude::*};
  8use language_models::provider::mistral::{CODESTRAL_API_URL, codestral_api_key};
  9use project::Project;
 10use ui::{ButtonLink, ConfiguredApiCard, WithScrollbar, prelude::*};
 11
 12use crate::{
 13    SettingField, SettingItem, SettingsFieldMetadata, SettingsPageItem, SettingsWindow, USER,
 14    components::{SettingsInputField, SettingsSectionHeader},
 15};
 16
 17pub struct EditPredictionSetupPage {
 18    settings_window: Entity<SettingsWindow>,
 19    scroll_handle: ScrollHandle,
 20}
 21
 22impl EditPredictionSetupPage {
 23    pub fn new(settings_window: Entity<SettingsWindow>) -> Self {
 24        Self {
 25            settings_window,
 26            scroll_handle: ScrollHandle::new(),
 27        }
 28    }
 29}
 30
 31impl Render for EditPredictionSetupPage {
 32    fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
 33        let settings_window = self.settings_window.clone();
 34        let project = settings_window
 35            .read(cx)
 36            .original_window
 37            .as_ref()
 38            .and_then(|window| {
 39                window
 40                    .read_with(cx, |workspace, _| workspace.project().clone())
 41                    .ok()
 42            });
 43        let providers = [
 44            project.and_then(|project| {
 45                Some(render_github_copilot_provider(project, window, cx)?.into_any_element())
 46            }),
 47            cx.has_flag::<MercuryFeatureFlag>().then(|| {
 48                render_api_key_provider(
 49                    IconName::Inception,
 50                    "Mercury",
 51                    "https://platform.inceptionlabs.ai/dashboard/api-keys".into(),
 52                    mercury_api_token(cx),
 53                    |_cx| MERCURY_CREDENTIALS_URL,
 54                    None,
 55                    window,
 56                    cx,
 57                )
 58                .into_any_element()
 59            }),
 60            cx.has_flag::<SweepFeatureFlag>().then(|| {
 61                render_api_key_provider(
 62                    IconName::SweepAi,
 63                    "Sweep",
 64                    "https://app.sweep.dev/".into(),
 65                    sweep_api_token(cx),
 66                    |_cx| SWEEP_CREDENTIALS_URL,
 67                    None,
 68                    window,
 69                    cx,
 70                )
 71                .into_any_element()
 72            }),
 73            Some(
 74                render_api_key_provider(
 75                    IconName::AiMistral,
 76                    "Codestral",
 77                    "https://console.mistral.ai/codestral".into(),
 78                    codestral_api_key(cx),
 79                    |cx| language_models::MistralLanguageModelProvider::api_url(cx),
 80                    Some(settings_window.update(cx, |settings_window, cx| {
 81                        let codestral_settings = codestral_settings();
 82                        settings_window
 83                            .render_sub_page_items_section(
 84                                codestral_settings.iter().enumerate(),
 85                                None,
 86                                window,
 87                                cx,
 88                            )
 89                            .into_any_element()
 90                    })),
 91                    window,
 92                    cx,
 93                )
 94                .into_any_element(),
 95            ),
 96        ];
 97
 98        div()
 99            .size_full()
100            .vertical_scrollbar_for(&self.scroll_handle, window, cx)
101            .child(
102                v_flex()
103                    .id("ep-setup-page")
104                    .min_w_0()
105                    .size_full()
106                    .px_8()
107                    .pb_16()
108                    .overflow_y_scroll()
109                    .track_scroll(&self.scroll_handle)
110                    .children(providers.into_iter().flatten()),
111            )
112    }
113}
114
115fn render_api_key_provider(
116    icon: IconName,
117    title: &'static str,
118    link: SharedString,
119    api_key_state: Entity<ApiKeyState>,
120    current_url: fn(&mut App) -> SharedString,
121    additional_fields: Option<AnyElement>,
122    window: &mut Window,
123    cx: &mut Context<EditPredictionSetupPage>,
124) -> impl IntoElement {
125    let weak_page = cx.weak_entity();
126    _ = window.use_keyed_state(title, cx, |_, cx| {
127        let task = api_key_state.update(cx, |key_state, cx| {
128            key_state.load_if_needed(current_url(cx), |state| state, cx)
129        });
130        cx.spawn(async move |_, cx| {
131            task.await.ok();
132            weak_page
133                .update(cx, |_, cx| {
134                    cx.notify();
135                })
136                .ok();
137        })
138    });
139
140    let (has_key, env_var_name, is_from_env_var) = api_key_state.read_with(cx, |state, _| {
141        (
142            state.has_key(),
143            Some(state.env_var_name().clone()),
144            state.is_from_env_var(),
145        )
146    });
147
148    let write_key = move |api_key: Option<String>, cx: &mut App| {
149        api_key_state
150            .update(cx, |key_state, cx| {
151                let url = current_url(cx);
152                key_state.store(url, api_key, |key_state| key_state, cx)
153            })
154            .detach_and_log_err(cx);
155    };
156
157    let base_container = v_flex().id(title).min_w_0().pt_8().gap_1p5();
158    let header = SettingsSectionHeader::new(title)
159        .icon(icon)
160        .no_padding(true);
161    let button_link_label = format!("{} dashboard", title);
162    let description = h_flex()
163        .min_w_0()
164        .gap_0p5()
165        .child(
166            Label::new("Visit the")
167                .size(LabelSize::Small)
168                .color(Color::Muted),
169        )
170        .child(
171            ButtonLink::new(button_link_label, link)
172                .no_icon(true)
173                .label_size(LabelSize::Small)
174                .label_color(Color::Muted),
175        )
176        .child(
177            Label::new("to generate an API key.")
178                .size(LabelSize::Small)
179                .color(Color::Muted),
180        );
181    let configured_card_label = if is_from_env_var {
182        "API Key Set in Environment Variable"
183    } else {
184        "API Key Configured"
185    };
186
187    let container = if has_key {
188        base_container.child(header).child(
189            ConfiguredApiCard::new(configured_card_label)
190                .button_label("Reset Key")
191                .button_tab_index(0)
192                .disabled(is_from_env_var)
193                .when_some(env_var_name, |this, env_var_name| {
194                    this.when(is_from_env_var, |this| {
195                        this.tooltip_label(format!(
196                            "To reset your API key, unset the {} environment variable.",
197                            env_var_name
198                        ))
199                    })
200                })
201                .on_click(move |_, _, cx| {
202                    write_key(None, cx);
203                }),
204        )
205    } else {
206        base_container.child(header).child(
207            h_flex()
208                .pt_2p5()
209                .w_full()
210                .justify_between()
211                .child(
212                    v_flex()
213                        .w_full()
214                        .max_w_1_2()
215                        .child(Label::new("API Key"))
216                        .child(description)
217                        .when_some(env_var_name, |this, env_var_name| {
218                            this.child({
219                                let label = format!(
220                                    "Or set the {} env var and restart Zed.",
221                                    env_var_name.as_ref()
222                                );
223                                Label::new(label).size(LabelSize::Small).color(Color::Muted)
224                            })
225                        }),
226                )
227                .child(
228                    SettingsInputField::new()
229                        .tab_index(0)
230                        .with_placeholder("xxxxxxxxxxxxxxxxxxxx")
231                        .on_confirm(move |api_key, cx| {
232                            write_key(api_key.filter(|key| !key.is_empty()), cx);
233                        }),
234                ),
235        )
236    };
237
238    container.when_some(additional_fields, |this, additional_fields| {
239        this.child(
240            div()
241                .map(|this| if has_key { this.mt_1() } else { this.mt_4() })
242                .px_neg_8()
243                .border_t_1()
244                .border_color(cx.theme().colors().border_variant)
245                .child(additional_fields),
246        )
247    })
248}
249
250fn codestral_settings() -> Box<[SettingsPageItem]> {
251    Box::new([
252        SettingsPageItem::SettingItem(SettingItem {
253            title: "API URL",
254            description: "The API URL to use for Codestral.",
255            field: Box::new(SettingField {
256                pick: |settings| {
257                    settings
258                        .project
259                        .all_languages
260                        .edit_predictions
261                        .as_ref()?
262                        .codestral
263                        .as_ref()?
264                        .api_url
265                        .as_ref()
266                },
267                write: |settings, value| {
268                    settings
269                        .project
270                        .all_languages
271                        .edit_predictions
272                        .get_or_insert_default()
273                        .codestral
274                        .get_or_insert_default()
275                        .api_url = value;
276                },
277                json_path: Some("edit_predictions.codestral.api_url"),
278            }),
279            metadata: Some(Box::new(SettingsFieldMetadata {
280                placeholder: Some(CODESTRAL_API_URL),
281                ..Default::default()
282            })),
283            files: USER,
284        }),
285        SettingsPageItem::SettingItem(SettingItem {
286            title: "Max Tokens",
287            description: "The maximum number of tokens to generate.",
288            field: Box::new(SettingField {
289                pick: |settings| {
290                    settings
291                        .project
292                        .all_languages
293                        .edit_predictions
294                        .as_ref()?
295                        .codestral
296                        .as_ref()?
297                        .max_tokens
298                        .as_ref()
299                },
300                write: |settings, value| {
301                    settings
302                        .project
303                        .all_languages
304                        .edit_predictions
305                        .get_or_insert_default()
306                        .codestral
307                        .get_or_insert_default()
308                        .max_tokens = value;
309                },
310                json_path: Some("edit_predictions.codestral.max_tokens"),
311            }),
312            metadata: None,
313            files: USER,
314        }),
315        SettingsPageItem::SettingItem(SettingItem {
316            title: "Model",
317            description: "The Codestral model id to use.",
318            field: Box::new(SettingField {
319                pick: |settings| {
320                    settings
321                        .project
322                        .all_languages
323                        .edit_predictions
324                        .as_ref()?
325                        .codestral
326                        .as_ref()?
327                        .model
328                        .as_ref()
329                },
330                write: |settings, value| {
331                    settings
332                        .project
333                        .all_languages
334                        .edit_predictions
335                        .get_or_insert_default()
336                        .codestral
337                        .get_or_insert_default()
338                        .model = value;
339                },
340                json_path: Some("edit_predictions.codestral.model"),
341            }),
342            metadata: Some(Box::new(SettingsFieldMetadata {
343                placeholder: Some("codestral-latest"),
344                ..Default::default()
345            })),
346            files: USER,
347        }),
348    ])
349}
350
351fn render_github_copilot_provider(
352    project: Entity<Project>,
353    window: &mut Window,
354    cx: &mut App,
355) -> Option<impl IntoElement> {
356    let copilot = EditPredictionStore::try_global(cx)?
357        .read(cx)
358        .copilot_for_project(&project);
359    let configuration_view = window.use_state(cx, |_, cx| {
360        copilot_ui::ConfigurationView::new(
361            move |cx| {
362                copilot
363                    .as_ref()
364                    .is_some_and(|copilot| copilot.read(cx).is_authenticated())
365            },
366            copilot_ui::ConfigurationMode::EditPrediction,
367            cx,
368        )
369    });
370
371    Some(
372        v_flex()
373            .id("github-copilot")
374            .min_w_0()
375            .gap_1p5()
376            .child(
377                SettingsSectionHeader::new("GitHub Copilot")
378                    .icon(IconName::Copilot)
379                    .no_padding(true),
380            )
381            .child(configuration_view),
382    )
383}