auto_command.rs

  1use anyhow::{anyhow, Result};
  2use assistant_slash_command::{
  3    ArgumentCompletion, SlashCommand, SlashCommandOutput, SlashCommandOutputSection,
  4    SlashCommandResult,
  5};
  6use feature_flags::FeatureFlag;
  7use futures::StreamExt;
  8use gpui::{AppContext, AsyncAppContext, Task, WeakView};
  9use language::{CodeLabel, LspAdapterDelegate};
 10use language_model::{
 11    LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
 12    LanguageModelRequestMessage, Role,
 13};
 14use semantic_index::{FileSummary, SemanticDb};
 15use smol::channel;
 16use std::sync::{atomic::AtomicBool, Arc};
 17use ui::{BorrowAppContext, WindowContext};
 18use util::ResultExt;
 19use workspace::Workspace;
 20
 21use crate::slash_command::create_label_for_command;
 22
 23pub struct AutoSlashCommandFeatureFlag;
 24
 25impl FeatureFlag for AutoSlashCommandFeatureFlag {
 26    const NAME: &'static str = "auto-slash-command";
 27}
 28
 29pub(crate) struct AutoCommand;
 30
 31impl SlashCommand for AutoCommand {
 32    fn name(&self) -> String {
 33        "auto".into()
 34    }
 35
 36    fn description(&self) -> String {
 37        "Automatically infer what context to add".into()
 38    }
 39
 40    fn menu_text(&self) -> String {
 41        self.description()
 42    }
 43
 44    fn label(&self, cx: &AppContext) -> CodeLabel {
 45        create_label_for_command("auto", &["--prompt"], cx)
 46    }
 47
 48    fn complete_argument(
 49        self: Arc<Self>,
 50        _arguments: &[String],
 51        _cancel: Arc<AtomicBool>,
 52        workspace: Option<WeakView<Workspace>>,
 53        cx: &mut WindowContext,
 54    ) -> Task<Result<Vec<ArgumentCompletion>>> {
 55        // There's no autocomplete for a prompt, since it's arbitrary text.
 56        // However, we can use this opportunity to kick off a drain of the backlog.
 57        // That way, it can hopefully be done resummarizing by the time we've actually
 58        // typed out our prompt. This re-runs on every keystroke during autocomplete,
 59        // but in the future, we could instead do it only once, when /auto is first entered.
 60        let Some(workspace) = workspace.and_then(|ws| ws.upgrade()) else {
 61            log::warn!("workspace was dropped or unavailable during /auto autocomplete");
 62
 63            return Task::ready(Ok(Vec::new()));
 64        };
 65
 66        let project = workspace.read(cx).project().clone();
 67        let Some(project_index) =
 68            cx.update_global(|index: &mut SemanticDb, cx| index.project_index(project, cx))
 69        else {
 70            return Task::ready(Err(anyhow!("No project indexer, cannot use /auto")));
 71        };
 72
 73        let cx: &mut AppContext = cx;
 74
 75        cx.spawn(|cx: gpui::AsyncAppContext| async move {
 76            let task = project_index.read_with(&cx, |project_index, cx| {
 77                project_index.flush_summary_backlogs(cx)
 78            })?;
 79
 80            cx.background_executor().spawn(task).await;
 81
 82            anyhow::Ok(Vec::new())
 83        })
 84    }
 85
 86    fn requires_argument(&self) -> bool {
 87        true
 88    }
 89
 90    fn run(
 91        self: Arc<Self>,
 92        arguments: &[String],
 93        _context_slash_command_output_sections: &[SlashCommandOutputSection<language::Anchor>],
 94        _context_buffer: language::BufferSnapshot,
 95        workspace: WeakView<Workspace>,
 96        _delegate: Option<Arc<dyn LspAdapterDelegate>>,
 97        cx: &mut WindowContext,
 98    ) -> Task<SlashCommandResult> {
 99        let Some(workspace) = workspace.upgrade() else {
100            return Task::ready(Err(anyhow::anyhow!("workspace was dropped")));
101        };
102        if arguments.is_empty() {
103            return Task::ready(Err(anyhow!("missing prompt")));
104        };
105        let argument = arguments.join(" ");
106        let original_prompt = argument.to_string();
107        let project = workspace.read(cx).project().clone();
108        let Some(project_index) =
109            cx.update_global(|index: &mut SemanticDb, cx| index.project_index(project, cx))
110        else {
111            return Task::ready(Err(anyhow!("no project indexer")));
112        };
113
114        let task = cx.spawn(|cx: gpui::AsyncWindowContext| async move {
115            let summaries = project_index
116                .read_with(&cx, |project_index, cx| project_index.all_summaries(cx))?
117                .await?;
118
119            commands_for_summaries(&summaries, &original_prompt, &cx).await
120        });
121
122        // As a convenience, append /auto's argument to the end of the prompt
123        // so you don't have to write it again.
124        let original_prompt = argument.to_string();
125
126        cx.background_executor().spawn(async move {
127            let commands = task.await?;
128            let mut prompt = String::new();
129
130            log::info!(
131                "Translating this response into slash-commands: {:?}",
132                commands
133            );
134
135            for command in commands {
136                prompt.push('/');
137                prompt.push_str(&command.name);
138                prompt.push(' ');
139                prompt.push_str(&command.arg);
140                prompt.push('\n');
141            }
142
143            prompt.push('\n');
144            prompt.push_str(&original_prompt);
145
146            Ok(SlashCommandOutput {
147                text: prompt,
148                sections: Vec::new(),
149                run_commands_in_text: true,
150            }
151            .to_event_stream())
152        })
153    }
154}
155
156const PROMPT_INSTRUCTIONS_BEFORE_SUMMARY: &str = include_str!("prompt_before_summary.txt");
157const PROMPT_INSTRUCTIONS_AFTER_SUMMARY: &str = include_str!("prompt_after_summary.txt");
158
159fn summaries_prompt(summaries: &[FileSummary], original_prompt: &str) -> String {
160    let json_summaries = serde_json::to_string(summaries).unwrap();
161
162    format!("{PROMPT_INSTRUCTIONS_BEFORE_SUMMARY}\n{json_summaries}\n{PROMPT_INSTRUCTIONS_AFTER_SUMMARY}\n{original_prompt}")
163}
164
165/// The slash commands that the model is told about, and which we look for in the inference response.
166const SUPPORTED_SLASH_COMMANDS: &[&str] = &["search", "file"];
167
168#[derive(Debug, Clone)]
169struct CommandToRun {
170    name: String,
171    arg: String,
172}
173
174/// Given the pre-indexed file summaries for this project, as well as the original prompt
175/// string passed to `/auto`, get a list of slash commands to run, along with their arguments.
176///
177/// The prompt's output does not include the slashes (to reduce the chance that it makes a mistake),
178/// so taking one of these returned Strings and turning it into a real slash-command-with-argument
179/// involves prepending a slash to it.
180///
181/// This function will validate that each of the returned lines begins with one of SUPPORTED_SLASH_COMMANDS.
182/// Any other lines it encounters will be discarded, with a warning logged.
183async fn commands_for_summaries(
184    summaries: &[FileSummary],
185    original_prompt: &str,
186    cx: &AsyncAppContext,
187) -> Result<Vec<CommandToRun>> {
188    if summaries.is_empty() {
189        log::warn!("Inferring no context because there were no summaries available.");
190        return Ok(Vec::new());
191    }
192
193    // Use the globally configured model to translate the summaries into slash-commands,
194    // because Qwen2-7B-Instruct has not done a good job at that task.
195    let Some(model) = cx.update(|cx| LanguageModelRegistry::read_global(cx).active_model())? else {
196        log::warn!("Can't infer context because there's no active model.");
197        return Ok(Vec::new());
198    };
199    // Only go up to 90% of the actual max token count, to reduce chances of
200    // exceeding the token count due to inaccuracies in the token counting heuristic.
201    let max_token_count = (model.max_token_count() * 9) / 10;
202
203    // Rather than recursing (which would require this async function use a pinned box),
204    // we use an explicit stack of arguments and answers for when we need to "recurse."
205    let mut stack = vec![summaries];
206    let mut final_response = Vec::new();
207    let mut prompts = Vec::new();
208
209    // TODO We only need to create multiple Requests because we currently
210    // don't have the ability to tell if a CompletionProvider::complete response
211    // was a "too many tokens in this request" error. If we had that, then
212    // we could try the request once, instead of having to make separate requests
213    // to check the token count and then afterwards to run the actual prompt.
214    let make_request = |prompt: String| LanguageModelRequest {
215        messages: vec![LanguageModelRequestMessage {
216            role: Role::User,
217            content: vec![prompt.into()],
218            // Nothing in here will benefit from caching
219            cache: false,
220        }],
221        tools: Vec::new(),
222        stop: Vec::new(),
223        temperature: None,
224    };
225
226    while let Some(current_summaries) = stack.pop() {
227        // The split can result in one slice being empty and the other having one element.
228        // Whenever that happens, skip the empty one.
229        if current_summaries.is_empty() {
230            continue;
231        }
232
233        log::info!(
234            "Inferring prompt context using {} file summaries",
235            current_summaries.len()
236        );
237
238        let prompt = summaries_prompt(&current_summaries, original_prompt);
239        let start = std::time::Instant::now();
240        // Per OpenAI, 1 token ~= 4 chars in English (we go with 4.5 to overestimate a bit, because failed API requests cost a lot of perf)
241        // Verifying this against an actual model.count_tokens() confirms that it's usually within ~5% of the correct answer, whereas
242        // getting the correct answer from tiktoken takes hundreds of milliseconds (compared to this arithmetic being ~free).
243        // source: https://help.openai.com/en/articles/4936856-what-are-tokens-and-how-to-count-them
244        let token_estimate = prompt.len() * 2 / 9;
245        let duration = start.elapsed();
246        log::info!(
247            "Time taken to count tokens for prompt of length {:?}B: {:?}",
248            prompt.len(),
249            duration
250        );
251
252        if token_estimate < max_token_count {
253            prompts.push(prompt);
254        } else if current_summaries.len() == 1 {
255            log::warn!("Inferring context for a single file's summary failed because the prompt's token length exceeded the model's token limit.");
256        } else {
257            log::info!(
258                "Context inference using file summaries resulted in a prompt containing {token_estimate} tokens, which exceeded the model's max of {max_token_count}. Retrying as two separate prompts, each including half the number of summaries.",
259            );
260            let (left, right) = current_summaries.split_at(current_summaries.len() / 2);
261            stack.push(right);
262            stack.push(left);
263        }
264    }
265
266    let all_start = std::time::Instant::now();
267
268    let (tx, rx) = channel::bounded(1024);
269
270    let completion_streams = prompts
271        .into_iter()
272        .map(|prompt| {
273            let request = make_request(prompt.clone());
274            let model = model.clone();
275            let tx = tx.clone();
276            let stream = model.stream_completion(request, &cx);
277
278            (stream, tx)
279        })
280        .collect::<Vec<_>>();
281
282    cx.background_executor()
283        .spawn(async move {
284            let futures = completion_streams
285                .into_iter()
286                .enumerate()
287                .map(|(ix, (stream, tx))| async move {
288                    let start = std::time::Instant::now();
289                    let events = stream.await?;
290                    log::info!("Time taken for awaiting /await chunk stream #{ix}: {:?}", start.elapsed());
291
292                    let completion: String = events
293                        .filter_map(|event| async {
294                            if let Ok(LanguageModelCompletionEvent::Text(text)) = event {
295                                Some(text)
296                            } else {
297                                None
298                            }
299                        })
300                        .collect()
301                        .await;
302
303                    log::info!("Time taken for all /auto chunks to come back for #{ix}: {:?}", start.elapsed());
304
305                    for line in completion.split('\n') {
306                        if let Some(first_space) = line.find(' ') {
307                            let command = &line[..first_space].trim();
308                            let arg = &line[first_space..].trim();
309
310                            tx.send(CommandToRun {
311                                name: command.to_string(),
312                                arg: arg.to_string(),
313                            })
314                            .await?;
315                        } else if !line.trim().is_empty() {
316                            // All slash-commands currently supported in context inference need a space for the argument.
317                            log::warn!(
318                                "Context inference returned a non-blank line that contained no spaces (meaning no argument for the slash command): {:?}",
319                                line
320                            );
321                        }
322                    }
323
324                    anyhow::Ok(())
325                })
326                .collect::<Vec<_>>();
327
328            let _ = futures::future::try_join_all(futures).await.log_err();
329
330            let duration = all_start.elapsed();
331            eprintln!("All futures completed in {:?}", duration);
332        })
333        .await;
334
335    drop(tx); // Close the channel so that rx.collect() won't hang. This is safe because all futures have completed.
336    let results = rx.collect::<Vec<_>>().await;
337    eprintln!(
338        "Finished collecting from the channel with {} results",
339        results.len()
340    );
341    for command in results {
342        // Don't return empty or duplicate commands
343        if !command.name.is_empty()
344            && !final_response
345                .iter()
346                .any(|cmd: &CommandToRun| cmd.name == command.name && cmd.arg == command.arg)
347        {
348            if SUPPORTED_SLASH_COMMANDS
349                .iter()
350                .any(|supported| &command.name == supported)
351            {
352                final_response.push(command);
353            } else {
354                log::warn!(
355                    "Context inference returned an unrecognized slash command: {:?}",
356                    command
357                );
358            }
359        }
360    }
361
362    // Sort the commands by name (reversed just so that /search appears before /file)
363    final_response.sort_by(|cmd1, cmd2| cmd1.name.cmp(&cmd2.name).reverse());
364
365    Ok(final_response)
366}