scripting_session.rs

   1use anyhow::anyhow;
   2use buffer_diff::BufferDiff;
   3use collections::{HashMap, HashSet};
   4use futures::{
   5    channel::{mpsc, oneshot},
   6    pin_mut, SinkExt, StreamExt,
   7};
   8use gpui::{AppContext, AsyncApp, Context, Entity, Task, WeakEntity};
   9use language::Buffer;
  10use mlua::{ExternalResult, Lua, MultiValue, ObjectLike, Table, UserData, UserDataMethods};
  11use parking_lot::Mutex;
  12use project::{search::SearchQuery, Fs, Project, ProjectPath, WorktreeId};
  13use regex::Regex;
  14use std::{
  15    path::{Path, PathBuf},
  16    sync::Arc,
  17};
  18use util::{paths::PathMatcher, ResultExt};
  19
  20struct ForegroundFn(Box<dyn FnOnce(WeakEntity<ScriptingSession>, AsyncApp) + Send>);
  21
  22struct BufferChanges {
  23    diff: Entity<BufferDiff>,
  24    edit_ids: Vec<clock::Lamport>,
  25}
  26
  27pub struct ScriptingSession {
  28    project: Entity<Project>,
  29    scripts: Vec<Script>,
  30    changes_by_buffer: HashMap<Entity<Buffer>, BufferChanges>,
  31    foreground_fns_tx: mpsc::Sender<ForegroundFn>,
  32    _invoke_foreground_fns: Task<()>,
  33}
  34
  35impl ScriptingSession {
  36    pub fn new(project: Entity<Project>, cx: &mut Context<Self>) -> Self {
  37        let (foreground_fns_tx, mut foreground_fns_rx) = mpsc::channel(128);
  38        ScriptingSession {
  39            project,
  40            scripts: Vec::new(),
  41            changes_by_buffer: HashMap::default(),
  42            foreground_fns_tx,
  43            _invoke_foreground_fns: cx.spawn(async move |this, cx| {
  44                while let Some(foreground_fn) = foreground_fns_rx.next().await {
  45                    foreground_fn.0(this.clone(), cx.clone());
  46                }
  47            }),
  48        }
  49    }
  50
  51    pub fn changed_buffers(&self) -> impl ExactSizeIterator<Item = &Entity<Buffer>> {
  52        self.changes_by_buffer.keys()
  53    }
  54
  55    pub fn run_script(
  56        &mut self,
  57        script_src: String,
  58        cx: &mut Context<Self>,
  59    ) -> (ScriptId, Task<()>) {
  60        let id = ScriptId(self.scripts.len() as u32);
  61
  62        let stdout = Arc::new(Mutex::new(String::new()));
  63
  64        let script = Script {
  65            state: ScriptState::Running {
  66                stdout: stdout.clone(),
  67            },
  68        };
  69        self.scripts.push(script);
  70
  71        let task = self.run_lua(script_src, stdout, cx);
  72
  73        let task = cx.spawn(async move |session, cx| {
  74            let result = task.await;
  75
  76            session
  77                .update(cx, |session, _cx| {
  78                    let script = session.get_mut(id);
  79                    let stdout = script.stdout_snapshot();
  80
  81                    script.state = match result {
  82                        Ok(()) => ScriptState::Succeeded { stdout },
  83                        Err(error) => ScriptState::Failed { stdout, error },
  84                    };
  85                })
  86                .log_err();
  87        });
  88
  89        (id, task)
  90    }
  91
  92    fn run_lua(
  93        &mut self,
  94        script: String,
  95        stdout: Arc<Mutex<String>>,
  96        cx: &mut Context<Self>,
  97    ) -> Task<anyhow::Result<()>> {
  98        const SANDBOX_PREAMBLE: &str = include_str!("sandbox_preamble.lua");
  99
 100        // TODO Honor all worktrees instead of the first one
 101        let worktree_info = self
 102            .project
 103            .read(cx)
 104            .visible_worktrees(cx)
 105            .next()
 106            .map(|worktree| {
 107                let worktree = worktree.read(cx);
 108                (worktree.id(), worktree.abs_path())
 109            });
 110
 111        let root_dir = worktree_info.as_ref().map(|(_, root)| root.clone());
 112
 113        let fs = self.project.read(cx).fs().clone();
 114        let foreground_fns_tx = self.foreground_fns_tx.clone();
 115
 116        let task = cx.background_spawn({
 117            let stdout = stdout.clone();
 118
 119            async move {
 120                let lua = Lua::new();
 121                lua.set_memory_limit(2 * 1024 * 1024 * 1024)?; // 2 GB
 122                let globals = lua.globals();
 123
 124                // Use the project root dir as the script's current working dir.
 125                if let Some(root_dir) = &root_dir {
 126                    if let Some(root_dir) = root_dir.to_str() {
 127                        globals.set("cwd", root_dir)?;
 128                    }
 129                }
 130
 131                globals.set(
 132                    "sb_print",
 133                    lua.create_function({
 134                        let stdout = stdout.clone();
 135                        move |_, args: MultiValue| Self::print(args, &stdout)
 136                    })?,
 137                )?;
 138                globals.set(
 139                    "search",
 140                    lua.create_async_function({
 141                        let foreground_fns_tx = foreground_fns_tx.clone();
 142                        let fs = fs.clone();
 143                        move |lua, regex| {
 144                            let mut foreground_fns_tx = foreground_fns_tx.clone();
 145                            let fs = fs.clone();
 146                            async move {
 147                                Self::search(&lua, &mut foreground_fns_tx, fs, regex)
 148                                    .await
 149                                    .into_lua_err()
 150                            }
 151                        }
 152                    })?,
 153                )?;
 154                globals.set(
 155                    "outline",
 156                    lua.create_async_function({
 157                        let root_dir = root_dir.clone();
 158                        let foreground_fns_tx = foreground_fns_tx.clone();
 159                        move |_lua, path| {
 160                            let mut foreground_fns_tx = foreground_fns_tx.clone();
 161                            let root_dir = root_dir.clone();
 162                            async move {
 163                                Self::outline(root_dir, &mut foreground_fns_tx, path)
 164                                    .await
 165                                    .into_lua_err()
 166                            }
 167                        }
 168                    })?,
 169                )?;
 170                globals.set(
 171                    "sb_io_open",
 172                    lua.create_async_function({
 173                        let worktree_info = worktree_info.clone();
 174                        let foreground_fns_tx = foreground_fns_tx.clone();
 175                        move |lua, (path_str, mode)| {
 176                            let worktree_info = worktree_info.clone();
 177                            let mut foreground_fns_tx = foreground_fns_tx.clone();
 178                            let fs = fs.clone();
 179                            async move {
 180                                Self::io_open(
 181                                    &lua,
 182                                    worktree_info,
 183                                    &mut foreground_fns_tx,
 184                                    fs,
 185                                    path_str,
 186                                    mode,
 187                                )
 188                                .await
 189                            }
 190                        }
 191                    })?,
 192                )?;
 193                globals.set("user_script", script)?;
 194
 195                lua.load(SANDBOX_PREAMBLE).exec_async().await?;
 196
 197                anyhow::Ok(())
 198            }
 199        });
 200
 201        task
 202    }
 203
 204    pub fn get(&self, script_id: ScriptId) -> &Script {
 205        &self.scripts[script_id.0 as usize]
 206    }
 207
 208    fn get_mut(&mut self, script_id: ScriptId) -> &mut Script {
 209        &mut self.scripts[script_id.0 as usize]
 210    }
 211
 212    /// Sandboxed print() function in Lua.
 213    fn print(args: MultiValue, stdout: &Mutex<String>) -> mlua::Result<()> {
 214        for (index, arg) in args.into_iter().enumerate() {
 215            // Lua's `print()` prints tab characters between each argument.
 216            if index > 0 {
 217                stdout.lock().push('\t');
 218            }
 219
 220            // If the argument's to_string() fails, have the whole function call fail.
 221            stdout.lock().push_str(&arg.to_string()?);
 222        }
 223        stdout.lock().push('\n');
 224
 225        Ok(())
 226    }
 227
 228    /// Sandboxed io.open() function in Lua.
 229    async fn io_open(
 230        lua: &Lua,
 231        worktree_info: Option<(WorktreeId, Arc<Path>)>,
 232        foreground_tx: &mut mpsc::Sender<ForegroundFn>,
 233        fs: Arc<dyn Fs>,
 234        path_str: String,
 235        mode: Option<String>,
 236    ) -> mlua::Result<(Option<Table>, String)> {
 237        let (worktree_id, root_dir) = worktree_info
 238            .ok_or_else(|| mlua::Error::runtime("cannot open file without a root directory"))?;
 239
 240        let mode = mode.unwrap_or_else(|| "r".to_string());
 241
 242        // Parse the mode string to determine read/write permissions
 243        let read_perm = mode.contains('r');
 244        let write_perm = mode.contains('w') || mode.contains('a') || mode.contains('+');
 245        let append = mode.contains('a');
 246        let truncate = mode.contains('w');
 247
 248        // This will be the Lua value returned from the `open` function.
 249        let file = lua.create_table()?;
 250
 251        // Store file metadata in the file
 252        file.set("__mode", mode.clone())?;
 253        file.set("__read_perm", read_perm)?;
 254        file.set("__write_perm", write_perm)?;
 255
 256        let path = match Self::parse_abs_path_in_root_dir(&root_dir, &path_str) {
 257            Ok(path) => path,
 258            Err(err) => return Ok((None, format!("{err}"))),
 259        };
 260
 261        let project_path = ProjectPath {
 262            worktree_id,
 263            path: Path::new(&path_str).into(),
 264        };
 265
 266        // flush / close method
 267        let flush_fn = {
 268            let project_path = project_path.clone();
 269            let foreground_tx = foreground_tx.clone();
 270            lua.create_async_function(move |_lua, file_userdata: mlua::Table| {
 271                let project_path = project_path.clone();
 272                let mut foreground_tx = foreground_tx.clone();
 273                async move {
 274                    Self::io_file_flush(file_userdata, project_path, &mut foreground_tx).await
 275                }
 276            })?
 277        };
 278        file.set("flush", flush_fn.clone())?;
 279        // We don't really hold files open, so we only need to flush on close
 280        file.set("close", flush_fn)?;
 281
 282        // If it's a directory, give it a custom read() and return early.
 283        if fs.is_dir(&path).await {
 284            return Self::io_file_dir(lua, fs, file, &path).await;
 285        }
 286
 287        let mut file_content = Vec::new();
 288
 289        if !truncate {
 290            // Try to read existing content if we're not truncating
 291            match Self::read_buffer(project_path.clone(), foreground_tx).await {
 292                Ok(content) => file_content = content.into_bytes(),
 293                Err(e) => return Ok((None, format!("Error reading file: {}", e))),
 294            }
 295        }
 296
 297        // If in append mode, position should be at the end
 298        let position = if append { file_content.len() } else { 0 };
 299        file.set("__position", position)?;
 300        file.set(
 301            "__content",
 302            lua.create_userdata(FileContent(Arc::new(Mutex::new(file_content))))?,
 303        )?;
 304
 305        // Create file methods
 306
 307        // read method
 308        let read_fn = lua.create_function(Self::io_file_read)?;
 309        file.set("read", read_fn)?;
 310
 311        // lines method
 312        let lines_fn = lua.create_function(Self::io_file_lines)?;
 313        file.set("lines", lines_fn)?;
 314
 315        // write method
 316        let write_fn = lua.create_function(Self::io_file_write)?;
 317        file.set("write", write_fn)?;
 318
 319        // If we got this far, the file was opened successfully
 320        Ok((Some(file), String::new()))
 321    }
 322
 323    async fn read_buffer(
 324        project_path: ProjectPath,
 325        foreground_tx: &mut mpsc::Sender<ForegroundFn>,
 326    ) -> anyhow::Result<String> {
 327        Self::run_foreground_fn(
 328            "read file from buffer",
 329            foreground_tx,
 330            Box::new(move |session, mut cx| {
 331                session.update(&mut cx, |session, cx| {
 332                    let open_buffer_task = session
 333                        .project
 334                        .update(cx, |project, cx| project.open_buffer(project_path, cx));
 335
 336                    cx.spawn(async move |_, cx| {
 337                        let buffer = open_buffer_task.await?;
 338
 339                        let text = buffer.read_with(cx, |buffer, _cx| buffer.text())?;
 340                        Ok(text)
 341                    })
 342                })
 343            }),
 344        )
 345        .await??
 346        .await
 347    }
 348
 349    async fn io_file_flush(
 350        file_userdata: mlua::Table,
 351        project_path: ProjectPath,
 352        foreground_tx: &mut mpsc::Sender<ForegroundFn>,
 353    ) -> mlua::Result<bool> {
 354        let write_perm = file_userdata.get::<bool>("__write_perm")?;
 355
 356        if write_perm {
 357            let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
 358            let content_ref = content.borrow::<FileContent>()?;
 359            let text = {
 360                let mut content_vec = content_ref.0.lock();
 361                let content_vec = std::mem::take(&mut *content_vec);
 362                String::from_utf8(content_vec).into_lua_err()?
 363            };
 364
 365            Self::write_to_buffer(project_path, text, foreground_tx)
 366                .await
 367                .into_lua_err()?;
 368        }
 369
 370        Ok(true)
 371    }
 372
 373    async fn write_to_buffer(
 374        project_path: ProjectPath,
 375        text: String,
 376        foreground_tx: &mut mpsc::Sender<ForegroundFn>,
 377    ) -> anyhow::Result<()> {
 378        Self::run_foreground_fn(
 379            "write to buffer",
 380            foreground_tx,
 381            Box::new(move |session, mut cx| {
 382                session.update(&mut cx, |session, cx| {
 383                    let open_buffer_task = session
 384                        .project
 385                        .update(cx, |project, cx| project.open_buffer(project_path, cx));
 386
 387                    cx.spawn(async move |session, cx| {
 388                        let buffer = open_buffer_task.await?;
 389
 390                        let diff = buffer.update(cx, |buffer, cx| buffer.diff(text, cx))?.await;
 391
 392                        let edit_ids = buffer.update(cx, |buffer, cx| {
 393                            buffer.finalize_last_transaction();
 394                            buffer.apply_diff(diff, cx);
 395                            let transaction = buffer.finalize_last_transaction();
 396                            transaction
 397                                .map_or(Vec::new(), |transaction| transaction.edit_ids.clone())
 398                        })?;
 399
 400                        session
 401                            .update(cx, {
 402                                let buffer = buffer.clone();
 403
 404                                |session, cx| {
 405                                    session
 406                                        .project
 407                                        .update(cx, |project, cx| project.save_buffer(buffer, cx))
 408                                }
 409                            })?
 410                            .await?;
 411
 412                        let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
 413
 414                        // If we saved successfully, mark buffer as changed
 415                        let buffer_without_changes =
 416                            buffer.update(cx, |buffer, cx| buffer.branch(cx))?;
 417                        session
 418                            .update(cx, |session, cx| {
 419                                let changed_buffer = session
 420                                    .changes_by_buffer
 421                                    .entry(buffer)
 422                                    .or_insert_with(|| BufferChanges {
 423                                        diff: cx.new(|cx| BufferDiff::new(&snapshot, cx)),
 424                                        edit_ids: Vec::new(),
 425                                    });
 426                                changed_buffer.edit_ids.extend(edit_ids);
 427                                let operations_to_undo = changed_buffer
 428                                    .edit_ids
 429                                    .iter()
 430                                    .map(|edit_id| (*edit_id, u32::MAX))
 431                                    .collect::<HashMap<_, _>>();
 432                                buffer_without_changes.update(cx, |buffer, cx| {
 433                                    buffer.undo_operations(operations_to_undo, cx);
 434                                });
 435                                changed_buffer.diff.update(cx, |diff, cx| {
 436                                    diff.set_base_text(buffer_without_changes, snapshot.text, cx)
 437                                })
 438                            })?
 439                            .await?;
 440
 441                        Ok(())
 442                    })
 443                })
 444            }),
 445        )
 446        .await??
 447        .await
 448    }
 449
 450    async fn io_file_dir(
 451        lua: &Lua,
 452        fs: Arc<dyn Fs>,
 453        file: Table,
 454        path: &Path,
 455    ) -> mlua::Result<(Option<Table>, String)> {
 456        // Create a special directory handle
 457        file.set("__is_directory", true)?;
 458
 459        // Store directory entries
 460        let entries = match fs.read_dir(&path).await {
 461            Ok(entries) => {
 462                let mut entry_names = Vec::new();
 463
 464                // Process the stream of directory entries
 465                pin_mut!(entries);
 466                while let Some(Ok(entry_result)) = entries.next().await {
 467                    if let Some(file_name) = entry_result.file_name() {
 468                        entry_names.push(file_name.to_string_lossy().into_owned());
 469                    }
 470                }
 471
 472                entry_names
 473            }
 474            Err(e) => return Ok((None, format!("Error reading directory: {}", e))),
 475        };
 476
 477        // Save the list of entries
 478        file.set("__dir_entries", entries)?;
 479        file.set("__dir_position", 0usize)?;
 480
 481        // Create a directory-specific read function
 482        let read_fn = lua.create_function(|_lua, file_userdata: mlua::Table| {
 483            let position = file_userdata.get::<usize>("__dir_position")?;
 484            let entries = file_userdata.get::<Vec<String>>("__dir_entries")?;
 485
 486            if position >= entries.len() {
 487                return Ok(None); // No more entries
 488            }
 489
 490            let entry = entries[position].clone();
 491            file_userdata.set("__dir_position", position + 1)?;
 492
 493            Ok(Some(entry))
 494        })?;
 495        file.set("read", read_fn)?;
 496
 497        // If we got this far, the directory was opened successfully
 498        return Ok((Some(file), String::new()));
 499    }
 500
 501    fn io_file_read(
 502        lua: &Lua,
 503        (file_userdata, format): (Table, Option<mlua::Value>),
 504    ) -> mlua::Result<Option<mlua::String>> {
 505        let read_perm = file_userdata.get::<bool>("__read_perm")?;
 506        if !read_perm {
 507            return Err(mlua::Error::runtime("File not open for reading"));
 508        }
 509
 510        let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
 511        let position = file_userdata.get::<usize>("__position")?;
 512        let content_ref = content.borrow::<FileContent>()?;
 513        let content = content_ref.0.lock();
 514
 515        if position >= content.len() {
 516            return Ok(None); // EOF
 517        }
 518
 519        let (result, new_position) = match Self::io_file_read_format(format)? {
 520            FileReadFormat::All => {
 521                // Read entire file from current position
 522                let result = content[position..].to_vec();
 523                (Some(result), content.len())
 524            }
 525            FileReadFormat::Line => {
 526                if let Some(next_newline_ix) = content[position..].iter().position(|c| *c == b'\n')
 527                {
 528                    let mut line = content[position..position + next_newline_ix].to_vec();
 529                    if line.ends_with(b"\r") {
 530                        line.pop();
 531                    }
 532                    (Some(line), position + next_newline_ix + 1)
 533                } else if position < content.len() {
 534                    let line = content[position..].to_vec();
 535                    (Some(line), content.len())
 536                } else {
 537                    (None, position) // EOF
 538                }
 539            }
 540            FileReadFormat::LineWithLineFeed => {
 541                if position < content.len() {
 542                    let next_line_ix = content[position..]
 543                        .iter()
 544                        .position(|c| *c == b'\n')
 545                        .map_or(content.len(), |ix| position + ix + 1);
 546                    let line = content[position..next_line_ix].to_vec();
 547                    (Some(line), next_line_ix)
 548                } else {
 549                    (None, position) // EOF
 550                }
 551            }
 552            FileReadFormat::Bytes(n) => {
 553                let end = std::cmp::min(position + n, content.len());
 554                let result = content[position..end].to_vec();
 555                (Some(result), end)
 556            }
 557        };
 558
 559        // Update the position in the file userdata
 560        if new_position != position {
 561            file_userdata.set("__position", new_position)?;
 562        }
 563
 564        // Convert the result to a Lua string
 565        match result {
 566            Some(bytes) => Ok(Some(lua.create_string(bytes)?)),
 567            None => Ok(None),
 568        }
 569    }
 570
 571    fn io_file_lines(lua: &Lua, file_userdata: Table) -> mlua::Result<mlua::Function> {
 572        let read_perm = file_userdata.get::<bool>("__read_perm")?;
 573        if !read_perm {
 574            return Err(mlua::Error::runtime("File not open for reading"));
 575        }
 576
 577        lua.create_function::<_, _, mlua::Value>(move |lua, _: ()| {
 578            file_userdata.call_method("read", lua.create_string("*l")?)
 579        })
 580    }
 581
 582    fn io_file_read_format(format: Option<mlua::Value>) -> mlua::Result<FileReadFormat> {
 583        let format = match format {
 584            Some(mlua::Value::String(s)) => {
 585                let lossy_string = s.to_string_lossy();
 586                let format_str: &str = lossy_string.as_ref();
 587
 588                // Only consider the first 2 bytes, since it's common to pass e.g. "*all"  instead of "*a"
 589                match &format_str[0..2] {
 590                    "*a" => FileReadFormat::All,
 591                    "*l" => FileReadFormat::Line,
 592                    "*L" => FileReadFormat::LineWithLineFeed,
 593                    "*n" => {
 594                        // Try to parse as a number (number of bytes to read)
 595                        match format_str.parse::<usize>() {
 596                            Ok(n) => FileReadFormat::Bytes(n),
 597                            Err(_) => {
 598                                return Err(mlua::Error::runtime(format!(
 599                                    "Invalid format: {}",
 600                                    format_str
 601                                )))
 602                            }
 603                        }
 604                    }
 605                    _ => {
 606                        return Err(mlua::Error::runtime(format!(
 607                            "Unsupported format: {}",
 608                            format_str
 609                        )))
 610                    }
 611                }
 612            }
 613            Some(mlua::Value::Number(n)) => FileReadFormat::Bytes(n as usize),
 614            Some(mlua::Value::Integer(n)) => FileReadFormat::Bytes(n as usize),
 615            Some(value) => {
 616                return Err(mlua::Error::runtime(format!(
 617                    "Invalid file format {:?}",
 618                    value
 619                )))
 620            }
 621            None => FileReadFormat::Line, // Default is to read a line
 622        };
 623
 624        Ok(format)
 625    }
 626
 627    fn io_file_write(
 628        _lua: &Lua,
 629        (file_userdata, text): (Table, mlua::String),
 630    ) -> mlua::Result<bool> {
 631        let write_perm = file_userdata.get::<bool>("__write_perm")?;
 632        if !write_perm {
 633            return Err(mlua::Error::runtime("File not open for writing"));
 634        }
 635
 636        let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
 637        let position = file_userdata.get::<usize>("__position")?;
 638        let content_ref = content.borrow::<FileContent>()?;
 639        let mut content_vec = content_ref.0.lock();
 640
 641        let bytes = text.as_bytes();
 642
 643        // Ensure the vector has enough capacity
 644        if position + bytes.len() > content_vec.len() {
 645            content_vec.resize(position + bytes.len(), 0);
 646        }
 647
 648        // Write the bytes
 649        for (i, &byte) in bytes.iter().enumerate() {
 650            content_vec[position + i] = byte;
 651        }
 652
 653        // Update position
 654        let new_position = position + bytes.len();
 655        file_userdata.set("__position", new_position)?;
 656
 657        Ok(true)
 658    }
 659
 660    async fn search(
 661        lua: &Lua,
 662        foreground_tx: &mut mpsc::Sender<ForegroundFn>,
 663        fs: Arc<dyn Fs>,
 664        regex: String,
 665    ) -> anyhow::Result<Table> {
 666        // TODO: Allow specification of these options.
 667        let search_query = SearchQuery::regex(
 668            &regex,
 669            false,
 670            false,
 671            false,
 672            PathMatcher::default(),
 673            PathMatcher::default(),
 674            None,
 675        );
 676        let search_query = match search_query {
 677            Ok(query) => query,
 678            Err(e) => return Err(anyhow!("Invalid search query: {}", e)),
 679        };
 680
 681        // TODO: Should use `search_query.regex`. The tool description should also be updated,
 682        // as it specifies standard regex.
 683        let search_regex = match Regex::new(&regex) {
 684            Ok(re) => re,
 685            Err(e) => return Err(anyhow!("Invalid regex: {}", e)),
 686        };
 687
 688        let mut abs_paths_rx = Self::find_search_candidates(search_query, foreground_tx).await?;
 689
 690        let mut search_results: Vec<Table> = Vec::new();
 691        while let Some(path) = abs_paths_rx.next().await {
 692            // Skip files larger than 1MB
 693            if let Ok(Some(metadata)) = fs.metadata(&path).await {
 694                if metadata.len > 1_000_000 {
 695                    continue;
 696                }
 697            }
 698
 699            // Attempt to read the file as text
 700            if let Ok(content) = fs.load(&path).await {
 701                let mut matches = Vec::new();
 702
 703                // Find all regex matches in the content
 704                for capture in search_regex.find_iter(&content) {
 705                    matches.push(capture.as_str().to_string());
 706                }
 707
 708                // If we found matches, create a result entry
 709                if !matches.is_empty() {
 710                    let result_entry = lua.create_table()?;
 711                    result_entry.set("path", path.to_string_lossy().to_string())?;
 712
 713                    let matches_table = lua.create_table()?;
 714                    for (ix, m) in matches.iter().enumerate() {
 715                        matches_table.set(ix + 1, m.clone())?;
 716                    }
 717                    result_entry.set("matches", matches_table)?;
 718
 719                    search_results.push(result_entry);
 720                }
 721            }
 722        }
 723
 724        // Create a table to hold our results
 725        let results_table = lua.create_table()?;
 726        for (ix, entry) in search_results.into_iter().enumerate() {
 727            results_table.set(ix + 1, entry)?;
 728        }
 729
 730        Ok(results_table)
 731    }
 732
 733    async fn find_search_candidates(
 734        search_query: SearchQuery,
 735        foreground_tx: &mut mpsc::Sender<ForegroundFn>,
 736    ) -> anyhow::Result<mpsc::UnboundedReceiver<PathBuf>> {
 737        Self::run_foreground_fn(
 738            "finding search file candidates",
 739            foreground_tx,
 740            Box::new(move |session, mut cx| {
 741                session.update(&mut cx, |session, cx| {
 742                    session.project.update(cx, |project, cx| {
 743                        project.worktree_store().update(cx, |worktree_store, cx| {
 744                            // TODO: Better limit? For now this is the same as
 745                            // MAX_SEARCH_RESULT_FILES.
 746                            let limit = 5000;
 747                            // TODO: Providing non-empty open_entries can make this a bit more
 748                            // efficient as it can skip checking that these paths are textual.
 749                            let open_entries = HashSet::default();
 750                            let candidates = worktree_store.find_search_candidates(
 751                                search_query,
 752                                limit,
 753                                open_entries,
 754                                project.fs().clone(),
 755                                cx,
 756                            );
 757                            let (abs_paths_tx, abs_paths_rx) = mpsc::unbounded();
 758                            cx.spawn(async move |worktree_store, cx| {
 759                                pin_mut!(candidates);
 760
 761                                while let Some(project_path) = candidates.next().await {
 762                                    worktree_store.read_with(cx, |worktree_store, cx| {
 763                                        if let Some(worktree) = worktree_store
 764                                            .worktree_for_id(project_path.worktree_id, cx)
 765                                        {
 766                                            if let Some(abs_path) = worktree
 767                                                .read(cx)
 768                                                .absolutize(&project_path.path)
 769                                                .log_err()
 770                                            {
 771                                                abs_paths_tx.unbounded_send(abs_path)?;
 772                                            }
 773                                        }
 774                                        anyhow::Ok(())
 775                                    })??;
 776                                }
 777                                anyhow::Ok(())
 778                            })
 779                            .detach();
 780                            abs_paths_rx
 781                        })
 782                    })
 783                })
 784            }),
 785        )
 786        .await?
 787    }
 788
 789    async fn outline(
 790        root_dir: Option<Arc<Path>>,
 791        foreground_tx: &mut mpsc::Sender<ForegroundFn>,
 792        path_str: String,
 793    ) -> anyhow::Result<String> {
 794        let root_dir = root_dir
 795            .ok_or_else(|| mlua::Error::runtime("cannot get outline without a root directory"))?;
 796        let path = Self::parse_abs_path_in_root_dir(&root_dir, &path_str)?;
 797        let outline = Self::run_foreground_fn(
 798            "getting code outline",
 799            foreground_tx,
 800            Box::new(move |session, cx| {
 801                cx.spawn(async move |cx| {
 802                    // TODO: This will not use file content from `fs_changes`. It will also reflect
 803                    // user changes that have not been saved.
 804                    let buffer = session
 805                        .update(cx, |session, cx| {
 806                            session
 807                                .project
 808                                .update(cx, |project, cx| project.open_local_buffer(&path, cx))
 809                        })?
 810                        .await?;
 811                    buffer.update(cx, |buffer, _cx| {
 812                        if let Some(outline) = buffer.snapshot().outline(None) {
 813                            Ok(outline)
 814                        } else {
 815                            Err(anyhow!("No outline for file {path_str}"))
 816                        }
 817                    })
 818                })
 819            }),
 820        )
 821        .await?
 822        .await??;
 823
 824        Ok(outline
 825            .items
 826            .into_iter()
 827            .map(|item| {
 828                if item.text.contains('\n') {
 829                    log::error!("Outline item unexpectedly contains newline");
 830                }
 831                format!("{}{}", "  ".repeat(item.depth), item.text)
 832            })
 833            .collect::<Vec<String>>()
 834            .join("\n"))
 835    }
 836
 837    async fn run_foreground_fn<R: Send + 'static>(
 838        description: &str,
 839        foreground_tx: &mut mpsc::Sender<ForegroundFn>,
 840        function: Box<dyn FnOnce(WeakEntity<Self>, AsyncApp) -> R + Send>,
 841    ) -> anyhow::Result<R> {
 842        let (response_tx, response_rx) = oneshot::channel();
 843        let send_result = foreground_tx
 844            .send(ForegroundFn(Box::new(move |this, cx| {
 845                response_tx.send(function(this, cx)).ok();
 846            })))
 847            .await;
 848        match send_result {
 849            Ok(()) => (),
 850            Err(err) => {
 851                return Err(anyhow::Error::new(err).context(format!(
 852                    "Internal error while enqueuing work for {description}"
 853                )));
 854            }
 855        }
 856        match response_rx.await {
 857            Ok(result) => Ok(result),
 858            Err(oneshot::Canceled) => Err(anyhow!(
 859                "Internal error: response oneshot was canceled while {description}."
 860            )),
 861        }
 862    }
 863
 864    fn parse_abs_path_in_root_dir(root_dir: &Path, path_str: &str) -> anyhow::Result<PathBuf> {
 865        let path = Path::new(&path_str);
 866        if path.is_absolute() {
 867            // Check if path starts with root_dir prefix without resolving symlinks
 868            if path.starts_with(&root_dir) {
 869                Ok(path.to_path_buf())
 870            } else {
 871                Err(anyhow!(
 872                    "Error: Absolute path {} is outside the current working directory",
 873                    path_str
 874                ))
 875            }
 876        } else {
 877            // TODO: Does use of `../` break sandbox - is path canonicalization needed?
 878            Ok(root_dir.join(path))
 879        }
 880    }
 881}
 882
 883enum FileReadFormat {
 884    All,
 885    Line,
 886    LineWithLineFeed,
 887    Bytes(usize),
 888}
 889
 890struct FileContent(Arc<Mutex<Vec<u8>>>);
 891
 892impl UserData for FileContent {
 893    fn add_methods<M: UserDataMethods<Self>>(_methods: &mut M) {
 894        // FileContent doesn't have any methods so far.
 895    }
 896}
 897
 898#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
 899pub struct ScriptId(u32);
 900
 901pub struct Script {
 902    pub state: ScriptState,
 903}
 904
 905#[derive(Debug)]
 906pub enum ScriptState {
 907    Running {
 908        stdout: Arc<Mutex<String>>,
 909    },
 910    Succeeded {
 911        stdout: String,
 912    },
 913    Failed {
 914        stdout: String,
 915        error: anyhow::Error,
 916    },
 917}
 918
 919impl Script {
 920    /// If exited, returns a message with the output for the LLM
 921    pub fn output_message_for_llm(&self) -> Option<String> {
 922        match &self.state {
 923            ScriptState::Running { .. } => None,
 924            ScriptState::Succeeded { stdout } => {
 925                format!("Here's the script output:\n{}", stdout).into()
 926            }
 927            ScriptState::Failed { stdout, error } => format!(
 928                "The script failed with:\n{}\n\nHere's the output it managed to print:\n{}",
 929                error, stdout
 930            )
 931            .into(),
 932        }
 933    }
 934
 935    /// Get a snapshot of the script's stdout
 936    pub fn stdout_snapshot(&self) -> String {
 937        match &self.state {
 938            ScriptState::Running { stdout } => stdout.lock().clone(),
 939            ScriptState::Succeeded { stdout } => stdout.clone(),
 940            ScriptState::Failed { stdout, .. } => stdout.clone(),
 941        }
 942    }
 943}
 944
 945#[cfg(test)]
 946mod tests {
 947    use gpui::TestAppContext;
 948    use project::FakeFs;
 949    use serde_json::json;
 950    use settings::SettingsStore;
 951    use util::path;
 952
 953    use super::*;
 954
 955    #[gpui::test]
 956    async fn test_print(cx: &mut TestAppContext) {
 957        let script = r#"
 958            print("Hello", "world!")
 959            print("Goodbye", "moon!")
 960        "#;
 961
 962        let test_session = TestSession::init(cx).await;
 963        let output = test_session.test_success(script, cx).await;
 964        assert_eq!(output, "Hello\tworld!\nGoodbye\tmoon!\n");
 965    }
 966
 967    // search
 968
 969    #[gpui::test]
 970    async fn test_search(cx: &mut TestAppContext) {
 971        let script = r#"
 972            local results = search("world")
 973            for i, result in ipairs(results) do
 974                print("File: " .. result.path)
 975                print("Matches:")
 976                for j, match in ipairs(result.matches) do
 977                    print("  " .. match)
 978                end
 979            end
 980        "#;
 981
 982        let test_session = TestSession::init(cx).await;
 983        let output = test_session.test_success(script, cx).await;
 984        assert_eq!(
 985            output,
 986            concat!("File: ", path!("/file1.txt"), "\nMatches:\n  world\n")
 987        );
 988    }
 989
 990    // io.open
 991
 992    #[gpui::test]
 993    async fn test_open_and_read_file(cx: &mut TestAppContext) {
 994        let script = r#"
 995            local file = io.open("file1.txt", "r")
 996            local content = file:read()
 997            print("Content:", content)
 998            file:close()
 999        "#;
1000
1001        let test_session = TestSession::init(cx).await;
1002        let output = test_session.test_success(script, cx).await;
1003        assert_eq!(output, "Content:\tHello world!\n");
1004        assert_eq!(test_session.diff(cx), Vec::new());
1005    }
1006
1007    #[gpui::test]
1008    async fn test_lines_iterator(cx: &mut TestAppContext) {
1009        let script = r#"
1010            -- Create a test file with multiple lines
1011            local file = io.open("lines_test.txt", "w")
1012            file:write("Line 1\nLine 2\nLine 3\nLine 4\nLine 5")
1013            file:close()
1014
1015            -- Read it back using the lines iterator
1016            local read_file = io.open("lines_test.txt", "r")
1017            local count = 0
1018            for line in read_file:lines() do
1019                count = count + 1
1020                print(count .. ": " .. line)
1021            end
1022            read_file:close()
1023
1024            print("Total lines:", count)
1025        "#;
1026
1027        let test_session = TestSession::init(cx).await;
1028        let output = test_session.test_success(script, cx).await;
1029        assert_eq!(
1030            output,
1031            "1: Line 1\n2: Line 2\n3: Line 3\n4: Line 4\n5: Line 5\nTotal lines:\t5\n"
1032        );
1033    }
1034
1035    #[gpui::test]
1036    async fn test_read_write_roundtrip(cx: &mut TestAppContext) {
1037        let script = r#"
1038            local file = io.open("file1.txt", "w")
1039            file:write("This is new content")
1040            file:close()
1041
1042            -- Read back to verify
1043            local read_file = io.open("file1.txt", "r")
1044            local content = read_file:read("*a")
1045            print("Written content:", content)
1046            read_file:close()
1047        "#;
1048
1049        let test_session = TestSession::init(cx).await;
1050        let output = test_session.test_success(script, cx).await;
1051        assert_eq!(output, "Written content:\tThis is new content\n");
1052        assert_eq!(
1053            test_session.diff(cx),
1054            vec![(
1055                PathBuf::from("file1.txt"),
1056                vec![(
1057                    "Hello world!\n".to_string(),
1058                    "This is new content".to_string()
1059                )]
1060            )]
1061        );
1062    }
1063
1064    #[gpui::test]
1065    async fn test_multiple_writes(cx: &mut TestAppContext) {
1066        let script = r#"
1067            -- Test writing to a file multiple times
1068            local file = io.open("multiwrite.txt", "w")
1069            file:write("First line\n")
1070            file:write("Second line\n")
1071            file:write("Third line")
1072            file:close()
1073
1074            -- Read back to verify
1075            local read_file = io.open("multiwrite.txt", "r")
1076            if read_file then
1077                local content = read_file:read("*a")
1078                print("Full content:", content)
1079                read_file:close()
1080            end
1081        "#;
1082
1083        let test_session = TestSession::init(cx).await;
1084        let output = test_session.test_success(script, cx).await;
1085        assert_eq!(
1086            output,
1087            "Full content:\tFirst line\nSecond line\nThird line\n"
1088        );
1089        assert_eq!(
1090            test_session.diff(cx),
1091            vec![(
1092                PathBuf::from("multiwrite.txt"),
1093                vec![(
1094                    "".to_string(),
1095                    "First line\nSecond line\nThird line".to_string()
1096                )]
1097            )]
1098        );
1099    }
1100
1101    #[gpui::test]
1102    async fn test_multiple_writes_diff_handles(cx: &mut TestAppContext) {
1103        let script = r#"
1104            -- Write to a file
1105            local file1 = io.open("multi_open.txt", "w")
1106            file1:write("Content written by first handle\n")
1107            file1:close()
1108
1109            -- Open it again and add more content
1110            local file2 = io.open("multi_open.txt", "w")
1111            file2:write("Content written by second handle\n")
1112            file2:close()
1113
1114            -- Open it a third time and read
1115            local file3 = io.open("multi_open.txt", "r")
1116            local content = file3:read("*a")
1117            print("Final content:", content)
1118            file3:close()
1119        "#;
1120
1121        let test_session = TestSession::init(cx).await;
1122        let output = test_session.test_success(script, cx).await;
1123        assert_eq!(
1124            output,
1125            "Final content:\tContent written by second handle\n\n"
1126        );
1127        assert_eq!(
1128            test_session.diff(cx),
1129            vec![(
1130                PathBuf::from("multi_open.txt"),
1131                vec![(
1132                    "".to_string(),
1133                    "Content written by second handle\n".to_string()
1134                )]
1135            )]
1136        );
1137    }
1138
1139    #[gpui::test]
1140    async fn test_append_mode(cx: &mut TestAppContext) {
1141        let script = r#"
1142            -- Append more content
1143            file = io.open("file1.txt", "a")
1144            file:write("Appended content\n")
1145            file:close()
1146
1147            -- Add even more
1148            file = io.open("file1.txt", "a")
1149            file:write("More appended content")
1150            file:close()
1151
1152            -- Read back to verify
1153            local read_file = io.open("file1.txt", "r")
1154            local content = read_file:read("*a")
1155            print("Content after appends:", content)
1156            read_file:close()
1157        "#;
1158
1159        let test_session = TestSession::init(cx).await;
1160        let output = test_session.test_success(script, cx).await;
1161        assert_eq!(
1162            output,
1163            "Content after appends:\tHello world!\nAppended content\nMore appended content\n"
1164        );
1165        assert_eq!(
1166            test_session.diff(cx),
1167            vec![(
1168                PathBuf::from("file1.txt"),
1169                vec![(
1170                    "".to_string(),
1171                    "Appended content\nMore appended content".to_string()
1172                )]
1173            )]
1174        );
1175    }
1176
1177    #[gpui::test]
1178    async fn test_read_formats(cx: &mut TestAppContext) {
1179        let script = r#"
1180            local file = io.open("multiline.txt", "w")
1181            file:write("Line 1\nLine 2\nLine 3")
1182            file:close()
1183
1184            -- Test "*a" (all)
1185            local f = io.open("multiline.txt", "r")
1186            local all = f:read("*a")
1187            print("All:", all)
1188            f:close()
1189
1190            -- Test "*l" (line)
1191            f = io.open("multiline.txt", "r")
1192            local line1 = f:read("*l")
1193            local line2 = f:read("*l")
1194            local line3 = f:read("*l")
1195            print("Line 1:", line1)
1196            print("Line 2:", line2)
1197            print("Line 3:", line3)
1198            f:close()
1199
1200            -- Test "*L" (line with newline)
1201            f = io.open("multiline.txt", "r")
1202            local line_with_nl = f:read("*L")
1203            print("Line with newline length:", #line_with_nl)
1204            print("Last char:", string.byte(line_with_nl, #line_with_nl))
1205            f:close()
1206
1207            -- Test number of bytes
1208            f = io.open("multiline.txt", "r")
1209            local bytes5 = f:read(5)
1210            print("5 bytes:", bytes5)
1211            f:close()
1212        "#;
1213
1214        let test_session = TestSession::init(cx).await;
1215        let output = test_session.test_success(script, cx).await;
1216        println!("{}", &output);
1217        assert!(output.contains("All:\tLine 1\nLine 2\nLine 3"));
1218        assert!(output.contains("Line 1:\tLine 1"));
1219        assert!(output.contains("Line 2:\tLine 2"));
1220        assert!(output.contains("Line 3:\tLine 3"));
1221        assert!(output.contains("Line with newline length:\t7"));
1222        assert!(output.contains("Last char:\t10")); // LF
1223        assert!(output.contains("5 bytes:\tLine "));
1224        assert_eq!(
1225            test_session.diff(cx),
1226            vec![(
1227                PathBuf::from("multiline.txt"),
1228                vec![("".to_string(), "Line 1\nLine 2\nLine 3".to_string())]
1229            )]
1230        );
1231    }
1232
1233    // helpers
1234
1235    struct TestSession {
1236        session: Entity<ScriptingSession>,
1237    }
1238
1239    impl TestSession {
1240        async fn init(cx: &mut TestAppContext) -> Self {
1241            let settings_store = cx.update(SettingsStore::test);
1242            cx.set_global(settings_store);
1243            cx.update(Project::init_settings);
1244            cx.update(language::init);
1245
1246            let fs = FakeFs::new(cx.executor());
1247            fs.insert_tree(
1248                path!("/"),
1249                json!({
1250                    "file1.txt": "Hello world!\n",
1251                    "file2.txt": "Goodbye moon!\n"
1252                }),
1253            )
1254            .await;
1255
1256            let project = Project::test(fs.clone(), [Path::new(path!("/"))], cx).await;
1257            let session = cx.new(|cx| ScriptingSession::new(project, cx));
1258
1259            TestSession { session }
1260        }
1261
1262        async fn test_success(&self, source: &str, cx: &mut TestAppContext) -> String {
1263            let script_id = self.run_script(source, cx).await;
1264
1265            self.session.read_with(cx, |session, _cx| {
1266                let script = session.get(script_id);
1267                let stdout = script.stdout_snapshot();
1268
1269                if let ScriptState::Failed { error, .. } = &script.state {
1270                    panic!("Script failed:\n{}\n\n{}", error, stdout);
1271                }
1272
1273                stdout
1274            })
1275        }
1276
1277        fn diff(&self, cx: &mut TestAppContext) -> Vec<(PathBuf, Vec<(String, String)>)> {
1278            self.session.read_with(cx, |session, cx| {
1279                session
1280                    .changes_by_buffer
1281                    .iter()
1282                    .map(|(buffer, changes)| {
1283                        let snapshot = buffer.read(cx).snapshot();
1284                        let diff = changes.diff.read(cx);
1285                        let hunks = diff.hunks(&snapshot, cx);
1286                        let path = buffer.read(cx).file().unwrap().path().clone();
1287                        let diffs = hunks
1288                            .map(|hunk| {
1289                                let old_text = diff
1290                                    .base_text()
1291                                    .text_for_range(hunk.diff_base_byte_range)
1292                                    .collect::<String>();
1293                                let new_text =
1294                                    snapshot.text_for_range(hunk.range).collect::<String>();
1295                                (old_text, new_text)
1296                            })
1297                            .collect();
1298                        (path.to_path_buf(), diffs)
1299                    })
1300                    .collect()
1301            })
1302        }
1303
1304        async fn run_script(&self, source: &str, cx: &mut TestAppContext) -> ScriptId {
1305            let (script_id, task) = self
1306                .session
1307                .update(cx, |session, cx| session.run_script(source.to_string(), cx));
1308
1309            task.await;
1310
1311            script_id
1312        }
1313    }
1314}