model_selector.rs

  1use std::sync::Arc;
  2
  3use crate::{assistant_settings::AssistantSettings, LanguageModelCompletionProvider};
  4use fs::Fs;
  5use gpui::SharedString;
  6use language_model::LanguageModelRegistry;
  7use settings::update_settings_file;
  8use ui::{prelude::*, ContextMenu, PopoverMenu, PopoverMenuHandle, PopoverTrigger};
  9
 10#[derive(IntoElement)]
 11pub struct ModelSelector<T: PopoverTrigger> {
 12    handle: Option<PopoverMenuHandle<ContextMenu>>,
 13    fs: Arc<dyn Fs>,
 14    trigger: T,
 15    info_text: Option<SharedString>,
 16}
 17
 18impl<T: PopoverTrigger> ModelSelector<T> {
 19    pub fn new(fs: Arc<dyn Fs>, trigger: T) -> Self {
 20        ModelSelector {
 21            handle: None,
 22            fs,
 23            trigger,
 24            info_text: None,
 25        }
 26    }
 27
 28    pub fn with_handle(mut self, handle: PopoverMenuHandle<ContextMenu>) -> Self {
 29        self.handle = Some(handle);
 30        self
 31    }
 32
 33    pub fn with_info_text(mut self, text: impl Into<SharedString>) -> Self {
 34        self.info_text = Some(text.into());
 35        self
 36    }
 37}
 38
 39impl<T: PopoverTrigger> RenderOnce for ModelSelector<T> {
 40    fn render(self, _: &mut WindowContext) -> impl IntoElement {
 41        let mut menu = PopoverMenu::new("model-switcher");
 42        if let Some(handle) = self.handle {
 43            menu = menu.with_handle(handle);
 44        }
 45
 46        let info_text = self.info_text.clone();
 47
 48        menu.menu(move |cx| {
 49            ContextMenu::build(cx, |mut menu, cx| {
 50                if let Some(info_text) = info_text.clone() {
 51                    menu = menu
 52                        .custom_row(move |_cx| {
 53                            Label::new(info_text.clone())
 54                                .color(Color::Muted)
 55                                .into_any_element()
 56                        })
 57                        .separator();
 58                }
 59
 60                for (index, provider) in LanguageModelRegistry::global(cx)
 61                    .read(cx)
 62                    .providers()
 63                    .enumerate()
 64                {
 65                    if index > 0 {
 66                        menu = menu.separator();
 67                    }
 68                    menu = menu.header(provider.name().0);
 69
 70                    let available_models = provider.provided_models(cx);
 71                    if available_models.is_empty() {
 72                        menu = menu.custom_entry(
 73                            {
 74                                move |_| {
 75                                    h_flex()
 76                                        .w_full()
 77                                        .gap_1()
 78                                        .child(Icon::new(IconName::Settings))
 79                                        .child(Label::new("Configure"))
 80                                        .into_any()
 81                                }
 82                            },
 83                            {
 84                                let provider = provider.id();
 85                                move |cx| {
 86                                    LanguageModelCompletionProvider::global(cx).update(
 87                                        cx,
 88                                        |completion_provider, cx| {
 89                                            completion_provider
 90                                                .set_active_provider(provider.clone(), cx)
 91                                        },
 92                                    );
 93                                }
 94                            },
 95                        );
 96                    }
 97
 98                    let selected_model = LanguageModelCompletionProvider::read_global(cx)
 99                        .active_model()
100                        .map(|m| m.id());
101                    let selected_provider = LanguageModelCompletionProvider::read_global(cx)
102                        .active_provider()
103                        .map(|m| m.id());
104
105                    for available_model in available_models {
106                        menu = menu.custom_entry(
107                            {
108                                let id = available_model.id();
109                                let provider_id = available_model.provider_id();
110                                let model_name = available_model.name().0.clone();
111                                let selected_model = selected_model.clone();
112                                let selected_provider = selected_provider.clone();
113                                move |_| {
114                                    h_flex()
115                                        .w_full()
116                                        .justify_between()
117                                        .child(Label::new(model_name.clone()))
118                                        .when(
119                                            selected_model.as_ref() == Some(&id)
120                                                && selected_provider.as_ref() == Some(&provider_id),
121                                            |this| this.child(Icon::new(IconName::Check)),
122                                        )
123                                        .into_any()
124                                }
125                            },
126                            {
127                                let fs = self.fs.clone();
128                                let model = available_model.clone();
129                                move |cx| {
130                                    let model = model.clone();
131                                    update_settings_file::<AssistantSettings>(
132                                        fs.clone(),
133                                        cx,
134                                        move |settings, _| settings.set_model(model),
135                                    );
136                                }
137                            },
138                        );
139                    }
140                }
141                menu
142            })
143            .into()
144        })
145        .trigger(self.trigger)
146        .attach(gpui::AnchorCorner::BottomLeft)
147    }
148}