assistant.rs

  1#![cfg_attr(target_os = "windows", allow(unused, dead_code))]
  2
  3mod assistant_configuration;
  4pub mod assistant_panel;
  5mod inline_assistant;
  6pub mod slash_command_settings;
  7mod terminal_inline_assistant;
  8
  9use std::sync::Arc;
 10
 11use assistant_settings::AssistantSettings;
 12use assistant_slash_command::SlashCommandRegistry;
 13use client::Client;
 14use command_palette_hooks::CommandPaletteFilter;
 15use feature_flags::FeatureFlagAppExt;
 16use fs::Fs;
 17use gpui::{App, Global, ReadGlobal, UpdateGlobal, actions};
 18use language_model::{
 19    LanguageModelId, LanguageModelProviderId, LanguageModelRegistry, LanguageModelResponseMessage,
 20};
 21use prompt_store::PromptBuilder;
 22use serde::Deserialize;
 23use settings::{Settings, SettingsStore};
 24
 25pub use crate::assistant_panel::{AssistantPanel, AssistantPanelEvent};
 26pub(crate) use crate::inline_assistant::*;
 27use crate::slash_command_settings::SlashCommandSettings;
 28
 29actions!(
 30    assistant,
 31    [
 32        InsertActivePrompt,
 33        DeployHistory,
 34        NewChat,
 35        CycleNextInlineAssist,
 36        CyclePreviousInlineAssist
 37    ]
 38);
 39
 40const DEFAULT_CONTEXT_LINES: usize = 50;
 41
 42#[derive(Deserialize, Debug)]
 43pub struct LanguageModelUsage {
 44    pub prompt_tokens: u32,
 45    pub completion_tokens: u32,
 46    pub total_tokens: u32,
 47}
 48
 49#[derive(Deserialize, Debug)]
 50pub struct LanguageModelChoiceDelta {
 51    pub index: u32,
 52    pub delta: LanguageModelResponseMessage,
 53    pub finish_reason: Option<String>,
 54}
 55
 56/// The state pertaining to the Assistant.
 57#[derive(Default)]
 58struct Assistant {
 59    /// Whether the Assistant is enabled.
 60    enabled: bool,
 61}
 62
 63impl Global for Assistant {}
 64
 65impl Assistant {
 66    const NAMESPACE: &'static str = "assistant";
 67
 68    fn set_enabled(&mut self, enabled: bool, cx: &mut App) {
 69        if self.enabled == enabled {
 70            return;
 71        }
 72
 73        self.enabled = enabled;
 74
 75        if !enabled {
 76            CommandPaletteFilter::update_global(cx, |filter, _cx| {
 77                filter.hide_namespace(Self::NAMESPACE);
 78            });
 79
 80            return;
 81        }
 82
 83        CommandPaletteFilter::update_global(cx, |filter, _cx| {
 84            filter.show_namespace(Self::NAMESPACE);
 85        });
 86    }
 87
 88    pub fn enabled(cx: &App) -> bool {
 89        Self::global(cx).enabled
 90    }
 91}
 92
 93pub fn init(
 94    fs: Arc<dyn Fs>,
 95    client: Arc<Client>,
 96    prompt_builder: Arc<PromptBuilder>,
 97    cx: &mut App,
 98) {
 99    cx.set_global(Assistant::default());
100    AssistantSettings::register(cx);
101    SlashCommandSettings::register(cx);
102
103    assistant_context_editor::init(client.clone(), cx);
104    prompt_library::init(cx);
105    init_language_model_settings(cx);
106    assistant_slash_command::init(cx);
107    assistant_tool::init(cx);
108    assistant_panel::init(cx);
109    context_server::init(cx);
110
111    register_slash_commands(cx);
112    inline_assistant::init(
113        fs.clone(),
114        prompt_builder.clone(),
115        client.telemetry().clone(),
116        cx,
117    );
118    terminal_inline_assistant::init(
119        fs.clone(),
120        prompt_builder.clone(),
121        client.telemetry().clone(),
122        cx,
123    );
124    indexed_docs::init(cx);
125
126    CommandPaletteFilter::update_global(cx, |filter, _cx| {
127        filter.hide_namespace(Assistant::NAMESPACE);
128    });
129    Assistant::update_global(cx, |assistant, cx| {
130        let settings = AssistantSettings::get_global(cx);
131
132        assistant.set_enabled(settings.enabled, cx);
133    });
134    cx.observe_global::<SettingsStore>(|cx| {
135        Assistant::update_global(cx, |assistant, cx| {
136            let settings = AssistantSettings::get_global(cx);
137            assistant.set_enabled(settings.enabled, cx);
138        });
139    })
140    .detach();
141}
142
143fn init_language_model_settings(cx: &mut App) {
144    update_active_language_model_from_settings(cx);
145
146    cx.observe_global::<SettingsStore>(update_active_language_model_from_settings)
147        .detach();
148    cx.subscribe(
149        &LanguageModelRegistry::global(cx),
150        |_, event: &language_model::Event, cx| match event {
151            language_model::Event::ProviderStateChanged
152            | language_model::Event::AddedProvider(_)
153            | language_model::Event::RemovedProvider(_) => {
154                update_active_language_model_from_settings(cx);
155            }
156            _ => {}
157        },
158    )
159    .detach();
160}
161
162fn update_active_language_model_from_settings(cx: &mut App) {
163    let settings = AssistantSettings::get_global(cx);
164    // Default model - used as fallback
165    let active_model_provider_name =
166        LanguageModelProviderId::from(settings.default_model.provider.clone());
167    let active_model_id = LanguageModelId::from(settings.default_model.model.clone());
168
169    // Inline assistant model
170    let inline_assistant_model = settings
171        .inline_assistant_model
172        .as_ref()
173        .unwrap_or(&settings.default_model);
174    let inline_assistant_provider_name =
175        LanguageModelProviderId::from(inline_assistant_model.provider.clone());
176    let inline_assistant_model_id = LanguageModelId::from(inline_assistant_model.model.clone());
177
178    // Commit message model
179    let commit_message_model = settings
180        .commit_message_model
181        .as_ref()
182        .unwrap_or(&settings.default_model);
183    let commit_message_provider_name =
184        LanguageModelProviderId::from(commit_message_model.provider.clone());
185    let commit_message_model_id = LanguageModelId::from(commit_message_model.model.clone());
186
187    // Thread summary model
188    let thread_summary_model = settings
189        .thread_summary_model
190        .as_ref()
191        .unwrap_or(&settings.default_model);
192    let thread_summary_provider_name =
193        LanguageModelProviderId::from(thread_summary_model.provider.clone());
194    let thread_summary_model_id = LanguageModelId::from(thread_summary_model.model.clone());
195
196    let inline_alternatives = settings
197        .inline_alternatives
198        .iter()
199        .map(|alternative| {
200            (
201                LanguageModelProviderId::from(alternative.provider.clone()),
202                LanguageModelId::from(alternative.model.clone()),
203            )
204        })
205        .collect::<Vec<_>>();
206
207    LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
208        // Set the default model
209        registry.select_default_model(&active_model_provider_name, &active_model_id, cx);
210
211        // Set the specific models
212        registry.select_inline_assistant_model(
213            &inline_assistant_provider_name,
214            &inline_assistant_model_id,
215            cx,
216        );
217        registry.select_commit_message_model(
218            &commit_message_provider_name,
219            &commit_message_model_id,
220            cx,
221        );
222        registry.select_thread_summary_model(
223            &thread_summary_provider_name,
224            &thread_summary_model_id,
225            cx,
226        );
227
228        // Set the alternatives
229        registry.select_inline_alternative_models(inline_alternatives, cx);
230    });
231}
232
233fn register_slash_commands(cx: &mut App) {
234    let slash_command_registry = SlashCommandRegistry::global(cx);
235
236    slash_command_registry.register_command(assistant_slash_commands::FileSlashCommand, true);
237    slash_command_registry.register_command(assistant_slash_commands::DeltaSlashCommand, true);
238    slash_command_registry.register_command(assistant_slash_commands::OutlineSlashCommand, true);
239    slash_command_registry.register_command(assistant_slash_commands::TabSlashCommand, true);
240    slash_command_registry
241        .register_command(assistant_slash_commands::CargoWorkspaceSlashCommand, true);
242    slash_command_registry.register_command(assistant_slash_commands::PromptSlashCommand, true);
243    slash_command_registry.register_command(assistant_slash_commands::SelectionCommand, true);
244    slash_command_registry.register_command(assistant_slash_commands::DefaultSlashCommand, false);
245    slash_command_registry.register_command(assistant_slash_commands::TerminalSlashCommand, true);
246    slash_command_registry.register_command(assistant_slash_commands::NowSlashCommand, false);
247    slash_command_registry
248        .register_command(assistant_slash_commands::DiagnosticsSlashCommand, true);
249    slash_command_registry.register_command(assistant_slash_commands::FetchSlashCommand, true);
250
251    cx.observe_flag::<assistant_slash_commands::StreamingExampleSlashCommandFeatureFlag, _>({
252        let slash_command_registry = slash_command_registry.clone();
253        move |is_enabled, _cx| {
254            if is_enabled {
255                slash_command_registry.register_command(
256                    assistant_slash_commands::StreamingExampleSlashCommand,
257                    false,
258                );
259            }
260        }
261    })
262    .detach();
263
264    update_slash_commands_from_settings(cx);
265    cx.observe_global::<SettingsStore>(update_slash_commands_from_settings)
266        .detach();
267}
268
269fn update_slash_commands_from_settings(cx: &mut App) {
270    let slash_command_registry = SlashCommandRegistry::global(cx);
271    let settings = SlashCommandSettings::get_global(cx);
272
273    if settings.docs.enabled {
274        slash_command_registry.register_command(assistant_slash_commands::DocsSlashCommand, true);
275    } else {
276        slash_command_registry.unregister_command(assistant_slash_commands::DocsSlashCommand);
277    }
278
279    if settings.cargo_workspace.enabled {
280        slash_command_registry
281            .register_command(assistant_slash_commands::CargoWorkspaceSlashCommand, true);
282    } else {
283        slash_command_registry
284            .unregister_command(assistant_slash_commands::CargoWorkspaceSlashCommand);
285    }
286}
287
288#[cfg(test)]
289#[ctor::ctor]
290fn init_logger() {
291    if std::env::var("RUST_LOG").is_ok() {
292        env_logger::init();
293    }
294}