add_llm_provider_modal.rs

  1use std::sync::Arc;
  2
  3use anyhow::Result;
  4use collections::HashSet;
  5use fs::Fs;
  6use gpui::{
  7    DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, Render, ScrollHandle, Task,
  8};
  9use language_model::LanguageModelRegistry;
 10use language_models::provider::open_ai_compatible::{AvailableModel, ModelCapabilities};
 11use settings::{OpenAiCompatibleSettingsContent, update_settings_file};
 12use ui::{
 13    Banner, Checkbox, KeyBinding, Modal, ModalFooter, ModalHeader, Section, ToggleState,
 14    WithScrollbar, prelude::*,
 15};
 16use ui_input::InputField;
 17use workspace::{ModalView, Workspace};
 18
 19fn single_line_input(
 20    label: impl Into<SharedString>,
 21    placeholder: impl Into<SharedString>,
 22    text: Option<&str>,
 23    tab_index: isize,
 24    window: &mut Window,
 25    cx: &mut App,
 26) -> Entity<InputField> {
 27    cx.new(|cx| {
 28        let input = InputField::new(window, cx, placeholder)
 29            .label(label)
 30            .tab_index(tab_index)
 31            .tab_stop(true);
 32
 33        if let Some(text) = text {
 34            input
 35                .editor()
 36                .update(cx, |editor, cx| editor.set_text(text, window, cx));
 37        }
 38        input
 39    })
 40}
 41
 42#[derive(Clone, Copy)]
 43pub enum LlmCompatibleProvider {
 44    OpenAi,
 45}
 46
 47impl LlmCompatibleProvider {
 48    fn name(&self) -> &'static str {
 49        match self {
 50            LlmCompatibleProvider::OpenAi => "OpenAI",
 51        }
 52    }
 53
 54    fn api_url(&self) -> &'static str {
 55        match self {
 56            LlmCompatibleProvider::OpenAi => "https://api.openai.com/v1",
 57        }
 58    }
 59}
 60
 61struct AddLlmProviderInput {
 62    provider_name: Entity<InputField>,
 63    api_url: Entity<InputField>,
 64    api_key: Entity<InputField>,
 65    models: Vec<ModelInput>,
 66}
 67
 68impl AddLlmProviderInput {
 69    fn new(provider: LlmCompatibleProvider, window: &mut Window, cx: &mut App) -> Self {
 70        let provider_name =
 71            single_line_input("Provider Name", provider.name(), None, 1, window, cx);
 72        let api_url = single_line_input("API URL", provider.api_url(), None, 2, window, cx);
 73        let api_key = single_line_input(
 74            "API Key",
 75            "000000000000000000000000000000000000000000000000",
 76            None,
 77            3,
 78            window,
 79            cx,
 80        );
 81
 82        Self {
 83            provider_name,
 84            api_url,
 85            api_key,
 86            models: vec![ModelInput::new(0, window, cx)],
 87        }
 88    }
 89
 90    fn add_model(&mut self, window: &mut Window, cx: &mut App) {
 91        let model_index = self.models.len();
 92        self.models.push(ModelInput::new(model_index, window, cx));
 93    }
 94
 95    fn remove_model(&mut self, index: usize) {
 96        self.models.remove(index);
 97    }
 98}
 99
