session.rs

  1use anyhow::anyhow;
  2use collections::{HashMap, HashSet};
  3use futures::{
  4    channel::{mpsc, oneshot},
  5    pin_mut, SinkExt, StreamExt,
  6};
  7use gpui::{AppContext, AsyncApp, Context, Entity, Task, WeakEntity};
  8use mlua::{ExternalResult, Lua, MultiValue, Table, UserData, UserDataMethods};
  9use parking_lot::Mutex;
 10use project::{search::SearchQuery, Fs, Project};
 11use regex::Regex;
 12use std::{
 13    cell::RefCell,
 14    path::{Path, PathBuf},
 15    sync::Arc,
 16};
 17use util::{paths::PathMatcher, ResultExt};
 18
 19struct ForegroundFn(Box<dyn FnOnce(WeakEntity<ScriptSession>, AsyncApp) + Send>);
 20
 21pub struct ScriptSession {
 22    project: Entity<Project>,
 23    // TODO Remove this
 24    fs_changes: Arc<Mutex<HashMap<PathBuf, Vec<u8>>>>,
 25    foreground_fns_tx: mpsc::Sender<ForegroundFn>,
 26    _invoke_foreground_fns: Task<()>,
 27    scripts: Vec<Script>,
 28}
 29
 30impl ScriptSession {
 31    pub fn new(project: Entity<Project>, cx: &mut Context<Self>) -> Self {
 32        let (foreground_fns_tx, mut foreground_fns_rx) = mpsc::channel(128);
 33        ScriptSession {
 34            project,
 35            fs_changes: Arc::new(Mutex::new(HashMap::default())),
 36            foreground_fns_tx,
 37            _invoke_foreground_fns: cx.spawn(|this, cx| async move {
 38                while let Some(foreground_fn) = foreground_fns_rx.next().await {
 39                    foreground_fn.0(this.clone(), cx.clone());
 40                }
 41            }),
 42            scripts: Vec::new(),
 43        }
 44    }
 45
 46    pub fn run_script(
 47        &mut self,
 48        script_src: String,
 49        cx: &mut Context<Self>,
 50    ) -> (ScriptId, Task<()>) {
 51        let id = ScriptId(self.scripts.len() as u32);
 52
 53        let stdout = Arc::new(Mutex::new(String::new()));
 54
 55        let script = Script {
 56            state: ScriptState::Running {
 57                stdout: stdout.clone(),
 58            },
 59        };
 60        self.scripts.push(script);
 61
 62        let task = self.run_lua(script_src, stdout, cx);
 63
 64        let task = cx.spawn(|session, mut cx| async move {
 65            let result = task.await;
 66
 67            session
 68                .update(&mut cx, |session, _cx| {
 69                    let script = session.get_mut(id);
 70                    let stdout = script.stdout_snapshot();
 71
 72                    script.state = match result {
 73                        Ok(()) => ScriptState::Succeeded { stdout },
 74                        Err(error) => ScriptState::Failed { stdout, error },
 75                    };
 76                })
 77                .log_err();
 78        });
 79
 80        (id, task)
 81    }
 82
 83    fn run_lua(
 84        &mut self,
 85        script: String,
 86        stdout: Arc<Mutex<String>>,
 87        cx: &mut Context<Self>,
 88    ) -> Task<anyhow::Result<()>> {
 89        const SANDBOX_PREAMBLE: &str = include_str!("sandbox_preamble.lua");
 90
 91        // TODO Remove fs_changes
 92        let fs_changes = self.fs_changes.clone();
 93        // TODO Honor all worktrees instead of the first one
 94        let root_dir = self
 95            .project
 96            .read(cx)
 97            .visible_worktrees(cx)
 98            .next()
 99            .map(|worktree| worktree.read(cx).abs_path());
100
101        let fs = self.project.read(cx).fs().clone();
102        let foreground_fns_tx = self.foreground_fns_tx.clone();
103
104        let task = cx.background_spawn({
105            let stdout = stdout.clone();
106
107            async move {
108                let lua = Lua::new();
109                lua.set_memory_limit(2 * 1024 * 1024 * 1024)?; // 2 GB
110                let globals = lua.globals();
111
112                // Use the project root dir as the script's current working dir.
113                if let Some(root_dir) = &root_dir {
114                    if let Some(root_dir) = root_dir.to_str() {
115                        globals.set("cwd", root_dir)?;
116                    }
117                }
118
119                globals.set(
120                    "sb_print",
121                    lua.create_function({
122                        let stdout = stdout.clone();
123                        move |_, args: MultiValue| Self::print(args, &stdout)
124                    })?,
125                )?;
126                globals.set(
127                    "search",
128                    lua.create_async_function({
129                        let foreground_fns_tx = foreground_fns_tx.clone();
130                        move |lua, regex| {
131                            let mut foreground_fns_tx = foreground_fns_tx.clone();
132                            let fs = fs.clone();
133                            async move {
134                                Self::search(&lua, &mut foreground_fns_tx, fs, regex)
135                                    .await
136                                    .into_lua_err()
137                            }
138                        }
139                    })?,
140                )?;
141                globals.set(
142                    "outline",
143                    lua.create_async_function({
144                        let root_dir = root_dir.clone();
145                        move |_lua, path| {
146                            let mut foreground_fns_tx = foreground_fns_tx.clone();
147                            let root_dir = root_dir.clone();
148                            async move {
149                                Self::outline(root_dir, &mut foreground_fns_tx, path)
150                                    .await
151                                    .into_lua_err()
152                            }
153                        }
154                    })?,
155                )?;
156                globals.set(
157                    "sb_io_open",
158                    lua.create_function({
159                        let fs_changes = fs_changes.clone();
160                        let root_dir = root_dir.clone();
161                        move |lua, (path_str, mode)| {
162                            Self::io_open(&lua, &fs_changes, root_dir.as_ref(), path_str, mode)
163                        }
164                    })?,
165                )?;
166                globals.set("user_script", script)?;
167
168                lua.load(SANDBOX_PREAMBLE).exec_async().await?;
169
170                // Drop Lua instance to decrement reference count.
171                drop(lua);
172
173                anyhow::Ok(())
174            }
175        });
176
177        task
178    }
179
180    pub fn get(&self, script_id: ScriptId) -> &Script {
181        &self.scripts[script_id.0 as usize]
182    }
183
184    fn get_mut(&mut self, script_id: ScriptId) -> &mut Script {
185        &mut self.scripts[script_id.0 as usize]
186    }
187
188    /// Sandboxed print() function in Lua.
189    fn print(args: MultiValue, stdout: &Mutex<String>) -> mlua::Result<()> {
190        for (index, arg) in args.into_iter().enumerate() {
191            // Lua's `print()` prints tab characters between each argument.
192            if index > 0 {
193                stdout.lock().push('\t');
194            }
195
196            // If the argument's to_string() fails, have the whole function call fail.
197            stdout.lock().push_str(&arg.to_string()?);
198        }
199        stdout.lock().push('\n');
200
201        Ok(())
202    }
203
204    /// Sandboxed io.open() function in Lua.
205    fn io_open(
206        lua: &Lua,
207        fs_changes: &Arc<Mutex<HashMap<PathBuf, Vec<u8>>>>,
208        root_dir: Option<&Arc<Path>>,
209        path_str: String,
210        mode: Option<String>,
211    ) -> mlua::Result<(Option<Table>, String)> {
212        let root_dir = root_dir
213            .ok_or_else(|| mlua::Error::runtime("cannot open file without a root directory"))?;
214
215        let mode = mode.unwrap_or_else(|| "r".to_string());
216
217        // Parse the mode string to determine read/write permissions
218        let read_perm = mode.contains('r');
219        let write_perm = mode.contains('w') || mode.contains('a') || mode.contains('+');
220        let append = mode.contains('a');
221        let truncate = mode.contains('w');
222
223        // This will be the Lua value returned from the `open` function.
224        let file = lua.create_table()?;
225
226        // Store file metadata in the file
227        file.set("__path", path_str.clone())?;
228        file.set("__mode", mode.clone())?;
229        file.set("__read_perm", read_perm)?;
230        file.set("__write_perm", write_perm)?;
231
232        let path = match Self::parse_abs_path_in_root_dir(&root_dir, &path_str) {
233            Ok(path) => path,
234            Err(err) => return Ok((None, format!("{err}"))),
235        };
236
237        // close method
238        let close_fn = {
239            let fs_changes = fs_changes.clone();
240            lua.create_function(move |_lua, file_userdata: mlua::Table| {
241                let write_perm = file_userdata.get::<bool>("__write_perm")?;
242                let path = file_userdata.get::<String>("__path")?;
243
244                if write_perm {
245                    // When closing a writable file, record the content
246                    let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
247                    let content_ref = content.borrow::<FileContent>()?;
248                    let content_vec = content_ref.0.borrow();
249
250                    // Don't actually write to disk; instead, just update fs_changes.
251                    let path_buf = PathBuf::from(&path);
252                    fs_changes
253                        .lock()
254                        .insert(path_buf.clone(), content_vec.clone());
255                }
256
257                Ok(true)
258            })?
259        };
260        file.set("close", close_fn)?;
261
262        // If it's a directory, give it a custom read() and return early.
263        if path.is_dir() {
264            // TODO handle the case where we changed it in the in-memory fs
265
266            // Create a special directory handle
267            file.set("__is_directory", true)?;
268
269            // Store directory entries
270            let entries = match std::fs::read_dir(&path) {
271                Ok(entries) => {
272                    let mut entry_names = Vec::new();
273                    for entry in entries.flatten() {
274                        entry_names.push(entry.file_name().to_string_lossy().into_owned());
275                    }
276                    entry_names
277                }
278                Err(e) => return Ok((None, format!("Error reading directory: {}", e))),
279            };
280
281            // Save the list of entries
282            file.set("__dir_entries", entries)?;
283            file.set("__dir_position", 0usize)?;
284
285            // Create a directory-specific read function
286            let read_fn = lua.create_function(|_lua, file_userdata: mlua::Table| {
287                let position = file_userdata.get::<usize>("__dir_position")?;
288                let entries = file_userdata.get::<Vec<String>>("__dir_entries")?;
289
290                if position >= entries.len() {
291                    return Ok(None); // No more entries
292                }
293
294                let entry = entries[position].clone();
295                file_userdata.set("__dir_position", position + 1)?;
296
297                Ok(Some(entry))
298            })?;
299            file.set("read", read_fn)?;
300
301            // If we got this far, the directory was opened successfully
302            return Ok((Some(file), String::new()));
303        }
304
305        let fs_changes_map = fs_changes.lock();
306
307        let is_in_changes = fs_changes_map.contains_key(&path);
308        let file_exists = is_in_changes || path.exists();
309        let mut file_content = Vec::new();
310
311        if file_exists && !truncate {
312            if is_in_changes {
313                file_content = fs_changes_map.get(&path).unwrap().clone();
314            } else {
315                // Try to read existing content if file exists and we're not truncating
316                match std::fs::read(&path) {
317                    Ok(content) => file_content = content,
318                    Err(e) => return Ok((None, format!("Error reading file: {}", e))),
319                }
320            }
321        }
322
323        drop(fs_changes_map); // Unlock the fs_changes mutex.
324
325        // If in append mode, position should be at the end
326        let position = if append && file_exists {
327            file_content.len()
328        } else {
329            0
330        };
331        file.set("__position", position)?;
332        file.set(
333            "__content",
334            lua.create_userdata(FileContent(RefCell::new(file_content)))?,
335        )?;
336
337        // Create file methods
338
339        // read method
340        let read_fn = {
341            lua.create_function(
342                |_lua, (file_userdata, format): (mlua::Table, Option<mlua::Value>)| {
343                    let read_perm = file_userdata.get::<bool>("__read_perm")?;
344                    if !read_perm {
345                        return Err(mlua::Error::runtime("File not open for reading"));
346                    }
347
348                    let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
349                    let mut position = file_userdata.get::<usize>("__position")?;
350                    let content_ref = content.borrow::<FileContent>()?;
351                    let content_vec = content_ref.0.borrow();
352
353                    if position >= content_vec.len() {
354                        return Ok(None); // EOF
355                    }
356
357                    match format {
358                        Some(mlua::Value::String(s)) => {
359                            let lossy_string = s.to_string_lossy();
360                            let format_str: &str = lossy_string.as_ref();
361
362                            // Only consider the first 2 bytes, since it's common to pass e.g. "*all"  instead of "*a"
363                            match &format_str[0..2] {
364                                "*a" => {
365                                    // Read entire file from current position
366                                    let result = String::from_utf8_lossy(&content_vec[position..])
367                                        .to_string();
368                                    position = content_vec.len();
369                                    file_userdata.set("__position", position)?;
370                                    Ok(Some(result))
371                                }
372                                "*l" => {
373                                    // Read next line
374                                    let mut line = Vec::new();
375                                    let mut found_newline = false;
376
377                                    while position < content_vec.len() {
378                                        let byte = content_vec[position];
379                                        position += 1;
380
381                                        if byte == b'\n' {
382                                            found_newline = true;
383                                            break;
384                                        }
385
386                                        // Skip \r in \r\n sequence but add it if it's alone
387                                        if byte == b'\r' {
388                                            if position < content_vec.len()
389                                                && content_vec[position] == b'\n'
390                                            {
391                                                position += 1;
392                                                found_newline = true;
393                                                break;
394                                            }
395                                        }
396
397                                        line.push(byte);
398                                    }
399
400                                    file_userdata.set("__position", position)?;
401
402                                    if !found_newline
403                                        && line.is_empty()
404                                        && position >= content_vec.len()
405                                    {
406                                        return Ok(None); // EOF
407                                    }
408
409                                    let result = String::from_utf8_lossy(&line).to_string();
410                                    Ok(Some(result))
411                                }
412                                "*n" => {
413                                    // Try to parse as a number (number of bytes to read)
414                                    match format_str.parse::<usize>() {
415                                        Ok(n) => {
416                                            let end =
417                                                std::cmp::min(position + n, content_vec.len());
418                                            let bytes = &content_vec[position..end];
419                                            let result = String::from_utf8_lossy(bytes).to_string();
420                                            position = end;
421                                            file_userdata.set("__position", position)?;
422                                            Ok(Some(result))
423                                        }
424                                        Err(_) => Err(mlua::Error::runtime(format!(
425                                            "Invalid format: {}",
426                                            format_str
427                                        ))),
428                                    }
429                                }
430                                "*L" => {
431                                    // Read next line keeping the end of line
432                                    let mut line = Vec::new();
433
434                                    while position < content_vec.len() {
435                                        let byte = content_vec[position];
436                                        position += 1;
437
438                                        line.push(byte);
439
440                                        if byte == b'\n' {
441                                            break;
442                                        }
443
444                                        // If we encounter a \r, add it and check if the next is \n
445                                        if byte == b'\r'
446                                            && position < content_vec.len()
447                                            && content_vec[position] == b'\n'
448                                        {
449                                            line.push(content_vec[position]);
450                                            position += 1;
451                                            break;
452                                        }
453                                    }
454
455                                    file_userdata.set("__position", position)?;
456
457                                    if line.is_empty() && position >= content_vec.len() {
458                                        return Ok(None); // EOF
459                                    }
460
461                                    let result = String::from_utf8_lossy(&line).to_string();
462                                    Ok(Some(result))
463                                }
464                                _ => Err(mlua::Error::runtime(format!(
465                                    "Unsupported format: {}",
466                                    format_str
467                                ))),
468                            }
469                        }
470                        Some(mlua::Value::Number(n)) => {
471                            // Read n bytes
472                            let n = n as usize;
473                            let end = std::cmp::min(position + n, content_vec.len());
474                            let bytes = &content_vec[position..end];
475                            let result = String::from_utf8_lossy(bytes).to_string();
476                            position = end;
477                            file_userdata.set("__position", position)?;
478                            Ok(Some(result))
479                        }
480                        Some(_) => Err(mlua::Error::runtime("Invalid format")),
481                        None => {
482                            // Default is to read a line
483                            let mut line = Vec::new();
484                            let mut found_newline = false;
485
486                            while position < content_vec.len() {
487                                let byte = content_vec[position];
488                                position += 1;
489
490                                if byte == b'\n' {
491                                    found_newline = true;
492                                    break;
493                                }
494
495                                // Handle \r\n
496                                if byte == b'\r' {
497                                    if position < content_vec.len()
498                                        && content_vec[position] == b'\n'
499                                    {
500                                        position += 1;
501                                        found_newline = true;
502                                        break;
503                                    }
504                                }
505
506                                line.push(byte);
507                            }
508
509                            file_userdata.set("__position", position)?;
510
511                            if !found_newline && line.is_empty() && position >= content_vec.len() {
512                                return Ok(None); // EOF
513                            }
514
515                            let result = String::from_utf8_lossy(&line).to_string();
516                            Ok(Some(result))
517                        }
518                    }
519                },
520            )?
521        };
522        file.set("read", read_fn)?;
523
524        // write method
525        let write_fn = {
526            let fs_changes = fs_changes.clone();
527
528            lua.create_function(move |_lua, (file_userdata, text): (mlua::Table, String)| {
529                let write_perm = file_userdata.get::<bool>("__write_perm")?;
530                if !write_perm {
531                    return Err(mlua::Error::runtime("File not open for writing"));
532                }
533
534                let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
535                let position = file_userdata.get::<usize>("__position")?;
536                let content_ref = content.borrow::<FileContent>()?;
537                let mut content_vec = content_ref.0.borrow_mut();
538
539                let bytes = text.as_bytes();
540
541                // Ensure the vector has enough capacity
542                if position + bytes.len() > content_vec.len() {
543                    content_vec.resize(position + bytes.len(), 0);
544                }
545
546                // Write the bytes
547                for (i, &byte) in bytes.iter().enumerate() {
548                    content_vec[position + i] = byte;
549                }
550
551                // Update position
552                let new_position = position + bytes.len();
553                file_userdata.set("__position", new_position)?;
554
555                // Update fs_changes
556                let path = file_userdata.get::<String>("__path")?;
557                let path_buf = PathBuf::from(path);
558                fs_changes.lock().insert(path_buf, content_vec.clone());
559
560                Ok(true)
561            })?
562        };
563        file.set("write", write_fn)?;
564
565        // If we got this far, the file was opened successfully
566        Ok((Some(file), String::new()))
567    }
568
569    async fn search(
570        lua: &Lua,
571        foreground_tx: &mut mpsc::Sender<ForegroundFn>,
572        fs: Arc<dyn Fs>,
573        regex: String,
574    ) -> anyhow::Result<Table> {
575        // TODO: Allow specification of these options.
576        let search_query = SearchQuery::regex(
577            &regex,
578            false,
579            false,
580            false,
581            PathMatcher::default(),
582            PathMatcher::default(),
583            None,
584        );
585        let search_query = match search_query {
586            Ok(query) => query,
587            Err(e) => return Err(anyhow!("Invalid search query: {}", e)),
588        };
589
590        // TODO: Should use `search_query.regex`. The tool description should also be updated,
591        // as it specifies standard regex.
592        let search_regex = match Regex::new(&regex) {
593            Ok(re) => re,
594            Err(e) => return Err(anyhow!("Invalid regex: {}", e)),
595        };
596
597        let mut abs_paths_rx = Self::find_search_candidates(search_query, foreground_tx).await?;
598
599        let mut search_results: Vec<Table> = Vec::new();
600        while let Some(path) = abs_paths_rx.next().await {
601            // Skip files larger than 1MB
602            if let Ok(Some(metadata)) = fs.metadata(&path).await {
603                if metadata.len > 1_000_000 {
604                    continue;
605                }
606            }
607
608            // Attempt to read the file as text
609            if let Ok(content) = fs.load(&path).await {
610                let mut matches = Vec::new();
611
612                // Find all regex matches in the content
613                for capture in search_regex.find_iter(&content) {
614                    matches.push(capture.as_str().to_string());
615                }
616
617                // If we found matches, create a result entry
618                if !matches.is_empty() {
619                    let result_entry = lua.create_table()?;
620                    result_entry.set("path", path.to_string_lossy().to_string())?;
621
622                    let matches_table = lua.create_table()?;
623                    for (ix, m) in matches.iter().enumerate() {
624                        matches_table.set(ix + 1, m.clone())?;
625                    }
626                    result_entry.set("matches", matches_table)?;
627
628                    search_results.push(result_entry);
629                }
630            }
631        }
632
633        // Create a table to hold our results
634        let results_table = lua.create_table()?;
635        for (ix, entry) in search_results.into_iter().enumerate() {
636            results_table.set(ix + 1, entry)?;
637        }
638
639        Ok(results_table)
640    }
641
642    async fn find_search_candidates(
643        search_query: SearchQuery,
644        foreground_tx: &mut mpsc::Sender<ForegroundFn>,
645    ) -> anyhow::Result<mpsc::UnboundedReceiver<PathBuf>> {
646        Self::run_foreground_fn(
647            "finding search file candidates",
648            foreground_tx,
649            Box::new(move |session, mut cx| {
650                session.update(&mut cx, |session, cx| {
651                    session.project.update(cx, |project, cx| {
652                        project.worktree_store().update(cx, |worktree_store, cx| {
653                            // TODO: Better limit? For now this is the same as
654                            // MAX_SEARCH_RESULT_FILES.
655                            let limit = 5000;
656                            // TODO: Providing non-empty open_entries can make this a bit more
657                            // efficient as it can skip checking that these paths are textual.
658                            let open_entries = HashSet::default();
659                            let candidates = worktree_store.find_search_candidates(
660                                search_query,
661                                limit,
662                                open_entries,
663                                project.fs().clone(),
664                                cx,
665                            );
666                            let (abs_paths_tx, abs_paths_rx) = mpsc::unbounded();
667                            cx.spawn(|worktree_store, cx| async move {
668                                pin_mut!(candidates);
669
670                                while let Some(project_path) = candidates.next().await {
671                                    worktree_store.read_with(&cx, |worktree_store, cx| {
672                                        if let Some(worktree) = worktree_store
673                                            .worktree_for_id(project_path.worktree_id, cx)
674                                        {
675                                            if let Some(abs_path) = worktree
676                                                .read(cx)
677                                                .absolutize(&project_path.path)
678                                                .log_err()
679                                            {
680                                                abs_paths_tx.unbounded_send(abs_path)?;
681                                            }
682                                        }
683                                        anyhow::Ok(())
684                                    })??;
685                                }
686                                anyhow::Ok(())
687                            })
688                            .detach();
689                            abs_paths_rx
690                        })
691                    })
692                })
693            }),
694        )
695        .await?
696    }
697
698    async fn outline(
699        root_dir: Option<Arc<Path>>,
700        foreground_tx: &mut mpsc::Sender<ForegroundFn>,
701        path_str: String,
702    ) -> anyhow::Result<String> {
703        let root_dir = root_dir
704            .ok_or_else(|| mlua::Error::runtime("cannot get outline without a root directory"))?;
705        let path = Self::parse_abs_path_in_root_dir(&root_dir, &path_str)?;
706        let outline = Self::run_foreground_fn(
707            "getting code outline",
708            foreground_tx,
709            Box::new(move |session, cx| {
710                cx.spawn(move |mut cx| async move {
711                    // TODO: This will not use file content from `fs_changes`. It will also reflect
712                    // user changes that have not been saved.
713                    let buffer = session
714                        .update(&mut cx, |session, cx| {
715                            session
716                                .project
717                                .update(cx, |project, cx| project.open_local_buffer(&path, cx))
718                        })?
719                        .await?;
720                    buffer.update(&mut cx, |buffer, _cx| {
721                        if let Some(outline) = buffer.snapshot().outline(None) {
722                            Ok(outline)
723                        } else {
724                            Err(anyhow!("No outline for file {path_str}"))
725                        }
726                    })
727                })
728            }),
729        )
730        .await?
731        .await??;
732
733        Ok(outline
734            .items
735            .into_iter()
736            .map(|item| {
737                if item.text.contains('\n') {
738                    log::error!("Outline item unexpectedly contains newline");
739                }
740                format!("{}{}", "  ".repeat(item.depth), item.text)
741            })
742            .collect::<Vec<String>>()
743            .join("\n"))
744    }
745
746    async fn run_foreground_fn<R: Send + 'static>(
747        description: &str,
748        foreground_tx: &mut mpsc::Sender<ForegroundFn>,
749        function: Box<dyn FnOnce(WeakEntity<Self>, AsyncApp) -> R + Send>,
750    ) -> anyhow::Result<R> {
751        let (response_tx, response_rx) = oneshot::channel();
752        let send_result = foreground_tx
753            .send(ForegroundFn(Box::new(move |this, cx| {
754                response_tx.send(function(this, cx)).ok();
755            })))
756            .await;
757        match send_result {
758            Ok(()) => (),
759            Err(err) => {
760                return Err(anyhow::Error::new(err).context(format!(
761                    "Internal error while enqueuing work for {description}"
762                )));
763            }
764        }
765        match response_rx.await {
766            Ok(result) => Ok(result),
767            Err(oneshot::Canceled) => Err(anyhow!(
768                "Internal error: response oneshot was canceled while {description}."
769            )),
770        }
771    }
772
773    fn parse_abs_path_in_root_dir(root_dir: &Path, path_str: &str) -> anyhow::Result<PathBuf> {
774        let path = Path::new(&path_str);
775        if path.is_absolute() {
776            // Check if path starts with root_dir prefix without resolving symlinks
777            if path.starts_with(&root_dir) {
778                Ok(path.to_path_buf())
779            } else {
780                Err(anyhow!(
781                    "Error: Absolute path {} is outside the current working directory",
782                    path_str
783                ))
784            }
785        } else {
786            // TODO: Does use of `../` break sandbox - is path canonicalization needed?
787            Ok(root_dir.join(path))
788        }
789    }
790}
791
792struct FileContent(RefCell<Vec<u8>>);
793
794impl UserData for FileContent {
795    fn add_methods<M: UserDataMethods<Self>>(_methods: &mut M) {
796        // FileContent doesn't have any methods so far.
797    }
798}
799
800#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
801pub struct ScriptId(u32);
802
803pub struct Script {
804    pub state: ScriptState,
805}
806
807pub enum ScriptState {
808    Running {
809        stdout: Arc<Mutex<String>>,
810    },
811    Succeeded {
812        stdout: String,
813    },
814    Failed {
815        stdout: String,
816        error: anyhow::Error,
817    },
818}
819
820impl Script {
821    /// If exited, returns a message with the output for the LLM
822    pub fn output_message_for_llm(&self) -> Option<String> {
823        match &self.state {
824            ScriptState::Running { .. } => None,
825            ScriptState::Succeeded { stdout } => {
826                format!("Here's the script output:\n{}", stdout).into()
827            }
828            ScriptState::Failed { stdout, error } => format!(
829                "The script failed with:\n{}\n\nHere's the output it managed to print:\n{}",
830                error, stdout
831            )
832            .into(),
833        }
834    }
835
836    /// Get a snapshot of the script's stdout
837    pub fn stdout_snapshot(&self) -> String {
838        match &self.state {
839            ScriptState::Running { stdout } => stdout.lock().clone(),
840            ScriptState::Succeeded { stdout } => stdout.clone(),
841            ScriptState::Failed { stdout, .. } => stdout.clone(),
842        }
843    }
844}
845
846#[cfg(test)]
847mod tests {
848    use gpui::TestAppContext;
849    use project::FakeFs;
850    use serde_json::json;
851    use settings::SettingsStore;
852
853    use super::*;
854
855    #[gpui::test]
856    async fn test_print(cx: &mut TestAppContext) {
857        let script = r#"
858            print("Hello", "world!")
859            print("Goodbye", "moon!")
860        "#;
861
862        let output = test_script(script, cx).await.unwrap();
863        assert_eq!(output, "Hello\tworld!\nGoodbye\tmoon!\n");
864    }
865
866    #[gpui::test]
867    async fn test_search(cx: &mut TestAppContext) {
868        let script = r#"
869            local results = search("world")
870            for i, result in ipairs(results) do
871                print("File: " .. result.path)
872                print("Matches:")
873                for j, match in ipairs(result.matches) do
874                    print("  " .. match)
875                end
876            end
877        "#;
878
879        let output = test_script(script, cx).await.unwrap();
880        assert_eq!(output, "File: /file1.txt\nMatches:\n  world\n");
881    }
882
883    async fn test_script(source: &str, cx: &mut TestAppContext) -> anyhow::Result<String> {
884        init_test(cx);
885        let fs = FakeFs::new(cx.executor());
886        fs.insert_tree(
887            "/",
888            json!({
889                "file1.txt": "Hello world!",
890                "file2.txt": "Goodbye moon!"
891            }),
892        )
893        .await;
894
895        let project = Project::test(fs, [Path::new("/")], cx).await;
896        let session = cx.new(|cx| ScriptSession::new(project, cx));
897
898        let (script_id, task) =
899            session.update(cx, |session, cx| session.run_script(source.to_string(), cx));
900
901        task.await;
902
903        Ok(session.read_with(cx, |session, _cx| session.get(script_id).stdout_snapshot()))
904    }
905
906    fn init_test(cx: &mut TestAppContext) {
907        let settings_store = cx.update(SettingsStore::test);
908        cx.set_global(settings_store);
909        cx.update(Project::init_settings);
910    }
911}