assistant.rs

  1#![cfg_attr(target_os = "windows", allow(unused, dead_code))]
  2
  3pub mod assistant_panel;
  4pub mod assistant_settings;
  5mod context;
  6pub mod context_store;
  7mod inline_assistant;
  8mod model_selector;
  9mod patch;
 10mod prompt_library;
 11mod prompts;
 12mod slash_command;
 13pub(crate) mod slash_command_picker;
 14pub mod slash_command_settings;
 15mod slash_command_working_set;
 16mod streaming_diff;
 17mod terminal_inline_assistant;
 18mod tool_working_set;
 19mod tools;
 20
 21use crate::slash_command::project_command::ProjectSlashCommandFeatureFlag;
 22pub use crate::slash_command_working_set::{SlashCommandId, SlashCommandWorkingSet};
 23pub use crate::tool_working_set::{ToolId, ToolWorkingSet};
 24pub use assistant_panel::{AssistantPanel, AssistantPanelEvent};
 25use assistant_settings::AssistantSettings;
 26use assistant_slash_command::SlashCommandRegistry;
 27use assistant_tool::ToolRegistry;
 28use client::{proto, Client};
 29use command_palette_hooks::CommandPaletteFilter;
 30pub use context::*;
 31pub use context_store::*;
 32use feature_flags::FeatureFlagAppExt;
 33use fs::Fs;
 34use gpui::impl_actions;
 35use gpui::{actions, AppContext, Global, SharedString, UpdateGlobal};
 36pub(crate) use inline_assistant::*;
 37use language_model::{
 38    LanguageModelId, LanguageModelProviderId, LanguageModelRegistry, LanguageModelResponseMessage,
 39};
 40pub(crate) use model_selector::*;
 41pub use patch::*;
 42pub use prompts::PromptBuilder;
 43use prompts::PromptLoadingParams;
 44use semantic_index::{CloudEmbeddingProvider, SemanticDb};
 45use serde::{Deserialize, Serialize};
 46use settings::{update_settings_file, Settings, SettingsStore};
 47use slash_command::search_command::SearchSlashCommandFeatureFlag;
 48use slash_command::{
 49    auto_command, cargo_workspace_command, default_command, delta_command, diagnostics_command,
 50    docs_command, fetch_command, file_command, now_command, project_command, prompt_command,
 51    search_command, selection_command, symbols_command, tab_command, terminal_command,
 52};
 53use std::path::PathBuf;
 54use std::sync::Arc;
 55pub(crate) use streaming_diff::*;
 56use util::ResultExt;
 57
 58use crate::slash_command::streaming_example_command;
 59use crate::slash_command_settings::SlashCommandSettings;
 60
 61actions!(
 62    assistant,
 63    [
 64        Assist,
 65        Edit,
 66        Split,
 67        CopyCode,
 68        CycleMessageRole,
 69        QuoteSelection,
 70        InsertIntoEditor,
 71        ToggleFocus,
 72        InsertActivePrompt,
 73        DeployHistory,
 74        DeployPromptLibrary,
 75        ConfirmCommand,
 76        NewContext,
 77        ToggleModelSelector,
 78        CycleNextInlineAssist,
 79        CyclePreviousInlineAssist
 80    ]
 81);
 82
 83#[derive(PartialEq, Clone, Deserialize)]
 84pub enum InsertDraggedFiles {
 85    ProjectPaths(Vec<PathBuf>),
 86    ExternalFiles(Vec<PathBuf>),
 87}
 88
 89impl_actions!(assistant, [InsertDraggedFiles]);
 90
 91const DEFAULT_CONTEXT_LINES: usize = 50;
 92
 93#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
 94pub struct MessageId(clock::Lamport);
 95
 96impl MessageId {
 97    pub fn as_u64(self) -> u64 {
 98        self.0.as_u64()
 99    }
100}
101
102#[derive(Deserialize, Debug)]
103pub struct LanguageModelUsage {
104    pub prompt_tokens: u32,
105    pub completion_tokens: u32,
106    pub total_tokens: u32,
107}
108
109#[derive(Deserialize, Debug)]
110pub struct LanguageModelChoiceDelta {
111    pub index: u32,
112    pub delta: LanguageModelResponseMessage,
113    pub finish_reason: Option<String>,
114}
115
116#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
117pub enum MessageStatus {
118    Pending,
119    Done,
120    Error(SharedString),
121    Canceled,
122}
123
124impl MessageStatus {
125    pub fn from_proto(status: proto::ContextMessageStatus) -> MessageStatus {
126        match status.variant {
127            Some(proto::context_message_status::Variant::Pending(_)) => MessageStatus::Pending,
128            Some(proto::context_message_status::Variant::Done(_)) => MessageStatus::Done,
129            Some(proto::context_message_status::Variant::Error(error)) => {
130                MessageStatus::Error(error.message.into())
131            }
132            Some(proto::context_message_status::Variant::Canceled(_)) => MessageStatus::Canceled,
133            None => MessageStatus::Pending,
134        }
135    }
136
137    pub fn to_proto(&self) -> proto::ContextMessageStatus {
138        match self {
139            MessageStatus::Pending => proto::ContextMessageStatus {
140                variant: Some(proto::context_message_status::Variant::Pending(
141                    proto::context_message_status::Pending {},
142                )),
143            },
144            MessageStatus::Done => proto::ContextMessageStatus {
145                variant: Some(proto::context_message_status::Variant::Done(
146                    proto::context_message_status::Done {},
147                )),
148            },
149            MessageStatus::Error(message) => proto::ContextMessageStatus {
150                variant: Some(proto::context_message_status::Variant::Error(
151                    proto::context_message_status::Error {
152                        message: message.to_string(),
153                    },
154                )),
155            },
156            MessageStatus::Canceled => proto::ContextMessageStatus {
157                variant: Some(proto::context_message_status::Variant::Canceled(
158                    proto::context_message_status::Canceled {},
159                )),
160            },
161        }
162    }
163}
164
165/// The state pertaining to the Assistant.
166#[derive(Default)]
167struct Assistant {
168    /// Whether the Assistant is enabled.
169    enabled: bool,
170}
171
172impl Global for Assistant {}
173
174impl Assistant {
175    const NAMESPACE: &'static str = "assistant";
176
177    fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
178        if self.enabled == enabled {
179            return;
180        }
181
182        self.enabled = enabled;
183
184        if !enabled {
185            CommandPaletteFilter::update_global(cx, |filter, _cx| {
186                filter.hide_namespace(Self::NAMESPACE);
187            });
188
189            return;
190        }
191
192        CommandPaletteFilter::update_global(cx, |filter, _cx| {
193            filter.show_namespace(Self::NAMESPACE);
194        });
195    }
196}
197
198pub fn init(
199    fs: Arc<dyn Fs>,
200    client: Arc<Client>,
201    stdout_is_a_pty: bool,
202    cx: &mut AppContext,
203) -> Arc<PromptBuilder> {
204    cx.set_global(Assistant::default());
205    AssistantSettings::register(cx);
206    SlashCommandSettings::register(cx);
207
208    // TODO: remove this when 0.148.0 is released.
209    if AssistantSettings::get_global(cx).using_outdated_settings_version {
210        update_settings_file::<AssistantSettings>(fs.clone(), cx, {
211            let fs = fs.clone();
212            |content, cx| {
213                content.update_file(fs, cx);
214            }
215        });
216    }
217
218    cx.spawn(|mut cx| {
219        let client = client.clone();
220        async move {
221            let is_search_slash_command_enabled = cx
222                .update(|cx| cx.wait_for_flag::<SearchSlashCommandFeatureFlag>())?
223                .await;
224            let is_project_slash_command_enabled = cx
225                .update(|cx| cx.wait_for_flag::<ProjectSlashCommandFeatureFlag>())?
226                .await;
227
228            if !is_search_slash_command_enabled && !is_project_slash_command_enabled {
229                return Ok(());
230            }
231
232            let embedding_provider = CloudEmbeddingProvider::new(client.clone());
233            let semantic_index = SemanticDb::new(
234                paths::embeddings_dir().join("semantic-index-db.0.mdb"),
235                Arc::new(embedding_provider),
236                &mut cx,
237            )
238            .await?;
239
240            cx.update(|cx| cx.set_global(semantic_index))
241        }
242    })
243    .detach();
244
245    context_store::init(&client.clone().into());
246    prompt_library::init(cx);
247    init_language_model_settings(cx);
248    assistant_slash_command::init(cx);
249    assistant_tool::init(cx);
250    assistant_panel::init(cx);
251    context_servers::init(cx);
252
253    let prompt_builder = prompts::PromptBuilder::new(Some(PromptLoadingParams {
254        fs: fs.clone(),
255        repo_path: stdout_is_a_pty
256            .then(|| std::env::current_dir().log_err())
257            .flatten(),
258        cx,
259    }))
260    .log_err()
261    .map(Arc::new)
262    .unwrap_or_else(|| Arc::new(prompts::PromptBuilder::new(None).unwrap()));
263    register_slash_commands(Some(prompt_builder.clone()), cx);
264    register_tools(cx);
265    inline_assistant::init(
266        fs.clone(),
267        prompt_builder.clone(),
268        client.telemetry().clone(),
269        cx,
270    );
271    terminal_inline_assistant::init(
272        fs.clone(),
273        prompt_builder.clone(),
274        client.telemetry().clone(),
275        cx,
276    );
277    indexed_docs::init(cx);
278
279    CommandPaletteFilter::update_global(cx, |filter, _cx| {
280        filter.hide_namespace(Assistant::NAMESPACE);
281    });
282    Assistant::update_global(cx, |assistant, cx| {
283        let settings = AssistantSettings::get_global(cx);
284
285        assistant.set_enabled(settings.enabled, cx);
286    });
287    cx.observe_global::<SettingsStore>(|cx| {
288        Assistant::update_global(cx, |assistant, cx| {
289            let settings = AssistantSettings::get_global(cx);
290            assistant.set_enabled(settings.enabled, cx);
291        });
292    })
293    .detach();
294
295    prompt_builder
296}
297
298fn init_language_model_settings(cx: &mut AppContext) {
299    update_active_language_model_from_settings(cx);
300
301    cx.observe_global::<SettingsStore>(update_active_language_model_from_settings)
302        .detach();
303    cx.subscribe(
304        &LanguageModelRegistry::global(cx),
305        |_, event: &language_model::Event, cx| match event {
306            language_model::Event::ProviderStateChanged
307            | language_model::Event::AddedProvider(_)
308            | language_model::Event::RemovedProvider(_) => {
309                update_active_language_model_from_settings(cx);
310            }
311            _ => {}
312        },
313    )
314    .detach();
315}
316
317fn update_active_language_model_from_settings(cx: &mut AppContext) {
318    let settings = AssistantSettings::get_global(cx);
319    let provider_name = LanguageModelProviderId::from(settings.default_model.provider.clone());
320    let model_id = LanguageModelId::from(settings.default_model.model.clone());
321    let inline_alternatives = settings
322        .inline_alternatives
323        .iter()
324        .map(|alternative| {
325            (
326                LanguageModelProviderId::from(alternative.provider.clone()),
327                LanguageModelId::from(alternative.model.clone()),
328            )
329        })
330        .collect::<Vec<_>>();
331    LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
332        registry.select_active_model(&provider_name, &model_id, cx);
333        registry.select_inline_alternative_models(inline_alternatives, cx);
334    });
335}
336
337fn register_slash_commands(prompt_builder: Option<Arc<PromptBuilder>>, cx: &mut AppContext) {
338    let slash_command_registry = SlashCommandRegistry::global(cx);
339
340    slash_command_registry.register_command(file_command::FileSlashCommand, true);
341    slash_command_registry.register_command(delta_command::DeltaSlashCommand, true);
342    slash_command_registry.register_command(symbols_command::OutlineSlashCommand, true);
343    slash_command_registry.register_command(tab_command::TabSlashCommand, true);
344    slash_command_registry
345        .register_command(cargo_workspace_command::CargoWorkspaceSlashCommand, true);
346    slash_command_registry.register_command(prompt_command::PromptSlashCommand, true);
347    slash_command_registry.register_command(selection_command::SelectionCommand, true);
348    slash_command_registry.register_command(default_command::DefaultSlashCommand, false);
349    slash_command_registry.register_command(terminal_command::TerminalSlashCommand, true);
350    slash_command_registry.register_command(now_command::NowSlashCommand, false);
351    slash_command_registry.register_command(diagnostics_command::DiagnosticsSlashCommand, true);
352    slash_command_registry.register_command(fetch_command::FetchSlashCommand, false);
353    slash_command_registry.register_command(fetch_command::FetchSlashCommand, false);
354
355    if let Some(prompt_builder) = prompt_builder {
356        cx.observe_flag::<project_command::ProjectSlashCommandFeatureFlag, _>({
357            let slash_command_registry = slash_command_registry.clone();
358            move |is_enabled, _cx| {
359                if is_enabled {
360                    slash_command_registry.register_command(
361                        project_command::ProjectSlashCommand::new(prompt_builder.clone()),
362                        true,
363                    );
364                }
365            }
366        })
367        .detach();
368    }
369
370    cx.observe_flag::<auto_command::AutoSlashCommandFeatureFlag, _>({
371        let slash_command_registry = slash_command_registry.clone();
372        move |is_enabled, _cx| {
373            if is_enabled {
374                // [#auto-staff-ship] TODO remove this when /auto is no longer staff-shipped
375                slash_command_registry.register_command(auto_command::AutoCommand, true);
376            }
377        }
378    })
379    .detach();
380
381    cx.observe_flag::<streaming_example_command::StreamingExampleSlashCommandFeatureFlag, _>({
382        let slash_command_registry = slash_command_registry.clone();
383        move |is_enabled, _cx| {
384            if is_enabled {
385                slash_command_registry.register_command(
386                    streaming_example_command::StreamingExampleSlashCommand,
387                    false,
388                );
389            }
390        }
391    })
392    .detach();
393
394    update_slash_commands_from_settings(cx);
395    cx.observe_global::<SettingsStore>(update_slash_commands_from_settings)
396        .detach();
397
398    cx.observe_flag::<search_command::SearchSlashCommandFeatureFlag, _>({
399        let slash_command_registry = slash_command_registry.clone();
400        move |is_enabled, _cx| {
401            if is_enabled {
402                slash_command_registry.register_command(search_command::SearchSlashCommand, true);
403            }
404        }
405    })
406    .detach();
407}
408
409fn update_slash_commands_from_settings(cx: &mut AppContext) {
410    let slash_command_registry = SlashCommandRegistry::global(cx);
411    let settings = SlashCommandSettings::get_global(cx);
412
413    if settings.docs.enabled {
414        slash_command_registry.register_command(docs_command::DocsSlashCommand, true);
415    } else {
416        slash_command_registry.unregister_command(docs_command::DocsSlashCommand);
417    }
418
419    if settings.cargo_workspace.enabled {
420        slash_command_registry
421            .register_command(cargo_workspace_command::CargoWorkspaceSlashCommand, true);
422    } else {
423        slash_command_registry
424            .unregister_command(cargo_workspace_command::CargoWorkspaceSlashCommand);
425    }
426}
427
428fn register_tools(cx: &mut AppContext) {
429    let tool_registry = ToolRegistry::global(cx);
430    tool_registry.register_tool(tools::now_tool::NowTool);
431}
432
433pub fn humanize_token_count(count: usize) -> String {
434    match count {
435        0..=999 => count.to_string(),
436        1000..=9999 => {
437            let thousands = count / 1000;
438            let hundreds = (count % 1000 + 50) / 100;
439            if hundreds == 0 {
440                format!("{}k", thousands)
441            } else if hundreds == 10 {
442                format!("{}k", thousands + 1)
443            } else {
444                format!("{}.{}k", thousands, hundreds)
445            }
446        }
447        _ => format!("{}k", (count + 500) / 1000),
448    }
449}
450
451#[cfg(test)]
452#[ctor::ctor]
453fn init_logger() {
454    if std::env::var("RUST_LOG").is_ok() {
455        env_logger::init();
456    }
457}