assistant.rs

  1#![cfg_attr(target_os = "windows", allow(unused, dead_code))]
  2
  3pub mod assistant_panel;
  4pub mod assistant_settings;
  5mod context;
  6pub mod context_store;
  7mod inline_assistant;
  8mod model_selector;
  9mod prompt_library;
 10mod prompts;
 11mod slash_command;
 12mod streaming_diff;
 13mod terminal_inline_assistant;
 14
 15pub use assistant_panel::{AssistantPanel, AssistantPanelEvent};
 16use assistant_settings::AssistantSettings;
 17use assistant_slash_command::SlashCommandRegistry;
 18use client::{proto, Client};
 19use command_palette_hooks::CommandPaletteFilter;
 20pub use context::*;
 21pub use context_store::*;
 22use fs::Fs;
 23use gpui::{actions, impl_actions, AppContext, Global, SharedString, UpdateGlobal};
 24use indexed_docs::IndexedDocsRegistry;
 25pub(crate) use inline_assistant::*;
 26use language_model::{
 27    LanguageModelId, LanguageModelProviderId, LanguageModelRegistry, LanguageModelResponseMessage,
 28};
 29pub(crate) use model_selector::*;
 30pub use prompts::PromptBuilder;
 31use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
 32use serde::{Deserialize, Serialize};
 33use settings::{update_settings_file, Settings, SettingsStore};
 34use slash_command::{
 35    active_command, default_command, diagnostics_command, docs_command, fetch_command,
 36    file_command, now_command, project_command, prompt_command, search_command, symbols_command,
 37    tabs_command, term_command, workflow_command,
 38};
 39use std::sync::Arc;
 40pub(crate) use streaming_diff::*;
 41use util::ResultExt;
 42
 43actions!(
 44    assistant,
 45    [
 46        Assist,
 47        Split,
 48        CycleMessageRole,
 49        QuoteSelection,
 50        InsertIntoEditor,
 51        ToggleFocus,
 52        InsertActivePrompt,
 53        ShowConfiguration,
 54        DeployHistory,
 55        DeployPromptLibrary,
 56        ConfirmCommand,
 57        ToggleModelSelector,
 58        DebugEditSteps
 59    ]
 60);
 61
 62const DEFAULT_CONTEXT_LINES: usize = 20;
 63
 64#[derive(Clone, Default, Deserialize, PartialEq)]
 65pub struct InlineAssist {
 66    prompt: Option<String>,
 67}
 68
 69impl_actions!(assistant, [InlineAssist]);
 70
 71#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
 72pub struct MessageId(clock::Lamport);
 73
 74impl MessageId {
 75    pub fn as_u64(self) -> u64 {
 76        self.0.as_u64()
 77    }
 78}
 79
 80#[derive(Deserialize, Debug)]
 81pub struct LanguageModelUsage {
 82    pub prompt_tokens: u32,
 83    pub completion_tokens: u32,
 84    pub total_tokens: u32,
 85}
 86
 87#[derive(Deserialize, Debug)]
 88pub struct LanguageModelChoiceDelta {
 89    pub index: u32,
 90    pub delta: LanguageModelResponseMessage,
 91    pub finish_reason: Option<String>,
 92}
 93
 94#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
 95pub enum MessageStatus {
 96    Pending,
 97    Done,
 98    Error(SharedString),
 99}
