terminal_tool.rs

  1use crate::schema::json_schema_for;
  2use anyhow::{Context as _, Result, anyhow};
  3use assistant_tool::{ActionLog, Tool, ToolResult};
  4use futures::io::BufReader;
  5use futures::{AsyncBufReadExt, AsyncReadExt, FutureExt};
  6use gpui::{AnyWindowHandle, App, AppContext, Entity, Task};
  7use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
  8use project::Project;
  9use schemars::JsonSchema;
 10use serde::{Deserialize, Serialize};
 11use std::future;
 12use util::get_system_shell;
 13
 14use std::path::Path;
 15use std::sync::Arc;
 16use ui::IconName;
 17use util::command::new_smol_command;
 18use util::markdown::MarkdownInlineCode;
 19
 20#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 21pub struct TerminalToolInput {
 22    /// The one-liner command to execute.
 23    command: String,
 24    /// Working directory for the command. This must be one of the root directories of the project.
 25    cd: String,
 26}
 27
 28pub struct TerminalTool;
 29
 30impl Tool for TerminalTool {
 31    fn name(&self) -> String {
 32        "terminal".to_string()
 33    }
 34
 35    fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
 36        true
 37    }
 38
 39    fn description(&self) -> String {
 40        include_str!("./terminal_tool/description.md").to_string()
 41    }
 42
 43    fn icon(&self) -> IconName {
 44        IconName::Terminal
 45    }
 46
 47    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
 48        json_schema_for::<TerminalToolInput>(format)
 49    }
 50
 51    fn ui_text(&self, input: &serde_json::Value) -> String {
 52        match serde_json::from_value::<TerminalToolInput>(input.clone()) {
 53            Ok(input) => {
 54                let mut lines = input.command.lines();
 55                let first_line = lines.next().unwrap_or_default();
 56                let remaining_line_count = lines.count();
 57                match remaining_line_count {
 58                    0 => MarkdownInlineCode(&first_line).to_string(),
 59                    1 => MarkdownInlineCode(&format!(
 60                        "{} - {} more line",
 61                        first_line, remaining_line_count
 62                    ))
 63                    .to_string(),
 64                    n => MarkdownInlineCode(&format!("{} - {} more lines", first_line, n))
 65                        .to_string(),
 66                }
 67            }
 68            Err(_) => "Run terminal command".to_string(),
 69        }
 70    }
 71
 72    fn run(
 73        self: Arc<Self>,
 74        input: serde_json::Value,
 75        _messages: &[LanguageModelRequestMessage],
 76        project: Entity<Project>,
 77        _action_log: Entity<ActionLog>,
 78        _window: Option<AnyWindowHandle>,
 79        cx: &mut App,
 80    ) -> ToolResult {
 81        let input: TerminalToolInput = match serde_json::from_value(input) {
 82            Ok(input) => input,
 83            Err(err) => return Task::ready(Err(anyhow!(err))).into(),
 84        };
 85
 86        let project = project.read(cx);
 87        let input_path = Path::new(&input.cd);
 88        let working_dir = if input.cd == "." {
 89            // Accept "." as meaning "the one worktree" if we only have one worktree.
 90            let mut worktrees = project.worktrees(cx);
 91
 92            let only_worktree = match worktrees.next() {
 93                Some(worktree) => worktree,
 94                None => {
 95                    return Task::ready(Err(anyhow!("No worktrees found in the project"))).into();
 96                }
 97            };
 98
 99            if worktrees.next().is_some() {
100                return Task::ready(Err(anyhow!(
101                    "'.' is ambiguous in multi-root workspaces. Please specify a root directory explicitly."
102                ))).into();
103            }
104
105            only_worktree.read(cx).abs_path()
106        } else if input_path.is_absolute() {
107            // Absolute paths are allowed, but only if they're in one of the project's worktrees.
108            if !project
109                .worktrees(cx)
110                .any(|worktree| input_path.starts_with(&worktree.read(cx).abs_path()))
111            {
112                return Task::ready(Err(anyhow!(
113                    "The absolute path must be within one of the project's worktrees"
114                )))
115                .into();
116            }
117
118            input_path.into()
119        } else {
120            let Some(worktree) = project.worktree_for_root_name(&input.cd, cx) else {
121                return Task::ready(Err(anyhow!(
122                    "`cd` directory {} not found in the project",
123                    &input.cd
124                )))
125                .into();
126            };
127
128            worktree.read(cx).abs_path()
129        };
130
131        cx.background_spawn(run_command_limited(working_dir, input.command))
132            .into()
133    }
134}
135
136const LIMIT: usize = 16 * 1024;
137
138async fn run_command_limited(working_dir: Arc<Path>, command: String) -> Result<String> {
139    let shell = get_system_shell();
140
141    let mut cmd = new_smol_command(&shell)
142        .arg("-c")
143        .arg(&command)
144        .current_dir(working_dir)
145        .stdout(std::process::Stdio::piped())
146        .stderr(std::process::Stdio::piped())
147        .spawn()
148        .context("Failed to execute terminal command")?;
149
150    let mut combined_buffer = String::with_capacity(LIMIT + 1);
151
152    let mut out_reader = BufReader::new(cmd.stdout.take().context("Failed to get stdout")?);
153    let mut out_tmp_buffer = String::with_capacity(512);
154    let mut err_reader = BufReader::new(cmd.stderr.take().context("Failed to get stderr")?);
155    let mut err_tmp_buffer = String::with_capacity(512);
156
157    let mut out_line = Box::pin(
158        out_reader
159            .read_line(&mut out_tmp_buffer)
160            .left_future()
161            .fuse(),
162    );
163    let mut err_line = Box::pin(
164        err_reader
165            .read_line(&mut err_tmp_buffer)
166            .left_future()
167            .fuse(),
168    );
169
170    let mut has_stdout = true;
171    let mut has_stderr = true;
172    while (has_stdout || has_stderr) && combined_buffer.len() < LIMIT + 1 {
173        futures::select_biased! {
174            read = out_line => {
175                drop(out_line);
176                combined_buffer.extend(out_tmp_buffer.drain(..));
177                if read? == 0 {
178                    out_line = Box::pin(future::pending().right_future().fuse());
179                    has_stdout = false;
180                } else {
181                    out_line = Box::pin(out_reader.read_line(&mut out_tmp_buffer).left_future().fuse());
182                }
183            }
184            read = err_line => {
185                drop(err_line);
186                combined_buffer.extend(err_tmp_buffer.drain(..));
187                if read? == 0 {
188                    err_line = Box::pin(future::pending().right_future().fuse());
189                    has_stderr = false;
190                } else {
191                    err_line = Box::pin(err_reader.read_line(&mut err_tmp_buffer).left_future().fuse());
192                }
193            }
194        };
195    }
196
197    drop((out_line, err_line));
198
199    let truncated = combined_buffer.len() > LIMIT;
200    combined_buffer.truncate(LIMIT);
201
202    consume_reader(out_reader, truncated).await?;
203    consume_reader(err_reader, truncated).await?;
204
205    let status = cmd.status().await.context("Failed to get command status")?;
206
207    let output_string = if truncated {
208        // Valid to find `\n` in UTF-8 since 0-127 ASCII characters are not used in
209        // multi-byte characters.
210        let last_line_ix = combined_buffer.bytes().rposition(|b| b == b'\n');
211        let combined_buffer = &combined_buffer[..last_line_ix.unwrap_or(combined_buffer.len())];
212
213        format!(
214            "Command output too long. The first {} bytes:\n\n{}",
215            combined_buffer.len(),
216            output_block(&combined_buffer),
217        )
218    } else {
219        output_block(&combined_buffer)
220    };
221
222    let output_with_status = if status.success() {
223        if output_string.is_empty() {
224            "Command executed successfully.".to_string()
225        } else {
226            output_string.to_string()
227        }
228    } else {
229        format!(
230            "Command failed with exit code {} (shell: {}).\n\n{}",
231            status.code().unwrap_or(-1),
232            shell,
233            output_string,
234        )
235    };
236
237    Ok(output_with_status)
238}
239
240async fn consume_reader<T: AsyncReadExt + Unpin>(
241    mut reader: BufReader<T>,
242    truncated: bool,
243) -> Result<(), std::io::Error> {
244    loop {
245        let skipped_bytes = reader.fill_buf().await?;
246        if skipped_bytes.is_empty() {
247            break;
248        }
249        let skipped_bytes_len = skipped_bytes.len();
250        reader.consume_unpin(skipped_bytes_len);
251
252        // Should only skip if we went over the limit
253        debug_assert!(truncated);
254    }
255    Ok(())
256}
257
258fn output_block(output: &str) -> String {
259    format!(
260        "```\n{}{}```",
261        output,
262        if output.ends_with('\n') { "" } else { "\n" }
263    )
264}
265
266#[cfg(test)]
267#[cfg(not(windows))]
268mod tests {
269    use gpui::TestAppContext;
270
271    use super::*;
272
273    #[gpui::test(iterations = 10)]
274    async fn test_run_command_simple(cx: &mut TestAppContext) {
275        cx.executor().allow_parking();
276
277        let result =
278            run_command_limited(Path::new(".").into(), "echo 'Hello, World!'".to_string()).await;
279
280        assert!(result.is_ok());
281        assert_eq!(result.unwrap(), "```\nHello, World!\n```");
282    }
283
284    #[gpui::test(iterations = 10)]
285    async fn test_interleaved_stdout_stderr(cx: &mut TestAppContext) {
286        cx.executor().allow_parking();
287
288        let command = "echo 'stdout 1' && sleep 0.01 && echo 'stderr 1' >&2 && sleep 0.01 && echo 'stdout 2' && sleep 0.01 && echo 'stderr 2' >&2";
289        let result = run_command_limited(Path::new(".").into(), command.to_string()).await;
290
291        assert!(result.is_ok());
292        assert_eq!(
293            result.unwrap(),
294            "```\nstdout 1\nstderr 1\nstdout 2\nstderr 2\n```"
295        );
296    }
297
298    #[gpui::test(iterations = 10)]
299    async fn test_multiple_output_reads(cx: &mut TestAppContext) {
300        cx.executor().allow_parking();
301
302        // Command with multiple outputs that might require multiple reads
303        let result = run_command_limited(
304            Path::new(".").into(),
305            "echo '1'; sleep 0.01; echo '2'; sleep 0.01; echo '3'".to_string(),
306        )
307        .await;
308
309        assert!(result.is_ok());
310        assert_eq!(result.unwrap(), "```\n1\n2\n3\n```");
311    }
312
313    #[gpui::test(iterations = 10)]
314    async fn test_output_truncation_single_line(cx: &mut TestAppContext) {
315        cx.executor().allow_parking();
316
317        let cmd = format!("echo '{}'; sleep 0.01;", "X".repeat(LIMIT * 2));
318
319        let result = run_command_limited(Path::new(".").into(), cmd).await;
320
321        assert!(result.is_ok());
322        let output = result.unwrap();
323
324        let content_start = output.find("```\n").map(|i| i + 4).unwrap_or(0);
325        let content_end = output.rfind("\n```").unwrap_or(output.len());
326        let content_length = content_end - content_start;
327
328        // Output should be exactly the limit
329        assert_eq!(content_length, LIMIT);
330    }
331
332    #[gpui::test(iterations = 10)]
333    async fn test_output_truncation_multiline(cx: &mut TestAppContext) {
334        cx.executor().allow_parking();
335
336        let cmd = format!("echo '{}'; ", "X".repeat(120)).repeat(160);
337        let result = run_command_limited(Path::new(".").into(), cmd).await;
338
339        assert!(result.is_ok());
340        let output = result.unwrap();
341
342        assert!(output.starts_with("Command output too long. The first 16334 bytes:\n\n"));
343
344        let content_start = output.find("```\n").map(|i| i + 4).unwrap_or(0);
345        let content_end = output.rfind("\n```").unwrap_or(output.len());
346        let content_length = content_end - content_start;
347
348        assert!(content_length <= LIMIT);
349    }
350
351    #[gpui::test(iterations = 10)]
352    async fn test_command_failure(cx: &mut TestAppContext) {
353        cx.executor().allow_parking();
354
355        let result = run_command_limited(Path::new(".").into(), "exit 42".to_string()).await;
356
357        assert!(result.is_ok());
358        let output = result.unwrap();
359
360        // Extract the shell name from path for cleaner test output
361        let shell_path = std::env::var("SHELL").unwrap_or("bash".to_string());
362
363        let expected_output = format!(
364            "Command failed with exit code 42 (shell: {}).\n\n```\n\n```",
365            shell_path
366        );
367        assert_eq!(output, expected_output);
368    }
369}