model_selector.rs
1use std::sync::Arc;
2
3use crate::assistant_settings::AssistantSettings;
4use fs::Fs;
5use gpui::SharedString;
6use language_model::{LanguageModelAvailability, LanguageModelRegistry};
7use proto::Plan;
8use settings::update_settings_file;
9use ui::{prelude::*, ContextMenu, PopoverMenu, PopoverMenuHandle, PopoverTrigger};
10
11#[derive(IntoElement)]
12pub struct ModelSelector<T: PopoverTrigger> {
13 handle: Option<PopoverMenuHandle<ContextMenu>>,
14 fs: Arc<dyn Fs>,
15 trigger: T,
16 info_text: Option<SharedString>,
17}
18
19impl<T: PopoverTrigger> ModelSelector<T> {
20 pub fn new(fs: Arc<dyn Fs>, trigger: T) -> Self {
21 ModelSelector {
22 handle: None,
23 fs,
24 trigger,
25 info_text: None,
26 }
27 }
28
29 pub fn with_handle(mut self, handle: PopoverMenuHandle<ContextMenu>) -> Self {
30 self.handle = Some(handle);
31 self
32 }
33
34 pub fn with_info_text(mut self, text: impl Into<SharedString>) -> Self {
35 self.info_text = Some(text.into());
36 self
37 }
38}
39
40impl<T: PopoverTrigger> RenderOnce for ModelSelector<T> {
41 fn render(self, _cx: &mut WindowContext) -> impl IntoElement {
42 let mut menu = PopoverMenu::new("model-switcher");
43 if let Some(handle) = self.handle {
44 menu = menu.with_handle(handle);
45 }
46
47 let info_text = self.info_text.clone();
48
49 menu.menu(move |cx| {
50 ContextMenu::build(cx, |mut menu, cx| {
51 if let Some(info_text) = info_text.clone() {
52 menu = menu
53 .custom_row(move |_cx| {
54 Label::new(info_text.clone())
55 .color(Color::Muted)
56 .into_any_element()
57 })
58 .separator();
59 }
60
61 for (index, provider) in LanguageModelRegistry::global(cx)
62 .read(cx)
63 .providers()
64 .into_iter()
65 .enumerate()
66 {
67 let provider_icon = provider.icon();
68 let provider_name = provider.name().0.clone();
69
70 if index > 0 {
71 menu = menu.separator();
72 }
73 menu = menu.custom_row(move |_| {
74 h_flex()
75 .pb_1()
76 .gap_1p5()
77 .w_full()
78 .child(
79 Icon::new(provider_icon)
80 .color(Color::Muted)
81 .size(IconSize::Small),
82 )
83 .child(Label::new(provider_name.clone()))
84 .into_any_element()
85 });
86
87 let available_models = provider.provided_models(cx);
88 if available_models.is_empty() {
89 menu = menu.custom_entry(
90 {
91 move |_| {
92 h_flex()
93 .w_full()
94 .gap_1()
95 .child(Icon::new(IconName::Settings))
96 .child(Label::new("Configure"))
97 .into_any()
98 }
99 },
100 {
101 let provider = provider.clone();
102 move |cx| {
103 LanguageModelRegistry::global(cx).update(
104 cx,
105 |completion_provider, cx| {
106 completion_provider
107 .set_active_provider(Some(provider.clone()), cx);
108 },
109 );
110 }
111 },
112 );
113 }
114
115 let selected_provider = LanguageModelRegistry::read_global(cx)
116 .active_provider()
117 .map(|m| m.id());
118 let selected_model = LanguageModelRegistry::read_global(cx)
119 .active_model()
120 .map(|m| m.id());
121
122 for available_model in available_models {
123 menu = menu.custom_entry(
124 {
125 let id = available_model.id();
126 let provider_id = available_model.provider_id();
127 let model_name = available_model.name().0.clone();
128 let availability = available_model.availability();
129 let selected_model = selected_model.clone();
130 let selected_provider = selected_provider.clone();
131 move |cx| {
132 h_flex()
133 .w_full()
134 .justify_between()
135 .font_buffer(cx)
136 .min_w(px(260.))
137 .child(
138 h_flex()
139 .gap_2()
140 .child(Label::new(model_name.clone()))
141 .children(match availability {
142 LanguageModelAvailability::Public => None,
143 LanguageModelAvailability::RequiresPlan(
144 Plan::Free,
145 ) => None,
146 LanguageModelAvailability::RequiresPlan(
147 Plan::ZedPro,
148 ) => Some(
149 Label::new("Pro")
150 .size(LabelSize::XSmall)
151 .color(Color::Muted),
152 ),
153 }),
154 )
155 .child(div().when(
156 selected_model.as_ref() == Some(&id)
157 && selected_provider.as_ref() == Some(&provider_id),
158 |this| {
159 this.child(
160 Icon::new(IconName::Check)
161 .color(Color::Accent)
162 .size(IconSize::Small),
163 )
164 },
165 ))
166 .into_any()
167 }
168 },
169 {
170 let fs = self.fs.clone();
171 let model = available_model.clone();
172 move |cx| {
173 let model = model.clone();
174 update_settings_file::<AssistantSettings>(
175 fs.clone(),
176 cx,
177 move |settings, _| settings.set_model(model),
178 );
179 }
180 },
181 );
182 }
183 }
184 menu
185 })
186 .into()
187 })
188 .trigger(self.trigger)
189 .attach(gpui::AnchorCorner::BottomLeft)
190 }
191}