auto_command.rs

  1use super::create_label_for_command;
  2use super::{SlashCommand, SlashCommandOutput};
  3use anyhow::{anyhow, Result};
  4use assistant_slash_command::{ArgumentCompletion, SlashCommandOutputSection};
  5use feature_flags::FeatureFlag;
  6use futures::StreamExt;
  7use gpui::{AppContext, AsyncAppContext, Task, WeakView};
  8use language::{CodeLabel, LspAdapterDelegate};
  9use language_model::{
 10    LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
 11    LanguageModelRequestMessage, Role,
 12};
 13use semantic_index::{FileSummary, SemanticDb};
 14use smol::channel;
 15use std::sync::{atomic::AtomicBool, Arc};
 16use ui::{BorrowAppContext, WindowContext};
 17use util::ResultExt;
 18use workspace::Workspace;
 19
 20pub struct AutoSlashCommandFeatureFlag;
 21
 22impl FeatureFlag for AutoSlashCommandFeatureFlag {
 23    const NAME: &'static str = "auto-slash-command";
 24}
 25
 26pub(crate) struct AutoCommand;
 27
 28impl SlashCommand for AutoCommand {
 29    fn name(&self) -> String {
 30        "auto".into()
 31    }
 32
 33    fn description(&self) -> String {
 34        "Automatically infer what context to add, based on your prompt".into()
 35    }
 36
 37    fn menu_text(&self) -> String {
 38        "Automatically Infer Context".into()
 39    }
 40
 41    fn label(&self, cx: &AppContext) -> CodeLabel {
 42        create_label_for_command("auto", &["--prompt"], cx)
 43    }
 44
 45    fn complete_argument(
 46        self: Arc<Self>,
 47        _arguments: &[String],
 48        _cancel: Arc<AtomicBool>,
 49        workspace: Option<WeakView<Workspace>>,
 50        cx: &mut WindowContext,
 51    ) -> Task<Result<Vec<ArgumentCompletion>>> {
 52        // There's no autocomplete for a prompt, since it's arbitrary text.
 53        // However, we can use this opportunity to kick off a drain of the backlog.
 54        // That way, it can hopefully be done resummarizing by the time we've actually
 55        // typed out our prompt. This re-runs on every keystroke during autocomplete,
 56        // but in the future, we could instead do it only once, when /auto is first entered.
 57        let Some(workspace) = workspace.and_then(|ws| ws.upgrade()) else {
 58            log::warn!("workspace was dropped or unavailable during /auto autocomplete");
 59
 60            return Task::ready(Ok(Vec::new()));
 61        };
 62
 63        let project = workspace.read(cx).project().clone();
 64        let Some(project_index) =
 65            cx.update_global(|index: &mut SemanticDb, cx| index.project_index(project, cx))
 66        else {
 67            return Task::ready(Err(anyhow!("No project indexer, cannot use /auto")));
 68        };
 69
 70        let cx: &mut AppContext = cx;
 71
 72        cx.spawn(|cx: gpui::AsyncAppContext| async move {
 73            let task = project_index.read_with(&cx, |project_index, cx| {
 74                project_index.flush_summary_backlogs(cx)
 75            })?;
 76
 77            cx.background_executor().spawn(task).await;
 78
 79            anyhow::Ok(Vec::new())
 80        })
 81    }
 82
 83    fn requires_argument(&self) -> bool {
 84        true
 85    }
 86
 87    fn run(
 88        self: Arc<Self>,
 89        arguments: &[String],
 90        _context_slash_command_output_sections: &[SlashCommandOutputSection<language::Anchor>],
 91        _context_buffer: language::BufferSnapshot,
 92        workspace: WeakView<Workspace>,
 93        _delegate: Option<Arc<dyn LspAdapterDelegate>>,
 94        cx: &mut WindowContext,
 95    ) -> Task<Result<SlashCommandOutput>> {
 96        let Some(workspace) = workspace.upgrade() else {
 97            return Task::ready(Err(anyhow::anyhow!("workspace was dropped")));
 98        };
 99        if arguments.is_empty() {
100            return Task::ready(Err(anyhow!("missing prompt")));
101        };
102        let argument = arguments.join(" ");
103        let original_prompt = argument.to_string();
104        let project = workspace.read(cx).project().clone();
105        let Some(project_index) =
106            cx.update_global(|index: &mut SemanticDb, cx| index.project_index(project, cx))
107        else {
108            return Task::ready(Err(anyhow!("no project indexer")));
109        };
110
111        let task = cx.spawn(|cx: gpui::AsyncWindowContext| async move {
112            let summaries = project_index
113                .read_with(&cx, |project_index, cx| project_index.all_summaries(cx))?
114                .await?;
115
116            commands_for_summaries(&summaries, &original_prompt, &cx).await
117        });
118
119        // As a convenience, append /auto's argument to the end of the prompt
120        // so you don't have to write it again.
121        let original_prompt = argument.to_string();
122
123        cx.background_executor().spawn(async move {
124            let commands = task.await?;
125            let mut prompt = String::new();
126
127            log::info!(
128                "Translating this response into slash-commands: {:?}",
129                commands
130            );
131
132            for command in commands {
133                prompt.push('/');
134                prompt.push_str(&command.name);
135                prompt.push(' ');
136                prompt.push_str(&command.arg);
137                prompt.push('\n');
138            }
139
140            prompt.push('\n');
141            prompt.push_str(&original_prompt);
142
143            Ok(SlashCommandOutput {
144                text: prompt,
145                sections: Vec::new(),
146                run_commands_in_text: true,
147            })
148        })
149    }
150}
151
152const PROMPT_INSTRUCTIONS_BEFORE_SUMMARY: &str = include_str!("prompt_before_summary.txt");
153const PROMPT_INSTRUCTIONS_AFTER_SUMMARY: &str = include_str!("prompt_after_summary.txt");
154
155fn summaries_prompt(summaries: &[FileSummary], original_prompt: &str) -> String {
156    let json_summaries = serde_json::to_string(summaries).unwrap();
157
158    format!("{PROMPT_INSTRUCTIONS_BEFORE_SUMMARY}\n{json_summaries}\n{PROMPT_INSTRUCTIONS_AFTER_SUMMARY}\n{original_prompt}")
159}
160
161/// The slash commands that the model is told about, and which we look for in the inference response.
162const SUPPORTED_SLASH_COMMANDS: &[&str] = &["search", "file"];
163
164#[derive(Debug, Clone)]
165struct CommandToRun {
166    name: String,
167    arg: String,
168}
169
170/// Given the pre-indexed file summaries for this project, as well as the original prompt
171/// string passed to `/auto`, get a list of slash commands to run, along with their arguments.
172///
173/// The prompt's output does not include the slashes (to reduce the chance that it makes a mistake),
174/// so taking one of these returned Strings and turning it into a real slash-command-with-argument
175/// involves prepending a slash to it.
176///
177/// This function will validate that each of the returned lines begins with one of SUPPORTED_SLASH_COMMANDS.
178/// Any other lines it encounters will be discarded, with a warning logged.
179async fn commands_for_summaries(
180    summaries: &[FileSummary],
181    original_prompt: &str,
182    cx: &AsyncAppContext,
183) -> Result<Vec<CommandToRun>> {
184    if summaries.is_empty() {
185        log::warn!("Inferring no context because there were no summaries available.");
186        return Ok(Vec::new());
187    }
188
189    // Use the globally configured model to translate the summaries into slash-commands,
190    // because Qwen2-7B-Instruct has not done a good job at that task.
191    let Some(model) = cx.update(|cx| LanguageModelRegistry::read_global(cx).active_model())? else {
192        log::warn!("Can't infer context because there's no active model.");
193        return Ok(Vec::new());
194    };
195    // Only go up to 90% of the actual max token count, to reduce chances of
196    // exceeding the token count due to inaccuracies in the token counting heuristic.
197    let max_token_count = (model.max_token_count() * 9) / 10;
198
199    // Rather than recursing (which would require this async function use a pinned box),
200    // we use an explicit stack of arguments and answers for when we need to "recurse."
201    let mut stack = vec![summaries];
202    let mut final_response = Vec::new();
203    let mut prompts = Vec::new();
204
205    // TODO We only need to create multiple Requests because we currently
206    // don't have the ability to tell if a CompletionProvider::complete response
207    // was a "too many tokens in this request" error. If we had that, then
208    // we could try the request once, instead of having to make separate requests
209    // to check the token count and then afterwards to run the actual prompt.
210    let make_request = |prompt: String| LanguageModelRequest {
211        messages: vec![LanguageModelRequestMessage {
212            role: Role::User,
213            content: vec![prompt.into()],
214            // Nothing in here will benefit from caching
215            cache: false,
216        }],
217        tools: Vec::new(),
218        stop: Vec::new(),
219        temperature: 1.0,
220    };
221
222    while let Some(current_summaries) = stack.pop() {
223        // The split can result in one slice being empty and the other having one element.
224        // Whenever that happens, skip the empty one.
225        if current_summaries.is_empty() {
226            continue;
227        }
228
229        log::info!(
230            "Inferring prompt context using {} file summaries",
231            current_summaries.len()
232        );
233
234        let prompt = summaries_prompt(&current_summaries, original_prompt);
235        let start = std::time::Instant::now();
236        // Per OpenAI, 1 token ~= 4 chars in English (we go with 4.5 to overestimate a bit, because failed API requests cost a lot of perf)
237        // Verifying this against an actual model.count_tokens() confirms that it's usually within ~5% of the correct answer, whereas
238        // getting the correct answer from tiktoken takes hundreds of milliseconds (compared to this arithmetic being ~free).
239        // source: https://help.openai.com/en/articles/4936856-what-are-tokens-and-how-to-count-them
240        let token_estimate = prompt.len() * 2 / 9;
241        let duration = start.elapsed();
242        log::info!(
243            "Time taken to count tokens for prompt of length {:?}B: {:?}",
244            prompt.len(),
245            duration
246        );
247
248        if token_estimate < max_token_count {
249            prompts.push(prompt);
250        } else if current_summaries.len() == 1 {
251            log::warn!("Inferring context for a single file's summary failed because the prompt's token length exceeded the model's token limit.");
252        } else {
253            log::info!(
254                "Context inference using file summaries resulted in a prompt containing {token_estimate} tokens, which exceeded the model's max of {max_token_count}. Retrying as two separate prompts, each including half the number of summaries.",
255            );
256            let (left, right) = current_summaries.split_at(current_summaries.len() / 2);
257            stack.push(right);
258            stack.push(left);
259        }
260    }
261
262    let all_start = std::time::Instant::now();
263
264    let (tx, rx) = channel::bounded(1024);
265
266    let completion_streams = prompts
267        .into_iter()
268        .map(|prompt| {
269            let request = make_request(prompt.clone());
270            let model = model.clone();
271            let tx = tx.clone();
272            let stream = model.stream_completion(request, &cx);
273
274            (stream, tx)
275        })
276        .collect::<Vec<_>>();
277
278    cx.background_executor()
279        .spawn(async move {
280            let futures = completion_streams
281                .into_iter()
282                .enumerate()
283                .map(|(ix, (stream, tx))| async move {
284                    let start = std::time::Instant::now();
285                    let events = stream.await?;
286                    log::info!("Time taken for awaiting /await chunk stream #{ix}: {:?}", start.elapsed());
287
288                    let completion: String = events
289                        .filter_map(|event| async {
290                            if let Ok(LanguageModelCompletionEvent::Text(text)) = event {
291                                Some(text)
292                            } else {
293                                None
294                            }
295                        })
296                        .collect()
297                        .await;
298
299                    log::info!("Time taken for all /auto chunks to come back for #{ix}: {:?}", start.elapsed());
300
301                    for line in completion.split('\n') {
302                        if let Some(first_space) = line.find(' ') {
303                            let command = &line[..first_space].trim();
304                            let arg = &line[first_space..].trim();
305
306                            tx.send(CommandToRun {
307                                name: command.to_string(),
308                                arg: arg.to_string(),
309                            })
310                            .await?;
311                        } else if !line.trim().is_empty() {
312                            // All slash-commands currently supported in context inference need a space for the argument.
313                            log::warn!(
314                                "Context inference returned a non-blank line that contained no spaces (meaning no argument for the slash command): {:?}",
315                                line
316                            );
317                        }
318                    }
319
320                    anyhow::Ok(())
321                })
322                .collect::<Vec<_>>();
323
324            let _ = futures::future::try_join_all(futures).await.log_err();
325
326            let duration = all_start.elapsed();
327            eprintln!("All futures completed in {:?}", duration);
328        })
329        .await;
330
331    drop(tx); // Close the channel so that rx.collect() won't hang. This is safe because all futures have completed.
332    let results = rx.collect::<Vec<_>>().await;
333    eprintln!(
334        "Finished collecting from the channel with {} results",
335        results.len()
336    );
337    for command in results {
338        // Don't return empty or duplicate commands
339        if !command.name.is_empty()
340            && !final_response
341                .iter()
342                .any(|cmd: &CommandToRun| cmd.name == command.name && cmd.arg == command.arg)
343        {
344            if SUPPORTED_SLASH_COMMANDS
345                .iter()
346                .any(|supported| &command.name == supported)
347            {
348                final_response.push(command);
349            } else {
350                log::warn!(
351                    "Context inference returned an unrecognized slash command: {:?}",
352                    command
353                );
354            }
355        }
356    }
357
358    // Sort the commands by name (reversed just so that /search appears before /file)
359    final_response.sort_by(|cmd1, cmd2| cmd1.name.cmp(&cmd2.name).reverse());
360
361    Ok(final_response)
362}