terminal_tool.rs

  1use crate::schema::json_schema_for;
  2use anyhow::{Context as _, Result, anyhow};
  3use assistant_tool::{ActionLog, Tool, ToolResult};
  4use futures::io::BufReader;
  5use futures::{AsyncBufReadExt, AsyncReadExt, FutureExt};
  6use gpui::{AnyWindowHandle, App, AppContext, Entity, Task};
  7use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
  8use project::Project;
  9use schemars::JsonSchema;
 10use serde::{Deserialize, Serialize};
 11use std::future;
 12use util::get_system_shell;
 13
 14use std::path::Path;
 15use std::sync::Arc;
 16use ui::IconName;
 17use util::command::new_smol_command;
 18use util::markdown::MarkdownString;
 19
 20#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 21pub struct TerminalToolInput {
 22    /// The one-liner command to execute.
 23    command: String,
 24    /// Working directory for the command. This must be one of the root directories of the project.
 25    cd: String,
 26}
 27
 28pub struct TerminalTool;
 29
 30impl Tool for TerminalTool {
 31    fn name(&self) -> String {
 32        "terminal".to_string()
 33    }
 34
 35    fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
 36        true
 37    }
 38
 39    fn description(&self) -> String {
 40        include_str!("./terminal_tool/description.md").to_string()
 41    }
 42
 43    fn icon(&self) -> IconName {
 44        IconName::Terminal
 45    }
 46
 47    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
 48        json_schema_for::<TerminalToolInput>(format)
 49    }
 50
 51    fn ui_text(&self, input: &serde_json::Value) -> String {
 52        match serde_json::from_value::<TerminalToolInput>(input.clone()) {
 53            Ok(input) => {
 54                let mut lines = input.command.lines();
 55                let first_line = lines.next().unwrap_or_default();
 56                let remaining_line_count = lines.count();
 57                match remaining_line_count {
 58                    0 => MarkdownString::inline_code(&first_line).0,
 59                    1 => {
 60                        MarkdownString::inline_code(&format!(
 61                            "{} - {} more line",
 62                            first_line, remaining_line_count
 63                        ))
 64                        .0
 65                    }
 66                    n => {
 67                        MarkdownString::inline_code(&format!("{} - {} more lines", first_line, n)).0
 68                    }
 69                }
 70            }
 71            Err(_) => "Run terminal command".to_string(),
 72        }
 73    }
 74
 75    fn run(
 76        self: Arc<Self>,
 77        input: serde_json::Value,
 78        _messages: &[LanguageModelRequestMessage],
 79        project: Entity<Project>,
 80        _action_log: Entity<ActionLog>,
 81        _window: Option<AnyWindowHandle>,
 82        cx: &mut App,
 83    ) -> ToolResult {
 84        let input: TerminalToolInput = match serde_json::from_value(input) {
 85            Ok(input) => input,
 86            Err(err) => return Task::ready(Err(anyhow!(err))).into(),
 87        };
 88
 89        let project = project.read(cx);
 90        let input_path = Path::new(&input.cd);
 91        let working_dir = if input.cd == "." {
 92            // Accept "." as meaning "the one worktree" if we only have one worktree.
 93            let mut worktrees = project.worktrees(cx);
 94
 95            let only_worktree = match worktrees.next() {
 96                Some(worktree) => worktree,
 97                None => {
 98                    return Task::ready(Err(anyhow!("No worktrees found in the project"))).into();
 99                }
100            };
101
102            if worktrees.next().is_some() {
103                return Task::ready(Err(anyhow!(
104                    "'.' is ambiguous in multi-root workspaces. Please specify a root directory explicitly."
105                ))).into();
106            }
107
108            only_worktree.read(cx).abs_path()
109        } else if input_path.is_absolute() {
110            // Absolute paths are allowed, but only if they're in one of the project's worktrees.
111            if !project
112                .worktrees(cx)
113                .any(|worktree| input_path.starts_with(&worktree.read(cx).abs_path()))
114            {
115                return Task::ready(Err(anyhow!(
116                    "The absolute path must be within one of the project's worktrees"
117                )))
118                .into();
119            }
120
121            input_path.into()
122        } else {
123            let Some(worktree) = project.worktree_for_root_name(&input.cd, cx) else {
124                return Task::ready(Err(anyhow!(
125                    "`cd` directory {} not found in the project",
126                    &input.cd
127                )))
128                .into();
129            };
130
131            worktree.read(cx).abs_path()
132        };
133
134        cx.background_spawn(run_command_limited(working_dir, input.command))
135            .into()
136    }
137}
138
139const LIMIT: usize = 16 * 1024;
140
141async fn run_command_limited(working_dir: Arc<Path>, command: String) -> Result<String> {
142    let shell = get_system_shell();
143
144    let mut cmd = new_smol_command(&shell)
145        .arg("-c")
146        .arg(&command)
147        .current_dir(working_dir)
148        .stdout(std::process::Stdio::piped())
149        .stderr(std::process::Stdio::piped())
150        .spawn()
151        .context("Failed to execute terminal command")?;
152
153    let mut combined_buffer = String::with_capacity(LIMIT + 1);
154
155    let mut out_reader = BufReader::new(cmd.stdout.take().context("Failed to get stdout")?);
156    let mut out_tmp_buffer = String::with_capacity(512);
157    let mut err_reader = BufReader::new(cmd.stderr.take().context("Failed to get stderr")?);
158    let mut err_tmp_buffer = String::with_capacity(512);
159
160    let mut out_line = Box::pin(
161        out_reader
162            .read_line(&mut out_tmp_buffer)
163            .left_future()
164            .fuse(),
165    );
166    let mut err_line = Box::pin(
167        err_reader
168            .read_line(&mut err_tmp_buffer)
169            .left_future()
170            .fuse(),
171    );
172
173    let mut has_stdout = true;
174    let mut has_stderr = true;
175    while (has_stdout || has_stderr) && combined_buffer.len() < LIMIT + 1 {
176        futures::select_biased! {
177            read = out_line => {
178                drop(out_line);
179                combined_buffer.extend(out_tmp_buffer.drain(..));
180                if read? == 0 {
181                    out_line = Box::pin(future::pending().right_future().fuse());
182                    has_stdout = false;
183                } else {
184                    out_line = Box::pin(out_reader.read_line(&mut out_tmp_buffer).left_future().fuse());
185                }
186            }
187            read = err_line => {
188                drop(err_line);
189                combined_buffer.extend(err_tmp_buffer.drain(..));
190                if read? == 0 {
191                    err_line = Box::pin(future::pending().right_future().fuse());
192                    has_stderr = false;
193                } else {
194                    err_line = Box::pin(err_reader.read_line(&mut err_tmp_buffer).left_future().fuse());
195                }
196            }
197        };
198    }
199
200    drop((out_line, err_line));
201
202    let truncated = combined_buffer.len() > LIMIT;
203    combined_buffer.truncate(LIMIT);
204
205    consume_reader(out_reader, truncated).await?;
206    consume_reader(err_reader, truncated).await?;
207
208    // Handle potential errors during status retrieval, including interruption.
209    match cmd.status().await {
210        Ok(status) => {
211            let output_string = if truncated {
212                // Valid to find `\n` in UTF-8 since 0-127 ASCII characters are not used in
213                // multi-byte characters.
214                let last_line_ix = combined_buffer.bytes().rposition(|b| b == b'\n');
215                let buffer_content =
216                    &combined_buffer[..last_line_ix.unwrap_or(combined_buffer.len())];
217
218                format!(
219                    "Command output too long. The first {} bytes:\n\n{}",
220                    buffer_content.len(),
221                    output_block(buffer_content),
222                )
223            } else {
224                output_block(&combined_buffer)
225            };
226
227            let output_with_status = if status.success() {
228                if output_string.is_empty() {
229                    "Command executed successfully.".to_string()
230                } else {
231                    output_string
232                }
233            } else {
234                format!(
235                    "Command failed with exit code {} (shell: {}).\n\n{}",
236                    status.code().unwrap_or(-1),
237                    shell,
238                    output_string,
239                )
240            };
241
242            Ok(output_with_status)
243        }
244        Err(err) => {
245            // Error occurred getting status (potential interruption). Include partial output.
246            let partial_output = output_block(&combined_buffer);
247            let error_message = format!(
248                "Command failed or was interrupted.\nPartial output captured:\n\n{}",
249                partial_output
250            );
251            Err(anyhow!(err).context(error_message))
252        }
253    }
254}
255
256async fn consume_reader<T: AsyncReadExt + Unpin>(
257    mut reader: BufReader<T>,
258    truncated: bool,
259) -> Result<(), std::io::Error> {
260    loop {
261        let skipped_bytes = reader.fill_buf().await?;
262        if skipped_bytes.is_empty() {
263            break;
264        }
265        let skipped_bytes_len = skipped_bytes.len();
266        reader.consume_unpin(skipped_bytes_len);
267
268        // Should only skip if we went over the limit
269        debug_assert!(truncated);
270    }
271    Ok(())
272}
273
274fn output_block(output: &str) -> String {
275    format!(
276        "```\n{}{}```",
277        output,
278        if output.ends_with('\n') { "" } else { "\n" }
279    )
280}
281
282#[cfg(test)]
283#[cfg(not(windows))]
284mod tests {
285    use gpui::TestAppContext;
286
287    use super::*;
288
289    #[gpui::test(iterations = 10)]
290    async fn test_run_command_simple(cx: &mut TestAppContext) {
291        cx.executor().allow_parking();
292
293        let result =
294            run_command_limited(Path::new(".").into(), "echo 'Hello, World!'".to_string()).await;
295
296        assert!(result.is_ok());
297        assert_eq!(result.unwrap(), "```\nHello, World!\n```");
298    }
299
300    #[gpui::test(iterations = 10)]
301    async fn test_interleaved_stdout_stderr(cx: &mut TestAppContext) {
302        cx.executor().allow_parking();
303
304        let command = "echo 'stdout 1' && sleep 0.01 && echo 'stderr 1' >&2 && sleep 0.01 && echo 'stdout 2' && sleep 0.01 && echo 'stderr 2' >&2";
305        let result = run_command_limited(Path::new(".").into(), command.to_string()).await;
306
307        assert!(result.is_ok());
308        assert_eq!(
309            result.unwrap(),
310            "```\nstdout 1\nstderr 1\nstdout 2\nstderr 2\n```"
311        );
312    }
313
314    #[gpui::test(iterations = 10)]
315    async fn test_multiple_output_reads(cx: &mut TestAppContext) {
316        cx.executor().allow_parking();
317
318        // Command with multiple outputs that might require multiple reads
319        let result = run_command_limited(
320            Path::new(".").into(),
321            "echo '1'; sleep 0.01; echo '2'; sleep 0.01; echo '3'".to_string(),
322        )
323        .await;
324
325        assert!(result.is_ok());
326        assert_eq!(result.unwrap(), "```\n1\n2\n3\n```");
327    }
328
329    #[gpui::test(iterations = 10)]
330    async fn test_output_truncation_single_line(cx: &mut TestAppContext) {
331        cx.executor().allow_parking();
332
333        let cmd = format!("echo '{}'; sleep 0.01;", "X".repeat(LIMIT * 2));
334
335        let result = run_command_limited(Path::new(".").into(), cmd).await;
336
337        assert!(result.is_ok());
338        let output = result.unwrap();
339
340        let content_start = output.find("```\n").map(|i| i + 4).unwrap_or(0);
341        let content_end = output.rfind("\n```").unwrap_or(output.len());
342        let content_length = content_end - content_start;
343
344        // Output should be exactly the limit
345        assert_eq!(content_length, LIMIT);
346    }
347
348    #[gpui::test(iterations = 10)]
349    async fn test_output_truncation_multiline(cx: &mut TestAppContext) {
350        cx.executor().allow_parking();
351
352        let cmd = format!("echo '{}'; ", "X".repeat(120)).repeat(160);
353        let result = run_command_limited(Path::new(".").into(), cmd).await;
354
355        assert!(result.is_ok());
356        let output = result.unwrap();
357
358        assert!(output.starts_with("Command output too long. The first 16334 bytes:\n\n"));
359
360        let content_start = output.find("```\n").map(|i| i + 4).unwrap_or(0);
361        let content_end = output.rfind("\n```").unwrap_or(output.len());
362        let content_length = content_end - content_start;
363
364        assert!(content_length <= LIMIT);
365    }
366
367    #[gpui::test(iterations = 10)]
368    async fn test_command_failure(cx: &mut TestAppContext) {
369        cx.executor().allow_parking();
370
371        let result = run_command_limited(Path::new(".").into(), "exit 42".to_string()).await;
372
373        assert!(result.is_ok());
374        let output = result.unwrap();
375
376        // Extract the shell name from path for cleaner test output
377        let shell_path = std::env::var("SHELL").unwrap_or("bash".to_string());
378
379        let expected_output = format!(
380            "Command failed with exit code 42 (shell: {}).\n\n```\n\n```",
381            shell_path
382        );
383        assert_eq!(output, expected_output);
384    }
385}