session.rs

  1use anyhow::Result;
  2use collections::{HashMap, HashSet};
  3use futures::{
  4    channel::{mpsc, oneshot},
  5    pin_mut, SinkExt, StreamExt,
  6};
  7use gpui::{AppContext, AsyncApp, Context, Entity, Task, WeakEntity};
  8use mlua::{Lua, MultiValue, Table, UserData, UserDataMethods};
  9use parking_lot::Mutex;
 10use project::{search::SearchQuery, Fs, Project};
 11use regex::Regex;
 12use std::{
 13    cell::RefCell,
 14    path::{Path, PathBuf},
 15    sync::Arc,
 16};
 17use util::{paths::PathMatcher, ResultExt};
 18
 19pub struct ScriptOutput {
 20    pub stdout: String,
 21}
 22
 23struct ForegroundFn(Box<dyn FnOnce(WeakEntity<Session>, AsyncApp) + Send>);
 24
 25pub struct Session {
 26    project: Entity<Project>,
 27    // TODO Remove this
 28    fs_changes: Arc<Mutex<HashMap<PathBuf, Vec<u8>>>>,
 29    foreground_fns_tx: mpsc::Sender<ForegroundFn>,
 30    _invoke_foreground_fns: Task<()>,
 31}
 32
 33impl Session {
 34    pub fn new(project: Entity<Project>, cx: &mut Context<Self>) -> Self {
 35        let (foreground_fns_tx, mut foreground_fns_rx) = mpsc::channel(128);
 36        Session {
 37            project,
 38            fs_changes: Arc::new(Mutex::new(HashMap::default())),
 39            foreground_fns_tx,
 40            _invoke_foreground_fns: cx.spawn(|this, cx| async move {
 41                while let Some(foreground_fn) = foreground_fns_rx.next().await {
 42                    foreground_fn.0(this.clone(), cx.clone());
 43                }
 44            }),
 45        }
 46    }
 47
 48    /// Runs a Lua script in a sandboxed environment and returns the printed lines
 49    pub fn run_script(
 50        &mut self,
 51        script: String,
 52        cx: &mut Context<Self>,
 53    ) -> Task<Result<ScriptOutput>> {
 54        const SANDBOX_PREAMBLE: &str = include_str!("sandbox_preamble.lua");
 55
 56        // TODO Remove fs_changes
 57        let fs_changes = self.fs_changes.clone();
 58        // TODO Honor all worktrees instead of the first one
 59        let root_dir = self
 60            .project
 61            .read(cx)
 62            .visible_worktrees(cx)
 63            .next()
 64            .map(|worktree| worktree.read(cx).abs_path());
 65        let fs = self.project.read(cx).fs().clone();
 66        let foreground_fns_tx = self.foreground_fns_tx.clone();
 67        cx.background_spawn(async move {
 68            let lua = Lua::new();
 69            lua.set_memory_limit(2 * 1024 * 1024 * 1024)?; // 2 GB
 70            let globals = lua.globals();
 71            let stdout = Arc::new(Mutex::new(String::new()));
 72            globals.set(
 73                "sb_print",
 74                lua.create_function({
 75                    let stdout = stdout.clone();
 76                    move |_, args: MultiValue| Self::print(args, &stdout)
 77                })?,
 78            )?;
 79            globals.set(
 80                "search",
 81                lua.create_async_function({
 82                    let foreground_fns_tx = foreground_fns_tx.clone();
 83                    let fs = fs.clone();
 84                    move |lua, regex| {
 85                        Self::search(lua, foreground_fns_tx.clone(), fs.clone(), regex)
 86                    }
 87                })?,
 88            )?;
 89            globals.set(
 90                "sb_io_open",
 91                lua.create_function({
 92                    let fs_changes = fs_changes.clone();
 93                    let root_dir = root_dir.clone();
 94                    move |lua, (path_str, mode)| {
 95                        Self::io_open(&lua, &fs_changes, root_dir.as_ref(), path_str, mode)
 96                    }
 97                })?,
 98            )?;
 99            globals.set("user_script", script)?;
100
101            lua.load(SANDBOX_PREAMBLE).exec_async().await?;
102
103            // Drop Lua instance to decrement reference count.
104            drop(lua);
105
106            let stdout = Arc::try_unwrap(stdout)
107                .expect("no more references to stdout")
108                .into_inner();
109            Ok(ScriptOutput { stdout })
110        })
111    }
112
113    /// Sandboxed print() function in Lua.
114    fn print(args: MultiValue, stdout: &Mutex<String>) -> mlua::Result<()> {
115        for (index, arg) in args.into_iter().enumerate() {
116            // Lua's `print()` prints tab characters between each argument.
117            if index > 0 {
118                stdout.lock().push('\t');
119            }
120
121            // If the argument's to_string() fails, have the whole function call fail.
122            stdout.lock().push_str(&arg.to_string()?);
123        }
124        stdout.lock().push('\n');
125
126        Ok(())
127    }
128
129    /// Sandboxed io.open() function in Lua.
130    fn io_open(
131        lua: &Lua,
132        fs_changes: &Arc<Mutex<HashMap<PathBuf, Vec<u8>>>>,
133        root_dir: Option<&Arc<Path>>,
134        path_str: String,
135        mode: Option<String>,
136    ) -> mlua::Result<(Option<Table>, String)> {
137        let root_dir = root_dir
138            .ok_or_else(|| mlua::Error::runtime("cannot open file without a root directory"))?;
139
140        let mode = mode.unwrap_or_else(|| "r".to_string());
141
142        // Parse the mode string to determine read/write permissions
143        let read_perm = mode.contains('r');
144        let write_perm = mode.contains('w') || mode.contains('a') || mode.contains('+');
145        let append = mode.contains('a');
146        let truncate = mode.contains('w');
147
148        // This will be the Lua value returned from the `open` function.
149        let file = lua.create_table()?;
150
151        // Store file metadata in the file
152        file.set("__path", path_str.clone())?;
153        file.set("__mode", mode.clone())?;
154        file.set("__read_perm", read_perm)?;
155        file.set("__write_perm", write_perm)?;
156
157        // Sandbox the path; it must be within root_dir
158        let path: PathBuf = {
159            let rust_path = Path::new(&path_str);
160
161            // Get absolute path
162            if rust_path.is_absolute() {
163                // Check if path starts with root_dir prefix without resolving symlinks
164                if !rust_path.starts_with(&root_dir) {
165                    return Ok((
166                        None,
167                        format!(
168                            "Error: Absolute path {} is outside the current working directory",
169                            path_str
170                        ),
171                    ));
172                }
173                rust_path.to_path_buf()
174            } else {
175                // Make relative path absolute relative to cwd
176                root_dir.join(rust_path)
177            }
178        };
179
180        // close method
181        let close_fn = {
182            let fs_changes = fs_changes.clone();
183            lua.create_function(move |_lua, file_userdata: mlua::Table| {
184                let write_perm = file_userdata.get::<bool>("__write_perm")?;
185                let path = file_userdata.get::<String>("__path")?;
186
187                if write_perm {
188                    // When closing a writable file, record the content
189                    let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
190                    let content_ref = content.borrow::<FileContent>()?;
191                    let content_vec = content_ref.0.borrow();
192
193                    // Don't actually write to disk; instead, just update fs_changes.
194                    let path_buf = PathBuf::from(&path);
195                    fs_changes
196                        .lock()
197                        .insert(path_buf.clone(), content_vec.clone());
198                }
199
200                Ok(true)
201            })?
202        };
203        file.set("close", close_fn)?;
204
205        // If it's a directory, give it a custom read() and return early.
206        if path.is_dir() {
207            // TODO handle the case where we changed it in the in-memory fs
208
209            // Create a special directory handle
210            file.set("__is_directory", true)?;
211
212            // Store directory entries
213            let entries = match std::fs::read_dir(&path) {
214                Ok(entries) => {
215                    let mut entry_names = Vec::new();
216                    for entry in entries.flatten() {
217                        entry_names.push(entry.file_name().to_string_lossy().into_owned());
218                    }
219                    entry_names
220                }
221                Err(e) => return Ok((None, format!("Error reading directory: {}", e))),
222            };
223
224            // Save the list of entries
225            file.set("__dir_entries", entries)?;
226            file.set("__dir_position", 0usize)?;
227
228            // Create a directory-specific read function
229            let read_fn = lua.create_function(|_lua, file_userdata: mlua::Table| {
230                let position = file_userdata.get::<usize>("__dir_position")?;
231                let entries = file_userdata.get::<Vec<String>>("__dir_entries")?;
232
233                if position >= entries.len() {
234                    return Ok(None); // No more entries
235                }
236
237                let entry = entries[position].clone();
238                file_userdata.set("__dir_position", position + 1)?;
239
240                Ok(Some(entry))
241            })?;
242            file.set("read", read_fn)?;
243
244            // If we got this far, the directory was opened successfully
245            return Ok((Some(file), String::new()));
246        }
247
248        let fs_changes_map = fs_changes.lock();
249
250        let is_in_changes = fs_changes_map.contains_key(&path);
251        let file_exists = is_in_changes || path.exists();
252        let mut file_content = Vec::new();
253
254        if file_exists && !truncate {
255            if is_in_changes {
256                file_content = fs_changes_map.get(&path).unwrap().clone();
257            } else {
258                // Try to read existing content if file exists and we're not truncating
259                match std::fs::read(&path) {
260                    Ok(content) => file_content = content,
261                    Err(e) => return Ok((None, format!("Error reading file: {}", e))),
262                }
263            }
264        }
265
266        drop(fs_changes_map); // Unlock the fs_changes mutex.
267
268        // If in append mode, position should be at the end
269        let position = if append && file_exists {
270            file_content.len()
271        } else {
272            0
273        };
274        file.set("__position", position)?;
275        file.set(
276            "__content",
277            lua.create_userdata(FileContent(RefCell::new(file_content)))?,
278        )?;
279
280        // Create file methods
281
282        // read method
283        let read_fn = {
284            lua.create_function(
285                |_lua, (file_userdata, format): (mlua::Table, Option<mlua::Value>)| {
286                    let read_perm = file_userdata.get::<bool>("__read_perm")?;
287                    if !read_perm {
288                        return Err(mlua::Error::runtime("File not open for reading"));
289                    }
290
291                    let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
292                    let mut position = file_userdata.get::<usize>("__position")?;
293                    let content_ref = content.borrow::<FileContent>()?;
294                    let content_vec = content_ref.0.borrow();
295
296                    if position >= content_vec.len() {
297                        return Ok(None); // EOF
298                    }
299
300                    match format {
301                        Some(mlua::Value::String(s)) => {
302                            let lossy_string = s.to_string_lossy();
303                            let format_str: &str = lossy_string.as_ref();
304
305                            // Only consider the first 2 bytes, since it's common to pass e.g. "*all"  instead of "*a"
306                            match &format_str[0..2] {
307                                "*a" => {
308                                    // Read entire file from current position
309                                    let result = String::from_utf8_lossy(&content_vec[position..])
310                                        .to_string();
311                                    position = content_vec.len();
312                                    file_userdata.set("__position", position)?;
313                                    Ok(Some(result))
314                                }
315                                "*l" => {
316                                    // Read next line
317                                    let mut line = Vec::new();
318                                    let mut found_newline = false;
319
320                                    while position < content_vec.len() {
321                                        let byte = content_vec[position];
322                                        position += 1;
323
324                                        if byte == b'\n' {
325                                            found_newline = true;
326                                            break;
327                                        }
328
329                                        // Skip \r in \r\n sequence but add it if it's alone
330                                        if byte == b'\r' {
331                                            if position < content_vec.len()
332                                                && content_vec[position] == b'\n'
333                                            {
334                                                position += 1;
335                                                found_newline = true;
336                                                break;
337                                            }
338                                        }
339
340                                        line.push(byte);
341                                    }
342
343                                    file_userdata.set("__position", position)?;
344
345                                    if !found_newline
346                                        && line.is_empty()
347                                        && position >= content_vec.len()
348                                    {
349                                        return Ok(None); // EOF
350                                    }
351
352                                    let result = String::from_utf8_lossy(&line).to_string();
353                                    Ok(Some(result))
354                                }
355                                "*n" => {
356                                    // Try to parse as a number (number of bytes to read)
357                                    match format_str.parse::<usize>() {
358                                        Ok(n) => {
359                                            let end =
360                                                std::cmp::min(position + n, content_vec.len());
361                                            let bytes = &content_vec[position..end];
362                                            let result = String::from_utf8_lossy(bytes).to_string();
363                                            position = end;
364                                            file_userdata.set("__position", position)?;
365                                            Ok(Some(result))
366                                        }
367                                        Err(_) => Err(mlua::Error::runtime(format!(
368                                            "Invalid format: {}",
369                                            format_str
370                                        ))),
371                                    }
372                                }
373                                "*L" => {
374                                    // Read next line keeping the end of line
375                                    let mut line = Vec::new();
376
377                                    while position < content_vec.len() {
378                                        let byte = content_vec[position];
379                                        position += 1;
380
381                                        line.push(byte);
382
383                                        if byte == b'\n' {
384                                            break;
385                                        }
386
387                                        // If we encounter a \r, add it and check if the next is \n
388                                        if byte == b'\r'
389                                            && position < content_vec.len()
390                                            && content_vec[position] == b'\n'
391                                        {
392                                            line.push(content_vec[position]);
393                                            position += 1;
394                                            break;
395                                        }
396                                    }
397
398                                    file_userdata.set("__position", position)?;
399
400                                    if line.is_empty() && position >= content_vec.len() {
401                                        return Ok(None); // EOF
402                                    }
403
404                                    let result = String::from_utf8_lossy(&line).to_string();
405                                    Ok(Some(result))
406                                }
407                                _ => Err(mlua::Error::runtime(format!(
408                                    "Unsupported format: {}",
409                                    format_str
410                                ))),
411                            }
412                        }
413                        Some(mlua::Value::Number(n)) => {
414                            // Read n bytes
415                            let n = n as usize;
416                            let end = std::cmp::min(position + n, content_vec.len());
417                            let bytes = &content_vec[position..end];
418                            let result = String::from_utf8_lossy(bytes).to_string();
419                            position = end;
420                            file_userdata.set("__position", position)?;
421                            Ok(Some(result))
422                        }
423                        Some(_) => Err(mlua::Error::runtime("Invalid format")),
424                        None => {
425                            // Default is to read a line
426                            let mut line = Vec::new();
427                            let mut found_newline = false;
428
429                            while position < content_vec.len() {
430                                let byte = content_vec[position];
431                                position += 1;
432
433                                if byte == b'\n' {
434                                    found_newline = true;
435                                    break;
436                                }
437
438                                // Handle \r\n
439                                if byte == b'\r' {
440                                    if position < content_vec.len()
441                                        && content_vec[position] == b'\n'
442                                    {
443                                        position += 1;
444                                        found_newline = true;
445                                        break;
446                                    }
447                                }
448
449                                line.push(byte);
450                            }
451
452                            file_userdata.set("__position", position)?;
453
454                            if !found_newline && line.is_empty() && position >= content_vec.len() {
455                                return Ok(None); // EOF
456                            }
457
458                            let result = String::from_utf8_lossy(&line).to_string();
459                            Ok(Some(result))
460                        }
461                    }
462                },
463            )?
464        };
465        file.set("read", read_fn)?;
466
467        // write method
468        let write_fn = {
469            let fs_changes = fs_changes.clone();
470
471            lua.create_function(move |_lua, (file_userdata, text): (mlua::Table, String)| {
472                let write_perm = file_userdata.get::<bool>("__write_perm")?;
473                if !write_perm {
474                    return Err(mlua::Error::runtime("File not open for writing"));
475                }
476
477                let content = file_userdata.get::<mlua::AnyUserData>("__content")?;
478                let position = file_userdata.get::<usize>("__position")?;
479                let content_ref = content.borrow::<FileContent>()?;
480                let mut content_vec = content_ref.0.borrow_mut();
481
482                let bytes = text.as_bytes();
483
484                // Ensure the vector has enough capacity
485                if position + bytes.len() > content_vec.len() {
486                    content_vec.resize(position + bytes.len(), 0);
487                }
488
489                // Write the bytes
490                for (i, &byte) in bytes.iter().enumerate() {
491                    content_vec[position + i] = byte;
492                }
493
494                // Update position
495                let new_position = position + bytes.len();
496                file_userdata.set("__position", new_position)?;
497
498                // Update fs_changes
499                let path = file_userdata.get::<String>("__path")?;
500                let path_buf = PathBuf::from(path);
501                fs_changes.lock().insert(path_buf, content_vec.clone());
502
503                Ok(true)
504            })?
505        };
506        file.set("write", write_fn)?;
507
508        // If we got this far, the file was opened successfully
509        Ok((Some(file), String::new()))
510    }
511
512    async fn search(
513        lua: Lua,
514        mut foreground_tx: mpsc::Sender<ForegroundFn>,
515        fs: Arc<dyn Fs>,
516        regex: String,
517    ) -> mlua::Result<Table> {
518        // TODO: Allow specification of these options.
519        let search_query = SearchQuery::regex(
520            &regex,
521            false,
522            false,
523            false,
524            PathMatcher::default(),
525            PathMatcher::default(),
526            None,
527        );
528        let search_query = match search_query {
529            Ok(query) => query,
530            Err(e) => return Err(mlua::Error::runtime(format!("Invalid search query: {}", e))),
531        };
532
533        // TODO: Should use `search_query.regex`. The tool description should also be updated,
534        // as it specifies standard regex.
535        let search_regex = match Regex::new(&regex) {
536            Ok(re) => re,
537            Err(e) => return Err(mlua::Error::runtime(format!("Invalid regex: {}", e))),
538        };
539
540        let mut abs_paths_rx =
541            Self::find_search_candidates(search_query, &mut foreground_tx).await?;
542
543        let mut search_results: Vec<Table> = Vec::new();
544        while let Some(path) = abs_paths_rx.next().await {
545            // Skip files larger than 1MB
546            if let Ok(Some(metadata)) = fs.metadata(&path).await {
547                if metadata.len > 1_000_000 {
548                    continue;
549                }
550            }
551
552            // Attempt to read the file as text
553            if let Ok(content) = fs.load(&path).await {
554                let mut matches = Vec::new();
555
556                // Find all regex matches in the content
557                for capture in search_regex.find_iter(&content) {
558                    matches.push(capture.as_str().to_string());
559                }
560
561                // If we found matches, create a result entry
562                if !matches.is_empty() {
563                    let result_entry = lua.create_table()?;
564                    result_entry.set("path", path.to_string_lossy().to_string())?;
565
566                    let matches_table = lua.create_table()?;
567                    for (ix, m) in matches.iter().enumerate() {
568                        matches_table.set(ix + 1, m.clone())?;
569                    }
570                    result_entry.set("matches", matches_table)?;
571
572                    search_results.push(result_entry);
573                }
574            }
575        }
576
577        // Create a table to hold our results
578        let results_table = lua.create_table()?;
579        for (ix, entry) in search_results.into_iter().enumerate() {
580            results_table.set(ix + 1, entry)?;
581        }
582
583        Ok(results_table)
584    }
585
586    async fn find_search_candidates(
587        search_query: SearchQuery,
588        foreground_tx: &mut mpsc::Sender<ForegroundFn>,
589    ) -> mlua::Result<mpsc::UnboundedReceiver<PathBuf>> {
590        Self::run_foreground_fn(
591            "finding search file candidates",
592            foreground_tx,
593            Box::new(move |session, mut cx| {
594                session.update(&mut cx, |session, cx| {
595                    session.project.update(cx, |project, cx| {
596                        project.worktree_store().update(cx, |worktree_store, cx| {
597                            // TODO: Better limit? For now this is the same as
598                            // MAX_SEARCH_RESULT_FILES.
599                            let limit = 5000;
600                            // TODO: Providing non-empty open_entries can make this a bit more
601                            // efficient as it can skip checking that these paths are textual.
602                            let open_entries = HashSet::default();
603                            let candidates = worktree_store.find_search_candidates(
604                                search_query,
605                                limit,
606                                open_entries,
607                                project.fs().clone(),
608                                cx,
609                            );
610                            let (abs_paths_tx, abs_paths_rx) = mpsc::unbounded();
611                            cx.spawn(|worktree_store, cx| async move {
612                                pin_mut!(candidates);
613
614                                while let Some(project_path) = candidates.next().await {
615                                    worktree_store.read_with(&cx, |worktree_store, cx| {
616                                        if let Some(worktree) = worktree_store
617                                            .worktree_for_id(project_path.worktree_id, cx)
618                                        {
619                                            if let Some(abs_path) = worktree
620                                                .read(cx)
621                                                .absolutize(&project_path.path)
622                                                .log_err()
623                                            {
624                                                abs_paths_tx.unbounded_send(abs_path)?;
625                                            }
626                                        }
627                                        anyhow::Ok(())
628                                    })??;
629                                }
630                                anyhow::Ok(())
631                            })
632                            .detach();
633                            abs_paths_rx
634                        })
635                    })
636                })
637            }),
638        )
639        .await
640    }
641
642    async fn run_foreground_fn<R: Send + 'static>(
643        description: &str,
644        foreground_tx: &mut mpsc::Sender<ForegroundFn>,
645        function: Box<dyn FnOnce(WeakEntity<Self>, AsyncApp) -> anyhow::Result<R> + Send>,
646    ) -> mlua::Result<R> {
647        let (response_tx, response_rx) = oneshot::channel();
648        let send_result = foreground_tx
649            .send(ForegroundFn(Box::new(move |this, cx| {
650                response_tx.send(function(this, cx)).ok();
651            })))
652            .await;
653        match send_result {
654            Ok(()) => (),
655            Err(err) => {
656                return Err(mlua::Error::runtime(format!(
657                    "Internal error while enqueuing work for {description}: {err}"
658                )))
659            }
660        }
661        match response_rx.await {
662            Ok(Ok(result)) => Ok(result),
663            Ok(Err(err)) => Err(mlua::Error::runtime(format!(
664                "Error while {description}: {err}"
665            ))),
666            Err(oneshot::Canceled) => Err(mlua::Error::runtime(format!(
667                "Internal error: response oneshot was canceled while {description}."
668            ))),
669        }
670    }
671}
672
673struct FileContent(RefCell<Vec<u8>>);
674
675impl UserData for FileContent {
676    fn add_methods<M: UserDataMethods<Self>>(_methods: &mut M) {
677        // FileContent doesn't have any methods so far.
678    }
679}
680
681#[cfg(test)]
682mod tests {
683    use gpui::TestAppContext;
684    use project::FakeFs;
685    use serde_json::json;
686    use settings::SettingsStore;
687
688    use super::*;
689
690    #[gpui::test]
691    async fn test_print(cx: &mut TestAppContext) {
692        init_test(cx);
693        let fs = FakeFs::new(cx.executor());
694        let project = Project::test(fs, [], cx).await;
695        let session = cx.new(|cx| Session::new(project, cx));
696        let script = r#"
697            print("Hello", "world!")
698            print("Goodbye", "moon!")
699        "#;
700        let output = session
701            .update(cx, |session, cx| session.run_script(script.to_string(), cx))
702            .await
703            .unwrap();
704        assert_eq!(output.stdout, "Hello\tworld!\nGoodbye\tmoon!\n");
705    }
706
707    #[gpui::test]
708    async fn test_search(cx: &mut TestAppContext) {
709        init_test(cx);
710        let fs = FakeFs::new(cx.executor());
711        fs.insert_tree(
712            "/",
713            json!({
714                "file1.txt": "Hello world!",
715                "file2.txt": "Goodbye moon!"
716            }),
717        )
718        .await;
719        let project = Project::test(fs, [Path::new("/")], cx).await;
720        let session = cx.new(|cx| Session::new(project, cx));
721        let script = r#"
722            local results = search("world")
723            for i, result in ipairs(results) do
724                print("File: " .. result.path)
725                print("Matches:")
726                for j, match in ipairs(result.matches) do
727                    print("  " .. match)
728                end
729            end
730        "#;
731        let output = session
732            .update(cx, |session, cx| session.run_script(script.to_string(), cx))
733            .await
734            .unwrap();
735        assert_eq!(output.stdout, "File: /file1.txt\nMatches:\n  world\n");
736    }
737
738    fn init_test(cx: &mut TestAppContext) {
739        let settings_store = cx.update(SettingsStore::test);
740        cx.set_global(settings_store);
741        cx.update(Project::init_settings);
742    }
743}