session.rs

  1use anyhow::anyhow;
  2use collections::{HashMap, HashSet};
  3use futures::{
  4    channel::{mpsc, oneshot},
  5    pin_mut, SinkExt, StreamExt,
  6};
  7use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
  8use mlua::{ExternalResult, Lua, MultiValue, Table, UserData, UserDataMethods};
  9use parking_lot::Mutex;
 10use project::{search::SearchQuery, Fs, Project};
 11use regex::Regex;
 12use std::{
 13    cell::RefCell,
 14    path::{Path, PathBuf},
 15    sync::Arc,
 16};
 17use util::{paths::PathMatcher, ResultExt};
 18
 19use crate::{SCRIPT_END_TAG, SCRIPT_START_TAG};
 20
 21struct ForegroundFn(Box<dyn FnOnce(WeakEntity<ScriptSession>, AsyncApp) + Send>);
 22
 23pub struct ScriptSession {
 24    project: Entity<Project>,
 25    // TODO Remove this
 26    fs_changes: Arc<Mutex<HashMap<PathBuf, Vec<u8>>>>,
 27    foreground_fns_tx: mpsc::Sender<ForegroundFn>,
 28    _invoke_foreground_fns: Task<()>,
 29    scripts: Vec<Script>,
 30}
 31
 32impl ScriptSession {
 33    pub fn new(project: Entity<Project>, cx: &mut Context<Self>) -> Self {
 34        let (foreground_fns_tx, mut foreground_fns_rx) = mpsc::channel(128);
 35        ScriptSession {
 36            project,
 37            fs_changes: Arc::new(Mutex::new(HashMap::default())),
 38            foreground_fns_tx,
 39            _invoke_foreground_fns: cx.spawn(|this, cx| async move {
 40                while let Some(foreground_fn) = foreground_fns_rx.next().await {
 41                    foreground_fn.0(this.clone(), cx.clone());
 42                }
 43            }),
 44            scripts: Vec::new(),
 45        }
 46    }
 47
 48    pub fn new_script(&mut self) -> ScriptId {
 49        let id = ScriptId(self.scripts.len() as u32);
 50        let script = Script {
 51            id,
 52            state: ScriptState::Generating,
 53            source: SharedString::new_static(""),
 54        };
 55        self.scripts.push(script);
 56        id
 57    }
 58
 59    pub fn run_script(
 60        &mut self,
 61        script_id: ScriptId,
 62        script_src: String,
 63        cx: &mut Context<Self>,
 64    ) -> Task<anyhow::Result<()>> {
 65        let script = self.get_mut(script_id);
 66
 67        let stdout = Arc::new(Mutex::new(String::new()));
 68        script.source = script_src.clone().into();
 69        script.state = ScriptState::Running {
 70            stdout: stdout.clone(),
 71        };
 72
 73        let task = self.run_lua(script_src, stdout, cx);
 74
 75        cx.emit(ScriptEvent::Spawned(script_id));
 76
 77        cx.spawn(|session, mut cx| async move {
 78            let result = task.await;
 79
 80            session.update(&mut cx, |session, cx| {
 81                let script = session.get_mut(script_id);
 82                let stdout = script.stdout_snapshot();
 83
 84                script.state = match result {
 85                    Ok(()) => ScriptState::Succeeded { stdout },
 86                    Err(error) => ScriptState::Failed { stdout, error },
 87                };
 88
 89                cx.emit(ScriptEvent::Exited(script_id))
 90            })
 91        })
 92    }
 93
 94    fn run_lua(
 95        &mut self,
 96        script: String,
 97        stdout: Arc<Mutex<String>>,
 98        cx: &mut Context<Self>,
 99    ) -> Task<anyhow::Result<()>> {
100        const SANDBOX_PREAMBLE: &str = include_str!("sandbox_preamble.lua");
101
102        // TODO Remove fs_changes
103        let fs_changes = self.fs_changes.clone();
104        // TODO Honor all worktrees instead of the first one
105        let root_dir = self
106            .project
107            .read(cx)
108            .visible_worktrees(cx)
109            .next()
110            .map(|worktree| worktree.read(cx).abs_path());
111
112        let fs = self.project.read(cx).fs().clone();
113        let foreground_fns_tx = self.foreground_fns_tx.clone();
114
115        let task = cx.background_spawn({
116            let stdout = stdout.clone();
117
118            async move {
119                let lua = Lua::new();
120                lua.set_memory_limit(2 * 1024 * 1024 * 1024)?; // 2 GB
121                let globals = lua.globals();
122                globals.set(
123                    "sb_print",
124                    lua.create_function({
125                        let stdout = stdout.clone();
126                        move |_, args: MultiValue| Self::print(args, &stdout)
127                    })?,
128                )?;
129                globals.set(
130                    "search",
131                    lua.create_async_function({
132                        let foreground_fns_tx = foreground_fns_tx.clone();
133                        move |lua, regex| {
134                            let mut foreground_fns_tx = foreground_fns_tx.clone();
135                            let fs = fs.clone();
136                            async move {
137                                Self::search(&lua, &mut foreground_fns_tx, fs, regex)
138                                    .await
139                                    .into_lua_err()
140                            }
141                        }
142                    })?,
143                )?;
144                globals.set(
145                    "outline",
146                    lua.create_async_function({
147                        let root_dir = root_dir.clone();
148                        move |_lua, path| {
149                            let mut foreground_fns_tx = foreground_fns_tx.clone();
150                            let root_dir = root_dir.clone();
151                            async move {
152                                Self::outline(root_dir, &mut foreground_fns_tx, path)
153                                    .await
154                                    .into_lua_err()
155                            }
156                        }
157                    })?,
158                )?;
159                globals.set(
160                    "sb_io_open",
161                    lua.create_function({
162                        let fs_changes = fs_changes.clone();
163                        let root_dir = root_dir.clone();
164                        move |lua, (path_str, mode)| {
165                            Self::io_open(&lua, &fs_changes, root_dir.as_ref(), path_str, mode)
166                        }
167                    })?,
168                )?;
169                globals.set("user_script", script)?;
170
171                lua.load(SANDBOX_PREAMBLE).exec_async().await?;
172
173                // Drop Lua instance to decrement reference count.
174                drop(lua);
175
176                anyhow::Ok(())
177            }
178        });
179
180        task
181    }
182
183    pub fn get(&self, script_id: ScriptId) -> &Script {
184        &self.scripts[script_id.0 as usize]
185    }
186
187    fn get_mut(&mut self, script_id: ScriptId) -> &mut Script {
188        &mut self.scripts[script_id.0 as usize]
189    }
190
191    /// Sandboxed print() function in Lua.
192    fn print(args: MultiValue, stdout: &Mutex<String>) -> mlua::Result<()> {
193        for (index, arg) in args.into_iter().enumerate() {
194            // Lua's `print()` prints tab characters between each argument.
195            if index > 0 {
196                stdout.lock().push('\t');
197            }
198
199            // If the argument's to_string() fails, have the whole function call fail.
200            stdout.lock().push_str(&arg.to_string()?);
201        }
202        stdout.lock().push('\n');
203
204        Ok(())
205    }
206
207    /// Sandboxed io.open() function in Lua.
208    fn io_open(
209        lua: &Lua,
210        fs_changes: &Arc<Mutex<HashMap<PathBuf, Vec<u8>>>>,
211        root_dir: Option<&Arc<Path>>,
212        path_str: String,
213        mode: Option<String>,
214    ) -> mlua::Result<(Option<Table>, String)> {
215        let root_dir = root_dir
216            .ok_or_else(|| mlua::Error::runtime("cannot open file without a root directory"))?;
217
218        let mode = mode.unwrap_or_else(|| "r".to_string());
219
220        // Parse the mode string to determine read/write permissions
221        let read_perm = mode.contains('r');
222        let write_perm = mode.contains('w') || mode.contains('a') || mode.contains('+');
223        let append = mode.contains('a');
224        let truncate = mode.contains('w');
225
226        // This will be the Lua value returned from the `open` function.
227        let file = lua.create_table()?;
228
229        // Store file metadata in the file
230        file.set("__path", path_str.clone())?;
231        file.set("__mode", mode.clone())?;
232        file.set("__read_perm", read_perm)?;
233        file.set("__write_perm", write_perm)?;
234
235        let path = match Self::parse_abs_path_in_root_dir(&root_dir, &path_str) {
236            Ok(path) => path,
237            Err(err) => return Ok((None, format!("{err}"))),
238        };
239
240        // close method
241        let close_fn = {
242            let fs_changes = fs_changes.clone();
243            lua.create_function(move |_lua, file_userdata: mlua::Table| {
244                let write_perm = file_userdata.get::<bool>("__write_perm")?;
245                let path = file_userdata.get::<String>("__path")?;
246
247                if write_perm {
248                    // When closing a writable file, record the content
249                    let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
250                    let content_ref = content.borrow::<FileContent>()?;
251                    let content_vec = content_ref.0.borrow();
252
253                    // Don't actually write to disk; instead, just update fs_changes.
254                    let path_buf = PathBuf::from(&path);
255                    fs_changes
256                        .lock()
257                        .insert(path_buf.clone(), content_vec.clone());
258                }
259
260                Ok(true)
261            })?
262        };
263        file.set("close", close_fn)?;
264
265        // If it's a directory, give it a custom read() and return early.
266        if path.is_dir() {
267            // TODO handle the case where we changed it in the in-memory fs
268
269            // Create a special directory handle
270            file.set("__is_directory", true)?;
271
272            // Store directory entries
273            let entries = match std::fs::read_dir(&path) {
274                Ok(entries) => {
275                    let mut entry_names = Vec::new();
276                    for entry in entries.flatten() {
277                        entry_names.push(entry.file_name().to_string_lossy().into_owned());
278                    }
279                    entry_names
280                }
281                Err(e) => return Ok((None, format!("Error reading directory: {}", e))),
282            };
283
284            // Save the list of entries
285            file.set("__dir_entries", entries)?;
286            file.set("__dir_position", 0usize)?;
287
288            // Create a directory-specific read function
289            let read_fn = lua.create_function(|_lua, file_userdata: mlua::Table| {
290                let position = file_userdata.get::<usize>("__dir_position")?;
291                let entries = file_userdata.get::<Vec<String>>("__dir_entries")?;
292
293                if position >= entries.len() {
294                    return Ok(None); // No more entries
295                }
296
297                let entry = entries[position].clone();
298                file_userdata.set("__dir_position", position + 1)?;
299
300                Ok(Some(entry))
301            })?;
302            file.set("read", read_fn)?;
303
304            // If we got this far, the directory was opened successfully
305            return Ok((Some(file), String::new()));
306        }
307
308        let fs_changes_map = fs_changes.lock();
309
310        let is_in_changes = fs_changes_map.contains_key(&path);
311        let file_exists = is_in_changes || path.exists();
312        let mut file_content = Vec::new();
313
314        if file_exists && !truncate {
315            if is_in_changes {
316                file_content = fs_changes_map.get(&path).unwrap().clone();
317            } else {
318                // Try to read existing content if file exists and we're not truncating
319                match std::fs::read(&path) {
320                    Ok(content) => file_content = content,
321                    Err(e) => return Ok((None, format!("Error reading file: {}", e))),
322                }
323            }
324        }
325
326        drop(fs_changes_map); // Unlock the fs_changes mutex.
327
328        // If in append mode, position should be at the end
329        let position = if append && file_exists {
330            file_content.len()
331        } else {
332            0
333        };
334        file.set("__position", position)?;
335        file.set(
336            "__content",
337            lua.create_userdata(FileContent(RefCell::new(file_content)))?,
338        )?;
339
340        // Create file methods
341
342        // read method
343        let read_fn = {
344            lua.create_function(
345                |_lua, (file_userdata, format): (mlua::Table, Option<mlua::Value>)| {
346                    let read_perm = file_userdata.get::<bool>("__read_perm")?;
347                    if !read_perm {
348                        return Err(mlua::Error::runtime("File not open for reading"));
349                    }
350
351                    let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
352                    let mut position = file_userdata.get::<usize>("__position")?;
353                    let content_ref = content.borrow::<FileContent>()?;
354                    let content_vec = content_ref.0.borrow();
355
356                    if position >= content_vec.len() {
357                        return Ok(None); // EOF
358                    }
359
360                    match format {
361                        Some(mlua::Value::String(s)) => {
362                            let lossy_string = s.to_string_lossy();
363                            let format_str: &str = lossy_string.as_ref();
364
365                            // Only consider the first 2 bytes, since it's common to pass e.g. "*all"  instead of "*a"
366                            match &format_str[0..2] {
367                                "*a" => {
368                                    // Read entire file from current position
369                                    let result = String::from_utf8_lossy(&content_vec[position..])
370                                        .to_string();
371                                    position = content_vec.len();
372                                    file_userdata.set("__position", position)?;
373                                    Ok(Some(result))
374                                }
375                                "*l" => {
376                                    // Read next line
377                                    let mut line = Vec::new();
378                                    let mut found_newline = false;
379
380                                    while position < content_vec.len() {
381                                        let byte = content_vec[position];
382                                        position += 1;
383
384                                        if byte == b'\n' {
385                                            found_newline = true;
386                                            break;
387                                        }
388
389                                        // Skip \r in \r\n sequence but add it if it's alone
390                                        if byte == b'\r' {
391                                            if position < content_vec.len()
392                                                && content_vec[position] == b'\n'
393                                            {
394                                                position += 1;
395                                                found_newline = true;
396                                                break;
397                                            }
398                                        }
399
400                                        line.push(byte);
401                                    }
402
403                                    file_userdata.set("__position", position)?;
404
405                                    if !found_newline
406                                        && line.is_empty()
407                                        && position >= content_vec.len()
408                                    {
409                                        return Ok(None); // EOF
410                                    }
411
412                                    let result = String::from_utf8_lossy(&line).to_string();
413                                    Ok(Some(result))
414                                }
415                                "*n" => {
416                                    // Try to parse as a number (number of bytes to read)
417                                    match format_str.parse::<usize>() {
418                                        Ok(n) => {
419                                            let end =
420                                                std::cmp::min(position + n, content_vec.len());
421                                            let bytes = &content_vec[position..end];
422                                            let result = String::from_utf8_lossy(bytes).to_string();
423                                            position = end;
424                                            file_userdata.set("__position", position)?;
425                                            Ok(Some(result))
426                                        }
427                                        Err(_) => Err(mlua::Error::runtime(format!(
428                                            "Invalid format: {}",
429                                            format_str
430                                        ))),
431                                    }
432                                }
433                                "*L" => {
434                                    // Read next line keeping the end of line
435                                    let mut line = Vec::new();
436
437                                    while position < content_vec.len() {
438                                        let byte = content_vec[position];
439                                        position += 1;
440
441                                        line.push(byte);
442
443                                        if byte == b'\n' {
444                                            break;
445                                        }
446
447                                        // If we encounter a \r, add it and check if the next is \n
448                                        if byte == b'\r'
449                                            && position < content_vec.len()
450                                            && content_vec[position] == b'\n'
451                                        {
452                                            line.push(content_vec[position]);
453                                            position += 1;
454                                            break;
455                                        }
456                                    }
457
458                                    file_userdata.set("__position", position)?;
459
460                                    if line.is_empty() && position >= content_vec.len() {
461                                        return Ok(None); // EOF
462                                    }
463
464                                    let result = String::from_utf8_lossy(&line).to_string();
465                                    Ok(Some(result))
466                                }
467                                _ => Err(mlua::Error::runtime(format!(
468                                    "Unsupported format: {}",
469                                    format_str
470                                ))),
471                            }
472                        }
473                        Some(mlua::Value::Number(n)) => {
474                            // Read n bytes
475                            let n = n as usize;
476                            let end = std::cmp::min(position + n, content_vec.len());
477                            let bytes = &content_vec[position..end];
478                            let result = String::from_utf8_lossy(bytes).to_string();
479                            position = end;
480                            file_userdata.set("__position", position)?;
481                            Ok(Some(result))
482                        }
483                        Some(_) => Err(mlua::Error::runtime("Invalid format")),
484                        None => {
485                            // Default is to read a line
486                            let mut line = Vec::new();
487                            let mut found_newline = false;
488
489                            while position < content_vec.len() {
490                                let byte = content_vec[position];
491                                position += 1;
492
493                                if byte == b'\n' {
494                                    found_newline = true;
495                                    break;
496                                }
497
498                                // Handle \r\n
499                                if byte == b'\r' {
500                                    if position < content_vec.len()
501                                        && content_vec[position] == b'\n'
502                                    {
503                                        position += 1;
504                                        found_newline = true;
505                                        break;
506                                    }
507                                }
508
509                                line.push(byte);
510                            }
511
512                            file_userdata.set("__position", position)?;
513
514                            if !found_newline && line.is_empty() && position >= content_vec.len() {
515                                return Ok(None); // EOF
516                            }
517
518                            let result = String::from_utf8_lossy(&line).to_string();
519                            Ok(Some(result))
520                        }
521                    }
522                },
523            )?
524        };
525        file.set("read", read_fn)?;
526
527        // write method
528        let write_fn = {
529            let fs_changes = fs_changes.clone();
530
531            lua.create_function(move |_lua, (file_userdata, text): (mlua::Table, String)| {
532                let write_perm = file_userdata.get::<bool>("__write_perm")?;
533                if !write_perm {
534                    return Err(mlua::Error::runtime("File not open for writing"));
535                }
536
537                let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
538                let position = file_userdata.get::<usize>("__position")?;
539                let content_ref = content.borrow::<FileContent>()?;
540                let mut content_vec = content_ref.0.borrow_mut();
541
542                let bytes = text.as_bytes();
543
544                // Ensure the vector has enough capacity
545                if position + bytes.len() > content_vec.len() {
546                    content_vec.resize(position + bytes.len(), 0);
547                }
548
549                // Write the bytes
550                for (i, &byte) in bytes.iter().enumerate() {
551                    content_vec[position + i] = byte;
552                }
553
554                // Update position
555                let new_position = position + bytes.len();
556                file_userdata.set("__position", new_position)?;
557
558                // Update fs_changes
559                let path = file_userdata.get::<String>("__path")?;
560                let path_buf = PathBuf::from(path);
561                fs_changes.lock().insert(path_buf, content_vec.clone());
562
563                Ok(true)
564            })?
565        };
566        file.set("write", write_fn)?;
567
568        // If we got this far, the file was opened successfully
569        Ok((Some(file), String::new()))
570    }
571
572    async fn search(
573        lua: &Lua,
574        foreground_tx: &mut mpsc::Sender<ForegroundFn>,
575        fs: Arc<dyn Fs>,
576        regex: String,
577    ) -> anyhow::Result<Table> {
578        // TODO: Allow specification of these options.
579        let search_query = SearchQuery::regex(
580            &regex,
581            false,
582            false,
583            false,
584            PathMatcher::default(),
585            PathMatcher::default(),
586            None,
587        );
588        let search_query = match search_query {
589            Ok(query) => query,
590            Err(e) => return Err(anyhow!("Invalid search query: {}", e)),
591        };
592
593        // TODO: Should use `search_query.regex`. The tool description should also be updated,
594        // as it specifies standard regex.
595        let search_regex = match Regex::new(&regex) {
596            Ok(re) => re,
597            Err(e) => return Err(anyhow!("Invalid regex: {}", e)),
598        };
599
600        let mut abs_paths_rx = Self::find_search_candidates(search_query, foreground_tx).await?;
601
602        let mut search_results: Vec<Table> = Vec::new();
603        while let Some(path) = abs_paths_rx.next().await {
604            // Skip files larger than 1MB
605            if let Ok(Some(metadata)) = fs.metadata(&path).await {
606                if metadata.len > 1_000_000 {
607                    continue;
608                }
609            }
610
611            // Attempt to read the file as text
612            if let Ok(content) = fs.load(&path).await {
613                let mut matches = Vec::new();
614
615                // Find all regex matches in the content
616                for capture in search_regex.find_iter(&content) {
617                    matches.push(capture.as_str().to_string());
618                }
619
620                // If we found matches, create a result entry
621                if !matches.is_empty() {
622                    let result_entry = lua.create_table()?;
623                    result_entry.set("path", path.to_string_lossy().to_string())?;
624
625                    let matches_table = lua.create_table()?;
626                    for (ix, m) in matches.iter().enumerate() {
627                        matches_table.set(ix + 1, m.clone())?;
628                    }
629                    result_entry.set("matches", matches_table)?;
630
631                    search_results.push(result_entry);
632                }
633            }
634        }
635
636        // Create a table to hold our results
637        let results_table = lua.create_table()?;
638        for (ix, entry) in search_results.into_iter().enumerate() {
639            results_table.set(ix + 1, entry)?;
640        }
641
642        Ok(results_table)
643    }
644
645    async fn find_search_candidates(
646        search_query: SearchQuery,
647        foreground_tx: &mut mpsc::Sender<ForegroundFn>,
648    ) -> anyhow::Result<mpsc::UnboundedReceiver<PathBuf>> {
649        Self::run_foreground_fn(
650            "finding search file candidates",
651            foreground_tx,
652            Box::new(move |session, mut cx| {
653                session.update(&mut cx, |session, cx| {
654                    session.project.update(cx, |project, cx| {
655                        project.worktree_store().update(cx, |worktree_store, cx| {
656                            // TODO: Better limit? For now this is the same as
657                            // MAX_SEARCH_RESULT_FILES.
658                            let limit = 5000;
659                            // TODO: Providing non-empty open_entries can make this a bit more
660                            // efficient as it can skip checking that these paths are textual.
661                            let open_entries = HashSet::default();
662                            let candidates = worktree_store.find_search_candidates(
663                                search_query,
664                                limit,
665                                open_entries,
666                                project.fs().clone(),
667                                cx,
668                            );
669                            let (abs_paths_tx, abs_paths_rx) = mpsc::unbounded();
670                            cx.spawn(|worktree_store, cx| async move {
671                                pin_mut!(candidates);
672
673                                while let Some(project_path) = candidates.next().await {
674                                    worktree_store.read_with(&cx, |worktree_store, cx| {
675                                        if let Some(worktree) = worktree_store
676                                            .worktree_for_id(project_path.worktree_id, cx)
677                                        {
678                                            if let Some(abs_path) = worktree
679                                                .read(cx)
680                                                .absolutize(&project_path.path)
681                                                .log_err()
682                                            {
683                                                abs_paths_tx.unbounded_send(abs_path)?;
684                                            }
685                                        }
686                                        anyhow::Ok(())
687                                    })??;
688                                }
689                                anyhow::Ok(())
690                            })
691                            .detach();
692                            abs_paths_rx
693                        })
694                    })
695                })
696            }),
697        )
698        .await?
699    }
700
701    async fn outline(
702        root_dir: Option<Arc<Path>>,
703        foreground_tx: &mut mpsc::Sender<ForegroundFn>,
704        path_str: String,
705    ) -> anyhow::Result<String> {
706        let root_dir = root_dir
707            .ok_or_else(|| mlua::Error::runtime("cannot get outline without a root directory"))?;
708        let path = Self::parse_abs_path_in_root_dir(&root_dir, &path_str)?;
709        let outline = Self::run_foreground_fn(
710            "getting code outline",
711            foreground_tx,
712            Box::new(move |session, cx| {
713                cx.spawn(move |mut cx| async move {
714                    // TODO: This will not use file content from `fs_changes`. It will also reflect
715                    // user changes that have not been saved.
716                    let buffer = session
717                        .update(&mut cx, |session, cx| {
718                            session
719                                .project
720                                .update(cx, |project, cx| project.open_local_buffer(&path, cx))
721                        })?
722                        .await?;
723                    buffer.update(&mut cx, |buffer, _cx| {
724                        if let Some(outline) = buffer.snapshot().outline(None) {
725                            Ok(outline)
726                        } else {
727                            Err(anyhow!("No outline for file {path_str}"))
728                        }
729                    })
730                })
731            }),
732        )
733        .await?
734        .await??;
735
736        Ok(outline
737            .items
738            .into_iter()
739            .map(|item| {
740                if item.text.contains('\n') {
741                    log::error!("Outline item unexpectedly contains newline");
742                }
743                format!("{}{}", "  ".repeat(item.depth), item.text)
744            })
745            .collect::<Vec<String>>()
746            .join("\n"))
747    }
748
749    async fn run_foreground_fn<R: Send + 'static>(
750        description: &str,
751        foreground_tx: &mut mpsc::Sender<ForegroundFn>,
752        function: Box<dyn FnOnce(WeakEntity<Self>, AsyncApp) -> R + Send>,
753    ) -> anyhow::Result<R> {
754        let (response_tx, response_rx) = oneshot::channel();
755        let send_result = foreground_tx
756            .send(ForegroundFn(Box::new(move |this, cx| {
757                response_tx.send(function(this, cx)).ok();
758            })))
759            .await;
760        match send_result {
761            Ok(()) => (),
762            Err(err) => {
763                return Err(anyhow::Error::new(err).context(format!(
764                    "Internal error while enqueuing work for {description}"
765                )));
766            }
767        }
768        match response_rx.await {
769            Ok(result) => Ok(result),
770            Err(oneshot::Canceled) => Err(anyhow!(
771                "Internal error: response oneshot was canceled while {description}."
772            )),
773        }
774    }
775
776    fn parse_abs_path_in_root_dir(root_dir: &Path, path_str: &str) -> anyhow::Result<PathBuf> {
777        let path = Path::new(&path_str);
778        if path.is_absolute() {
779            // Check if path starts with root_dir prefix without resolving symlinks
780            if path.starts_with(&root_dir) {
781                Ok(path.to_path_buf())
782            } else {
783                Err(anyhow!(
784                    "Error: Absolute path {} is outside the current working directory",
785                    path_str
786                ))
787            }
788        } else {
789            // TODO: Does use of `../` break sandbox - is path canonicalization needed?
790            Ok(root_dir.join(path))
791        }
792    }
793}
794
795struct FileContent(RefCell<Vec<u8>>);
796
797impl UserData for FileContent {
798    fn add_methods<M: UserDataMethods<Self>>(_methods: &mut M) {
799        // FileContent doesn't have any methods so far.
800    }
801}
802
803#[derive(Debug)]
804pub enum ScriptEvent {
805    Spawned(ScriptId),
806    Exited(ScriptId),
807}
808
809impl EventEmitter<ScriptEvent> for ScriptSession {}
810
811#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
812pub struct ScriptId(u32);
813
814pub struct Script {
815    pub id: ScriptId,
816    pub state: ScriptState,
817    pub source: SharedString,
818}
819
820pub enum ScriptState {
821    Generating,
822    Running {
823        stdout: Arc<Mutex<String>>,
824    },
825    Succeeded {
826        stdout: String,
827    },
828    Failed {
829        stdout: String,
830        error: anyhow::Error,
831    },
832}
833
834impl Script {
835    pub fn source_tag(&self) -> String {
836        format!("{}{}{}", SCRIPT_START_TAG, self.source, SCRIPT_END_TAG)
837    }
838
839    /// If exited, returns a message with the output for the LLM
840    pub fn output_message_for_llm(&self) -> Option<String> {
841        match &self.state {
842            ScriptState::Generating { .. } => None,
843            ScriptState::Running { .. } => None,
844            ScriptState::Succeeded { stdout } => {
845                format!("Here's the script output:\n{}", stdout).into()
846            }
847            ScriptState::Failed { stdout, error } => format!(
848                "The script failed with:\n{}\n\nHere's the output it managed to print:\n{}",
849                error, stdout
850            )
851            .into(),
852        }
853    }
854
855    /// Get a snapshot of the script's stdout
856    pub fn stdout_snapshot(&self) -> String {
857        match &self.state {
858            ScriptState::Generating { .. } => String::new(),
859            ScriptState::Running { stdout } => stdout.lock().clone(),
860            ScriptState::Succeeded { stdout } => stdout.clone(),
861            ScriptState::Failed { stdout, .. } => stdout.clone(),
862        }
863    }
864
865    /// Returns the error if the script failed, otherwise None
866    pub fn error(&self) -> Option<&anyhow::Error> {
867        match &self.state {
868            ScriptState::Generating { .. } => None,
869            ScriptState::Running { .. } => None,
870            ScriptState::Succeeded { .. } => None,
871            ScriptState::Failed { error, .. } => Some(error),
872        }
873    }
874}
875
876#[cfg(test)]
877mod tests {
878    use gpui::TestAppContext;
879    use project::FakeFs;
880    use serde_json::json;
881    use settings::SettingsStore;
882
883    use super::*;
884
885    #[gpui::test]
886    async fn test_print(cx: &mut TestAppContext) {
887        let script = r#"
888            print("Hello", "world!")
889            print("Goodbye", "moon!")
890        "#;
891
892        let output = test_script(script, cx).await.unwrap();
893        assert_eq!(output, "Hello\tworld!\nGoodbye\tmoon!\n");
894    }
895
896    #[gpui::test]
897    async fn test_search(cx: &mut TestAppContext) {
898        let script = r#"
899            local results = search("world")
900            for i, result in ipairs(results) do
901                print("File: " .. result.path)
902                print("Matches:")
903                for j, match in ipairs(result.matches) do
904                    print("  " .. match)
905                end
906            end
907        "#;
908
909        let output = test_script(script, cx).await.unwrap();
910        assert_eq!(output, "File: /file1.txt\nMatches:\n  world\n");
911    }
912
913    async fn test_script(source: &str, cx: &mut TestAppContext) -> anyhow::Result<String> {
914        init_test(cx);
915        let fs = FakeFs::new(cx.executor());
916        fs.insert_tree(
917            "/",
918            json!({
919                "file1.txt": "Hello world!",
920                "file2.txt": "Goodbye moon!"
921            }),
922        )
923        .await;
924
925        let project = Project::test(fs, [Path::new("/")], cx).await;
926        let session = cx.new(|cx| ScriptSession::new(project, cx));
927
928        let (script_id, task) = session.update(cx, |session, cx| {
929            let script_id = session.new_script();
930            let task = session.run_script(script_id, source.to_string(), cx);
931
932            (script_id, task)
933        });
934
935        task.await?;
936
937        Ok(session.read_with(cx, |session, _cx| session.get(script_id).stdout_snapshot()))
938    }
939
940    fn init_test(cx: &mut TestAppContext) {
941        let settings_store = cx.update(SettingsStore::test);
942        cx.set_global(settings_store);
943        cx.update(Project::init_settings);
944    }
945}