assistant.rs

  1#![cfg_attr(target_os = "windows", allow(unused, dead_code))]
  2
  3pub mod assistant_panel;
  4pub mod assistant_settings;
  5mod context;
  6pub mod context_store;
  7mod inline_assistant;
  8mod model_selector;
  9mod prompt_library;
 10mod prompts;
 11mod slash_command;
 12pub(crate) mod slash_command_picker;
 13pub mod slash_command_settings;
 14mod streaming_diff;
 15mod terminal_inline_assistant;
 16mod workflow;
 17
 18pub use assistant_panel::{AssistantPanel, AssistantPanelEvent};
 19use assistant_settings::AssistantSettings;
 20use assistant_slash_command::SlashCommandRegistry;
 21use client::{proto, Client};
 22use command_palette_hooks::CommandPaletteFilter;
 23pub use context::*;
 24use context_servers::ContextServerRegistry;
 25pub use context_store::*;
 26use feature_flags::FeatureFlagAppExt;
 27use fs::Fs;
 28use gpui::Context as _;
 29use gpui::{actions, AppContext, Global, SharedString, UpdateGlobal};
 30use indexed_docs::IndexedDocsRegistry;
 31pub(crate) use inline_assistant::*;
 32use language_model::{
 33    LanguageModelId, LanguageModelProviderId, LanguageModelRegistry, LanguageModelResponseMessage,
 34};
 35pub(crate) use model_selector::*;
 36pub use prompts::PromptBuilder;
 37use prompts::PromptLoadingParams;
 38use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
 39use serde::{Deserialize, Serialize};
 40use settings::{update_settings_file, Settings, SettingsStore};
 41use slash_command::{
 42    context_server_command, default_command, diagnostics_command, docs_command, fetch_command,
 43    file_command, now_command, project_command, prompt_command, search_command, symbols_command,
 44    tab_command, terminal_command, workflow_command,
 45};
 46use std::sync::Arc;
 47pub(crate) use streaming_diff::*;
 48use util::ResultExt;
 49pub use workflow::*;
 50
 51use crate::slash_command_settings::SlashCommandSettings;
 52
 53actions!(
 54    assistant,
 55    [
 56        Assist,
 57        Split,
 58        CycleMessageRole,
 59        QuoteSelection,
 60        InsertIntoEditor,
 61        ToggleFocus,
 62        InsertActivePrompt,
 63        DeployHistory,
 64        DeployPromptLibrary,
 65        ConfirmCommand,
 66        ToggleModelSelector,
 67    ]
 68);
 69
 70const DEFAULT_CONTEXT_LINES: usize = 50;
 71
 72#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
 73pub struct MessageId(clock::Lamport);
 74
 75impl MessageId {
 76    pub fn as_u64(self) -> u64 {
 77        self.0.as_u64()
 78    }
 79}
 80
 81#[derive(Deserialize, Debug)]
 82pub struct LanguageModelUsage {
 83    pub prompt_tokens: u32,
 84    pub completion_tokens: u32,
 85    pub total_tokens: u32,
 86}
 87
 88#[derive(Deserialize, Debug)]
 89pub struct LanguageModelChoiceDelta {
 90    pub index: u32,
 91    pub delta: LanguageModelResponseMessage,
 92    pub finish_reason: Option<String>,
 93}
 94
 95#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
 96pub enum MessageStatus {
 97    Pending,
 98    Done,
 99    Error(SharedString),
100    Canceled,
101}
102
103impl MessageStatus {
104    pub fn from_proto(status: proto::ContextMessageStatus) -> MessageStatus {
105        match status.variant {
106            Some(proto::context_message_status::Variant::Pending(_)) => MessageStatus::Pending,
107            Some(proto::context_message_status::Variant::Done(_)) => MessageStatus::Done,
108            Some(proto::context_message_status::Variant::Error(error)) => {
109                MessageStatus::Error(error.message.into())
110            }
111            Some(proto::context_message_status::Variant::Canceled(_)) => MessageStatus::Canceled,
112            None => MessageStatus::Pending,
113        }
114    }
115
116    pub fn to_proto(&self) -> proto::ContextMessageStatus {
117        match self {
118            MessageStatus::Pending => proto::ContextMessageStatus {
119                variant: Some(proto::context_message_status::Variant::Pending(
120                    proto::context_message_status::Pending {},
121                )),
122            },
123            MessageStatus::Done => proto::ContextMessageStatus {
124                variant: Some(proto::context_message_status::Variant::Done(
125                    proto::context_message_status::Done {},
126                )),
127            },
128            MessageStatus::Error(message) => proto::ContextMessageStatus {
129                variant: Some(proto::context_message_status::Variant::Error(
130                    proto::context_message_status::Error {
131                        message: message.to_string(),
132                    },
133                )),
134            },
135            MessageStatus::Canceled => proto::ContextMessageStatus {
136                variant: Some(proto::context_message_status::Variant::Canceled(
137                    proto::context_message_status::Canceled {},
138                )),
139            },
140        }
141    }
142}
143
144/// The state pertaining to the Assistant.
145#[derive(Default)]
146struct Assistant {
147    /// Whether the Assistant is enabled.
148    enabled: bool,
149}
150
151impl Global for Assistant {}
152
153impl Assistant {
154    const NAMESPACE: &'static str = "assistant";
155
156    fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
157        if self.enabled == enabled {
158            return;
159        }
160
161        self.enabled = enabled;
162
163        if !enabled {
164            CommandPaletteFilter::update_global(cx, |filter, _cx| {
165                filter.hide_namespace(Self::NAMESPACE);
166            });
167
168            return;
169        }
170
171        CommandPaletteFilter::update_global(cx, |filter, _cx| {
172            filter.show_namespace(Self::NAMESPACE);
173        });
174    }
175}
176
177pub fn init(
178    fs: Arc<dyn Fs>,
179    client: Arc<Client>,
180    stdout_is_a_pty: bool,
181    cx: &mut AppContext,
182) -> Arc<PromptBuilder> {
183    cx.set_global(Assistant::default());
184    AssistantSettings::register(cx);
185    SlashCommandSettings::register(cx);
186
187    // TODO: remove this when 0.148.0 is released.
188    if AssistantSettings::get_global(cx).using_outdated_settings_version {
189        update_settings_file::<AssistantSettings>(fs.clone(), cx, {
190            let fs = fs.clone();
191            |content, cx| {
192                content.update_file(fs, cx);
193            }
194        });
195    }
196
197    cx.spawn(|mut cx| {
198        let client = client.clone();
199        async move {
200            let embedding_provider = CloudEmbeddingProvider::new(client.clone());
201            let semantic_index = SemanticIndex::new(
202                paths::embeddings_dir().join("semantic-index-db.0.mdb"),
203                Arc::new(embedding_provider),
204                &mut cx,
205            )
206            .await?;
207            cx.update(|cx| cx.set_global(semantic_index))
208        }
209    })
210    .detach();
211
212    context_store::init(&client);
213    prompt_library::init(cx);
214    init_language_model_settings(cx);
215    assistant_slash_command::init(cx);
216    assistant_panel::init(cx);
217    context_servers::init(cx);
218
219    let prompt_builder = prompts::PromptBuilder::new(Some(PromptLoadingParams {
220        fs: fs.clone(),
221        repo_path: stdout_is_a_pty
222            .then(|| std::env::current_dir().log_err())
223            .flatten(),
224        cx,
225    }))
226    .log_err()
227    .map(Arc::new)
228    .unwrap_or_else(|| Arc::new(prompts::PromptBuilder::new(None).unwrap()));
229    register_slash_commands(Some(prompt_builder.clone()), cx);
230    inline_assistant::init(
231        fs.clone(),
232        prompt_builder.clone(),
233        client.telemetry().clone(),
234        cx,
235    );
236    terminal_inline_assistant::init(
237        fs.clone(),
238        prompt_builder.clone(),
239        client.telemetry().clone(),
240        cx,
241    );
242    IndexedDocsRegistry::init_global(cx);
243
244    CommandPaletteFilter::update_global(cx, |filter, _cx| {
245        filter.hide_namespace(Assistant::NAMESPACE);
246    });
247    Assistant::update_global(cx, |assistant, cx| {
248        let settings = AssistantSettings::get_global(cx);
249
250        assistant.set_enabled(settings.enabled, cx);
251    });
252    cx.observe_global::<SettingsStore>(|cx| {
253        Assistant::update_global(cx, |assistant, cx| {
254            let settings = AssistantSettings::get_global(cx);
255            assistant.set_enabled(settings.enabled, cx);
256        });
257    })
258    .detach();
259
260    register_context_server_handlers(cx);
261
262    prompt_builder
263}
264
265fn register_context_server_handlers(cx: &mut AppContext) {
266    cx.subscribe(
267        &context_servers::manager::ContextServerManager::global(cx),
268        |manager, event, cx| match event {
269            context_servers::manager::Event::ServerStarted { server_id } => {
270                cx.update_model(
271                    &manager,
272                    |manager: &mut context_servers::manager::ContextServerManager, cx| {
273                        let slash_command_registry = SlashCommandRegistry::global(cx);
274                        let context_server_registry = ContextServerRegistry::global(cx);
275                        if let Some(server) = manager.get_server(server_id) {
276                            cx.spawn(|_, _| async move {
277                                let Some(protocol) = server.client.read().clone() else {
278                                    return;
279                                };
280
281                                if let Some(prompts) = protocol.list_prompts().await.log_err() {
282                                    for prompt in prompts
283                                        .into_iter()
284                                        .filter(context_server_command::acceptable_prompt)
285                                    {
286                                        log::info!(
287                                            "registering context server command: {:?}",
288                                            prompt.name
289                                        );
290                                        context_server_registry.register_command(
291                                            server.id.clone(),
292                                            prompt.name.as_str(),
293                                        );
294                                        slash_command_registry.register_command(
295                                            context_server_command::ContextServerSlashCommand::new(
296                                                &server, prompt,
297                                            ),
298                                            true,
299                                        );
300                                    }
301                                }
302                            })
303                            .detach();
304                        }
305                    },
306                );
307            }
308            context_servers::manager::Event::ServerStopped { server_id } => {
309                let slash_command_registry = SlashCommandRegistry::global(cx);
310                let context_server_registry = ContextServerRegistry::global(cx);
311                if let Some(commands) = context_server_registry.get_commands(server_id) {
312                    for command_name in commands {
313                        slash_command_registry.unregister_command_by_name(&command_name);
314                        context_server_registry.unregister_command(&server_id, &command_name);
315                    }
316                }
317            }
318        },
319    )
320    .detach();
321}
322
323fn init_language_model_settings(cx: &mut AppContext) {
324    update_active_language_model_from_settings(cx);
325
326    cx.observe_global::<SettingsStore>(update_active_language_model_from_settings)
327        .detach();
328    cx.subscribe(
329        &LanguageModelRegistry::global(cx),
330        |_, event: &language_model::Event, cx| match event {
331            language_model::Event::ProviderStateChanged
332            | language_model::Event::AddedProvider(_)
333            | language_model::Event::RemovedProvider(_) => {
334                update_active_language_model_from_settings(cx);
335            }
336            _ => {}
337        },
338    )
339    .detach();
340}
341
342fn update_active_language_model_from_settings(cx: &mut AppContext) {
343    let settings = AssistantSettings::get_global(cx);
344    let provider_name = LanguageModelProviderId::from(settings.default_model.provider.clone());
345    let model_id = LanguageModelId::from(settings.default_model.model.clone());
346    LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
347        registry.select_active_model(&provider_name, &model_id, cx);
348    });
349}
350
351fn register_slash_commands(prompt_builder: Option<Arc<PromptBuilder>>, cx: &mut AppContext) {
352    let slash_command_registry = SlashCommandRegistry::global(cx);
353    slash_command_registry.register_command(file_command::FileSlashCommand, true);
354    slash_command_registry.register_command(symbols_command::OutlineSlashCommand, true);
355    slash_command_registry.register_command(tab_command::TabSlashCommand, true);
356    slash_command_registry.register_command(project_command::ProjectSlashCommand, true);
357    slash_command_registry.register_command(prompt_command::PromptSlashCommand, true);
358    slash_command_registry.register_command(default_command::DefaultSlashCommand, false);
359    slash_command_registry.register_command(terminal_command::TerminalSlashCommand, true);
360    slash_command_registry.register_command(now_command::NowSlashCommand, false);
361    slash_command_registry.register_command(diagnostics_command::DiagnosticsSlashCommand, true);
362
363    if let Some(prompt_builder) = prompt_builder {
364        slash_command_registry.register_command(
365            workflow_command::WorkflowSlashCommand::new(prompt_builder),
366            true,
367        );
368    }
369    slash_command_registry.register_command(fetch_command::FetchSlashCommand, false);
370
371    update_slash_commands_from_settings(cx);
372    cx.observe_global::<SettingsStore>(update_slash_commands_from_settings)
373        .detach();
374
375    cx.observe_flag::<search_command::SearchSlashCommandFeatureFlag, _>({
376        let slash_command_registry = slash_command_registry.clone();
377        move |is_enabled, _cx| {
378            if is_enabled {
379                slash_command_registry.register_command(search_command::SearchSlashCommand, true);
380            }
381        }
382    })
383    .detach();
384}
385
386fn update_slash_commands_from_settings(cx: &mut AppContext) {
387    let slash_command_registry = SlashCommandRegistry::global(cx);
388    let settings = SlashCommandSettings::get_global(cx);
389
390    if settings.docs.enabled {
391        slash_command_registry.register_command(docs_command::DocsSlashCommand, true);
392    } else {
393        slash_command_registry.unregister_command(docs_command::DocsSlashCommand);
394    }
395
396    if settings.project.enabled {
397        slash_command_registry.register_command(project_command::ProjectSlashCommand, true);
398    } else {
399        slash_command_registry.unregister_command(project_command::ProjectSlashCommand);
400    }
401}
402
403pub fn humanize_token_count(count: usize) -> String {
404    match count {
405        0..=999 => count.to_string(),
406        1000..=9999 => {
407            let thousands = count / 1000;
408            let hundreds = (count % 1000 + 50) / 100;
409            if hundreds == 0 {
410                format!("{}k", thousands)
411            } else if hundreds == 10 {
412                format!("{}k", thousands + 1)
413            } else {
414                format!("{}.{}k", thousands, hundreds)
415            }
416        }
417        _ => format!("{}k", (count + 500) / 1000),
418    }
419}
420
421#[cfg(test)]
422#[ctor::ctor]
423fn init_logger() {
424    if std::env::var("RUST_LOG").is_ok() {
425        env_logger::init();
426    }
427}