add_llm_provider_modal.rs

  1use std::sync::Arc;
  2
  3use anyhow::Result;
  4use collections::HashSet;
  5use fs::Fs;
  6use gpui::{
  7    DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, Render, ScrollHandle, Task,
  8};
  9use language_model::LanguageModelRegistry;
 10use language_models::provider::open_ai_compatible::{AvailableModel, ModelCapabilities};
 11use settings::{OpenAiCompatibleSettingsContent, update_settings_file};
 12use ui::{
 13    Banner, Checkbox, KeyBinding, Modal, ModalFooter, ModalHeader, Section, ToggleState,
 14    WithScrollbar, prelude::*,
 15};
 16use ui_input::InputField;
 17use workspace::{ModalView, Workspace};
 18
 19fn single_line_input(
 20    label: impl Into<SharedString>,
 21    placeholder: &str,
 22    text: Option<&str>,
 23    tab_index: isize,
 24    window: &mut Window,
 25    cx: &mut App,
 26) -> Entity<InputField> {
 27    cx.new(|cx| {
 28        let input = InputField::new(window, cx, placeholder)
 29            .label(label)
 30            .tab_index(tab_index)
 31            .tab_stop(true);
 32
 33        if let Some(text) = text {
 34            input.set_text(text, window, cx);
 35        }
 36        input
 37    })
 38}
 39
 40#[derive(Clone, Copy)]
 41pub enum LlmCompatibleProvider {
 42    OpenAi,
 43}
 44
 45impl LlmCompatibleProvider {
 46    fn name(&self) -> &'static str {
 47        match self {
 48            LlmCompatibleProvider::OpenAi => "OpenAI",
 49        }
 50    }
 51
 52    fn api_url(&self) -> &'static str {
 53        match self {
 54            LlmCompatibleProvider::OpenAi => "https://api.openai.com/v1",
 55        }
 56    }
 57}
 58
 59struct AddLlmProviderInput {
 60    provider_name: Entity<InputField>,
 61    api_url: Entity<InputField>,
 62    api_key: Entity<InputField>,
 63    models: Vec<ModelInput>,
 64}
 65
 66impl AddLlmProviderInput {
 67    fn new(provider: LlmCompatibleProvider, window: &mut Window, cx: &mut App) -> Self {
 68        let provider_name =
 69            single_line_input("Provider Name", provider.name(), None, 1, window, cx);
 70        let api_url = single_line_input("API URL", provider.api_url(), None, 2, window, cx);
 71        let api_key = single_line_input(
 72            "API Key",
 73            "000000000000000000000000000000000000000000000000",
 74            None,
 75            3,
 76            window,
 77            cx,
 78        );
 79
 80        Self {
 81            provider_name,
 82            api_url,
 83            api_key,
 84            models: vec![ModelInput::new(0, window, cx)],
 85        }
 86    }
 87
 88    fn add_model(&mut self, window: &mut Window, cx: &mut App) {
 89        let model_index = self.models.len();
 90        self.models.push(ModelInput::new(model_index, window, cx));
 91    }
 92
 93    fn remove_model(&mut self, index: usize) {
 94        self.models.remove(index);
 95    }
 96}
 97
 98struct ModelCapabilityToggles {
 99    pub supports_tools: ToggleState,
100    pub supports_images: ToggleState,
101    pub supports_parallel_tool_calls: ToggleState,
102    pub supports_prompt_cache_key: ToggleState,
103    pub supports_chat_completions: ToggleState,
104}
105
106struct ModelInput {
107    name: Entity<InputField>,
108    max_completion_tokens: Entity<InputField>,
109    max_output_tokens: Entity<InputField>,
110    max_tokens: Entity<InputField>,
111    capabilities: ModelCapabilityToggles,
112}
113
114impl ModelInput {
115    fn new(model_index: usize, window: &mut Window, cx: &mut App) -> Self {
116        let base_tab_index = (3 + (model_index * 4)) as isize;
117
118        let model_name = single_line_input(
119            "Model Name",
120            "e.g. gpt-4o, claude-opus-4, gemini-2.5-pro",
121            None,
122            base_tab_index + 1,
123            window,
124            cx,
125        );
126        let max_completion_tokens = single_line_input(
127            "Max Completion Tokens",
128            "200000",
129            Some("200000"),
130            base_tab_index + 2,
131            window,
132            cx,
133        );
134        let max_output_tokens = single_line_input(
135            "Max Output Tokens",
136            "Max Output Tokens",
137            Some("32000"),
138            base_tab_index + 3,
139            window,
140            cx,
141        );
142        let max_tokens = single_line_input(
143            "Max Tokens",
144            "Max Tokens",
145            Some("200000"),
146            base_tab_index + 4,
147            window,
148            cx,
149        );
150
151        let ModelCapabilities {
152            tools,
153            images,
154            parallel_tool_calls,
155            prompt_cache_key,
156            chat_completions,
157        } = ModelCapabilities::default();
158
159        Self {
160            name: model_name,
161            max_completion_tokens,
162            max_output_tokens,
163            max_tokens,
164            capabilities: ModelCapabilityToggles {
165                supports_tools: tools.into(),
166                supports_images: images.into(),
167                supports_parallel_tool_calls: parallel_tool_calls.into(),
168                supports_prompt_cache_key: prompt_cache_key.into(),
169                supports_chat_completions: chat_completions.into(),
170            },
171        }
172    }
173
174    fn parse(&self, cx: &App) -> Result<AvailableModel, SharedString> {
175        let name = self.name.read(cx).text(cx);
176        if name.is_empty() {
177            return Err(SharedString::from("Model Name cannot be empty"));
178        }
179        Ok(AvailableModel {
180            name,
181            display_name: None,
182            max_completion_tokens: Some(
183                self.max_completion_tokens
184                    .read(cx)
185                    .text(cx)
186                    .parse::<u64>()
187                    .map_err(|_| SharedString::from("Max Completion Tokens must be a number"))?,
188            ),
189            max_output_tokens: Some(
190                self.max_output_tokens
191                    .read(cx)
192                    .text(cx)
193                    .parse::<u64>()
194                    .map_err(|_| SharedString::from("Max Output Tokens must be a number"))?,
195            ),
196            max_tokens: self
197                .max_tokens
198                .read(cx)
199                .text(cx)
200                .parse::<u64>()
201                .map_err(|_| SharedString::from("Max Tokens must be a number"))?,
202            capabilities: ModelCapabilities {
203                tools: self.capabilities.supports_tools.selected(),
204                images: self.capabilities.supports_images.selected(),
205                parallel_tool_calls: self.capabilities.supports_parallel_tool_calls.selected(),
206                prompt_cache_key: self.capabilities.supports_prompt_cache_key.selected(),
207                chat_completions: self.capabilities.supports_chat_completions.selected(),
208            },
209        })
210    }
211}
212
213fn save_provider_to_settings(
214    input: &AddLlmProviderInput,
215    cx: &mut App,
216) -> Task<Result<(), SharedString>> {
217    let provider_name: Arc<str> = input.provider_name.read(cx).text(cx).into();
218    if provider_name.is_empty() {
219        return Task::ready(Err("Provider Name cannot be empty".into()));
220    }
221
222    if LanguageModelRegistry::read_global(cx)
223        .providers()
224        .iter()
225        .any(|provider| {
226            provider.id().0.as_ref() == provider_name.as_ref()
227                || provider.name().0.as_ref() == provider_name.as_ref()
228        })
229    {
230        return Task::ready(Err(
231            "Provider Name is already taken by another provider".into()
232        ));
233    }
234
235    let api_url = input.api_url.read(cx).text(cx);
236    if api_url.is_empty() {
237        return Task::ready(Err("API URL cannot be empty".into()));
238    }
239
240    let api_key = input.api_key.read(cx).text(cx);
241    if api_key.is_empty() {
242        return Task::ready(Err("API Key cannot be empty".into()));
243    }
244
245    let mut models = Vec::new();
246    let mut model_names: HashSet<String> = HashSet::default();
247    for model in &input.models {
248        match model.parse(cx) {
249            Ok(model) => {
250                if !model_names.insert(model.name.clone()) {
251                    return Task::ready(Err("Model Names must be unique".into()));
252                }
253                models.push(model)
254            }
255            Err(err) => return Task::ready(Err(err)),
256        }
257    }
258
259    let fs = <dyn Fs>::global(cx);
260    let task = cx.write_credentials(&api_url, "Bearer", api_key.as_bytes());
261    cx.spawn(async move |cx| {
262        task.await
263            .map_err(|_| SharedString::from("Failed to write API key to keychain"))?;
264        cx.update(|cx| {
265            update_settings_file(fs, cx, |settings, _cx| {
266                settings
267                    .language_models
268                    .get_or_insert_default()
269                    .openai_compatible
270                    .get_or_insert_default()
271                    .insert(
272                        provider_name,
273                        OpenAiCompatibleSettingsContent {
274                            api_url,
275                            available_models: models,
276                        },
277                    );
278            });
279        });
280        Ok(())
281    })
282}
283
284pub struct AddLlmProviderModal {
285    provider: LlmCompatibleProvider,
286    input: AddLlmProviderInput,
287    scroll_handle: ScrollHandle,
288    focus_handle: FocusHandle,
289    last_error: Option<SharedString>,
290}
291
292impl AddLlmProviderModal {
293    pub fn toggle(
294        provider: LlmCompatibleProvider,
295        workspace: &mut Workspace,
296        window: &mut Window,
297        cx: &mut Context<Workspace>,
298    ) {
299        workspace.toggle_modal(window, cx, |window, cx| Self::new(provider, window, cx));
300    }
301
302    fn new(provider: LlmCompatibleProvider, window: &mut Window, cx: &mut Context<Self>) -> Self {
303        Self {
304            input: AddLlmProviderInput::new(provider, window, cx),
305            provider,
306            last_error: None,
307            focus_handle: cx.focus_handle(),
308            scroll_handle: ScrollHandle::new(),
309        }
310    }
311
312    fn confirm(&mut self, _: &menu::Confirm, _: &mut Window, cx: &mut Context<Self>) {
313        let task = save_provider_to_settings(&self.input, cx);
314        cx.spawn(async move |this, cx| {
315            let result = task.await;
316            this.update(cx, |this, cx| match result {
317                Ok(_) => {
318                    cx.emit(DismissEvent);
319                }
320                Err(error) => {
321                    this.last_error = Some(error);
322                    cx.notify();
323                }
324            })
325        })
326        .detach_and_log_err(cx);
327    }
328
329    fn cancel(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context<Self>) {
330        cx.emit(DismissEvent);
331    }
332
333    fn render_model_section(&self, cx: &mut Context<Self>) -> impl IntoElement {
334        v_flex()
335            .mt_1()
336            .gap_2()
337            .child(
338                h_flex()
339                    .justify_between()
340                    .child(Label::new("Models").size(LabelSize::Small))
341                    .child(
342                        Button::new("add-model", "Add Model")
343                            .icon(IconName::Plus)
344                            .icon_position(IconPosition::Start)
345                            .icon_size(IconSize::XSmall)
346                            .icon_color(Color::Muted)
347                            .label_size(LabelSize::Small)
348                            .on_click(cx.listener(|this, _, window, cx| {
349                                this.input.add_model(window, cx);
350                                cx.notify();
351                            })),
352                    ),
353            )
354            .children(
355                self.input
356                    .models
357                    .iter()
358                    .enumerate()
359                    .map(|(ix, _)| self.render_model(ix, cx)),
360            )
361    }
362
363    fn render_model(&self, ix: usize, cx: &mut Context<Self>) -> impl IntoElement + use<> {
364        let has_more_than_one_model = self.input.models.len() > 1;
365        let model = &self.input.models[ix];
366
367        v_flex()
368            .p_2()
369            .gap_2()
370            .rounded_sm()
371            .border_1()
372            .border_dashed()
373            .border_color(cx.theme().colors().border.opacity(0.6))
374            .bg(cx.theme().colors().element_active.opacity(0.15))
375            .child(model.name.clone())
376            .child(
377                h_flex()
378                    .gap_2()
379                    .child(model.max_completion_tokens.clone())
380                    .child(model.max_output_tokens.clone()),
381            )
382            .child(model.max_tokens.clone())
383            .child(
384                v_flex()
385                    .gap_1()
386                    .child(
387                        Checkbox::new(("supports-tools", ix), model.capabilities.supports_tools)
388                            .label("Supports tools")
389                            .on_click(cx.listener(move |this, checked, _window, cx| {
390                                this.input.models[ix].capabilities.supports_tools = *checked;
391                                cx.notify();
392                            })),
393                    )
394                    .child(
395                        Checkbox::new(("supports-images", ix), model.capabilities.supports_images)
396                            .label("Supports images")
397                            .on_click(cx.listener(move |this, checked, _window, cx| {
398                                this.input.models[ix].capabilities.supports_images = *checked;
399                                cx.notify();
400                            })),
401                    )
402                    .child(
403                        Checkbox::new(
404                            ("supports-parallel-tool-calls", ix),
405                            model.capabilities.supports_parallel_tool_calls,
406                        )
407                        .label("Supports parallel_tool_calls")
408                        .on_click(cx.listener(
409                            move |this, checked, _window, cx| {
410                                this.input.models[ix]
411                                    .capabilities
412                                    .supports_parallel_tool_calls = *checked;
413                                cx.notify();
414                            },
415                        )),
416                    )
417                    .child(
418                        Checkbox::new(
419                            ("supports-prompt-cache-key", ix),
420                            model.capabilities.supports_prompt_cache_key,
421                        )
422                        .label("Supports prompt_cache_key")
423                        .on_click(cx.listener(
424                            move |this, checked, _window, cx| {
425                                this.input.models[ix].capabilities.supports_prompt_cache_key =
426                                    *checked;
427                                cx.notify();
428                            },
429                        )),
430                    )
431                    .child(
432                        Checkbox::new(
433                            ("supports-chat-completions", ix),
434                            model.capabilities.supports_chat_completions,
435                        )
436                        .label("Supports /chat/completions")
437                        .on_click(cx.listener(
438                            move |this, checked, _window, cx| {
439                                this.input.models[ix].capabilities.supports_chat_completions =
440                                    *checked;
441                                cx.notify();
442                            },
443                        )),
444                    ),
445            )
446            .when(has_more_than_one_model, |this| {
447                this.child(
448                    Button::new(("remove-model", ix), "Remove Model")
449                        .icon(IconName::Trash)
450                        .icon_position(IconPosition::Start)
451                        .icon_size(IconSize::XSmall)
452                        .icon_color(Color::Muted)
453                        .label_size(LabelSize::Small)
454                        .style(ButtonStyle::Outlined)
455                        .full_width()
456                        .on_click(cx.listener(move |this, _, _window, cx| {
457                            this.input.remove_model(ix);
458                            cx.notify();
459                        })),
460                )
461            })
462    }
463
464    fn on_tab(&mut self, _: &menu::SelectNext, window: &mut Window, cx: &mut Context<Self>) {
465        window.focus_next(cx);
466    }
467
468    fn on_tab_prev(
469        &mut self,
470        _: &menu::SelectPrevious,
471        window: &mut Window,
472        cx: &mut Context<Self>,
473    ) {
474        window.focus_prev(cx);
475    }
476}
477
478impl EventEmitter<DismissEvent> for AddLlmProviderModal {}
479
480impl Focusable for AddLlmProviderModal {
481    fn focus_handle(&self, _cx: &App) -> FocusHandle {
482        self.focus_handle.clone()
483    }
484}
485
486impl ModalView for AddLlmProviderModal {}
487
488impl Render for AddLlmProviderModal {
489    fn render(&mut self, window: &mut ui::Window, cx: &mut ui::Context<Self>) -> impl IntoElement {
490        let focus_handle = self.focus_handle(cx);
491
492        let window_size = window.viewport_size();
493        let rem_size = window.rem_size();
494        let is_large_window = window_size.height / rem_size > rems_from_px(600.).0;
495
496        let modal_max_height = if is_large_window {
497            rems_from_px(450.)
498        } else {
499            rems_from_px(200.)
500        };
501
502        v_flex()
503            .id("add-llm-provider-modal")
504            .key_context("AddLlmProviderModal")
505            .w(rems(34.))
506            .elevation_3(cx)
507            .on_action(cx.listener(Self::cancel))
508            .on_action(cx.listener(Self::on_tab))
509            .on_action(cx.listener(Self::on_tab_prev))
510            .capture_any_mouse_down(cx.listener(|this, _, window, cx| {
511                this.focus_handle(cx).focus(window, cx);
512            }))
513            .child(
514                Modal::new("configure-context-server", None)
515                    .header(ModalHeader::new().headline("Add LLM Provider").description(
516                        match self.provider {
517                            LlmCompatibleProvider::OpenAi => {
518                                "This provider will use an OpenAI compatible API."
519                            }
520                        },
521                    ))
522                    .when_some(self.last_error.clone(), |this, error| {
523                        this.section(
524                            Section::new().child(
525                                Banner::new()
526                                    .severity(Severity::Warning)
527                                    .child(div().text_xs().child(error)),
528                            ),
529                        )
530                    })
531                    .child(
532                        div()
533                            .size_full()
534                            .vertical_scrollbar_for(&self.scroll_handle, window, cx)
535                            .child(
536                                v_flex()
537                                    .id("modal_content")
538                                    .size_full()
539                                    .tab_group()
540                                    .max_h(modal_max_height)
541                                    .pl_3()
542                                    .pr_4()
543                                    .gap_2()
544                                    .overflow_y_scroll()
545                                    .track_scroll(&self.scroll_handle)
546                                    .child(self.input.provider_name.clone())
547                                    .child(self.input.api_url.clone())
548                                    .child(self.input.api_key.clone())
549                                    .child(self.render_model_section(cx)),
550                            ),
551                    )
552                    .footer(
553                        ModalFooter::new().end_slot(
554                            h_flex()
555                                .gap_1()
556                                .child(
557                                    Button::new("cancel", "Cancel")
558                                        .key_binding(
559                                            KeyBinding::for_action_in(
560                                                &menu::Cancel,
561                                                &focus_handle,
562                                                cx,
563                                            )
564                                            .map(|kb| kb.size(rems_from_px(12.))),
565                                        )
566                                        .on_click(cx.listener(|this, _event, window, cx| {
567                                            this.cancel(&menu::Cancel, window, cx)
568                                        })),
569                                )
570                                .child(
571                                    Button::new("save-server", "Save Provider")
572                                        .key_binding(
573                                            KeyBinding::for_action_in(
574                                                &menu::Confirm,
575                                                &focus_handle,
576                                                cx,
577                                            )
578                                            .map(|kb| kb.size(rems_from_px(12.))),
579                                        )
580                                        .on_click(cx.listener(|this, _event, window, cx| {
581                                            this.confirm(&menu::Confirm, window, cx)
582                                        })),
583                                ),
584                        ),
585                    ),
586            )
587    }
588}
589
590#[cfg(test)]
591mod tests {
592    use super::*;
593    use fs::FakeFs;
594    use gpui::{TestAppContext, VisualTestContext};
595    use language_model::{
596        LanguageModelProviderId, LanguageModelProviderName,
597        fake_provider::FakeLanguageModelProvider,
598    };
599    use project::Project;
600    use settings::SettingsStore;
601    use util::path;
602
603    #[gpui::test]
604    async fn test_save_provider_invalid_inputs(cx: &mut TestAppContext) {
605        let cx = setup_test(cx).await;
606
607        assert_eq!(
608            save_provider_validation_errors("", "someurl", "somekey", vec![], cx,).await,
609            Some("Provider Name cannot be empty".into())
610        );
611
612        assert_eq!(
613            save_provider_validation_errors("someprovider", "", "somekey", vec![], cx,).await,
614            Some("API URL cannot be empty".into())
615        );
616
617        assert_eq!(
618            save_provider_validation_errors("someprovider", "someurl", "", vec![], cx,).await,
619            Some("API Key cannot be empty".into())
620        );
621
622        assert_eq!(
623            save_provider_validation_errors(
624                "someprovider",
625                "someurl",
626                "somekey",
627                vec![("", "200000", "200000", "32000")],
628                cx,
629            )
630            .await,
631            Some("Model Name cannot be empty".into())
632        );
633
634        assert_eq!(
635            save_provider_validation_errors(
636                "someprovider",
637                "someurl",
638                "somekey",
639                vec![("somemodel", "abc", "200000", "32000")],
640                cx,
641            )
642            .await,
643            Some("Max Tokens must be a number".into())
644        );
645
646        assert_eq!(
647            save_provider_validation_errors(
648                "someprovider",
649                "someurl",
650                "somekey",
651                vec![("somemodel", "200000", "abc", "32000")],
652                cx,
653            )
654            .await,
655            Some("Max Completion Tokens must be a number".into())
656        );
657
658        assert_eq!(
659            save_provider_validation_errors(
660                "someprovider",
661                "someurl",
662                "somekey",
663                vec![("somemodel", "200000", "200000", "abc")],
664                cx,
665            )
666            .await,
667            Some("Max Output Tokens must be a number".into())
668        );
669
670        assert_eq!(
671            save_provider_validation_errors(
672                "someprovider",
673                "someurl",
674                "somekey",
675                vec![
676                    ("somemodel", "200000", "200000", "32000"),
677                    ("somemodel", "200000", "200000", "32000"),
678                ],
679                cx,
680            )
681            .await,
682            Some("Model Names must be unique".into())
683        );
684    }
685
686    #[gpui::test]
687    async fn test_save_provider_name_conflict(cx: &mut TestAppContext) {
688        let cx = setup_test(cx).await;
689
690        cx.update(|_window, cx| {
691            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
692                registry.register_provider(
693                    Arc::new(FakeLanguageModelProvider::new(
694                        LanguageModelProviderId::new("someprovider"),
695                        LanguageModelProviderName::new("Some Provider"),
696                    )),
697                    cx,
698                );
699            });
700        });
701
702        assert_eq!(
703            save_provider_validation_errors(
704                "someprovider",
705                "someurl",
706                "someapikey",
707                vec![("somemodel", "200000", "200000", "32000")],
708                cx,
709            )
710            .await,
711            Some("Provider Name is already taken by another provider".into())
712        );
713    }
714
715    #[gpui::test]
716    async fn test_model_input_default_capabilities(cx: &mut TestAppContext) {
717        let cx = setup_test(cx).await;
718
719        cx.update(|window, cx| {
720            let model_input = ModelInput::new(0, window, cx);
721            model_input.name.update(cx, |input, cx| {
722                input.set_text("somemodel", window, cx);
723            });
724            assert_eq!(
725                model_input.capabilities.supports_tools,
726                ToggleState::Selected
727            );
728            assert_eq!(
729                model_input.capabilities.supports_images,
730                ToggleState::Unselected
731            );
732            assert_eq!(
733                model_input.capabilities.supports_parallel_tool_calls,
734                ToggleState::Unselected
735            );
736            assert_eq!(
737                model_input.capabilities.supports_prompt_cache_key,
738                ToggleState::Unselected
739            );
740            assert_eq!(
741                model_input.capabilities.supports_chat_completions,
742                ToggleState::Selected
743            );
744
745            let parsed_model = model_input.parse(cx).unwrap();
746            assert!(parsed_model.capabilities.tools);
747            assert!(!parsed_model.capabilities.images);
748            assert!(!parsed_model.capabilities.parallel_tool_calls);
749            assert!(!parsed_model.capabilities.prompt_cache_key);
750            assert!(parsed_model.capabilities.chat_completions);
751        });
752    }
753
754    #[gpui::test]
755    async fn test_model_input_deselected_capabilities(cx: &mut TestAppContext) {
756        let cx = setup_test(cx).await;
757
758        cx.update(|window, cx| {
759            let mut model_input = ModelInput::new(0, window, cx);
760            model_input.name.update(cx, |input, cx| {
761                input.set_text("somemodel", window, cx);
762            });
763
764            model_input.capabilities.supports_tools = ToggleState::Unselected;
765            model_input.capabilities.supports_images = ToggleState::Unselected;
766            model_input.capabilities.supports_parallel_tool_calls = ToggleState::Unselected;
767            model_input.capabilities.supports_prompt_cache_key = ToggleState::Unselected;
768            model_input.capabilities.supports_chat_completions = ToggleState::Unselected;
769
770            let parsed_model = model_input.parse(cx).unwrap();
771            assert!(!parsed_model.capabilities.tools);
772            assert!(!parsed_model.capabilities.images);
773            assert!(!parsed_model.capabilities.parallel_tool_calls);
774            assert!(!parsed_model.capabilities.prompt_cache_key);
775            assert!(!parsed_model.capabilities.chat_completions);
776        });
777    }
778
779    #[gpui::test]
780    async fn test_model_input_with_name_and_capabilities(cx: &mut TestAppContext) {
781        let cx = setup_test(cx).await;
782
783        cx.update(|window, cx| {
784            let mut model_input = ModelInput::new(0, window, cx);
785            model_input.name.update(cx, |input, cx| {
786                input.set_text("somemodel", window, cx);
787            });
788
789            model_input.capabilities.supports_tools = ToggleState::Selected;
790            model_input.capabilities.supports_images = ToggleState::Unselected;
791            model_input.capabilities.supports_parallel_tool_calls = ToggleState::Selected;
792            model_input.capabilities.supports_prompt_cache_key = ToggleState::Unselected;
793            model_input.capabilities.supports_chat_completions = ToggleState::Selected;
794
795            let parsed_model = model_input.parse(cx).unwrap();
796            assert_eq!(parsed_model.name, "somemodel");
797            assert!(parsed_model.capabilities.tools);
798            assert!(!parsed_model.capabilities.images);
799            assert!(parsed_model.capabilities.parallel_tool_calls);
800            assert!(!parsed_model.capabilities.prompt_cache_key);
801            assert!(parsed_model.capabilities.chat_completions);
802        });
803    }
804
805    async fn setup_test(cx: &mut TestAppContext) -> &mut VisualTestContext {
806        cx.update(|cx| {
807            let store = SettingsStore::test(cx);
808            cx.set_global(store);
809            theme::init(theme::LoadThemes::JustBase, cx);
810
811            language_model::init_settings(cx);
812            editor::init(cx);
813        });
814
815        let fs = FakeFs::new(cx.executor());
816        cx.update(|cx| <dyn Fs>::set_global(fs.clone(), cx));
817        let project = Project::test(fs, [path!("/dir").as_ref()], cx).await;
818        let (_, cx) =
819            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
820
821        cx
822    }
823
824    async fn save_provider_validation_errors(
825        provider_name: &str,
826        api_url: &str,
827        api_key: &str,
828        models: Vec<(&str, &str, &str, &str)>,
829        cx: &mut VisualTestContext,
830    ) -> Option<SharedString> {
831        fn set_text(input: &Entity<InputField>, text: &str, window: &mut Window, cx: &mut App) {
832            input.update(cx, |input, cx| {
833                input.set_text(text, window, cx);
834            });
835        }
836
837        let task = cx.update(|window, cx| {
838            let mut input = AddLlmProviderInput::new(LlmCompatibleProvider::OpenAi, window, cx);
839            set_text(&input.provider_name, provider_name, window, cx);
840            set_text(&input.api_url, api_url, window, cx);
841            set_text(&input.api_key, api_key, window, cx);
842
843            for (i, (name, max_tokens, max_completion_tokens, max_output_tokens)) in
844                models.iter().enumerate()
845            {
846                if i >= input.models.len() {
847                    input.models.push(ModelInput::new(i, window, cx));
848                }
849                let model = &mut input.models[i];
850                set_text(&model.name, name, window, cx);
851                set_text(&model.max_tokens, max_tokens, window, cx);
852                set_text(
853                    &model.max_completion_tokens,
854                    max_completion_tokens,
855                    window,
856                    cx,
857                );
858                set_text(&model.max_output_tokens, max_output_tokens, window, cx);
859            }
860            save_provider_to_settings(&input, cx)
861        });
862
863        task.await.err()
864    }
865}