assistant.rs

  1#![cfg_attr(target_os = "windows", allow(unused, dead_code))]
  2
  3pub mod assistant_panel;
  4pub mod assistant_settings;
  5mod context;
  6pub mod context_store;
  7mod inline_assistant;
  8mod model_selector;
  9mod patch;
 10mod prompt_library;
 11mod prompts;
 12mod slash_command;
 13pub(crate) mod slash_command_picker;
 14pub mod slash_command_settings;
 15mod streaming_diff;
 16mod terminal_inline_assistant;
 17mod tools;
 18
 19pub use assistant_panel::{AssistantPanel, AssistantPanelEvent};
 20use assistant_settings::AssistantSettings;
 21use assistant_slash_command::SlashCommandRegistry;
 22use assistant_tool::ToolRegistry;
 23use client::{proto, Client};
 24use command_palette_hooks::CommandPaletteFilter;
 25pub use context::*;
 26use context_servers::ContextServerRegistry;
 27pub use context_store::*;
 28use feature_flags::FeatureFlagAppExt;
 29use fs::Fs;
 30use gpui::{actions, AppContext, Global, SharedString, UpdateGlobal};
 31use gpui::{impl_actions, Context as _};
 32use indexed_docs::IndexedDocsRegistry;
 33pub(crate) use inline_assistant::*;
 34use language_model::{
 35    LanguageModelId, LanguageModelProviderId, LanguageModelRegistry, LanguageModelResponseMessage,
 36};
 37pub(crate) use model_selector::*;
 38pub use patch::*;
 39pub use prompts::PromptBuilder;
 40use prompts::PromptLoadingParams;
 41use semantic_index::{CloudEmbeddingProvider, SemanticDb};
 42use serde::{Deserialize, Serialize};
 43use settings::{update_settings_file, Settings, SettingsStore};
 44use slash_command::{
 45    auto_command, cargo_workspace_command, context_server_command, default_command, delta_command,
 46    diagnostics_command, docs_command, fetch_command, file_command, now_command, project_command,
 47    prompt_command, search_command, selection_command, symbols_command, tab_command,
 48    terminal_command,
 49};
 50use std::path::PathBuf;
 51use std::sync::Arc;
 52pub(crate) use streaming_diff::*;
 53use util::ResultExt;
 54
 55use crate::slash_command::streaming_example_command;
 56use crate::slash_command_settings::SlashCommandSettings;
 57
 58actions!(
 59    assistant,
 60    [
 61        Assist,
 62        Edit,
 63        Split,
 64        CopyCode,
 65        CycleMessageRole,
 66        QuoteSelection,
 67        InsertIntoEditor,
 68        ToggleFocus,
 69        InsertActivePrompt,
 70        DeployHistory,
 71        DeployPromptLibrary,
 72        ConfirmCommand,
 73        NewContext,
 74        ToggleModelSelector,
 75        CycleNextInlineAssist,
 76        CyclePreviousInlineAssist
 77    ]
 78);
 79
 80#[derive(PartialEq, Clone, Deserialize)]
 81pub enum InsertDraggedFiles {
 82    ProjectPaths(Vec<PathBuf>),
 83    ExternalFiles(Vec<PathBuf>),
 84}
 85
 86impl_actions!(assistant, [InsertDraggedFiles]);
 87
 88const DEFAULT_CONTEXT_LINES: usize = 50;
 89
 90#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
 91pub struct MessageId(clock::Lamport);
 92
 93impl MessageId {
 94    pub fn as_u64(self) -> u64 {
 95        self.0.as_u64()
 96    }
 97}
 98
 99#[derive(Deserialize, Debug)]
