1use crate::schema::json_schema_for;
2use anyhow::{Context as _, Result, anyhow};
3use assistant_tool::{ActionLog, Tool, ToolResult};
4use futures::io::BufReader;
5use futures::{AsyncBufReadExt, AsyncReadExt, FutureExt};
6use gpui::{AnyWindowHandle, App, AppContext, Entity, Task};
7use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
8use project::Project;
9use schemars::JsonSchema;
10use serde::{Deserialize, Serialize};
11use std::future;
12use util::get_system_shell;
13
14use std::path::Path;
15use std::sync::Arc;
16use ui::IconName;
17use util::command::new_smol_command;
18use util::markdown::MarkdownString;
19
20#[derive(Debug, Serialize, Deserialize, JsonSchema)]
21pub struct TerminalToolInput {
22 /// The one-liner command to execute.
23 command: String,
24 /// Working directory for the command. This must be one of the root directories of the project.
25 cd: String,
26}
27
28pub struct TerminalTool;
29
30impl Tool for TerminalTool {
31 fn name(&self) -> String {
32 "terminal".to_string()
33 }
34
35 fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
36 true
37 }
38
39 fn description(&self) -> String {
40 include_str!("./terminal_tool/description.md").to_string()
41 }
42
43 fn icon(&self) -> IconName {
44 IconName::Terminal
45 }
46
47 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
48 json_schema_for::<TerminalToolInput>(format)
49 }
50
51 fn ui_text(&self, input: &serde_json::Value) -> String {
52 match serde_json::from_value::<TerminalToolInput>(input.clone()) {
53 Ok(input) => {
54 let mut lines = input.command.lines();
55 let first_line = lines.next().unwrap_or_default();
56 let remaining_line_count = lines.count();
57 match remaining_line_count {
58 0 => MarkdownString::inline_code(&first_line).0,
59 1 => {
60 MarkdownString::inline_code(&format!(
61 "{} - {} more line",
62 first_line, remaining_line_count
63 ))
64 .0
65 }
66 n => {
67 MarkdownString::inline_code(&format!("{} - {} more lines", first_line, n)).0
68 }
69 }
70 }
71 Err(_) => "Run terminal command".to_string(),
72 }
73 }
74
75 fn run(
76 self: Arc<Self>,
77 input: serde_json::Value,
78 _messages: &[LanguageModelRequestMessage],
79 project: Entity<Project>,
80 _action_log: Entity<ActionLog>,
81 _window: Option<AnyWindowHandle>,
82 cx: &mut App,
83 ) -> ToolResult {
84 let input: TerminalToolInput = match serde_json::from_value(input) {
85 Ok(input) => input,
86 Err(err) => return Task::ready(Err(anyhow!(err))).into(),
87 };
88
89 let project = project.read(cx);
90 let input_path = Path::new(&input.cd);
91 let working_dir = if input.cd == "." {
92 // Accept "." as meaning "the one worktree" if we only have one worktree.
93 let mut worktrees = project.worktrees(cx);
94
95 let only_worktree = match worktrees.next() {
96 Some(worktree) => worktree,
97 None => {
98 return Task::ready(Err(anyhow!("No worktrees found in the project"))).into();
99 }
100 };
101
102 if worktrees.next().is_some() {
103 return Task::ready(Err(anyhow!(
104 "'.' is ambiguous in multi-root workspaces. Please specify a root directory explicitly."
105 ))).into();
106 }
107
108 only_worktree.read(cx).abs_path()
109 } else if input_path.is_absolute() {
110 // Absolute paths are allowed, but only if they're in one of the project's worktrees.
111 if !project
112 .worktrees(cx)
113 .any(|worktree| input_path.starts_with(&worktree.read(cx).abs_path()))
114 {
115 return Task::ready(Err(anyhow!(
116 "The absolute path must be within one of the project's worktrees"
117 )))
118 .into();
119 }
120
121 input_path.into()
122 } else {
123 let Some(worktree) = project.worktree_for_root_name(&input.cd, cx) else {
124 return Task::ready(Err(anyhow!(
125 "`cd` directory {} not found in the project",
126 &input.cd
127 )))
128 .into();
129 };
130
131 worktree.read(cx).abs_path()
132 };
133
134 cx.background_spawn(run_command_limited(working_dir, input.command))
135 .into()
136 }
137}
138
139const LIMIT: usize = 16 * 1024;
140
141async fn run_command_limited(working_dir: Arc<Path>, command: String) -> Result<String> {
142 let shell = get_system_shell();
143
144 let mut cmd = new_smol_command(&shell)
145 .arg("-c")
146 .arg(&command)
147 .current_dir(working_dir)
148 .stdout(std::process::Stdio::piped())
149 .stderr(std::process::Stdio::piped())
150 .spawn()
151 .context("Failed to execute terminal command")?;
152
153 let mut combined_buffer = String::with_capacity(LIMIT + 1);
154
155 let mut out_reader = BufReader::new(cmd.stdout.take().context("Failed to get stdout")?);
156 let mut out_tmp_buffer = String::with_capacity(512);
157 let mut err_reader = BufReader::new(cmd.stderr.take().context("Failed to get stderr")?);
158 let mut err_tmp_buffer = String::with_capacity(512);
159
160 let mut out_line = Box::pin(
161 out_reader
162 .read_line(&mut out_tmp_buffer)
163 .left_future()
164 .fuse(),
165 );
166 let mut err_line = Box::pin(
167 err_reader
168 .read_line(&mut err_tmp_buffer)
169 .left_future()
170 .fuse(),
171 );
172
173 let mut has_stdout = true;
174 let mut has_stderr = true;
175 while (has_stdout || has_stderr) && combined_buffer.len() < LIMIT + 1 {
176 futures::select_biased! {
177 read = out_line => {
178 drop(out_line);
179 combined_buffer.extend(out_tmp_buffer.drain(..));
180 if read? == 0 {
181 out_line = Box::pin(future::pending().right_future().fuse());
182 has_stdout = false;
183 } else {
184 out_line = Box::pin(out_reader.read_line(&mut out_tmp_buffer).left_future().fuse());
185 }
186 }
187 read = err_line => {
188 drop(err_line);
189 combined_buffer.extend(err_tmp_buffer.drain(..));
190 if read? == 0 {
191 err_line = Box::pin(future::pending().right_future().fuse());
192 has_stderr = false;
193 } else {
194 err_line = Box::pin(err_reader.read_line(&mut err_tmp_buffer).left_future().fuse());
195 }
196 }
197 };
198 }
199
200 drop((out_line, err_line));
201
202 let truncated = combined_buffer.len() > LIMIT;
203 combined_buffer.truncate(LIMIT);
204
205 consume_reader(out_reader, truncated).await?;
206 consume_reader(err_reader, truncated).await?;
207
208 // Handle potential errors during status retrieval, including interruption.
209 match cmd.status().await {
210 Ok(status) => {
211 let output_string = if truncated {
212 // Valid to find `\n` in UTF-8 since 0-127 ASCII characters are not used in
213 // multi-byte characters.
214 let last_line_ix = combined_buffer.bytes().rposition(|b| b == b'\n');
215 let buffer_content =
216 &combined_buffer[..last_line_ix.unwrap_or(combined_buffer.len())];
217
218 format!(
219 "Command output too long. The first {} bytes:\n\n{}",
220 buffer_content.len(),
221 output_block(buffer_content),
222 )
223 } else {
224 output_block(&combined_buffer)
225 };
226
227 let output_with_status = if status.success() {
228 if output_string.is_empty() {
229 "Command executed successfully.".to_string()
230 } else {
231 output_string
232 }
233 } else {
234 format!(
235 "Command failed with exit code {} (shell: {}).\n\n{}",
236 status.code().unwrap_or(-1),
237 shell,
238 output_string,
239 )
240 };
241
242 Ok(output_with_status)
243 }
244 Err(err) => {
245 // Error occurred getting status (potential interruption). Include partial output.
246 let partial_output = output_block(&combined_buffer);
247 let error_message = format!(
248 "Command failed or was interrupted.\nPartial output captured:\n\n{}",
249 partial_output
250 );
251 Err(anyhow!(err).context(error_message))
252 }
253 }
254}
255
256async fn consume_reader<T: AsyncReadExt + Unpin>(
257 mut reader: BufReader<T>,
258 truncated: bool,
259) -> Result<(), std::io::Error> {
260 loop {
261 let skipped_bytes = reader.fill_buf().await?;
262 if skipped_bytes.is_empty() {
263 break;
264 }
265 let skipped_bytes_len = skipped_bytes.len();
266 reader.consume_unpin(skipped_bytes_len);
267
268 // Should only skip if we went over the limit
269 debug_assert!(truncated);
270 }
271 Ok(())
272}
273
274fn output_block(output: &str) -> String {
275 format!(
276 "```\n{}{}```",
277 output,
278 if output.ends_with('\n') { "" } else { "\n" }
279 )
280}
281
282#[cfg(test)]
283#[cfg(not(windows))]
284mod tests {
285 use gpui::TestAppContext;
286
287 use super::*;
288
289 #[gpui::test(iterations = 10)]
290 async fn test_run_command_simple(cx: &mut TestAppContext) {
291 cx.executor().allow_parking();
292
293 let result =
294 run_command_limited(Path::new(".").into(), "echo 'Hello, World!'".to_string()).await;
295
296 assert!(result.is_ok());
297 assert_eq!(result.unwrap(), "```\nHello, World!\n```");
298 }
299
300 #[gpui::test(iterations = 10)]
301 async fn test_interleaved_stdout_stderr(cx: &mut TestAppContext) {
302 cx.executor().allow_parking();
303
304 let command = "echo 'stdout 1' && sleep 0.01 && echo 'stderr 1' >&2 && sleep 0.01 && echo 'stdout 2' && sleep 0.01 && echo 'stderr 2' >&2";
305 let result = run_command_limited(Path::new(".").into(), command.to_string()).await;
306
307 assert!(result.is_ok());
308 assert_eq!(
309 result.unwrap(),
310 "```\nstdout 1\nstderr 1\nstdout 2\nstderr 2\n```"
311 );
312 }
313
314 #[gpui::test(iterations = 10)]
315 async fn test_multiple_output_reads(cx: &mut TestAppContext) {
316 cx.executor().allow_parking();
317
318 // Command with multiple outputs that might require multiple reads
319 let result = run_command_limited(
320 Path::new(".").into(),
321 "echo '1'; sleep 0.01; echo '2'; sleep 0.01; echo '3'".to_string(),
322 )
323 .await;
324
325 assert!(result.is_ok());
326 assert_eq!(result.unwrap(), "```\n1\n2\n3\n```");
327 }
328
329 #[gpui::test(iterations = 10)]
330 async fn test_output_truncation_single_line(cx: &mut TestAppContext) {
331 cx.executor().allow_parking();
332
333 let cmd = format!("echo '{}'; sleep 0.01;", "X".repeat(LIMIT * 2));
334
335 let result = run_command_limited(Path::new(".").into(), cmd).await;
336
337 assert!(result.is_ok());
338 let output = result.unwrap();
339
340 let content_start = output.find("```\n").map(|i| i + 4).unwrap_or(0);
341 let content_end = output.rfind("\n```").unwrap_or(output.len());
342 let content_length = content_end - content_start;
343
344 // Output should be exactly the limit
345 assert_eq!(content_length, LIMIT);
346 }
347
348 #[gpui::test(iterations = 10)]
349 async fn test_output_truncation_multiline(cx: &mut TestAppContext) {
350 cx.executor().allow_parking();
351
352 let cmd = format!("echo '{}'; ", "X".repeat(120)).repeat(160);
353 let result = run_command_limited(Path::new(".").into(), cmd).await;
354
355 assert!(result.is_ok());
356 let output = result.unwrap();
357
358 assert!(output.starts_with("Command output too long. The first 16334 bytes:\n\n"));
359
360 let content_start = output.find("```\n").map(|i| i + 4).unwrap_or(0);
361 let content_end = output.rfind("\n```").unwrap_or(output.len());
362 let content_length = content_end - content_start;
363
364 assert!(content_length <= LIMIT);
365 }
366
367 #[gpui::test(iterations = 10)]
368 async fn test_command_failure(cx: &mut TestAppContext) {
369 cx.executor().allow_parking();
370
371 let result = run_command_limited(Path::new(".").into(), "exit 42".to_string()).await;
372
373 assert!(result.is_ok());
374 let output = result.unwrap();
375
376 // Extract the shell name from path for cleaner test output
377 let shell_path = std::env::var("SHELL").unwrap_or("bash".to_string());
378
379 let expected_output = format!(
380 "Command failed with exit code 42 (shell: {}).\n\n```\n\n```",
381 shell_path
382 );
383 assert_eq!(output, expected_output);
384 }
385}