assistant.rs

  1#![cfg_attr(target_os = "windows", allow(unused, dead_code))]
  2
  3mod assistant_configuration;
  4pub mod assistant_panel;
  5mod inline_assistant;
  6pub mod slash_command_settings;
  7mod terminal_inline_assistant;
  8
  9use std::sync::Arc;
 10
 11use assistant_settings::AssistantSettings;
 12use assistant_slash_command::SlashCommandRegistry;
 13use assistant_slash_commands::SearchSlashCommandFeatureFlag;
 14use client::Client;
 15use command_palette_hooks::CommandPaletteFilter;
 16use feature_flags::FeatureFlagAppExt;
 17use fs::Fs;
 18use gpui::{actions, App, Global, ReadGlobal, UpdateGlobal};
 19use language_model::{
 20    LanguageModelId, LanguageModelProviderId, LanguageModelRegistry, LanguageModelResponseMessage,
 21};
 22use prompt_store::PromptBuilder;
 23use semantic_index::{CloudEmbeddingProvider, SemanticDb};
 24use serde::Deserialize;
 25use settings::{Settings, SettingsStore};
 26
 27pub use crate::assistant_panel::{AssistantPanel, AssistantPanelEvent};
 28pub(crate) use crate::inline_assistant::*;
 29use crate::slash_command_settings::SlashCommandSettings;
 30
 31actions!(
 32    assistant,
 33    [
 34        InsertActivePrompt,
 35        DeployHistory,
 36        NewChat,
 37        CycleNextInlineAssist,
 38        CyclePreviousInlineAssist
 39    ]
 40);
 41
 42const DEFAULT_CONTEXT_LINES: usize = 50;
 43
 44#[derive(Deserialize, Debug)]
 45pub struct LanguageModelUsage {
 46    pub prompt_tokens: u32,
 47    pub completion_tokens: u32,
 48    pub total_tokens: u32,
 49}
 50
 51#[derive(Deserialize, Debug)]
 52pub struct LanguageModelChoiceDelta {
 53    pub index: u32,
 54    pub delta: LanguageModelResponseMessage,
 55    pub finish_reason: Option<String>,
 56}
 57
 58/// The state pertaining to the Assistant.
 59#[derive(Default)]
 60struct Assistant {
 61    /// Whether the Assistant is enabled.
 62    enabled: bool,
 63}
 64
 65impl Global for Assistant {}
 66
 67impl Assistant {
 68    const NAMESPACE: &'static str = "assistant";
 69
 70    fn set_enabled(&mut self, enabled: bool, cx: &mut App) {
 71        if self.enabled == enabled {
 72            return;
 73        }
 74
 75        self.enabled = enabled;
 76
 77        if !enabled {
 78            CommandPaletteFilter::update_global(cx, |filter, _cx| {
 79                filter.hide_namespace(Self::NAMESPACE);
 80            });
 81
 82            return;
 83        }
 84
 85        CommandPaletteFilter::update_global(cx, |filter, _cx| {
 86            filter.show_namespace(Self::NAMESPACE);
 87        });
 88    }
 89
 90    pub fn enabled(cx: &App) -> bool {
 91        Self::global(cx).enabled
 92    }
 93}
 94
 95pub fn init(
 96    fs: Arc<dyn Fs>,
 97    client: Arc<Client>,
 98    prompt_builder: Arc<PromptBuilder>,
 99    cx: &mut App,
100) {
101    cx.set_global(Assistant::default());
102    AssistantSettings::register(cx);
103    SlashCommandSettings::register(cx);
104
105    cx.spawn({
106        let client = client.clone();
107        async move |cx| {
108            let is_search_slash_command_enabled = cx
109                .update(|cx| cx.wait_for_flag::<SearchSlashCommandFeatureFlag>())?
110                .await;
111
112            if !is_search_slash_command_enabled {
113                return Ok(());
114            }
115
116            let embedding_provider = CloudEmbeddingProvider::new(client.clone());
117            let semantic_index = SemanticDb::new(
118                paths::embeddings_dir().join("semantic-index-db.0.mdb"),
119                Arc::new(embedding_provider),
120                cx,
121            )
122            .await?;
123
124            cx.update(|cx| cx.set_global(semantic_index))
125        }
126    })
127    .detach();
128
129    assistant_context_editor::init(client.clone(), cx);
130    prompt_library::init(cx);
131    init_language_model_settings(cx);
132    assistant_slash_command::init(cx);
133    assistant_tool::init(cx);
134    assistant_panel::init(cx);
135    context_server::init(cx);
136
137    register_slash_commands(cx);
138    inline_assistant::init(
139        fs.clone(),
140        prompt_builder.clone(),
141        client.telemetry().clone(),
142        cx,
143    );
144    terminal_inline_assistant::init(
145        fs.clone(),
146        prompt_builder.clone(),
147        client.telemetry().clone(),
148        cx,
149    );
150    indexed_docs::init(cx);
151
152    CommandPaletteFilter::update_global(cx, |filter, _cx| {
153        filter.hide_namespace(Assistant::NAMESPACE);
154    });
155    Assistant::update_global(cx, |assistant, cx| {
156        let settings = AssistantSettings::get_global(cx);
157
158        assistant.set_enabled(settings.enabled, cx);
159    });
160    cx.observe_global::<SettingsStore>(|cx| {
161        Assistant::update_global(cx, |assistant, cx| {
162            let settings = AssistantSettings::get_global(cx);
163            assistant.set_enabled(settings.enabled, cx);
164        });
165    })
166    .detach();
167}
168
169fn init_language_model_settings(cx: &mut App) {
170    update_active_language_model_from_settings(cx);
171
172    cx.observe_global::<SettingsStore>(update_active_language_model_from_settings)
173        .detach();
174    cx.subscribe(
175        &LanguageModelRegistry::global(cx),
176        |_, event: &language_model::Event, cx| match event {
177            language_model::Event::ProviderStateChanged
178            | language_model::Event::AddedProvider(_)
179            | language_model::Event::RemovedProvider(_) => {
180                update_active_language_model_from_settings(cx);
181            }
182            _ => {}
183        },
184    )
185    .detach();
186}
187
188fn update_active_language_model_from_settings(cx: &mut App) {
189    let settings = AssistantSettings::get_global(cx);
190    let active_model_provider_name =
191        LanguageModelProviderId::from(settings.default_model.provider.clone());
192    let active_model_id = LanguageModelId::from(settings.default_model.model.clone());
193    let editor_provider_name =
194        LanguageModelProviderId::from(settings.editor_model.provider.clone());
195    let editor_model_id = LanguageModelId::from(settings.editor_model.model.clone());
196    let inline_alternatives = settings
197        .inline_alternatives
198        .iter()
199        .map(|alternative| {
200            (
201                LanguageModelProviderId::from(alternative.provider.clone()),
202                LanguageModelId::from(alternative.model.clone()),
203            )
204        })
205        .collect::<Vec<_>>();
206    LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
207        registry.select_active_model(&active_model_provider_name, &active_model_id, cx);
208        registry.select_editor_model(&editor_provider_name, &editor_model_id, cx);
209        registry.select_inline_alternative_models(inline_alternatives, cx);
210    });
211}
212
213fn register_slash_commands(cx: &mut App) {
214    let slash_command_registry = SlashCommandRegistry::global(cx);
215
216    slash_command_registry.register_command(assistant_slash_commands::FileSlashCommand, true);
217    slash_command_registry.register_command(assistant_slash_commands::DeltaSlashCommand, true);
218    slash_command_registry.register_command(assistant_slash_commands::OutlineSlashCommand, true);
219    slash_command_registry.register_command(assistant_slash_commands::TabSlashCommand, true);
220    slash_command_registry
221        .register_command(assistant_slash_commands::CargoWorkspaceSlashCommand, true);
222    slash_command_registry.register_command(assistant_slash_commands::PromptSlashCommand, true);
223    slash_command_registry.register_command(assistant_slash_commands::SelectionCommand, true);
224    slash_command_registry.register_command(assistant_slash_commands::DefaultSlashCommand, false);
225    slash_command_registry.register_command(assistant_slash_commands::TerminalSlashCommand, true);
226    slash_command_registry.register_command(assistant_slash_commands::NowSlashCommand, false);
227    slash_command_registry
228        .register_command(assistant_slash_commands::DiagnosticsSlashCommand, true);
229    slash_command_registry.register_command(assistant_slash_commands::FetchSlashCommand, true);
230
231    cx.observe_flag::<assistant_slash_commands::StreamingExampleSlashCommandFeatureFlag, _>({
232        let slash_command_registry = slash_command_registry.clone();
233        move |is_enabled, _cx| {
234            if is_enabled {
235                slash_command_registry.register_command(
236                    assistant_slash_commands::StreamingExampleSlashCommand,
237                    false,
238                );
239            }
240        }
241    })
242    .detach();
243
244    update_slash_commands_from_settings(cx);
245    cx.observe_global::<SettingsStore>(update_slash_commands_from_settings)
246        .detach();
247
248    cx.observe_flag::<assistant_slash_commands::SearchSlashCommandFeatureFlag, _>({
249        let slash_command_registry = slash_command_registry.clone();
250        move |is_enabled, _cx| {
251            if is_enabled {
252                slash_command_registry
253                    .register_command(assistant_slash_commands::SearchSlashCommand, true);
254            }
255        }
256    })
257    .detach();
258}
259
260fn update_slash_commands_from_settings(cx: &mut App) {
261    let slash_command_registry = SlashCommandRegistry::global(cx);
262    let settings = SlashCommandSettings::get_global(cx);
263
264    if settings.docs.enabled {
265        slash_command_registry.register_command(assistant_slash_commands::DocsSlashCommand, true);
266    } else {
267        slash_command_registry.unregister_command(assistant_slash_commands::DocsSlashCommand);
268    }
269
270    if settings.cargo_workspace.enabled {
271        slash_command_registry
272            .register_command(assistant_slash_commands::CargoWorkspaceSlashCommand, true);
273    } else {
274        slash_command_registry
275            .unregister_command(assistant_slash_commands::CargoWorkspaceSlashCommand);
276    }
277}
278
279#[cfg(test)]
280#[ctor::ctor]
281fn init_logger() {
282    if std::env::var("RUST_LOG").is_ok() {
283        env_logger::init();
284    }
285}