1use std::sync::Arc;
2
3use crate::{assistant_settings::AssistantSettings, LanguageModelCompletionProvider};
4use fs::Fs;
5use gpui::SharedString;
6use language_model::LanguageModelRegistry;
7use settings::update_settings_file;
8use ui::{prelude::*, ContextMenu, PopoverMenu, PopoverMenuHandle, PopoverTrigger};
9
10#[derive(IntoElement)]
11pub struct ModelSelector<T: PopoverTrigger> {
12 handle: Option<PopoverMenuHandle<ContextMenu>>,
13 fs: Arc<dyn Fs>,
14 trigger: T,
15 info_text: Option<SharedString>,
16}
17
18impl<T: PopoverTrigger> ModelSelector<T> {
19 pub fn new(fs: Arc<dyn Fs>, trigger: T) -> Self {
20 ModelSelector {
21 handle: None,
22 fs,
23 trigger,
24 info_text: None,
25 }
26 }
27
28 pub fn with_handle(mut self, handle: PopoverMenuHandle<ContextMenu>) -> Self {
29 self.handle = Some(handle);
30 self
31 }
32
33 pub fn with_info_text(mut self, text: impl Into<SharedString>) -> Self {
34 self.info_text = Some(text.into());
35 self
36 }
37}
38
39impl<T: PopoverTrigger> RenderOnce for ModelSelector<T> {
40 fn render(self, _: &mut WindowContext) -> impl IntoElement {
41 let mut menu = PopoverMenu::new("model-switcher");
42 if let Some(handle) = self.handle {
43 menu = menu.with_handle(handle);
44 }
45
46 let info_text = self.info_text.clone();
47
48 menu.menu(move |cx| {
49 ContextMenu::build(cx, |mut menu, cx| {
50 if let Some(info_text) = info_text.clone() {
51 menu = menu
52 .custom_row(move |_cx| {
53 Label::new(info_text.clone())
54 .color(Color::Muted)
55 .into_any_element()
56 })
57 .separator();
58 }
59
60 for (index, provider) in LanguageModelRegistry::global(cx)
61 .read(cx)
62 .providers()
63 .enumerate()
64 {
65 if index > 0 {
66 menu = menu.separator();
67 }
68 menu = menu.header(provider.name().0);
69
70 let available_models = provider.provided_models(cx);
71 if available_models.is_empty() {
72 menu = menu.custom_entry(
73 {
74 move |_| {
75 h_flex()
76 .w_full()
77 .gap_1()
78 .child(Icon::new(IconName::Settings))
79 .child(Label::new("Configure"))
80 .into_any()
81 }
82 },
83 {
84 let provider = provider.id();
85 move |cx| {
86 LanguageModelCompletionProvider::global(cx).update(
87 cx,
88 |completion_provider, cx| {
89 completion_provider
90 .set_active_provider(provider.clone(), cx)
91 },
92 );
93 }
94 },
95 );
96 }
97
98 let selected_model = LanguageModelCompletionProvider::read_global(cx)
99 .active_model()
100 .map(|m| m.id());
101 let selected_provider = LanguageModelCompletionProvider::read_global(cx)
102 .active_provider()
103 .map(|m| m.id());
104
105 for available_model in available_models {
106 menu = menu.custom_entry(
107 {
108 let id = available_model.id();
109 let provider_id = available_model.provider_id();
110 let model_name = available_model.name().0.clone();
111 let selected_model = selected_model.clone();
112 let selected_provider = selected_provider.clone();
113 move |_| {
114 h_flex()
115 .w_full()
116 .justify_between()
117 .child(Label::new(model_name.clone()))
118 .when(
119 selected_model.as_ref() == Some(&id)
120 && selected_provider.as_ref() == Some(&provider_id),
121 |this| this.child(Icon::new(IconName::Check)),
122 )
123 .into_any()
124 }
125 },
126 {
127 let fs = self.fs.clone();
128 let model = available_model.clone();
129 move |cx| {
130 let model = model.clone();
131 update_settings_file::<AssistantSettings>(
132 fs.clone(),
133 cx,
134 move |settings, _| settings.set_model(model),
135 );
136 }
137 },
138 );
139 }
140 }
141 menu
142 })
143 .into()
144 })
145 .trigger(self.trigger)
146 .attach(gpui::AnchorCorner::BottomLeft)
147 }
148}