scripting_session.rs

   1use anyhow::anyhow;
   2use buffer_diff::BufferDiff;
   3use collections::{HashMap, HashSet};
   4use futures::{
   5    channel::{mpsc, oneshot},
   6    pin_mut, SinkExt, StreamExt,
   7};
   8use gpui::{AppContext, AsyncApp, Context, Entity, Task, WeakEntity};
   9use language::Buffer;
  10use mlua::{ExternalResult, Lua, MultiValue, Table, UserData, UserDataMethods};
  11use parking_lot::Mutex;
  12use project::{search::SearchQuery, Fs, Project, ProjectPath, WorktreeId};
  13use regex::Regex;
  14use std::{
  15    path::{Path, PathBuf},
  16    sync::Arc,
  17};
  18use util::{paths::PathMatcher, ResultExt};
  19
  20struct ForegroundFn(Box<dyn FnOnce(WeakEntity<ScriptingSession>, AsyncApp) + Send>);
  21
  22struct BufferChanges {
  23    diff: Entity<BufferDiff>,
  24    edit_ids: Vec<clock::Lamport>,
  25}
  26
  27pub struct ScriptingSession {
  28    project: Entity<Project>,
  29    scripts: Vec<Script>,
  30    changes_by_buffer: HashMap<Entity<Buffer>, BufferChanges>,
  31    foreground_fns_tx: mpsc::Sender<ForegroundFn>,
  32    _invoke_foreground_fns: Task<()>,
  33}
  34
  35impl ScriptingSession {
  36    pub fn new(project: Entity<Project>, cx: &mut Context<Self>) -> Self {
  37        let (foreground_fns_tx, mut foreground_fns_rx) = mpsc::channel(128);
  38        ScriptingSession {
  39            project,
  40            scripts: Vec::new(),
  41            changes_by_buffer: HashMap::default(),
  42            foreground_fns_tx,
  43            _invoke_foreground_fns: cx.spawn(|this, cx| async move {
  44                while let Some(foreground_fn) = foreground_fns_rx.next().await {
  45                    foreground_fn.0(this.clone(), cx.clone());
  46                }
  47            }),
  48        }
  49    }
  50
  51    pub fn changed_buffers(&self) -> impl ExactSizeIterator<Item = &Entity<Buffer>> {
  52        self.changes_by_buffer.keys()
  53    }
  54
  55    pub fn run_script(
  56        &mut self,
  57        script_src: String,
  58        cx: &mut Context<Self>,
  59    ) -> (ScriptId, Task<()>) {
  60        let id = ScriptId(self.scripts.len() as u32);
  61
  62        let stdout = Arc::new(Mutex::new(String::new()));
  63
  64        let script = Script {
  65            state: ScriptState::Running {
  66                stdout: stdout.clone(),
  67            },
  68        };
  69        self.scripts.push(script);
  70
  71        let task = self.run_lua(script_src, stdout, cx);
  72
  73        let task = cx.spawn(|session, mut cx| async move {
  74            let result = task.await;
  75
  76            session
  77                .update(&mut cx, |session, _cx| {
  78                    let script = session.get_mut(id);
  79                    let stdout = script.stdout_snapshot();
  80
  81                    script.state = match result {
  82                        Ok(()) => ScriptState::Succeeded { stdout },
  83                        Err(error) => ScriptState::Failed { stdout, error },
  84                    };
  85                })
  86                .log_err();
  87        });
  88
  89        (id, task)
  90    }
  91
  92    fn run_lua(
  93        &mut self,
  94        script: String,
  95        stdout: Arc<Mutex<String>>,
  96        cx: &mut Context<Self>,
  97    ) -> Task<anyhow::Result<()>> {
  98        const SANDBOX_PREAMBLE: &str = include_str!("sandbox_preamble.lua");
  99
 100        // TODO Honor all worktrees instead of the first one
 101        let worktree_info = self
 102            .project
 103            .read(cx)
 104            .visible_worktrees(cx)
 105            .next()
 106            .map(|worktree| {
 107                let worktree = worktree.read(cx);
 108                (worktree.id(), worktree.abs_path())
 109            });
 110
 111        let root_dir = worktree_info.as_ref().map(|(_, root)| root.clone());
 112
 113        let fs = self.project.read(cx).fs().clone();
 114        let foreground_fns_tx = self.foreground_fns_tx.clone();
 115
 116        let task = cx.background_spawn({
 117            let stdout = stdout.clone();
 118
 119            async move {
 120                let lua = Lua::new();
 121                lua.set_memory_limit(2 * 1024 * 1024 * 1024)?; // 2 GB
 122                let globals = lua.globals();
 123
 124                // Use the project root dir as the script's current working dir.
 125                if let Some(root_dir) = &root_dir {
 126                    if let Some(root_dir) = root_dir.to_str() {
 127                        globals.set("cwd", root_dir)?;
 128                    }
 129                }
 130
 131                globals.set(
 132                    "sb_print",
 133                    lua.create_function({
 134                        let stdout = stdout.clone();
 135                        move |_, args: MultiValue| Self::print(args, &stdout)
 136                    })?,
 137                )?;
 138                globals.set(
 139                    "search",
 140                    lua.create_async_function({
 141                        let foreground_fns_tx = foreground_fns_tx.clone();
 142                        let fs = fs.clone();
 143                        move |lua, regex| {
 144                            let mut foreground_fns_tx = foreground_fns_tx.clone();
 145                            let fs = fs.clone();
 146                            async move {
 147                                Self::search(&lua, &mut foreground_fns_tx, fs, regex)
 148                                    .await
 149                                    .into_lua_err()
 150                            }
 151                        }
 152                    })?,
 153                )?;
 154                globals.set(
 155                    "outline",
 156                    lua.create_async_function({
 157                        let root_dir = root_dir.clone();
 158                        let foreground_fns_tx = foreground_fns_tx.clone();
 159                        move |_lua, path| {
 160                            let mut foreground_fns_tx = foreground_fns_tx.clone();
 161                            let root_dir = root_dir.clone();
 162                            async move {
 163                                Self::outline(root_dir, &mut foreground_fns_tx, path)
 164                                    .await
 165                                    .into_lua_err()
 166                            }
 167                        }
 168                    })?,
 169                )?;
 170                globals.set(
 171                    "sb_io_open",
 172                    lua.create_async_function({
 173                        let worktree_info = worktree_info.clone();
 174                        let foreground_fns_tx = foreground_fns_tx.clone();
 175                        move |lua, (path_str, mode)| {
 176                            let worktree_info = worktree_info.clone();
 177                            let mut foreground_fns_tx = foreground_fns_tx.clone();
 178                            let fs = fs.clone();
 179                            async move {
 180                                Self::io_open(
 181                                    &lua,
 182                                    worktree_info,
 183                                    &mut foreground_fns_tx,
 184                                    fs,
 185                                    path_str,
 186                                    mode,
 187                                )
 188                                .await
 189                            }
 190                        }
 191                    })?,
 192                )?;
 193                globals.set("user_script", script)?;
 194
 195                lua.load(SANDBOX_PREAMBLE).exec_async().await?;
 196
 197                anyhow::Ok(())
 198            }
 199        });
 200
 201        task
 202    }
 203
 204    pub fn get(&self, script_id: ScriptId) -> &Script {
 205        &self.scripts[script_id.0 as usize]
 206    }
 207
 208    fn get_mut(&mut self, script_id: ScriptId) -> &mut Script {
 209        &mut self.scripts[script_id.0 as usize]
 210    }
 211
 212    /// Sandboxed print() function in Lua.
 213    fn print(args: MultiValue, stdout: &Mutex<String>) -> mlua::Result<()> {
 214        for (index, arg) in args.into_iter().enumerate() {
 215            // Lua's `print()` prints tab characters between each argument.
 216            if index > 0 {
 217                stdout.lock().push('\t');
 218            }
 219
 220            // If the argument's to_string() fails, have the whole function call fail.
 221            stdout.lock().push_str(&arg.to_string()?);
 222        }
 223        stdout.lock().push('\n');
 224
 225        Ok(())
 226    }
 227
 228    /// Sandboxed io.open() function in Lua.
 229    async fn io_open(
 230        lua: &Lua,
 231        worktree_info: Option<(WorktreeId, Arc<Path>)>,
 232        foreground_tx: &mut mpsc::Sender<ForegroundFn>,
 233        fs: Arc<dyn Fs>,
 234        path_str: String,
 235        mode: Option<String>,
 236    ) -> mlua::Result<(Option<Table>, String)> {
 237        let (worktree_id, root_dir) = worktree_info
 238            .ok_or_else(|| mlua::Error::runtime("cannot open file without a root directory"))?;
 239
 240        let mode = mode.unwrap_or_else(|| "r".to_string());
 241
 242        // Parse the mode string to determine read/write permissions
 243        let read_perm = mode.contains('r');
 244        let write_perm = mode.contains('w') || mode.contains('a') || mode.contains('+');
 245        let append = mode.contains('a');
 246        let truncate = mode.contains('w');
 247
 248        // This will be the Lua value returned from the `open` function.
 249        let file = lua.create_table()?;
 250
 251        // Store file metadata in the file
 252        file.set("__mode", mode.clone())?;
 253        file.set("__read_perm", read_perm)?;
 254        file.set("__write_perm", write_perm)?;
 255
 256        let path = match Self::parse_abs_path_in_root_dir(&root_dir, &path_str) {
 257            Ok(path) => path,
 258            Err(err) => return Ok((None, format!("{err}"))),
 259        };
 260
 261        let project_path = ProjectPath {
 262            worktree_id,
 263            path: Path::new(&path_str).into(),
 264        };
 265
 266        // flush / close method
 267        let flush_fn = {
 268            let project_path = project_path.clone();
 269            let foreground_tx = foreground_tx.clone();
 270            lua.create_async_function(move |_lua, file_userdata: mlua::Table| {
 271                let project_path = project_path.clone();
 272                let mut foreground_tx = foreground_tx.clone();
 273                async move {
 274                    Self::io_file_flush(file_userdata, project_path, &mut foreground_tx).await
 275                }
 276            })?
 277        };
 278        file.set("flush", flush_fn.clone())?;
 279        // We don't really hold files open, so we only need to flush on close
 280        file.set("close", flush_fn)?;
 281
 282        // If it's a directory, give it a custom read() and return early.
 283        if fs.is_dir(&path).await {
 284            return Self::io_file_dir(lua, fs, file, &path).await;
 285        }
 286
 287        let mut file_content = Vec::new();
 288
 289        if !truncate {
 290            // Try to read existing content if we're not truncating
 291            match Self::read_buffer(project_path.clone(), foreground_tx).await {
 292                Ok(content) => file_content = content.into_bytes(),
 293                Err(e) => return Ok((None, format!("Error reading file: {}", e))),
 294            }
 295        }
 296
 297        // If in append mode, position should be at the end
 298        let position = if append { file_content.len() } else { 0 };
 299        file.set("__position", position)?;
 300        file.set(
 301            "__content",
 302            lua.create_userdata(FileContent(Arc::new(Mutex::new(file_content))))?,
 303        )?;
 304
 305        // Create file methods
 306
 307        // read method
 308        let read_fn = lua.create_function(Self::io_file_read)?;
 309        file.set("read", read_fn)?;
 310
 311        // write method
 312        let write_fn = lua.create_function(Self::io_file_write)?;
 313        file.set("write", write_fn)?;
 314
 315        // If we got this far, the file was opened successfully
 316        Ok((Some(file), String::new()))
 317    }
 318
 319    async fn read_buffer(
 320        project_path: ProjectPath,
 321        foreground_tx: &mut mpsc::Sender<ForegroundFn>,
 322    ) -> anyhow::Result<String> {
 323        Self::run_foreground_fn(
 324            "read file from buffer",
 325            foreground_tx,
 326            Box::new(move |session, mut cx| {
 327                session.update(&mut cx, |session, cx| {
 328                    let open_buffer_task = session
 329                        .project
 330                        .update(cx, |project, cx| project.open_buffer(project_path, cx));
 331
 332                    cx.spawn(|_, cx| async move {
 333                        let buffer = open_buffer_task.await?;
 334
 335                        let text = buffer.read_with(&cx, |buffer, _cx| buffer.text())?;
 336                        Ok(text)
 337                    })
 338                })
 339            }),
 340        )
 341        .await??
 342        .await
 343    }
 344
 345    async fn io_file_flush(
 346        file_userdata: mlua::Table,
 347        project_path: ProjectPath,
 348        foreground_tx: &mut mpsc::Sender<ForegroundFn>,
 349    ) -> mlua::Result<bool> {
 350        let write_perm = file_userdata.get::<bool>("__write_perm")?;
 351
 352        if write_perm {
 353            let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
 354            let content_ref = content.borrow::<FileContent>()?;
 355            let text = {
 356                let mut content_vec = content_ref.0.lock();
 357                let content_vec = std::mem::take(&mut *content_vec);
 358                String::from_utf8(content_vec).into_lua_err()?
 359            };
 360
 361            Self::write_to_buffer(project_path, text, foreground_tx)
 362                .await
 363                .into_lua_err()?;
 364        }
 365
 366        Ok(true)
 367    }
 368
 369    async fn write_to_buffer(
 370        project_path: ProjectPath,
 371        text: String,
 372        foreground_tx: &mut mpsc::Sender<ForegroundFn>,
 373    ) -> anyhow::Result<()> {
 374        Self::run_foreground_fn(
 375            "write to buffer",
 376            foreground_tx,
 377            Box::new(move |session, mut cx| {
 378                session.update(&mut cx, |session, cx| {
 379                    let open_buffer_task = session
 380                        .project
 381                        .update(cx, |project, cx| project.open_buffer(project_path, cx));
 382
 383                    cx.spawn(move |session, mut cx| async move {
 384                        let buffer = open_buffer_task.await?;
 385
 386                        let diff = buffer
 387                            .update(&mut cx, |buffer, cx| buffer.diff(text, cx))?
 388                            .await;
 389
 390                        let edit_ids = buffer.update(&mut cx, |buffer, cx| {
 391                            buffer.finalize_last_transaction();
 392                            buffer.apply_diff(diff, cx);
 393                            let transaction = buffer.finalize_last_transaction();
 394                            transaction
 395                                .map_or(Vec::new(), |transaction| transaction.edit_ids.clone())
 396                        })?;
 397
 398                        session
 399                            .update(&mut cx, {
 400                                let buffer = buffer.clone();
 401
 402                                |session, cx| {
 403                                    session
 404                                        .project
 405                                        .update(cx, |project, cx| project.save_buffer(buffer, cx))
 406                                }
 407                            })?
 408                            .await?;
 409
 410                        let snapshot = buffer.read_with(&cx, |buffer, _| buffer.snapshot())?;
 411
 412                        // If we saved successfully, mark buffer as changed
 413                        let buffer_without_changes =
 414                            buffer.update(&mut cx, |buffer, cx| buffer.branch(cx))?;
 415                        session
 416                            .update(&mut cx, |session, cx| {
 417                                let changed_buffer = session
 418                                    .changes_by_buffer
 419                                    .entry(buffer)
 420                                    .or_insert_with(|| BufferChanges {
 421                                        diff: cx.new(|cx| BufferDiff::new(&snapshot, cx)),
 422                                        edit_ids: Vec::new(),
 423                                    });
 424                                changed_buffer.edit_ids.extend(edit_ids);
 425                                let operations_to_undo = changed_buffer
 426                                    .edit_ids
 427                                    .iter()
 428                                    .map(|edit_id| (*edit_id, u32::MAX))
 429                                    .collect::<HashMap<_, _>>();
 430                                buffer_without_changes.update(cx, |buffer, cx| {
 431                                    buffer.undo_operations(operations_to_undo, cx);
 432                                });
 433                                changed_buffer.diff.update(cx, |diff, cx| {
 434                                    diff.set_base_text(buffer_without_changes, snapshot.text, cx)
 435                                })
 436                            })?
 437                            .await?;
 438
 439                        Ok(())
 440                    })
 441                })
 442            }),
 443        )
 444        .await??
 445        .await
 446    }
 447
 448    async fn io_file_dir(
 449        lua: &Lua,
 450        fs: Arc<dyn Fs>,
 451        file: Table,
 452        path: &Path,
 453    ) -> mlua::Result<(Option<Table>, String)> {
 454        // Create a special directory handle
 455        file.set("__is_directory", true)?;
 456
 457        // Store directory entries
 458        let entries = match fs.read_dir(&path).await {
 459            Ok(entries) => {
 460                let mut entry_names = Vec::new();
 461
 462                // Process the stream of directory entries
 463                pin_mut!(entries);
 464                while let Some(Ok(entry_result)) = entries.next().await {
 465                    if let Some(file_name) = entry_result.file_name() {
 466                        entry_names.push(file_name.to_string_lossy().into_owned());
 467                    }
 468                }
 469
 470                entry_names
 471            }
 472            Err(e) => return Ok((None, format!("Error reading directory: {}", e))),
 473        };
 474
 475        // Save the list of entries
 476        file.set("__dir_entries", entries)?;
 477        file.set("__dir_position", 0usize)?;
 478
 479        // Create a directory-specific read function
 480        let read_fn = lua.create_function(|_lua, file_userdata: mlua::Table| {
 481            let position = file_userdata.get::<usize>("__dir_position")?;
 482            let entries = file_userdata.get::<Vec<String>>("__dir_entries")?;
 483
 484            if position >= entries.len() {
 485                return Ok(None); // No more entries
 486            }
 487
 488            let entry = entries[position].clone();
 489            file_userdata.set("__dir_position", position + 1)?;
 490
 491            Ok(Some(entry))
 492        })?;
 493        file.set("read", read_fn)?;
 494
 495        // If we got this far, the directory was opened successfully
 496        return Ok((Some(file), String::new()));
 497    }
 498
 499    fn io_file_read(
 500        lua: &Lua,
 501        (file_userdata, format): (Table, Option<mlua::Value>),
 502    ) -> mlua::Result<Option<mlua::String>> {
 503        let read_perm = file_userdata.get::<bool>("__read_perm")?;
 504        if !read_perm {
 505            return Err(mlua::Error::runtime("File not open for reading"));
 506        }
 507
 508        let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
 509        let position = file_userdata.get::<usize>("__position")?;
 510        let content_ref = content.borrow::<FileContent>()?;
 511        let content = content_ref.0.lock();
 512
 513        if position >= content.len() {
 514            return Ok(None); // EOF
 515        }
 516
 517        let (result, new_position) = match Self::io_file_read_format(format)? {
 518            FileReadFormat::All => {
 519                // Read entire file from current position
 520                let result = content[position..].to_vec();
 521                (Some(result), content.len())
 522            }
 523            FileReadFormat::Line => {
 524                if let Some(next_newline_ix) = content[position..].iter().position(|c| *c == b'\n')
 525                {
 526                    let mut line = content[position..position + next_newline_ix].to_vec();
 527                    if line.ends_with(b"\r") {
 528                        line.pop();
 529                    }
 530                    (Some(line), position + next_newline_ix + 1)
 531                } else if position < content.len() {
 532                    let line = content[position..].to_vec();
 533                    (Some(line), content.len())
 534                } else {
 535                    (None, position) // EOF
 536                }
 537            }
 538            FileReadFormat::LineWithLineFeed => {
 539                if position < content.len() {
 540                    let next_line_ix = content[position..]
 541                        .iter()
 542                        .position(|c| *c == b'\n')
 543                        .map_or(content.len(), |ix| position + ix + 1);
 544                    let line = content[position..next_line_ix].to_vec();
 545                    (Some(line), next_line_ix)
 546                } else {
 547                    (None, position) // EOF
 548                }
 549            }
 550            FileReadFormat::Bytes(n) => {
 551                let end = std::cmp::min(position + n, content.len());
 552                let result = content[position..end].to_vec();
 553                (Some(result), end)
 554            }
 555        };
 556
 557        // Update the position in the file userdata
 558        if new_position != position {
 559            file_userdata.set("__position", new_position)?;
 560        }
 561
 562        // Convert the result to a Lua string
 563        match result {
 564            Some(bytes) => Ok(Some(lua.create_string(bytes)?)),
 565            None => Ok(None),
 566        }
 567    }
 568
 569    fn io_file_read_format(format: Option<mlua::Value>) -> mlua::Result<FileReadFormat> {
 570        let format = match format {
 571            Some(mlua::Value::String(s)) => {
 572                let lossy_string = s.to_string_lossy();
 573                let format_str: &str = lossy_string.as_ref();
 574
 575                // Only consider the first 2 bytes, since it's common to pass e.g. "*all"  instead of "*a"
 576                match &format_str[0..2] {
 577                    "*a" => FileReadFormat::All,
 578                    "*l" => FileReadFormat::Line,
 579                    "*L" => FileReadFormat::LineWithLineFeed,
 580                    "*n" => {
 581                        // Try to parse as a number (number of bytes to read)
 582                        match format_str.parse::<usize>() {
 583                            Ok(n) => FileReadFormat::Bytes(n),
 584                            Err(_) => {
 585                                return Err(mlua::Error::runtime(format!(
 586                                    "Invalid format: {}",
 587                                    format_str
 588                                )))
 589                            }
 590                        }
 591                    }
 592                    _ => {
 593                        return Err(mlua::Error::runtime(format!(
 594                            "Unsupported format: {}",
 595                            format_str
 596                        )))
 597                    }
 598                }
 599            }
 600            Some(mlua::Value::Number(n)) => FileReadFormat::Bytes(n as usize),
 601            Some(mlua::Value::Integer(n)) => FileReadFormat::Bytes(n as usize),
 602            Some(value) => {
 603                return Err(mlua::Error::runtime(format!(
 604                    "Invalid file format {:?}",
 605                    value
 606                )))
 607            }
 608            None => FileReadFormat::Line, // Default is to read a line
 609        };
 610
 611        Ok(format)
 612    }
 613
 614    fn io_file_write(
 615        _lua: &Lua,
 616        (file_userdata, text): (Table, mlua::String),
 617    ) -> mlua::Result<bool> {
 618        let write_perm = file_userdata.get::<bool>("__write_perm")?;
 619        if !write_perm {
 620            return Err(mlua::Error::runtime("File not open for writing"));
 621        }
 622
 623        let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
 624        let position = file_userdata.get::<usize>("__position")?;
 625        let content_ref = content.borrow::<FileContent>()?;
 626        let mut content_vec = content_ref.0.lock();
 627
 628        let bytes = text.as_bytes();
 629
 630        // Ensure the vector has enough capacity
 631        if position + bytes.len() > content_vec.len() {
 632            content_vec.resize(position + bytes.len(), 0);
 633        }
 634
 635        // Write the bytes
 636        for (i, &byte) in bytes.iter().enumerate() {
 637            content_vec[position + i] = byte;
 638        }
 639
 640        // Update position
 641        let new_position = position + bytes.len();
 642        file_userdata.set("__position", new_position)?;
 643
 644        Ok(true)
 645    }
 646
 647    async fn search(
 648        lua: &Lua,
 649        foreground_tx: &mut mpsc::Sender<ForegroundFn>,
 650        fs: Arc<dyn Fs>,
 651        regex: String,
 652    ) -> anyhow::Result<Table> {
 653        // TODO: Allow specification of these options.
 654        let search_query = SearchQuery::regex(
 655            &regex,
 656            false,
 657            false,
 658            false,
 659            PathMatcher::default(),
 660            PathMatcher::default(),
 661            None,
 662        );
 663        let search_query = match search_query {
 664            Ok(query) => query,
 665            Err(e) => return Err(anyhow!("Invalid search query: {}", e)),
 666        };
 667
 668        // TODO: Should use `search_query.regex`. The tool description should also be updated,
 669        // as it specifies standard regex.
 670        let search_regex = match Regex::new(&regex) {
 671            Ok(re) => re,
 672            Err(e) => return Err(anyhow!("Invalid regex: {}", e)),
 673        };
 674
 675        let mut abs_paths_rx = Self::find_search_candidates(search_query, foreground_tx).await?;
 676
 677        let mut search_results: Vec<Table> = Vec::new();
 678        while let Some(path) = abs_paths_rx.next().await {
 679            // Skip files larger than 1MB
 680            if let Ok(Some(metadata)) = fs.metadata(&path).await {
 681                if metadata.len > 1_000_000 {
 682                    continue;
 683                }
 684            }
 685
 686            // Attempt to read the file as text
 687            if let Ok(content) = fs.load(&path).await {
 688                let mut matches = Vec::new();
 689
 690                // Find all regex matches in the content
 691                for capture in search_regex.find_iter(&content) {
 692                    matches.push(capture.as_str().to_string());
 693                }
 694
 695                // If we found matches, create a result entry
 696                if !matches.is_empty() {
 697                    let result_entry = lua.create_table()?;
 698                    result_entry.set("path", path.to_string_lossy().to_string())?;
 699
 700                    let matches_table = lua.create_table()?;
 701                    for (ix, m) in matches.iter().enumerate() {
 702                        matches_table.set(ix + 1, m.clone())?;
 703                    }
 704                    result_entry.set("matches", matches_table)?;
 705
 706                    search_results.push(result_entry);
 707                }
 708            }
 709        }
 710
 711        // Create a table to hold our results
 712        let results_table = lua.create_table()?;
 713        for (ix, entry) in search_results.into_iter().enumerate() {
 714            results_table.set(ix + 1, entry)?;
 715        }
 716
 717        Ok(results_table)
 718    }
 719
 720    async fn find_search_candidates(
 721        search_query: SearchQuery,
 722        foreground_tx: &mut mpsc::Sender<ForegroundFn>,
 723    ) -> anyhow::Result<mpsc::UnboundedReceiver<PathBuf>> {
 724        Self::run_foreground_fn(
 725            "finding search file candidates",
 726            foreground_tx,
 727            Box::new(move |session, mut cx| {
 728                session.update(&mut cx, |session, cx| {
 729                    session.project.update(cx, |project, cx| {
 730                        project.worktree_store().update(cx, |worktree_store, cx| {
 731                            // TODO: Better limit? For now this is the same as
 732                            // MAX_SEARCH_RESULT_FILES.
 733                            let limit = 5000;
 734                            // TODO: Providing non-empty open_entries can make this a bit more
 735                            // efficient as it can skip checking that these paths are textual.
 736                            let open_entries = HashSet::default();
 737                            let candidates = worktree_store.find_search_candidates(
 738                                search_query,
 739                                limit,
 740                                open_entries,
 741                                project.fs().clone(),
 742                                cx,
 743                            );
 744                            let (abs_paths_tx, abs_paths_rx) = mpsc::unbounded();
 745                            cx.spawn(|worktree_store, cx| async move {
 746                                pin_mut!(candidates);
 747
 748                                while let Some(project_path) = candidates.next().await {
 749                                    worktree_store.read_with(&cx, |worktree_store, cx| {
 750                                        if let Some(worktree) = worktree_store
 751                                            .worktree_for_id(project_path.worktree_id, cx)
 752                                        {
 753                                            if let Some(abs_path) = worktree
 754                                                .read(cx)
 755                                                .absolutize(&project_path.path)
 756                                                .log_err()
 757                                            {
 758                                                abs_paths_tx.unbounded_send(abs_path)?;
 759                                            }
 760                                        }
 761                                        anyhow::Ok(())
 762                                    })??;
 763                                }
 764                                anyhow::Ok(())
 765                            })
 766                            .detach();
 767                            abs_paths_rx
 768                        })
 769                    })
 770                })
 771            }),
 772        )
 773        .await?
 774    }
 775
 776    async fn outline(
 777        root_dir: Option<Arc<Path>>,
 778        foreground_tx: &mut mpsc::Sender<ForegroundFn>,
 779        path_str: String,
 780    ) -> anyhow::Result<String> {
 781        let root_dir = root_dir
 782            .ok_or_else(|| mlua::Error::runtime("cannot get outline without a root directory"))?;
 783        let path = Self::parse_abs_path_in_root_dir(&root_dir, &path_str)?;
 784        let outline = Self::run_foreground_fn(
 785            "getting code outline",
 786            foreground_tx,
 787            Box::new(move |session, cx| {
 788                cx.spawn(move |mut cx| async move {
 789                    // TODO: This will not use file content from `fs_changes`. It will also reflect
 790                    // user changes that have not been saved.
 791                    let buffer = session
 792                        .update(&mut cx, |session, cx| {
 793                            session
 794                                .project
 795                                .update(cx, |project, cx| project.open_local_buffer(&path, cx))
 796                        })?
 797                        .await?;
 798                    buffer.update(&mut cx, |buffer, _cx| {
 799                        if let Some(outline) = buffer.snapshot().outline(None) {
 800                            Ok(outline)
 801                        } else {
 802                            Err(anyhow!("No outline for file {path_str}"))
 803                        }
 804                    })
 805                })
 806            }),
 807        )
 808        .await?
 809        .await??;
 810
 811        Ok(outline
 812            .items
 813            .into_iter()
 814            .map(|item| {
 815                if item.text.contains('\n') {
 816                    log::error!("Outline item unexpectedly contains newline");
 817                }
 818                format!("{}{}", "  ".repeat(item.depth), item.text)
 819            })
 820            .collect::<Vec<String>>()
 821            .join("\n"))
 822    }
 823
 824    async fn run_foreground_fn<R: Send + 'static>(
 825        description: &str,
 826        foreground_tx: &mut mpsc::Sender<ForegroundFn>,
 827        function: Box<dyn FnOnce(WeakEntity<Self>, AsyncApp) -> R + Send>,
 828    ) -> anyhow::Result<R> {
 829        let (response_tx, response_rx) = oneshot::channel();
 830        let send_result = foreground_tx
 831            .send(ForegroundFn(Box::new(move |this, cx| {
 832                response_tx.send(function(this, cx)).ok();
 833            })))
 834            .await;
 835        match send_result {
 836            Ok(()) => (),
 837            Err(err) => {
 838                return Err(anyhow::Error::new(err).context(format!(
 839                    "Internal error while enqueuing work for {description}"
 840                )));
 841            }
 842        }
 843        match response_rx.await {
 844            Ok(result) => Ok(result),
 845            Err(oneshot::Canceled) => Err(anyhow!(
 846                "Internal error: response oneshot was canceled while {description}."
 847            )),
 848        }
 849    }
 850
 851    fn parse_abs_path_in_root_dir(root_dir: &Path, path_str: &str) -> anyhow::Result<PathBuf> {
 852        let path = Path::new(&path_str);
 853        if path.is_absolute() {
 854            // Check if path starts with root_dir prefix without resolving symlinks
 855            if path.starts_with(&root_dir) {
 856                Ok(path.to_path_buf())
 857            } else {
 858                Err(anyhow!(
 859                    "Error: Absolute path {} is outside the current working directory",
 860                    path_str
 861                ))
 862            }
 863        } else {
 864            // TODO: Does use of `../` break sandbox - is path canonicalization needed?
 865            Ok(root_dir.join(path))
 866        }
 867    }
 868}
 869
 870enum FileReadFormat {
 871    All,
 872    Line,
 873    LineWithLineFeed,
 874    Bytes(usize),
 875}
 876
 877struct FileContent(Arc<Mutex<Vec<u8>>>);
 878
 879impl UserData for FileContent {
 880    fn add_methods<M: UserDataMethods<Self>>(_methods: &mut M) {
 881        // FileContent doesn't have any methods so far.
 882    }
 883}
 884
 885#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
 886pub struct ScriptId(u32);
 887
 888pub struct Script {
 889    pub state: ScriptState,
 890}
 891
 892#[derive(Debug)]
 893pub enum ScriptState {
 894    Running {
 895        stdout: Arc<Mutex<String>>,
 896    },
 897    Succeeded {
 898        stdout: String,
 899    },
 900    Failed {
 901        stdout: String,
 902        error: anyhow::Error,
 903    },
 904}
 905
 906impl Script {
 907    /// If exited, returns a message with the output for the LLM
 908    pub fn output_message_for_llm(&self) -> Option<String> {
 909        match &self.state {
 910            ScriptState::Running { .. } => None,
 911            ScriptState::Succeeded { stdout } => {
 912                format!("Here's the script output:\n{}", stdout).into()
 913            }
 914            ScriptState::Failed { stdout, error } => format!(
 915                "The script failed with:\n{}\n\nHere's the output it managed to print:\n{}",
 916                error, stdout
 917            )
 918            .into(),
 919        }
 920    }
 921
 922    /// Get a snapshot of the script's stdout
 923    pub fn stdout_snapshot(&self) -> String {
 924        match &self.state {
 925            ScriptState::Running { stdout } => stdout.lock().clone(),
 926            ScriptState::Succeeded { stdout } => stdout.clone(),
 927            ScriptState::Failed { stdout, .. } => stdout.clone(),
 928        }
 929    }
 930}
 931
 932#[cfg(test)]
 933mod tests {
 934    use gpui::TestAppContext;
 935    use project::FakeFs;
 936    use serde_json::json;
 937    use settings::SettingsStore;
 938    use util::path;
 939
 940    use super::*;
 941
 942    #[gpui::test]
 943    async fn test_print(cx: &mut TestAppContext) {
 944        let script = r#"
 945            print("Hello", "world!")
 946            print("Goodbye", "moon!")
 947        "#;
 948
 949        let test_session = TestSession::init(cx).await;
 950        let output = test_session.test_success(script, cx).await;
 951        assert_eq!(output, "Hello\tworld!\nGoodbye\tmoon!\n");
 952    }
 953
 954    // search
 955
 956    #[gpui::test]
 957    async fn test_search(cx: &mut TestAppContext) {
 958        let script = r#"
 959            local results = search("world")
 960            for i, result in ipairs(results) do
 961                print("File: " .. result.path)
 962                print("Matches:")
 963                for j, match in ipairs(result.matches) do
 964                    print("  " .. match)
 965                end
 966            end
 967        "#;
 968
 969        let test_session = TestSession::init(cx).await;
 970        let output = test_session.test_success(script, cx).await;
 971        assert_eq!(
 972            output,
 973            concat!("File: ", path!("/file1.txt"), "\nMatches:\n  world\n")
 974        );
 975    }
 976
 977    // io.open
 978
 979    #[gpui::test]
 980    async fn test_open_and_read_file(cx: &mut TestAppContext) {
 981        let script = r#"
 982            local file = io.open("file1.txt", "r")
 983            local content = file:read()
 984            print("Content:", content)
 985            file:close()
 986        "#;
 987
 988        let test_session = TestSession::init(cx).await;
 989        let output = test_session.test_success(script, cx).await;
 990        assert_eq!(output, "Content:\tHello world!\n");
 991        assert_eq!(test_session.diff(cx), Vec::new());
 992    }
 993
 994    #[gpui::test]
 995    async fn test_read_write_roundtrip(cx: &mut TestAppContext) {
 996        let script = r#"
 997            local file = io.open("file1.txt", "w")
 998            file:write("This is new content")
 999            file:close()
1000
1001            -- Read back to verify
1002            local read_file = io.open("file1.txt", "r")
1003            local content = read_file:read("*a")
1004            print("Written content:", content)
1005            read_file:close()
1006        "#;
1007
1008        let test_session = TestSession::init(cx).await;
1009        let output = test_session.test_success(script, cx).await;
1010        assert_eq!(output, "Written content:\tThis is new content\n");
1011        assert_eq!(
1012            test_session.diff(cx),
1013            vec![(
1014                PathBuf::from("file1.txt"),
1015                vec![(
1016                    "Hello world!\n".to_string(),
1017                    "This is new content".to_string()
1018                )]
1019            )]
1020        );
1021    }
1022
1023    #[gpui::test]
1024    async fn test_multiple_writes(cx: &mut TestAppContext) {
1025        let script = r#"
1026            -- Test writing to a file multiple times
1027            local file = io.open("multiwrite.txt", "w")
1028            file:write("First line\n")
1029            file:write("Second line\n")
1030            file:write("Third line")
1031            file:close()
1032
1033            -- Read back to verify
1034            local read_file = io.open("multiwrite.txt", "r")
1035            if read_file then
1036                local content = read_file:read("*a")
1037                print("Full content:", content)
1038                read_file:close()
1039            end
1040        "#;
1041
1042        let test_session = TestSession::init(cx).await;
1043        let output = test_session.test_success(script, cx).await;
1044        assert_eq!(
1045            output,
1046            "Full content:\tFirst line\nSecond line\nThird line\n"
1047        );
1048        assert_eq!(
1049            test_session.diff(cx),
1050            vec![(
1051                PathBuf::from("multiwrite.txt"),
1052                vec![(
1053                    "".to_string(),
1054                    "First line\nSecond line\nThird line".to_string()
1055                )]
1056            )]
1057        );
1058    }
1059
1060    #[gpui::test]
1061    async fn test_multiple_writes_diff_handles(cx: &mut TestAppContext) {
1062        let script = r#"
1063            -- Write to a file
1064            local file1 = io.open("multi_open.txt", "w")
1065            file1:write("Content written by first handle\n")
1066            file1:close()
1067
1068            -- Open it again and add more content
1069            local file2 = io.open("multi_open.txt", "w")
1070            file2:write("Content written by second handle\n")
1071            file2:close()
1072
1073            -- Open it a third time and read
1074            local file3 = io.open("multi_open.txt", "r")
1075            local content = file3:read("*a")
1076            print("Final content:", content)
1077            file3:close()
1078        "#;
1079
1080        let test_session = TestSession::init(cx).await;
1081        let output = test_session.test_success(script, cx).await;
1082        assert_eq!(
1083            output,
1084            "Final content:\tContent written by second handle\n\n"
1085        );
1086        assert_eq!(
1087            test_session.diff(cx),
1088            vec![(
1089                PathBuf::from("multi_open.txt"),
1090                vec![(
1091                    "".to_string(),
1092                    "Content written by second handle\n".to_string()
1093                )]
1094            )]
1095        );
1096    }
1097
1098    #[gpui::test]
1099    async fn test_append_mode(cx: &mut TestAppContext) {
1100        let script = r#"
1101            -- Append more content
1102            file = io.open("file1.txt", "a")
1103            file:write("Appended content\n")
1104            file:close()
1105
1106            -- Add even more
1107            file = io.open("file1.txt", "a")
1108            file:write("More appended content")
1109            file:close()
1110
1111            -- Read back to verify
1112            local read_file = io.open("file1.txt", "r")
1113            local content = read_file:read("*a")
1114            print("Content after appends:", content)
1115            read_file:close()
1116        "#;
1117
1118        let test_session = TestSession::init(cx).await;
1119        let output = test_session.test_success(script, cx).await;
1120        assert_eq!(
1121            output,
1122            "Content after appends:\tHello world!\nAppended content\nMore appended content\n"
1123        );
1124        assert_eq!(
1125            test_session.diff(cx),
1126            vec![(
1127                PathBuf::from("file1.txt"),
1128                vec![(
1129                    "".to_string(),
1130                    "Appended content\nMore appended content".to_string()
1131                )]
1132            )]
1133        );
1134    }
1135
1136    #[gpui::test]
1137    async fn test_read_formats(cx: &mut TestAppContext) {
1138        let script = r#"
1139            local file = io.open("multiline.txt", "w")
1140            file:write("Line 1\nLine 2\nLine 3")
1141            file:close()
1142
1143            -- Test "*a" (all)
1144            local f = io.open("multiline.txt", "r")
1145            local all = f:read("*a")
1146            print("All:", all)
1147            f:close()
1148
1149            -- Test "*l" (line)
1150            f = io.open("multiline.txt", "r")
1151            local line1 = f:read("*l")
1152            local line2 = f:read("*l")
1153            local line3 = f:read("*l")
1154            print("Line 1:", line1)
1155            print("Line 2:", line2)
1156            print("Line 3:", line3)
1157            f:close()
1158
1159            -- Test "*L" (line with newline)
1160            f = io.open("multiline.txt", "r")
1161            local line_with_nl = f:read("*L")
1162            print("Line with newline length:", #line_with_nl)
1163            print("Last char:", string.byte(line_with_nl, #line_with_nl))
1164            f:close()
1165
1166            -- Test number of bytes
1167            f = io.open("multiline.txt", "r")
1168            local bytes5 = f:read(5)
1169            print("5 bytes:", bytes5)
1170            f:close()
1171        "#;
1172
1173        let test_session = TestSession::init(cx).await;
1174        let output = test_session.test_success(script, cx).await;
1175        println!("{}", &output);
1176        assert!(output.contains("All:\tLine 1\nLine 2\nLine 3"));
1177        assert!(output.contains("Line 1:\tLine 1"));
1178        assert!(output.contains("Line 2:\tLine 2"));
1179        assert!(output.contains("Line 3:\tLine 3"));
1180        assert!(output.contains("Line with newline length:\t7"));
1181        assert!(output.contains("Last char:\t10")); // LF
1182        assert!(output.contains("5 bytes:\tLine "));
1183        assert_eq!(
1184            test_session.diff(cx),
1185            vec![(
1186                PathBuf::from("multiline.txt"),
1187                vec![("".to_string(), "Line 1\nLine 2\nLine 3".to_string())]
1188            )]
1189        );
1190    }
1191
1192    // helpers
1193
1194    struct TestSession {
1195        session: Entity<ScriptingSession>,
1196    }
1197
1198    impl TestSession {
1199        async fn init(cx: &mut TestAppContext) -> Self {
1200            let settings_store = cx.update(SettingsStore::test);
1201            cx.set_global(settings_store);
1202            cx.update(Project::init_settings);
1203            cx.update(language::init);
1204
1205            let fs = FakeFs::new(cx.executor());
1206            fs.insert_tree(
1207                path!("/"),
1208                json!({
1209                    "file1.txt": "Hello world!\n",
1210                    "file2.txt": "Goodbye moon!\n"
1211                }),
1212            )
1213            .await;
1214
1215            let project = Project::test(fs.clone(), [Path::new(path!("/"))], cx).await;
1216            let session = cx.new(|cx| ScriptingSession::new(project, cx));
1217
1218            TestSession { session }
1219        }
1220
1221        async fn test_success(&self, source: &str, cx: &mut TestAppContext) -> String {
1222            let script_id = self.run_script(source, cx).await;
1223
1224            self.session.read_with(cx, |session, _cx| {
1225                let script = session.get(script_id);
1226                let stdout = script.stdout_snapshot();
1227
1228                if let ScriptState::Failed { error, .. } = &script.state {
1229                    panic!("Script failed:\n{}\n\n{}", error, stdout);
1230                }
1231
1232                stdout
1233            })
1234        }
1235
1236        fn diff(&self, cx: &mut TestAppContext) -> Vec<(PathBuf, Vec<(String, String)>)> {
1237            self.session.read_with(cx, |session, cx| {
1238                session
1239                    .changes_by_buffer
1240                    .iter()
1241                    .map(|(buffer, changes)| {
1242                        let snapshot = buffer.read(cx).snapshot();
1243                        let diff = changes.diff.read(cx);
1244                        let hunks = diff.hunks(&snapshot, cx);
1245                        let path = buffer.read(cx).file().unwrap().path().clone();
1246                        let diffs = hunks
1247                            .map(|hunk| {
1248                                let old_text = diff
1249                                    .base_text()
1250                                    .text_for_range(hunk.diff_base_byte_range)
1251                                    .collect::<String>();
1252                                let new_text =
1253                                    snapshot.text_for_range(hunk.range).collect::<String>();
1254                                (old_text, new_text)
1255                            })
1256                            .collect();
1257                        (path.to_path_buf(), diffs)
1258                    })
1259                    .collect()
1260            })
1261        }
1262
1263        async fn run_script(&self, source: &str, cx: &mut TestAppContext) -> ScriptId {
1264            let (script_id, task) = self
1265                .session
1266                .update(cx, |session, cx| session.run_script(source.to_string(), cx));
1267
1268            task.await;
1269
1270            script_id
1271        }
1272    }
1273}