1use std::{cmp::Reverse, sync::Arc};
2
3use agent_settings::AgentSettings;
4use collections::{HashMap, HashSet, IndexMap};
5use futures::{StreamExt, channel::mpsc};
6use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
7use gpui::{Action, AnyElement, App, BackgroundExecutor, DismissEvent, FocusHandle, Task};
8use language_model::{
9 AuthenticateError, ConfiguredModel, IconOrSvg, LanguageModel, LanguageModelId,
10 LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry,
11};
12use ordered_float::OrderedFloat;
13use picker::{Picker, PickerDelegate};
14use settings::Settings;
15use ui::prelude::*;
16use zed_actions::agent::OpenSettings;
17
18use crate::ui::{ModelSelectorFooter, ModelSelectorHeader, ModelSelectorListItem};
19
20type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &mut App) + 'static>;
21type GetActiveModel = Arc<dyn Fn(&App) -> Option<ConfiguredModel> + 'static>;
22type OnToggleFavorite = Arc<dyn Fn(Arc<dyn LanguageModel>, bool, &App) + 'static>;
23
24pub type LanguageModelSelector = Picker<LanguageModelPickerDelegate>;
25
26pub fn language_model_selector(
27 get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
28 on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
29 on_toggle_favorite: impl Fn(Arc<dyn LanguageModel>, bool, &App) + 'static,
30 popover_styles: bool,
31 focus_handle: FocusHandle,
32 window: &mut Window,
33 cx: &mut Context<LanguageModelSelector>,
34) -> LanguageModelSelector {
35 let delegate = LanguageModelPickerDelegate::new(
36 get_active_model,
37 on_model_changed,
38 on_toggle_favorite,
39 popover_styles,
40 focus_handle,
41 window,
42 cx,
43 );
44
45 if popover_styles {
46 Picker::list(delegate, window, cx)
47 .show_scrollbar(true)
48 .width(rems(20.))
49 .max_height(Some(rems(20.).into()))
50 } else {
51 Picker::list(delegate, window, cx).show_scrollbar(true)
52 }
53}
54
55fn all_models(cx: &App) -> GroupedModels {
56 let lm_registry = LanguageModelRegistry::global(cx).read(cx);
57 let providers = lm_registry.visible_providers();
58
59 let mut favorites_index = FavoritesIndex::default();
60
61 for sel in &AgentSettings::get_global(cx).favorite_models {
62 favorites_index
63 .entry(sel.provider.0.clone().into())
64 .or_default()
65 .insert(sel.model.clone().into());
66 }
67
68 let recommended = providers
69 .iter()
70 .flat_map(|provider| {
71 provider
72 .recommended_models(cx)
73 .into_iter()
74 .map(|model| ModelInfo::new(&**provider, model, &favorites_index))
75 })
76 .collect();
77
78 let all: Vec<ModelInfo> = providers
79 .iter()
80 .flat_map(|provider| {
81 provider
82 .provided_models(cx)
83 .into_iter()
84 .map(|model| ModelInfo::new(&**provider, model, &favorites_index))
85 })
86 .collect();
87
88 GroupedModels::new(all, recommended)
89}
90
91type FavoritesIndex = HashMap<LanguageModelProviderId, HashSet<LanguageModelId>>;
92
93#[derive(Clone)]
94struct ModelInfo {
95 model: Arc<dyn LanguageModel>,
96 icon: IconOrSvg,
97 is_favorite: bool,
98}
99
100impl ModelInfo {
101 fn new(
102 provider: &dyn LanguageModelProvider,
103 model: Arc<dyn LanguageModel>,
104 favorites_index: &FavoritesIndex,
105 ) -> Self {
106 let is_favorite = favorites_index
107 .get(&provider.id())
108 .map_or(false, |set| set.contains(&model.id()));
109
110 Self {
111 model,
112 icon: provider.icon(),
113 is_favorite,
114 }
115 }
116}
117
118pub struct LanguageModelPickerDelegate {
119 on_model_changed: OnModelChanged,
120 get_active_model: GetActiveModel,
121 on_toggle_favorite: OnToggleFavorite,
122 all_models: Arc<GroupedModels>,
123 filtered_entries: Vec<LanguageModelPickerEntry>,
124 selected_index: usize,
125 _authenticate_all_providers_task: Task<()>,
126 _refresh_models_task: Task<()>,
127 popover_styles: bool,
128 focus_handle: FocusHandle,
129}
130
131impl LanguageModelPickerDelegate {
132 fn new(
133 get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
134 on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
135 on_toggle_favorite: impl Fn(Arc<dyn LanguageModel>, bool, &App) + 'static,
136 popover_styles: bool,
137 focus_handle: FocusHandle,
138 window: &mut Window,
139 cx: &mut Context<Picker<Self>>,
140 ) -> Self {
141 let on_model_changed = Arc::new(on_model_changed);
142 let models = all_models(cx);
143 let entries = models.entries();
144
145 Self {
146 on_model_changed,
147 all_models: Arc::new(models),
148 selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
149 filtered_entries: entries,
150 get_active_model: Arc::new(get_active_model),
151 on_toggle_favorite: Arc::new(on_toggle_favorite),
152 _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
153 _refresh_models_task: {
154 // Create a channel to signal when models need refreshing
155 let (refresh_tx, mut refresh_rx) = mpsc::unbounded::<()>();
156
157 // Subscribe to registry events and send refresh signals through the channel
158 let registry = LanguageModelRegistry::global(cx);
159 cx.subscribe(®istry, move |_picker, _, event, _cx| match event {
160 language_model::Event::ProviderStateChanged(_)
161 | language_model::Event::AddedProvider(_)
162 | language_model::Event::RemovedProvider(_)
163 | language_model::Event::ProvidersChanged => {
164 refresh_tx.unbounded_send(()).ok();
165 }
166 language_model::Event::DefaultModelChanged
167 | language_model::Event::InlineAssistantModelChanged
168 | language_model::Event::CommitMessageModelChanged
169 | language_model::Event::ThreadSummaryModelChanged => {}
170 })
171 .detach();
172
173 // Spawn a task that listens for refresh signals and updates the picker
174 cx.spawn_in(window, async move |this, cx| {
175 while let Some(()) = refresh_rx.next().await {
176 if this
177 .update_in(cx, |picker, window, cx| {
178 picker.delegate.all_models = Arc::new(all_models(cx));
179 picker.refresh(window, cx);
180 })
181 .is_err()
182 {
183 // Picker was dropped, exit the loop
184 break;
185 }
186 }
187 })
188 },
189 popover_styles,
190 focus_handle,
191 }
192 }
193
194 fn get_active_model_index(
195 entries: &[LanguageModelPickerEntry],
196 active_model: Option<ConfiguredModel>,
197 ) -> usize {
198 entries
199 .iter()
200 .position(|entry| {
201 if let LanguageModelPickerEntry::Model(model) = entry {
202 active_model
203 .as_ref()
204 .map(|active_model| {
205 active_model.model.id() == model.model.id()
206 && active_model.provider.id() == model.model.provider_id()
207 })
208 .unwrap_or_default()
209 } else {
210 false
211 }
212 })
213 .unwrap_or(0)
214 }
215
216 /// Authenticates all providers in the [`LanguageModelRegistry`].
217 ///
218 /// We do this so that we can populate the language selector with all of the
219 /// models from the configured providers.
220 fn authenticate_all_providers(cx: &mut App) -> Task<()> {
221 let authenticate_all_providers = LanguageModelRegistry::global(cx)
222 .read(cx)
223 .visible_providers()
224 .iter()
225 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
226 .collect::<Vec<_>>();
227
228 cx.spawn(async move |_cx| {
229 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
230 if let Err(err) = authenticate_task.await {
231 if matches!(err, AuthenticateError::CredentialsNotFound) {
232 // Since we're authenticating these providers in the
233 // background for the purposes of populating the
234 // language selector, we don't care about providers
235 // where the credentials are not found.
236 } else {
237 // Some providers have noisy failure states that we
238 // don't want to spam the logs with every time the
239 // language model selector is initialized.
240 //
241 // Ideally these should have more clear failure modes
242 // that we know are safe to ignore here, like what we do
243 // with `CredentialsNotFound` above.
244 match provider_id.0.as_ref() {
245 "lmstudio" | "ollama" => {
246 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
247 //
248 // These fail noisily, so we don't log them.
249 }
250 "copilot_chat" => {
251 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
252 }
253 _ => {
254 log::error!(
255 "Failed to authenticate provider: {}: {err:#}",
256 provider_name.0
257 );
258 }
259 }
260 }
261 }
262 }
263 })
264 }
265
266 pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
267 (self.get_active_model)(cx)
268 }
269
270 pub fn cycle_favorite_models(&mut self, window: &mut Window, cx: &mut Context<Picker<Self>>) {
271 if self.all_models.favorites.is_empty() {
272 return;
273 }
274
275 let active_model = (self.get_active_model)(cx);
276 let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
277 let active_model_id = active_model.as_ref().map(|m| m.model.id());
278
279 let current_index = self
280 .all_models
281 .favorites
282 .iter()
283 .position(|info| {
284 Some(info.model.provider_id()) == active_provider_id
285 && Some(info.model.id()) == active_model_id
286 })
287 .unwrap_or(usize::MAX);
288
289 let next_index = if current_index == usize::MAX {
290 0
291 } else {
292 (current_index + 1) % self.all_models.favorites.len()
293 };
294
295 let next_model = self.all_models.favorites[next_index].model.clone();
296
297 (self.on_model_changed)(next_model, cx);
298
299 // Align the picker selection with the newly-active model
300 let new_index =
301 Self::get_active_model_index(&self.filtered_entries, (self.get_active_model)(cx));
302 self.set_selected_index(new_index, window, cx);
303 }
304}
305
306struct GroupedModels {
307 favorites: Vec<ModelInfo>,
308 recommended: Vec<ModelInfo>,
309 all: IndexMap<LanguageModelProviderId, Vec<ModelInfo>>,
310}
311
312impl GroupedModels {
313 pub fn new(all: Vec<ModelInfo>, recommended: Vec<ModelInfo>) -> Self {
314 let favorites = all
315 .iter()
316 .filter(|info| info.is_favorite)
317 .cloned()
318 .collect();
319
320 let mut all_by_provider: IndexMap<_, Vec<ModelInfo>> = IndexMap::default();
321 for model in all {
322 let provider = model.model.provider_id();
323 if let Some(models) = all_by_provider.get_mut(&provider) {
324 models.push(model);
325 } else {
326 all_by_provider.insert(provider, vec![model]);
327 }
328 }
329
330 Self {
331 favorites,
332 recommended,
333 all: all_by_provider,
334 }
335 }
336
337 fn entries(&self) -> Vec<LanguageModelPickerEntry> {
338 let mut entries = Vec::new();
339
340 if !self.favorites.is_empty() {
341 entries.push(LanguageModelPickerEntry::Separator("Favorite".into()));
342 for info in &self.favorites {
343 entries.push(LanguageModelPickerEntry::Model(info.clone()));
344 }
345 }
346
347 if !self.recommended.is_empty() {
348 entries.push(LanguageModelPickerEntry::Separator("Recommended".into()));
349 for info in &self.recommended {
350 entries.push(LanguageModelPickerEntry::Model(info.clone()));
351 }
352 }
353
354 for models in self.all.values() {
355 if models.is_empty() {
356 continue;
357 }
358 entries.push(LanguageModelPickerEntry::Separator(
359 models[0].model.provider_name().0,
360 ));
361 for info in models {
362 entries.push(LanguageModelPickerEntry::Model(info.clone()));
363 }
364 }
365
366 entries
367 }
368}
369
370enum LanguageModelPickerEntry {
371 Model(ModelInfo),
372 Separator(SharedString),
373}
374
375struct ModelMatcher {
376 models: Vec<ModelInfo>,
377 bg_executor: BackgroundExecutor,
378 candidates: Vec<StringMatchCandidate>,
379}
380
381impl ModelMatcher {
382 fn new(models: Vec<ModelInfo>, bg_executor: BackgroundExecutor) -> ModelMatcher {
383 let candidates = Self::make_match_candidates(&models);
384 Self {
385 models,
386 bg_executor,
387 candidates,
388 }
389 }
390
391 pub fn fuzzy_search(&self, query: &str) -> Vec<ModelInfo> {
392 let mut matches = self.bg_executor.block(match_strings(
393 &self.candidates,
394 query,
395 false,
396 true,
397 100,
398 &Default::default(),
399 self.bg_executor.clone(),
400 ));
401
402 let sorting_key = |mat: &StringMatch| {
403 let candidate = &self.candidates[mat.candidate_id];
404 (Reverse(OrderedFloat(mat.score)), candidate.id)
405 };
406 matches.sort_unstable_by_key(sorting_key);
407
408 let matched_models: Vec<_> = matches
409 .into_iter()
410 .map(|mat| self.models[mat.candidate_id].clone())
411 .collect();
412
413 matched_models
414 }
415
416 pub fn exact_search(&self, query: &str) -> Vec<ModelInfo> {
417 self.models
418 .iter()
419 .filter(|m| {
420 m.model
421 .name()
422 .0
423 .to_lowercase()
424 .contains(&query.to_lowercase())
425 })
426 .cloned()
427 .collect::<Vec<_>>()
428 }
429
430 fn make_match_candidates(model_infos: &Vec<ModelInfo>) -> Vec<StringMatchCandidate> {
431 model_infos
432 .iter()
433 .enumerate()
434 .map(|(index, model)| {
435 StringMatchCandidate::new(
436 index,
437 &format!(
438 "{}/{}",
439 &model.model.provider_name().0,
440 &model.model.name().0
441 ),
442 )
443 })
444 .collect::<Vec<_>>()
445 }
446}
447
448impl PickerDelegate for LanguageModelPickerDelegate {
449 type ListItem = AnyElement;
450
451 fn match_count(&self) -> usize {
452 self.filtered_entries.len()
453 }
454
455 fn selected_index(&self) -> usize {
456 self.selected_index
457 }
458
459 fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
460 self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
461 cx.notify();
462 }
463
464 fn can_select(
465 &mut self,
466 ix: usize,
467 _window: &mut Window,
468 _cx: &mut Context<Picker<Self>>,
469 ) -> bool {
470 match self.filtered_entries.get(ix) {
471 Some(LanguageModelPickerEntry::Model(_)) => true,
472 Some(LanguageModelPickerEntry::Separator(_)) | None => false,
473 }
474 }
475
476 fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
477 "Select a model…".into()
478 }
479
480 fn update_matches(
481 &mut self,
482 query: String,
483 window: &mut Window,
484 cx: &mut Context<Picker<Self>>,
485 ) -> Task<()> {
486 let all_models = self.all_models.clone();
487 let active_model = (self.get_active_model)(cx);
488 let bg_executor = cx.background_executor();
489
490 let language_model_registry = LanguageModelRegistry::global(cx);
491
492 let configured_providers = language_model_registry
493 .read(cx)
494 .visible_providers()
495 .into_iter()
496 .filter(|provider| provider.is_authenticated(cx))
497 .collect::<Vec<_>>();
498
499 let configured_provider_ids = configured_providers
500 .iter()
501 .map(|provider| provider.id())
502 .collect::<Vec<_>>();
503
504 let recommended_models = all_models
505 .recommended
506 .iter()
507 .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
508 .cloned()
509 .collect::<Vec<_>>();
510
511 let available_models = all_models
512 .all
513 .values()
514 .flat_map(|models| models.iter())
515 .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
516 .cloned()
517 .collect::<Vec<_>>();
518
519 let matcher_rec = ModelMatcher::new(recommended_models, bg_executor.clone());
520 let matcher_all = ModelMatcher::new(available_models, bg_executor.clone());
521
522 let recommended = matcher_rec.exact_search(&query);
523 let all = matcher_all.fuzzy_search(&query);
524
525 let filtered_models = GroupedModels::new(all, recommended);
526
527 cx.spawn_in(window, async move |this, cx| {
528 this.update_in(cx, |this, window, cx| {
529 this.delegate.filtered_entries = filtered_models.entries();
530 // Finds the currently selected model in the list
531 let new_index =
532 Self::get_active_model_index(&this.delegate.filtered_entries, active_model);
533 this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
534 cx.notify();
535 })
536 .ok();
537 })
538 }
539
540 fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
541 if let Some(LanguageModelPickerEntry::Model(model_info)) =
542 self.filtered_entries.get(self.selected_index)
543 {
544 let model = model_info.model.clone();
545 (self.on_model_changed)(model.clone(), cx);
546
547 let current_index = self.selected_index;
548 self.set_selected_index(current_index, window, cx);
549
550 cx.emit(DismissEvent);
551 }
552 }
553
554 fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
555 cx.emit(DismissEvent);
556 }
557
558 fn render_match(
559 &self,
560 ix: usize,
561 selected: bool,
562 _: &mut Window,
563 cx: &mut Context<Picker<Self>>,
564 ) -> Option<Self::ListItem> {
565 match self.filtered_entries.get(ix)? {
566 LanguageModelPickerEntry::Separator(title) => {
567 Some(ModelSelectorHeader::new(title, ix > 1).into_any_element())
568 }
569 LanguageModelPickerEntry::Model(model_info) => {
570 let active_model = (self.get_active_model)(cx);
571 let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
572 let active_model_id = active_model.map(|m| m.model.id());
573
574 let is_selected = Some(model_info.model.provider_id()) == active_provider_id
575 && Some(model_info.model.id()) == active_model_id;
576
577 let is_favorite = model_info.is_favorite;
578 let handle_action_click = {
579 let model = model_info.model.clone();
580 let on_toggle_favorite = self.on_toggle_favorite.clone();
581 move |cx: &App| on_toggle_favorite(model.clone(), !is_favorite, cx)
582 };
583
584 Some(
585 ModelSelectorListItem::new(ix, model_info.model.name().0)
586 .map(|this| match &model_info.icon {
587 IconOrSvg::Icon(icon_name) => this.icon(*icon_name),
588 IconOrSvg::Svg(icon_path) => this.icon_path(icon_path.clone()),
589 })
590 .is_selected(is_selected)
591 .is_focused(selected)
592 .is_favorite(is_favorite)
593 .on_toggle_favorite(handle_action_click)
594 .into_any_element(),
595 )
596 }
597 }
598 }
599
600 fn render_footer(
601 &self,
602 _window: &mut Window,
603 _cx: &mut Context<Picker<Self>>,
604 ) -> Option<gpui::AnyElement> {
605 let focus_handle = self.focus_handle.clone();
606
607 if !self.popover_styles {
608 return None;
609 }
610
611 Some(ModelSelectorFooter::new(OpenSettings.boxed_clone(), focus_handle).into_any_element())
612 }
613}
614
615#[cfg(test)]
616mod tests {
617 use super::*;
618 use futures::{future::BoxFuture, stream::BoxStream};
619 use gpui::{AsyncApp, TestAppContext, http_client};
620 use language_model::{
621 LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
622 LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
623 LanguageModelRequest, LanguageModelToolChoice,
624 };
625 use ui::IconName;
626
627 #[derive(Clone)]
628 struct TestLanguageModel {
629 name: LanguageModelName,
630 id: LanguageModelId,
631 provider_id: LanguageModelProviderId,
632 provider_name: LanguageModelProviderName,
633 }
634
635 impl TestLanguageModel {
636 fn new(name: &str, provider: &str) -> Self {
637 Self {
638 name: LanguageModelName::from(name.to_string()),
639 id: LanguageModelId::from(name.to_string()),
640 provider_id: LanguageModelProviderId::from(provider.to_string()),
641 provider_name: LanguageModelProviderName::from(provider.to_string()),
642 }
643 }
644 }
645
646 impl LanguageModel for TestLanguageModel {
647 fn id(&self) -> LanguageModelId {
648 self.id.clone()
649 }
650
651 fn name(&self) -> LanguageModelName {
652 self.name.clone()
653 }
654
655 fn provider_id(&self) -> LanguageModelProviderId {
656 self.provider_id.clone()
657 }
658
659 fn provider_name(&self) -> LanguageModelProviderName {
660 self.provider_name.clone()
661 }
662
663 fn supports_tools(&self) -> bool {
664 false
665 }
666
667 fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
668 false
669 }
670
671 fn supports_images(&self) -> bool {
672 false
673 }
674
675 fn telemetry_id(&self) -> String {
676 format!("{}/{}", self.provider_id.0, self.name.0)
677 }
678
679 fn max_token_count(&self) -> u64 {
680 1000
681 }
682
683 fn count_tokens(
684 &self,
685 _: LanguageModelRequest,
686 _: &App,
687 ) -> BoxFuture<'static, http_client::Result<u64>> {
688 unimplemented!()
689 }
690
691 fn stream_completion(
692 &self,
693 _: LanguageModelRequest,
694 _: &AsyncApp,
695 ) -> BoxFuture<
696 'static,
697 Result<
698 BoxStream<
699 'static,
700 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
701 >,
702 LanguageModelCompletionError,
703 >,
704 > {
705 unimplemented!()
706 }
707 }
708
709 fn create_models(model_specs: Vec<(&str, &str)>) -> Vec<ModelInfo> {
710 create_models_with_favorites(model_specs, vec![])
711 }
712
713 fn create_models_with_favorites(
714 model_specs: Vec<(&str, &str)>,
715 favorites: Vec<(&str, &str)>,
716 ) -> Vec<ModelInfo> {
717 model_specs
718 .into_iter()
719 .map(|(provider, name)| {
720 let is_favorite = favorites
721 .iter()
722 .any(|(fav_provider, fav_name)| *fav_provider == provider && *fav_name == name);
723 ModelInfo {
724 model: Arc::new(TestLanguageModel::new(name, provider)),
725 icon: IconOrSvg::Icon(IconName::Ai),
726 is_favorite,
727 }
728 })
729 .collect()
730 }
731
732 fn assert_models_eq(result: Vec<ModelInfo>, expected: Vec<&str>) {
733 assert_eq!(
734 result.len(),
735 expected.len(),
736 "Number of models doesn't match"
737 );
738
739 for (i, expected_name) in expected.iter().enumerate() {
740 assert_eq!(
741 result[i].model.telemetry_id(),
742 *expected_name,
743 "Model at position {} doesn't match expected model",
744 i
745 );
746 }
747 }
748
749 #[gpui::test]
750 fn test_exact_match(cx: &mut TestAppContext) {
751 let models = create_models(vec![
752 ("zed", "Claude 3.7 Sonnet"),
753 ("zed", "Claude 3.7 Sonnet Thinking"),
754 ("zed", "gpt-4.1"),
755 ("zed", "gpt-4.1-nano"),
756 ("openai", "gpt-3.5-turbo"),
757 ("openai", "gpt-4.1"),
758 ("openai", "gpt-4.1-nano"),
759 ("ollama", "mistral"),
760 ("ollama", "deepseek"),
761 ]);
762 let matcher = ModelMatcher::new(models, cx.background_executor.clone());
763
764 // The order of models should be maintained, case doesn't matter
765 let results = matcher.exact_search("GPT-4.1");
766 assert_models_eq(
767 results,
768 vec![
769 "zed/gpt-4.1",
770 "zed/gpt-4.1-nano",
771 "openai/gpt-4.1",
772 "openai/gpt-4.1-nano",
773 ],
774 );
775 }
776
777 #[gpui::test]
778 fn test_fuzzy_match(cx: &mut TestAppContext) {
779 let models = create_models(vec![
780 ("zed", "Claude 3.7 Sonnet"),
781 ("zed", "Claude 3.7 Sonnet Thinking"),
782 ("zed", "gpt-4.1"),
783 ("zed", "gpt-4.1-nano"),
784 ("openai", "gpt-3.5-turbo"),
785 ("openai", "gpt-4.1"),
786 ("openai", "gpt-4.1-nano"),
787 ("ollama", "mistral"),
788 ("ollama", "deepseek"),
789 ]);
790 let matcher = ModelMatcher::new(models, cx.background_executor.clone());
791
792 // Results should preserve models order whenever possible.
793 // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
794 // similarity scores, but `zed/gpt-4.1` was higher in the models list,
795 // so it should appear first in the results.
796 let results = matcher.fuzzy_search("41");
797 assert_models_eq(
798 results,
799 vec![
800 "zed/gpt-4.1",
801 "openai/gpt-4.1",
802 "zed/gpt-4.1-nano",
803 "openai/gpt-4.1-nano",
804 ],
805 );
806
807 // Model provider should be searchable as well
808 let results = matcher.fuzzy_search("ol"); // meaning "ollama"
809 assert_models_eq(results, vec!["ollama/mistral", "ollama/deepseek"]);
810
811 // Fuzzy search
812 let results = matcher.fuzzy_search("z4n");
813 assert_models_eq(results, vec!["zed/gpt-4.1-nano"]);
814 }
815
816 #[gpui::test]
817 fn test_recommended_models_also_appear_in_other(_cx: &mut TestAppContext) {
818 let recommended_models = create_models(vec![("zed", "claude")]);
819 let all_models = create_models(vec![
820 ("zed", "claude"), // Should also appear in "other"
821 ("zed", "gemini"),
822 ("copilot", "o3"),
823 ]);
824
825 let grouped_models = GroupedModels::new(all_models, recommended_models);
826
827 let actual_all_models = grouped_models
828 .all
829 .values()
830 .flatten()
831 .cloned()
832 .collect::<Vec<_>>();
833
834 // Recommended models should also appear in "all"
835 assert_models_eq(
836 actual_all_models,
837 vec!["zed/claude", "zed/gemini", "copilot/o3"],
838 );
839 }
840
841 #[gpui::test]
842 fn test_models_from_different_providers(_cx: &mut TestAppContext) {
843 let recommended_models = create_models(vec![("zed", "claude")]);
844 let all_models = create_models(vec![
845 ("zed", "claude"), // Should also appear in "other"
846 ("zed", "gemini"),
847 ("copilot", "claude"), // Different provider, should appear in "other"
848 ]);
849
850 let grouped_models = GroupedModels::new(all_models, recommended_models);
851
852 let actual_all_models = grouped_models
853 .all
854 .values()
855 .flatten()
856 .cloned()
857 .collect::<Vec<_>>();
858
859 // All models should appear in "all" regardless of recommended status
860 assert_models_eq(
861 actual_all_models,
862 vec!["zed/claude", "zed/gemini", "copilot/claude"],
863 );
864 }
865
866 #[gpui::test]
867 fn test_favorites_section_appears_when_favorites_exist(_cx: &mut TestAppContext) {
868 let recommended_models = create_models(vec![("zed", "claude")]);
869 let all_models = create_models_with_favorites(
870 vec![("zed", "claude"), ("zed", "gemini"), ("openai", "gpt-4")],
871 vec![("zed", "gemini")],
872 );
873
874 let grouped_models = GroupedModels::new(all_models, recommended_models);
875 let entries = grouped_models.entries();
876
877 assert!(matches!(
878 entries.first(),
879 Some(LanguageModelPickerEntry::Separator(s)) if s == "Favorite"
880 ));
881
882 assert_models_eq(grouped_models.favorites, vec!["zed/gemini"]);
883 }
884
885 #[gpui::test]
886 fn test_no_favorites_section_when_no_favorites(_cx: &mut TestAppContext) {
887 let recommended_models = create_models(vec![("zed", "claude")]);
888 let all_models = create_models(vec![("zed", "claude"), ("zed", "gemini")]);
889
890 let grouped_models = GroupedModels::new(all_models, recommended_models);
891 let entries = grouped_models.entries();
892
893 assert!(matches!(
894 entries.first(),
895 Some(LanguageModelPickerEntry::Separator(s)) if s == "Recommended"
896 ));
897
898 assert!(grouped_models.favorites.is_empty());
899 }
900
901 #[gpui::test]
902 fn test_models_have_correct_actions(_cx: &mut TestAppContext) {
903 let recommended_models =
904 create_models_with_favorites(vec![("zed", "claude")], vec![("zed", "claude")]);
905 let all_models = create_models_with_favorites(
906 vec![("zed", "claude"), ("zed", "gemini"), ("openai", "gpt-4")],
907 vec![("zed", "claude")],
908 );
909
910 let grouped_models = GroupedModels::new(all_models, recommended_models);
911 let entries = grouped_models.entries();
912
913 for entry in &entries {
914 if let LanguageModelPickerEntry::Model(info) = entry {
915 if info.model.telemetry_id() == "zed/claude" {
916 assert!(info.is_favorite, "zed/claude should be a favorite");
917 } else {
918 assert!(
919 !info.is_favorite,
920 "{} should not be a favorite",
921 info.model.telemetry_id()
922 );
923 }
924 }
925 }
926 }
927
928 #[gpui::test]
929 fn test_favorites_appear_in_other_sections(_cx: &mut TestAppContext) {
930 let favorites = vec![("zed", "gemini"), ("openai", "gpt-4")];
931
932 let recommended_models =
933 create_models_with_favorites(vec![("zed", "claude")], favorites.clone());
934
935 let all_models = create_models_with_favorites(
936 vec![
937 ("zed", "claude"),
938 ("zed", "gemini"),
939 ("openai", "gpt-4"),
940 ("openai", "gpt-3.5"),
941 ],
942 favorites,
943 );
944
945 let grouped_models = GroupedModels::new(all_models, recommended_models);
946
947 assert_models_eq(grouped_models.favorites, vec!["zed/gemini", "openai/gpt-4"]);
948 assert_models_eq(grouped_models.recommended, vec!["zed/claude"]);
949 assert_models_eq(
950 grouped_models.all.values().flatten().cloned().collect(),
951 vec!["zed/claude", "zed/gemini", "openai/gpt-4", "openai/gpt-3.5"],
952 );
953 }
954}