util.rs

   1pub mod arc_cow;
   2pub mod command;
   3pub mod fs;
   4pub mod markdown;
   5pub mod paths;
   6pub mod serde;
   7pub mod size;
   8#[cfg(any(test, feature = "test-support"))]
   9pub mod test;
  10pub mod time;
  11
  12use anyhow::Result;
  13use futures::Future;
  14use itertools::Either;
  15use regex::Regex;
  16use std::sync::{LazyLock, OnceLock};
  17use std::{
  18    borrow::Cow,
  19    cmp::{self, Ordering},
  20    env,
  21    ops::{AddAssign, Range, RangeInclusive},
  22    panic::Location,
  23    pin::Pin,
  24    task::{Context, Poll},
  25    time::Instant,
  26};
  27use unicase::UniCase;
  28
  29#[cfg(unix)]
  30use anyhow::{Context as _, anyhow};
  31
  32pub use take_until::*;
  33#[cfg(any(test, feature = "test-support"))]
  34pub use util_macros::{line_endings, separator, uri};
  35
  36#[macro_export]
  37macro_rules! debug_panic {
  38    ( $($fmt_arg:tt)* ) => {
  39        if cfg!(debug_assertions) {
  40            panic!( $($fmt_arg)* );
  41        } else {
  42            let backtrace = std::backtrace::Backtrace::capture();
  43            log::error!("{}\n{:?}", format_args!($($fmt_arg)*), backtrace);
  44        }
  45    };
  46}
  47
  48/// A macro to add "C:" to the beginning of a path literal on Windows, and replace all
  49/// the separator from `/` to `\`.
  50/// But on non-Windows platforms, it will return the path literal as is.
  51///
  52/// # Examples
  53/// ```rust
  54/// use util::path;
  55///
  56/// let path = path!("/Users/user/file.txt");
  57/// #[cfg(target_os = "windows")]
  58/// assert_eq!(path, "C:\\Users\\user\\file.txt");
  59/// #[cfg(not(target_os = "windows"))]
  60/// assert_eq!(path, "/Users/user/file.txt");
  61/// ```
  62#[cfg(all(any(test, feature = "test-support"), target_os = "windows"))]
  63#[macro_export]
  64macro_rules! path {
  65    ($path:literal) => {
  66        concat!("C:", util::separator!($path))
  67    };
  68}
  69
  70/// A macro to add "C:" to the beginning of a path literal on Windows, and replace all
  71/// the separator from `/` to `\`.
  72/// But on non-Windows platforms, it will return the path literal as is.
  73///
  74/// # Examples
  75/// ```rust
  76/// use util::path;
  77///
  78/// let path = path!("/Users/user/file.txt");
  79/// #[cfg(target_os = "windows")]
  80/// assert_eq!(path, "C:\\Users\\user\\file.txt");
  81/// #[cfg(not(target_os = "windows"))]
  82/// assert_eq!(path, "/Users/user/file.txt");
  83/// ```
  84#[cfg(all(any(test, feature = "test-support"), not(target_os = "windows")))]
  85#[macro_export]
  86macro_rules! path {
  87    ($path:literal) => {
  88        $path
  89    };
  90}
  91
  92pub fn truncate(s: &str, max_chars: usize) -> &str {
  93    match s.char_indices().nth(max_chars) {
  94        None => s,
  95        Some((idx, _)) => &s[..idx],
  96    }
  97}
  98
  99/// Removes characters from the end of the string if its length is greater than `max_chars` and
 100/// appends "..." to the string. Returns string unchanged if its length is smaller than max_chars.
 101pub fn truncate_and_trailoff(s: &str, max_chars: usize) -> String {
 102    debug_assert!(max_chars >= 5);
 103
 104    // If the string's byte length is <= max_chars, walking the string can be skipped since the
 105    // number of chars is <= the number of bytes.
 106    if s.len() <= max_chars {
 107        return s.to_string();
 108    }
 109    let truncation_ix = s.char_indices().map(|(i, _)| i).nth(max_chars);
 110    match truncation_ix {
 111        Some(index) => s[..index].to_string() + "",
 112        _ => s.to_string(),
 113    }
 114}
 115
 116/// Removes characters from the front of the string if its length is greater than `max_chars` and
 117/// prepends the string with "...". Returns string unchanged if its length is smaller than max_chars.
 118pub fn truncate_and_remove_front(s: &str, max_chars: usize) -> String {
 119    debug_assert!(max_chars >= 5);
 120
 121    // If the string's byte length is <= max_chars, walking the string can be skipped since the
 122    // number of chars is <= the number of bytes.
 123    if s.len() <= max_chars {
 124        return s.to_string();
 125    }
 126    let suffix_char_length = max_chars.saturating_sub(1);
 127    let truncation_ix = s
 128        .char_indices()
 129        .map(|(i, _)| i)
 130        .nth_back(suffix_char_length);
 131    match truncation_ix {
 132        Some(index) if index > 0 => "".to_string() + &s[index..],
 133        _ => s.to_string(),
 134    }
 135}
 136
 137/// Takes only `max_lines` from the string and, if there were more than `max_lines-1`, appends a
 138/// a newline and "..." to the string, so that `max_lines` are returned.
 139/// Returns string unchanged if its length is smaller than max_lines.
 140pub fn truncate_lines_and_trailoff(s: &str, max_lines: usize) -> String {
 141    let mut lines = s.lines().take(max_lines).collect::<Vec<_>>();
 142    if lines.len() > max_lines - 1 {
 143        lines.pop();
 144        lines.join("\n") + "\n"
 145    } else {
 146        lines.join("\n")
 147    }
 148}
 149
 150/// Truncates the string at a character boundary, such that the result is less than `max_bytes` in
 151/// length.
 152pub fn truncate_to_byte_limit(s: &str, max_bytes: usize) -> &str {
 153    if s.len() < max_bytes {
 154        return s;
 155    }
 156
 157    for i in (0..max_bytes).rev() {
 158        if s.is_char_boundary(i) {
 159            return &s[..i];
 160        }
 161    }
 162
 163    ""
 164}
 165
 166/// Takes a prefix of complete lines which fit within the byte limit. If the first line is longer
 167/// than the limit, truncates at a character boundary.
 168pub fn truncate_lines_to_byte_limit(s: &str, max_bytes: usize) -> &str {
 169    if s.len() < max_bytes {
 170        return s;
 171    }
 172
 173    for i in (0..max_bytes).rev() {
 174        if s.is_char_boundary(i) {
 175            if s.as_bytes()[i] == b'\n' {
 176                // Since the i-th character is \n, valid to slice at i + 1.
 177                return &s[..i + 1];
 178            }
 179        }
 180    }
 181
 182    truncate_to_byte_limit(s, max_bytes)
 183}
 184
 185#[test]
 186fn test_truncate_lines_to_byte_limit() {
 187    let text = "Line 1\nLine 2\nLine 3\nLine 4";
 188
 189    // Limit that includes all lines
 190    assert_eq!(truncate_lines_to_byte_limit(text, 100), text);
 191
 192    // Exactly the first line
 193    assert_eq!(truncate_lines_to_byte_limit(text, 7), "Line 1\n");
 194
 195    // Limit between lines
 196    assert_eq!(truncate_lines_to_byte_limit(text, 13), "Line 1\n");
 197    assert_eq!(truncate_lines_to_byte_limit(text, 20), "Line 1\nLine 2\n");
 198
 199    // Limit before first newline
 200    assert_eq!(truncate_lines_to_byte_limit(text, 6), "Line ");
 201
 202    // Test with non-ASCII characters
 203    let text_utf8 = "Line 1\nLíne 2\nLine 3";
 204    assert_eq!(
 205        truncate_lines_to_byte_limit(text_utf8, 15),
 206        "Line 1\nLíne 2\n"
 207    );
 208}
 209
 210pub fn post_inc<T: From<u8> + AddAssign<T> + Copy>(value: &mut T) -> T {
 211    let prev = *value;
 212    *value += T::from(1);
 213    prev
 214}
 215
 216/// Extend a sorted vector with a sorted sequence of items, maintaining the vector's sort order and
 217/// enforcing a maximum length. This also de-duplicates items. Sort the items according to the given callback. Before calling this,
 218/// both `vec` and `new_items` should already be sorted according to the `cmp` comparator.
 219pub fn extend_sorted<T, I, F>(vec: &mut Vec<T>, new_items: I, limit: usize, mut cmp: F)
 220where
 221    I: IntoIterator<Item = T>,
 222    F: FnMut(&T, &T) -> Ordering,
 223{
 224    let mut start_index = 0;
 225    for new_item in new_items {
 226        if let Err(i) = vec[start_index..].binary_search_by(|m| cmp(m, &new_item)) {
 227            let index = start_index + i;
 228            if vec.len() < limit {
 229                vec.insert(index, new_item);
 230            } else if index < vec.len() {
 231                vec.pop();
 232                vec.insert(index, new_item);
 233            }
 234            start_index = index;
 235        }
 236    }
 237}
 238
 239pub fn truncate_to_bottom_n_sorted_by<T, F>(items: &mut Vec<T>, limit: usize, compare: &F)
 240where
 241    F: Fn(&T, &T) -> Ordering,
 242{
 243    if limit == 0 {
 244        items.truncate(0);
 245    }
 246    if items.len() <= limit {
 247        items.sort_by(compare);
 248        return;
 249    }
 250    // When limit is near to items.len() it may be more efficient to sort the whole list and
 251    // truncate, rather than always doing selection first as is done below. It's hard to analyze
 252    // where the threshold for this should be since the quickselect style algorithm used by
 253    // `select_nth_unstable_by` makes the prefix partially sorted, and so its work is not wasted -
 254    // the expected number of comparisons needed by `sort_by` is less than it is for some arbitrary
 255    // unsorted input.
 256    items.select_nth_unstable_by(limit, compare);
 257    items.truncate(limit);
 258    items.sort_by(compare);
 259}
 260
 261#[cfg(unix)]
 262fn load_shell_from_passwd() -> Result<()> {
 263    let buflen = match unsafe { libc::sysconf(libc::_SC_GETPW_R_SIZE_MAX) } {
 264        n if n < 0 => 1024,
 265        n => n as usize,
 266    };
 267    let mut buffer = Vec::with_capacity(buflen);
 268
 269    let mut pwd: std::mem::MaybeUninit<libc::passwd> = std::mem::MaybeUninit::uninit();
 270    let mut result: *mut libc::passwd = std::ptr::null_mut();
 271
 272    let uid = unsafe { libc::getuid() };
 273    let status = unsafe {
 274        libc::getpwuid_r(
 275            uid,
 276            pwd.as_mut_ptr(),
 277            buffer.as_mut_ptr() as *mut libc::c_char,
 278            buflen,
 279            &mut result,
 280        )
 281    };
 282    let entry = unsafe { pwd.assume_init() };
 283
 284    anyhow::ensure!(
 285        status == 0,
 286        "call to getpwuid_r failed. uid: {}, status: {}",
 287        uid,
 288        status
 289    );
 290    anyhow::ensure!(!result.is_null(), "passwd entry for uid {} not found", uid);
 291    anyhow::ensure!(
 292        entry.pw_uid == uid,
 293        "passwd entry has different uid ({}) than getuid ({}) returned",
 294        entry.pw_uid,
 295        uid,
 296    );
 297
 298    let shell = unsafe { std::ffi::CStr::from_ptr(entry.pw_shell).to_str().unwrap() };
 299    if env::var("SHELL").map_or(true, |shell_env| shell_env != shell) {
 300        log::info!(
 301            "updating SHELL environment variable to value from passwd entry: {:?}",
 302            shell,
 303        );
 304        unsafe { env::set_var("SHELL", shell) };
 305    }
 306
 307    Ok(())
 308}
 309
 310#[cfg(unix)]
 311pub fn load_login_shell_environment() -> Result<()> {
 312    load_shell_from_passwd().log_err();
 313
 314    let marker = "ZED_LOGIN_SHELL_START";
 315    let shell = env::var("SHELL").context(
 316        "SHELL environment variable is not assigned so we can't source login environment variables",
 317    )?;
 318
 319    // If possible, we want to `cd` in the user's `$HOME` to trigger programs
 320    // such as direnv, asdf, mise, ... to adjust the PATH. These tools often hook
 321    // into shell's `cd` command (and hooks) to manipulate env.
 322    // We do this so that we get the env a user would have when spawning a shell
 323    // in home directory.
 324    let shell_cmd_prefix = std::env::var_os("HOME")
 325        .and_then(|home| home.into_string().ok())
 326        .map(|home| format!("cd '{home}';"));
 327
 328    let shell_cmd = format!(
 329        "{}printf '%s' {marker}; /usr/bin/env;",
 330        shell_cmd_prefix.as_deref().unwrap_or("")
 331    );
 332
 333    let output = set_pre_exec_to_start_new_session(
 334        std::process::Command::new(&shell).args(["-l", "-i", "-c", &shell_cmd]),
 335    )
 336    .output()
 337    .context("failed to spawn login shell to source login environment variables")?;
 338    if !output.status.success() {
 339        Err(anyhow!("login shell exited with error"))?;
 340    }
 341
 342    let stdout = String::from_utf8_lossy(&output.stdout);
 343
 344    if let Some(env_output_start) = stdout.find(marker) {
 345        let env_output = &stdout[env_output_start + marker.len()..];
 346
 347        parse_env_output(env_output, |key, value| unsafe { env::set_var(key, value) });
 348
 349        log::info!(
 350            "set environment variables from shell:{}, path:{}",
 351            shell,
 352            env::var("PATH").unwrap_or_default(),
 353        );
 354    }
 355
 356    Ok(())
 357}
 358
 359/// Configures the process to start a new session, to prevent interactive shells from taking control
 360/// of the terminal.
 361///
 362/// For more details: https://registerspill.thorstenball.com/p/how-to-lose-control-of-your-shell
 363pub fn set_pre_exec_to_start_new_session(
 364    command: &mut std::process::Command,
 365) -> &mut std::process::Command {
 366    // safety: code in pre_exec should be signal safe.
 367    // https://man7.org/linux/man-pages/man7/signal-safety.7.html
 368    #[cfg(not(target_os = "windows"))]
 369    unsafe {
 370        use std::os::unix::process::CommandExt;
 371        command.pre_exec(|| {
 372            libc::setsid();
 373            Ok(())
 374        });
 375    };
 376    command
 377}
 378
 379/// Parse the result of calling `usr/bin/env` with no arguments
 380pub fn parse_env_output(env: &str, mut f: impl FnMut(String, String)) {
 381    let mut current_key: Option<String> = None;
 382    let mut current_value: Option<String> = None;
 383
 384    for line in env.split_terminator('\n') {
 385        if let Some(separator_index) = line.find('=') {
 386            if !line[..separator_index].is_empty() {
 387                if let Some((key, value)) = Option::zip(current_key.take(), current_value.take()) {
 388                    f(key, value)
 389                }
 390                current_key = Some(line[..separator_index].to_string());
 391                current_value = Some(line[separator_index + 1..].to_string());
 392                continue;
 393            };
 394        }
 395        if let Some(value) = current_value.as_mut() {
 396            value.push('\n');
 397            value.push_str(line);
 398        }
 399    }
 400    if let Some((key, value)) = Option::zip(current_key.take(), current_value.take()) {
 401        f(key, value)
 402    }
 403}
 404
 405pub fn merge_json_value_into(source: serde_json::Value, target: &mut serde_json::Value) {
 406    use serde_json::Value;
 407
 408    match (source, target) {
 409        (Value::Object(source), Value::Object(target)) => {
 410            for (key, value) in source {
 411                if let Some(target) = target.get_mut(&key) {
 412                    merge_json_value_into(value, target);
 413                } else {
 414                    target.insert(key, value);
 415                }
 416            }
 417        }
 418
 419        (Value::Array(source), Value::Array(target)) => {
 420            for value in source {
 421                target.push(value);
 422            }
 423        }
 424
 425        (source, target) => *target = source,
 426    }
 427}
 428
 429pub fn merge_non_null_json_value_into(source: serde_json::Value, target: &mut serde_json::Value) {
 430    use serde_json::Value;
 431    if let Value::Object(source_object) = source {
 432        let target_object = if let Value::Object(target) = target {
 433            target
 434        } else {
 435            *target = Value::Object(Default::default());
 436            target.as_object_mut().unwrap()
 437        };
 438        for (key, value) in source_object {
 439            if let Some(target) = target_object.get_mut(&key) {
 440                merge_non_null_json_value_into(value, target);
 441            } else if !value.is_null() {
 442                target_object.insert(key, value);
 443            }
 444        }
 445    } else if !source.is_null() {
 446        *target = source
 447    }
 448}
 449
 450pub fn measure<R>(label: &str, f: impl FnOnce() -> R) -> R {
 451    static ZED_MEASUREMENTS: OnceLock<bool> = OnceLock::new();
 452    let zed_measurements = ZED_MEASUREMENTS.get_or_init(|| {
 453        env::var("ZED_MEASUREMENTS")
 454            .map(|measurements| measurements == "1" || measurements == "true")
 455            .unwrap_or(false)
 456    });
 457
 458    if *zed_measurements {
 459        let start = Instant::now();
 460        let result = f();
 461        let elapsed = start.elapsed();
 462        eprintln!("{}: {:?}", label, elapsed);
 463        result
 464    } else {
 465        f()
 466    }
 467}
 468
 469pub fn iterate_expanded_and_wrapped_usize_range(
 470    range: Range<usize>,
 471    additional_before: usize,
 472    additional_after: usize,
 473    wrap_length: usize,
 474) -> impl Iterator<Item = usize> {
 475    let start_wraps = range.start < additional_before;
 476    let end_wraps = wrap_length < range.end + additional_after;
 477    if start_wraps && end_wraps {
 478        Either::Left(0..wrap_length)
 479    } else if start_wraps {
 480        let wrapped_start = (range.start + wrap_length).saturating_sub(additional_before);
 481        if wrapped_start <= range.end {
 482            Either::Left(0..wrap_length)
 483        } else {
 484            Either::Right((0..range.end + additional_after).chain(wrapped_start..wrap_length))
 485        }
 486    } else if end_wraps {
 487        let wrapped_end = range.end + additional_after - wrap_length;
 488        if range.start <= wrapped_end {
 489            Either::Left(0..wrap_length)
 490        } else {
 491            Either::Right((0..wrapped_end).chain(range.start - additional_before..wrap_length))
 492        }
 493    } else {
 494        Either::Left((range.start - additional_before)..(range.end + additional_after))
 495    }
 496}
 497
 498#[cfg(target_os = "windows")]
 499pub fn get_windows_system_shell() -> String {
 500    use std::path::PathBuf;
 501
 502    fn find_pwsh_in_programfiles(find_alternate: bool, find_preview: bool) -> Option<PathBuf> {
 503        #[cfg(target_pointer_width = "64")]
 504        let env_var = if find_alternate {
 505            "ProgramFiles(x86)"
 506        } else {
 507            "ProgramFiles"
 508        };
 509
 510        #[cfg(target_pointer_width = "32")]
 511        let env_var = if find_alternate {
 512            "ProgramW6432"
 513        } else {
 514            "ProgramFiles"
 515        };
 516
 517        let install_base_dir = PathBuf::from(std::env::var_os(env_var)?).join("PowerShell");
 518        install_base_dir
 519            .read_dir()
 520            .ok()?
 521            .filter_map(Result::ok)
 522            .filter(|entry| matches!(entry.file_type(), Ok(ft) if ft.is_dir()))
 523            .filter_map(|entry| {
 524                let dir_name = entry.file_name();
 525                let dir_name = dir_name.to_string_lossy();
 526
 527                let version = if find_preview {
 528                    let dash_index = dir_name.find('-')?;
 529                    if &dir_name[dash_index + 1..] != "preview" {
 530                        return None;
 531                    };
 532                    dir_name[..dash_index].parse::<u32>().ok()?
 533                } else {
 534                    dir_name.parse::<u32>().ok()?
 535                };
 536
 537                let exe_path = entry.path().join("pwsh.exe");
 538                if exe_path.exists() {
 539                    Some((version, exe_path))
 540                } else {
 541                    None
 542                }
 543            })
 544            .max_by_key(|(version, _)| *version)
 545            .map(|(_, path)| path)
 546    }
 547
 548    fn find_pwsh_in_msix(find_preview: bool) -> Option<PathBuf> {
 549        let msix_app_dir =
 550            PathBuf::from(std::env::var_os("LOCALAPPDATA")?).join("Microsoft\\WindowsApps");
 551        if !msix_app_dir.exists() {
 552            return None;
 553        }
 554
 555        let prefix = if find_preview {
 556            "Microsoft.PowerShellPreview_"
 557        } else {
 558            "Microsoft.PowerShell_"
 559        };
 560        msix_app_dir
 561            .read_dir()
 562            .ok()?
 563            .filter_map(|entry| {
 564                let entry = entry.ok()?;
 565                if !matches!(entry.file_type(), Ok(ft) if ft.is_dir()) {
 566                    return None;
 567                }
 568
 569                if !entry.file_name().to_string_lossy().starts_with(prefix) {
 570                    return None;
 571                }
 572
 573                let exe_path = entry.path().join("pwsh.exe");
 574                exe_path.exists().then_some(exe_path)
 575            })
 576            .next()
 577    }
 578
 579    fn find_pwsh_in_scoop() -> Option<PathBuf> {
 580        let pwsh_exe =
 581            PathBuf::from(std::env::var_os("USERPROFILE")?).join("scoop\\shims\\pwsh.exe");
 582        pwsh_exe.exists().then_some(pwsh_exe)
 583    }
 584
 585    static SYSTEM_SHELL: LazyLock<String> = LazyLock::new(|| {
 586        find_pwsh_in_programfiles(false, false)
 587            .or_else(|| find_pwsh_in_programfiles(true, false))
 588            .or_else(|| find_pwsh_in_msix(false))
 589            .or_else(|| find_pwsh_in_programfiles(false, true))
 590            .or_else(|| find_pwsh_in_msix(true))
 591            .or_else(|| find_pwsh_in_programfiles(true, true))
 592            .or_else(find_pwsh_in_scoop)
 593            .map(|p| p.to_string_lossy().to_string())
 594            .unwrap_or("powershell.exe".to_string())
 595    });
 596
 597    (*SYSTEM_SHELL).clone()
 598}
 599
 600pub trait ResultExt<E> {
 601    type Ok;
 602
 603    fn log_err(self) -> Option<Self::Ok>;
 604    /// Assert that this result should never be an error in development or tests.
 605    fn debug_assert_ok(self, reason: &str) -> Self;
 606    fn warn_on_err(self) -> Option<Self::Ok>;
 607    fn log_with_level(self, level: log::Level) -> Option<Self::Ok>;
 608    fn anyhow(self) -> anyhow::Result<Self::Ok>
 609    where
 610        E: Into<anyhow::Error>;
 611}
 612
 613impl<T, E> ResultExt<E> for Result<T, E>
 614where
 615    E: std::fmt::Debug,
 616{
 617    type Ok = T;
 618
 619    #[track_caller]
 620    fn log_err(self) -> Option<T> {
 621        self.log_with_level(log::Level::Error)
 622    }
 623
 624    #[track_caller]
 625    fn debug_assert_ok(self, reason: &str) -> Self {
 626        if let Err(error) = &self {
 627            debug_panic!("{reason} - {error:?}");
 628        }
 629        self
 630    }
 631
 632    #[track_caller]
 633    fn warn_on_err(self) -> Option<T> {
 634        self.log_with_level(log::Level::Warn)
 635    }
 636
 637    #[track_caller]
 638    fn log_with_level(self, level: log::Level) -> Option<T> {
 639        match self {
 640            Ok(value) => Some(value),
 641            Err(error) => {
 642                log_error_with_caller(*Location::caller(), error, level);
 643                None
 644            }
 645        }
 646    }
 647
 648    fn anyhow(self) -> anyhow::Result<T>
 649    where
 650        E: Into<anyhow::Error>,
 651    {
 652        self.map_err(Into::into)
 653    }
 654}
 655
 656fn log_error_with_caller<E>(caller: core::panic::Location<'_>, error: E, level: log::Level)
 657where
 658    E: std::fmt::Debug,
 659{
 660    #[cfg(not(target_os = "windows"))]
 661    let file = caller.file();
 662    #[cfg(target_os = "windows")]
 663    let file = caller.file().replace('\\', "/");
 664    // In this codebase, the first segment of the file path is
 665    // the 'crates' folder, followed by the crate name.
 666    let target = file.split('/').nth(1);
 667
 668    log::logger().log(
 669        &log::Record::builder()
 670            .target(target.unwrap_or(""))
 671            .module_path(target)
 672            .args(format_args!("{:?}", error))
 673            .file(Some(caller.file()))
 674            .line(Some(caller.line()))
 675            .level(level)
 676            .build(),
 677    );
 678}
 679
 680pub fn log_err<E: std::fmt::Debug>(error: &E) {
 681    log_error_with_caller(*Location::caller(), error, log::Level::Warn);
 682}
 683
 684pub trait TryFutureExt {
 685    fn log_err(self) -> LogErrorFuture<Self>
 686    where
 687        Self: Sized;
 688
 689    fn log_tracked_err(self, location: core::panic::Location<'static>) -> LogErrorFuture<Self>
 690    where
 691        Self: Sized;
 692
 693    fn warn_on_err(self) -> LogErrorFuture<Self>
 694    where
 695        Self: Sized;
 696    fn unwrap(self) -> UnwrapFuture<Self>
 697    where
 698        Self: Sized;
 699}
 700
 701impl<F, T, E> TryFutureExt for F
 702where
 703    F: Future<Output = Result<T, E>>,
 704    E: std::fmt::Debug,
 705{
 706    #[track_caller]
 707    fn log_err(self) -> LogErrorFuture<Self>
 708    where
 709        Self: Sized,
 710    {
 711        let location = Location::caller();
 712        LogErrorFuture(self, log::Level::Error, *location)
 713    }
 714
 715    fn log_tracked_err(self, location: core::panic::Location<'static>) -> LogErrorFuture<Self>
 716    where
 717        Self: Sized,
 718    {
 719        LogErrorFuture(self, log::Level::Error, location)
 720    }
 721
 722    #[track_caller]
 723    fn warn_on_err(self) -> LogErrorFuture<Self>
 724    where
 725        Self: Sized,
 726    {
 727        let location = Location::caller();
 728        LogErrorFuture(self, log::Level::Warn, *location)
 729    }
 730
 731    fn unwrap(self) -> UnwrapFuture<Self>
 732    where
 733        Self: Sized,
 734    {
 735        UnwrapFuture(self)
 736    }
 737}
 738
 739#[must_use]
 740pub struct LogErrorFuture<F>(F, log::Level, core::panic::Location<'static>);
 741
 742impl<F, T, E> Future for LogErrorFuture<F>
 743where
 744    F: Future<Output = Result<T, E>>,
 745    E: std::fmt::Debug,
 746{
 747    type Output = Option<T>;
 748
 749    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
 750        let level = self.1;
 751        let location = self.2;
 752        let inner = unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().0) };
 753        match inner.poll(cx) {
 754            Poll::Ready(output) => Poll::Ready(match output {
 755                Ok(output) => Some(output),
 756                Err(error) => {
 757                    log_error_with_caller(location, error, level);
 758                    None
 759                }
 760            }),
 761            Poll::Pending => Poll::Pending,
 762        }
 763    }
 764}
 765
 766pub struct UnwrapFuture<F>(F);
 767
 768impl<F, T, E> Future for UnwrapFuture<F>
 769where
 770    F: Future<Output = Result<T, E>>,
 771    E: std::fmt::Debug,
 772{
 773    type Output = T;
 774
 775    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
 776        let inner = unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().0) };
 777        match inner.poll(cx) {
 778            Poll::Ready(result) => Poll::Ready(result.unwrap()),
 779            Poll::Pending => Poll::Pending,
 780        }
 781    }
 782}
 783
 784pub struct Deferred<F: FnOnce()>(Option<F>);
 785
 786impl<F: FnOnce()> Deferred<F> {
 787    /// Drop without running the deferred function.
 788    pub fn abort(mut self) {
 789        self.0.take();
 790    }
 791}
 792
 793impl<F: FnOnce()> Drop for Deferred<F> {
 794    fn drop(&mut self) {
 795        if let Some(f) = self.0.take() {
 796            f()
 797        }
 798    }
 799}
 800
 801/// Run the given function when the returned value is dropped (unless it's cancelled).
 802#[must_use]
 803pub fn defer<F: FnOnce()>(f: F) -> Deferred<F> {
 804    Deferred(Some(f))
 805}
 806
 807#[cfg(any(test, feature = "test-support"))]
 808mod rng {
 809    use rand::{Rng, seq::SliceRandom};
 810    pub struct RandomCharIter<T: Rng> {
 811        rng: T,
 812        simple_text: bool,
 813    }
 814
 815    impl<T: Rng> RandomCharIter<T> {
 816        pub fn new(rng: T) -> Self {
 817            Self {
 818                rng,
 819                simple_text: std::env::var("SIMPLE_TEXT").map_or(false, |v| !v.is_empty()),
 820            }
 821        }
 822
 823        pub fn with_simple_text(mut self) -> Self {
 824            self.simple_text = true;
 825            self
 826        }
 827    }
 828
 829    impl<T: Rng> Iterator for RandomCharIter<T> {
 830        type Item = char;
 831
 832        fn next(&mut self) -> Option<Self::Item> {
 833            if self.simple_text {
 834                return if self.rng.gen_range(0..100) < 5 {
 835                    Some('\n')
 836                } else {
 837                    Some(self.rng.gen_range(b'a'..b'z' + 1).into())
 838                };
 839            }
 840
 841            match self.rng.gen_range(0..100) {
 842                // whitespace
 843                0..=19 => [' ', '\n', '\r', '\t'].choose(&mut self.rng).copied(),
 844                // two-byte greek letters
 845                20..=32 => char::from_u32(self.rng.gen_range(('α' as u32)..('ω' as u32 + 1))),
 846                // // three-byte characters
 847                33..=45 => ['✋', '✅', '❌', '❎', '⭐']
 848                    .choose(&mut self.rng)
 849                    .copied(),
 850                // // four-byte characters
 851                46..=58 => ['🍐', '🏀', '🍗', '🎉'].choose(&mut self.rng).copied(),
 852                // ascii letters
 853                _ => Some(self.rng.gen_range(b'a'..b'z' + 1).into()),
 854            }
 855        }
 856    }
 857}
 858#[cfg(any(test, feature = "test-support"))]
 859pub use rng::RandomCharIter;
 860/// Get an embedded file as a string.
 861pub fn asset_str<A: rust_embed::RustEmbed>(path: &str) -> Cow<'static, str> {
 862    match A::get(path).expect(path).data {
 863        Cow::Borrowed(bytes) => Cow::Borrowed(std::str::from_utf8(bytes).unwrap()),
 864        Cow::Owned(bytes) => Cow::Owned(String::from_utf8(bytes).unwrap()),
 865    }
 866}
 867
 868/// Expands to an immediately-invoked function expression. Good for using the ? operator
 869/// in functions which do not return an Option or Result.
 870///
 871/// Accepts a normal block, an async block, or an async move block.
 872#[macro_export]
 873macro_rules! maybe {
 874    ($block:block) => {
 875        (|| $block)()
 876    };
 877    (async $block:block) => {
 878        (|| async $block)()
 879    };
 880    (async move $block:block) => {
 881        (|| async move $block)()
 882    };
 883}
 884
 885pub trait RangeExt<T> {
 886    fn sorted(&self) -> Self;
 887    fn to_inclusive(&self) -> RangeInclusive<T>;
 888    fn overlaps(&self, other: &Range<T>) -> bool;
 889    fn contains_inclusive(&self, other: &Range<T>) -> bool;
 890}
 891
 892impl<T: Ord + Clone> RangeExt<T> for Range<T> {
 893    fn sorted(&self) -> Self {
 894        cmp::min(&self.start, &self.end).clone()..cmp::max(&self.start, &self.end).clone()
 895    }
 896
 897    fn to_inclusive(&self) -> RangeInclusive<T> {
 898        self.start.clone()..=self.end.clone()
 899    }
 900
 901    fn overlaps(&self, other: &Range<T>) -> bool {
 902        self.start < other.end && other.start < self.end
 903    }
 904
 905    fn contains_inclusive(&self, other: &Range<T>) -> bool {
 906        self.start <= other.start && other.end <= self.end
 907    }
 908}
 909
 910impl<T: Ord + Clone> RangeExt<T> for RangeInclusive<T> {
 911    fn sorted(&self) -> Self {
 912        cmp::min(self.start(), self.end()).clone()..=cmp::max(self.start(), self.end()).clone()
 913    }
 914
 915    fn to_inclusive(&self) -> RangeInclusive<T> {
 916        self.clone()
 917    }
 918
 919    fn overlaps(&self, other: &Range<T>) -> bool {
 920        self.start() < &other.end && &other.start <= self.end()
 921    }
 922
 923    fn contains_inclusive(&self, other: &Range<T>) -> bool {
 924        self.start() <= &other.start && &other.end <= self.end()
 925    }
 926}
 927
 928/// A way to sort strings with starting numbers numerically first, falling back to alphanumeric one,
 929/// case-insensitive.
 930///
 931/// This is useful for turning regular alphanumerically sorted sequences as `1-abc, 10, 11-def, .., 2, 21-abc`
 932/// into `1-abc, 2, 10, 11-def, .., 21-abc`
 933#[derive(Debug, PartialEq, Eq)]
 934pub struct NumericPrefixWithSuffix<'a>(Option<u64>, &'a str);
 935
 936impl<'a> NumericPrefixWithSuffix<'a> {
 937    pub fn from_numeric_prefixed_str(str: &'a str) -> Self {
 938        let i = str.chars().take_while(|c| c.is_ascii_digit()).count();
 939        let (prefix, remainder) = str.split_at(i);
 940
 941        let prefix = prefix.parse().ok();
 942        Self(prefix, remainder)
 943    }
 944}
 945
 946/// When dealing with equality, we need to consider the case of the strings to achieve strict equality
 947/// to handle cases like "a" < "A" instead of "a" == "A".
 948impl Ord for NumericPrefixWithSuffix<'_> {
 949    fn cmp(&self, other: &Self) -> Ordering {
 950        match (self.0, other.0) {
 951            (None, None) => UniCase::new(self.1)
 952                .cmp(&UniCase::new(other.1))
 953                .then_with(|| self.1.cmp(other.1).reverse()),
 954            (None, Some(_)) => Ordering::Greater,
 955            (Some(_), None) => Ordering::Less,
 956            (Some(a), Some(b)) => a.cmp(&b).then_with(|| {
 957                UniCase::new(self.1)
 958                    .cmp(&UniCase::new(other.1))
 959                    .then_with(|| self.1.cmp(other.1).reverse())
 960            }),
 961        }
 962    }
 963}
 964
 965impl PartialOrd for NumericPrefixWithSuffix<'_> {
 966    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
 967        Some(self.cmp(other))
 968    }
 969}
 970
 971/// Capitalizes the first character of a string.
 972///
 973/// This function takes a string slice as input and returns a new `String` with the first character
 974/// capitalized.
 975///
 976/// # Examples
 977///
 978/// ```
 979/// use util::capitalize;
 980///
 981/// assert_eq!(capitalize("hello"), "Hello");
 982/// assert_eq!(capitalize("WORLD"), "WORLD");
 983/// assert_eq!(capitalize(""), "");
 984/// ```
 985pub fn capitalize(str: &str) -> String {
 986    let mut chars = str.chars();
 987    match chars.next() {
 988        None => String::new(),
 989        Some(first_char) => first_char.to_uppercase().collect::<String>() + chars.as_str(),
 990    }
 991}
 992
 993fn emoji_regex() -> &'static Regex {
 994    static EMOJI_REGEX: LazyLock<Regex> =
 995        LazyLock::new(|| Regex::new("(\\p{Emoji}|\u{200D})").unwrap());
 996    &EMOJI_REGEX
 997}
 998
 999/// Returns true if the given string consists of emojis only.
