session.rs

  1use anyhow::anyhow;
  2use collections::{HashMap, HashSet};
  3use futures::{
  4    channel::{mpsc, oneshot},
  5    pin_mut, SinkExt, StreamExt,
  6};
  7use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
  8use mlua::{ExternalResult, Lua, MultiValue, Table, UserData, UserDataMethods};
  9use parking_lot::Mutex;
 10use project::{search::SearchQuery, Fs, Project};
 11use regex::Regex;
 12use std::{
 13    cell::RefCell,
 14    path::{Path, PathBuf},
 15    sync::Arc,
 16};
 17use util::{paths::PathMatcher, ResultExt};
 18
 19use crate::{SCRIPT_END_TAG, SCRIPT_START_TAG};
 20
 21struct ForegroundFn(Box<dyn FnOnce(WeakEntity<ScriptSession>, AsyncApp) + Send>);
 22
 23pub struct ScriptSession {
 24    project: Entity<Project>,
 25    // TODO Remove this
 26    fs_changes: Arc<Mutex<HashMap<PathBuf, Vec<u8>>>>,
 27    foreground_fns_tx: mpsc::Sender<ForegroundFn>,
 28    _invoke_foreground_fns: Task<()>,
 29    scripts: Vec<Script>,
 30}
 31
 32impl ScriptSession {
 33    pub fn new(project: Entity<Project>, cx: &mut Context<Self>) -> Self {
 34        let (foreground_fns_tx, mut foreground_fns_rx) = mpsc::channel(128);
 35        ScriptSession {
 36            project,
 37            fs_changes: Arc::new(Mutex::new(HashMap::default())),
 38            foreground_fns_tx,
 39            _invoke_foreground_fns: cx.spawn(|this, cx| async move {
 40                while let Some(foreground_fn) = foreground_fns_rx.next().await {
 41                    foreground_fn.0(this.clone(), cx.clone());
 42                }
 43            }),
 44            scripts: Vec::new(),
 45        }
 46    }
 47
 48    pub fn new_script(&mut self) -> ScriptId {
 49        let id = ScriptId(self.scripts.len() as u32);
 50        let script = Script {
 51            id,
 52            state: ScriptState::Generating,
 53            source: SharedString::new_static(""),
 54        };
 55        self.scripts.push(script);
 56        id
 57    }
 58
 59    pub fn run_script(
 60        &mut self,
 61        script_id: ScriptId,
 62        script_src: String,
 63        cx: &mut Context<Self>,
 64    ) -> Task<anyhow::Result<()>> {
 65        let script = self.get_mut(script_id);
 66
 67        let stdout = Arc::new(Mutex::new(String::new()));
 68        script.source = script_src.clone().into();
 69        script.state = ScriptState::Running {
 70            stdout: stdout.clone(),
 71        };
 72
 73        let task = self.run_lua(script_src, stdout, cx);
 74
 75        cx.emit(ScriptEvent::Spawned(script_id));
 76
 77        cx.spawn(|session, mut cx| async move {
 78            let result = task.await;
 79
 80            session.update(&mut cx, |session, cx| {
 81                let script = session.get_mut(script_id);
 82                let stdout = script.stdout_snapshot();
 83
 84                script.state = match result {
 85                    Ok(()) => ScriptState::Succeeded { stdout },
 86                    Err(error) => ScriptState::Failed { stdout, error },
 87                };
 88
 89                cx.emit(ScriptEvent::Exited(script_id))
 90            })
 91        })
 92    }
 93
 94    fn run_lua(
 95        &mut self,
 96        script: String,
 97        stdout: Arc<Mutex<String>>,
 98        cx: &mut Context<Self>,
 99    ) -> Task<anyhow::Result<()>> {
100        const SANDBOX_PREAMBLE: &str = include_str!("sandbox_preamble.lua");
101
102        // TODO Remove fs_changes
103        let fs_changes = self.fs_changes.clone();
104        // TODO Honor all worktrees instead of the first one
105        let root_dir = self
106            .project
107            .read(cx)
108            .visible_worktrees(cx)
109            .next()
110            .map(|worktree| worktree.read(cx).abs_path());
111
112        let fs = self.project.read(cx).fs().clone();
113        let foreground_fns_tx = self.foreground_fns_tx.clone();
114
115        let task = cx.background_spawn({
116            let stdout = stdout.clone();
117
118            async move {
119                let lua = Lua::new();
120                lua.set_memory_limit(2 * 1024 * 1024 * 1024)?; // 2 GB
121                let globals = lua.globals();
122
123                // Use the project root dir as the script's current working dir.
124                if let Some(root_dir) = &root_dir {
125                    if let Some(root_dir) = root_dir.to_str() {
126                        globals.set("cwd", root_dir)?;
127                    }
128                }
129
130                globals.set(
131                    "sb_print",
132                    lua.create_function({
133                        let stdout = stdout.clone();
134                        move |_, args: MultiValue| Self::print(args, &stdout)
135                    })?,
136                )?;
137                globals.set(
138                    "search",
139                    lua.create_async_function({
140                        let foreground_fns_tx = foreground_fns_tx.clone();
141                        move |lua, regex| {
142                            let mut foreground_fns_tx = foreground_fns_tx.clone();
143                            let fs = fs.clone();
144                            async move {
145                                Self::search(&lua, &mut foreground_fns_tx, fs, regex)
146                                    .await
147                                    .into_lua_err()
148                            }
149                        }
150                    })?,
151                )?;
152                globals.set(
153                    "outline",
154                    lua.create_async_function({
155                        let root_dir = root_dir.clone();
156                        move |_lua, path| {
157                            let mut foreground_fns_tx = foreground_fns_tx.clone();
158                            let root_dir = root_dir.clone();
159                            async move {
160                                Self::outline(root_dir, &mut foreground_fns_tx, path)
161                                    .await
162                                    .into_lua_err()
163                            }
164                        }
165                    })?,
166                )?;
167                globals.set(
168                    "sb_io_open",
169                    lua.create_function({
170                        let fs_changes = fs_changes.clone();
171                        let root_dir = root_dir.clone();
172                        move |lua, (path_str, mode)| {
173                            Self::io_open(&lua, &fs_changes, root_dir.as_ref(), path_str, mode)
174                        }
175                    })?,
176                )?;
177                globals.set("user_script", script)?;
178
179                lua.load(SANDBOX_PREAMBLE).exec_async().await?;
180
181                // Drop Lua instance to decrement reference count.
182                drop(lua);
183
184                anyhow::Ok(())
185            }
186        });
187
188        task
189    }
190
191    pub fn get(&self, script_id: ScriptId) -> &Script {
192        &self.scripts[script_id.0 as usize]
193    }
194
195    fn get_mut(&mut self, script_id: ScriptId) -> &mut Script {
196        &mut self.scripts[script_id.0 as usize]
197    }
198
199    /// Sandboxed print() function in Lua.
200    fn print(args: MultiValue, stdout: &Mutex<String>) -> mlua::Result<()> {
201        for (index, arg) in args.into_iter().enumerate() {
202            // Lua's `print()` prints tab characters between each argument.
203            if index > 0 {
204                stdout.lock().push('\t');
205            }
206
207            // If the argument's to_string() fails, have the whole function call fail.
208            stdout.lock().push_str(&arg.to_string()?);
209        }
210        stdout.lock().push('\n');
211
212        Ok(())
213    }
214
215    /// Sandboxed io.open() function in Lua.
216    fn io_open(
217        lua: &Lua,
218        fs_changes: &Arc<Mutex<HashMap<PathBuf, Vec<u8>>>>,
219        root_dir: Option<&Arc<Path>>,
220        path_str: String,
221        mode: Option<String>,
222    ) -> mlua::Result<(Option<Table>, String)> {
223        let root_dir = root_dir
224            .ok_or_else(|| mlua::Error::runtime("cannot open file without a root directory"))?;
225
226        let mode = mode.unwrap_or_else(|| "r".to_string());
227
228        // Parse the mode string to determine read/write permissions
229        let read_perm = mode.contains('r');
230        let write_perm = mode.contains('w') || mode.contains('a') || mode.contains('+');
231        let append = mode.contains('a');
232        let truncate = mode.contains('w');
233
234        // This will be the Lua value returned from the `open` function.
235        let file = lua.create_table()?;
236
237        // Store file metadata in the file
238        file.set("__path", path_str.clone())?;
239        file.set("__mode", mode.clone())?;
240        file.set("__read_perm", read_perm)?;
241        file.set("__write_perm", write_perm)?;
242
243        let path = match Self::parse_abs_path_in_root_dir(&root_dir, &path_str) {
244            Ok(path) => path,
245            Err(err) => return Ok((None, format!("{err}"))),
246        };
247
248        // close method
249        let close_fn = {
250            let fs_changes = fs_changes.clone();
251            lua.create_function(move |_lua, file_userdata: mlua::Table| {
252                let write_perm = file_userdata.get::<bool>("__write_perm")?;
253                let path = file_userdata.get::<String>("__path")?;
254
255                if write_perm {
256                    // When closing a writable file, record the content
257                    let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
258                    let content_ref = content.borrow::<FileContent>()?;
259                    let content_vec = content_ref.0.borrow();
260
261                    // Don't actually write to disk; instead, just update fs_changes.
262                    let path_buf = PathBuf::from(&path);
263                    fs_changes
264                        .lock()
265                        .insert(path_buf.clone(), content_vec.clone());
266                }
267
268                Ok(true)
269            })?
270        };
271        file.set("close", close_fn)?;
272
273        // If it's a directory, give it a custom read() and return early.
274        if path.is_dir() {
275            // TODO handle the case where we changed it in the in-memory fs
276
277            // Create a special directory handle
278            file.set("__is_directory", true)?;
279
280            // Store directory entries
281            let entries = match std::fs::read_dir(&path) {
282                Ok(entries) => {
283                    let mut entry_names = Vec::new();
284                    for entry in entries.flatten() {
285                        entry_names.push(entry.file_name().to_string_lossy().into_owned());
286                    }
287                    entry_names
288                }
289                Err(e) => return Ok((None, format!("Error reading directory: {}", e))),
290            };
291
292            // Save the list of entries
293            file.set("__dir_entries", entries)?;
294            file.set("__dir_position", 0usize)?;
295
296            // Create a directory-specific read function
297            let read_fn = lua.create_function(|_lua, file_userdata: mlua::Table| {
298                let position = file_userdata.get::<usize>("__dir_position")?;
299                let entries = file_userdata.get::<Vec<String>>("__dir_entries")?;
300
301                if position >= entries.len() {
302                    return Ok(None); // No more entries
303                }
304
305                let entry = entries[position].clone();
306                file_userdata.set("__dir_position", position + 1)?;
307
308                Ok(Some(entry))
309            })?;
310            file.set("read", read_fn)?;
311
312            // If we got this far, the directory was opened successfully
313            return Ok((Some(file), String::new()));
314        }
315
316        let fs_changes_map = fs_changes.lock();
317
318        let is_in_changes = fs_changes_map.contains_key(&path);
319        let file_exists = is_in_changes || path.exists();
320        let mut file_content = Vec::new();
321
322        if file_exists && !truncate {
323            if is_in_changes {
324                file_content = fs_changes_map.get(&path).unwrap().clone();
325            } else {
326                // Try to read existing content if file exists and we're not truncating
327                match std::fs::read(&path) {
328                    Ok(content) => file_content = content,
329                    Err(e) => return Ok((None, format!("Error reading file: {}", e))),
330                }
331            }
332        }
333
334        drop(fs_changes_map); // Unlock the fs_changes mutex.
335
336        // If in append mode, position should be at the end
337        let position = if append && file_exists {
338            file_content.len()
339        } else {
340            0
341        };
342        file.set("__position", position)?;
343        file.set(
344            "__content",
345            lua.create_userdata(FileContent(RefCell::new(file_content)))?,
346        )?;
347
348        // Create file methods
349
350        // read method
351        let read_fn = {
352            lua.create_function(
353                |_lua, (file_userdata, format): (mlua::Table, Option<mlua::Value>)| {
354                    let read_perm = file_userdata.get::<bool>("__read_perm")?;
355                    if !read_perm {
356                        return Err(mlua::Error::runtime("File not open for reading"));
357                    }
358
359                    let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
360                    let mut position = file_userdata.get::<usize>("__position")?;
361                    let content_ref = content.borrow::<FileContent>()?;
362                    let content_vec = content_ref.0.borrow();
363
364                    if position >= content_vec.len() {
365                        return Ok(None); // EOF
366                    }
367
368                    match format {
369                        Some(mlua::Value::String(s)) => {
370                            let lossy_string = s.to_string_lossy();
371                            let format_str: &str = lossy_string.as_ref();
372
373                            // Only consider the first 2 bytes, since it's common to pass e.g. "*all"  instead of "*a"
374                            match &format_str[0..2] {
375                                "*a" => {
376                                    // Read entire file from current position
377                                    let result = String::from_utf8_lossy(&content_vec[position..])
378                                        .to_string();
379                                    position = content_vec.len();
380                                    file_userdata.set("__position", position)?;
381                                    Ok(Some(result))
382                                }
383                                "*l" => {
384                                    // Read next line
385                                    let mut line = Vec::new();
386                                    let mut found_newline = false;
387
388                                    while position < content_vec.len() {
389                                        let byte = content_vec[position];
390                                        position += 1;
391
392                                        if byte == b'\n' {
393                                            found_newline = true;
394                                            break;
395                                        }
396
397                                        // Skip \r in \r\n sequence but add it if it's alone
398                                        if byte == b'\r' {
399                                            if position < content_vec.len()
400                                                && content_vec[position] == b'\n'
401                                            {
402                                                position += 1;
403                                                found_newline = true;
404                                                break;
405                                            }
406                                        }
407
408                                        line.push(byte);
409                                    }
410
411                                    file_userdata.set("__position", position)?;
412
413                                    if !found_newline
414                                        && line.is_empty()
415                                        && position >= content_vec.len()
416                                    {
417                                        return Ok(None); // EOF
418                                    }
419
420                                    let result = String::from_utf8_lossy(&line).to_string();
421                                    Ok(Some(result))
422                                }
423                                "*n" => {
424                                    // Try to parse as a number (number of bytes to read)
425                                    match format_str.parse::<usize>() {
426                                        Ok(n) => {
427                                            let end =
428                                                std::cmp::min(position + n, content_vec.len());
429                                            let bytes = &content_vec[position..end];
430                                            let result = String::from_utf8_lossy(bytes).to_string();
431                                            position = end;
432                                            file_userdata.set("__position", position)?;
433                                            Ok(Some(result))
434                                        }
435                                        Err(_) => Err(mlua::Error::runtime(format!(
436                                            "Invalid format: {}",
437                                            format_str
438                                        ))),
439                                    }
440                                }
441                                "*L" => {
442                                    // Read next line keeping the end of line
443                                    let mut line = Vec::new();
444
445                                    while position < content_vec.len() {
446                                        let byte = content_vec[position];
447                                        position += 1;
448
449                                        line.push(byte);
450
451                                        if byte == b'\n' {
452                                            break;
453                                        }
454
455                                        // If we encounter a \r, add it and check if the next is \n
456                                        if byte == b'\r'
457                                            && position < content_vec.len()
458                                            && content_vec[position] == b'\n'
459                                        {
460                                            line.push(content_vec[position]);
461                                            position += 1;
462                                            break;
463                                        }
464                                    }
465
466                                    file_userdata.set("__position", position)?;
467
468                                    if line.is_empty() && position >= content_vec.len() {
469                                        return Ok(None); // EOF
470                                    }
471
472                                    let result = String::from_utf8_lossy(&line).to_string();
473                                    Ok(Some(result))
474                                }
475                                _ => Err(mlua::Error::runtime(format!(
476                                    "Unsupported format: {}",
477                                    format_str
478                                ))),
479                            }
480                        }
481                        Some(mlua::Value::Number(n)) => {
482                            // Read n bytes
483                            let n = n as usize;
484                            let end = std::cmp::min(position + n, content_vec.len());
485                            let bytes = &content_vec[position..end];
486                            let result = String::from_utf8_lossy(bytes).to_string();
487                            position = end;
488                            file_userdata.set("__position", position)?;
489                            Ok(Some(result))
490                        }
491                        Some(_) => Err(mlua::Error::runtime("Invalid format")),
492                        None => {
493                            // Default is to read a line
494                            let mut line = Vec::new();
495                            let mut found_newline = false;
496
497                            while position < content_vec.len() {
498                                let byte = content_vec[position];
499                                position += 1;
500
501                                if byte == b'\n' {
502                                    found_newline = true;
503                                    break;
504                                }
505
506                                // Handle \r\n
507                                if byte == b'\r' {
508                                    if position < content_vec.len()
509                                        && content_vec[position] == b'\n'
510                                    {
511                                        position += 1;
512                                        found_newline = true;
513                                        break;
514                                    }
515                                }
516
517                                line.push(byte);
518                            }
519
520                            file_userdata.set("__position", position)?;
521
522                            if !found_newline && line.is_empty() && position >= content_vec.len() {
523                                return Ok(None); // EOF
524                            }
525
526                            let result = String::from_utf8_lossy(&line).to_string();
527                            Ok(Some(result))
528                        }
529                    }
530                },
531            )?
532        };
533        file.set("read", read_fn)?;
534
535        // write method
536        let write_fn = {
537            let fs_changes = fs_changes.clone();
538
539            lua.create_function(move |_lua, (file_userdata, text): (mlua::Table, String)| {
540                let write_perm = file_userdata.get::<bool>("__write_perm")?;
541                if !write_perm {
542                    return Err(mlua::Error::runtime("File not open for writing"));
543                }
544
545                let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
546                let position = file_userdata.get::<usize>("__position")?;
547                let content_ref = content.borrow::<FileContent>()?;
548                let mut content_vec = content_ref.0.borrow_mut();
549
550                let bytes = text.as_bytes();
551
552                // Ensure the vector has enough capacity
553                if position + bytes.len() > content_vec.len() {
554                    content_vec.resize(position + bytes.len(), 0);
555                }
556
557                // Write the bytes
558                for (i, &byte) in bytes.iter().enumerate() {
559                    content_vec[position + i] = byte;
560                }
561
562                // Update position
563                let new_position = position + bytes.len();
564                file_userdata.set("__position", new_position)?;
565
566                // Update fs_changes
567                let path = file_userdata.get::<String>("__path")?;
568                let path_buf = PathBuf::from(path);
569                fs_changes.lock().insert(path_buf, content_vec.clone());
570
571                Ok(true)
572            })?
573        };
574        file.set("write", write_fn)?;
575
576        // If we got this far, the file was opened successfully
577        Ok((Some(file), String::new()))
578    }
579
580    async fn search(
581        lua: &Lua,
582        foreground_tx: &mut mpsc::Sender<ForegroundFn>,
583        fs: Arc<dyn Fs>,
584        regex: String,
585    ) -> anyhow::Result<Table> {
586        // TODO: Allow specification of these options.
587        let search_query = SearchQuery::regex(
588            &regex,
589            false,
590            false,
591            false,
592            PathMatcher::default(),
593            PathMatcher::default(),
594            None,
595        );
596        let search_query = match search_query {
597            Ok(query) => query,
598            Err(e) => return Err(anyhow!("Invalid search query: {}", e)),
599        };
600
601        // TODO: Should use `search_query.regex`. The tool description should also be updated,
602        // as it specifies standard regex.
603        let search_regex = match Regex::new(&regex) {
604            Ok(re) => re,
605            Err(e) => return Err(anyhow!("Invalid regex: {}", e)),
606        };
607
608        let mut abs_paths_rx = Self::find_search_candidates(search_query, foreground_tx).await?;
609
610        let mut search_results: Vec<Table> = Vec::new();
611        while let Some(path) = abs_paths_rx.next().await {
612            // Skip files larger than 1MB
613            if let Ok(Some(metadata)) = fs.metadata(&path).await {
614                if metadata.len > 1_000_000 {
615                    continue;
616                }
617            }
618
619            // Attempt to read the file as text
620            if let Ok(content) = fs.load(&path).await {
621                let mut matches = Vec::new();
622
623                // Find all regex matches in the content
624                for capture in search_regex.find_iter(&content) {
625                    matches.push(capture.as_str().to_string());
626                }
627
628                // If we found matches, create a result entry
629                if !matches.is_empty() {
630                    let result_entry = lua.create_table()?;
631                    result_entry.set("path", path.to_string_lossy().to_string())?;
632
633                    let matches_table = lua.create_table()?;
634                    for (ix, m) in matches.iter().enumerate() {
635                        matches_table.set(ix + 1, m.clone())?;
636                    }
637                    result_entry.set("matches", matches_table)?;
638
639                    search_results.push(result_entry);
640                }
641            }
642        }
643
644        // Create a table to hold our results
645        let results_table = lua.create_table()?;
646        for (ix, entry) in search_results.into_iter().enumerate() {
647            results_table.set(ix + 1, entry)?;
648        }
649
650        Ok(results_table)
651    }
652
653    async fn find_search_candidates(
654        search_query: SearchQuery,
655        foreground_tx: &mut mpsc::Sender<ForegroundFn>,
656    ) -> anyhow::Result<mpsc::UnboundedReceiver<PathBuf>> {
657        Self::run_foreground_fn(
658            "finding search file candidates",
659            foreground_tx,
660            Box::new(move |session, mut cx| {
661                session.update(&mut cx, |session, cx| {
662                    session.project.update(cx, |project, cx| {
663                        project.worktree_store().update(cx, |worktree_store, cx| {
664                            // TODO: Better limit? For now this is the same as
665                            // MAX_SEARCH_RESULT_FILES.
666                            let limit = 5000;
667                            // TODO: Providing non-empty open_entries can make this a bit more
668                            // efficient as it can skip checking that these paths are textual.
669                            let open_entries = HashSet::default();
670                            let candidates = worktree_store.find_search_candidates(
671                                search_query,
672                                limit,
673                                open_entries,
674                                project.fs().clone(),
675                                cx,
676                            );
677                            let (abs_paths_tx, abs_paths_rx) = mpsc::unbounded();
678                            cx.spawn(|worktree_store, cx| async move {
679                                pin_mut!(candidates);
680
681                                while let Some(project_path) = candidates.next().await {
682                                    worktree_store.read_with(&cx, |worktree_store, cx| {
683                                        if let Some(worktree) = worktree_store
684                                            .worktree_for_id(project_path.worktree_id, cx)
685                                        {
686                                            if let Some(abs_path) = worktree
687                                                .read(cx)
688                                                .absolutize(&project_path.path)
689                                                .log_err()
690                                            {
691                                                abs_paths_tx.unbounded_send(abs_path)?;
692                                            }
693                                        }
694                                        anyhow::Ok(())
695                                    })??;
696                                }
697                                anyhow::Ok(())
698                            })
699                            .detach();
700                            abs_paths_rx
701                        })
702                    })
703                })
704            }),
705        )
706        .await?
707    }
708
709    async fn outline(
710        root_dir: Option<Arc<Path>>,
711        foreground_tx: &mut mpsc::Sender<ForegroundFn>,
712        path_str: String,
713    ) -> anyhow::Result<String> {
714        let root_dir = root_dir
715            .ok_or_else(|| mlua::Error::runtime("cannot get outline without a root directory"))?;
716        let path = Self::parse_abs_path_in_root_dir(&root_dir, &path_str)?;
717        let outline = Self::run_foreground_fn(
718            "getting code outline",
719            foreground_tx,
720            Box::new(move |session, cx| {
721                cx.spawn(move |mut cx| async move {
722                    // TODO: This will not use file content from `fs_changes`. It will also reflect
723                    // user changes that have not been saved.
724                    let buffer = session
725                        .update(&mut cx, |session, cx| {
726                            session
727                                .project
728                                .update(cx, |project, cx| project.open_local_buffer(&path, cx))
729                        })?
730                        .await?;
731                    buffer.update(&mut cx, |buffer, _cx| {
732                        if let Some(outline) = buffer.snapshot().outline(None) {
733                            Ok(outline)
734                        } else {
735                            Err(anyhow!("No outline for file {path_str}"))
736                        }
737                    })
738                })
739            }),
740        )
741        .await?
742        .await??;
743
744        Ok(outline
745            .items
746            .into_iter()
747            .map(|item| {
748                if item.text.contains('\n') {
749                    log::error!("Outline item unexpectedly contains newline");
750                }
751                format!("{}{}", "  ".repeat(item.depth), item.text)
752            })
753            .collect::<Vec<String>>()
754            .join("\n"))
755    }
756
757    async fn run_foreground_fn<R: Send + 'static>(
758        description: &str,
759        foreground_tx: &mut mpsc::Sender<ForegroundFn>,
760        function: Box<dyn FnOnce(WeakEntity<Self>, AsyncApp) -> R + Send>,
761    ) -> anyhow::Result<R> {
762        let (response_tx, response_rx) = oneshot::channel();
763        let send_result = foreground_tx
764            .send(ForegroundFn(Box::new(move |this, cx| {
765                response_tx.send(function(this, cx)).ok();
766            })))
767            .await;
768        match send_result {
769            Ok(()) => (),
770            Err(err) => {
771                return Err(anyhow::Error::new(err).context(format!(
772                    "Internal error while enqueuing work for {description}"
773                )));
774            }
775        }
776        match response_rx.await {
777            Ok(result) => Ok(result),
778            Err(oneshot::Canceled) => Err(anyhow!(
779                "Internal error: response oneshot was canceled while {description}."
780            )),
781        }
782    }
783
784    fn parse_abs_path_in_root_dir(root_dir: &Path, path_str: &str) -> anyhow::Result<PathBuf> {
785        let path = Path::new(&path_str);
786        if path.is_absolute() {
787            // Check if path starts with root_dir prefix without resolving symlinks
788            if path.starts_with(&root_dir) {
789                Ok(path.to_path_buf())
790            } else {
791                Err(anyhow!(
792                    "Error: Absolute path {} is outside the current working directory",
793                    path_str
794                ))
795            }
796        } else {
797            // TODO: Does use of `../` break sandbox - is path canonicalization needed?
798            Ok(root_dir.join(path))
799        }
800    }
801}
802
803struct FileContent(RefCell<Vec<u8>>);
804
805impl UserData for FileContent {
806    fn add_methods<M: UserDataMethods<Self>>(_methods: &mut M) {
807        // FileContent doesn't have any methods so far.
808    }
809}
810
811#[derive(Debug)]
812pub enum ScriptEvent {
813    Spawned(ScriptId),
814    Exited(ScriptId),
815}
816
817impl EventEmitter<ScriptEvent> for ScriptSession {}
818
819#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
820pub struct ScriptId(u32);
821
822pub struct Script {
823    pub id: ScriptId,
824    pub state: ScriptState,
825    pub source: SharedString,
826}
827
828pub enum ScriptState {
829    Generating,
830    Running {
831        stdout: Arc<Mutex<String>>,
832    },
833    Succeeded {
834        stdout: String,
835    },
836    Failed {
837        stdout: String,
838        error: anyhow::Error,
839    },
840}
841
842impl Script {
843    pub fn source_tag(&self) -> String {
844        format!("{}{}{}", SCRIPT_START_TAG, self.source, SCRIPT_END_TAG)
845    }
846
847    /// If exited, returns a message with the output for the LLM
848    pub fn output_message_for_llm(&self) -> Option<String> {
849        match &self.state {
850            ScriptState::Generating { .. } => None,
851            ScriptState::Running { .. } => None,
852            ScriptState::Succeeded { stdout } => {
853                format!("Here's the script output:\n{}", stdout).into()
854            }
855            ScriptState::Failed { stdout, error } => format!(
856                "The script failed with:\n{}\n\nHere's the output it managed to print:\n{}",
857                error, stdout
858            )
859            .into(),
860        }
861    }
862
863    /// Get a snapshot of the script's stdout
864    pub fn stdout_snapshot(&self) -> String {
865        match &self.state {
866            ScriptState::Generating { .. } => String::new(),
867            ScriptState::Running { stdout } => stdout.lock().clone(),
868            ScriptState::Succeeded { stdout } => stdout.clone(),
869            ScriptState::Failed { stdout, .. } => stdout.clone(),
870        }
871    }
872
873    /// Returns the error if the script failed, otherwise None
874    pub fn error(&self) -> Option<&anyhow::Error> {
875        match &self.state {
876            ScriptState::Generating { .. } => None,
877            ScriptState::Running { .. } => None,
878            ScriptState::Succeeded { .. } => None,
879            ScriptState::Failed { error, .. } => Some(error),
880        }
881    }
882}
883
884#[cfg(test)]
885mod tests {
886    use gpui::TestAppContext;
887    use project::FakeFs;
888    use serde_json::json;
889    use settings::SettingsStore;
890
891    use super::*;
892
893    #[gpui::test]
894    async fn test_print(cx: &mut TestAppContext) {
895        let script = r#"
896            print("Hello", "world!")
897            print("Goodbye", "moon!")
898        "#;
899
900        let output = test_script(script, cx).await.unwrap();
901        assert_eq!(output, "Hello\tworld!\nGoodbye\tmoon!\n");
902    }
903
904    #[gpui::test]
905    async fn test_search(cx: &mut TestAppContext) {
906        let script = r#"
907            local results = search("world")
908            for i, result in ipairs(results) do
909                print("File: " .. result.path)
910                print("Matches:")
911                for j, match in ipairs(result.matches) do
912                    print("  " .. match)
913                end
914            end
915        "#;
916
917        let output = test_script(script, cx).await.unwrap();
918        assert_eq!(output, "File: /file1.txt\nMatches:\n  world\n");
919    }
920
921    async fn test_script(source: &str, cx: &mut TestAppContext) -> anyhow::Result<String> {
922        init_test(cx);
923        let fs = FakeFs::new(cx.executor());
924        fs.insert_tree(
925            "/",
926            json!({
927                "file1.txt": "Hello world!",
928                "file2.txt": "Goodbye moon!"
929            }),
930        )
931        .await;
932
933        let project = Project::test(fs, [Path::new("/")], cx).await;
934        let session = cx.new(|cx| ScriptSession::new(project, cx));
935
936        let (script_id, task) = session.update(cx, |session, cx| {
937            let script_id = session.new_script();
938            let task = session.run_script(script_id, source.to_string(), cx);
939
940            (script_id, task)
941        });
942
943        task.await?;
944
945        Ok(session.read_with(cx, |session, _cx| session.get(script_id).stdout_snapshot()))
946    }
947
948    fn init_test(cx: &mut TestAppContext) {
949        let settings_store = cx.update(SettingsStore::test);
950        cx.set_global(settings_store);
951        cx.update(Project::init_settings);
952    }
953}