1use std::sync::Arc;
2
3use anyhow::Result;
4use collections::HashSet;
5use fs::Fs;
6use gpui::{
7 DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, Render, ScrollHandle, Task,
8};
9use language_model::LanguageModelRegistry;
10use language_models::provider::open_ai_compatible::{AvailableModel, ModelCapabilities};
11use settings::{OpenAiCompatibleSettingsContent, update_settings_file};
12use ui::{
13 Banner, Checkbox, KeyBinding, Modal, ModalFooter, ModalHeader, Section, ToggleState,
14 WithScrollbar, prelude::*,
15};
16use ui_input::InputField;
17use workspace::{ModalView, Workspace};
18
19fn single_line_input(
20 label: impl Into<SharedString>,
21 placeholder: impl Into<SharedString>,
22 text: Option<&str>,
23 tab_index: isize,
24 window: &mut Window,
25 cx: &mut App,
26) -> Entity<InputField> {
27 cx.new(|cx| {
28 let input = InputField::new(window, cx, placeholder)
29 .label(label)
30 .tab_index(tab_index)
31 .tab_stop(true);
32
33 if let Some(text) = text {
34 input
35 .editor()
36 .update(cx, |editor, cx| editor.set_text(text, window, cx));
37 }
38 input
39 })
40}
41
42#[derive(Clone, Copy)]
43pub enum LlmCompatibleProvider {
44 OpenAi,
45}
46
47impl LlmCompatibleProvider {
48 fn name(&self) -> &'static str {
49 match self {
50 LlmCompatibleProvider::OpenAi => "OpenAI",
51 }
52 }
53
54 fn api_url(&self) -> &'static str {
55 match self {
56 LlmCompatibleProvider::OpenAi => "https://api.openai.com/v1",
57 }
58 }
59}
60
61struct AddLlmProviderInput {
62 provider_name: Entity<InputField>,
63 api_url: Entity<InputField>,
64 api_key: Entity<InputField>,
65 models: Vec<ModelInput>,
66}
67
68impl AddLlmProviderInput {
69 fn new(provider: LlmCompatibleProvider, window: &mut Window, cx: &mut App) -> Self {
70 let provider_name =
71 single_line_input("Provider Name", provider.name(), None, 1, window, cx);
72 let api_url = single_line_input("API URL", provider.api_url(), None, 2, window, cx);
73 let api_key = single_line_input(
74 "API Key",
75 "000000000000000000000000000000000000000000000000",
76 None,
77 3,
78 window,
79 cx,
80 );
81
82 Self {
83 provider_name,
84 api_url,
85 api_key,
86 models: vec![ModelInput::new(0, window, cx)],
87 }
88 }
89
90 fn add_model(&mut self, window: &mut Window, cx: &mut App) {
91 let model_index = self.models.len();
92 self.models.push(ModelInput::new(model_index, window, cx));
93 }
94
95 fn remove_model(&mut self, index: usize) {
96 self.models.remove(index);
97 }
98}
99
100struct ModelCapabilityToggles {
101 pub supports_tools: ToggleState,
102 pub supports_images: ToggleState,
103 pub supports_parallel_tool_calls: ToggleState,
104 pub supports_prompt_cache_key: ToggleState,
105 pub supports_chat_completions: ToggleState,
106}
107
108struct ModelInput {
109 name: Entity<InputField>,
110 max_completion_tokens: Entity<InputField>,
111 max_output_tokens: Entity<InputField>,
112 max_tokens: Entity<InputField>,
113 capabilities: ModelCapabilityToggles,
114}
115
116impl ModelInput {
117 fn new(model_index: usize, window: &mut Window, cx: &mut App) -> Self {
118 let base_tab_index = (3 + (model_index * 4)) as isize;
119
120 let model_name = single_line_input(
121 "Model Name",
122 "e.g. gpt-4o, claude-opus-4, gemini-2.5-pro",
123 None,
124 base_tab_index + 1,
125 window,
126 cx,
127 );
128 let max_completion_tokens = single_line_input(
129 "Max Completion Tokens",
130 "200000",
131 Some("200000"),
132 base_tab_index + 2,
133 window,
134 cx,
135 );
136 let max_output_tokens = single_line_input(
137 "Max Output Tokens",
138 "Max Output Tokens",
139 Some("32000"),
140 base_tab_index + 3,
141 window,
142 cx,
143 );
144 let max_tokens = single_line_input(
145 "Max Tokens",
146 "Max Tokens",
147 Some("200000"),
148 base_tab_index + 4,
149 window,
150 cx,
151 );
152
153 let ModelCapabilities {
154 tools,
155 images,
156 parallel_tool_calls,
157 prompt_cache_key,
158 chat_completions,
159 } = ModelCapabilities::default();
160
161 Self {
162 name: model_name,
163 max_completion_tokens,
164 max_output_tokens,
165 max_tokens,
166 capabilities: ModelCapabilityToggles {
167 supports_tools: tools.into(),
168 supports_images: images.into(),
169 supports_parallel_tool_calls: parallel_tool_calls.into(),
170 supports_prompt_cache_key: prompt_cache_key.into(),
171 supports_chat_completions: chat_completions.into(),
172 },
173 }
174 }
175
176 fn parse(&self, cx: &App) -> Result<AvailableModel, SharedString> {
177 let name = self.name.read(cx).text(cx);
178 if name.is_empty() {
179 return Err(SharedString::from("Model Name cannot be empty"));
180 }
181 Ok(AvailableModel {
182 name,
183 display_name: None,
184 max_completion_tokens: Some(
185 self.max_completion_tokens
186 .read(cx)
187 .text(cx)
188 .parse::<u64>()
189 .map_err(|_| SharedString::from("Max Completion Tokens must be a number"))?,
190 ),
191 max_output_tokens: Some(
192 self.max_output_tokens
193 .read(cx)
194 .text(cx)
195 .parse::<u64>()
196 .map_err(|_| SharedString::from("Max Output Tokens must be a number"))?,
197 ),
198 max_tokens: self
199 .max_tokens
200 .read(cx)
201 .text(cx)
202 .parse::<u64>()
203 .map_err(|_| SharedString::from("Max Tokens must be a number"))?,
204 capabilities: ModelCapabilities {
205 tools: self.capabilities.supports_tools.selected(),
206 images: self.capabilities.supports_images.selected(),
207 parallel_tool_calls: self.capabilities.supports_parallel_tool_calls.selected(),
208 prompt_cache_key: self.capabilities.supports_prompt_cache_key.selected(),
209 chat_completions: self.capabilities.supports_chat_completions.selected(),
210 },
211 })
212 }
213}
214
215fn save_provider_to_settings(
216 input: &AddLlmProviderInput,
217 cx: &mut App,
218) -> Task<Result<(), SharedString>> {
219 let provider_name: Arc<str> = input.provider_name.read(cx).text(cx).into();
220 if provider_name.is_empty() {
221 return Task::ready(Err("Provider Name cannot be empty".into()));
222 }
223
224 if LanguageModelRegistry::read_global(cx)
225 .providers()
226 .iter()
227 .any(|provider| {
228 provider.id().0.as_ref() == provider_name.as_ref()
229 || provider.name().0.as_ref() == provider_name.as_ref()
230 })
231 {
232 return Task::ready(Err(
233 "Provider Name is already taken by another provider".into()
234 ));
235 }
236
237 let api_url = input.api_url.read(cx).text(cx);
238 if api_url.is_empty() {
239 return Task::ready(Err("API URL cannot be empty".into()));
240 }
241
242 let api_key = input.api_key.read(cx).text(cx);
243 if api_key.is_empty() {
244 return Task::ready(Err("API Key cannot be empty".into()));
245 }
246
247 let mut models = Vec::new();
248 let mut model_names: HashSet<String> = HashSet::default();
249 for model in &input.models {
250 match model.parse(cx) {
251 Ok(model) => {
252 if !model_names.insert(model.name.clone()) {
253 return Task::ready(Err("Model Names must be unique".into()));
254 }
255 models.push(model)
256 }
257 Err(err) => return Task::ready(Err(err)),
258 }
259 }
260
261 let fs = <dyn Fs>::global(cx);
262 let task = cx.write_credentials(&api_url, "Bearer", api_key.as_bytes());
263 cx.spawn(async move |cx| {
264 task.await
265 .map_err(|_| "Failed to write API key to keychain")?;
266 cx.update(|cx| {
267 update_settings_file(fs, cx, |settings, _cx| {
268 settings
269 .language_models
270 .get_or_insert_default()
271 .openai_compatible
272 .get_or_insert_default()
273 .insert(
274 provider_name,
275 OpenAiCompatibleSettingsContent {
276 api_url,
277 available_models: models,
278 },
279 );
280 });
281 })
282 .ok();
283 Ok(())
284 })
285}
286
287pub struct AddLlmProviderModal {
288 provider: LlmCompatibleProvider,
289 input: AddLlmProviderInput,
290 scroll_handle: ScrollHandle,
291 focus_handle: FocusHandle,
292 last_error: Option<SharedString>,
293}
294
295impl AddLlmProviderModal {
296 pub fn toggle(
297 provider: LlmCompatibleProvider,
298 workspace: &mut Workspace,
299 window: &mut Window,
300 cx: &mut Context<Workspace>,
301 ) {
302 workspace.toggle_modal(window, cx, |window, cx| Self::new(provider, window, cx));
303 }
304
305 fn new(provider: LlmCompatibleProvider, window: &mut Window, cx: &mut Context<Self>) -> Self {
306 Self {
307 input: AddLlmProviderInput::new(provider, window, cx),
308 provider,
309 last_error: None,
310 focus_handle: cx.focus_handle(),
311 scroll_handle: ScrollHandle::new(),
312 }
313 }
314
315 fn confirm(&mut self, _: &menu::Confirm, _: &mut Window, cx: &mut Context<Self>) {
316 let task = save_provider_to_settings(&self.input, cx);
317 cx.spawn(async move |this, cx| {
318 let result = task.await;
319 this.update(cx, |this, cx| match result {
320 Ok(_) => {
321 cx.emit(DismissEvent);
322 }
323 Err(error) => {
324 this.last_error = Some(error);
325 cx.notify();
326 }
327 })
328 })
329 .detach_and_log_err(cx);
330 }
331
332 fn cancel(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context<Self>) {
333 cx.emit(DismissEvent);
334 }
335
336 fn render_model_section(&self, cx: &mut Context<Self>) -> impl IntoElement {
337 v_flex()
338 .mt_1()
339 .gap_2()
340 .child(
341 h_flex()
342 .justify_between()
343 .child(Label::new("Models").size(LabelSize::Small))
344 .child(
345 Button::new("add-model", "Add Model")
346 .icon(IconName::Plus)
347 .icon_position(IconPosition::Start)
348 .icon_size(IconSize::XSmall)
349 .icon_color(Color::Muted)
350 .label_size(LabelSize::Small)
351 .on_click(cx.listener(|this, _, window, cx| {
352 this.input.add_model(window, cx);
353 cx.notify();
354 })),
355 ),
356 )
357 .children(
358 self.input
359 .models
360 .iter()
361 .enumerate()
362 .map(|(ix, _)| self.render_model(ix, cx)),
363 )
364 }
365
366 fn render_model(&self, ix: usize, cx: &mut Context<Self>) -> impl IntoElement + use<> {
367 let has_more_than_one_model = self.input.models.len() > 1;
368 let model = &self.input.models[ix];
369
370 v_flex()
371 .p_2()
372 .gap_2()
373 .rounded_sm()
374 .border_1()
375 .border_dashed()
376 .border_color(cx.theme().colors().border.opacity(0.6))
377 .bg(cx.theme().colors().element_active.opacity(0.15))
378 .child(model.name.clone())
379 .child(
380 h_flex()
381 .gap_2()
382 .child(model.max_completion_tokens.clone())
383 .child(model.max_output_tokens.clone()),
384 )
385 .child(model.max_tokens.clone())
386 .child(
387 v_flex()
388 .gap_1()
389 .child(
390 Checkbox::new(("supports-tools", ix), model.capabilities.supports_tools)
391 .label("Supports tools")
392 .on_click(cx.listener(move |this, checked, _window, cx| {
393 this.input.models[ix].capabilities.supports_tools = *checked;
394 cx.notify();
395 })),
396 )
397 .child(
398 Checkbox::new(("supports-images", ix), model.capabilities.supports_images)
399 .label("Supports images")
400 .on_click(cx.listener(move |this, checked, _window, cx| {
401 this.input.models[ix].capabilities.supports_images = *checked;
402 cx.notify();
403 })),
404 )
405 .child(
406 Checkbox::new(
407 ("supports-parallel-tool-calls", ix),
408 model.capabilities.supports_parallel_tool_calls,
409 )
410 .label("Supports parallel_tool_calls")
411 .on_click(cx.listener(
412 move |this, checked, _window, cx| {
413 this.input.models[ix]
414 .capabilities
415 .supports_parallel_tool_calls = *checked;
416 cx.notify();
417 },
418 )),
419 )
420 .child(
421 Checkbox::new(
422 ("supports-prompt-cache-key", ix),
423 model.capabilities.supports_prompt_cache_key,
424 )
425 .label("Supports prompt_cache_key")
426 .on_click(cx.listener(
427 move |this, checked, _window, cx| {
428 this.input.models[ix].capabilities.supports_prompt_cache_key =
429 *checked;
430 cx.notify();
431 },
432 )),
433 )
434 .child(
435 Checkbox::new(
436 ("supports-chat-completions", ix),
437 model.capabilities.supports_chat_completions,
438 )
439 .label("Supports /chat/completions")
440 .on_click(cx.listener(
441 move |this, checked, _window, cx| {
442 this.input.models[ix].capabilities.supports_chat_completions =
443 *checked;
444 cx.notify();
445 },
446 )),
447 ),
448 )
449 .when(has_more_than_one_model, |this| {
450 this.child(
451 Button::new(("remove-model", ix), "Remove Model")
452 .icon(IconName::Trash)
453 .icon_position(IconPosition::Start)
454 .icon_size(IconSize::XSmall)
455 .icon_color(Color::Muted)
456 .label_size(LabelSize::Small)
457 .style(ButtonStyle::Outlined)
458 .full_width()
459 .on_click(cx.listener(move |this, _, _window, cx| {
460 this.input.remove_model(ix);
461 cx.notify();
462 })),
463 )
464 })
465 }
466
467 fn on_tab(&mut self, _: &menu::SelectNext, window: &mut Window, cx: &mut Context<Self>) {
468 window.focus_next(cx);
469 }
470
471 fn on_tab_prev(
472 &mut self,
473 _: &menu::SelectPrevious,
474 window: &mut Window,
475 cx: &mut Context<Self>,
476 ) {
477 window.focus_prev(cx);
478 }
479}
480
481impl EventEmitter<DismissEvent> for AddLlmProviderModal {}
482
483impl Focusable for AddLlmProviderModal {
484 fn focus_handle(&self, _cx: &App) -> FocusHandle {
485 self.focus_handle.clone()
486 }
487}
488
489impl ModalView for AddLlmProviderModal {}
490
491impl Render for AddLlmProviderModal {
492 fn render(&mut self, window: &mut ui::Window, cx: &mut ui::Context<Self>) -> impl IntoElement {
493 let focus_handle = self.focus_handle(cx);
494
495 let window_size = window.viewport_size();
496 let rem_size = window.rem_size();
497 let is_large_window = window_size.height / rem_size > rems_from_px(600.).0;
498
499 let modal_max_height = if is_large_window {
500 rems_from_px(450.)
501 } else {
502 rems_from_px(200.)
503 };
504
505 v_flex()
506 .id("add-llm-provider-modal")
507 .key_context("AddLlmProviderModal")
508 .w(rems(34.))
509 .elevation_3(cx)
510 .on_action(cx.listener(Self::cancel))
511 .on_action(cx.listener(Self::on_tab))
512 .on_action(cx.listener(Self::on_tab_prev))
513 .capture_any_mouse_down(cx.listener(|this, _, window, cx| {
514 this.focus_handle(cx).focus(window, cx);
515 }))
516 .child(
517 Modal::new("configure-context-server", None)
518 .header(ModalHeader::new().headline("Add LLM Provider").description(
519 match self.provider {
520 LlmCompatibleProvider::OpenAi => {
521 "This provider will use an OpenAI compatible API."
522 }
523 },
524 ))
525 .when_some(self.last_error.clone(), |this, error| {
526 this.section(
527 Section::new().child(
528 Banner::new()
529 .severity(Severity::Warning)
530 .child(div().text_xs().child(error)),
531 ),
532 )
533 })
534 .child(
535 div()
536 .size_full()
537 .vertical_scrollbar_for(&self.scroll_handle, window, cx)
538 .child(
539 v_flex()
540 .id("modal_content")
541 .size_full()
542 .tab_group()
543 .max_h(modal_max_height)
544 .pl_3()
545 .pr_4()
546 .gap_2()
547 .overflow_y_scroll()
548 .track_scroll(&self.scroll_handle)
549 .child(self.input.provider_name.clone())
550 .child(self.input.api_url.clone())
551 .child(self.input.api_key.clone())
552 .child(self.render_model_section(cx)),
553 ),
554 )
555 .footer(
556 ModalFooter::new().end_slot(
557 h_flex()
558 .gap_1()
559 .child(
560 Button::new("cancel", "Cancel")
561 .key_binding(
562 KeyBinding::for_action_in(
563 &menu::Cancel,
564 &focus_handle,
565 cx,
566 )
567 .map(|kb| kb.size(rems_from_px(12.))),
568 )
569 .on_click(cx.listener(|this, _event, window, cx| {
570 this.cancel(&menu::Cancel, window, cx)
571 })),
572 )
573 .child(
574 Button::new("save-server", "Save Provider")
575 .key_binding(
576 KeyBinding::for_action_in(
577 &menu::Confirm,
578 &focus_handle,
579 cx,
580 )
581 .map(|kb| kb.size(rems_from_px(12.))),
582 )
583 .on_click(cx.listener(|this, _event, window, cx| {
584 this.confirm(&menu::Confirm, window, cx)
585 })),
586 ),
587 ),
588 ),
589 )
590 }
591}
592
593#[cfg(test)]
594mod tests {
595 use super::*;
596 use fs::FakeFs;
597 use gpui::{TestAppContext, VisualTestContext};
598 use language_model::{
599 LanguageModelProviderId, LanguageModelProviderName,
600 fake_provider::FakeLanguageModelProvider,
601 };
602 use project::Project;
603 use settings::SettingsStore;
604 use util::path;
605
606 #[gpui::test]
607 async fn test_save_provider_invalid_inputs(cx: &mut TestAppContext) {
608 let cx = setup_test(cx).await;
609
610 assert_eq!(
611 save_provider_validation_errors("", "someurl", "somekey", vec![], cx,).await,
612 Some("Provider Name cannot be empty".into())
613 );
614
615 assert_eq!(
616 save_provider_validation_errors("someprovider", "", "somekey", vec![], cx,).await,
617 Some("API URL cannot be empty".into())
618 );
619
620 assert_eq!(
621 save_provider_validation_errors("someprovider", "someurl", "", vec![], cx,).await,
622 Some("API Key cannot be empty".into())
623 );
624
625 assert_eq!(
626 save_provider_validation_errors(
627 "someprovider",
628 "someurl",
629 "somekey",
630 vec![("", "200000", "200000", "32000")],
631 cx,
632 )
633 .await,
634 Some("Model Name cannot be empty".into())
635 );
636
637 assert_eq!(
638 save_provider_validation_errors(
639 "someprovider",
640 "someurl",
641 "somekey",
642 vec![("somemodel", "abc", "200000", "32000")],
643 cx,
644 )
645 .await,
646 Some("Max Tokens must be a number".into())
647 );
648
649 assert_eq!(
650 save_provider_validation_errors(
651 "someprovider",
652 "someurl",
653 "somekey",
654 vec![("somemodel", "200000", "abc", "32000")],
655 cx,
656 )
657 .await,
658 Some("Max Completion Tokens must be a number".into())
659 );
660
661 assert_eq!(
662 save_provider_validation_errors(
663 "someprovider",
664 "someurl",
665 "somekey",
666 vec![("somemodel", "200000", "200000", "abc")],
667 cx,
668 )
669 .await,
670 Some("Max Output Tokens must be a number".into())
671 );
672
673 assert_eq!(
674 save_provider_validation_errors(
675 "someprovider",
676 "someurl",
677 "somekey",
678 vec![
679 ("somemodel", "200000", "200000", "32000"),
680 ("somemodel", "200000", "200000", "32000"),
681 ],
682 cx,
683 )
684 .await,
685 Some("Model Names must be unique".into())
686 );
687 }
688
689 #[gpui::test]
690 async fn test_save_provider_name_conflict(cx: &mut TestAppContext) {
691 let cx = setup_test(cx).await;
692
693 cx.update(|_window, cx| {
694 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
695 registry.register_provider(
696 Arc::new(FakeLanguageModelProvider::new(
697 LanguageModelProviderId::new("someprovider"),
698 LanguageModelProviderName::new("Some Provider"),
699 )),
700 cx,
701 );
702 });
703 });
704
705 assert_eq!(
706 save_provider_validation_errors(
707 "someprovider",
708 "someurl",
709 "someapikey",
710 vec![("somemodel", "200000", "200000", "32000")],
711 cx,
712 )
713 .await,
714 Some("Provider Name is already taken by another provider".into())
715 );
716 }
717
718 #[gpui::test]
719 async fn test_model_input_default_capabilities(cx: &mut TestAppContext) {
720 let cx = setup_test(cx).await;
721
722 cx.update(|window, cx| {
723 let model_input = ModelInput::new(0, window, cx);
724 model_input.name.update(cx, |input, cx| {
725 input.editor().update(cx, |editor, cx| {
726 editor.set_text("somemodel", window, cx);
727 });
728 });
729 assert_eq!(
730 model_input.capabilities.supports_tools,
731 ToggleState::Selected
732 );
733 assert_eq!(
734 model_input.capabilities.supports_images,
735 ToggleState::Unselected
736 );
737 assert_eq!(
738 model_input.capabilities.supports_parallel_tool_calls,
739 ToggleState::Unselected
740 );
741 assert_eq!(
742 model_input.capabilities.supports_prompt_cache_key,
743 ToggleState::Unselected
744 );
745 assert_eq!(
746 model_input.capabilities.supports_chat_completions,
747 ToggleState::Selected
748 );
749
750 let parsed_model = model_input.parse(cx).unwrap();
751 assert!(parsed_model.capabilities.tools);
752 assert!(!parsed_model.capabilities.images);
753 assert!(!parsed_model.capabilities.parallel_tool_calls);
754 assert!(!parsed_model.capabilities.prompt_cache_key);
755 assert!(parsed_model.capabilities.chat_completions);
756 });
757 }
758
759 #[gpui::test]
760 async fn test_model_input_deselected_capabilities(cx: &mut TestAppContext) {
761 let cx = setup_test(cx).await;
762
763 cx.update(|window, cx| {
764 let mut model_input = ModelInput::new(0, window, cx);
765 model_input.name.update(cx, |input, cx| {
766 input.editor().update(cx, |editor, cx| {
767 editor.set_text("somemodel", window, cx);
768 });
769 });
770
771 model_input.capabilities.supports_tools = ToggleState::Unselected;
772 model_input.capabilities.supports_images = ToggleState::Unselected;
773 model_input.capabilities.supports_parallel_tool_calls = ToggleState::Unselected;
774 model_input.capabilities.supports_prompt_cache_key = ToggleState::Unselected;
775 model_input.capabilities.supports_chat_completions = ToggleState::Unselected;
776
777 let parsed_model = model_input.parse(cx).unwrap();
778 assert!(!parsed_model.capabilities.tools);
779 assert!(!parsed_model.capabilities.images);
780 assert!(!parsed_model.capabilities.parallel_tool_calls);
781 assert!(!parsed_model.capabilities.prompt_cache_key);
782 assert!(!parsed_model.capabilities.chat_completions);
783 });
784 }
785
786 #[gpui::test]
787 async fn test_model_input_with_name_and_capabilities(cx: &mut TestAppContext) {
788 let cx = setup_test(cx).await;
789
790 cx.update(|window, cx| {
791 let mut model_input = ModelInput::new(0, window, cx);
792 model_input.name.update(cx, |input, cx| {
793 input.editor().update(cx, |editor, cx| {
794 editor.set_text("somemodel", window, cx);
795 });
796 });
797
798 model_input.capabilities.supports_tools = ToggleState::Selected;
799 model_input.capabilities.supports_images = ToggleState::Unselected;
800 model_input.capabilities.supports_parallel_tool_calls = ToggleState::Selected;
801 model_input.capabilities.supports_prompt_cache_key = ToggleState::Unselected;
802 model_input.capabilities.supports_chat_completions = ToggleState::Selected;
803
804 let parsed_model = model_input.parse(cx).unwrap();
805 assert_eq!(parsed_model.name, "somemodel");
806 assert!(parsed_model.capabilities.tools);
807 assert!(!parsed_model.capabilities.images);
808 assert!(parsed_model.capabilities.parallel_tool_calls);
809 assert!(!parsed_model.capabilities.prompt_cache_key);
810 assert!(parsed_model.capabilities.chat_completions);
811 });
812 }
813
814 async fn setup_test(cx: &mut TestAppContext) -> &mut VisualTestContext {
815 cx.update(|cx| {
816 let store = SettingsStore::test(cx);
817 cx.set_global(store);
818 theme::init(theme::LoadThemes::JustBase, cx);
819
820 language_model::init_settings(cx);
821 });
822
823 let fs = FakeFs::new(cx.executor());
824 cx.update(|cx| <dyn Fs>::set_global(fs.clone(), cx));
825 let project = Project::test(fs, [path!("/dir").as_ref()], cx).await;
826 let (_, cx) =
827 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
828
829 cx
830 }
831
832 async fn save_provider_validation_errors(
833 provider_name: &str,
834 api_url: &str,
835 api_key: &str,
836 models: Vec<(&str, &str, &str, &str)>,
837 cx: &mut VisualTestContext,
838 ) -> Option<SharedString> {
839 fn set_text(input: &Entity<InputField>, text: &str, window: &mut Window, cx: &mut App) {
840 input.update(cx, |input, cx| {
841 input.editor().update(cx, |editor, cx| {
842 editor.set_text(text, window, cx);
843 });
844 });
845 }
846
847 let task = cx.update(|window, cx| {
848 let mut input = AddLlmProviderInput::new(LlmCompatibleProvider::OpenAi, window, cx);
849 set_text(&input.provider_name, provider_name, window, cx);
850 set_text(&input.api_url, api_url, window, cx);
851 set_text(&input.api_key, api_key, window, cx);
852
853 for (i, (name, max_tokens, max_completion_tokens, max_output_tokens)) in
854 models.iter().enumerate()
855 {
856 if i >= input.models.len() {
857 input.models.push(ModelInput::new(i, window, cx));
858 }
859 let model = &mut input.models[i];
860 set_text(&model.name, name, window, cx);
861 set_text(&model.max_tokens, max_tokens, window, cx);
862 set_text(
863 &model.max_completion_tokens,
864 max_completion_tokens,
865 window,
866 cx,
867 );
868 set_text(&model.max_output_tokens, max_output_tokens, window, cx);
869 }
870 save_provider_to_settings(&input, cx)
871 });
872
873 task.await.err()
874 }
875}