1use anyhow::anyhow;
2use collections::HashSet;
3use futures::{
4 channel::{mpsc, oneshot},
5 pin_mut, SinkExt, StreamExt,
6};
7use gpui::{AppContext, AsyncApp, Context, Entity, Task, WeakEntity};
8use language::Buffer;
9use mlua::{ExternalResult, Lua, MultiValue, Table, UserData, UserDataMethods};
10use parking_lot::Mutex;
11use project::{search::SearchQuery, Fs, Project, ProjectPath, WorktreeId};
12use regex::Regex;
13use std::{
14 path::{Path, PathBuf},
15 sync::Arc,
16};
17use util::{paths::PathMatcher, ResultExt};
18
19struct ForegroundFn(Box<dyn FnOnce(WeakEntity<ScriptingSession>, AsyncApp) + Send>);
20
21pub struct ScriptingSession {
22 project: Entity<Project>,
23 scripts: Vec<Script>,
24 changed_buffers: HashSet<Entity<Buffer>>,
25 foreground_fns_tx: mpsc::Sender<ForegroundFn>,
26 _invoke_foreground_fns: Task<()>,
27}
28
29impl ScriptingSession {
30 pub fn new(project: Entity<Project>, cx: &mut Context<Self>) -> Self {
31 let (foreground_fns_tx, mut foreground_fns_rx) = mpsc::channel(128);
32 ScriptingSession {
33 project,
34 scripts: Vec::new(),
35 changed_buffers: HashSet::default(),
36 foreground_fns_tx,
37 _invoke_foreground_fns: cx.spawn(|this, cx| async move {
38 while let Some(foreground_fn) = foreground_fns_rx.next().await {
39 foreground_fn.0(this.clone(), cx.clone());
40 }
41 }),
42 }
43 }
44
45 pub fn changed_buffers(&self) -> impl ExactSizeIterator<Item = &Entity<Buffer>> {
46 self.changed_buffers.iter()
47 }
48
49 pub fn run_script(
50 &mut self,
51 script_src: String,
52 cx: &mut Context<Self>,
53 ) -> (ScriptId, Task<()>) {
54 let id = ScriptId(self.scripts.len() as u32);
55
56 let stdout = Arc::new(Mutex::new(String::new()));
57
58 let script = Script {
59 state: ScriptState::Running {
60 stdout: stdout.clone(),
61 },
62 };
63 self.scripts.push(script);
64
65 let task = self.run_lua(script_src, stdout, cx);
66
67 let task = cx.spawn(|session, mut cx| async move {
68 let result = task.await;
69
70 session
71 .update(&mut cx, |session, _cx| {
72 let script = session.get_mut(id);
73 let stdout = script.stdout_snapshot();
74
75 script.state = match result {
76 Ok(()) => ScriptState::Succeeded { stdout },
77 Err(error) => ScriptState::Failed { stdout, error },
78 };
79 })
80 .log_err();
81 });
82
83 (id, task)
84 }
85
86 fn run_lua(
87 &mut self,
88 script: String,
89 stdout: Arc<Mutex<String>>,
90 cx: &mut Context<Self>,
91 ) -> Task<anyhow::Result<()>> {
92 const SANDBOX_PREAMBLE: &str = include_str!("sandbox_preamble.lua");
93
94 // TODO Honor all worktrees instead of the first one
95 let worktree_info = self
96 .project
97 .read(cx)
98 .visible_worktrees(cx)
99 .next()
100 .map(|worktree| {
101 let worktree = worktree.read(cx);
102 (worktree.id(), worktree.abs_path())
103 });
104
105 let root_dir = worktree_info.as_ref().map(|(_, root)| root.clone());
106
107 let fs = self.project.read(cx).fs().clone();
108 let foreground_fns_tx = self.foreground_fns_tx.clone();
109
110 let task = cx.background_spawn({
111 let stdout = stdout.clone();
112
113 async move {
114 let lua = Lua::new();
115 lua.set_memory_limit(2 * 1024 * 1024 * 1024)?; // 2 GB
116 let globals = lua.globals();
117
118 // Use the project root dir as the script's current working dir.
119 if let Some(root_dir) = &root_dir {
120 if let Some(root_dir) = root_dir.to_str() {
121 globals.set("cwd", root_dir)?;
122 }
123 }
124
125 globals.set(
126 "sb_print",
127 lua.create_function({
128 let stdout = stdout.clone();
129 move |_, args: MultiValue| Self::print(args, &stdout)
130 })?,
131 )?;
132 globals.set(
133 "search",
134 lua.create_async_function({
135 let foreground_fns_tx = foreground_fns_tx.clone();
136 let fs = fs.clone();
137 move |lua, regex| {
138 let mut foreground_fns_tx = foreground_fns_tx.clone();
139 let fs = fs.clone();
140 async move {
141 Self::search(&lua, &mut foreground_fns_tx, fs, regex)
142 .await
143 .into_lua_err()
144 }
145 }
146 })?,
147 )?;
148 globals.set(
149 "outline",
150 lua.create_async_function({
151 let root_dir = root_dir.clone();
152 let foreground_fns_tx = foreground_fns_tx.clone();
153 move |_lua, path| {
154 let mut foreground_fns_tx = foreground_fns_tx.clone();
155 let root_dir = root_dir.clone();
156 async move {
157 Self::outline(root_dir, &mut foreground_fns_tx, path)
158 .await
159 .into_lua_err()
160 }
161 }
162 })?,
163 )?;
164 globals.set(
165 "sb_io_open",
166 lua.create_async_function({
167 let worktree_info = worktree_info.clone();
168 let foreground_fns_tx = foreground_fns_tx.clone();
169 move |lua, (path_str, mode)| {
170 let worktree_info = worktree_info.clone();
171 let mut foreground_fns_tx = foreground_fns_tx.clone();
172 let fs = fs.clone();
173 async move {
174 Self::io_open(
175 &lua,
176 worktree_info,
177 &mut foreground_fns_tx,
178 fs,
179 path_str,
180 mode,
181 )
182 .await
183 }
184 }
185 })?,
186 )?;
187 globals.set("user_script", script)?;
188
189 lua.load(SANDBOX_PREAMBLE).exec_async().await?;
190
191 // Drop Lua instance to decrement reference count.
192 drop(lua);
193
194 anyhow::Ok(())
195 }
196 });
197
198 task
199 }
200
201 pub fn get(&self, script_id: ScriptId) -> &Script {
202 &self.scripts[script_id.0 as usize]
203 }
204
205 fn get_mut(&mut self, script_id: ScriptId) -> &mut Script {
206 &mut self.scripts[script_id.0 as usize]
207 }
208
209 /// Sandboxed print() function in Lua.
210 fn print(args: MultiValue, stdout: &Mutex<String>) -> mlua::Result<()> {
211 for (index, arg) in args.into_iter().enumerate() {
212 // Lua's `print()` prints tab characters between each argument.
213 if index > 0 {
214 stdout.lock().push('\t');
215 }
216
217 // If the argument's to_string() fails, have the whole function call fail.
218 stdout.lock().push_str(&arg.to_string()?);
219 }
220 stdout.lock().push('\n');
221
222 Ok(())
223 }
224
225 /// Sandboxed io.open() function in Lua.
226 async fn io_open(
227 lua: &Lua,
228 worktree_info: Option<(WorktreeId, Arc<Path>)>,
229 foreground_tx: &mut mpsc::Sender<ForegroundFn>,
230 fs: Arc<dyn Fs>,
231 path_str: String,
232 mode: Option<String>,
233 ) -> mlua::Result<(Option<Table>, String)> {
234 let (worktree_id, root_dir) = worktree_info
235 .ok_or_else(|| mlua::Error::runtime("cannot open file without a root directory"))?;
236
237 let mode = mode.unwrap_or_else(|| "r".to_string());
238
239 // Parse the mode string to determine read/write permissions
240 let read_perm = mode.contains('r');
241 let write_perm = mode.contains('w') || mode.contains('a') || mode.contains('+');
242 let append = mode.contains('a');
243 let truncate = mode.contains('w');
244
245 // This will be the Lua value returned from the `open` function.
246 let file = lua.create_table()?;
247
248 // Store file metadata in the file
249 file.set("__mode", mode.clone())?;
250 file.set("__read_perm", read_perm)?;
251 file.set("__write_perm", write_perm)?;
252
253 let path = match Self::parse_abs_path_in_root_dir(&root_dir, &path_str) {
254 Ok(path) => path,
255 Err(err) => return Ok((None, format!("{err}"))),
256 };
257
258 let project_path = ProjectPath {
259 worktree_id,
260 path: Path::new(&path_str).into(),
261 };
262
263 // flush / close method
264 let flush_fn = {
265 let project_path = project_path.clone();
266 let foreground_tx = foreground_tx.clone();
267 lua.create_async_function(move |_lua, file_userdata: mlua::Table| {
268 let project_path = project_path.clone();
269 let mut foreground_tx = foreground_tx.clone();
270 async move {
271 Self::io_file_flush(file_userdata, project_path, &mut foreground_tx).await
272 }
273 })?
274 };
275 file.set("flush", flush_fn.clone())?;
276 // We don't really hold files open, so we only need to flush on close
277 file.set("close", flush_fn)?;
278
279 // If it's a directory, give it a custom read() and return early.
280 if fs.is_dir(&path).await {
281 return Self::io_file_dir(lua, fs, file, &path).await;
282 }
283
284 let mut file_content = Vec::new();
285
286 if !truncate {
287 // Try to read existing content if we're not truncating
288 match Self::read_buffer(project_path.clone(), foreground_tx).await {
289 Ok(content) => file_content = content.into_bytes(),
290 Err(e) => return Ok((None, format!("Error reading file: {}", e))),
291 }
292 }
293
294 // If in append mode, position should be at the end
295 let position = if append { file_content.len() } else { 0 };
296 file.set("__position", position)?;
297 file.set(
298 "__content",
299 lua.create_userdata(FileContent(Arc::new(Mutex::new(file_content))))?,
300 )?;
301
302 // Create file methods
303
304 // read method
305 let read_fn = lua.create_function(Self::io_file_read)?;
306 file.set("read", read_fn)?;
307
308 // write method
309 let write_fn = lua.create_function(Self::io_file_write)?;
310 file.set("write", write_fn)?;
311
312 // If we got this far, the file was opened successfully
313 Ok((Some(file), String::new()))
314 }
315
316 async fn read_buffer(
317 project_path: ProjectPath,
318 foreground_tx: &mut mpsc::Sender<ForegroundFn>,
319 ) -> anyhow::Result<String> {
320 Self::run_foreground_fn(
321 "read file from buffer",
322 foreground_tx,
323 Box::new(move |session, mut cx| {
324 session.update(&mut cx, |session, cx| {
325 let open_buffer_task = session
326 .project
327 .update(cx, |project, cx| project.open_buffer(project_path, cx));
328
329 cx.spawn(|_, cx| async move {
330 let buffer = open_buffer_task.await?;
331
332 let text = buffer.read_with(&cx, |buffer, _cx| buffer.text())?;
333 Ok(text)
334 })
335 })
336 }),
337 )
338 .await??
339 .await
340 }
341
342 async fn io_file_flush(
343 file_userdata: mlua::Table,
344 project_path: ProjectPath,
345 foreground_tx: &mut mpsc::Sender<ForegroundFn>,
346 ) -> mlua::Result<bool> {
347 let write_perm = file_userdata.get::<bool>("__write_perm")?;
348
349 if write_perm {
350 let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
351 let content_ref = content.borrow::<FileContent>()?;
352 let text = {
353 let mut content_vec = content_ref.0.lock();
354 let content_vec = std::mem::take(&mut *content_vec);
355 String::from_utf8(content_vec).into_lua_err()?
356 };
357
358 Self::write_to_buffer(project_path, text, foreground_tx)
359 .await
360 .into_lua_err()?;
361 }
362
363 Ok(true)
364 }
365
366 async fn write_to_buffer(
367 project_path: ProjectPath,
368 text: String,
369 foreground_tx: &mut mpsc::Sender<ForegroundFn>,
370 ) -> anyhow::Result<()> {
371 Self::run_foreground_fn(
372 "write to buffer",
373 foreground_tx,
374 Box::new(move |session, mut cx| {
375 session.update(&mut cx, |session, cx| {
376 let open_buffer_task = session
377 .project
378 .update(cx, |project, cx| project.open_buffer(project_path, cx));
379
380 cx.spawn(move |session, mut cx| async move {
381 let buffer = open_buffer_task.await?;
382
383 let diff = buffer
384 .update(&mut cx, |buffer, cx| buffer.diff(text, cx))?
385 .await;
386
387 buffer.update(&mut cx, |buffer, cx| {
388 buffer.apply_diff(diff, cx);
389 })?;
390
391 session
392 .update(&mut cx, {
393 let buffer = buffer.clone();
394
395 |session, cx| {
396 session
397 .project
398 .update(cx, |project, cx| project.save_buffer(buffer, cx))
399 }
400 })?
401 .await?;
402
403 // If we saved successfully, mark buffer as changed
404 session.update(&mut cx, |session, _cx| {
405 session.changed_buffers.insert(buffer);
406 })
407 })
408 })
409 }),
410 )
411 .await??
412 .await
413 }
414
415 async fn io_file_dir(
416 lua: &Lua,
417 fs: Arc<dyn Fs>,
418 file: Table,
419 path: &Path,
420 ) -> mlua::Result<(Option<Table>, String)> {
421 // Create a special directory handle
422 file.set("__is_directory", true)?;
423
424 // Store directory entries
425 let entries = match fs.read_dir(&path).await {
426 Ok(entries) => {
427 let mut entry_names = Vec::new();
428
429 // Process the stream of directory entries
430 pin_mut!(entries);
431 while let Some(Ok(entry_result)) = entries.next().await {
432 if let Some(file_name) = entry_result.file_name() {
433 entry_names.push(file_name.to_string_lossy().into_owned());
434 }
435 }
436
437 entry_names
438 }
439 Err(e) => return Ok((None, format!("Error reading directory: {}", e))),
440 };
441
442 // Save the list of entries
443 file.set("__dir_entries", entries)?;
444 file.set("__dir_position", 0usize)?;
445
446 // Create a directory-specific read function
447 let read_fn = lua.create_function(|_lua, file_userdata: mlua::Table| {
448 let position = file_userdata.get::<usize>("__dir_position")?;
449 let entries = file_userdata.get::<Vec<String>>("__dir_entries")?;
450
451 if position >= entries.len() {
452 return Ok(None); // No more entries
453 }
454
455 let entry = entries[position].clone();
456 file_userdata.set("__dir_position", position + 1)?;
457
458 Ok(Some(entry))
459 })?;
460 file.set("read", read_fn)?;
461
462 // If we got this far, the directory was opened successfully
463 return Ok((Some(file), String::new()));
464 }
465
466 fn io_file_read(
467 lua: &Lua,
468 (file_userdata, format): (Table, Option<mlua::Value>),
469 ) -> mlua::Result<Option<mlua::String>> {
470 let read_perm = file_userdata.get::<bool>("__read_perm")?;
471 if !read_perm {
472 return Err(mlua::Error::runtime("File not open for reading"));
473 }
474
475 let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
476 let position = file_userdata.get::<usize>("__position")?;
477 let content_ref = content.borrow::<FileContent>()?;
478 let content = content_ref.0.lock();
479
480 if position >= content.len() {
481 return Ok(None); // EOF
482 }
483
484 let (result, new_position) = match Self::io_file_read_format(format)? {
485 FileReadFormat::All => {
486 // Read entire file from current position
487 let result = content[position..].to_vec();
488 (Some(result), content.len())
489 }
490 FileReadFormat::Line => {
491 if let Some(next_newline_ix) = content[position..].iter().position(|c| *c == b'\n')
492 {
493 let mut line = content[position..position + next_newline_ix].to_vec();
494 if line.ends_with(b"\r") {
495 line.pop();
496 }
497 (Some(line), position + next_newline_ix + 1)
498 } else if position < content.len() {
499 let line = content[position..].to_vec();
500 (Some(line), content.len())
501 } else {
502 (None, position) // EOF
503 }
504 }
505 FileReadFormat::LineWithLineFeed => {
506 if position < content.len() {
507 let next_line_ix = content[position..]
508 .iter()
509 .position(|c| *c == b'\n')
510 .map_or(content.len(), |ix| position + ix + 1);
511 let line = content[position..next_line_ix].to_vec();
512 (Some(line), next_line_ix)
513 } else {
514 (None, position) // EOF
515 }
516 }
517 FileReadFormat::Bytes(n) => {
518 let end = std::cmp::min(position + n, content.len());
519 let result = content[position..end].to_vec();
520 (Some(result), end)
521 }
522 };
523
524 // Update the position in the file userdata
525 if new_position != position {
526 file_userdata.set("__position", new_position)?;
527 }
528
529 // Convert the result to a Lua string
530 match result {
531 Some(bytes) => Ok(Some(lua.create_string(bytes)?)),
532 None => Ok(None),
533 }
534 }
535
536 fn io_file_read_format(format: Option<mlua::Value>) -> mlua::Result<FileReadFormat> {
537 let format = match format {
538 Some(mlua::Value::String(s)) => {
539 let lossy_string = s.to_string_lossy();
540 let format_str: &str = lossy_string.as_ref();
541
542 // Only consider the first 2 bytes, since it's common to pass e.g. "*all" instead of "*a"
543 match &format_str[0..2] {
544 "*a" => FileReadFormat::All,
545 "*l" => FileReadFormat::Line,
546 "*L" => FileReadFormat::LineWithLineFeed,
547 "*n" => {
548 // Try to parse as a number (number of bytes to read)
549 match format_str.parse::<usize>() {
550 Ok(n) => FileReadFormat::Bytes(n),
551 Err(_) => {
552 return Err(mlua::Error::runtime(format!(
553 "Invalid format: {}",
554 format_str
555 )))
556 }
557 }
558 }
559 _ => {
560 return Err(mlua::Error::runtime(format!(
561 "Unsupported format: {}",
562 format_str
563 )))
564 }
565 }
566 }
567 Some(mlua::Value::Number(n)) => FileReadFormat::Bytes(n as usize),
568 Some(mlua::Value::Integer(n)) => FileReadFormat::Bytes(n as usize),
569 Some(value) => {
570 return Err(mlua::Error::runtime(format!(
571 "Invalid file format {:?}",
572 value
573 )))
574 }
575 None => FileReadFormat::Line, // Default is to read a line
576 };
577
578 Ok(format)
579 }
580
581 fn io_file_write(
582 _lua: &Lua,
583 (file_userdata, text): (Table, mlua::String),
584 ) -> mlua::Result<bool> {
585 let write_perm = file_userdata.get::<bool>("__write_perm")?;
586 if !write_perm {
587 return Err(mlua::Error::runtime("File not open for writing"));
588 }
589
590 let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
591 let position = file_userdata.get::<usize>("__position")?;
592 let content_ref = content.borrow::<FileContent>()?;
593 let mut content_vec = content_ref.0.lock();
594
595 let bytes = text.as_bytes();
596
597 // Ensure the vector has enough capacity
598 if position + bytes.len() > content_vec.len() {
599 content_vec.resize(position + bytes.len(), 0);
600 }
601
602 // Write the bytes
603 for (i, &byte) in bytes.iter().enumerate() {
604 content_vec[position + i] = byte;
605 }
606
607 // Update position
608 let new_position = position + bytes.len();
609 file_userdata.set("__position", new_position)?;
610
611 Ok(true)
612 }
613
614 async fn search(
615 lua: &Lua,
616 foreground_tx: &mut mpsc::Sender<ForegroundFn>,
617 fs: Arc<dyn Fs>,
618 regex: String,
619 ) -> anyhow::Result<Table> {
620 // TODO: Allow specification of these options.
621 let search_query = SearchQuery::regex(
622 ®ex,
623 false,
624 false,
625 false,
626 PathMatcher::default(),
627 PathMatcher::default(),
628 None,
629 );
630 let search_query = match search_query {
631 Ok(query) => query,
632 Err(e) => return Err(anyhow!("Invalid search query: {}", e)),
633 };
634
635 // TODO: Should use `search_query.regex`. The tool description should also be updated,
636 // as it specifies standard regex.
637 let search_regex = match Regex::new(®ex) {
638 Ok(re) => re,
639 Err(e) => return Err(anyhow!("Invalid regex: {}", e)),
640 };
641
642 let mut abs_paths_rx = Self::find_search_candidates(search_query, foreground_tx).await?;
643
644 let mut search_results: Vec<Table> = Vec::new();
645 while let Some(path) = abs_paths_rx.next().await {
646 // Skip files larger than 1MB
647 if let Ok(Some(metadata)) = fs.metadata(&path).await {
648 if metadata.len > 1_000_000 {
649 continue;
650 }
651 }
652
653 // Attempt to read the file as text
654 if let Ok(content) = fs.load(&path).await {
655 let mut matches = Vec::new();
656
657 // Find all regex matches in the content
658 for capture in search_regex.find_iter(&content) {
659 matches.push(capture.as_str().to_string());
660 }
661
662 // If we found matches, create a result entry
663 if !matches.is_empty() {
664 let result_entry = lua.create_table()?;
665 result_entry.set("path", path.to_string_lossy().to_string())?;
666
667 let matches_table = lua.create_table()?;
668 for (ix, m) in matches.iter().enumerate() {
669 matches_table.set(ix + 1, m.clone())?;
670 }
671 result_entry.set("matches", matches_table)?;
672
673 search_results.push(result_entry);
674 }
675 }
676 }
677
678 // Create a table to hold our results
679 let results_table = lua.create_table()?;
680 for (ix, entry) in search_results.into_iter().enumerate() {
681 results_table.set(ix + 1, entry)?;
682 }
683
684 Ok(results_table)
685 }
686
687 async fn find_search_candidates(
688 search_query: SearchQuery,
689 foreground_tx: &mut mpsc::Sender<ForegroundFn>,
690 ) -> anyhow::Result<mpsc::UnboundedReceiver<PathBuf>> {
691 Self::run_foreground_fn(
692 "finding search file candidates",
693 foreground_tx,
694 Box::new(move |session, mut cx| {
695 session.update(&mut cx, |session, cx| {
696 session.project.update(cx, |project, cx| {
697 project.worktree_store().update(cx, |worktree_store, cx| {
698 // TODO: Better limit? For now this is the same as
699 // MAX_SEARCH_RESULT_FILES.
700 let limit = 5000;
701 // TODO: Providing non-empty open_entries can make this a bit more
702 // efficient as it can skip checking that these paths are textual.
703 let open_entries = HashSet::default();
704 let candidates = worktree_store.find_search_candidates(
705 search_query,
706 limit,
707 open_entries,
708 project.fs().clone(),
709 cx,
710 );
711 let (abs_paths_tx, abs_paths_rx) = mpsc::unbounded();
712 cx.spawn(|worktree_store, cx| async move {
713 pin_mut!(candidates);
714
715 while let Some(project_path) = candidates.next().await {
716 worktree_store.read_with(&cx, |worktree_store, cx| {
717 if let Some(worktree) = worktree_store
718 .worktree_for_id(project_path.worktree_id, cx)
719 {
720 if let Some(abs_path) = worktree
721 .read(cx)
722 .absolutize(&project_path.path)
723 .log_err()
724 {
725 abs_paths_tx.unbounded_send(abs_path)?;
726 }
727 }
728 anyhow::Ok(())
729 })??;
730 }
731 anyhow::Ok(())
732 })
733 .detach();
734 abs_paths_rx
735 })
736 })
737 })
738 }),
739 )
740 .await?
741 }
742
743 async fn outline(
744 root_dir: Option<Arc<Path>>,
745 foreground_tx: &mut mpsc::Sender<ForegroundFn>,
746 path_str: String,
747 ) -> anyhow::Result<String> {
748 let root_dir = root_dir
749 .ok_or_else(|| mlua::Error::runtime("cannot get outline without a root directory"))?;
750 let path = Self::parse_abs_path_in_root_dir(&root_dir, &path_str)?;
751 let outline = Self::run_foreground_fn(
752 "getting code outline",
753 foreground_tx,
754 Box::new(move |session, cx| {
755 cx.spawn(move |mut cx| async move {
756 // TODO: This will not use file content from `fs_changes`. It will also reflect
757 // user changes that have not been saved.
758 let buffer = session
759 .update(&mut cx, |session, cx| {
760 session
761 .project
762 .update(cx, |project, cx| project.open_local_buffer(&path, cx))
763 })?
764 .await?;
765 buffer.update(&mut cx, |buffer, _cx| {
766 if let Some(outline) = buffer.snapshot().outline(None) {
767 Ok(outline)
768 } else {
769 Err(anyhow!("No outline for file {path_str}"))
770 }
771 })
772 })
773 }),
774 )
775 .await?
776 .await??;
777
778 Ok(outline
779 .items
780 .into_iter()
781 .map(|item| {
782 if item.text.contains('\n') {
783 log::error!("Outline item unexpectedly contains newline");
784 }
785 format!("{}{}", " ".repeat(item.depth), item.text)
786 })
787 .collect::<Vec<String>>()
788 .join("\n"))
789 }
790
791 async fn run_foreground_fn<R: Send + 'static>(
792 description: &str,
793 foreground_tx: &mut mpsc::Sender<ForegroundFn>,
794 function: Box<dyn FnOnce(WeakEntity<Self>, AsyncApp) -> R + Send>,
795 ) -> anyhow::Result<R> {
796 let (response_tx, response_rx) = oneshot::channel();
797 let send_result = foreground_tx
798 .send(ForegroundFn(Box::new(move |this, cx| {
799 response_tx.send(function(this, cx)).ok();
800 })))
801 .await;
802 match send_result {
803 Ok(()) => (),
804 Err(err) => {
805 return Err(anyhow::Error::new(err).context(format!(
806 "Internal error while enqueuing work for {description}"
807 )));
808 }
809 }
810 match response_rx.await {
811 Ok(result) => Ok(result),
812 Err(oneshot::Canceled) => Err(anyhow!(
813 "Internal error: response oneshot was canceled while {description}."
814 )),
815 }
816 }
817
818 fn parse_abs_path_in_root_dir(root_dir: &Path, path_str: &str) -> anyhow::Result<PathBuf> {
819 let path = Path::new(&path_str);
820 if path.is_absolute() {
821 // Check if path starts with root_dir prefix without resolving symlinks
822 if path.starts_with(&root_dir) {
823 Ok(path.to_path_buf())
824 } else {
825 Err(anyhow!(
826 "Error: Absolute path {} is outside the current working directory",
827 path_str
828 ))
829 }
830 } else {
831 // TODO: Does use of `../` break sandbox - is path canonicalization needed?
832 Ok(root_dir.join(path))
833 }
834 }
835}
836
837enum FileReadFormat {
838 All,
839 Line,
840 LineWithLineFeed,
841 Bytes(usize),
842}
843
844struct FileContent(Arc<Mutex<Vec<u8>>>);
845
846impl UserData for FileContent {
847 fn add_methods<M: UserDataMethods<Self>>(_methods: &mut M) {
848 // FileContent doesn't have any methods so far.
849 }
850}
851
852#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
853pub struct ScriptId(u32);
854
855pub struct Script {
856 pub state: ScriptState,
857}
858
859#[derive(Debug)]
860pub enum ScriptState {
861 Running {
862 stdout: Arc<Mutex<String>>,
863 },
864 Succeeded {
865 stdout: String,
866 },
867 Failed {
868 stdout: String,
869 error: anyhow::Error,
870 },
871}
872
873impl Script {
874 /// If exited, returns a message with the output for the LLM
875 pub fn output_message_for_llm(&self) -> Option<String> {
876 match &self.state {
877 ScriptState::Running { .. } => None,
878 ScriptState::Succeeded { stdout } => {
879 format!("Here's the script output:\n{}", stdout).into()
880 }
881 ScriptState::Failed { stdout, error } => format!(
882 "The script failed with:\n{}\n\nHere's the output it managed to print:\n{}",
883 error, stdout
884 )
885 .into(),
886 }
887 }
888
889 /// Get a snapshot of the script's stdout
890 pub fn stdout_snapshot(&self) -> String {
891 match &self.state {
892 ScriptState::Running { stdout } => stdout.lock().clone(),
893 ScriptState::Succeeded { stdout } => stdout.clone(),
894 ScriptState::Failed { stdout, .. } => stdout.clone(),
895 }
896 }
897}
898#[cfg(test)]
899mod tests {
900 use gpui::TestAppContext;
901 use project::FakeFs;
902 use serde_json::json;
903 use settings::SettingsStore;
904 use util::path;
905
906 use super::*;
907
908 #[gpui::test]
909 async fn test_print(cx: &mut TestAppContext) {
910 let script = r#"
911 print("Hello", "world!")
912 print("Goodbye", "moon!")
913 "#;
914
915 let test_session = TestSession::init(cx).await;
916 let output = test_session.test_success(script, cx).await;
917 assert_eq!(output, "Hello\tworld!\nGoodbye\tmoon!\n");
918 }
919
920 // search
921
922 #[gpui::test]
923 async fn test_search(cx: &mut TestAppContext) {
924 let script = r#"
925 local results = search("world")
926 for i, result in ipairs(results) do
927 print("File: " .. result.path)
928 print("Matches:")
929 for j, match in ipairs(result.matches) do
930 print(" " .. match)
931 end
932 end
933 "#;
934
935 let test_session = TestSession::init(cx).await;
936 let output = test_session.test_success(script, cx).await;
937 assert_eq!(
938 output,
939 concat!("File: ", path!("/file1.txt"), "\nMatches:\n world\n")
940 );
941 }
942
943 // io.open
944
945 #[gpui::test]
946 async fn test_open_and_read_file(cx: &mut TestAppContext) {
947 let script = r#"
948 local file = io.open("file1.txt", "r")
949 local content = file:read()
950 print("Content:", content)
951 file:close()
952 "#;
953
954 let test_session = TestSession::init(cx).await;
955 let output = test_session.test_success(script, cx).await;
956 assert_eq!(output, "Content:\tHello world!\n");
957
958 // Only read, should not be marked as changed
959 assert!(!test_session.was_marked_changed("file1.txt", cx));
960 }
961
962 #[gpui::test]
963 async fn test_read_write_roundtrip(cx: &mut TestAppContext) {
964 let script = r#"
965 local file = io.open("file1.txt", "w")
966 file:write("This is new content")
967 file:close()
968
969 -- Read back to verify
970 local read_file = io.open("file1.txt", "r")
971 local content = read_file:read("*a")
972 print("Written content:", content)
973 read_file:close()
974 "#;
975
976 let test_session = TestSession::init(cx).await;
977 let output = test_session.test_success(script, cx).await;
978 assert_eq!(output, "Written content:\tThis is new content\n");
979 assert!(test_session.was_marked_changed("file1.txt", cx));
980 }
981
982 #[gpui::test]
983 async fn test_multiple_writes(cx: &mut TestAppContext) {
984 let script = r#"
985 -- Test writing to a file multiple times
986 local file = io.open("multiwrite.txt", "w")
987 file:write("First line\n")
988 file:write("Second line\n")
989 file:write("Third line")
990 file:close()
991
992 -- Read back to verify
993 local read_file = io.open("multiwrite.txt", "r")
994 if read_file then
995 local content = read_file:read("*a")
996 print("Full content:", content)
997 read_file:close()
998 end
999 "#;
1000
1001 let test_session = TestSession::init(cx).await;
1002 let output = test_session.test_success(script, cx).await;
1003 assert_eq!(
1004 output,
1005 "Full content:\tFirst line\nSecond line\nThird line\n"
1006 );
1007 assert!(test_session.was_marked_changed("multiwrite.txt", cx));
1008 }
1009
1010 #[gpui::test]
1011 async fn test_multiple_writes_diff_handles(cx: &mut TestAppContext) {
1012 let script = r#"
1013 -- Write to a file
1014 local file1 = io.open("multi_open.txt", "w")
1015 file1:write("Content written by first handle\n")
1016 file1:close()
1017
1018 -- Open it again and add more content
1019 local file2 = io.open("multi_open.txt", "w")
1020 file2:write("Content written by second handle\n")
1021 file2:close()
1022
1023 -- Open it a third time and read
1024 local file3 = io.open("multi_open.txt", "r")
1025 local content = file3:read("*a")
1026 print("Final content:", content)
1027 file3:close()
1028 "#;
1029
1030 let test_session = TestSession::init(cx).await;
1031 let output = test_session.test_success(script, cx).await;
1032 assert_eq!(
1033 output,
1034 "Final content:\tContent written by second handle\n\n"
1035 );
1036 assert!(test_session.was_marked_changed("multi_open.txt", cx));
1037 }
1038
1039 #[gpui::test]
1040 async fn test_append_mode(cx: &mut TestAppContext) {
1041 let script = r#"
1042 -- Test append mode
1043 local file = io.open("append.txt", "w")
1044 file:write("Initial content\n")
1045 file:close()
1046
1047 -- Append more content
1048 file = io.open("append.txt", "a")
1049 file:write("Appended content\n")
1050 file:close()
1051
1052 -- Add even more
1053 file = io.open("append.txt", "a")
1054 file:write("More appended content")
1055 file:close()
1056
1057 -- Read back to verify
1058 local read_file = io.open("append.txt", "r")
1059 local content = read_file:read("*a")
1060 print("Content after appends:", content)
1061 read_file:close()
1062 "#;
1063
1064 let test_session = TestSession::init(cx).await;
1065 let output = test_session.test_success(script, cx).await;
1066 assert_eq!(
1067 output,
1068 "Content after appends:\tInitial content\nAppended content\nMore appended content\n"
1069 );
1070 assert!(test_session.was_marked_changed("append.txt", cx));
1071 }
1072
1073 #[gpui::test]
1074 async fn test_read_formats(cx: &mut TestAppContext) {
1075 let script = r#"
1076 local file = io.open("multiline.txt", "w")
1077 file:write("Line 1\nLine 2\nLine 3")
1078 file:close()
1079
1080 -- Test "*a" (all)
1081 local f = io.open("multiline.txt", "r")
1082 local all = f:read("*a")
1083 print("All:", all)
1084 f:close()
1085
1086 -- Test "*l" (line)
1087 f = io.open("multiline.txt", "r")
1088 local line1 = f:read("*l")
1089 local line2 = f:read("*l")
1090 local line3 = f:read("*l")
1091 print("Line 1:", line1)
1092 print("Line 2:", line2)
1093 print("Line 3:", line3)
1094 f:close()
1095
1096 -- Test "*L" (line with newline)
1097 f = io.open("multiline.txt", "r")
1098 local line_with_nl = f:read("*L")
1099 print("Line with newline length:", #line_with_nl)
1100 print("Last char:", string.byte(line_with_nl, #line_with_nl))
1101 f:close()
1102
1103 -- Test number of bytes
1104 f = io.open("multiline.txt", "r")
1105 local bytes5 = f:read(5)
1106 print("5 bytes:", bytes5)
1107 f:close()
1108 "#;
1109
1110 let test_session = TestSession::init(cx).await;
1111 let output = test_session.test_success(script, cx).await;
1112 println!("{}", &output);
1113 assert!(output.contains("All:\tLine 1\nLine 2\nLine 3"));
1114 assert!(output.contains("Line 1:\tLine 1"));
1115 assert!(output.contains("Line 2:\tLine 2"));
1116 assert!(output.contains("Line 3:\tLine 3"));
1117 assert!(output.contains("Line with newline length:\t7"));
1118 assert!(output.contains("Last char:\t10")); // LF
1119 assert!(output.contains("5 bytes:\tLine "));
1120 assert!(test_session.was_marked_changed("multiline.txt", cx));
1121 }
1122
1123 // helpers
1124
1125 struct TestSession {
1126 session: Entity<ScriptingSession>,
1127 }
1128
1129 impl TestSession {
1130 async fn init(cx: &mut TestAppContext) -> Self {
1131 let settings_store = cx.update(SettingsStore::test);
1132 cx.set_global(settings_store);
1133 cx.update(Project::init_settings);
1134 cx.update(language::init);
1135
1136 let fs = FakeFs::new(cx.executor());
1137 fs.insert_tree(
1138 path!("/"),
1139 json!({
1140 "file1.txt": "Hello world!",
1141 "file2.txt": "Goodbye moon!"
1142 }),
1143 )
1144 .await;
1145
1146 let project = Project::test(fs.clone(), [Path::new(path!("/"))], cx).await;
1147 let session = cx.new(|cx| ScriptingSession::new(project, cx));
1148
1149 TestSession { session }
1150 }
1151
1152 async fn test_success(&self, source: &str, cx: &mut TestAppContext) -> String {
1153 let script_id = self.run_script(source, cx).await;
1154
1155 self.session.read_with(cx, |session, _cx| {
1156 let script = session.get(script_id);
1157 let stdout = script.stdout_snapshot();
1158
1159 if let ScriptState::Failed { error, .. } = &script.state {
1160 panic!("Script failed:\n{}\n\n{}", error, stdout);
1161 }
1162
1163 stdout
1164 })
1165 }
1166
1167 fn was_marked_changed(&self, path_str: &str, cx: &mut TestAppContext) -> bool {
1168 self.session.read_with(cx, |session, cx| {
1169 let count_changed = session
1170 .changed_buffers
1171 .iter()
1172 .filter(|buffer| buffer.read(cx).file().unwrap().path().ends_with(path_str))
1173 .count();
1174
1175 assert!(count_changed < 2, "Multiple buffers matched for same path");
1176
1177 count_changed > 0
1178 })
1179 }
1180
1181 async fn run_script(&self, source: &str, cx: &mut TestAppContext) -> ScriptId {
1182 let (script_id, task) = self
1183 .session
1184 .update(cx, |session, cx| session.run_script(source.to_string(), cx));
1185
1186 task.await;
1187
1188 script_id
1189 }
1190 }
1191}