scripting_session.rs

   1use anyhow::anyhow;
   2use buffer_diff::BufferDiff;
   3use collections::{HashMap, HashSet};
   4use futures::{
   5    channel::{mpsc, oneshot},
   6    pin_mut, SinkExt, StreamExt,
   7};
   8use gpui::{AppContext, AsyncApp, Context, Entity, Task, WeakEntity};
   9use language::Buffer;
  10use mlua::{ExternalResult, Lua, MultiValue, ObjectLike, Table, UserData, UserDataMethods};
  11use parking_lot::Mutex;
  12use project::{search::SearchQuery, Fs, Project, ProjectPath, WorktreeId};
  13use regex::Regex;
  14use std::{
  15    path::{Path, PathBuf},
  16    sync::Arc,
  17};
  18use util::{paths::PathMatcher, ResultExt};
  19
  20struct ForegroundFn(Box<dyn FnOnce(WeakEntity<ScriptingSession>, AsyncApp) + Send>);
  21
  22struct BufferChanges {
  23    diff: Entity<BufferDiff>,
  24    edit_ids: Vec<clock::Lamport>,
  25}
  26
  27pub struct ScriptingSession {
  28    project: Entity<Project>,
  29    scripts: Vec<Script>,
  30    changes_by_buffer: HashMap<Entity<Buffer>, BufferChanges>,
  31    foreground_fns_tx: mpsc::Sender<ForegroundFn>,
  32    _invoke_foreground_fns: Task<()>,
  33}
  34
  35impl ScriptingSession {
  36    pub fn new(project: Entity<Project>, cx: &mut Context<Self>) -> Self {
  37        let (foreground_fns_tx, mut foreground_fns_rx) = mpsc::channel(128);
  38        ScriptingSession {
  39            project,
  40            scripts: Vec::new(),
  41            changes_by_buffer: HashMap::default(),
  42            foreground_fns_tx,
  43            _invoke_foreground_fns: cx.spawn(|this, cx| async move {
  44                while let Some(foreground_fn) = foreground_fns_rx.next().await {
  45                    foreground_fn.0(this.clone(), cx.clone());
  46                }
  47            }),
  48        }
  49    }
  50
  51    pub fn changed_buffers(&self) -> impl ExactSizeIterator<Item = &Entity<Buffer>> {
  52        self.changes_by_buffer.keys()
  53    }
  54
  55    pub fn run_script(
  56        &mut self,
  57        script_src: String,
  58        cx: &mut Context<Self>,
  59    ) -> (ScriptId, Task<()>) {
  60        let id = ScriptId(self.scripts.len() as u32);
  61
  62        let stdout = Arc::new(Mutex::new(String::new()));
  63
  64        let script = Script {
  65            state: ScriptState::Running {
  66                stdout: stdout.clone(),
  67            },
  68        };
  69        self.scripts.push(script);
  70
  71        let task = self.run_lua(script_src, stdout, cx);
  72
  73        let task = cx.spawn(|session, mut cx| async move {
  74            let result = task.await;
  75
  76            session
  77                .update(&mut cx, |session, _cx| {
  78                    let script = session.get_mut(id);
  79                    let stdout = script.stdout_snapshot();
  80
  81                    script.state = match result {
  82                        Ok(()) => ScriptState::Succeeded { stdout },
  83                        Err(error) => ScriptState::Failed { stdout, error },
  84                    };
  85                })
  86                .log_err();
  87        });
  88
  89        (id, task)
  90    }
  91
  92    fn run_lua(
  93        &mut self,
  94        script: String,
  95        stdout: Arc<Mutex<String>>,
  96        cx: &mut Context<Self>,
  97    ) -> Task<anyhow::Result<()>> {
  98        const SANDBOX_PREAMBLE: &str = include_str!("sandbox_preamble.lua");
  99
 100        // TODO Honor all worktrees instead of the first one
 101        let worktree_info = self
 102            .project
 103            .read(cx)
 104            .visible_worktrees(cx)
 105            .next()
 106            .map(|worktree| {
 107                let worktree = worktree.read(cx);
 108                (worktree.id(), worktree.abs_path())
 109            });
 110
 111        let root_dir = worktree_info.as_ref().map(|(_, root)| root.clone());
 112
 113        let fs = self.project.read(cx).fs().clone();
 114        let foreground_fns_tx = self.foreground_fns_tx.clone();
 115
 116        let task = cx.background_spawn({
 117            let stdout = stdout.clone();
 118
 119            async move {
 120                let lua = Lua::new();
 121                lua.set_memory_limit(2 * 1024 * 1024 * 1024)?; // 2 GB
 122                let globals = lua.globals();
 123
 124                // Use the project root dir as the script's current working dir.
 125                if let Some(root_dir) = &root_dir {
 126                    if let Some(root_dir) = root_dir.to_str() {
 127                        globals.set("cwd", root_dir)?;
 128                    }
 129                }
 130
 131                globals.set(
 132                    "sb_print",
 133                    lua.create_function({
 134                        let stdout = stdout.clone();
 135                        move |_, args: MultiValue| Self::print(args, &stdout)
 136                    })?,
 137                )?;
 138                globals.set(
 139                    "search",
 140                    lua.create_async_function({
 141                        let foreground_fns_tx = foreground_fns_tx.clone();
 142                        let fs = fs.clone();
 143                        move |lua, regex| {
 144                            let mut foreground_fns_tx = foreground_fns_tx.clone();
 145                            let fs = fs.clone();
 146                            async move {
 147                                Self::search(&lua, &mut foreground_fns_tx, fs, regex)
 148                                    .await
 149                                    .into_lua_err()
 150                            }
 151                        }
 152                    })?,
 153                )?;
 154                globals.set(
 155                    "outline",
 156                    lua.create_async_function({
 157                        let root_dir = root_dir.clone();
 158                        let foreground_fns_tx = foreground_fns_tx.clone();
 159                        move |_lua, path| {
 160                            let mut foreground_fns_tx = foreground_fns_tx.clone();
 161                            let root_dir = root_dir.clone();
 162                            async move {
 163                                Self::outline(root_dir, &mut foreground_fns_tx, path)
 164                                    .await
 165                                    .into_lua_err()
 166                            }
 167                        }
 168                    })?,
 169                )?;
 170                globals.set(
 171                    "sb_io_open",
 172                    lua.create_async_function({
 173                        let worktree_info = worktree_info.clone();
 174                        let foreground_fns_tx = foreground_fns_tx.clone();
 175                        move |lua, (path_str, mode)| {
 176                            let worktree_info = worktree_info.clone();
 177                            let mut foreground_fns_tx = foreground_fns_tx.clone();
 178                            let fs = fs.clone();
 179                            async move {
 180                                Self::io_open(
 181                                    &lua,
 182                                    worktree_info,
 183                                    &mut foreground_fns_tx,
 184                                    fs,
 185                                    path_str,
 186                                    mode,
 187                                )
 188                                .await
 189                            }
 190                        }
 191                    })?,
 192                )?;
 193                globals.set("user_script", script)?;
 194
 195                lua.load(SANDBOX_PREAMBLE).exec_async().await?;
 196
 197                anyhow::Ok(())
 198            }
 199        });
 200
 201        task
 202    }
 203
 204    pub fn get(&self, script_id: ScriptId) -> &Script {
 205        &self.scripts[script_id.0 as usize]
 206    }
 207
 208    fn get_mut(&mut self, script_id: ScriptId) -> &mut Script {
 209        &mut self.scripts[script_id.0 as usize]
 210    }
 211
 212    /// Sandboxed print() function in Lua.
 213    fn print(args: MultiValue, stdout: &Mutex<String>) -> mlua::Result<()> {
 214        for (index, arg) in args.into_iter().enumerate() {
 215            // Lua's `print()` prints tab characters between each argument.
 216            if index > 0 {
 217                stdout.lock().push('\t');
 218            }
 219
 220            // If the argument's to_string() fails, have the whole function call fail.
 221            stdout.lock().push_str(&arg.to_string()?);
 222        }
 223        stdout.lock().push('\n');
 224
 225        Ok(())
 226    }
 227
 228    /// Sandboxed io.open() function in Lua.
 229    async fn io_open(
 230        lua: &Lua,
 231        worktree_info: Option<(WorktreeId, Arc<Path>)>,
 232        foreground_tx: &mut mpsc::Sender<ForegroundFn>,
 233        fs: Arc<dyn Fs>,
 234        path_str: String,
 235        mode: Option<String>,
 236    ) -> mlua::Result<(Option<Table>, String)> {
 237        let (worktree_id, root_dir) = worktree_info
 238            .ok_or_else(|| mlua::Error::runtime("cannot open file without a root directory"))?;
 239
 240        let mode = mode.unwrap_or_else(|| "r".to_string());
 241
 242        // Parse the mode string to determine read/write permissions
 243        let read_perm = mode.contains('r');
 244        let write_perm = mode.contains('w') || mode.contains('a') || mode.contains('+');
 245        let append = mode.contains('a');
 246        let truncate = mode.contains('w');
 247
 248        // This will be the Lua value returned from the `open` function.
 249        let file = lua.create_table()?;
 250
 251        // Store file metadata in the file
 252        file.set("__mode", mode.clone())?;
 253        file.set("__read_perm", read_perm)?;
 254        file.set("__write_perm", write_perm)?;
 255
 256        let path = match Self::parse_abs_path_in_root_dir(&root_dir, &path_str) {
 257            Ok(path) => path,
 258            Err(err) => return Ok((None, format!("{err}"))),
 259        };
 260
 261        let project_path = ProjectPath {
 262            worktree_id,
 263            path: Path::new(&path_str).into(),
 264        };
 265
 266        // flush / close method
 267        let flush_fn = {
 268            let project_path = project_path.clone();
 269            let foreground_tx = foreground_tx.clone();
 270            lua.create_async_function(move |_lua, file_userdata: mlua::Table| {
 271                let project_path = project_path.clone();
 272                let mut foreground_tx = foreground_tx.clone();
 273                async move {
 274                    Self::io_file_flush(file_userdata, project_path, &mut foreground_tx).await
 275                }
 276            })?
 277        };
 278        file.set("flush", flush_fn.clone())?;
 279        // We don't really hold files open, so we only need to flush on close
 280        file.set("close", flush_fn)?;
 281
 282        // If it's a directory, give it a custom read() and return early.
 283        if fs.is_dir(&path).await {
 284            return Self::io_file_dir(lua, fs, file, &path).await;
 285        }
 286
 287        let mut file_content = Vec::new();
 288
 289        if !truncate {
 290            // Try to read existing content if we're not truncating
 291            match Self::read_buffer(project_path.clone(), foreground_tx).await {
 292                Ok(content) => file_content = content.into_bytes(),
 293                Err(e) => return Ok((None, format!("Error reading file: {}", e))),
 294            }
 295        }
 296
 297        // If in append mode, position should be at the end
 298        let position = if append { file_content.len() } else { 0 };
 299        file.set("__position", position)?;
 300        file.set(
 301            "__content",
 302            lua.create_userdata(FileContent(Arc::new(Mutex::new(file_content))))?,
 303        )?;
 304
 305        // Create file methods
 306
 307        // read method
 308        let read_fn = lua.create_function(Self::io_file_read)?;
 309        file.set("read", read_fn)?;
 310
 311        // lines method
 312        let lines_fn = lua.create_function(Self::io_file_lines)?;
 313        file.set("lines", lines_fn)?;
 314
 315        // write method
 316        let write_fn = lua.create_function(Self::io_file_write)?;
 317        file.set("write", write_fn)?;
 318
 319        // If we got this far, the file was opened successfully
 320        Ok((Some(file), String::new()))
 321    }
 322
 323    async fn read_buffer(
 324        project_path: ProjectPath,
 325        foreground_tx: &mut mpsc::Sender<ForegroundFn>,
 326    ) -> anyhow::Result<String> {
 327        Self::run_foreground_fn(
 328            "read file from buffer",
 329            foreground_tx,
 330            Box::new(move |session, mut cx| {
 331                session.update(&mut cx, |session, cx| {
 332                    let open_buffer_task = session
 333                        .project
 334                        .update(cx, |project, cx| project.open_buffer(project_path, cx));
 335
 336                    cx.spawn(|_, cx| async move {
 337                        let buffer = open_buffer_task.await?;
 338
 339                        let text = buffer.read_with(&cx, |buffer, _cx| buffer.text())?;
 340                        Ok(text)
 341                    })
 342                })
 343            }),
 344        )
 345        .await??
 346        .await
 347    }
 348
 349    async fn io_file_flush(
 350        file_userdata: mlua::Table,
 351        project_path: ProjectPath,
 352        foreground_tx: &mut mpsc::Sender<ForegroundFn>,
 353    ) -> mlua::Result<bool> {
 354        let write_perm = file_userdata.get::<bool>("__write_perm")?;
 355
 356        if write_perm {
 357            let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
 358            let content_ref = content.borrow::<FileContent>()?;
 359            let text = {
 360                let mut content_vec = content_ref.0.lock();
 361                let content_vec = std::mem::take(&mut *content_vec);
 362                String::from_utf8(content_vec).into_lua_err()?
 363            };
 364
 365            Self::write_to_buffer(project_path, text, foreground_tx)
 366                .await
 367                .into_lua_err()?;
 368        }
 369
 370        Ok(true)
 371    }
 372
 373    async fn write_to_buffer(
 374        project_path: ProjectPath,
 375        text: String,
 376        foreground_tx: &mut mpsc::Sender<ForegroundFn>,
 377    ) -> anyhow::Result<()> {
 378        Self::run_foreground_fn(
 379            "write to buffer",
 380            foreground_tx,
 381            Box::new(move |session, mut cx| {
 382                session.update(&mut cx, |session, cx| {
 383                    let open_buffer_task = session
 384                        .project
 385                        .update(cx, |project, cx| project.open_buffer(project_path, cx));
 386
 387                    cx.spawn(move |session, mut cx| async move {
 388                        let buffer = open_buffer_task.await?;
 389
 390                        let diff = buffer
 391                            .update(&mut cx, |buffer, cx| buffer.diff(text, cx))?
 392                            .await;
 393
 394                        let edit_ids = buffer.update(&mut cx, |buffer, cx| {
 395                            buffer.finalize_last_transaction();
 396                            buffer.apply_diff(diff, cx);
 397                            let transaction = buffer.finalize_last_transaction();
 398                            transaction
 399                                .map_or(Vec::new(), |transaction| transaction.edit_ids.clone())
 400                        })?;
 401
 402                        session
 403                            .update(&mut cx, {
 404                                let buffer = buffer.clone();
 405
 406                                |session, cx| {
 407                                    session
 408                                        .project
 409                                        .update(cx, |project, cx| project.save_buffer(buffer, cx))
 410                                }
 411                            })?
 412                            .await?;
 413
 414                        let snapshot = buffer.read_with(&cx, |buffer, _| buffer.snapshot())?;
 415
 416                        // If we saved successfully, mark buffer as changed
 417                        let buffer_without_changes =
 418                            buffer.update(&mut cx, |buffer, cx| buffer.branch(cx))?;
 419                        session
 420                            .update(&mut cx, |session, cx| {
 421                                let changed_buffer = session
 422                                    .changes_by_buffer
 423                                    .entry(buffer)
 424                                    .or_insert_with(|| BufferChanges {
 425                                        diff: cx.new(|cx| BufferDiff::new(&snapshot, cx)),
 426                                        edit_ids: Vec::new(),
 427                                    });
 428                                changed_buffer.edit_ids.extend(edit_ids);
 429                                let operations_to_undo = changed_buffer
 430                                    .edit_ids
 431                                    .iter()
 432                                    .map(|edit_id| (*edit_id, u32::MAX))
 433                                    .collect::<HashMap<_, _>>();
 434                                buffer_without_changes.update(cx, |buffer, cx| {
 435                                    buffer.undo_operations(operations_to_undo, cx);
 436                                });
 437                                changed_buffer.diff.update(cx, |diff, cx| {
 438                                    diff.set_base_text(buffer_without_changes, snapshot.text, cx)
 439                                })
 440                            })?
 441                            .await?;
 442
 443                        Ok(())
 444                    })
 445                })
 446            }),
 447        )
 448        .await??
 449        .await
 450    }
 451
 452    async fn io_file_dir(
 453        lua: &Lua,
 454        fs: Arc<dyn Fs>,
 455        file: Table,
 456        path: &Path,
 457    ) -> mlua::Result<(Option<Table>, String)> {
 458        // Create a special directory handle
 459        file.set("__is_directory", true)?;
 460
 461        // Store directory entries
 462        let entries = match fs.read_dir(&path).await {
 463            Ok(entries) => {
 464                let mut entry_names = Vec::new();
 465
 466                // Process the stream of directory entries
 467                pin_mut!(entries);
 468                while let Some(Ok(entry_result)) = entries.next().await {
 469                    if let Some(file_name) = entry_result.file_name() {
 470                        entry_names.push(file_name.to_string_lossy().into_owned());
 471                    }
 472                }
 473
 474                entry_names
 475            }
 476            Err(e) => return Ok((None, format!("Error reading directory: {}", e))),
 477        };
 478
 479        // Save the list of entries
 480        file.set("__dir_entries", entries)?;
 481        file.set("__dir_position", 0usize)?;
 482
 483        // Create a directory-specific read function
 484        let read_fn = lua.create_function(|_lua, file_userdata: mlua::Table| {
 485            let position = file_userdata.get::<usize>("__dir_position")?;
 486            let entries = file_userdata.get::<Vec<String>>("__dir_entries")?;
 487
 488            if position >= entries.len() {
 489                return Ok(None); // No more entries
 490            }
 491
 492            let entry = entries[position].clone();
 493            file_userdata.set("__dir_position", position + 1)?;
 494
 495            Ok(Some(entry))
 496        })?;
 497        file.set("read", read_fn)?;
 498
 499        // If we got this far, the directory was opened successfully
 500        return Ok((Some(file), String::new()));
 501    }
 502
 503    fn io_file_read(
 504        lua: &Lua,
 505        (file_userdata, format): (Table, Option<mlua::Value>),
 506    ) -> mlua::Result<Option<mlua::String>> {
 507        let read_perm = file_userdata.get::<bool>("__read_perm")?;
 508        if !read_perm {
 509            return Err(mlua::Error::runtime("File not open for reading"));
 510        }
 511
 512        let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
 513        let position = file_userdata.get::<usize>("__position")?;
 514        let content_ref = content.borrow::<FileContent>()?;
 515        let content = content_ref.0.lock();
 516
 517        if position >= content.len() {
 518            return Ok(None); // EOF
 519        }
 520
 521        let (result, new_position) = match Self::io_file_read_format(format)? {
 522            FileReadFormat::All => {
 523                // Read entire file from current position
 524                let result = content[position..].to_vec();
 525                (Some(result), content.len())
 526            }
 527            FileReadFormat::Line => {
 528                if let Some(next_newline_ix) = content[position..].iter().position(|c| *c == b'\n')
 529                {
 530                    let mut line = content[position..position + next_newline_ix].to_vec();
 531                    if line.ends_with(b"\r") {
 532                        line.pop();
 533                    }
 534                    (Some(line), position + next_newline_ix + 1)
 535                } else if position < content.len() {
 536                    let line = content[position..].to_vec();
 537                    (Some(line), content.len())
 538                } else {
 539                    (None, position) // EOF
 540                }
 541            }
 542            FileReadFormat::LineWithLineFeed => {
 543                if position < content.len() {
 544                    let next_line_ix = content[position..]
 545                        .iter()
 546                        .position(|c| *c == b'\n')
 547                        .map_or(content.len(), |ix| position + ix + 1);
 548                    let line = content[position..next_line_ix].to_vec();
 549                    (Some(line), next_line_ix)
 550                } else {
 551                    (None, position) // EOF
 552                }
 553            }
 554            FileReadFormat::Bytes(n) => {
 555                let end = std::cmp::min(position + n, content.len());
 556                let result = content[position..end].to_vec();
 557                (Some(result), end)
 558            }
 559        };
 560
 561        // Update the position in the file userdata
 562        if new_position != position {
 563            file_userdata.set("__position", new_position)?;
 564        }
 565
 566        // Convert the result to a Lua string
 567        match result {
 568            Some(bytes) => Ok(Some(lua.create_string(bytes)?)),
 569            None => Ok(None),
 570        }
 571    }
 572
 573    fn io_file_lines(lua: &Lua, file_userdata: Table) -> mlua::Result<mlua::Function> {
 574        let read_perm = file_userdata.get::<bool>("__read_perm")?;
 575        if !read_perm {
 576            return Err(mlua::Error::runtime("File not open for reading"));
 577        }
 578
 579        lua.create_function::<_, _, mlua::Value>(move |lua, _: ()| {
 580            file_userdata.call_method("read", lua.create_string("*l")?)
 581        })
 582    }
 583
 584    fn io_file_read_format(format: Option<mlua::Value>) -> mlua::Result<FileReadFormat> {
 585        let format = match format {
 586            Some(mlua::Value::String(s)) => {
 587                let lossy_string = s.to_string_lossy();
 588                let format_str: &str = lossy_string.as_ref();
 589
 590                // Only consider the first 2 bytes, since it's common to pass e.g. "*all"  instead of "*a"
 591                match &format_str[0..2] {
 592                    "*a" => FileReadFormat::All,
 593                    "*l" => FileReadFormat::Line,
 594                    "*L" => FileReadFormat::LineWithLineFeed,
 595                    "*n" => {
 596                        // Try to parse as a number (number of bytes to read)
 597                        match format_str.parse::<usize>() {
 598                            Ok(n) => FileReadFormat::Bytes(n),
 599                            Err(_) => {
 600                                return Err(mlua::Error::runtime(format!(
 601                                    "Invalid format: {}",
 602                                    format_str
 603                                )))
 604                            }
 605                        }
 606                    }
 607                    _ => {
 608                        return Err(mlua::Error::runtime(format!(
 609                            "Unsupported format: {}",
 610                            format_str
 611                        )))
 612                    }
 613                }
 614            }
 615            Some(mlua::Value::Number(n)) => FileReadFormat::Bytes(n as usize),
 616            Some(mlua::Value::Integer(n)) => FileReadFormat::Bytes(n as usize),
 617            Some(value) => {
 618                return Err(mlua::Error::runtime(format!(
 619                    "Invalid file format {:?}",
 620                    value
 621                )))
 622            }
 623            None => FileReadFormat::Line, // Default is to read a line
 624        };
 625
 626        Ok(format)
 627    }
 628
 629    fn io_file_write(
 630        _lua: &Lua,
 631        (file_userdata, text): (Table, mlua::String),
 632    ) -> mlua::Result<bool> {
 633        let write_perm = file_userdata.get::<bool>("__write_perm")?;
 634        if !write_perm {
 635            return Err(mlua::Error::runtime("File not open for writing"));
 636        }
 637
 638        let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
 639        let position = file_userdata.get::<usize>("__position")?;
 640        let content_ref = content.borrow::<FileContent>()?;
 641        let mut content_vec = content_ref.0.lock();
 642
 643        let bytes = text.as_bytes();
 644
 645        // Ensure the vector has enough capacity
 646        if position + bytes.len() > content_vec.len() {
 647            content_vec.resize(position + bytes.len(), 0);
 648        }
 649
 650        // Write the bytes
 651        for (i, &byte) in bytes.iter().enumerate() {
 652            content_vec[position + i] = byte;
 653        }
 654
 655        // Update position
 656        let new_position = position + bytes.len();
 657        file_userdata.set("__position", new_position)?;
 658
 659        Ok(true)
 660    }
 661
 662    async fn search(
 663        lua: &Lua,
 664        foreground_tx: &mut mpsc::Sender<ForegroundFn>,
 665        fs: Arc<dyn Fs>,
 666        regex: String,
 667    ) -> anyhow::Result<Table> {
 668        // TODO: Allow specification of these options.
 669        let search_query = SearchQuery::regex(
 670            &regex,
 671            false,
 672            false,
 673            false,
 674            PathMatcher::default(),
 675            PathMatcher::default(),
 676            None,
 677        );
 678        let search_query = match search_query {
 679            Ok(query) => query,
 680            Err(e) => return Err(anyhow!("Invalid search query: {}", e)),
 681        };
 682
 683        // TODO: Should use `search_query.regex`. The tool description should also be updated,
 684        // as it specifies standard regex.
 685        let search_regex = match Regex::new(&regex) {
 686            Ok(re) => re,
 687            Err(e) => return Err(anyhow!("Invalid regex: {}", e)),
 688        };
 689
 690        let mut abs_paths_rx = Self::find_search_candidates(search_query, foreground_tx).await?;
 691
 692        let mut search_results: Vec<Table> = Vec::new();
 693        while let Some(path) = abs_paths_rx.next().await {
 694            // Skip files larger than 1MB
 695            if let Ok(Some(metadata)) = fs.metadata(&path).await {
 696                if metadata.len > 1_000_000 {
 697                    continue;
 698                }
 699            }
 700
 701            // Attempt to read the file as text
 702            if let Ok(content) = fs.load(&path).await {
 703                let mut matches = Vec::new();
 704
 705                // Find all regex matches in the content
 706                for capture in search_regex.find_iter(&content) {
 707                    matches.push(capture.as_str().to_string());
 708                }
 709
 710                // If we found matches, create a result entry
 711                if !matches.is_empty() {
 712                    let result_entry = lua.create_table()?;
 713                    result_entry.set("path", path.to_string_lossy().to_string())?;
 714
 715                    let matches_table = lua.create_table()?;
 716                    for (ix, m) in matches.iter().enumerate() {
 717                        matches_table.set(ix + 1, m.clone())?;
 718                    }
 719                    result_entry.set("matches", matches_table)?;
 720
 721                    search_results.push(result_entry);
 722                }
 723            }
 724        }
 725
 726        // Create a table to hold our results
 727        let results_table = lua.create_table()?;
 728        for (ix, entry) in search_results.into_iter().enumerate() {
 729            results_table.set(ix + 1, entry)?;
 730        }
 731
 732        Ok(results_table)
 733    }
 734
 735    async fn find_search_candidates(
 736        search_query: SearchQuery,
 737        foreground_tx: &mut mpsc::Sender<ForegroundFn>,
 738    ) -> anyhow::Result<mpsc::UnboundedReceiver<PathBuf>> {
 739        Self::run_foreground_fn(
 740            "finding search file candidates",
 741            foreground_tx,
 742            Box::new(move |session, mut cx| {
 743                session.update(&mut cx, |session, cx| {
 744                    session.project.update(cx, |project, cx| {
 745                        project.worktree_store().update(cx, |worktree_store, cx| {
 746                            // TODO: Better limit? For now this is the same as
 747                            // MAX_SEARCH_RESULT_FILES.
 748                            let limit = 5000;
 749                            // TODO: Providing non-empty open_entries can make this a bit more
 750                            // efficient as it can skip checking that these paths are textual.
 751                            let open_entries = HashSet::default();
 752                            let candidates = worktree_store.find_search_candidates(
 753                                search_query,
 754                                limit,
 755                                open_entries,
 756                                project.fs().clone(),
 757                                cx,
 758                            );
 759                            let (abs_paths_tx, abs_paths_rx) = mpsc::unbounded();
 760                            cx.spawn(|worktree_store, cx| async move {
 761                                pin_mut!(candidates);
 762
 763                                while let Some(project_path) = candidates.next().await {
 764                                    worktree_store.read_with(&cx, |worktree_store, cx| {
 765                                        if let Some(worktree) = worktree_store
 766                                            .worktree_for_id(project_path.worktree_id, cx)
 767                                        {
 768                                            if let Some(abs_path) = worktree
 769                                                .read(cx)
 770                                                .absolutize(&project_path.path)
 771                                                .log_err()
 772                                            {
 773                                                abs_paths_tx.unbounded_send(abs_path)?;
 774                                            }
 775                                        }
 776                                        anyhow::Ok(())
 777                                    })??;
 778                                }
 779                                anyhow::Ok(())
 780                            })
 781                            .detach();
 782                            abs_paths_rx
 783                        })
 784                    })
 785                })
 786            }),
 787        )
 788        .await?
 789    }
 790
 791    async fn outline(
 792        root_dir: Option<Arc<Path>>,
 793        foreground_tx: &mut mpsc::Sender<ForegroundFn>,
 794        path_str: String,
 795    ) -> anyhow::Result<String> {
 796        let root_dir = root_dir
 797            .ok_or_else(|| mlua::Error::runtime("cannot get outline without a root directory"))?;
 798        let path = Self::parse_abs_path_in_root_dir(&root_dir, &path_str)?;
 799        let outline = Self::run_foreground_fn(
 800            "getting code outline",
 801            foreground_tx,
 802            Box::new(move |session, cx| {
 803                cx.spawn(move |mut cx| async move {
 804                    // TODO: This will not use file content from `fs_changes`. It will also reflect
 805                    // user changes that have not been saved.
 806                    let buffer = session
 807                        .update(&mut cx, |session, cx| {
 808                            session
 809                                .project
 810                                .update(cx, |project, cx| project.open_local_buffer(&path, cx))
 811                        })?
 812                        .await?;
 813                    buffer.update(&mut cx, |buffer, _cx| {
 814                        if let Some(outline) = buffer.snapshot().outline(None) {
 815                            Ok(outline)
 816                        } else {
 817                            Err(anyhow!("No outline for file {path_str}"))
 818                        }
 819                    })
 820                })
 821            }),
 822        )
 823        .await?
 824        .await??;
 825
 826        Ok(outline
 827            .items
 828            .into_iter()
 829            .map(|item| {
 830                if item.text.contains('\n') {
 831                    log::error!("Outline item unexpectedly contains newline");
 832                }
 833                format!("{}{}", "  ".repeat(item.depth), item.text)
 834            })
 835            .collect::<Vec<String>>()
 836            .join("\n"))
 837    }
 838
 839    async fn run_foreground_fn<R: Send + 'static>(
 840        description: &str,
 841        foreground_tx: &mut mpsc::Sender<ForegroundFn>,
 842        function: Box<dyn FnOnce(WeakEntity<Self>, AsyncApp) -> R + Send>,
 843    ) -> anyhow::Result<R> {
 844        let (response_tx, response_rx) = oneshot::channel();
 845        let send_result = foreground_tx
 846            .send(ForegroundFn(Box::new(move |this, cx| {
 847                response_tx.send(function(this, cx)).ok();
 848            })))
 849            .await;
 850        match send_result {
 851            Ok(()) => (),
 852            Err(err) => {
 853                return Err(anyhow::Error::new(err).context(format!(
 854                    "Internal error while enqueuing work for {description}"
 855                )));
 856            }
 857        }
 858        match response_rx.await {
 859            Ok(result) => Ok(result),
 860            Err(oneshot::Canceled) => Err(anyhow!(
 861                "Internal error: response oneshot was canceled while {description}."
 862            )),
 863        }
 864    }
 865
 866    fn parse_abs_path_in_root_dir(root_dir: &Path, path_str: &str) -> anyhow::Result<PathBuf> {
 867        let path = Path::new(&path_str);
 868        if path.is_absolute() {
 869            // Check if path starts with root_dir prefix without resolving symlinks
 870            if path.starts_with(&root_dir) {
 871                Ok(path.to_path_buf())
 872            } else {
 873                Err(anyhow!(
 874                    "Error: Absolute path {} is outside the current working directory",
 875                    path_str
 876                ))
 877            }
 878        } else {
 879            // TODO: Does use of `../` break sandbox - is path canonicalization needed?
 880            Ok(root_dir.join(path))
 881        }
 882    }
 883}
 884
 885enum FileReadFormat {
 886    All,
 887    Line,
 888    LineWithLineFeed,
 889    Bytes(usize),
 890}
 891
 892struct FileContent(Arc<Mutex<Vec<u8>>>);
 893
 894impl UserData for FileContent {
 895    fn add_methods<M: UserDataMethods<Self>>(_methods: &mut M) {
 896        // FileContent doesn't have any methods so far.
 897    }
 898}
 899
 900#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
 901pub struct ScriptId(u32);
 902
 903pub struct Script {
 904    pub state: ScriptState,
 905}
 906
 907#[derive(Debug)]
 908pub enum ScriptState {
 909    Running {
 910        stdout: Arc<Mutex<String>>,
 911    },
 912    Succeeded {
 913        stdout: String,
 914    },
 915    Failed {
 916        stdout: String,
 917        error: anyhow::Error,
 918    },
 919}
 920
 921impl Script {
 922    /// If exited, returns a message with the output for the LLM
 923    pub fn output_message_for_llm(&self) -> Option<String> {
 924        match &self.state {
 925            ScriptState::Running { .. } => None,
 926            ScriptState::Succeeded { stdout } => {
 927                format!("Here's the script output:\n{}", stdout).into()
 928            }
 929            ScriptState::Failed { stdout, error } => format!(
 930                "The script failed with:\n{}\n\nHere's the output it managed to print:\n{}",
 931                error, stdout
 932            )
 933            .into(),
 934        }
 935    }
 936
 937    /// Get a snapshot of the script's stdout
 938    pub fn stdout_snapshot(&self) -> String {
 939        match &self.state {
 940            ScriptState::Running { stdout } => stdout.lock().clone(),
 941            ScriptState::Succeeded { stdout } => stdout.clone(),
 942            ScriptState::Failed { stdout, .. } => stdout.clone(),
 943        }
 944    }
 945}
 946
 947#[cfg(test)]
 948mod tests {
 949    use gpui::TestAppContext;
 950    use project::FakeFs;
 951    use serde_json::json;
 952    use settings::SettingsStore;
 953    use util::path;
 954
 955    use super::*;
 956
 957    #[gpui::test]
 958    async fn test_print(cx: &mut TestAppContext) {
 959        let script = r#"
 960            print("Hello", "world!")
 961            print("Goodbye", "moon!")
 962        "#;
 963
 964        let test_session = TestSession::init(cx).await;
 965        let output = test_session.test_success(script, cx).await;
 966        assert_eq!(output, "Hello\tworld!\nGoodbye\tmoon!\n");
 967    }
 968
 969    // search
 970
 971    #[gpui::test]
 972    async fn test_search(cx: &mut TestAppContext) {
 973        let script = r#"
 974            local results = search("world")
 975            for i, result in ipairs(results) do
 976                print("File: " .. result.path)
 977                print("Matches:")
 978                for j, match in ipairs(result.matches) do
 979                    print("  " .. match)
 980                end
 981            end
 982        "#;
 983
 984        let test_session = TestSession::init(cx).await;
 985        let output = test_session.test_success(script, cx).await;
 986        assert_eq!(
 987            output,
 988            concat!("File: ", path!("/file1.txt"), "\nMatches:\n  world\n")
 989        );
 990    }
 991
 992    // io.open
 993
 994    #[gpui::test]
 995    async fn test_open_and_read_file(cx: &mut TestAppContext) {
 996        let script = r#"
 997            local file = io.open("file1.txt", "r")
 998            local content = file:read()
 999            print("Content:", content)
1000            file:close()
1001        "#;
1002
1003        let test_session = TestSession::init(cx).await;
1004        let output = test_session.test_success(script, cx).await;
1005        assert_eq!(output, "Content:\tHello world!\n");
1006        assert_eq!(test_session.diff(cx), Vec::new());
1007    }
1008
1009    #[gpui::test]
1010    async fn test_lines_iterator(cx: &mut TestAppContext) {
1011        let script = r#"
1012            -- Create a test file with multiple lines
1013            local file = io.open("lines_test.txt", "w")
1014            file:write("Line 1\nLine 2\nLine 3\nLine 4\nLine 5")
1015            file:close()
1016
1017            -- Read it back using the lines iterator
1018            local read_file = io.open("lines_test.txt", "r")
1019            local count = 0
1020            for line in read_file:lines() do
1021                count = count + 1
1022                print(count .. ": " .. line)
1023            end
1024            read_file:close()
1025
1026            print("Total lines:", count)
1027        "#;
1028
1029        let test_session = TestSession::init(cx).await;
1030        let output = test_session.test_success(script, cx).await;
1031        assert_eq!(
1032            output,
1033            "1: Line 1\n2: Line 2\n3: Line 3\n4: Line 4\n5: Line 5\nTotal lines:\t5\n"
1034        );
1035    }
1036
1037    #[gpui::test]
1038    async fn test_read_write_roundtrip(cx: &mut TestAppContext) {
1039        let script = r#"
1040            local file = io.open("file1.txt", "w")
1041            file:write("This is new content")
1042            file:close()
1043
1044            -- Read back to verify
1045            local read_file = io.open("file1.txt", "r")
1046            local content = read_file:read("*a")
1047            print("Written content:", content)
1048            read_file:close()
1049        "#;
1050
1051        let test_session = TestSession::init(cx).await;
1052        let output = test_session.test_success(script, cx).await;
1053        assert_eq!(output, "Written content:\tThis is new content\n");
1054        assert_eq!(
1055            test_session.diff(cx),
1056            vec![(
1057                PathBuf::from("file1.txt"),
1058                vec![(
1059                    "Hello world!\n".to_string(),
1060                    "This is new content".to_string()
1061                )]
1062            )]
1063        );
1064    }
1065
1066    #[gpui::test]
1067    async fn test_multiple_writes(cx: &mut TestAppContext) {
1068        let script = r#"
1069            -- Test writing to a file multiple times
1070            local file = io.open("multiwrite.txt", "w")
1071            file:write("First line\n")
1072            file:write("Second line\n")
1073            file:write("Third line")
1074            file:close()
1075
1076            -- Read back to verify
1077            local read_file = io.open("multiwrite.txt", "r")
1078            if read_file then
1079                local content = read_file:read("*a")
1080                print("Full content:", content)
1081                read_file:close()
1082            end
1083        "#;
1084
1085        let test_session = TestSession::init(cx).await;
1086        let output = test_session.test_success(script, cx).await;
1087        assert_eq!(
1088            output,
1089            "Full content:\tFirst line\nSecond line\nThird line\n"
1090        );
1091        assert_eq!(
1092            test_session.diff(cx),
1093            vec![(
1094                PathBuf::from("multiwrite.txt"),
1095                vec![(
1096                    "".to_string(),
1097                    "First line\nSecond line\nThird line".to_string()
1098                )]
1099            )]
1100        );
1101    }
1102
1103    #[gpui::test]
1104    async fn test_multiple_writes_diff_handles(cx: &mut TestAppContext) {
1105        let script = r#"
1106            -- Write to a file
1107            local file1 = io.open("multi_open.txt", "w")
1108            file1:write("Content written by first handle\n")
1109            file1:close()
1110
1111            -- Open it again and add more content
1112            local file2 = io.open("multi_open.txt", "w")
1113            file2:write("Content written by second handle\n")
1114            file2:close()
1115
1116            -- Open it a third time and read
1117            local file3 = io.open("multi_open.txt", "r")
1118            local content = file3:read("*a")
1119            print("Final content:", content)
1120            file3:close()
1121        "#;
1122
1123        let test_session = TestSession::init(cx).await;
1124        let output = test_session.test_success(script, cx).await;
1125        assert_eq!(
1126            output,
1127            "Final content:\tContent written by second handle\n\n"
1128        );
1129        assert_eq!(
1130            test_session.diff(cx),
1131            vec![(
1132                PathBuf::from("multi_open.txt"),
1133                vec![(
1134                    "".to_string(),
1135                    "Content written by second handle\n".to_string()
1136                )]
1137            )]
1138        );
1139    }
1140
1141    #[gpui::test]
1142    async fn test_append_mode(cx: &mut TestAppContext) {
1143        let script = r#"
1144            -- Append more content
1145            file = io.open("file1.txt", "a")
1146            file:write("Appended content\n")
1147            file:close()
1148
1149            -- Add even more
1150            file = io.open("file1.txt", "a")
1151            file:write("More appended content")
1152            file:close()
1153
1154            -- Read back to verify
1155            local read_file = io.open("file1.txt", "r")
1156            local content = read_file:read("*a")
1157            print("Content after appends:", content)
1158            read_file:close()
1159        "#;
1160
1161        let test_session = TestSession::init(cx).await;
1162        let output = test_session.test_success(script, cx).await;
1163        assert_eq!(
1164            output,
1165            "Content after appends:\tHello world!\nAppended content\nMore appended content\n"
1166        );
1167        assert_eq!(
1168            test_session.diff(cx),
1169            vec![(
1170                PathBuf::from("file1.txt"),
1171                vec![(
1172                    "".to_string(),
1173                    "Appended content\nMore appended content".to_string()
1174                )]
1175            )]
1176        );
1177    }
1178
1179    #[gpui::test]
1180    async fn test_read_formats(cx: &mut TestAppContext) {
1181        let script = r#"
1182            local file = io.open("multiline.txt", "w")
1183            file:write("Line 1\nLine 2\nLine 3")
1184            file:close()
1185
1186            -- Test "*a" (all)
1187            local f = io.open("multiline.txt", "r")
1188            local all = f:read("*a")
1189            print("All:", all)
1190            f:close()
1191
1192            -- Test "*l" (line)
1193            f = io.open("multiline.txt", "r")
1194            local line1 = f:read("*l")
1195            local line2 = f:read("*l")
1196            local line3 = f:read("*l")
1197            print("Line 1:", line1)
1198            print("Line 2:", line2)
1199            print("Line 3:", line3)
1200            f:close()
1201
1202            -- Test "*L" (line with newline)
1203            f = io.open("multiline.txt", "r")
1204            local line_with_nl = f:read("*L")
1205            print("Line with newline length:", #line_with_nl)
1206            print("Last char:", string.byte(line_with_nl, #line_with_nl))
1207            f:close()
1208
1209            -- Test number of bytes
1210            f = io.open("multiline.txt", "r")
1211            local bytes5 = f:read(5)
1212            print("5 bytes:", bytes5)
1213            f:close()
1214        "#;
1215
1216        let test_session = TestSession::init(cx).await;
1217        let output = test_session.test_success(script, cx).await;
1218        println!("{}", &output);
1219        assert!(output.contains("All:\tLine 1\nLine 2\nLine 3"));
1220        assert!(output.contains("Line 1:\tLine 1"));
1221        assert!(output.contains("Line 2:\tLine 2"));
1222        assert!(output.contains("Line 3:\tLine 3"));
1223        assert!(output.contains("Line with newline length:\t7"));
1224        assert!(output.contains("Last char:\t10")); // LF
1225        assert!(output.contains("5 bytes:\tLine "));
1226        assert_eq!(
1227            test_session.diff(cx),
1228            vec![(
1229                PathBuf::from("multiline.txt"),
1230                vec![("".to_string(), "Line 1\nLine 2\nLine 3".to_string())]
1231            )]
1232        );
1233    }
1234
1235    // helpers
1236
1237    struct TestSession {
1238        session: Entity<ScriptingSession>,
1239    }
1240
1241    impl TestSession {
1242        async fn init(cx: &mut TestAppContext) -> Self {
1243            let settings_store = cx.update(SettingsStore::test);
1244            cx.set_global(settings_store);
1245            cx.update(Project::init_settings);
1246            cx.update(language::init);
1247
1248            let fs = FakeFs::new(cx.executor());
1249            fs.insert_tree(
1250                path!("/"),
1251                json!({
1252                    "file1.txt": "Hello world!\n",
1253                    "file2.txt": "Goodbye moon!\n"
1254                }),
1255            )
1256            .await;
1257
1258            let project = Project::test(fs.clone(), [Path::new(path!("/"))], cx).await;
1259            let session = cx.new(|cx| ScriptingSession::new(project, cx));
1260
1261            TestSession { session }
1262        }
1263
1264        async fn test_success(&self, source: &str, cx: &mut TestAppContext) -> String {
1265            let script_id = self.run_script(source, cx).await;
1266
1267            self.session.read_with(cx, |session, _cx| {
1268                let script = session.get(script_id);
1269                let stdout = script.stdout_snapshot();
1270
1271                if let ScriptState::Failed { error, .. } = &script.state {
1272                    panic!("Script failed:\n{}\n\n{}", error, stdout);
1273                }
1274
1275                stdout
1276            })
1277        }
1278
1279        fn diff(&self, cx: &mut TestAppContext) -> Vec<(PathBuf, Vec<(String, String)>)> {
1280            self.session.read_with(cx, |session, cx| {
1281                session
1282                    .changes_by_buffer
1283                    .iter()
1284                    .map(|(buffer, changes)| {
1285                        let snapshot = buffer.read(cx).snapshot();
1286                        let diff = changes.diff.read(cx);
1287                        let hunks = diff.hunks(&snapshot, cx);
1288                        let path = buffer.read(cx).file().unwrap().path().clone();
1289                        let diffs = hunks
1290                            .map(|hunk| {
1291                                let old_text = diff
1292                                    .base_text()
1293                                    .text_for_range(hunk.diff_base_byte_range)
1294                                    .collect::<String>();
1295                                let new_text =
1296                                    snapshot.text_for_range(hunk.range).collect::<String>();
1297                                (old_text, new_text)
1298                            })
1299                            .collect();
1300                        (path.to_path_buf(), diffs)
1301                    })
1302                    .collect()
1303            })
1304        }
1305
1306        async fn run_script(&self, source: &str, cx: &mut TestAppContext) -> ScriptId {
1307            let (script_id, task) = self
1308                .session
1309                .update(cx, |session, cx| session.run_script(source.to_string(), cx));
1310
1311            task.await;
1312
1313            script_id
1314        }
1315    }
1316}