assistant.rs

  1pub mod assistant_panel;
  2pub mod assistant_settings;
  3mod context;
  4pub mod context_store;
  5mod inline_assistant;
  6mod model_selector;
  7mod prompt_library;
  8mod prompts;
  9mod slash_command;
 10mod streaming_diff;
 11mod terminal_inline_assistant;
 12
 13pub use assistant_panel::{AssistantPanel, AssistantPanelEvent};
 14use assistant_settings::AssistantSettings;
 15use assistant_slash_command::SlashCommandRegistry;
 16use client::{proto, Client};
 17use command_palette_hooks::CommandPaletteFilter;
 18pub use context::*;
 19pub use context_store::*;
 20use fs::Fs;
 21use gpui::{actions, impl_actions, AppContext, Global, SharedString, UpdateGlobal};
 22use indexed_docs::IndexedDocsRegistry;
 23pub(crate) use inline_assistant::*;
 24use language_model::{
 25    LanguageModelId, LanguageModelProviderId, LanguageModelRegistry, LanguageModelResponseMessage,
 26};
 27pub(crate) use model_selector::*;
 28use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
 29use serde::{Deserialize, Serialize};
 30use settings::{update_settings_file, Settings, SettingsStore};
 31use slash_command::{
 32    active_command, default_command, diagnostics_command, docs_command, fetch_command,
 33    file_command, now_command, project_command, prompt_command, search_command, symbols_command,
 34    tabs_command, term_command,
 35};
 36use std::sync::Arc;
 37pub(crate) use streaming_diff::*;
 38
 39actions!(
 40    assistant,
 41    [
 42        Assist,
 43        Split,
 44        CycleMessageRole,
 45        QuoteSelection,
 46        InsertIntoEditor,
 47        ToggleFocus,
 48        ResetKey,
 49        InsertActivePrompt,
 50        DeployHistory,
 51        DeployPromptLibrary,
 52        ConfirmCommand,
 53        ToggleModelSelector,
 54        DebugEditSteps
 55    ]
 56);
 57
 58#[derive(Clone, Default, Deserialize, PartialEq)]
 59pub struct InlineAssist {
 60    prompt: Option<String>,
 61}
 62
 63impl_actions!(assistant, [InlineAssist]);
 64
 65#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
 66pub struct MessageId(clock::Lamport);
 67
 68impl MessageId {
 69    pub fn as_u64(self) -> u64 {
 70        self.0.as_u64()
 71    }
 72}
 73
 74#[derive(Deserialize, Debug)]
 75pub struct LanguageModelUsage {
 76    pub prompt_tokens: u32,
 77    pub completion_tokens: u32,
 78    pub total_tokens: u32,
 79}
 80
 81#[derive(Deserialize, Debug)]
 82pub struct LanguageModelChoiceDelta {
 83    pub index: u32,
 84    pub delta: LanguageModelResponseMessage,
 85    pub finish_reason: Option<String>,
 86}
 87
 88#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
 89pub enum MessageStatus {
 90    Pending,
 91    Done,
 92    Error(SharedString),
 93}
 94
 95impl MessageStatus {
 96    pub fn from_proto(status: proto::ContextMessageStatus) -> MessageStatus {
 97        match status.variant {
 98            Some(proto::context_message_status::Variant::Pending(_)) => MessageStatus::Pending,
 99            Some(proto::context_message_status::Variant::Done(_)) => MessageStatus::Done,
100            Some(proto::context_message_status::Variant::Error(error)) => {
101                MessageStatus::Error(error.message.into())
102            }
103            None => MessageStatus::Pending,
104        }
105    }
106
107    pub fn to_proto(&self) -> proto::ContextMessageStatus {
108        match self {
109            MessageStatus::Pending => proto::ContextMessageStatus {
110                variant: Some(proto::context_message_status::Variant::Pending(
111                    proto::context_message_status::Pending {},
112                )),
113            },
114            MessageStatus::Done => proto::ContextMessageStatus {
115                variant: Some(proto::context_message_status::Variant::Done(
116                    proto::context_message_status::Done {},
117                )),
118            },
119            MessageStatus::Error(message) => proto::ContextMessageStatus {
120                variant: Some(proto::context_message_status::Variant::Error(
121                    proto::context_message_status::Error {
122                        message: message.to_string(),
123                    },
124                )),
125            },
126        }
127    }
128}
129
130/// The state pertaining to the Assistant.
131#[derive(Default)]
132struct Assistant {
133    /// Whether the Assistant is enabled.
134    enabled: bool,
135}
136
137impl Global for Assistant {}
138
139impl Assistant {
140    const NAMESPACE: &'static str = "assistant";
141
142    fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
143        if self.enabled == enabled {
144            return;
145        }
146
147        self.enabled = enabled;
148
149        if !enabled {
150            CommandPaletteFilter::update_global(cx, |filter, _cx| {
151                filter.hide_namespace(Self::NAMESPACE);
152            });
153
154            return;
155        }
156
157        CommandPaletteFilter::update_global(cx, |filter, _cx| {
158            filter.show_namespace(Self::NAMESPACE);
159        });
160    }
161}
162
163pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) {
164    cx.set_global(Assistant::default());
165    AssistantSettings::register(cx);
166
167    // TODO: remove this when 0.148.0 is released.
168    if AssistantSettings::get_global(cx).using_outdated_settings_version {
169        update_settings_file::<AssistantSettings>(fs.clone(), cx, {
170            let fs = fs.clone();
171            |content, cx| {
172                content.update_file(fs, cx);
173            }
174        });
175    }
176
177    cx.spawn(|mut cx| {
178        let client = client.clone();
179        async move {
180            let embedding_provider = CloudEmbeddingProvider::new(client.clone());
181            let semantic_index = SemanticIndex::new(
182                paths::embeddings_dir().join("semantic-index-db.0.mdb"),
183                Arc::new(embedding_provider),
184                &mut cx,
185            )
186            .await?;
187            cx.update(|cx| cx.set_global(semantic_index))
188        }
189    })
190    .detach();
191
192    context_store::init(&client);
193    prompt_library::init(cx);
194    init_language_model_settings(cx);
195    assistant_slash_command::init(cx);
196    register_slash_commands(cx);
197    assistant_panel::init(cx);
198    inline_assistant::init(fs.clone(), client.telemetry().clone(), cx);
199    terminal_inline_assistant::init(fs.clone(), client.telemetry().clone(), cx);
200    IndexedDocsRegistry::init_global(cx);
201
202    CommandPaletteFilter::update_global(cx, |filter, _cx| {
203        filter.hide_namespace(Assistant::NAMESPACE);
204    });
205    Assistant::update_global(cx, |assistant, cx| {
206        let settings = AssistantSettings::get_global(cx);
207
208        assistant.set_enabled(settings.enabled, cx);
209    });
210    cx.observe_global::<SettingsStore>(|cx| {
211        Assistant::update_global(cx, |assistant, cx| {
212            let settings = AssistantSettings::get_global(cx);
213            assistant.set_enabled(settings.enabled, cx);
214        });
215    })
216    .detach();
217}
218
219fn init_language_model_settings(cx: &mut AppContext) {
220    update_active_language_model_from_settings(cx);
221
222    cx.observe_global::<SettingsStore>(update_active_language_model_from_settings)
223        .detach();
224    cx.observe(&LanguageModelRegistry::global(cx), |_, cx| {
225        update_active_language_model_from_settings(cx)
226    })
227    .detach();
228}
229
230fn update_active_language_model_from_settings(cx: &mut AppContext) {
231    let settings = AssistantSettings::get_global(cx);
232    let provider_name = LanguageModelProviderId::from(settings.default_model.provider.clone());
233    let model_id = LanguageModelId::from(settings.default_model.model.clone());
234    LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
235        registry.select_active_model(&provider_name, &model_id, cx);
236    });
237}
238
239fn register_slash_commands(cx: &mut AppContext) {
240    let slash_command_registry = SlashCommandRegistry::global(cx);
241    slash_command_registry.register_command(file_command::FileSlashCommand, true);
242    slash_command_registry.register_command(active_command::ActiveSlashCommand, true);
243    slash_command_registry.register_command(symbols_command::OutlineSlashCommand, true);
244    slash_command_registry.register_command(tabs_command::TabsSlashCommand, true);
245    slash_command_registry.register_command(project_command::ProjectSlashCommand, true);
246    slash_command_registry.register_command(search_command::SearchSlashCommand, true);
247    slash_command_registry.register_command(prompt_command::PromptSlashCommand, true);
248    slash_command_registry.register_command(default_command::DefaultSlashCommand, true);
249    slash_command_registry.register_command(term_command::TermSlashCommand, true);
250    slash_command_registry.register_command(now_command::NowSlashCommand, true);
251    slash_command_registry.register_command(diagnostics_command::DiagnosticsSlashCommand, true);
252    slash_command_registry.register_command(docs_command::DocsSlashCommand, true);
253    slash_command_registry.register_command(fetch_command::FetchSlashCommand, false);
254}
255
256pub fn humanize_token_count(count: usize) -> String {
257    match count {
258        0..=999 => count.to_string(),
259        1000..=9999 => {
260            let thousands = count / 1000;
261            let hundreds = (count % 1000 + 50) / 100;
262            if hundreds == 0 {
263                format!("{}k", thousands)
264            } else if hundreds == 10 {
265                format!("{}k", thousands + 1)
266            } else {
267                format!("{}.{}k", thousands, hundreds)
268            }
269        }
270        _ => format!("{}k", (count + 500) / 1000),
271    }
272}
273
274#[cfg(test)]
275#[ctor::ctor]
276fn init_logger() {
277    if std::env::var("RUST_LOG").is_ok() {
278        env_logger::init();
279    }
280}