assistant.rs

  1#![cfg_attr(target_os = "windows", allow(unused, dead_code))]
  2
  3pub mod assistant_panel;
  4pub mod assistant_settings;
  5mod context;
  6pub(crate) mod context_inspector;
  7pub mod context_store;
  8mod inline_assistant;
  9mod model_selector;
 10mod prompt_library;
 11mod prompts;
 12mod slash_command;
 13pub mod slash_command_settings;
 14mod streaming_diff;
 15mod terminal_inline_assistant;
 16
 17pub use assistant_panel::{AssistantPanel, AssistantPanelEvent};
 18use assistant_settings::AssistantSettings;
 19use assistant_slash_command::SlashCommandRegistry;
 20use client::{proto, Client};
 21use command_palette_hooks::CommandPaletteFilter;
 22pub use context::*;
 23pub use context_store::*;
 24use feature_flags::FeatureFlagAppExt;
 25use fs::Fs;
 26use gpui::{actions, impl_actions, AppContext, Global, SharedString, UpdateGlobal};
 27use indexed_docs::IndexedDocsRegistry;
 28pub(crate) use inline_assistant::*;
 29use language_model::{
 30    LanguageModelId, LanguageModelProviderId, LanguageModelRegistry, LanguageModelResponseMessage,
 31};
 32pub(crate) use model_selector::*;
 33pub use prompts::PromptBuilder;
 34use prompts::PromptOverrideContext;
 35use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
 36use serde::{Deserialize, Serialize};
 37use settings::{update_settings_file, Settings, SettingsStore};
 38use slash_command::{
 39    default_command, diagnostics_command, docs_command, fetch_command, file_command, now_command,
 40    project_command, prompt_command, search_command, symbols_command, tab_command,
 41    terminal_command, workflow_command,
 42};
 43use std::sync::Arc;
 44pub(crate) use streaming_diff::*;
 45use util::ResultExt;
 46
 47use crate::slash_command_settings::SlashCommandSettings;
 48
 49actions!(
 50    assistant,
 51    [
 52        Assist,
 53        Split,
 54        CycleMessageRole,
 55        QuoteSelection,
 56        InsertIntoEditor,
 57        ToggleFocus,
 58        InsertActivePrompt,
 59        ShowConfiguration,
 60        DeployHistory,
 61        DeployPromptLibrary,
 62        ConfirmCommand,
 63        ToggleModelSelector,
 64        DebugWorkflowSteps
 65    ]
 66);
 67
 68const DEFAULT_CONTEXT_LINES: usize = 50;
 69
 70#[derive(Clone, Default, Deserialize, PartialEq)]
 71pub struct InlineAssist {
 72    prompt: Option<String>,
 73}
 74
 75impl_actions!(assistant, [InlineAssist]);
 76
 77#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
 78pub struct MessageId(clock::Lamport);
 79
 80impl MessageId {
 81    pub fn as_u64(self) -> u64 {
 82        self.0.as_u64()
 83    }
 84}
 85
 86#[derive(Deserialize, Debug)]
 87pub struct LanguageModelUsage {
 88    pub prompt_tokens: u32,
 89    pub completion_tokens: u32,
 90    pub total_tokens: u32,
 91}
 92
 93#[derive(Deserialize, Debug)]
 94pub struct LanguageModelChoiceDelta {
 95    pub index: u32,
 96    pub delta: LanguageModelResponseMessage,
 97    pub finish_reason: Option<String>,
 98}
 99
