assistant.rs

  1pub mod assistant_panel;
  2pub mod assistant_settings;
  3mod context;
  4pub mod context_store;
  5mod inline_assistant;
  6mod model_selector;
  7mod prompt_library;
  8mod prompts;
  9mod slash_command;
 10mod streaming_diff;
 11mod terminal_inline_assistant;
 12
 13pub use assistant_panel::{AssistantPanel, AssistantPanelEvent};
 14use assistant_settings::AssistantSettings;
 15use assistant_slash_command::SlashCommandRegistry;
 16use client::{proto, Client};
 17use command_palette_hooks::CommandPaletteFilter;
 18pub use context::*;
 19pub use context_store::*;
 20use fs::Fs;
 21use gpui::{actions, impl_actions, AppContext, Global, SharedString, UpdateGlobal};
 22use indexed_docs::IndexedDocsRegistry;
 23pub(crate) use inline_assistant::*;
 24use language_model::{
 25    LanguageModelId, LanguageModelProviderId, LanguageModelRegistry, LanguageModelResponseMessage,
 26};
 27pub(crate) use model_selector::*;
 28pub use prompts::PromptBuilder;
 29use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
 30use serde::{Deserialize, Serialize};
 31use settings::{update_settings_file, Settings, SettingsStore};
 32use slash_command::{
 33    active_command, default_command, diagnostics_command, docs_command, fetch_command,
 34    file_command, now_command, project_command, prompt_command, search_command, symbols_command,
 35    tabs_command, term_command, workflow_command,
 36};
 37use std::sync::Arc;
 38pub(crate) use streaming_diff::*;
 39use util::ResultExt;
 40
 41actions!(
 42    assistant,
 43    [
 44        Assist,
 45        Split,
 46        CycleMessageRole,
 47        QuoteSelection,
 48        InsertIntoEditor,
 49        ToggleFocus,
 50        InsertActivePrompt,
 51        ShowConfiguration,
 52        DeployHistory,
 53        DeployPromptLibrary,
 54        ConfirmCommand,
 55        ToggleModelSelector,
 56        DebugEditSteps
 57    ]
 58);
 59
 60const DEFAULT_CONTEXT_LINES: usize = 20;
 61
 62#[derive(Clone, Default, Deserialize, PartialEq)]
 63pub struct InlineAssist {
 64    prompt: Option<String>,
 65}
 66
 67impl_actions!(assistant, [InlineAssist]);
 68
 69#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
 70pub struct MessageId(clock::Lamport);
 71
 72impl MessageId {
 73    pub fn as_u64(self) -> u64 {
 74        self.0.as_u64()
 75    }
 76}
 77
 78#[derive(Deserialize, Debug)]
 79pub struct LanguageModelUsage {
 80    pub prompt_tokens: u32,
 81    pub completion_tokens: u32,
 82    pub total_tokens: u32,
 83}
 84
 85#[derive(Deserialize, Debug)]
 86pub struct LanguageModelChoiceDelta {
 87    pub index: u32,
 88    pub delta: LanguageModelResponseMessage,
 89    pub finish_reason: Option<String>,
 90}
 91
 92#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
 93pub enum MessageStatus {
 94    Pending,
 95    Done,
 96    Error(SharedString),
 97}
 98
 99impl MessageStatus {
100    pub fn from_proto(status: proto::ContextMessageStatus) -> MessageStatus {
101        match status.variant {
102            Some(proto::context_message_status::Variant::Pending(_)) => MessageStatus::Pending,
103            Some(proto::context_message_status::Variant::Done(_)) => MessageStatus::Done,
104            Some(proto::context_message_status::Variant::Error(error)) => {
105                MessageStatus::Error(error.message.into())
106            }
107            None => MessageStatus::Pending,
108        }
109    }
110
111    pub fn to_proto(&self) -> proto::ContextMessageStatus {
112        match self {
113            MessageStatus::Pending => proto::ContextMessageStatus {
114                variant: Some(proto::context_message_status::Variant::Pending(
115                    proto::context_message_status::Pending {},
116                )),
117            },
118            MessageStatus::Done => proto::ContextMessageStatus {
119                variant: Some(proto::context_message_status::Variant::Done(
120                    proto::context_message_status::Done {},
121                )),
122            },
123            MessageStatus::Error(message) => proto::ContextMessageStatus {
124                variant: Some(proto::context_message_status::Variant::Error(
125                    proto::context_message_status::Error {
126                        message: message.to_string(),
127                    },
128                )),
129            },
130        }
131    }
132}
133
134/// The state pertaining to the Assistant.
135#[derive(Default)]
136struct Assistant {
137    /// Whether the Assistant is enabled.
138    enabled: bool,
139}
140
141impl Global for Assistant {}
142
143impl Assistant {
144    const NAMESPACE: &'static str = "assistant";
145
146    fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
147        if self.enabled == enabled {
148            return;
149        }
150
151        self.enabled = enabled;
152
153        if !enabled {
154            CommandPaletteFilter::update_global(cx, |filter, _cx| {
155                filter.hide_namespace(Self::NAMESPACE);
156            });
157
158            return;
159        }
160
161        CommandPaletteFilter::update_global(cx, |filter, _cx| {
162            filter.show_namespace(Self::NAMESPACE);
163        });
164    }
165}
166
167pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) -> Arc<PromptBuilder> {
168    cx.set_global(Assistant::default());
169    AssistantSettings::register(cx);
170
171    // TODO: remove this when 0.148.0 is released.
172    if AssistantSettings::get_global(cx).using_outdated_settings_version {
173        update_settings_file::<AssistantSettings>(fs.clone(), cx, {
174            let fs = fs.clone();
175            |content, cx| {
176                content.update_file(fs, cx);
177            }
178        });
179    }
180
181    cx.spawn(|mut cx| {
182        let client = client.clone();
183        async move {
184            let embedding_provider = CloudEmbeddingProvider::new(client.clone());
185            let semantic_index = SemanticIndex::new(
186                paths::embeddings_dir().join("semantic-index-db.0.mdb"),
187                Arc::new(embedding_provider),
188                &mut cx,
189            )
190            .await?;
191            cx.update(|cx| cx.set_global(semantic_index))
192        }
193    })
194    .detach();
195
196    context_store::init(&client);
197    prompt_library::init(cx);
198    init_language_model_settings(cx);
199    assistant_slash_command::init(cx);
200    assistant_panel::init(cx);
201
202    let prompt_builder = prompts::PromptBuilder::new(Some((fs.clone(), cx)))
203        .log_err()
204        .map(Arc::new)
205        .unwrap_or_else(|| Arc::new(prompts::PromptBuilder::new(None).unwrap()));
206    register_slash_commands(Some(prompt_builder.clone()), cx);
207    inline_assistant::init(
208        fs.clone(),
209        prompt_builder.clone(),
210        client.telemetry().clone(),
211        cx,
212    );
213    terminal_inline_assistant::init(
214        fs.clone(),
215        prompt_builder.clone(),
216        client.telemetry().clone(),
217        cx,
218    );
219    IndexedDocsRegistry::init_global(cx);
220
221    CommandPaletteFilter::update_global(cx, |filter, _cx| {
222        filter.hide_namespace(Assistant::NAMESPACE);
223    });
224    Assistant::update_global(cx, |assistant, cx| {
225        let settings = AssistantSettings::get_global(cx);
226
227        assistant.set_enabled(settings.enabled, cx);
228    });
229    cx.observe_global::<SettingsStore>(|cx| {
230        Assistant::update_global(cx, |assistant, cx| {
231            let settings = AssistantSettings::get_global(cx);
232            assistant.set_enabled(settings.enabled, cx);
233        });
234    })
235    .detach();
236
237    prompt_builder
238}
239
240fn init_language_model_settings(cx: &mut AppContext) {
241    update_active_language_model_from_settings(cx);
242
243    cx.observe_global::<SettingsStore>(update_active_language_model_from_settings)
244        .detach();
245    cx.subscribe(
246        &LanguageModelRegistry::global(cx),
247        |_, event: &language_model::Event, cx| match event {
248            language_model::Event::ProviderStateChanged
249            | language_model::Event::AddedProvider(_)
250            | language_model::Event::RemovedProvider(_) => {
251                update_active_language_model_from_settings(cx);
252            }
253            _ => {}
254        },
255    )
256    .detach();
257}
258
259fn update_active_language_model_from_settings(cx: &mut AppContext) {
260    let settings = AssistantSettings::get_global(cx);
261    let provider_name = LanguageModelProviderId::from(settings.default_model.provider.clone());
262    let model_id = LanguageModelId::from(settings.default_model.model.clone());
263    LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
264        registry.select_active_model(&provider_name, &model_id, cx);
265    });
266}
267
268fn register_slash_commands(prompt_builder: Option<Arc<PromptBuilder>>, cx: &mut AppContext) {
269    let slash_command_registry = SlashCommandRegistry::global(cx);
270    slash_command_registry.register_command(file_command::FileSlashCommand, true);
271    slash_command_registry.register_command(active_command::ActiveSlashCommand, true);
272    slash_command_registry.register_command(symbols_command::OutlineSlashCommand, true);
273    slash_command_registry.register_command(tabs_command::TabsSlashCommand, true);
274    slash_command_registry.register_command(project_command::ProjectSlashCommand, true);
275    slash_command_registry.register_command(search_command::SearchSlashCommand, true);
276    slash_command_registry.register_command(prompt_command::PromptSlashCommand, true);
277    slash_command_registry.register_command(default_command::DefaultSlashCommand, true);
278    slash_command_registry.register_command(term_command::TermSlashCommand, true);
279    slash_command_registry.register_command(now_command::NowSlashCommand, true);
280    slash_command_registry.register_command(diagnostics_command::DiagnosticsSlashCommand, true);
281    slash_command_registry.register_command(docs_command::DocsSlashCommand, true);
282    if let Some(prompt_builder) = prompt_builder {
283        slash_command_registry.register_command(
284            workflow_command::WorkflowSlashCommand::new(prompt_builder),
285            true,
286        );
287    }
288    slash_command_registry.register_command(fetch_command::FetchSlashCommand, false);
289}
290
291pub fn humanize_token_count(count: usize) -> String {
292    match count {
293        0..=999 => count.to_string(),
294        1000..=9999 => {
295            let thousands = count / 1000;
296            let hundreds = (count % 1000 + 50) / 100;
297            if hundreds == 0 {
298                format!("{}k", thousands)
299            } else if hundreds == 10 {
300                format!("{}k", thousands + 1)
301            } else {
302                format!("{}.{}k", thousands, hundreds)
303            }
304        }
305        _ => format!("{}k", (count + 500) / 1000),
306    }
307}
308
309#[cfg(test)]
310#[ctor::ctor]
311fn init_logger() {
312    if std::env::var("RUST_LOG").is_ok() {
313        env_logger::init();
314    }
315}