1use std::{cmp::Reverse, sync::Arc};
2
3use agent_settings::AgentSettings;
4use collections::{HashMap, HashSet, IndexMap};
5use futures::{StreamExt, channel::mpsc};
6use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
7use gpui::{Action, AnyElement, App, BackgroundExecutor, DismissEvent, FocusHandle, Task};
8use language_model::{
9 AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelId, LanguageModelProvider,
10 LanguageModelProviderId, LanguageModelRegistry,
11};
12use ordered_float::OrderedFloat;
13use picker::{Picker, PickerDelegate};
14use settings::Settings;
15use ui::prelude::*;
16use zed_actions::agent::OpenSettings;
17
18use crate::ui::{ModelSelectorFooter, ModelSelectorHeader, ModelSelectorListItem};
19
20type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &mut App) + 'static>;
21type GetActiveModel = Arc<dyn Fn(&App) -> Option<ConfiguredModel> + 'static>;
22type OnToggleFavorite = Arc<dyn Fn(Arc<dyn LanguageModel>, bool, &App) + 'static>;
23
24pub type LanguageModelSelector = Picker<LanguageModelPickerDelegate>;
25
26pub fn language_model_selector(
27 get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
28 on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
29 on_toggle_favorite: impl Fn(Arc<dyn LanguageModel>, bool, &App) + 'static,
30 popover_styles: bool,
31 focus_handle: FocusHandle,
32 window: &mut Window,
33 cx: &mut Context<LanguageModelSelector>,
34) -> LanguageModelSelector {
35 let delegate = LanguageModelPickerDelegate::new(
36 get_active_model,
37 on_model_changed,
38 on_toggle_favorite,
39 popover_styles,
40 focus_handle,
41 window,
42 cx,
43 );
44
45 if popover_styles {
46 Picker::list(delegate, window, cx)
47 .show_scrollbar(true)
48 .width(rems(20.))
49 .max_height(Some(rems(20.).into()))
50 } else {
51 Picker::list(delegate, window, cx).show_scrollbar(true)
52 }
53}
54
55fn all_models(cx: &App) -> GroupedModels {
56 let lm_registry = LanguageModelRegistry::global(cx).read(cx);
57 let providers = lm_registry.visible_providers();
58
59 let mut favorites_index = FavoritesIndex::default();
60
61 for sel in &AgentSettings::get_global(cx).favorite_models {
62 favorites_index
63 .entry(sel.provider.0.clone().into())
64 .or_default()
65 .insert(sel.model.clone().into());
66 }
67
68 let recommended = providers
69 .iter()
70 .flat_map(|provider| {
71 provider
72 .recommended_models(cx)
73 .into_iter()
74 .map(|model| ModelInfo::new(&**provider, model, &favorites_index))
75 })
76 .collect();
77
78 let all: Vec<ModelInfo> = providers
79 .iter()
80 .flat_map(|provider| {
81 provider
82 .provided_models(cx)
83 .into_iter()
84 .map(|model| ModelInfo::new(&**provider, model, &favorites_index))
85 })
86 .collect();
87
88 GroupedModels::new(all, recommended)
89}
90
91type FavoritesIndex = HashMap<LanguageModelProviderId, HashSet<LanguageModelId>>;
92
93#[derive(Clone)]
94enum ProviderIcon {
95 Name(IconName),
96 Path(SharedString),
97}
98
99impl ProviderIcon {
100 fn from_provider(provider: &dyn LanguageModelProvider) -> Self {
101 if let Some(path) = provider.icon_path() {
102 Self::Path(path)
103 } else {
104 Self::Name(provider.icon())
105 }
106 }
107}
108
109#[derive(Clone)]
110struct ModelInfo {
111 model: Arc<dyn LanguageModel>,
112 icon: ProviderIcon,
113 is_favorite: bool,
114}
115
116impl ModelInfo {
117 fn new(
118 provider: &dyn LanguageModelProvider,
119 model: Arc<dyn LanguageModel>,
120 favorites_index: &FavoritesIndex,
121 ) -> Self {
122 let is_favorite = favorites_index
123 .get(&provider.id())
124 .map_or(false, |set| set.contains(&model.id()));
125
126 Self {
127 model,
128 icon: ProviderIcon::from_provider(provider),
129 is_favorite,
130 }
131 }
132}
133
134pub struct LanguageModelPickerDelegate {
135 on_model_changed: OnModelChanged,
136 get_active_model: GetActiveModel,
137 on_toggle_favorite: OnToggleFavorite,
138 all_models: Arc<GroupedModels>,
139 filtered_entries: Vec<LanguageModelPickerEntry>,
140 selected_index: usize,
141 _authenticate_all_providers_task: Task<()>,
142 _refresh_models_task: Task<()>,
143 popover_styles: bool,
144 focus_handle: FocusHandle,
145}
146
147impl LanguageModelPickerDelegate {
148 fn new(
149 get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
150 on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
151 on_toggle_favorite: impl Fn(Arc<dyn LanguageModel>, bool, &App) + 'static,
152 popover_styles: bool,
153 focus_handle: FocusHandle,
154 window: &mut Window,
155 cx: &mut Context<Picker<Self>>,
156 ) -> Self {
157 let on_model_changed = Arc::new(on_model_changed);
158 let models = all_models(cx);
159 let entries = models.entries();
160
161 Self {
162 on_model_changed,
163 all_models: Arc::new(models),
164 selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
165 filtered_entries: entries,
166 get_active_model: Arc::new(get_active_model),
167 on_toggle_favorite: Arc::new(on_toggle_favorite),
168 _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
169 _refresh_models_task: {
170 // Create a channel to signal when models need refreshing
171 let (refresh_tx, mut refresh_rx) = mpsc::unbounded::<()>();
172
173 // Subscribe to registry events and send refresh signals through the channel
174 let registry = LanguageModelRegistry::global(cx);
175 cx.subscribe(®istry, move |_picker, _, event, _cx| match event {
176 language_model::Event::ProviderStateChanged(_)
177 | language_model::Event::AddedProvider(_)
178 | language_model::Event::RemovedProvider(_)
179 | language_model::Event::ProvidersChanged => {
180 refresh_tx.unbounded_send(()).ok();
181 }
182 language_model::Event::DefaultModelChanged
183 | language_model::Event::InlineAssistantModelChanged
184 | language_model::Event::CommitMessageModelChanged
185 | language_model::Event::ThreadSummaryModelChanged => {}
186 })
187 .detach();
188
189 // Spawn a task that listens for refresh signals and updates the picker
190 cx.spawn_in(window, async move |this, cx| {
191 while let Some(()) = refresh_rx.next().await {
192 if this
193 .update_in(cx, |picker, window, cx| {
194 picker.delegate.all_models = Arc::new(all_models(cx));
195 picker.refresh(window, cx);
196 })
197 .is_err()
198 {
199 // Picker was dropped, exit the loop
200 break;
201 }
202 }
203 })
204 },
205 popover_styles,
206 focus_handle,
207 }
208 }
209
210 fn get_active_model_index(
211 entries: &[LanguageModelPickerEntry],
212 active_model: Option<ConfiguredModel>,
213 ) -> usize {
214 entries
215 .iter()
216 .position(|entry| {
217 if let LanguageModelPickerEntry::Model(model) = entry {
218 active_model
219 .as_ref()
220 .map(|active_model| {
221 active_model.model.id() == model.model.id()
222 && active_model.provider.id() == model.model.provider_id()
223 })
224 .unwrap_or_default()
225 } else {
226 false
227 }
228 })
229 .unwrap_or(0)
230 }
231
232 /// Authenticates all providers in the [`LanguageModelRegistry`].
233 ///
234 /// We do this so that we can populate the language selector with all of the
235 /// models from the configured providers.
236 fn authenticate_all_providers(cx: &mut App) -> Task<()> {
237 let authenticate_all_providers = LanguageModelRegistry::global(cx)
238 .read(cx)
239 .providers()
240 .iter()
241 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
242 .collect::<Vec<_>>();
243
244 cx.spawn(async move |_cx| {
245 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
246 if let Err(err) = authenticate_task.await {
247 if matches!(err, AuthenticateError::CredentialsNotFound) {
248 // Since we're authenticating these providers in the
249 // background for the purposes of populating the
250 // language selector, we don't care about providers
251 // where the credentials are not found.
252 } else {
253 // Some providers have noisy failure states that we
254 // don't want to spam the logs with every time the
255 // language model selector is initialized.
256 //
257 // Ideally these should have more clear failure modes
258 // that we know are safe to ignore here, like what we do
259 // with `CredentialsNotFound` above.
260 match provider_id.0.as_ref() {
261 "lmstudio" | "ollama" => {
262 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
263 //
264 // These fail noisily, so we don't log them.
265 }
266 "copilot_chat" => {
267 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
268 }
269 _ => {
270 log::error!(
271 "Failed to authenticate provider: {}: {err:#}",
272 provider_name.0
273 );
274 }
275 }
276 }
277 }
278 }
279 })
280 }
281
282 pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
283 (self.get_active_model)(cx)
284 }
285
286 pub fn cycle_favorite_models(&mut self, window: &mut Window, cx: &mut Context<Picker<Self>>) {
287 if self.all_models.favorites.is_empty() {
288 return;
289 }
290
291 let active_model = (self.get_active_model)(cx);
292 let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
293 let active_model_id = active_model.as_ref().map(|m| m.model.id());
294
295 let current_index = self
296 .all_models
297 .favorites
298 .iter()
299 .position(|info| {
300 Some(info.model.provider_id()) == active_provider_id
301 && Some(info.model.id()) == active_model_id
302 })
303 .unwrap_or(usize::MAX);
304
305 let next_index = if current_index == usize::MAX {
306 0
307 } else {
308 (current_index + 1) % self.all_models.favorites.len()
309 };
310
311 let next_model = self.all_models.favorites[next_index].model.clone();
312
313 (self.on_model_changed)(next_model, cx);
314
315 // Align the picker selection with the newly-active model
316 let new_index =
317 Self::get_active_model_index(&self.filtered_entries, (self.get_active_model)(cx));
318 self.set_selected_index(new_index, window, cx);
319 }
320}
321
322struct GroupedModels {
323 favorites: Vec<ModelInfo>,
324 recommended: Vec<ModelInfo>,
325 all: IndexMap<LanguageModelProviderId, Vec<ModelInfo>>,
326}
327
328impl GroupedModels {
329 pub fn new(all: Vec<ModelInfo>, recommended: Vec<ModelInfo>) -> Self {
330 let favorites = all
331 .iter()
332 .filter(|info| info.is_favorite)
333 .cloned()
334 .collect();
335
336 let mut all_by_provider: IndexMap<_, Vec<ModelInfo>> = IndexMap::default();
337 for model in all {
338 let provider = model.model.provider_id();
339 if let Some(models) = all_by_provider.get_mut(&provider) {
340 models.push(model);
341 } else {
342 all_by_provider.insert(provider, vec![model]);
343 }
344 }
345
346 Self {
347 favorites,
348 recommended,
349 all: all_by_provider,
350 }
351 }
352
353 fn entries(&self) -> Vec<LanguageModelPickerEntry> {
354 let mut entries = Vec::new();
355
356 if !self.favorites.is_empty() {
357 entries.push(LanguageModelPickerEntry::Separator("Favorite".into()));
358 for info in &self.favorites {
359 entries.push(LanguageModelPickerEntry::Model(info.clone()));
360 }
361 }
362
363 if !self.recommended.is_empty() {
364 entries.push(LanguageModelPickerEntry::Separator("Recommended".into()));
365 for info in &self.recommended {
366 entries.push(LanguageModelPickerEntry::Model(info.clone()));
367 }
368 }
369
370 for models in self.all.values() {
371 if models.is_empty() {
372 continue;
373 }
374 entries.push(LanguageModelPickerEntry::Separator(
375 models[0].model.provider_name().0,
376 ));
377 for info in models {
378 entries.push(LanguageModelPickerEntry::Model(info.clone()));
379 }
380 }
381
382 entries
383 }
384}
385
386enum LanguageModelPickerEntry {
387 Model(ModelInfo),
388 Separator(SharedString),
389}
390
391struct ModelMatcher {
392 models: Vec<ModelInfo>,
393 bg_executor: BackgroundExecutor,
394 candidates: Vec<StringMatchCandidate>,
395}
396
397impl ModelMatcher {
398 fn new(models: Vec<ModelInfo>, bg_executor: BackgroundExecutor) -> ModelMatcher {
399 let candidates = Self::make_match_candidates(&models);
400 Self {
401 models,
402 bg_executor,
403 candidates,
404 }
405 }
406
407 pub fn fuzzy_search(&self, query: &str) -> Vec<ModelInfo> {
408 let mut matches = self.bg_executor.block(match_strings(
409 &self.candidates,
410 query,
411 false,
412 true,
413 100,
414 &Default::default(),
415 self.bg_executor.clone(),
416 ));
417
418 let sorting_key = |mat: &StringMatch| {
419 let candidate = &self.candidates[mat.candidate_id];
420 (Reverse(OrderedFloat(mat.score)), candidate.id)
421 };
422 matches.sort_unstable_by_key(sorting_key);
423
424 let matched_models: Vec<_> = matches
425 .into_iter()
426 .map(|mat| self.models[mat.candidate_id].clone())
427 .collect();
428
429 matched_models
430 }
431
432 pub fn exact_search(&self, query: &str) -> Vec<ModelInfo> {
433 self.models
434 .iter()
435 .filter(|m| {
436 m.model
437 .name()
438 .0
439 .to_lowercase()
440 .contains(&query.to_lowercase())
441 })
442 .cloned()
443 .collect::<Vec<_>>()
444 }
445
446 fn make_match_candidates(model_infos: &Vec<ModelInfo>) -> Vec<StringMatchCandidate> {
447 model_infos
448 .iter()
449 .enumerate()
450 .map(|(index, model)| {
451 StringMatchCandidate::new(
452 index,
453 &format!(
454 "{}/{}",
455 &model.model.provider_name().0,
456 &model.model.name().0
457 ),
458 )
459 })
460 .collect::<Vec<_>>()
461 }
462}
463
464impl PickerDelegate for LanguageModelPickerDelegate {
465 type ListItem = AnyElement;
466
467 fn match_count(&self) -> usize {
468 self.filtered_entries.len()
469 }
470
471 fn selected_index(&self) -> usize {
472 self.selected_index
473 }
474
475 fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
476 self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
477 cx.notify();
478 }
479
480 fn can_select(
481 &mut self,
482 ix: usize,
483 _window: &mut Window,
484 _cx: &mut Context<Picker<Self>>,
485 ) -> bool {
486 match self.filtered_entries.get(ix) {
487 Some(LanguageModelPickerEntry::Model(_)) => true,
488 Some(LanguageModelPickerEntry::Separator(_)) | None => false,
489 }
490 }
491
492 fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
493 "Select a model…".into()
494 }
495
496 fn update_matches(
497 &mut self,
498 query: String,
499 window: &mut Window,
500 cx: &mut Context<Picker<Self>>,
501 ) -> Task<()> {
502 let all_models = self.all_models.clone();
503 let active_model = (self.get_active_model)(cx);
504 let bg_executor = cx.background_executor();
505
506 let language_model_registry = LanguageModelRegistry::global(cx);
507
508 let configured_providers = language_model_registry
509 .read(cx)
510 .visible_providers()
511 .into_iter()
512 .filter(|provider| provider.is_authenticated(cx))
513 .collect::<Vec<_>>();
514
515 let configured_provider_ids = configured_providers
516 .iter()
517 .map(|provider| provider.id())
518 .collect::<Vec<_>>();
519
520 let recommended_models = all_models
521 .recommended
522 .iter()
523 .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
524 .cloned()
525 .collect::<Vec<_>>();
526
527 let available_models = all_models
528 .all
529 .values()
530 .flat_map(|models| models.iter())
531 .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
532 .cloned()
533 .collect::<Vec<_>>();
534
535 let matcher_rec = ModelMatcher::new(recommended_models, bg_executor.clone());
536 let matcher_all = ModelMatcher::new(available_models, bg_executor.clone());
537
538 let recommended = matcher_rec.exact_search(&query);
539 let all = matcher_all.fuzzy_search(&query);
540
541 let filtered_models = GroupedModels::new(all, recommended);
542
543 cx.spawn_in(window, async move |this, cx| {
544 this.update_in(cx, |this, window, cx| {
545 this.delegate.filtered_entries = filtered_models.entries();
546 // Finds the currently selected model in the list
547 let new_index =
548 Self::get_active_model_index(&this.delegate.filtered_entries, active_model);
549 this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
550 cx.notify();
551 })
552 .ok();
553 })
554 }
555
556 fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
557 if let Some(LanguageModelPickerEntry::Model(model_info)) =
558 self.filtered_entries.get(self.selected_index)
559 {
560 let model = model_info.model.clone();
561 (self.on_model_changed)(model.clone(), cx);
562
563 let current_index = self.selected_index;
564 self.set_selected_index(current_index, window, cx);
565
566 cx.emit(DismissEvent);
567 }
568 }
569
570 fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
571 cx.emit(DismissEvent);
572 }
573
574 fn render_match(
575 &self,
576 ix: usize,
577 selected: bool,
578 _: &mut Window,
579 cx: &mut Context<Picker<Self>>,
580 ) -> Option<Self::ListItem> {
581 match self.filtered_entries.get(ix)? {
582 LanguageModelPickerEntry::Separator(title) => {
583 Some(ModelSelectorHeader::new(title, ix > 1).into_any_element())
584 }
585 LanguageModelPickerEntry::Model(model_info) => {
586 let active_model = (self.get_active_model)(cx);
587 let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
588 let active_model_id = active_model.map(|m| m.model.id());
589
590 let is_selected = Some(model_info.model.provider_id()) == active_provider_id
591 && Some(model_info.model.id()) == active_model_id;
592
593 let is_favorite = model_info.is_favorite;
594 let handle_action_click = {
595 let model = model_info.model.clone();
596 let on_toggle_favorite = self.on_toggle_favorite.clone();
597 move |cx: &App| on_toggle_favorite(model.clone(), !is_favorite, cx)
598 };
599
600 Some(
601 ModelSelectorListItem::new(ix, model_info.model.name().0)
602 .map(|this| match &model_info.icon {
603 ProviderIcon::Name(icon_name) => this.icon(*icon_name),
604 ProviderIcon::Path(icon_path) => this.icon_path(icon_path.clone()),
605 })
606 .is_selected(is_selected)
607 .is_focused(selected)
608 .is_favorite(is_favorite)
609 .on_toggle_favorite(handle_action_click)
610 .into_any_element(),
611 )
612 }
613 }
614 }
615
616 fn render_footer(
617 &self,
618 _window: &mut Window,
619 _cx: &mut Context<Picker<Self>>,
620 ) -> Option<gpui::AnyElement> {
621 let focus_handle = self.focus_handle.clone();
622
623 if !self.popover_styles {
624 return None;
625 }
626
627 Some(ModelSelectorFooter::new(OpenSettings.boxed_clone(), focus_handle).into_any_element())
628 }
629}
630
631#[cfg(test)]
632mod tests {
633 use super::*;
634 use futures::{future::BoxFuture, stream::BoxStream};
635 use gpui::{AsyncApp, TestAppContext, http_client};
636 use language_model::{
637 LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
638 LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
639 LanguageModelRequest, LanguageModelToolChoice,
640 };
641 use ui::IconName;
642
643 #[derive(Clone)]
644 struct TestLanguageModel {
645 name: LanguageModelName,
646 id: LanguageModelId,
647 provider_id: LanguageModelProviderId,
648 provider_name: LanguageModelProviderName,
649 }
650
651 impl TestLanguageModel {
652 fn new(name: &str, provider: &str) -> Self {
653 Self {
654 name: LanguageModelName::from(name.to_string()),
655 id: LanguageModelId::from(name.to_string()),
656 provider_id: LanguageModelProviderId::from(provider.to_string()),
657 provider_name: LanguageModelProviderName::from(provider.to_string()),
658 }
659 }
660 }
661
662 impl LanguageModel for TestLanguageModel {
663 fn id(&self) -> LanguageModelId {
664 self.id.clone()
665 }
666
667 fn name(&self) -> LanguageModelName {
668 self.name.clone()
669 }
670
671 fn provider_id(&self) -> LanguageModelProviderId {
672 self.provider_id.clone()
673 }
674
675 fn provider_name(&self) -> LanguageModelProviderName {
676 self.provider_name.clone()
677 }
678
679 fn supports_tools(&self) -> bool {
680 false
681 }
682
683 fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
684 false
685 }
686
687 fn supports_images(&self) -> bool {
688 false
689 }
690
691 fn telemetry_id(&self) -> String {
692 format!("{}/{}", self.provider_id.0, self.name.0)
693 }
694
695 fn max_token_count(&self) -> u64 {
696 1000
697 }
698
699 fn count_tokens(
700 &self,
701 _: LanguageModelRequest,
702 _: &App,
703 ) -> BoxFuture<'static, http_client::Result<u64>> {
704 unimplemented!()
705 }
706
707 fn stream_completion(
708 &self,
709 _: LanguageModelRequest,
710 _: &AsyncApp,
711 ) -> BoxFuture<
712 'static,
713 Result<
714 BoxStream<
715 'static,
716 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
717 >,
718 LanguageModelCompletionError,
719 >,
720 > {
721 unimplemented!()
722 }
723 }
724
725 fn create_models(model_specs: Vec<(&str, &str)>) -> Vec<ModelInfo> {
726 create_models_with_favorites(model_specs, vec![])
727 }
728
729 fn create_models_with_favorites(
730 model_specs: Vec<(&str, &str)>,
731 favorites: Vec<(&str, &str)>,
732 ) -> Vec<ModelInfo> {
733 model_specs
734 .into_iter()
735 .map(|(provider, name)| {
736 let is_favorite = favorites
737 .iter()
738 .any(|(fav_provider, fav_name)| *fav_provider == provider && *fav_name == name);
739 ModelInfo {
740 model: Arc::new(TestLanguageModel::new(name, provider)),
741 icon: ProviderIcon::Name(IconName::Ai),
742 is_favorite,
743 }
744 })
745 .collect()
746 }
747
748 fn assert_models_eq(result: Vec<ModelInfo>, expected: Vec<&str>) {
749 assert_eq!(
750 result.len(),
751 expected.len(),
752 "Number of models doesn't match"
753 );
754
755 for (i, expected_name) in expected.iter().enumerate() {
756 assert_eq!(
757 result[i].model.telemetry_id(),
758 *expected_name,
759 "Model at position {} doesn't match expected model",
760 i
761 );
762 }
763 }
764
765 #[gpui::test]
766 fn test_exact_match(cx: &mut TestAppContext) {
767 let models = create_models(vec![
768 ("zed", "Claude 3.7 Sonnet"),
769 ("zed", "Claude 3.7 Sonnet Thinking"),
770 ("zed", "gpt-4.1"),
771 ("zed", "gpt-4.1-nano"),
772 ("openai", "gpt-3.5-turbo"),
773 ("openai", "gpt-4.1"),
774 ("openai", "gpt-4.1-nano"),
775 ("ollama", "mistral"),
776 ("ollama", "deepseek"),
777 ]);
778 let matcher = ModelMatcher::new(models, cx.background_executor.clone());
779
780 // The order of models should be maintained, case doesn't matter
781 let results = matcher.exact_search("GPT-4.1");
782 assert_models_eq(
783 results,
784 vec![
785 "zed/gpt-4.1",
786 "zed/gpt-4.1-nano",
787 "openai/gpt-4.1",
788 "openai/gpt-4.1-nano",
789 ],
790 );
791 }
792
793 #[gpui::test]
794 fn test_fuzzy_match(cx: &mut TestAppContext) {
795 let models = create_models(vec![
796 ("zed", "Claude 3.7 Sonnet"),
797 ("zed", "Claude 3.7 Sonnet Thinking"),
798 ("zed", "gpt-4.1"),
799 ("zed", "gpt-4.1-nano"),
800 ("openai", "gpt-3.5-turbo"),
801 ("openai", "gpt-4.1"),
802 ("openai", "gpt-4.1-nano"),
803 ("ollama", "mistral"),
804 ("ollama", "deepseek"),
805 ]);
806 let matcher = ModelMatcher::new(models, cx.background_executor.clone());
807
808 // Results should preserve models order whenever possible.
809 // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
810 // similarity scores, but `zed/gpt-4.1` was higher in the models list,
811 // so it should appear first in the results.
812 let results = matcher.fuzzy_search("41");
813 assert_models_eq(
814 results,
815 vec![
816 "zed/gpt-4.1",
817 "openai/gpt-4.1",
818 "zed/gpt-4.1-nano",
819 "openai/gpt-4.1-nano",
820 ],
821 );
822
823 // Model provider should be searchable as well
824 let results = matcher.fuzzy_search("ol"); // meaning "ollama"
825 assert_models_eq(results, vec!["ollama/mistral", "ollama/deepseek"]);
826
827 // Fuzzy search
828 let results = matcher.fuzzy_search("z4n");
829 assert_models_eq(results, vec!["zed/gpt-4.1-nano"]);
830 }
831
832 #[gpui::test]
833 fn test_recommended_models_also_appear_in_other(_cx: &mut TestAppContext) {
834 let recommended_models = create_models(vec![("zed", "claude")]);
835 let all_models = create_models(vec![
836 ("zed", "claude"), // Should also appear in "other"
837 ("zed", "gemini"),
838 ("copilot", "o3"),
839 ]);
840
841 let grouped_models = GroupedModels::new(all_models, recommended_models);
842
843 let actual_all_models = grouped_models
844 .all
845 .values()
846 .flatten()
847 .cloned()
848 .collect::<Vec<_>>();
849
850 // Recommended models should also appear in "all"
851 assert_models_eq(
852 actual_all_models,
853 vec!["zed/claude", "zed/gemini", "copilot/o3"],
854 );
855 }
856
857 #[gpui::test]
858 fn test_models_from_different_providers(_cx: &mut TestAppContext) {
859 let recommended_models = create_models(vec![("zed", "claude")]);
860 let all_models = create_models(vec![
861 ("zed", "claude"), // Should also appear in "other"
862 ("zed", "gemini"),
863 ("copilot", "claude"), // Different provider, should appear in "other"
864 ]);
865
866 let grouped_models = GroupedModels::new(all_models, recommended_models);
867
868 let actual_all_models = grouped_models
869 .all
870 .values()
871 .flatten()
872 .cloned()
873 .collect::<Vec<_>>();
874
875 // All models should appear in "all" regardless of recommended status
876 assert_models_eq(
877 actual_all_models,
878 vec!["zed/claude", "zed/gemini", "copilot/claude"],
879 );
880 }
881
882 #[gpui::test]
883 fn test_favorites_section_appears_when_favorites_exist(_cx: &mut TestAppContext) {
884 let recommended_models = create_models(vec![("zed", "claude")]);
885 let all_models = create_models_with_favorites(
886 vec![("zed", "claude"), ("zed", "gemini"), ("openai", "gpt-4")],
887 vec![("zed", "gemini")],
888 );
889
890 let grouped_models = GroupedModels::new(all_models, recommended_models);
891 let entries = grouped_models.entries();
892
893 assert!(matches!(
894 entries.first(),
895 Some(LanguageModelPickerEntry::Separator(s)) if s == "Favorite"
896 ));
897
898 assert_models_eq(grouped_models.favorites, vec!["zed/gemini"]);
899 }
900
901 #[gpui::test]
902 fn test_no_favorites_section_when_no_favorites(_cx: &mut TestAppContext) {
903 let recommended_models = create_models(vec![("zed", "claude")]);
904 let all_models = create_models(vec![("zed", "claude"), ("zed", "gemini")]);
905
906 let grouped_models = GroupedModels::new(all_models, recommended_models);
907 let entries = grouped_models.entries();
908
909 assert!(matches!(
910 entries.first(),
911 Some(LanguageModelPickerEntry::Separator(s)) if s == "Recommended"
912 ));
913
914 assert!(grouped_models.favorites.is_empty());
915 }
916
917 #[gpui::test]
918 fn test_models_have_correct_actions(_cx: &mut TestAppContext) {
919 let recommended_models =
920 create_models_with_favorites(vec![("zed", "claude")], vec![("zed", "claude")]);
921 let all_models = create_models_with_favorites(
922 vec![("zed", "claude"), ("zed", "gemini"), ("openai", "gpt-4")],
923 vec![("zed", "claude")],
924 );
925
926 let grouped_models = GroupedModels::new(all_models, recommended_models);
927 let entries = grouped_models.entries();
928
929 for entry in &entries {
930 if let LanguageModelPickerEntry::Model(info) = entry {
931 if info.model.telemetry_id() == "zed/claude" {
932 assert!(info.is_favorite, "zed/claude should be a favorite");
933 } else {
934 assert!(
935 !info.is_favorite,
936 "{} should not be a favorite",
937 info.model.telemetry_id()
938 );
939 }
940 }
941 }
942 }
943
944 #[gpui::test]
945 fn test_favorites_appear_in_other_sections(_cx: &mut TestAppContext) {
946 let favorites = vec![("zed", "gemini"), ("openai", "gpt-4")];
947
948 let recommended_models =
949 create_models_with_favorites(vec![("zed", "claude")], favorites.clone());
950
951 let all_models = create_models_with_favorites(
952 vec![
953 ("zed", "claude"),
954 ("zed", "gemini"),
955 ("openai", "gpt-4"),
956 ("openai", "gpt-3.5"),
957 ],
958 favorites,
959 );
960
961 let grouped_models = GroupedModels::new(all_models, recommended_models);
962
963 assert_models_eq(grouped_models.favorites, vec!["zed/gemini", "openai/gpt-4"]);
964 assert_models_eq(grouped_models.recommended, vec!["zed/claude"]);
965 assert_models_eq(
966 grouped_models.all.values().flatten().cloned().collect(),
967 vec!["zed/claude", "zed/gemini", "openai/gpt-4", "openai/gpt-3.5"],
968 );
969 }
970}