100#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
101pub enum MessageStatus {
102    Pending,
103    Done,
104    Error(SharedString),
105    Canceled,
106}
107
108impl MessageStatus {
109    pub fn from_proto(status: proto::ContextMessageStatus) -> MessageStatus {
110        match status.variant {
111            Some(proto::context_message_status::Variant::Pending(_)) => MessageStatus::Pending,
112            Some(proto::context_message_status::Variant::Done(_)) => MessageStatus::Done,
113            Some(proto::context_message_status::Variant::Error(error)) => {
114                MessageStatus::Error(error.message.into())
115            }
116            Some(proto::context_message_status::Variant::Canceled(_)) => MessageStatus::Canceled,
117            None => MessageStatus::Pending,
118        }
119    }
120
121    pub fn to_proto(&self) -> proto::ContextMessageStatus {
122        match self {
123            MessageStatus::Pending => proto::ContextMessageStatus {
124                variant: Some(proto::context_message_status::Variant::Pending(
125                    proto::context_message_status::Pending {},
126                )),
127            },
128            MessageStatus::Done => proto::ContextMessageStatus {
129                variant: Some(proto::context_message_status::Variant::Done(
130                    proto::context_message_status::Done {},
131                )),
132            },
133            MessageStatus::Error(message) => proto::ContextMessageStatus {
134                variant: Some(proto::context_message_status::Variant::Error(
135                    proto::context_message_status::Error {
136                        message: message.to_string(),
137                    },
138                )),
139            },
140            MessageStatus::Canceled => proto::ContextMessageStatus {
141                variant: Some(proto::context_message_status::Variant::Canceled(
142                    proto::context_message_status::Canceled {},
143                )),
144            },
145        }
146    }
147}
148
149/// The state pertaining to the Assistant.
150#[derive(Default)]
151struct Assistant {
152    /// Whether the Assistant is enabled.
153    enabled: bool,
154}
155
156impl Global for Assistant {}
157
158impl Assistant {
159    const NAMESPACE: &'static str = "assistant";
160
161    fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
162        if self.enabled == enabled {
163            return;
164        }
165
166        self.enabled = enabled;
167
168        if !enabled {
169            CommandPaletteFilter::update_global(cx, |filter, _cx| {
170                filter.hide_namespace(Self::NAMESPACE);
171            });
172
173            return;
174        }
175
176        CommandPaletteFilter::update_global(cx, |filter, _cx| {
177            filter.show_namespace(Self::NAMESPACE);
178        });
179    }
180}
181
182pub fn init(
183    fs: Arc<dyn Fs>,
184    client: Arc<Client>,
185    dev_mode: bool,
186    cx: &mut AppContext,
187) -> Arc<PromptBuilder> {
188    cx.set_global(Assistant::default());
189    AssistantSettings::register(cx);
190    SlashCommandSettings::register(cx);
191
192    // TODO: remove this when 0.148.0 is released.
193    if AssistantSettings::get_global(cx).using_outdated_settings_version {
194        update_settings_file::<AssistantSettings>(fs.clone(), cx, {
195            let fs = fs.clone();
196            |content, cx| {
197                content.update_file(fs, cx);
198            }
199        });
200    }
201
202    cx.spawn(|mut cx| {
203        let client = client.clone();
204        async move {
205            let embedding_provider = CloudEmbeddingProvider::new(client.clone());
206            let semantic_index = SemanticIndex::new(
207                paths::embeddings_dir().join("semantic-index-db.0.mdb"),
208                Arc::new(embedding_provider),
209                &mut cx,
210            )
211            .await?;
212            cx.update(|cx| cx.set_global(semantic_index))
213        }
214    })
215    .detach();
216
217    context_store::init(&client);
218    prompt_library::init(cx);
219    init_language_model_settings(cx);
220    assistant_slash_command::init(cx);
221    assistant_panel::init(cx);
222
223    let prompt_builder = prompts::PromptBuilder::new(Some(PromptOverrideContext {
224        dev_mode,
225        fs: fs.clone(),
226        cx,
227    }))
228    .log_err()
229    .map(Arc::new)
230    .unwrap_or_else(|| Arc::new(prompts::PromptBuilder::new(None).unwrap()));
231    register_slash_commands(Some(prompt_builder.clone()), cx);
232    inline_assistant::init(
233        fs.clone(),
234        prompt_builder.clone(),
235        client.telemetry().clone(),
236        cx,
237    );
238    terminal_inline_assistant::init(
239        fs.clone(),
240        prompt_builder.clone(),
241        client.telemetry().clone(),
242        cx,
243    );
244    IndexedDocsRegistry::init_global(cx);
245
246    CommandPaletteFilter::update_global(cx, |filter, _cx| {
247        filter.hide_namespace(Assistant::NAMESPACE);
248    });
249    Assistant::update_global(cx, |assistant, cx| {
250        let settings = AssistantSettings::get_global(cx);
251
252        assistant.set_enabled(settings.enabled, cx);
253    });
254    cx.observe_global::<SettingsStore>(|cx| {
255        Assistant::update_global(cx, |assistant, cx| {
256            let settings = AssistantSettings::get_global(cx);
257            assistant.set_enabled(settings.enabled, cx);
258        });
259    })
260    .detach();
261
262    prompt_builder
263}
264
265fn init_language_model_settings(cx: &mut AppContext) {
266    update_active_language_model_from_settings(cx);
267
268    cx.observe_global::<SettingsStore>(update_active_language_model_from_settings)
269        .detach();
270    cx.subscribe(
271        &LanguageModelRegistry::global(cx),
272        |_, event: &language_model::Event, cx| match event {
273            language_model::Event::ProviderStateChanged
274            | language_model::Event::AddedProvider(_)
275            | language_model::Event::RemovedProvider(_) => {
276                update_active_language_model_from_settings(cx);
277            }
278            _ => {}
279        },
280    )
281    .detach();
282}
283
284fn update_active_language_model_from_settings(cx: &mut AppContext) {
285    let settings = AssistantSettings::get_global(cx);
286    let provider_name = LanguageModelProviderId::from(settings.default_model.provider.clone());
287    let model_id = LanguageModelId::from(settings.default_model.model.clone());
288    LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
289        registry.select_active_model(&provider_name, &model_id, cx);
290    });
291}
292
293fn register_slash_commands(prompt_builder: Option<Arc<PromptBuilder>>, cx: &mut AppContext) {
294    let slash_command_registry = SlashCommandRegistry::global(cx);
295    slash_command_registry.register_command(file_command::FileSlashCommand, true);
296    slash_command_registry.register_command(symbols_command::OutlineSlashCommand, true);
297    slash_command_registry.register_command(tab_command::TabSlashCommand, true);
298    slash_command_registry.register_command(project_command::ProjectSlashCommand, true);
299    slash_command_registry.register_command(prompt_command::PromptSlashCommand, true);
300    slash_command_registry.register_command(default_command::DefaultSlashCommand, false);
301    slash_command_registry.register_command(terminal_command::TerminalSlashCommand, true);
302    slash_command_registry.register_command(now_command::NowSlashCommand, false);
303    slash_command_registry.register_command(diagnostics_command::DiagnosticsSlashCommand, true);
304
305    if let Some(prompt_builder) = prompt_builder {
306        slash_command_registry.register_command(
307            workflow_command::WorkflowSlashCommand::new(prompt_builder),
308            true,
309        );
310    }
311    slash_command_registry.register_command(fetch_command::FetchSlashCommand, false);
312
313    update_slash_commands_from_settings(cx);
314    cx.observe_global::<SettingsStore>(update_slash_commands_from_settings)
315        .detach();
316
317    cx.observe_flag::<search_command::SearchSlashCommandFeatureFlag, _>({
318        let slash_command_registry = slash_command_registry.clone();
319        move |is_enabled, _cx| {
320            if is_enabled {
321                slash_command_registry.register_command(search_command::SearchSlashCommand, true);
322            }
323        }
324    })
325    .detach();
326}
327
328fn update_slash_commands_from_settings(cx: &mut AppContext) {
329    let slash_command_registry = SlashCommandRegistry::global(cx);
330    let settings = SlashCommandSettings::get_global(cx);
331
332    if settings.docs.enabled {
333        slash_command_registry.register_command(docs_command::DocsSlashCommand, true);
334    } else {
335        slash_command_registry.unregister_command(docs_command::DocsSlashCommand);
336    }
337
338    if settings.project.enabled {
339        slash_command_registry.register_command(project_command::ProjectSlashCommand, true);
340    } else {
341        slash_command_registry.unregister_command(project_command::ProjectSlashCommand);
342    }
343}
344
345pub fn humanize_token_count(count: usize) -> String {
346    match count {
347        0..=999 => count.to_string(),
348        1000..=9999 => {
349            let thousands = count / 1000;
350            let hundreds = (count % 1000 + 50) / 100;
351            if hundreds == 0 {
352                format!("{}k", thousands)
353            } else if hundreds == 10 {
354                format!("{}k", thousands + 1)
355            } else {
356                format!("{}.{}k", thousands, hundreds)
357            }
358        }
359        _ => format!("{}k", (count + 500) / 1000),
360    }
361}
362
363#[cfg(test)]
364#[ctor::ctor]
365fn init_logger() {
366    if std::env::var("RUST_LOG").is_ok() {
367        env_logger::init();
368    }
369}