1use anyhow::anyhow;
2use buffer_diff::BufferDiff;
3use collections::{HashMap, HashSet};
4use futures::{
5 channel::{mpsc, oneshot},
6 pin_mut, SinkExt, StreamExt,
7};
8use gpui::{AppContext, AsyncApp, Context, Entity, Task, WeakEntity};
9use language::Buffer;
10use mlua::{ExternalResult, Lua, MultiValue, Table, UserData, UserDataMethods};
11use parking_lot::Mutex;
12use project::{search::SearchQuery, Fs, Project, ProjectPath, WorktreeId};
13use regex::Regex;
14use std::{
15 path::{Path, PathBuf},
16 sync::Arc,
17};
18use util::{paths::PathMatcher, ResultExt};
19
20struct ForegroundFn(Box<dyn FnOnce(WeakEntity<ScriptingSession>, AsyncApp) + Send>);
21
22struct BufferChanges {
23 diff: Entity<BufferDiff>,
24 edit_ids: Vec<clock::Lamport>,
25}
26
27pub struct ScriptingSession {
28 project: Entity<Project>,
29 scripts: Vec<Script>,
30 changes_by_buffer: HashMap<Entity<Buffer>, BufferChanges>,
31 foreground_fns_tx: mpsc::Sender<ForegroundFn>,
32 _invoke_foreground_fns: Task<()>,
33}
34
35impl ScriptingSession {
36 pub fn new(project: Entity<Project>, cx: &mut Context<Self>) -> Self {
37 let (foreground_fns_tx, mut foreground_fns_rx) = mpsc::channel(128);
38 ScriptingSession {
39 project,
40 scripts: Vec::new(),
41 changes_by_buffer: HashMap::default(),
42 foreground_fns_tx,
43 _invoke_foreground_fns: cx.spawn(|this, cx| async move {
44 while let Some(foreground_fn) = foreground_fns_rx.next().await {
45 foreground_fn.0(this.clone(), cx.clone());
46 }
47 }),
48 }
49 }
50
51 pub fn changed_buffers(&self) -> impl ExactSizeIterator<Item = &Entity<Buffer>> {
52 self.changes_by_buffer.keys()
53 }
54
55 pub fn run_script(
56 &mut self,
57 script_src: String,
58 cx: &mut Context<Self>,
59 ) -> (ScriptId, Task<()>) {
60 let id = ScriptId(self.scripts.len() as u32);
61
62 let stdout = Arc::new(Mutex::new(String::new()));
63
64 let script = Script {
65 state: ScriptState::Running {
66 stdout: stdout.clone(),
67 },
68 };
69 self.scripts.push(script);
70
71 let task = self.run_lua(script_src, stdout, cx);
72
73 let task = cx.spawn(|session, mut cx| async move {
74 let result = task.await;
75
76 session
77 .update(&mut cx, |session, _cx| {
78 let script = session.get_mut(id);
79 let stdout = script.stdout_snapshot();
80
81 script.state = match result {
82 Ok(()) => ScriptState::Succeeded { stdout },
83 Err(error) => ScriptState::Failed { stdout, error },
84 };
85 })
86 .log_err();
87 });
88
89 (id, task)
90 }
91
92 fn run_lua(
93 &mut self,
94 script: String,
95 stdout: Arc<Mutex<String>>,
96 cx: &mut Context<Self>,
97 ) -> Task<anyhow::Result<()>> {
98 const SANDBOX_PREAMBLE: &str = include_str!("sandbox_preamble.lua");
99
100 // TODO Honor all worktrees instead of the first one
101 let worktree_info = self
102 .project
103 .read(cx)
104 .visible_worktrees(cx)
105 .next()
106 .map(|worktree| {
107 let worktree = worktree.read(cx);
108 (worktree.id(), worktree.abs_path())
109 });
110
111 let root_dir = worktree_info.as_ref().map(|(_, root)| root.clone());
112
113 let fs = self.project.read(cx).fs().clone();
114 let foreground_fns_tx = self.foreground_fns_tx.clone();
115
116 let task = cx.background_spawn({
117 let stdout = stdout.clone();
118
119 async move {
120 let lua = Lua::new();
121 lua.set_memory_limit(2 * 1024 * 1024 * 1024)?; // 2 GB
122 let globals = lua.globals();
123
124 // Use the project root dir as the script's current working dir.
125 if let Some(root_dir) = &root_dir {
126 if let Some(root_dir) = root_dir.to_str() {
127 globals.set("cwd", root_dir)?;
128 }
129 }
130
131 globals.set(
132 "sb_print",
133 lua.create_function({
134 let stdout = stdout.clone();
135 move |_, args: MultiValue| Self::print(args, &stdout)
136 })?,
137 )?;
138 globals.set(
139 "search",
140 lua.create_async_function({
141 let foreground_fns_tx = foreground_fns_tx.clone();
142 let fs = fs.clone();
143 move |lua, regex| {
144 let mut foreground_fns_tx = foreground_fns_tx.clone();
145 let fs = fs.clone();
146 async move {
147 Self::search(&lua, &mut foreground_fns_tx, fs, regex)
148 .await
149 .into_lua_err()
150 }
151 }
152 })?,
153 )?;
154 globals.set(
155 "outline",
156 lua.create_async_function({
157 let root_dir = root_dir.clone();
158 let foreground_fns_tx = foreground_fns_tx.clone();
159 move |_lua, path| {
160 let mut foreground_fns_tx = foreground_fns_tx.clone();
161 let root_dir = root_dir.clone();
162 async move {
163 Self::outline(root_dir, &mut foreground_fns_tx, path)
164 .await
165 .into_lua_err()
166 }
167 }
168 })?,
169 )?;
170 globals.set(
171 "sb_io_open",
172 lua.create_async_function({
173 let worktree_info = worktree_info.clone();
174 let foreground_fns_tx = foreground_fns_tx.clone();
175 move |lua, (path_str, mode)| {
176 let worktree_info = worktree_info.clone();
177 let mut foreground_fns_tx = foreground_fns_tx.clone();
178 let fs = fs.clone();
179 async move {
180 Self::io_open(
181 &lua,
182 worktree_info,
183 &mut foreground_fns_tx,
184 fs,
185 path_str,
186 mode,
187 )
188 .await
189 }
190 }
191 })?,
192 )?;
193 globals.set("user_script", script)?;
194
195 lua.load(SANDBOX_PREAMBLE).exec_async().await?;
196
197 anyhow::Ok(())
198 }
199 });
200
201 task
202 }
203
204 pub fn get(&self, script_id: ScriptId) -> &Script {
205 &self.scripts[script_id.0 as usize]
206 }
207
208 fn get_mut(&mut self, script_id: ScriptId) -> &mut Script {
209 &mut self.scripts[script_id.0 as usize]
210 }
211
212 /// Sandboxed print() function in Lua.
213 fn print(args: MultiValue, stdout: &Mutex<String>) -> mlua::Result<()> {
214 for (index, arg) in args.into_iter().enumerate() {
215 // Lua's `print()` prints tab characters between each argument.
216 if index > 0 {
217 stdout.lock().push('\t');
218 }
219
220 // If the argument's to_string() fails, have the whole function call fail.
221 stdout.lock().push_str(&arg.to_string()?);
222 }
223 stdout.lock().push('\n');
224
225 Ok(())
226 }
227
228 /// Sandboxed io.open() function in Lua.
229 async fn io_open(
230 lua: &Lua,
231 worktree_info: Option<(WorktreeId, Arc<Path>)>,
232 foreground_tx: &mut mpsc::Sender<ForegroundFn>,
233 fs: Arc<dyn Fs>,
234 path_str: String,
235 mode: Option<String>,
236 ) -> mlua::Result<(Option<Table>, String)> {
237 let (worktree_id, root_dir) = worktree_info
238 .ok_or_else(|| mlua::Error::runtime("cannot open file without a root directory"))?;
239
240 let mode = mode.unwrap_or_else(|| "r".to_string());
241
242 // Parse the mode string to determine read/write permissions
243 let read_perm = mode.contains('r');
244 let write_perm = mode.contains('w') || mode.contains('a') || mode.contains('+');
245 let append = mode.contains('a');
246 let truncate = mode.contains('w');
247
248 // This will be the Lua value returned from the `open` function.
249 let file = lua.create_table()?;
250
251 // Store file metadata in the file
252 file.set("__mode", mode.clone())?;
253 file.set("__read_perm", read_perm)?;
254 file.set("__write_perm", write_perm)?;
255
256 let path = match Self::parse_abs_path_in_root_dir(&root_dir, &path_str) {
257 Ok(path) => path,
258 Err(err) => return Ok((None, format!("{err}"))),
259 };
260
261 let project_path = ProjectPath {
262 worktree_id,
263 path: Path::new(&path_str).into(),
264 };
265
266 // flush / close method
267 let flush_fn = {
268 let project_path = project_path.clone();
269 let foreground_tx = foreground_tx.clone();
270 lua.create_async_function(move |_lua, file_userdata: mlua::Table| {
271 let project_path = project_path.clone();
272 let mut foreground_tx = foreground_tx.clone();
273 async move {
274 Self::io_file_flush(file_userdata, project_path, &mut foreground_tx).await
275 }
276 })?
277 };
278 file.set("flush", flush_fn.clone())?;
279 // We don't really hold files open, so we only need to flush on close
280 file.set("close", flush_fn)?;
281
282 // If it's a directory, give it a custom read() and return early.
283 if fs.is_dir(&path).await {
284 return Self::io_file_dir(lua, fs, file, &path).await;
285 }
286
287 let mut file_content = Vec::new();
288
289 if !truncate {
290 // Try to read existing content if we're not truncating
291 match Self::read_buffer(project_path.clone(), foreground_tx).await {
292 Ok(content) => file_content = content.into_bytes(),
293 Err(e) => return Ok((None, format!("Error reading file: {}", e))),
294 }
295 }
296
297 // If in append mode, position should be at the end
298 let position = if append { file_content.len() } else { 0 };
299 file.set("__position", position)?;
300 file.set(
301 "__content",
302 lua.create_userdata(FileContent(Arc::new(Mutex::new(file_content))))?,
303 )?;
304
305 // Create file methods
306
307 // read method
308 let read_fn = lua.create_function(Self::io_file_read)?;
309 file.set("read", read_fn)?;
310
311 // write method
312 let write_fn = lua.create_function(Self::io_file_write)?;
313 file.set("write", write_fn)?;
314
315 // If we got this far, the file was opened successfully
316 Ok((Some(file), String::new()))
317 }
318
319 async fn read_buffer(
320 project_path: ProjectPath,
321 foreground_tx: &mut mpsc::Sender<ForegroundFn>,
322 ) -> anyhow::Result<String> {
323 Self::run_foreground_fn(
324 "read file from buffer",
325 foreground_tx,
326 Box::new(move |session, mut cx| {
327 session.update(&mut cx, |session, cx| {
328 let open_buffer_task = session
329 .project
330 .update(cx, |project, cx| project.open_buffer(project_path, cx));
331
332 cx.spawn(|_, cx| async move {
333 let buffer = open_buffer_task.await?;
334
335 let text = buffer.read_with(&cx, |buffer, _cx| buffer.text())?;
336 Ok(text)
337 })
338 })
339 }),
340 )
341 .await??
342 .await
343 }
344
345 async fn io_file_flush(
346 file_userdata: mlua::Table,
347 project_path: ProjectPath,
348 foreground_tx: &mut mpsc::Sender<ForegroundFn>,
349 ) -> mlua::Result<bool> {
350 let write_perm = file_userdata.get::<bool>("__write_perm")?;
351
352 if write_perm {
353 let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
354 let content_ref = content.borrow::<FileContent>()?;
355 let text = {
356 let mut content_vec = content_ref.0.lock();
357 let content_vec = std::mem::take(&mut *content_vec);
358 String::from_utf8(content_vec).into_lua_err()?
359 };
360
361 Self::write_to_buffer(project_path, text, foreground_tx)
362 .await
363 .into_lua_err()?;
364 }
365
366 Ok(true)
367 }
368
369 async fn write_to_buffer(
370 project_path: ProjectPath,
371 text: String,
372 foreground_tx: &mut mpsc::Sender<ForegroundFn>,
373 ) -> anyhow::Result<()> {
374 Self::run_foreground_fn(
375 "write to buffer",
376 foreground_tx,
377 Box::new(move |session, mut cx| {
378 session.update(&mut cx, |session, cx| {
379 let open_buffer_task = session
380 .project
381 .update(cx, |project, cx| project.open_buffer(project_path, cx));
382
383 cx.spawn(move |session, mut cx| async move {
384 let buffer = open_buffer_task.await?;
385
386 let diff = buffer
387 .update(&mut cx, |buffer, cx| buffer.diff(text, cx))?
388 .await;
389
390 let edit_ids = buffer.update(&mut cx, |buffer, cx| {
391 buffer.finalize_last_transaction();
392 buffer.apply_diff(diff, cx);
393 let transaction = buffer.finalize_last_transaction();
394 transaction
395 .map_or(Vec::new(), |transaction| transaction.edit_ids.clone())
396 })?;
397
398 session
399 .update(&mut cx, {
400 let buffer = buffer.clone();
401
402 |session, cx| {
403 session
404 .project
405 .update(cx, |project, cx| project.save_buffer(buffer, cx))
406 }
407 })?
408 .await?;
409
410 let snapshot = buffer.read_with(&cx, |buffer, _| buffer.snapshot())?;
411
412 // If we saved successfully, mark buffer as changed
413 let buffer_without_changes =
414 buffer.update(&mut cx, |buffer, cx| buffer.branch(cx))?;
415 session
416 .update(&mut cx, |session, cx| {
417 let changed_buffer = session
418 .changes_by_buffer
419 .entry(buffer)
420 .or_insert_with(|| BufferChanges {
421 diff: cx.new(|cx| BufferDiff::new(&snapshot, cx)),
422 edit_ids: Vec::new(),
423 });
424 changed_buffer.edit_ids.extend(edit_ids);
425 let operations_to_undo = changed_buffer
426 .edit_ids
427 .iter()
428 .map(|edit_id| (*edit_id, u32::MAX))
429 .collect::<HashMap<_, _>>();
430 buffer_without_changes.update(cx, |buffer, cx| {
431 buffer.undo_operations(operations_to_undo, cx);
432 });
433 changed_buffer.diff.update(cx, |diff, cx| {
434 diff.set_base_text(buffer_without_changes, snapshot.text, cx)
435 })
436 })?
437 .await?;
438
439 Ok(())
440 })
441 })
442 }),
443 )
444 .await??
445 .await
446 }
447
448 async fn io_file_dir(
449 lua: &Lua,
450 fs: Arc<dyn Fs>,
451 file: Table,
452 path: &Path,
453 ) -> mlua::Result<(Option<Table>, String)> {
454 // Create a special directory handle
455 file.set("__is_directory", true)?;
456
457 // Store directory entries
458 let entries = match fs.read_dir(&path).await {
459 Ok(entries) => {
460 let mut entry_names = Vec::new();
461
462 // Process the stream of directory entries
463 pin_mut!(entries);
464 while let Some(Ok(entry_result)) = entries.next().await {
465 if let Some(file_name) = entry_result.file_name() {
466 entry_names.push(file_name.to_string_lossy().into_owned());
467 }
468 }
469
470 entry_names
471 }
472 Err(e) => return Ok((None, format!("Error reading directory: {}", e))),
473 };
474
475 // Save the list of entries
476 file.set("__dir_entries", entries)?;
477 file.set("__dir_position", 0usize)?;
478
479 // Create a directory-specific read function
480 let read_fn = lua.create_function(|_lua, file_userdata: mlua::Table| {
481 let position = file_userdata.get::<usize>("__dir_position")?;
482 let entries = file_userdata.get::<Vec<String>>("__dir_entries")?;
483
484 if position >= entries.len() {
485 return Ok(None); // No more entries
486 }
487
488 let entry = entries[position].clone();
489 file_userdata.set("__dir_position", position + 1)?;
490
491 Ok(Some(entry))
492 })?;
493 file.set("read", read_fn)?;
494
495 // If we got this far, the directory was opened successfully
496 return Ok((Some(file), String::new()));
497 }
498
499 fn io_file_read(
500 lua: &Lua,
501 (file_userdata, format): (Table, Option<mlua::Value>),
502 ) -> mlua::Result<Option<mlua::String>> {
503 let read_perm = file_userdata.get::<bool>("__read_perm")?;
504 if !read_perm {
505 return Err(mlua::Error::runtime("File not open for reading"));
506 }
507
508 let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
509 let position = file_userdata.get::<usize>("__position")?;
510 let content_ref = content.borrow::<FileContent>()?;
511 let content = content_ref.0.lock();
512
513 if position >= content.len() {
514 return Ok(None); // EOF
515 }
516
517 let (result, new_position) = match Self::io_file_read_format(format)? {
518 FileReadFormat::All => {
519 // Read entire file from current position
520 let result = content[position..].to_vec();
521 (Some(result), content.len())
522 }
523 FileReadFormat::Line => {
524 if let Some(next_newline_ix) = content[position..].iter().position(|c| *c == b'\n')
525 {
526 let mut line = content[position..position + next_newline_ix].to_vec();
527 if line.ends_with(b"\r") {
528 line.pop();
529 }
530 (Some(line), position + next_newline_ix + 1)
531 } else if position < content.len() {
532 let line = content[position..].to_vec();
533 (Some(line), content.len())
534 } else {
535 (None, position) // EOF
536 }
537 }
538 FileReadFormat::LineWithLineFeed => {
539 if position < content.len() {
540 let next_line_ix = content[position..]
541 .iter()
542 .position(|c| *c == b'\n')
543 .map_or(content.len(), |ix| position + ix + 1);
544 let line = content[position..next_line_ix].to_vec();
545 (Some(line), next_line_ix)
546 } else {
547 (None, position) // EOF
548 }
549 }
550 FileReadFormat::Bytes(n) => {
551 let end = std::cmp::min(position + n, content.len());
552 let result = content[position..end].to_vec();
553 (Some(result), end)
554 }
555 };
556
557 // Update the position in the file userdata
558 if new_position != position {
559 file_userdata.set("__position", new_position)?;
560 }
561
562 // Convert the result to a Lua string
563 match result {
564 Some(bytes) => Ok(Some(lua.create_string(bytes)?)),
565 None => Ok(None),
566 }
567 }
568
569 fn io_file_read_format(format: Option<mlua::Value>) -> mlua::Result<FileReadFormat> {
570 let format = match format {
571 Some(mlua::Value::String(s)) => {
572 let lossy_string = s.to_string_lossy();
573 let format_str: &str = lossy_string.as_ref();
574
575 // Only consider the first 2 bytes, since it's common to pass e.g. "*all" instead of "*a"
576 match &format_str[0..2] {
577 "*a" => FileReadFormat::All,
578 "*l" => FileReadFormat::Line,
579 "*L" => FileReadFormat::LineWithLineFeed,
580 "*n" => {
581 // Try to parse as a number (number of bytes to read)
582 match format_str.parse::<usize>() {
583 Ok(n) => FileReadFormat::Bytes(n),
584 Err(_) => {
585 return Err(mlua::Error::runtime(format!(
586 "Invalid format: {}",
587 format_str
588 )))
589 }
590 }
591 }
592 _ => {
593 return Err(mlua::Error::runtime(format!(
594 "Unsupported format: {}",
595 format_str
596 )))
597 }
598 }
599 }
600 Some(mlua::Value::Number(n)) => FileReadFormat::Bytes(n as usize),
601 Some(mlua::Value::Integer(n)) => FileReadFormat::Bytes(n as usize),
602 Some(value) => {
603 return Err(mlua::Error::runtime(format!(
604 "Invalid file format {:?}",
605 value
606 )))
607 }
608 None => FileReadFormat::Line, // Default is to read a line
609 };
610
611 Ok(format)
612 }
613
614 fn io_file_write(
615 _lua: &Lua,
616 (file_userdata, text): (Table, mlua::String),
617 ) -> mlua::Result<bool> {
618 let write_perm = file_userdata.get::<bool>("__write_perm")?;
619 if !write_perm {
620 return Err(mlua::Error::runtime("File not open for writing"));
621 }
622
623 let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
624 let position = file_userdata.get::<usize>("__position")?;
625 let content_ref = content.borrow::<FileContent>()?;
626 let mut content_vec = content_ref.0.lock();
627
628 let bytes = text.as_bytes();
629
630 // Ensure the vector has enough capacity
631 if position + bytes.len() > content_vec.len() {
632 content_vec.resize(position + bytes.len(), 0);
633 }
634
635 // Write the bytes
636 for (i, &byte) in bytes.iter().enumerate() {
637 content_vec[position + i] = byte;
638 }
639
640 // Update position
641 let new_position = position + bytes.len();
642 file_userdata.set("__position", new_position)?;
643
644 Ok(true)
645 }
646
647 async fn search(
648 lua: &Lua,
649 foreground_tx: &mut mpsc::Sender<ForegroundFn>,
650 fs: Arc<dyn Fs>,
651 regex: String,
652 ) -> anyhow::Result<Table> {
653 // TODO: Allow specification of these options.
654 let search_query = SearchQuery::regex(
655 ®ex,
656 false,
657 false,
658 false,
659 PathMatcher::default(),
660 PathMatcher::default(),
661 None,
662 );
663 let search_query = match search_query {
664 Ok(query) => query,
665 Err(e) => return Err(anyhow!("Invalid search query: {}", e)),
666 };
667
668 // TODO: Should use `search_query.regex`. The tool description should also be updated,
669 // as it specifies standard regex.
670 let search_regex = match Regex::new(®ex) {
671 Ok(re) => re,
672 Err(e) => return Err(anyhow!("Invalid regex: {}", e)),
673 };
674
675 let mut abs_paths_rx = Self::find_search_candidates(search_query, foreground_tx).await?;
676
677 let mut search_results: Vec<Table> = Vec::new();
678 while let Some(path) = abs_paths_rx.next().await {
679 // Skip files larger than 1MB
680 if let Ok(Some(metadata)) = fs.metadata(&path).await {
681 if metadata.len > 1_000_000 {
682 continue;
683 }
684 }
685
686 // Attempt to read the file as text
687 if let Ok(content) = fs.load(&path).await {
688 let mut matches = Vec::new();
689
690 // Find all regex matches in the content
691 for capture in search_regex.find_iter(&content) {
692 matches.push(capture.as_str().to_string());
693 }
694
695 // If we found matches, create a result entry
696 if !matches.is_empty() {
697 let result_entry = lua.create_table()?;
698 result_entry.set("path", path.to_string_lossy().to_string())?;
699
700 let matches_table = lua.create_table()?;
701 for (ix, m) in matches.iter().enumerate() {
702 matches_table.set(ix + 1, m.clone())?;
703 }
704 result_entry.set("matches", matches_table)?;
705
706 search_results.push(result_entry);
707 }
708 }
709 }
710
711 // Create a table to hold our results
712 let results_table = lua.create_table()?;
713 for (ix, entry) in search_results.into_iter().enumerate() {
714 results_table.set(ix + 1, entry)?;
715 }
716
717 Ok(results_table)
718 }
719
720 async fn find_search_candidates(
721 search_query: SearchQuery,
722 foreground_tx: &mut mpsc::Sender<ForegroundFn>,
723 ) -> anyhow::Result<mpsc::UnboundedReceiver<PathBuf>> {
724 Self::run_foreground_fn(
725 "finding search file candidates",
726 foreground_tx,
727 Box::new(move |session, mut cx| {
728 session.update(&mut cx, |session, cx| {
729 session.project.update(cx, |project, cx| {
730 project.worktree_store().update(cx, |worktree_store, cx| {
731 // TODO: Better limit? For now this is the same as
732 // MAX_SEARCH_RESULT_FILES.
733 let limit = 5000;
734 // TODO: Providing non-empty open_entries can make this a bit more
735 // efficient as it can skip checking that these paths are textual.
736 let open_entries = HashSet::default();
737 let candidates = worktree_store.find_search_candidates(
738 search_query,
739 limit,
740 open_entries,
741 project.fs().clone(),
742 cx,
743 );
744 let (abs_paths_tx, abs_paths_rx) = mpsc::unbounded();
745 cx.spawn(|worktree_store, cx| async move {
746 pin_mut!(candidates);
747
748 while let Some(project_path) = candidates.next().await {
749 worktree_store.read_with(&cx, |worktree_store, cx| {
750 if let Some(worktree) = worktree_store
751 .worktree_for_id(project_path.worktree_id, cx)
752 {
753 if let Some(abs_path) = worktree
754 .read(cx)
755 .absolutize(&project_path.path)
756 .log_err()
757 {
758 abs_paths_tx.unbounded_send(abs_path)?;
759 }
760 }
761 anyhow::Ok(())
762 })??;
763 }
764 anyhow::Ok(())
765 })
766 .detach();
767 abs_paths_rx
768 })
769 })
770 })
771 }),
772 )
773 .await?
774 }
775
776 async fn outline(
777 root_dir: Option<Arc<Path>>,
778 foreground_tx: &mut mpsc::Sender<ForegroundFn>,
779 path_str: String,
780 ) -> anyhow::Result<String> {
781 let root_dir = root_dir
782 .ok_or_else(|| mlua::Error::runtime("cannot get outline without a root directory"))?;
783 let path = Self::parse_abs_path_in_root_dir(&root_dir, &path_str)?;
784 let outline = Self::run_foreground_fn(
785 "getting code outline",
786 foreground_tx,
787 Box::new(move |session, cx| {
788 cx.spawn(move |mut cx| async move {
789 // TODO: This will not use file content from `fs_changes`. It will also reflect
790 // user changes that have not been saved.
791 let buffer = session
792 .update(&mut cx, |session, cx| {
793 session
794 .project
795 .update(cx, |project, cx| project.open_local_buffer(&path, cx))
796 })?
797 .await?;
798 buffer.update(&mut cx, |buffer, _cx| {
799 if let Some(outline) = buffer.snapshot().outline(None) {
800 Ok(outline)
801 } else {
802 Err(anyhow!("No outline for file {path_str}"))
803 }
804 })
805 })
806 }),
807 )
808 .await?
809 .await??;
810
811 Ok(outline
812 .items
813 .into_iter()
814 .map(|item| {
815 if item.text.contains('\n') {
816 log::error!("Outline item unexpectedly contains newline");
817 }
818 format!("{}{}", " ".repeat(item.depth), item.text)
819 })
820 .collect::<Vec<String>>()
821 .join("\n"))
822 }
823
824 async fn run_foreground_fn<R: Send + 'static>(
825 description: &str,
826 foreground_tx: &mut mpsc::Sender<ForegroundFn>,
827 function: Box<dyn FnOnce(WeakEntity<Self>, AsyncApp) -> R + Send>,
828 ) -> anyhow::Result<R> {
829 let (response_tx, response_rx) = oneshot::channel();
830 let send_result = foreground_tx
831 .send(ForegroundFn(Box::new(move |this, cx| {
832 response_tx.send(function(this, cx)).ok();
833 })))
834 .await;
835 match send_result {
836 Ok(()) => (),
837 Err(err) => {
838 return Err(anyhow::Error::new(err).context(format!(
839 "Internal error while enqueuing work for {description}"
840 )));
841 }
842 }
843 match response_rx.await {
844 Ok(result) => Ok(result),
845 Err(oneshot::Canceled) => Err(anyhow!(
846 "Internal error: response oneshot was canceled while {description}."
847 )),
848 }
849 }
850
851 fn parse_abs_path_in_root_dir(root_dir: &Path, path_str: &str) -> anyhow::Result<PathBuf> {
852 let path = Path::new(&path_str);
853 if path.is_absolute() {
854 // Check if path starts with root_dir prefix without resolving symlinks
855 if path.starts_with(&root_dir) {
856 Ok(path.to_path_buf())
857 } else {
858 Err(anyhow!(
859 "Error: Absolute path {} is outside the current working directory",
860 path_str
861 ))
862 }
863 } else {
864 // TODO: Does use of `../` break sandbox - is path canonicalization needed?
865 Ok(root_dir.join(path))
866 }
867 }
868}
869
870enum FileReadFormat {
871 All,
872 Line,
873 LineWithLineFeed,
874 Bytes(usize),
875}
876
877struct FileContent(Arc<Mutex<Vec<u8>>>);
878
879impl UserData for FileContent {
880 fn add_methods<M: UserDataMethods<Self>>(_methods: &mut M) {
881 // FileContent doesn't have any methods so far.
882 }
883}
884
885#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
886pub struct ScriptId(u32);
887
888pub struct Script {
889 pub state: ScriptState,
890}
891
892#[derive(Debug)]
893pub enum ScriptState {
894 Running {
895 stdout: Arc<Mutex<String>>,
896 },
897 Succeeded {
898 stdout: String,
899 },
900 Failed {
901 stdout: String,
902 error: anyhow::Error,
903 },
904}
905
906impl Script {
907 /// If exited, returns a message with the output for the LLM
908 pub fn output_message_for_llm(&self) -> Option<String> {
909 match &self.state {
910 ScriptState::Running { .. } => None,
911 ScriptState::Succeeded { stdout } => {
912 format!("Here's the script output:\n{}", stdout).into()
913 }
914 ScriptState::Failed { stdout, error } => format!(
915 "The script failed with:\n{}\n\nHere's the output it managed to print:\n{}",
916 error, stdout
917 )
918 .into(),
919 }
920 }
921
922 /// Get a snapshot of the script's stdout
923 pub fn stdout_snapshot(&self) -> String {
924 match &self.state {
925 ScriptState::Running { stdout } => stdout.lock().clone(),
926 ScriptState::Succeeded { stdout } => stdout.clone(),
927 ScriptState::Failed { stdout, .. } => stdout.clone(),
928 }
929 }
930}
931
932#[cfg(test)]
933mod tests {
934 use gpui::TestAppContext;
935 use project::FakeFs;
936 use serde_json::json;
937 use settings::SettingsStore;
938 use util::path;
939
940 use super::*;
941
942 #[gpui::test]
943 async fn test_print(cx: &mut TestAppContext) {
944 let script = r#"
945 print("Hello", "world!")
946 print("Goodbye", "moon!")
947 "#;
948
949 let test_session = TestSession::init(cx).await;
950 let output = test_session.test_success(script, cx).await;
951 assert_eq!(output, "Hello\tworld!\nGoodbye\tmoon!\n");
952 }
953
954 // search
955
956 #[gpui::test]
957 async fn test_search(cx: &mut TestAppContext) {
958 let script = r#"
959 local results = search("world")
960 for i, result in ipairs(results) do
961 print("File: " .. result.path)
962 print("Matches:")
963 for j, match in ipairs(result.matches) do
964 print(" " .. match)
965 end
966 end
967 "#;
968
969 let test_session = TestSession::init(cx).await;
970 let output = test_session.test_success(script, cx).await;
971 assert_eq!(
972 output,
973 concat!("File: ", path!("/file1.txt"), "\nMatches:\n world\n")
974 );
975 }
976
977 // io.open
978
979 #[gpui::test]
980 async fn test_open_and_read_file(cx: &mut TestAppContext) {
981 let script = r#"
982 local file = io.open("file1.txt", "r")
983 local content = file:read()
984 print("Content:", content)
985 file:close()
986 "#;
987
988 let test_session = TestSession::init(cx).await;
989 let output = test_session.test_success(script, cx).await;
990 assert_eq!(output, "Content:\tHello world!\n");
991 assert_eq!(test_session.diff(cx), Vec::new());
992 }
993
994 #[gpui::test]
995 async fn test_read_write_roundtrip(cx: &mut TestAppContext) {
996 let script = r#"
997 local file = io.open("file1.txt", "w")
998 file:write("This is new content")
999 file:close()
1000
1001 -- Read back to verify
1002 local read_file = io.open("file1.txt", "r")
1003 local content = read_file:read("*a")
1004 print("Written content:", content)
1005 read_file:close()
1006 "#;
1007
1008 let test_session = TestSession::init(cx).await;
1009 let output = test_session.test_success(script, cx).await;
1010 assert_eq!(output, "Written content:\tThis is new content\n");
1011 assert_eq!(
1012 test_session.diff(cx),
1013 vec![(
1014 PathBuf::from("file1.txt"),
1015 vec![(
1016 "Hello world!\n".to_string(),
1017 "This is new content".to_string()
1018 )]
1019 )]
1020 );
1021 }
1022
1023 #[gpui::test]
1024 async fn test_multiple_writes(cx: &mut TestAppContext) {
1025 let script = r#"
1026 -- Test writing to a file multiple times
1027 local file = io.open("multiwrite.txt", "w")
1028 file:write("First line\n")
1029 file:write("Second line\n")
1030 file:write("Third line")
1031 file:close()
1032
1033 -- Read back to verify
1034 local read_file = io.open("multiwrite.txt", "r")
1035 if read_file then
1036 local content = read_file:read("*a")
1037 print("Full content:", content)
1038 read_file:close()
1039 end
1040 "#;
1041
1042 let test_session = TestSession::init(cx).await;
1043 let output = test_session.test_success(script, cx).await;
1044 assert_eq!(
1045 output,
1046 "Full content:\tFirst line\nSecond line\nThird line\n"
1047 );
1048 assert_eq!(
1049 test_session.diff(cx),
1050 vec![(
1051 PathBuf::from("multiwrite.txt"),
1052 vec![(
1053 "".to_string(),
1054 "First line\nSecond line\nThird line".to_string()
1055 )]
1056 )]
1057 );
1058 }
1059
1060 #[gpui::test]
1061 async fn test_multiple_writes_diff_handles(cx: &mut TestAppContext) {
1062 let script = r#"
1063 -- Write to a file
1064 local file1 = io.open("multi_open.txt", "w")
1065 file1:write("Content written by first handle\n")
1066 file1:close()
1067
1068 -- Open it again and add more content
1069 local file2 = io.open("multi_open.txt", "w")
1070 file2:write("Content written by second handle\n")
1071 file2:close()
1072
1073 -- Open it a third time and read
1074 local file3 = io.open("multi_open.txt", "r")
1075 local content = file3:read("*a")
1076 print("Final content:", content)
1077 file3:close()
1078 "#;
1079
1080 let test_session = TestSession::init(cx).await;
1081 let output = test_session.test_success(script, cx).await;
1082 assert_eq!(
1083 output,
1084 "Final content:\tContent written by second handle\n\n"
1085 );
1086 assert_eq!(
1087 test_session.diff(cx),
1088 vec![(
1089 PathBuf::from("multi_open.txt"),
1090 vec![(
1091 "".to_string(),
1092 "Content written by second handle\n".to_string()
1093 )]
1094 )]
1095 );
1096 }
1097
1098 #[gpui::test]
1099 async fn test_append_mode(cx: &mut TestAppContext) {
1100 let script = r#"
1101 -- Append more content
1102 file = io.open("file1.txt", "a")
1103 file:write("Appended content\n")
1104 file:close()
1105
1106 -- Add even more
1107 file = io.open("file1.txt", "a")
1108 file:write("More appended content")
1109 file:close()
1110
1111 -- Read back to verify
1112 local read_file = io.open("file1.txt", "r")
1113 local content = read_file:read("*a")
1114 print("Content after appends:", content)
1115 read_file:close()
1116 "#;
1117
1118 let test_session = TestSession::init(cx).await;
1119 let output = test_session.test_success(script, cx).await;
1120 assert_eq!(
1121 output,
1122 "Content after appends:\tHello world!\nAppended content\nMore appended content\n"
1123 );
1124 assert_eq!(
1125 test_session.diff(cx),
1126 vec![(
1127 PathBuf::from("file1.txt"),
1128 vec![(
1129 "".to_string(),
1130 "Appended content\nMore appended content".to_string()
1131 )]
1132 )]
1133 );
1134 }
1135
1136 #[gpui::test]
1137 async fn test_read_formats(cx: &mut TestAppContext) {
1138 let script = r#"
1139 local file = io.open("multiline.txt", "w")
1140 file:write("Line 1\nLine 2\nLine 3")
1141 file:close()
1142
1143 -- Test "*a" (all)
1144 local f = io.open("multiline.txt", "r")
1145 local all = f:read("*a")
1146 print("All:", all)
1147 f:close()
1148
1149 -- Test "*l" (line)
1150 f = io.open("multiline.txt", "r")
1151 local line1 = f:read("*l")
1152 local line2 = f:read("*l")
1153 local line3 = f:read("*l")
1154 print("Line 1:", line1)
1155 print("Line 2:", line2)
1156 print("Line 3:", line3)
1157 f:close()
1158
1159 -- Test "*L" (line with newline)
1160 f = io.open("multiline.txt", "r")
1161 local line_with_nl = f:read("*L")
1162 print("Line with newline length:", #line_with_nl)
1163 print("Last char:", string.byte(line_with_nl, #line_with_nl))
1164 f:close()
1165
1166 -- Test number of bytes
1167 f = io.open("multiline.txt", "r")
1168 local bytes5 = f:read(5)
1169 print("5 bytes:", bytes5)
1170 f:close()
1171 "#;
1172
1173 let test_session = TestSession::init(cx).await;
1174 let output = test_session.test_success(script, cx).await;
1175 println!("{}", &output);
1176 assert!(output.contains("All:\tLine 1\nLine 2\nLine 3"));
1177 assert!(output.contains("Line 1:\tLine 1"));
1178 assert!(output.contains("Line 2:\tLine 2"));
1179 assert!(output.contains("Line 3:\tLine 3"));
1180 assert!(output.contains("Line with newline length:\t7"));
1181 assert!(output.contains("Last char:\t10")); // LF
1182 assert!(output.contains("5 bytes:\tLine "));
1183 assert_eq!(
1184 test_session.diff(cx),
1185 vec![(
1186 PathBuf::from("multiline.txt"),
1187 vec![("".to_string(), "Line 1\nLine 2\nLine 3".to_string())]
1188 )]
1189 );
1190 }
1191
1192 // helpers
1193
1194 struct TestSession {
1195 session: Entity<ScriptingSession>,
1196 }
1197
1198 impl TestSession {
1199 async fn init(cx: &mut TestAppContext) -> Self {
1200 let settings_store = cx.update(SettingsStore::test);
1201 cx.set_global(settings_store);
1202 cx.update(Project::init_settings);
1203 cx.update(language::init);
1204
1205 let fs = FakeFs::new(cx.executor());
1206 fs.insert_tree(
1207 path!("/"),
1208 json!({
1209 "file1.txt": "Hello world!\n",
1210 "file2.txt": "Goodbye moon!\n"
1211 }),
1212 )
1213 .await;
1214
1215 let project = Project::test(fs.clone(), [Path::new(path!("/"))], cx).await;
1216 let session = cx.new(|cx| ScriptingSession::new(project, cx));
1217
1218 TestSession { session }
1219 }
1220
1221 async fn test_success(&self, source: &str, cx: &mut TestAppContext) -> String {
1222 let script_id = self.run_script(source, cx).await;
1223
1224 self.session.read_with(cx, |session, _cx| {
1225 let script = session.get(script_id);
1226 let stdout = script.stdout_snapshot();
1227
1228 if let ScriptState::Failed { error, .. } = &script.state {
1229 panic!("Script failed:\n{}\n\n{}", error, stdout);
1230 }
1231
1232 stdout
1233 })
1234 }
1235
1236 fn diff(&self, cx: &mut TestAppContext) -> Vec<(PathBuf, Vec<(String, String)>)> {
1237 self.session.read_with(cx, |session, cx| {
1238 session
1239 .changes_by_buffer
1240 .iter()
1241 .map(|(buffer, changes)| {
1242 let snapshot = buffer.read(cx).snapshot();
1243 let diff = changes.diff.read(cx);
1244 let hunks = diff.hunks(&snapshot, cx);
1245 let path = buffer.read(cx).file().unwrap().path().clone();
1246 let diffs = hunks
1247 .map(|hunk| {
1248 let old_text = diff
1249 .base_text()
1250 .text_for_range(hunk.diff_base_byte_range)
1251 .collect::<String>();
1252 let new_text =
1253 snapshot.text_for_range(hunk.range).collect::<String>();
1254 (old_text, new_text)
1255 })
1256 .collect();
1257 (path.to_path_buf(), diffs)
1258 })
1259 .collect()
1260 })
1261 }
1262
1263 async fn run_script(&self, source: &str, cx: &mut TestAppContext) -> ScriptId {
1264 let (script_id, task) = self
1265 .session
1266 .update(cx, |session, cx| session.run_script(source.to_string(), cx));
1267
1268 task.await;
1269
1270 script_id
1271 }
1272 }
1273}