model_selector.rs

  1use std::sync::Arc;
  2
  3use crate::{assistant_settings::AssistantSettings, LanguageModelCompletionProvider};
  4use fs::Fs;
  5use language_model::LanguageModelRegistry;
  6use settings::update_settings_file;
  7use ui::{prelude::*, ContextMenu, PopoverMenu, PopoverMenuHandle, PopoverTrigger};
  8
  9#[derive(IntoElement)]
 10pub struct ModelSelector<T: PopoverTrigger> {
 11    handle: Option<PopoverMenuHandle<ContextMenu>>,
 12    fs: Arc<dyn Fs>,
 13    trigger: T,
 14}
 15
 16impl<T: PopoverTrigger> ModelSelector<T> {
 17    pub fn new(fs: Arc<dyn Fs>, trigger: T) -> Self {
 18        ModelSelector {
 19            handle: None,
 20            fs,
 21            trigger,
 22        }
 23    }
 24
 25    pub fn with_handle(mut self, handle: PopoverMenuHandle<ContextMenu>) -> Self {
 26        self.handle = Some(handle);
 27        self
 28    }
 29}
 30
 31impl<T: PopoverTrigger> RenderOnce for ModelSelector<T> {
 32    fn render(self, _: &mut WindowContext) -> impl IntoElement {
 33        let mut menu = PopoverMenu::new("model-switcher");
 34        if let Some(handle) = self.handle {
 35            menu = menu.with_handle(handle);
 36        }
 37
 38        menu.menu(move |cx| {
 39            ContextMenu::build(cx, |mut menu, cx| {
 40                for (index, provider) in LanguageModelRegistry::global(cx)
 41                    .read(cx)
 42                    .providers()
 43                    .enumerate()
 44                {
 45                    if index > 0 {
 46                        menu = menu.separator();
 47                    }
 48                    menu = menu.header(provider.name().0);
 49
 50                    let available_models = provider.provided_models(cx);
 51                    if available_models.is_empty() {
 52                        menu = menu.custom_entry(
 53                            {
 54                                move |_| {
 55                                    h_flex()
 56                                        .w_full()
 57                                        .gap_1()
 58                                        .child(Icon::new(IconName::Settings))
 59                                        .child(Label::new("Configure"))
 60                                        .into_any()
 61                                }
 62                            },
 63                            {
 64                                let provider = provider.id();
 65                                move |cx| {
 66                                    LanguageModelCompletionProvider::global(cx).update(
 67                                        cx,
 68                                        |completion_provider, cx| {
 69                                            completion_provider
 70                                                .set_active_provider(provider.clone(), cx)
 71                                        },
 72                                    );
 73                                }
 74                            },
 75                        );
 76                    }
 77
 78                    let selected_model = LanguageModelCompletionProvider::read_global(cx)
 79                        .active_model()
 80                        .map(|m| m.id());
 81                    let selected_provider = LanguageModelCompletionProvider::read_global(cx)
 82                        .active_provider()
 83                        .map(|m| m.id());
 84
 85                    for available_model in available_models {
 86                        menu = menu.custom_entry(
 87                            {
 88                                let id = available_model.id();
 89                                let provider_id = available_model.provider_id();
 90                                let model_name = available_model.name().0.clone();
 91                                let selected_model = selected_model.clone();
 92                                let selected_provider = selected_provider.clone();
 93                                move |_| {
 94                                    h_flex()
 95                                        .w_full()
 96                                        .justify_between()
 97                                        .child(Label::new(model_name.clone()))
 98                                        .when(
 99                                            selected_model.as_ref() == Some(&id)
100                                                && selected_provider.as_ref() == Some(&provider_id),
101                                            |this| this.child(Icon::new(IconName::Check)),
102                                        )
103                                        .into_any()
104                                }
105                            },
106                            {
107                                let fs = self.fs.clone();
108                                let model = available_model.clone();
109                                move |cx| {
110                                    let model = model.clone();
111                                    update_settings_file::<AssistantSettings>(
112                                        fs.clone(),
113                                        cx,
114                                        move |settings, _| settings.set_model(model),
115                                    );
116                                }
117                            },
118                        );
119                    }
120                }
121                menu
122            })
123            .into()
124        })
125        .trigger(self.trigger)
126        .attach(gpui::AnchorCorner::BottomLeft)
127    }
128}