100
101impl MessageStatus {
102    pub fn from_proto(status: proto::ContextMessageStatus) -> MessageStatus {
103        match status.variant {
104            Some(proto::context_message_status::Variant::Pending(_)) => MessageStatus::Pending,
105            Some(proto::context_message_status::Variant::Done(_)) => MessageStatus::Done,
106            Some(proto::context_message_status::Variant::Error(error)) => {
107                MessageStatus::Error(error.message.into())
108            }
109            None => MessageStatus::Pending,
110        }
111    }
112
113    pub fn to_proto(&self) -> proto::ContextMessageStatus {
114        match self {
115            MessageStatus::Pending => proto::ContextMessageStatus {
116                variant: Some(proto::context_message_status::Variant::Pending(
117                    proto::context_message_status::Pending {},
118                )),
119            },
120            MessageStatus::Done => proto::ContextMessageStatus {
121                variant: Some(proto::context_message_status::Variant::Done(
122                    proto::context_message_status::Done {},
123                )),
124            },
125            MessageStatus::Error(message) => proto::ContextMessageStatus {
126                variant: Some(proto::context_message_status::Variant::Error(
127                    proto::context_message_status::Error {
128                        message: message.to_string(),
129                    },
130                )),
131            },
132        }
133    }
134}
135
136/// The state pertaining to the Assistant.
137#[derive(Default)]
138struct Assistant {
139    /// Whether the Assistant is enabled.
140    enabled: bool,
141}
142
143impl Global for Assistant {}
144
145impl Assistant {
146    const NAMESPACE: &'static str = "assistant";
147
148    fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
149        if self.enabled == enabled {
150            return;
151        }
152
153        self.enabled = enabled;
154
155        if !enabled {
156            CommandPaletteFilter::update_global(cx, |filter, _cx| {
157                filter.hide_namespace(Self::NAMESPACE);
158            });
159
160            return;
161        }
162
163        CommandPaletteFilter::update_global(cx, |filter, _cx| {
164            filter.show_namespace(Self::NAMESPACE);
165        });
166    }
167}
168
169pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) -> Arc<PromptBuilder> {
170    cx.set_global(Assistant::default());
171    AssistantSettings::register(cx);
172
173    // TODO: remove this when 0.148.0 is released.
174    if AssistantSettings::get_global(cx).using_outdated_settings_version {
175        update_settings_file::<AssistantSettings>(fs.clone(), cx, {
176            let fs = fs.clone();
177            |content, cx| {
178                content.update_file(fs, cx);
179            }
180        });
181    }
182
183    cx.spawn(|mut cx| {
184        let client = client.clone();
185        async move {
186            let embedding_provider = CloudEmbeddingProvider::new(client.clone());
187            let semantic_index = SemanticIndex::new(
188                paths::embeddings_dir().join("semantic-index-db.0.mdb"),
189                Arc::new(embedding_provider),
190                &mut cx,
191            )
192            .await?;
193            cx.update(|cx| cx.set_global(semantic_index))
194        }
195    })
196    .detach();
197
198    context_store::init(&client);
199    prompt_library::init(cx);
200    init_language_model_settings(cx);
201    assistant_slash_command::init(cx);
202    assistant_panel::init(cx);
203
204    let prompt_builder = prompts::PromptBuilder::new(Some((fs.clone(), cx)))
205        .log_err()
206        .map(Arc::new)
207        .unwrap_or_else(|| Arc::new(prompts::PromptBuilder::new(None).unwrap()));
208    register_slash_commands(Some(prompt_builder.clone()), cx);
209    inline_assistant::init(
210        fs.clone(),
211        prompt_builder.clone(),
212        client.telemetry().clone(),
213        cx,
214    );
215    terminal_inline_assistant::init(
216        fs.clone(),
217        prompt_builder.clone(),
218        client.telemetry().clone(),
219        cx,
220    );
221    IndexedDocsRegistry::init_global(cx);
222
223    CommandPaletteFilter::update_global(cx, |filter, _cx| {
224        filter.hide_namespace(Assistant::NAMESPACE);
225    });
226    Assistant::update_global(cx, |assistant, cx| {
227        let settings = AssistantSettings::get_global(cx);
228
229        assistant.set_enabled(settings.enabled, cx);
230    });
231    cx.observe_global::<SettingsStore>(|cx| {
232        Assistant::update_global(cx, |assistant, cx| {
233            let settings = AssistantSettings::get_global(cx);
234            assistant.set_enabled(settings.enabled, cx);
235        });
236    })
237    .detach();
238
239    prompt_builder
240}
241
242fn init_language_model_settings(cx: &mut AppContext) {
243    update_active_language_model_from_settings(cx);
244
245    cx.observe_global::<SettingsStore>(update_active_language_model_from_settings)
246        .detach();
247    cx.subscribe(
248        &LanguageModelRegistry::global(cx),
249        |_, event: &language_model::Event, cx| match event {
250            language_model::Event::ProviderStateChanged
251            | language_model::Event::AddedProvider(_)
252            | language_model::Event::RemovedProvider(_) => {
253                update_active_language_model_from_settings(cx);
254            }
255            _ => {}
256        },
257    )
258    .detach();
259}
260
261fn update_active_language_model_from_settings(cx: &mut AppContext) {
262    let settings = AssistantSettings::get_global(cx);
263    let provider_name = LanguageModelProviderId::from(settings.default_model.provider.clone());
264    let model_id = LanguageModelId::from(settings.default_model.model.clone());
265    LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
266        registry.select_active_model(&provider_name, &model_id, cx);
267    });
268}
269
270fn register_slash_commands(prompt_builder: Option<Arc<PromptBuilder>>, cx: &mut AppContext) {
271    let slash_command_registry = SlashCommandRegistry::global(cx);
272    slash_command_registry.register_command(file_command::FileSlashCommand, true);
273    slash_command_registry.register_command(active_command::ActiveSlashCommand, true);
274    slash_command_registry.register_command(symbols_command::OutlineSlashCommand, true);
275    slash_command_registry.register_command(tabs_command::TabsSlashCommand, true);
276    slash_command_registry.register_command(project_command::ProjectSlashCommand, true);
277    slash_command_registry.register_command(search_command::SearchSlashCommand, true);
278    slash_command_registry.register_command(prompt_command::PromptSlashCommand, true);
279    slash_command_registry.register_command(default_command::DefaultSlashCommand, true);
280    slash_command_registry.register_command(term_command::TermSlashCommand, true);
281    slash_command_registry.register_command(now_command::NowSlashCommand, true);
282    slash_command_registry.register_command(diagnostics_command::DiagnosticsSlashCommand, true);
283    slash_command_registry.register_command(docs_command::DocsSlashCommand, true);
284    if let Some(prompt_builder) = prompt_builder {
285        slash_command_registry.register_command(
286            workflow_command::WorkflowSlashCommand::new(prompt_builder),
287            true,
288        );
289    }
290    slash_command_registry.register_command(fetch_command::FetchSlashCommand, false);
291}
292
293pub fn humanize_token_count(count: usize) -> String {
294    match count {
295        0..=999 => count.to_string(),
296        1000..=9999 => {
297            let thousands = count / 1000;
298            let hundreds = (count % 1000 + 50) / 100;
299            if hundreds == 0 {
300                format!("{}k", thousands)
301            } else if hundreds == 10 {
302                format!("{}k", thousands + 1)
303            } else {
304                format!("{}.{}k", thousands, hundreds)
305            }
306        }
307        _ => format!("{}k", (count + 500) / 1000),
308    }
309}
310
311#[cfg(test)]
312#[ctor::ctor]
313fn init_logger() {
314    if std::env::var("RUST_LOG").is_ok() {
315        env_logger::init();
316    }
317}