1#![cfg_attr(target_os = "windows", allow(unused, dead_code))]
2
3pub mod assistant_panel;
4pub mod assistant_settings;
5mod context;
6pub mod context_store;
7mod inline_assistant;
8mod model_selector;
9mod prompt_library;
10mod prompts;
11mod slash_command;
12pub(crate) mod slash_command_picker;
13pub mod slash_command_settings;
14mod streaming_diff;
15mod terminal_inline_assistant;
16mod tools;
17mod workflow;
18
19pub use assistant_panel::{AssistantPanel, AssistantPanelEvent};
20use assistant_settings::AssistantSettings;
21use assistant_slash_command::SlashCommandRegistry;
22use assistant_tool::ToolRegistry;
23use client::{proto, Client};
24use command_palette_hooks::CommandPaletteFilter;
25pub use context::*;
26use context_servers::ContextServerRegistry;
27pub use context_store::*;
28use feature_flags::FeatureFlagAppExt;
29use fs::Fs;
30use gpui::{actions, AppContext, Global, SharedString, UpdateGlobal};
31use gpui::{impl_actions, Context as _};
32use indexed_docs::IndexedDocsRegistry;
33pub(crate) use inline_assistant::*;
34use language_model::{
35 LanguageModelId, LanguageModelProviderId, LanguageModelRegistry, LanguageModelResponseMessage,
36};
37pub(crate) use model_selector::*;
38pub use prompts::PromptBuilder;
39use prompts::PromptLoadingParams;
40use semantic_index::{CloudEmbeddingProvider, SemanticDb};
41use serde::{Deserialize, Serialize};
42use settings::{update_settings_file, Settings, SettingsStore};
43use slash_command::{
44 auto_command, context_server_command, default_command, delta_command, diagnostics_command,
45 docs_command, fetch_command, file_command, now_command, project_command, prompt_command,
46 search_command, symbols_command, tab_command, terminal_command, workflow_command,
47};
48use std::path::PathBuf;
49use std::sync::Arc;
50pub(crate) use streaming_diff::*;
51use util::ResultExt;
52pub use workflow::*;
53
54use crate::slash_command_settings::SlashCommandSettings;
55
56actions!(
57 assistant,
58 [
59 Assist,
60 Split,
61 CopyCode,
62 CycleMessageRole,
63 QuoteSelection,
64 InsertIntoEditor,
65 ToggleFocus,
66 InsertActivePrompt,
67 DeployHistory,
68 DeployPromptLibrary,
69 ConfirmCommand,
70 NewContext,
71 ToggleModelSelector,
72 ]
73);
74
75#[derive(PartialEq, Clone, Deserialize)]
76pub enum InsertDraggedFiles {
77 ProjectPaths(Vec<PathBuf>),
78 ExternalFiles(Vec<PathBuf>),
79}
80
81impl_actions!(assistant, [InsertDraggedFiles]);
82
83const DEFAULT_CONTEXT_LINES: usize = 50;
84
85#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
86pub struct MessageId(clock::Lamport);
87
88impl MessageId {
89 pub fn as_u64(self) -> u64 {
90 self.0.as_u64()
91 }
92}
93
94#[derive(Deserialize, Debug)]
95pub struct LanguageModelUsage {
96 pub prompt_tokens: u32,
97 pub completion_tokens: u32,
98 pub total_tokens: u32,
99}
100
101#[derive(Deserialize, Debug)]
102pub struct LanguageModelChoiceDelta {
103 pub index: u32,
104 pub delta: LanguageModelResponseMessage,
105 pub finish_reason: Option<String>,
106}
107
108#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
109pub enum MessageStatus {
110 Pending,
111 Done,
112 Error(SharedString),
113 Canceled,
114}
115
116impl MessageStatus {
117 pub fn from_proto(status: proto::ContextMessageStatus) -> MessageStatus {
118 match status.variant {
119 Some(proto::context_message_status::Variant::Pending(_)) => MessageStatus::Pending,
120 Some(proto::context_message_status::Variant::Done(_)) => MessageStatus::Done,
121 Some(proto::context_message_status::Variant::Error(error)) => {
122 MessageStatus::Error(error.message.into())
123 }
124 Some(proto::context_message_status::Variant::Canceled(_)) => MessageStatus::Canceled,
125 None => MessageStatus::Pending,
126 }
127 }
128
129 pub fn to_proto(&self) -> proto::ContextMessageStatus {
130 match self {
131 MessageStatus::Pending => proto::ContextMessageStatus {
132 variant: Some(proto::context_message_status::Variant::Pending(
133 proto::context_message_status::Pending {},
134 )),
135 },
136 MessageStatus::Done => proto::ContextMessageStatus {
137 variant: Some(proto::context_message_status::Variant::Done(
138 proto::context_message_status::Done {},
139 )),
140 },
141 MessageStatus::Error(message) => proto::ContextMessageStatus {
142 variant: Some(proto::context_message_status::Variant::Error(
143 proto::context_message_status::Error {
144 message: message.to_string(),
145 },
146 )),
147 },
148 MessageStatus::Canceled => proto::ContextMessageStatus {
149 variant: Some(proto::context_message_status::Variant::Canceled(
150 proto::context_message_status::Canceled {},
151 )),
152 },
153 }
154 }
155}
156
157/// The state pertaining to the Assistant.
158#[derive(Default)]
159struct Assistant {
160 /// Whether the Assistant is enabled.
161 enabled: bool,
162}
163
164impl Global for Assistant {}
165
166impl Assistant {
167 const NAMESPACE: &'static str = "assistant";
168
169 fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
170 if self.enabled == enabled {
171 return;
172 }
173
174 self.enabled = enabled;
175
176 if !enabled {
177 CommandPaletteFilter::update_global(cx, |filter, _cx| {
178 filter.hide_namespace(Self::NAMESPACE);
179 });
180
181 return;
182 }
183
184 CommandPaletteFilter::update_global(cx, |filter, _cx| {
185 filter.show_namespace(Self::NAMESPACE);
186 });
187 }
188}
189
190pub fn init(
191 fs: Arc<dyn Fs>,
192 client: Arc<Client>,
193 stdout_is_a_pty: bool,
194 cx: &mut AppContext,
195) -> Arc<PromptBuilder> {
196 cx.set_global(Assistant::default());
197 AssistantSettings::register(cx);
198 SlashCommandSettings::register(cx);
199
200 // TODO: remove this when 0.148.0 is released.
201 if AssistantSettings::get_global(cx).using_outdated_settings_version {
202 update_settings_file::<AssistantSettings>(fs.clone(), cx, {
203 let fs = fs.clone();
204 |content, cx| {
205 content.update_file(fs, cx);
206 }
207 });
208 }
209
210 cx.spawn(|mut cx| {
211 let client = client.clone();
212 async move {
213 let embedding_provider = CloudEmbeddingProvider::new(client.clone());
214 let semantic_index = SemanticDb::new(
215 paths::embeddings_dir().join("semantic-index-db.0.mdb"),
216 Arc::new(embedding_provider),
217 &mut cx,
218 )
219 .await?;
220
221 cx.update(|cx| cx.set_global(semantic_index))
222 }
223 })
224 .detach();
225
226 context_store::init(&client.clone().into());
227 prompt_library::init(cx);
228 init_language_model_settings(cx);
229 assistant_slash_command::init(cx);
230 assistant_tool::init(cx);
231 assistant_panel::init(cx);
232 context_servers::init(cx);
233
234 let prompt_builder = prompts::PromptBuilder::new(Some(PromptLoadingParams {
235 fs: fs.clone(),
236 repo_path: stdout_is_a_pty
237 .then(|| std::env::current_dir().log_err())
238 .flatten(),
239 cx,
240 }))
241 .log_err()
242 .map(Arc::new)
243 .unwrap_or_else(|| Arc::new(prompts::PromptBuilder::new(None).unwrap()));
244 register_slash_commands(Some(prompt_builder.clone()), cx);
245 register_tools(cx);
246 inline_assistant::init(
247 fs.clone(),
248 prompt_builder.clone(),
249 client.telemetry().clone(),
250 cx,
251 );
252 terminal_inline_assistant::init(
253 fs.clone(),
254 prompt_builder.clone(),
255 client.telemetry().clone(),
256 cx,
257 );
258 IndexedDocsRegistry::init_global(cx);
259
260 CommandPaletteFilter::update_global(cx, |filter, _cx| {
261 filter.hide_namespace(Assistant::NAMESPACE);
262 });
263 Assistant::update_global(cx, |assistant, cx| {
264 let settings = AssistantSettings::get_global(cx);
265
266 assistant.set_enabled(settings.enabled, cx);
267 });
268 cx.observe_global::<SettingsStore>(|cx| {
269 Assistant::update_global(cx, |assistant, cx| {
270 let settings = AssistantSettings::get_global(cx);
271 assistant.set_enabled(settings.enabled, cx);
272 });
273 })
274 .detach();
275
276 register_context_server_handlers(cx);
277
278 prompt_builder
279}
280
281fn register_context_server_handlers(cx: &mut AppContext) {
282 cx.subscribe(
283 &context_servers::manager::ContextServerManager::global(cx),
284 |manager, event, cx| match event {
285 context_servers::manager::Event::ServerStarted { server_id } => {
286 cx.update_model(
287 &manager,
288 |manager: &mut context_servers::manager::ContextServerManager, cx| {
289 let slash_command_registry = SlashCommandRegistry::global(cx);
290 let context_server_registry = ContextServerRegistry::global(cx);
291 if let Some(server) = manager.get_server(server_id) {
292 cx.spawn(|_, _| async move {
293 let Some(protocol) = server.client.read().clone() else {
294 return;
295 };
296
297 if let Some(prompts) = protocol.list_prompts().await.log_err() {
298 for prompt in prompts
299 .into_iter()
300 .filter(context_server_command::acceptable_prompt)
301 {
302 log::info!(
303 "registering context server command: {:?}",
304 prompt.name
305 );
306 context_server_registry.register_command(
307 server.id.clone(),
308 prompt.name.as_str(),
309 );
310 slash_command_registry.register_command(
311 context_server_command::ContextServerSlashCommand::new(
312 &server, prompt,
313 ),
314 true,
315 );
316 }
317 }
318 })
319 .detach();
320 }
321 },
322 );
323 }
324 context_servers::manager::Event::ServerStopped { server_id } => {
325 let slash_command_registry = SlashCommandRegistry::global(cx);
326 let context_server_registry = ContextServerRegistry::global(cx);
327 if let Some(commands) = context_server_registry.get_commands(server_id) {
328 for command_name in commands {
329 slash_command_registry.unregister_command_by_name(&command_name);
330 context_server_registry.unregister_command(&server_id, &command_name);
331 }
332 }
333 }
334 },
335 )
336 .detach();
337}
338
339fn init_language_model_settings(cx: &mut AppContext) {
340 update_active_language_model_from_settings(cx);
341
342 cx.observe_global::<SettingsStore>(update_active_language_model_from_settings)
343 .detach();
344 cx.subscribe(
345 &LanguageModelRegistry::global(cx),
346 |_, event: &language_model::Event, cx| match event {
347 language_model::Event::ProviderStateChanged
348 | language_model::Event::AddedProvider(_)
349 | language_model::Event::RemovedProvider(_) => {
350 update_active_language_model_from_settings(cx);
351 }
352 _ => {}
353 },
354 )
355 .detach();
356}
357
358fn update_active_language_model_from_settings(cx: &mut AppContext) {
359 let settings = AssistantSettings::get_global(cx);
360 let provider_name = LanguageModelProviderId::from(settings.default_model.provider.clone());
361 let model_id = LanguageModelId::from(settings.default_model.model.clone());
362 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
363 registry.select_active_model(&provider_name, &model_id, cx);
364 });
365}
366
367fn register_slash_commands(prompt_builder: Option<Arc<PromptBuilder>>, cx: &mut AppContext) {
368 let slash_command_registry = SlashCommandRegistry::global(cx);
369
370 slash_command_registry.register_command(file_command::FileSlashCommand, true);
371 slash_command_registry.register_command(delta_command::DeltaSlashCommand, true);
372 slash_command_registry.register_command(symbols_command::OutlineSlashCommand, true);
373 slash_command_registry.register_command(tab_command::TabSlashCommand, true);
374 slash_command_registry.register_command(project_command::ProjectSlashCommand, true);
375 slash_command_registry.register_command(prompt_command::PromptSlashCommand, true);
376 slash_command_registry.register_command(default_command::DefaultSlashCommand, false);
377 slash_command_registry.register_command(terminal_command::TerminalSlashCommand, true);
378 slash_command_registry.register_command(now_command::NowSlashCommand, false);
379 slash_command_registry.register_command(diagnostics_command::DiagnosticsSlashCommand, true);
380
381 if let Some(prompt_builder) = prompt_builder {
382 slash_command_registry.register_command(
383 workflow_command::WorkflowSlashCommand::new(prompt_builder.clone()),
384 true,
385 );
386 }
387 slash_command_registry.register_command(fetch_command::FetchSlashCommand, false);
388
389 cx.observe_flag::<auto_command::AutoSlashCommandFeatureFlag, _>({
390 let slash_command_registry = slash_command_registry.clone();
391 move |is_enabled, _cx| {
392 if is_enabled {
393 // [#auto-staff-ship] TODO remove this when /auto is no longer staff-shipped
394 slash_command_registry.register_command(auto_command::AutoCommand, true);
395 }
396 }
397 })
398 .detach();
399
400 update_slash_commands_from_settings(cx);
401 cx.observe_global::<SettingsStore>(update_slash_commands_from_settings)
402 .detach();
403
404 cx.observe_flag::<search_command::SearchSlashCommandFeatureFlag, _>({
405 let slash_command_registry = slash_command_registry.clone();
406 move |is_enabled, _cx| {
407 if is_enabled {
408 slash_command_registry.register_command(search_command::SearchSlashCommand, true);
409 }
410 }
411 })
412 .detach();
413}
414
415fn update_slash_commands_from_settings(cx: &mut AppContext) {
416 let slash_command_registry = SlashCommandRegistry::global(cx);
417 let settings = SlashCommandSettings::get_global(cx);
418
419 if settings.docs.enabled {
420 slash_command_registry.register_command(docs_command::DocsSlashCommand, true);
421 } else {
422 slash_command_registry.unregister_command(docs_command::DocsSlashCommand);
423 }
424
425 if settings.project.enabled {
426 slash_command_registry.register_command(project_command::ProjectSlashCommand, true);
427 } else {
428 slash_command_registry.unregister_command(project_command::ProjectSlashCommand);
429 }
430}
431
432fn register_tools(cx: &mut AppContext) {
433 let tool_registry = ToolRegistry::global(cx);
434 tool_registry.register_tool(tools::now_tool::NowTool);
435}
436
437pub fn humanize_token_count(count: usize) -> String {
438 match count {
439 0..=999 => count.to_string(),
440 1000..=9999 => {
441 let thousands = count / 1000;
442 let hundreds = (count % 1000 + 50) / 100;
443 if hundreds == 0 {
444 format!("{}k", thousands)
445 } else if hundreds == 10 {
446 format!("{}k", thousands + 1)
447 } else {
448 format!("{}.{}k", thousands, hundreds)
449 }
450 }
451 _ => format!("{}k", (count + 500) / 1000),
452 }
453}
454
455#[cfg(test)]
456#[ctor::ctor]
457fn init_logger() {
458 if std::env::var("RUST_LOG").is_ok() {
459 env_logger::init();
460 }
461}