1#![cfg_attr(target_os = "windows", allow(unused, dead_code))]
2
3pub mod assistant_panel;
4pub mod assistant_settings;
5mod context;
6pub mod context_store;
7mod inline_assistant;
8mod model_selector;
9mod patch;
10mod prompt_library;
11mod prompts;
12mod slash_command;
13pub(crate) mod slash_command_picker;
14pub mod slash_command_settings;
15mod streaming_diff;
16mod terminal_inline_assistant;
17mod tools;
18
19pub use assistant_panel::{AssistantPanel, AssistantPanelEvent};
20use assistant_settings::AssistantSettings;
21use assistant_slash_command::SlashCommandRegistry;
22use assistant_tool::ToolRegistry;
23use client::{proto, Client};
24use command_palette_hooks::CommandPaletteFilter;
25pub use context::*;
26use context_servers::ContextServerRegistry;
27pub use context_store::*;
28use feature_flags::FeatureFlagAppExt;
29use fs::Fs;
30use gpui::{actions, AppContext, Global, SharedString, UpdateGlobal};
31use gpui::{impl_actions, Context as _};
32use indexed_docs::IndexedDocsRegistry;
33pub(crate) use inline_assistant::*;
34use language_model::{
35 LanguageModelId, LanguageModelProviderId, LanguageModelRegistry, LanguageModelResponseMessage,
36};
37pub(crate) use model_selector::*;
38pub use patch::*;
39pub use prompts::PromptBuilder;
40use prompts::PromptLoadingParams;
41use semantic_index::{CloudEmbeddingProvider, SemanticDb};
42use serde::{Deserialize, Serialize};
43use settings::{update_settings_file, Settings, SettingsStore};
44use slash_command::{
45 auto_command, cargo_workspace_command, context_server_command, default_command, delta_command,
46 diagnostics_command, docs_command, fetch_command, file_command, now_command, project_command,
47 prompt_command, search_command, symbols_command, tab_command, terminal_command,
48};
49use std::path::PathBuf;
50use std::sync::Arc;
51pub(crate) use streaming_diff::*;
52use util::ResultExt;
53
54use crate::slash_command_settings::SlashCommandSettings;
55
56actions!(
57 assistant,
58 [
59 Assist,
60 Edit,
61 Split,
62 CopyCode,
63 CycleMessageRole,
64 QuoteSelection,
65 InsertIntoEditor,
66 ToggleFocus,
67 InsertActivePrompt,
68 DeployHistory,
69 DeployPromptLibrary,
70 ConfirmCommand,
71 NewContext,
72 ToggleModelSelector,
73 CycleNextInlineAssist,
74 CyclePreviousInlineAssist
75 ]
76);
77
78#[derive(PartialEq, Clone, Deserialize)]
79pub enum InsertDraggedFiles {
80 ProjectPaths(Vec<PathBuf>),
81 ExternalFiles(Vec<PathBuf>),
82}
83
84impl_actions!(assistant, [InsertDraggedFiles]);
85
86const DEFAULT_CONTEXT_LINES: usize = 50;
87
88#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
89pub struct MessageId(clock::Lamport);
90
91impl MessageId {
92 pub fn as_u64(self) -> u64 {
93 self.0.as_u64()
94 }
95}
96
97#[derive(Deserialize, Debug)]
98pub struct LanguageModelUsage {
99 pub prompt_tokens: u32,
100 pub completion_tokens: u32,
101 pub total_tokens: u32,
102}
103
104#[derive(Deserialize, Debug)]
105pub struct LanguageModelChoiceDelta {
106 pub index: u32,
107 pub delta: LanguageModelResponseMessage,
108 pub finish_reason: Option<String>,
109}
110
111#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
112pub enum MessageStatus {
113 Pending,
114 Done,
115 Error(SharedString),
116 Canceled,
117}
118
119impl MessageStatus {
120 pub fn from_proto(status: proto::ContextMessageStatus) -> MessageStatus {
121 match status.variant {
122 Some(proto::context_message_status::Variant::Pending(_)) => MessageStatus::Pending,
123 Some(proto::context_message_status::Variant::Done(_)) => MessageStatus::Done,
124 Some(proto::context_message_status::Variant::Error(error)) => {
125 MessageStatus::Error(error.message.into())
126 }
127 Some(proto::context_message_status::Variant::Canceled(_)) => MessageStatus::Canceled,
128 None => MessageStatus::Pending,
129 }
130 }
131
132 pub fn to_proto(&self) -> proto::ContextMessageStatus {
133 match self {
134 MessageStatus::Pending => proto::ContextMessageStatus {
135 variant: Some(proto::context_message_status::Variant::Pending(
136 proto::context_message_status::Pending {},
137 )),
138 },
139 MessageStatus::Done => proto::ContextMessageStatus {
140 variant: Some(proto::context_message_status::Variant::Done(
141 proto::context_message_status::Done {},
142 )),
143 },
144 MessageStatus::Error(message) => proto::ContextMessageStatus {
145 variant: Some(proto::context_message_status::Variant::Error(
146 proto::context_message_status::Error {
147 message: message.to_string(),
148 },
149 )),
150 },
151 MessageStatus::Canceled => proto::ContextMessageStatus {
152 variant: Some(proto::context_message_status::Variant::Canceled(
153 proto::context_message_status::Canceled {},
154 )),
155 },
156 }
157 }
158}
159
160/// The state pertaining to the Assistant.
161#[derive(Default)]
162struct Assistant {
163 /// Whether the Assistant is enabled.
164 enabled: bool,
165}
166
167impl Global for Assistant {}
168
169impl Assistant {
170 const NAMESPACE: &'static str = "assistant";
171
172 fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
173 if self.enabled == enabled {
174 return;
175 }
176
177 self.enabled = enabled;
178
179 if !enabled {
180 CommandPaletteFilter::update_global(cx, |filter, _cx| {
181 filter.hide_namespace(Self::NAMESPACE);
182 });
183
184 return;
185 }
186
187 CommandPaletteFilter::update_global(cx, |filter, _cx| {
188 filter.show_namespace(Self::NAMESPACE);
189 });
190 }
191}
192
193pub fn init(
194 fs: Arc<dyn Fs>,
195 client: Arc<Client>,
196 stdout_is_a_pty: bool,
197 cx: &mut AppContext,
198) -> Arc<PromptBuilder> {
199 cx.set_global(Assistant::default());
200 AssistantSettings::register(cx);
201 SlashCommandSettings::register(cx);
202
203 // TODO: remove this when 0.148.0 is released.
204 if AssistantSettings::get_global(cx).using_outdated_settings_version {
205 update_settings_file::<AssistantSettings>(fs.clone(), cx, {
206 let fs = fs.clone();
207 |content, cx| {
208 content.update_file(fs, cx);
209 }
210 });
211 }
212
213 cx.spawn(|mut cx| {
214 let client = client.clone();
215 async move {
216 let embedding_provider = CloudEmbeddingProvider::new(client.clone());
217 let semantic_index = SemanticDb::new(
218 paths::embeddings_dir().join("semantic-index-db.0.mdb"),
219 Arc::new(embedding_provider),
220 &mut cx,
221 )
222 .await?;
223
224 cx.update(|cx| cx.set_global(semantic_index))
225 }
226 })
227 .detach();
228
229 context_store::init(&client.clone().into());
230 prompt_library::init(cx);
231 init_language_model_settings(cx);
232 assistant_slash_command::init(cx);
233 assistant_tool::init(cx);
234 assistant_panel::init(cx);
235 context_servers::init(cx);
236
237 let prompt_builder = prompts::PromptBuilder::new(Some(PromptLoadingParams {
238 fs: fs.clone(),
239 repo_path: stdout_is_a_pty
240 .then(|| std::env::current_dir().log_err())
241 .flatten(),
242 cx,
243 }))
244 .log_err()
245 .map(Arc::new)
246 .unwrap_or_else(|| Arc::new(prompts::PromptBuilder::new(None).unwrap()));
247 register_slash_commands(Some(prompt_builder.clone()), cx);
248 register_tools(cx);
249 inline_assistant::init(
250 fs.clone(),
251 prompt_builder.clone(),
252 client.telemetry().clone(),
253 cx,
254 );
255 terminal_inline_assistant::init(
256 fs.clone(),
257 prompt_builder.clone(),
258 client.telemetry().clone(),
259 cx,
260 );
261 IndexedDocsRegistry::init_global(cx);
262
263 CommandPaletteFilter::update_global(cx, |filter, _cx| {
264 filter.hide_namespace(Assistant::NAMESPACE);
265 });
266 Assistant::update_global(cx, |assistant, cx| {
267 let settings = AssistantSettings::get_global(cx);
268
269 assistant.set_enabled(settings.enabled, cx);
270 });
271 cx.observe_global::<SettingsStore>(|cx| {
272 Assistant::update_global(cx, |assistant, cx| {
273 let settings = AssistantSettings::get_global(cx);
274 assistant.set_enabled(settings.enabled, cx);
275 });
276 })
277 .detach();
278
279 register_context_server_handlers(cx);
280
281 prompt_builder
282}
283
284fn register_context_server_handlers(cx: &mut AppContext) {
285 cx.subscribe(
286 &context_servers::manager::ContextServerManager::global(cx),
287 |manager, event, cx| match event {
288 context_servers::manager::Event::ServerStarted { server_id } => {
289 cx.update_model(
290 &manager,
291 |manager: &mut context_servers::manager::ContextServerManager, cx| {
292 let slash_command_registry = SlashCommandRegistry::global(cx);
293 let context_server_registry = ContextServerRegistry::global(cx);
294 if let Some(server) = manager.get_server(server_id) {
295 cx.spawn(|_, _| async move {
296 let Some(protocol) = server.client.read().clone() else {
297 return;
298 };
299
300 if protocol.capable(context_servers::protocol::ServerCapability::Prompts) {
301 if let Some(prompts) = protocol.list_prompts().await.log_err() {
302 for prompt in prompts
303 .into_iter()
304 .filter(context_server_command::acceptable_prompt)
305 {
306 log::info!(
307 "registering context server command: {:?}",
308 prompt.name
309 );
310 context_server_registry.register_command(
311 server.id.clone(),
312 prompt.name.as_str(),
313 );
314 slash_command_registry.register_command(
315 context_server_command::ContextServerSlashCommand::new(
316 &server, prompt,
317 ),
318 true,
319 );
320 }
321 }
322 }
323 })
324 .detach();
325 }
326 },
327 );
328
329 cx.update_model(
330 &manager,
331 |manager: &mut context_servers::manager::ContextServerManager, cx| {
332 let tool_registry = ToolRegistry::global(cx);
333 let context_server_registry = ContextServerRegistry::global(cx);
334 if let Some(server) = manager.get_server(server_id) {
335 cx.spawn(|_, _| async move {
336 let Some(protocol) = server.client.read().clone() else {
337 return;
338 };
339
340 if protocol.capable(context_servers::protocol::ServerCapability::Tools) {
341 if let Some(tools) = protocol.list_tools().await.log_err() {
342 for tool in tools.tools {
343 log::info!(
344 "registering context server tool: {:?}",
345 tool.name
346 );
347 context_server_registry.register_tool(
348 server.id.clone(),
349 tool.name.as_str(),
350 );
351 tool_registry.register_tool(
352 tools::context_server_tool::ContextServerTool::new(
353 server.id.clone(),
354 tool
355 ),
356 );
357 }
358 }
359 }
360 })
361 .detach();
362 }
363 },
364 );
365 }
366 context_servers::manager::Event::ServerStopped { server_id } => {
367 let slash_command_registry = SlashCommandRegistry::global(cx);
368 let context_server_registry = ContextServerRegistry::global(cx);
369 if let Some(commands) = context_server_registry.get_commands(server_id) {
370 for command_name in commands {
371 slash_command_registry.unregister_command_by_name(&command_name);
372 context_server_registry.unregister_command(&server_id, &command_name);
373 }
374 }
375
376 if let Some(tools) = context_server_registry.get_tools(server_id) {
377 let tool_registry = ToolRegistry::global(cx);
378 for tool_name in tools {
379 tool_registry.unregister_tool_by_name(&tool_name);
380 context_server_registry.unregister_tool(&server_id, &tool_name);
381 }
382 }
383 }
384 },
385 )
386 .detach();
387}
388
389fn init_language_model_settings(cx: &mut AppContext) {
390 update_active_language_model_from_settings(cx);
391
392 cx.observe_global::<SettingsStore>(update_active_language_model_from_settings)
393 .detach();
394 cx.subscribe(
395 &LanguageModelRegistry::global(cx),
396 |_, event: &language_model::Event, cx| match event {
397 language_model::Event::ProviderStateChanged
398 | language_model::Event::AddedProvider(_)
399 | language_model::Event::RemovedProvider(_) => {
400 update_active_language_model_from_settings(cx);
401 }
402 _ => {}
403 },
404 )
405 .detach();
406}
407
408fn update_active_language_model_from_settings(cx: &mut AppContext) {
409 let settings = AssistantSettings::get_global(cx);
410 let provider_name = LanguageModelProviderId::from(settings.default_model.provider.clone());
411 let model_id = LanguageModelId::from(settings.default_model.model.clone());
412 let inline_alternatives = settings
413 .inline_alternatives
414 .iter()
415 .map(|alternative| {
416 (
417 LanguageModelProviderId::from(alternative.provider.clone()),
418 LanguageModelId::from(alternative.model.clone()),
419 )
420 })
421 .collect::<Vec<_>>();
422 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
423 registry.select_active_model(&provider_name, &model_id, cx);
424 registry.select_inline_alternative_models(inline_alternatives, cx);
425 });
426}
427
428fn register_slash_commands(prompt_builder: Option<Arc<PromptBuilder>>, cx: &mut AppContext) {
429 let slash_command_registry = SlashCommandRegistry::global(cx);
430
431 slash_command_registry.register_command(file_command::FileSlashCommand, true);
432 slash_command_registry.register_command(delta_command::DeltaSlashCommand, true);
433 slash_command_registry.register_command(symbols_command::OutlineSlashCommand, true);
434 slash_command_registry.register_command(tab_command::TabSlashCommand, true);
435 slash_command_registry
436 .register_command(cargo_workspace_command::CargoWorkspaceSlashCommand, true);
437 slash_command_registry.register_command(prompt_command::PromptSlashCommand, true);
438 slash_command_registry.register_command(default_command::DefaultSlashCommand, false);
439 slash_command_registry.register_command(terminal_command::TerminalSlashCommand, true);
440 slash_command_registry.register_command(now_command::NowSlashCommand, false);
441 slash_command_registry.register_command(diagnostics_command::DiagnosticsSlashCommand, true);
442 slash_command_registry.register_command(fetch_command::FetchSlashCommand, false);
443 slash_command_registry.register_command(fetch_command::FetchSlashCommand, false);
444
445 if let Some(prompt_builder) = prompt_builder {
446 cx.observe_flag::<project_command::ProjectSlashCommandFeatureFlag, _>({
447 let slash_command_registry = slash_command_registry.clone();
448 move |is_enabled, _cx| {
449 if is_enabled {
450 slash_command_registry.register_command(
451 project_command::ProjectSlashCommand::new(prompt_builder.clone()),
452 true,
453 );
454 }
455 }
456 })
457 .detach();
458 }
459
460 cx.observe_flag::<auto_command::AutoSlashCommandFeatureFlag, _>({
461 let slash_command_registry = slash_command_registry.clone();
462 move |is_enabled, _cx| {
463 if is_enabled {
464 // [#auto-staff-ship] TODO remove this when /auto is no longer staff-shipped
465 slash_command_registry.register_command(auto_command::AutoCommand, true);
466 }
467 }
468 })
469 .detach();
470
471 update_slash_commands_from_settings(cx);
472 cx.observe_global::<SettingsStore>(update_slash_commands_from_settings)
473 .detach();
474
475 cx.observe_flag::<search_command::SearchSlashCommandFeatureFlag, _>({
476 let slash_command_registry = slash_command_registry.clone();
477 move |is_enabled, _cx| {
478 if is_enabled {
479 slash_command_registry.register_command(search_command::SearchSlashCommand, true);
480 }
481 }
482 })
483 .detach();
484}
485
486fn update_slash_commands_from_settings(cx: &mut AppContext) {
487 let slash_command_registry = SlashCommandRegistry::global(cx);
488 let settings = SlashCommandSettings::get_global(cx);
489
490 if settings.docs.enabled {
491 slash_command_registry.register_command(docs_command::DocsSlashCommand, true);
492 } else {
493 slash_command_registry.unregister_command(docs_command::DocsSlashCommand);
494 }
495
496 if settings.cargo_workspace.enabled {
497 slash_command_registry
498 .register_command(cargo_workspace_command::CargoWorkspaceSlashCommand, true);
499 } else {
500 slash_command_registry
501 .unregister_command(cargo_workspace_command::CargoWorkspaceSlashCommand);
502 }
503}
504
505fn register_tools(cx: &mut AppContext) {
506 let tool_registry = ToolRegistry::global(cx);
507 tool_registry.register_tool(tools::now_tool::NowTool);
508}
509
510pub fn humanize_token_count(count: usize) -> String {
511 match count {
512 0..=999 => count.to_string(),
513 1000..=9999 => {
514 let thousands = count / 1000;
515 let hundreds = (count % 1000 + 50) / 100;
516 if hundreds == 0 {
517 format!("{}k", thousands)
518 } else if hundreds == 10 {
519 format!("{}k", thousands + 1)
520 } else {
521 format!("{}.{}k", thousands, hundreds)
522 }
523 }
524 _ => format!("{}k", (count + 500) / 1000),
525 }
526}
527
528#[cfg(test)]
529#[ctor::ctor]
530fn init_logger() {
531 if std::env::var("RUST_LOG").is_ok() {
532 env_logger::init();
533 }
534}