agent_model_selector.rs

  1use crate::{
  2    ModelUsageContext,
  3    language_model_selector::{LanguageModelSelector, language_model_selector},
  4};
  5use agent_settings::AgentSettings;
  6use fs::Fs;
  7use gpui::{Entity, FocusHandle, SharedString};
  8use language_model::{ConfiguredModel, LanguageModelRegistry};
  9use picker::popover_menu::PickerPopoverMenu;
 10use settings::update_settings_file;
 11use std::sync::Arc;
 12use ui::{ButtonLike, PopoverMenuHandle, Tooltip, prelude::*};
 13use zed_actions::agent::ToggleModelSelector;
 14
 15pub struct AgentModelSelector {
 16    selector: Entity<LanguageModelSelector>,
 17    menu_handle: PopoverMenuHandle<LanguageModelSelector>,
 18    focus_handle: FocusHandle,
 19}
 20
 21impl AgentModelSelector {
 22    pub(crate) fn new(
 23        fs: Arc<dyn Fs>,
 24        menu_handle: PopoverMenuHandle<LanguageModelSelector>,
 25        focus_handle: FocusHandle,
 26        model_usage_context: ModelUsageContext,
 27        window: &mut Window,
 28        cx: &mut Context<Self>,
 29    ) -> Self {
 30        Self {
 31            selector: cx.new(move |cx| {
 32                let fs = fs.clone();
 33                language_model_selector(
 34                    {
 35                        let model_context = model_usage_context.clone();
 36                        move |cx| model_context.configured_model(cx)
 37                    },
 38                    move |model, cx| {
 39                        let provider = model.provider_id().0.to_string();
 40                        let model_id = model.id().0.to_string();
 41
 42                        // Authenticate the provider when a model is selected
 43                        let registry = LanguageModelRegistry::read_global(cx);
 44                        if let Some(provider) = registry.provider(&model.provider_id()) {
 45                            provider.authenticate(cx).detach();
 46                        }
 47
 48                        match &model_usage_context {
 49                            ModelUsageContext::Thread(thread) => {
 50                                thread.update(cx, |thread, cx| {
 51                                    let registry = LanguageModelRegistry::read_global(cx);
 52                                    if let Some(provider) = registry.provider(&model.provider_id())
 53                                    {
 54                                        thread.set_configured_model(
 55                                            Some(ConfiguredModel {
 56                                                provider,
 57                                                model: model.clone(),
 58                                            }),
 59                                            cx,
 60                                        );
 61                                    }
 62                                });
 63                                update_settings_file::<AgentSettings>(
 64                                    fs.clone(),
 65                                    cx,
 66                                    move |settings, _cx| {
 67                                        settings.set_model(model.clone());
 68                                    },
 69                                );
 70                            }
 71                            ModelUsageContext::InlineAssistant => {
 72                                update_settings_file::<AgentSettings>(
 73                                    fs.clone(),
 74                                    cx,
 75                                    move |settings, _cx| {
 76                                        settings.set_inline_assistant_model(
 77                                            provider.clone(),
 78                                            model_id.clone(),
 79                                        );
 80                                    },
 81                                );
 82                            }
 83                        }
 84                    },
 85                    window,
 86                    cx,
 87                )
 88            }),
 89            menu_handle,
 90            focus_handle,
 91        }
 92    }
 93
 94    pub fn toggle(&self, window: &mut Window, cx: &mut Context<Self>) {
 95        self.menu_handle.toggle(window, cx);
 96    }
 97}
 98
 99impl Render for AgentModelSelector {
100    fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
101        let model = self.selector.read(cx).delegate.active_model(cx);
102        let model_name = model
103            .as_ref()
104            .map(|model| model.model.name().0)
105            .unwrap_or_else(|| SharedString::from("Select a Model"));
106
107        let provider_icon = model.as_ref().map(|model| model.provider.icon());
108
109        let focus_handle = self.focus_handle.clone();
110
111        PickerPopoverMenu::new(
112            self.selector.clone(),
113            ButtonLike::new("active-model")
114                .when_some(provider_icon, |this, icon| {
115                    this.child(Icon::new(icon).color(Color::Muted).size(IconSize::XSmall))
116                })
117                .child(
118                    Label::new(model_name)
119                        .color(Color::Muted)
120                        .size(LabelSize::Small)
121                        .ml_0p5(),
122                )
123                .child(
124                    Icon::new(IconName::ChevronDown)
125                        .color(Color::Muted)
126                        .size(IconSize::XSmall),
127                ),
128            move |window, cx| {
129                Tooltip::for_action_in(
130                    "Change Model",
131                    &ToggleModelSelector,
132                    &focus_handle,
133                    window,
134                    cx,
135                )
136            },
137            gpui::Corner::BottomRight,
138            cx,
139        )
140        .with_handle(self.menu_handle.clone())
141        .render(window, cx)
142    }
143}