assistant.rs

  1#![cfg_attr(target_os = "windows", allow(unused, dead_code))]
  2
  3pub mod assistant_panel;
  4mod inline_assistant;
  5pub mod slash_command_settings;
  6mod terminal_inline_assistant;
  7
  8use std::sync::Arc;
  9
 10use assistant_settings::AssistantSettings;
 11use assistant_slash_command::SlashCommandRegistry;
 12use assistant_slash_commands::{ProjectSlashCommandFeatureFlag, SearchSlashCommandFeatureFlag};
 13use client::Client;
 14use command_palette_hooks::CommandPaletteFilter;
 15use feature_flags::FeatureFlagAppExt;
 16use fs::Fs;
 17use gpui::{actions, AppContext, Global, UpdateGlobal};
 18use language_model::{
 19    LanguageModelId, LanguageModelProviderId, LanguageModelRegistry, LanguageModelResponseMessage,
 20};
 21use prompt_library::{PromptBuilder, PromptLoadingParams};
 22use semantic_index::{CloudEmbeddingProvider, SemanticDb};
 23use serde::Deserialize;
 24use settings::{Settings, SettingsStore};
 25use util::ResultExt;
 26
 27pub use crate::assistant_panel::{AssistantPanel, AssistantPanelEvent};
 28pub(crate) use crate::inline_assistant::*;
 29use crate::slash_command_settings::SlashCommandSettings;
 30
 31actions!(
 32    assistant,
 33    [
 34        InsertActivePrompt,
 35        DeployHistory,
 36        DeployPromptLibrary,
 37        NewContext,
 38        CycleNextInlineAssist,
 39        CyclePreviousInlineAssist
 40    ]
 41);
 42
 43const DEFAULT_CONTEXT_LINES: usize = 50;
 44
 45#[derive(Deserialize, Debug)]
 46pub struct LanguageModelUsage {
 47    pub prompt_tokens: u32,
 48    pub completion_tokens: u32,
 49    pub total_tokens: u32,
 50}
 51
 52#[derive(Deserialize, Debug)]
 53pub struct LanguageModelChoiceDelta {
 54    pub index: u32,
 55    pub delta: LanguageModelResponseMessage,
 56    pub finish_reason: Option<String>,
 57}
 58
 59/// The state pertaining to the Assistant.
 60#[derive(Default)]
 61struct Assistant {
 62    /// Whether the Assistant is enabled.
 63    enabled: bool,
 64}
 65
 66impl Global for Assistant {}
 67
 68impl Assistant {
 69    const NAMESPACE: &'static str = "assistant";
 70
 71    fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
 72        if self.enabled == enabled {
 73            return;
 74        }
 75
 76        self.enabled = enabled;
 77
 78        if !enabled {
 79            CommandPaletteFilter::update_global(cx, |filter, _cx| {
 80                filter.hide_namespace(Self::NAMESPACE);
 81            });
 82
 83            return;
 84        }
 85
 86        CommandPaletteFilter::update_global(cx, |filter, _cx| {
 87            filter.show_namespace(Self::NAMESPACE);
 88        });
 89    }
 90}
 91
 92pub fn init(
 93    fs: Arc<dyn Fs>,
 94    client: Arc<Client>,
 95    stdout_is_a_pty: bool,
 96    cx: &mut AppContext,
 97) -> Arc<PromptBuilder> {
 98    cx.set_global(Assistant::default());
 99    AssistantSettings::register(cx);
100    SlashCommandSettings::register(cx);
101
102    cx.spawn(|mut cx| {
103        let client = client.clone();
104        async move {
105            let is_search_slash_command_enabled = cx
106                .update(|cx| cx.wait_for_flag::<SearchSlashCommandFeatureFlag>())?
107                .await;
108            let is_project_slash_command_enabled = cx
109                .update(|cx| cx.wait_for_flag::<ProjectSlashCommandFeatureFlag>())?
110                .await;
111
112            if !is_search_slash_command_enabled && !is_project_slash_command_enabled {
113                return Ok(());
114            }
115
116            let embedding_provider = CloudEmbeddingProvider::new(client.clone());
117            let semantic_index = SemanticDb::new(
118                paths::embeddings_dir().join("semantic-index-db.0.mdb"),
119                Arc::new(embedding_provider),
120                &mut cx,
121            )
122            .await?;
123
124            cx.update(|cx| cx.set_global(semantic_index))
125        }
126    })
127    .detach();
128
129    assistant_context_editor::init(client.clone(), cx);
130    prompt_library::init(cx);
131    init_language_model_settings(cx);
132    assistant_slash_command::init(cx);
133    assistant_tool::init(cx);
134    assistant_panel::init(cx);
135    context_server::init(cx);
136
137    let prompt_builder = PromptBuilder::new(Some(PromptLoadingParams {
138        fs: fs.clone(),
139        repo_path: stdout_is_a_pty
140            .then(|| std::env::current_dir().log_err())
141            .flatten(),
142        cx,
143    }))
144    .log_err()
145    .map(Arc::new)
146    .unwrap_or_else(|| Arc::new(PromptBuilder::new(None).unwrap()));
147    register_slash_commands(Some(prompt_builder.clone()), cx);
148    inline_assistant::init(
149        fs.clone(),
150        prompt_builder.clone(),
151        client.telemetry().clone(),
152        cx,
153    );
154    terminal_inline_assistant::init(
155        fs.clone(),
156        prompt_builder.clone(),
157        client.telemetry().clone(),
158        cx,
159    );
160    indexed_docs::init(cx);
161
162    CommandPaletteFilter::update_global(cx, |filter, _cx| {
163        filter.hide_namespace(Assistant::NAMESPACE);
164    });
165    Assistant::update_global(cx, |assistant, cx| {
166        let settings = AssistantSettings::get_global(cx);
167
168        assistant.set_enabled(settings.enabled, cx);
169    });
170    cx.observe_global::<SettingsStore>(|cx| {
171        Assistant::update_global(cx, |assistant, cx| {
172            let settings = AssistantSettings::get_global(cx);
173            assistant.set_enabled(settings.enabled, cx);
174        });
175    })
176    .detach();
177
178    prompt_builder
179}
180
181fn init_language_model_settings(cx: &mut AppContext) {
182    update_active_language_model_from_settings(cx);
183
184    cx.observe_global::<SettingsStore>(update_active_language_model_from_settings)
185        .detach();
186    cx.subscribe(
187        &LanguageModelRegistry::global(cx),
188        |_, event: &language_model::Event, cx| match event {
189            language_model::Event::ProviderStateChanged
190            | language_model::Event::AddedProvider(_)
191            | language_model::Event::RemovedProvider(_) => {
192                update_active_language_model_from_settings(cx);
193            }
194            _ => {}
195        },
196    )
197    .detach();
198}
199
200fn update_active_language_model_from_settings(cx: &mut AppContext) {
201    let settings = AssistantSettings::get_global(cx);
202    let provider_name = LanguageModelProviderId::from(settings.default_model.provider.clone());
203    let model_id = LanguageModelId::from(settings.default_model.model.clone());
204    let inline_alternatives = settings
205        .inline_alternatives
206        .iter()
207        .map(|alternative| {
208            (
209                LanguageModelProviderId::from(alternative.provider.clone()),
210                LanguageModelId::from(alternative.model.clone()),
211            )
212        })
213        .collect::<Vec<_>>();
214    LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
215        registry.select_active_model(&provider_name, &model_id, cx);
216        registry.select_inline_alternative_models(inline_alternatives, cx);
217    });
218}
219
220fn register_slash_commands(prompt_builder: Option<Arc<PromptBuilder>>, cx: &mut AppContext) {
221    let slash_command_registry = SlashCommandRegistry::global(cx);
222
223    slash_command_registry.register_command(assistant_slash_commands::FileSlashCommand, true);
224    slash_command_registry.register_command(assistant_slash_commands::DeltaSlashCommand, true);
225    slash_command_registry.register_command(assistant_slash_commands::OutlineSlashCommand, true);
226    slash_command_registry.register_command(assistant_slash_commands::TabSlashCommand, true);
227    slash_command_registry
228        .register_command(assistant_slash_commands::CargoWorkspaceSlashCommand, true);
229    slash_command_registry.register_command(assistant_slash_commands::PromptSlashCommand, true);
230    slash_command_registry.register_command(assistant_slash_commands::SelectionCommand, true);
231    slash_command_registry.register_command(assistant_slash_commands::DefaultSlashCommand, false);
232    slash_command_registry.register_command(assistant_slash_commands::TerminalSlashCommand, true);
233    slash_command_registry.register_command(assistant_slash_commands::NowSlashCommand, false);
234    slash_command_registry
235        .register_command(assistant_slash_commands::DiagnosticsSlashCommand, true);
236    slash_command_registry.register_command(assistant_slash_commands::FetchSlashCommand, true);
237
238    if let Some(prompt_builder) = prompt_builder {
239        cx.observe_flag::<assistant_slash_commands::ProjectSlashCommandFeatureFlag, _>({
240            let slash_command_registry = slash_command_registry.clone();
241            move |is_enabled, _cx| {
242                if is_enabled {
243                    slash_command_registry.register_command(
244                        assistant_slash_commands::ProjectSlashCommand::new(prompt_builder.clone()),
245                        true,
246                    );
247                }
248            }
249        })
250        .detach();
251    }
252
253    cx.observe_flag::<assistant_slash_commands::AutoSlashCommandFeatureFlag, _>({
254        let slash_command_registry = slash_command_registry.clone();
255        move |is_enabled, _cx| {
256            if is_enabled {
257                // [#auto-staff-ship] TODO remove this when /auto is no longer staff-shipped
258                slash_command_registry
259                    .register_command(assistant_slash_commands::AutoCommand, true);
260            }
261        }
262    })
263    .detach();
264
265    cx.observe_flag::<assistant_slash_commands::StreamingExampleSlashCommandFeatureFlag, _>({
266        let slash_command_registry = slash_command_registry.clone();
267        move |is_enabled, _cx| {
268            if is_enabled {
269                slash_command_registry.register_command(
270                    assistant_slash_commands::StreamingExampleSlashCommand,
271                    false,
272                );
273            }
274        }
275    })
276    .detach();
277
278    update_slash_commands_from_settings(cx);
279    cx.observe_global::<SettingsStore>(update_slash_commands_from_settings)
280        .detach();
281
282    cx.observe_flag::<assistant_slash_commands::SearchSlashCommandFeatureFlag, _>({
283        let slash_command_registry = slash_command_registry.clone();
284        move |is_enabled, _cx| {
285            if is_enabled {
286                slash_command_registry
287                    .register_command(assistant_slash_commands::SearchSlashCommand, true);
288            }
289        }
290    })
291    .detach();
292}
293
294fn update_slash_commands_from_settings(cx: &mut AppContext) {
295    let slash_command_registry = SlashCommandRegistry::global(cx);
296    let settings = SlashCommandSettings::get_global(cx);
297
298    if settings.docs.enabled {
299        slash_command_registry.register_command(assistant_slash_commands::DocsSlashCommand, true);
300    } else {
301        slash_command_registry.unregister_command(assistant_slash_commands::DocsSlashCommand);
302    }
303
304    if settings.cargo_workspace.enabled {
305        slash_command_registry
306            .register_command(assistant_slash_commands::CargoWorkspaceSlashCommand, true);
307    } else {
308        slash_command_registry
309            .unregister_command(assistant_slash_commands::CargoWorkspaceSlashCommand);
310    }
311}
312
313#[cfg(test)]
314#[ctor::ctor]
315fn init_logger() {
316    if std::env::var("RUST_LOG").is_ok() {
317        env_logger::init();
318    }
319}