1use anyhow::{anyhow, Result};
2use assistant_slash_command::{
3 ArgumentCompletion, SlashCommand, SlashCommandOutput, SlashCommandOutputSection,
4 SlashCommandResult,
5};
6use feature_flags::FeatureFlag;
7use futures::StreamExt;
8use gpui::{App, AsyncApp, Task, WeakEntity, Window};
9use language::{CodeLabel, LspAdapterDelegate};
10use language_model::{
11 LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
12 LanguageModelRequestMessage, Role,
13};
14use semantic_index::{FileSummary, SemanticDb};
15use smol::channel;
16use std::sync::{atomic::AtomicBool, Arc};
17use ui::{prelude::*, BorrowAppContext};
18use util::ResultExt;
19use workspace::Workspace;
20
21use crate::create_label_for_command;
22
23pub struct AutoSlashCommandFeatureFlag;
24
25impl FeatureFlag for AutoSlashCommandFeatureFlag {
26 const NAME: &'static str = "auto-slash-command";
27}
28
29pub struct AutoCommand;
30
31impl SlashCommand for AutoCommand {
32 fn name(&self) -> String {
33 "auto".into()
34 }
35
36 fn description(&self) -> String {
37 "Automatically infer what context to add".into()
38 }
39
40 fn icon(&self) -> IconName {
41 IconName::Wand
42 }
43
44 fn menu_text(&self) -> String {
45 self.description()
46 }
47
48 fn label(&self, cx: &App) -> CodeLabel {
49 create_label_for_command("auto", &["--prompt"], cx)
50 }
51
52 fn complete_argument(
53 self: Arc<Self>,
54 _arguments: &[String],
55 _cancel: Arc<AtomicBool>,
56 workspace: Option<WeakEntity<Workspace>>,
57 _window: &mut Window,
58 cx: &mut App,
59 ) -> Task<Result<Vec<ArgumentCompletion>>> {
60 // There's no autocomplete for a prompt, since it's arbitrary text.
61 // However, we can use this opportunity to kick off a drain of the backlog.
62 // That way, it can hopefully be done resummarizing by the time we've actually
63 // typed out our prompt. This re-runs on every keystroke during autocomplete,
64 // but in the future, we could instead do it only once, when /auto is first entered.
65 let Some(workspace) = workspace.and_then(|ws| ws.upgrade()) else {
66 log::warn!("workspace was dropped or unavailable during /auto autocomplete");
67
68 return Task::ready(Ok(Vec::new()));
69 };
70
71 let project = workspace.read(cx).project().clone();
72 let Some(project_index) =
73 cx.update_global(|index: &mut SemanticDb, cx| index.project_index(project, cx))
74 else {
75 return Task::ready(Err(anyhow!("No project indexer, cannot use /auto")));
76 };
77
78 let cx: &mut App = cx;
79
80 cx.spawn(|cx: gpui::AsyncApp| async move {
81 let task = project_index.read_with(&cx, |project_index, cx| {
82 project_index.flush_summary_backlogs(cx)
83 })?;
84
85 cx.background_executor().spawn(task).await;
86
87 anyhow::Ok(Vec::new())
88 })
89 }
90
91 fn requires_argument(&self) -> bool {
92 true
93 }
94
95 fn run(
96 self: Arc<Self>,
97 arguments: &[String],
98 _context_slash_command_output_sections: &[SlashCommandOutputSection<language::Anchor>],
99 _context_buffer: language::BufferSnapshot,
100 workspace: WeakEntity<Workspace>,
101 _delegate: Option<Arc<dyn LspAdapterDelegate>>,
102 window: &mut Window,
103 cx: &mut App,
104 ) -> Task<SlashCommandResult> {
105 let Some(workspace) = workspace.upgrade() else {
106 return Task::ready(Err(anyhow::anyhow!("workspace was dropped")));
107 };
108 if arguments.is_empty() {
109 return Task::ready(Err(anyhow!("missing prompt")));
110 };
111 let argument = arguments.join(" ");
112 let original_prompt = argument.to_string();
113 let project = workspace.read(cx).project().clone();
114 let Some(project_index) =
115 cx.update_global(|index: &mut SemanticDb, cx| index.project_index(project, cx))
116 else {
117 return Task::ready(Err(anyhow!("no project indexer")));
118 };
119
120 let task = window.spawn(cx, |cx| async move {
121 let summaries = project_index
122 .read_with(&cx, |project_index, cx| project_index.all_summaries(cx))?
123 .await?;
124
125 commands_for_summaries(&summaries, &original_prompt, &cx).await
126 });
127
128 // As a convenience, append /auto's argument to the end of the prompt
129 // so you don't have to write it again.
130 let original_prompt = argument.to_string();
131
132 cx.background_executor().spawn(async move {
133 let commands = task.await?;
134 let mut prompt = String::new();
135
136 log::info!(
137 "Translating this response into slash-commands: {:?}",
138 commands
139 );
140
141 for command in commands {
142 prompt.push('/');
143 prompt.push_str(&command.name);
144 prompt.push(' ');
145 prompt.push_str(&command.arg);
146 prompt.push('\n');
147 }
148
149 prompt.push('\n');
150 prompt.push_str(&original_prompt);
151
152 Ok(SlashCommandOutput {
153 text: prompt,
154 sections: Vec::new(),
155 run_commands_in_text: true,
156 }
157 .to_event_stream())
158 })
159 }
160}
161
162const PROMPT_INSTRUCTIONS_BEFORE_SUMMARY: &str = include_str!("prompt_before_summary.txt");
163const PROMPT_INSTRUCTIONS_AFTER_SUMMARY: &str = include_str!("prompt_after_summary.txt");
164
165fn summaries_prompt(summaries: &[FileSummary], original_prompt: &str) -> String {
166 let json_summaries = serde_json::to_string(summaries).unwrap();
167
168 format!("{PROMPT_INSTRUCTIONS_BEFORE_SUMMARY}\n{json_summaries}\n{PROMPT_INSTRUCTIONS_AFTER_SUMMARY}\n{original_prompt}")
169}
170
171/// The slash commands that the model is told about, and which we look for in the inference response.
172const SUPPORTED_SLASH_COMMANDS: &[&str] = &["search", "file"];
173
174#[derive(Debug, Clone)]
175struct CommandToRun {
176 name: String,
177 arg: String,
178}
179
180/// Given the pre-indexed file summaries for this project, as well as the original prompt
181/// string passed to `/auto`, get a list of slash commands to run, along with their arguments.
182///
183/// The prompt's output does not include the slashes (to reduce the chance that it makes a mistake),
184/// so taking one of these returned Strings and turning it into a real slash-command-with-argument
185/// involves prepending a slash to it.
186///
187/// This function will validate that each of the returned lines begins with one of SUPPORTED_SLASH_COMMANDS.
188/// Any other lines it encounters will be discarded, with a warning logged.
189async fn commands_for_summaries(
190 summaries: &[FileSummary],
191 original_prompt: &str,
192 cx: &AsyncApp,
193) -> Result<Vec<CommandToRun>> {
194 if summaries.is_empty() {
195 log::warn!("Inferring no context because there were no summaries available.");
196 return Ok(Vec::new());
197 }
198
199 // Use the globally configured model to translate the summaries into slash-commands,
200 // because Qwen2-7B-Instruct has not done a good job at that task.
201 let Some(model) = cx.update(|cx| LanguageModelRegistry::read_global(cx).active_model())? else {
202 log::warn!("Can't infer context because there's no active model.");
203 return Ok(Vec::new());
204 };
205 // Only go up to 90% of the actual max token count, to reduce chances of
206 // exceeding the token count due to inaccuracies in the token counting heuristic.
207 let max_token_count = (model.max_token_count() * 9) / 10;
208
209 // Rather than recursing (which would require this async function use a pinned box),
210 // we use an explicit stack of arguments and answers for when we need to "recurse."
211 let mut stack = vec![summaries];
212 let mut final_response = Vec::new();
213 let mut prompts = Vec::new();
214
215 // TODO We only need to create multiple Requests because we currently
216 // don't have the ability to tell if a CompletionProvider::complete response
217 // was a "too many tokens in this request" error. If we had that, then
218 // we could try the request once, instead of having to make separate requests
219 // to check the token count and then afterwards to run the actual prompt.
220 let make_request = |prompt: String| LanguageModelRequest {
221 messages: vec![LanguageModelRequestMessage {
222 role: Role::User,
223 content: vec![prompt.into()],
224 // Nothing in here will benefit from caching
225 cache: false,
226 }],
227 tools: Vec::new(),
228 stop: Vec::new(),
229 temperature: None,
230 };
231
232 while let Some(current_summaries) = stack.pop() {
233 // The split can result in one slice being empty and the other having one element.
234 // Whenever that happens, skip the empty one.
235 if current_summaries.is_empty() {
236 continue;
237 }
238
239 log::info!(
240 "Inferring prompt context using {} file summaries",
241 current_summaries.len()
242 );
243
244 let prompt = summaries_prompt(¤t_summaries, original_prompt);
245 let start = std::time::Instant::now();
246 // Per OpenAI, 1 token ~= 4 chars in English (we go with 4.5 to overestimate a bit, because failed API requests cost a lot of perf)
247 // Verifying this against an actual model.count_tokens() confirms that it's usually within ~5% of the correct answer, whereas
248 // getting the correct answer from tiktoken takes hundreds of milliseconds (compared to this arithmetic being ~free).
249 // source: https://help.openai.com/en/articles/4936856-what-are-tokens-and-how-to-count-them
250 let token_estimate = prompt.len() * 2 / 9;
251 let duration = start.elapsed();
252 log::info!(
253 "Time taken to count tokens for prompt of length {:?}B: {:?}",
254 prompt.len(),
255 duration
256 );
257
258 if token_estimate < max_token_count {
259 prompts.push(prompt);
260 } else if current_summaries.len() == 1 {
261 log::warn!("Inferring context for a single file's summary failed because the prompt's token length exceeded the model's token limit.");
262 } else {
263 log::info!(
264 "Context inference using file summaries resulted in a prompt containing {token_estimate} tokens, which exceeded the model's max of {max_token_count}. Retrying as two separate prompts, each including half the number of summaries.",
265 );
266 let (left, right) = current_summaries.split_at(current_summaries.len() / 2);
267 stack.push(right);
268 stack.push(left);
269 }
270 }
271
272 let all_start = std::time::Instant::now();
273
274 let (tx, rx) = channel::bounded(1024);
275
276 let completion_streams = prompts
277 .into_iter()
278 .map(|prompt| {
279 let request = make_request(prompt.clone());
280 let model = model.clone();
281 let tx = tx.clone();
282 let stream = model.stream_completion(request, &cx);
283
284 (stream, tx)
285 })
286 .collect::<Vec<_>>();
287
288 cx.background_executor()
289 .spawn(async move {
290 let futures = completion_streams
291 .into_iter()
292 .enumerate()
293 .map(|(ix, (stream, tx))| async move {
294 let start = std::time::Instant::now();
295 let events = stream.await?;
296 log::info!("Time taken for awaiting /await chunk stream #{ix}: {:?}", start.elapsed());
297
298 let completion: String = events
299 .filter_map(|event| async {
300 if let Ok(LanguageModelCompletionEvent::Text(text)) = event {
301 Some(text)
302 } else {
303 None
304 }
305 })
306 .collect()
307 .await;
308
309 log::info!("Time taken for all /auto chunks to come back for #{ix}: {:?}", start.elapsed());
310
311 for line in completion.split('\n') {
312 if let Some(first_space) = line.find(' ') {
313 let command = &line[..first_space].trim();
314 let arg = &line[first_space..].trim();
315
316 tx.send(CommandToRun {
317 name: command.to_string(),
318 arg: arg.to_string(),
319 })
320 .await?;
321 } else if !line.trim().is_empty() {
322 // All slash-commands currently supported in context inference need a space for the argument.
323 log::warn!(
324 "Context inference returned a non-blank line that contained no spaces (meaning no argument for the slash command): {:?}",
325 line
326 );
327 }
328 }
329
330 anyhow::Ok(())
331 })
332 .collect::<Vec<_>>();
333
334 let _ = futures::future::try_join_all(futures).await.log_err();
335
336 let duration = all_start.elapsed();
337 eprintln!("All futures completed in {:?}", duration);
338 })
339 .await;
340
341 drop(tx); // Close the channel so that rx.collect() won't hang. This is safe because all futures have completed.
342 let results = rx.collect::<Vec<_>>().await;
343 eprintln!(
344 "Finished collecting from the channel with {} results",
345 results.len()
346 );
347 for command in results {
348 // Don't return empty or duplicate commands
349 if !command.name.is_empty()
350 && !final_response
351 .iter()
352 .any(|cmd: &CommandToRun| cmd.name == command.name && cmd.arg == command.arg)
353 {
354 if SUPPORTED_SLASH_COMMANDS
355 .iter()
356 .any(|supported| &command.name == supported)
357 {
358 final_response.push(command);
359 } else {
360 log::warn!(
361 "Context inference returned an unrecognized slash command: {:?}",
362 command
363 );
364 }
365 }
366 }
367
368 // Sort the commands by name (reversed just so that /search appears before /file)
369 final_response.sort_by(|cmd1, cmd2| cmd1.name.cmp(&cmd2.name).reverse());
370
371 Ok(final_response)
372}