add_llm_provider_modal.rs

  1use std::sync::Arc;
  2
  3use anyhow::Result;
  4use collections::HashSet;
  5use fs::Fs;
  6use gpui::{
  7    DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, Render, ScrollHandle, Task,
  8};
  9use language_model::LanguageModelRegistry;
 10use language_models::provider::open_ai_compatible::{AvailableModel, ModelCapabilities};
 11use settings::{OpenAiCompatibleSettingsContent, update_settings_file};
 12use ui::{
 13    Banner, Checkbox, KeyBinding, Modal, ModalFooter, ModalHeader, Section, ToggleState,
 14    WithScrollbar, prelude::*,
 15};
 16use ui_input::InputField;
 17use workspace::{ModalView, Workspace};
 18
 19fn single_line_input(
 20    label: impl Into<SharedString>,
 21    placeholder: impl Into<SharedString>,
 22    text: Option<&str>,
 23    tab_index: isize,
 24    window: &mut Window,
 25    cx: &mut App,
 26) -> Entity<InputField> {
 27    cx.new(|cx| {
 28        let input = InputField::new(window, cx, placeholder)
 29            .label(label)
 30            .tab_index(tab_index)
 31            .tab_stop(true);
 32
 33        if let Some(text) = text {
 34            input
 35                .editor()
 36                .update(cx, |editor, cx| editor.set_text(text, window, cx));
 37        }
 38        input
 39    })
 40}
 41
 42#[derive(Clone, Copy)]
 43pub enum LlmCompatibleProvider {
 44    OpenAi,
 45}
 46
 47impl LlmCompatibleProvider {
 48    fn name(&self) -> &'static str {
 49        match self {
 50            LlmCompatibleProvider::OpenAi => "OpenAI",
 51        }
 52    }
 53
 54    fn api_url(&self) -> &'static str {
 55        match self {
 56            LlmCompatibleProvider::OpenAi => "https://api.openai.com/v1",
 57        }
 58    }
 59}
 60
 61struct AddLlmProviderInput {
 62    provider_name: Entity<InputField>,
 63    api_url: Entity<InputField>,
 64    api_key: Entity<InputField>,
 65    models: Vec<ModelInput>,
 66}
 67
 68impl AddLlmProviderInput {
 69    fn new(provider: LlmCompatibleProvider, window: &mut Window, cx: &mut App) -> Self {
 70        let provider_name =
 71            single_line_input("Provider Name", provider.name(), None, 1, window, cx);
 72        let api_url = single_line_input("API URL", provider.api_url(), None, 2, window, cx);
 73        let api_key = single_line_input(
 74            "API Key",
 75            "000000000000000000000000000000000000000000000000",
 76            None,
 77            3,
 78            window,
 79            cx,
 80        );
 81
 82        Self {
 83            provider_name,
 84            api_url,
 85            api_key,
 86            models: vec![ModelInput::new(0, window, cx)],
 87        }
 88    }
 89
 90    fn add_model(&mut self, window: &mut Window, cx: &mut App) {
 91        let model_index = self.models.len();
 92        self.models.push(ModelInput::new(model_index, window, cx));
 93    }
 94
 95    fn remove_model(&mut self, index: usize) {
 96        self.models.remove(index);
 97    }
 98}
 99
