1#![cfg_attr(target_os = "windows", allow(unused, dead_code))]
2
3pub mod assistant_panel;
4pub mod assistant_settings;
5mod context;
6pub mod context_store;
7mod inline_assistant;
8mod model_selector;
9mod prompt_library;
10mod prompts;
11mod slash_command;
12mod slash_command_picker;
13pub mod slash_command_settings;
14mod streaming_diff;
15mod terminal_inline_assistant;
16mod workflow;
17
18pub use assistant_panel::{AssistantPanel, AssistantPanelEvent};
19use assistant_settings::AssistantSettings;
20use assistant_slash_command::SlashCommandRegistry;
21use client::{proto, Client};
22use command_palette_hooks::CommandPaletteFilter;
23pub use context::*;
24use context_servers::ContextServerRegistry;
25pub use context_store::*;
26use feature_flags::FeatureFlagAppExt;
27use fs::Fs;
28use gpui::Context as _;
29use gpui::{actions, impl_actions, AppContext, Global, SharedString, UpdateGlobal};
30use indexed_docs::IndexedDocsRegistry;
31pub(crate) use inline_assistant::*;
32use language_model::{
33 LanguageModelId, LanguageModelProviderId, LanguageModelRegistry, LanguageModelResponseMessage,
34};
35pub(crate) use model_selector::*;
36pub use prompts::PromptBuilder;
37use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
38use serde::{Deserialize, Serialize};
39use settings::{update_settings_file, Settings, SettingsStore};
40use slash_command::{
41 context_server_command, default_command, diagnostics_command, docs_command, fetch_command,
42 file_command, now_command, project_command, prompt_command, search_command, symbols_command,
43 tab_command, terminal_command, workflow_command,
44};
45use std::sync::Arc;
46pub(crate) use streaming_diff::*;
47use util::ResultExt;
48pub use workflow::*;
49
50use crate::slash_command_settings::SlashCommandSettings;
51
52actions!(
53 assistant,
54 [
55 Assist,
56 Split,
57 CycleMessageRole,
58 QuoteSelection,
59 InsertIntoEditor,
60 ToggleFocus,
61 InsertActivePrompt,
62 DeployHistory,
63 DeployPromptLibrary,
64 ConfirmCommand,
65 ToggleModelSelector,
66 ]
67);
68
69const DEFAULT_CONTEXT_LINES: usize = 50;
70
71#[derive(Clone, Default, Deserialize, PartialEq)]
72pub struct InlineAssist {
73 prompt: Option<String>,
74}
75
76impl_actions!(assistant, [InlineAssist]);
77
78#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
79pub struct MessageId(clock::Lamport);
80
81impl MessageId {
82 pub fn as_u64(self) -> u64 {
83 self.0.as_u64()
84 }
85}
86
87#[derive(Deserialize, Debug)]
88pub struct LanguageModelUsage {
89 pub prompt_tokens: u32,
90 pub completion_tokens: u32,
91 pub total_tokens: u32,
92}
93
94#[derive(Deserialize, Debug)]
95pub struct LanguageModelChoiceDelta {
96 pub index: u32,
97 pub delta: LanguageModelResponseMessage,
98 pub finish_reason: Option<String>,
99}
100
101#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
102pub enum MessageStatus {
103 Pending,
104 Done,
105 Error(SharedString),
106 Canceled,
107}
108
109impl MessageStatus {
110 pub fn from_proto(status: proto::ContextMessageStatus) -> MessageStatus {
111 match status.variant {
112 Some(proto::context_message_status::Variant::Pending(_)) => MessageStatus::Pending,
113 Some(proto::context_message_status::Variant::Done(_)) => MessageStatus::Done,
114 Some(proto::context_message_status::Variant::Error(error)) => {
115 MessageStatus::Error(error.message.into())
116 }
117 Some(proto::context_message_status::Variant::Canceled(_)) => MessageStatus::Canceled,
118 None => MessageStatus::Pending,
119 }
120 }
121
122 pub fn to_proto(&self) -> proto::ContextMessageStatus {
123 match self {
124 MessageStatus::Pending => proto::ContextMessageStatus {
125 variant: Some(proto::context_message_status::Variant::Pending(
126 proto::context_message_status::Pending {},
127 )),
128 },
129 MessageStatus::Done => proto::ContextMessageStatus {
130 variant: Some(proto::context_message_status::Variant::Done(
131 proto::context_message_status::Done {},
132 )),
133 },
134 MessageStatus::Error(message) => proto::ContextMessageStatus {
135 variant: Some(proto::context_message_status::Variant::Error(
136 proto::context_message_status::Error {
137 message: message.to_string(),
138 },
139 )),
140 },
141 MessageStatus::Canceled => proto::ContextMessageStatus {
142 variant: Some(proto::context_message_status::Variant::Canceled(
143 proto::context_message_status::Canceled {},
144 )),
145 },
146 }
147 }
148}
149
150/// The state pertaining to the Assistant.
151#[derive(Default)]
152struct Assistant {
153 /// Whether the Assistant is enabled.
154 enabled: bool,
155}
156
157impl Global for Assistant {}
158
159impl Assistant {
160 const NAMESPACE: &'static str = "assistant";
161
162 fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
163 if self.enabled == enabled {
164 return;
165 }
166
167 self.enabled = enabled;
168
169 if !enabled {
170 CommandPaletteFilter::update_global(cx, |filter, _cx| {
171 filter.hide_namespace(Self::NAMESPACE);
172 });
173
174 return;
175 }
176
177 CommandPaletteFilter::update_global(cx, |filter, _cx| {
178 filter.show_namespace(Self::NAMESPACE);
179 });
180 }
181}
182
183pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) -> Arc<PromptBuilder> {
184 cx.set_global(Assistant::default());
185 AssistantSettings::register(cx);
186 SlashCommandSettings::register(cx);
187
188 // TODO: remove this when 0.148.0 is released.
189 if AssistantSettings::get_global(cx).using_outdated_settings_version {
190 update_settings_file::<AssistantSettings>(fs.clone(), cx, {
191 let fs = fs.clone();
192 |content, cx| {
193 content.update_file(fs, cx);
194 }
195 });
196 }
197
198 cx.spawn(|mut cx| {
199 let client = client.clone();
200 async move {
201 let embedding_provider = CloudEmbeddingProvider::new(client.clone());
202 let semantic_index = SemanticIndex::new(
203 paths::embeddings_dir().join("semantic-index-db.0.mdb"),
204 Arc::new(embedding_provider),
205 &mut cx,
206 )
207 .await?;
208 cx.update(|cx| cx.set_global(semantic_index))
209 }
210 })
211 .detach();
212
213 context_store::init(&client);
214 prompt_library::init(cx);
215 init_language_model_settings(cx);
216 assistant_slash_command::init(cx);
217 assistant_panel::init(cx);
218 context_servers::init(cx);
219
220 let prompt_builder = prompts::PromptBuilder::new(Some((fs.clone(), cx)))
221 .log_err()
222 .map(Arc::new)
223 .unwrap_or_else(|| Arc::new(prompts::PromptBuilder::new(None).unwrap()));
224 register_slash_commands(Some(prompt_builder.clone()), cx);
225 inline_assistant::init(
226 fs.clone(),
227 prompt_builder.clone(),
228 client.telemetry().clone(),
229 cx,
230 );
231 terminal_inline_assistant::init(
232 fs.clone(),
233 prompt_builder.clone(),
234 client.telemetry().clone(),
235 cx,
236 );
237 IndexedDocsRegistry::init_global(cx);
238
239 CommandPaletteFilter::update_global(cx, |filter, _cx| {
240 filter.hide_namespace(Assistant::NAMESPACE);
241 });
242 Assistant::update_global(cx, |assistant, cx| {
243 let settings = AssistantSettings::get_global(cx);
244
245 assistant.set_enabled(settings.enabled, cx);
246 });
247 cx.observe_global::<SettingsStore>(|cx| {
248 Assistant::update_global(cx, |assistant, cx| {
249 let settings = AssistantSettings::get_global(cx);
250 assistant.set_enabled(settings.enabled, cx);
251 });
252 })
253 .detach();
254
255 register_context_server_handlers(cx);
256
257 prompt_builder
258}
259
260fn register_context_server_handlers(cx: &mut AppContext) {
261 cx.subscribe(
262 &context_servers::manager::ContextServerManager::global(cx),
263 |manager, event, cx| match event {
264 context_servers::manager::Event::ServerStarted { server_id } => {
265 cx.update_model(
266 &manager,
267 |manager: &mut context_servers::manager::ContextServerManager, cx| {
268 let slash_command_registry = SlashCommandRegistry::global(cx);
269 let context_server_registry = ContextServerRegistry::global(cx);
270 if let Some(server) = manager.get_server(server_id) {
271 cx.spawn(|_, _| async move {
272 let Some(protocol) = server.client.read().clone() else {
273 return;
274 };
275
276 if let Some(prompts) = protocol.list_prompts().await.log_err() {
277 for prompt in prompts
278 .into_iter()
279 .filter(context_server_command::acceptable_prompt)
280 {
281 log::info!(
282 "registering context server command: {:?}",
283 prompt.name
284 );
285 context_server_registry.register_command(
286 server.id.clone(),
287 prompt.name.as_str(),
288 );
289 slash_command_registry.register_command(
290 context_server_command::ContextServerSlashCommand::new(
291 &server, prompt,
292 ),
293 true,
294 );
295 }
296 }
297 })
298 .detach();
299 }
300 },
301 );
302 }
303 context_servers::manager::Event::ServerStopped { server_id } => {
304 let slash_command_registry = SlashCommandRegistry::global(cx);
305 let context_server_registry = ContextServerRegistry::global(cx);
306 if let Some(commands) = context_server_registry.get_commands(server_id) {
307 for command_name in commands {
308 slash_command_registry.unregister_command_by_name(&command_name);
309 context_server_registry.unregister_command(&server_id, &command_name);
310 }
311 }
312 }
313 },
314 )
315 .detach();
316}
317
318fn init_language_model_settings(cx: &mut AppContext) {
319 update_active_language_model_from_settings(cx);
320
321 cx.observe_global::<SettingsStore>(update_active_language_model_from_settings)
322 .detach();
323 cx.subscribe(
324 &LanguageModelRegistry::global(cx),
325 |_, event: &language_model::Event, cx| match event {
326 language_model::Event::ProviderStateChanged
327 | language_model::Event::AddedProvider(_)
328 | language_model::Event::RemovedProvider(_) => {
329 update_active_language_model_from_settings(cx);
330 }
331 _ => {}
332 },
333 )
334 .detach();
335}
336
337fn update_active_language_model_from_settings(cx: &mut AppContext) {
338 let settings = AssistantSettings::get_global(cx);
339 let provider_name = LanguageModelProviderId::from(settings.default_model.provider.clone());
340 let model_id = LanguageModelId::from(settings.default_model.model.clone());
341 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
342 registry.select_active_model(&provider_name, &model_id, cx);
343 });
344}
345
346fn register_slash_commands(prompt_builder: Option<Arc<PromptBuilder>>, cx: &mut AppContext) {
347 let slash_command_registry = SlashCommandRegistry::global(cx);
348 slash_command_registry.register_command(file_command::FileSlashCommand, true);
349 slash_command_registry.register_command(symbols_command::OutlineSlashCommand, true);
350 slash_command_registry.register_command(tab_command::TabSlashCommand, true);
351 slash_command_registry.register_command(project_command::ProjectSlashCommand, true);
352 slash_command_registry.register_command(prompt_command::PromptSlashCommand, true);
353 slash_command_registry.register_command(default_command::DefaultSlashCommand, false);
354 slash_command_registry.register_command(terminal_command::TerminalSlashCommand, true);
355 slash_command_registry.register_command(now_command::NowSlashCommand, false);
356 slash_command_registry.register_command(diagnostics_command::DiagnosticsSlashCommand, true);
357
358 if let Some(prompt_builder) = prompt_builder {
359 slash_command_registry.register_command(
360 workflow_command::WorkflowSlashCommand::new(prompt_builder),
361 true,
362 );
363 }
364 slash_command_registry.register_command(fetch_command::FetchSlashCommand, false);
365
366 update_slash_commands_from_settings(cx);
367 cx.observe_global::<SettingsStore>(update_slash_commands_from_settings)
368 .detach();
369
370 cx.observe_flag::<search_command::SearchSlashCommandFeatureFlag, _>({
371 let slash_command_registry = slash_command_registry.clone();
372 move |is_enabled, _cx| {
373 if is_enabled {
374 slash_command_registry.register_command(search_command::SearchSlashCommand, true);
375 }
376 }
377 })
378 .detach();
379}
380
381fn update_slash_commands_from_settings(cx: &mut AppContext) {
382 let slash_command_registry = SlashCommandRegistry::global(cx);
383 let settings = SlashCommandSettings::get_global(cx);
384
385 if settings.docs.enabled {
386 slash_command_registry.register_command(docs_command::DocsSlashCommand, true);
387 } else {
388 slash_command_registry.unregister_command(docs_command::DocsSlashCommand);
389 }
390
391 if settings.project.enabled {
392 slash_command_registry.register_command(project_command::ProjectSlashCommand, true);
393 } else {
394 slash_command_registry.unregister_command(project_command::ProjectSlashCommand);
395 }
396}
397
398pub fn humanize_token_count(count: usize) -> String {
399 match count {
400 0..=999 => count.to_string(),
401 1000..=9999 => {
402 let thousands = count / 1000;
403 let hundreds = (count % 1000 + 50) / 100;
404 if hundreds == 0 {
405 format!("{}k", thousands)
406 } else if hundreds == 10 {
407 format!("{}k", thousands + 1)
408 } else {
409 format!("{}.{}k", thousands, hundreds)
410 }
411 }
412 _ => format!("{}k", (count + 500) / 1000),
413 }
414}
415
416#[cfg(test)]
417#[ctor::ctor]
418fn init_logger() {
419 if std::env::var("RUST_LOG").is_ok() {
420 env_logger::init();
421 }
422}