assistant.rs

  1#![cfg_attr(target_os = "windows", allow(unused, dead_code))]
  2
  3pub mod assistant_panel;
  4mod context;
  5pub mod context_store;
  6mod inline_assistant;
  7mod patch;
  8mod prompt_library;
  9mod prompts;
 10mod slash_command;
 11pub(crate) mod slash_command_picker;
 12pub mod slash_command_settings;
 13mod slash_command_working_set;
 14mod streaming_diff;
 15mod terminal_inline_assistant;
 16
 17use crate::slash_command::project_command::ProjectSlashCommandFeatureFlag;
 18pub use crate::slash_command_working_set::{SlashCommandId, SlashCommandWorkingSet};
 19pub use assistant_panel::{AssistantPanel, AssistantPanelEvent};
 20use assistant_settings::AssistantSettings;
 21use assistant_slash_command::SlashCommandRegistry;
 22use client::{proto, Client};
 23use command_palette_hooks::CommandPaletteFilter;
 24pub use context::*;
 25pub use context_store::*;
 26use feature_flags::FeatureFlagAppExt;
 27use fs::Fs;
 28use gpui::impl_internal_actions;
 29use gpui::{actions, AppContext, Global, SharedString, UpdateGlobal};
 30pub(crate) use inline_assistant::*;
 31use language_model::{
 32    LanguageModelId, LanguageModelProviderId, LanguageModelRegistry, LanguageModelResponseMessage,
 33};
 34pub use patch::*;
 35pub use prompts::PromptBuilder;
 36use prompts::PromptLoadingParams;
 37use semantic_index::{CloudEmbeddingProvider, SemanticDb};
 38use serde::{Deserialize, Serialize};
 39use settings::{Settings, SettingsStore};
 40use slash_command::search_command::SearchSlashCommandFeatureFlag;
 41use slash_command::{
 42    auto_command, cargo_workspace_command, default_command, delta_command, diagnostics_command,
 43    docs_command, fetch_command, file_command, now_command, project_command, prompt_command,
 44    search_command, selection_command, symbols_command, tab_command, terminal_command,
 45};
 46use std::path::PathBuf;
 47use std::sync::Arc;
 48pub(crate) use streaming_diff::*;
 49use util::ResultExt;
 50
 51use crate::slash_command::streaming_example_command;
 52use crate::slash_command_settings::SlashCommandSettings;
 53
 54actions!(
 55    assistant,
 56    [
 57        Assist,
 58        Edit,
 59        Split,
 60        CopyCode,
 61        CycleMessageRole,
 62        QuoteSelection,
 63        InsertIntoEditor,
 64        ToggleFocus,
 65        InsertActivePrompt,
 66        DeployHistory,
 67        DeployPromptLibrary,
 68        ConfirmCommand,
 69        NewContext,
 70        ToggleModelSelector,
 71        CycleNextInlineAssist,
 72        CyclePreviousInlineAssist
 73    ]
 74);
 75
 76#[derive(PartialEq, Clone)]
 77pub enum InsertDraggedFiles {
 78    ProjectPaths(Vec<PathBuf>),
 79    ExternalFiles(Vec<PathBuf>),
 80}
 81
 82impl_internal_actions!(assistant, [InsertDraggedFiles]);
 83
 84const DEFAULT_CONTEXT_LINES: usize = 50;
 85
 86#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
 87pub struct MessageId(clock::Lamport);
 88
 89impl MessageId {
 90    pub fn as_u64(self) -> u64 {
 91        self.0.as_u64()
 92    }
 93}
 94
 95#[derive(Deserialize, Debug)]
 96pub struct LanguageModelUsage {
 97    pub prompt_tokens: u32,
 98    pub completion_tokens: u32,
 99    pub total_tokens: u32,
100}
101
102#[derive(Deserialize, Debug)]
103pub struct LanguageModelChoiceDelta {
104    pub index: u32,
105    pub delta: LanguageModelResponseMessage,
106    pub finish_reason: Option<String>,
107}
108
109#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
110pub enum MessageStatus {
111    Pending,
112    Done,
113    Error(SharedString),
114    Canceled,
115}
116
117impl MessageStatus {
118    pub fn from_proto(status: proto::ContextMessageStatus) -> MessageStatus {
119        match status.variant {
120            Some(proto::context_message_status::Variant::Pending(_)) => MessageStatus::Pending,
121            Some(proto::context_message_status::Variant::Done(_)) => MessageStatus::Done,
122            Some(proto::context_message_status::Variant::Error(error)) => {
123                MessageStatus::Error(error.message.into())
124            }
125            Some(proto::context_message_status::Variant::Canceled(_)) => MessageStatus::Canceled,
126            None => MessageStatus::Pending,
127        }
128    }
129
130    pub fn to_proto(&self) -> proto::ContextMessageStatus {
131        match self {
132            MessageStatus::Pending => proto::ContextMessageStatus {
133                variant: Some(proto::context_message_status::Variant::Pending(
134                    proto::context_message_status::Pending {},
135                )),
136            },
137            MessageStatus::Done => proto::ContextMessageStatus {
138                variant: Some(proto::context_message_status::Variant::Done(
139                    proto::context_message_status::Done {},
140                )),
141            },
142            MessageStatus::Error(message) => proto::ContextMessageStatus {
143                variant: Some(proto::context_message_status::Variant::Error(
144                    proto::context_message_status::Error {
145                        message: message.to_string(),
146                    },
147                )),
148            },
149            MessageStatus::Canceled => proto::ContextMessageStatus {
150                variant: Some(proto::context_message_status::Variant::Canceled(
151                    proto::context_message_status::Canceled {},
152                )),
153            },
154        }
155    }
156}
157
158/// The state pertaining to the Assistant.
159#[derive(Default)]
160struct Assistant {
161    /// Whether the Assistant is enabled.
162    enabled: bool,
163}
164
165impl Global for Assistant {}
166
167impl Assistant {
168    const NAMESPACE: &'static str = "assistant";
169
170    fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
171        if self.enabled == enabled {
172            return;
173        }
174
175        self.enabled = enabled;
176
177        if !enabled {
178            CommandPaletteFilter::update_global(cx, |filter, _cx| {
179                filter.hide_namespace(Self::NAMESPACE);
180            });
181
182            return;
183        }
184
185        CommandPaletteFilter::update_global(cx, |filter, _cx| {
186            filter.show_namespace(Self::NAMESPACE);
187        });
188    }
189}
190
191pub fn init(
192    fs: Arc<dyn Fs>,
193    client: Arc<Client>,
194    stdout_is_a_pty: bool,
195    cx: &mut AppContext,
196) -> Arc<PromptBuilder> {
197    cx.set_global(Assistant::default());
198    AssistantSettings::register(cx);
199    SlashCommandSettings::register(cx);
200
201    cx.spawn(|mut cx| {
202        let client = client.clone();
203        async move {
204            let is_search_slash_command_enabled = cx
205                .update(|cx| cx.wait_for_flag::<SearchSlashCommandFeatureFlag>())?
206                .await;
207            let is_project_slash_command_enabled = cx
208                .update(|cx| cx.wait_for_flag::<ProjectSlashCommandFeatureFlag>())?
209                .await;
210
211            if !is_search_slash_command_enabled && !is_project_slash_command_enabled {
212                return Ok(());
213            }
214
215            let embedding_provider = CloudEmbeddingProvider::new(client.clone());
216            let semantic_index = SemanticDb::new(
217                paths::embeddings_dir().join("semantic-index-db.0.mdb"),
218                Arc::new(embedding_provider),
219                &mut cx,
220            )
221            .await?;
222
223            cx.update(|cx| cx.set_global(semantic_index))
224        }
225    })
226    .detach();
227
228    context_store::init(&client.clone().into());
229    prompt_library::init(cx);
230    init_language_model_settings(cx);
231    assistant_slash_command::init(cx);
232    assistant_tool::init(cx);
233    assistant_panel::init(cx);
234    context_server::init(cx);
235
236    let prompt_builder = prompts::PromptBuilder::new(Some(PromptLoadingParams {
237        fs: fs.clone(),
238        repo_path: stdout_is_a_pty
239            .then(|| std::env::current_dir().log_err())
240            .flatten(),
241        cx,
242    }))
243    .log_err()
244    .map(Arc::new)
245    .unwrap_or_else(|| Arc::new(prompts::PromptBuilder::new(None).unwrap()));
246    register_slash_commands(Some(prompt_builder.clone()), cx);
247    inline_assistant::init(
248        fs.clone(),
249        prompt_builder.clone(),
250        client.telemetry().clone(),
251        cx,
252    );
253    terminal_inline_assistant::init(
254        fs.clone(),
255        prompt_builder.clone(),
256        client.telemetry().clone(),
257        cx,
258    );
259    indexed_docs::init(cx);
260
261    CommandPaletteFilter::update_global(cx, |filter, _cx| {
262        filter.hide_namespace(Assistant::NAMESPACE);
263    });
264    Assistant::update_global(cx, |assistant, cx| {
265        let settings = AssistantSettings::get_global(cx);
266
267        assistant.set_enabled(settings.enabled, cx);
268    });
269    cx.observe_global::<SettingsStore>(|cx| {
270        Assistant::update_global(cx, |assistant, cx| {
271            let settings = AssistantSettings::get_global(cx);
272            assistant.set_enabled(settings.enabled, cx);
273        });
274    })
275    .detach();
276
277    prompt_builder
278}
279
280fn init_language_model_settings(cx: &mut AppContext) {
281    update_active_language_model_from_settings(cx);
282
283    cx.observe_global::<SettingsStore>(update_active_language_model_from_settings)
284        .detach();
285    cx.subscribe(
286        &LanguageModelRegistry::global(cx),
287        |_, event: &language_model::Event, cx| match event {
288            language_model::Event::ProviderStateChanged
289            | language_model::Event::AddedProvider(_)
290            | language_model::Event::RemovedProvider(_) => {
291                update_active_language_model_from_settings(cx);
292            }
293            _ => {}
294        },
295    )
296    .detach();
297}
298
299fn update_active_language_model_from_settings(cx: &mut AppContext) {
300    let settings = AssistantSettings::get_global(cx);
301    let provider_name = LanguageModelProviderId::from(settings.default_model.provider.clone());
302    let model_id = LanguageModelId::from(settings.default_model.model.clone());
303    let inline_alternatives = settings
304        .inline_alternatives
305        .iter()
306        .map(|alternative| {
307            (
308                LanguageModelProviderId::from(alternative.provider.clone()),
309                LanguageModelId::from(alternative.model.clone()),
310            )
311        })
312        .collect::<Vec<_>>();
313    LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
314        registry.select_active_model(&provider_name, &model_id, cx);
315        registry.select_inline_alternative_models(inline_alternatives, cx);
316    });
317}
318
319fn register_slash_commands(prompt_builder: Option<Arc<PromptBuilder>>, cx: &mut AppContext) {
320    let slash_command_registry = SlashCommandRegistry::global(cx);
321
322    slash_command_registry.register_command(file_command::FileSlashCommand, true);
323    slash_command_registry.register_command(delta_command::DeltaSlashCommand, true);
324    slash_command_registry.register_command(symbols_command::OutlineSlashCommand, true);
325    slash_command_registry.register_command(tab_command::TabSlashCommand, true);
326    slash_command_registry
327        .register_command(cargo_workspace_command::CargoWorkspaceSlashCommand, true);
328    slash_command_registry.register_command(prompt_command::PromptSlashCommand, true);
329    slash_command_registry.register_command(selection_command::SelectionCommand, true);
330    slash_command_registry.register_command(default_command::DefaultSlashCommand, false);
331    slash_command_registry.register_command(terminal_command::TerminalSlashCommand, true);
332    slash_command_registry.register_command(now_command::NowSlashCommand, false);
333    slash_command_registry.register_command(diagnostics_command::DiagnosticsSlashCommand, true);
334    slash_command_registry.register_command(fetch_command::FetchSlashCommand, true);
335
336    if let Some(prompt_builder) = prompt_builder {
337        cx.observe_flag::<project_command::ProjectSlashCommandFeatureFlag, _>({
338            let slash_command_registry = slash_command_registry.clone();
339            move |is_enabled, _cx| {
340                if is_enabled {
341                    slash_command_registry.register_command(
342                        project_command::ProjectSlashCommand::new(prompt_builder.clone()),
343                        true,
344                    );
345                }
346            }
347        })
348        .detach();
349    }
350
351    cx.observe_flag::<auto_command::AutoSlashCommandFeatureFlag, _>({
352        let slash_command_registry = slash_command_registry.clone();
353        move |is_enabled, _cx| {
354            if is_enabled {
355                // [#auto-staff-ship] TODO remove this when /auto is no longer staff-shipped
356                slash_command_registry.register_command(auto_command::AutoCommand, true);
357            }
358        }
359    })
360    .detach();
361
362    cx.observe_flag::<streaming_example_command::StreamingExampleSlashCommandFeatureFlag, _>({
363        let slash_command_registry = slash_command_registry.clone();
364        move |is_enabled, _cx| {
365            if is_enabled {
366                slash_command_registry.register_command(
367                    streaming_example_command::StreamingExampleSlashCommand,
368                    false,
369                );
370            }
371        }
372    })
373    .detach();
374
375    update_slash_commands_from_settings(cx);
376    cx.observe_global::<SettingsStore>(update_slash_commands_from_settings)
377        .detach();
378
379    cx.observe_flag::<search_command::SearchSlashCommandFeatureFlag, _>({
380        let slash_command_registry = slash_command_registry.clone();
381        move |is_enabled, _cx| {
382            if is_enabled {
383                slash_command_registry.register_command(search_command::SearchSlashCommand, true);
384            }
385        }
386    })
387    .detach();
388}
389
390fn update_slash_commands_from_settings(cx: &mut AppContext) {
391    let slash_command_registry = SlashCommandRegistry::global(cx);
392    let settings = SlashCommandSettings::get_global(cx);
393
394    if settings.docs.enabled {
395        slash_command_registry.register_command(docs_command::DocsSlashCommand, true);
396    } else {
397        slash_command_registry.unregister_command(docs_command::DocsSlashCommand);
398    }
399
400    if settings.cargo_workspace.enabled {
401        slash_command_registry
402            .register_command(cargo_workspace_command::CargoWorkspaceSlashCommand, true);
403    } else {
404        slash_command_registry
405            .unregister_command(cargo_workspace_command::CargoWorkspaceSlashCommand);
406    }
407}
408
409pub fn humanize_token_count(count: usize) -> String {
410    match count {
411        0..=999 => count.to_string(),
412        1000..=9999 => {
413            let thousands = count / 1000;
414            let hundreds = (count % 1000 + 50) / 100;
415            if hundreds == 0 {
416                format!("{}k", thousands)
417            } else if hundreds == 10 {
418                format!("{}k", thousands + 1)
419            } else {
420                format!("{}.{}k", thousands, hundreds)
421            }
422        }
423        _ => format!("{}k", (count + 500) / 1000),
424    }
425}
426
427#[cfg(test)]
428#[ctor::ctor]
429fn init_logger() {
430    if std::env::var("RUST_LOG").is_ok() {
431        env_logger::init();
432    }
433}