1000/// E.g. "👨‍👩‍👧‍👧👋" will return true, but "👋!" will return false.
1001pub fn word_consists_of_emojis(s: &str) -> bool {
1002    let mut prev_end = 0;
1003    for capture in emoji_regex().find_iter(s) {
1004        if capture.start() != prev_end {
1005            return false;
1006        }
1007        prev_end = capture.end();
1008    }
1009    prev_end == s.len()
1010}
1011
1012pub fn default<D: Default>() -> D {
1013    Default::default()
1014}
1015
1016pub fn get_system_shell() -> String {
1017    #[cfg(target_os = "windows")]
1018    {
1019        get_windows_system_shell()
1020    }
1021
1022    #[cfg(not(target_os = "windows"))]
1023    {
1024        std::env::var("SHELL").unwrap_or("/bin/sh".to_string())
1025    }
1026}
1027
1028#[derive(Debug)]
1029pub enum ConnectionResult<O> {
1030    Timeout,
1031    ConnectionReset,
1032    Result(anyhow::Result<O>),
1033}
1034
1035impl<O> ConnectionResult<O> {
1036    pub fn into_response(self) -> anyhow::Result<O> {
1037        match self {
1038            ConnectionResult::Timeout => anyhow::bail!("Request timed out"),
1039            ConnectionResult::ConnectionReset => anyhow::bail!("Server reset the connection"),
1040            ConnectionResult::Result(r) => r,
1041        }
1042    }
1043}
1044
1045impl<O> From<anyhow::Result<O>> for ConnectionResult<O> {
1046    fn from(result: anyhow::Result<O>) -> Self {
1047        ConnectionResult::Result(result)
1048    }
1049}
1050
1051#[cfg(test)]
1052mod tests {
1053    use super::*;
1054
1055    #[test]
1056    fn test_extend_sorted() {
1057        let mut vec = vec![];
1058
1059        extend_sorted(&mut vec, vec![21, 17, 13, 8, 1, 0], 5, |a, b| b.cmp(a));
1060        assert_eq!(vec, &[21, 17, 13, 8, 1]);
1061
1062        extend_sorted(&mut vec, vec![101, 19, 17, 8, 2], 8, |a, b| b.cmp(a));
1063        assert_eq!(vec, &[101, 21, 19, 17, 13, 8, 2, 1]);
1064
1065        extend_sorted(&mut vec, vec![1000, 19, 17, 9, 5], 8, |a, b| b.cmp(a));
1066        assert_eq!(vec, &[1000, 101, 21, 19, 17, 13, 9, 8]);
1067    }
1068
1069    #[test]
1070    fn test_truncate_to_bottom_n_sorted_by() {
1071        let mut vec: Vec<u32> = vec![5, 2, 3, 4, 1];
1072        truncate_to_bottom_n_sorted_by(&mut vec, 10, &u32::cmp);
1073        assert_eq!(vec, &[1, 2, 3, 4, 5]);
1074
1075        vec = vec![5, 2, 3, 4, 1];
1076        truncate_to_bottom_n_sorted_by(&mut vec, 5, &u32::cmp);
1077        assert_eq!(vec, &[1, 2, 3, 4, 5]);
1078
1079        vec = vec![5, 2, 3, 4, 1];
1080        truncate_to_bottom_n_sorted_by(&mut vec, 4, &u32::cmp);
1081        assert_eq!(vec, &[1, 2, 3, 4]);
1082
1083        vec = vec![5, 2, 3, 4, 1];
1084        truncate_to_bottom_n_sorted_by(&mut vec, 1, &u32::cmp);
1085        assert_eq!(vec, &[1]);
1086
1087        vec = vec![5, 2, 3, 4, 1];
1088        truncate_to_bottom_n_sorted_by(&mut vec, 0, &u32::cmp);
1089        assert!(vec.is_empty());
1090    }
1091
1092    #[test]
1093    fn test_iife() {
1094        fn option_returning_function() -> Option<()> {
1095            None
1096        }
1097
1098        let foo = maybe!({
1099            option_returning_function()?;
1100            Some(())
1101        });
1102
1103        assert_eq!(foo, None);
1104    }
1105
1106    #[test]
1107    fn test_truncate_and_trailoff() {
1108        assert_eq!(truncate_and_trailoff("", 5), "");
1109        assert_eq!(truncate_and_trailoff("aaaaaa", 7), "aaaaaa");
1110        assert_eq!(truncate_and_trailoff("aaaaaa", 6), "aaaaaa");
1111        assert_eq!(truncate_and_trailoff("aaaaaa", 5), "aaaaa…");
1112        assert_eq!(truncate_and_trailoff("èèèèèè", 7), "èèèèèè");
1113        assert_eq!(truncate_and_trailoff("èèèèèè", 6), "èèèèèè");
1114        assert_eq!(truncate_and_trailoff("èèèèèè", 5), "èèèèè…");
1115    }
1116
1117    #[test]
1118    fn test_truncate_and_remove_front() {
1119        assert_eq!(truncate_and_remove_front("", 5), "");
1120        assert_eq!(truncate_and_remove_front("aaaaaa", 7), "aaaaaa");
1121        assert_eq!(truncate_and_remove_front("aaaaaa", 6), "aaaaaa");
1122        assert_eq!(truncate_and_remove_front("aaaaaa", 5), "…aaaaa");
1123        assert_eq!(truncate_and_remove_front("èèèèèè", 7), "èèèèèè");
1124        assert_eq!(truncate_and_remove_front("èèèèèè", 6), "èèèèèè");
1125        assert_eq!(truncate_and_remove_front("èèèèèè", 5), "…èèèèè");
1126    }
1127
1128    #[test]
1129    fn test_numeric_prefix_str_method() {
1130        let target = "1a";
1131        assert_eq!(
1132            NumericPrefixWithSuffix::from_numeric_prefixed_str(target),
1133            NumericPrefixWithSuffix(Some(1), "a")
1134        );
1135
1136        let target = "12ab";
1137        assert_eq!(
1138            NumericPrefixWithSuffix::from_numeric_prefixed_str(target),
1139            NumericPrefixWithSuffix(Some(12), "ab")
1140        );
1141
1142        let target = "12_ab";
1143        assert_eq!(
1144            NumericPrefixWithSuffix::from_numeric_prefixed_str(target),
1145            NumericPrefixWithSuffix(Some(12), "_ab")
1146        );
1147
1148        let target = "1_2ab";
1149        assert_eq!(
1150            NumericPrefixWithSuffix::from_numeric_prefixed_str(target),
1151            NumericPrefixWithSuffix(Some(1), "_2ab")
1152        );
1153
1154        let target = "1.2";
1155        assert_eq!(
1156            NumericPrefixWithSuffix::from_numeric_prefixed_str(target),
1157            NumericPrefixWithSuffix(Some(1), ".2")
1158        );
1159
1160        let target = "1.2_a";
1161        assert_eq!(
1162            NumericPrefixWithSuffix::from_numeric_prefixed_str(target),
1163            NumericPrefixWithSuffix(Some(1), ".2_a")
1164        );
1165
1166        let target = "12.2_a";
1167        assert_eq!(
1168            NumericPrefixWithSuffix::from_numeric_prefixed_str(target),
1169            NumericPrefixWithSuffix(Some(12), ".2_a")
1170        );
1171
1172        let target = "12a.2_a";
1173        assert_eq!(
1174            NumericPrefixWithSuffix::from_numeric_prefixed_str(target),
1175            NumericPrefixWithSuffix(Some(12), "a.2_a")
1176        );
1177    }
1178
1179    #[test]
1180    fn test_numeric_prefix_with_suffix() {
1181        let mut sorted = vec!["1-abc", "10", "11def", "2", "21-abc"];
1182        sorted.sort_by_key(|s| NumericPrefixWithSuffix::from_numeric_prefixed_str(s));
1183        assert_eq!(sorted, ["1-abc", "2", "10", "11def", "21-abc"]);
1184
1185        for numeric_prefix_less in ["numeric_prefix_less", "aaa", "~™£"] {
1186            assert_eq!(
1187                NumericPrefixWithSuffix::from_numeric_prefixed_str(numeric_prefix_less),
1188                NumericPrefixWithSuffix(None, numeric_prefix_less),
1189                "String without numeric prefix `{numeric_prefix_less}` should not be converted into NumericPrefixWithSuffix"
1190            )
1191        }
1192    }
1193
1194    #[test]
1195    fn test_word_consists_of_emojis() {
1196        let words_to_test = vec![
1197            ("👨‍👩‍👧‍👧👋🥒", true),
1198            ("👋", true),
1199            ("!👋", false),
1200            ("👋!", false),
1201            ("👋 ", false),
1202            (" 👋", false),
1203            ("Test", false),
1204        ];
1205
1206        for (text, expected_result) in words_to_test {
1207            assert_eq!(word_consists_of_emojis(text), expected_result);
1208        }
1209    }
1210
1211    #[test]
1212    fn test_truncate_lines_and_trailoff() {
1213        let text = r#"Line 1
1214Line 2
1215Line 3"#;
1216
1217        assert_eq!(
1218            truncate_lines_and_trailoff(text, 2),
1219            r#"Line 1
1220…"#
1221        );
1222
1223        assert_eq!(
1224            truncate_lines_and_trailoff(text, 3),
1225            r#"Line 1
1226Line 2
1227…"#
1228        );
1229
1230        assert_eq!(
1231            truncate_lines_and_trailoff(text, 4),
1232            r#"Line 1
1233Line 2
1234Line 3"#
1235        );
1236    }
1237
1238    #[test]
1239    fn test_iterate_expanded_and_wrapped_usize_range() {
1240        // Neither wrap
1241        assert_eq!(
1242            iterate_expanded_and_wrapped_usize_range(2..4, 1, 1, 8).collect::<Vec<usize>>(),
1243            (1..5).collect::<Vec<usize>>()
1244        );
1245        // Start wraps
1246        assert_eq!(
1247            iterate_expanded_and_wrapped_usize_range(2..4, 3, 1, 8).collect::<Vec<usize>>(),
1248            ((0..5).chain(7..8)).collect::<Vec<usize>>()
1249        );
1250        // Start wraps all the way around
1251        assert_eq!(
1252            iterate_expanded_and_wrapped_usize_range(2..4, 5, 1, 8).collect::<Vec<usize>>(),
1253            (0..8).collect::<Vec<usize>>()
1254        );
1255        // Start wraps all the way around and past 0
1256        assert_eq!(
1257            iterate_expanded_and_wrapped_usize_range(2..4, 10, 1, 8).collect::<Vec<usize>>(),
1258            (0..8).collect::<Vec<usize>>()
1259        );
1260        // End wraps
1261        assert_eq!(
1262            iterate_expanded_and_wrapped_usize_range(3..5, 1, 4, 8).collect::<Vec<usize>>(),
1263            (0..1).chain(2..8).collect::<Vec<usize>>()
1264        );
1265        // End wraps all the way around
1266        assert_eq!(
1267            iterate_expanded_and_wrapped_usize_range(3..5, 1, 5, 8).collect::<Vec<usize>>(),
1268            (0..8).collect::<Vec<usize>>()
1269        );
1270        // End wraps all the way around and past the end
1271        assert_eq!(
1272            iterate_expanded_and_wrapped_usize_range(3..5, 1, 10, 8).collect::<Vec<usize>>(),
1273            (0..8).collect::<Vec<usize>>()
1274        );
1275        // Both start and end wrap
1276        assert_eq!(
1277            iterate_expanded_and_wrapped_usize_range(3..5, 4, 4, 8).collect::<Vec<usize>>(),
1278            (0..8).collect::<Vec<usize>>()
1279        );
1280    }
1281}