scripting_session.rs

   1use anyhow::anyhow;
   2use collections::HashSet;
   3use futures::{
   4    channel::{mpsc, oneshot},
   5    pin_mut, SinkExt, StreamExt,
   6};
   7use gpui::{AppContext, AsyncApp, Context, Entity, Task, WeakEntity};
   8use language::Buffer;
   9use mlua::{ExternalResult, Lua, MultiValue, Table, UserData, UserDataMethods};
  10use parking_lot::Mutex;
  11use project::{search::SearchQuery, Fs, Project, ProjectPath, WorktreeId};
  12use regex::Regex;
  13use std::{
  14    path::{Path, PathBuf},
  15    sync::Arc,
  16};
  17use util::{paths::PathMatcher, ResultExt};
  18
  19struct ForegroundFn(Box<dyn FnOnce(WeakEntity<ScriptingSession>, AsyncApp) + Send>);
  20
  21pub struct ScriptingSession {
  22    project: Entity<Project>,
  23    scripts: Vec<Script>,
  24    changed_buffers: HashSet<Entity<Buffer>>,
  25    foreground_fns_tx: mpsc::Sender<ForegroundFn>,
  26    _invoke_foreground_fns: Task<()>,
  27}
  28
  29impl ScriptingSession {
  30    pub fn new(project: Entity<Project>, cx: &mut Context<Self>) -> Self {
  31        let (foreground_fns_tx, mut foreground_fns_rx) = mpsc::channel(128);
  32        ScriptingSession {
  33            project,
  34            scripts: Vec::new(),
  35            changed_buffers: HashSet::default(),
  36            foreground_fns_tx,
  37            _invoke_foreground_fns: cx.spawn(|this, cx| async move {
  38                while let Some(foreground_fn) = foreground_fns_rx.next().await {
  39                    foreground_fn.0(this.clone(), cx.clone());
  40                }
  41            }),
  42        }
  43    }
  44
  45    pub fn changed_buffers(&self) -> impl ExactSizeIterator<Item = &Entity<Buffer>> {
  46        self.changed_buffers.iter()
  47    }
  48
  49    pub fn run_script(
  50        &mut self,
  51        script_src: String,
  52        cx: &mut Context<Self>,
  53    ) -> (ScriptId, Task<()>) {
  54        let id = ScriptId(self.scripts.len() as u32);
  55
  56        let stdout = Arc::new(Mutex::new(String::new()));
  57
  58        let script = Script {
  59            state: ScriptState::Running {
  60                stdout: stdout.clone(),
  61            },
  62        };
  63        self.scripts.push(script);
  64
  65        let task = self.run_lua(script_src, stdout, cx);
  66
  67        let task = cx.spawn(|session, mut cx| async move {
  68            let result = task.await;
  69
  70            session
  71                .update(&mut cx, |session, _cx| {
  72                    let script = session.get_mut(id);
  73                    let stdout = script.stdout_snapshot();
  74
  75                    script.state = match result {
  76                        Ok(()) => ScriptState::Succeeded { stdout },
  77                        Err(error) => ScriptState::Failed { stdout, error },
  78                    };
  79                })
  80                .log_err();
  81        });
  82
  83        (id, task)
  84    }
  85
  86    fn run_lua(
  87        &mut self,
  88        script: String,
  89        stdout: Arc<Mutex<String>>,
  90        cx: &mut Context<Self>,
  91    ) -> Task<anyhow::Result<()>> {
  92        const SANDBOX_PREAMBLE: &str = include_str!("sandbox_preamble.lua");
  93
  94        // TODO Honor all worktrees instead of the first one
  95        let worktree_info = self
  96            .project
  97            .read(cx)
  98            .visible_worktrees(cx)
  99            .next()
 100            .map(|worktree| {
 101                let worktree = worktree.read(cx);
 102                (worktree.id(), worktree.abs_path())
 103            });
 104
 105        let root_dir = worktree_info.as_ref().map(|(_, root)| root.clone());
 106
 107        let fs = self.project.read(cx).fs().clone();
 108        let foreground_fns_tx = self.foreground_fns_tx.clone();
 109
 110        let task = cx.background_spawn({
 111            let stdout = stdout.clone();
 112
 113            async move {
 114                let lua = Lua::new();
 115                lua.set_memory_limit(2 * 1024 * 1024 * 1024)?; // 2 GB
 116                let globals = lua.globals();
 117
 118                // Use the project root dir as the script's current working dir.
 119                if let Some(root_dir) = &root_dir {
 120                    if let Some(root_dir) = root_dir.to_str() {
 121                        globals.set("cwd", root_dir)?;
 122                    }
 123                }
 124
 125                globals.set(
 126                    "sb_print",
 127                    lua.create_function({
 128                        let stdout = stdout.clone();
 129                        move |_, args: MultiValue| Self::print(args, &stdout)
 130                    })?,
 131                )?;
 132                globals.set(
 133                    "search",
 134                    lua.create_async_function({
 135                        let foreground_fns_tx = foreground_fns_tx.clone();
 136                        let fs = fs.clone();
 137                        move |lua, regex| {
 138                            let mut foreground_fns_tx = foreground_fns_tx.clone();
 139                            let fs = fs.clone();
 140                            async move {
 141                                Self::search(&lua, &mut foreground_fns_tx, fs, regex)
 142                                    .await
 143                                    .into_lua_err()
 144                            }
 145                        }
 146                    })?,
 147                )?;
 148                globals.set(
 149                    "outline",
 150                    lua.create_async_function({
 151                        let root_dir = root_dir.clone();
 152                        let foreground_fns_tx = foreground_fns_tx.clone();
 153                        move |_lua, path| {
 154                            let mut foreground_fns_tx = foreground_fns_tx.clone();
 155                            let root_dir = root_dir.clone();
 156                            async move {
 157                                Self::outline(root_dir, &mut foreground_fns_tx, path)
 158                                    .await
 159                                    .into_lua_err()
 160                            }
 161                        }
 162                    })?,
 163                )?;
 164                globals.set(
 165                    "sb_io_open",
 166                    lua.create_async_function({
 167                        let worktree_info = worktree_info.clone();
 168                        let foreground_fns_tx = foreground_fns_tx.clone();
 169                        move |lua, (path_str, mode)| {
 170                            let worktree_info = worktree_info.clone();
 171                            let mut foreground_fns_tx = foreground_fns_tx.clone();
 172                            let fs = fs.clone();
 173                            async move {
 174                                Self::io_open(
 175                                    &lua,
 176                                    worktree_info,
 177                                    &mut foreground_fns_tx,
 178                                    fs,
 179                                    path_str,
 180                                    mode,
 181                                )
 182                                .await
 183                            }
 184                        }
 185                    })?,
 186                )?;
 187                globals.set("user_script", script)?;
 188
 189                lua.load(SANDBOX_PREAMBLE).exec_async().await?;
 190
 191                // Drop Lua instance to decrement reference count.
 192                drop(lua);
 193
 194                anyhow::Ok(())
 195            }
 196        });
 197
 198        task
 199    }
 200
 201    pub fn get(&self, script_id: ScriptId) -> &Script {
 202        &self.scripts[script_id.0 as usize]
 203    }
 204
 205    fn get_mut(&mut self, script_id: ScriptId) -> &mut Script {
 206        &mut self.scripts[script_id.0 as usize]
 207    }
 208
 209    /// Sandboxed print() function in Lua.
 210    fn print(args: MultiValue, stdout: &Mutex<String>) -> mlua::Result<()> {
 211        for (index, arg) in args.into_iter().enumerate() {
 212            // Lua's `print()` prints tab characters between each argument.
 213            if index > 0 {
 214                stdout.lock().push('\t');
 215            }
 216
 217            // If the argument's to_string() fails, have the whole function call fail.
 218            stdout.lock().push_str(&arg.to_string()?);
 219        }
 220        stdout.lock().push('\n');
 221
 222        Ok(())
 223    }
 224
 225    /// Sandboxed io.open() function in Lua.
 226    async fn io_open(
 227        lua: &Lua,
 228        worktree_info: Option<(WorktreeId, Arc<Path>)>,
 229        foreground_tx: &mut mpsc::Sender<ForegroundFn>,
 230        fs: Arc<dyn Fs>,
 231        path_str: String,
 232        mode: Option<String>,
 233    ) -> mlua::Result<(Option<Table>, String)> {
 234        let (worktree_id, root_dir) = worktree_info
 235            .ok_or_else(|| mlua::Error::runtime("cannot open file without a root directory"))?;
 236
 237        let mode = mode.unwrap_or_else(|| "r".to_string());
 238
 239        // Parse the mode string to determine read/write permissions
 240        let read_perm = mode.contains('r');
 241        let write_perm = mode.contains('w') || mode.contains('a') || mode.contains('+');
 242        let append = mode.contains('a');
 243        let truncate = mode.contains('w');
 244
 245        // This will be the Lua value returned from the `open` function.
 246        let file = lua.create_table()?;
 247
 248        // Store file metadata in the file
 249        file.set("__mode", mode.clone())?;
 250        file.set("__read_perm", read_perm)?;
 251        file.set("__write_perm", write_perm)?;
 252
 253        let path = match Self::parse_abs_path_in_root_dir(&root_dir, &path_str) {
 254            Ok(path) => path,
 255            Err(err) => return Ok((None, format!("{err}"))),
 256        };
 257
 258        let project_path = ProjectPath {
 259            worktree_id,
 260            path: Path::new(&path_str).into(),
 261        };
 262
 263        // flush / close method
 264        let flush_fn = {
 265            let project_path = project_path.clone();
 266            let foreground_tx = foreground_tx.clone();
 267            lua.create_async_function(move |_lua, file_userdata: mlua::Table| {
 268                let project_path = project_path.clone();
 269                let mut foreground_tx = foreground_tx.clone();
 270                async move {
 271                    Self::io_file_flush(file_userdata, project_path, &mut foreground_tx).await
 272                }
 273            })?
 274        };
 275        file.set("flush", flush_fn.clone())?;
 276        // We don't really hold files open, so we only need to flush on close
 277        file.set("close", flush_fn)?;
 278
 279        // If it's a directory, give it a custom read() and return early.
 280        if fs.is_dir(&path).await {
 281            return Self::io_file_dir(lua, fs, file, &path).await;
 282        }
 283
 284        let mut file_content = Vec::new();
 285
 286        if !truncate {
 287            // Try to read existing content if we're not truncating
 288            match Self::read_buffer(project_path.clone(), foreground_tx).await {
 289                Ok(content) => file_content = content.into_bytes(),
 290                Err(e) => return Ok((None, format!("Error reading file: {}", e))),
 291            }
 292        }
 293
 294        // If in append mode, position should be at the end
 295        let position = if append { file_content.len() } else { 0 };
 296        file.set("__position", position)?;
 297        file.set(
 298            "__content",
 299            lua.create_userdata(FileContent(Arc::new(Mutex::new(file_content))))?,
 300        )?;
 301
 302        // Create file methods
 303
 304        // read method
 305        let read_fn = lua.create_function(Self::io_file_read)?;
 306        file.set("read", read_fn)?;
 307
 308        // write method
 309        let write_fn = lua.create_function(Self::io_file_write)?;
 310        file.set("write", write_fn)?;
 311
 312        // If we got this far, the file was opened successfully
 313        Ok((Some(file), String::new()))
 314    }
 315
 316    async fn read_buffer(
 317        project_path: ProjectPath,
 318        foreground_tx: &mut mpsc::Sender<ForegroundFn>,
 319    ) -> anyhow::Result<String> {
 320        Self::run_foreground_fn(
 321            "read file from buffer",
 322            foreground_tx,
 323            Box::new(move |session, mut cx| {
 324                session.update(&mut cx, |session, cx| {
 325                    let open_buffer_task = session
 326                        .project
 327                        .update(cx, |project, cx| project.open_buffer(project_path, cx));
 328
 329                    cx.spawn(|_, cx| async move {
 330                        let buffer = open_buffer_task.await?;
 331
 332                        let text = buffer.read_with(&cx, |buffer, _cx| buffer.text())?;
 333                        Ok(text)
 334                    })
 335                })
 336            }),
 337        )
 338        .await??
 339        .await
 340    }
 341
 342    async fn io_file_flush(
 343        file_userdata: mlua::Table,
 344        project_path: ProjectPath,
 345        foreground_tx: &mut mpsc::Sender<ForegroundFn>,
 346    ) -> mlua::Result<bool> {
 347        let write_perm = file_userdata.get::<bool>("__write_perm")?;
 348
 349        if write_perm {
 350            let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
 351            let content_ref = content.borrow::<FileContent>()?;
 352            let text = {
 353                let mut content_vec = content_ref.0.lock();
 354                let content_vec = std::mem::take(&mut *content_vec);
 355                String::from_utf8(content_vec).into_lua_err()?
 356            };
 357
 358            Self::write_to_buffer(project_path, text, foreground_tx)
 359                .await
 360                .into_lua_err()?;
 361        }
 362
 363        Ok(true)
 364    }
 365
 366    async fn write_to_buffer(
 367        project_path: ProjectPath,
 368        text: String,
 369        foreground_tx: &mut mpsc::Sender<ForegroundFn>,
 370    ) -> anyhow::Result<()> {
 371        Self::run_foreground_fn(
 372            "write to buffer",
 373            foreground_tx,
 374            Box::new(move |session, mut cx| {
 375                session.update(&mut cx, |session, cx| {
 376                    let open_buffer_task = session
 377                        .project
 378                        .update(cx, |project, cx| project.open_buffer(project_path, cx));
 379
 380                    cx.spawn(move |session, mut cx| async move {
 381                        let buffer = open_buffer_task.await?;
 382
 383                        let diff = buffer
 384                            .update(&mut cx, |buffer, cx| buffer.diff(text, cx))?
 385                            .await;
 386
 387                        buffer.update(&mut cx, |buffer, cx| {
 388                            buffer.apply_diff(diff, cx);
 389                        })?;
 390
 391                        session
 392                            .update(&mut cx, {
 393                                let buffer = buffer.clone();
 394
 395                                |session, cx| {
 396                                    session
 397                                        .project
 398                                        .update(cx, |project, cx| project.save_buffer(buffer, cx))
 399                                }
 400                            })?
 401                            .await?;
 402
 403                        // If we saved successfully, mark buffer as changed
 404                        session.update(&mut cx, |session, _cx| {
 405                            session.changed_buffers.insert(buffer);
 406                        })
 407                    })
 408                })
 409            }),
 410        )
 411        .await??
 412        .await
 413    }
 414
 415    async fn io_file_dir(
 416        lua: &Lua,
 417        fs: Arc<dyn Fs>,
 418        file: Table,
 419        path: &Path,
 420    ) -> mlua::Result<(Option<Table>, String)> {
 421        // Create a special directory handle
 422        file.set("__is_directory", true)?;
 423
 424        // Store directory entries
 425        let entries = match fs.read_dir(&path).await {
 426            Ok(entries) => {
 427                let mut entry_names = Vec::new();
 428
 429                // Process the stream of directory entries
 430                pin_mut!(entries);
 431                while let Some(Ok(entry_result)) = entries.next().await {
 432                    if let Some(file_name) = entry_result.file_name() {
 433                        entry_names.push(file_name.to_string_lossy().into_owned());
 434                    }
 435                }
 436
 437                entry_names
 438            }
 439            Err(e) => return Ok((None, format!("Error reading directory: {}", e))),
 440        };
 441
 442        // Save the list of entries
 443        file.set("__dir_entries", entries)?;
 444        file.set("__dir_position", 0usize)?;
 445
 446        // Create a directory-specific read function
 447        let read_fn = lua.create_function(|_lua, file_userdata: mlua::Table| {
 448            let position = file_userdata.get::<usize>("__dir_position")?;
 449            let entries = file_userdata.get::<Vec<String>>("__dir_entries")?;
 450
 451            if position >= entries.len() {
 452                return Ok(None); // No more entries
 453            }
 454
 455            let entry = entries[position].clone();
 456            file_userdata.set("__dir_position", position + 1)?;
 457
 458            Ok(Some(entry))
 459        })?;
 460        file.set("read", read_fn)?;
 461
 462        // If we got this far, the directory was opened successfully
 463        return Ok((Some(file), String::new()));
 464    }
 465
 466    fn io_file_read(
 467        lua: &Lua,
 468        (file_userdata, format): (Table, Option<mlua::Value>),
 469    ) -> mlua::Result<Option<mlua::String>> {
 470        let read_perm = file_userdata.get::<bool>("__read_perm")?;
 471        if !read_perm {
 472            return Err(mlua::Error::runtime("File not open for reading"));
 473        }
 474
 475        let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
 476        let position = file_userdata.get::<usize>("__position")?;
 477        let content_ref = content.borrow::<FileContent>()?;
 478        let content = content_ref.0.lock();
 479
 480        if position >= content.len() {
 481            return Ok(None); // EOF
 482        }
 483
 484        let (result, new_position) = match Self::io_file_read_format(format)? {
 485            FileReadFormat::All => {
 486                // Read entire file from current position
 487                let result = content[position..].to_vec();
 488                (Some(result), content.len())
 489            }
 490            FileReadFormat::Line => {
 491                if let Some(next_newline_ix) = content[position..].iter().position(|c| *c == b'\n')
 492                {
 493                    let mut line = content[position..position + next_newline_ix].to_vec();
 494                    if line.ends_with(b"\r") {
 495                        line.pop();
 496                    }
 497                    (Some(line), position + next_newline_ix + 1)
 498                } else if position < content.len() {
 499                    let line = content[position..].to_vec();
 500                    (Some(line), content.len())
 501                } else {
 502                    (None, position) // EOF
 503                }
 504            }
 505            FileReadFormat::LineWithLineFeed => {
 506                if position < content.len() {
 507                    let next_line_ix = content[position..]
 508                        .iter()
 509                        .position(|c| *c == b'\n')
 510                        .map_or(content.len(), |ix| position + ix + 1);
 511                    let line = content[position..next_line_ix].to_vec();
 512                    (Some(line), next_line_ix)
 513                } else {
 514                    (None, position) // EOF
 515                }
 516            }
 517            FileReadFormat::Bytes(n) => {
 518                let end = std::cmp::min(position + n, content.len());
 519                let result = content[position..end].to_vec();
 520                (Some(result), end)
 521            }
 522        };
 523
 524        // Update the position in the file userdata
 525        if new_position != position {
 526            file_userdata.set("__position", new_position)?;
 527        }
 528
 529        // Convert the result to a Lua string
 530        match result {
 531            Some(bytes) => Ok(Some(lua.create_string(bytes)?)),
 532            None => Ok(None),
 533        }
 534    }
 535
 536    fn io_file_read_format(format: Option<mlua::Value>) -> mlua::Result<FileReadFormat> {
 537        let format = match format {
 538            Some(mlua::Value::String(s)) => {
 539                let lossy_string = s.to_string_lossy();
 540                let format_str: &str = lossy_string.as_ref();
 541
 542                // Only consider the first 2 bytes, since it's common to pass e.g. "*all"  instead of "*a"
 543                match &format_str[0..2] {
 544                    "*a" => FileReadFormat::All,
 545                    "*l" => FileReadFormat::Line,
 546                    "*L" => FileReadFormat::LineWithLineFeed,
 547                    "*n" => {
 548                        // Try to parse as a number (number of bytes to read)
 549                        match format_str.parse::<usize>() {
 550                            Ok(n) => FileReadFormat::Bytes(n),
 551                            Err(_) => {
 552                                return Err(mlua::Error::runtime(format!(
 553                                    "Invalid format: {}",
 554                                    format_str
 555                                )))
 556                            }
 557                        }
 558                    }
 559                    _ => {
 560                        return Err(mlua::Error::runtime(format!(
 561                            "Unsupported format: {}",
 562                            format_str
 563                        )))
 564                    }
 565                }
 566            }
 567            Some(mlua::Value::Number(n)) => FileReadFormat::Bytes(n as usize),
 568            Some(mlua::Value::Integer(n)) => FileReadFormat::Bytes(n as usize),
 569            Some(value) => {
 570                return Err(mlua::Error::runtime(format!(
 571                    "Invalid file format {:?}",
 572                    value
 573                )))
 574            }
 575            None => FileReadFormat::Line, // Default is to read a line
 576        };
 577
 578        Ok(format)
 579    }
 580
 581    fn io_file_write(
 582        _lua: &Lua,
 583        (file_userdata, text): (Table, mlua::String),
 584    ) -> mlua::Result<bool> {
 585        let write_perm = file_userdata.get::<bool>("__write_perm")?;
 586        if !write_perm {
 587            return Err(mlua::Error::runtime("File not open for writing"));
 588        }
 589
 590        let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
 591        let position = file_userdata.get::<usize>("__position")?;
 592        let content_ref = content.borrow::<FileContent>()?;
 593        let mut content_vec = content_ref.0.lock();
 594
 595        let bytes = text.as_bytes();
 596
 597        // Ensure the vector has enough capacity
 598        if position + bytes.len() > content_vec.len() {
 599            content_vec.resize(position + bytes.len(), 0);
 600        }
 601
 602        // Write the bytes
 603        for (i, &byte) in bytes.iter().enumerate() {
 604            content_vec[position + i] = byte;
 605        }
 606
 607        // Update position
 608        let new_position = position + bytes.len();
 609        file_userdata.set("__position", new_position)?;
 610
 611        Ok(true)
 612    }
 613
 614    async fn search(
 615        lua: &Lua,
 616        foreground_tx: &mut mpsc::Sender<ForegroundFn>,
 617        fs: Arc<dyn Fs>,
 618        regex: String,
 619    ) -> anyhow::Result<Table> {
 620        // TODO: Allow specification of these options.
 621        let search_query = SearchQuery::regex(
 622            &regex,
 623            false,
 624            false,
 625            false,
 626            PathMatcher::default(),
 627            PathMatcher::default(),
 628            None,
 629        );
 630        let search_query = match search_query {
 631            Ok(query) => query,
 632            Err(e) => return Err(anyhow!("Invalid search query: {}", e)),
 633        };
 634
 635        // TODO: Should use `search_query.regex`. The tool description should also be updated,
 636        // as it specifies standard regex.
 637        let search_regex = match Regex::new(&regex) {
 638            Ok(re) => re,
 639            Err(e) => return Err(anyhow!("Invalid regex: {}", e)),
 640        };
 641
 642        let mut abs_paths_rx = Self::find_search_candidates(search_query, foreground_tx).await?;
 643
 644        let mut search_results: Vec<Table> = Vec::new();
 645        while let Some(path) = abs_paths_rx.next().await {
 646            // Skip files larger than 1MB
 647            if let Ok(Some(metadata)) = fs.metadata(&path).await {
 648                if metadata.len > 1_000_000 {
 649                    continue;
 650                }
 651            }
 652
 653            // Attempt to read the file as text
 654            if let Ok(content) = fs.load(&path).await {
 655                let mut matches = Vec::new();
 656
 657                // Find all regex matches in the content
 658                for capture in search_regex.find_iter(&content) {
 659                    matches.push(capture.as_str().to_string());
 660                }
 661
 662                // If we found matches, create a result entry
 663                if !matches.is_empty() {
 664                    let result_entry = lua.create_table()?;
 665                    result_entry.set("path", path.to_string_lossy().to_string())?;
 666
 667                    let matches_table = lua.create_table()?;
 668                    for (ix, m) in matches.iter().enumerate() {
 669                        matches_table.set(ix + 1, m.clone())?;
 670                    }
 671                    result_entry.set("matches", matches_table)?;
 672
 673                    search_results.push(result_entry);
 674                }
 675            }
 676        }
 677
 678        // Create a table to hold our results
 679        let results_table = lua.create_table()?;
 680        for (ix, entry) in search_results.into_iter().enumerate() {
 681            results_table.set(ix + 1, entry)?;
 682        }
 683
 684        Ok(results_table)
 685    }
 686
 687    async fn find_search_candidates(
 688        search_query: SearchQuery,
 689        foreground_tx: &mut mpsc::Sender<ForegroundFn>,
 690    ) -> anyhow::Result<mpsc::UnboundedReceiver<PathBuf>> {
 691        Self::run_foreground_fn(
 692            "finding search file candidates",
 693            foreground_tx,
 694            Box::new(move |session, mut cx| {
 695                session.update(&mut cx, |session, cx| {
 696                    session.project.update(cx, |project, cx| {
 697                        project.worktree_store().update(cx, |worktree_store, cx| {
 698                            // TODO: Better limit? For now this is the same as
 699                            // MAX_SEARCH_RESULT_FILES.
 700                            let limit = 5000;
 701                            // TODO: Providing non-empty open_entries can make this a bit more
 702                            // efficient as it can skip checking that these paths are textual.
 703                            let open_entries = HashSet::default();
 704                            let candidates = worktree_store.find_search_candidates(
 705                                search_query,
 706                                limit,
 707                                open_entries,
 708                                project.fs().clone(),
 709                                cx,
 710                            );
 711                            let (abs_paths_tx, abs_paths_rx) = mpsc::unbounded();
 712                            cx.spawn(|worktree_store, cx| async move {
 713                                pin_mut!(candidates);
 714
 715                                while let Some(project_path) = candidates.next().await {
 716                                    worktree_store.read_with(&cx, |worktree_store, cx| {
 717                                        if let Some(worktree) = worktree_store
 718                                            .worktree_for_id(project_path.worktree_id, cx)
 719                                        {
 720                                            if let Some(abs_path) = worktree
 721                                                .read(cx)
 722                                                .absolutize(&project_path.path)
 723                                                .log_err()
 724                                            {
 725                                                abs_paths_tx.unbounded_send(abs_path)?;
 726                                            }
 727                                        }
 728                                        anyhow::Ok(())
 729                                    })??;
 730                                }
 731                                anyhow::Ok(())
 732                            })
 733                            .detach();
 734                            abs_paths_rx
 735                        })
 736                    })
 737                })
 738            }),
 739        )
 740        .await?
 741    }
 742
 743    async fn outline(
 744        root_dir: Option<Arc<Path>>,
 745        foreground_tx: &mut mpsc::Sender<ForegroundFn>,
 746        path_str: String,
 747    ) -> anyhow::Result<String> {
 748        let root_dir = root_dir
 749            .ok_or_else(|| mlua::Error::runtime("cannot get outline without a root directory"))?;
 750        let path = Self::parse_abs_path_in_root_dir(&root_dir, &path_str)?;
 751        let outline = Self::run_foreground_fn(
 752            "getting code outline",
 753            foreground_tx,
 754            Box::new(move |session, cx| {
 755                cx.spawn(move |mut cx| async move {
 756                    // TODO: This will not use file content from `fs_changes`. It will also reflect
 757                    // user changes that have not been saved.
 758                    let buffer = session
 759                        .update(&mut cx, |session, cx| {
 760                            session
 761                                .project
 762                                .update(cx, |project, cx| project.open_local_buffer(&path, cx))
 763                        })?
 764                        .await?;
 765                    buffer.update(&mut cx, |buffer, _cx| {
 766                        if let Some(outline) = buffer.snapshot().outline(None) {
 767                            Ok(outline)
 768                        } else {
 769                            Err(anyhow!("No outline for file {path_str}"))
 770                        }
 771                    })
 772                })
 773            }),
 774        )
 775        .await?
 776        .await??;
 777
 778        Ok(outline
 779            .items
 780            .into_iter()
 781            .map(|item| {
 782                if item.text.contains('\n') {
 783                    log::error!("Outline item unexpectedly contains newline");
 784                }
 785                format!("{}{}", "  ".repeat(item.depth), item.text)
 786            })
 787            .collect::<Vec<String>>()
 788            .join("\n"))
 789    }
 790
 791    async fn run_foreground_fn<R: Send + 'static>(
 792        description: &str,
 793        foreground_tx: &mut mpsc::Sender<ForegroundFn>,
 794        function: Box<dyn FnOnce(WeakEntity<Self>, AsyncApp) -> R + Send>,
 795    ) -> anyhow::Result<R> {
 796        let (response_tx, response_rx) = oneshot::channel();
 797        let send_result = foreground_tx
 798            .send(ForegroundFn(Box::new(move |this, cx| {
 799                response_tx.send(function(this, cx)).ok();
 800            })))
 801            .await;
 802        match send_result {
 803            Ok(()) => (),
 804            Err(err) => {
 805                return Err(anyhow::Error::new(err).context(format!(
 806                    "Internal error while enqueuing work for {description}"
 807                )));
 808            }
 809        }
 810        match response_rx.await {
 811            Ok(result) => Ok(result),
 812            Err(oneshot::Canceled) => Err(anyhow!(
 813                "Internal error: response oneshot was canceled while {description}."
 814            )),
 815        }
 816    }
 817
 818    fn parse_abs_path_in_root_dir(root_dir: &Path, path_str: &str) -> anyhow::Result<PathBuf> {
 819        let path = Path::new(&path_str);
 820        if path.is_absolute() {
 821            // Check if path starts with root_dir prefix without resolving symlinks
 822            if path.starts_with(&root_dir) {
 823                Ok(path.to_path_buf())
 824            } else {
 825                Err(anyhow!(
 826                    "Error: Absolute path {} is outside the current working directory",
 827                    path_str
 828                ))
 829            }
 830        } else {
 831            // TODO: Does use of `../` break sandbox - is path canonicalization needed?
 832            Ok(root_dir.join(path))
 833        }
 834    }
 835}
 836
 837enum FileReadFormat {
 838    All,
 839    Line,
 840    LineWithLineFeed,
 841    Bytes(usize),
 842}
 843
 844struct FileContent(Arc<Mutex<Vec<u8>>>);
 845
 846impl UserData for FileContent {
 847    fn add_methods<M: UserDataMethods<Self>>(_methods: &mut M) {
 848        // FileContent doesn't have any methods so far.
 849    }
 850}
 851
 852#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
 853pub struct ScriptId(u32);
 854
 855pub struct Script {
 856    pub state: ScriptState,
 857}
 858
 859#[derive(Debug)]
 860pub enum ScriptState {
 861    Running {
 862        stdout: Arc<Mutex<String>>,
 863    },
 864    Succeeded {
 865        stdout: String,
 866    },
 867    Failed {
 868        stdout: String,
 869        error: anyhow::Error,
 870    },
 871}
 872
 873impl Script {
 874    /// If exited, returns a message with the output for the LLM
 875    pub fn output_message_for_llm(&self) -> Option<String> {
 876        match &self.state {
 877            ScriptState::Running { .. } => None,
 878            ScriptState::Succeeded { stdout } => {
 879                format!("Here's the script output:\n{}", stdout).into()
 880            }
 881            ScriptState::Failed { stdout, error } => format!(
 882                "The script failed with:\n{}\n\nHere's the output it managed to print:\n{}",
 883                error, stdout
 884            )
 885            .into(),
 886        }
 887    }
 888
 889    /// Get a snapshot of the script's stdout
 890    pub fn stdout_snapshot(&self) -> String {
 891        match &self.state {
 892            ScriptState::Running { stdout } => stdout.lock().clone(),
 893            ScriptState::Succeeded { stdout } => stdout.clone(),
 894            ScriptState::Failed { stdout, .. } => stdout.clone(),
 895        }
 896    }
 897}
 898#[cfg(test)]
 899mod tests {
 900    use gpui::TestAppContext;
 901    use project::FakeFs;
 902    use serde_json::json;
 903    use settings::SettingsStore;
 904    use util::path;
 905
 906    use super::*;
 907
 908    #[gpui::test]
 909    async fn test_print(cx: &mut TestAppContext) {
 910        let script = r#"
 911            print("Hello", "world!")
 912            print("Goodbye", "moon!")
 913        "#;
 914
 915        let test_session = TestSession::init(cx).await;
 916        let output = test_session.test_success(script, cx).await;
 917        assert_eq!(output, "Hello\tworld!\nGoodbye\tmoon!\n");
 918    }
 919
 920    // search
 921
 922    #[gpui::test]
 923    async fn test_search(cx: &mut TestAppContext) {
 924        let script = r#"
 925            local results = search("world")
 926            for i, result in ipairs(results) do
 927                print("File: " .. result.path)
 928                print("Matches:")
 929                for j, match in ipairs(result.matches) do
 930                    print("  " .. match)
 931                end
 932            end
 933        "#;
 934
 935        let test_session = TestSession::init(cx).await;
 936        let output = test_session.test_success(script, cx).await;
 937        assert_eq!(
 938            output,
 939            concat!("File: ", path!("/file1.txt"), "\nMatches:\n  world\n")
 940        );
 941    }
 942
 943    // io.open
 944
 945    #[gpui::test]
 946    async fn test_open_and_read_file(cx: &mut TestAppContext) {
 947        let script = r#"
 948            local file = io.open("file1.txt", "r")
 949            local content = file:read()
 950            print("Content:", content)
 951            file:close()
 952        "#;
 953
 954        let test_session = TestSession::init(cx).await;
 955        let output = test_session.test_success(script, cx).await;
 956        assert_eq!(output, "Content:\tHello world!\n");
 957
 958        // Only read, should not be marked as changed
 959        assert!(!test_session.was_marked_changed("file1.txt", cx));
 960    }
 961
 962    #[gpui::test]
 963    async fn test_read_write_roundtrip(cx: &mut TestAppContext) {
 964        let script = r#"
 965            local file = io.open("file1.txt", "w")
 966            file:write("This is new content")
 967            file:close()
 968
 969            -- Read back to verify
 970            local read_file = io.open("file1.txt", "r")
 971            local content = read_file:read("*a")
 972            print("Written content:", content)
 973            read_file:close()
 974        "#;
 975
 976        let test_session = TestSession::init(cx).await;
 977        let output = test_session.test_success(script, cx).await;
 978        assert_eq!(output, "Written content:\tThis is new content\n");
 979        assert!(test_session.was_marked_changed("file1.txt", cx));
 980    }
 981
 982    #[gpui::test]
 983    async fn test_multiple_writes(cx: &mut TestAppContext) {
 984        let script = r#"
 985            -- Test writing to a file multiple times
 986            local file = io.open("multiwrite.txt", "w")
 987            file:write("First line\n")
 988            file:write("Second line\n")
 989            file:write("Third line")
 990            file:close()
 991
 992            -- Read back to verify
 993            local read_file = io.open("multiwrite.txt", "r")
 994            if read_file then
 995                local content = read_file:read("*a")
 996                print("Full content:", content)
 997                read_file:close()
 998            end
 999        "#;
1000
1001        let test_session = TestSession::init(cx).await;
1002        let output = test_session.test_success(script, cx).await;
1003        assert_eq!(
1004            output,
1005            "Full content:\tFirst line\nSecond line\nThird line\n"
1006        );
1007        assert!(test_session.was_marked_changed("multiwrite.txt", cx));
1008    }
1009
1010    #[gpui::test]
1011    async fn test_multiple_writes_diff_handles(cx: &mut TestAppContext) {
1012        let script = r#"
1013            -- Write to a file
1014            local file1 = io.open("multi_open.txt", "w")
1015            file1:write("Content written by first handle\n")
1016            file1:close()
1017
1018            -- Open it again and add more content
1019            local file2 = io.open("multi_open.txt", "w")
1020            file2:write("Content written by second handle\n")
1021            file2:close()
1022
1023            -- Open it a third time and read
1024            local file3 = io.open("multi_open.txt", "r")
1025            local content = file3:read("*a")
1026            print("Final content:", content)
1027            file3:close()
1028        "#;
1029
1030        let test_session = TestSession::init(cx).await;
1031        let output = test_session.test_success(script, cx).await;
1032        assert_eq!(
1033            output,
1034            "Final content:\tContent written by second handle\n\n"
1035        );
1036        assert!(test_session.was_marked_changed("multi_open.txt", cx));
1037    }
1038
1039    #[gpui::test]
1040    async fn test_append_mode(cx: &mut TestAppContext) {
1041        let script = r#"
1042            -- Test append mode
1043            local file = io.open("append.txt", "w")
1044            file:write("Initial content\n")
1045            file:close()
1046
1047            -- Append more content
1048            file = io.open("append.txt", "a")
1049            file:write("Appended content\n")
1050            file:close()
1051
1052            -- Add even more
1053            file = io.open("append.txt", "a")
1054            file:write("More appended content")
1055            file:close()
1056
1057            -- Read back to verify
1058            local read_file = io.open("append.txt", "r")
1059            local content = read_file:read("*a")
1060            print("Content after appends:", content)
1061            read_file:close()
1062        "#;
1063
1064        let test_session = TestSession::init(cx).await;
1065        let output = test_session.test_success(script, cx).await;
1066        assert_eq!(
1067            output,
1068            "Content after appends:\tInitial content\nAppended content\nMore appended content\n"
1069        );
1070        assert!(test_session.was_marked_changed("append.txt", cx));
1071    }
1072
1073    #[gpui::test]
1074    async fn test_read_formats(cx: &mut TestAppContext) {
1075        let script = r#"
1076            local file = io.open("multiline.txt", "w")
1077            file:write("Line 1\nLine 2\nLine 3")
1078            file:close()
1079
1080            -- Test "*a" (all)
1081            local f = io.open("multiline.txt", "r")
1082            local all = f:read("*a")
1083            print("All:", all)
1084            f:close()
1085
1086            -- Test "*l" (line)
1087            f = io.open("multiline.txt", "r")
1088            local line1 = f:read("*l")
1089            local line2 = f:read("*l")
1090            local line3 = f:read("*l")
1091            print("Line 1:", line1)
1092            print("Line 2:", line2)
1093            print("Line 3:", line3)
1094            f:close()
1095
1096            -- Test "*L" (line with newline)
1097            f = io.open("multiline.txt", "r")
1098            local line_with_nl = f:read("*L")
1099            print("Line with newline length:", #line_with_nl)
1100            print("Last char:", string.byte(line_with_nl, #line_with_nl))
1101            f:close()
1102
1103            -- Test number of bytes
1104            f = io.open("multiline.txt", "r")
1105            local bytes5 = f:read(5)
1106            print("5 bytes:", bytes5)
1107            f:close()
1108        "#;
1109
1110        let test_session = TestSession::init(cx).await;
1111        let output = test_session.test_success(script, cx).await;
1112        println!("{}", &output);
1113        assert!(output.contains("All:\tLine 1\nLine 2\nLine 3"));
1114        assert!(output.contains("Line 1:\tLine 1"));
1115        assert!(output.contains("Line 2:\tLine 2"));
1116        assert!(output.contains("Line 3:\tLine 3"));
1117        assert!(output.contains("Line with newline length:\t7"));
1118        assert!(output.contains("Last char:\t10")); // LF
1119        assert!(output.contains("5 bytes:\tLine "));
1120        assert!(test_session.was_marked_changed("multiline.txt", cx));
1121    }
1122
1123    // helpers
1124
1125    struct TestSession {
1126        session: Entity<ScriptingSession>,
1127    }
1128
1129    impl TestSession {
1130        async fn init(cx: &mut TestAppContext) -> Self {
1131            let settings_store = cx.update(SettingsStore::test);
1132            cx.set_global(settings_store);
1133            cx.update(Project::init_settings);
1134            cx.update(language::init);
1135
1136            let fs = FakeFs::new(cx.executor());
1137            fs.insert_tree(
1138                path!("/"),
1139                json!({
1140                    "file1.txt": "Hello world!",
1141                    "file2.txt": "Goodbye moon!"
1142                }),
1143            )
1144            .await;
1145
1146            let project = Project::test(fs.clone(), [Path::new(path!("/"))], cx).await;
1147            let session = cx.new(|cx| ScriptingSession::new(project, cx));
1148
1149            TestSession { session }
1150        }
1151
1152        async fn test_success(&self, source: &str, cx: &mut TestAppContext) -> String {
1153            let script_id = self.run_script(source, cx).await;
1154
1155            self.session.read_with(cx, |session, _cx| {
1156                let script = session.get(script_id);
1157                let stdout = script.stdout_snapshot();
1158
1159                if let ScriptState::Failed { error, .. } = &script.state {
1160                    panic!("Script failed:\n{}\n\n{}", error, stdout);
1161                }
1162
1163                stdout
1164            })
1165        }
1166
1167        fn was_marked_changed(&self, path_str: &str, cx: &mut TestAppContext) -> bool {
1168            self.session.read_with(cx, |session, cx| {
1169                let count_changed = session
1170                    .changed_buffers
1171                    .iter()
1172                    .filter(|buffer| buffer.read(cx).file().unwrap().path().ends_with(path_str))
1173                    .count();
1174
1175                assert!(count_changed < 2, "Multiple buffers matched for same path");
1176
1177                count_changed > 0
1178            })
1179        }
1180
1181        async fn run_script(&self, source: &str, cx: &mut TestAppContext) -> ScriptId {
1182            let (script_id, task) = self
1183                .session
1184                .update(cx, |session, cx| session.run_script(source.to_string(), cx));
1185
1186            task.await;
1187
1188            script_id
1189        }
1190    }
1191}