100pub struct LanguageModelUsage {
101    pub prompt_tokens: u32,
102    pub completion_tokens: u32,
103    pub total_tokens: u32,
104}
105
106#[derive(Deserialize, Debug)]
107pub struct LanguageModelChoiceDelta {
108    pub index: u32,
109    pub delta: LanguageModelResponseMessage,
110    pub finish_reason: Option<String>,
111}
112
113#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
114pub enum MessageStatus {
115    Pending,
116    Done,
117    Error(SharedString),
118    Canceled,
119}
120
121impl MessageStatus {
122    pub fn from_proto(status: proto::ContextMessageStatus) -> MessageStatus {
123        match status.variant {
124            Some(proto::context_message_status::Variant::Pending(_)) => MessageStatus::Pending,
125            Some(proto::context_message_status::Variant::Done(_)) => MessageStatus::Done,
126            Some(proto::context_message_status::Variant::Error(error)) => {
127                MessageStatus::Error(error.message.into())
128            }
129            Some(proto::context_message_status::Variant::Canceled(_)) => MessageStatus::Canceled,
130            None => MessageStatus::Pending,
131        }
132    }
133
134    pub fn to_proto(&self) -> proto::ContextMessageStatus {
135        match self {
136            MessageStatus::Pending => proto::ContextMessageStatus {
137                variant: Some(proto::context_message_status::Variant::Pending(
138                    proto::context_message_status::Pending {},
139                )),
140            },
141            MessageStatus::Done => proto::ContextMessageStatus {
142                variant: Some(proto::context_message_status::Variant::Done(
143                    proto::context_message_status::Done {},
144                )),
145            },
146            MessageStatus::Error(message) => proto::ContextMessageStatus {
147                variant: Some(proto::context_message_status::Variant::Error(
148                    proto::context_message_status::Error {
149                        message: message.to_string(),
150                    },
151                )),
152            },
153            MessageStatus::Canceled => proto::ContextMessageStatus {
154                variant: Some(proto::context_message_status::Variant::Canceled(
155                    proto::context_message_status::Canceled {},
156                )),
157            },
158        }
159    }
160}
161
162/// The state pertaining to the Assistant.
163#[derive(Default)]
164struct Assistant {
165    /// Whether the Assistant is enabled.
166    enabled: bool,
167}
168
169impl Global for Assistant {}
170
171impl Assistant {
172    const NAMESPACE: &'static str = "assistant";
173
174    fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
175        if self.enabled == enabled {
176            return;
177        }
178
179        self.enabled = enabled;
180
181        if !enabled {
182            CommandPaletteFilter::update_global(cx, |filter, _cx| {
183                filter.hide_namespace(Self::NAMESPACE);
184            });
185
186            return;
187        }
188
189        CommandPaletteFilter::update_global(cx, |filter, _cx| {
190            filter.show_namespace(Self::NAMESPACE);
191        });
192    }
193}
194
195pub fn init(
196    fs: Arc<dyn Fs>,
197    client: Arc<Client>,
198    stdout_is_a_pty: bool,
199    cx: &mut AppContext,
200) -> Arc<PromptBuilder> {
201    cx.set_global(Assistant::default());
202    AssistantSettings::register(cx);
203    SlashCommandSettings::register(cx);
204
205    // TODO: remove this when 0.148.0 is released.
206    if AssistantSettings::get_global(cx).using_outdated_settings_version {
207        update_settings_file::<AssistantSettings>(fs.clone(), cx, {
208            let fs = fs.clone();
209            |content, cx| {
210                content.update_file(fs, cx);
211            }
212        });
213    }
214
215    cx.spawn(|mut cx| {
216        let client = client.clone();
217        async move {
218            let embedding_provider = CloudEmbeddingProvider::new(client.clone());
219            let semantic_index = SemanticDb::new(
220                paths::embeddings_dir().join("semantic-index-db.0.mdb"),
221                Arc::new(embedding_provider),
222                &mut cx,
223            )
224            .await?;
225
226            cx.update(|cx| cx.set_global(semantic_index))
227        }
228    })
229    .detach();
230
231    context_store::init(&client.clone().into());
232    prompt_library::init(cx);
233    init_language_model_settings(cx);
234    assistant_slash_command::init(cx);
235    assistant_tool::init(cx);
236    assistant_panel::init(cx);
237    context_servers::init(cx);
238
239    let prompt_builder = prompts::PromptBuilder::new(Some(PromptLoadingParams {
240        fs: fs.clone(),
241        repo_path: stdout_is_a_pty
242            .then(|| std::env::current_dir().log_err())
243            .flatten(),
244        cx,
245    }))
246    .log_err()
247    .map(Arc::new)
248    .unwrap_or_else(|| Arc::new(prompts::PromptBuilder::new(None).unwrap()));
249    register_slash_commands(Some(prompt_builder.clone()), cx);
250    register_tools(cx);
251    inline_assistant::init(
252        fs.clone(),
253        prompt_builder.clone(),
254        client.telemetry().clone(),
255        cx,
256    );
257    terminal_inline_assistant::init(
258        fs.clone(),
259        prompt_builder.clone(),
260        client.telemetry().clone(),
261        cx,
262    );
263    IndexedDocsRegistry::init_global(cx);
264
265    CommandPaletteFilter::update_global(cx, |filter, _cx| {
266        filter.hide_namespace(Assistant::NAMESPACE);
267    });
268    Assistant::update_global(cx, |assistant, cx| {
269        let settings = AssistantSettings::get_global(cx);
270
271        assistant.set_enabled(settings.enabled, cx);
272    });
273    cx.observe_global::<SettingsStore>(|cx| {
274        Assistant::update_global(cx, |assistant, cx| {
275            let settings = AssistantSettings::get_global(cx);
276            assistant.set_enabled(settings.enabled, cx);
277        });
278    })
279    .detach();
280
281    register_context_server_handlers(cx);
282
283    prompt_builder
284}
285
286fn register_context_server_handlers(cx: &mut AppContext) {
287    cx.subscribe(
288        &context_servers::manager::ContextServerManager::global(cx),
289        |manager, event, cx| match event {
290            context_servers::manager::Event::ServerStarted { server_id } => {
291                cx.update_model(
292                    &manager,
293                    |manager: &mut context_servers::manager::ContextServerManager, cx| {
294                        let slash_command_registry = SlashCommandRegistry::global(cx);
295                        let context_server_registry = ContextServerRegistry::global(cx);
296                        if let Some(server) = manager.get_server(server_id) {
297                            cx.spawn(|_, _| async move {
298                                let Some(protocol) = server.client.read().clone() else {
299                                    return;
300                                };
301
302                                if protocol.capable(context_servers::protocol::ServerCapability::Prompts) {
303                                    if let Some(prompts) = protocol.list_prompts().await.log_err() {
304                                        for prompt in prompts
305                                            .into_iter()
306                                            .filter(context_server_command::acceptable_prompt)
307                                        {
308                                            log::info!(
309                                                "registering context server command: {:?}",
310                                                prompt.name
311                                            );
312                                            context_server_registry.register_command(
313                                                server.id.clone(),
314                                                prompt.name.as_str(),
315                                            );
316                                            slash_command_registry.register_command(
317                                                context_server_command::ContextServerSlashCommand::new(
318                                                    &server, prompt,
319                                                ),
320                                                true,
321                                            );
322                                        }
323                                    }
324                                }
325                            })
326                            .detach();
327                        }
328                    },
329                );
330
331                cx.update_model(
332                    &manager,
333                    |manager: &mut context_servers::manager::ContextServerManager, cx| {
334                        let tool_registry = ToolRegistry::global(cx);
335                        let context_server_registry = ContextServerRegistry::global(cx);
336                        if let Some(server) = manager.get_server(server_id) {
337                            cx.spawn(|_, _| async move {
338                                let Some(protocol) = server.client.read().clone() else {
339                                    return;
340                                };
341
342                                if protocol.capable(context_servers::protocol::ServerCapability::Tools) {
343                                    if let Some(tools) = protocol.list_tools().await.log_err() {
344                                        for tool in tools.tools {
345                                            log::info!(
346                                                "registering context server tool: {:?}",
347                                                tool.name
348                                            );
349                                            context_server_registry.register_tool(
350                                                server.id.clone(),
351                                                tool.name.as_str(),
352                                            );
353                                            tool_registry.register_tool(
354                                                tools::context_server_tool::ContextServerTool::new(
355                                                    server.id.clone(),
356                                                    tool
357                                                ),
358                                            );
359                                        }
360                                    }
361                                }
362                            })
363                            .detach();
364                        }
365                    },
366                );
367            }
368            context_servers::manager::Event::ServerStopped { server_id } => {
369                let slash_command_registry = SlashCommandRegistry::global(cx);
370                let context_server_registry = ContextServerRegistry::global(cx);
371                if let Some(commands) = context_server_registry.get_commands(server_id) {
372                    for command_name in commands {
373                        slash_command_registry.unregister_command_by_name(&command_name);
374                        context_server_registry.unregister_command(&server_id, &command_name);
375                    }
376                }
377
378                if let Some(tools) = context_server_registry.get_tools(server_id) {
379                    let tool_registry = ToolRegistry::global(cx);
380                    for tool_name in tools {
381                        tool_registry.unregister_tool_by_name(&tool_name);
382                        context_server_registry.unregister_tool(&server_id, &tool_name);
383                    }
384                }
385            }
386        },
387    )
388    .detach();
389}
390
391fn init_language_model_settings(cx: &mut AppContext) {
392    update_active_language_model_from_settings(cx);
393
394    cx.observe_global::<SettingsStore>(update_active_language_model_from_settings)
395        .detach();
396    cx.subscribe(
397        &LanguageModelRegistry::global(cx),
398        |_, event: &language_model::Event, cx| match event {
399            language_model::Event::ProviderStateChanged
400            | language_model::Event::AddedProvider(_)
401            | language_model::Event::RemovedProvider(_) => {
402                update_active_language_model_from_settings(cx);
403            }
404            _ => {}
405        },
406    )
407    .detach();
408}
409
410fn update_active_language_model_from_settings(cx: &mut AppContext) {
411    let settings = AssistantSettings::get_global(cx);
412    let provider_name = LanguageModelProviderId::from(settings.default_model.provider.clone());
413    let model_id = LanguageModelId::from(settings.default_model.model.clone());
414    let inline_alternatives = settings
415        .inline_alternatives
416        .iter()
417        .map(|alternative| {
418            (
419                LanguageModelProviderId::from(alternative.provider.clone()),
420                LanguageModelId::from(alternative.model.clone()),
421            )
422        })
423        .collect::<Vec<_>>();
424    LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
425        registry.select_active_model(&provider_name, &model_id, cx);
426        registry.select_inline_alternative_models(inline_alternatives, cx);
427    });
428}
429
430fn register_slash_commands(prompt_builder: Option<Arc<PromptBuilder>>, cx: &mut AppContext) {
431    let slash_command_registry = SlashCommandRegistry::global(cx);
432
433    slash_command_registry.register_command(file_command::FileSlashCommand, true);
434    slash_command_registry.register_command(delta_command::DeltaSlashCommand, true);
435    slash_command_registry.register_command(symbols_command::OutlineSlashCommand, true);
436    slash_command_registry.register_command(tab_command::TabSlashCommand, true);
437    slash_command_registry
438        .register_command(cargo_workspace_command::CargoWorkspaceSlashCommand, true);
439    slash_command_registry.register_command(prompt_command::PromptSlashCommand, true);
440    slash_command_registry.register_command(selection_command::SelectionCommand, true);
441    slash_command_registry.register_command(default_command::DefaultSlashCommand, false);
442    slash_command_registry.register_command(terminal_command::TerminalSlashCommand, true);
443    slash_command_registry.register_command(now_command::NowSlashCommand, false);
444    slash_command_registry.register_command(diagnostics_command::DiagnosticsSlashCommand, true);
445    slash_command_registry.register_command(fetch_command::FetchSlashCommand, false);
446    slash_command_registry.register_command(fetch_command::FetchSlashCommand, false);
447
448    if let Some(prompt_builder) = prompt_builder {
449        cx.observe_flag::<project_command::ProjectSlashCommandFeatureFlag, _>({
450            let slash_command_registry = slash_command_registry.clone();
451            move |is_enabled, _cx| {
452                if is_enabled {
453                    slash_command_registry.register_command(
454                        project_command::ProjectSlashCommand::new(prompt_builder.clone()),
455                        true,
456                    );
457                }
458            }
459        })
460        .detach();
461    }
462
463    cx.observe_flag::<auto_command::AutoSlashCommandFeatureFlag, _>({
464        let slash_command_registry = slash_command_registry.clone();
465        move |is_enabled, _cx| {
466            if is_enabled {
467                // [#auto-staff-ship] TODO remove this when /auto is no longer staff-shipped
468                slash_command_registry.register_command(auto_command::AutoCommand, true);
469            }
470        }
471    })
472    .detach();
473
474    cx.observe_flag::<streaming_example_command::StreamingExampleSlashCommandFeatureFlag, _>({
475        let slash_command_registry = slash_command_registry.clone();
476        move |is_enabled, _cx| {
477            if is_enabled {
478                slash_command_registry.register_command(
479                    streaming_example_command::StreamingExampleSlashCommand,
480                    false,
481                );
482            }
483        }
484    })
485    .detach();
486
487    update_slash_commands_from_settings(cx);
488    cx.observe_global::<SettingsStore>(update_slash_commands_from_settings)
489        .detach();
490
491    cx.observe_flag::<search_command::SearchSlashCommandFeatureFlag, _>({
492        let slash_command_registry = slash_command_registry.clone();
493        move |is_enabled, _cx| {
494            if is_enabled {
495                slash_command_registry.register_command(search_command::SearchSlashCommand, true);
496            }
497        }
498    })
499    .detach();
500}
501
502fn update_slash_commands_from_settings(cx: &mut AppContext) {
503    let slash_command_registry = SlashCommandRegistry::global(cx);
504    let settings = SlashCommandSettings::get_global(cx);
505
506    if settings.docs.enabled {
507        slash_command_registry.register_command(docs_command::DocsSlashCommand, true);
508    } else {
509        slash_command_registry.unregister_command(docs_command::DocsSlashCommand);
510    }
511
512    if settings.cargo_workspace.enabled {
513        slash_command_registry
514            .register_command(cargo_workspace_command::CargoWorkspaceSlashCommand, true);
515    } else {
516        slash_command_registry
517            .unregister_command(cargo_workspace_command::CargoWorkspaceSlashCommand);
518    }
519}
520
521fn register_tools(cx: &mut AppContext) {
522    let tool_registry = ToolRegistry::global(cx);
523    tool_registry.register_tool(tools::now_tool::NowTool);
524}
525
526pub fn humanize_token_count(count: usize) -> String {
527    match count {
528        0..=999 => count.to_string(),
529        1000..=9999 => {
530            let thousands = count / 1000;
531            let hundreds = (count % 1000 + 50) / 100;
532            if hundreds == 0 {
533                format!("{}k", thousands)
534            } else if hundreds == 10 {
535                format!("{}k", thousands + 1)
536            } else {
537                format!("{}.{}k", thousands, hundreds)
538            }
539        }
540        _ => format!("{}k", (count + 500) / 1000),
541    }
542}
543
544#[cfg(test)]
545#[ctor::ctor]
546fn init_logger() {
547    if std::env::var("RUST_LOG").is_ok() {
548        env_logger::init();
549    }
550}