assistant.rs

  1#![cfg_attr(target_os = "windows", allow(unused, dead_code))]
  2
  3pub mod assistant_panel;
  4mod context_history;
  5mod inline_assistant;
  6pub mod slash_command_settings;
  7mod terminal_inline_assistant;
  8
  9use std::sync::Arc;
 10
 11use assistant_settings::AssistantSettings;
 12use assistant_slash_command::SlashCommandRegistry;
 13use assistant_slash_commands::{ProjectSlashCommandFeatureFlag, SearchSlashCommandFeatureFlag};
 14use client::Client;
 15use command_palette_hooks::CommandPaletteFilter;
 16use feature_flags::FeatureFlagAppExt;
 17use fs::Fs;
 18use gpui::{actions, AppContext, Global, UpdateGlobal};
 19use language_model::{
 20    LanguageModelId, LanguageModelProviderId, LanguageModelRegistry, LanguageModelResponseMessage,
 21};
 22use prompt_library::{PromptBuilder, PromptLoadingParams};
 23use semantic_index::{CloudEmbeddingProvider, SemanticDb};
 24use serde::Deserialize;
 25use settings::{Settings, SettingsStore};
 26use util::ResultExt;
 27
 28pub use crate::assistant_panel::{AssistantPanel, AssistantPanelEvent};
 29pub(crate) use crate::inline_assistant::*;
 30use crate::slash_command_settings::SlashCommandSettings;
 31
 32actions!(
 33    assistant,
 34    [
 35        ToggleFocus,
 36        InsertActivePrompt,
 37        DeployHistory,
 38        DeployPromptLibrary,
 39        NewContext,
 40        CycleNextInlineAssist,
 41        CyclePreviousInlineAssist
 42    ]
 43);
 44
 45const DEFAULT_CONTEXT_LINES: usize = 50;
 46
 47#[derive(Deserialize, Debug)]
 48pub struct LanguageModelUsage {
 49    pub prompt_tokens: u32,
 50    pub completion_tokens: u32,
 51    pub total_tokens: u32,
 52}
 53
 54#[derive(Deserialize, Debug)]
 55pub struct LanguageModelChoiceDelta {
 56    pub index: u32,
 57    pub delta: LanguageModelResponseMessage,
 58    pub finish_reason: Option<String>,
 59}
 60
 61/// The state pertaining to the Assistant.
 62#[derive(Default)]
 63struct Assistant {
 64    /// Whether the Assistant is enabled.
 65    enabled: bool,
 66}
 67
 68impl Global for Assistant {}
 69
 70impl Assistant {
 71    const NAMESPACE: &'static str = "assistant";
 72
 73    fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
 74        if self.enabled == enabled {
 75            return;
 76        }
 77
 78        self.enabled = enabled;
 79
 80        if !enabled {
 81            CommandPaletteFilter::update_global(cx, |filter, _cx| {
 82                filter.hide_namespace(Self::NAMESPACE);
 83            });
 84
 85            return;
 86        }
 87
 88        CommandPaletteFilter::update_global(cx, |filter, _cx| {
 89            filter.show_namespace(Self::NAMESPACE);
 90        });
 91    }
 92}
 93
 94pub fn init(
 95    fs: Arc<dyn Fs>,
 96    client: Arc<Client>,
 97    stdout_is_a_pty: bool,
 98    cx: &mut AppContext,
 99) -> Arc<PromptBuilder> {
100    cx.set_global(Assistant::default());
101    AssistantSettings::register(cx);
102    SlashCommandSettings::register(cx);
103
104    cx.spawn(|mut cx| {
105        let client = client.clone();
106        async move {
107            let is_search_slash_command_enabled = cx
108                .update(|cx| cx.wait_for_flag::<SearchSlashCommandFeatureFlag>())?
109                .await;
110            let is_project_slash_command_enabled = cx
111                .update(|cx| cx.wait_for_flag::<ProjectSlashCommandFeatureFlag>())?
112                .await;
113
114            if !is_search_slash_command_enabled && !is_project_slash_command_enabled {
115                return Ok(());
116            }
117
118            let embedding_provider = CloudEmbeddingProvider::new(client.clone());
119            let semantic_index = SemanticDb::new(
120                paths::embeddings_dir().join("semantic-index-db.0.mdb"),
121                Arc::new(embedding_provider),
122                &mut cx,
123            )
124            .await?;
125
126            cx.update(|cx| cx.set_global(semantic_index))
127        }
128    })
129    .detach();
130
131    assistant_context_editor::init(client.clone(), cx);
132    prompt_library::init(cx);
133    init_language_model_settings(cx);
134    assistant_slash_command::init(cx);
135    assistant_tool::init(cx);
136    assistant_panel::init(cx);
137    context_server::init(cx);
138
139    let prompt_builder = PromptBuilder::new(Some(PromptLoadingParams {
140        fs: fs.clone(),
141        repo_path: stdout_is_a_pty
142            .then(|| std::env::current_dir().log_err())
143            .flatten(),
144        cx,
145    }))
146    .log_err()
147    .map(Arc::new)
148    .unwrap_or_else(|| Arc::new(PromptBuilder::new(None).unwrap()));
149    register_slash_commands(Some(prompt_builder.clone()), cx);
150    inline_assistant::init(
151        fs.clone(),
152        prompt_builder.clone(),
153        client.telemetry().clone(),
154        cx,
155    );
156    terminal_inline_assistant::init(
157        fs.clone(),
158        prompt_builder.clone(),
159        client.telemetry().clone(),
160        cx,
161    );
162    indexed_docs::init(cx);
163
164    CommandPaletteFilter::update_global(cx, |filter, _cx| {
165        filter.hide_namespace(Assistant::NAMESPACE);
166    });
167    Assistant::update_global(cx, |assistant, cx| {
168        let settings = AssistantSettings::get_global(cx);
169
170        assistant.set_enabled(settings.enabled, cx);
171    });
172    cx.observe_global::<SettingsStore>(|cx| {
173        Assistant::update_global(cx, |assistant, cx| {
174            let settings = AssistantSettings::get_global(cx);
175            assistant.set_enabled(settings.enabled, cx);
176        });
177    })
178    .detach();
179
180    prompt_builder
181}
182
183fn init_language_model_settings(cx: &mut AppContext) {
184    update_active_language_model_from_settings(cx);
185
186    cx.observe_global::<SettingsStore>(update_active_language_model_from_settings)
187        .detach();
188    cx.subscribe(
189        &LanguageModelRegistry::global(cx),
190        |_, event: &language_model::Event, cx| match event {
191            language_model::Event::ProviderStateChanged
192            | language_model::Event::AddedProvider(_)
193            | language_model::Event::RemovedProvider(_) => {
194                update_active_language_model_from_settings(cx);
195            }
196            _ => {}
197        },
198    )
199    .detach();
200}
201
202fn update_active_language_model_from_settings(cx: &mut AppContext) {
203    let settings = AssistantSettings::get_global(cx);
204    let provider_name = LanguageModelProviderId::from(settings.default_model.provider.clone());
205    let model_id = LanguageModelId::from(settings.default_model.model.clone());
206    let inline_alternatives = settings
207        .inline_alternatives
208        .iter()
209        .map(|alternative| {
210            (
211                LanguageModelProviderId::from(alternative.provider.clone()),
212                LanguageModelId::from(alternative.model.clone()),
213            )
214        })
215        .collect::<Vec<_>>();
216    LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
217        registry.select_active_model(&provider_name, &model_id, cx);
218        registry.select_inline_alternative_models(inline_alternatives, cx);
219    });
220}
221
222fn register_slash_commands(prompt_builder: Option<Arc<PromptBuilder>>, cx: &mut AppContext) {
223    let slash_command_registry = SlashCommandRegistry::global(cx);
224
225    slash_command_registry.register_command(assistant_slash_commands::FileSlashCommand, true);
226    slash_command_registry.register_command(assistant_slash_commands::DeltaSlashCommand, true);
227    slash_command_registry.register_command(assistant_slash_commands::OutlineSlashCommand, true);
228    slash_command_registry.register_command(assistant_slash_commands::TabSlashCommand, true);
229    slash_command_registry
230        .register_command(assistant_slash_commands::CargoWorkspaceSlashCommand, true);
231    slash_command_registry.register_command(assistant_slash_commands::PromptSlashCommand, true);
232    slash_command_registry.register_command(assistant_slash_commands::SelectionCommand, true);
233    slash_command_registry.register_command(assistant_slash_commands::DefaultSlashCommand, false);
234    slash_command_registry.register_command(assistant_slash_commands::TerminalSlashCommand, true);
235    slash_command_registry.register_command(assistant_slash_commands::NowSlashCommand, false);
236    slash_command_registry
237        .register_command(assistant_slash_commands::DiagnosticsSlashCommand, true);
238    slash_command_registry.register_command(assistant_slash_commands::FetchSlashCommand, true);
239
240    if let Some(prompt_builder) = prompt_builder {
241        cx.observe_flag::<assistant_slash_commands::ProjectSlashCommandFeatureFlag, _>({
242            let slash_command_registry = slash_command_registry.clone();
243            move |is_enabled, _cx| {
244                if is_enabled {
245                    slash_command_registry.register_command(
246                        assistant_slash_commands::ProjectSlashCommand::new(prompt_builder.clone()),
247                        true,
248                    );
249                }
250            }
251        })
252        .detach();
253    }
254
255    cx.observe_flag::<assistant_slash_commands::AutoSlashCommandFeatureFlag, _>({
256        let slash_command_registry = slash_command_registry.clone();
257        move |is_enabled, _cx| {
258            if is_enabled {
259                // [#auto-staff-ship] TODO remove this when /auto is no longer staff-shipped
260                slash_command_registry
261                    .register_command(assistant_slash_commands::AutoCommand, true);
262            }
263        }
264    })
265    .detach();
266
267    cx.observe_flag::<assistant_slash_commands::StreamingExampleSlashCommandFeatureFlag, _>({
268        let slash_command_registry = slash_command_registry.clone();
269        move |is_enabled, _cx| {
270            if is_enabled {
271                slash_command_registry.register_command(
272                    assistant_slash_commands::StreamingExampleSlashCommand,
273                    false,
274                );
275            }
276        }
277    })
278    .detach();
279
280    update_slash_commands_from_settings(cx);
281    cx.observe_global::<SettingsStore>(update_slash_commands_from_settings)
282        .detach();
283
284    cx.observe_flag::<assistant_slash_commands::SearchSlashCommandFeatureFlag, _>({
285        let slash_command_registry = slash_command_registry.clone();
286        move |is_enabled, _cx| {
287            if is_enabled {
288                slash_command_registry
289                    .register_command(assistant_slash_commands::SearchSlashCommand, true);
290            }
291        }
292    })
293    .detach();
294}
295
296fn update_slash_commands_from_settings(cx: &mut AppContext) {
297    let slash_command_registry = SlashCommandRegistry::global(cx);
298    let settings = SlashCommandSettings::get_global(cx);
299
300    if settings.docs.enabled {
301        slash_command_registry.register_command(assistant_slash_commands::DocsSlashCommand, true);
302    } else {
303        slash_command_registry.unregister_command(assistant_slash_commands::DocsSlashCommand);
304    }
305
306    if settings.cargo_workspace.enabled {
307        slash_command_registry
308            .register_command(assistant_slash_commands::CargoWorkspaceSlashCommand, true);
309    } else {
310        slash_command_registry
311            .unregister_command(assistant_slash_commands::CargoWorkspaceSlashCommand);
312    }
313}
314
315#[cfg(test)]
316#[ctor::ctor]
317fn init_logger() {
318    if std::env::var("RUST_LOG").is_ok() {
319        env_logger::init();
320    }
321}