add_llm_provider_modal.rs

  1use std::sync::Arc;
  2
  3use anyhow::Result;
  4use collections::HashSet;
  5use fs::Fs;
  6use gpui::{
  7    DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, Render, ScrollHandle, Task,
  8};
  9use language_model::LanguageModelRegistry;
 10use language_models::provider::open_ai_compatible::{AvailableModel, ModelCapabilities};
 11use settings::{OpenAiCompatibleSettingsContent, update_settings_file};
 12use ui::{
 13    Banner, Checkbox, KeyBinding, Modal, ModalFooter, ModalHeader, Section, ToggleState,
 14    WithScrollbar, prelude::*,
 15};
 16use ui_input::InputField;
 17use workspace::{ModalView, Workspace};
 18
 19fn single_line_input(
 20    label: impl Into<SharedString>,
 21    placeholder: &str,
 22    text: Option<&str>,
 23    tab_index: isize,
 24    window: &mut Window,
 25    cx: &mut App,
 26) -> Entity<InputField> {
 27    cx.new(|cx| {
 28        let input = InputField::new(window, cx, placeholder)
 29            .label(label)
 30            .tab_index(tab_index)
 31            .tab_stop(true);
 32
 33        if let Some(text) = text {
 34            input.set_text(text, window, cx);
 35        }
 36        input
 37    })
 38}
 39
 40#[derive(Clone, Copy)]
 41pub enum LlmCompatibleProvider {
 42    OpenAi,
 43}
 44
 45impl LlmCompatibleProvider {
 46    fn name(&self) -> &'static str {
 47        match self {
 48            LlmCompatibleProvider::OpenAi => "OpenAI",
 49        }
 50    }
 51
 52    fn api_url(&self) -> &'static str {
 53        match self {
 54            LlmCompatibleProvider::OpenAi => "https://api.openai.com/v1",
 55        }
 56    }
 57}
 58
 59struct AddLlmProviderInput {
 60    provider_name: Entity<InputField>,
 61    api_url: Entity<InputField>,
 62    api_key: Entity<InputField>,
 63    models: Vec<ModelInput>,
 64}
 65
 66impl AddLlmProviderInput {
 67    fn new(provider: LlmCompatibleProvider, window: &mut Window, cx: &mut App) -> Self {
 68        let provider_name =
 69            single_line_input("Provider Name", provider.name(), None, 1, window, cx);
 70        let api_url = single_line_input("API URL", provider.api_url(), None, 2, window, cx);
 71        let api_key = cx.new(|cx| {
 72            InputField::new(
 73                window,
 74                cx,
 75                "000000000000000000000000000000000000000000000000",
 76            )
 77            .label("API Key")
 78            .tab_index(3)
 79            .tab_stop(true)
 80            .masked(true)
 81        });
 82
 83        Self {
 84            provider_name,
 85            api_url,
 86            api_key,
 87            models: vec![ModelInput::new(0, window, cx)],
 88        }
 89    }
 90
 91    fn add_model(&mut self, window: &mut Window, cx: &mut App) {
 92        let model_index = self.models.len();
 93        self.models.push(ModelInput::new(model_index, window, cx));
 94    }
 95
 96    fn remove_model(&mut self, index: usize) {
 97        self.models.remove(index);
 98    }
 99}
