assistant.rs

  1#![cfg_attr(target_os = "windows", allow(unused, dead_code))]
  2
  3pub mod assistant_panel;
  4pub mod assistant_settings;
  5mod context;
  6pub mod context_store;
  7mod inline_assistant;
  8mod model_selector;
  9mod patch;
 10mod prompt_library;
 11mod prompts;
 12mod slash_command;
 13pub(crate) mod slash_command_picker;
 14pub mod slash_command_settings;
 15mod streaming_diff;
 16mod terminal_inline_assistant;
 17mod tools;
 18
 19pub use assistant_panel::{AssistantPanel, AssistantPanelEvent};
 20use assistant_settings::AssistantSettings;
 21use assistant_slash_command::SlashCommandRegistry;
 22use assistant_tool::ToolRegistry;
 23use client::{proto, Client};
 24use command_palette_hooks::CommandPaletteFilter;
 25pub use context::*;
 26use context_servers::ContextServerRegistry;
 27pub use context_store::*;
 28use feature_flags::FeatureFlagAppExt;
 29use fs::Fs;
 30use gpui::{actions, AppContext, Global, SharedString, UpdateGlobal};
 31use gpui::{impl_actions, Context as _};
 32use indexed_docs::IndexedDocsRegistry;
 33pub(crate) use inline_assistant::*;
 34use language_model::{
 35    LanguageModelId, LanguageModelProviderId, LanguageModelRegistry, LanguageModelResponseMessage,
 36};
 37pub(crate) use model_selector::*;
 38pub use patch::*;
 39pub use prompts::PromptBuilder;
 40use prompts::PromptLoadingParams;
 41use semantic_index::{CloudEmbeddingProvider, SemanticDb};
 42use serde::{Deserialize, Serialize};
 43use settings::{update_settings_file, Settings, SettingsStore};
 44use slash_command::search_command::SearchSlashCommandFeatureFlag;
 45use slash_command::{
 46    auto_command, cargo_workspace_command, context_server_command, default_command, delta_command,
 47    diagnostics_command, docs_command, fetch_command, file_command, now_command, project_command,
 48    prompt_command, search_command, selection_command, symbols_command, tab_command,
 49    terminal_command,
 50};
 51use std::path::PathBuf;
 52use std::sync::Arc;
 53pub(crate) use streaming_diff::*;
 54use util::ResultExt;
 55
 56use crate::slash_command::streaming_example_command;
 57use crate::slash_command_settings::SlashCommandSettings;
 58
 59actions!(
 60    assistant,
 61    [
 62        Assist,
 63        Edit,
 64        Split,
 65        CopyCode,
 66        CycleMessageRole,
 67        QuoteSelection,
 68        InsertIntoEditor,
 69        ToggleFocus,
 70        InsertActivePrompt,
 71        DeployHistory,
 72        DeployPromptLibrary,
 73        ConfirmCommand,
 74        NewContext,
 75        ToggleModelSelector,
 76        CycleNextInlineAssist,
 77        CyclePreviousInlineAssist
 78    ]
 79);
 80
 81#[derive(PartialEq, Clone, Deserialize)]
 82pub enum InsertDraggedFiles {
 83    ProjectPaths(Vec<PathBuf>),
 84    ExternalFiles(Vec<PathBuf>),
 85}
 86
 87impl_actions!(assistant, [InsertDraggedFiles]);
 88
 89const DEFAULT_CONTEXT_LINES: usize = 50;
 90
 91#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
 92pub struct MessageId(clock::Lamport);
 93
 94impl MessageId {
 95    pub fn as_u64(self) -> u64 {
 96        self.0.as_u64()
 97    }
 98}
 99
