1#![cfg_attr(target_os = "windows", allow(unused, dead_code))]
2
3pub mod assistant_panel;
4pub mod assistant_settings;
5mod context;
6pub mod context_store;
7mod inline_assistant;
8mod model_selector;
9mod prompt_library;
10mod prompts;
11mod slash_command;
12pub(crate) mod slash_command_picker;
13pub mod slash_command_settings;
14mod streaming_diff;
15mod terminal_inline_assistant;
16mod workflow;
17
18pub use assistant_panel::{AssistantPanel, AssistantPanelEvent};
19use assistant_settings::AssistantSettings;
20use assistant_slash_command::SlashCommandRegistry;
21use client::{proto, Client};
22use command_palette_hooks::CommandPaletteFilter;
23pub use context::*;
24use context_servers::ContextServerRegistry;
25pub use context_store::*;
26use feature_flags::FeatureFlagAppExt;
27use fs::Fs;
28use gpui::Context as _;
29use gpui::{actions, AppContext, Global, SharedString, UpdateGlobal};
30use indexed_docs::IndexedDocsRegistry;
31pub(crate) use inline_assistant::*;
32use language_model::{
33 LanguageModelId, LanguageModelProviderId, LanguageModelRegistry, LanguageModelResponseMessage,
34};
35pub(crate) use model_selector::*;
36pub use prompts::PromptBuilder;
37use prompts::PromptLoadingParams;
38use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
39use serde::{Deserialize, Serialize};
40use settings::{update_settings_file, Settings, SettingsStore};
41use slash_command::{
42 context_server_command, default_command, diagnostics_command, docs_command, fetch_command,
43 file_command, now_command, project_command, prompt_command, search_command, symbols_command,
44 tab_command, terminal_command, workflow_command,
45};
46use std::sync::Arc;
47pub(crate) use streaming_diff::*;
48use util::ResultExt;
49pub use workflow::*;
50
51use crate::slash_command_settings::SlashCommandSettings;
52
53actions!(
54 assistant,
55 [
56 Assist,
57 Split,
58 CycleMessageRole,
59 QuoteSelection,
60 InsertIntoEditor,
61 ToggleFocus,
62 InsertActivePrompt,
63 DeployHistory,
64 DeployPromptLibrary,
65 ConfirmCommand,
66 NewContext,
67 ToggleModelSelector,
68 ]
69);
70
71const DEFAULT_CONTEXT_LINES: usize = 50;
72
73#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
74pub struct MessageId(clock::Lamport);
75
76impl MessageId {
77 pub fn as_u64(self) -> u64 {
78 self.0.as_u64()
79 }
80}
81
82#[derive(Deserialize, Debug)]
83pub struct LanguageModelUsage {
84 pub prompt_tokens: u32,
85 pub completion_tokens: u32,
86 pub total_tokens: u32,
87}
88
89#[derive(Deserialize, Debug)]
90pub struct LanguageModelChoiceDelta {
91 pub index: u32,
92 pub delta: LanguageModelResponseMessage,
93 pub finish_reason: Option<String>,
94}
95
96#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
97pub enum MessageStatus {
98 Pending,
99 Done,
100 Error(SharedString),
101 Canceled,
102}
103
104impl MessageStatus {
105 pub fn from_proto(status: proto::ContextMessageStatus) -> MessageStatus {
106 match status.variant {
107 Some(proto::context_message_status::Variant::Pending(_)) => MessageStatus::Pending,
108 Some(proto::context_message_status::Variant::Done(_)) => MessageStatus::Done,
109 Some(proto::context_message_status::Variant::Error(error)) => {
110 MessageStatus::Error(error.message.into())
111 }
112 Some(proto::context_message_status::Variant::Canceled(_)) => MessageStatus::Canceled,
113 None => MessageStatus::Pending,
114 }
115 }
116
117 pub fn to_proto(&self) -> proto::ContextMessageStatus {
118 match self {
119 MessageStatus::Pending => proto::ContextMessageStatus {
120 variant: Some(proto::context_message_status::Variant::Pending(
121 proto::context_message_status::Pending {},
122 )),
123 },
124 MessageStatus::Done => proto::ContextMessageStatus {
125 variant: Some(proto::context_message_status::Variant::Done(
126 proto::context_message_status::Done {},
127 )),
128 },
129 MessageStatus::Error(message) => proto::ContextMessageStatus {
130 variant: Some(proto::context_message_status::Variant::Error(
131 proto::context_message_status::Error {
132 message: message.to_string(),
133 },
134 )),
135 },
136 MessageStatus::Canceled => proto::ContextMessageStatus {
137 variant: Some(proto::context_message_status::Variant::Canceled(
138 proto::context_message_status::Canceled {},
139 )),
140 },
141 }
142 }
143}
144
145/// The state pertaining to the Assistant.
146#[derive(Default)]
147struct Assistant {
148 /// Whether the Assistant is enabled.
149 enabled: bool,
150}
151
152impl Global for Assistant {}
153
154impl Assistant {
155 const NAMESPACE: &'static str = "assistant";
156
157 fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
158 if self.enabled == enabled {
159 return;
160 }
161
162 self.enabled = enabled;
163
164 if !enabled {
165 CommandPaletteFilter::update_global(cx, |filter, _cx| {
166 filter.hide_namespace(Self::NAMESPACE);
167 });
168
169 return;
170 }
171
172 CommandPaletteFilter::update_global(cx, |filter, _cx| {
173 filter.show_namespace(Self::NAMESPACE);
174 });
175 }
176}
177
178pub fn init(
179 fs: Arc<dyn Fs>,
180 client: Arc<Client>,
181 stdout_is_a_pty: bool,
182 cx: &mut AppContext,
183) -> Arc<PromptBuilder> {
184 cx.set_global(Assistant::default());
185 AssistantSettings::register(cx);
186 SlashCommandSettings::register(cx);
187
188 // TODO: remove this when 0.148.0 is released.
189 if AssistantSettings::get_global(cx).using_outdated_settings_version {
190 update_settings_file::<AssistantSettings>(fs.clone(), cx, {
191 let fs = fs.clone();
192 |content, cx| {
193 content.update_file(fs, cx);
194 }
195 });
196 }
197
198 cx.spawn(|mut cx| {
199 let client = client.clone();
200 async move {
201 let embedding_provider = CloudEmbeddingProvider::new(client.clone());
202 let semantic_index = SemanticIndex::new(
203 paths::embeddings_dir().join("semantic-index-db.0.mdb"),
204 Arc::new(embedding_provider),
205 &mut cx,
206 )
207 .await?;
208 cx.update(|cx| cx.set_global(semantic_index))
209 }
210 })
211 .detach();
212
213 context_store::init(&client.clone().into());
214 prompt_library::init(cx);
215 init_language_model_settings(cx);
216 assistant_slash_command::init(cx);
217 assistant_panel::init(cx);
218 context_servers::init(cx);
219
220 let prompt_builder = prompts::PromptBuilder::new(Some(PromptLoadingParams {
221 fs: fs.clone(),
222 repo_path: stdout_is_a_pty
223 .then(|| std::env::current_dir().log_err())
224 .flatten(),
225 cx,
226 }))
227 .log_err()
228 .map(Arc::new)
229 .unwrap_or_else(|| Arc::new(prompts::PromptBuilder::new(None).unwrap()));
230 register_slash_commands(Some(prompt_builder.clone()), cx);
231 inline_assistant::init(
232 fs.clone(),
233 prompt_builder.clone(),
234 client.telemetry().clone(),
235 cx,
236 );
237 terminal_inline_assistant::init(
238 fs.clone(),
239 prompt_builder.clone(),
240 client.telemetry().clone(),
241 cx,
242 );
243 IndexedDocsRegistry::init_global(cx);
244
245 CommandPaletteFilter::update_global(cx, |filter, _cx| {
246 filter.hide_namespace(Assistant::NAMESPACE);
247 });
248 Assistant::update_global(cx, |assistant, cx| {
249 let settings = AssistantSettings::get_global(cx);
250
251 assistant.set_enabled(settings.enabled, cx);
252 });
253 cx.observe_global::<SettingsStore>(|cx| {
254 Assistant::update_global(cx, |assistant, cx| {
255 let settings = AssistantSettings::get_global(cx);
256 assistant.set_enabled(settings.enabled, cx);
257 });
258 })
259 .detach();
260
261 register_context_server_handlers(cx);
262
263 prompt_builder
264}
265
266fn register_context_server_handlers(cx: &mut AppContext) {
267 cx.subscribe(
268 &context_servers::manager::ContextServerManager::global(cx),
269 |manager, event, cx| match event {
270 context_servers::manager::Event::ServerStarted { server_id } => {
271 cx.update_model(
272 &manager,
273 |manager: &mut context_servers::manager::ContextServerManager, cx| {
274 let slash_command_registry = SlashCommandRegistry::global(cx);
275 let context_server_registry = ContextServerRegistry::global(cx);
276 if let Some(server) = manager.get_server(server_id) {
277 cx.spawn(|_, _| async move {
278 let Some(protocol) = server.client.read().clone() else {
279 return;
280 };
281
282 if let Some(prompts) = protocol.list_prompts().await.log_err() {
283 for prompt in prompts
284 .into_iter()
285 .filter(context_server_command::acceptable_prompt)
286 {
287 log::info!(
288 "registering context server command: {:?}",
289 prompt.name
290 );
291 context_server_registry.register_command(
292 server.id.clone(),
293 prompt.name.as_str(),
294 );
295 slash_command_registry.register_command(
296 context_server_command::ContextServerSlashCommand::new(
297 &server, prompt,
298 ),
299 true,
300 );
301 }
302 }
303 })
304 .detach();
305 }
306 },
307 );
308 }
309 context_servers::manager::Event::ServerStopped { server_id } => {
310 let slash_command_registry = SlashCommandRegistry::global(cx);
311 let context_server_registry = ContextServerRegistry::global(cx);
312 if let Some(commands) = context_server_registry.get_commands(server_id) {
313 for command_name in commands {
314 slash_command_registry.unregister_command_by_name(&command_name);
315 context_server_registry.unregister_command(&server_id, &command_name);
316 }
317 }
318 }
319 },
320 )
321 .detach();
322}
323
324fn init_language_model_settings(cx: &mut AppContext) {
325 update_active_language_model_from_settings(cx);
326
327 cx.observe_global::<SettingsStore>(update_active_language_model_from_settings)
328 .detach();
329 cx.subscribe(
330 &LanguageModelRegistry::global(cx),
331 |_, event: &language_model::Event, cx| match event {
332 language_model::Event::ProviderStateChanged
333 | language_model::Event::AddedProvider(_)
334 | language_model::Event::RemovedProvider(_) => {
335 update_active_language_model_from_settings(cx);
336 }
337 _ => {}
338 },
339 )
340 .detach();
341}
342
343fn update_active_language_model_from_settings(cx: &mut AppContext) {
344 let settings = AssistantSettings::get_global(cx);
345 let provider_name = LanguageModelProviderId::from(settings.default_model.provider.clone());
346 let model_id = LanguageModelId::from(settings.default_model.model.clone());
347 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
348 registry.select_active_model(&provider_name, &model_id, cx);
349 });
350}
351
352fn register_slash_commands(prompt_builder: Option<Arc<PromptBuilder>>, cx: &mut AppContext) {
353 let slash_command_registry = SlashCommandRegistry::global(cx);
354 slash_command_registry.register_command(file_command::FileSlashCommand, true);
355 slash_command_registry.register_command(symbols_command::OutlineSlashCommand, true);
356 slash_command_registry.register_command(tab_command::TabSlashCommand, true);
357 slash_command_registry.register_command(project_command::ProjectSlashCommand, true);
358 slash_command_registry.register_command(prompt_command::PromptSlashCommand, true);
359 slash_command_registry.register_command(default_command::DefaultSlashCommand, false);
360 slash_command_registry.register_command(terminal_command::TerminalSlashCommand, true);
361 slash_command_registry.register_command(now_command::NowSlashCommand, false);
362 slash_command_registry.register_command(diagnostics_command::DiagnosticsSlashCommand, true);
363
364 if let Some(prompt_builder) = prompt_builder {
365 slash_command_registry.register_command(
366 workflow_command::WorkflowSlashCommand::new(prompt_builder.clone()),
367 true,
368 );
369 }
370 slash_command_registry.register_command(fetch_command::FetchSlashCommand, false);
371
372 update_slash_commands_from_settings(cx);
373 cx.observe_global::<SettingsStore>(update_slash_commands_from_settings)
374 .detach();
375
376 cx.observe_flag::<search_command::SearchSlashCommandFeatureFlag, _>({
377 let slash_command_registry = slash_command_registry.clone();
378 move |is_enabled, _cx| {
379 if is_enabled {
380 slash_command_registry.register_command(search_command::SearchSlashCommand, true);
381 }
382 }
383 })
384 .detach();
385}
386
387fn update_slash_commands_from_settings(cx: &mut AppContext) {
388 let slash_command_registry = SlashCommandRegistry::global(cx);
389 let settings = SlashCommandSettings::get_global(cx);
390
391 if settings.docs.enabled {
392 slash_command_registry.register_command(docs_command::DocsSlashCommand, true);
393 } else {
394 slash_command_registry.unregister_command(docs_command::DocsSlashCommand);
395 }
396
397 if settings.project.enabled {
398 slash_command_registry.register_command(project_command::ProjectSlashCommand, true);
399 } else {
400 slash_command_registry.unregister_command(project_command::ProjectSlashCommand);
401 }
402}
403
404pub fn humanize_token_count(count: usize) -> String {
405 match count {
406 0..=999 => count.to_string(),
407 1000..=9999 => {
408 let thousands = count / 1000;
409 let hundreds = (count % 1000 + 50) / 100;
410 if hundreds == 0 {
411 format!("{}k", thousands)
412 } else if hundreds == 10 {
413 format!("{}k", thousands + 1)
414 } else {
415 format!("{}.{}k", thousands, hundreds)
416 }
417 }
418 _ => format!("{}k", (count + 500) / 1000),
419 }
420}
421
422#[cfg(test)]
423#[ctor::ctor]
424fn init_logger() {
425 if std::env::var("RUST_LOG").is_ok() {
426 env_logger::init();
427 }
428}