assistant.rs

  1#![cfg_attr(target_os = "windows", allow(unused, dead_code))]
  2
  3pub mod assistant_panel;
  4pub mod assistant_settings;
  5mod context;
  6pub mod context_store;
  7mod inline_assistant;
  8mod model_selector;
  9mod prompt_library;
 10mod prompts;
 11mod slash_command;
 12mod streaming_diff;
 13mod terminal_inline_assistant;
 14
 15pub use assistant_panel::{AssistantPanel, AssistantPanelEvent};
 16use assistant_settings::AssistantSettings;
 17use assistant_slash_command::SlashCommandRegistry;
 18use client::{proto, Client};
 19use command_palette_hooks::CommandPaletteFilter;
 20pub use context::*;
 21pub use context_store::*;
 22use feature_flags::FeatureFlagAppExt;
 23use fs::Fs;
 24use gpui::{actions, impl_actions, AppContext, Global, SharedString, UpdateGlobal};
 25use indexed_docs::IndexedDocsRegistry;
 26pub(crate) use inline_assistant::*;
 27use language_model::{
 28    LanguageModelId, LanguageModelProviderId, LanguageModelRegistry, LanguageModelResponseMessage,
 29};
 30pub(crate) use model_selector::*;
 31pub use prompts::PromptBuilder;
 32use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
 33use serde::{Deserialize, Serialize};
 34use settings::{update_settings_file, Settings, SettingsStore};
 35use slash_command::{
 36    active_command, default_command, diagnostics_command, docs_command, fetch_command,
 37    file_command, now_command, project_command, prompt_command, search_command, symbols_command,
 38    tabs_command, term_command, workflow_command,
 39};
 40use std::sync::Arc;
 41pub(crate) use streaming_diff::*;
 42use util::ResultExt;
 43
 44actions!(
 45    assistant,
 46    [
 47        Assist,
 48        Split,
 49        CycleMessageRole,
 50        QuoteSelection,
 51        InsertIntoEditor,
 52        ToggleFocus,
 53        InsertActivePrompt,
 54        ShowConfiguration,
 55        DeployHistory,
 56        DeployPromptLibrary,
 57        ConfirmCommand,
 58        ToggleModelSelector,
 59        DebugWorkflowSteps
 60    ]
 61);
 62
 63const DEFAULT_CONTEXT_LINES: usize = 20;
 64
 65#[derive(Clone, Default, Deserialize, PartialEq)]
 66pub struct InlineAssist {
 67    prompt: Option<String>,
 68}
 69
 70impl_actions!(assistant, [InlineAssist]);
 71
 72#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
 73pub struct MessageId(clock::Lamport);
 74
 75impl MessageId {
 76    pub fn as_u64(self) -> u64 {
 77        self.0.as_u64()
 78    }
 79}
 80
 81#[derive(Deserialize, Debug)]
 82pub struct LanguageModelUsage {
 83    pub prompt_tokens: u32,
 84    pub completion_tokens: u32,
 85    pub total_tokens: u32,
 86}
 87
 88#[derive(Deserialize, Debug)]
 89pub struct LanguageModelChoiceDelta {
 90    pub index: u32,
 91    pub delta: LanguageModelResponseMessage,
 92    pub finish_reason: Option<String>,
 93}
 94
 95#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
 96pub enum MessageStatus {
 97    Pending,
 98    Done,
 99    Error(SharedString),
100}
101
102impl MessageStatus {
103    pub fn from_proto(status: proto::ContextMessageStatus) -> MessageStatus {
104        match status.variant {
105            Some(proto::context_message_status::Variant::Pending(_)) => MessageStatus::Pending,
106            Some(proto::context_message_status::Variant::Done(_)) => MessageStatus::Done,
107            Some(proto::context_message_status::Variant::Error(error)) => {
108                MessageStatus::Error(error.message.into())
109            }
110            None => MessageStatus::Pending,
111        }
112    }
113
114    pub fn to_proto(&self) -> proto::ContextMessageStatus {
115        match self {
116            MessageStatus::Pending => proto::ContextMessageStatus {
117                variant: Some(proto::context_message_status::Variant::Pending(
118                    proto::context_message_status::Pending {},
119                )),
120            },
121            MessageStatus::Done => proto::ContextMessageStatus {
122                variant: Some(proto::context_message_status::Variant::Done(
123                    proto::context_message_status::Done {},
124                )),
125            },
126            MessageStatus::Error(message) => proto::ContextMessageStatus {
127                variant: Some(proto::context_message_status::Variant::Error(
128                    proto::context_message_status::Error {
129                        message: message.to_string(),
130                    },
131                )),
132            },
133        }
134    }
135}
136
137/// The state pertaining to the Assistant.
138#[derive(Default)]
139struct Assistant {
140    /// Whether the Assistant is enabled.
141    enabled: bool,
142}
143
144impl Global for Assistant {}
145
146impl Assistant {
147    const NAMESPACE: &'static str = "assistant";
148
149    fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
150        if self.enabled == enabled {
151            return;
152        }
153
154        self.enabled = enabled;
155
156        if !enabled {
157            CommandPaletteFilter::update_global(cx, |filter, _cx| {
158                filter.hide_namespace(Self::NAMESPACE);
159            });
160
161            return;
162        }
163
164        CommandPaletteFilter::update_global(cx, |filter, _cx| {
165            filter.show_namespace(Self::NAMESPACE);
166        });
167    }
168}
169
170pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) -> Arc<PromptBuilder> {
171    cx.set_global(Assistant::default());
172    AssistantSettings::register(cx);
173
174    // TODO: remove this when 0.148.0 is released.
175    if AssistantSettings::get_global(cx).using_outdated_settings_version {
176        update_settings_file::<AssistantSettings>(fs.clone(), cx, {
177            let fs = fs.clone();
178            |content, cx| {
179                content.update_file(fs, cx);
180            }
181        });
182    }
183
184    cx.spawn(|mut cx| {
185        let client = client.clone();
186        async move {
187            let embedding_provider = CloudEmbeddingProvider::new(client.clone());
188            let semantic_index = SemanticIndex::new(
189                paths::embeddings_dir().join("semantic-index-db.0.mdb"),
190                Arc::new(embedding_provider),
191                &mut cx,
192            )
193            .await?;
194            cx.update(|cx| cx.set_global(semantic_index))
195        }
196    })
197    .detach();
198
199    context_store::init(&client);
200    prompt_library::init(cx);
201    init_language_model_settings(cx);
202    assistant_slash_command::init(cx);
203    assistant_panel::init(cx);
204
205    let prompt_builder = prompts::PromptBuilder::new(Some((fs.clone(), cx)))
206        .log_err()
207        .map(Arc::new)
208        .unwrap_or_else(|| Arc::new(prompts::PromptBuilder::new(None).unwrap()));
209    register_slash_commands(Some(prompt_builder.clone()), cx);
210    inline_assistant::init(
211        fs.clone(),
212        prompt_builder.clone(),
213        client.telemetry().clone(),
214        cx,
215    );
216    terminal_inline_assistant::init(
217        fs.clone(),
218        prompt_builder.clone(),
219        client.telemetry().clone(),
220        cx,
221    );
222    IndexedDocsRegistry::init_global(cx);
223
224    CommandPaletteFilter::update_global(cx, |filter, _cx| {
225        filter.hide_namespace(Assistant::NAMESPACE);
226    });
227    Assistant::update_global(cx, |assistant, cx| {
228        let settings = AssistantSettings::get_global(cx);
229
230        assistant.set_enabled(settings.enabled, cx);
231    });
232    cx.observe_global::<SettingsStore>(|cx| {
233        Assistant::update_global(cx, |assistant, cx| {
234            let settings = AssistantSettings::get_global(cx);
235            assistant.set_enabled(settings.enabled, cx);
236        });
237    })
238    .detach();
239
240    prompt_builder
241}
242
243fn init_language_model_settings(cx: &mut AppContext) {
244    update_active_language_model_from_settings(cx);
245
246    cx.observe_global::<SettingsStore>(update_active_language_model_from_settings)
247        .detach();
248    cx.subscribe(
249        &LanguageModelRegistry::global(cx),
250        |_, event: &language_model::Event, cx| match event {
251            language_model::Event::ProviderStateChanged
252            | language_model::Event::AddedProvider(_)
253            | language_model::Event::RemovedProvider(_) => {
254                update_active_language_model_from_settings(cx);
255            }
256            _ => {}
257        },
258    )
259    .detach();
260}
261
262fn update_active_language_model_from_settings(cx: &mut AppContext) {
263    let settings = AssistantSettings::get_global(cx);
264    let provider_name = LanguageModelProviderId::from(settings.default_model.provider.clone());
265    let model_id = LanguageModelId::from(settings.default_model.model.clone());
266    LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
267        registry.select_active_model(&provider_name, &model_id, cx);
268    });
269}
270
271fn register_slash_commands(prompt_builder: Option<Arc<PromptBuilder>>, cx: &mut AppContext) {
272    let slash_command_registry = SlashCommandRegistry::global(cx);
273    slash_command_registry.register_command(file_command::FileSlashCommand, true);
274    slash_command_registry.register_command(active_command::ActiveSlashCommand, true);
275    slash_command_registry.register_command(symbols_command::OutlineSlashCommand, true);
276    slash_command_registry.register_command(tabs_command::TabsSlashCommand, true);
277    slash_command_registry.register_command(project_command::ProjectSlashCommand, true);
278    slash_command_registry.register_command(prompt_command::PromptSlashCommand, true);
279    slash_command_registry.register_command(default_command::DefaultSlashCommand, true);
280    slash_command_registry.register_command(term_command::TermSlashCommand, true);
281    slash_command_registry.register_command(now_command::NowSlashCommand, true);
282    slash_command_registry.register_command(diagnostics_command::DiagnosticsSlashCommand, true);
283    if let Some(prompt_builder) = prompt_builder {
284        slash_command_registry.register_command(
285            workflow_command::WorkflowSlashCommand::new(prompt_builder),
286            true,
287        );
288    }
289    slash_command_registry.register_command(fetch_command::FetchSlashCommand, false);
290
291    cx.observe_flag::<docs_command::DocsSlashCommandFeatureFlag, _>({
292        let slash_command_registry = slash_command_registry.clone();
293        move |is_enabled, _cx| {
294            if is_enabled {
295                slash_command_registry.register_command(docs_command::DocsSlashCommand, true);
296            }
297        }
298    })
299    .detach();
300    cx.observe_flag::<search_command::SearchSlashCommandFeatureFlag, _>({
301        let slash_command_registry = slash_command_registry.clone();
302        move |is_enabled, _cx| {
303            if is_enabled {
304                slash_command_registry.register_command(search_command::SearchSlashCommand, true);
305            }
306        }
307    })
308    .detach();
309}
310
311pub fn humanize_token_count(count: usize) -> String {
312    match count {
313        0..=999 => count.to_string(),
314        1000..=9999 => {
315            let thousands = count / 1000;
316            let hundreds = (count % 1000 + 50) / 100;
317            if hundreds == 0 {
318                format!("{}k", thousands)
319            } else if hundreds == 10 {
320                format!("{}k", thousands + 1)
321            } else {
322                format!("{}.{}k", thousands, hundreds)
323            }
324        }
325        _ => format!("{}k", (count + 500) / 1000),
326    }
327}
328
329#[cfg(test)]
330#[ctor::ctor]
331fn init_logger() {
332    if std::env::var("RUST_LOG").is_ok() {
333        env_logger::init();
334    }
335}