100#[derive(Deserialize, Debug)]
101pub struct LanguageModelUsage {
102    pub prompt_tokens: u32,
103    pub completion_tokens: u32,
104    pub total_tokens: u32,
105}
106
107#[derive(Deserialize, Debug)]
108pub struct LanguageModelChoiceDelta {
109    pub index: u32,
110    pub delta: LanguageModelResponseMessage,
111    pub finish_reason: Option<String>,
112}
113
114#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
115pub enum MessageStatus {
116    Pending,
117    Done,
118    Error(SharedString),
119    Canceled,
120}
121
122impl MessageStatus {
123    pub fn from_proto(status: proto::ContextMessageStatus) -> MessageStatus {
124        match status.variant {
125            Some(proto::context_message_status::Variant::Pending(_)) => MessageStatus::Pending,
126            Some(proto::context_message_status::Variant::Done(_)) => MessageStatus::Done,
127            Some(proto::context_message_status::Variant::Error(error)) => {
128                MessageStatus::Error(error.message.into())
129            }
130            Some(proto::context_message_status::Variant::Canceled(_)) => MessageStatus::Canceled,
131            None => MessageStatus::Pending,
132        }
133    }
134
135    pub fn to_proto(&self) -> proto::ContextMessageStatus {
136        match self {
137            MessageStatus::Pending => proto::ContextMessageStatus {
138                variant: Some(proto::context_message_status::Variant::Pending(
139                    proto::context_message_status::Pending {},
140                )),
141            },
142            MessageStatus::Done => proto::ContextMessageStatus {
143                variant: Some(proto::context_message_status::Variant::Done(
144                    proto::context_message_status::Done {},
145                )),
146            },
147            MessageStatus::Error(message) => proto::ContextMessageStatus {
148                variant: Some(proto::context_message_status::Variant::Error(
149                    proto::context_message_status::Error {
150                        message: message.to_string(),
151                    },
152                )),
153            },
154            MessageStatus::Canceled => proto::ContextMessageStatus {
155                variant: Some(proto::context_message_status::Variant::Canceled(
156                    proto::context_message_status::Canceled {},
157                )),
158            },
159        }
160    }
161}
162
163/// The state pertaining to the Assistant.
164#[derive(Default)]
165struct Assistant {
166    /// Whether the Assistant is enabled.
167    enabled: bool,
168}
169
170impl Global for Assistant {}
171
172impl Assistant {
173    const NAMESPACE: &'static str = "assistant";
174
175    fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
176        if self.enabled == enabled {
177            return;
178        }
179
180        self.enabled = enabled;
181
182        if !enabled {
183            CommandPaletteFilter::update_global(cx, |filter, _cx| {
184                filter.hide_namespace(Self::NAMESPACE);
185            });
186
187            return;
188        }
189
190        CommandPaletteFilter::update_global(cx, |filter, _cx| {
191            filter.show_namespace(Self::NAMESPACE);
192        });
193    }
194}
195
196pub fn init(
197    fs: Arc<dyn Fs>,
198    client: Arc<Client>,
199    stdout_is_a_pty: bool,
200    cx: &mut AppContext,
201) -> Arc<PromptBuilder> {
202    cx.set_global(Assistant::default());
203    AssistantSettings::register(cx);
204    SlashCommandSettings::register(cx);
205
206    // TODO: remove this when 0.148.0 is released.
207    if AssistantSettings::get_global(cx).using_outdated_settings_version {
208        update_settings_file::<AssistantSettings>(fs.clone(), cx, {
209            let fs = fs.clone();
210            |content, cx| {
211                content.update_file(fs, cx);
212            }
213        });
214    }
215
216    if cx.has_flag::<SearchSlashCommandFeatureFlag>() {
217        cx.spawn(|mut cx| {
218            let client = client.clone();
219            async move {
220                let embedding_provider = CloudEmbeddingProvider::new(client.clone());
221                let semantic_index = SemanticDb::new(
222                    paths::embeddings_dir().join("semantic-index-db.0.mdb"),
223                    Arc::new(embedding_provider),
224                    &mut cx,
225                )
226                .await?;
227
228                cx.update(|cx| cx.set_global(semantic_index))
229            }
230        })
231        .detach();
232    }
233
234    context_store::init(&client.clone().into());
235    prompt_library::init(cx);
236    init_language_model_settings(cx);
237    assistant_slash_command::init(cx);
238    assistant_tool::init(cx);
239    assistant_panel::init(cx);
240    context_servers::init(cx);
241
242    let prompt_builder = prompts::PromptBuilder::new(Some(PromptLoadingParams {
243        fs: fs.clone(),
244        repo_path: stdout_is_a_pty
245            .then(|| std::env::current_dir().log_err())
246            .flatten(),
247        cx,
248    }))
249    .log_err()
250    .map(Arc::new)
251    .unwrap_or_else(|| Arc::new(prompts::PromptBuilder::new(None).unwrap()));
252    register_slash_commands(Some(prompt_builder.clone()), cx);
253    register_tools(cx);
254    inline_assistant::init(
255        fs.clone(),
256        prompt_builder.clone(),
257        client.telemetry().clone(),
258        cx,
259    );
260    terminal_inline_assistant::init(
261        fs.clone(),
262        prompt_builder.clone(),
263        client.telemetry().clone(),
264        cx,
265    );
266    IndexedDocsRegistry::init_global(cx);
267
268    CommandPaletteFilter::update_global(cx, |filter, _cx| {
269        filter.hide_namespace(Assistant::NAMESPACE);
270    });
271    Assistant::update_global(cx, |assistant, cx| {
272        let settings = AssistantSettings::get_global(cx);
273
274        assistant.set_enabled(settings.enabled, cx);
275    });
276    cx.observe_global::<SettingsStore>(|cx| {
277        Assistant::update_global(cx, |assistant, cx| {
278            let settings = AssistantSettings::get_global(cx);
279            assistant.set_enabled(settings.enabled, cx);
280        });
281    })
282    .detach();
283
284    register_context_server_handlers(cx);
285
286    prompt_builder
287}
288
289fn register_context_server_handlers(cx: &mut AppContext) {
290    cx.subscribe(
291        &context_servers::manager::ContextServerManager::global(cx),
292        |manager, event, cx| match event {
293            context_servers::manager::Event::ServerStarted { server_id } => {
294                cx.update_model(
295                    &manager,
296                    |manager: &mut context_servers::manager::ContextServerManager, cx| {
297                        let slash_command_registry = SlashCommandRegistry::global(cx);
298                        let context_server_registry = ContextServerRegistry::global(cx);
299                        if let Some(server) = manager.get_server(server_id) {
300                            cx.spawn(|_, _| async move {
301                                let Some(protocol) = server.client.read().clone() else {
302                                    return;
303                                };
304
305                                if protocol.capable(context_servers::protocol::ServerCapability::Prompts) {
306                                    if let Some(prompts) = protocol.list_prompts().await.log_err() {
307                                        for prompt in prompts
308                                            .into_iter()
309                                            .filter(context_server_command::acceptable_prompt)
310                                        {
311                                            log::info!(
312                                                "registering context server command: {:?}",
313                                                prompt.name
314                                            );
315                                            context_server_registry.register_command(
316                                                server.id.clone(),
317                                                prompt.name.as_str(),
318                                            );
319                                            slash_command_registry.register_command(
320                                                context_server_command::ContextServerSlashCommand::new(
321                                                    &server, prompt,
322                                                ),
323                                                true,
324                                            );
325                                        }
326                                    }
327                                }
328                            })
329                            .detach();
330                        }
331                    },
332                );
333
334                cx.update_model(
335                    &manager,
336                    |manager: &mut context_servers::manager::ContextServerManager, cx| {
337                        let tool_registry = ToolRegistry::global(cx);
338                        let context_server_registry = ContextServerRegistry::global(cx);
339                        if let Some(server) = manager.get_server(server_id) {
340                            cx.spawn(|_, _| async move {
341                                let Some(protocol) = server.client.read().clone() else {
342                                    return;
343                                };
344
345                                if protocol.capable(context_servers::protocol::ServerCapability::Tools) {
346                                    if let Some(tools) = protocol.list_tools().await.log_err() {
347                                        for tool in tools.tools {
348                                            log::info!(
349                                                "registering context server tool: {:?}",
350                                                tool.name
351                                            );
352                                            context_server_registry.register_tool(
353                                                server.id.clone(),
354                                                tool.name.as_str(),
355                                            );
356                                            tool_registry.register_tool(
357                                                tools::context_server_tool::ContextServerTool::new(
358                                                    server.id.clone(),
359                                                    tool
360                                                ),
361                                            );
362                                        }
363                                    }
364                                }
365                            })
366                            .detach();
367                        }
368                    },
369                );
370            }
371            context_servers::manager::Event::ServerStopped { server_id } => {
372                let slash_command_registry = SlashCommandRegistry::global(cx);
373                let context_server_registry = ContextServerRegistry::global(cx);
374                if let Some(commands) = context_server_registry.get_commands(server_id) {
375                    for command_name in commands {
376                        slash_command_registry.unregister_command_by_name(&command_name);
377                        context_server_registry.unregister_command(&server_id, &command_name);
378                    }
379                }
380
381                if let Some(tools) = context_server_registry.get_tools(server_id) {
382                    let tool_registry = ToolRegistry::global(cx);
383                    for tool_name in tools {
384                        tool_registry.unregister_tool_by_name(&tool_name);
385                        context_server_registry.unregister_tool(&server_id, &tool_name);
386                    }
387                }
388            }
389        },
390    )
391    .detach();
392}
393
394fn init_language_model_settings(cx: &mut AppContext) {
395    update_active_language_model_from_settings(cx);
396
397    cx.observe_global::<SettingsStore>(update_active_language_model_from_settings)
398        .detach();
399    cx.subscribe(
400        &LanguageModelRegistry::global(cx),
401        |_, event: &language_model::Event, cx| match event {
402            language_model::Event::ProviderStateChanged
403            | language_model::Event::AddedProvider(_)
404            | language_model::Event::RemovedProvider(_) => {
405                update_active_language_model_from_settings(cx);
406            }
407            _ => {}
408        },
409    )
410    .detach();
411}
412
413fn update_active_language_model_from_settings(cx: &mut AppContext) {
414    let settings = AssistantSettings::get_global(cx);
415    let provider_name = LanguageModelProviderId::from(settings.default_model.provider.clone());
416    let model_id = LanguageModelId::from(settings.default_model.model.clone());
417    let inline_alternatives = settings
418        .inline_alternatives
419        .iter()
420        .map(|alternative| {
421            (
422                LanguageModelProviderId::from(alternative.provider.clone()),
423                LanguageModelId::from(alternative.model.clone()),
424            )
425        })
426        .collect::<Vec<_>>();
427    LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
428        registry.select_active_model(&provider_name, &model_id, cx);
429        registry.select_inline_alternative_models(inline_alternatives, cx);
430    });
431}
432
433fn register_slash_commands(prompt_builder: Option<Arc<PromptBuilder>>, cx: &mut AppContext) {
434    let slash_command_registry = SlashCommandRegistry::global(cx);
435
436    slash_command_registry.register_command(file_command::FileSlashCommand, true);
437    slash_command_registry.register_command(delta_command::DeltaSlashCommand, true);
438    slash_command_registry.register_command(symbols_command::OutlineSlashCommand, true);
439    slash_command_registry.register_command(tab_command::TabSlashCommand, true);
440    slash_command_registry
441        .register_command(cargo_workspace_command::CargoWorkspaceSlashCommand, true);
442    slash_command_registry.register_command(prompt_command::PromptSlashCommand, true);
443    slash_command_registry.register_command(selection_command::SelectionCommand, true);
444    slash_command_registry.register_command(default_command::DefaultSlashCommand, false);
445    slash_command_registry.register_command(terminal_command::TerminalSlashCommand, true);
446    slash_command_registry.register_command(now_command::NowSlashCommand, false);
447    slash_command_registry.register_command(diagnostics_command::DiagnosticsSlashCommand, true);
448    slash_command_registry.register_command(fetch_command::FetchSlashCommand, false);
449    slash_command_registry.register_command(fetch_command::FetchSlashCommand, false);
450
451    if let Some(prompt_builder) = prompt_builder {
452        cx.observe_flag::<project_command::ProjectSlashCommandFeatureFlag, _>({
453            let slash_command_registry = slash_command_registry.clone();
454            move |is_enabled, _cx| {
455                if is_enabled {
456                    slash_command_registry.register_command(
457                        project_command::ProjectSlashCommand::new(prompt_builder.clone()),
458                        true,
459                    );
460                }
461            }
462        })
463        .detach();
464    }
465
466    cx.observe_flag::<auto_command::AutoSlashCommandFeatureFlag, _>({
467        let slash_command_registry = slash_command_registry.clone();
468        move |is_enabled, _cx| {
469            if is_enabled {
470                // [#auto-staff-ship] TODO remove this when /auto is no longer staff-shipped
471                slash_command_registry.register_command(auto_command::AutoCommand, true);
472            }
473        }
474    })
475    .detach();
476
477    cx.observe_flag::<streaming_example_command::StreamingExampleSlashCommandFeatureFlag, _>({
478        let slash_command_registry = slash_command_registry.clone();
479        move |is_enabled, _cx| {
480            if is_enabled {
481                slash_command_registry.register_command(
482                    streaming_example_command::StreamingExampleSlashCommand,
483                    false,
484                );
485            }
486        }
487    })
488    .detach();
489
490    update_slash_commands_from_settings(cx);
491    cx.observe_global::<SettingsStore>(update_slash_commands_from_settings)
492        .detach();
493
494    cx.observe_flag::<search_command::SearchSlashCommandFeatureFlag, _>({
495        let slash_command_registry = slash_command_registry.clone();
496        move |is_enabled, _cx| {
497            if is_enabled {
498                slash_command_registry.register_command(search_command::SearchSlashCommand, true);
499            }
500        }
501    })
502    .detach();
503}
504
505fn update_slash_commands_from_settings(cx: &mut AppContext) {
506    let slash_command_registry = SlashCommandRegistry::global(cx);
507    let settings = SlashCommandSettings::get_global(cx);
508
509    if settings.docs.enabled {
510        slash_command_registry.register_command(docs_command::DocsSlashCommand, true);
511    } else {
512        slash_command_registry.unregister_command(docs_command::DocsSlashCommand);
513    }
514
515    if settings.cargo_workspace.enabled {
516        slash_command_registry
517            .register_command(cargo_workspace_command::CargoWorkspaceSlashCommand, true);
518    } else {
519        slash_command_registry
520            .unregister_command(cargo_workspace_command::CargoWorkspaceSlashCommand);
521    }
522}
523
524fn register_tools(cx: &mut AppContext) {
525    let tool_registry = ToolRegistry::global(cx);
526    tool_registry.register_tool(tools::now_tool::NowTool);
527}
528
529pub fn humanize_token_count(count: usize) -> String {
530    match count {
531        0..=999 => count.to_string(),
532        1000..=9999 => {
533            let thousands = count / 1000;
534            let hundreds = (count % 1000 + 50) / 100;
535            if hundreds == 0 {
536                format!("{}k", thousands)
537            } else if hundreds == 10 {
538                format!("{}k", thousands + 1)
539            } else {
540                format!("{}.{}k", thousands, hundreds)
541            }
542        }
543        _ => format!("{}k", (count + 500) / 1000),
544    }
545}
546
547#[cfg(test)]
548#[ctor::ctor]
549fn init_logger() {
550    if std::env::var("RUST_LOG").is_ok() {
551        env_logger::init();
552    }
553}