assistant.rs

  1#![cfg_attr(target_os = "windows", allow(unused, dead_code))]
  2
  3pub mod assistant_panel;
  4mod context;
  5pub mod context_store;
  6mod inline_assistant;
  7mod patch;
  8mod prompt_library;
  9mod slash_command;
 10pub(crate) mod slash_command_picker;
 11pub mod slash_command_settings;
 12mod streaming_diff;
 13mod terminal_inline_assistant;
 14
 15use std::path::PathBuf;
 16use std::sync::Arc;
 17
 18use ::prompt_library::{PromptBuilder, PromptLoadingParams};
 19use assistant_settings::AssistantSettings;
 20use assistant_slash_command::SlashCommandRegistry;
 21use assistant_slash_commands::{ProjectSlashCommandFeatureFlag, SearchSlashCommandFeatureFlag};
 22use client::{proto, Client};
 23use command_palette_hooks::CommandPaletteFilter;
 24use feature_flags::FeatureFlagAppExt;
 25use fs::Fs;
 26use gpui::impl_internal_actions;
 27use gpui::{actions, AppContext, Global, SharedString, UpdateGlobal};
 28use language_model::{
 29    LanguageModelId, LanguageModelProviderId, LanguageModelRegistry, LanguageModelResponseMessage,
 30};
 31use semantic_index::{CloudEmbeddingProvider, SemanticDb};
 32use serde::{Deserialize, Serialize};
 33use settings::{Settings, SettingsStore};
 34use util::ResultExt;
 35
 36pub use crate::assistant_panel::{AssistantPanel, AssistantPanelEvent};
 37pub use crate::context::*;
 38pub use crate::context_store::*;
 39pub(crate) use crate::inline_assistant::*;
 40pub use crate::patch::*;
 41use crate::slash_command_settings::SlashCommandSettings;
 42pub(crate) use crate::streaming_diff::*;
 43
 44actions!(
 45    assistant,
 46    [
 47        Assist,
 48        Edit,
 49        Split,
 50        CopyCode,
 51        CycleMessageRole,
 52        QuoteSelection,
 53        InsertIntoEditor,
 54        ToggleFocus,
 55        InsertActivePrompt,
 56        DeployHistory,
 57        DeployPromptLibrary,
 58        ConfirmCommand,
 59        NewContext,
 60        ToggleModelSelector,
 61        CycleNextInlineAssist,
 62        CyclePreviousInlineAssist
 63    ]
 64);
 65
 66#[derive(PartialEq, Clone)]
 67pub enum InsertDraggedFiles {
 68    ProjectPaths(Vec<PathBuf>),
 69    ExternalFiles(Vec<PathBuf>),
 70}
 71
 72impl_internal_actions!(assistant, [InsertDraggedFiles]);
 73
 74const DEFAULT_CONTEXT_LINES: usize = 50;
 75
 76#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
 77pub struct MessageId(clock::Lamport);
 78
 79impl MessageId {
 80    pub fn as_u64(self) -> u64 {
 81        self.0.as_u64()
 82    }
 83}
 84
 85#[derive(Deserialize, Debug)]
 86pub struct LanguageModelUsage {
 87    pub prompt_tokens: u32,
 88    pub completion_tokens: u32,
 89    pub total_tokens: u32,
 90}
 91
 92#[derive(Deserialize, Debug)]
 93pub struct LanguageModelChoiceDelta {
 94    pub index: u32,
 95    pub delta: LanguageModelResponseMessage,
 96    pub finish_reason: Option<String>,
 97}
 98
 99#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
