model_selector.rs

  1use std::sync::Arc;
  2
  3use crate::{assistant_settings::AssistantSettings, ShowConfiguration};
  4use fs::Fs;
  5use gpui::{Action, SharedString};
  6use language_model::{LanguageModelAvailability, LanguageModelRegistry};
  7use proto::Plan;
  8use settings::update_settings_file;
  9use ui::{prelude::*, ContextMenu, PopoverMenu, PopoverMenuHandle, PopoverTrigger};
 10
 11#[derive(IntoElement)]
 12pub struct ModelSelector<T: PopoverTrigger> {
 13    handle: Option<PopoverMenuHandle<ContextMenu>>,
 14    fs: Arc<dyn Fs>,
 15    trigger: T,
 16    info_text: Option<SharedString>,
 17}
 18
 19impl<T: PopoverTrigger> ModelSelector<T> {
 20    pub fn new(fs: Arc<dyn Fs>, trigger: T) -> Self {
 21        ModelSelector {
 22            handle: None,
 23            fs,
 24            trigger,
 25            info_text: None,
 26        }
 27    }
 28
 29    pub fn with_handle(mut self, handle: PopoverMenuHandle<ContextMenu>) -> Self {
 30        self.handle = Some(handle);
 31        self
 32    }
 33
 34    pub fn with_info_text(mut self, text: impl Into<SharedString>) -> Self {
 35        self.info_text = Some(text.into());
 36        self
 37    }
 38}
 39
 40impl<T: PopoverTrigger> RenderOnce for ModelSelector<T> {
 41    fn render(self, _cx: &mut WindowContext) -> impl IntoElement {
 42        let mut menu = PopoverMenu::new("model-switcher");
 43        if let Some(handle) = self.handle {
 44            menu = menu.with_handle(handle);
 45        }
 46
 47        let info_text = self.info_text.clone();
 48
 49        menu.menu(move |cx| {
 50            ContextMenu::build(cx, |mut menu, cx| {
 51                if let Some(info_text) = info_text.clone() {
 52                    menu = menu
 53                        .custom_row(move |_cx| {
 54                            Label::new(info_text.clone())
 55                                .color(Color::Muted)
 56                                .into_any_element()
 57                        })
 58                        .separator();
 59                }
 60
 61                for (index, provider) in LanguageModelRegistry::global(cx)
 62                    .read(cx)
 63                    .providers()
 64                    .into_iter()
 65                    .enumerate()
 66                {
 67                    let provider_icon = provider.icon();
 68                    let provider_name = provider.name().0.clone();
 69
 70                    if index > 0 {
 71                        menu = menu.separator();
 72                    }
 73                    menu = menu.custom_row(move |_| {
 74                        h_flex()
 75                            .pb_1()
 76                            .gap_1p5()
 77                            .w_full()
 78                            .child(
 79                                Icon::new(provider_icon)
 80                                    .color(Color::Muted)
 81                                    .size(IconSize::Small),
 82                            )
 83                            .child(Label::new(provider_name.clone()))
 84                            .into_any_element()
 85                    });
 86
 87                    let available_models = provider.provided_models(cx);
 88                    if available_models.is_empty() {
 89                        menu = menu.custom_entry(
 90                            {
 91                                move |_| {
 92                                    h_flex()
 93                                        .w_full()
 94                                        .gap_1()
 95                                        .child(Icon::new(IconName::Settings))
 96                                        .child(Label::new("Configure"))
 97                                        .into_any()
 98                                }
 99                            },
100                            {
101                                |cx| {
102                                    cx.dispatch_action(ShowConfiguration.boxed_clone());
103                                }
104                            },
105                        );
106                    }
107
108                    let selected_provider = LanguageModelRegistry::read_global(cx)
109                        .active_provider()
110                        .map(|m| m.id());
111                    let selected_model = LanguageModelRegistry::read_global(cx)
112                        .active_model()
113                        .map(|m| m.id());
114
115                    for available_model in available_models {
116                        menu = menu.custom_entry(
117                            {
118                                let id = available_model.id();
119                                let provider_id = available_model.provider_id();
120                                let model_name = available_model.name().0.clone();
121                                let availability = available_model.availability();
122                                let selected_model = selected_model.clone();
123                                let selected_provider = selected_provider.clone();
124                                move |cx| {
125                                    h_flex()
126                                        .w_full()
127                                        .justify_between()
128                                        .font_buffer(cx)
129                                        .min_w(px(260.))
130                                        .child(
131                                            h_flex()
132                                                .gap_2()
133                                                .child(Label::new(model_name.clone()))
134                                                .children(match availability {
135                                                    LanguageModelAvailability::Public => None,
136                                                    LanguageModelAvailability::RequiresPlan(
137                                                        Plan::Free,
138                                                    ) => None,
139                                                    LanguageModelAvailability::RequiresPlan(
140                                                        Plan::ZedPro,
141                                                    ) => Some(
142                                                        Label::new("Pro")
143                                                            .size(LabelSize::XSmall)
144                                                            .color(Color::Muted),
145                                                    ),
146                                                }),
147                                        )
148                                        .child(div().when(
149                                            selected_model.as_ref() == Some(&id)
150                                                && selected_provider.as_ref() == Some(&provider_id),
151                                            |this| {
152                                                this.child(
153                                                    Icon::new(IconName::Check)
154                                                        .color(Color::Accent)
155                                                        .size(IconSize::Small),
156                                                )
157                                            },
158                                        ))
159                                        .into_any()
160                                }
161                            },
162                            {
163                                let fs = self.fs.clone();
164                                let model = available_model.clone();
165                                move |cx| {
166                                    let model = model.clone();
167                                    update_settings_file::<AssistantSettings>(
168                                        fs.clone(),
169                                        cx,
170                                        move |settings, _| settings.set_model(model),
171                                    );
172                                }
173                            },
174                        );
175                    }
176                }
177                menu
178            })
179            .into()
180        })
181        .trigger(self.trigger)
182        .attach(gpui::AnchorCorner::BottomLeft)
183    }
184}