assistant.rs

  1#![cfg_attr(target_os = "windows", allow(unused, dead_code))]
  2
  3pub mod assistant_panel;
  4pub mod assistant_settings;
  5mod context;
  6pub(crate) mod context_inspector;
  7pub mod context_store;
  8mod inline_assistant;
  9mod model_selector;
 10mod prompt_library;
 11mod prompts;
 12mod slash_command;
 13mod streaming_diff;
 14mod terminal_inline_assistant;
 15
 16pub use assistant_panel::{AssistantPanel, AssistantPanelEvent};
 17use assistant_settings::AssistantSettings;
 18use assistant_slash_command::SlashCommandRegistry;
 19use client::{proto, Client};
 20use command_palette_hooks::CommandPaletteFilter;
 21pub use context::*;
 22pub use context_store::*;
 23use feature_flags::FeatureFlagAppExt;
 24use fs::Fs;
 25use gpui::{actions, impl_actions, AppContext, Global, SharedString, UpdateGlobal};
 26use indexed_docs::IndexedDocsRegistry;
 27pub(crate) use inline_assistant::*;
 28use language_model::{
 29    LanguageModelId, LanguageModelProviderId, LanguageModelRegistry, LanguageModelResponseMessage,
 30};
 31pub(crate) use model_selector::*;
 32pub use prompts::PromptBuilder;
 33use prompts::PromptOverrideContext;
 34use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
 35use serde::{Deserialize, Serialize};
 36use settings::{update_settings_file, Settings, SettingsStore};
 37use slash_command::{
 38    default_command, diagnostics_command, docs_command, fetch_command, file_command, now_command,
 39    project_command, prompt_command, search_command, symbols_command, tabs_command,
 40    terminal_command, workflow_command,
 41};
 42use std::sync::Arc;
 43pub(crate) use streaming_diff::*;
 44use util::ResultExt;
 45
 46actions!(
 47    assistant,
 48    [
 49        Assist,
 50        Split,
 51        CycleMessageRole,
 52        QuoteSelection,
 53        InsertIntoEditor,
 54        ToggleFocus,
 55        InsertActivePrompt,
 56        ShowConfiguration,
 57        DeployHistory,
 58        DeployPromptLibrary,
 59        ConfirmCommand,
 60        ToggleModelSelector,
 61        DebugWorkflowSteps
 62    ]
 63);
 64
 65const DEFAULT_CONTEXT_LINES: usize = 20;
 66
 67#[derive(Clone, Default, Deserialize, PartialEq)]
 68pub struct InlineAssist {
 69    prompt: Option<String>,
 70}
 71
 72impl_actions!(assistant, [InlineAssist]);
 73
 74#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
 75pub struct MessageId(clock::Lamport);
 76
 77impl MessageId {
 78    pub fn as_u64(self) -> u64 {
 79        self.0.as_u64()
 80    }
 81}
 82
 83#[derive(Deserialize, Debug)]
 84pub struct LanguageModelUsage {
 85    pub prompt_tokens: u32,
 86    pub completion_tokens: u32,
 87    pub total_tokens: u32,
 88}
 89
 90#[derive(Deserialize, Debug)]
 91pub struct LanguageModelChoiceDelta {
 92    pub index: u32,
 93    pub delta: LanguageModelResponseMessage,
 94    pub finish_reason: Option<String>,
 95}
 96
 97#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
 98pub enum MessageStatus {
 99    Pending,
100    Done,
101    Error(SharedString),
102}
103
104impl MessageStatus {
105    pub fn from_proto(status: proto::ContextMessageStatus) -> MessageStatus {
106        match status.variant {
107            Some(proto::context_message_status::Variant::Pending(_)) => MessageStatus::Pending,
108            Some(proto::context_message_status::Variant::Done(_)) => MessageStatus::Done,
109            Some(proto::context_message_status::Variant::Error(error)) => {
110                MessageStatus::Error(error.message.into())
111            }
112            None => MessageStatus::Pending,
113        }
114    }
115
116    pub fn to_proto(&self) -> proto::ContextMessageStatus {
117        match self {
118            MessageStatus::Pending => proto::ContextMessageStatus {
119                variant: Some(proto::context_message_status::Variant::Pending(
120                    proto::context_message_status::Pending {},
121                )),
122            },
123            MessageStatus::Done => proto::ContextMessageStatus {
124                variant: Some(proto::context_message_status::Variant::Done(
125                    proto::context_message_status::Done {},
126                )),
127            },
128            MessageStatus::Error(message) => proto::ContextMessageStatus {
129                variant: Some(proto::context_message_status::Variant::Error(
130                    proto::context_message_status::Error {
131                        message: message.to_string(),
132                    },
133                )),
134            },
135        }
136    }
137}
138
139/// The state pertaining to the Assistant.
140#[derive(Default)]
141struct Assistant {
142    /// Whether the Assistant is enabled.
143    enabled: bool,
144}
145
146impl Global for Assistant {}
147
148impl Assistant {
149    const NAMESPACE: &'static str = "assistant";
150
151    fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
152        if self.enabled == enabled {
153            return;
154        }
155
156        self.enabled = enabled;
157
158        if !enabled {
159            CommandPaletteFilter::update_global(cx, |filter, _cx| {
160                filter.hide_namespace(Self::NAMESPACE);
161            });
162
163            return;
164        }
165
166        CommandPaletteFilter::update_global(cx, |filter, _cx| {
167            filter.show_namespace(Self::NAMESPACE);
168        });
169    }
170}
171
172pub fn init(
173    fs: Arc<dyn Fs>,
174    client: Arc<Client>,
175    dev_mode: bool,
176    cx: &mut AppContext,
177) -> Arc<PromptBuilder> {
178    cx.set_global(Assistant::default());
179    AssistantSettings::register(cx);
180
181    // TODO: remove this when 0.148.0 is released.
182    if AssistantSettings::get_global(cx).using_outdated_settings_version {
183        update_settings_file::<AssistantSettings>(fs.clone(), cx, {
184            let fs = fs.clone();
185            |content, cx| {
186                content.update_file(fs, cx);
187            }
188        });
189    }
190
191    cx.spawn(|mut cx| {
192        let client = client.clone();
193        async move {
194            let embedding_provider = CloudEmbeddingProvider::new(client.clone());
195            let semantic_index = SemanticIndex::new(
196                paths::embeddings_dir().join("semantic-index-db.0.mdb"),
197                Arc::new(embedding_provider),
198                &mut cx,
199            )
200            .await?;
201            cx.update(|cx| cx.set_global(semantic_index))
202        }
203    })
204    .detach();
205
206    context_store::init(&client);
207    prompt_library::init(cx);
208    init_language_model_settings(cx);
209    assistant_slash_command::init(cx);
210    assistant_panel::init(cx);
211
212    let prompt_builder = prompts::PromptBuilder::new(Some(PromptOverrideContext {
213        dev_mode,
214        fs: fs.clone(),
215        cx,
216    }))
217    .log_err()
218    .map(Arc::new)
219    .unwrap_or_else(|| Arc::new(prompts::PromptBuilder::new(None).unwrap()));
220    register_slash_commands(Some(prompt_builder.clone()), cx);
221    inline_assistant::init(
222        fs.clone(),
223        prompt_builder.clone(),
224        client.telemetry().clone(),
225        cx,
226    );
227    terminal_inline_assistant::init(
228        fs.clone(),
229        prompt_builder.clone(),
230        client.telemetry().clone(),
231        cx,
232    );
233    IndexedDocsRegistry::init_global(cx);
234
235    CommandPaletteFilter::update_global(cx, |filter, _cx| {
236        filter.hide_namespace(Assistant::NAMESPACE);
237    });
238    Assistant::update_global(cx, |assistant, cx| {
239        let settings = AssistantSettings::get_global(cx);
240
241        assistant.set_enabled(settings.enabled, cx);
242    });
243    cx.observe_global::<SettingsStore>(|cx| {
244        Assistant::update_global(cx, |assistant, cx| {
245            let settings = AssistantSettings::get_global(cx);
246            assistant.set_enabled(settings.enabled, cx);
247        });
248    })
249    .detach();
250
251    prompt_builder
252}
253
254fn init_language_model_settings(cx: &mut AppContext) {
255    update_active_language_model_from_settings(cx);
256
257    cx.observe_global::<SettingsStore>(update_active_language_model_from_settings)
258        .detach();
259    cx.subscribe(
260        &LanguageModelRegistry::global(cx),
261        |_, event: &language_model::Event, cx| match event {
262            language_model::Event::ProviderStateChanged
263            | language_model::Event::AddedProvider(_)
264            | language_model::Event::RemovedProvider(_) => {
265                update_active_language_model_from_settings(cx);
266            }
267            _ => {}
268        },
269    )
270    .detach();
271}
272
273fn update_active_language_model_from_settings(cx: &mut AppContext) {
274    let settings = AssistantSettings::get_global(cx);
275    let provider_name = LanguageModelProviderId::from(settings.default_model.provider.clone());
276    let model_id = LanguageModelId::from(settings.default_model.model.clone());
277    LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
278        registry.select_active_model(&provider_name, &model_id, cx);
279    });
280}
281
282fn register_slash_commands(prompt_builder: Option<Arc<PromptBuilder>>, cx: &mut AppContext) {
283    let slash_command_registry = SlashCommandRegistry::global(cx);
284    slash_command_registry.register_command(file_command::FileSlashCommand, true);
285    slash_command_registry.register_command(symbols_command::OutlineSlashCommand, true);
286    slash_command_registry.register_command(tabs_command::TabsSlashCommand, true);
287    slash_command_registry.register_command(project_command::ProjectSlashCommand, true);
288    slash_command_registry.register_command(prompt_command::PromptSlashCommand, true);
289    slash_command_registry.register_command(default_command::DefaultSlashCommand, false);
290    slash_command_registry.register_command(terminal_command::TerminalSlashCommand, true);
291    slash_command_registry.register_command(now_command::NowSlashCommand, false);
292    slash_command_registry.register_command(diagnostics_command::DiagnosticsSlashCommand, true);
293    if let Some(prompt_builder) = prompt_builder {
294        slash_command_registry.register_command(
295            workflow_command::WorkflowSlashCommand::new(prompt_builder),
296            true,
297        );
298    }
299    slash_command_registry.register_command(fetch_command::FetchSlashCommand, false);
300
301    cx.observe_flag::<docs_command::DocsSlashCommandFeatureFlag, _>({
302        let slash_command_registry = slash_command_registry.clone();
303        move |is_enabled, _cx| {
304            if is_enabled {
305                slash_command_registry.register_command(docs_command::DocsSlashCommand, true);
306            }
307        }
308    })
309    .detach();
310    cx.observe_flag::<search_command::SearchSlashCommandFeatureFlag, _>({
311        let slash_command_registry = slash_command_registry.clone();
312        move |is_enabled, _cx| {
313            if is_enabled {
314                slash_command_registry.register_command(search_command::SearchSlashCommand, true);
315            }
316        }
317    })
318    .detach();
319}
320
321pub fn humanize_token_count(count: usize) -> String {
322    match count {
323        0..=999 => count.to_string(),
324        1000..=9999 => {
325            let thousands = count / 1000;
326            let hundreds = (count % 1000 + 50) / 100;
327            if hundreds == 0 {
328                format!("{}k", thousands)
329            } else if hundreds == 10 {
330                format!("{}k", thousands + 1)
331            } else {
332                format!("{}.{}k", thousands, hundreds)
333            }
334        }
335        _ => format!("{}k", (count + 500) / 1000),
336    }
337}
338
339#[cfg(test)]
340#[ctor::ctor]
341fn init_logger() {
342    if std::env::var("RUST_LOG").is_ok() {
343        env_logger::init();
344    }
345}