100pub enum MessageStatus {
101    Pending,
102    Done,
103    Error(SharedString),
104    Canceled,
105}
106
107impl MessageStatus {
108    pub fn from_proto(status: proto::ContextMessageStatus) -> MessageStatus {
109        match status.variant {
110            Some(proto::context_message_status::Variant::Pending(_)) => MessageStatus::Pending,
111            Some(proto::context_message_status::Variant::Done(_)) => MessageStatus::Done,
112            Some(proto::context_message_status::Variant::Error(error)) => {
113                MessageStatus::Error(error.message.into())
114            }
115            Some(proto::context_message_status::Variant::Canceled(_)) => MessageStatus::Canceled,
116            None => MessageStatus::Pending,
117        }
118    }
119
120    pub fn to_proto(&self) -> proto::ContextMessageStatus {
121        match self {
122            MessageStatus::Pending => proto::ContextMessageStatus {
123                variant: Some(proto::context_message_status::Variant::Pending(
124                    proto::context_message_status::Pending {},
125                )),
126            },
127            MessageStatus::Done => proto::ContextMessageStatus {
128                variant: Some(proto::context_message_status::Variant::Done(
129                    proto::context_message_status::Done {},
130                )),
131            },
132            MessageStatus::Error(message) => proto::ContextMessageStatus {
133                variant: Some(proto::context_message_status::Variant::Error(
134                    proto::context_message_status::Error {
135                        message: message.to_string(),
136                    },
137                )),
138            },
139            MessageStatus::Canceled => proto::ContextMessageStatus {
140                variant: Some(proto::context_message_status::Variant::Canceled(
141                    proto::context_message_status::Canceled {},
142                )),
143            },
144        }
145    }
146}
147
148/// The state pertaining to the Assistant.
149#[derive(Default)]
150struct Assistant {
151    /// Whether the Assistant is enabled.
152    enabled: bool,
153}
154
155impl Global for Assistant {}
156
157impl Assistant {
158    const NAMESPACE: &'static str = "assistant";
159
160    fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
161        if self.enabled == enabled {
162            return;
163        }
164
165        self.enabled = enabled;
166
167        if !enabled {
168            CommandPaletteFilter::update_global(cx, |filter, _cx| {
169                filter.hide_namespace(Self::NAMESPACE);
170            });
171
172            return;
173        }
174
175        CommandPaletteFilter::update_global(cx, |filter, _cx| {
176            filter.show_namespace(Self::NAMESPACE);
177        });
178    }
179}
180
181pub fn init(
182    fs: Arc<dyn Fs>,
183    client: Arc<Client>,
184    stdout_is_a_pty: bool,
185    cx: &mut AppContext,
186) -> Arc<PromptBuilder> {
187    cx.set_global(Assistant::default());
188    AssistantSettings::register(cx);
189    SlashCommandSettings::register(cx);
190
191    cx.spawn(|mut cx| {
192        let client = client.clone();
193        async move {
194            let is_search_slash_command_enabled = cx
195                .update(|cx| cx.wait_for_flag::<SearchSlashCommandFeatureFlag>())?
196                .await;
197            let is_project_slash_command_enabled = cx
198                .update(|cx| cx.wait_for_flag::<ProjectSlashCommandFeatureFlag>())?
199                .await;
200
201            if !is_search_slash_command_enabled && !is_project_slash_command_enabled {
202                return Ok(());
203            }
204
205            let embedding_provider = CloudEmbeddingProvider::new(client.clone());
206            let semantic_index = SemanticDb::new(
207                paths::embeddings_dir().join("semantic-index-db.0.mdb"),
208                Arc::new(embedding_provider),
209                &mut cx,
210            )
211            .await?;
212
213            cx.update(|cx| cx.set_global(semantic_index))
214        }
215    })
216    .detach();
217
218    context_store::init(&client.clone().into());
219    ::prompt_library::init(cx);
220    init_language_model_settings(cx);
221    assistant_slash_command::init(cx);
222    assistant_tool::init(cx);
223    assistant_panel::init(cx);
224    context_server::init(cx);
225
226    let prompt_builder = PromptBuilder::new(Some(PromptLoadingParams {
227        fs: fs.clone(),
228        repo_path: stdout_is_a_pty
229            .then(|| std::env::current_dir().log_err())
230            .flatten(),
231        cx,
232    }))
233    .log_err()
234    .map(Arc::new)
235    .unwrap_or_else(|| Arc::new(PromptBuilder::new(None).unwrap()));
236    register_slash_commands(Some(prompt_builder.clone()), cx);
237    inline_assistant::init(
238        fs.clone(),
239        prompt_builder.clone(),
240        client.telemetry().clone(),
241        cx,
242    );
243    terminal_inline_assistant::init(
244        fs.clone(),
245        prompt_builder.clone(),
246        client.telemetry().clone(),
247        cx,
248    );
249    indexed_docs::init(cx);
250
251    CommandPaletteFilter::update_global(cx, |filter, _cx| {
252        filter.hide_namespace(Assistant::NAMESPACE);
253    });
254    Assistant::update_global(cx, |assistant, cx| {
255        let settings = AssistantSettings::get_global(cx);
256
257        assistant.set_enabled(settings.enabled, cx);
258    });
259    cx.observe_global::<SettingsStore>(|cx| {
260        Assistant::update_global(cx, |assistant, cx| {
261            let settings = AssistantSettings::get_global(cx);
262            assistant.set_enabled(settings.enabled, cx);
263        });
264    })
265    .detach();
266
267    prompt_builder
268}
269
270fn init_language_model_settings(cx: &mut AppContext) {
271    update_active_language_model_from_settings(cx);
272
273    cx.observe_global::<SettingsStore>(update_active_language_model_from_settings)
274        .detach();
275    cx.subscribe(
276        &LanguageModelRegistry::global(cx),
277        |_, event: &language_model::Event, cx| match event {
278            language_model::Event::ProviderStateChanged
279            | language_model::Event::AddedProvider(_)
280            | language_model::Event::RemovedProvider(_) => {
281                update_active_language_model_from_settings(cx);
282            }
283            _ => {}
284        },
285    )
286    .detach();
287}
288
289fn update_active_language_model_from_settings(cx: &mut AppContext) {
290    let settings = AssistantSettings::get_global(cx);
291    let provider_name = LanguageModelProviderId::from(settings.default_model.provider.clone());
292    let model_id = LanguageModelId::from(settings.default_model.model.clone());
293    let inline_alternatives = settings
294        .inline_alternatives
295        .iter()
296        .map(|alternative| {
297            (
298                LanguageModelProviderId::from(alternative.provider.clone()),
299                LanguageModelId::from(alternative.model.clone()),
300            )
301        })
302        .collect::<Vec<_>>();
303    LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
304        registry.select_active_model(&provider_name, &model_id, cx);
305        registry.select_inline_alternative_models(inline_alternatives, cx);
306    });
307}
308
309fn register_slash_commands(prompt_builder: Option<Arc<PromptBuilder>>, cx: &mut AppContext) {
310    let slash_command_registry = SlashCommandRegistry::global(cx);
311
312    slash_command_registry.register_command(assistant_slash_commands::FileSlashCommand, true);
313    slash_command_registry.register_command(assistant_slash_commands::DeltaSlashCommand, true);
314    slash_command_registry.register_command(assistant_slash_commands::OutlineSlashCommand, true);
315    slash_command_registry.register_command(assistant_slash_commands::TabSlashCommand, true);
316    slash_command_registry
317        .register_command(assistant_slash_commands::CargoWorkspaceSlashCommand, true);
318    slash_command_registry.register_command(assistant_slash_commands::PromptSlashCommand, true);
319    slash_command_registry.register_command(assistant_slash_commands::SelectionCommand, true);
320    slash_command_registry.register_command(assistant_slash_commands::DefaultSlashCommand, false);
321    slash_command_registry.register_command(assistant_slash_commands::TerminalSlashCommand, true);
322    slash_command_registry.register_command(assistant_slash_commands::NowSlashCommand, false);
323    slash_command_registry
324        .register_command(assistant_slash_commands::DiagnosticsSlashCommand, true);
325    slash_command_registry.register_command(assistant_slash_commands::FetchSlashCommand, true);
326
327    if let Some(prompt_builder) = prompt_builder {
328        cx.observe_flag::<assistant_slash_commands::ProjectSlashCommandFeatureFlag, _>({
329            let slash_command_registry = slash_command_registry.clone();
330            move |is_enabled, _cx| {
331                if is_enabled {
332                    slash_command_registry.register_command(
333                        assistant_slash_commands::ProjectSlashCommand::new(prompt_builder.clone()),
334                        true,
335                    );
336                }
337            }
338        })
339        .detach();
340    }
341
342    cx.observe_flag::<assistant_slash_commands::AutoSlashCommandFeatureFlag, _>({
343        let slash_command_registry = slash_command_registry.clone();
344        move |is_enabled, _cx| {
345            if is_enabled {
346                // [#auto-staff-ship] TODO remove this when /auto is no longer staff-shipped
347                slash_command_registry
348                    .register_command(assistant_slash_commands::AutoCommand, true);
349            }
350        }
351    })
352    .detach();
353
354    cx.observe_flag::<assistant_slash_commands::StreamingExampleSlashCommandFeatureFlag, _>({
355        let slash_command_registry = slash_command_registry.clone();
356        move |is_enabled, _cx| {
357            if is_enabled {
358                slash_command_registry.register_command(
359                    assistant_slash_commands::StreamingExampleSlashCommand,
360                    false,
361                );
362            }
363        }
364    })
365    .detach();
366
367    update_slash_commands_from_settings(cx);
368    cx.observe_global::<SettingsStore>(update_slash_commands_from_settings)
369        .detach();
370
371    cx.observe_flag::<assistant_slash_commands::SearchSlashCommandFeatureFlag, _>({
372        let slash_command_registry = slash_command_registry.clone();
373        move |is_enabled, _cx| {
374            if is_enabled {
375                slash_command_registry
376                    .register_command(assistant_slash_commands::SearchSlashCommand, true);
377            }
378        }
379    })
380    .detach();
381}
382
383fn update_slash_commands_from_settings(cx: &mut AppContext) {
384    let slash_command_registry = SlashCommandRegistry::global(cx);
385    let settings = SlashCommandSettings::get_global(cx);
386
387    if settings.docs.enabled {
388        slash_command_registry.register_command(assistant_slash_commands::DocsSlashCommand, true);
389    } else {
390        slash_command_registry.unregister_command(assistant_slash_commands::DocsSlashCommand);
391    }
392
393    if settings.cargo_workspace.enabled {
394        slash_command_registry
395            .register_command(assistant_slash_commands::CargoWorkspaceSlashCommand, true);
396    } else {
397        slash_command_registry
398            .unregister_command(assistant_slash_commands::CargoWorkspaceSlashCommand);
399    }
400}
401
402pub fn humanize_token_count(count: usize) -> String {
403    match count {
404        0..=999 => count.to_string(),
405        1000..=9999 => {
406            let thousands = count / 1000;
407            let hundreds = (count % 1000 + 50) / 100;
408            if hundreds == 0 {
409                format!("{}k", thousands)
410            } else if hundreds == 10 {
411                format!("{}k", thousands + 1)
412            } else {
413                format!("{}.{}k", thousands, hundreds)
414            }
415        }
416        _ => format!("{}k", (count + 500) / 1000),
417    }
418}
419
420#[cfg(test)]
421#[ctor::ctor]
422fn init_logger() {
423    if std::env::var("RUST_LOG").is_ok() {
424        env_logger::init();
425    }
426}