agent_settings.rs

  1mod agent_profile;
  2
  3use std::sync::Arc;
  4
  5use collections::IndexMap;
  6use gpui::{App, Pixels, px};
  7use language_model::LanguageModel;
  8use project::DisableAiSettings;
  9use schemars::JsonSchema;
 10use serde::{Deserialize, Serialize};
 11use settings::{
 12    DefaultAgentView, DockPosition, LanguageModelParameters, LanguageModelSelection,
 13    NotifyWhenAgentWaiting, RegisterSetting, Settings,
 14};
 15
 16pub use crate::agent_profile::*;
 17
 18pub const SUMMARIZE_THREAD_PROMPT: &str = include_str!("prompts/summarize_thread_prompt.txt");
 19pub const SUMMARIZE_THREAD_DETAILED_PROMPT: &str =
 20    include_str!("prompts/summarize_thread_detailed_prompt.txt");
 21
 22#[derive(Clone, Debug, RegisterSetting)]
 23pub struct AgentSettings {
 24    pub enabled: bool,
 25    pub button: bool,
 26    pub dock: DockPosition,
 27    pub default_width: Pixels,
 28    pub default_height: Pixels,
 29    pub default_model: Option<LanguageModelSelection>,
 30    pub inline_assistant_model: Option<LanguageModelSelection>,
 31    pub inline_assistant_use_streaming_tools: bool,
 32    pub commit_message_model: Option<LanguageModelSelection>,
 33    pub thread_summary_model: Option<LanguageModelSelection>,
 34    pub inline_alternatives: Vec<LanguageModelSelection>,
 35    pub default_profile: AgentProfileId,
 36    pub default_view: DefaultAgentView,
 37    pub profiles: IndexMap<AgentProfileId, AgentProfileSettings>,
 38    pub always_allow_tool_actions: bool,
 39    pub notify_when_agent_waiting: NotifyWhenAgentWaiting,
 40    pub play_sound_when_agent_done: bool,
 41    pub single_file_review: bool,
 42    pub model_parameters: Vec<LanguageModelParameters>,
 43    pub preferred_completion_mode: CompletionMode,
 44    pub enable_feedback: bool,
 45    pub expand_edit_card: bool,
 46    pub expand_terminal_card: bool,
 47    pub use_modifier_to_send: bool,
 48    pub message_editor_min_lines: usize,
 49}
 50
 51impl AgentSettings {
 52    pub fn enabled(&self, cx: &App) -> bool {
 53        self.enabled && !DisableAiSettings::get_global(cx).disable_ai
 54    }
 55
 56    pub fn temperature_for_model(model: &Arc<dyn LanguageModel>, cx: &App) -> Option<f32> {
 57        let settings = Self::get_global(cx);
 58        for setting in settings.model_parameters.iter().rev() {
 59            if let Some(provider) = &setting.provider
 60                && provider.0 != model.provider_id().0
 61            {
 62                continue;
 63            }
 64            if let Some(setting_model) = &setting.model
 65                && *setting_model != model.id().0
 66            {
 67                continue;
 68            }
 69            return setting.temperature;
 70        }
 71        return None;
 72    }
 73
 74    pub fn set_inline_assistant_model(&mut self, provider: String, model: String) {
 75        self.inline_assistant_model = Some(LanguageModelSelection {
 76            provider: provider.into(),
 77            model,
 78        });
 79    }
 80
 81    pub fn set_commit_message_model(&mut self, provider: String, model: String) {
 82        self.commit_message_model = Some(LanguageModelSelection {
 83            provider: provider.into(),
 84            model,
 85        });
 86    }
 87
 88    pub fn set_thread_summary_model(&mut self, provider: String, model: String) {
 89        self.thread_summary_model = Some(LanguageModelSelection {
 90            provider: provider.into(),
 91            model,
 92        });
 93    }
 94
 95    pub fn set_message_editor_max_lines(&self) -> usize {
 96        self.message_editor_min_lines * 2
 97    }
 98}
 99
100#[derive(Clone, Copy, Debug, Serialize, Deserialize, JsonSchema, PartialEq, Default)]
101#[serde(rename_all = "snake_case")]
102pub enum CompletionMode {
103    #[default]
104    Normal,
105    #[serde(alias = "max")]
106    Burn,
107}
108
109impl From<CompletionMode> for cloud_llm_client::CompletionMode {
110    fn from(value: CompletionMode) -> Self {
111        match value {
112            CompletionMode::Normal => cloud_llm_client::CompletionMode::Normal,
113            CompletionMode::Burn => cloud_llm_client::CompletionMode::Max,
114        }
115    }
116}
117
118impl From<settings::CompletionMode> for CompletionMode {
119    fn from(value: settings::CompletionMode) -> Self {
120        match value {
121            settings::CompletionMode::Normal => CompletionMode::Normal,
122            settings::CompletionMode::Burn => CompletionMode::Burn,
123        }
124    }
125}
126
127#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize, JsonSchema)]
128pub struct AgentProfileId(pub Arc<str>);
129
130impl AgentProfileId {
131    pub fn as_str(&self) -> &str {
132        &self.0
133    }
134}
135
136impl std::fmt::Display for AgentProfileId {
137    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
138        write!(f, "{}", self.0)
139    }
140}
141
142impl Default for AgentProfileId {
143    fn default() -> Self {
144        Self("write".into())
145    }
146}
147
148impl Settings for AgentSettings {
149    fn from_settings(content: &settings::SettingsContent) -> Self {
150        let agent = content.agent.clone().unwrap();
151        Self {
152            enabled: agent.enabled.unwrap(),
153            button: agent.button.unwrap(),
154            dock: agent.dock.unwrap(),
155            default_width: px(agent.default_width.unwrap()),
156            default_height: px(agent.default_height.unwrap()),
157            default_model: Some(agent.default_model.unwrap()),
158            inline_assistant_model: agent.inline_assistant_model,
159            inline_assistant_use_streaming_tools: agent
160                .inline_assistant_use_streaming_tools
161                .unwrap_or(true),
162            commit_message_model: agent.commit_message_model,
163            thread_summary_model: agent.thread_summary_model,
164            inline_alternatives: agent.inline_alternatives.unwrap_or_default(),
165            default_profile: AgentProfileId(agent.default_profile.unwrap()),
166            default_view: agent.default_view.unwrap(),
167            profiles: agent
168                .profiles
169                .unwrap()
170                .into_iter()
171                .map(|(key, val)| (AgentProfileId(key), val.into()))
172                .collect(),
173            always_allow_tool_actions: agent.always_allow_tool_actions.unwrap(),
174            notify_when_agent_waiting: agent.notify_when_agent_waiting.unwrap(),
175            play_sound_when_agent_done: agent.play_sound_when_agent_done.unwrap(),
176            single_file_review: agent.single_file_review.unwrap(),
177            model_parameters: agent.model_parameters,
178            preferred_completion_mode: agent.preferred_completion_mode.unwrap().into(),
179            enable_feedback: agent.enable_feedback.unwrap(),
180            expand_edit_card: agent.expand_edit_card.unwrap(),
181            expand_terminal_card: agent.expand_terminal_card.unwrap(),
182            use_modifier_to_send: agent.use_modifier_to_send.unwrap(),
183            message_editor_min_lines: agent.message_editor_min_lines.unwrap(),
184        }
185    }
186}