assistant.rs

  1#![cfg_attr(target_os = "windows", allow(unused, dead_code))]
  2
  3mod assistant_configuration;
  4pub mod assistant_panel;
  5mod inline_assistant;
  6pub mod slash_command_settings;
  7mod terminal_inline_assistant;
  8
  9use std::sync::Arc;
 10
 11use assistant_settings::AssistantSettings;
 12use assistant_slash_command::SlashCommandRegistry;
 13use assistant_slash_commands::{ProjectSlashCommandFeatureFlag, SearchSlashCommandFeatureFlag};
 14use client::Client;
 15use command_palette_hooks::CommandPaletteFilter;
 16use feature_flags::FeatureFlagAppExt;
 17use fs::Fs;
 18use gpui::{actions, App, Global, ReadGlobal, UpdateGlobal};
 19use language_model::{
 20    LanguageModelId, LanguageModelProviderId, LanguageModelRegistry, LanguageModelResponseMessage,
 21};
 22use prompt_store::PromptBuilder;
 23use semantic_index::{CloudEmbeddingProvider, SemanticDb};
 24use serde::Deserialize;
 25use settings::{Settings, SettingsStore};
 26
 27pub use crate::assistant_panel::{AssistantPanel, AssistantPanelEvent};
 28pub(crate) use crate::inline_assistant::*;
 29use crate::slash_command_settings::SlashCommandSettings;
 30
 31actions!(
 32    assistant,
 33    [
 34        InsertActivePrompt,
 35        DeployHistory,
 36        NewChat,
 37        CycleNextInlineAssist,
 38        CyclePreviousInlineAssist
 39    ]
 40);
 41
 42const DEFAULT_CONTEXT_LINES: usize = 50;
 43
 44#[derive(Deserialize, Debug)]
 45pub struct LanguageModelUsage {
 46    pub prompt_tokens: u32,
 47    pub completion_tokens: u32,
 48    pub total_tokens: u32,
 49}
 50
 51#[derive(Deserialize, Debug)]
 52pub struct LanguageModelChoiceDelta {
 53    pub index: u32,
 54    pub delta: LanguageModelResponseMessage,
 55    pub finish_reason: Option<String>,
 56}
 57
 58/// The state pertaining to the Assistant.
 59#[derive(Default)]
 60struct Assistant {
 61    /// Whether the Assistant is enabled.
 62    enabled: bool,
 63}
 64
 65impl Global for Assistant {}
 66
 67impl Assistant {
 68    const NAMESPACE: &'static str = "assistant";
 69
 70    fn set_enabled(&mut self, enabled: bool, cx: &mut App) {
 71        if self.enabled == enabled {
 72            return;
 73        }
 74
 75        self.enabled = enabled;
 76
 77        if !enabled {
 78            CommandPaletteFilter::update_global(cx, |filter, _cx| {
 79                filter.hide_namespace(Self::NAMESPACE);
 80            });
 81
 82            return;
 83        }
 84
 85        CommandPaletteFilter::update_global(cx, |filter, _cx| {
 86            filter.show_namespace(Self::NAMESPACE);
 87        });
 88    }
 89
 90    pub fn enabled(cx: &App) -> bool {
 91        Self::global(cx).enabled
 92    }
 93}
 94
 95pub fn init(
 96    fs: Arc<dyn Fs>,
 97    client: Arc<Client>,
 98    prompt_builder: Arc<PromptBuilder>,
 99    cx: &mut App,
100) {
101    cx.set_global(Assistant::default());
102    AssistantSettings::register(cx);
103    SlashCommandSettings::register(cx);
104
105    cx.spawn({
106        let client = client.clone();
107        async move |cx| {
108            let is_search_slash_command_enabled = cx
109                .update(|cx| cx.wait_for_flag::<SearchSlashCommandFeatureFlag>())?
110                .await;
111            let is_project_slash_command_enabled = cx
112                .update(|cx| cx.wait_for_flag::<ProjectSlashCommandFeatureFlag>())?
113                .await;
114
115            if !is_search_slash_command_enabled && !is_project_slash_command_enabled {
116                return Ok(());
117            }
118
119            let embedding_provider = CloudEmbeddingProvider::new(client.clone());
120            let semantic_index = SemanticDb::new(
121                paths::embeddings_dir().join("semantic-index-db.0.mdb"),
122                Arc::new(embedding_provider),
123                cx,
124            )
125            .await?;
126
127            cx.update(|cx| cx.set_global(semantic_index))
128        }
129    })
130    .detach();
131
132    assistant_context_editor::init(client.clone(), cx);
133    prompt_library::init(cx);
134    init_language_model_settings(cx);
135    assistant_slash_command::init(cx);
136    assistant_tool::init(cx);
137    assistant_panel::init(cx);
138    context_server::init(cx);
139
140    register_slash_commands(Some(prompt_builder.clone()), cx);
141    inline_assistant::init(
142        fs.clone(),
143        prompt_builder.clone(),
144        client.telemetry().clone(),
145        cx,
146    );
147    terminal_inline_assistant::init(
148        fs.clone(),
149        prompt_builder.clone(),
150        client.telemetry().clone(),
151        cx,
152    );
153    indexed_docs::init(cx);
154
155    CommandPaletteFilter::update_global(cx, |filter, _cx| {
156        filter.hide_namespace(Assistant::NAMESPACE);
157    });
158    Assistant::update_global(cx, |assistant, cx| {
159        let settings = AssistantSettings::get_global(cx);
160
161        assistant.set_enabled(settings.enabled, cx);
162    });
163    cx.observe_global::<SettingsStore>(|cx| {
164        Assistant::update_global(cx, |assistant, cx| {
165            let settings = AssistantSettings::get_global(cx);
166            assistant.set_enabled(settings.enabled, cx);
167        });
168    })
169    .detach();
170}
171
172fn init_language_model_settings(cx: &mut App) {
173    update_active_language_model_from_settings(cx);
174
175    cx.observe_global::<SettingsStore>(update_active_language_model_from_settings)
176        .detach();
177    cx.subscribe(
178        &LanguageModelRegistry::global(cx),
179        |_, event: &language_model::Event, cx| match event {
180            language_model::Event::ProviderStateChanged
181            | language_model::Event::AddedProvider(_)
182            | language_model::Event::RemovedProvider(_) => {
183                update_active_language_model_from_settings(cx);
184            }
185            _ => {}
186        },
187    )
188    .detach();
189}
190
191fn update_active_language_model_from_settings(cx: &mut App) {
192    let settings = AssistantSettings::get_global(cx);
193    let active_model_provider_name =
194        LanguageModelProviderId::from(settings.default_model.provider.clone());
195    let active_model_id = LanguageModelId::from(settings.default_model.model.clone());
196    let editor_provider_name =
197        LanguageModelProviderId::from(settings.editor_model.provider.clone());
198    let editor_model_id = LanguageModelId::from(settings.editor_model.model.clone());
199    let inline_alternatives = settings
200        .inline_alternatives
201        .iter()
202        .map(|alternative| {
203            (
204                LanguageModelProviderId::from(alternative.provider.clone()),
205                LanguageModelId::from(alternative.model.clone()),
206            )
207        })
208        .collect::<Vec<_>>();
209    LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
210        registry.select_active_model(&active_model_provider_name, &active_model_id, cx);
211        registry.select_editor_model(&editor_provider_name, &editor_model_id, cx);
212        registry.select_inline_alternative_models(inline_alternatives, cx);
213    });
214}
215
216fn register_slash_commands(prompt_builder: Option<Arc<PromptBuilder>>, cx: &mut App) {
217    let slash_command_registry = SlashCommandRegistry::global(cx);
218
219    slash_command_registry.register_command(assistant_slash_commands::FileSlashCommand, true);
220    slash_command_registry.register_command(assistant_slash_commands::DeltaSlashCommand, true);
221    slash_command_registry.register_command(assistant_slash_commands::OutlineSlashCommand, true);
222    slash_command_registry.register_command(assistant_slash_commands::TabSlashCommand, true);
223    slash_command_registry
224        .register_command(assistant_slash_commands::CargoWorkspaceSlashCommand, true);
225    slash_command_registry.register_command(assistant_slash_commands::PromptSlashCommand, true);
226    slash_command_registry.register_command(assistant_slash_commands::SelectionCommand, true);
227    slash_command_registry.register_command(assistant_slash_commands::DefaultSlashCommand, false);
228    slash_command_registry.register_command(assistant_slash_commands::TerminalSlashCommand, true);
229    slash_command_registry.register_command(assistant_slash_commands::NowSlashCommand, false);
230    slash_command_registry
231        .register_command(assistant_slash_commands::DiagnosticsSlashCommand, true);
232    slash_command_registry.register_command(assistant_slash_commands::FetchSlashCommand, true);
233
234    if let Some(prompt_builder) = prompt_builder {
235        cx.observe_flag::<assistant_slash_commands::ProjectSlashCommandFeatureFlag, _>({
236            let slash_command_registry = slash_command_registry.clone();
237            move |is_enabled, _cx| {
238                if is_enabled {
239                    slash_command_registry.register_command(
240                        assistant_slash_commands::ProjectSlashCommand::new(prompt_builder.clone()),
241                        true,
242                    );
243                }
244            }
245        })
246        .detach();
247    }
248
249    cx.observe_flag::<assistant_slash_commands::StreamingExampleSlashCommandFeatureFlag, _>({
250        let slash_command_registry = slash_command_registry.clone();
251        move |is_enabled, _cx| {
252            if is_enabled {
253                slash_command_registry.register_command(
254                    assistant_slash_commands::StreamingExampleSlashCommand,
255                    false,
256                );
257            }
258        }
259    })
260    .detach();
261
262    update_slash_commands_from_settings(cx);
263    cx.observe_global::<SettingsStore>(update_slash_commands_from_settings)
264        .detach();
265
266    cx.observe_flag::<assistant_slash_commands::SearchSlashCommandFeatureFlag, _>({
267        let slash_command_registry = slash_command_registry.clone();
268        move |is_enabled, _cx| {
269            if is_enabled {
270                slash_command_registry
271                    .register_command(assistant_slash_commands::SearchSlashCommand, true);
272            }
273        }
274    })
275    .detach();
276}
277
278fn update_slash_commands_from_settings(cx: &mut App) {
279    let slash_command_registry = SlashCommandRegistry::global(cx);
280    let settings = SlashCommandSettings::get_global(cx);
281
282    if settings.docs.enabled {
283        slash_command_registry.register_command(assistant_slash_commands::DocsSlashCommand, true);
284    } else {
285        slash_command_registry.unregister_command(assistant_slash_commands::DocsSlashCommand);
286    }
287
288    if settings.cargo_workspace.enabled {
289        slash_command_registry
290            .register_command(assistant_slash_commands::CargoWorkspaceSlashCommand, true);
291    } else {
292        slash_command_registry
293            .unregister_command(assistant_slash_commands::CargoWorkspaceSlashCommand);
294    }
295}
296
297#[cfg(test)]
298#[ctor::ctor]
299fn init_logger() {
300    if std::env::var("RUST_LOG").is_ok() {
301        env_logger::init();
302    }
303}