1use anyhow::anyhow;
2use buffer_diff::BufferDiff;
3use collections::{HashMap, HashSet};
4use futures::{
5 channel::{mpsc, oneshot},
6 pin_mut, SinkExt, StreamExt,
7};
8use gpui::{AppContext, AsyncApp, Context, Entity, Task, WeakEntity};
9use language::Buffer;
10use mlua::{ExternalResult, Lua, MultiValue, ObjectLike, Table, UserData, UserDataMethods};
11use parking_lot::Mutex;
12use project::{search::SearchQuery, Fs, Project, ProjectPath, WorktreeId};
13use regex::Regex;
14use std::{
15 path::{Path, PathBuf},
16 sync::Arc,
17};
18use util::{paths::PathMatcher, ResultExt};
19
20struct ForegroundFn(Box<dyn FnOnce(WeakEntity<ScriptingSession>, AsyncApp) + Send>);
21
22struct BufferChanges {
23 diff: Entity<BufferDiff>,
24 edit_ids: Vec<clock::Lamport>,
25}
26
27pub struct ScriptingSession {
28 project: Entity<Project>,
29 scripts: Vec<Script>,
30 changes_by_buffer: HashMap<Entity<Buffer>, BufferChanges>,
31 foreground_fns_tx: mpsc::Sender<ForegroundFn>,
32 _invoke_foreground_fns: Task<()>,
33}
34
35impl ScriptingSession {
36 pub fn new(project: Entity<Project>, cx: &mut Context<Self>) -> Self {
37 let (foreground_fns_tx, mut foreground_fns_rx) = mpsc::channel(128);
38 ScriptingSession {
39 project,
40 scripts: Vec::new(),
41 changes_by_buffer: HashMap::default(),
42 foreground_fns_tx,
43 _invoke_foreground_fns: cx.spawn(|this, cx| async move {
44 while let Some(foreground_fn) = foreground_fns_rx.next().await {
45 foreground_fn.0(this.clone(), cx.clone());
46 }
47 }),
48 }
49 }
50
51 pub fn changed_buffers(&self) -> impl ExactSizeIterator<Item = &Entity<Buffer>> {
52 self.changes_by_buffer.keys()
53 }
54
55 pub fn run_script(
56 &mut self,
57 script_src: String,
58 cx: &mut Context<Self>,
59 ) -> (ScriptId, Task<()>) {
60 let id = ScriptId(self.scripts.len() as u32);
61
62 let stdout = Arc::new(Mutex::new(String::new()));
63
64 let script = Script {
65 state: ScriptState::Running {
66 stdout: stdout.clone(),
67 },
68 };
69 self.scripts.push(script);
70
71 let task = self.run_lua(script_src, stdout, cx);
72
73 let task = cx.spawn(|session, mut cx| async move {
74 let result = task.await;
75
76 session
77 .update(&mut cx, |session, _cx| {
78 let script = session.get_mut(id);
79 let stdout = script.stdout_snapshot();
80
81 script.state = match result {
82 Ok(()) => ScriptState::Succeeded { stdout },
83 Err(error) => ScriptState::Failed { stdout, error },
84 };
85 })
86 .log_err();
87 });
88
89 (id, task)
90 }
91
92 fn run_lua(
93 &mut self,
94 script: String,
95 stdout: Arc<Mutex<String>>,
96 cx: &mut Context<Self>,
97 ) -> Task<anyhow::Result<()>> {
98 const SANDBOX_PREAMBLE: &str = include_str!("sandbox_preamble.lua");
99
100 // TODO Honor all worktrees instead of the first one
101 let worktree_info = self
102 .project
103 .read(cx)
104 .visible_worktrees(cx)
105 .next()
106 .map(|worktree| {
107 let worktree = worktree.read(cx);
108 (worktree.id(), worktree.abs_path())
109 });
110
111 let root_dir = worktree_info.as_ref().map(|(_, root)| root.clone());
112
113 let fs = self.project.read(cx).fs().clone();
114 let foreground_fns_tx = self.foreground_fns_tx.clone();
115
116 let task = cx.background_spawn({
117 let stdout = stdout.clone();
118
119 async move {
120 let lua = Lua::new();
121 lua.set_memory_limit(2 * 1024 * 1024 * 1024)?; // 2 GB
122 let globals = lua.globals();
123
124 // Use the project root dir as the script's current working dir.
125 if let Some(root_dir) = &root_dir {
126 if let Some(root_dir) = root_dir.to_str() {
127 globals.set("cwd", root_dir)?;
128 }
129 }
130
131 globals.set(
132 "sb_print",
133 lua.create_function({
134 let stdout = stdout.clone();
135 move |_, args: MultiValue| Self::print(args, &stdout)
136 })?,
137 )?;
138 globals.set(
139 "search",
140 lua.create_async_function({
141 let foreground_fns_tx = foreground_fns_tx.clone();
142 let fs = fs.clone();
143 move |lua, regex| {
144 let mut foreground_fns_tx = foreground_fns_tx.clone();
145 let fs = fs.clone();
146 async move {
147 Self::search(&lua, &mut foreground_fns_tx, fs, regex)
148 .await
149 .into_lua_err()
150 }
151 }
152 })?,
153 )?;
154 globals.set(
155 "outline",
156 lua.create_async_function({
157 let root_dir = root_dir.clone();
158 let foreground_fns_tx = foreground_fns_tx.clone();
159 move |_lua, path| {
160 let mut foreground_fns_tx = foreground_fns_tx.clone();
161 let root_dir = root_dir.clone();
162 async move {
163 Self::outline(root_dir, &mut foreground_fns_tx, path)
164 .await
165 .into_lua_err()
166 }
167 }
168 })?,
169 )?;
170 globals.set(
171 "sb_io_open",
172 lua.create_async_function({
173 let worktree_info = worktree_info.clone();
174 let foreground_fns_tx = foreground_fns_tx.clone();
175 move |lua, (path_str, mode)| {
176 let worktree_info = worktree_info.clone();
177 let mut foreground_fns_tx = foreground_fns_tx.clone();
178 let fs = fs.clone();
179 async move {
180 Self::io_open(
181 &lua,
182 worktree_info,
183 &mut foreground_fns_tx,
184 fs,
185 path_str,
186 mode,
187 )
188 .await
189 }
190 }
191 })?,
192 )?;
193 globals.set("user_script", script)?;
194
195 lua.load(SANDBOX_PREAMBLE).exec_async().await?;
196
197 anyhow::Ok(())
198 }
199 });
200
201 task
202 }
203
204 pub fn get(&self, script_id: ScriptId) -> &Script {
205 &self.scripts[script_id.0 as usize]
206 }
207
208 fn get_mut(&mut self, script_id: ScriptId) -> &mut Script {
209 &mut self.scripts[script_id.0 as usize]
210 }
211
212 /// Sandboxed print() function in Lua.
213 fn print(args: MultiValue, stdout: &Mutex<String>) -> mlua::Result<()> {
214 for (index, arg) in args.into_iter().enumerate() {
215 // Lua's `print()` prints tab characters between each argument.
216 if index > 0 {
217 stdout.lock().push('\t');
218 }
219
220 // If the argument's to_string() fails, have the whole function call fail.
221 stdout.lock().push_str(&arg.to_string()?);
222 }
223 stdout.lock().push('\n');
224
225 Ok(())
226 }
227
228 /// Sandboxed io.open() function in Lua.
229 async fn io_open(
230 lua: &Lua,
231 worktree_info: Option<(WorktreeId, Arc<Path>)>,
232 foreground_tx: &mut mpsc::Sender<ForegroundFn>,
233 fs: Arc<dyn Fs>,
234 path_str: String,
235 mode: Option<String>,
236 ) -> mlua::Result<(Option<Table>, String)> {
237 let (worktree_id, root_dir) = worktree_info
238 .ok_or_else(|| mlua::Error::runtime("cannot open file without a root directory"))?;
239
240 let mode = mode.unwrap_or_else(|| "r".to_string());
241
242 // Parse the mode string to determine read/write permissions
243 let read_perm = mode.contains('r');
244 let write_perm = mode.contains('w') || mode.contains('a') || mode.contains('+');
245 let append = mode.contains('a');
246 let truncate = mode.contains('w');
247
248 // This will be the Lua value returned from the `open` function.
249 let file = lua.create_table()?;
250
251 // Store file metadata in the file
252 file.set("__mode", mode.clone())?;
253 file.set("__read_perm", read_perm)?;
254 file.set("__write_perm", write_perm)?;
255
256 let path = match Self::parse_abs_path_in_root_dir(&root_dir, &path_str) {
257 Ok(path) => path,
258 Err(err) => return Ok((None, format!("{err}"))),
259 };
260
261 let project_path = ProjectPath {
262 worktree_id,
263 path: Path::new(&path_str).into(),
264 };
265
266 // flush / close method
267 let flush_fn = {
268 let project_path = project_path.clone();
269 let foreground_tx = foreground_tx.clone();
270 lua.create_async_function(move |_lua, file_userdata: mlua::Table| {
271 let project_path = project_path.clone();
272 let mut foreground_tx = foreground_tx.clone();
273 async move {
274 Self::io_file_flush(file_userdata, project_path, &mut foreground_tx).await
275 }
276 })?
277 };
278 file.set("flush", flush_fn.clone())?;
279 // We don't really hold files open, so we only need to flush on close
280 file.set("close", flush_fn)?;
281
282 // If it's a directory, give it a custom read() and return early.
283 if fs.is_dir(&path).await {
284 return Self::io_file_dir(lua, fs, file, &path).await;
285 }
286
287 let mut file_content = Vec::new();
288
289 if !truncate {
290 // Try to read existing content if we're not truncating
291 match Self::read_buffer(project_path.clone(), foreground_tx).await {
292 Ok(content) => file_content = content.into_bytes(),
293 Err(e) => return Ok((None, format!("Error reading file: {}", e))),
294 }
295 }
296
297 // If in append mode, position should be at the end
298 let position = if append { file_content.len() } else { 0 };
299 file.set("__position", position)?;
300 file.set(
301 "__content",
302 lua.create_userdata(FileContent(Arc::new(Mutex::new(file_content))))?,
303 )?;
304
305 // Create file methods
306
307 // read method
308 let read_fn = lua.create_function(Self::io_file_read)?;
309 file.set("read", read_fn)?;
310
311 // lines method
312 let lines_fn = lua.create_function(Self::io_file_lines)?;
313 file.set("lines", lines_fn)?;
314
315 // write method
316 let write_fn = lua.create_function(Self::io_file_write)?;
317 file.set("write", write_fn)?;
318
319 // If we got this far, the file was opened successfully
320 Ok((Some(file), String::new()))
321 }
322
323 async fn read_buffer(
324 project_path: ProjectPath,
325 foreground_tx: &mut mpsc::Sender<ForegroundFn>,
326 ) -> anyhow::Result<String> {
327 Self::run_foreground_fn(
328 "read file from buffer",
329 foreground_tx,
330 Box::new(move |session, mut cx| {
331 session.update(&mut cx, |session, cx| {
332 let open_buffer_task = session
333 .project
334 .update(cx, |project, cx| project.open_buffer(project_path, cx));
335
336 cx.spawn(|_, cx| async move {
337 let buffer = open_buffer_task.await?;
338
339 let text = buffer.read_with(&cx, |buffer, _cx| buffer.text())?;
340 Ok(text)
341 })
342 })
343 }),
344 )
345 .await??
346 .await
347 }
348
349 async fn io_file_flush(
350 file_userdata: mlua::Table,
351 project_path: ProjectPath,
352 foreground_tx: &mut mpsc::Sender<ForegroundFn>,
353 ) -> mlua::Result<bool> {
354 let write_perm = file_userdata.get::<bool>("__write_perm")?;
355
356 if write_perm {
357 let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
358 let content_ref = content.borrow::<FileContent>()?;
359 let text = {
360 let mut content_vec = content_ref.0.lock();
361 let content_vec = std::mem::take(&mut *content_vec);
362 String::from_utf8(content_vec).into_lua_err()?
363 };
364
365 Self::write_to_buffer(project_path, text, foreground_tx)
366 .await
367 .into_lua_err()?;
368 }
369
370 Ok(true)
371 }
372
373 async fn write_to_buffer(
374 project_path: ProjectPath,
375 text: String,
376 foreground_tx: &mut mpsc::Sender<ForegroundFn>,
377 ) -> anyhow::Result<()> {
378 Self::run_foreground_fn(
379 "write to buffer",
380 foreground_tx,
381 Box::new(move |session, mut cx| {
382 session.update(&mut cx, |session, cx| {
383 let open_buffer_task = session
384 .project
385 .update(cx, |project, cx| project.open_buffer(project_path, cx));
386
387 cx.spawn(move |session, mut cx| async move {
388 let buffer = open_buffer_task.await?;
389
390 let diff = buffer
391 .update(&mut cx, |buffer, cx| buffer.diff(text, cx))?
392 .await;
393
394 let edit_ids = buffer.update(&mut cx, |buffer, cx| {
395 buffer.finalize_last_transaction();
396 buffer.apply_diff(diff, cx);
397 let transaction = buffer.finalize_last_transaction();
398 transaction
399 .map_or(Vec::new(), |transaction| transaction.edit_ids.clone())
400 })?;
401
402 session
403 .update(&mut cx, {
404 let buffer = buffer.clone();
405
406 |session, cx| {
407 session
408 .project
409 .update(cx, |project, cx| project.save_buffer(buffer, cx))
410 }
411 })?
412 .await?;
413
414 let snapshot = buffer.read_with(&cx, |buffer, _| buffer.snapshot())?;
415
416 // If we saved successfully, mark buffer as changed
417 let buffer_without_changes =
418 buffer.update(&mut cx, |buffer, cx| buffer.branch(cx))?;
419 session
420 .update(&mut cx, |session, cx| {
421 let changed_buffer = session
422 .changes_by_buffer
423 .entry(buffer)
424 .or_insert_with(|| BufferChanges {
425 diff: cx.new(|cx| BufferDiff::new(&snapshot, cx)),
426 edit_ids: Vec::new(),
427 });
428 changed_buffer.edit_ids.extend(edit_ids);
429 let operations_to_undo = changed_buffer
430 .edit_ids
431 .iter()
432 .map(|edit_id| (*edit_id, u32::MAX))
433 .collect::<HashMap<_, _>>();
434 buffer_without_changes.update(cx, |buffer, cx| {
435 buffer.undo_operations(operations_to_undo, cx);
436 });
437 changed_buffer.diff.update(cx, |diff, cx| {
438 diff.set_base_text(buffer_without_changes, snapshot.text, cx)
439 })
440 })?
441 .await?;
442
443 Ok(())
444 })
445 })
446 }),
447 )
448 .await??
449 .await
450 }
451
452 async fn io_file_dir(
453 lua: &Lua,
454 fs: Arc<dyn Fs>,
455 file: Table,
456 path: &Path,
457 ) -> mlua::Result<(Option<Table>, String)> {
458 // Create a special directory handle
459 file.set("__is_directory", true)?;
460
461 // Store directory entries
462 let entries = match fs.read_dir(&path).await {
463 Ok(entries) => {
464 let mut entry_names = Vec::new();
465
466 // Process the stream of directory entries
467 pin_mut!(entries);
468 while let Some(Ok(entry_result)) = entries.next().await {
469 if let Some(file_name) = entry_result.file_name() {
470 entry_names.push(file_name.to_string_lossy().into_owned());
471 }
472 }
473
474 entry_names
475 }
476 Err(e) => return Ok((None, format!("Error reading directory: {}", e))),
477 };
478
479 // Save the list of entries
480 file.set("__dir_entries", entries)?;
481 file.set("__dir_position", 0usize)?;
482
483 // Create a directory-specific read function
484 let read_fn = lua.create_function(|_lua, file_userdata: mlua::Table| {
485 let position = file_userdata.get::<usize>("__dir_position")?;
486 let entries = file_userdata.get::<Vec<String>>("__dir_entries")?;
487
488 if position >= entries.len() {
489 return Ok(None); // No more entries
490 }
491
492 let entry = entries[position].clone();
493 file_userdata.set("__dir_position", position + 1)?;
494
495 Ok(Some(entry))
496 })?;
497 file.set("read", read_fn)?;
498
499 // If we got this far, the directory was opened successfully
500 return Ok((Some(file), String::new()));
501 }
502
503 fn io_file_read(
504 lua: &Lua,
505 (file_userdata, format): (Table, Option<mlua::Value>),
506 ) -> mlua::Result<Option<mlua::String>> {
507 let read_perm = file_userdata.get::<bool>("__read_perm")?;
508 if !read_perm {
509 return Err(mlua::Error::runtime("File not open for reading"));
510 }
511
512 let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
513 let position = file_userdata.get::<usize>("__position")?;
514 let content_ref = content.borrow::<FileContent>()?;
515 let content = content_ref.0.lock();
516
517 if position >= content.len() {
518 return Ok(None); // EOF
519 }
520
521 let (result, new_position) = match Self::io_file_read_format(format)? {
522 FileReadFormat::All => {
523 // Read entire file from current position
524 let result = content[position..].to_vec();
525 (Some(result), content.len())
526 }
527 FileReadFormat::Line => {
528 if let Some(next_newline_ix) = content[position..].iter().position(|c| *c == b'\n')
529 {
530 let mut line = content[position..position + next_newline_ix].to_vec();
531 if line.ends_with(b"\r") {
532 line.pop();
533 }
534 (Some(line), position + next_newline_ix + 1)
535 } else if position < content.len() {
536 let line = content[position..].to_vec();
537 (Some(line), content.len())
538 } else {
539 (None, position) // EOF
540 }
541 }
542 FileReadFormat::LineWithLineFeed => {
543 if position < content.len() {
544 let next_line_ix = content[position..]
545 .iter()
546 .position(|c| *c == b'\n')
547 .map_or(content.len(), |ix| position + ix + 1);
548 let line = content[position..next_line_ix].to_vec();
549 (Some(line), next_line_ix)
550 } else {
551 (None, position) // EOF
552 }
553 }
554 FileReadFormat::Bytes(n) => {
555 let end = std::cmp::min(position + n, content.len());
556 let result = content[position..end].to_vec();
557 (Some(result), end)
558 }
559 };
560
561 // Update the position in the file userdata
562 if new_position != position {
563 file_userdata.set("__position", new_position)?;
564 }
565
566 // Convert the result to a Lua string
567 match result {
568 Some(bytes) => Ok(Some(lua.create_string(bytes)?)),
569 None => Ok(None),
570 }
571 }
572
573 fn io_file_lines(lua: &Lua, file_userdata: Table) -> mlua::Result<mlua::Function> {
574 let read_perm = file_userdata.get::<bool>("__read_perm")?;
575 if !read_perm {
576 return Err(mlua::Error::runtime("File not open for reading"));
577 }
578
579 lua.create_function::<_, _, mlua::Value>(move |lua, _: ()| {
580 file_userdata.call_method("read", lua.create_string("*l")?)
581 })
582 }
583
584 fn io_file_read_format(format: Option<mlua::Value>) -> mlua::Result<FileReadFormat> {
585 let format = match format {
586 Some(mlua::Value::String(s)) => {
587 let lossy_string = s.to_string_lossy();
588 let format_str: &str = lossy_string.as_ref();
589
590 // Only consider the first 2 bytes, since it's common to pass e.g. "*all" instead of "*a"
591 match &format_str[0..2] {
592 "*a" => FileReadFormat::All,
593 "*l" => FileReadFormat::Line,
594 "*L" => FileReadFormat::LineWithLineFeed,
595 "*n" => {
596 // Try to parse as a number (number of bytes to read)
597 match format_str.parse::<usize>() {
598 Ok(n) => FileReadFormat::Bytes(n),
599 Err(_) => {
600 return Err(mlua::Error::runtime(format!(
601 "Invalid format: {}",
602 format_str
603 )))
604 }
605 }
606 }
607 _ => {
608 return Err(mlua::Error::runtime(format!(
609 "Unsupported format: {}",
610 format_str
611 )))
612 }
613 }
614 }
615 Some(mlua::Value::Number(n)) => FileReadFormat::Bytes(n as usize),
616 Some(mlua::Value::Integer(n)) => FileReadFormat::Bytes(n as usize),
617 Some(value) => {
618 return Err(mlua::Error::runtime(format!(
619 "Invalid file format {:?}",
620 value
621 )))
622 }
623 None => FileReadFormat::Line, // Default is to read a line
624 };
625
626 Ok(format)
627 }
628
629 fn io_file_write(
630 _lua: &Lua,
631 (file_userdata, text): (Table, mlua::String),
632 ) -> mlua::Result<bool> {
633 let write_perm = file_userdata.get::<bool>("__write_perm")?;
634 if !write_perm {
635 return Err(mlua::Error::runtime("File not open for writing"));
636 }
637
638 let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
639 let position = file_userdata.get::<usize>("__position")?;
640 let content_ref = content.borrow::<FileContent>()?;
641 let mut content_vec = content_ref.0.lock();
642
643 let bytes = text.as_bytes();
644
645 // Ensure the vector has enough capacity
646 if position + bytes.len() > content_vec.len() {
647 content_vec.resize(position + bytes.len(), 0);
648 }
649
650 // Write the bytes
651 for (i, &byte) in bytes.iter().enumerate() {
652 content_vec[position + i] = byte;
653 }
654
655 // Update position
656 let new_position = position + bytes.len();
657 file_userdata.set("__position", new_position)?;
658
659 Ok(true)
660 }
661
662 async fn search(
663 lua: &Lua,
664 foreground_tx: &mut mpsc::Sender<ForegroundFn>,
665 fs: Arc<dyn Fs>,
666 regex: String,
667 ) -> anyhow::Result<Table> {
668 // TODO: Allow specification of these options.
669 let search_query = SearchQuery::regex(
670 ®ex,
671 false,
672 false,
673 false,
674 PathMatcher::default(),
675 PathMatcher::default(),
676 None,
677 );
678 let search_query = match search_query {
679 Ok(query) => query,
680 Err(e) => return Err(anyhow!("Invalid search query: {}", e)),
681 };
682
683 // TODO: Should use `search_query.regex`. The tool description should also be updated,
684 // as it specifies standard regex.
685 let search_regex = match Regex::new(®ex) {
686 Ok(re) => re,
687 Err(e) => return Err(anyhow!("Invalid regex: {}", e)),
688 };
689
690 let mut abs_paths_rx = Self::find_search_candidates(search_query, foreground_tx).await?;
691
692 let mut search_results: Vec<Table> = Vec::new();
693 while let Some(path) = abs_paths_rx.next().await {
694 // Skip files larger than 1MB
695 if let Ok(Some(metadata)) = fs.metadata(&path).await {
696 if metadata.len > 1_000_000 {
697 continue;
698 }
699 }
700
701 // Attempt to read the file as text
702 if let Ok(content) = fs.load(&path).await {
703 let mut matches = Vec::new();
704
705 // Find all regex matches in the content
706 for capture in search_regex.find_iter(&content) {
707 matches.push(capture.as_str().to_string());
708 }
709
710 // If we found matches, create a result entry
711 if !matches.is_empty() {
712 let result_entry = lua.create_table()?;
713 result_entry.set("path", path.to_string_lossy().to_string())?;
714
715 let matches_table = lua.create_table()?;
716 for (ix, m) in matches.iter().enumerate() {
717 matches_table.set(ix + 1, m.clone())?;
718 }
719 result_entry.set("matches", matches_table)?;
720
721 search_results.push(result_entry);
722 }
723 }
724 }
725
726 // Create a table to hold our results
727 let results_table = lua.create_table()?;
728 for (ix, entry) in search_results.into_iter().enumerate() {
729 results_table.set(ix + 1, entry)?;
730 }
731
732 Ok(results_table)
733 }
734
735 async fn find_search_candidates(
736 search_query: SearchQuery,
737 foreground_tx: &mut mpsc::Sender<ForegroundFn>,
738 ) -> anyhow::Result<mpsc::UnboundedReceiver<PathBuf>> {
739 Self::run_foreground_fn(
740 "finding search file candidates",
741 foreground_tx,
742 Box::new(move |session, mut cx| {
743 session.update(&mut cx, |session, cx| {
744 session.project.update(cx, |project, cx| {
745 project.worktree_store().update(cx, |worktree_store, cx| {
746 // TODO: Better limit? For now this is the same as
747 // MAX_SEARCH_RESULT_FILES.
748 let limit = 5000;
749 // TODO: Providing non-empty open_entries can make this a bit more
750 // efficient as it can skip checking that these paths are textual.
751 let open_entries = HashSet::default();
752 let candidates = worktree_store.find_search_candidates(
753 search_query,
754 limit,
755 open_entries,
756 project.fs().clone(),
757 cx,
758 );
759 let (abs_paths_tx, abs_paths_rx) = mpsc::unbounded();
760 cx.spawn(|worktree_store, cx| async move {
761 pin_mut!(candidates);
762
763 while let Some(project_path) = candidates.next().await {
764 worktree_store.read_with(&cx, |worktree_store, cx| {
765 if let Some(worktree) = worktree_store
766 .worktree_for_id(project_path.worktree_id, cx)
767 {
768 if let Some(abs_path) = worktree
769 .read(cx)
770 .absolutize(&project_path.path)
771 .log_err()
772 {
773 abs_paths_tx.unbounded_send(abs_path)?;
774 }
775 }
776 anyhow::Ok(())
777 })??;
778 }
779 anyhow::Ok(())
780 })
781 .detach();
782 abs_paths_rx
783 })
784 })
785 })
786 }),
787 )
788 .await?
789 }
790
791 async fn outline(
792 root_dir: Option<Arc<Path>>,
793 foreground_tx: &mut mpsc::Sender<ForegroundFn>,
794 path_str: String,
795 ) -> anyhow::Result<String> {
796 let root_dir = root_dir
797 .ok_or_else(|| mlua::Error::runtime("cannot get outline without a root directory"))?;
798 let path = Self::parse_abs_path_in_root_dir(&root_dir, &path_str)?;
799 let outline = Self::run_foreground_fn(
800 "getting code outline",
801 foreground_tx,
802 Box::new(move |session, cx| {
803 cx.spawn(move |mut cx| async move {
804 // TODO: This will not use file content from `fs_changes`. It will also reflect
805 // user changes that have not been saved.
806 let buffer = session
807 .update(&mut cx, |session, cx| {
808 session
809 .project
810 .update(cx, |project, cx| project.open_local_buffer(&path, cx))
811 })?
812 .await?;
813 buffer.update(&mut cx, |buffer, _cx| {
814 if let Some(outline) = buffer.snapshot().outline(None) {
815 Ok(outline)
816 } else {
817 Err(anyhow!("No outline for file {path_str}"))
818 }
819 })
820 })
821 }),
822 )
823 .await?
824 .await??;
825
826 Ok(outline
827 .items
828 .into_iter()
829 .map(|item| {
830 if item.text.contains('\n') {
831 log::error!("Outline item unexpectedly contains newline");
832 }
833 format!("{}{}", " ".repeat(item.depth), item.text)
834 })
835 .collect::<Vec<String>>()
836 .join("\n"))
837 }
838
839 async fn run_foreground_fn<R: Send + 'static>(
840 description: &str,
841 foreground_tx: &mut mpsc::Sender<ForegroundFn>,
842 function: Box<dyn FnOnce(WeakEntity<Self>, AsyncApp) -> R + Send>,
843 ) -> anyhow::Result<R> {
844 let (response_tx, response_rx) = oneshot::channel();
845 let send_result = foreground_tx
846 .send(ForegroundFn(Box::new(move |this, cx| {
847 response_tx.send(function(this, cx)).ok();
848 })))
849 .await;
850 match send_result {
851 Ok(()) => (),
852 Err(err) => {
853 return Err(anyhow::Error::new(err).context(format!(
854 "Internal error while enqueuing work for {description}"
855 )));
856 }
857 }
858 match response_rx.await {
859 Ok(result) => Ok(result),
860 Err(oneshot::Canceled) => Err(anyhow!(
861 "Internal error: response oneshot was canceled while {description}."
862 )),
863 }
864 }
865
866 fn parse_abs_path_in_root_dir(root_dir: &Path, path_str: &str) -> anyhow::Result<PathBuf> {
867 let path = Path::new(&path_str);
868 if path.is_absolute() {
869 // Check if path starts with root_dir prefix without resolving symlinks
870 if path.starts_with(&root_dir) {
871 Ok(path.to_path_buf())
872 } else {
873 Err(anyhow!(
874 "Error: Absolute path {} is outside the current working directory",
875 path_str
876 ))
877 }
878 } else {
879 // TODO: Does use of `../` break sandbox - is path canonicalization needed?
880 Ok(root_dir.join(path))
881 }
882 }
883}
884
885enum FileReadFormat {
886 All,
887 Line,
888 LineWithLineFeed,
889 Bytes(usize),
890}
891
892struct FileContent(Arc<Mutex<Vec<u8>>>);
893
894impl UserData for FileContent {
895 fn add_methods<M: UserDataMethods<Self>>(_methods: &mut M) {
896 // FileContent doesn't have any methods so far.
897 }
898}
899
900#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
901pub struct ScriptId(u32);
902
903pub struct Script {
904 pub state: ScriptState,
905}
906
907#[derive(Debug)]
908pub enum ScriptState {
909 Running {
910 stdout: Arc<Mutex<String>>,
911 },
912 Succeeded {
913 stdout: String,
914 },
915 Failed {
916 stdout: String,
917 error: anyhow::Error,
918 },
919}
920
921impl Script {
922 /// If exited, returns a message with the output for the LLM
923 pub fn output_message_for_llm(&self) -> Option<String> {
924 match &self.state {
925 ScriptState::Running { .. } => None,
926 ScriptState::Succeeded { stdout } => {
927 format!("Here's the script output:\n{}", stdout).into()
928 }
929 ScriptState::Failed { stdout, error } => format!(
930 "The script failed with:\n{}\n\nHere's the output it managed to print:\n{}",
931 error, stdout
932 )
933 .into(),
934 }
935 }
936
937 /// Get a snapshot of the script's stdout
938 pub fn stdout_snapshot(&self) -> String {
939 match &self.state {
940 ScriptState::Running { stdout } => stdout.lock().clone(),
941 ScriptState::Succeeded { stdout } => stdout.clone(),
942 ScriptState::Failed { stdout, .. } => stdout.clone(),
943 }
944 }
945}
946
947#[cfg(test)]
948mod tests {
949 use gpui::TestAppContext;
950 use project::FakeFs;
951 use serde_json::json;
952 use settings::SettingsStore;
953 use util::path;
954
955 use super::*;
956
957 #[gpui::test]
958 async fn test_print(cx: &mut TestAppContext) {
959 let script = r#"
960 print("Hello", "world!")
961 print("Goodbye", "moon!")
962 "#;
963
964 let test_session = TestSession::init(cx).await;
965 let output = test_session.test_success(script, cx).await;
966 assert_eq!(output, "Hello\tworld!\nGoodbye\tmoon!\n");
967 }
968
969 // search
970
971 #[gpui::test]
972 async fn test_search(cx: &mut TestAppContext) {
973 let script = r#"
974 local results = search("world")
975 for i, result in ipairs(results) do
976 print("File: " .. result.path)
977 print("Matches:")
978 for j, match in ipairs(result.matches) do
979 print(" " .. match)
980 end
981 end
982 "#;
983
984 let test_session = TestSession::init(cx).await;
985 let output = test_session.test_success(script, cx).await;
986 assert_eq!(
987 output,
988 concat!("File: ", path!("/file1.txt"), "\nMatches:\n world\n")
989 );
990 }
991
992 // io.open
993
994 #[gpui::test]
995 async fn test_open_and_read_file(cx: &mut TestAppContext) {
996 let script = r#"
997 local file = io.open("file1.txt", "r")
998 local content = file:read()
999 print("Content:", content)
1000 file:close()
1001 "#;
1002
1003 let test_session = TestSession::init(cx).await;
1004 let output = test_session.test_success(script, cx).await;
1005 assert_eq!(output, "Content:\tHello world!\n");
1006 assert_eq!(test_session.diff(cx), Vec::new());
1007 }
1008
1009 #[gpui::test]
1010 async fn test_lines_iterator(cx: &mut TestAppContext) {
1011 let script = r#"
1012 -- Create a test file with multiple lines
1013 local file = io.open("lines_test.txt", "w")
1014 file:write("Line 1\nLine 2\nLine 3\nLine 4\nLine 5")
1015 file:close()
1016
1017 -- Read it back using the lines iterator
1018 local read_file = io.open("lines_test.txt", "r")
1019 local count = 0
1020 for line in read_file:lines() do
1021 count = count + 1
1022 print(count .. ": " .. line)
1023 end
1024 read_file:close()
1025
1026 print("Total lines:", count)
1027 "#;
1028
1029 let test_session = TestSession::init(cx).await;
1030 let output = test_session.test_success(script, cx).await;
1031 assert_eq!(
1032 output,
1033 "1: Line 1\n2: Line 2\n3: Line 3\n4: Line 4\n5: Line 5\nTotal lines:\t5\n"
1034 );
1035 }
1036
1037 #[gpui::test]
1038 async fn test_read_write_roundtrip(cx: &mut TestAppContext) {
1039 let script = r#"
1040 local file = io.open("file1.txt", "w")
1041 file:write("This is new content")
1042 file:close()
1043
1044 -- Read back to verify
1045 local read_file = io.open("file1.txt", "r")
1046 local content = read_file:read("*a")
1047 print("Written content:", content)
1048 read_file:close()
1049 "#;
1050
1051 let test_session = TestSession::init(cx).await;
1052 let output = test_session.test_success(script, cx).await;
1053 assert_eq!(output, "Written content:\tThis is new content\n");
1054 assert_eq!(
1055 test_session.diff(cx),
1056 vec![(
1057 PathBuf::from("file1.txt"),
1058 vec![(
1059 "Hello world!\n".to_string(),
1060 "This is new content".to_string()
1061 )]
1062 )]
1063 );
1064 }
1065
1066 #[gpui::test]
1067 async fn test_multiple_writes(cx: &mut TestAppContext) {
1068 let script = r#"
1069 -- Test writing to a file multiple times
1070 local file = io.open("multiwrite.txt", "w")
1071 file:write("First line\n")
1072 file:write("Second line\n")
1073 file:write("Third line")
1074 file:close()
1075
1076 -- Read back to verify
1077 local read_file = io.open("multiwrite.txt", "r")
1078 if read_file then
1079 local content = read_file:read("*a")
1080 print("Full content:", content)
1081 read_file:close()
1082 end
1083 "#;
1084
1085 let test_session = TestSession::init(cx).await;
1086 let output = test_session.test_success(script, cx).await;
1087 assert_eq!(
1088 output,
1089 "Full content:\tFirst line\nSecond line\nThird line\n"
1090 );
1091 assert_eq!(
1092 test_session.diff(cx),
1093 vec![(
1094 PathBuf::from("multiwrite.txt"),
1095 vec![(
1096 "".to_string(),
1097 "First line\nSecond line\nThird line".to_string()
1098 )]
1099 )]
1100 );
1101 }
1102
1103 #[gpui::test]
1104 async fn test_multiple_writes_diff_handles(cx: &mut TestAppContext) {
1105 let script = r#"
1106 -- Write to a file
1107 local file1 = io.open("multi_open.txt", "w")
1108 file1:write("Content written by first handle\n")
1109 file1:close()
1110
1111 -- Open it again and add more content
1112 local file2 = io.open("multi_open.txt", "w")
1113 file2:write("Content written by second handle\n")
1114 file2:close()
1115
1116 -- Open it a third time and read
1117 local file3 = io.open("multi_open.txt", "r")
1118 local content = file3:read("*a")
1119 print("Final content:", content)
1120 file3:close()
1121 "#;
1122
1123 let test_session = TestSession::init(cx).await;
1124 let output = test_session.test_success(script, cx).await;
1125 assert_eq!(
1126 output,
1127 "Final content:\tContent written by second handle\n\n"
1128 );
1129 assert_eq!(
1130 test_session.diff(cx),
1131 vec![(
1132 PathBuf::from("multi_open.txt"),
1133 vec![(
1134 "".to_string(),
1135 "Content written by second handle\n".to_string()
1136 )]
1137 )]
1138 );
1139 }
1140
1141 #[gpui::test]
1142 async fn test_append_mode(cx: &mut TestAppContext) {
1143 let script = r#"
1144 -- Append more content
1145 file = io.open("file1.txt", "a")
1146 file:write("Appended content\n")
1147 file:close()
1148
1149 -- Add even more
1150 file = io.open("file1.txt", "a")
1151 file:write("More appended content")
1152 file:close()
1153
1154 -- Read back to verify
1155 local read_file = io.open("file1.txt", "r")
1156 local content = read_file:read("*a")
1157 print("Content after appends:", content)
1158 read_file:close()
1159 "#;
1160
1161 let test_session = TestSession::init(cx).await;
1162 let output = test_session.test_success(script, cx).await;
1163 assert_eq!(
1164 output,
1165 "Content after appends:\tHello world!\nAppended content\nMore appended content\n"
1166 );
1167 assert_eq!(
1168 test_session.diff(cx),
1169 vec![(
1170 PathBuf::from("file1.txt"),
1171 vec![(
1172 "".to_string(),
1173 "Appended content\nMore appended content".to_string()
1174 )]
1175 )]
1176 );
1177 }
1178
1179 #[gpui::test]
1180 async fn test_read_formats(cx: &mut TestAppContext) {
1181 let script = r#"
1182 local file = io.open("multiline.txt", "w")
1183 file:write("Line 1\nLine 2\nLine 3")
1184 file:close()
1185
1186 -- Test "*a" (all)
1187 local f = io.open("multiline.txt", "r")
1188 local all = f:read("*a")
1189 print("All:", all)
1190 f:close()
1191
1192 -- Test "*l" (line)
1193 f = io.open("multiline.txt", "r")
1194 local line1 = f:read("*l")
1195 local line2 = f:read("*l")
1196 local line3 = f:read("*l")
1197 print("Line 1:", line1)
1198 print("Line 2:", line2)
1199 print("Line 3:", line3)
1200 f:close()
1201
1202 -- Test "*L" (line with newline)
1203 f = io.open("multiline.txt", "r")
1204 local line_with_nl = f:read("*L")
1205 print("Line with newline length:", #line_with_nl)
1206 print("Last char:", string.byte(line_with_nl, #line_with_nl))
1207 f:close()
1208
1209 -- Test number of bytes
1210 f = io.open("multiline.txt", "r")
1211 local bytes5 = f:read(5)
1212 print("5 bytes:", bytes5)
1213 f:close()
1214 "#;
1215
1216 let test_session = TestSession::init(cx).await;
1217 let output = test_session.test_success(script, cx).await;
1218 println!("{}", &output);
1219 assert!(output.contains("All:\tLine 1\nLine 2\nLine 3"));
1220 assert!(output.contains("Line 1:\tLine 1"));
1221 assert!(output.contains("Line 2:\tLine 2"));
1222 assert!(output.contains("Line 3:\tLine 3"));
1223 assert!(output.contains("Line with newline length:\t7"));
1224 assert!(output.contains("Last char:\t10")); // LF
1225 assert!(output.contains("5 bytes:\tLine "));
1226 assert_eq!(
1227 test_session.diff(cx),
1228 vec![(
1229 PathBuf::from("multiline.txt"),
1230 vec![("".to_string(), "Line 1\nLine 2\nLine 3".to_string())]
1231 )]
1232 );
1233 }
1234
1235 // helpers
1236
1237 struct TestSession {
1238 session: Entity<ScriptingSession>,
1239 }
1240
1241 impl TestSession {
1242 async fn init(cx: &mut TestAppContext) -> Self {
1243 let settings_store = cx.update(SettingsStore::test);
1244 cx.set_global(settings_store);
1245 cx.update(Project::init_settings);
1246 cx.update(language::init);
1247
1248 let fs = FakeFs::new(cx.executor());
1249 fs.insert_tree(
1250 path!("/"),
1251 json!({
1252 "file1.txt": "Hello world!\n",
1253 "file2.txt": "Goodbye moon!\n"
1254 }),
1255 )
1256 .await;
1257
1258 let project = Project::test(fs.clone(), [Path::new(path!("/"))], cx).await;
1259 let session = cx.new(|cx| ScriptingSession::new(project, cx));
1260
1261 TestSession { session }
1262 }
1263
1264 async fn test_success(&self, source: &str, cx: &mut TestAppContext) -> String {
1265 let script_id = self.run_script(source, cx).await;
1266
1267 self.session.read_with(cx, |session, _cx| {
1268 let script = session.get(script_id);
1269 let stdout = script.stdout_snapshot();
1270
1271 if let ScriptState::Failed { error, .. } = &script.state {
1272 panic!("Script failed:\n{}\n\n{}", error, stdout);
1273 }
1274
1275 stdout
1276 })
1277 }
1278
1279 fn diff(&self, cx: &mut TestAppContext) -> Vec<(PathBuf, Vec<(String, String)>)> {
1280 self.session.read_with(cx, |session, cx| {
1281 session
1282 .changes_by_buffer
1283 .iter()
1284 .map(|(buffer, changes)| {
1285 let snapshot = buffer.read(cx).snapshot();
1286 let diff = changes.diff.read(cx);
1287 let hunks = diff.hunks(&snapshot, cx);
1288 let path = buffer.read(cx).file().unwrap().path().clone();
1289 let diffs = hunks
1290 .map(|hunk| {
1291 let old_text = diff
1292 .base_text()
1293 .text_for_range(hunk.diff_base_byte_range)
1294 .collect::<String>();
1295 let new_text =
1296 snapshot.text_for_range(hunk.range).collect::<String>();
1297 (old_text, new_text)
1298 })
1299 .collect();
1300 (path.to_path_buf(), diffs)
1301 })
1302 .collect()
1303 })
1304 }
1305
1306 async fn run_script(&self, source: &str, cx: &mut TestAppContext) -> ScriptId {
1307 let (script_id, task) = self
1308 .session
1309 .update(cx, |session, cx| session.run_script(source.to_string(), cx));
1310
1311 task.await;
1312
1313 script_id
1314 }
1315 }
1316}