assistant.rs

  1#![cfg_attr(target_os = "windows", allow(unused, dead_code))]
  2
  3pub mod assistant_panel;
  4mod context;
  5mod context_editor;
  6mod context_history;
  7pub mod context_store;
  8mod inline_assistant;
  9mod patch;
 10mod slash_command;
 11pub(crate) mod slash_command_picker;
 12pub mod slash_command_settings;
 13mod terminal_inline_assistant;
 14
 15use std::path::PathBuf;
 16use std::sync::Arc;
 17
 18use assistant_settings::AssistantSettings;
 19use assistant_slash_command::SlashCommandRegistry;
 20use assistant_slash_commands::{ProjectSlashCommandFeatureFlag, SearchSlashCommandFeatureFlag};
 21use client::{proto, Client};
 22use command_palette_hooks::CommandPaletteFilter;
 23use feature_flags::FeatureFlagAppExt;
 24use fs::Fs;
 25use gpui::impl_internal_actions;
 26use gpui::{actions, AppContext, Global, SharedString, UpdateGlobal};
 27use language_model::{
 28    LanguageModelId, LanguageModelProviderId, LanguageModelRegistry, LanguageModelResponseMessage,
 29};
 30use prompt_library::{PromptBuilder, PromptLoadingParams};
 31use semantic_index::{CloudEmbeddingProvider, SemanticDb};
 32use serde::{Deserialize, Serialize};
 33use settings::{Settings, SettingsStore};
 34use util::ResultExt;
 35
 36pub use crate::assistant_panel::{AssistantPanel, AssistantPanelEvent};
 37pub use crate::context::*;
 38pub use crate::context_store::*;
 39pub(crate) use crate::inline_assistant::*;
 40pub use crate::patch::*;
 41use crate::slash_command_settings::SlashCommandSettings;
 42
 43actions!(
 44    assistant,
 45    [
 46        Assist,
 47        Edit,
 48        Split,
 49        CopyCode,
 50        CycleMessageRole,
 51        QuoteSelection,
 52        InsertIntoEditor,
 53        ToggleFocus,
 54        InsertActivePrompt,
 55        DeployHistory,
 56        DeployPromptLibrary,
 57        ConfirmCommand,
 58        NewContext,
 59        ToggleModelSelector,
 60        CycleNextInlineAssist,
 61        CyclePreviousInlineAssist
 62    ]
 63);
 64
 65#[derive(PartialEq, Clone)]
 66pub enum InsertDraggedFiles {
 67    ProjectPaths(Vec<PathBuf>),
 68    ExternalFiles(Vec<PathBuf>),
 69}
 70
 71impl_internal_actions!(assistant, [InsertDraggedFiles]);
 72
 73const DEFAULT_CONTEXT_LINES: usize = 50;
 74
 75#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
 76pub struct MessageId(clock::Lamport);
 77
 78impl MessageId {
 79    pub fn as_u64(self) -> u64 {
 80        self.0.as_u64()
 81    }
 82}
 83
 84#[derive(Deserialize, Debug)]
 85pub struct LanguageModelUsage {
 86    pub prompt_tokens: u32,
 87    pub completion_tokens: u32,
 88    pub total_tokens: u32,
 89}
 90
 91#[derive(Deserialize, Debug)]
 92pub struct LanguageModelChoiceDelta {
 93    pub index: u32,
 94    pub delta: LanguageModelResponseMessage,
 95    pub finish_reason: Option<String>,
 96}
 97
 98#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
 99pub enum MessageStatus {
100    Pending,
101    Done,
102    Error(SharedString),
103    Canceled,
104}
105
106impl MessageStatus {
107    pub fn from_proto(status: proto::ContextMessageStatus) -> MessageStatus {
108        match status.variant {
109            Some(proto::context_message_status::Variant::Pending(_)) => MessageStatus::Pending,
110            Some(proto::context_message_status::Variant::Done(_)) => MessageStatus::Done,
111            Some(proto::context_message_status::Variant::Error(error)) => {
112                MessageStatus::Error(error.message.into())
113            }
114            Some(proto::context_message_status::Variant::Canceled(_)) => MessageStatus::Canceled,
115            None => MessageStatus::Pending,
116        }
117    }
118
119    pub fn to_proto(&self) -> proto::ContextMessageStatus {
120        match self {
121            MessageStatus::Pending => proto::ContextMessageStatus {
122                variant: Some(proto::context_message_status::Variant::Pending(
123                    proto::context_message_status::Pending {},
124                )),
125            },
126            MessageStatus::Done => proto::ContextMessageStatus {
127                variant: Some(proto::context_message_status::Variant::Done(
128                    proto::context_message_status::Done {},
129                )),
130            },
131            MessageStatus::Error(message) => proto::ContextMessageStatus {
132                variant: Some(proto::context_message_status::Variant::Error(
133                    proto::context_message_status::Error {
134                        message: message.to_string(),
135                    },
136                )),
137            },
138            MessageStatus::Canceled => proto::ContextMessageStatus {
139                variant: Some(proto::context_message_status::Variant::Canceled(
140                    proto::context_message_status::Canceled {},
141                )),
142            },
143        }
144    }
145}
146
147/// The state pertaining to the Assistant.
148#[derive(Default)]
149struct Assistant {
150    /// Whether the Assistant is enabled.
151    enabled: bool,
152}
153
154impl Global for Assistant {}
155
156impl Assistant {
157    const NAMESPACE: &'static str = "assistant";
158
159    fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
160        if self.enabled == enabled {
161            return;
162        }
163
164        self.enabled = enabled;
165
166        if !enabled {
167            CommandPaletteFilter::update_global(cx, |filter, _cx| {
168                filter.hide_namespace(Self::NAMESPACE);
169            });
170
171            return;
172        }
173
174        CommandPaletteFilter::update_global(cx, |filter, _cx| {
175            filter.show_namespace(Self::NAMESPACE);
176        });
177    }
178}
179
180pub fn init(
181    fs: Arc<dyn Fs>,
182    client: Arc<Client>,
183    stdout_is_a_pty: bool,
184    cx: &mut AppContext,
185) -> Arc<PromptBuilder> {
186    cx.set_global(Assistant::default());
187    AssistantSettings::register(cx);
188    SlashCommandSettings::register(cx);
189
190    cx.spawn(|mut cx| {
191        let client = client.clone();
192        async move {
193            let is_search_slash_command_enabled = cx
194                .update(|cx| cx.wait_for_flag::<SearchSlashCommandFeatureFlag>())?
195                .await;
196            let is_project_slash_command_enabled = cx
197                .update(|cx| cx.wait_for_flag::<ProjectSlashCommandFeatureFlag>())?
198                .await;
199
200            if !is_search_slash_command_enabled && !is_project_slash_command_enabled {
201                return Ok(());
202            }
203
204            let embedding_provider = CloudEmbeddingProvider::new(client.clone());
205            let semantic_index = SemanticDb::new(
206                paths::embeddings_dir().join("semantic-index-db.0.mdb"),
207                Arc::new(embedding_provider),
208                &mut cx,
209            )
210            .await?;
211
212            cx.update(|cx| cx.set_global(semantic_index))
213        }
214    })
215    .detach();
216
217    context_store::init(&client.clone().into());
218    prompt_library::init(cx);
219    init_language_model_settings(cx);
220    assistant_slash_command::init(cx);
221    assistant_tool::init(cx);
222    assistant_panel::init(cx);
223    context_server::init(cx);
224
225    let prompt_builder = PromptBuilder::new(Some(PromptLoadingParams {
226        fs: fs.clone(),
227        repo_path: stdout_is_a_pty
228            .then(|| std::env::current_dir().log_err())
229            .flatten(),
230        cx,
231    }))
232    .log_err()
233    .map(Arc::new)
234    .unwrap_or_else(|| Arc::new(PromptBuilder::new(None).unwrap()));
235    register_slash_commands(Some(prompt_builder.clone()), cx);
236    inline_assistant::init(
237        fs.clone(),
238        prompt_builder.clone(),
239        client.telemetry().clone(),
240        cx,
241    );
242    terminal_inline_assistant::init(
243        fs.clone(),
244        prompt_builder.clone(),
245        client.telemetry().clone(),
246        cx,
247    );
248    indexed_docs::init(cx);
249
250    CommandPaletteFilter::update_global(cx, |filter, _cx| {
251        filter.hide_namespace(Assistant::NAMESPACE);
252    });
253    Assistant::update_global(cx, |assistant, cx| {
254        let settings = AssistantSettings::get_global(cx);
255
256        assistant.set_enabled(settings.enabled, cx);
257    });
258    cx.observe_global::<SettingsStore>(|cx| {
259        Assistant::update_global(cx, |assistant, cx| {
260            let settings = AssistantSettings::get_global(cx);
261            assistant.set_enabled(settings.enabled, cx);
262        });
263    })
264    .detach();
265
266    prompt_builder
267}
268
269fn init_language_model_settings(cx: &mut AppContext) {
270    update_active_language_model_from_settings(cx);
271
272    cx.observe_global::<SettingsStore>(update_active_language_model_from_settings)
273        .detach();
274    cx.subscribe(
275        &LanguageModelRegistry::global(cx),
276        |_, event: &language_model::Event, cx| match event {
277            language_model::Event::ProviderStateChanged
278            | language_model::Event::AddedProvider(_)
279            | language_model::Event::RemovedProvider(_) => {
280                update_active_language_model_from_settings(cx);
281            }
282            _ => {}
283        },
284    )
285    .detach();
286}
287
288fn update_active_language_model_from_settings(cx: &mut AppContext) {
289    let settings = AssistantSettings::get_global(cx);
290    let provider_name = LanguageModelProviderId::from(settings.default_model.provider.clone());
291    let model_id = LanguageModelId::from(settings.default_model.model.clone());
292    let inline_alternatives = settings
293        .inline_alternatives
294        .iter()
295        .map(|alternative| {
296            (
297                LanguageModelProviderId::from(alternative.provider.clone()),
298                LanguageModelId::from(alternative.model.clone()),
299            )
300        })
301        .collect::<Vec<_>>();
302    LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
303        registry.select_active_model(&provider_name, &model_id, cx);
304        registry.select_inline_alternative_models(inline_alternatives, cx);
305    });
306}
307
308fn register_slash_commands(prompt_builder: Option<Arc<PromptBuilder>>, cx: &mut AppContext) {
309    let slash_command_registry = SlashCommandRegistry::global(cx);
310
311    slash_command_registry.register_command(assistant_slash_commands::FileSlashCommand, true);
312    slash_command_registry.register_command(assistant_slash_commands::DeltaSlashCommand, true);
313    slash_command_registry.register_command(assistant_slash_commands::OutlineSlashCommand, true);
314    slash_command_registry.register_command(assistant_slash_commands::TabSlashCommand, true);
315    slash_command_registry
316        .register_command(assistant_slash_commands::CargoWorkspaceSlashCommand, true);
317    slash_command_registry.register_command(assistant_slash_commands::PromptSlashCommand, true);
318    slash_command_registry.register_command(assistant_slash_commands::SelectionCommand, true);
319    slash_command_registry.register_command(assistant_slash_commands::DefaultSlashCommand, false);
320    slash_command_registry.register_command(assistant_slash_commands::TerminalSlashCommand, true);
321    slash_command_registry.register_command(assistant_slash_commands::NowSlashCommand, false);
322    slash_command_registry
323        .register_command(assistant_slash_commands::DiagnosticsSlashCommand, true);
324    slash_command_registry.register_command(assistant_slash_commands::FetchSlashCommand, true);
325
326    if let Some(prompt_builder) = prompt_builder {
327        cx.observe_flag::<assistant_slash_commands::ProjectSlashCommandFeatureFlag, _>({
328            let slash_command_registry = slash_command_registry.clone();
329            move |is_enabled, _cx| {
330                if is_enabled {
331                    slash_command_registry.register_command(
332                        assistant_slash_commands::ProjectSlashCommand::new(prompt_builder.clone()),
333                        true,
334                    );
335                }
336            }
337        })
338        .detach();
339    }
340
341    cx.observe_flag::<assistant_slash_commands::AutoSlashCommandFeatureFlag, _>({
342        let slash_command_registry = slash_command_registry.clone();
343        move |is_enabled, _cx| {
344            if is_enabled {
345                // [#auto-staff-ship] TODO remove this when /auto is no longer staff-shipped
346                slash_command_registry
347                    .register_command(assistant_slash_commands::AutoCommand, true);
348            }
349        }
350    })
351    .detach();
352
353    cx.observe_flag::<assistant_slash_commands::StreamingExampleSlashCommandFeatureFlag, _>({
354        let slash_command_registry = slash_command_registry.clone();
355        move |is_enabled, _cx| {
356            if is_enabled {
357                slash_command_registry.register_command(
358                    assistant_slash_commands::StreamingExampleSlashCommand,
359                    false,
360                );
361            }
362        }
363    })
364    .detach();
365
366    update_slash_commands_from_settings(cx);
367    cx.observe_global::<SettingsStore>(update_slash_commands_from_settings)
368        .detach();
369
370    cx.observe_flag::<assistant_slash_commands::SearchSlashCommandFeatureFlag, _>({
371        let slash_command_registry = slash_command_registry.clone();
372        move |is_enabled, _cx| {
373            if is_enabled {
374                slash_command_registry
375                    .register_command(assistant_slash_commands::SearchSlashCommand, true);
376            }
377        }
378    })
379    .detach();
380}
381
382fn update_slash_commands_from_settings(cx: &mut AppContext) {
383    let slash_command_registry = SlashCommandRegistry::global(cx);
384    let settings = SlashCommandSettings::get_global(cx);
385
386    if settings.docs.enabled {
387        slash_command_registry.register_command(assistant_slash_commands::DocsSlashCommand, true);
388    } else {
389        slash_command_registry.unregister_command(assistant_slash_commands::DocsSlashCommand);
390    }
391
392    if settings.cargo_workspace.enabled {
393        slash_command_registry
394            .register_command(assistant_slash_commands::CargoWorkspaceSlashCommand, true);
395    } else {
396        slash_command_registry
397            .unregister_command(assistant_slash_commands::CargoWorkspaceSlashCommand);
398    }
399}
400
401pub fn humanize_token_count(count: usize) -> String {
402    match count {
403        0..=999 => count.to_string(),
404        1000..=9999 => {
405            let thousands = count / 1000;
406            let hundreds = (count % 1000 + 50) / 100;
407            if hundreds == 0 {
408                format!("{}k", thousands)
409            } else if hundreds == 10 {
410                format!("{}k", thousands + 1)
411            } else {
412                format!("{}.{}k", thousands, hundreds)
413            }
414        }
415        _ => format!("{}k", (count + 500) / 1000),
416    }
417}
418
419#[cfg(test)]
420#[ctor::ctor]
421fn init_logger() {
422    if std::env::var("RUST_LOG").is_ok() {
423        env_logger::init();
424    }
425}