model_selector.rs

  1use std::sync::Arc;
  2
  3use crate::assistant_settings::AssistantSettings;
  4use fs::Fs;
  5use gpui::SharedString;
  6use language_model::LanguageModelRegistry;
  7use settings::update_settings_file;
  8use ui::{prelude::*, ContextMenu, PopoverMenu, PopoverMenuHandle, PopoverTrigger};
  9
 10#[derive(IntoElement)]
 11pub struct ModelSelector<T: PopoverTrigger> {
 12    handle: Option<PopoverMenuHandle<ContextMenu>>,
 13    fs: Arc<dyn Fs>,
 14    trigger: T,
 15    info_text: Option<SharedString>,
 16}
 17
 18impl<T: PopoverTrigger> ModelSelector<T> {
 19    pub fn new(fs: Arc<dyn Fs>, trigger: T) -> Self {
 20        ModelSelector {
 21            handle: None,
 22            fs,
 23            trigger,
 24            info_text: None,
 25        }
 26    }
 27
 28    pub fn with_handle(mut self, handle: PopoverMenuHandle<ContextMenu>) -> Self {
 29        self.handle = Some(handle);
 30        self
 31    }
 32
 33    pub fn with_info_text(mut self, text: impl Into<SharedString>) -> Self {
 34        self.info_text = Some(text.into());
 35        self
 36    }
 37}
 38
 39impl<T: PopoverTrigger> RenderOnce for ModelSelector<T> {
 40    fn render(self, _: &mut WindowContext) -> impl IntoElement {
 41        let mut menu = PopoverMenu::new("model-switcher");
 42        if let Some(handle) = self.handle {
 43            menu = menu.with_handle(handle);
 44        }
 45
 46        let info_text = self.info_text.clone();
 47
 48        menu.menu(move |cx| {
 49            ContextMenu::build(cx, |mut menu, cx| {
 50                if let Some(info_text) = info_text.clone() {
 51                    menu = menu
 52                        .custom_row(move |_cx| {
 53                            Label::new(info_text.clone())
 54                                .color(Color::Muted)
 55                                .into_any_element()
 56                        })
 57                        .separator();
 58                }
 59
 60                for (index, provider) in LanguageModelRegistry::global(cx)
 61                    .read(cx)
 62                    .providers()
 63                    .into_iter()
 64                    .enumerate()
 65                {
 66                    if index > 0 {
 67                        menu = menu.separator();
 68                    }
 69                    menu = menu.header(provider.name().0);
 70
 71                    let available_models = provider.provided_models(cx);
 72                    if available_models.is_empty() {
 73                        menu = menu.custom_entry(
 74                            {
 75                                move |_| {
 76                                    h_flex()
 77                                        .w_full()
 78                                        .gap_1()
 79                                        .child(Icon::new(IconName::Settings))
 80                                        .child(Label::new("Configure"))
 81                                        .into_any()
 82                                }
 83                            },
 84                            {
 85                                let provider = provider.clone();
 86                                move |cx| {
 87                                    LanguageModelRegistry::global(cx).update(
 88                                        cx,
 89                                        |completion_provider, cx| {
 90                                            completion_provider
 91                                                .set_active_provider(Some(provider.clone()), cx);
 92                                        },
 93                                    );
 94                                }
 95                            },
 96                        );
 97                    }
 98
 99                    let selected_provider = LanguageModelRegistry::read_global(cx)
100                        .active_provider()
101                        .map(|m| m.id());
102                    let selected_model = LanguageModelRegistry::read_global(cx)
103                        .active_model()
104                        .map(|m| m.id());
105
106                    for available_model in available_models {
107                        menu = menu.custom_entry(
108                            {
109                                let id = available_model.id();
110                                let provider_id = available_model.provider_id();
111                                let model_name = available_model.name().0.clone();
112                                let selected_model = selected_model.clone();
113                                let selected_provider = selected_provider.clone();
114                                move |_| {
115                                    h_flex()
116                                        .w_full()
117                                        .justify_between()
118                                        .child(Label::new(model_name.clone()))
119                                        .when(
120                                            selected_model.as_ref() == Some(&id)
121                                                && selected_provider.as_ref() == Some(&provider_id),
122                                            |this| this.child(Icon::new(IconName::Check)),
123                                        )
124                                        .into_any()
125                                }
126                            },
127                            {
128                                let fs = self.fs.clone();
129                                let model = available_model.clone();
130                                move |cx| {
131                                    let model = model.clone();
132                                    update_settings_file::<AssistantSettings>(
133                                        fs.clone(),
134                                        cx,
135                                        move |settings, _| settings.set_model(model),
136                                    );
137                                }
138                            },
139                        );
140                    }
141                }
142                menu
143            })
144            .into()
145        })
146        .trigger(self.trigger)
147        .attach(gpui::AnchorCorner::BottomLeft)
148    }
149}