1use anyhow::anyhow;
2use buffer_diff::BufferDiff;
3use collections::{HashMap, HashSet};
4use futures::{
5 channel::{mpsc, oneshot},
6 pin_mut, SinkExt, StreamExt,
7};
8use gpui::{AppContext, AsyncApp, Context, Entity, Task, WeakEntity};
9use language::Buffer;
10use mlua::{ExternalResult, Lua, MultiValue, ObjectLike, Table, UserData, UserDataMethods};
11use parking_lot::Mutex;
12use project::{search::SearchQuery, Fs, Project, ProjectPath, WorktreeId};
13use regex::Regex;
14use std::{
15 path::{Path, PathBuf},
16 sync::Arc,
17};
18use util::{paths::PathMatcher, ResultExt};
19
20struct ForegroundFn(Box<dyn FnOnce(WeakEntity<ScriptingSession>, AsyncApp) + Send>);
21
22struct BufferChanges {
23 diff: Entity<BufferDiff>,
24 edit_ids: Vec<clock::Lamport>,
25}
26
27pub struct ScriptingSession {
28 project: Entity<Project>,
29 scripts: Vec<Script>,
30 changes_by_buffer: HashMap<Entity<Buffer>, BufferChanges>,
31 foreground_fns_tx: mpsc::Sender<ForegroundFn>,
32 _invoke_foreground_fns: Task<()>,
33}
34
35impl ScriptingSession {
36 pub fn new(project: Entity<Project>, cx: &mut Context<Self>) -> Self {
37 let (foreground_fns_tx, mut foreground_fns_rx) = mpsc::channel(128);
38 ScriptingSession {
39 project,
40 scripts: Vec::new(),
41 changes_by_buffer: HashMap::default(),
42 foreground_fns_tx,
43 _invoke_foreground_fns: cx.spawn(async move |this, cx| {
44 while let Some(foreground_fn) = foreground_fns_rx.next().await {
45 foreground_fn.0(this.clone(), cx.clone());
46 }
47 }),
48 }
49 }
50
51 pub fn changed_buffers(&self) -> impl ExactSizeIterator<Item = &Entity<Buffer>> {
52 self.changes_by_buffer.keys()
53 }
54
55 pub fn run_script(
56 &mut self,
57 script_src: String,
58 cx: &mut Context<Self>,
59 ) -> (ScriptId, Task<()>) {
60 let id = ScriptId(self.scripts.len() as u32);
61
62 let stdout = Arc::new(Mutex::new(String::new()));
63
64 let script = Script {
65 state: ScriptState::Running {
66 stdout: stdout.clone(),
67 },
68 };
69 self.scripts.push(script);
70
71 let task = self.run_lua(script_src, stdout, cx);
72
73 let task = cx.spawn(async move |session, cx| {
74 let result = task.await;
75
76 session
77 .update(cx, |session, _cx| {
78 let script = session.get_mut(id);
79 let stdout = script.stdout_snapshot();
80
81 script.state = match result {
82 Ok(()) => ScriptState::Succeeded { stdout },
83 Err(error) => ScriptState::Failed { stdout, error },
84 };
85 })
86 .log_err();
87 });
88
89 (id, task)
90 }
91
92 fn run_lua(
93 &mut self,
94 script: String,
95 stdout: Arc<Mutex<String>>,
96 cx: &mut Context<Self>,
97 ) -> Task<anyhow::Result<()>> {
98 const SANDBOX_PREAMBLE: &str = include_str!("sandbox_preamble.lua");
99
100 // TODO Honor all worktrees instead of the first one
101 let worktree_info = self
102 .project
103 .read(cx)
104 .visible_worktrees(cx)
105 .next()
106 .map(|worktree| {
107 let worktree = worktree.read(cx);
108 (worktree.id(), worktree.abs_path())
109 });
110
111 let root_dir = worktree_info.as_ref().map(|(_, root)| root.clone());
112
113 let fs = self.project.read(cx).fs().clone();
114 let foreground_fns_tx = self.foreground_fns_tx.clone();
115
116 let task = cx.background_spawn({
117 let stdout = stdout.clone();
118
119 async move {
120 let lua = Lua::new();
121 lua.set_memory_limit(2 * 1024 * 1024 * 1024)?; // 2 GB
122 let globals = lua.globals();
123
124 // Use the project root dir as the script's current working dir.
125 if let Some(root_dir) = &root_dir {
126 if let Some(root_dir) = root_dir.to_str() {
127 globals.set("cwd", root_dir)?;
128 }
129 }
130
131 globals.set(
132 "sb_print",
133 lua.create_function({
134 let stdout = stdout.clone();
135 move |_, args: MultiValue| Self::print(args, &stdout)
136 })?,
137 )?;
138 globals.set(
139 "search",
140 lua.create_async_function({
141 let foreground_fns_tx = foreground_fns_tx.clone();
142 let fs = fs.clone();
143 move |lua, regex| {
144 let mut foreground_fns_tx = foreground_fns_tx.clone();
145 let fs = fs.clone();
146 async move {
147 Self::search(&lua, &mut foreground_fns_tx, fs, regex)
148 .await
149 .into_lua_err()
150 }
151 }
152 })?,
153 )?;
154 globals.set(
155 "outline",
156 lua.create_async_function({
157 let root_dir = root_dir.clone();
158 let foreground_fns_tx = foreground_fns_tx.clone();
159 move |_lua, path| {
160 let mut foreground_fns_tx = foreground_fns_tx.clone();
161 let root_dir = root_dir.clone();
162 async move {
163 Self::outline(root_dir, &mut foreground_fns_tx, path)
164 .await
165 .into_lua_err()
166 }
167 }
168 })?,
169 )?;
170 globals.set(
171 "sb_io_open",
172 lua.create_async_function({
173 let worktree_info = worktree_info.clone();
174 let foreground_fns_tx = foreground_fns_tx.clone();
175 move |lua, (path_str, mode)| {
176 let worktree_info = worktree_info.clone();
177 let mut foreground_fns_tx = foreground_fns_tx.clone();
178 let fs = fs.clone();
179 async move {
180 Self::io_open(
181 &lua,
182 worktree_info,
183 &mut foreground_fns_tx,
184 fs,
185 path_str,
186 mode,
187 )
188 .await
189 }
190 }
191 })?,
192 )?;
193 globals.set("user_script", script)?;
194
195 lua.load(SANDBOX_PREAMBLE).exec_async().await?;
196
197 anyhow::Ok(())
198 }
199 });
200
201 task
202 }
203
204 pub fn get(&self, script_id: ScriptId) -> &Script {
205 &self.scripts[script_id.0 as usize]
206 }
207
208 fn get_mut(&mut self, script_id: ScriptId) -> &mut Script {
209 &mut self.scripts[script_id.0 as usize]
210 }
211
212 /// Sandboxed print() function in Lua.
213 fn print(args: MultiValue, stdout: &Mutex<String>) -> mlua::Result<()> {
214 for (index, arg) in args.into_iter().enumerate() {
215 // Lua's `print()` prints tab characters between each argument.
216 if index > 0 {
217 stdout.lock().push('\t');
218 }
219
220 // If the argument's to_string() fails, have the whole function call fail.
221 stdout.lock().push_str(&arg.to_string()?);
222 }
223 stdout.lock().push('\n');
224
225 Ok(())
226 }
227
228 /// Sandboxed io.open() function in Lua.
229 async fn io_open(
230 lua: &Lua,
231 worktree_info: Option<(WorktreeId, Arc<Path>)>,
232 foreground_tx: &mut mpsc::Sender<ForegroundFn>,
233 fs: Arc<dyn Fs>,
234 path_str: String,
235 mode: Option<String>,
236 ) -> mlua::Result<(Option<Table>, String)> {
237 let (worktree_id, root_dir) = worktree_info
238 .ok_or_else(|| mlua::Error::runtime("cannot open file without a root directory"))?;
239
240 let mode = mode.unwrap_or_else(|| "r".to_string());
241
242 // Parse the mode string to determine read/write permissions
243 let read_perm = mode.contains('r');
244 let write_perm = mode.contains('w') || mode.contains('a') || mode.contains('+');
245 let append = mode.contains('a');
246 let truncate = mode.contains('w');
247
248 // This will be the Lua value returned from the `open` function.
249 let file = lua.create_table()?;
250
251 // Store file metadata in the file
252 file.set("__mode", mode.clone())?;
253 file.set("__read_perm", read_perm)?;
254 file.set("__write_perm", write_perm)?;
255
256 let path = match Self::parse_abs_path_in_root_dir(&root_dir, &path_str) {
257 Ok(path) => path,
258 Err(err) => return Ok((None, format!("{err}"))),
259 };
260
261 let project_path = ProjectPath {
262 worktree_id,
263 path: Path::new(&path_str).into(),
264 };
265
266 // flush / close method
267 let flush_fn = {
268 let project_path = project_path.clone();
269 let foreground_tx = foreground_tx.clone();
270 lua.create_async_function(move |_lua, file_userdata: mlua::Table| {
271 let project_path = project_path.clone();
272 let mut foreground_tx = foreground_tx.clone();
273 async move {
274 Self::io_file_flush(file_userdata, project_path, &mut foreground_tx).await
275 }
276 })?
277 };
278 file.set("flush", flush_fn.clone())?;
279 // We don't really hold files open, so we only need to flush on close
280 file.set("close", flush_fn)?;
281
282 // If it's a directory, give it a custom read() and return early.
283 if fs.is_dir(&path).await {
284 return Self::io_file_dir(lua, fs, file, &path).await;
285 }
286
287 let mut file_content = Vec::new();
288
289 if !truncate {
290 // Try to read existing content if we're not truncating
291 match Self::read_buffer(project_path.clone(), foreground_tx).await {
292 Ok(content) => file_content = content.into_bytes(),
293 Err(e) => return Ok((None, format!("Error reading file: {}", e))),
294 }
295 }
296
297 // If in append mode, position should be at the end
298 let position = if append { file_content.len() } else { 0 };
299 file.set("__position", position)?;
300 file.set(
301 "__content",
302 lua.create_userdata(FileContent(Arc::new(Mutex::new(file_content))))?,
303 )?;
304
305 // Create file methods
306
307 // read method
308 let read_fn = lua.create_function(Self::io_file_read)?;
309 file.set("read", read_fn)?;
310
311 // lines method
312 let lines_fn = lua.create_function(Self::io_file_lines)?;
313 file.set("lines", lines_fn)?;
314
315 // write method
316 let write_fn = lua.create_function(Self::io_file_write)?;
317 file.set("write", write_fn)?;
318
319 // If we got this far, the file was opened successfully
320 Ok((Some(file), String::new()))
321 }
322
323 async fn read_buffer(
324 project_path: ProjectPath,
325 foreground_tx: &mut mpsc::Sender<ForegroundFn>,
326 ) -> anyhow::Result<String> {
327 Self::run_foreground_fn(
328 "read file from buffer",
329 foreground_tx,
330 Box::new(move |session, mut cx| {
331 session.update(&mut cx, |session, cx| {
332 let open_buffer_task = session
333 .project
334 .update(cx, |project, cx| project.open_buffer(project_path, cx));
335
336 cx.spawn(async move |_, cx| {
337 let buffer = open_buffer_task.await?;
338
339 let text = buffer.read_with(cx, |buffer, _cx| buffer.text())?;
340 Ok(text)
341 })
342 })
343 }),
344 )
345 .await??
346 .await
347 }
348
349 async fn io_file_flush(
350 file_userdata: mlua::Table,
351 project_path: ProjectPath,
352 foreground_tx: &mut mpsc::Sender<ForegroundFn>,
353 ) -> mlua::Result<bool> {
354 let write_perm = file_userdata.get::<bool>("__write_perm")?;
355
356 if write_perm {
357 let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
358 let content_ref = content.borrow::<FileContent>()?;
359 let text = {
360 let mut content_vec = content_ref.0.lock();
361 let content_vec = std::mem::take(&mut *content_vec);
362 String::from_utf8(content_vec).into_lua_err()?
363 };
364
365 Self::write_to_buffer(project_path, text, foreground_tx)
366 .await
367 .into_lua_err()?;
368 }
369
370 Ok(true)
371 }
372
373 async fn write_to_buffer(
374 project_path: ProjectPath,
375 text: String,
376 foreground_tx: &mut mpsc::Sender<ForegroundFn>,
377 ) -> anyhow::Result<()> {
378 Self::run_foreground_fn(
379 "write to buffer",
380 foreground_tx,
381 Box::new(move |session, mut cx| {
382 session.update(&mut cx, |session, cx| {
383 let open_buffer_task = session
384 .project
385 .update(cx, |project, cx| project.open_buffer(project_path, cx));
386
387 cx.spawn(async move |session, cx| {
388 let buffer = open_buffer_task.await?;
389
390 let diff = buffer.update(cx, |buffer, cx| buffer.diff(text, cx))?.await;
391
392 let edit_ids = buffer.update(cx, |buffer, cx| {
393 buffer.finalize_last_transaction();
394 buffer.apply_diff(diff, cx);
395 let transaction = buffer.finalize_last_transaction();
396 transaction
397 .map_or(Vec::new(), |transaction| transaction.edit_ids.clone())
398 })?;
399
400 session
401 .update(cx, {
402 let buffer = buffer.clone();
403
404 |session, cx| {
405 session
406 .project
407 .update(cx, |project, cx| project.save_buffer(buffer, cx))
408 }
409 })?
410 .await?;
411
412 let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
413
414 // If we saved successfully, mark buffer as changed
415 let buffer_without_changes =
416 buffer.update(cx, |buffer, cx| buffer.branch(cx))?;
417 session
418 .update(cx, |session, cx| {
419 let changed_buffer = session
420 .changes_by_buffer
421 .entry(buffer)
422 .or_insert_with(|| BufferChanges {
423 diff: cx.new(|cx| BufferDiff::new(&snapshot, cx)),
424 edit_ids: Vec::new(),
425 });
426 changed_buffer.edit_ids.extend(edit_ids);
427 let operations_to_undo = changed_buffer
428 .edit_ids
429 .iter()
430 .map(|edit_id| (*edit_id, u32::MAX))
431 .collect::<HashMap<_, _>>();
432 buffer_without_changes.update(cx, |buffer, cx| {
433 buffer.undo_operations(operations_to_undo, cx);
434 });
435 changed_buffer.diff.update(cx, |diff, cx| {
436 diff.set_base_text(buffer_without_changes, snapshot.text, cx)
437 })
438 })?
439 .await?;
440
441 Ok(())
442 })
443 })
444 }),
445 )
446 .await??
447 .await
448 }
449
450 async fn io_file_dir(
451 lua: &Lua,
452 fs: Arc<dyn Fs>,
453 file: Table,
454 path: &Path,
455 ) -> mlua::Result<(Option<Table>, String)> {
456 // Create a special directory handle
457 file.set("__is_directory", true)?;
458
459 // Store directory entries
460 let entries = match fs.read_dir(&path).await {
461 Ok(entries) => {
462 let mut entry_names = Vec::new();
463
464 // Process the stream of directory entries
465 pin_mut!(entries);
466 while let Some(Ok(entry_result)) = entries.next().await {
467 if let Some(file_name) = entry_result.file_name() {
468 entry_names.push(file_name.to_string_lossy().into_owned());
469 }
470 }
471
472 entry_names
473 }
474 Err(e) => return Ok((None, format!("Error reading directory: {}", e))),
475 };
476
477 // Save the list of entries
478 file.set("__dir_entries", entries)?;
479 file.set("__dir_position", 0usize)?;
480
481 // Create a directory-specific read function
482 let read_fn = lua.create_function(|_lua, file_userdata: mlua::Table| {
483 let position = file_userdata.get::<usize>("__dir_position")?;
484 let entries = file_userdata.get::<Vec<String>>("__dir_entries")?;
485
486 if position >= entries.len() {
487 return Ok(None); // No more entries
488 }
489
490 let entry = entries[position].clone();
491 file_userdata.set("__dir_position", position + 1)?;
492
493 Ok(Some(entry))
494 })?;
495 file.set("read", read_fn)?;
496
497 // If we got this far, the directory was opened successfully
498 return Ok((Some(file), String::new()));
499 }
500
501 fn io_file_read(
502 lua: &Lua,
503 (file_userdata, format): (Table, Option<mlua::Value>),
504 ) -> mlua::Result<Option<mlua::String>> {
505 let read_perm = file_userdata.get::<bool>("__read_perm")?;
506 if !read_perm {
507 return Err(mlua::Error::runtime("File not open for reading"));
508 }
509
510 let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
511 let position = file_userdata.get::<usize>("__position")?;
512 let content_ref = content.borrow::<FileContent>()?;
513 let content = content_ref.0.lock();
514
515 if position >= content.len() {
516 return Ok(None); // EOF
517 }
518
519 let (result, new_position) = match Self::io_file_read_format(format)? {
520 FileReadFormat::All => {
521 // Read entire file from current position
522 let result = content[position..].to_vec();
523 (Some(result), content.len())
524 }
525 FileReadFormat::Line => {
526 if let Some(next_newline_ix) = content[position..].iter().position(|c| *c == b'\n')
527 {
528 let mut line = content[position..position + next_newline_ix].to_vec();
529 if line.ends_with(b"\r") {
530 line.pop();
531 }
532 (Some(line), position + next_newline_ix + 1)
533 } else if position < content.len() {
534 let line = content[position..].to_vec();
535 (Some(line), content.len())
536 } else {
537 (None, position) // EOF
538 }
539 }
540 FileReadFormat::LineWithLineFeed => {
541 if position < content.len() {
542 let next_line_ix = content[position..]
543 .iter()
544 .position(|c| *c == b'\n')
545 .map_or(content.len(), |ix| position + ix + 1);
546 let line = content[position..next_line_ix].to_vec();
547 (Some(line), next_line_ix)
548 } else {
549 (None, position) // EOF
550 }
551 }
552 FileReadFormat::Bytes(n) => {
553 let end = std::cmp::min(position + n, content.len());
554 let result = content[position..end].to_vec();
555 (Some(result), end)
556 }
557 };
558
559 // Update the position in the file userdata
560 if new_position != position {
561 file_userdata.set("__position", new_position)?;
562 }
563
564 // Convert the result to a Lua string
565 match result {
566 Some(bytes) => Ok(Some(lua.create_string(bytes)?)),
567 None => Ok(None),
568 }
569 }
570
571 fn io_file_lines(lua: &Lua, file_userdata: Table) -> mlua::Result<mlua::Function> {
572 let read_perm = file_userdata.get::<bool>("__read_perm")?;
573 if !read_perm {
574 return Err(mlua::Error::runtime("File not open for reading"));
575 }
576
577 lua.create_function::<_, _, mlua::Value>(move |lua, _: ()| {
578 file_userdata.call_method("read", lua.create_string("*l")?)
579 })
580 }
581
582 fn io_file_read_format(format: Option<mlua::Value>) -> mlua::Result<FileReadFormat> {
583 let format = match format {
584 Some(mlua::Value::String(s)) => {
585 let lossy_string = s.to_string_lossy();
586 let format_str: &str = lossy_string.as_ref();
587
588 // Only consider the first 2 bytes, since it's common to pass e.g. "*all" instead of "*a"
589 match &format_str[0..2] {
590 "*a" => FileReadFormat::All,
591 "*l" => FileReadFormat::Line,
592 "*L" => FileReadFormat::LineWithLineFeed,
593 "*n" => {
594 // Try to parse as a number (number of bytes to read)
595 match format_str.parse::<usize>() {
596 Ok(n) => FileReadFormat::Bytes(n),
597 Err(_) => {
598 return Err(mlua::Error::runtime(format!(
599 "Invalid format: {}",
600 format_str
601 )))
602 }
603 }
604 }
605 _ => {
606 return Err(mlua::Error::runtime(format!(
607 "Unsupported format: {}",
608 format_str
609 )))
610 }
611 }
612 }
613 Some(mlua::Value::Number(n)) => FileReadFormat::Bytes(n as usize),
614 Some(mlua::Value::Integer(n)) => FileReadFormat::Bytes(n as usize),
615 Some(value) => {
616 return Err(mlua::Error::runtime(format!(
617 "Invalid file format {:?}",
618 value
619 )))
620 }
621 None => FileReadFormat::Line, // Default is to read a line
622 };
623
624 Ok(format)
625 }
626
627 fn io_file_write(
628 _lua: &Lua,
629 (file_userdata, text): (Table, mlua::String),
630 ) -> mlua::Result<bool> {
631 let write_perm = file_userdata.get::<bool>("__write_perm")?;
632 if !write_perm {
633 return Err(mlua::Error::runtime("File not open for writing"));
634 }
635
636 let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
637 let position = file_userdata.get::<usize>("__position")?;
638 let content_ref = content.borrow::<FileContent>()?;
639 let mut content_vec = content_ref.0.lock();
640
641 let bytes = text.as_bytes();
642
643 // Ensure the vector has enough capacity
644 if position + bytes.len() > content_vec.len() {
645 content_vec.resize(position + bytes.len(), 0);
646 }
647
648 // Write the bytes
649 for (i, &byte) in bytes.iter().enumerate() {
650 content_vec[position + i] = byte;
651 }
652
653 // Update position
654 let new_position = position + bytes.len();
655 file_userdata.set("__position", new_position)?;
656
657 Ok(true)
658 }
659
660 async fn search(
661 lua: &Lua,
662 foreground_tx: &mut mpsc::Sender<ForegroundFn>,
663 fs: Arc<dyn Fs>,
664 regex: String,
665 ) -> anyhow::Result<Table> {
666 // TODO: Allow specification of these options.
667 let search_query = SearchQuery::regex(
668 ®ex,
669 false,
670 false,
671 false,
672 PathMatcher::default(),
673 PathMatcher::default(),
674 None,
675 );
676 let search_query = match search_query {
677 Ok(query) => query,
678 Err(e) => return Err(anyhow!("Invalid search query: {}", e)),
679 };
680
681 // TODO: Should use `search_query.regex`. The tool description should also be updated,
682 // as it specifies standard regex.
683 let search_regex = match Regex::new(®ex) {
684 Ok(re) => re,
685 Err(e) => return Err(anyhow!("Invalid regex: {}", e)),
686 };
687
688 let mut abs_paths_rx = Self::find_search_candidates(search_query, foreground_tx).await?;
689
690 let mut search_results: Vec<Table> = Vec::new();
691 while let Some(path) = abs_paths_rx.next().await {
692 // Skip files larger than 1MB
693 if let Ok(Some(metadata)) = fs.metadata(&path).await {
694 if metadata.len > 1_000_000 {
695 continue;
696 }
697 }
698
699 // Attempt to read the file as text
700 if let Ok(content) = fs.load(&path).await {
701 let mut matches = Vec::new();
702
703 // Find all regex matches in the content
704 for capture in search_regex.find_iter(&content) {
705 matches.push(capture.as_str().to_string());
706 }
707
708 // If we found matches, create a result entry
709 if !matches.is_empty() {
710 let result_entry = lua.create_table()?;
711 result_entry.set("path", path.to_string_lossy().to_string())?;
712
713 let matches_table = lua.create_table()?;
714 for (ix, m) in matches.iter().enumerate() {
715 matches_table.set(ix + 1, m.clone())?;
716 }
717 result_entry.set("matches", matches_table)?;
718
719 search_results.push(result_entry);
720 }
721 }
722 }
723
724 // Create a table to hold our results
725 let results_table = lua.create_table()?;
726 for (ix, entry) in search_results.into_iter().enumerate() {
727 results_table.set(ix + 1, entry)?;
728 }
729
730 Ok(results_table)
731 }
732
733 async fn find_search_candidates(
734 search_query: SearchQuery,
735 foreground_tx: &mut mpsc::Sender<ForegroundFn>,
736 ) -> anyhow::Result<mpsc::UnboundedReceiver<PathBuf>> {
737 Self::run_foreground_fn(
738 "finding search file candidates",
739 foreground_tx,
740 Box::new(move |session, mut cx| {
741 session.update(&mut cx, |session, cx| {
742 session.project.update(cx, |project, cx| {
743 project.worktree_store().update(cx, |worktree_store, cx| {
744 // TODO: Better limit? For now this is the same as
745 // MAX_SEARCH_RESULT_FILES.
746 let limit = 5000;
747 // TODO: Providing non-empty open_entries can make this a bit more
748 // efficient as it can skip checking that these paths are textual.
749 let open_entries = HashSet::default();
750 let candidates = worktree_store.find_search_candidates(
751 search_query,
752 limit,
753 open_entries,
754 project.fs().clone(),
755 cx,
756 );
757 let (abs_paths_tx, abs_paths_rx) = mpsc::unbounded();
758 cx.spawn(async move |worktree_store, cx| {
759 pin_mut!(candidates);
760
761 while let Some(project_path) = candidates.next().await {
762 worktree_store.read_with(cx, |worktree_store, cx| {
763 if let Some(worktree) = worktree_store
764 .worktree_for_id(project_path.worktree_id, cx)
765 {
766 if let Some(abs_path) = worktree
767 .read(cx)
768 .absolutize(&project_path.path)
769 .log_err()
770 {
771 abs_paths_tx.unbounded_send(abs_path)?;
772 }
773 }
774 anyhow::Ok(())
775 })??;
776 }
777 anyhow::Ok(())
778 })
779 .detach();
780 abs_paths_rx
781 })
782 })
783 })
784 }),
785 )
786 .await?
787 }
788
789 async fn outline(
790 root_dir: Option<Arc<Path>>,
791 foreground_tx: &mut mpsc::Sender<ForegroundFn>,
792 path_str: String,
793 ) -> anyhow::Result<String> {
794 let root_dir = root_dir
795 .ok_or_else(|| mlua::Error::runtime("cannot get outline without a root directory"))?;
796 let path = Self::parse_abs_path_in_root_dir(&root_dir, &path_str)?;
797 let outline = Self::run_foreground_fn(
798 "getting code outline",
799 foreground_tx,
800 Box::new(move |session, cx| {
801 cx.spawn(async move |cx| {
802 // TODO: This will not use file content from `fs_changes`. It will also reflect
803 // user changes that have not been saved.
804 let buffer = session
805 .update(cx, |session, cx| {
806 session
807 .project
808 .update(cx, |project, cx| project.open_local_buffer(&path, cx))
809 })?
810 .await?;
811 buffer.update(cx, |buffer, _cx| {
812 if let Some(outline) = buffer.snapshot().outline(None) {
813 Ok(outline)
814 } else {
815 Err(anyhow!("No outline for file {path_str}"))
816 }
817 })
818 })
819 }),
820 )
821 .await?
822 .await??;
823
824 Ok(outline
825 .items
826 .into_iter()
827 .map(|item| {
828 if item.text.contains('\n') {
829 log::error!("Outline item unexpectedly contains newline");
830 }
831 format!("{}{}", " ".repeat(item.depth), item.text)
832 })
833 .collect::<Vec<String>>()
834 .join("\n"))
835 }
836
837 async fn run_foreground_fn<R: Send + 'static>(
838 description: &str,
839 foreground_tx: &mut mpsc::Sender<ForegroundFn>,
840 function: Box<dyn FnOnce(WeakEntity<Self>, AsyncApp) -> R + Send>,
841 ) -> anyhow::Result<R> {
842 let (response_tx, response_rx) = oneshot::channel();
843 let send_result = foreground_tx
844 .send(ForegroundFn(Box::new(move |this, cx| {
845 response_tx.send(function(this, cx)).ok();
846 })))
847 .await;
848 match send_result {
849 Ok(()) => (),
850 Err(err) => {
851 return Err(anyhow::Error::new(err).context(format!(
852 "Internal error while enqueuing work for {description}"
853 )));
854 }
855 }
856 match response_rx.await {
857 Ok(result) => Ok(result),
858 Err(oneshot::Canceled) => Err(anyhow!(
859 "Internal error: response oneshot was canceled while {description}."
860 )),
861 }
862 }
863
864 fn parse_abs_path_in_root_dir(root_dir: &Path, path_str: &str) -> anyhow::Result<PathBuf> {
865 let path = Path::new(&path_str);
866 if path.is_absolute() {
867 // Check if path starts with root_dir prefix without resolving symlinks
868 if path.starts_with(&root_dir) {
869 Ok(path.to_path_buf())
870 } else {
871 Err(anyhow!(
872 "Error: Absolute path {} is outside the current working directory",
873 path_str
874 ))
875 }
876 } else {
877 // TODO: Does use of `../` break sandbox - is path canonicalization needed?
878 Ok(root_dir.join(path))
879 }
880 }
881}
882
883enum FileReadFormat {
884 All,
885 Line,
886 LineWithLineFeed,
887 Bytes(usize),
888}
889
890struct FileContent(Arc<Mutex<Vec<u8>>>);
891
892impl UserData for FileContent {
893 fn add_methods<M: UserDataMethods<Self>>(_methods: &mut M) {
894 // FileContent doesn't have any methods so far.
895 }
896}
897
898#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
899pub struct ScriptId(u32);
900
901pub struct Script {
902 pub state: ScriptState,
903}
904
905#[derive(Debug)]
906pub enum ScriptState {
907 Running {
908 stdout: Arc<Mutex<String>>,
909 },
910 Succeeded {
911 stdout: String,
912 },
913 Failed {
914 stdout: String,
915 error: anyhow::Error,
916 },
917}
918
919impl Script {
920 /// If exited, returns a message with the output for the LLM
921 pub fn output_message_for_llm(&self) -> Option<String> {
922 match &self.state {
923 ScriptState::Running { .. } => None,
924 ScriptState::Succeeded { stdout } => {
925 format!("Here's the script output:\n{}", stdout).into()
926 }
927 ScriptState::Failed { stdout, error } => format!(
928 "The script failed with:\n{}\n\nHere's the output it managed to print:\n{}",
929 error, stdout
930 )
931 .into(),
932 }
933 }
934
935 /// Get a snapshot of the script's stdout
936 pub fn stdout_snapshot(&self) -> String {
937 match &self.state {
938 ScriptState::Running { stdout } => stdout.lock().clone(),
939 ScriptState::Succeeded { stdout } => stdout.clone(),
940 ScriptState::Failed { stdout, .. } => stdout.clone(),
941 }
942 }
943}
944
945#[cfg(test)]
946mod tests {
947 use gpui::TestAppContext;
948 use project::FakeFs;
949 use serde_json::json;
950 use settings::SettingsStore;
951 use util::path;
952
953 use super::*;
954
955 #[gpui::test]
956 async fn test_print(cx: &mut TestAppContext) {
957 let script = r#"
958 print("Hello", "world!")
959 print("Goodbye", "moon!")
960 "#;
961
962 let test_session = TestSession::init(cx).await;
963 let output = test_session.test_success(script, cx).await;
964 assert_eq!(output, "Hello\tworld!\nGoodbye\tmoon!\n");
965 }
966
967 // search
968
969 #[gpui::test]
970 async fn test_search(cx: &mut TestAppContext) {
971 let script = r#"
972 local results = search("world")
973 for i, result in ipairs(results) do
974 print("File: " .. result.path)
975 print("Matches:")
976 for j, match in ipairs(result.matches) do
977 print(" " .. match)
978 end
979 end
980 "#;
981
982 let test_session = TestSession::init(cx).await;
983 let output = test_session.test_success(script, cx).await;
984 assert_eq!(
985 output,
986 concat!("File: ", path!("/file1.txt"), "\nMatches:\n world\n")
987 );
988 }
989
990 // io.open
991
992 #[gpui::test]
993 async fn test_open_and_read_file(cx: &mut TestAppContext) {
994 let script = r#"
995 local file = io.open("file1.txt", "r")
996 local content = file:read()
997 print("Content:", content)
998 file:close()
999 "#;
1000
1001 let test_session = TestSession::init(cx).await;
1002 let output = test_session.test_success(script, cx).await;
1003 assert_eq!(output, "Content:\tHello world!\n");
1004 assert_eq!(test_session.diff(cx), Vec::new());
1005 }
1006
1007 #[gpui::test]
1008 async fn test_lines_iterator(cx: &mut TestAppContext) {
1009 let script = r#"
1010 -- Create a test file with multiple lines
1011 local file = io.open("lines_test.txt", "w")
1012 file:write("Line 1\nLine 2\nLine 3\nLine 4\nLine 5")
1013 file:close()
1014
1015 -- Read it back using the lines iterator
1016 local read_file = io.open("lines_test.txt", "r")
1017 local count = 0
1018 for line in read_file:lines() do
1019 count = count + 1
1020 print(count .. ": " .. line)
1021 end
1022 read_file:close()
1023
1024 print("Total lines:", count)
1025 "#;
1026
1027 let test_session = TestSession::init(cx).await;
1028 let output = test_session.test_success(script, cx).await;
1029 assert_eq!(
1030 output,
1031 "1: Line 1\n2: Line 2\n3: Line 3\n4: Line 4\n5: Line 5\nTotal lines:\t5\n"
1032 );
1033 }
1034
1035 #[gpui::test]
1036 async fn test_read_write_roundtrip(cx: &mut TestAppContext) {
1037 let script = r#"
1038 local file = io.open("file1.txt", "w")
1039 file:write("This is new content")
1040 file:close()
1041
1042 -- Read back to verify
1043 local read_file = io.open("file1.txt", "r")
1044 local content = read_file:read("*a")
1045 print("Written content:", content)
1046 read_file:close()
1047 "#;
1048
1049 let test_session = TestSession::init(cx).await;
1050 let output = test_session.test_success(script, cx).await;
1051 assert_eq!(output, "Written content:\tThis is new content\n");
1052 assert_eq!(
1053 test_session.diff(cx),
1054 vec![(
1055 PathBuf::from("file1.txt"),
1056 vec![(
1057 "Hello world!\n".to_string(),
1058 "This is new content".to_string()
1059 )]
1060 )]
1061 );
1062 }
1063
1064 #[gpui::test]
1065 async fn test_multiple_writes(cx: &mut TestAppContext) {
1066 let script = r#"
1067 -- Test writing to a file multiple times
1068 local file = io.open("multiwrite.txt", "w")
1069 file:write("First line\n")
1070 file:write("Second line\n")
1071 file:write("Third line")
1072 file:close()
1073
1074 -- Read back to verify
1075 local read_file = io.open("multiwrite.txt", "r")
1076 if read_file then
1077 local content = read_file:read("*a")
1078 print("Full content:", content)
1079 read_file:close()
1080 end
1081 "#;
1082
1083 let test_session = TestSession::init(cx).await;
1084 let output = test_session.test_success(script, cx).await;
1085 assert_eq!(
1086 output,
1087 "Full content:\tFirst line\nSecond line\nThird line\n"
1088 );
1089 assert_eq!(
1090 test_session.diff(cx),
1091 vec![(
1092 PathBuf::from("multiwrite.txt"),
1093 vec![(
1094 "".to_string(),
1095 "First line\nSecond line\nThird line".to_string()
1096 )]
1097 )]
1098 );
1099 }
1100
1101 #[gpui::test]
1102 async fn test_multiple_writes_diff_handles(cx: &mut TestAppContext) {
1103 let script = r#"
1104 -- Write to a file
1105 local file1 = io.open("multi_open.txt", "w")
1106 file1:write("Content written by first handle\n")
1107 file1:close()
1108
1109 -- Open it again and add more content
1110 local file2 = io.open("multi_open.txt", "w")
1111 file2:write("Content written by second handle\n")
1112 file2:close()
1113
1114 -- Open it a third time and read
1115 local file3 = io.open("multi_open.txt", "r")
1116 local content = file3:read("*a")
1117 print("Final content:", content)
1118 file3:close()
1119 "#;
1120
1121 let test_session = TestSession::init(cx).await;
1122 let output = test_session.test_success(script, cx).await;
1123 assert_eq!(
1124 output,
1125 "Final content:\tContent written by second handle\n\n"
1126 );
1127 assert_eq!(
1128 test_session.diff(cx),
1129 vec![(
1130 PathBuf::from("multi_open.txt"),
1131 vec![(
1132 "".to_string(),
1133 "Content written by second handle\n".to_string()
1134 )]
1135 )]
1136 );
1137 }
1138
1139 #[gpui::test]
1140 async fn test_append_mode(cx: &mut TestAppContext) {
1141 let script = r#"
1142 -- Append more content
1143 file = io.open("file1.txt", "a")
1144 file:write("Appended content\n")
1145 file:close()
1146
1147 -- Add even more
1148 file = io.open("file1.txt", "a")
1149 file:write("More appended content")
1150 file:close()
1151
1152 -- Read back to verify
1153 local read_file = io.open("file1.txt", "r")
1154 local content = read_file:read("*a")
1155 print("Content after appends:", content)
1156 read_file:close()
1157 "#;
1158
1159 let test_session = TestSession::init(cx).await;
1160 let output = test_session.test_success(script, cx).await;
1161 assert_eq!(
1162 output,
1163 "Content after appends:\tHello world!\nAppended content\nMore appended content\n"
1164 );
1165 assert_eq!(
1166 test_session.diff(cx),
1167 vec![(
1168 PathBuf::from("file1.txt"),
1169 vec![(
1170 "".to_string(),
1171 "Appended content\nMore appended content".to_string()
1172 )]
1173 )]
1174 );
1175 }
1176
1177 #[gpui::test]
1178 async fn test_read_formats(cx: &mut TestAppContext) {
1179 let script = r#"
1180 local file = io.open("multiline.txt", "w")
1181 file:write("Line 1\nLine 2\nLine 3")
1182 file:close()
1183
1184 -- Test "*a" (all)
1185 local f = io.open("multiline.txt", "r")
1186 local all = f:read("*a")
1187 print("All:", all)
1188 f:close()
1189
1190 -- Test "*l" (line)
1191 f = io.open("multiline.txt", "r")
1192 local line1 = f:read("*l")
1193 local line2 = f:read("*l")
1194 local line3 = f:read("*l")
1195 print("Line 1:", line1)
1196 print("Line 2:", line2)
1197 print("Line 3:", line3)
1198 f:close()
1199
1200 -- Test "*L" (line with newline)
1201 f = io.open("multiline.txt", "r")
1202 local line_with_nl = f:read("*L")
1203 print("Line with newline length:", #line_with_nl)
1204 print("Last char:", string.byte(line_with_nl, #line_with_nl))
1205 f:close()
1206
1207 -- Test number of bytes
1208 f = io.open("multiline.txt", "r")
1209 local bytes5 = f:read(5)
1210 print("5 bytes:", bytes5)
1211 f:close()
1212 "#;
1213
1214 let test_session = TestSession::init(cx).await;
1215 let output = test_session.test_success(script, cx).await;
1216 println!("{}", &output);
1217 assert!(output.contains("All:\tLine 1\nLine 2\nLine 3"));
1218 assert!(output.contains("Line 1:\tLine 1"));
1219 assert!(output.contains("Line 2:\tLine 2"));
1220 assert!(output.contains("Line 3:\tLine 3"));
1221 assert!(output.contains("Line with newline length:\t7"));
1222 assert!(output.contains("Last char:\t10")); // LF
1223 assert!(output.contains("5 bytes:\tLine "));
1224 assert_eq!(
1225 test_session.diff(cx),
1226 vec![(
1227 PathBuf::from("multiline.txt"),
1228 vec![("".to_string(), "Line 1\nLine 2\nLine 3".to_string())]
1229 )]
1230 );
1231 }
1232
1233 // helpers
1234
1235 struct TestSession {
1236 session: Entity<ScriptingSession>,
1237 }
1238
1239 impl TestSession {
1240 async fn init(cx: &mut TestAppContext) -> Self {
1241 let settings_store = cx.update(SettingsStore::test);
1242 cx.set_global(settings_store);
1243 cx.update(Project::init_settings);
1244 cx.update(language::init);
1245
1246 let fs = FakeFs::new(cx.executor());
1247 fs.insert_tree(
1248 path!("/"),
1249 json!({
1250 "file1.txt": "Hello world!\n",
1251 "file2.txt": "Goodbye moon!\n"
1252 }),
1253 )
1254 .await;
1255
1256 let project = Project::test(fs.clone(), [Path::new(path!("/"))], cx).await;
1257 let session = cx.new(|cx| ScriptingSession::new(project, cx));
1258
1259 TestSession { session }
1260 }
1261
1262 async fn test_success(&self, source: &str, cx: &mut TestAppContext) -> String {
1263 let script_id = self.run_script(source, cx).await;
1264
1265 self.session.read_with(cx, |session, _cx| {
1266 let script = session.get(script_id);
1267 let stdout = script.stdout_snapshot();
1268
1269 if let ScriptState::Failed { error, .. } = &script.state {
1270 panic!("Script failed:\n{}\n\n{}", error, stdout);
1271 }
1272
1273 stdout
1274 })
1275 }
1276
1277 fn diff(&self, cx: &mut TestAppContext) -> Vec<(PathBuf, Vec<(String, String)>)> {
1278 self.session.read_with(cx, |session, cx| {
1279 session
1280 .changes_by_buffer
1281 .iter()
1282 .map(|(buffer, changes)| {
1283 let snapshot = buffer.read(cx).snapshot();
1284 let diff = changes.diff.read(cx);
1285 let hunks = diff.hunks(&snapshot, cx);
1286 let path = buffer.read(cx).file().unwrap().path().clone();
1287 let diffs = hunks
1288 .map(|hunk| {
1289 let old_text = diff
1290 .base_text()
1291 .text_for_range(hunk.diff_base_byte_range)
1292 .collect::<String>();
1293 let new_text =
1294 snapshot.text_for_range(hunk.range).collect::<String>();
1295 (old_text, new_text)
1296 })
1297 .collect();
1298 (path.to_path_buf(), diffs)
1299 })
1300 .collect()
1301 })
1302 }
1303
1304 async fn run_script(&self, source: &str, cx: &mut TestAppContext) -> ScriptId {
1305 let (script_id, task) = self
1306 .session
1307 .update(cx, |session, cx| session.run_script(source.to_string(), cx));
1308
1309 task.await;
1310
1311 script_id
1312 }
1313 }
1314}