100struct ModelCapabilityToggles {
101    pub supports_tools: ToggleState,
102    pub supports_images: ToggleState,
103    pub supports_parallel_tool_calls: ToggleState,
104    pub supports_prompt_cache_key: ToggleState,
105    pub supports_chat_completions: ToggleState,
106}
107
108struct ModelInput {
109    name: Entity<InputField>,
110    max_completion_tokens: Entity<InputField>,
111    max_output_tokens: Entity<InputField>,
112    max_tokens: Entity<InputField>,
113    capabilities: ModelCapabilityToggles,
114}
115
116impl ModelInput {
117    fn new(model_index: usize, window: &mut Window, cx: &mut App) -> Self {
118        let base_tab_index = (3 + (model_index * 4)) as isize;
119
120        let model_name = single_line_input(
121            "Model Name",
122            "e.g. gpt-4o, claude-opus-4, gemini-2.5-pro",
123            None,
124            base_tab_index + 1,
125            window,
126            cx,
127        );
128        let max_completion_tokens = single_line_input(
129            "Max Completion Tokens",
130            "200000",
131            Some("200000"),
132            base_tab_index + 2,
133            window,
134            cx,
135        );
136        let max_output_tokens = single_line_input(
137            "Max Output Tokens",
138            "Max Output Tokens",
139            Some("32000"),
140            base_tab_index + 3,
141            window,
142            cx,
143        );
144        let max_tokens = single_line_input(
145            "Max Tokens",
146            "Max Tokens",
147            Some("200000"),
148            base_tab_index + 4,
149            window,
150            cx,
151        );
152
153        let ModelCapabilities {
154            tools,
155            images,
156            parallel_tool_calls,
157            prompt_cache_key,
158            chat_completions,
159        } = ModelCapabilities::default();
160
161        Self {
162            name: model_name,
163            max_completion_tokens,
164            max_output_tokens,
165            max_tokens,
166            capabilities: ModelCapabilityToggles {
167                supports_tools: tools.into(),
168                supports_images: images.into(),
169                supports_parallel_tool_calls: parallel_tool_calls.into(),
170                supports_prompt_cache_key: prompt_cache_key.into(),
171                supports_chat_completions: chat_completions.into(),
172            },
173        }
174    }
175
176    fn parse(&self, cx: &App) -> Result<AvailableModel, SharedString> {
177        let name = self.name.read(cx).text(cx);
178        if name.is_empty() {
179            return Err(SharedString::from("Model Name cannot be empty"));
180        }
181        Ok(AvailableModel {
182            name,
183            display_name: None,
184            max_completion_tokens: Some(
185                self.max_completion_tokens
186                    .read(cx)
187                    .text(cx)
188                    .parse::<u64>()
189                    .map_err(|_| SharedString::from("Max Completion Tokens must be a number"))?,
190            ),
191            max_output_tokens: Some(
192                self.max_output_tokens
193                    .read(cx)
194                    .text(cx)
195                    .parse::<u64>()
196                    .map_err(|_| SharedString::from("Max Output Tokens must be a number"))?,
197            ),
198            max_tokens: self
199                .max_tokens
200                .read(cx)
201                .text(cx)
202                .parse::<u64>()
203                .map_err(|_| SharedString::from("Max Tokens must be a number"))?,
204            capabilities: ModelCapabilities {
205                tools: self.capabilities.supports_tools.selected(),
206                images: self.capabilities.supports_images.selected(),
207                parallel_tool_calls: self.capabilities.supports_parallel_tool_calls.selected(),
208                prompt_cache_key: self.capabilities.supports_prompt_cache_key.selected(),
209                chat_completions: self.capabilities.supports_chat_completions.selected(),
210            },
211        })
212    }
213}
214
215fn save_provider_to_settings(
216    input: &AddLlmProviderInput,
217    cx: &mut App,
218) -> Task<Result<(), SharedString>> {
219    let provider_name: Arc<str> = input.provider_name.read(cx).text(cx).into();
220    if provider_name.is_empty() {
221        return Task::ready(Err("Provider Name cannot be empty".into()));
222    }
223
224    if LanguageModelRegistry::read_global(cx)
225        .providers()
226        .iter()
227        .any(|provider| {
228            provider.id().0.as_ref() == provider_name.as_ref()
229                || provider.name().0.as_ref() == provider_name.as_ref()
230        })
231    {
232        return Task::ready(Err(
233            "Provider Name is already taken by another provider".into()
234        ));
235    }
236
237    let api_url = input.api_url.read(cx).text(cx);
238    if api_url.is_empty() {
239        return Task::ready(Err("API URL cannot be empty".into()));
240    }
241
242    let api_key = input.api_key.read(cx).text(cx);
243    if api_key.is_empty() {
244        return Task::ready(Err("API Key cannot be empty".into()));
245    }
246
247    let mut models = Vec::new();
248    let mut model_names: HashSet<String> = HashSet::default();
249    for model in &input.models {
250        match model.parse(cx) {
251            Ok(model) => {
252                if !model_names.insert(model.name.clone()) {
253                    return Task::ready(Err("Model Names must be unique".into()));
254                }
255                models.push(model)
256            }
257            Err(err) => return Task::ready(Err(err)),
258        }
259    }
260
261    let fs = <dyn Fs>::global(cx);
262    let task = cx.write_credentials(&api_url, "Bearer", api_key.as_bytes());
263    cx.spawn(async move |cx| {
264        task.await
265            .map_err(|_| "Failed to write API key to keychain")?;
266        cx.update(|cx| {
267            update_settings_file(fs, cx, |settings, _cx| {
268                settings
269                    .language_models
270                    .get_or_insert_default()
271                    .openai_compatible
272                    .get_or_insert_default()
273                    .insert(
274                        provider_name,
275                        OpenAiCompatibleSettingsContent {
276                            api_url,
277                            available_models: models,
278                        },
279                    );
280            });
281        })
282        .ok();
283        Ok(())
284    })
285}
286
287pub struct AddLlmProviderModal {
288    provider: LlmCompatibleProvider,
289    input: AddLlmProviderInput,
290    scroll_handle: ScrollHandle,
291    focus_handle: FocusHandle,
292    last_error: Option<SharedString>,
293}
294
295impl AddLlmProviderModal {
296    pub fn toggle(
297        provider: LlmCompatibleProvider,
298        workspace: &mut Workspace,
299        window: &mut Window,
300        cx: &mut Context<Workspace>,
301    ) {
302        workspace.toggle_modal(window, cx, |window, cx| Self::new(provider, window, cx));
303    }
304
305    fn new(provider: LlmCompatibleProvider, window: &mut Window, cx: &mut Context<Self>) -> Self {
306        Self {
307            input: AddLlmProviderInput::new(provider, window, cx),
308            provider,
309            last_error: None,
310            focus_handle: cx.focus_handle(),
311            scroll_handle: ScrollHandle::new(),
312        }
313    }
314
315    fn confirm(&mut self, _: &menu::Confirm, _: &mut Window, cx: &mut Context<Self>) {
316        let task = save_provider_to_settings(&self.input, cx);
317        cx.spawn(async move |this, cx| {
318            let result = task.await;
319            this.update(cx, |this, cx| match result {
320                Ok(_) => {
321                    cx.emit(DismissEvent);
322                }
323                Err(error) => {
324                    this.last_error = Some(error);
325                    cx.notify();
326                }
327            })
328        })
329        .detach_and_log_err(cx);
330    }
331
332    fn cancel(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context<Self>) {
333        cx.emit(DismissEvent);
334    }
335
336    fn render_model_section(&self, cx: &mut Context<Self>) -> impl IntoElement {
337        v_flex()
338            .mt_1()
339            .gap_2()
340            .child(
341                h_flex()
342                    .justify_between()
343                    .child(Label::new("Models").size(LabelSize::Small))
344                    .child(
345                        Button::new("add-model", "Add Model")
346                            .icon(IconName::Plus)
347                            .icon_position(IconPosition::Start)
348                            .icon_size(IconSize::XSmall)
349                            .icon_color(Color::Muted)
350                            .label_size(LabelSize::Small)
351                            .on_click(cx.listener(|this, _, window, cx| {
352                                this.input.add_model(window, cx);
353                                cx.notify();
354                            })),
355                    ),
356            )
357            .children(
358                self.input
359                    .models
360                    .iter()
361                    .enumerate()
362                    .map(|(ix, _)| self.render_model(ix, cx)),
363            )
364    }
365
366    fn render_model(&self, ix: usize, cx: &mut Context<Self>) -> impl IntoElement + use<> {
367        let has_more_than_one_model = self.input.models.len() > 1;
368        let model = &self.input.models[ix];
369
370        v_flex()
371            .p_2()
372            .gap_2()
373            .rounded_sm()
374            .border_1()
375            .border_dashed()
376            .border_color(cx.theme().colors().border.opacity(0.6))
377            .bg(cx.theme().colors().element_active.opacity(0.15))
378            .child(model.name.clone())
379            .child(
380                h_flex()
381                    .gap_2()
382                    .child(model.max_completion_tokens.clone())
383                    .child(model.max_output_tokens.clone()),
384            )
385            .child(model.max_tokens.clone())
386            .child(
387                v_flex()
388                    .gap_1()
389                    .child(
390                        Checkbox::new(("supports-tools", ix), model.capabilities.supports_tools)
391                            .label("Supports tools")
392                            .on_click(cx.listener(move |this, checked, _window, cx| {
393                                this.input.models[ix].capabilities.supports_tools = *checked;
394                                cx.notify();
395                            })),
396                    )
397                    .child(
398                        Checkbox::new(("supports-images", ix), model.capabilities.supports_images)
399                            .label("Supports images")
400                            .on_click(cx.listener(move |this, checked, _window, cx| {
401                                this.input.models[ix].capabilities.supports_images = *checked;
402                                cx.notify();
403                            })),
404                    )
405                    .child(
406                        Checkbox::new(
407                            ("supports-parallel-tool-calls", ix),
408                            model.capabilities.supports_parallel_tool_calls,
409                        )
410                        .label("Supports parallel_tool_calls")
411                        .on_click(cx.listener(
412                            move |this, checked, _window, cx| {
413                                this.input.models[ix]
414                                    .capabilities
415                                    .supports_parallel_tool_calls = *checked;
416                                cx.notify();
417                            },
418                        )),
419                    )
420                    .child(
421                        Checkbox::new(
422                            ("supports-prompt-cache-key", ix),
423                            model.capabilities.supports_prompt_cache_key,
424                        )
425                        .label("Supports prompt_cache_key")
426                        .on_click(cx.listener(
427                            move |this, checked, _window, cx| {
428                                this.input.models[ix].capabilities.supports_prompt_cache_key =
429                                    *checked;
430                                cx.notify();
431                            },
432                        )),
433                    )
434                    .child(
435                        Checkbox::new(
436                            ("supports-chat-completions", ix),
437                            model.capabilities.supports_chat_completions,
438                        )
439                        .label("Supports /chat/completions")
440                        .on_click(cx.listener(
441                            move |this, checked, _window, cx| {
442                                this.input.models[ix].capabilities.supports_chat_completions =
443                                    *checked;
444                                cx.notify();
445                            },
446                        )),
447                    ),
448            )
449            .when(has_more_than_one_model, |this| {
450                this.child(
451                    Button::new(("remove-model", ix), "Remove Model")
452                        .icon(IconName::Trash)
453                        .icon_position(IconPosition::Start)
454                        .icon_size(IconSize::XSmall)
455                        .icon_color(Color::Muted)
456                        .label_size(LabelSize::Small)
457                        .style(ButtonStyle::Outlined)
458                        .full_width()
459                        .on_click(cx.listener(move |this, _, _window, cx| {
460                            this.input.remove_model(ix);
461                            cx.notify();
462                        })),
463                )
464            })
465    }
466
467    fn on_tab(&mut self, _: &menu::SelectNext, window: &mut Window, cx: &mut Context<Self>) {
468        window.focus_next(cx);
469    }
470
471    fn on_tab_prev(
472        &mut self,
473        _: &menu::SelectPrevious,
474        window: &mut Window,
475        cx: &mut Context<Self>,
476    ) {
477        window.focus_prev(cx);
478    }
479}
480
481impl EventEmitter<DismissEvent> for AddLlmProviderModal {}
482
483impl Focusable for AddLlmProviderModal {
484    fn focus_handle(&self, _cx: &App) -> FocusHandle {
485        self.focus_handle.clone()
486    }
487}
488
489impl ModalView for AddLlmProviderModal {}
490
491impl Render for AddLlmProviderModal {
492    fn render(&mut self, window: &mut ui::Window, cx: &mut ui::Context<Self>) -> impl IntoElement {
493        let focus_handle = self.focus_handle(cx);
494
495        let window_size = window.viewport_size();
496        let rem_size = window.rem_size();
497        let is_large_window = window_size.height / rem_size > rems_from_px(600.).0;
498
499        let modal_max_height = if is_large_window {
500            rems_from_px(450.)
501        } else {
502            rems_from_px(200.)
503        };
504
505        v_flex()
506            .id("add-llm-provider-modal")
507            .key_context("AddLlmProviderModal")
508            .w(rems(34.))
509            .elevation_3(cx)
510            .on_action(cx.listener(Self::cancel))
511            .on_action(cx.listener(Self::on_tab))
512            .on_action(cx.listener(Self::on_tab_prev))
513            .capture_any_mouse_down(cx.listener(|this, _, window, cx| {
514                this.focus_handle(cx).focus(window, cx);
515            }))
516            .child(
517                Modal::new("configure-context-server", None)
518                    .header(ModalHeader::new().headline("Add LLM Provider").description(
519                        match self.provider {
520                            LlmCompatibleProvider::OpenAi => {
521                                "This provider will use an OpenAI compatible API."
522                            }
523                        },
524                    ))
525                    .when_some(self.last_error.clone(), |this, error| {
526                        this.section(
527                            Section::new().child(
528                                Banner::new()
529                                    .severity(Severity::Warning)
530                                    .child(div().text_xs().child(error)),
531                            ),
532                        )
533                    })
534                    .child(
535                        div()
536                            .size_full()
537                            .vertical_scrollbar_for(&self.scroll_handle, window, cx)
538                            .child(
539                                v_flex()
540                                    .id("modal_content")
541                                    .size_full()
542                                    .tab_group()
543                                    .max_h(modal_max_height)
544                                    .pl_3()
545                                    .pr_4()
546                                    .gap_2()
547                                    .overflow_y_scroll()
548                                    .track_scroll(&self.scroll_handle)
549                                    .child(self.input.provider_name.clone())
550                                    .child(self.input.api_url.clone())
551                                    .child(self.input.api_key.clone())
552                                    .child(self.render_model_section(cx)),
553                            ),
554                    )
555                    .footer(
556                        ModalFooter::new().end_slot(
557                            h_flex()
558                                .gap_1()
559                                .child(
560                                    Button::new("cancel", "Cancel")
561                                        .key_binding(
562                                            KeyBinding::for_action_in(
563                                                &menu::Cancel,
564                                                &focus_handle,
565                                                cx,
566                                            )
567                                            .map(|kb| kb.size(rems_from_px(12.))),
568                                        )
569                                        .on_click(cx.listener(|this, _event, window, cx| {
570                                            this.cancel(&menu::Cancel, window, cx)
571                                        })),
572                                )
573                                .child(
574                                    Button::new("save-server", "Save Provider")
575                                        .key_binding(
576                                            KeyBinding::for_action_in(
577                                                &menu::Confirm,
578                                                &focus_handle,
579                                                cx,
580                                            )
581                                            .map(|kb| kb.size(rems_from_px(12.))),
582                                        )
583                                        .on_click(cx.listener(|this, _event, window, cx| {
584                                            this.confirm(&menu::Confirm, window, cx)
585                                        })),
586                                ),
587                        ),
588                    ),
589            )
590    }
591}
592
593#[cfg(test)]
594mod tests {
595    use super::*;
596    use fs::FakeFs;
597    use gpui::{TestAppContext, VisualTestContext};
598    use language_model::{
599        LanguageModelProviderId, LanguageModelProviderName,
600        fake_provider::FakeLanguageModelProvider,
601    };
602    use project::Project;
603    use settings::SettingsStore;
604    use util::path;
605
606    #[gpui::test]
607    async fn test_save_provider_invalid_inputs(cx: &mut TestAppContext) {
608        let cx = setup_test(cx).await;
609
610        assert_eq!(
611            save_provider_validation_errors("", "someurl", "somekey", vec![], cx,).await,
612            Some("Provider Name cannot be empty".into())
613        );
614
615        assert_eq!(
616            save_provider_validation_errors("someprovider", "", "somekey", vec![], cx,).await,
617            Some("API URL cannot be empty".into())
618        );
619
620        assert_eq!(
621            save_provider_validation_errors("someprovider", "someurl", "", vec![], cx,).await,
622            Some("API Key cannot be empty".into())
623        );
624
625        assert_eq!(
626            save_provider_validation_errors(
627                "someprovider",
628                "someurl",
629                "somekey",
630                vec![("", "200000", "200000", "32000")],
631                cx,
632            )
633            .await,
634            Some("Model Name cannot be empty".into())
635        );
636
637        assert_eq!(
638            save_provider_validation_errors(
639                "someprovider",
640                "someurl",
641                "somekey",
642                vec![("somemodel", "abc", "200000", "32000")],
643                cx,
644            )
645            .await,
646            Some("Max Tokens must be a number".into())
647        );
648
649        assert_eq!(
650            save_provider_validation_errors(
651                "someprovider",
652                "someurl",
653                "somekey",
654                vec![("somemodel", "200000", "abc", "32000")],
655                cx,
656            )
657            .await,
658            Some("Max Completion Tokens must be a number".into())
659        );
660
661        assert_eq!(
662            save_provider_validation_errors(
663                "someprovider",
664                "someurl",
665                "somekey",
666                vec![("somemodel", "200000", "200000", "abc")],
667                cx,
668            )
669            .await,
670            Some("Max Output Tokens must be a number".into())
671        );
672
673        assert_eq!(
674            save_provider_validation_errors(
675                "someprovider",
676                "someurl",
677                "somekey",
678                vec![
679                    ("somemodel", "200000", "200000", "32000"),
680                    ("somemodel", "200000", "200000", "32000"),
681                ],
682                cx,
683            )
684            .await,
685            Some("Model Names must be unique".into())
686        );
687    }
688
689    #[gpui::test]
690    async fn test_save_provider_name_conflict(cx: &mut TestAppContext) {
691        let cx = setup_test(cx).await;
692
693        cx.update(|_window, cx| {
694            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
695                registry.register_provider(
696                    Arc::new(FakeLanguageModelProvider::new(
697                        LanguageModelProviderId::new("someprovider"),
698                        LanguageModelProviderName::new("Some Provider"),
699                    )),
700                    cx,
701                );
702            });
703        });
704
705        assert_eq!(
706            save_provider_validation_errors(
707                "someprovider",
708                "someurl",
709                "someapikey",
710                vec![("somemodel", "200000", "200000", "32000")],
711                cx,
712            )
713            .await,
714            Some("Provider Name is already taken by another provider".into())
715        );
716    }
717
718    #[gpui::test]
719    async fn test_model_input_default_capabilities(cx: &mut TestAppContext) {
720        let cx = setup_test(cx).await;
721
722        cx.update(|window, cx| {
723            let model_input = ModelInput::new(0, window, cx);
724            model_input.name.update(cx, |input, cx| {
725                input.editor().update(cx, |editor, cx| {
726                    editor.set_text("somemodel", window, cx);
727                });
728            });
729            assert_eq!(
730                model_input.capabilities.supports_tools,
731                ToggleState::Selected
732            );
733            assert_eq!(
734                model_input.capabilities.supports_images,
735                ToggleState::Unselected
736            );
737            assert_eq!(
738                model_input.capabilities.supports_parallel_tool_calls,
739                ToggleState::Unselected
740            );
741            assert_eq!(
742                model_input.capabilities.supports_prompt_cache_key,
743                ToggleState::Unselected
744            );
745            assert_eq!(
746                model_input.capabilities.supports_chat_completions,
747                ToggleState::Selected
748            );
749
750            let parsed_model = model_input.parse(cx).unwrap();
751            assert!(parsed_model.capabilities.tools);
752            assert!(!parsed_model.capabilities.images);
753            assert!(!parsed_model.capabilities.parallel_tool_calls);
754            assert!(!parsed_model.capabilities.prompt_cache_key);
755            assert!(parsed_model.capabilities.chat_completions);
756        });
757    }
758
759    #[gpui::test]
760    async fn test_model_input_deselected_capabilities(cx: &mut TestAppContext) {
761        let cx = setup_test(cx).await;
762
763        cx.update(|window, cx| {
764            let mut model_input = ModelInput::new(0, window, cx);
765            model_input.name.update(cx, |input, cx| {
766                input.editor().update(cx, |editor, cx| {
767                    editor.set_text("somemodel", window, cx);
768                });
769            });
770
771            model_input.capabilities.supports_tools = ToggleState::Unselected;
772            model_input.capabilities.supports_images = ToggleState::Unselected;
773            model_input.capabilities.supports_parallel_tool_calls = ToggleState::Unselected;
774            model_input.capabilities.supports_prompt_cache_key = ToggleState::Unselected;
775            model_input.capabilities.supports_chat_completions = ToggleState::Unselected;
776
777            let parsed_model = model_input.parse(cx).unwrap();
778            assert!(!parsed_model.capabilities.tools);
779            assert!(!parsed_model.capabilities.images);
780            assert!(!parsed_model.capabilities.parallel_tool_calls);
781            assert!(!parsed_model.capabilities.prompt_cache_key);
782            assert!(!parsed_model.capabilities.chat_completions);
783        });
784    }
785
786    #[gpui::test]
787    async fn test_model_input_with_name_and_capabilities(cx: &mut TestAppContext) {
788        let cx = setup_test(cx).await;
789
790        cx.update(|window, cx| {
791            let mut model_input = ModelInput::new(0, window, cx);
792            model_input.name.update(cx, |input, cx| {
793                input.editor().update(cx, |editor, cx| {
794                    editor.set_text("somemodel", window, cx);
795                });
796            });
797
798            model_input.capabilities.supports_tools = ToggleState::Selected;
799            model_input.capabilities.supports_images = ToggleState::Unselected;
800            model_input.capabilities.supports_parallel_tool_calls = ToggleState::Selected;
801            model_input.capabilities.supports_prompt_cache_key = ToggleState::Unselected;
802            model_input.capabilities.supports_chat_completions = ToggleState::Selected;
803
804            let parsed_model = model_input.parse(cx).unwrap();
805            assert_eq!(parsed_model.name, "somemodel");
806            assert!(parsed_model.capabilities.tools);
807            assert!(!parsed_model.capabilities.images);
808            assert!(parsed_model.capabilities.parallel_tool_calls);
809            assert!(!parsed_model.capabilities.prompt_cache_key);
810            assert!(parsed_model.capabilities.chat_completions);
811        });
812    }
813
814    async fn setup_test(cx: &mut TestAppContext) -> &mut VisualTestContext {
815        cx.update(|cx| {
816            let store = SettingsStore::test(cx);
817            cx.set_global(store);
818            theme::init(theme::LoadThemes::JustBase, cx);
819
820            language_model::init_settings(cx);
821        });
822
823        let fs = FakeFs::new(cx.executor());
824        cx.update(|cx| <dyn Fs>::set_global(fs.clone(), cx));
825        let project = Project::test(fs, [path!("/dir").as_ref()], cx).await;
826        let (_, cx) =
827            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
828
829        cx
830    }
831
832    async fn save_provider_validation_errors(
833        provider_name: &str,
834        api_url: &str,
835        api_key: &str,
836        models: Vec<(&str, &str, &str, &str)>,
837        cx: &mut VisualTestContext,
838    ) -> Option<SharedString> {
839        fn set_text(input: &Entity<InputField>, text: &str, window: &mut Window, cx: &mut App) {
840            input.update(cx, |input, cx| {
841                input.editor().update(cx, |editor, cx| {
842                    editor.set_text(text, window, cx);
843                });
844            });
845        }
846
847        let task = cx.update(|window, cx| {
848            let mut input = AddLlmProviderInput::new(LlmCompatibleProvider::OpenAi, window, cx);
849            set_text(&input.provider_name, provider_name, window, cx);
850            set_text(&input.api_url, api_url, window, cx);
851            set_text(&input.api_key, api_key, window, cx);
852
853            for (i, (name, max_tokens, max_completion_tokens, max_output_tokens)) in
854                models.iter().enumerate()
855            {
856                if i >= input.models.len() {
857                    input.models.push(ModelInput::new(i, window, cx));
858                }
859                let model = &mut input.models[i];
860                set_text(&model.name, name, window, cx);
861                set_text(&model.max_tokens, max_tokens, window, cx);
862                set_text(
863                    &model.max_completion_tokens,
864                    max_completion_tokens,
865                    window,
866                    cx,
867                );
868                set_text(&model.max_output_tokens, max_output_tokens, window, cx);
869            }
870            save_provider_to_settings(&input, cx)
871        });
872
873        task.await.err()
874    }
875}