util.rs

   1pub mod arc_cow;
   2pub mod command;
   3pub mod fs;
   4pub mod markdown;
   5pub mod paths;
   6pub mod serde;
   7pub mod size;
   8#[cfg(any(test, feature = "test-support"))]
   9pub mod test;
  10pub mod time;
  11
  12use anyhow::Result;
  13use futures::Future;
  14use itertools::Either;
  15use regex::Regex;
  16use std::sync::{LazyLock, OnceLock};
  17use std::{
  18    borrow::Cow,
  19    cmp::{self, Ordering},
  20    env,
  21    ops::{AddAssign, Range, RangeInclusive},
  22    panic::Location,
  23    pin::Pin,
  24    task::{Context, Poll},
  25    time::Instant,
  26};
  27use unicase::UniCase;
  28
  29#[cfg(unix)]
  30use anyhow::Context as _;
  31
  32pub use take_until::*;
  33#[cfg(any(test, feature = "test-support"))]
  34pub use util_macros::{line_endings, separator, uri};
  35
  36#[macro_export]
  37macro_rules! debug_panic {
  38    ( $($fmt_arg:tt)* ) => {
  39        if cfg!(debug_assertions) {
  40            panic!( $($fmt_arg)* );
  41        } else {
  42            let backtrace = std::backtrace::Backtrace::capture();
  43            log::error!("{}\n{:?}", format_args!($($fmt_arg)*), backtrace);
  44        }
  45    };
  46}
  47
  48/// A macro to add "C:" to the beginning of a path literal on Windows, and replace all
  49/// the separator from `/` to `\`.
  50/// But on non-Windows platforms, it will return the path literal as is.
  51///
  52/// # Examples
  53/// ```rust
  54/// use util::path;
  55///
  56/// let path = path!("/Users/user/file.txt");
  57/// #[cfg(target_os = "windows")]
  58/// assert_eq!(path, "C:\\Users\\user\\file.txt");
  59/// #[cfg(not(target_os = "windows"))]
  60/// assert_eq!(path, "/Users/user/file.txt");
  61/// ```
  62#[cfg(all(any(test, feature = "test-support"), target_os = "windows"))]
  63#[macro_export]
  64macro_rules! path {
  65    ($path:literal) => {
  66        concat!("C:", util::separator!($path))
  67    };
  68}
  69
  70/// A macro to add "C:" to the beginning of a path literal on Windows, and replace all
  71/// the separator from `/` to `\`.
  72/// But on non-Windows platforms, it will return the path literal as is.
  73///
  74/// # Examples
  75/// ```rust
  76/// use util::path;
  77///
  78/// let path = path!("/Users/user/file.txt");
  79/// #[cfg(target_os = "windows")]
  80/// assert_eq!(path, "C:\\Users\\user\\file.txt");
  81/// #[cfg(not(target_os = "windows"))]
  82/// assert_eq!(path, "/Users/user/file.txt");
  83/// ```
  84#[cfg(all(any(test, feature = "test-support"), not(target_os = "windows")))]
  85#[macro_export]
  86macro_rules! path {
  87    ($path:literal) => {
  88        $path
  89    };
  90}
  91
  92pub fn truncate(s: &str, max_chars: usize) -> &str {
  93    match s.char_indices().nth(max_chars) {
  94        None => s,
  95        Some((idx, _)) => &s[..idx],
  96    }
  97}
  98
  99/// Removes characters from the end of the string if its length is greater than `max_chars` and
 100/// appends "..." to the string. Returns string unchanged if its length is smaller than max_chars.
 101pub fn truncate_and_trailoff(s: &str, max_chars: usize) -> String {
 102    debug_assert!(max_chars >= 5);
 103
 104    // If the string's byte length is <= max_chars, walking the string can be skipped since the
 105    // number of chars is <= the number of bytes.
 106    if s.len() <= max_chars {
 107        return s.to_string();
 108    }
 109    let truncation_ix = s.char_indices().map(|(i, _)| i).nth(max_chars);
 110    match truncation_ix {
 111        Some(index) => s[..index].to_string() + "",
 112        _ => s.to_string(),
 113    }
 114}
 115
 116/// Removes characters from the front of the string if its length is greater than `max_chars` and
 117/// prepends the string with "...". Returns string unchanged if its length is smaller than max_chars.
 118pub fn truncate_and_remove_front(s: &str, max_chars: usize) -> String {
 119    debug_assert!(max_chars >= 5);
 120
 121    // If the string's byte length is <= max_chars, walking the string can be skipped since the
 122    // number of chars is <= the number of bytes.
 123    if s.len() <= max_chars {
 124        return s.to_string();
 125    }
 126    let suffix_char_length = max_chars.saturating_sub(1);
 127    let truncation_ix = s
 128        .char_indices()
 129        .map(|(i, _)| i)
 130        .nth_back(suffix_char_length);
 131    match truncation_ix {
 132        Some(index) if index > 0 => "".to_string() + &s[index..],
 133        _ => s.to_string(),
 134    }
 135}
 136
 137/// Takes only `max_lines` from the string and, if there were more than `max_lines-1`, appends a
 138/// a newline and "..." to the string, so that `max_lines` are returned.
 139/// Returns string unchanged if its length is smaller than max_lines.
 140pub fn truncate_lines_and_trailoff(s: &str, max_lines: usize) -> String {
 141    let mut lines = s.lines().take(max_lines).collect::<Vec<_>>();
 142    if lines.len() > max_lines - 1 {
 143        lines.pop();
 144        lines.join("\n") + "\n"
 145    } else {
 146        lines.join("\n")
 147    }
 148}
 149
 150/// Truncates the string at a character boundary, such that the result is less than `max_bytes` in
 151/// length.
 152pub fn truncate_to_byte_limit(s: &str, max_bytes: usize) -> &str {
 153    if s.len() < max_bytes {
 154        return s;
 155    }
 156
 157    for i in (0..max_bytes).rev() {
 158        if s.is_char_boundary(i) {
 159            return &s[..i];
 160        }
 161    }
 162
 163    ""
 164}
 165
 166/// Takes a prefix of complete lines which fit within the byte limit. If the first line is longer
 167/// than the limit, truncates at a character boundary.
 168pub fn truncate_lines_to_byte_limit(s: &str, max_bytes: usize) -> &str {
 169    if s.len() < max_bytes {
 170        return s;
 171    }
 172
 173    for i in (0..max_bytes).rev() {
 174        if s.is_char_boundary(i) {
 175            if s.as_bytes()[i] == b'\n' {
 176                // Since the i-th character is \n, valid to slice at i + 1.
 177                return &s[..i + 1];
 178            }
 179        }
 180    }
 181
 182    truncate_to_byte_limit(s, max_bytes)
 183}
 184
 185#[test]
 186fn test_truncate_lines_to_byte_limit() {
 187    let text = "Line 1\nLine 2\nLine 3\nLine 4";
 188
 189    // Limit that includes all lines
 190    assert_eq!(truncate_lines_to_byte_limit(text, 100), text);
 191
 192    // Exactly the first line
 193    assert_eq!(truncate_lines_to_byte_limit(text, 7), "Line 1\n");
 194
 195    // Limit between lines
 196    assert_eq!(truncate_lines_to_byte_limit(text, 13), "Line 1\n");
 197    assert_eq!(truncate_lines_to_byte_limit(text, 20), "Line 1\nLine 2\n");
 198
 199    // Limit before first newline
 200    assert_eq!(truncate_lines_to_byte_limit(text, 6), "Line ");
 201
 202    // Test with non-ASCII characters
 203    let text_utf8 = "Line 1\nLíne 2\nLine 3";
 204    assert_eq!(
 205        truncate_lines_to_byte_limit(text_utf8, 15),
 206        "Line 1\nLíne 2\n"
 207    );
 208}
 209
 210pub fn post_inc<T: From<u8> + AddAssign<T> + Copy>(value: &mut T) -> T {
 211    let prev = *value;
 212    *value += T::from(1);
 213    prev
 214}
 215
 216/// Extend a sorted vector with a sorted sequence of items, maintaining the vector's sort order and
 217/// enforcing a maximum length. This also de-duplicates items. Sort the items according to the given callback. Before calling this,
 218/// both `vec` and `new_items` should already be sorted according to the `cmp` comparator.
 219pub fn extend_sorted<T, I, F>(vec: &mut Vec<T>, new_items: I, limit: usize, mut cmp: F)
 220where
 221    I: IntoIterator<Item = T>,
 222    F: FnMut(&T, &T) -> Ordering,
 223{
 224    let mut start_index = 0;
 225    for new_item in new_items {
 226        if let Err(i) = vec[start_index..].binary_search_by(|m| cmp(m, &new_item)) {
 227            let index = start_index + i;
 228            if vec.len() < limit {
 229                vec.insert(index, new_item);
 230            } else if index < vec.len() {
 231                vec.pop();
 232                vec.insert(index, new_item);
 233            }
 234            start_index = index;
 235        }
 236    }
 237}
 238
 239pub fn truncate_to_bottom_n_sorted_by<T, F>(items: &mut Vec<T>, limit: usize, compare: &F)
 240where
 241    F: Fn(&T, &T) -> Ordering,
 242{
 243    if limit == 0 {
 244        items.truncate(0);
 245    }
 246    if items.len() <= limit {
 247        items.sort_by(compare);
 248        return;
 249    }
 250    // When limit is near to items.len() it may be more efficient to sort the whole list and
 251    // truncate, rather than always doing selection first as is done below. It's hard to analyze
 252    // where the threshold for this should be since the quickselect style algorithm used by
 253    // `select_nth_unstable_by` makes the prefix partially sorted, and so its work is not wasted -
 254    // the expected number of comparisons needed by `sort_by` is less than it is for some arbitrary
 255    // unsorted input.
 256    items.select_nth_unstable_by(limit, compare);
 257    items.truncate(limit);
 258    items.sort_by(compare);
 259}
 260
 261#[cfg(unix)]
 262fn load_shell_from_passwd() -> Result<()> {
 263    let buflen = match unsafe { libc::sysconf(libc::_SC_GETPW_R_SIZE_MAX) } {
 264        n if n < 0 => 1024,
 265        n => n as usize,
 266    };
 267    let mut buffer = Vec::with_capacity(buflen);
 268
 269    let mut pwd: std::mem::MaybeUninit<libc::passwd> = std::mem::MaybeUninit::uninit();
 270    let mut result: *mut libc::passwd = std::ptr::null_mut();
 271
 272    let uid = unsafe { libc::getuid() };
 273    let status = unsafe {
 274        libc::getpwuid_r(
 275            uid,
 276            pwd.as_mut_ptr(),
 277            buffer.as_mut_ptr() as *mut libc::c_char,
 278            buflen,
 279            &mut result,
 280        )
 281    };
 282    let entry = unsafe { pwd.assume_init() };
 283
 284    anyhow::ensure!(
 285        status == 0,
 286        "call to getpwuid_r failed. uid: {}, status: {}",
 287        uid,
 288        status
 289    );
 290    anyhow::ensure!(!result.is_null(), "passwd entry for uid {} not found", uid);
 291    anyhow::ensure!(
 292        entry.pw_uid == uid,
 293        "passwd entry has different uid ({}) than getuid ({}) returned",
 294        entry.pw_uid,
 295        uid,
 296    );
 297
 298    let shell = unsafe { std::ffi::CStr::from_ptr(entry.pw_shell).to_str().unwrap() };
 299    if env::var("SHELL").map_or(true, |shell_env| shell_env != shell) {
 300        log::info!(
 301            "updating SHELL environment variable to value from passwd entry: {:?}",
 302            shell,
 303        );
 304        unsafe { env::set_var("SHELL", shell) };
 305    }
 306
 307    Ok(())
 308}
 309
 310#[cfg(unix)]
 311pub fn load_login_shell_environment() -> Result<()> {
 312    load_shell_from_passwd().log_err();
 313
 314    let marker = "ZED_LOGIN_SHELL_START";
 315    let shell = env::var("SHELL").context(
 316        "SHELL environment variable is not assigned so we can't source login environment variables",
 317    )?;
 318
 319    // If possible, we want to `cd` in the user's `$HOME` to trigger programs
 320    // such as direnv, asdf, mise, ... to adjust the PATH. These tools often hook
 321    // into shell's `cd` command (and hooks) to manipulate env.
 322    // We do this so that we get the env a user would have when spawning a shell
 323    // in home directory.
 324    let shell_cmd_prefix = std::env::var_os("HOME")
 325        .and_then(|home| home.into_string().ok())
 326        .map(|home| format!("cd '{home}';"));
 327
 328    let shell_cmd = format!(
 329        "{}printf '%s' {marker}; /usr/bin/env;",
 330        shell_cmd_prefix.as_deref().unwrap_or("")
 331    );
 332
 333    let output = set_pre_exec_to_start_new_session(
 334        std::process::Command::new(&shell).args(["-l", "-i", "-c", &shell_cmd]),
 335    )
 336    .output()
 337    .context("failed to spawn login shell to source login environment variables")?;
 338    anyhow::ensure!(output.status.success(), "login shell exited with error");
 339
 340    let stdout = String::from_utf8_lossy(&output.stdout);
 341
 342    if let Some(env_output_start) = stdout.find(marker) {
 343        let env_output = &stdout[env_output_start + marker.len()..];
 344
 345        parse_env_output(env_output, |key, value| unsafe { env::set_var(key, value) });
 346
 347        log::info!(
 348            "set environment variables from shell:{}, path:{}",
 349            shell,
 350            env::var("PATH").unwrap_or_default(),
 351        );
 352    }
 353
 354    Ok(())
 355}
 356
 357/// Configures the process to start a new session, to prevent interactive shells from taking control
 358/// of the terminal.
 359///
 360/// For more details: https://registerspill.thorstenball.com/p/how-to-lose-control-of-your-shell
 361pub fn set_pre_exec_to_start_new_session(
 362    command: &mut std::process::Command,
 363) -> &mut std::process::Command {
 364    // safety: code in pre_exec should be signal safe.
 365    // https://man7.org/linux/man-pages/man7/signal-safety.7.html
 366    #[cfg(not(target_os = "windows"))]
 367    unsafe {
 368        use std::os::unix::process::CommandExt;
 369        command.pre_exec(|| {
 370            libc::setsid();
 371            Ok(())
 372        });
 373    };
 374    command
 375}
 376
 377/// Parse the result of calling `usr/bin/env` with no arguments
 378pub fn parse_env_output(env: &str, mut f: impl FnMut(String, String)) {
 379    let mut current_key: Option<String> = None;
 380    let mut current_value: Option<String> = None;
 381
 382    for line in env.split_terminator('\n') {
 383        if let Some(separator_index) = line.find('=') {
 384            if !line[..separator_index].is_empty() {
 385                if let Some((key, value)) = Option::zip(current_key.take(), current_value.take()) {
 386                    f(key, value)
 387                }
 388                current_key = Some(line[..separator_index].to_string());
 389                current_value = Some(line[separator_index + 1..].to_string());
 390                continue;
 391            };
 392        }
 393        if let Some(value) = current_value.as_mut() {
 394            value.push('\n');
 395            value.push_str(line);
 396        }
 397    }
 398    if let Some((key, value)) = Option::zip(current_key.take(), current_value.take()) {
 399        f(key, value)
 400    }
 401}
 402
 403pub fn merge_json_value_into(source: serde_json::Value, target: &mut serde_json::Value) {
 404    use serde_json::Value;
 405
 406    match (source, target) {
 407        (Value::Object(source), Value::Object(target)) => {
 408            for (key, value) in source {
 409                if let Some(target) = target.get_mut(&key) {
 410                    merge_json_value_into(value, target);
 411                } else {
 412                    target.insert(key, value);
 413                }
 414            }
 415        }
 416
 417        (Value::Array(source), Value::Array(target)) => {
 418            for value in source {
 419                target.push(value);
 420            }
 421        }
 422
 423        (source, target) => *target = source,
 424    }
 425}
 426
 427pub fn merge_non_null_json_value_into(source: serde_json::Value, target: &mut serde_json::Value) {
 428    use serde_json::Value;
 429    if let Value::Object(source_object) = source {
 430        let target_object = if let Value::Object(target) = target {
 431            target
 432        } else {
 433            *target = Value::Object(Default::default());
 434            target.as_object_mut().unwrap()
 435        };
 436        for (key, value) in source_object {
 437            if let Some(target) = target_object.get_mut(&key) {
 438                merge_non_null_json_value_into(value, target);
 439            } else if !value.is_null() {
 440                target_object.insert(key, value);
 441            }
 442        }
 443    } else if !source.is_null() {
 444        *target = source
 445    }
 446}
 447
 448pub fn measure<R>(label: &str, f: impl FnOnce() -> R) -> R {
 449    static ZED_MEASUREMENTS: OnceLock<bool> = OnceLock::new();
 450    let zed_measurements = ZED_MEASUREMENTS.get_or_init(|| {
 451        env::var("ZED_MEASUREMENTS")
 452            .map(|measurements| measurements == "1" || measurements == "true")
 453            .unwrap_or(false)
 454    });
 455
 456    if *zed_measurements {
 457        let start = Instant::now();
 458        let result = f();
 459        let elapsed = start.elapsed();
 460        eprintln!("{}: {:?}", label, elapsed);
 461        result
 462    } else {
 463        f()
 464    }
 465}
 466
 467pub fn iterate_expanded_and_wrapped_usize_range(
 468    range: Range<usize>,
 469    additional_before: usize,
 470    additional_after: usize,
 471    wrap_length: usize,
 472) -> impl Iterator<Item = usize> {
 473    let start_wraps = range.start < additional_before;
 474    let end_wraps = wrap_length < range.end + additional_after;
 475    if start_wraps && end_wraps {
 476        Either::Left(0..wrap_length)
 477    } else if start_wraps {
 478        let wrapped_start = (range.start + wrap_length).saturating_sub(additional_before);
 479        if wrapped_start <= range.end {
 480            Either::Left(0..wrap_length)
 481        } else {
 482            Either::Right((0..range.end + additional_after).chain(wrapped_start..wrap_length))
 483        }
 484    } else if end_wraps {
 485        let wrapped_end = range.end + additional_after - wrap_length;
 486        if range.start <= wrapped_end {
 487            Either::Left(0..wrap_length)
 488        } else {
 489            Either::Right((0..wrapped_end).chain(range.start - additional_before..wrap_length))
 490        }
 491    } else {
 492        Either::Left((range.start - additional_before)..(range.end + additional_after))
 493    }
 494}
 495
 496#[cfg(target_os = "windows")]
 497pub fn get_windows_system_shell() -> String {
 498    use std::path::PathBuf;
 499
 500    fn find_pwsh_in_programfiles(find_alternate: bool, find_preview: bool) -> Option<PathBuf> {
 501        #[cfg(target_pointer_width = "64")]
 502        let env_var = if find_alternate {
 503            "ProgramFiles(x86)"
 504        } else {
 505            "ProgramFiles"
 506        };
 507
 508        #[cfg(target_pointer_width = "32")]
 509        let env_var = if find_alternate {
 510            "ProgramW6432"
 511        } else {
 512            "ProgramFiles"
 513        };
 514
 515        let install_base_dir = PathBuf::from(std::env::var_os(env_var)?).join("PowerShell");
 516        install_base_dir
 517            .read_dir()
 518            .ok()?
 519            .filter_map(Result::ok)
 520            .filter(|entry| matches!(entry.file_type(), Ok(ft) if ft.is_dir()))
 521            .filter_map(|entry| {
 522                let dir_name = entry.file_name();
 523                let dir_name = dir_name.to_string_lossy();
 524
 525                let version = if find_preview {
 526                    let dash_index = dir_name.find('-')?;
 527                    if &dir_name[dash_index + 1..] != "preview" {
 528                        return None;
 529                    };
 530                    dir_name[..dash_index].parse::<u32>().ok()?
 531                } else {
 532                    dir_name.parse::<u32>().ok()?
 533                };
 534
 535                let exe_path = entry.path().join("pwsh.exe");
 536                if exe_path.exists() {
 537                    Some((version, exe_path))
 538                } else {
 539                    None
 540                }
 541            })
 542            .max_by_key(|(version, _)| *version)
 543            .map(|(_, path)| path)
 544    }
 545
 546    fn find_pwsh_in_msix(find_preview: bool) -> Option<PathBuf> {
 547        let msix_app_dir =
 548            PathBuf::from(std::env::var_os("LOCALAPPDATA")?).join("Microsoft\\WindowsApps");
 549        if !msix_app_dir.exists() {
 550            return None;
 551        }
 552
 553        let prefix = if find_preview {
 554            "Microsoft.PowerShellPreview_"
 555        } else {
 556            "Microsoft.PowerShell_"
 557        };
 558        msix_app_dir
 559            .read_dir()
 560            .ok()?
 561            .filter_map(|entry| {
 562                let entry = entry.ok()?;
 563                if !matches!(entry.file_type(), Ok(ft) if ft.is_dir()) {
 564                    return None;
 565                }
 566
 567                if !entry.file_name().to_string_lossy().starts_with(prefix) {
 568                    return None;
 569                }
 570
 571                let exe_path = entry.path().join("pwsh.exe");
 572                exe_path.exists().then_some(exe_path)
 573            })
 574            .next()
 575    }
 576
 577    fn find_pwsh_in_scoop() -> Option<PathBuf> {
 578        let pwsh_exe =
 579            PathBuf::from(std::env::var_os("USERPROFILE")?).join("scoop\\shims\\pwsh.exe");
 580        pwsh_exe.exists().then_some(pwsh_exe)
 581    }
 582
 583    static SYSTEM_SHELL: LazyLock<String> = LazyLock::new(|| {
 584        find_pwsh_in_programfiles(false, false)
 585            .or_else(|| find_pwsh_in_programfiles(true, false))
 586            .or_else(|| find_pwsh_in_msix(false))
 587            .or_else(|| find_pwsh_in_programfiles(false, true))
 588            .or_else(|| find_pwsh_in_msix(true))
 589            .or_else(|| find_pwsh_in_programfiles(true, true))
 590            .or_else(find_pwsh_in_scoop)
 591            .map(|p| p.to_string_lossy().to_string())
 592            .unwrap_or("powershell.exe".to_string())
 593    });
 594
 595    (*SYSTEM_SHELL).clone()
 596}
 597
 598pub trait ResultExt<E> {
 599    type Ok;
 600
 601    fn log_err(self) -> Option<Self::Ok>;
 602    /// Assert that this result should never be an error in development or tests.
 603    fn debug_assert_ok(self, reason: &str) -> Self;
 604    fn warn_on_err(self) -> Option<Self::Ok>;
 605    fn log_with_level(self, level: log::Level) -> Option<Self::Ok>;
 606    fn anyhow(self) -> anyhow::Result<Self::Ok>
 607    where
 608        E: Into<anyhow::Error>;
 609}
 610
 611impl<T, E> ResultExt<E> for Result<T, E>
 612where
 613    E: std::fmt::Debug,
 614{
 615    type Ok = T;
 616
 617    #[track_caller]
 618    fn log_err(self) -> Option<T> {
 619        self.log_with_level(log::Level::Error)
 620    }
 621
 622    #[track_caller]
 623    fn debug_assert_ok(self, reason: &str) -> Self {
 624        if let Err(error) = &self {
 625            debug_panic!("{reason} - {error:?}");
 626        }
 627        self
 628    }
 629
 630    #[track_caller]
 631    fn warn_on_err(self) -> Option<T> {
 632        self.log_with_level(log::Level::Warn)
 633    }
 634
 635    #[track_caller]
 636    fn log_with_level(self, level: log::Level) -> Option<T> {
 637        match self {
 638            Ok(value) => Some(value),
 639            Err(error) => {
 640                log_error_with_caller(*Location::caller(), error, level);
 641                None
 642            }
 643        }
 644    }
 645
 646    fn anyhow(self) -> anyhow::Result<T>
 647    where
 648        E: Into<anyhow::Error>,
 649    {
 650        self.map_err(Into::into)
 651    }
 652}
 653
 654fn log_error_with_caller<E>(caller: core::panic::Location<'_>, error: E, level: log::Level)
 655where
 656    E: std::fmt::Debug,
 657{
 658    #[cfg(not(target_os = "windows"))]
 659    let file = caller.file();
 660    #[cfg(target_os = "windows")]
 661    let file = caller.file().replace('\\', "/");
 662    // In this codebase, the first segment of the file path is
 663    // the 'crates' folder, followed by the crate name.
 664    let target = file.split('/').nth(1);
 665
 666    log::logger().log(
 667        &log::Record::builder()
 668            .target(target.unwrap_or(""))
 669            .module_path(target)
 670            .args(format_args!("{:?}", error))
 671            .file(Some(caller.file()))
 672            .line(Some(caller.line()))
 673            .level(level)
 674            .build(),
 675    );
 676}
 677
 678pub fn log_err<E: std::fmt::Debug>(error: &E) {
 679    log_error_with_caller(*Location::caller(), error, log::Level::Warn);
 680}
 681
 682pub trait TryFutureExt {
 683    fn log_err(self) -> LogErrorFuture<Self>
 684    where
 685        Self: Sized;
 686
 687    fn log_tracked_err(self, location: core::panic::Location<'static>) -> LogErrorFuture<Self>
 688    where
 689        Self: Sized;
 690
 691    fn warn_on_err(self) -> LogErrorFuture<Self>
 692    where
 693        Self: Sized;
 694    fn unwrap(self) -> UnwrapFuture<Self>
 695    where
 696        Self: Sized;
 697}
 698
 699impl<F, T, E> TryFutureExt for F
 700where
 701    F: Future<Output = Result<T, E>>,
 702    E: std::fmt::Debug,
 703{
 704    #[track_caller]
 705    fn log_err(self) -> LogErrorFuture<Self>
 706    where
 707        Self: Sized,
 708    {
 709        let location = Location::caller();
 710        LogErrorFuture(self, log::Level::Error, *location)
 711    }
 712
 713    fn log_tracked_err(self, location: core::panic::Location<'static>) -> LogErrorFuture<Self>
 714    where
 715        Self: Sized,
 716    {
 717        LogErrorFuture(self, log::Level::Error, location)
 718    }
 719
 720    #[track_caller]
 721    fn warn_on_err(self) -> LogErrorFuture<Self>
 722    where
 723        Self: Sized,
 724    {
 725        let location = Location::caller();
 726        LogErrorFuture(self, log::Level::Warn, *location)
 727    }
 728
 729    fn unwrap(self) -> UnwrapFuture<Self>
 730    where
 731        Self: Sized,
 732    {
 733        UnwrapFuture(self)
 734    }
 735}
 736
 737#[must_use]
 738pub struct LogErrorFuture<F>(F, log::Level, core::panic::Location<'static>);
 739
 740impl<F, T, E> Future for LogErrorFuture<F>
 741where
 742    F: Future<Output = Result<T, E>>,
 743    E: std::fmt::Debug,
 744{
 745    type Output = Option<T>;
 746
 747    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
 748        let level = self.1;
 749        let location = self.2;
 750        let inner = unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().0) };
 751        match inner.poll(cx) {
 752            Poll::Ready(output) => Poll::Ready(match output {
 753                Ok(output) => Some(output),
 754                Err(error) => {
 755                    log_error_with_caller(location, error, level);
 756                    None
 757                }
 758            }),
 759            Poll::Pending => Poll::Pending,
 760        }
 761    }
 762}
 763
 764pub struct UnwrapFuture<F>(F);
 765
 766impl<F, T, E> Future for UnwrapFuture<F>
 767where
 768    F: Future<Output = Result<T, E>>,
 769    E: std::fmt::Debug,
 770{
 771    type Output = T;
 772
 773    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
 774        let inner = unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().0) };
 775        match inner.poll(cx) {
 776            Poll::Ready(result) => Poll::Ready(result.unwrap()),
 777            Poll::Pending => Poll::Pending,
 778        }
 779    }
 780}
 781
 782pub struct Deferred<F: FnOnce()>(Option<F>);
 783
 784impl<F: FnOnce()> Deferred<F> {
 785    /// Drop without running the deferred function.
 786    pub fn abort(mut self) {
 787        self.0.take();
 788    }
 789}
 790
 791impl<F: FnOnce()> Drop for Deferred<F> {
 792    fn drop(&mut self) {
 793        if let Some(f) = self.0.take() {
 794            f()
 795        }
 796    }
 797}
 798
 799/// Run the given function when the returned value is dropped (unless it's cancelled).
 800#[must_use]
 801pub fn defer<F: FnOnce()>(f: F) -> Deferred<F> {
 802    Deferred(Some(f))
 803}
 804
 805#[cfg(any(test, feature = "test-support"))]
 806mod rng {
 807    use rand::{Rng, seq::SliceRandom};
 808    pub struct RandomCharIter<T: Rng> {
 809        rng: T,
 810        simple_text: bool,
 811    }
 812
 813    impl<T: Rng> RandomCharIter<T> {
 814        pub fn new(rng: T) -> Self {
 815            Self {
 816                rng,
 817                simple_text: std::env::var("SIMPLE_TEXT").map_or(false, |v| !v.is_empty()),
 818            }
 819        }
 820
 821        pub fn with_simple_text(mut self) -> Self {
 822            self.simple_text = true;
 823            self
 824        }
 825    }
 826
 827    impl<T: Rng> Iterator for RandomCharIter<T> {
 828        type Item = char;
 829
 830        fn next(&mut self) -> Option<Self::Item> {
 831            if self.simple_text {
 832                return if self.rng.gen_range(0..100) < 5 {
 833                    Some('\n')
 834                } else {
 835                    Some(self.rng.gen_range(b'a'..b'z' + 1).into())
 836                };
 837            }
 838
 839            match self.rng.gen_range(0..100) {
 840                // whitespace
 841                0..=19 => [' ', '\n', '\r', '\t'].choose(&mut self.rng).copied(),
 842                // two-byte greek letters
 843                20..=32 => char::from_u32(self.rng.gen_range(('α' as u32)..('ω' as u32 + 1))),
 844                // // three-byte characters
 845                33..=45 => ['✋', '✅', '❌', '❎', '⭐']
 846                    .choose(&mut self.rng)
 847                    .copied(),
 848                // // four-byte characters
 849                46..=58 => ['🍐', '🏀', '🍗', '🎉'].choose(&mut self.rng).copied(),
 850                // ascii letters
 851                _ => Some(self.rng.gen_range(b'a'..b'z' + 1).into()),
 852            }
 853        }
 854    }
 855}
 856#[cfg(any(test, feature = "test-support"))]
 857pub use rng::RandomCharIter;
 858/// Get an embedded file as a string.
 859pub fn asset_str<A: rust_embed::RustEmbed>(path: &str) -> Cow<'static, str> {
 860    match A::get(path).expect(path).data {
 861        Cow::Borrowed(bytes) => Cow::Borrowed(std::str::from_utf8(bytes).unwrap()),
 862        Cow::Owned(bytes) => Cow::Owned(String::from_utf8(bytes).unwrap()),
 863    }
 864}
 865
 866/// Expands to an immediately-invoked function expression. Good for using the ? operator
 867/// in functions which do not return an Option or Result.
 868///
 869/// Accepts a normal block, an async block, or an async move block.
 870#[macro_export]
 871macro_rules! maybe {
 872    ($block:block) => {
 873        (|| $block)()
 874    };
 875    (async $block:block) => {
 876        (|| async $block)()
 877    };
 878    (async move $block:block) => {
 879        (|| async move $block)()
 880    };
 881}
 882
 883pub trait RangeExt<T> {
 884    fn sorted(&self) -> Self;
 885    fn to_inclusive(&self) -> RangeInclusive<T>;
 886    fn overlaps(&self, other: &Range<T>) -> bool;
 887    fn contains_inclusive(&self, other: &Range<T>) -> bool;
 888}
 889
 890impl<T: Ord + Clone> RangeExt<T> for Range<T> {
 891    fn sorted(&self) -> Self {
 892        cmp::min(&self.start, &self.end).clone()..cmp::max(&self.start, &self.end).clone()
 893    }
 894
 895    fn to_inclusive(&self) -> RangeInclusive<T> {
 896        self.start.clone()..=self.end.clone()
 897    }
 898
 899    fn overlaps(&self, other: &Range<T>) -> bool {
 900        self.start < other.end && other.start < self.end
 901    }
 902
 903    fn contains_inclusive(&self, other: &Range<T>) -> bool {
 904        self.start <= other.start && other.end <= self.end
 905    }
 906}
 907
 908impl<T: Ord + Clone> RangeExt<T> for RangeInclusive<T> {
 909    fn sorted(&self) -> Self {
 910        cmp::min(self.start(), self.end()).clone()..=cmp::max(self.start(), self.end()).clone()
 911    }
 912
 913    fn to_inclusive(&self) -> RangeInclusive<T> {
 914        self.clone()
 915    }
 916
 917    fn overlaps(&self, other: &Range<T>) -> bool {
 918        self.start() < &other.end && &other.start <= self.end()
 919    }
 920
 921    fn contains_inclusive(&self, other: &Range<T>) -> bool {
 922        self.start() <= &other.start && &other.end <= self.end()
 923    }
 924}
 925
 926/// A way to sort strings with starting numbers numerically first, falling back to alphanumeric one,
 927/// case-insensitive.
 928///
 929/// This is useful for turning regular alphanumerically sorted sequences as `1-abc, 10, 11-def, .., 2, 21-abc`
 930/// into `1-abc, 2, 10, 11-def, .., 21-abc`
 931#[derive(Debug, PartialEq, Eq)]
 932pub struct NumericPrefixWithSuffix<'a>(Option<u64>, &'a str);
 933
 934impl<'a> NumericPrefixWithSuffix<'a> {
 935    pub fn from_numeric_prefixed_str(str: &'a str) -> Self {
 936        let i = str.chars().take_while(|c| c.is_ascii_digit()).count();
 937        let (prefix, remainder) = str.split_at(i);
 938
 939        let prefix = prefix.parse().ok();
 940        Self(prefix, remainder)
 941    }
 942}
 943
 944/// When dealing with equality, we need to consider the case of the strings to achieve strict equality
 945/// to handle cases like "a" < "A" instead of "a" == "A".
 946impl Ord for NumericPrefixWithSuffix<'_> {
 947    fn cmp(&self, other: &Self) -> Ordering {
 948        match (self.0, other.0) {
 949            (None, None) => UniCase::new(self.1)
 950                .cmp(&UniCase::new(other.1))
 951                .then_with(|| self.1.cmp(other.1).reverse()),
 952            (None, Some(_)) => Ordering::Greater,
 953            (Some(_), None) => Ordering::Less,
 954            (Some(a), Some(b)) => a.cmp(&b).then_with(|| {
 955                UniCase::new(self.1)
 956                    .cmp(&UniCase::new(other.1))
 957                    .then_with(|| self.1.cmp(other.1).reverse())
 958            }),
 959        }
 960    }
 961}
 962
 963impl PartialOrd for NumericPrefixWithSuffix<'_> {
 964    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
 965        Some(self.cmp(other))
 966    }
 967}
 968
 969/// Capitalizes the first character of a string.
 970///
 971/// This function takes a string slice as input and returns a new `String` with the first character
 972/// capitalized.
 973///
 974/// # Examples
 975///
 976/// ```
 977/// use util::capitalize;
 978///
 979/// assert_eq!(capitalize("hello"), "Hello");
 980/// assert_eq!(capitalize("WORLD"), "WORLD");
 981/// assert_eq!(capitalize(""), "");
 982/// ```
 983pub fn capitalize(str: &str) -> String {
 984    let mut chars = str.chars();
 985    match chars.next() {
 986        None => String::new(),
 987        Some(first_char) => first_char.to_uppercase().collect::<String>() + chars.as_str(),
 988    }
 989}
 990
 991fn emoji_regex() -> &'static Regex {
 992    static EMOJI_REGEX: LazyLock<Regex> =
 993        LazyLock::new(|| Regex::new("(\\p{Emoji}|\u{200D})").unwrap());
 994    &EMOJI_REGEX
 995}
 996
 997/// Returns true if the given string consists of emojis only.
 998/// E.g. "👨‍👩‍👧‍👧👋" will return true, but "👋!" will return false.
 999pub fn word_consists_of_emojis(s: &str) -> bool {
1000    let mut prev_end = 0;
1001    for capture in emoji_regex().find_iter(s) {
1002        if capture.start() != prev_end {
1003            return false;
1004        }
1005        prev_end = capture.end();
1006    }
1007    prev_end == s.len()
1008}
1009
1010pub fn default<D: Default>() -> D {
1011    Default::default()
1012}
1013
1014pub fn get_system_shell() -> String {
1015    #[cfg(target_os = "windows")]
1016    {
1017        get_windows_system_shell()
1018    }
1019
1020    #[cfg(not(target_os = "windows"))]
1021    {
1022        std::env::var("SHELL").unwrap_or("/bin/sh".to_string())
1023    }
1024}
1025
1026#[derive(Debug)]
1027pub enum ConnectionResult<O> {
1028    Timeout,
1029    ConnectionReset,
1030    Result(anyhow::Result<O>),
1031}
1032
1033impl<O> ConnectionResult<O> {
1034    pub fn into_response(self) -> anyhow::Result<O> {
1035        match self {
1036            ConnectionResult::Timeout => anyhow::bail!("Request timed out"),
1037            ConnectionResult::ConnectionReset => anyhow::bail!("Server reset the connection"),
1038            ConnectionResult::Result(r) => r,
1039        }
1040    }
1041}
1042
1043impl<O> From<anyhow::Result<O>> for ConnectionResult<O> {
1044    fn from(result: anyhow::Result<O>) -> Self {
1045        ConnectionResult::Result(result)
1046    }
1047}
1048
1049#[cfg(test)]
1050mod tests {
1051    use super::*;
1052
1053    #[test]
1054    fn test_extend_sorted() {
1055        let mut vec = vec![];
1056
1057        extend_sorted(&mut vec, vec![21, 17, 13, 8, 1, 0], 5, |a, b| b.cmp(a));
1058        assert_eq!(vec, &[21, 17, 13, 8, 1]);
1059
1060        extend_sorted(&mut vec, vec![101, 19, 17, 8, 2], 8, |a, b| b.cmp(a));
1061        assert_eq!(vec, &[101, 21, 19, 17, 13, 8, 2, 1]);
1062
1063        extend_sorted(&mut vec, vec![1000, 19, 17, 9, 5], 8, |a, b| b.cmp(a));
1064        assert_eq!(vec, &[1000, 101, 21, 19, 17, 13, 9, 8]);
1065    }
1066
1067    #[test]
1068    fn test_truncate_to_bottom_n_sorted_by() {
1069        let mut vec: Vec<u32> = vec![5, 2, 3, 4, 1];
1070        truncate_to_bottom_n_sorted_by(&mut vec, 10, &u32::cmp);
1071        assert_eq!(vec, &[1, 2, 3, 4, 5]);
1072
1073        vec = vec![5, 2, 3, 4, 1];
1074        truncate_to_bottom_n_sorted_by(&mut vec, 5, &u32::cmp);
1075        assert_eq!(vec, &[1, 2, 3, 4, 5]);
1076
1077        vec = vec![5, 2, 3, 4, 1];
1078        truncate_to_bottom_n_sorted_by(&mut vec, 4, &u32::cmp);
1079        assert_eq!(vec, &[1, 2, 3, 4]);
1080
1081        vec = vec![5, 2, 3, 4, 1];
1082        truncate_to_bottom_n_sorted_by(&mut vec, 1, &u32::cmp);
1083        assert_eq!(vec, &[1]);
1084
1085        vec = vec![5, 2, 3, 4, 1];
1086        truncate_to_bottom_n_sorted_by(&mut vec, 0, &u32::cmp);
1087        assert!(vec.is_empty());
1088    }
1089
1090    #[test]
1091    fn test_iife() {
1092        fn option_returning_function() -> Option<()> {
1093            None
1094        }
1095
1096        let foo = maybe!({
1097            option_returning_function()?;
1098            Some(())
1099        });
1100
1101        assert_eq!(foo, None);
1102    }
1103
1104    #[test]
1105    fn test_truncate_and_trailoff() {
1106        assert_eq!(truncate_and_trailoff("", 5), "");
1107        assert_eq!(truncate_and_trailoff("aaaaaa", 7), "aaaaaa");
1108        assert_eq!(truncate_and_trailoff("aaaaaa", 6), "aaaaaa");
1109        assert_eq!(truncate_and_trailoff("aaaaaa", 5), "aaaaa…");
1110        assert_eq!(truncate_and_trailoff("èèèèèè", 7), "èèèèèè");
1111        assert_eq!(truncate_and_trailoff("èèèèèè", 6), "èèèèèè");
1112        assert_eq!(truncate_and_trailoff("èèèèèè", 5), "èèèèè…");
1113    }
1114
1115    #[test]
1116    fn test_truncate_and_remove_front() {
1117        assert_eq!(truncate_and_remove_front("", 5), "");
1118        assert_eq!(truncate_and_remove_front("aaaaaa", 7), "aaaaaa");
1119        assert_eq!(truncate_and_remove_front("aaaaaa", 6), "aaaaaa");
1120        assert_eq!(truncate_and_remove_front("aaaaaa", 5), "…aaaaa");
1121        assert_eq!(truncate_and_remove_front("èèèèèè", 7), "èèèèèè");
1122        assert_eq!(truncate_and_remove_front("èèèèèè", 6), "èèèèèè");
1123        assert_eq!(truncate_and_remove_front("èèèèèè", 5), "…èèèèè");
1124    }
1125
1126    #[test]
1127    fn test_numeric_prefix_str_method() {
1128        let target = "1a";
1129        assert_eq!(
1130            NumericPrefixWithSuffix::from_numeric_prefixed_str(target),
1131            NumericPrefixWithSuffix(Some(1), "a")
1132        );
1133
1134        let target = "12ab";
1135        assert_eq!(
1136            NumericPrefixWithSuffix::from_numeric_prefixed_str(target),
1137            NumericPrefixWithSuffix(Some(12), "ab")
1138        );
1139
1140        let target = "12_ab";
1141        assert_eq!(
1142            NumericPrefixWithSuffix::from_numeric_prefixed_str(target),
1143            NumericPrefixWithSuffix(Some(12), "_ab")
1144        );
1145
1146        let target = "1_2ab";
1147        assert_eq!(
1148            NumericPrefixWithSuffix::from_numeric_prefixed_str(target),
1149            NumericPrefixWithSuffix(Some(1), "_2ab")
1150        );
1151
1152        let target = "1.2";
1153        assert_eq!(
1154            NumericPrefixWithSuffix::from_numeric_prefixed_str(target),
1155            NumericPrefixWithSuffix(Some(1), ".2")
1156        );
1157
1158        let target = "1.2_a";
1159        assert_eq!(
1160            NumericPrefixWithSuffix::from_numeric_prefixed_str(target),
1161            NumericPrefixWithSuffix(Some(1), ".2_a")
1162        );
1163
1164        let target = "12.2_a";
1165        assert_eq!(
1166            NumericPrefixWithSuffix::from_numeric_prefixed_str(target),
1167            NumericPrefixWithSuffix(Some(12), ".2_a")
1168        );
1169
1170        let target = "12a.2_a";
1171        assert_eq!(
1172            NumericPrefixWithSuffix::from_numeric_prefixed_str(target),
1173            NumericPrefixWithSuffix(Some(12), "a.2_a")
1174        );
1175    }
1176
1177    #[test]
1178    fn test_numeric_prefix_with_suffix() {
1179        let mut sorted = vec!["1-abc", "10", "11def", "2", "21-abc"];
1180        sorted.sort_by_key(|s| NumericPrefixWithSuffix::from_numeric_prefixed_str(s));
1181        assert_eq!(sorted, ["1-abc", "2", "10", "11def", "21-abc"]);
1182
1183        for numeric_prefix_less in ["numeric_prefix_less", "aaa", "~™£"] {
1184            assert_eq!(
1185                NumericPrefixWithSuffix::from_numeric_prefixed_str(numeric_prefix_less),
1186                NumericPrefixWithSuffix(None, numeric_prefix_less),
1187                "String without numeric prefix `{numeric_prefix_less}` should not be converted into NumericPrefixWithSuffix"
1188            )
1189        }
1190    }
1191
1192    #[test]
1193    fn test_word_consists_of_emojis() {
1194        let words_to_test = vec![
1195            ("👨‍👩‍👧‍👧👋🥒", true),
1196            ("👋", true),
1197            ("!👋", false),
1198            ("👋!", false),
1199            ("👋 ", false),
1200            (" 👋", false),
1201            ("Test", false),
1202        ];
1203
1204        for (text, expected_result) in words_to_test {
1205            assert_eq!(word_consists_of_emojis(text), expected_result);
1206        }
1207    }
1208
1209    #[test]
1210    fn test_truncate_lines_and_trailoff() {
1211        let text = r#"Line 1
1212Line 2
1213Line 3"#;
1214
1215        assert_eq!(
1216            truncate_lines_and_trailoff(text, 2),
1217            r#"Line 1
1218…"#
1219        );
1220
1221        assert_eq!(
1222            truncate_lines_and_trailoff(text, 3),
1223            r#"Line 1
1224Line 2
1225…"#
1226        );
1227
1228        assert_eq!(
1229            truncate_lines_and_trailoff(text, 4),
1230            r#"Line 1
1231Line 2
1232Line 3"#
1233        );
1234    }
1235
1236    #[test]
1237    fn test_iterate_expanded_and_wrapped_usize_range() {
1238        // Neither wrap
1239        assert_eq!(
1240            iterate_expanded_and_wrapped_usize_range(2..4, 1, 1, 8).collect::<Vec<usize>>(),
1241            (1..5).collect::<Vec<usize>>()
1242        );
1243        // Start wraps
1244        assert_eq!(
1245            iterate_expanded_and_wrapped_usize_range(2..4, 3, 1, 8).collect::<Vec<usize>>(),
1246            ((0..5).chain(7..8)).collect::<Vec<usize>>()
1247        );
1248        // Start wraps all the way around
1249        assert_eq!(
1250            iterate_expanded_and_wrapped_usize_range(2..4, 5, 1, 8).collect::<Vec<usize>>(),
1251            (0..8).collect::<Vec<usize>>()
1252        );
1253        // Start wraps all the way around and past 0
1254        assert_eq!(
1255            iterate_expanded_and_wrapped_usize_range(2..4, 10, 1, 8).collect::<Vec<usize>>(),
1256            (0..8).collect::<Vec<usize>>()
1257        );
1258        // End wraps
1259        assert_eq!(
1260            iterate_expanded_and_wrapped_usize_range(3..5, 1, 4, 8).collect::<Vec<usize>>(),
1261            (0..1).chain(2..8).collect::<Vec<usize>>()
1262        );
1263        // End wraps all the way around
1264        assert_eq!(
1265            iterate_expanded_and_wrapped_usize_range(3..5, 1, 5, 8).collect::<Vec<usize>>(),
1266            (0..8).collect::<Vec<usize>>()
1267        );
1268        // End wraps all the way around and past the end
1269        assert_eq!(
1270            iterate_expanded_and_wrapped_usize_range(3..5, 1, 10, 8).collect::<Vec<usize>>(),
1271            (0..8).collect::<Vec<usize>>()
1272        );
1273        // Both start and end wrap
1274        assert_eq!(
1275            iterate_expanded_and_wrapped_usize_range(3..5, 4, 4, 8).collect::<Vec<usize>>(),
1276            (0..8).collect::<Vec<usize>>()
1277        );
1278    }
1279}