1use std::{cmp::Reverse, sync::Arc};
2
3use cloud_llm_client::Plan;
4use collections::{HashSet, IndexMap};
5use feature_flags::ZedProFeatureFlag;
6use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
7use gpui::{Action, AnyElement, App, BackgroundExecutor, DismissEvent, Subscription, Task};
8use language_model::{
9 ConfiguredModel, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
10};
11use ordered_float::OrderedFloat;
12use picker::{Picker, PickerDelegate};
13use ui::{ListItem, ListItemSpacing, prelude::*};
14
15const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
16
17type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &mut App) + 'static>;
18type GetActiveModel = Arc<dyn Fn(&App) -> Option<ConfiguredModel> + 'static>;
19
20pub type LanguageModelSelector = Picker<LanguageModelPickerDelegate>;
21
22pub fn language_model_selector(
23 get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
24 on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
25 window: &mut Window,
26 cx: &mut Context<LanguageModelSelector>,
27) -> LanguageModelSelector {
28 let delegate = LanguageModelPickerDelegate::new(get_active_model, on_model_changed, window, cx);
29 Picker::list(delegate, window, cx)
30 .show_scrollbar(true)
31 .width(rems(20.))
32 .max_height(Some(rems(20.).into()))
33}
34
35fn all_models(cx: &App) -> GroupedModels {
36 let providers = LanguageModelRegistry::global(cx).read(cx).providers();
37
38 let recommended = providers
39 .iter()
40 .flat_map(|provider| {
41 provider
42 .recommended_models(cx)
43 .into_iter()
44 .map(|model| ModelInfo {
45 model,
46 icon: provider.icon(),
47 })
48 })
49 .collect();
50
51 let other = providers
52 .iter()
53 .flat_map(|provider| {
54 provider
55 .provided_models(cx)
56 .into_iter()
57 .map(|model| ModelInfo {
58 model,
59 icon: provider.icon(),
60 })
61 })
62 .collect();
63
64 GroupedModels::new(other, recommended)
65}
66
67#[derive(Clone)]
68struct ModelInfo {
69 model: Arc<dyn LanguageModel>,
70 icon: IconName,
71}
72
73pub struct LanguageModelPickerDelegate {
74 on_model_changed: OnModelChanged,
75 get_active_model: GetActiveModel,
76 all_models: Arc<GroupedModels>,
77 filtered_entries: Vec<LanguageModelPickerEntry>,
78 selected_index: usize,
79 _subscriptions: Vec<Subscription>,
80}
81
82impl LanguageModelPickerDelegate {
83 fn new(
84 get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
85 on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
86 window: &mut Window,
87 cx: &mut Context<Picker<Self>>,
88 ) -> Self {
89 let on_model_changed = Arc::new(on_model_changed);
90 let models = all_models(cx);
91 let entries = models.entries();
92
93 Self {
94 on_model_changed,
95 all_models: Arc::new(models),
96 selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
97 filtered_entries: entries,
98 get_active_model: Arc::new(get_active_model),
99 _subscriptions: vec![cx.subscribe_in(
100 &LanguageModelRegistry::global(cx),
101 window,
102 |picker, _, event, window, cx| {
103 match event {
104 language_model::Event::ProviderStateChanged(_)
105 | language_model::Event::AddedProvider(_)
106 | language_model::Event::RemovedProvider(_) => {
107 let query = picker.query(cx);
108 picker.delegate.all_models = Arc::new(all_models(cx));
109 // Update matches will automatically drop the previous task
110 // if we get a provider event again
111 picker.update_matches(query, window, cx)
112 }
113 _ => {}
114 }
115 },
116 )],
117 }
118 }
119
120 fn get_active_model_index(
121 entries: &[LanguageModelPickerEntry],
122 active_model: Option<ConfiguredModel>,
123 ) -> usize {
124 entries
125 .iter()
126 .position(|entry| {
127 if let LanguageModelPickerEntry::Model(model) = entry {
128 active_model
129 .as_ref()
130 .map(|active_model| {
131 active_model.model.id() == model.model.id()
132 && active_model.provider.id() == model.model.provider_id()
133 })
134 .unwrap_or_default()
135 } else {
136 false
137 }
138 })
139 .unwrap_or(0)
140 }
141
142 pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
143 (self.get_active_model)(cx)
144 }
145}
146
147struct GroupedModels {
148 recommended: Vec<ModelInfo>,
149 other: IndexMap<LanguageModelProviderId, Vec<ModelInfo>>,
150}
151
152impl GroupedModels {
153 pub fn new(other: Vec<ModelInfo>, recommended: Vec<ModelInfo>) -> Self {
154 let recommended_ids = recommended
155 .iter()
156 .map(|info| (info.model.provider_id(), info.model.id()))
157 .collect::<HashSet<_>>();
158
159 let mut other_by_provider: IndexMap<_, Vec<ModelInfo>> = IndexMap::default();
160 for model in other {
161 if recommended_ids.contains(&(model.model.provider_id(), model.model.id())) {
162 continue;
163 }
164
165 let provider = model.model.provider_id();
166 if let Some(models) = other_by_provider.get_mut(&provider) {
167 models.push(model);
168 } else {
169 other_by_provider.insert(provider, vec![model]);
170 }
171 }
172
173 Self {
174 recommended,
175 other: other_by_provider,
176 }
177 }
178
179 fn entries(&self) -> Vec<LanguageModelPickerEntry> {
180 let mut entries = Vec::new();
181
182 if !self.recommended.is_empty() {
183 entries.push(LanguageModelPickerEntry::Separator("Recommended".into()));
184 entries.extend(
185 self.recommended
186 .iter()
187 .map(|info| LanguageModelPickerEntry::Model(info.clone())),
188 );
189 }
190
191 for models in self.other.values() {
192 if models.is_empty() {
193 continue;
194 }
195 entries.push(LanguageModelPickerEntry::Separator(
196 models[0].model.provider_name().0,
197 ));
198 entries.extend(
199 models
200 .iter()
201 .map(|info| LanguageModelPickerEntry::Model(info.clone())),
202 );
203 }
204 entries
205 }
206
207 fn model_infos(&self) -> Vec<ModelInfo> {
208 let other = self
209 .other
210 .values()
211 .flat_map(|model| model.iter())
212 .cloned()
213 .collect::<Vec<_>>();
214 self.recommended
215 .iter()
216 .chain(&other)
217 .cloned()
218 .collect::<Vec<_>>()
219 }
220}
221
222enum LanguageModelPickerEntry {
223 Model(ModelInfo),
224 Separator(SharedString),
225}
226
227struct ModelMatcher {
228 models: Vec<ModelInfo>,
229 bg_executor: BackgroundExecutor,
230 candidates: Vec<StringMatchCandidate>,
231}
232
233impl ModelMatcher {
234 fn new(models: Vec<ModelInfo>, bg_executor: BackgroundExecutor) -> ModelMatcher {
235 let candidates = Self::make_match_candidates(&models);
236 Self {
237 models,
238 bg_executor,
239 candidates,
240 }
241 }
242
243 pub fn fuzzy_search(&self, query: &str) -> Vec<ModelInfo> {
244 let mut matches = self.bg_executor.block(match_strings(
245 &self.candidates,
246 query,
247 false,
248 true,
249 100,
250 &Default::default(),
251 self.bg_executor.clone(),
252 ));
253
254 let sorting_key = |mat: &StringMatch| {
255 let candidate = &self.candidates[mat.candidate_id];
256 (Reverse(OrderedFloat(mat.score)), candidate.id)
257 };
258 matches.sort_unstable_by_key(sorting_key);
259
260 let matched_models: Vec<_> = matches
261 .into_iter()
262 .map(|mat| self.models[mat.candidate_id].clone())
263 .collect();
264
265 matched_models
266 }
267
268 pub fn exact_search(&self, query: &str) -> Vec<ModelInfo> {
269 self.models
270 .iter()
271 .filter(|m| {
272 m.model
273 .name()
274 .0
275 .to_lowercase()
276 .contains(&query.to_lowercase())
277 })
278 .cloned()
279 .collect::<Vec<_>>()
280 }
281
282 fn make_match_candidates(model_infos: &Vec<ModelInfo>) -> Vec<StringMatchCandidate> {
283 model_infos
284 .iter()
285 .enumerate()
286 .map(|(index, model)| {
287 StringMatchCandidate::new(
288 index,
289 &format!(
290 "{}/{}",
291 &model.model.provider_name().0,
292 &model.model.name().0
293 ),
294 )
295 })
296 .collect::<Vec<_>>()
297 }
298}
299
300impl PickerDelegate for LanguageModelPickerDelegate {
301 type ListItem = AnyElement;
302
303 fn match_count(&self) -> usize {
304 self.filtered_entries.len()
305 }
306
307 fn selected_index(&self) -> usize {
308 self.selected_index
309 }
310
311 fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
312 self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
313 cx.notify();
314 }
315
316 fn can_select(
317 &mut self,
318 ix: usize,
319 _window: &mut Window,
320 _cx: &mut Context<Picker<Self>>,
321 ) -> bool {
322 match self.filtered_entries.get(ix) {
323 Some(LanguageModelPickerEntry::Model(_)) => true,
324 Some(LanguageModelPickerEntry::Separator(_)) | None => false,
325 }
326 }
327
328 fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
329 "Select a model…".into()
330 }
331
332 fn update_matches(
333 &mut self,
334 query: String,
335 window: &mut Window,
336 cx: &mut Context<Picker<Self>>,
337 ) -> Task<()> {
338 let all_models = self.all_models.clone();
339 let active_model = (self.get_active_model)(cx);
340 let bg_executor = cx.background_executor();
341
342 let language_model_registry = LanguageModelRegistry::global(cx);
343
344 let configured_providers = language_model_registry
345 .read(cx)
346 .providers()
347 .into_iter()
348 .filter(|provider| provider.is_authenticated(cx))
349 .collect::<Vec<_>>();
350
351 let configured_provider_ids = configured_providers
352 .iter()
353 .map(|provider| provider.id())
354 .collect::<Vec<_>>();
355
356 let recommended_models = all_models
357 .recommended
358 .iter()
359 .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
360 .cloned()
361 .collect::<Vec<_>>();
362
363 let available_models = all_models
364 .model_infos()
365 .iter()
366 .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
367 .cloned()
368 .collect::<Vec<_>>();
369
370 let matcher_rec = ModelMatcher::new(recommended_models, bg_executor.clone());
371 let matcher_all = ModelMatcher::new(available_models, bg_executor.clone());
372
373 let recommended = matcher_rec.exact_search(&query);
374 let all = matcher_all.fuzzy_search(&query);
375
376 let filtered_models = GroupedModels::new(all, recommended);
377
378 cx.spawn_in(window, async move |this, cx| {
379 this.update_in(cx, |this, window, cx| {
380 this.delegate.filtered_entries = filtered_models.entries();
381 // Finds the currently selected model in the list
382 let new_index =
383 Self::get_active_model_index(&this.delegate.filtered_entries, active_model);
384 this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
385 cx.notify();
386 })
387 .ok();
388 })
389 }
390
391 fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
392 if let Some(LanguageModelPickerEntry::Model(model_info)) =
393 self.filtered_entries.get(self.selected_index)
394 {
395 let model = model_info.model.clone();
396 (self.on_model_changed)(model.clone(), cx);
397
398 let current_index = self.selected_index;
399 self.set_selected_index(current_index, window, cx);
400
401 cx.emit(DismissEvent);
402 }
403 }
404
405 fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
406 cx.emit(DismissEvent);
407 }
408
409 fn render_match(
410 &self,
411 ix: usize,
412 selected: bool,
413 _: &mut Window,
414 cx: &mut Context<Picker<Self>>,
415 ) -> Option<Self::ListItem> {
416 match self.filtered_entries.get(ix)? {
417 LanguageModelPickerEntry::Separator(title) => Some(
418 div()
419 .px_2()
420 .pb_1()
421 .when(ix > 1, |this| {
422 this.mt_1()
423 .pt_2()
424 .border_t_1()
425 .border_color(cx.theme().colors().border_variant)
426 })
427 .child(
428 Label::new(title)
429 .size(LabelSize::XSmall)
430 .color(Color::Muted),
431 )
432 .into_any_element(),
433 ),
434 LanguageModelPickerEntry::Model(model_info) => {
435 let active_model = (self.get_active_model)(cx);
436 let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
437 let active_model_id = active_model.map(|m| m.model.id());
438
439 let is_selected = Some(model_info.model.provider_id()) == active_provider_id
440 && Some(model_info.model.id()) == active_model_id;
441
442 let model_icon_color = if is_selected {
443 Color::Accent
444 } else {
445 Color::Muted
446 };
447
448 Some(
449 ListItem::new(ix)
450 .inset(true)
451 .spacing(ListItemSpacing::Sparse)
452 .toggle_state(selected)
453 .start_slot(
454 Icon::new(model_info.icon)
455 .color(model_icon_color)
456 .size(IconSize::Small),
457 )
458 .child(
459 h_flex()
460 .w_full()
461 .pl_0p5()
462 .gap_1p5()
463 .w(px(240.))
464 .child(Label::new(model_info.model.name().0).truncate()),
465 )
466 .end_slot(div().pr_3().when(is_selected, |this| {
467 this.child(
468 Icon::new(IconName::Check)
469 .color(Color::Accent)
470 .size(IconSize::Small),
471 )
472 }))
473 .into_any_element(),
474 )
475 }
476 }
477 }
478
479 fn render_footer(
480 &self,
481 _: &mut Window,
482 cx: &mut Context<Picker<Self>>,
483 ) -> Option<gpui::AnyElement> {
484 use feature_flags::FeatureFlagAppExt;
485
486 let plan = Plan::ZedPro;
487
488 Some(
489 h_flex()
490 .w_full()
491 .border_t_1()
492 .border_color(cx.theme().colors().border_variant)
493 .p_1()
494 .gap_4()
495 .justify_between()
496 .when(cx.has_flag::<ZedProFeatureFlag>(), |this| {
497 this.child(match plan {
498 Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
499 .icon(IconName::ZedAssistant)
500 .icon_size(IconSize::Small)
501 .icon_color(Color::Muted)
502 .icon_position(IconPosition::Start)
503 .on_click(|_, window, cx| {
504 window
505 .dispatch_action(Box::new(zed_actions::OpenAccountSettings), cx)
506 }),
507 Plan::ZedFree | Plan::ZedProTrial => Button::new(
508 "try-pro",
509 if plan == Plan::ZedProTrial {
510 "Upgrade to Pro"
511 } else {
512 "Try Pro"
513 },
514 )
515 .on_click(|_, _, cx| cx.open_url(TRY_ZED_PRO_URL)),
516 })
517 })
518 .child(
519 Button::new("configure", "Configure")
520 .icon(IconName::Settings)
521 .icon_size(IconSize::Small)
522 .icon_color(Color::Muted)
523 .icon_position(IconPosition::Start)
524 .on_click(|_, window, cx| {
525 window.dispatch_action(
526 zed_actions::agent::OpenSettings.boxed_clone(),
527 cx,
528 );
529 }),
530 )
531 .into_any(),
532 )
533 }
534}
535
536#[cfg(test)]
537mod tests {
538 use super::*;
539 use futures::{future::BoxFuture, stream::BoxStream};
540 use gpui::{AsyncApp, TestAppContext, http_client};
541 use language_model::{
542 LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
543 LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
544 LanguageModelRequest, LanguageModelToolChoice,
545 };
546 use ui::IconName;
547
548 #[derive(Clone)]
549 struct TestLanguageModel {
550 name: LanguageModelName,
551 id: LanguageModelId,
552 provider_id: LanguageModelProviderId,
553 provider_name: LanguageModelProviderName,
554 }
555
556 impl TestLanguageModel {
557 fn new(name: &str, provider: &str) -> Self {
558 Self {
559 name: LanguageModelName::from(name.to_string()),
560 id: LanguageModelId::from(name.to_string()),
561 provider_id: LanguageModelProviderId::from(provider.to_string()),
562 provider_name: LanguageModelProviderName::from(provider.to_string()),
563 }
564 }
565 }
566
567 impl LanguageModel for TestLanguageModel {
568 fn id(&self) -> LanguageModelId {
569 self.id.clone()
570 }
571
572 fn name(&self) -> LanguageModelName {
573 self.name.clone()
574 }
575
576 fn provider_id(&self) -> LanguageModelProviderId {
577 self.provider_id.clone()
578 }
579
580 fn provider_name(&self) -> LanguageModelProviderName {
581 self.provider_name.clone()
582 }
583
584 fn supports_tools(&self) -> bool {
585 false
586 }
587
588 fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
589 false
590 }
591
592 fn supports_images(&self) -> bool {
593 false
594 }
595
596 fn telemetry_id(&self) -> String {
597 format!("{}/{}", self.provider_id.0, self.name.0)
598 }
599
600 fn max_token_count(&self) -> u64 {
601 1000
602 }
603
604 fn count_tokens(
605 &self,
606 _: LanguageModelRequest,
607 _: &App,
608 ) -> BoxFuture<'static, http_client::Result<u64>> {
609 unimplemented!()
610 }
611
612 fn stream_completion(
613 &self,
614 _: LanguageModelRequest,
615 _: &AsyncApp,
616 ) -> BoxFuture<
617 'static,
618 Result<
619 BoxStream<
620 'static,
621 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
622 >,
623 LanguageModelCompletionError,
624 >,
625 > {
626 unimplemented!()
627 }
628 }
629
630 fn create_models(model_specs: Vec<(&str, &str)>) -> Vec<ModelInfo> {
631 model_specs
632 .into_iter()
633 .map(|(provider, name)| ModelInfo {
634 model: Arc::new(TestLanguageModel::new(name, provider)),
635 icon: IconName::Ai,
636 })
637 .collect()
638 }
639
640 fn assert_models_eq(result: Vec<ModelInfo>, expected: Vec<&str>) {
641 assert_eq!(
642 result.len(),
643 expected.len(),
644 "Number of models doesn't match"
645 );
646
647 for (i, expected_name) in expected.iter().enumerate() {
648 assert_eq!(
649 result[i].model.telemetry_id(),
650 *expected_name,
651 "Model at position {} doesn't match expected model",
652 i
653 );
654 }
655 }
656
657 #[gpui::test]
658 fn test_exact_match(cx: &mut TestAppContext) {
659 let models = create_models(vec![
660 ("zed", "Claude 3.7 Sonnet"),
661 ("zed", "Claude 3.7 Sonnet Thinking"),
662 ("zed", "gpt-4.1"),
663 ("zed", "gpt-4.1-nano"),
664 ("openai", "gpt-3.5-turbo"),
665 ("openai", "gpt-4.1"),
666 ("openai", "gpt-4.1-nano"),
667 ("ollama", "mistral"),
668 ("ollama", "deepseek"),
669 ]);
670 let matcher = ModelMatcher::new(models, cx.background_executor.clone());
671
672 // The order of models should be maintained, case doesn't matter
673 let results = matcher.exact_search("GPT-4.1");
674 assert_models_eq(
675 results,
676 vec![
677 "zed/gpt-4.1",
678 "zed/gpt-4.1-nano",
679 "openai/gpt-4.1",
680 "openai/gpt-4.1-nano",
681 ],
682 );
683 }
684
685 #[gpui::test]
686 fn test_fuzzy_match(cx: &mut TestAppContext) {
687 let models = create_models(vec![
688 ("zed", "Claude 3.7 Sonnet"),
689 ("zed", "Claude 3.7 Sonnet Thinking"),
690 ("zed", "gpt-4.1"),
691 ("zed", "gpt-4.1-nano"),
692 ("openai", "gpt-3.5-turbo"),
693 ("openai", "gpt-4.1"),
694 ("openai", "gpt-4.1-nano"),
695 ("ollama", "mistral"),
696 ("ollama", "deepseek"),
697 ]);
698 let matcher = ModelMatcher::new(models, cx.background_executor.clone());
699
700 // Results should preserve models order whenever possible.
701 // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
702 // similarity scores, but `zed/gpt-4.1` was higher in the models list,
703 // so it should appear first in the results.
704 let results = matcher.fuzzy_search("41");
705 assert_models_eq(
706 results,
707 vec![
708 "zed/gpt-4.1",
709 "openai/gpt-4.1",
710 "zed/gpt-4.1-nano",
711 "openai/gpt-4.1-nano",
712 ],
713 );
714
715 // Model provider should be searchable as well
716 let results = matcher.fuzzy_search("ol"); // meaning "ollama"
717 assert_models_eq(results, vec!["ollama/mistral", "ollama/deepseek"]);
718
719 // Fuzzy search
720 let results = matcher.fuzzy_search("z4n");
721 assert_models_eq(results, vec!["zed/gpt-4.1-nano"]);
722 }
723
724 #[gpui::test]
725 fn test_exclude_recommended_models(_cx: &mut TestAppContext) {
726 let recommended_models = create_models(vec![("zed", "claude")]);
727 let all_models = create_models(vec![
728 ("zed", "claude"), // Should be filtered out from "other"
729 ("zed", "gemini"),
730 ("copilot", "o3"),
731 ]);
732
733 let grouped_models = GroupedModels::new(all_models, recommended_models);
734
735 let actual_other_models = grouped_models
736 .other
737 .values()
738 .flatten()
739 .cloned()
740 .collect::<Vec<_>>();
741
742 // Recommended models should not appear in "other"
743 assert_models_eq(actual_other_models, vec!["zed/gemini", "copilot/o3"]);
744 }
745
746 #[gpui::test]
747 fn test_dont_exclude_models_from_other_providers(_cx: &mut TestAppContext) {
748 let recommended_models = create_models(vec![("zed", "claude")]);
749 let all_models = create_models(vec![
750 ("zed", "claude"), // Should be filtered out from "other"
751 ("zed", "gemini"),
752 ("copilot", "claude"), // Should not be filtered out from "other"
753 ]);
754
755 let grouped_models = GroupedModels::new(all_models, recommended_models);
756
757 let actual_other_models = grouped_models
758 .other
759 .values()
760 .flatten()
761 .cloned()
762 .collect::<Vec<_>>();
763
764 // Recommended models should not appear in "other"
765 assert_models_eq(actual_other_models, vec!["zed/gemini", "copilot/claude"]);
766 }
767}