100
101struct ModelCapabilityToggles {
102    pub supports_tools: ToggleState,
103    pub supports_images: ToggleState,
104    pub supports_parallel_tool_calls: ToggleState,
105    pub supports_prompt_cache_key: ToggleState,
106    pub supports_chat_completions: ToggleState,
107}
108
109struct ModelInput {
110    name: Entity<InputField>,
111    max_completion_tokens: Entity<InputField>,
112    max_output_tokens: Entity<InputField>,
113    max_tokens: Entity<InputField>,
114    capabilities: ModelCapabilityToggles,
115}
116
117impl ModelInput {
118    fn new(model_index: usize, window: &mut Window, cx: &mut App) -> Self {
119        let base_tab_index = (3 + (model_index * 4)) as isize;
120
121        let model_name = single_line_input(
122            "Model Name",
123            "e.g. gpt-5, claude-opus-4, gemini-2.5-pro",
124            None,
125            base_tab_index + 1,
126            window,
127            cx,
128        );
129        let max_completion_tokens = single_line_input(
130            "Max Completion Tokens",
131            "200000",
132            Some("200000"),
133            base_tab_index + 2,
134            window,
135            cx,
136        );
137        let max_output_tokens = single_line_input(
138            "Max Output Tokens",
139            "Max Output Tokens",
140            Some("32000"),
141            base_tab_index + 3,
142            window,
143            cx,
144        );
145        let max_tokens = single_line_input(
146            "Max Tokens",
147            "Max Tokens",
148            Some("200000"),
149            base_tab_index + 4,
150            window,
151            cx,
152        );
153
154        let ModelCapabilities {
155            tools,
156            images,
157            parallel_tool_calls,
158            prompt_cache_key,
159            chat_completions,
160        } = ModelCapabilities::default();
161
162        Self {
163            name: model_name,
164            max_completion_tokens,
165            max_output_tokens,
166            max_tokens,
167            capabilities: ModelCapabilityToggles {
168                supports_tools: tools.into(),
169                supports_images: images.into(),
170                supports_parallel_tool_calls: parallel_tool_calls.into(),
171                supports_prompt_cache_key: prompt_cache_key.into(),
172                supports_chat_completions: chat_completions.into(),
173            },
174        }
175    }
176
177    fn parse(&self, cx: &App) -> Result<AvailableModel, SharedString> {
178        let name = self.name.read(cx).text(cx);
179        if name.is_empty() {
180            return Err(SharedString::from("Model Name cannot be empty"));
181        }
182        Ok(AvailableModel {
183            name,
184            display_name: None,
185            max_completion_tokens: Some(
186                self.max_completion_tokens
187                    .read(cx)
188                    .text(cx)
189                    .parse::<u64>()
190                    .map_err(|_| SharedString::from("Max Completion Tokens must be a number"))?,
191            ),
192            max_output_tokens: Some(
193                self.max_output_tokens
194                    .read(cx)
195                    .text(cx)
196                    .parse::<u64>()
197                    .map_err(|_| SharedString::from("Max Output Tokens must be a number"))?,
198            ),
199            max_tokens: self
200                .max_tokens
201                .read(cx)
202                .text(cx)
203                .parse::<u64>()
204                .map_err(|_| SharedString::from("Max Tokens must be a number"))?,
205            capabilities: ModelCapabilities {
206                tools: self.capabilities.supports_tools.selected(),
207                images: self.capabilities.supports_images.selected(),
208                parallel_tool_calls: self.capabilities.supports_parallel_tool_calls.selected(),
209                prompt_cache_key: self.capabilities.supports_prompt_cache_key.selected(),
210                chat_completions: self.capabilities.supports_chat_completions.selected(),
211            },
212        })
213    }
214}
215
216fn save_provider_to_settings(
217    input: &AddLlmProviderInput,
218    cx: &mut App,
219) -> Task<Result<(), SharedString>> {
220    let provider_name: Arc<str> = input.provider_name.read(cx).text(cx).into();
221    if provider_name.is_empty() {
222        return Task::ready(Err("Provider Name cannot be empty".into()));
223    }
224
225    if LanguageModelRegistry::read_global(cx)
226        .providers()
227        .iter()
228        .any(|provider| {
229            provider.id().0.as_ref() == provider_name.as_ref()
230                || provider.name().0.as_ref() == provider_name.as_ref()
231        })
232    {
233        return Task::ready(Err(
234            "Provider Name is already taken by another provider".into()
235        ));
236    }
237
238    let api_url = input.api_url.read(cx).text(cx);
239    if api_url.is_empty() {
240        return Task::ready(Err("API URL cannot be empty".into()));
241    }
242
243    let api_key = input.api_key.read(cx).text(cx);
244    if api_key.is_empty() {
245        return Task::ready(Err("API Key cannot be empty".into()));
246    }
247
248    let mut models = Vec::new();
249    let mut model_names: HashSet<String> = HashSet::default();
250    for model in &input.models {
251        match model.parse(cx) {
252            Ok(model) => {
253                if !model_names.insert(model.name.clone()) {
254                    return Task::ready(Err("Model Names must be unique".into()));
255                }
256                models.push(model)
257            }
258            Err(err) => return Task::ready(Err(err)),
259        }
260    }
261
262    let fs = <dyn Fs>::global(cx);
263    let task = cx.write_credentials(&api_url, "Bearer", api_key.as_bytes());
264    cx.spawn(async move |cx| {
265        task.await
266            .map_err(|_| SharedString::from("Failed to write API key to keychain"))?;
267        cx.update(|cx| {
268            update_settings_file(fs, cx, |settings, _cx| {
269                settings
270                    .language_models
271                    .get_or_insert_default()
272                    .openai_compatible
273                    .get_or_insert_default()
274                    .insert(
275                        provider_name,
276                        OpenAiCompatibleSettingsContent {
277                            api_url,
278                            available_models: models,
279                        },
280                    );
281            });
282        });
283        Ok(())
284    })
285}
286
287pub struct AddLlmProviderModal {
288    provider: LlmCompatibleProvider,
289    input: AddLlmProviderInput,
290    scroll_handle: ScrollHandle,
291    focus_handle: FocusHandle,
292    last_error: Option<SharedString>,
293}
294
295impl AddLlmProviderModal {
296    pub fn toggle(
297        provider: LlmCompatibleProvider,
298        workspace: &mut Workspace,
299        window: &mut Window,
300        cx: &mut Context<Workspace>,
301    ) {
302        workspace.toggle_modal(window, cx, |window, cx| Self::new(provider, window, cx));
303    }
304
305    fn new(provider: LlmCompatibleProvider, window: &mut Window, cx: &mut Context<Self>) -> Self {
306        Self {
307            input: AddLlmProviderInput::new(provider, window, cx),
308            provider,
309            last_error: None,
310            focus_handle: cx.focus_handle(),
311            scroll_handle: ScrollHandle::new(),
312        }
313    }
314
315    fn confirm(&mut self, _: &menu::Confirm, _: &mut Window, cx: &mut Context<Self>) {
316        let task = save_provider_to_settings(&self.input, cx);
317        cx.spawn(async move |this, cx| {
318            let result = task.await;
319            this.update(cx, |this, cx| match result {
320                Ok(_) => {
321                    cx.emit(DismissEvent);
322                }
323                Err(error) => {
324                    this.last_error = Some(error);
325                    cx.notify();
326                }
327            })
328        })
329        .detach_and_log_err(cx);
330    }
331
332    fn cancel(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context<Self>) {
333        cx.emit(DismissEvent);
334    }
335
336    fn render_model_section(&self, cx: &mut Context<Self>) -> impl IntoElement {
337        v_flex()
338            .mt_1()
339            .gap_2()
340            .child(
341                h_flex()
342                    .justify_between()
343                    .child(Label::new("Models").size(LabelSize::Small))
344                    .child(
345                        Button::new("add-model", "Add Model")
346                            .start_icon(
347                                Icon::new(IconName::Plus)
348                                    .size(IconSize::XSmall)
349                                    .color(Color::Muted),
350                            )
351                            .label_size(LabelSize::Small)
352                            .on_click(cx.listener(|this, _, window, cx| {
353                                this.input.add_model(window, cx);
354                                cx.notify();
355                            })),
356                    ),
357            )
358            .children(
359                self.input
360                    .models
361                    .iter()
362                    .enumerate()
363                    .map(|(ix, _)| self.render_model(ix, cx)),
364            )
365    }
366
367    fn render_model(&self, ix: usize, cx: &mut Context<Self>) -> impl IntoElement + use<> {
368        let has_more_than_one_model = self.input.models.len() > 1;
369        let model = &self.input.models[ix];
370
371        v_flex()
372            .p_2()
373            .gap_2()
374            .rounded_sm()
375            .border_1()
376            .border_dashed()
377            .border_color(cx.theme().colors().border.opacity(0.6))
378            .bg(cx.theme().colors().element_active.opacity(0.15))
379            .child(model.name.clone())
380            .child(
381                h_flex()
382                    .gap_2()
383                    .child(model.max_completion_tokens.clone())
384                    .child(model.max_output_tokens.clone()),
385            )
386            .child(model.max_tokens.clone())
387            .child(
388                v_flex()
389                    .gap_1()
390                    .child(
391                        Checkbox::new(("supports-tools", ix), model.capabilities.supports_tools)
392                            .label("Supports tools")
393                            .on_click(cx.listener(move |this, checked, _window, cx| {
394                                this.input.models[ix].capabilities.supports_tools = *checked;
395                                cx.notify();
396                            })),
397                    )
398                    .child(
399                        Checkbox::new(("supports-images", ix), model.capabilities.supports_images)
400                            .label("Supports images")
401                            .on_click(cx.listener(move |this, checked, _window, cx| {
402                                this.input.models[ix].capabilities.supports_images = *checked;
403                                cx.notify();
404                            })),
405                    )
406                    .child(
407                        Checkbox::new(
408                            ("supports-parallel-tool-calls", ix),
409                            model.capabilities.supports_parallel_tool_calls,
410                        )
411                        .label("Supports parallel_tool_calls")
412                        .on_click(cx.listener(
413                            move |this, checked, _window, cx| {
414                                this.input.models[ix]
415                                    .capabilities
416                                    .supports_parallel_tool_calls = *checked;
417                                cx.notify();
418                            },
419                        )),
420                    )
421                    .child(
422                        Checkbox::new(
423                            ("supports-prompt-cache-key", ix),
424                            model.capabilities.supports_prompt_cache_key,
425                        )
426                        .label("Supports prompt_cache_key")
427                        .on_click(cx.listener(
428                            move |this, checked, _window, cx| {
429                                this.input.models[ix].capabilities.supports_prompt_cache_key =
430                                    *checked;
431                                cx.notify();
432                            },
433                        )),
434                    )
435                    .child(
436                        Checkbox::new(
437                            ("supports-chat-completions", ix),
438                            model.capabilities.supports_chat_completions,
439                        )
440                        .label("Supports /chat/completions")
441                        .on_click(cx.listener(
442                            move |this, checked, _window, cx| {
443                                this.input.models[ix].capabilities.supports_chat_completions =
444                                    *checked;
445                                cx.notify();
446                            },
447                        )),
448                    ),
449            )
450            .when(has_more_than_one_model, |this| {
451                this.child(
452                    Button::new(("remove-model", ix), "Remove Model")
453                        .start_icon(
454                            Icon::new(IconName::Trash)
455                                .size(IconSize::XSmall)
456                                .color(Color::Muted),
457                        )
458                        .label_size(LabelSize::Small)
459                        .style(ButtonStyle::Outlined)
460                        .full_width()
461                        .on_click(cx.listener(move |this, _, _window, cx| {
462                            this.input.remove_model(ix);
463                            cx.notify();
464                        })),
465                )
466            })
467    }
468
469    fn on_tab(&mut self, _: &menu::SelectNext, window: &mut Window, cx: &mut Context<Self>) {
470        window.focus_next(cx);
471    }
472
473    fn on_tab_prev(
474        &mut self,
475        _: &menu::SelectPrevious,
476        window: &mut Window,
477        cx: &mut Context<Self>,
478    ) {
479        window.focus_prev(cx);
480    }
481}
482
483impl EventEmitter<DismissEvent> for AddLlmProviderModal {}
484
485impl Focusable for AddLlmProviderModal {
486    fn focus_handle(&self, _cx: &App) -> FocusHandle {
487        self.focus_handle.clone()
488    }
489}
490
491impl ModalView for AddLlmProviderModal {}
492
493impl Render for AddLlmProviderModal {
494    fn render(&mut self, window: &mut ui::Window, cx: &mut ui::Context<Self>) -> impl IntoElement {
495        let focus_handle = self.focus_handle(cx);
496
497        let window_size = window.viewport_size();
498        let rem_size = window.rem_size();
499        let is_large_window = window_size.height / rem_size > rems_from_px(600.).0;
500
501        let modal_max_height = if is_large_window {
502            rems_from_px(450.)
503        } else {
504            rems_from_px(200.)
505        };
506
507        v_flex()
508            .id("add-llm-provider-modal")
509            .key_context("AddLlmProviderModal")
510            .w(rems(34.))
511            .elevation_3(cx)
512            .on_action(cx.listener(Self::cancel))
513            .on_action(cx.listener(Self::on_tab))
514            .on_action(cx.listener(Self::on_tab_prev))
515            .capture_any_mouse_down(cx.listener(|this, _, window, cx| {
516                this.focus_handle(cx).focus(window, cx);
517            }))
518            .child(
519                Modal::new("configure-context-server", None)
520                    .header(ModalHeader::new().headline("Add LLM Provider").description(
521                        match self.provider {
522                            LlmCompatibleProvider::OpenAi => {
523                                "This provider will use an OpenAI compatible API."
524                            }
525                        },
526                    ))
527                    .when_some(self.last_error.clone(), |this, error| {
528                        this.section(
529                            Section::new().child(
530                                Banner::new()
531                                    .severity(Severity::Warning)
532                                    .child(div().text_xs().child(error)),
533                            ),
534                        )
535                    })
536                    .child(
537                        div()
538                            .size_full()
539                            .vertical_scrollbar_for(&self.scroll_handle, window, cx)
540                            .child(
541                                v_flex()
542                                    .id("modal_content")
543                                    .size_full()
544                                    .tab_group()
545                                    .max_h(modal_max_height)
546                                    .pl_3()
547                                    .pr_4()
548                                    .pb_2()
549                                    .gap_2()
550                                    .overflow_y_scroll()
551                                    .track_scroll(&self.scroll_handle)
552                                    .child(self.input.provider_name.clone())
553                                    .child(self.input.api_url.clone())
554                                    .child(self.input.api_key.clone())
555                                    .child(self.render_model_section(cx)),
556                            ),
557                    )
558                    .footer(
559                        ModalFooter::new().end_slot(
560                            h_flex()
561                                .gap_1()
562                                .child(
563                                    Button::new("cancel", "Cancel")
564                                        .key_binding(
565                                            KeyBinding::for_action_in(
566                                                &menu::Cancel,
567                                                &focus_handle,
568                                                cx,
569                                            )
570                                            .map(|kb| kb.size(rems_from_px(12.))),
571                                        )
572                                        .on_click(cx.listener(|this, _event, window, cx| {
573                                            this.cancel(&menu::Cancel, window, cx)
574                                        })),
575                                )
576                                .child(
577                                    Button::new("save-server", "Save Provider")
578                                        .key_binding(
579                                            KeyBinding::for_action_in(
580                                                &menu::Confirm,
581                                                &focus_handle,
582                                                cx,
583                                            )
584                                            .map(|kb| kb.size(rems_from_px(12.))),
585                                        )
586                                        .on_click(cx.listener(|this, _event, window, cx| {
587                                            this.confirm(&menu::Confirm, window, cx)
588                                        })),
589                                ),
590                        ),
591                    ),
592            )
593    }
594}
595
596#[cfg(test)]
597mod tests {
598    use super::*;
599    use fs::FakeFs;
600    use gpui::{TestAppContext, VisualTestContext};
601    use language_model::{
602        LanguageModelProviderId, LanguageModelProviderName,
603        fake_provider::FakeLanguageModelProvider,
604    };
605    use project::Project;
606    use settings::SettingsStore;
607    use util::path;
608    use workspace::MultiWorkspace;
609
610    #[gpui::test]
611    async fn test_save_provider_invalid_inputs(cx: &mut TestAppContext) {
612        let cx = setup_test(cx).await;
613
614        assert_eq!(
615            save_provider_validation_errors("", "someurl", "somekey", vec![], cx,).await,
616            Some("Provider Name cannot be empty".into())
617        );
618
619        assert_eq!(
620            save_provider_validation_errors("someprovider", "", "somekey", vec![], cx,).await,
621            Some("API URL cannot be empty".into())
622        );
623
624        assert_eq!(
625            save_provider_validation_errors("someprovider", "someurl", "", vec![], cx,).await,
626            Some("API Key cannot be empty".into())
627        );
628
629        assert_eq!(
630            save_provider_validation_errors(
631                "someprovider",
632                "someurl",
633                "somekey",
634                vec![("", "200000", "200000", "32000")],
635                cx,
636            )
637            .await,
638            Some("Model Name cannot be empty".into())
639        );
640
641        assert_eq!(
642            save_provider_validation_errors(
643                "someprovider",
644                "someurl",
645                "somekey",
646                vec![("somemodel", "abc", "200000", "32000")],
647                cx,
648            )
649            .await,
650            Some("Max Tokens must be a number".into())
651        );
652
653        assert_eq!(
654            save_provider_validation_errors(
655                "someprovider",
656                "someurl",
657                "somekey",
658                vec![("somemodel", "200000", "abc", "32000")],
659                cx,
660            )
661            .await,
662            Some("Max Completion Tokens must be a number".into())
663        );
664
665        assert_eq!(
666            save_provider_validation_errors(
667                "someprovider",
668                "someurl",
669                "somekey",
670                vec![("somemodel", "200000", "200000", "abc")],
671                cx,
672            )
673            .await,
674            Some("Max Output Tokens must be a number".into())
675        );
676
677        assert_eq!(
678            save_provider_validation_errors(
679                "someprovider",
680                "someurl",
681                "somekey",
682                vec![
683                    ("somemodel", "200000", "200000", "32000"),
684                    ("somemodel", "200000", "200000", "32000"),
685                ],
686                cx,
687            )
688            .await,
689            Some("Model Names must be unique".into())
690        );
691    }
692
693    #[gpui::test]
694    async fn test_save_provider_name_conflict(cx: &mut TestAppContext) {
695        let cx = setup_test(cx).await;
696
697        cx.update(|_window, cx| {
698            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
699                registry.register_provider(
700                    Arc::new(FakeLanguageModelProvider::new(
701                        LanguageModelProviderId::new("someprovider"),
702                        LanguageModelProviderName::new("Some Provider"),
703                    )),
704                    cx,
705                );
706            });
707        });
708
709        assert_eq!(
710            save_provider_validation_errors(
711                "someprovider",
712                "someurl",
713                "someapikey",
714                vec![("somemodel", "200000", "200000", "32000")],
715                cx,
716            )
717            .await,
718            Some("Provider Name is already taken by another provider".into())
719        );
720    }
721
722    #[gpui::test]
723    async fn test_model_input_default_capabilities(cx: &mut TestAppContext) {
724        let cx = setup_test(cx).await;
725
726        cx.update(|window, cx| {
727            let model_input = ModelInput::new(0, window, cx);
728            model_input.name.update(cx, |input, cx| {
729                input.set_text("somemodel", window, cx);
730            });
731            assert_eq!(
732                model_input.capabilities.supports_tools,
733                ToggleState::Selected
734            );
735            assert_eq!(
736                model_input.capabilities.supports_images,
737                ToggleState::Unselected
738            );
739            assert_eq!(
740                model_input.capabilities.supports_parallel_tool_calls,
741                ToggleState::Unselected
742            );
743            assert_eq!(
744                model_input.capabilities.supports_prompt_cache_key,
745                ToggleState::Unselected
746            );
747            assert_eq!(
748                model_input.capabilities.supports_chat_completions,
749                ToggleState::Selected
750            );
751
752            let parsed_model = model_input.parse(cx).unwrap();
753            assert!(parsed_model.capabilities.tools);
754            assert!(!parsed_model.capabilities.images);
755            assert!(!parsed_model.capabilities.parallel_tool_calls);
756            assert!(!parsed_model.capabilities.prompt_cache_key);
757            assert!(parsed_model.capabilities.chat_completions);
758        });
759    }
760
761    #[gpui::test]
762    async fn test_model_input_deselected_capabilities(cx: &mut TestAppContext) {
763        let cx = setup_test(cx).await;
764
765        cx.update(|window, cx| {
766            let mut model_input = ModelInput::new(0, window, cx);
767            model_input.name.update(cx, |input, cx| {
768                input.set_text("somemodel", window, cx);
769            });
770
771            model_input.capabilities.supports_tools = ToggleState::Unselected;
772            model_input.capabilities.supports_images = ToggleState::Unselected;
773            model_input.capabilities.supports_parallel_tool_calls = ToggleState::Unselected;
774            model_input.capabilities.supports_prompt_cache_key = ToggleState::Unselected;
775            model_input.capabilities.supports_chat_completions = ToggleState::Unselected;
776
777            let parsed_model = model_input.parse(cx).unwrap();
778            assert!(!parsed_model.capabilities.tools);
779            assert!(!parsed_model.capabilities.images);
780            assert!(!parsed_model.capabilities.parallel_tool_calls);
781            assert!(!parsed_model.capabilities.prompt_cache_key);
782            assert!(!parsed_model.capabilities.chat_completions);
783        });
784    }
785
786    #[gpui::test]
787    async fn test_model_input_with_name_and_capabilities(cx: &mut TestAppContext) {
788        let cx = setup_test(cx).await;
789
790        cx.update(|window, cx| {
791            let mut model_input = ModelInput::new(0, window, cx);
792            model_input.name.update(cx, |input, cx| {
793                input.set_text("somemodel", window, cx);
794            });
795
796            model_input.capabilities.supports_tools = ToggleState::Selected;
797            model_input.capabilities.supports_images = ToggleState::Unselected;
798            model_input.capabilities.supports_parallel_tool_calls = ToggleState::Selected;
799            model_input.capabilities.supports_prompt_cache_key = ToggleState::Unselected;
800            model_input.capabilities.supports_chat_completions = ToggleState::Selected;
801
802            let parsed_model = model_input.parse(cx).unwrap();
803            assert_eq!(parsed_model.name, "somemodel");
804            assert!(parsed_model.capabilities.tools);
805            assert!(!parsed_model.capabilities.images);
806            assert!(parsed_model.capabilities.parallel_tool_calls);
807            assert!(!parsed_model.capabilities.prompt_cache_key);
808            assert!(parsed_model.capabilities.chat_completions);
809        });
810    }
811
812    async fn setup_test(cx: &mut TestAppContext) -> &mut VisualTestContext {
813        cx.update(|cx| {
814            let store = SettingsStore::test(cx);
815            cx.set_global(store);
816            theme_settings::init(theme::LoadThemes::JustBase, cx);
817
818            language_model::init_settings(cx);
819            editor::init(cx);
820        });
821
822        let fs = FakeFs::new(cx.executor());
823        cx.update(|cx| <dyn Fs>::set_global(fs.clone(), cx));
824        let project = Project::test(fs, [path!("/dir").as_ref()], cx).await;
825        let (multi_workspace, cx) =
826            cx.add_window_view(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx));
827        let _workspace = multi_workspace.read_with(cx, |mw, _| mw.workspace().clone());
828
829        cx
830    }
831
832    async fn save_provider_validation_errors(
833        provider_name: &str,
834        api_url: &str,
835        api_key: &str,
836        models: Vec<(&str, &str, &str, &str)>,
837        cx: &mut VisualTestContext,
838    ) -> Option<SharedString> {
839        fn set_text(input: &Entity<InputField>, text: &str, window: &mut Window, cx: &mut App) {
840            input.update(cx, |input, cx| {
841                input.set_text(text, window, cx);
842            });
843        }
844
845        let task = cx.update(|window, cx| {
846            let mut input = AddLlmProviderInput::new(LlmCompatibleProvider::OpenAi, window, cx);
847            set_text(&input.provider_name, provider_name, window, cx);
848            set_text(&input.api_url, api_url, window, cx);
849            set_text(&input.api_key, api_key, window, cx);
850
851            for (i, (name, max_tokens, max_completion_tokens, max_output_tokens)) in
852                models.iter().enumerate()
853            {
854                if i >= input.models.len() {
855                    input.models.push(ModelInput::new(i, window, cx));
856                }
857                let model = &mut input.models[i];
858                set_text(&model.name, name, window, cx);
859                set_text(&model.max_tokens, max_tokens, window, cx);
860                set_text(
861                    &model.max_completion_tokens,
862                    max_completion_tokens,
863                    window,
864                    cx,
865                );
866                set_text(&model.max_output_tokens, max_output_tokens, window, cx);
867            }
868            save_provider_to_settings(&input, cx)
869        });
870
871        task.await.err()
872    }
873}