agent: Introduce `ModelUsageContext` (#32076)

Bennet Bo Fenner created

This PR is a refactor of the existing `ModelType` in
`agent_model_selector`.

In #31848 we also need to know which context we are operating in, to
check if the configured model has image support.
In order to deduplicate the logic needed, I introduced a new type called
`ModelUsageContext` which can be used throughout the agent crate


Release Notes:

- N/A

Change summary

crates/agent/src/agent.rs                | 23 +++++++++++++++++++++--
crates/agent/src/agent_model_selector.rs | 25 +++++++------------------
crates/agent/src/inline_prompt_editor.rs |  8 ++++----
crates/agent/src/message_editor.rs       |  8 ++++----
4 files changed, 36 insertions(+), 28 deletions(-)

Detailed changes

crates/agent/src/agent.rs 🔗

@@ -33,9 +33,11 @@ use assistant_slash_command::SlashCommandRegistry;
 use client::Client;
 use feature_flags::FeatureFlagAppExt as _;
 use fs::Fs;
-use gpui::{App, actions, impl_actions};
+use gpui::{App, Entity, actions, impl_actions};
 use language::LanguageRegistry;
-use language_model::{LanguageModelId, LanguageModelProviderId, LanguageModelRegistry};
+use language_model::{
+    ConfiguredModel, LanguageModelId, LanguageModelProviderId, LanguageModelRegistry,
+};
 use prompt_store::PromptBuilder;
 use schemars::JsonSchema;
 use serde::Deserialize;
@@ -115,6 +117,23 @@ impl ManageProfiles {
 
 impl_actions!(agent, [NewThread, ManageProfiles]);
 
+#[derive(Clone)]
+pub(crate) enum ModelUsageContext {
+    Thread(Entity<Thread>),
+    InlineAssistant,
+}
+
+impl ModelUsageContext {
+    pub fn configured_model(&self, cx: &App) -> Option<ConfiguredModel> {
+        match self {
+            Self::Thread(thread) => thread.read(cx).configured_model(),
+            Self::InlineAssistant => {
+                LanguageModelRegistry::read_global(cx).inline_assistant_model()
+            }
+        }
+    }
+}
+
 /// Initializes the `agent` crate.
 pub fn init(
     fs: Arc<dyn Fs>,

crates/agent/src/agent_model_selector.rs 🔗

@@ -3,7 +3,7 @@ use fs::Fs;
 use gpui::{Entity, FocusHandle, SharedString};
 use picker::popover_menu::PickerPopoverMenu;
 
-use crate::Thread;
+use crate::ModelUsageContext;
 use assistant_context_editor::language_model_selector::{
     LanguageModelSelector, ToggleModelSelector, language_model_selector,
 };
@@ -12,12 +12,6 @@ use settings::update_settings_file;
 use std::sync::Arc;
 use ui::{PopoverMenuHandle, Tooltip, prelude::*};
 
-#[derive(Clone)]
-pub enum ModelType {
-    Default(Entity<Thread>),
-    InlineAssistant,
-}
-
 pub struct AgentModelSelector {
     selector: Entity<LanguageModelSelector>,
     menu_handle: PopoverMenuHandle<LanguageModelSelector>,
@@ -29,7 +23,7 @@ impl AgentModelSelector {
         fs: Arc<dyn Fs>,
         menu_handle: PopoverMenuHandle<LanguageModelSelector>,
         focus_handle: FocusHandle,
-        model_type: ModelType,
+        model_usage_context: ModelUsageContext,
         window: &mut Window,
         cx: &mut Context<Self>,
     ) -> Self {
@@ -38,19 +32,14 @@ impl AgentModelSelector {
                 let fs = fs.clone();
                 language_model_selector(
                     {
-                        let model_type = model_type.clone();
-                        move |cx| match &model_type {
-                            ModelType::Default(thread) => thread.read(cx).configured_model(),
-                            ModelType::InlineAssistant => {
-                                LanguageModelRegistry::read_global(cx).inline_assistant_model()
-                            }
-                        }
+                        let model_context = model_usage_context.clone();
+                        move |cx| model_context.configured_model(cx)
                     },
                     move |model, cx| {
                         let provider = model.provider_id().0.to_string();
                         let model_id = model.id().0.to_string();
-                        match &model_type {
-                            ModelType::Default(thread) => {
+                        match &model_usage_context {
+                            ModelUsageContext::Thread(thread) => {
                                 thread.update(cx, |thread, cx| {
                                     let registry = LanguageModelRegistry::read_global(cx);
                                     if let Some(provider) = registry.provider(&model.provider_id())
@@ -72,7 +61,7 @@ impl AgentModelSelector {
                                     },
                                 );
                             }
-                            ModelType::InlineAssistant => {
+                            ModelUsageContext::InlineAssistant => {
                                 update_settings_file::<AgentSettings>(
                                     fs.clone(),
                                     cx,

crates/agent/src/inline_prompt_editor.rs 🔗

@@ -1,4 +1,4 @@
-use crate::agent_model_selector::{AgentModelSelector, ModelType};
+use crate::agent_model_selector::AgentModelSelector;
 use crate::buffer_codegen::BufferCodegen;
 use crate::context::ContextCreasesAddon;
 use crate::context_picker::{ContextPicker, ContextPickerCompletionProvider};
@@ -7,7 +7,7 @@ use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind};
 use crate::message_editor::{extract_message_creases, insert_message_creases};
 use crate::terminal_codegen::TerminalCodegen;
 use crate::thread_store::{TextThreadStore, ThreadStore};
-use crate::{CycleNextInlineAssist, CyclePreviousInlineAssist};
+use crate::{CycleNextInlineAssist, CyclePreviousInlineAssist, ModelUsageContext};
 use crate::{RemoveAllContext, ToggleContextPicker};
 use assistant_context_editor::language_model_selector::ToggleModelSelector;
 use client::ErrorExt;
@@ -930,7 +930,7 @@ impl PromptEditor<BufferCodegen> {
                     fs,
                     model_selector_menu_handle,
                     prompt_editor.focus_handle(cx),
-                    ModelType::InlineAssistant,
+                    ModelUsageContext::InlineAssistant,
                     window,
                     cx,
                 )
@@ -1101,7 +1101,7 @@ impl PromptEditor<TerminalCodegen> {
                     fs,
                     model_selector_menu_handle.clone(),
                     prompt_editor.focus_handle(cx),
-                    ModelType::InlineAssistant,
+                    ModelUsageContext::InlineAssistant,
                     window,
                     cx,
                 )

crates/agent/src/message_editor.rs 🔗

@@ -2,7 +2,7 @@ use std::collections::BTreeMap;
 use std::rc::Rc;
 use std::sync::Arc;
 
-use crate::agent_model_selector::{AgentModelSelector, ModelType};
+use crate::agent_model_selector::AgentModelSelector;
 use crate::context::{AgentContextKey, ContextCreasesAddon, ContextLoadResult, load_context};
 use crate::tool_compatibility::{IncompatibleToolsState, IncompatibleToolsTooltip};
 use crate::ui::{
@@ -52,8 +52,8 @@ use crate::thread::{MessageCrease, Thread, TokenUsageRatio};
 use crate::thread_store::{TextThreadStore, ThreadStore};
 use crate::{
     ActiveThread, AgentDiffPane, Chat, ChatWithFollow, ExpandMessageEditor, Follow, KeepAll,
-    NewThread, OpenAgentDiff, RejectAll, RemoveAllContext, ToggleBurnMode, ToggleContextPicker,
-    ToggleProfileSelector, register_agent_preview,
+    ModelUsageContext, NewThread, OpenAgentDiff, RejectAll, RemoveAllContext, ToggleBurnMode,
+    ToggleContextPicker, ToggleProfileSelector, register_agent_preview,
 };
 
 #[derive(RegisterComponent)]
@@ -197,7 +197,7 @@ impl MessageEditor {
                 fs.clone(),
                 model_selector_menu_handle,
                 editor.focus_handle(cx),
-                ModelType::Default(thread.clone()),
+                ModelUsageContext::Thread(thread.clone()),
                 window,
                 cx,
             )