1use std::{cmp::Reverse, sync::Arc};
2
3use collections::IndexMap;
4use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
5use gpui::{
6 Action, AnyElement, App, BackgroundExecutor, DismissEvent, FocusHandle, Subscription, Task,
7};
8use language_model::{
9 AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProviderId,
10 LanguageModelRegistry,
11};
12use ordered_float::OrderedFloat;
13use picker::{Picker, PickerDelegate};
14use ui::prelude::*;
15use zed_actions::agent::OpenSettings;
16
17use crate::ui::{ModelSelectorFooter, ModelSelectorHeader, ModelSelectorListItem};
18
19type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &mut App) + 'static>;
20type GetActiveModel = Arc<dyn Fn(&App) -> Option<ConfiguredModel> + 'static>;
21
22pub type LanguageModelSelector = Picker<LanguageModelPickerDelegate>;
23
24pub fn language_model_selector(
25 get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
26 on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
27 popover_styles: bool,
28 focus_handle: FocusHandle,
29 window: &mut Window,
30 cx: &mut Context<LanguageModelSelector>,
31) -> LanguageModelSelector {
32 let delegate = LanguageModelPickerDelegate::new(
33 get_active_model,
34 on_model_changed,
35 popover_styles,
36 focus_handle,
37 window,
38 cx,
39 );
40
41 if popover_styles {
42 Picker::list(delegate, window, cx)
43 .show_scrollbar(true)
44 .width(rems(20.))
45 .max_height(Some(rems(20.).into()))
46 } else {
47 Picker::list(delegate, window, cx).show_scrollbar(true)
48 }
49}
50
51fn all_models(cx: &App) -> GroupedModels {
52 let providers = LanguageModelRegistry::global(cx).read(cx).providers();
53
54 let recommended = providers
55 .iter()
56 .flat_map(|provider| {
57 provider
58 .recommended_models(cx)
59 .into_iter()
60 .map(|model| ModelInfo {
61 model,
62 icon: provider.icon(),
63 })
64 })
65 .collect();
66
67 let all = providers
68 .iter()
69 .flat_map(|provider| {
70 provider
71 .provided_models(cx)
72 .into_iter()
73 .map(|model| ModelInfo {
74 model,
75 icon: provider.icon(),
76 })
77 })
78 .collect();
79
80 GroupedModels::new(all, recommended)
81}
82
83#[derive(Clone)]
84struct ModelInfo {
85 model: Arc<dyn LanguageModel>,
86 icon: IconName,
87}
88
89pub struct LanguageModelPickerDelegate {
90 on_model_changed: OnModelChanged,
91 get_active_model: GetActiveModel,
92 all_models: Arc<GroupedModels>,
93 filtered_entries: Vec<LanguageModelPickerEntry>,
94 selected_index: usize,
95 _authenticate_all_providers_task: Task<()>,
96 _subscriptions: Vec<Subscription>,
97 popover_styles: bool,
98 focus_handle: FocusHandle,
99}
100
101impl LanguageModelPickerDelegate {
102 fn new(
103 get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
104 on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
105 popover_styles: bool,
106 focus_handle: FocusHandle,
107 window: &mut Window,
108 cx: &mut Context<Picker<Self>>,
109 ) -> Self {
110 let on_model_changed = Arc::new(on_model_changed);
111 let models = all_models(cx);
112 let entries = models.entries();
113
114 Self {
115 on_model_changed,
116 all_models: Arc::new(models),
117 selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
118 filtered_entries: entries,
119 get_active_model: Arc::new(get_active_model),
120 _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
121 _subscriptions: vec![cx.subscribe_in(
122 &LanguageModelRegistry::global(cx),
123 window,
124 |picker, _, event, window, cx| {
125 match event {
126 language_model::Event::ProviderStateChanged(_)
127 | language_model::Event::AddedProvider(_)
128 | language_model::Event::RemovedProvider(_) => {
129 let query = picker.query(cx);
130 picker.delegate.all_models = Arc::new(all_models(cx));
131 // Update matches will automatically drop the previous task
132 // if we get a provider event again
133 picker.update_matches(query, window, cx)
134 }
135 _ => {}
136 }
137 },
138 )],
139 popover_styles,
140 focus_handle,
141 }
142 }
143
144 fn get_active_model_index(
145 entries: &[LanguageModelPickerEntry],
146 active_model: Option<ConfiguredModel>,
147 ) -> usize {
148 entries
149 .iter()
150 .position(|entry| {
151 if let LanguageModelPickerEntry::Model(model) = entry {
152 active_model
153 .as_ref()
154 .map(|active_model| {
155 active_model.model.id() == model.model.id()
156 && active_model.provider.id() == model.model.provider_id()
157 })
158 .unwrap_or_default()
159 } else {
160 false
161 }
162 })
163 .unwrap_or(0)
164 }
165
166 /// Authenticates all providers in the [`LanguageModelRegistry`].
167 ///
168 /// We do this so that we can populate the language selector with all of the
169 /// models from the configured providers.
170 fn authenticate_all_providers(cx: &mut App) -> Task<()> {
171 let authenticate_all_providers = LanguageModelRegistry::global(cx)
172 .read(cx)
173 .providers()
174 .iter()
175 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
176 .collect::<Vec<_>>();
177
178 cx.spawn(async move |_cx| {
179 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
180 if let Err(err) = authenticate_task.await {
181 if matches!(err, AuthenticateError::CredentialsNotFound) {
182 // Since we're authenticating these providers in the
183 // background for the purposes of populating the
184 // language selector, we don't care about providers
185 // where the credentials are not found.
186 } else {
187 // Some providers have noisy failure states that we
188 // don't want to spam the logs with every time the
189 // language model selector is initialized.
190 //
191 // Ideally these should have more clear failure modes
192 // that we know are safe to ignore here, like what we do
193 // with `CredentialsNotFound` above.
194 match provider_id.0.as_ref() {
195 "lmstudio" | "ollama" => {
196 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
197 //
198 // These fail noisily, so we don't log them.
199 }
200 "copilot_chat" => {
201 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
202 }
203 _ => {
204 log::error!(
205 "Failed to authenticate provider: {}: {err:#}",
206 provider_name.0
207 );
208 }
209 }
210 }
211 }
212 }
213 })
214 }
215
216 pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
217 (self.get_active_model)(cx)
218 }
219}
220
221struct GroupedModels {
222 recommended: Vec<ModelInfo>,
223 all: IndexMap<LanguageModelProviderId, Vec<ModelInfo>>,
224}
225
226impl GroupedModels {
227 pub fn new(all: Vec<ModelInfo>, recommended: Vec<ModelInfo>) -> Self {
228 let mut all_by_provider: IndexMap<_, Vec<ModelInfo>> = IndexMap::default();
229 for model in all {
230 let provider = model.model.provider_id();
231 if let Some(models) = all_by_provider.get_mut(&provider) {
232 models.push(model);
233 } else {
234 all_by_provider.insert(provider, vec![model]);
235 }
236 }
237
238 Self {
239 recommended,
240 all: all_by_provider,
241 }
242 }
243
244 fn entries(&self) -> Vec<LanguageModelPickerEntry> {
245 let mut entries = Vec::new();
246
247 if !self.recommended.is_empty() {
248 entries.push(LanguageModelPickerEntry::Separator("Recommended".into()));
249 entries.extend(
250 self.recommended
251 .iter()
252 .map(|info| LanguageModelPickerEntry::Model(info.clone())),
253 );
254 }
255
256 for models in self.all.values() {
257 if models.is_empty() {
258 continue;
259 }
260 entries.push(LanguageModelPickerEntry::Separator(
261 models[0].model.provider_name().0,
262 ));
263 entries.extend(
264 models
265 .iter()
266 .map(|info| LanguageModelPickerEntry::Model(info.clone())),
267 );
268 }
269 entries
270 }
271}
272
273enum LanguageModelPickerEntry {
274 Model(ModelInfo),
275 Separator(SharedString),
276}
277
278struct ModelMatcher {
279 models: Vec<ModelInfo>,
280 bg_executor: BackgroundExecutor,
281 candidates: Vec<StringMatchCandidate>,
282}
283
284impl ModelMatcher {
285 fn new(models: Vec<ModelInfo>, bg_executor: BackgroundExecutor) -> ModelMatcher {
286 let candidates = Self::make_match_candidates(&models);
287 Self {
288 models,
289 bg_executor,
290 candidates,
291 }
292 }
293
294 pub fn fuzzy_search(&self, query: &str) -> Vec<ModelInfo> {
295 let mut matches = self.bg_executor.block(match_strings(
296 &self.candidates,
297 query,
298 false,
299 true,
300 100,
301 &Default::default(),
302 self.bg_executor.clone(),
303 ));
304
305 let sorting_key = |mat: &StringMatch| {
306 let candidate = &self.candidates[mat.candidate_id];
307 (Reverse(OrderedFloat(mat.score)), candidate.id)
308 };
309 matches.sort_unstable_by_key(sorting_key);
310
311 let matched_models: Vec<_> = matches
312 .into_iter()
313 .map(|mat| self.models[mat.candidate_id].clone())
314 .collect();
315
316 matched_models
317 }
318
319 pub fn exact_search(&self, query: &str) -> Vec<ModelInfo> {
320 self.models
321 .iter()
322 .filter(|m| {
323 m.model
324 .name()
325 .0
326 .to_lowercase()
327 .contains(&query.to_lowercase())
328 })
329 .cloned()
330 .collect::<Vec<_>>()
331 }
332
333 fn make_match_candidates(model_infos: &Vec<ModelInfo>) -> Vec<StringMatchCandidate> {
334 model_infos
335 .iter()
336 .enumerate()
337 .map(|(index, model)| {
338 StringMatchCandidate::new(
339 index,
340 &format!(
341 "{}/{}",
342 &model.model.provider_name().0,
343 &model.model.name().0
344 ),
345 )
346 })
347 .collect::<Vec<_>>()
348 }
349}
350
351impl PickerDelegate for LanguageModelPickerDelegate {
352 type ListItem = AnyElement;
353
354 fn match_count(&self) -> usize {
355 self.filtered_entries.len()
356 }
357
358 fn selected_index(&self) -> usize {
359 self.selected_index
360 }
361
362 fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
363 self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
364 cx.notify();
365 }
366
367 fn can_select(
368 &mut self,
369 ix: usize,
370 _window: &mut Window,
371 _cx: &mut Context<Picker<Self>>,
372 ) -> bool {
373 match self.filtered_entries.get(ix) {
374 Some(LanguageModelPickerEntry::Model(_)) => true,
375 Some(LanguageModelPickerEntry::Separator(_)) | None => false,
376 }
377 }
378
379 fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
380 "Select a model…".into()
381 }
382
383 fn update_matches(
384 &mut self,
385 query: String,
386 window: &mut Window,
387 cx: &mut Context<Picker<Self>>,
388 ) -> Task<()> {
389 let all_models = self.all_models.clone();
390 let active_model = (self.get_active_model)(cx);
391 let bg_executor = cx.background_executor();
392
393 let language_model_registry = LanguageModelRegistry::global(cx);
394
395 let configured_providers = language_model_registry
396 .read(cx)
397 .providers()
398 .into_iter()
399 .filter(|provider| provider.is_authenticated(cx))
400 .collect::<Vec<_>>();
401
402 let configured_provider_ids = configured_providers
403 .iter()
404 .map(|provider| provider.id())
405 .collect::<Vec<_>>();
406
407 let recommended_models = all_models
408 .recommended
409 .iter()
410 .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
411 .cloned()
412 .collect::<Vec<_>>();
413
414 let available_models = all_models
415 .all
416 .values()
417 .flat_map(|models| models.iter())
418 .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
419 .cloned()
420 .collect::<Vec<_>>();
421
422 let matcher_rec = ModelMatcher::new(recommended_models, bg_executor.clone());
423 let matcher_all = ModelMatcher::new(available_models, bg_executor.clone());
424
425 let recommended = matcher_rec.exact_search(&query);
426 let all = matcher_all.fuzzy_search(&query);
427
428 let filtered_models = GroupedModels::new(all, recommended);
429
430 cx.spawn_in(window, async move |this, cx| {
431 this.update_in(cx, |this, window, cx| {
432 this.delegate.filtered_entries = filtered_models.entries();
433 // Finds the currently selected model in the list
434 let new_index =
435 Self::get_active_model_index(&this.delegate.filtered_entries, active_model);
436 this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
437 cx.notify();
438 })
439 .ok();
440 })
441 }
442
443 fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
444 if let Some(LanguageModelPickerEntry::Model(model_info)) =
445 self.filtered_entries.get(self.selected_index)
446 {
447 let model = model_info.model.clone();
448 (self.on_model_changed)(model.clone(), cx);
449
450 let current_index = self.selected_index;
451 self.set_selected_index(current_index, window, cx);
452
453 cx.emit(DismissEvent);
454 }
455 }
456
457 fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
458 cx.emit(DismissEvent);
459 }
460
461 fn render_match(
462 &self,
463 ix: usize,
464 is_focused: bool,
465 _: &mut Window,
466 cx: &mut Context<Picker<Self>>,
467 ) -> Option<Self::ListItem> {
468 match self.filtered_entries.get(ix)? {
469 LanguageModelPickerEntry::Separator(title) => {
470 Some(ModelSelectorHeader::new(title, ix > 1).into_any_element())
471 }
472 LanguageModelPickerEntry::Model(model_info) => {
473 let active_model = (self.get_active_model)(cx);
474 let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
475 let active_model_id = active_model.map(|m| m.model.id());
476
477 let is_selected = Some(model_info.model.provider_id()) == active_provider_id
478 && Some(model_info.model.id()) == active_model_id;
479
480 Some(
481 ModelSelectorListItem::new(ix, model_info.model.name().0)
482 .is_focused(is_focused)
483 .is_selected(is_selected)
484 .icon(model_info.icon)
485 .into_any_element(),
486 )
487 }
488 }
489 }
490
491 fn render_footer(
492 &self,
493 _window: &mut Window,
494 _cx: &mut Context<Picker<Self>>,
495 ) -> Option<gpui::AnyElement> {
496 if !self.popover_styles {
497 return None;
498 }
499
500 let focus_handle = self.focus_handle.clone();
501
502 Some(ModelSelectorFooter::new(OpenSettings.boxed_clone(), focus_handle).into_any_element())
503 }
504}
505
506#[cfg(test)]
507mod tests {
508 use super::*;
509 use futures::{future::BoxFuture, stream::BoxStream};
510 use gpui::{AsyncApp, TestAppContext, http_client};
511 use language_model::{
512 LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
513 LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
514 LanguageModelRequest, LanguageModelToolChoice,
515 };
516 use ui::IconName;
517
518 #[derive(Clone)]
519 struct TestLanguageModel {
520 name: LanguageModelName,
521 id: LanguageModelId,
522 provider_id: LanguageModelProviderId,
523 provider_name: LanguageModelProviderName,
524 }
525
526 impl TestLanguageModel {
527 fn new(name: &str, provider: &str) -> Self {
528 Self {
529 name: LanguageModelName::from(name.to_string()),
530 id: LanguageModelId::from(name.to_string()),
531 provider_id: LanguageModelProviderId::from(provider.to_string()),
532 provider_name: LanguageModelProviderName::from(provider.to_string()),
533 }
534 }
535 }
536
537 impl LanguageModel for TestLanguageModel {
538 fn id(&self) -> LanguageModelId {
539 self.id.clone()
540 }
541
542 fn name(&self) -> LanguageModelName {
543 self.name.clone()
544 }
545
546 fn provider_id(&self) -> LanguageModelProviderId {
547 self.provider_id.clone()
548 }
549
550 fn provider_name(&self) -> LanguageModelProviderName {
551 self.provider_name.clone()
552 }
553
554 fn supports_tools(&self) -> bool {
555 false
556 }
557
558 fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
559 false
560 }
561
562 fn supports_images(&self) -> bool {
563 false
564 }
565
566 fn telemetry_id(&self) -> String {
567 format!("{}/{}", self.provider_id.0, self.name.0)
568 }
569
570 fn max_token_count(&self) -> u64 {
571 1000
572 }
573
574 fn count_tokens(
575 &self,
576 _: LanguageModelRequest,
577 _: &App,
578 ) -> BoxFuture<'static, http_client::Result<u64>> {
579 unimplemented!()
580 }
581
582 fn stream_completion(
583 &self,
584 _: LanguageModelRequest,
585 _: &AsyncApp,
586 ) -> BoxFuture<
587 'static,
588 Result<
589 BoxStream<
590 'static,
591 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
592 >,
593 LanguageModelCompletionError,
594 >,
595 > {
596 unimplemented!()
597 }
598 }
599
600 fn create_models(model_specs: Vec<(&str, &str)>) -> Vec<ModelInfo> {
601 model_specs
602 .into_iter()
603 .map(|(provider, name)| ModelInfo {
604 model: Arc::new(TestLanguageModel::new(name, provider)),
605 icon: IconName::Ai,
606 })
607 .collect()
608 }
609
610 fn assert_models_eq(result: Vec<ModelInfo>, expected: Vec<&str>) {
611 assert_eq!(
612 result.len(),
613 expected.len(),
614 "Number of models doesn't match"
615 );
616
617 for (i, expected_name) in expected.iter().enumerate() {
618 assert_eq!(
619 result[i].model.telemetry_id(),
620 *expected_name,
621 "Model at position {} doesn't match expected model",
622 i
623 );
624 }
625 }
626
627 #[gpui::test]
628 fn test_exact_match(cx: &mut TestAppContext) {
629 let models = create_models(vec![
630 ("zed", "Claude 3.7 Sonnet"),
631 ("zed", "Claude 3.7 Sonnet Thinking"),
632 ("zed", "gpt-4.1"),
633 ("zed", "gpt-4.1-nano"),
634 ("openai", "gpt-3.5-turbo"),
635 ("openai", "gpt-4.1"),
636 ("openai", "gpt-4.1-nano"),
637 ("ollama", "mistral"),
638 ("ollama", "deepseek"),
639 ]);
640 let matcher = ModelMatcher::new(models, cx.background_executor.clone());
641
642 // The order of models should be maintained, case doesn't matter
643 let results = matcher.exact_search("GPT-4.1");
644 assert_models_eq(
645 results,
646 vec![
647 "zed/gpt-4.1",
648 "zed/gpt-4.1-nano",
649 "openai/gpt-4.1",
650 "openai/gpt-4.1-nano",
651 ],
652 );
653 }
654
655 #[gpui::test]
656 fn test_fuzzy_match(cx: &mut TestAppContext) {
657 let models = create_models(vec![
658 ("zed", "Claude 3.7 Sonnet"),
659 ("zed", "Claude 3.7 Sonnet Thinking"),
660 ("zed", "gpt-4.1"),
661 ("zed", "gpt-4.1-nano"),
662 ("openai", "gpt-3.5-turbo"),
663 ("openai", "gpt-4.1"),
664 ("openai", "gpt-4.1-nano"),
665 ("ollama", "mistral"),
666 ("ollama", "deepseek"),
667 ]);
668 let matcher = ModelMatcher::new(models, cx.background_executor.clone());
669
670 // Results should preserve models order whenever possible.
671 // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
672 // similarity scores, but `zed/gpt-4.1` was higher in the models list,
673 // so it should appear first in the results.
674 let results = matcher.fuzzy_search("41");
675 assert_models_eq(
676 results,
677 vec![
678 "zed/gpt-4.1",
679 "openai/gpt-4.1",
680 "zed/gpt-4.1-nano",
681 "openai/gpt-4.1-nano",
682 ],
683 );
684
685 // Model provider should be searchable as well
686 let results = matcher.fuzzy_search("ol"); // meaning "ollama"
687 assert_models_eq(results, vec!["ollama/mistral", "ollama/deepseek"]);
688
689 // Fuzzy search
690 let results = matcher.fuzzy_search("z4n");
691 assert_models_eq(results, vec!["zed/gpt-4.1-nano"]);
692 }
693
694 #[gpui::test]
695 fn test_recommended_models_also_appear_in_other(_cx: &mut TestAppContext) {
696 let recommended_models = create_models(vec![("zed", "claude")]);
697 let all_models = create_models(vec![
698 ("zed", "claude"), // Should also appear in "other"
699 ("zed", "gemini"),
700 ("copilot", "o3"),
701 ]);
702
703 let grouped_models = GroupedModels::new(all_models, recommended_models);
704
705 let actual_all_models = grouped_models
706 .all
707 .values()
708 .flatten()
709 .cloned()
710 .collect::<Vec<_>>();
711
712 // Recommended models should also appear in "all"
713 assert_models_eq(
714 actual_all_models,
715 vec!["zed/claude", "zed/gemini", "copilot/o3"],
716 );
717 }
718
719 #[gpui::test]
720 fn test_models_from_different_providers(_cx: &mut TestAppContext) {
721 let recommended_models = create_models(vec![("zed", "claude")]);
722 let all_models = create_models(vec![
723 ("zed", "claude"), // Should also appear in "other"
724 ("zed", "gemini"),
725 ("copilot", "claude"), // Different provider, should appear in "other"
726 ]);
727
728 let grouped_models = GroupedModels::new(all_models, recommended_models);
729
730 let actual_all_models = grouped_models
731 .all
732 .values()
733 .flatten()
734 .cloned()
735 .collect::<Vec<_>>();
736
737 // All models should appear in "all" regardless of recommended status
738 assert_models_eq(
739 actual_all_models,
740 vec!["zed/claude", "zed/gemini", "copilot/claude"],
741 );
742 }
743}