etw_tracing.rs

  1#![cfg(target_os = "windows")]
  2
  3use anyhow::{Context as _, Result, bail};
  4use gpui::{App, AppContext as _, DismissEvent, Global, actions};
  5use std::fmt::Write as _;
  6use std::io::{BufRead, BufReader, Write};
  7use std::path::{Path, PathBuf};
  8use std::time::Duration;
  9use util::{ResultExt as _, defer};
 10use windows::Win32::Foundation::{VARIANT_BOOL, VARIANT_FALSE};
 11use windows::Win32::System::Com::{CLSCTX_INPROC_SERVER, COINIT_MULTITHREADED, CoInitializeEx};
 12use windows_core::{BSTR, Interface};
 13use workspace::notifications::simple_message_notification::MessageNotification;
 14use workspace::notifications::{NotificationId, show_app_notification};
 15use wprcontrol::*;
 16
 17actions!(
 18    zed,
 19    [
 20        /// Starts recording an ETW (Event Tracing for Windows) trace.
 21        RecordEtwTrace,
 22        /// Starts recording an ETW (Event Tracing for Windows) trace with heap tracing.
 23        RecordEtwTraceWithHeapTracing,
 24        /// Saves an in-progress ETW trace to disk.
 25        SaveEtwTrace,
 26        /// Cancels an in-progress ETW trace without saving.
 27        CancelEtwTrace,
 28    ]
 29);
 30
 31struct EtwNotification;
 32
 33struct EtwSessionHandle {
 34    writer: net::OwnedWriteHalf,
 35    _listener: net::UnixListener,
 36    socket_path: PathBuf,
 37}
 38
 39impl Drop for EtwSessionHandle {
 40    fn drop(&mut self) {
 41        let _ = std::fs::remove_file(&self.socket_path);
 42    }
 43}
 44
 45struct GlobalEtwSession(Option<EtwSessionHandle>);
 46
 47impl Global for GlobalEtwSession {}
 48
 49fn has_active_etw_session(cx: &App) -> bool {
 50    cx.global::<GlobalEtwSession>().0.is_some()
 51}
 52
 53fn show_etw_notification(cx: &mut App, message: impl Into<gpui::SharedString>) {
 54    let message = message.into();
 55    show_app_notification(NotificationId::unique::<EtwNotification>(), cx, move |cx| {
 56        cx.new(|cx| MessageNotification::new(message.clone(), cx))
 57    });
 58}
 59
 60fn show_etw_notification_with_action(
 61    cx: &mut App,
 62    message: impl Into<gpui::SharedString>,
 63    button_label: impl Into<gpui::SharedString>,
 64    on_click: impl Fn(&mut gpui::Window, &mut gpui::Context<MessageNotification>)
 65    + Send
 66    + Sync
 67    + 'static,
 68) {
 69    let message = message.into();
 70    let button_label = button_label.into();
 71    let on_click = std::sync::Arc::new(on_click);
 72    show_app_notification(NotificationId::unique::<EtwNotification>(), cx, move |cx| {
 73        let message = message.clone();
 74        let button_label = button_label.clone();
 75        cx.new(|cx| {
 76            MessageNotification::new(message, cx)
 77                .primary_message(button_label)
 78                .primary_on_click_arc(on_click.clone())
 79        })
 80    });
 81}
 82
 83fn show_etw_status_notification(cx: &mut App, status: Result<StatusMessage>, output_path: PathBuf) {
 84    match status {
 85        Ok(StatusMessage::Stopped) => {
 86            let display_path = output_path.display().to_string();
 87            show_etw_notification_with_action(
 88                cx,
 89                format!("ETW trace saved to {display_path}"),
 90                "Show in File Manager",
 91                move |_window, cx| {
 92                    cx.reveal_path(&output_path);
 93                    cx.emit(DismissEvent);
 94                },
 95            );
 96        }
 97        Ok(StatusMessage::TimedOut) => {
 98            let display_path = output_path.display().to_string();
 99            show_etw_notification_with_action(
100                cx,
101                format!("ETW recording timed out. Trace saved to {display_path}"),
102                "Show in File Manager",
103                move |_window, cx| {
104                    cx.reveal_path(&output_path);
105                    cx.emit(DismissEvent);
106                },
107            );
108        }
109        Ok(StatusMessage::Cancelled) => {
110            show_etw_notification(cx, "ETW recording cancelled");
111        }
112        Ok(_) => {
113            show_etw_notification(cx, "ETW recording ended unexpectedly");
114        }
115        Err(error) => {
116            show_etw_notification(cx, format!("Failed to complete ETW recording: {error:#}"));
117        }
118    }
119}
120
121pub fn init(cx: &mut App) {
122    cx.set_global(GlobalEtwSession(None));
123
124    cx.on_action(|_: &RecordEtwTrace, cx: &mut App| {
125        start_etw_recording(cx, None);
126    });
127
128    cx.on_action(|_: &RecordEtwTraceWithHeapTracing, cx: &mut App| {
129        start_etw_recording(cx, Some(std::process::id()));
130    });
131
132    cx.on_action(|_: &SaveEtwTrace, cx: &mut App| {
133        let session = cx.global_mut::<GlobalEtwSession>().0.as_mut();
134        let Some(session) = session else {
135            show_etw_notification(cx, "No active ETW recording to stop");
136            return;
137        };
138        match send_json(&mut session.writer, &Command::Save) {
139            Ok(()) => {
140                show_etw_notification(cx, "Stopping ETW recording...");
141            }
142            Err(error) => {
143                show_etw_notification(cx, format!("Failed to stop ETW recording: {error:#}"));
144            }
145        }
146    });
147
148    cx.on_action(|_: &CancelEtwTrace, cx: &mut App| {
149        let session = cx.global_mut::<GlobalEtwSession>().0.as_mut();
150        let Some(session) = session else {
151            show_etw_notification(cx, "No active ETW recording to cancel");
152            return;
153        };
154        match send_json(&mut session.writer, &Command::Cancel) {
155            Ok(()) => {
156                show_etw_notification(cx, "Cancelling ETW recording...");
157            }
158            Err(error) => {
159                show_etw_notification(cx, format!("Failed to cancel ETW recording: {error:#}"));
160            }
161        }
162    });
163}
164
165fn start_etw_recording(cx: &mut App, heap_pid: Option<u32>) {
166    if has_active_etw_session(cx) {
167        show_etw_notification(cx, "ETW recording is already in progress");
168        return;
169    }
170    let save_dialog = cx.prompt_for_new_path(&PathBuf::default(), Some("zed-trace.etl"));
171    cx.spawn(async move |cx| {
172        let output_path = match save_dialog.await {
173            Ok(Ok(Some(path))) => path,
174            Ok(Ok(None)) => return,
175            Ok(Err(error)) => {
176                cx.update(|cx| {
177                    show_etw_notification(cx, format!("Failed to pick save location: {error:#}"));
178                });
179                return;
180            }
181            Err(_) => return,
182        };
183
184        let result = cx
185            .background_spawn(async move { launch_etw_recording(heap_pid, &output_path) })
186            .await;
187
188        let EtwSession {
189            output_path,
190            stream,
191            listener,
192            socket_path,
193        } = match result {
194            Ok(session) => session,
195            Err(error) => {
196                cx.update(|cx| {
197                    show_etw_notification(cx, format!("Failed to start ETW recording: {error:#}"));
198                });
199                return;
200            }
201        };
202
203        let (read_half, write_half) = stream.into_inner().into_split();
204
205        cx.spawn(async |cx| {
206            let status = cx
207                .background_spawn(async move {
208                    recv_json(&mut BufReader::new(read_half))
209                        .context("Receive status from subprocess")
210                })
211                .await;
212            cx.update(|cx| {
213                cx.global_mut::<GlobalEtwSession>().0 = None;
214                show_etw_status_notification(cx, status, output_path);
215            });
216        })
217        .detach();
218
219        cx.update(|cx| {
220            cx.global_mut::<GlobalEtwSession>().0 = Some(EtwSessionHandle {
221                writer: write_half,
222                _listener: listener,
223                socket_path,
224            });
225            show_etw_notification(cx, "ETW recording started");
226        });
227    })
228    .detach();
229}
230
231const RECORDING_TIMEOUT: Duration = Duration::from_secs(60);
232
233const INSTANCE_NAME: &str = "Zed";
234
235const BUILTIN_PROFILES: &[&str] = &[
236    "CPU.Verbose.Memory",
237    "GPU.Light.Memory",
238    "DiskIO.Light.Memory",
239    "FileIO.Light.Memory",
240];
241
242fn heap_tracing_profile(heap_pid: Option<u32>) -> String {
243    let (heap_provider, heap_collector) = match heap_pid {
244        Some(pid) => (
245            format!(
246                r#"
247    <HeapEventProvider Id="ZedHeapProvider">
248      <HeapProcessIds Operation="Set">
249        <HeapProcessId Value="{pid}"/>
250      </HeapProcessIds>
251    </HeapEventProvider>"#
252            ),
253            r#"
254      <Collectors Operation="Add">
255        <HeapEventCollectorId Value="HeapCollector_WPRHeapCollector">
256          <HeapEventProviders Operation="Set">
257            <HeapEventProviderId Value="ZedHeapProvider"/>
258          </HeapEventProviders>
259        </HeapEventCollectorId>
260      </Collectors>"#
261                .to_string(),
262        ),
263        None => (String::new(), String::new()),
264    };
265
266    format!(
267        r#"<?xml version="1.0" encoding="utf-8"?>
268<WindowsPerformanceRecorder Version="1.0" Author="Zed Industries">
269  <Profiles>
270    {heap_provider}
271
272    <Profile Id="ZedHeap.Verbose.Memory" Base="Heap.Verbose.Memory" Name="ZedHeap" DetailLevel="Verbose" LoggingMode="Memory" Description="Heap tracing">
273      {heap_collector}
274    </Profile>
275  </Profiles>
276
277  <TraceMergeProperties>
278    <TraceMergeProperty Id="TraceMerge_Default" Name="TraceMerge_Default">
279      <FileCompression Value="true"/>
280    </TraceMergeProperty>
281  </TraceMergeProperties>
282</WindowsPerformanceRecorder>"#
283    )
284}
285
286fn wpr_error_context(hresult: windows_core::HRESULT, source: &windows_core::IUnknown) -> String {
287    let mut out = format!("HRESULT: {hresult}");
288
289    unsafe {
290        let mut message = BSTR::new();
291        let mut description = BSTR::new();
292        let mut detail = BSTR::new();
293        if WPRCFormatError(
294            hresult,
295            Some(source),
296            &mut message,
297            Some(&mut description),
298            Some(&mut detail),
299        )
300        .is_ok()
301        {
302            for (label, value) in [
303                ("Message", &message),
304                ("Description", &description),
305                ("Detail", &detail),
306            ] {
307                if !value.is_empty() {
308                    let _ = write!(out, "\n  {label}: {value}");
309                }
310            }
311        }
312    }
313
314    if let Ok(info) = source.cast::<IParsingErrorInfo>() {
315        unsafe {
316            if let Ok(line) = info.GetLineNumber() {
317                let _ = write!(out, "\n  Parse error at line: {line}");
318                if let Ok(col) = info.GetColumnNumber() {
319                    let _ = write!(out, ", column: {col}");
320                }
321            }
322            for (label, getter) in [
323                ("Element type", info.GetElementType()),
324                ("Element ID", info.GetElementId()),
325                ("Description", info.GetDescription()),
326            ] {
327                if let Ok(value) = getter
328                    && !value.is_empty()
329                {
330                    let _ = write!(out, "\n  {label}: {value}");
331                }
332            }
333        }
334    }
335
336    fn append_control_chain(out: &mut String, source: &windows_core::IUnknown) {
337        let Ok(info) = source.cast::<IControlErrorInfo>() else {
338            return;
339        };
340        unsafe {
341            if let Ok(object_type) = info.GetObjectType() {
342                let name = match object_type {
343                    wprcontrol::ObjectType_Profile => "Profile",
344                    wprcontrol::ObjectType_Collector => "Collector",
345                    wprcontrol::ObjectType_Provider => "Provider",
346                    _ => "Unknown",
347                };
348                let _ = write!(out, "\n  Object type: {name}");
349            }
350            if let Ok(hr) = info.GetHResult() {
351                let _ = write!(out, "\n  Inner HRESULT: {hr}");
352            }
353            if let Ok(desc) = info.GetDescription()
354                && !desc.is_empty()
355            {
356                let _ = write!(out, "\n  Description: {desc}");
357            }
358            let mut inner = None;
359            if info.GetInnerErrorInfo(&mut inner).is_ok()
360                && let Some(inner) = inner
361            {
362                let _ = write!(out, "\n  Caused by:");
363                append_control_chain(out, &inner);
364            }
365        }
366    }
367    append_control_chain(&mut out, source);
368
369    if let Ok(info) = source.cast::<windows::Win32::System::Com::IErrorInfo>() {
370        unsafe {
371            if let Ok(desc) = info.GetDescription()
372                && !desc.is_empty()
373            {
374                let _ = write!(out, "\n  IErrorInfo: {desc}");
375            }
376        }
377    }
378
379    out
380}
381
382trait WprContext<T> {
383    fn wpr_context(self, source: &impl Interface) -> Result<T>;
384}
385
386impl<T> WprContext<T> for windows_core::Result<T> {
387    fn wpr_context(self, source: &impl Interface) -> Result<T> {
388        self.map_err(|e| {
389            let unknown: windows_core::IUnknown = source.cast().expect("cast to IUnknown");
390            let context = wpr_error_context(e.code(), &unknown);
391            anyhow::anyhow!("{context}")
392        })
393    }
394}
395
396fn create_wpr<T: windows_core::Interface>(clsid: &windows_core::GUID) -> Result<T> {
397    unsafe {
398        WPRCCreateInstanceUnderInstanceName::<_, T>(
399            &BSTR::from(INSTANCE_NAME),
400            clsid,
401            None,
402            CLSCTX_INPROC_SERVER.0,
403        )
404        .context("WPRCCreateInstance failed")
405    }
406}
407
408fn build_profile_collection(heap_pid: Option<u32>) -> Result<IProfileCollection> {
409    let collection: IProfileCollection = create_wpr(&CProfileCollection)?;
410
411    for profile_name in BUILTIN_PROFILES {
412        let profile: IProfile = create_wpr(&CProfile)?;
413        unsafe {
414            profile
415                .LoadFromFile(&BSTR::from(*profile_name), &BSTR::new())
416                .wpr_context(&profile)
417                .with_context(|| format!("Load built-in profile '{profile_name}'"))?;
418            collection
419                .Add(&profile, VARIANT_FALSE)
420                .wpr_context(&collection)
421                .with_context(|| format!("Add profile '{profile_name}' to collection"))?;
422        }
423    }
424
425    let heap_xml = heap_tracing_profile(heap_pid);
426    let heap_profile: IProfile = create_wpr(&CProfile)?;
427    unsafe {
428        heap_profile
429            .LoadFromString(&BSTR::from(heap_xml))
430            .wpr_context(&heap_profile)
431            .context("Load profile from XML string")?;
432        collection
433            .Add(&heap_profile, VARIANT_BOOL(0))
434            .wpr_context(&collection)
435            .context("Add ZedHeap profile to collection")?;
436    }
437
438    Ok(collection)
439}
440
441pub fn record_etw_trace(
442    heap_pid: Option<u32>,
443    output_path: &Path,
444    socket_path: &str,
445) -> Result<()> {
446    unsafe {
447        CoInitializeEx(None, COINIT_MULTITHREADED)
448            .ok()
449            .context("COM initialization failed")?;
450    }
451
452    let socket_path = Path::new(socket_path);
453    let mut stream = net::UnixStream::connect(socket_path).context("Connect to parent socket")?;
454
455    match record_etw_trace_inner(heap_pid, output_path, &mut stream) {
456        Ok(()) => Ok(()),
457        Err(e) => {
458            send_json(
459                &mut stream,
460                &StatusMessage::Error {
461                    message: format!("{e:#}"),
462                },
463            )
464            .log_err();
465            Err(e)
466        }
467    }
468}
469
470fn record_etw_trace_inner(
471    heap_pid: Option<u32>,
472    output_path: &Path,
473    stream: &mut net::UnixStream,
474) -> Result<()> {
475    let collection = build_profile_collection(heap_pid)?;
476    let control_manager: IControlManager = create_wpr(&CControlManager)?;
477
478    // Cancel any leftover sessions with the same name that might exist
479    unsafe {
480        _ = control_manager.Cancel(None);
481    }
482
483    unsafe {
484        control_manager
485            .Start(&collection)
486            .wpr_context(&control_manager)
487            .context("Start WPR recording")?;
488    }
489
490    // We must call Save or Cancel before returning or we'll leak the kernel buffers used to record the ETW session.
491    let cancel_guard = defer({
492        let control_manager = control_manager.clone();
493        move || unsafe {
494            let _ = control_manager.Cancel(None);
495        }
496    });
497
498    send_json(stream, &StatusMessage::Started)?;
499
500    let (command, timed_out) = receive_command(stream)?;
501
502    match command {
503        Command::Cancel => {
504            unsafe {
505                control_manager
506                    .Cancel(None)
507                    .wpr_context(&control_manager)
508                    .context("Cancel WPR recording")?;
509            }
510            cancel_guard.abort();
511
512            send_json(stream, &StatusMessage::Cancelled).log_err();
513        }
514        Command::Save => {
515            unsafe {
516                control_manager
517                    .Save(
518                        &BSTR::from(output_path.to_string_lossy().as_ref()),
519                        &collection,
520                        None,
521                    )
522                    .wpr_context(&control_manager)
523                    .context("Stop WPR recording")?;
524            }
525            cancel_guard.abort();
526
527            if timed_out {
528                send_json(stream, &StatusMessage::TimedOut).log_err();
529            } else {
530                send_json(stream, &StatusMessage::Stopped).log_err();
531            }
532        }
533    }
534
535    Ok(())
536}
537
538fn receive_command(stream: &mut net::UnixStream) -> Result<(Command, bool)> {
539    use std::os::windows::io::{AsRawSocket, AsSocket};
540    use windows::Win32::Networking::WinSock::{SO_RCVTIMEO, SOL_SOCKET, setsockopt};
541
542    // Set a receive timeout so read_line returns an error after `timeout`.
543    let millis = RECORDING_TIMEOUT.as_millis() as u32;
544    let socket = stream.as_socket();
545    let ret = unsafe {
546        setsockopt(
547            windows::Win32::Networking::WinSock::SOCKET(socket.as_raw_socket() as _),
548            SOL_SOCKET,
549            SO_RCVTIMEO,
550            Some(&millis.to_ne_bytes()),
551        )
552    };
553    if ret != 0 {
554        bail!("Failed to set socket receive timeout: setsockopt returned {ret}");
555    }
556
557    let mut reader = BufReader::new(&mut *stream);
558    match recv_json::<Command>(&mut reader) {
559        Ok(command) => Ok((command, false)),
560        Err(error) => {
561            log::warn!("Failed to receive ETW command, treating as timed-out Save: {error:#}");
562            Ok((Command::Save, true))
563        }
564    }
565}
566
567pub struct EtwSession {
568    output_path: PathBuf,
569    stream: BufReader<net::UnixStream>,
570    listener: net::UnixListener,
571    socket_path: PathBuf,
572}
573
574pub fn launch_etw_recording(heap_pid: Option<u32>, output_path: &Path) -> Result<EtwSession> {
575    let sock_path = std::env::temp_dir().join(format!("zed-etw-{}.sock", std::process::id()));
576
577    _ = std::fs::remove_file(&sock_path);
578    let listener = net::UnixListener::bind(&sock_path).context("Bind Unix socket for ETW IPC")?;
579
580    let exe_path = std::env::current_exe().context("Failed to get current exe path")?;
581    let pid_arg = heap_pid.map_or(-1i64, |pid| pid as i64);
582    let args = format!(
583        "--record-etw-trace --etw-zed-pid {} --etw-output \"{}\" --etw-socket \"{}\"",
584        pid_arg,
585        output_path.display(),
586        sock_path.display(),
587    );
588
589    use windows::Win32::UI::Shell::ShellExecuteW;
590    use windows_core::PCWSTR;
591
592    let operation: Vec<u16> = "runas\0".encode_utf16().collect();
593    let file: Vec<u16> = format!("{}\0", exe_path.to_string_lossy())
594        .encode_utf16()
595        .collect();
596    let parameters: Vec<u16> = format!("{args}\0").encode_utf16().collect();
597
598    let result = unsafe {
599        ShellExecuteW(
600            None,
601            PCWSTR(operation.as_ptr()),
602            PCWSTR(file.as_ptr()),
603            PCWSTR(parameters.as_ptr()),
604            PCWSTR::null(),
605            windows::Win32::UI::WindowsAndMessaging::SW_HIDE,
606        )
607    };
608
609    let result_code = result.0 as usize;
610    if result_code <= 32 {
611        bail!("ShellExecuteW failed to launch elevated process (code: {result_code})");
612    }
613
614    let (stream, _) = listener.accept().context("Accept subprocess connection")?;
615
616    let mut session = EtwSession {
617        output_path: output_path.to_path_buf(),
618        stream: BufReader::new(stream),
619        listener,
620        socket_path: sock_path,
621    };
622
623    let status: StatusMessage =
624        recv_json(&mut session.stream).context("Wait for Started status")?;
625
626    match status {
627        StatusMessage::Started => {}
628        StatusMessage::Error { message } => {
629            bail!("Subprocess reported error during start: {message}");
630        }
631        other => {
632            bail!("Unexpected status from subprocess: {other:?}");
633        }
634    }
635
636    Ok(session)
637}
638
639#[derive(Debug, serde::Serialize, serde::Deserialize)]
640#[serde(tag = "type")]
641pub enum StatusMessage {
642    Started,
643    Stopped,
644    TimedOut,
645    Cancelled,
646    Error { message: String },
647}
648
649#[derive(Debug, serde::Serialize, serde::Deserialize)]
650#[serde(tag = "type")]
651pub enum Command {
652    Save,
653    Cancel,
654}
655
656fn send_json<T: serde::Serialize>(writer: &mut impl Write, value: &T) -> Result<()> {
657    let json = serde_json::to_string(value).context("Serialize message")?;
658    writeln!(writer, "{json}").context("Write to socket")?;
659    writer.flush().context("Flush socket")?;
660    Ok(())
661}
662
663fn recv_json<T: serde::de::DeserializeOwned>(reader: &mut impl BufRead) -> Result<T> {
664    let mut line = String::new();
665    reader.read_line(&mut line).context("Read from socket")?;
666    if line.is_empty() {
667        bail!("Socket closed before a message was received");
668    }
669    serde_json::from_str(line.trim()).context("Parse message")
670}