agent_model_selector.rs

  1use agent_settings::AgentSettings;
  2use fs::Fs;
  3use gpui::{Entity, FocusHandle, SharedString};
  4use picker::popover_menu::PickerPopoverMenu;
  5
  6use crate::Thread;
  7use assistant_context_editor::language_model_selector::{
  8    LanguageModelSelector, ToggleModelSelector, language_model_selector,
  9};
 10use language_model::{ConfiguredModel, LanguageModelRegistry};
 11use settings::update_settings_file;
 12use std::sync::Arc;
 13use ui::{PopoverMenuHandle, Tooltip, prelude::*};
 14
 15#[derive(Clone)]
 16pub enum ModelType {
 17    Default(Entity<Thread>),
 18    InlineAssistant,
 19}
 20
 21pub struct AgentModelSelector {
 22    selector: Entity<LanguageModelSelector>,
 23    menu_handle: PopoverMenuHandle<LanguageModelSelector>,
 24    focus_handle: FocusHandle,
 25}
 26
 27impl AgentModelSelector {
 28    pub(crate) fn new(
 29        fs: Arc<dyn Fs>,
 30        menu_handle: PopoverMenuHandle<LanguageModelSelector>,
 31        focus_handle: FocusHandle,
 32        model_type: ModelType,
 33        window: &mut Window,
 34        cx: &mut Context<Self>,
 35    ) -> Self {
 36        Self {
 37            selector: cx.new(move |cx| {
 38                let fs = fs.clone();
 39                language_model_selector(
 40                    {
 41                        let model_type = model_type.clone();
 42                        move |cx| match &model_type {
 43                            ModelType::Default(thread) => thread.read(cx).configured_model(),
 44                            ModelType::InlineAssistant => {
 45                                LanguageModelRegistry::read_global(cx).inline_assistant_model()
 46                            }
 47                        }
 48                    },
 49                    move |model, cx| {
 50                        let provider = model.provider_id().0.to_string();
 51                        let model_id = model.id().0.to_string();
 52                        match &model_type {
 53                            ModelType::Default(thread) => {
 54                                thread.update(cx, |thread, cx| {
 55                                    let registry = LanguageModelRegistry::read_global(cx);
 56                                    if let Some(provider) = registry.provider(&model.provider_id())
 57                                    {
 58                                        thread.set_configured_model(
 59                                            Some(ConfiguredModel {
 60                                                provider,
 61                                                model: model.clone(),
 62                                            }),
 63                                            cx,
 64                                        );
 65                                    }
 66                                });
 67                                update_settings_file::<AgentSettings>(
 68                                    fs.clone(),
 69                                    cx,
 70                                    move |settings, _cx| {
 71                                        settings.set_model(model.clone());
 72                                    },
 73                                );
 74                            }
 75                            ModelType::InlineAssistant => {
 76                                update_settings_file::<AgentSettings>(
 77                                    fs.clone(),
 78                                    cx,
 79                                    move |settings, _cx| {
 80                                        settings.set_inline_assistant_model(
 81                                            provider.clone(),
 82                                            model_id.clone(),
 83                                        );
 84                                    },
 85                                );
 86                            }
 87                        }
 88                    },
 89                    window,
 90                    cx,
 91                )
 92            }),
 93            menu_handle,
 94            focus_handle,
 95        }
 96    }
 97
 98    pub fn toggle(&self, window: &mut Window, cx: &mut Context<Self>) {
 99        self.menu_handle.toggle(window, cx);
100    }
101}
102
103impl Render for AgentModelSelector {
104    fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
105        let focus_handle = self.focus_handle.clone();
106
107        let model = self.selector.read(cx).delegate.active_model(cx);
108        let model_name = model
109            .map(|model| model.model.name().0)
110            .unwrap_or_else(|| SharedString::from("No model selected"));
111        PickerPopoverMenu::new(
112            self.selector.clone(),
113            Button::new("active-model", model_name)
114                .label_size(LabelSize::Small)
115                .color(Color::Muted)
116                .icon(IconName::ChevronDown)
117                .icon_size(IconSize::XSmall)
118                .icon_position(IconPosition::End)
119                .icon_color(Color::Muted),
120            move |window, cx| {
121                Tooltip::for_action_in(
122                    "Change Model",
123                    &ToggleModelSelector,
124                    &focus_handle,
125                    window,
126                    cx,
127                )
128            },
129            gpui::Corner::BottomRight,
130            cx,
131        )
132        .with_handle(self.menu_handle.clone())
133        .render(window, cx)
134    }
135}