1use super::create_label_for_command;
2use super::{SlashCommand, SlashCommandOutput};
3use anyhow::{anyhow, Result};
4use assistant_slash_command::ArgumentCompletion;
5use feature_flags::FeatureFlag;
6use futures::StreamExt;
7use gpui::{AppContext, AsyncAppContext, Task, WeakView};
8use language::{CodeLabel, LspAdapterDelegate};
9use language_model::{
10 LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
11 LanguageModelRequestMessage, Role,
12};
13use semantic_index::{FileSummary, SemanticDb};
14use smol::channel;
15use std::sync::{atomic::AtomicBool, Arc};
16use ui::{BorrowAppContext, WindowContext};
17use util::ResultExt;
18use workspace::Workspace;
19
20pub struct AutoSlashCommandFeatureFlag;
21
22impl FeatureFlag for AutoSlashCommandFeatureFlag {
23 const NAME: &'static str = "auto-slash-command";
24}
25
26pub(crate) struct AutoCommand;
27
28impl SlashCommand for AutoCommand {
29 fn name(&self) -> String {
30 "auto".into()
31 }
32
33 fn description(&self) -> String {
34 "Automatically infer what context to add, based on your prompt".into()
35 }
36
37 fn menu_text(&self) -> String {
38 "Automatically Infer Context".into()
39 }
40
41 fn label(&self, cx: &AppContext) -> CodeLabel {
42 create_label_for_command("auto", &["--prompt"], cx)
43 }
44
45 fn complete_argument(
46 self: Arc<Self>,
47 _arguments: &[String],
48 _cancel: Arc<AtomicBool>,
49 workspace: Option<WeakView<Workspace>>,
50 cx: &mut WindowContext,
51 ) -> Task<Result<Vec<ArgumentCompletion>>> {
52 // There's no autocomplete for a prompt, since it's arbitrary text.
53 // However, we can use this opportunity to kick off a drain of the backlog.
54 // That way, it can hopefully be done resummarizing by the time we've actually
55 // typed out our prompt. This re-runs on every keystroke during autocomplete,
56 // but in the future, we could instead do it only once, when /auto is first entered.
57 let Some(workspace) = workspace.and_then(|ws| ws.upgrade()) else {
58 log::warn!("workspace was dropped or unavailable during /auto autocomplete");
59
60 return Task::ready(Ok(Vec::new()));
61 };
62
63 let project = workspace.read(cx).project().clone();
64 let Some(project_index) =
65 cx.update_global(|index: &mut SemanticDb, cx| index.project_index(project, cx))
66 else {
67 return Task::ready(Err(anyhow!("No project indexer, cannot use /auto")));
68 };
69
70 let cx: &mut AppContext = cx;
71
72 cx.spawn(|cx: gpui::AsyncAppContext| async move {
73 let task = project_index.read_with(&cx, |project_index, cx| {
74 project_index.flush_summary_backlogs(cx)
75 })?;
76
77 cx.background_executor().spawn(task).await;
78
79 anyhow::Ok(Vec::new())
80 })
81 }
82
83 fn requires_argument(&self) -> bool {
84 true
85 }
86
87 fn run(
88 self: Arc<Self>,
89 arguments: &[String],
90 workspace: WeakView<Workspace>,
91 _delegate: Option<Arc<dyn LspAdapterDelegate>>,
92 cx: &mut WindowContext,
93 ) -> Task<Result<SlashCommandOutput>> {
94 let Some(workspace) = workspace.upgrade() else {
95 return Task::ready(Err(anyhow::anyhow!("workspace was dropped")));
96 };
97 if arguments.is_empty() {
98 return Task::ready(Err(anyhow!("missing prompt")));
99 };
100 let argument = arguments.join(" ");
101 let original_prompt = argument.to_string();
102 let project = workspace.read(cx).project().clone();
103 let Some(project_index) =
104 cx.update_global(|index: &mut SemanticDb, cx| index.project_index(project, cx))
105 else {
106 return Task::ready(Err(anyhow!("no project indexer")));
107 };
108
109 let task = cx.spawn(|cx: gpui::AsyncWindowContext| async move {
110 let summaries = project_index
111 .read_with(&cx, |project_index, cx| project_index.all_summaries(cx))?
112 .await?;
113
114 commands_for_summaries(&summaries, &original_prompt, &cx).await
115 });
116
117 // As a convenience, append /auto's argument to the end of the prompt
118 // so you don't have to write it again.
119 let original_prompt = argument.to_string();
120
121 cx.background_executor().spawn(async move {
122 let commands = task.await?;
123 let mut prompt = String::new();
124
125 log::info!(
126 "Translating this response into slash-commands: {:?}",
127 commands
128 );
129
130 for command in commands {
131 prompt.push('/');
132 prompt.push_str(&command.name);
133 prompt.push(' ');
134 prompt.push_str(&command.arg);
135 prompt.push('\n');
136 }
137
138 prompt.push('\n');
139 prompt.push_str(&original_prompt);
140
141 Ok(SlashCommandOutput {
142 text: prompt,
143 sections: Vec::new(),
144 run_commands_in_text: true,
145 })
146 })
147 }
148}
149
150const PROMPT_INSTRUCTIONS_BEFORE_SUMMARY: &str = include_str!("prompt_before_summary.txt");
151const PROMPT_INSTRUCTIONS_AFTER_SUMMARY: &str = include_str!("prompt_after_summary.txt");
152
153fn summaries_prompt(summaries: &[FileSummary], original_prompt: &str) -> String {
154 let json_summaries = serde_json::to_string(summaries).unwrap();
155
156 format!("{PROMPT_INSTRUCTIONS_BEFORE_SUMMARY}\n{json_summaries}\n{PROMPT_INSTRUCTIONS_AFTER_SUMMARY}\n{original_prompt}")
157}
158
159/// The slash commands that the model is told about, and which we look for in the inference response.
160const SUPPORTED_SLASH_COMMANDS: &[&str] = &["search", "file"];
161
162#[derive(Debug, Clone)]
163struct CommandToRun {
164 name: String,
165 arg: String,
166}
167
168/// Given the pre-indexed file summaries for this project, as well as the original prompt
169/// string passed to `/auto`, get a list of slash commands to run, along with their arguments.
170///
171/// The prompt's output does not include the slashes (to reduce the chance that it makes a mistake),
172/// so taking one of these returned Strings and turning it into a real slash-command-with-argument
173/// involves prepending a slash to it.
174///
175/// This function will validate that each of the returned lines begins with one of SUPPORTED_SLASH_COMMANDS.
176/// Any other lines it encounters will be discarded, with a warning logged.
177async fn commands_for_summaries(
178 summaries: &[FileSummary],
179 original_prompt: &str,
180 cx: &AsyncAppContext,
181) -> Result<Vec<CommandToRun>> {
182 if summaries.is_empty() {
183 log::warn!("Inferring no context because there were no summaries available.");
184 return Ok(Vec::new());
185 }
186
187 // Use the globally configured model to translate the summaries into slash-commands,
188 // because Qwen2-7B-Instruct has not done a good job at that task.
189 let Some(model) = cx.update(|cx| LanguageModelRegistry::read_global(cx).active_model())? else {
190 log::warn!("Can't infer context because there's no active model.");
191 return Ok(Vec::new());
192 };
193 // Only go up to 90% of the actual max token count, to reduce chances of
194 // exceeding the token count due to inaccuracies in the token counting heuristic.
195 let max_token_count = (model.max_token_count() * 9) / 10;
196
197 // Rather than recursing (which would require this async function use a pinned box),
198 // we use an explicit stack of arguments and answers for when we need to "recurse."
199 let mut stack = vec![summaries];
200 let mut final_response = Vec::new();
201 let mut prompts = Vec::new();
202
203 // TODO We only need to create multiple Requests because we currently
204 // don't have the ability to tell if a CompletionProvider::complete response
205 // was a "too many tokens in this request" error. If we had that, then
206 // we could try the request once, instead of having to make separate requests
207 // to check the token count and then afterwards to run the actual prompt.
208 let make_request = |prompt: String| LanguageModelRequest {
209 messages: vec![LanguageModelRequestMessage {
210 role: Role::User,
211 content: vec![prompt.into()],
212 // Nothing in here will benefit from caching
213 cache: false,
214 }],
215 tools: Vec::new(),
216 stop: Vec::new(),
217 temperature: 1.0,
218 };
219
220 while let Some(current_summaries) = stack.pop() {
221 // The split can result in one slice being empty and the other having one element.
222 // Whenever that happens, skip the empty one.
223 if current_summaries.is_empty() {
224 continue;
225 }
226
227 log::info!(
228 "Inferring prompt context using {} file summaries",
229 current_summaries.len()
230 );
231
232 let prompt = summaries_prompt(¤t_summaries, original_prompt);
233 let start = std::time::Instant::now();
234 // Per OpenAI, 1 token ~= 4 chars in English (we go with 4.5 to overestimate a bit, because failed API requests cost a lot of perf)
235 // Verifying this against an actual model.count_tokens() confirms that it's usually within ~5% of the correct answer, whereas
236 // getting the correct answer from tiktoken takes hundreds of milliseconds (compared to this arithmetic being ~free).
237 // source: https://help.openai.com/en/articles/4936856-what-are-tokens-and-how-to-count-them
238 let token_estimate = prompt.len() * 2 / 9;
239 let duration = start.elapsed();
240 log::info!(
241 "Time taken to count tokens for prompt of length {:?}B: {:?}",
242 prompt.len(),
243 duration
244 );
245
246 if token_estimate < max_token_count {
247 prompts.push(prompt);
248 } else if current_summaries.len() == 1 {
249 log::warn!("Inferring context for a single file's summary failed because the prompt's token length exceeded the model's token limit.");
250 } else {
251 log::info!(
252 "Context inference using file summaries resulted in a prompt containing {token_estimate} tokens, which exceeded the model's max of {max_token_count}. Retrying as two separate prompts, each including half the number of summaries.",
253 );
254 let (left, right) = current_summaries.split_at(current_summaries.len() / 2);
255 stack.push(right);
256 stack.push(left);
257 }
258 }
259
260 let all_start = std::time::Instant::now();
261
262 let (tx, rx) = channel::bounded(1024);
263
264 let completion_streams = prompts
265 .into_iter()
266 .map(|prompt| {
267 let request = make_request(prompt.clone());
268 let model = model.clone();
269 let tx = tx.clone();
270 let stream = model.stream_completion(request, &cx);
271
272 (stream, tx)
273 })
274 .collect::<Vec<_>>();
275
276 cx.background_executor()
277 .spawn(async move {
278 let futures = completion_streams
279 .into_iter()
280 .enumerate()
281 .map(|(ix, (stream, tx))| async move {
282 let start = std::time::Instant::now();
283 let events = stream.await?;
284 log::info!("Time taken for awaiting /await chunk stream #{ix}: {:?}", start.elapsed());
285
286 let completion: String = events
287 .filter_map(|event| async {
288 if let Ok(LanguageModelCompletionEvent::Text(text)) = event {
289 Some(text)
290 } else {
291 None
292 }
293 })
294 .collect()
295 .await;
296
297 log::info!("Time taken for all /auto chunks to come back for #{ix}: {:?}", start.elapsed());
298
299 for line in completion.split('\n') {
300 if let Some(first_space) = line.find(' ') {
301 let command = &line[..first_space].trim();
302 let arg = &line[first_space..].trim();
303
304 tx.send(CommandToRun {
305 name: command.to_string(),
306 arg: arg.to_string(),
307 })
308 .await?;
309 } else if !line.trim().is_empty() {
310 // All slash-commands currently supported in context inference need a space for the argument.
311 log::warn!(
312 "Context inference returned a non-blank line that contained no spaces (meaning no argument for the slash command): {:?}",
313 line
314 );
315 }
316 }
317
318 anyhow::Ok(())
319 })
320 .collect::<Vec<_>>();
321
322 let _ = futures::future::try_join_all(futures).await.log_err();
323
324 let duration = all_start.elapsed();
325 eprintln!("All futures completed in {:?}", duration);
326 })
327 .await;
328
329 drop(tx); // Close the channel so that rx.collect() won't hang. This is safe because all futures have completed.
330 let results = rx.collect::<Vec<_>>().await;
331 eprintln!(
332 "Finished collecting from the channel with {} results",
333 results.len()
334 );
335 for command in results {
336 // Don't return empty or duplicate commands
337 if !command.name.is_empty()
338 && !final_response
339 .iter()
340 .any(|cmd: &CommandToRun| cmd.name == command.name && cmd.arg == command.arg)
341 {
342 if SUPPORTED_SLASH_COMMANDS
343 .iter()
344 .any(|supported| &command.name == supported)
345 {
346 final_response.push(command);
347 } else {
348 log::warn!(
349 "Context inference returned an unrecognized slash command: {:?}",
350 command
351 );
352 }
353 }
354 }
355
356 // Sort the commands by name (reversed just so that /search appears before /file)
357 final_response.sort_by(|cmd1, cmd2| cmd1.name.cmp(&cmd2.name).reverse());
358
359 Ok(final_response)
360}