1use crate::schema::json_schema_for;
2use anyhow::{Context as _, Result, anyhow};
3use assistant_tool::{ActionLog, Tool, ToolResult};
4use futures::io::BufReader;
5use futures::{AsyncBufReadExt, AsyncReadExt, FutureExt};
6use gpui::{AnyWindowHandle, App, AppContext, Entity, Task};
7use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
8use project::Project;
9use schemars::JsonSchema;
10use serde::{Deserialize, Serialize};
11use std::future;
12use util::get_system_shell;
13
14use std::path::Path;
15use std::sync::Arc;
16use ui::IconName;
17use util::command::new_smol_command;
18use util::markdown::MarkdownInlineCode;
19
20#[derive(Debug, Serialize, Deserialize, JsonSchema)]
21pub struct TerminalToolInput {
22 /// The one-liner command to execute.
23 command: String,
24 /// Working directory for the command. This must be one of the root directories of the project.
25 cd: String,
26}
27
28pub struct TerminalTool;
29
30impl Tool for TerminalTool {
31 fn name(&self) -> String {
32 "terminal".to_string()
33 }
34
35 fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
36 true
37 }
38
39 fn description(&self) -> String {
40 include_str!("./terminal_tool/description.md").to_string()
41 }
42
43 fn icon(&self) -> IconName {
44 IconName::Terminal
45 }
46
47 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
48 json_schema_for::<TerminalToolInput>(format)
49 }
50
51 fn ui_text(&self, input: &serde_json::Value) -> String {
52 match serde_json::from_value::<TerminalToolInput>(input.clone()) {
53 Ok(input) => {
54 let mut lines = input.command.lines();
55 let first_line = lines.next().unwrap_or_default();
56 let remaining_line_count = lines.count();
57 match remaining_line_count {
58 0 => MarkdownInlineCode(&first_line).to_string(),
59 1 => MarkdownInlineCode(&format!(
60 "{} - {} more line",
61 first_line, remaining_line_count
62 ))
63 .to_string(),
64 n => MarkdownInlineCode(&format!("{} - {} more lines", first_line, n))
65 .to_string(),
66 }
67 }
68 Err(_) => "Run terminal command".to_string(),
69 }
70 }
71
72 fn run(
73 self: Arc<Self>,
74 input: serde_json::Value,
75 _messages: &[LanguageModelRequestMessage],
76 project: Entity<Project>,
77 _action_log: Entity<ActionLog>,
78 _window: Option<AnyWindowHandle>,
79 cx: &mut App,
80 ) -> ToolResult {
81 let input: TerminalToolInput = match serde_json::from_value(input) {
82 Ok(input) => input,
83 Err(err) => return Task::ready(Err(anyhow!(err))).into(),
84 };
85
86 let project = project.read(cx);
87 let input_path = Path::new(&input.cd);
88 let working_dir = if input.cd == "." {
89 // Accept "." as meaning "the one worktree" if we only have one worktree.
90 let mut worktrees = project.worktrees(cx);
91
92 let only_worktree = match worktrees.next() {
93 Some(worktree) => worktree,
94 None => {
95 return Task::ready(Err(anyhow!("No worktrees found in the project"))).into();
96 }
97 };
98
99 if worktrees.next().is_some() {
100 return Task::ready(Err(anyhow!(
101 "'.' is ambiguous in multi-root workspaces. Please specify a root directory explicitly."
102 ))).into();
103 }
104
105 only_worktree.read(cx).abs_path()
106 } else if input_path.is_absolute() {
107 // Absolute paths are allowed, but only if they're in one of the project's worktrees.
108 if !project
109 .worktrees(cx)
110 .any(|worktree| input_path.starts_with(&worktree.read(cx).abs_path()))
111 {
112 return Task::ready(Err(anyhow!(
113 "The absolute path must be within one of the project's worktrees"
114 )))
115 .into();
116 }
117
118 input_path.into()
119 } else {
120 let Some(worktree) = project.worktree_for_root_name(&input.cd, cx) else {
121 return Task::ready(Err(anyhow!(
122 "`cd` directory {} not found in the project",
123 &input.cd
124 )))
125 .into();
126 };
127
128 worktree.read(cx).abs_path()
129 };
130
131 cx.background_spawn(run_command_limited(working_dir, input.command))
132 .into()
133 }
134}
135
136const LIMIT: usize = 16 * 1024;
137
138async fn run_command_limited(working_dir: Arc<Path>, command: String) -> Result<String> {
139 let shell = get_system_shell();
140
141 let mut cmd = new_smol_command(&shell)
142 .arg("-c")
143 .arg(&command)
144 .current_dir(working_dir)
145 .stdout(std::process::Stdio::piped())
146 .stderr(std::process::Stdio::piped())
147 .spawn()
148 .context("Failed to execute terminal command")?;
149
150 let mut combined_buffer = String::with_capacity(LIMIT + 1);
151
152 let mut out_reader = BufReader::new(cmd.stdout.take().context("Failed to get stdout")?);
153 let mut out_tmp_buffer = String::with_capacity(512);
154 let mut err_reader = BufReader::new(cmd.stderr.take().context("Failed to get stderr")?);
155 let mut err_tmp_buffer = String::with_capacity(512);
156
157 let mut out_line = Box::pin(
158 out_reader
159 .read_line(&mut out_tmp_buffer)
160 .left_future()
161 .fuse(),
162 );
163 let mut err_line = Box::pin(
164 err_reader
165 .read_line(&mut err_tmp_buffer)
166 .left_future()
167 .fuse(),
168 );
169
170 let mut has_stdout = true;
171 let mut has_stderr = true;
172 while (has_stdout || has_stderr) && combined_buffer.len() < LIMIT + 1 {
173 futures::select_biased! {
174 read = out_line => {
175 drop(out_line);
176 combined_buffer.extend(out_tmp_buffer.drain(..));
177 if read? == 0 {
178 out_line = Box::pin(future::pending().right_future().fuse());
179 has_stdout = false;
180 } else {
181 out_line = Box::pin(out_reader.read_line(&mut out_tmp_buffer).left_future().fuse());
182 }
183 }
184 read = err_line => {
185 drop(err_line);
186 combined_buffer.extend(err_tmp_buffer.drain(..));
187 if read? == 0 {
188 err_line = Box::pin(future::pending().right_future().fuse());
189 has_stderr = false;
190 } else {
191 err_line = Box::pin(err_reader.read_line(&mut err_tmp_buffer).left_future().fuse());
192 }
193 }
194 };
195 }
196
197 drop((out_line, err_line));
198
199 let truncated = combined_buffer.len() > LIMIT;
200 combined_buffer.truncate(LIMIT);
201
202 consume_reader(out_reader, truncated).await?;
203 consume_reader(err_reader, truncated).await?;
204
205 let status = cmd.status().await.context("Failed to get command status")?;
206
207 let output_string = if truncated {
208 // Valid to find `\n` in UTF-8 since 0-127 ASCII characters are not used in
209 // multi-byte characters.
210 let last_line_ix = combined_buffer.bytes().rposition(|b| b == b'\n');
211 let combined_buffer = &combined_buffer[..last_line_ix.unwrap_or(combined_buffer.len())];
212
213 format!(
214 "Command output too long. The first {} bytes:\n\n{}",
215 combined_buffer.len(),
216 output_block(&combined_buffer),
217 )
218 } else {
219 output_block(&combined_buffer)
220 };
221
222 let output_with_status = if status.success() {
223 if output_string.is_empty() {
224 "Command executed successfully.".to_string()
225 } else {
226 output_string.to_string()
227 }
228 } else {
229 format!(
230 "Command failed with exit code {} (shell: {}).\n\n{}",
231 status.code().unwrap_or(-1),
232 shell,
233 output_string,
234 )
235 };
236
237 Ok(output_with_status)
238}
239
240async fn consume_reader<T: AsyncReadExt + Unpin>(
241 mut reader: BufReader<T>,
242 truncated: bool,
243) -> Result<(), std::io::Error> {
244 loop {
245 let skipped_bytes = reader.fill_buf().await?;
246 if skipped_bytes.is_empty() {
247 break;
248 }
249 let skipped_bytes_len = skipped_bytes.len();
250 reader.consume_unpin(skipped_bytes_len);
251
252 // Should only skip if we went over the limit
253 debug_assert!(truncated);
254 }
255 Ok(())
256}
257
258fn output_block(output: &str) -> String {
259 format!(
260 "```\n{}{}```",
261 output,
262 if output.ends_with('\n') { "" } else { "\n" }
263 )
264}
265
266#[cfg(test)]
267#[cfg(not(windows))]
268mod tests {
269 use gpui::TestAppContext;
270
271 use super::*;
272
273 #[gpui::test(iterations = 10)]
274 async fn test_run_command_simple(cx: &mut TestAppContext) {
275 cx.executor().allow_parking();
276
277 let result =
278 run_command_limited(Path::new(".").into(), "echo 'Hello, World!'".to_string()).await;
279
280 assert!(result.is_ok());
281 assert_eq!(result.unwrap(), "```\nHello, World!\n```");
282 }
283
284 #[gpui::test(iterations = 10)]
285 async fn test_interleaved_stdout_stderr(cx: &mut TestAppContext) {
286 cx.executor().allow_parking();
287
288 let command = "echo 'stdout 1' && sleep 0.01 && echo 'stderr 1' >&2 && sleep 0.01 && echo 'stdout 2' && sleep 0.01 && echo 'stderr 2' >&2";
289 let result = run_command_limited(Path::new(".").into(), command.to_string()).await;
290
291 assert!(result.is_ok());
292 assert_eq!(
293 result.unwrap(),
294 "```\nstdout 1\nstderr 1\nstdout 2\nstderr 2\n```"
295 );
296 }
297
298 #[gpui::test(iterations = 10)]
299 async fn test_multiple_output_reads(cx: &mut TestAppContext) {
300 cx.executor().allow_parking();
301
302 // Command with multiple outputs that might require multiple reads
303 let result = run_command_limited(
304 Path::new(".").into(),
305 "echo '1'; sleep 0.01; echo '2'; sleep 0.01; echo '3'".to_string(),
306 )
307 .await;
308
309 assert!(result.is_ok());
310 assert_eq!(result.unwrap(), "```\n1\n2\n3\n```");
311 }
312
313 #[gpui::test(iterations = 10)]
314 async fn test_output_truncation_single_line(cx: &mut TestAppContext) {
315 cx.executor().allow_parking();
316
317 let cmd = format!("echo '{}'; sleep 0.01;", "X".repeat(LIMIT * 2));
318
319 let result = run_command_limited(Path::new(".").into(), cmd).await;
320
321 assert!(result.is_ok());
322 let output = result.unwrap();
323
324 let content_start = output.find("```\n").map(|i| i + 4).unwrap_or(0);
325 let content_end = output.rfind("\n```").unwrap_or(output.len());
326 let content_length = content_end - content_start;
327
328 // Output should be exactly the limit
329 assert_eq!(content_length, LIMIT);
330 }
331
332 #[gpui::test(iterations = 10)]
333 async fn test_output_truncation_multiline(cx: &mut TestAppContext) {
334 cx.executor().allow_parking();
335
336 let cmd = format!("echo '{}'; ", "X".repeat(120)).repeat(160);
337 let result = run_command_limited(Path::new(".").into(), cmd).await;
338
339 assert!(result.is_ok());
340 let output = result.unwrap();
341
342 assert!(output.starts_with("Command output too long. The first 16334 bytes:\n\n"));
343
344 let content_start = output.find("```\n").map(|i| i + 4).unwrap_or(0);
345 let content_end = output.rfind("\n```").unwrap_or(output.len());
346 let content_length = content_end - content_start;
347
348 assert!(content_length <= LIMIT);
349 }
350
351 #[gpui::test(iterations = 10)]
352 async fn test_command_failure(cx: &mut TestAppContext) {
353 cx.executor().allow_parking();
354
355 let result = run_command_limited(Path::new(".").into(), "exit 42".to_string()).await;
356
357 assert!(result.is_ok());
358 let output = result.unwrap();
359
360 // Extract the shell name from path for cleaner test output
361 let shell_path = std::env::var("SHELL").unwrap_or("bash".to_string());
362
363 let expected_output = format!(
364 "Command failed with exit code 42 (shell: {}).\n\n```\n\n```",
365 shell_path
366 );
367 assert_eq!(output, expected_output);
368 }
369}