assistant.rs

  1pub mod assistant_panel;
  2pub mod assistant_settings;
  3mod context;
  4pub mod context_store;
  5mod inline_assistant;
  6mod model_selector;
  7mod prompt_library;
  8mod prompts;
  9mod slash_command;
 10mod streaming_diff;
 11mod terminal_inline_assistant;
 12
 13pub use assistant_panel::{AssistantPanel, AssistantPanelEvent};
 14use assistant_settings::AssistantSettings;
 15use assistant_slash_command::SlashCommandRegistry;
 16use client::{proto, Client};
 17use command_palette_hooks::CommandPaletteFilter;
 18pub use context::*;
 19pub use context_store::*;
 20use feature_flags::FeatureFlagAppExt;
 21use fs::Fs;
 22use gpui::{actions, impl_actions, AppContext, Global, SharedString, UpdateGlobal};
 23use indexed_docs::IndexedDocsRegistry;
 24pub(crate) use inline_assistant::*;
 25use language_model::{
 26    LanguageModelId, LanguageModelProviderId, LanguageModelRegistry, LanguageModelResponseMessage,
 27};
 28pub(crate) use model_selector::*;
 29pub use prompts::PromptBuilder;
 30use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
 31use serde::{Deserialize, Serialize};
 32use settings::{update_settings_file, Settings, SettingsStore};
 33use slash_command::{
 34    active_command, default_command, diagnostics_command, docs_command, fetch_command,
 35    file_command, now_command, project_command, prompt_command, search_command, symbols_command,
 36    tabs_command, term_command, workflow_command,
 37};
 38use std::sync::Arc;
 39pub(crate) use streaming_diff::*;
 40use util::ResultExt;
 41
 42actions!(
 43    assistant,
 44    [
 45        Assist,
 46        Split,
 47        CycleMessageRole,
 48        QuoteSelection,
 49        InsertIntoEditor,
 50        ToggleFocus,
 51        InsertActivePrompt,
 52        ShowConfiguration,
 53        DeployHistory,
 54        DeployPromptLibrary,
 55        ConfirmCommand,
 56        ToggleModelSelector,
 57        DebugWorkflowSteps
 58    ]
 59);
 60
 61const DEFAULT_CONTEXT_LINES: usize = 20;
 62
 63#[derive(Clone, Default, Deserialize, PartialEq)]
 64pub struct InlineAssist {
 65    prompt: Option<String>,
 66}
 67
 68impl_actions!(assistant, [InlineAssist]);
 69
 70#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
 71pub struct MessageId(clock::Lamport);
 72
 73impl MessageId {
 74    pub fn as_u64(self) -> u64 {
 75        self.0.as_u64()
 76    }
 77}
 78
 79#[derive(Deserialize, Debug)]
 80pub struct LanguageModelUsage {
 81    pub prompt_tokens: u32,
 82    pub completion_tokens: u32,
 83    pub total_tokens: u32,
 84}
 85
 86#[derive(Deserialize, Debug)]
 87pub struct LanguageModelChoiceDelta {
 88    pub index: u32,
 89    pub delta: LanguageModelResponseMessage,
 90    pub finish_reason: Option<String>,
 91}
 92
 93#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
 94pub enum MessageStatus {
 95    Pending,
 96    Done,
 97    Error(SharedString),
 98}
 99
