1use std::sync::Arc;
2
3use crate::assistant_settings::AssistantSettings;
4use fs::Fs;
5use gpui::SharedString;
6use language_model::LanguageModelRegistry;
7use settings::update_settings_file;
8use ui::{prelude::*, ContextMenu, PopoverMenu, PopoverMenuHandle, PopoverTrigger};
9
10#[derive(IntoElement)]
11pub struct ModelSelector<T: PopoverTrigger> {
12 handle: Option<PopoverMenuHandle<ContextMenu>>,
13 fs: Arc<dyn Fs>,
14 trigger: T,
15 info_text: Option<SharedString>,
16}
17
18impl<T: PopoverTrigger> ModelSelector<T> {
19 pub fn new(fs: Arc<dyn Fs>, trigger: T) -> Self {
20 ModelSelector {
21 handle: None,
22 fs,
23 trigger,
24 info_text: None,
25 }
26 }
27
28 pub fn with_handle(mut self, handle: PopoverMenuHandle<ContextMenu>) -> Self {
29 self.handle = Some(handle);
30 self
31 }
32
33 pub fn with_info_text(mut self, text: impl Into<SharedString>) -> Self {
34 self.info_text = Some(text.into());
35 self
36 }
37}
38
39impl<T: PopoverTrigger> RenderOnce for ModelSelector<T> {
40 fn render(self, _: &mut WindowContext) -> impl IntoElement {
41 let mut menu = PopoverMenu::new("model-switcher");
42 if let Some(handle) = self.handle {
43 menu = menu.with_handle(handle);
44 }
45
46 let info_text = self.info_text.clone();
47
48 menu.menu(move |cx| {
49 ContextMenu::build(cx, |mut menu, cx| {
50 if let Some(info_text) = info_text.clone() {
51 menu = menu
52 .custom_row(move |_cx| {
53 Label::new(info_text.clone())
54 .color(Color::Muted)
55 .into_any_element()
56 })
57 .separator();
58 }
59
60 for (index, provider) in LanguageModelRegistry::global(cx)
61 .read(cx)
62 .providers()
63 .into_iter()
64 .enumerate()
65 {
66 if index > 0 {
67 menu = menu.separator();
68 }
69 menu = menu.header(provider.name().0);
70
71 let available_models = provider.provided_models(cx);
72 if available_models.is_empty() {
73 menu = menu.custom_entry(
74 {
75 move |_| {
76 h_flex()
77 .w_full()
78 .gap_1()
79 .child(Icon::new(IconName::Settings))
80 .child(Label::new("Configure"))
81 .into_any()
82 }
83 },
84 {
85 let provider = provider.clone();
86 move |cx| {
87 LanguageModelRegistry::global(cx).update(
88 cx,
89 |completion_provider, cx| {
90 completion_provider
91 .set_active_provider(Some(provider.clone()), cx);
92 },
93 );
94 }
95 },
96 );
97 }
98
99 let selected_provider = LanguageModelRegistry::read_global(cx)
100 .active_provider()
101 .map(|m| m.id());
102 let selected_model = LanguageModelRegistry::read_global(cx)
103 .active_model()
104 .map(|m| m.id());
105
106 for available_model in available_models {
107 menu = menu.custom_entry(
108 {
109 let id = available_model.id();
110 let provider_id = available_model.provider_id();
111 let model_name = available_model.name().0.clone();
112 let selected_model = selected_model.clone();
113 let selected_provider = selected_provider.clone();
114 move |_| {
115 h_flex()
116 .w_full()
117 .justify_between()
118 .child(Label::new(model_name.clone()))
119 .when(
120 selected_model.as_ref() == Some(&id)
121 && selected_provider.as_ref() == Some(&provider_id),
122 |this| this.child(Icon::new(IconName::Check)),
123 )
124 .into_any()
125 }
126 },
127 {
128 let fs = self.fs.clone();
129 let model = available_model.clone();
130 move |cx| {
131 let model = model.clone();
132 update_settings_file::<AssistantSettings>(
133 fs.clone(),
134 cx,
135 move |settings, _| settings.set_model(model),
136 );
137 }
138 },
139 );
140 }
141 }
142 menu
143 })
144 .into()
145 })
146 .trigger(self.trigger)
147 .attach(gpui::AnchorCorner::BottomLeft)
148 }
149}