1#![cfg_attr(target_os = "windows", allow(unused, dead_code))]
2
3pub mod assistant_panel;
4pub mod assistant_settings;
5mod context;
6pub mod context_store;
7mod inline_assistant;
8mod model_selector;
9mod patch;
10mod prompt_library;
11mod prompts;
12mod slash_command;
13pub(crate) mod slash_command_picker;
14pub mod slash_command_settings;
15mod slash_command_working_set;
16mod streaming_diff;
17mod terminal_inline_assistant;
18mod tool_working_set;
19mod tools;
20
21use crate::slash_command::project_command::ProjectSlashCommandFeatureFlag;
22pub use crate::slash_command_working_set::{SlashCommandId, SlashCommandWorkingSet};
23pub use crate::tool_working_set::{ToolId, ToolWorkingSet};
24pub use assistant_panel::{AssistantPanel, AssistantPanelEvent};
25use assistant_settings::AssistantSettings;
26use assistant_slash_command::SlashCommandRegistry;
27use assistant_tool::ToolRegistry;
28use client::{proto, Client};
29use command_palette_hooks::CommandPaletteFilter;
30pub use context::*;
31pub use context_store::*;
32use feature_flags::FeatureFlagAppExt;
33use fs::Fs;
34use gpui::impl_actions;
35use gpui::{actions, AppContext, Global, SharedString, UpdateGlobal};
36use indexed_docs::IndexedDocsRegistry;
37pub(crate) use inline_assistant::*;
38use language_model::{
39 LanguageModelId, LanguageModelProviderId, LanguageModelRegistry, LanguageModelResponseMessage,
40};
41pub(crate) use model_selector::*;
42pub use patch::*;
43pub use prompts::PromptBuilder;
44use prompts::PromptLoadingParams;
45use semantic_index::{CloudEmbeddingProvider, SemanticDb};
46use serde::{Deserialize, Serialize};
47use settings::{update_settings_file, Settings, SettingsStore};
48use slash_command::search_command::SearchSlashCommandFeatureFlag;
49use slash_command::{
50 auto_command, cargo_workspace_command, default_command, delta_command, diagnostics_command,
51 docs_command, fetch_command, file_command, now_command, project_command, prompt_command,
52 search_command, selection_command, symbols_command, tab_command, terminal_command,
53};
54use std::path::PathBuf;
55use std::sync::Arc;
56pub(crate) use streaming_diff::*;
57use util::ResultExt;
58
59use crate::slash_command::streaming_example_command;
60use crate::slash_command_settings::SlashCommandSettings;
61
62actions!(
63 assistant,
64 [
65 Assist,
66 Edit,
67 Split,
68 CopyCode,
69 CycleMessageRole,
70 QuoteSelection,
71 InsertIntoEditor,
72 ToggleFocus,
73 InsertActivePrompt,
74 DeployHistory,
75 DeployPromptLibrary,
76 ConfirmCommand,
77 NewContext,
78 ToggleModelSelector,
79 CycleNextInlineAssist,
80 CyclePreviousInlineAssist
81 ]
82);
83
84#[derive(PartialEq, Clone, Deserialize)]
85pub enum InsertDraggedFiles {
86 ProjectPaths(Vec<PathBuf>),
87 ExternalFiles(Vec<PathBuf>),
88}
89
90impl_actions!(assistant, [InsertDraggedFiles]);
91
92const DEFAULT_CONTEXT_LINES: usize = 50;
93
94#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
95pub struct MessageId(clock::Lamport);
96
97impl MessageId {
98 pub fn as_u64(self) -> u64 {
99 self.0.as_u64()
100 }
101}
102
103#[derive(Deserialize, Debug)]
104pub struct LanguageModelUsage {
105 pub prompt_tokens: u32,
106 pub completion_tokens: u32,
107 pub total_tokens: u32,
108}
109
110#[derive(Deserialize, Debug)]
111pub struct LanguageModelChoiceDelta {
112 pub index: u32,
113 pub delta: LanguageModelResponseMessage,
114 pub finish_reason: Option<String>,
115}
116
117#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
118pub enum MessageStatus {
119 Pending,
120 Done,
121 Error(SharedString),
122 Canceled,
123}
124
125impl MessageStatus {
126 pub fn from_proto(status: proto::ContextMessageStatus) -> MessageStatus {
127 match status.variant {
128 Some(proto::context_message_status::Variant::Pending(_)) => MessageStatus::Pending,
129 Some(proto::context_message_status::Variant::Done(_)) => MessageStatus::Done,
130 Some(proto::context_message_status::Variant::Error(error)) => {
131 MessageStatus::Error(error.message.into())
132 }
133 Some(proto::context_message_status::Variant::Canceled(_)) => MessageStatus::Canceled,
134 None => MessageStatus::Pending,
135 }
136 }
137
138 pub fn to_proto(&self) -> proto::ContextMessageStatus {
139 match self {
140 MessageStatus::Pending => proto::ContextMessageStatus {
141 variant: Some(proto::context_message_status::Variant::Pending(
142 proto::context_message_status::Pending {},
143 )),
144 },
145 MessageStatus::Done => proto::ContextMessageStatus {
146 variant: Some(proto::context_message_status::Variant::Done(
147 proto::context_message_status::Done {},
148 )),
149 },
150 MessageStatus::Error(message) => proto::ContextMessageStatus {
151 variant: Some(proto::context_message_status::Variant::Error(
152 proto::context_message_status::Error {
153 message: message.to_string(),
154 },
155 )),
156 },
157 MessageStatus::Canceled => proto::ContextMessageStatus {
158 variant: Some(proto::context_message_status::Variant::Canceled(
159 proto::context_message_status::Canceled {},
160 )),
161 },
162 }
163 }
164}
165
166/// The state pertaining to the Assistant.
167#[derive(Default)]
168struct Assistant {
169 /// Whether the Assistant is enabled.
170 enabled: bool,
171}
172
173impl Global for Assistant {}
174
175impl Assistant {
176 const NAMESPACE: &'static str = "assistant";
177
178 fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
179 if self.enabled == enabled {
180 return;
181 }
182
183 self.enabled = enabled;
184
185 if !enabled {
186 CommandPaletteFilter::update_global(cx, |filter, _cx| {
187 filter.hide_namespace(Self::NAMESPACE);
188 });
189
190 return;
191 }
192
193 CommandPaletteFilter::update_global(cx, |filter, _cx| {
194 filter.show_namespace(Self::NAMESPACE);
195 });
196 }
197}
198
199pub fn init(
200 fs: Arc<dyn Fs>,
201 client: Arc<Client>,
202 stdout_is_a_pty: bool,
203 cx: &mut AppContext,
204) -> Arc<PromptBuilder> {
205 cx.set_global(Assistant::default());
206 AssistantSettings::register(cx);
207 SlashCommandSettings::register(cx);
208
209 // TODO: remove this when 0.148.0 is released.
210 if AssistantSettings::get_global(cx).using_outdated_settings_version {
211 update_settings_file::<AssistantSettings>(fs.clone(), cx, {
212 let fs = fs.clone();
213 |content, cx| {
214 content.update_file(fs, cx);
215 }
216 });
217 }
218
219 cx.spawn(|mut cx| {
220 let client = client.clone();
221 async move {
222 let is_search_slash_command_enabled = cx
223 .update(|cx| cx.wait_for_flag::<SearchSlashCommandFeatureFlag>())?
224 .await;
225 let is_project_slash_command_enabled = cx
226 .update(|cx| cx.wait_for_flag::<ProjectSlashCommandFeatureFlag>())?
227 .await;
228
229 if !is_search_slash_command_enabled && !is_project_slash_command_enabled {
230 return Ok(());
231 }
232
233 let embedding_provider = CloudEmbeddingProvider::new(client.clone());
234 let semantic_index = SemanticDb::new(
235 paths::embeddings_dir().join("semantic-index-db.0.mdb"),
236 Arc::new(embedding_provider),
237 &mut cx,
238 )
239 .await?;
240
241 cx.update(|cx| cx.set_global(semantic_index))
242 }
243 })
244 .detach();
245
246 context_store::init(&client.clone().into());
247 prompt_library::init(cx);
248 init_language_model_settings(cx);
249 assistant_slash_command::init(cx);
250 assistant_tool::init(cx);
251 assistant_panel::init(cx);
252 context_servers::init(cx);
253
254 let prompt_builder = prompts::PromptBuilder::new(Some(PromptLoadingParams {
255 fs: fs.clone(),
256 repo_path: stdout_is_a_pty
257 .then(|| std::env::current_dir().log_err())
258 .flatten(),
259 cx,
260 }))
261 .log_err()
262 .map(Arc::new)
263 .unwrap_or_else(|| Arc::new(prompts::PromptBuilder::new(None).unwrap()));
264 register_slash_commands(Some(prompt_builder.clone()), cx);
265 register_tools(cx);
266 inline_assistant::init(
267 fs.clone(),
268 prompt_builder.clone(),
269 client.telemetry().clone(),
270 cx,
271 );
272 terminal_inline_assistant::init(
273 fs.clone(),
274 prompt_builder.clone(),
275 client.telemetry().clone(),
276 cx,
277 );
278 IndexedDocsRegistry::init_global(cx);
279
280 CommandPaletteFilter::update_global(cx, |filter, _cx| {
281 filter.hide_namespace(Assistant::NAMESPACE);
282 });
283 Assistant::update_global(cx, |assistant, cx| {
284 let settings = AssistantSettings::get_global(cx);
285
286 assistant.set_enabled(settings.enabled, cx);
287 });
288 cx.observe_global::<SettingsStore>(|cx| {
289 Assistant::update_global(cx, |assistant, cx| {
290 let settings = AssistantSettings::get_global(cx);
291 assistant.set_enabled(settings.enabled, cx);
292 });
293 })
294 .detach();
295
296 prompt_builder
297}
298
299fn init_language_model_settings(cx: &mut AppContext) {
300 update_active_language_model_from_settings(cx);
301
302 cx.observe_global::<SettingsStore>(update_active_language_model_from_settings)
303 .detach();
304 cx.subscribe(
305 &LanguageModelRegistry::global(cx),
306 |_, event: &language_model::Event, cx| match event {
307 language_model::Event::ProviderStateChanged
308 | language_model::Event::AddedProvider(_)
309 | language_model::Event::RemovedProvider(_) => {
310 update_active_language_model_from_settings(cx);
311 }
312 _ => {}
313 },
314 )
315 .detach();
316}
317
318fn update_active_language_model_from_settings(cx: &mut AppContext) {
319 let settings = AssistantSettings::get_global(cx);
320 let provider_name = LanguageModelProviderId::from(settings.default_model.provider.clone());
321 let model_id = LanguageModelId::from(settings.default_model.model.clone());
322 let inline_alternatives = settings
323 .inline_alternatives
324 .iter()
325 .map(|alternative| {
326 (
327 LanguageModelProviderId::from(alternative.provider.clone()),
328 LanguageModelId::from(alternative.model.clone()),
329 )
330 })
331 .collect::<Vec<_>>();
332 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
333 registry.select_active_model(&provider_name, &model_id, cx);
334 registry.select_inline_alternative_models(inline_alternatives, cx);
335 });
336}
337
338fn register_slash_commands(prompt_builder: Option<Arc<PromptBuilder>>, cx: &mut AppContext) {
339 let slash_command_registry = SlashCommandRegistry::global(cx);
340
341 slash_command_registry.register_command(file_command::FileSlashCommand, true);
342 slash_command_registry.register_command(delta_command::DeltaSlashCommand, true);
343 slash_command_registry.register_command(symbols_command::OutlineSlashCommand, true);
344 slash_command_registry.register_command(tab_command::TabSlashCommand, true);
345 slash_command_registry
346 .register_command(cargo_workspace_command::CargoWorkspaceSlashCommand, true);
347 slash_command_registry.register_command(prompt_command::PromptSlashCommand, true);
348 slash_command_registry.register_command(selection_command::SelectionCommand, true);
349 slash_command_registry.register_command(default_command::DefaultSlashCommand, false);
350 slash_command_registry.register_command(terminal_command::TerminalSlashCommand, true);
351 slash_command_registry.register_command(now_command::NowSlashCommand, false);
352 slash_command_registry.register_command(diagnostics_command::DiagnosticsSlashCommand, true);
353 slash_command_registry.register_command(fetch_command::FetchSlashCommand, false);
354 slash_command_registry.register_command(fetch_command::FetchSlashCommand, false);
355
356 if let Some(prompt_builder) = prompt_builder {
357 cx.observe_flag::<project_command::ProjectSlashCommandFeatureFlag, _>({
358 let slash_command_registry = slash_command_registry.clone();
359 move |is_enabled, _cx| {
360 if is_enabled {
361 slash_command_registry.register_command(
362 project_command::ProjectSlashCommand::new(prompt_builder.clone()),
363 true,
364 );
365 }
366 }
367 })
368 .detach();
369 }
370
371 cx.observe_flag::<auto_command::AutoSlashCommandFeatureFlag, _>({
372 let slash_command_registry = slash_command_registry.clone();
373 move |is_enabled, _cx| {
374 if is_enabled {
375 // [#auto-staff-ship] TODO remove this when /auto is no longer staff-shipped
376 slash_command_registry.register_command(auto_command::AutoCommand, true);
377 }
378 }
379 })
380 .detach();
381
382 cx.observe_flag::<streaming_example_command::StreamingExampleSlashCommandFeatureFlag, _>({
383 let slash_command_registry = slash_command_registry.clone();
384 move |is_enabled, _cx| {
385 if is_enabled {
386 slash_command_registry.register_command(
387 streaming_example_command::StreamingExampleSlashCommand,
388 false,
389 );
390 }
391 }
392 })
393 .detach();
394
395 update_slash_commands_from_settings(cx);
396 cx.observe_global::<SettingsStore>(update_slash_commands_from_settings)
397 .detach();
398
399 cx.observe_flag::<search_command::SearchSlashCommandFeatureFlag, _>({
400 let slash_command_registry = slash_command_registry.clone();
401 move |is_enabled, _cx| {
402 if is_enabled {
403 slash_command_registry.register_command(search_command::SearchSlashCommand, true);
404 }
405 }
406 })
407 .detach();
408}
409
410fn update_slash_commands_from_settings(cx: &mut AppContext) {
411 let slash_command_registry = SlashCommandRegistry::global(cx);
412 let settings = SlashCommandSettings::get_global(cx);
413
414 if settings.docs.enabled {
415 slash_command_registry.register_command(docs_command::DocsSlashCommand, true);
416 } else {
417 slash_command_registry.unregister_command(docs_command::DocsSlashCommand);
418 }
419
420 if settings.cargo_workspace.enabled {
421 slash_command_registry
422 .register_command(cargo_workspace_command::CargoWorkspaceSlashCommand, true);
423 } else {
424 slash_command_registry
425 .unregister_command(cargo_workspace_command::CargoWorkspaceSlashCommand);
426 }
427}
428
429fn register_tools(cx: &mut AppContext) {
430 let tool_registry = ToolRegistry::global(cx);
431 tool_registry.register_tool(tools::now_tool::NowTool);
432}
433
434pub fn humanize_token_count(count: usize) -> String {
435 match count {
436 0..=999 => count.to_string(),
437 1000..=9999 => {
438 let thousands = count / 1000;
439 let hundreds = (count % 1000 + 50) / 100;
440 if hundreds == 0 {
441 format!("{}k", thousands)
442 } else if hundreds == 10 {
443 format!("{}k", thousands + 1)
444 } else {
445 format!("{}.{}k", thousands, hundreds)
446 }
447 }
448 _ => format!("{}k", (count + 500) / 1000),
449 }
450}
451
452#[cfg(test)]
453#[ctor::ctor]
454fn init_logger() {
455 if std::env::var("RUST_LOG").is_ok() {
456 env_logger::init();
457 }
458}