assistant.rs

  1#![cfg_attr(target_os = "windows", allow(unused, dead_code))]
  2
  3mod assistant_configuration;
  4pub mod assistant_panel;
  5mod inline_assistant;
  6pub mod slash_command_settings;
  7mod terminal_inline_assistant;
  8
  9use std::sync::Arc;
 10
 11use assistant_settings::AssistantSettings;
 12use assistant_slash_command::SlashCommandRegistry;
 13use assistant_slash_commands::{ProjectSlashCommandFeatureFlag, SearchSlashCommandFeatureFlag};
 14use client::Client;
 15use command_palette_hooks::CommandPaletteFilter;
 16use feature_flags::FeatureFlagAppExt;
 17use fs::Fs;
 18use gpui::{actions, AppContext, Global, UpdateGlobal};
 19use language_model::{
 20    LanguageModelId, LanguageModelProviderId, LanguageModelRegistry, LanguageModelResponseMessage,
 21};
 22use prompt_library::{PromptBuilder, PromptLoadingParams};
 23use semantic_index::{CloudEmbeddingProvider, SemanticDb};
 24use serde::Deserialize;
 25use settings::{Settings, SettingsStore};
 26use util::ResultExt;
 27
 28pub use crate::assistant_panel::{AssistantPanel, AssistantPanelEvent};
 29pub(crate) use crate::inline_assistant::*;
 30use crate::slash_command_settings::SlashCommandSettings;
 31
 32actions!(
 33    assistant,
 34    [
 35        InsertActivePrompt,
 36        DeployHistory,
 37        DeployPromptLibrary,
 38        NewContext,
 39        CycleNextInlineAssist,
 40        CyclePreviousInlineAssist
 41    ]
 42);
 43
 44const DEFAULT_CONTEXT_LINES: usize = 50;
 45
 46#[derive(Deserialize, Debug)]
 47pub struct LanguageModelUsage {
 48    pub prompt_tokens: u32,
 49    pub completion_tokens: u32,
 50    pub total_tokens: u32,
 51}
 52
 53#[derive(Deserialize, Debug)]
 54pub struct LanguageModelChoiceDelta {
 55    pub index: u32,
 56    pub delta: LanguageModelResponseMessage,
 57    pub finish_reason: Option<String>,
 58}
 59
 60/// The state pertaining to the Assistant.
 61#[derive(Default)]
 62struct Assistant {
 63    /// Whether the Assistant is enabled.
 64    enabled: bool,
 65}
 66
 67impl Global for Assistant {}
 68
 69impl Assistant {
 70    const NAMESPACE: &'static str = "assistant";
 71
 72    fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
 73        if self.enabled == enabled {
 74            return;
 75        }
 76
 77        self.enabled = enabled;
 78
 79        if !enabled {
 80            CommandPaletteFilter::update_global(cx, |filter, _cx| {
 81                filter.hide_namespace(Self::NAMESPACE);
 82            });
 83
 84            return;
 85        }
 86
 87        CommandPaletteFilter::update_global(cx, |filter, _cx| {
 88            filter.show_namespace(Self::NAMESPACE);
 89        });
 90    }
 91}
 92
 93pub fn init(
 94    fs: Arc<dyn Fs>,
 95    client: Arc<Client>,
 96    stdout_is_a_pty: bool,
 97    cx: &mut AppContext,
 98) -> Arc<PromptBuilder> {
 99    cx.set_global(Assistant::default());
100    AssistantSettings::register(cx);
101    SlashCommandSettings::register(cx);
102
103    cx.spawn(|mut cx| {
104        let client = client.clone();
105        async move {
106            let is_search_slash_command_enabled = cx
107                .update(|cx| cx.wait_for_flag::<SearchSlashCommandFeatureFlag>())?
108                .await;
109            let is_project_slash_command_enabled = cx
110                .update(|cx| cx.wait_for_flag::<ProjectSlashCommandFeatureFlag>())?
111                .await;
112
113            if !is_search_slash_command_enabled && !is_project_slash_command_enabled {
114                return Ok(());
115            }
116
117            let embedding_provider = CloudEmbeddingProvider::new(client.clone());
118            let semantic_index = SemanticDb::new(
119                paths::embeddings_dir().join("semantic-index-db.0.mdb"),
120                Arc::new(embedding_provider),
121                &mut cx,
122            )
123            .await?;
124
125            cx.update(|cx| cx.set_global(semantic_index))
126        }
127    })
128    .detach();
129
130    assistant_context_editor::init(client.clone(), cx);
131    prompt_library::init(cx);
132    init_language_model_settings(cx);
133    assistant_slash_command::init(cx);
134    assistant_tool::init(cx);
135    assistant_panel::init(cx);
136    context_server::init(cx);
137
138    let prompt_builder = PromptBuilder::new(Some(PromptLoadingParams {
139        fs: fs.clone(),
140        repo_path: stdout_is_a_pty
141            .then(|| std::env::current_dir().log_err())
142            .flatten(),
143        cx,
144    }))
145    .log_err()
146    .map(Arc::new)
147    .unwrap_or_else(|| Arc::new(PromptBuilder::new(None).unwrap()));
148    register_slash_commands(Some(prompt_builder.clone()), cx);
149    inline_assistant::init(
150        fs.clone(),
151        prompt_builder.clone(),
152        client.telemetry().clone(),
153        cx,
154    );
155    terminal_inline_assistant::init(
156        fs.clone(),
157        prompt_builder.clone(),
158        client.telemetry().clone(),
159        cx,
160    );
161    indexed_docs::init(cx);
162
163    CommandPaletteFilter::update_global(cx, |filter, _cx| {
164        filter.hide_namespace(Assistant::NAMESPACE);
165    });
166    Assistant::update_global(cx, |assistant, cx| {
167        let settings = AssistantSettings::get_global(cx);
168
169        assistant.set_enabled(settings.enabled, cx);
170    });
171    cx.observe_global::<SettingsStore>(|cx| {
172        Assistant::update_global(cx, |assistant, cx| {
173            let settings = AssistantSettings::get_global(cx);
174            assistant.set_enabled(settings.enabled, cx);
175        });
176    })
177    .detach();
178
179    prompt_builder
180}
181
182fn init_language_model_settings(cx: &mut AppContext) {
183    update_active_language_model_from_settings(cx);
184
185    cx.observe_global::<SettingsStore>(update_active_language_model_from_settings)
186        .detach();
187    cx.subscribe(
188        &LanguageModelRegistry::global(cx),
189        |_, event: &language_model::Event, cx| match event {
190            language_model::Event::ProviderStateChanged
191            | language_model::Event::AddedProvider(_)
192            | language_model::Event::RemovedProvider(_) => {
193                update_active_language_model_from_settings(cx);
194            }
195            _ => {}
196        },
197    )
198    .detach();
199}
200
201fn update_active_language_model_from_settings(cx: &mut AppContext) {
202    let settings = AssistantSettings::get_global(cx);
203    let provider_name = LanguageModelProviderId::from(settings.default_model.provider.clone());
204    let model_id = LanguageModelId::from(settings.default_model.model.clone());
205    let inline_alternatives = settings
206        .inline_alternatives
207        .iter()
208        .map(|alternative| {
209            (
210                LanguageModelProviderId::from(alternative.provider.clone()),
211                LanguageModelId::from(alternative.model.clone()),
212            )
213        })
214        .collect::<Vec<_>>();
215    LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
216        registry.select_active_model(&provider_name, &model_id, cx);
217        registry.select_inline_alternative_models(inline_alternatives, cx);
218    });
219}
220
221fn register_slash_commands(prompt_builder: Option<Arc<PromptBuilder>>, cx: &mut AppContext) {
222    let slash_command_registry = SlashCommandRegistry::global(cx);
223
224    slash_command_registry.register_command(assistant_slash_commands::FileSlashCommand, true);
225    slash_command_registry.register_command(assistant_slash_commands::DeltaSlashCommand, true);
226    slash_command_registry.register_command(assistant_slash_commands::OutlineSlashCommand, true);
227    slash_command_registry.register_command(assistant_slash_commands::TabSlashCommand, true);
228    slash_command_registry
229        .register_command(assistant_slash_commands::CargoWorkspaceSlashCommand, true);
230    slash_command_registry.register_command(assistant_slash_commands::PromptSlashCommand, true);
231    slash_command_registry.register_command(assistant_slash_commands::SelectionCommand, true);
232    slash_command_registry.register_command(assistant_slash_commands::DefaultSlashCommand, false);
233    slash_command_registry.register_command(assistant_slash_commands::TerminalSlashCommand, true);
234    slash_command_registry.register_command(assistant_slash_commands::NowSlashCommand, false);
235    slash_command_registry
236        .register_command(assistant_slash_commands::DiagnosticsSlashCommand, true);
237    slash_command_registry.register_command(assistant_slash_commands::FetchSlashCommand, true);
238
239    if let Some(prompt_builder) = prompt_builder {
240        cx.observe_flag::<assistant_slash_commands::ProjectSlashCommandFeatureFlag, _>({
241            let slash_command_registry = slash_command_registry.clone();
242            move |is_enabled, _cx| {
243                if is_enabled {
244                    slash_command_registry.register_command(
245                        assistant_slash_commands::ProjectSlashCommand::new(prompt_builder.clone()),
246                        true,
247                    );
248                }
249            }
250        })
251        .detach();
252    }
253
254    cx.observe_flag::<assistant_slash_commands::AutoSlashCommandFeatureFlag, _>({
255        let slash_command_registry = slash_command_registry.clone();
256        move |is_enabled, _cx| {
257            if is_enabled {
258                // [#auto-staff-ship] TODO remove this when /auto is no longer staff-shipped
259                slash_command_registry
260                    .register_command(assistant_slash_commands::AutoCommand, true);
261            }
262        }
263    })
264    .detach();
265
266    cx.observe_flag::<assistant_slash_commands::StreamingExampleSlashCommandFeatureFlag, _>({
267        let slash_command_registry = slash_command_registry.clone();
268        move |is_enabled, _cx| {
269            if is_enabled {
270                slash_command_registry.register_command(
271                    assistant_slash_commands::StreamingExampleSlashCommand,
272                    false,
273                );
274            }
275        }
276    })
277    .detach();
278
279    update_slash_commands_from_settings(cx);
280    cx.observe_global::<SettingsStore>(update_slash_commands_from_settings)
281        .detach();
282
283    cx.observe_flag::<assistant_slash_commands::SearchSlashCommandFeatureFlag, _>({
284        let slash_command_registry = slash_command_registry.clone();
285        move |is_enabled, _cx| {
286            if is_enabled {
287                slash_command_registry
288                    .register_command(assistant_slash_commands::SearchSlashCommand, true);
289            }
290        }
291    })
292    .detach();
293}
294
295fn update_slash_commands_from_settings(cx: &mut AppContext) {
296    let slash_command_registry = SlashCommandRegistry::global(cx);
297    let settings = SlashCommandSettings::get_global(cx);
298
299    if settings.docs.enabled {
300        slash_command_registry.register_command(assistant_slash_commands::DocsSlashCommand, true);
301    } else {
302        slash_command_registry.unregister_command(assistant_slash_commands::DocsSlashCommand);
303    }
304
305    if settings.cargo_workspace.enabled {
306        slash_command_registry
307            .register_command(assistant_slash_commands::CargoWorkspaceSlashCommand, true);
308    } else {
309        slash_command_registry
310            .unregister_command(assistant_slash_commands::CargoWorkspaceSlashCommand);
311    }
312}
313
314#[cfg(test)]
315#[ctor::ctor]
316fn init_logger() {
317    if std::env::var("RUST_LOG").is_ok() {
318        env_logger::init();
319    }
320}