model_selector.rs

  1use std::sync::Arc;
  2
  3use crate::assistant_settings::AssistantSettings;
  4use fs::Fs;
  5use gpui::SharedString;
  6use language_model::{LanguageModelAvailability, LanguageModelRegistry};
  7use proto::Plan;
  8use settings::update_settings_file;
  9use ui::{prelude::*, ContextMenu, PopoverMenu, PopoverMenuHandle, PopoverTrigger};
 10
 11#[derive(IntoElement)]
 12pub struct ModelSelector<T: PopoverTrigger> {
 13    handle: Option<PopoverMenuHandle<ContextMenu>>,
 14    fs: Arc<dyn Fs>,
 15    trigger: T,
 16    info_text: Option<SharedString>,
 17}
 18
 19impl<T: PopoverTrigger> ModelSelector<T> {
 20    pub fn new(fs: Arc<dyn Fs>, trigger: T) -> Self {
 21        ModelSelector {
 22            handle: None,
 23            fs,
 24            trigger,
 25            info_text: None,
 26        }
 27    }
 28
 29    pub fn with_handle(mut self, handle: PopoverMenuHandle<ContextMenu>) -> Self {
 30        self.handle = Some(handle);
 31        self
 32    }
 33
 34    pub fn with_info_text(mut self, text: impl Into<SharedString>) -> Self {
 35        self.info_text = Some(text.into());
 36        self
 37    }
 38}
 39
 40impl<T: PopoverTrigger> RenderOnce for ModelSelector<T> {
 41    fn render(self, _cx: &mut WindowContext) -> impl IntoElement {
 42        let mut menu = PopoverMenu::new("model-switcher");
 43        if let Some(handle) = self.handle {
 44            menu = menu.with_handle(handle);
 45        }
 46
 47        let info_text = self.info_text.clone();
 48
 49        menu.menu(move |cx| {
 50            ContextMenu::build(cx, |mut menu, cx| {
 51                if let Some(info_text) = info_text.clone() {
 52                    menu = menu
 53                        .custom_row(move |_cx| {
 54                            Label::new(info_text.clone())
 55                                .color(Color::Muted)
 56                                .into_any_element()
 57                        })
 58                        .separator();
 59                }
 60
 61                for (index, provider) in LanguageModelRegistry::global(cx)
 62                    .read(cx)
 63                    .providers()
 64                    .into_iter()
 65                    .enumerate()
 66                {
 67                    let provider_icon = provider.icon();
 68                    let provider_name = provider.name().0.clone();
 69
 70                    if index > 0 {
 71                        menu = menu.separator();
 72                    }
 73                    menu = menu.custom_row(move |_| {
 74                        h_flex()
 75                            .pb_1()
 76                            .gap_1p5()
 77                            .w_full()
 78                            .child(
 79                                Icon::new(provider_icon)
 80                                    .color(Color::Muted)
 81                                    .size(IconSize::Small),
 82                            )
 83                            .child(Label::new(provider_name.clone()))
 84                            .into_any_element()
 85                    });
 86
 87                    let available_models = provider.provided_models(cx);
 88                    if available_models.is_empty() {
 89                        menu = menu.custom_entry(
 90                            {
 91                                move |_| {
 92                                    h_flex()
 93                                        .w_full()
 94                                        .gap_1()
 95                                        .child(Icon::new(IconName::Settings))
 96                                        .child(Label::new("Configure"))
 97                                        .into_any()
 98                                }
 99                            },
100                            {
101                                let provider = provider.clone();
102                                move |cx| {
103                                    LanguageModelRegistry::global(cx).update(
104                                        cx,
105                                        |completion_provider, cx| {
106                                            completion_provider
107                                                .set_active_provider(Some(provider.clone()), cx);
108                                        },
109                                    );
110                                }
111                            },
112                        );
113                    }
114
115                    let selected_provider = LanguageModelRegistry::read_global(cx)
116                        .active_provider()
117                        .map(|m| m.id());
118                    let selected_model = LanguageModelRegistry::read_global(cx)
119                        .active_model()
120                        .map(|m| m.id());
121
122                    for available_model in available_models {
123                        menu = menu.custom_entry(
124                            {
125                                let id = available_model.id();
126                                let provider_id = available_model.provider_id();
127                                let model_name = available_model.name().0.clone();
128                                let availability = available_model.availability();
129                                let selected_model = selected_model.clone();
130                                let selected_provider = selected_provider.clone();
131                                move |cx| {
132                                    h_flex()
133                                        .w_full()
134                                        .justify_between()
135                                        .font_buffer(cx)
136                                        .min_w(px(260.))
137                                        .child(
138                                            h_flex()
139                                                .gap_2()
140                                                .child(Label::new(model_name.clone()))
141                                                .children(match availability {
142                                                    LanguageModelAvailability::Public => None,
143                                                    LanguageModelAvailability::RequiresPlan(
144                                                        Plan::Free,
145                                                    ) => None,
146                                                    LanguageModelAvailability::RequiresPlan(
147                                                        Plan::ZedPro,
148                                                    ) => Some(
149                                                        Label::new("Pro")
150                                                            .size(LabelSize::XSmall)
151                                                            .color(Color::Muted),
152                                                    ),
153                                                }),
154                                        )
155                                        .child(div().when(
156                                            selected_model.as_ref() == Some(&id)
157                                                && selected_provider.as_ref() == Some(&provider_id),
158                                            |this| {
159                                                this.child(
160                                                    Icon::new(IconName::Check)
161                                                        .color(Color::Accent)
162                                                        .size(IconSize::Small),
163                                                )
164                                            },
165                                        ))
166                                        .into_any()
167                                }
168                            },
169                            {
170                                let fs = self.fs.clone();
171                                let model = available_model.clone();
172                                move |cx| {
173                                    let model = model.clone();
174                                    update_settings_file::<AssistantSettings>(
175                                        fs.clone(),
176                                        cx,
177                                        move |settings, _| settings.set_model(model),
178                                    );
179                                }
180                            },
181                        );
182                    }
183                }
184                menu
185            })
186            .into()
187        })
188        .trigger(self.trigger)
189        .attach(gpui::AnchorCorner::BottomLeft)
190    }
191}