assistant.rs

  1#![cfg_attr(target_os = "windows", allow(unused, dead_code))]
  2
  3mod assistant_configuration;
  4pub mod assistant_panel;
  5mod inline_assistant;
  6pub mod slash_command_settings;
  7mod terminal_inline_assistant;
  8
  9use std::sync::Arc;
 10
 11use assistant_settings::AssistantSettings;
 12use assistant_slash_command::SlashCommandRegistry;
 13use assistant_slash_commands::{ProjectSlashCommandFeatureFlag, SearchSlashCommandFeatureFlag};
 14use client::Client;
 15use command_palette_hooks::CommandPaletteFilter;
 16use feature_flags::FeatureFlagAppExt;
 17use fs::Fs;
 18use gpui::{actions, App, Global, UpdateGlobal};
 19use language_model::{
 20    LanguageModelId, LanguageModelProviderId, LanguageModelRegistry, LanguageModelResponseMessage,
 21};
 22use prompt_store::PromptBuilder;
 23use semantic_index::{CloudEmbeddingProvider, SemanticDb};
 24use serde::Deserialize;
 25use settings::{Settings, SettingsStore};
 26
 27pub use crate::assistant_panel::{AssistantPanel, AssistantPanelEvent};
 28pub(crate) use crate::inline_assistant::*;
 29use crate::slash_command_settings::SlashCommandSettings;
 30
 31actions!(
 32    assistant,
 33    [
 34        InsertActivePrompt,
 35        DeployHistory,
 36        NewChat,
 37        CycleNextInlineAssist,
 38        CyclePreviousInlineAssist
 39    ]
 40);
 41
 42const DEFAULT_CONTEXT_LINES: usize = 50;
 43
 44#[derive(Deserialize, Debug)]
 45pub struct LanguageModelUsage {
 46    pub prompt_tokens: u32,
 47    pub completion_tokens: u32,
 48    pub total_tokens: u32,
 49}
 50
 51#[derive(Deserialize, Debug)]
 52pub struct LanguageModelChoiceDelta {
 53    pub index: u32,
 54    pub delta: LanguageModelResponseMessage,
 55    pub finish_reason: Option<String>,
 56}
 57
 58/// The state pertaining to the Assistant.
 59#[derive(Default)]
 60struct Assistant {
 61    /// Whether the Assistant is enabled.
 62    enabled: bool,
 63}
 64
 65impl Global for Assistant {}
 66
 67impl Assistant {
 68    const NAMESPACE: &'static str = "assistant";
 69
 70    fn set_enabled(&mut self, enabled: bool, cx: &mut App) {
 71        if self.enabled == enabled {
 72            return;
 73        }
 74
 75        self.enabled = enabled;
 76
 77        if !enabled {
 78            CommandPaletteFilter::update_global(cx, |filter, _cx| {
 79                filter.hide_namespace(Self::NAMESPACE);
 80            });
 81
 82            return;
 83        }
 84
 85        CommandPaletteFilter::update_global(cx, |filter, _cx| {
 86            filter.show_namespace(Self::NAMESPACE);
 87        });
 88    }
 89}
 90
 91pub fn init(
 92    fs: Arc<dyn Fs>,
 93    client: Arc<Client>,
 94    prompt_builder: Arc<PromptBuilder>,
 95    cx: &mut App,
 96) {
 97    cx.set_global(Assistant::default());
 98    AssistantSettings::register(cx);
 99    SlashCommandSettings::register(cx);
100
101    cx.spawn({
102        let client = client.clone();
103        async move |cx| {
104            let is_search_slash_command_enabled = cx
105                .update(|cx| cx.wait_for_flag::<SearchSlashCommandFeatureFlag>())?
106                .await;
107            let is_project_slash_command_enabled = cx
108                .update(|cx| cx.wait_for_flag::<ProjectSlashCommandFeatureFlag>())?
109                .await;
110
111            if !is_search_slash_command_enabled && !is_project_slash_command_enabled {
112                return Ok(());
113            }
114
115            let embedding_provider = CloudEmbeddingProvider::new(client.clone());
116            let semantic_index = SemanticDb::new(
117                paths::embeddings_dir().join("semantic-index-db.0.mdb"),
118                Arc::new(embedding_provider),
119                cx,
120            )
121            .await?;
122
123            cx.update(|cx| cx.set_global(semantic_index))
124        }
125    })
126    .detach();
127
128    assistant_context_editor::init(client.clone(), cx);
129    prompt_library::init(cx);
130    init_language_model_settings(cx);
131    assistant_slash_command::init(cx);
132    assistant_tool::init(cx);
133    assistant_panel::init(cx);
134    context_server::init(cx);
135
136    register_slash_commands(Some(prompt_builder.clone()), cx);
137    inline_assistant::init(
138        fs.clone(),
139        prompt_builder.clone(),
140        client.telemetry().clone(),
141        cx,
142    );
143    terminal_inline_assistant::init(
144        fs.clone(),
145        prompt_builder.clone(),
146        client.telemetry().clone(),
147        cx,
148    );
149    indexed_docs::init(cx);
150
151    CommandPaletteFilter::update_global(cx, |filter, _cx| {
152        filter.hide_namespace(Assistant::NAMESPACE);
153    });
154    Assistant::update_global(cx, |assistant, cx| {
155        let settings = AssistantSettings::get_global(cx);
156
157        assistant.set_enabled(settings.enabled, cx);
158    });
159    cx.observe_global::<SettingsStore>(|cx| {
160        Assistant::update_global(cx, |assistant, cx| {
161            let settings = AssistantSettings::get_global(cx);
162            assistant.set_enabled(settings.enabled, cx);
163        });
164    })
165    .detach();
166}
167
168fn init_language_model_settings(cx: &mut App) {
169    update_active_language_model_from_settings(cx);
170
171    cx.observe_global::<SettingsStore>(update_active_language_model_from_settings)
172        .detach();
173    cx.subscribe(
174        &LanguageModelRegistry::global(cx),
175        |_, event: &language_model::Event, cx| match event {
176            language_model::Event::ProviderStateChanged
177            | language_model::Event::AddedProvider(_)
178            | language_model::Event::RemovedProvider(_) => {
179                update_active_language_model_from_settings(cx);
180            }
181            _ => {}
182        },
183    )
184    .detach();
185}
186
187fn update_active_language_model_from_settings(cx: &mut App) {
188    let settings = AssistantSettings::get_global(cx);
189    let active_model_provider_name =
190        LanguageModelProviderId::from(settings.default_model.provider.clone());
191    let active_model_id = LanguageModelId::from(settings.default_model.model.clone());
192    let editor_provider_name =
193        LanguageModelProviderId::from(settings.editor_model.provider.clone());
194    let editor_model_id = LanguageModelId::from(settings.editor_model.model.clone());
195    let inline_alternatives = settings
196        .inline_alternatives
197        .iter()
198        .map(|alternative| {
199            (
200                LanguageModelProviderId::from(alternative.provider.clone()),
201                LanguageModelId::from(alternative.model.clone()),
202            )
203        })
204        .collect::<Vec<_>>();
205    LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
206        registry.select_active_model(&active_model_provider_name, &active_model_id, cx);
207        registry.select_editor_model(&editor_provider_name, &editor_model_id, cx);
208        registry.select_inline_alternative_models(inline_alternatives, cx);
209    });
210}
211
212fn register_slash_commands(prompt_builder: Option<Arc<PromptBuilder>>, cx: &mut App) {
213    let slash_command_registry = SlashCommandRegistry::global(cx);
214
215    slash_command_registry.register_command(assistant_slash_commands::FileSlashCommand, true);
216    slash_command_registry.register_command(assistant_slash_commands::DeltaSlashCommand, true);
217    slash_command_registry.register_command(assistant_slash_commands::OutlineSlashCommand, true);
218    slash_command_registry.register_command(assistant_slash_commands::TabSlashCommand, true);
219    slash_command_registry
220        .register_command(assistant_slash_commands::CargoWorkspaceSlashCommand, true);
221    slash_command_registry.register_command(assistant_slash_commands::PromptSlashCommand, true);
222    slash_command_registry.register_command(assistant_slash_commands::SelectionCommand, true);
223    slash_command_registry.register_command(assistant_slash_commands::DefaultSlashCommand, false);
224    slash_command_registry.register_command(assistant_slash_commands::TerminalSlashCommand, true);
225    slash_command_registry.register_command(assistant_slash_commands::NowSlashCommand, false);
226    slash_command_registry
227        .register_command(assistant_slash_commands::DiagnosticsSlashCommand, true);
228    slash_command_registry.register_command(assistant_slash_commands::FetchSlashCommand, true);
229
230    if let Some(prompt_builder) = prompt_builder {
231        cx.observe_flag::<assistant_slash_commands::ProjectSlashCommandFeatureFlag, _>({
232            let slash_command_registry = slash_command_registry.clone();
233            move |is_enabled, _cx| {
234                if is_enabled {
235                    slash_command_registry.register_command(
236                        assistant_slash_commands::ProjectSlashCommand::new(prompt_builder.clone()),
237                        true,
238                    );
239                }
240            }
241        })
242        .detach();
243    }
244
245    cx.observe_flag::<assistant_slash_commands::AutoSlashCommandFeatureFlag, _>({
246        let slash_command_registry = slash_command_registry.clone();
247        move |is_enabled, _cx| {
248            if is_enabled {
249                // [#auto-staff-ship] TODO remove this when /auto is no longer staff-shipped
250                slash_command_registry
251                    .register_command(assistant_slash_commands::AutoCommand, true);
252            }
253        }
254    })
255    .detach();
256
257    cx.observe_flag::<assistant_slash_commands::StreamingExampleSlashCommandFeatureFlag, _>({
258        let slash_command_registry = slash_command_registry.clone();
259        move |is_enabled, _cx| {
260            if is_enabled {
261                slash_command_registry.register_command(
262                    assistant_slash_commands::StreamingExampleSlashCommand,
263                    false,
264                );
265            }
266        }
267    })
268    .detach();
269
270    update_slash_commands_from_settings(cx);
271    cx.observe_global::<SettingsStore>(update_slash_commands_from_settings)
272        .detach();
273
274    cx.observe_flag::<assistant_slash_commands::SearchSlashCommandFeatureFlag, _>({
275        let slash_command_registry = slash_command_registry.clone();
276        move |is_enabled, _cx| {
277            if is_enabled {
278                slash_command_registry
279                    .register_command(assistant_slash_commands::SearchSlashCommand, true);
280            }
281        }
282    })
283    .detach();
284}
285
286fn update_slash_commands_from_settings(cx: &mut App) {
287    let slash_command_registry = SlashCommandRegistry::global(cx);
288    let settings = SlashCommandSettings::get_global(cx);
289
290    if settings.docs.enabled {
291        slash_command_registry.register_command(assistant_slash_commands::DocsSlashCommand, true);
292    } else {
293        slash_command_registry.unregister_command(assistant_slash_commands::DocsSlashCommand);
294    }
295
296    if settings.cargo_workspace.enabled {
297        slash_command_registry
298            .register_command(assistant_slash_commands::CargoWorkspaceSlashCommand, true);
299    } else {
300        slash_command_registry
301            .unregister_command(assistant_slash_commands::CargoWorkspaceSlashCommand);
302    }
303}
304
305#[cfg(test)]
306#[ctor::ctor]
307fn init_logger() {
308    if std::env::var("RUST_LOG").is_ok() {
309        env_logger::init();
310    }
311}