100impl MessageStatus {
101    pub fn from_proto(status: proto::ContextMessageStatus) -> MessageStatus {
102        match status.variant {
103            Some(proto::context_message_status::Variant::Pending(_)) => MessageStatus::Pending,
104            Some(proto::context_message_status::Variant::Done(_)) => MessageStatus::Done,
105            Some(proto::context_message_status::Variant::Error(error)) => {
106                MessageStatus::Error(error.message.into())
107            }
108            None => MessageStatus::Pending,
109        }
110    }
111
112    pub fn to_proto(&self) -> proto::ContextMessageStatus {
113        match self {
114            MessageStatus::Pending => proto::ContextMessageStatus {
115                variant: Some(proto::context_message_status::Variant::Pending(
116                    proto::context_message_status::Pending {},
117                )),
118            },
119            MessageStatus::Done => proto::ContextMessageStatus {
120                variant: Some(proto::context_message_status::Variant::Done(
121                    proto::context_message_status::Done {},
122                )),
123            },
124            MessageStatus::Error(message) => proto::ContextMessageStatus {
125                variant: Some(proto::context_message_status::Variant::Error(
126                    proto::context_message_status::Error {
127                        message: message.to_string(),
128                    },
129                )),
130            },
131        }
132    }
133}
134
135/// The state pertaining to the Assistant.
136#[derive(Default)]
137struct Assistant {
138    /// Whether the Assistant is enabled.
139    enabled: bool,
140}
141
142impl Global for Assistant {}
143
144impl Assistant {
145    const NAMESPACE: &'static str = "assistant";
146
147    fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
148        if self.enabled == enabled {
149            return;
150        }
151
152        self.enabled = enabled;
153
154        if !enabled {
155            CommandPaletteFilter::update_global(cx, |filter, _cx| {
156                filter.hide_namespace(Self::NAMESPACE);
157            });
158
159            return;
160        }
161
162        CommandPaletteFilter::update_global(cx, |filter, _cx| {
163            filter.show_namespace(Self::NAMESPACE);
164        });
165    }
166}
167
168pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) -> Arc<PromptBuilder> {
169    cx.set_global(Assistant::default());
170    AssistantSettings::register(cx);
171
172    // TODO: remove this when 0.148.0 is released.
173    if AssistantSettings::get_global(cx).using_outdated_settings_version {
174        update_settings_file::<AssistantSettings>(fs.clone(), cx, {
175            let fs = fs.clone();
176            |content, cx| {
177                content.update_file(fs, cx);
178            }
179        });
180    }
181
182    cx.spawn(|mut cx| {
183        let client = client.clone();
184        async move {
185            let embedding_provider = CloudEmbeddingProvider::new(client.clone());
186            let semantic_index = SemanticIndex::new(
187                paths::embeddings_dir().join("semantic-index-db.0.mdb"),
188                Arc::new(embedding_provider),
189                &mut cx,
190            )
191            .await?;
192            cx.update(|cx| cx.set_global(semantic_index))
193        }
194    })
195    .detach();
196
197    context_store::init(&client);
198    prompt_library::init(cx);
199    init_language_model_settings(cx);
200    assistant_slash_command::init(cx);
201    assistant_panel::init(cx);
202
203    let prompt_builder = prompts::PromptBuilder::new(Some((fs.clone(), cx)))
204        .log_err()
205        .map(Arc::new)
206        .unwrap_or_else(|| Arc::new(prompts::PromptBuilder::new(None).unwrap()));
207    register_slash_commands(Some(prompt_builder.clone()), cx);
208    inline_assistant::init(
209        fs.clone(),
210        prompt_builder.clone(),
211        client.telemetry().clone(),
212        cx,
213    );
214    terminal_inline_assistant::init(
215        fs.clone(),
216        prompt_builder.clone(),
217        client.telemetry().clone(),
218        cx,
219    );
220    IndexedDocsRegistry::init_global(cx);
221
222    CommandPaletteFilter::update_global(cx, |filter, _cx| {
223        filter.hide_namespace(Assistant::NAMESPACE);
224    });
225    Assistant::update_global(cx, |assistant, cx| {
226        let settings = AssistantSettings::get_global(cx);
227
228        assistant.set_enabled(settings.enabled, cx);
229    });
230    cx.observe_global::<SettingsStore>(|cx| {
231        Assistant::update_global(cx, |assistant, cx| {
232            let settings = AssistantSettings::get_global(cx);
233            assistant.set_enabled(settings.enabled, cx);
234        });
235    })
236    .detach();
237
238    prompt_builder
239}
240
241fn init_language_model_settings(cx: &mut AppContext) {
242    update_active_language_model_from_settings(cx);
243
244    cx.observe_global::<SettingsStore>(update_active_language_model_from_settings)
245        .detach();
246    cx.subscribe(
247        &LanguageModelRegistry::global(cx),
248        |_, event: &language_model::Event, cx| match event {
249            language_model::Event::ProviderStateChanged
250            | language_model::Event::AddedProvider(_)
251            | language_model::Event::RemovedProvider(_) => {
252                update_active_language_model_from_settings(cx);
253            }
254            _ => {}
255        },
256    )
257    .detach();
258}
259
260fn update_active_language_model_from_settings(cx: &mut AppContext) {
261    let settings = AssistantSettings::get_global(cx);
262    let provider_name = LanguageModelProviderId::from(settings.default_model.provider.clone());
263    let model_id = LanguageModelId::from(settings.default_model.model.clone());
264    LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
265        registry.select_active_model(&provider_name, &model_id, cx);
266    });
267}
268
269fn register_slash_commands(prompt_builder: Option<Arc<PromptBuilder>>, cx: &mut AppContext) {
270    let slash_command_registry = SlashCommandRegistry::global(cx);
271    slash_command_registry.register_command(file_command::FileSlashCommand, true);
272    slash_command_registry.register_command(active_command::ActiveSlashCommand, true);
273    slash_command_registry.register_command(symbols_command::OutlineSlashCommand, true);
274    slash_command_registry.register_command(tabs_command::TabsSlashCommand, true);
275    slash_command_registry.register_command(project_command::ProjectSlashCommand, true);
276    slash_command_registry.register_command(prompt_command::PromptSlashCommand, true);
277    slash_command_registry.register_command(default_command::DefaultSlashCommand, true);
278    slash_command_registry.register_command(term_command::TermSlashCommand, true);
279    slash_command_registry.register_command(now_command::NowSlashCommand, true);
280    slash_command_registry.register_command(diagnostics_command::DiagnosticsSlashCommand, true);
281    slash_command_registry.register_command(docs_command::DocsSlashCommand, true);
282    if let Some(prompt_builder) = prompt_builder {
283        slash_command_registry.register_command(
284            workflow_command::WorkflowSlashCommand::new(prompt_builder),
285            true,
286        );
287    }
288    slash_command_registry.register_command(fetch_command::FetchSlashCommand, false);
289
290    cx.observe_flag::<search_command::SearchSlashCommandFeatureFlag, _>(move |is_enabled, _cx| {
291        if is_enabled {
292            slash_command_registry.register_command(search_command::SearchSlashCommand, true);
293        }
294    })
295    .detach();
296}
297
298pub fn humanize_token_count(count: usize) -> String {
299    match count {
300        0..=999 => count.to_string(),
301        1000..=9999 => {
302            let thousands = count / 1000;
303            let hundreds = (count % 1000 + 50) / 100;
304            if hundreds == 0 {
305                format!("{}k", thousands)
306            } else if hundreds == 10 {
307                format!("{}k", thousands + 1)
308            } else {
309                format!("{}.{}k", thousands, hundreds)
310            }
311        }
312        _ => format!("{}k", (count + 500) / 1000),
313    }
314}
315
316#[cfg(test)]
317#[ctor::ctor]
318fn init_logger() {
319    if std::env::var("RUST_LOG").is_ok() {
320        env_logger::init();
321    }
322}