100struct ModelCapabilityToggles {
101    pub supports_tools: ToggleState,
102    pub supports_images: ToggleState,
103    pub supports_parallel_tool_calls: ToggleState,
104    pub supports_prompt_cache_key: ToggleState,
105    pub supports_chat_completions: ToggleState,
106}
107
108struct ModelInput {
109    name: Entity<InputField>,
110    max_completion_tokens: Entity<InputField>,
111    max_output_tokens: Entity<InputField>,
112    max_tokens: Entity<InputField>,
113    capabilities: ModelCapabilityToggles,
114}
115
116impl ModelInput {
117    fn new(model_index: usize, window: &mut Window, cx: &mut App) -> Self {
118        let base_tab_index = (3 + (model_index * 4)) as isize;
119
120        let model_name = single_line_input(
121            "Model Name",
122            "e.g. gpt-4o, claude-opus-4, gemini-2.5-pro",
123            None,
124            base_tab_index + 1,
125            window,
126            cx,
127        );
128        let max_completion_tokens = single_line_input(
129            "Max Completion Tokens",
130            "200000",
131            Some("200000"),
132            base_tab_index + 2,
133            window,
134            cx,
135        );
136        let max_output_tokens = single_line_input(
137            "Max Output Tokens",
138            "Max Output Tokens",
139            Some("32000"),
140            base_tab_index + 3,
141            window,
142            cx,
143        );
144        let max_tokens = single_line_input(
145            "Max Tokens",
146            "Max Tokens",
147            Some("200000"),
148            base_tab_index + 4,
149            window,
150            cx,
151        );
152
153        let ModelCapabilities {
154            tools,
155            images,
156            parallel_tool_calls,
157            prompt_cache_key,
158            chat_completions,
159        } = ModelCapabilities::default();
160
161        Self {
162            name: model_name,
163            max_completion_tokens,
164            max_output_tokens,
165            max_tokens,
166            capabilities: ModelCapabilityToggles {
167                supports_tools: tools.into(),
168                supports_images: images.into(),
169                supports_parallel_tool_calls: parallel_tool_calls.into(),
170                supports_prompt_cache_key: prompt_cache_key.into(),
171                supports_chat_completions: chat_completions.into(),
172            },
173        }
174    }
175
176    fn parse(&self, cx: &App) -> Result<AvailableModel, SharedString> {
177        let name = self.name.read(cx).text(cx);
178        if name.is_empty() {
179            return Err(SharedString::from("Model Name cannot be empty"));
180        }
181        Ok(AvailableModel {
182            name,
183            display_name: None,
184            max_completion_tokens: Some(
185                self.max_completion_tokens
186                    .read(cx)
187                    .text(cx)
188                    .parse::<u64>()
189                    .map_err(|_| SharedString::from("Max Completion Tokens must be a number"))?,
190            ),
191            max_output_tokens: Some(
192                self.max_output_tokens
193                    .read(cx)
194                    .text(cx)
195                    .parse::<u64>()
196                    .map_err(|_| SharedString::from("Max Output Tokens must be a number"))?,
197            ),
198            max_tokens: self
199                .max_tokens
200                .read(cx)
201                .text(cx)
202                .parse::<u64>()
203                .map_err(|_| SharedString::from("Max Tokens must be a number"))?,
204            capabilities: ModelCapabilities {
205                tools: self.capabilities.supports_tools.selected(),
206                images: self.capabilities.supports_images.selected(),
207                parallel_tool_calls: self.capabilities.supports_parallel_tool_calls.selected(),
208                prompt_cache_key: self.capabilities.supports_prompt_cache_key.selected(),
209                chat_completions: self.capabilities.supports_chat_completions.selected(),
210            },
211        })
212    }
213}
214
215fn save_provider_to_settings(
216    input: &AddLlmProviderInput,
217    cx: &mut App,
218) -> Task<Result<(), SharedString>> {
219    let provider_name: Arc<str> = input.provider_name.read(cx).text(cx).into();
220    if provider_name.is_empty() {
221        return Task::ready(Err("Provider Name cannot be empty".into()));
222    }
223
224    if LanguageModelRegistry::read_global(cx)
225        .providers()
226        .iter()
227        .any(|provider| {
228            provider.id().0.as_ref() == provider_name.as_ref()
229                || provider.name().0.as_ref() == provider_name.as_ref()
230        })
231    {
232        return Task::ready(Err(
233            "Provider Name is already taken by another provider".into()
234        ));
235    }
236
237    let api_url = input.api_url.read(cx).text(cx);
238    if api_url.is_empty() {
239        return Task::ready(Err("API URL cannot be empty".into()));
240    }
241
242    let api_key = input.api_key.read(cx).text(cx);
243    if api_key.is_empty() {
244        return Task::ready(Err("API Key cannot be empty".into()));
245    }
246
247    let mut models = Vec::new();
248    let mut model_names: HashSet<String> = HashSet::default();
249    for model in &input.models {
250        match model.parse(cx) {
251            Ok(model) => {
252                if !model_names.insert(model.name.clone()) {
253                    return Task::ready(Err("Model Names must be unique".into()));
254                }
255                models.push(model)
256            }
257            Err(err) => return Task::ready(Err(err)),
258        }
259    }
260
261    let fs = <dyn Fs>::global(cx);
262    let task = cx.write_credentials(&api_url, "Bearer", api_key.as_bytes());
263    cx.spawn(async move |cx| {
264        task.await
265            .map_err(|_| SharedString::from("Failed to write API key to keychain"))?;
266        cx.update(|cx| {
267            update_settings_file(fs, cx, |settings, _cx| {
268                settings
269                    .language_models
270                    .get_or_insert_default()
271                    .openai_compatible
272                    .get_or_insert_default()
273                    .insert(
274                        provider_name,
275                        OpenAiCompatibleSettingsContent {
276                            api_url,
277                            available_models: models,
278                        },
279                    );
280            });
281        });
282        Ok(())
283    })
284}
285
286pub struct AddLlmProviderModal {
287    provider: LlmCompatibleProvider,
288    input: AddLlmProviderInput,
289    scroll_handle: ScrollHandle,
290    focus_handle: FocusHandle,
291    last_error: Option<SharedString>,
292}
293
294impl AddLlmProviderModal {
295    pub fn toggle(
296        provider: LlmCompatibleProvider,
297        workspace: &mut Workspace,
298        window: &mut Window,
299        cx: &mut Context<Workspace>,
300    ) {
301        workspace.toggle_modal(window, cx, |window, cx| Self::new(provider, window, cx));
302    }
303
304    fn new(provider: LlmCompatibleProvider, window: &mut Window, cx: &mut Context<Self>) -> Self {
305        Self {
306            input: AddLlmProviderInput::new(provider, window, cx),
307            provider,
308            last_error: None,
309            focus_handle: cx.focus_handle(),
310            scroll_handle: ScrollHandle::new(),
311        }
312    }
313
314    fn confirm(&mut self, _: &menu::Confirm, _: &mut Window, cx: &mut Context<Self>) {
315        let task = save_provider_to_settings(&self.input, cx);
316        cx.spawn(async move |this, cx| {
317            let result = task.await;
318            this.update(cx, |this, cx| match result {
319                Ok(_) => {
320                    cx.emit(DismissEvent);
321                }
322                Err(error) => {
323                    this.last_error = Some(error);
324                    cx.notify();
325                }
326            })
327        })
328        .detach_and_log_err(cx);
329    }
330
331    fn cancel(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context<Self>) {
332        cx.emit(DismissEvent);
333    }
334
335    fn render_model_section(&self, cx: &mut Context<Self>) -> impl IntoElement {
336        v_flex()
337            .mt_1()
338            .gap_2()
339            .child(
340                h_flex()
341                    .justify_between()
342                    .child(Label::new("Models").size(LabelSize::Small))
343                    .child(
344                        Button::new("add-model", "Add Model")
345                            .icon(IconName::Plus)
346                            .icon_position(IconPosition::Start)
347                            .icon_size(IconSize::XSmall)
348                            .icon_color(Color::Muted)
349                            .label_size(LabelSize::Small)
350                            .on_click(cx.listener(|this, _, window, cx| {
351                                this.input.add_model(window, cx);
352                                cx.notify();
353                            })),
354                    ),
355            )
356            .children(
357                self.input
358                    .models
359                    .iter()
360                    .enumerate()
361                    .map(|(ix, _)| self.render_model(ix, cx)),
362            )
363    }
364
365    fn render_model(&self, ix: usize, cx: &mut Context<Self>) -> impl IntoElement + use<> {
366        let has_more_than_one_model = self.input.models.len() > 1;
367        let model = &self.input.models[ix];
368
369        v_flex()
370            .p_2()
371            .gap_2()
372            .rounded_sm()
373            .border_1()
374            .border_dashed()
375            .border_color(cx.theme().colors().border.opacity(0.6))
376            .bg(cx.theme().colors().element_active.opacity(0.15))
377            .child(model.name.clone())
378            .child(
379                h_flex()
380                    .gap_2()
381                    .child(model.max_completion_tokens.clone())
382                    .child(model.max_output_tokens.clone()),
383            )
384            .child(model.max_tokens.clone())
385            .child(
386                v_flex()
387                    .gap_1()
388                    .child(
389                        Checkbox::new(("supports-tools", ix), model.capabilities.supports_tools)
390                            .label("Supports tools")
391                            .on_click(cx.listener(move |this, checked, _window, cx| {
392                                this.input.models[ix].capabilities.supports_tools = *checked;
393                                cx.notify();
394                            })),
395                    )
396                    .child(
397                        Checkbox::new(("supports-images", ix), model.capabilities.supports_images)
398                            .label("Supports images")
399                            .on_click(cx.listener(move |this, checked, _window, cx| {
400                                this.input.models[ix].capabilities.supports_images = *checked;
401                                cx.notify();
402                            })),
403                    )
404                    .child(
405                        Checkbox::new(
406                            ("supports-parallel-tool-calls", ix),
407                            model.capabilities.supports_parallel_tool_calls,
408                        )
409                        .label("Supports parallel_tool_calls")
410                        .on_click(cx.listener(
411                            move |this, checked, _window, cx| {
412                                this.input.models[ix]
413                                    .capabilities
414                                    .supports_parallel_tool_calls = *checked;
415                                cx.notify();
416                            },
417                        )),
418                    )
419                    .child(
420                        Checkbox::new(
421                            ("supports-prompt-cache-key", ix),
422                            model.capabilities.supports_prompt_cache_key,
423                        )
424                        .label("Supports prompt_cache_key")
425                        .on_click(cx.listener(
426                            move |this, checked, _window, cx| {
427                                this.input.models[ix].capabilities.supports_prompt_cache_key =
428                                    *checked;
429                                cx.notify();
430                            },
431                        )),
432                    )
433                    .child(
434                        Checkbox::new(
435                            ("supports-chat-completions", ix),
436                            model.capabilities.supports_chat_completions,
437                        )
438                        .label("Supports /chat/completions")
439                        .on_click(cx.listener(
440                            move |this, checked, _window, cx| {
441                                this.input.models[ix].capabilities.supports_chat_completions =
442                                    *checked;
443                                cx.notify();
444                            },
445                        )),
446                    ),
447            )
448            .when(has_more_than_one_model, |this| {
449                this.child(
450                    Button::new(("remove-model", ix), "Remove Model")
451                        .icon(IconName::Trash)
452                        .icon_position(IconPosition::Start)
453                        .icon_size(IconSize::XSmall)
454                        .icon_color(Color::Muted)
455                        .label_size(LabelSize::Small)
456                        .style(ButtonStyle::Outlined)
457                        .full_width()
458                        .on_click(cx.listener(move |this, _, _window, cx| {
459                            this.input.remove_model(ix);
460                            cx.notify();
461                        })),
462                )
463            })
464    }
465
466    fn on_tab(&mut self, _: &menu::SelectNext, window: &mut Window, cx: &mut Context<Self>) {
467        window.focus_next(cx);
468    }
469
470    fn on_tab_prev(
471        &mut self,
472        _: &menu::SelectPrevious,
473        window: &mut Window,
474        cx: &mut Context<Self>,
475    ) {
476        window.focus_prev(cx);
477    }
478}
479
480impl EventEmitter<DismissEvent> for AddLlmProviderModal {}
481
482impl Focusable for AddLlmProviderModal {
483    fn focus_handle(&self, _cx: &App) -> FocusHandle {
484        self.focus_handle.clone()
485    }
486}
487
488impl ModalView for AddLlmProviderModal {}
489
490impl Render for AddLlmProviderModal {
491    fn render(&mut self, window: &mut ui::Window, cx: &mut ui::Context<Self>) -> impl IntoElement {
492        let focus_handle = self.focus_handle(cx);
493
494        let window_size = window.viewport_size();
495        let rem_size = window.rem_size();
496        let is_large_window = window_size.height / rem_size > rems_from_px(600.).0;
497
498        let modal_max_height = if is_large_window {
499            rems_from_px(450.)
500        } else {
501            rems_from_px(200.)
502        };
503
504        v_flex()
505            .id("add-llm-provider-modal")
506            .key_context("AddLlmProviderModal")
507            .w(rems(34.))
508            .elevation_3(cx)
509            .on_action(cx.listener(Self::cancel))
510            .on_action(cx.listener(Self::on_tab))
511            .on_action(cx.listener(Self::on_tab_prev))
512            .capture_any_mouse_down(cx.listener(|this, _, window, cx| {
513                this.focus_handle(cx).focus(window, cx);
514            }))
515            .child(
516                Modal::new("configure-context-server", None)
517                    .header(ModalHeader::new().headline("Add LLM Provider").description(
518                        match self.provider {
519                            LlmCompatibleProvider::OpenAi => {
520                                "This provider will use an OpenAI compatible API."
521                            }
522                        },
523                    ))
524                    .when_some(self.last_error.clone(), |this, error| {
525                        this.section(
526                            Section::new().child(
527                                Banner::new()
528                                    .severity(Severity::Warning)
529                                    .child(div().text_xs().child(error)),
530                            ),
531                        )
532                    })
533                    .child(
534                        div()
535                            .size_full()
536                            .vertical_scrollbar_for(&self.scroll_handle, window, cx)
537                            .child(
538                                v_flex()
539                                    .id("modal_content")
540                                    .size_full()
541                                    .tab_group()
542                                    .max_h(modal_max_height)
543                                    .pl_3()
544                                    .pr_4()
545                                    .gap_2()
546                                    .overflow_y_scroll()
547                                    .track_scroll(&self.scroll_handle)
548                                    .child(self.input.provider_name.clone())
549                                    .child(self.input.api_url.clone())
550                                    .child(self.input.api_key.clone())
551                                    .child(self.render_model_section(cx)),
552                            ),
553                    )
554                    .footer(
555                        ModalFooter::new().end_slot(
556                            h_flex()
557                                .gap_1()
558                                .child(
559                                    Button::new("cancel", "Cancel")
560                                        .key_binding(
561                                            KeyBinding::for_action_in(
562                                                &menu::Cancel,
563                                                &focus_handle,
564                                                cx,
565                                            )
566                                            .map(|kb| kb.size(rems_from_px(12.))),
567                                        )
568                                        .on_click(cx.listener(|this, _event, window, cx| {
569                                            this.cancel(&menu::Cancel, window, cx)
570                                        })),
571                                )
572                                .child(
573                                    Button::new("save-server", "Save Provider")
574                                        .key_binding(
575                                            KeyBinding::for_action_in(
576                                                &menu::Confirm,
577                                                &focus_handle,
578                                                cx,
579                                            )
580                                            .map(|kb| kb.size(rems_from_px(12.))),
581                                        )
582                                        .on_click(cx.listener(|this, _event, window, cx| {
583                                            this.confirm(&menu::Confirm, window, cx)
584                                        })),
585                                ),
586                        ),
587                    ),
588            )
589    }
590}
591
592#[cfg(test)]
593mod tests {
594    use super::*;
595    use fs::FakeFs;
596    use gpui::{TestAppContext, VisualTestContext};
597    use language_model::{
598        LanguageModelProviderId, LanguageModelProviderName,
599        fake_provider::FakeLanguageModelProvider,
600    };
601    use project::Project;
602    use settings::SettingsStore;
603    use util::path;
604
605    #[gpui::test]
606    async fn test_save_provider_invalid_inputs(cx: &mut TestAppContext) {
607        let cx = setup_test(cx).await;
608
609        assert_eq!(
610            save_provider_validation_errors("", "someurl", "somekey", vec![], cx,).await,
611            Some("Provider Name cannot be empty".into())
612        );
613
614        assert_eq!(
615            save_provider_validation_errors("someprovider", "", "somekey", vec![], cx,).await,
616            Some("API URL cannot be empty".into())
617        );
618
619        assert_eq!(
620            save_provider_validation_errors("someprovider", "someurl", "", vec![], cx,).await,
621            Some("API Key cannot be empty".into())
622        );
623
624        assert_eq!(
625            save_provider_validation_errors(
626                "someprovider",
627                "someurl",
628                "somekey",
629                vec![("", "200000", "200000", "32000")],
630                cx,
631            )
632            .await,
633            Some("Model Name cannot be empty".into())
634        );
635
636        assert_eq!(
637            save_provider_validation_errors(
638                "someprovider",
639                "someurl",
640                "somekey",
641                vec![("somemodel", "abc", "200000", "32000")],
642                cx,
643            )
644            .await,
645            Some("Max Tokens must be a number".into())
646        );
647
648        assert_eq!(
649            save_provider_validation_errors(
650                "someprovider",
651                "someurl",
652                "somekey",
653                vec![("somemodel", "200000", "abc", "32000")],
654                cx,
655            )
656            .await,
657            Some("Max Completion Tokens must be a number".into())
658        );
659
660        assert_eq!(
661            save_provider_validation_errors(
662                "someprovider",
663                "someurl",
664                "somekey",
665                vec![("somemodel", "200000", "200000", "abc")],
666                cx,
667            )
668            .await,
669            Some("Max Output Tokens must be a number".into())
670        );
671
672        assert_eq!(
673            save_provider_validation_errors(
674                "someprovider",
675                "someurl",
676                "somekey",
677                vec![
678                    ("somemodel", "200000", "200000", "32000"),
679                    ("somemodel", "200000", "200000", "32000"),
680                ],
681                cx,
682            )
683            .await,
684            Some("Model Names must be unique".into())
685        );
686    }
687
688    #[gpui::test]
689    async fn test_save_provider_name_conflict(cx: &mut TestAppContext) {
690        let cx = setup_test(cx).await;
691
692        cx.update(|_window, cx| {
693            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
694                registry.register_provider(
695                    Arc::new(FakeLanguageModelProvider::new(
696                        LanguageModelProviderId::new("someprovider"),
697                        LanguageModelProviderName::new("Some Provider"),
698                    )),
699                    cx,
700                );
701            });
702        });
703
704        assert_eq!(
705            save_provider_validation_errors(
706                "someprovider",
707                "someurl",
708                "someapikey",
709                vec![("somemodel", "200000", "200000", "32000")],
710                cx,
711            )
712            .await,
713            Some("Provider Name is already taken by another provider".into())
714        );
715    }
716
717    #[gpui::test]
718    async fn test_model_input_default_capabilities(cx: &mut TestAppContext) {
719        let cx = setup_test(cx).await;
720
721        cx.update(|window, cx| {
722            let model_input = ModelInput::new(0, window, cx);
723            model_input.name.update(cx, |input, cx| {
724                input.editor().update(cx, |editor, cx| {
725                    editor.set_text("somemodel", window, cx);
726                });
727            });
728            assert_eq!(
729                model_input.capabilities.supports_tools,
730                ToggleState::Selected
731            );
732            assert_eq!(
733                model_input.capabilities.supports_images,
734                ToggleState::Unselected
735            );
736            assert_eq!(
737                model_input.capabilities.supports_parallel_tool_calls,
738                ToggleState::Unselected
739            );
740            assert_eq!(
741                model_input.capabilities.supports_prompt_cache_key,
742                ToggleState::Unselected
743            );
744            assert_eq!(
745                model_input.capabilities.supports_chat_completions,
746                ToggleState::Selected
747            );
748
749            let parsed_model = model_input.parse(cx).unwrap();
750            assert!(parsed_model.capabilities.tools);
751            assert!(!parsed_model.capabilities.images);
752            assert!(!parsed_model.capabilities.parallel_tool_calls);
753            assert!(!parsed_model.capabilities.prompt_cache_key);
754            assert!(parsed_model.capabilities.chat_completions);
755        });
756    }
757
758    #[gpui::test]
759    async fn test_model_input_deselected_capabilities(cx: &mut TestAppContext) {
760        let cx = setup_test(cx).await;
761
762        cx.update(|window, cx| {
763            let mut model_input = ModelInput::new(0, window, cx);
764            model_input.name.update(cx, |input, cx| {
765                input.editor().update(cx, |editor, cx| {
766                    editor.set_text("somemodel", window, cx);
767                });
768            });
769
770            model_input.capabilities.supports_tools = ToggleState::Unselected;
771            model_input.capabilities.supports_images = ToggleState::Unselected;
772            model_input.capabilities.supports_parallel_tool_calls = ToggleState::Unselected;
773            model_input.capabilities.supports_prompt_cache_key = ToggleState::Unselected;
774            model_input.capabilities.supports_chat_completions = ToggleState::Unselected;
775
776            let parsed_model = model_input.parse(cx).unwrap();
777            assert!(!parsed_model.capabilities.tools);
778            assert!(!parsed_model.capabilities.images);
779            assert!(!parsed_model.capabilities.parallel_tool_calls);
780            assert!(!parsed_model.capabilities.prompt_cache_key);
781            assert!(!parsed_model.capabilities.chat_completions);
782        });
783    }
784
785    #[gpui::test]
786    async fn test_model_input_with_name_and_capabilities(cx: &mut TestAppContext) {
787        let cx = setup_test(cx).await;
788
789        cx.update(|window, cx| {
790            let mut model_input = ModelInput::new(0, window, cx);
791            model_input.name.update(cx, |input, cx| {
792                input.editor().update(cx, |editor, cx| {
793                    editor.set_text("somemodel", window, cx);
794                });
795            });
796
797            model_input.capabilities.supports_tools = ToggleState::Selected;
798            model_input.capabilities.supports_images = ToggleState::Unselected;
799            model_input.capabilities.supports_parallel_tool_calls = ToggleState::Selected;
800            model_input.capabilities.supports_prompt_cache_key = ToggleState::Unselected;
801            model_input.capabilities.supports_chat_completions = ToggleState::Selected;
802
803            let parsed_model = model_input.parse(cx).unwrap();
804            assert_eq!(parsed_model.name, "somemodel");
805            assert!(parsed_model.capabilities.tools);
806            assert!(!parsed_model.capabilities.images);
807            assert!(parsed_model.capabilities.parallel_tool_calls);
808            assert!(!parsed_model.capabilities.prompt_cache_key);
809            assert!(parsed_model.capabilities.chat_completions);
810        });
811    }
812
813    async fn setup_test(cx: &mut TestAppContext) -> &mut VisualTestContext {
814        cx.update(|cx| {
815            let store = SettingsStore::test(cx);
816            cx.set_global(store);
817            theme::init(theme::LoadThemes::JustBase, cx);
818
819            language_model::init_settings(cx);
820        });
821
822        let fs = FakeFs::new(cx.executor());
823        cx.update(|cx| <dyn Fs>::set_global(fs.clone(), cx));
824        let project = Project::test(fs, [path!("/dir").as_ref()], cx).await;
825        let (_, cx) =
826            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
827
828        cx
829    }
830
831    async fn save_provider_validation_errors(
832        provider_name: &str,
833        api_url: &str,
834        api_key: &str,
835        models: Vec<(&str, &str, &str, &str)>,
836        cx: &mut VisualTestContext,
837    ) -> Option<SharedString> {
838        fn set_text(input: &Entity<InputField>, text: &str, window: &mut Window, cx: &mut App) {
839            input.update(cx, |input, cx| {
840                input.editor().update(cx, |editor, cx| {
841                    editor.set_text(text, window, cx);
842                });
843            });
844        }
845
846        let task = cx.update(|window, cx| {
847            let mut input = AddLlmProviderInput::new(LlmCompatibleProvider::OpenAi, window, cx);
848            set_text(&input.provider_name, provider_name, window, cx);
849            set_text(&input.api_url, api_url, window, cx);
850            set_text(&input.api_key, api_key, window, cx);
851
852            for (i, (name, max_tokens, max_completion_tokens, max_output_tokens)) in
853                models.iter().enumerate()
854            {
855                if i >= input.models.len() {
856                    input.models.push(ModelInput::new(i, window, cx));
857                }
858                let model = &mut input.models[i];
859                set_text(&model.name, name, window, cx);
860                set_text(&model.max_tokens, max_tokens, window, cx);
861                set_text(
862                    &model.max_completion_tokens,
863                    max_completion_tokens,
864                    window,
865                    cx,
866                );
867                set_text(&model.max_output_tokens, max_output_tokens, window, cx);
868            }
869            save_provider_to_settings(&input, cx)
870        });
871
872        task.await.err()
873    }
874}