1use std::sync::Arc;
2
3use anyhow::Result;
4use collections::HashSet;
5use fs::Fs;
6use gpui::{
7 DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, Render, ScrollHandle, Task,
8};
9use language_model::LanguageModelRegistry;
10use language_models::provider::open_ai_compatible::{AvailableModel, ModelCapabilities};
11use settings::{OpenAiCompatibleSettingsContent, update_settings_file};
12use ui::{
13 Banner, Checkbox, KeyBinding, Modal, ModalFooter, ModalHeader, Section, ToggleState,
14 WithScrollbar, prelude::*,
15};
16use ui_input::InputField;
17use workspace::{ModalView, Workspace};
18
19fn single_line_input(
20 label: impl Into<SharedString>,
21 placeholder: impl Into<SharedString>,
22 text: Option<&str>,
23 tab_index: isize,
24 window: &mut Window,
25 cx: &mut App,
26) -> Entity<InputField> {
27 cx.new(|cx| {
28 let input = InputField::new(window, cx, placeholder)
29 .label(label)
30 .tab_index(tab_index)
31 .tab_stop(true);
32
33 if let Some(text) = text {
34 input
35 .editor()
36 .update(cx, |editor, cx| editor.set_text(text, window, cx));
37 }
38 input
39 })
40}
41
42#[derive(Clone, Copy)]
43pub enum LlmCompatibleProvider {
44 OpenAi,
45}
46
47impl LlmCompatibleProvider {
48 fn name(&self) -> &'static str {
49 match self {
50 LlmCompatibleProvider::OpenAi => "OpenAI",
51 }
52 }
53
54 fn api_url(&self) -> &'static str {
55 match self {
56 LlmCompatibleProvider::OpenAi => "https://api.openai.com/v1",
57 }
58 }
59}
60
61struct AddLlmProviderInput {
62 provider_name: Entity<InputField>,
63 api_url: Entity<InputField>,
64 api_key: Entity<InputField>,
65 models: Vec<ModelInput>,
66}
67
68impl AddLlmProviderInput {
69 fn new(provider: LlmCompatibleProvider, window: &mut Window, cx: &mut App) -> Self {
70 let provider_name =
71 single_line_input("Provider Name", provider.name(), None, 1, window, cx);
72 let api_url = single_line_input("API URL", provider.api_url(), None, 2, window, cx);
73 let api_key = single_line_input(
74 "API Key",
75 "000000000000000000000000000000000000000000000000",
76 None,
77 3,
78 window,
79 cx,
80 );
81
82 Self {
83 provider_name,
84 api_url,
85 api_key,
86 models: vec![ModelInput::new(0, window, cx)],
87 }
88 }
89
90 fn add_model(&mut self, window: &mut Window, cx: &mut App) {
91 let model_index = self.models.len();
92 self.models.push(ModelInput::new(model_index, window, cx));
93 }
94
95 fn remove_model(&mut self, index: usize) {
96 self.models.remove(index);
97 }
98}
99
100struct ModelCapabilityToggles {
101 pub supports_tools: ToggleState,
102 pub supports_images: ToggleState,
103 pub supports_parallel_tool_calls: ToggleState,
104 pub supports_prompt_cache_key: ToggleState,
105 pub supports_chat_completions: ToggleState,
106}
107
108struct ModelInput {
109 name: Entity<InputField>,
110 max_completion_tokens: Entity<InputField>,
111 max_output_tokens: Entity<InputField>,
112 max_tokens: Entity<InputField>,
113 capabilities: ModelCapabilityToggles,
114}
115
116impl ModelInput {
117 fn new(model_index: usize, window: &mut Window, cx: &mut App) -> Self {
118 let base_tab_index = (3 + (model_index * 4)) as isize;
119
120 let model_name = single_line_input(
121 "Model Name",
122 "e.g. gpt-4o, claude-opus-4, gemini-2.5-pro",
123 None,
124 base_tab_index + 1,
125 window,
126 cx,
127 );
128 let max_completion_tokens = single_line_input(
129 "Max Completion Tokens",
130 "200000",
131 Some("200000"),
132 base_tab_index + 2,
133 window,
134 cx,
135 );
136 let max_output_tokens = single_line_input(
137 "Max Output Tokens",
138 "Max Output Tokens",
139 Some("32000"),
140 base_tab_index + 3,
141 window,
142 cx,
143 );
144 let max_tokens = single_line_input(
145 "Max Tokens",
146 "Max Tokens",
147 Some("200000"),
148 base_tab_index + 4,
149 window,
150 cx,
151 );
152
153 let ModelCapabilities {
154 tools,
155 images,
156 parallel_tool_calls,
157 prompt_cache_key,
158 chat_completions,
159 } = ModelCapabilities::default();
160
161 Self {
162 name: model_name,
163 max_completion_tokens,
164 max_output_tokens,
165 max_tokens,
166 capabilities: ModelCapabilityToggles {
167 supports_tools: tools.into(),
168 supports_images: images.into(),
169 supports_parallel_tool_calls: parallel_tool_calls.into(),
170 supports_prompt_cache_key: prompt_cache_key.into(),
171 supports_chat_completions: chat_completions.into(),
172 },
173 }
174 }
175
176 fn parse(&self, cx: &App) -> Result<AvailableModel, SharedString> {
177 let name = self.name.read(cx).text(cx);
178 if name.is_empty() {
179 return Err(SharedString::from("Model Name cannot be empty"));
180 }
181 Ok(AvailableModel {
182 name,
183 display_name: None,
184 max_completion_tokens: Some(
185 self.max_completion_tokens
186 .read(cx)
187 .text(cx)
188 .parse::<u64>()
189 .map_err(|_| SharedString::from("Max Completion Tokens must be a number"))?,
190 ),
191 max_output_tokens: Some(
192 self.max_output_tokens
193 .read(cx)
194 .text(cx)
195 .parse::<u64>()
196 .map_err(|_| SharedString::from("Max Output Tokens must be a number"))?,
197 ),
198 max_tokens: self
199 .max_tokens
200 .read(cx)
201 .text(cx)
202 .parse::<u64>()
203 .map_err(|_| SharedString::from("Max Tokens must be a number"))?,
204 capabilities: ModelCapabilities {
205 tools: self.capabilities.supports_tools.selected(),
206 images: self.capabilities.supports_images.selected(),
207 parallel_tool_calls: self.capabilities.supports_parallel_tool_calls.selected(),
208 prompt_cache_key: self.capabilities.supports_prompt_cache_key.selected(),
209 chat_completions: self.capabilities.supports_chat_completions.selected(),
210 },
211 })
212 }
213}
214
215fn save_provider_to_settings(
216 input: &AddLlmProviderInput,
217 cx: &mut App,
218) -> Task<Result<(), SharedString>> {
219 let provider_name: Arc<str> = input.provider_name.read(cx).text(cx).into();
220 if provider_name.is_empty() {
221 return Task::ready(Err("Provider Name cannot be empty".into()));
222 }
223
224 if LanguageModelRegistry::read_global(cx)
225 .providers()
226 .iter()
227 .any(|provider| {
228 provider.id().0.as_ref() == provider_name.as_ref()
229 || provider.name().0.as_ref() == provider_name.as_ref()
230 })
231 {
232 return Task::ready(Err(
233 "Provider Name is already taken by another provider".into()
234 ));
235 }
236
237 let api_url = input.api_url.read(cx).text(cx);
238 if api_url.is_empty() {
239 return Task::ready(Err("API URL cannot be empty".into()));
240 }
241
242 let api_key = input.api_key.read(cx).text(cx);
243 if api_key.is_empty() {
244 return Task::ready(Err("API Key cannot be empty".into()));
245 }
246
247 let mut models = Vec::new();
248 let mut model_names: HashSet<String> = HashSet::default();
249 for model in &input.models {
250 match model.parse(cx) {
251 Ok(model) => {
252 if !model_names.insert(model.name.clone()) {
253 return Task::ready(Err("Model Names must be unique".into()));
254 }
255 models.push(model)
256 }
257 Err(err) => return Task::ready(Err(err)),
258 }
259 }
260
261 let fs = <dyn Fs>::global(cx);
262 let task = cx.write_credentials(&api_url, "Bearer", api_key.as_bytes());
263 cx.spawn(async move |cx| {
264 task.await
265 .map_err(|_| SharedString::from("Failed to write API key to keychain"))?;
266 cx.update(|cx| {
267 update_settings_file(fs, cx, |settings, _cx| {
268 settings
269 .language_models
270 .get_or_insert_default()
271 .openai_compatible
272 .get_or_insert_default()
273 .insert(
274 provider_name,
275 OpenAiCompatibleSettingsContent {
276 api_url,
277 available_models: models,
278 },
279 );
280 });
281 });
282 Ok(())
283 })
284}
285
286pub struct AddLlmProviderModal {
287 provider: LlmCompatibleProvider,
288 input: AddLlmProviderInput,
289 scroll_handle: ScrollHandle,
290 focus_handle: FocusHandle,
291 last_error: Option<SharedString>,
292}
293
294impl AddLlmProviderModal {
295 pub fn toggle(
296 provider: LlmCompatibleProvider,
297 workspace: &mut Workspace,
298 window: &mut Window,
299 cx: &mut Context<Workspace>,
300 ) {
301 workspace.toggle_modal(window, cx, |window, cx| Self::new(provider, window, cx));
302 }
303
304 fn new(provider: LlmCompatibleProvider, window: &mut Window, cx: &mut Context<Self>) -> Self {
305 Self {
306 input: AddLlmProviderInput::new(provider, window, cx),
307 provider,
308 last_error: None,
309 focus_handle: cx.focus_handle(),
310 scroll_handle: ScrollHandle::new(),
311 }
312 }
313
314 fn confirm(&mut self, _: &menu::Confirm, _: &mut Window, cx: &mut Context<Self>) {
315 let task = save_provider_to_settings(&self.input, cx);
316 cx.spawn(async move |this, cx| {
317 let result = task.await;
318 this.update(cx, |this, cx| match result {
319 Ok(_) => {
320 cx.emit(DismissEvent);
321 }
322 Err(error) => {
323 this.last_error = Some(error);
324 cx.notify();
325 }
326 })
327 })
328 .detach_and_log_err(cx);
329 }
330
331 fn cancel(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context<Self>) {
332 cx.emit(DismissEvent);
333 }
334
335 fn render_model_section(&self, cx: &mut Context<Self>) -> impl IntoElement {
336 v_flex()
337 .mt_1()
338 .gap_2()
339 .child(
340 h_flex()
341 .justify_between()
342 .child(Label::new("Models").size(LabelSize::Small))
343 .child(
344 Button::new("add-model", "Add Model")
345 .icon(IconName::Plus)
346 .icon_position(IconPosition::Start)
347 .icon_size(IconSize::XSmall)
348 .icon_color(Color::Muted)
349 .label_size(LabelSize::Small)
350 .on_click(cx.listener(|this, _, window, cx| {
351 this.input.add_model(window, cx);
352 cx.notify();
353 })),
354 ),
355 )
356 .children(
357 self.input
358 .models
359 .iter()
360 .enumerate()
361 .map(|(ix, _)| self.render_model(ix, cx)),
362 )
363 }
364
365 fn render_model(&self, ix: usize, cx: &mut Context<Self>) -> impl IntoElement + use<> {
366 let has_more_than_one_model = self.input.models.len() > 1;
367 let model = &self.input.models[ix];
368
369 v_flex()
370 .p_2()
371 .gap_2()
372 .rounded_sm()
373 .border_1()
374 .border_dashed()
375 .border_color(cx.theme().colors().border.opacity(0.6))
376 .bg(cx.theme().colors().element_active.opacity(0.15))
377 .child(model.name.clone())
378 .child(
379 h_flex()
380 .gap_2()
381 .child(model.max_completion_tokens.clone())
382 .child(model.max_output_tokens.clone()),
383 )
384 .child(model.max_tokens.clone())
385 .child(
386 v_flex()
387 .gap_1()
388 .child(
389 Checkbox::new(("supports-tools", ix), model.capabilities.supports_tools)
390 .label("Supports tools")
391 .on_click(cx.listener(move |this, checked, _window, cx| {
392 this.input.models[ix].capabilities.supports_tools = *checked;
393 cx.notify();
394 })),
395 )
396 .child(
397 Checkbox::new(("supports-images", ix), model.capabilities.supports_images)
398 .label("Supports images")
399 .on_click(cx.listener(move |this, checked, _window, cx| {
400 this.input.models[ix].capabilities.supports_images = *checked;
401 cx.notify();
402 })),
403 )
404 .child(
405 Checkbox::new(
406 ("supports-parallel-tool-calls", ix),
407 model.capabilities.supports_parallel_tool_calls,
408 )
409 .label("Supports parallel_tool_calls")
410 .on_click(cx.listener(
411 move |this, checked, _window, cx| {
412 this.input.models[ix]
413 .capabilities
414 .supports_parallel_tool_calls = *checked;
415 cx.notify();
416 },
417 )),
418 )
419 .child(
420 Checkbox::new(
421 ("supports-prompt-cache-key", ix),
422 model.capabilities.supports_prompt_cache_key,
423 )
424 .label("Supports prompt_cache_key")
425 .on_click(cx.listener(
426 move |this, checked, _window, cx| {
427 this.input.models[ix].capabilities.supports_prompt_cache_key =
428 *checked;
429 cx.notify();
430 },
431 )),
432 )
433 .child(
434 Checkbox::new(
435 ("supports-chat-completions", ix),
436 model.capabilities.supports_chat_completions,
437 )
438 .label("Supports /chat/completions")
439 .on_click(cx.listener(
440 move |this, checked, _window, cx| {
441 this.input.models[ix].capabilities.supports_chat_completions =
442 *checked;
443 cx.notify();
444 },
445 )),
446 ),
447 )
448 .when(has_more_than_one_model, |this| {
449 this.child(
450 Button::new(("remove-model", ix), "Remove Model")
451 .icon(IconName::Trash)
452 .icon_position(IconPosition::Start)
453 .icon_size(IconSize::XSmall)
454 .icon_color(Color::Muted)
455 .label_size(LabelSize::Small)
456 .style(ButtonStyle::Outlined)
457 .full_width()
458 .on_click(cx.listener(move |this, _, _window, cx| {
459 this.input.remove_model(ix);
460 cx.notify();
461 })),
462 )
463 })
464 }
465
466 fn on_tab(&mut self, _: &menu::SelectNext, window: &mut Window, cx: &mut Context<Self>) {
467 window.focus_next(cx);
468 }
469
470 fn on_tab_prev(
471 &mut self,
472 _: &menu::SelectPrevious,
473 window: &mut Window,
474 cx: &mut Context<Self>,
475 ) {
476 window.focus_prev(cx);
477 }
478}
479
480impl EventEmitter<DismissEvent> for AddLlmProviderModal {}
481
482impl Focusable for AddLlmProviderModal {
483 fn focus_handle(&self, _cx: &App) -> FocusHandle {
484 self.focus_handle.clone()
485 }
486}
487
488impl ModalView for AddLlmProviderModal {}
489
490impl Render for AddLlmProviderModal {
491 fn render(&mut self, window: &mut ui::Window, cx: &mut ui::Context<Self>) -> impl IntoElement {
492 let focus_handle = self.focus_handle(cx);
493
494 let window_size = window.viewport_size();
495 let rem_size = window.rem_size();
496 let is_large_window = window_size.height / rem_size > rems_from_px(600.).0;
497
498 let modal_max_height = if is_large_window {
499 rems_from_px(450.)
500 } else {
501 rems_from_px(200.)
502 };
503
504 v_flex()
505 .id("add-llm-provider-modal")
506 .key_context("AddLlmProviderModal")
507 .w(rems(34.))
508 .elevation_3(cx)
509 .on_action(cx.listener(Self::cancel))
510 .on_action(cx.listener(Self::on_tab))
511 .on_action(cx.listener(Self::on_tab_prev))
512 .capture_any_mouse_down(cx.listener(|this, _, window, cx| {
513 this.focus_handle(cx).focus(window, cx);
514 }))
515 .child(
516 Modal::new("configure-context-server", None)
517 .header(ModalHeader::new().headline("Add LLM Provider").description(
518 match self.provider {
519 LlmCompatibleProvider::OpenAi => {
520 "This provider will use an OpenAI compatible API."
521 }
522 },
523 ))
524 .when_some(self.last_error.clone(), |this, error| {
525 this.section(
526 Section::new().child(
527 Banner::new()
528 .severity(Severity::Warning)
529 .child(div().text_xs().child(error)),
530 ),
531 )
532 })
533 .child(
534 div()
535 .size_full()
536 .vertical_scrollbar_for(&self.scroll_handle, window, cx)
537 .child(
538 v_flex()
539 .id("modal_content")
540 .size_full()
541 .tab_group()
542 .max_h(modal_max_height)
543 .pl_3()
544 .pr_4()
545 .gap_2()
546 .overflow_y_scroll()
547 .track_scroll(&self.scroll_handle)
548 .child(self.input.provider_name.clone())
549 .child(self.input.api_url.clone())
550 .child(self.input.api_key.clone())
551 .child(self.render_model_section(cx)),
552 ),
553 )
554 .footer(
555 ModalFooter::new().end_slot(
556 h_flex()
557 .gap_1()
558 .child(
559 Button::new("cancel", "Cancel")
560 .key_binding(
561 KeyBinding::for_action_in(
562 &menu::Cancel,
563 &focus_handle,
564 cx,
565 )
566 .map(|kb| kb.size(rems_from_px(12.))),
567 )
568 .on_click(cx.listener(|this, _event, window, cx| {
569 this.cancel(&menu::Cancel, window, cx)
570 })),
571 )
572 .child(
573 Button::new("save-server", "Save Provider")
574 .key_binding(
575 KeyBinding::for_action_in(
576 &menu::Confirm,
577 &focus_handle,
578 cx,
579 )
580 .map(|kb| kb.size(rems_from_px(12.))),
581 )
582 .on_click(cx.listener(|this, _event, window, cx| {
583 this.confirm(&menu::Confirm, window, cx)
584 })),
585 ),
586 ),
587 ),
588 )
589 }
590}
591
592#[cfg(test)]
593mod tests {
594 use super::*;
595 use fs::FakeFs;
596 use gpui::{TestAppContext, VisualTestContext};
597 use language_model::{
598 LanguageModelProviderId, LanguageModelProviderName,
599 fake_provider::FakeLanguageModelProvider,
600 };
601 use project::Project;
602 use settings::SettingsStore;
603 use util::path;
604
605 #[gpui::test]
606 async fn test_save_provider_invalid_inputs(cx: &mut TestAppContext) {
607 let cx = setup_test(cx).await;
608
609 assert_eq!(
610 save_provider_validation_errors("", "someurl", "somekey", vec![], cx,).await,
611 Some("Provider Name cannot be empty".into())
612 );
613
614 assert_eq!(
615 save_provider_validation_errors("someprovider", "", "somekey", vec![], cx,).await,
616 Some("API URL cannot be empty".into())
617 );
618
619 assert_eq!(
620 save_provider_validation_errors("someprovider", "someurl", "", vec![], cx,).await,
621 Some("API Key cannot be empty".into())
622 );
623
624 assert_eq!(
625 save_provider_validation_errors(
626 "someprovider",
627 "someurl",
628 "somekey",
629 vec![("", "200000", "200000", "32000")],
630 cx,
631 )
632 .await,
633 Some("Model Name cannot be empty".into())
634 );
635
636 assert_eq!(
637 save_provider_validation_errors(
638 "someprovider",
639 "someurl",
640 "somekey",
641 vec![("somemodel", "abc", "200000", "32000")],
642 cx,
643 )
644 .await,
645 Some("Max Tokens must be a number".into())
646 );
647
648 assert_eq!(
649 save_provider_validation_errors(
650 "someprovider",
651 "someurl",
652 "somekey",
653 vec![("somemodel", "200000", "abc", "32000")],
654 cx,
655 )
656 .await,
657 Some("Max Completion Tokens must be a number".into())
658 );
659
660 assert_eq!(
661 save_provider_validation_errors(
662 "someprovider",
663 "someurl",
664 "somekey",
665 vec![("somemodel", "200000", "200000", "abc")],
666 cx,
667 )
668 .await,
669 Some("Max Output Tokens must be a number".into())
670 );
671
672 assert_eq!(
673 save_provider_validation_errors(
674 "someprovider",
675 "someurl",
676 "somekey",
677 vec![
678 ("somemodel", "200000", "200000", "32000"),
679 ("somemodel", "200000", "200000", "32000"),
680 ],
681 cx,
682 )
683 .await,
684 Some("Model Names must be unique".into())
685 );
686 }
687
688 #[gpui::test]
689 async fn test_save_provider_name_conflict(cx: &mut TestAppContext) {
690 let cx = setup_test(cx).await;
691
692 cx.update(|_window, cx| {
693 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
694 registry.register_provider(
695 Arc::new(FakeLanguageModelProvider::new(
696 LanguageModelProviderId::new("someprovider"),
697 LanguageModelProviderName::new("Some Provider"),
698 )),
699 cx,
700 );
701 });
702 });
703
704 assert_eq!(
705 save_provider_validation_errors(
706 "someprovider",
707 "someurl",
708 "someapikey",
709 vec![("somemodel", "200000", "200000", "32000")],
710 cx,
711 )
712 .await,
713 Some("Provider Name is already taken by another provider".into())
714 );
715 }
716
717 #[gpui::test]
718 async fn test_model_input_default_capabilities(cx: &mut TestAppContext) {
719 let cx = setup_test(cx).await;
720
721 cx.update(|window, cx| {
722 let model_input = ModelInput::new(0, window, cx);
723 model_input.name.update(cx, |input, cx| {
724 input.editor().update(cx, |editor, cx| {
725 editor.set_text("somemodel", window, cx);
726 });
727 });
728 assert_eq!(
729 model_input.capabilities.supports_tools,
730 ToggleState::Selected
731 );
732 assert_eq!(
733 model_input.capabilities.supports_images,
734 ToggleState::Unselected
735 );
736 assert_eq!(
737 model_input.capabilities.supports_parallel_tool_calls,
738 ToggleState::Unselected
739 );
740 assert_eq!(
741 model_input.capabilities.supports_prompt_cache_key,
742 ToggleState::Unselected
743 );
744 assert_eq!(
745 model_input.capabilities.supports_chat_completions,
746 ToggleState::Selected
747 );
748
749 let parsed_model = model_input.parse(cx).unwrap();
750 assert!(parsed_model.capabilities.tools);
751 assert!(!parsed_model.capabilities.images);
752 assert!(!parsed_model.capabilities.parallel_tool_calls);
753 assert!(!parsed_model.capabilities.prompt_cache_key);
754 assert!(parsed_model.capabilities.chat_completions);
755 });
756 }
757
758 #[gpui::test]
759 async fn test_model_input_deselected_capabilities(cx: &mut TestAppContext) {
760 let cx = setup_test(cx).await;
761
762 cx.update(|window, cx| {
763 let mut model_input = ModelInput::new(0, window, cx);
764 model_input.name.update(cx, |input, cx| {
765 input.editor().update(cx, |editor, cx| {
766 editor.set_text("somemodel", window, cx);
767 });
768 });
769
770 model_input.capabilities.supports_tools = ToggleState::Unselected;
771 model_input.capabilities.supports_images = ToggleState::Unselected;
772 model_input.capabilities.supports_parallel_tool_calls = ToggleState::Unselected;
773 model_input.capabilities.supports_prompt_cache_key = ToggleState::Unselected;
774 model_input.capabilities.supports_chat_completions = ToggleState::Unselected;
775
776 let parsed_model = model_input.parse(cx).unwrap();
777 assert!(!parsed_model.capabilities.tools);
778 assert!(!parsed_model.capabilities.images);
779 assert!(!parsed_model.capabilities.parallel_tool_calls);
780 assert!(!parsed_model.capabilities.prompt_cache_key);
781 assert!(!parsed_model.capabilities.chat_completions);
782 });
783 }
784
785 #[gpui::test]
786 async fn test_model_input_with_name_and_capabilities(cx: &mut TestAppContext) {
787 let cx = setup_test(cx).await;
788
789 cx.update(|window, cx| {
790 let mut model_input = ModelInput::new(0, window, cx);
791 model_input.name.update(cx, |input, cx| {
792 input.editor().update(cx, |editor, cx| {
793 editor.set_text("somemodel", window, cx);
794 });
795 });
796
797 model_input.capabilities.supports_tools = ToggleState::Selected;
798 model_input.capabilities.supports_images = ToggleState::Unselected;
799 model_input.capabilities.supports_parallel_tool_calls = ToggleState::Selected;
800 model_input.capabilities.supports_prompt_cache_key = ToggleState::Unselected;
801 model_input.capabilities.supports_chat_completions = ToggleState::Selected;
802
803 let parsed_model = model_input.parse(cx).unwrap();
804 assert_eq!(parsed_model.name, "somemodel");
805 assert!(parsed_model.capabilities.tools);
806 assert!(!parsed_model.capabilities.images);
807 assert!(parsed_model.capabilities.parallel_tool_calls);
808 assert!(!parsed_model.capabilities.prompt_cache_key);
809 assert!(parsed_model.capabilities.chat_completions);
810 });
811 }
812
813 async fn setup_test(cx: &mut TestAppContext) -> &mut VisualTestContext {
814 cx.update(|cx| {
815 let store = SettingsStore::test(cx);
816 cx.set_global(store);
817 theme::init(theme::LoadThemes::JustBase, cx);
818
819 language_model::init_settings(cx);
820 });
821
822 let fs = FakeFs::new(cx.executor());
823 cx.update(|cx| <dyn Fs>::set_global(fs.clone(), cx));
824 let project = Project::test(fs, [path!("/dir").as_ref()], cx).await;
825 let (_, cx) =
826 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
827
828 cx
829 }
830
831 async fn save_provider_validation_errors(
832 provider_name: &str,
833 api_url: &str,
834 api_key: &str,
835 models: Vec<(&str, &str, &str, &str)>,
836 cx: &mut VisualTestContext,
837 ) -> Option<SharedString> {
838 fn set_text(input: &Entity<InputField>, text: &str, window: &mut Window, cx: &mut App) {
839 input.update(cx, |input, cx| {
840 input.editor().update(cx, |editor, cx| {
841 editor.set_text(text, window, cx);
842 });
843 });
844 }
845
846 let task = cx.update(|window, cx| {
847 let mut input = AddLlmProviderInput::new(LlmCompatibleProvider::OpenAi, window, cx);
848 set_text(&input.provider_name, provider_name, window, cx);
849 set_text(&input.api_url, api_url, window, cx);
850 set_text(&input.api_key, api_key, window, cx);
851
852 for (i, (name, max_tokens, max_completion_tokens, max_output_tokens)) in
853 models.iter().enumerate()
854 {
855 if i >= input.models.len() {
856 input.models.push(ModelInput::new(i, window, cx));
857 }
858 let model = &mut input.models[i];
859 set_text(&model.name, name, window, cx);
860 set_text(&model.max_tokens, max_tokens, window, cx);
861 set_text(
862 &model.max_completion_tokens,
863 max_completion_tokens,
864 window,
865 cx,
866 );
867 set_text(&model.max_output_tokens, max_output_tokens, window, cx);
868 }
869 save_provider_to_settings(&input, cx)
870 });
871
872 task.await.err()
873 }
874}