agent_model_selector.rs

  1use agent_settings::AgentSettings;
  2use fs::Fs;
  3use gpui::{Entity, FocusHandle, SharedString};
  4
  5use crate::Thread;
  6use assistant_context_editor::language_model_selector::{
  7    LanguageModelSelector, LanguageModelSelectorPopoverMenu, ToggleModelSelector,
  8};
  9use language_model::{ConfiguredModel, LanguageModelRegistry};
 10use settings::update_settings_file;
 11use std::sync::Arc;
 12use ui::{PopoverMenuHandle, Tooltip, prelude::*};
 13
 14#[derive(Clone)]
 15pub enum ModelType {
 16    Default(Entity<Thread>),
 17    InlineAssistant,
 18}
 19
 20pub struct AgentModelSelector {
 21    selector: Entity<LanguageModelSelector>,
 22    menu_handle: PopoverMenuHandle<LanguageModelSelector>,
 23    focus_handle: FocusHandle,
 24}
 25
 26impl AgentModelSelector {
 27    pub(crate) fn new(
 28        fs: Arc<dyn Fs>,
 29        menu_handle: PopoverMenuHandle<LanguageModelSelector>,
 30        focus_handle: FocusHandle,
 31        model_type: ModelType,
 32        window: &mut Window,
 33        cx: &mut Context<Self>,
 34    ) -> Self {
 35        Self {
 36            selector: cx.new(move |cx| {
 37                let fs = fs.clone();
 38                LanguageModelSelector::new(
 39                    {
 40                        let model_type = model_type.clone();
 41                        move |cx| match &model_type {
 42                            ModelType::Default(thread) => thread.read(cx).configured_model(),
 43                            ModelType::InlineAssistant => {
 44                                LanguageModelRegistry::read_global(cx).inline_assistant_model()
 45                            }
 46                        }
 47                    },
 48                    move |model, cx| {
 49                        let provider = model.provider_id().0.to_string();
 50                        let model_id = model.id().0.to_string();
 51                        match &model_type {
 52                            ModelType::Default(thread) => {
 53                                thread.update(cx, |thread, cx| {
 54                                    let registry = LanguageModelRegistry::read_global(cx);
 55                                    if let Some(provider) = registry.provider(&model.provider_id())
 56                                    {
 57                                        thread.set_configured_model(
 58                                            Some(ConfiguredModel {
 59                                                provider,
 60                                                model: model.clone(),
 61                                            }),
 62                                            cx,
 63                                        );
 64                                    }
 65                                });
 66                                update_settings_file::<AgentSettings>(
 67                                    fs.clone(),
 68                                    cx,
 69                                    move |settings, _cx| {
 70                                        settings.set_model(model.clone());
 71                                    },
 72                                );
 73                            }
 74                            ModelType::InlineAssistant => {
 75                                update_settings_file::<AgentSettings>(
 76                                    fs.clone(),
 77                                    cx,
 78                                    move |settings, _cx| {
 79                                        settings.set_inline_assistant_model(
 80                                            provider.clone(),
 81                                            model_id.clone(),
 82                                        );
 83                                    },
 84                                );
 85                            }
 86                        }
 87                    },
 88                    window,
 89                    cx,
 90                )
 91            }),
 92            menu_handle,
 93            focus_handle,
 94        }
 95    }
 96
 97    pub fn toggle(&self, window: &mut Window, cx: &mut Context<Self>) {
 98        self.menu_handle.toggle(window, cx);
 99    }
100}
101
102impl Render for AgentModelSelector {
103    fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
104        let focus_handle = self.focus_handle.clone();
105
106        let model = self.selector.read(cx).active_model(cx);
107        let model_name = model
108            .map(|model| model.model.name().0)
109            .unwrap_or_else(|| SharedString::from("No model selected"));
110
111        LanguageModelSelectorPopoverMenu::new(
112            self.selector.clone(),
113            Button::new("active-model", model_name)
114                .label_size(LabelSize::Small)
115                .color(Color::Muted)
116                .icon(IconName::ChevronDown)
117                .icon_size(IconSize::XSmall)
118                .icon_position(IconPosition::End)
119                .icon_color(Color::Muted),
120            move |window, cx| {
121                Tooltip::for_action_in(
122                    "Change Model",
123                    &ToggleModelSelector,
124                    &focus_handle,
125                    window,
126                    cx,
127                )
128            },
129            gpui::Corner::BottomRight,
130        )
131        .with_handle(self.menu_handle.clone())
132    }
133}