assistant.rs

  1#![cfg_attr(target_os = "windows", allow(unused, dead_code))]
  2
  3pub mod assistant_panel;
  4pub mod assistant_settings;
  5mod context;
  6pub(crate) mod context_inspector;
  7pub mod context_store;
  8mod inline_assistant;
  9mod model_selector;
 10mod prompt_library;
 11mod prompts;
 12mod slash_command;
 13mod streaming_diff;
 14mod terminal_inline_assistant;
 15
 16pub use assistant_panel::{AssistantPanel, AssistantPanelEvent};
 17use assistant_settings::AssistantSettings;
 18use assistant_slash_command::SlashCommandRegistry;
 19use client::{proto, Client};
 20use command_palette_hooks::CommandPaletteFilter;
 21pub use context::*;
 22pub use context_store::*;
 23use feature_flags::FeatureFlagAppExt;
 24use fs::Fs;
 25use gpui::{actions, impl_actions, AppContext, Global, SharedString, UpdateGlobal};
 26use indexed_docs::IndexedDocsRegistry;
 27pub(crate) use inline_assistant::*;
 28use language_model::{
 29    LanguageModelId, LanguageModelProviderId, LanguageModelRegistry, LanguageModelResponseMessage,
 30};
 31pub(crate) use model_selector::*;
 32pub use prompts::PromptBuilder;
 33use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
 34use serde::{Deserialize, Serialize};
 35use settings::{update_settings_file, Settings, SettingsStore};
 36use slash_command::{
 37    active_command, default_command, diagnostics_command, docs_command, fetch_command,
 38    file_command, now_command, project_command, prompt_command, search_command, symbols_command,
 39    tabs_command, term_command, workflow_command,
 40};
 41use std::sync::Arc;
 42pub(crate) use streaming_diff::*;
 43use util::ResultExt;
 44
 45actions!(
 46    assistant,
 47    [
 48        Assist,
 49        Split,
 50        CycleMessageRole,
 51        QuoteSelection,
 52        InsertIntoEditor,
 53        ToggleFocus,
 54        InsertActivePrompt,
 55        ShowConfiguration,
 56        DeployHistory,
 57        DeployPromptLibrary,
 58        ConfirmCommand,
 59        ToggleModelSelector,
 60        DebugWorkflowSteps
 61    ]
 62);
 63
 64const DEFAULT_CONTEXT_LINES: usize = 20;
 65
 66#[derive(Clone, Default, Deserialize, PartialEq)]
 67pub struct InlineAssist {
 68    prompt: Option<String>,
 69}
 70
 71impl_actions!(assistant, [InlineAssist]);
 72
 73#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
 74pub struct MessageId(clock::Lamport);
 75
 76impl MessageId {
 77    pub fn as_u64(self) -> u64 {
 78        self.0.as_u64()
 79    }
 80}
 81
 82#[derive(Deserialize, Debug)]
 83pub struct LanguageModelUsage {
 84    pub prompt_tokens: u32,
 85    pub completion_tokens: u32,
 86    pub total_tokens: u32,
 87}
 88
 89#[derive(Deserialize, Debug)]
 90pub struct LanguageModelChoiceDelta {
 91    pub index: u32,
 92    pub delta: LanguageModelResponseMessage,
 93    pub finish_reason: Option<String>,
 94}
 95
 96#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
 97pub enum MessageStatus {
 98    Pending,
 99    Done,
100    Error(SharedString),
101}
102
103impl MessageStatus {
104    pub fn from_proto(status: proto::ContextMessageStatus) -> MessageStatus {
105        match status.variant {
106            Some(proto::context_message_status::Variant::Pending(_)) => MessageStatus::Pending,
107            Some(proto::context_message_status::Variant::Done(_)) => MessageStatus::Done,
108            Some(proto::context_message_status::Variant::Error(error)) => {
109                MessageStatus::Error(error.message.into())
110            }
111            None => MessageStatus::Pending,
112        }
113    }
114
115    pub fn to_proto(&self) -> proto::ContextMessageStatus {
116        match self {
117            MessageStatus::Pending => proto::ContextMessageStatus {
118                variant: Some(proto::context_message_status::Variant::Pending(
119                    proto::context_message_status::Pending {},
120                )),
121            },
122            MessageStatus::Done => proto::ContextMessageStatus {
123                variant: Some(proto::context_message_status::Variant::Done(
124                    proto::context_message_status::Done {},
125                )),
126            },
127            MessageStatus::Error(message) => proto::ContextMessageStatus {
128                variant: Some(proto::context_message_status::Variant::Error(
129                    proto::context_message_status::Error {
130                        message: message.to_string(),
131                    },
132                )),
133            },
134        }
135    }
136}
137
138/// The state pertaining to the Assistant.
139#[derive(Default)]
140struct Assistant {
141    /// Whether the Assistant is enabled.
142    enabled: bool,
143}
144
145impl Global for Assistant {}
146
147impl Assistant {
148    const NAMESPACE: &'static str = "assistant";
149
150    fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
151        if self.enabled == enabled {
152            return;
153        }
154
155        self.enabled = enabled;
156
157        if !enabled {
158            CommandPaletteFilter::update_global(cx, |filter, _cx| {
159                filter.hide_namespace(Self::NAMESPACE);
160            });
161
162            return;
163        }
164
165        CommandPaletteFilter::update_global(cx, |filter, _cx| {
166            filter.show_namespace(Self::NAMESPACE);
167        });
168    }
169}
170
171pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) -> Arc<PromptBuilder> {
172    cx.set_global(Assistant::default());
173    AssistantSettings::register(cx);
174
175    // TODO: remove this when 0.148.0 is released.
176    if AssistantSettings::get_global(cx).using_outdated_settings_version {
177        update_settings_file::<AssistantSettings>(fs.clone(), cx, {
178            let fs = fs.clone();
179            |content, cx| {
180                content.update_file(fs, cx);
181            }
182        });
183    }
184
185    cx.spawn(|mut cx| {
186        let client = client.clone();
187        async move {
188            let embedding_provider = CloudEmbeddingProvider::new(client.clone());
189            let semantic_index = SemanticIndex::new(
190                paths::embeddings_dir().join("semantic-index-db.0.mdb"),
191                Arc::new(embedding_provider),
192                &mut cx,
193            )
194            .await?;
195            cx.update(|cx| cx.set_global(semantic_index))
196        }
197    })
198    .detach();
199
200    context_store::init(&client);
201    prompt_library::init(cx);
202    init_language_model_settings(cx);
203    assistant_slash_command::init(cx);
204    assistant_panel::init(cx);
205
206    let prompt_builder = prompts::PromptBuilder::new(Some((fs.clone(), cx)))
207        .log_err()
208        .map(Arc::new)
209        .unwrap_or_else(|| Arc::new(prompts::PromptBuilder::new(None).unwrap()));
210    register_slash_commands(Some(prompt_builder.clone()), cx);
211    inline_assistant::init(
212        fs.clone(),
213        prompt_builder.clone(),
214        client.telemetry().clone(),
215        cx,
216    );
217    terminal_inline_assistant::init(
218        fs.clone(),
219        prompt_builder.clone(),
220        client.telemetry().clone(),
221        cx,
222    );
223    IndexedDocsRegistry::init_global(cx);
224
225    CommandPaletteFilter::update_global(cx, |filter, _cx| {
226        filter.hide_namespace(Assistant::NAMESPACE);
227    });
228    Assistant::update_global(cx, |assistant, cx| {
229        let settings = AssistantSettings::get_global(cx);
230
231        assistant.set_enabled(settings.enabled, cx);
232    });
233    cx.observe_global::<SettingsStore>(|cx| {
234        Assistant::update_global(cx, |assistant, cx| {
235            let settings = AssistantSettings::get_global(cx);
236            assistant.set_enabled(settings.enabled, cx);
237        });
238    })
239    .detach();
240
241    prompt_builder
242}
243
244fn init_language_model_settings(cx: &mut AppContext) {
245    update_active_language_model_from_settings(cx);
246
247    cx.observe_global::<SettingsStore>(update_active_language_model_from_settings)
248        .detach();
249    cx.subscribe(
250        &LanguageModelRegistry::global(cx),
251        |_, event: &language_model::Event, cx| match event {
252            language_model::Event::ProviderStateChanged
253            | language_model::Event::AddedProvider(_)
254            | language_model::Event::RemovedProvider(_) => {
255                update_active_language_model_from_settings(cx);
256            }
257            _ => {}
258        },
259    )
260    .detach();
261}
262
263fn update_active_language_model_from_settings(cx: &mut AppContext) {
264    let settings = AssistantSettings::get_global(cx);
265    let provider_name = LanguageModelProviderId::from(settings.default_model.provider.clone());
266    let model_id = LanguageModelId::from(settings.default_model.model.clone());
267    LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
268        registry.select_active_model(&provider_name, &model_id, cx);
269    });
270}
271
272fn register_slash_commands(prompt_builder: Option<Arc<PromptBuilder>>, cx: &mut AppContext) {
273    let slash_command_registry = SlashCommandRegistry::global(cx);
274    slash_command_registry.register_command(file_command::FileSlashCommand, true);
275    slash_command_registry.register_command(active_command::ActiveSlashCommand, true);
276    slash_command_registry.register_command(symbols_command::OutlineSlashCommand, true);
277    slash_command_registry.register_command(tabs_command::TabsSlashCommand, true);
278    slash_command_registry.register_command(project_command::ProjectSlashCommand, true);
279    slash_command_registry.register_command(prompt_command::PromptSlashCommand, true);
280    slash_command_registry.register_command(default_command::DefaultSlashCommand, true);
281    slash_command_registry.register_command(term_command::TermSlashCommand, true);
282    slash_command_registry.register_command(now_command::NowSlashCommand, true);
283    slash_command_registry.register_command(diagnostics_command::DiagnosticsSlashCommand, true);
284    if let Some(prompt_builder) = prompt_builder {
285        slash_command_registry.register_command(
286            workflow_command::WorkflowSlashCommand::new(prompt_builder),
287            true,
288        );
289    }
290    slash_command_registry.register_command(fetch_command::FetchSlashCommand, false);
291
292    cx.observe_flag::<docs_command::DocsSlashCommandFeatureFlag, _>({
293        let slash_command_registry = slash_command_registry.clone();
294        move |is_enabled, _cx| {
295            if is_enabled {
296                slash_command_registry.register_command(docs_command::DocsSlashCommand, true);
297            }
298        }
299    })
300    .detach();
301    cx.observe_flag::<search_command::SearchSlashCommandFeatureFlag, _>({
302        let slash_command_registry = slash_command_registry.clone();
303        move |is_enabled, _cx| {
304            if is_enabled {
305                slash_command_registry.register_command(search_command::SearchSlashCommand, true);
306            }
307        }
308    })
309    .detach();
310}
311
312pub fn humanize_token_count(count: usize) -> String {
313    match count {
314        0..=999 => count.to_string(),
315        1000..=9999 => {
316            let thousands = count / 1000;
317            let hundreds = (count % 1000 + 50) / 100;
318            if hundreds == 0 {
319                format!("{}k", thousands)
320            } else if hundreds == 10 {
321                format!("{}k", thousands + 1)
322            } else {
323                format!("{}.{}k", thousands, hundreds)
324            }
325        }
326        _ => format!("{}k", (count + 500) / 1000),
327    }
328}
329
330#[cfg(test)]
331#[ctor::ctor]
332fn init_logger() {
333    if std::env::var("RUST_LOG").is_ok() {
334        env_logger::init();
335    }
336}