1#![cfg(target_os = "windows")]
2
3use anyhow::{Context as _, Result, bail};
4use gpui::{App, AppContext as _, DismissEvent, Global, actions};
5use std::fmt::Write as _;
6use std::io::{BufRead, BufReader, Write};
7use std::path::{Path, PathBuf};
8use std::time::Duration;
9use util::{ResultExt as _, defer};
10use windows::Win32::Foundation::{VARIANT_BOOL, VARIANT_FALSE};
11use windows::Win32::System::Com::{CLSCTX_INPROC_SERVER, COINIT_MULTITHREADED, CoInitializeEx};
12use windows_core::{BSTR, Interface};
13use workspace::notifications::simple_message_notification::MessageNotification;
14use workspace::notifications::{NotificationId, show_app_notification};
15use wprcontrol::*;
16
17actions!(
18 zed,
19 [
20 /// Starts recording an ETW (Event Tracing for Windows) trace.
21 RecordEtwTrace,
22 /// Starts recording an ETW (Event Tracing for Windows) trace with heap tracing.
23 RecordEtwTraceWithHeapTracing,
24 /// Saves an in-progress ETW trace to disk.
25 SaveEtwTrace,
26 /// Cancels an in-progress ETW trace without saving.
27 CancelEtwTrace,
28 ]
29);
30
31struct EtwNotification;
32
33struct EtwSessionHandle {
34 writer: net::OwnedWriteHalf,
35 _listener: net::UnixListener,
36 socket_path: PathBuf,
37}
38
39impl Drop for EtwSessionHandle {
40 fn drop(&mut self) {
41 let _ = std::fs::remove_file(&self.socket_path);
42 }
43}
44
45struct GlobalEtwSession(Option<EtwSessionHandle>);
46
47impl Global for GlobalEtwSession {}
48
49fn has_active_etw_session(cx: &App) -> bool {
50 cx.global::<GlobalEtwSession>().0.is_some()
51}
52
53fn show_etw_notification(cx: &mut App, message: impl Into<gpui::SharedString>) {
54 let message = message.into();
55 show_app_notification(NotificationId::unique::<EtwNotification>(), cx, move |cx| {
56 cx.new(|cx| MessageNotification::new(message.clone(), cx))
57 });
58}
59
60fn show_etw_notification_with_action(
61 cx: &mut App,
62 message: impl Into<gpui::SharedString>,
63 button_label: impl Into<gpui::SharedString>,
64 on_click: impl Fn(&mut gpui::Window, &mut gpui::Context<MessageNotification>)
65 + Send
66 + Sync
67 + 'static,
68) {
69 let message = message.into();
70 let button_label = button_label.into();
71 let on_click = std::sync::Arc::new(on_click);
72 show_app_notification(NotificationId::unique::<EtwNotification>(), cx, move |cx| {
73 let message = message.clone();
74 let button_label = button_label.clone();
75 cx.new(|cx| {
76 MessageNotification::new(message, cx)
77 .primary_message(button_label)
78 .primary_on_click_arc(on_click.clone())
79 })
80 });
81}
82
83fn show_etw_status_notification(cx: &mut App, status: Result<StatusMessage>, output_path: PathBuf) {
84 match status {
85 Ok(StatusMessage::Stopped) => {
86 let display_path = output_path.display().to_string();
87 show_etw_notification_with_action(
88 cx,
89 format!("ETW trace saved to {display_path}"),
90 "Show in File Manager",
91 move |_window, cx| {
92 cx.reveal_path(&output_path);
93 cx.emit(DismissEvent);
94 },
95 );
96 }
97 Ok(StatusMessage::TimedOut) => {
98 let display_path = output_path.display().to_string();
99 show_etw_notification_with_action(
100 cx,
101 format!("ETW recording timed out. Trace saved to {display_path}"),
102 "Show in File Manager",
103 move |_window, cx| {
104 cx.reveal_path(&output_path);
105 cx.emit(DismissEvent);
106 },
107 );
108 }
109 Ok(StatusMessage::Cancelled) => {
110 show_etw_notification(cx, "ETW recording cancelled");
111 }
112 Ok(_) => {
113 show_etw_notification(cx, "ETW recording ended unexpectedly");
114 }
115 Err(error) => {
116 show_etw_notification(cx, format!("Failed to complete ETW recording: {error:#}"));
117 }
118 }
119}
120
121pub fn init(cx: &mut App) {
122 cx.set_global(GlobalEtwSession(None));
123
124 cx.on_action(|_: &RecordEtwTrace, cx: &mut App| {
125 start_etw_recording(cx, None);
126 });
127
128 cx.on_action(|_: &RecordEtwTraceWithHeapTracing, cx: &mut App| {
129 start_etw_recording(cx, Some(std::process::id()));
130 });
131
132 cx.on_action(|_: &SaveEtwTrace, cx: &mut App| {
133 let session = cx.global_mut::<GlobalEtwSession>().0.as_mut();
134 let Some(session) = session else {
135 show_etw_notification(cx, "No active ETW recording to stop");
136 return;
137 };
138 match send_json(&mut session.writer, &Command::Save) {
139 Ok(()) => {
140 show_etw_notification(cx, "Stopping ETW recording...");
141 }
142 Err(error) => {
143 show_etw_notification(cx, format!("Failed to stop ETW recording: {error:#}"));
144 }
145 }
146 });
147
148 cx.on_action(|_: &CancelEtwTrace, cx: &mut App| {
149 let session = cx.global_mut::<GlobalEtwSession>().0.as_mut();
150 let Some(session) = session else {
151 show_etw_notification(cx, "No active ETW recording to cancel");
152 return;
153 };
154 match send_json(&mut session.writer, &Command::Cancel) {
155 Ok(()) => {
156 show_etw_notification(cx, "Cancelling ETW recording...");
157 }
158 Err(error) => {
159 show_etw_notification(cx, format!("Failed to cancel ETW recording: {error:#}"));
160 }
161 }
162 });
163}
164
165fn start_etw_recording(cx: &mut App, heap_pid: Option<u32>) {
166 if has_active_etw_session(cx) {
167 show_etw_notification(cx, "ETW recording is already in progress");
168 return;
169 }
170 let save_dialog = cx.prompt_for_new_path(&PathBuf::default(), Some("zed-trace.etl"));
171 cx.spawn(async move |cx| {
172 let output_path = match save_dialog.await {
173 Ok(Ok(Some(path))) => path,
174 Ok(Ok(None)) => return,
175 Ok(Err(error)) => {
176 cx.update(|cx| {
177 show_etw_notification(cx, format!("Failed to pick save location: {error:#}"));
178 });
179 return;
180 }
181 Err(_) => return,
182 };
183
184 let result = cx
185 .background_spawn(async move { launch_etw_recording(heap_pid, &output_path) })
186 .await;
187
188 let EtwSession {
189 output_path,
190 stream,
191 listener,
192 socket_path,
193 } = match result {
194 Ok(session) => session,
195 Err(error) => {
196 cx.update(|cx| {
197 show_etw_notification(cx, format!("Failed to start ETW recording: {error:#}"));
198 });
199 return;
200 }
201 };
202
203 let (read_half, write_half) = stream.into_inner().into_split();
204
205 cx.spawn(async |cx| {
206 let status = cx
207 .background_spawn(async move {
208 recv_json(&mut BufReader::new(read_half))
209 .context("Receive status from subprocess")
210 })
211 .await;
212 cx.update(|cx| {
213 cx.global_mut::<GlobalEtwSession>().0 = None;
214 show_etw_status_notification(cx, status, output_path);
215 });
216 })
217 .detach();
218
219 cx.update(|cx| {
220 cx.global_mut::<GlobalEtwSession>().0 = Some(EtwSessionHandle {
221 writer: write_half,
222 _listener: listener,
223 socket_path,
224 });
225 show_etw_notification(cx, "ETW recording started");
226 });
227 })
228 .detach();
229}
230
231const RECORDING_TIMEOUT: Duration = Duration::from_secs(60);
232
233const INSTANCE_NAME: &str = "Zed";
234
235const BUILTIN_PROFILES: &[&str] = &[
236 "CPU.Verbose.Memory",
237 "GPU.Light.Memory",
238 "DiskIO.Light.Memory",
239 "FileIO.Light.Memory",
240];
241
242fn heap_tracing_profile(heap_pid: Option<u32>) -> String {
243 let (heap_provider, heap_collector) = match heap_pid {
244 Some(pid) => (
245 format!(
246 r#"
247 <HeapEventProvider Id="ZedHeapProvider">
248 <HeapProcessIds Operation="Set">
249 <HeapProcessId Value="{pid}"/>
250 </HeapProcessIds>
251 </HeapEventProvider>"#
252 ),
253 r#"
254 <Collectors Operation="Add">
255 <HeapEventCollectorId Value="HeapCollector_WPRHeapCollector">
256 <HeapEventProviders Operation="Set">
257 <HeapEventProviderId Value="ZedHeapProvider"/>
258 </HeapEventProviders>
259 </HeapEventCollectorId>
260 </Collectors>"#
261 .to_string(),
262 ),
263 None => (String::new(), String::new()),
264 };
265
266 format!(
267 r#"<?xml version="1.0" encoding="utf-8"?>
268<WindowsPerformanceRecorder Version="1.0" Author="Zed Industries">
269 <Profiles>
270 {heap_provider}
271
272 <Profile Id="ZedHeap.Verbose.Memory" Base="Heap.Verbose.Memory" Name="ZedHeap" DetailLevel="Verbose" LoggingMode="Memory" Description="Heap tracing">
273 {heap_collector}
274 </Profile>
275 </Profiles>
276
277 <TraceMergeProperties>
278 <TraceMergeProperty Id="TraceMerge_Default" Name="TraceMerge_Default">
279 <FileCompression Value="true"/>
280 </TraceMergeProperty>
281 </TraceMergeProperties>
282</WindowsPerformanceRecorder>"#
283 )
284}
285
286fn wpr_error_context(hresult: windows_core::HRESULT, source: &windows_core::IUnknown) -> String {
287 let mut out = format!("HRESULT: {hresult}");
288
289 unsafe {
290 let mut message = BSTR::new();
291 let mut description = BSTR::new();
292 let mut detail = BSTR::new();
293 if WPRCFormatError(
294 hresult,
295 Some(source),
296 &mut message,
297 Some(&mut description),
298 Some(&mut detail),
299 )
300 .is_ok()
301 {
302 for (label, value) in [
303 ("Message", &message),
304 ("Description", &description),
305 ("Detail", &detail),
306 ] {
307 if !value.is_empty() {
308 let _ = write!(out, "\n {label}: {value}");
309 }
310 }
311 }
312 }
313
314 if let Ok(info) = source.cast::<IParsingErrorInfo>() {
315 unsafe {
316 if let Ok(line) = info.GetLineNumber() {
317 let _ = write!(out, "\n Parse error at line: {line}");
318 if let Ok(col) = info.GetColumnNumber() {
319 let _ = write!(out, ", column: {col}");
320 }
321 }
322 for (label, getter) in [
323 ("Element type", info.GetElementType()),
324 ("Element ID", info.GetElementId()),
325 ("Description", info.GetDescription()),
326 ] {
327 if let Ok(value) = getter
328 && !value.is_empty()
329 {
330 let _ = write!(out, "\n {label}: {value}");
331 }
332 }
333 }
334 }
335
336 fn append_control_chain(out: &mut String, source: &windows_core::IUnknown) {
337 let Ok(info) = source.cast::<IControlErrorInfo>() else {
338 return;
339 };
340 unsafe {
341 if let Ok(object_type) = info.GetObjectType() {
342 let name = match object_type {
343 wprcontrol::ObjectType_Profile => "Profile",
344 wprcontrol::ObjectType_Collector => "Collector",
345 wprcontrol::ObjectType_Provider => "Provider",
346 _ => "Unknown",
347 };
348 let _ = write!(out, "\n Object type: {name}");
349 }
350 if let Ok(hr) = info.GetHResult() {
351 let _ = write!(out, "\n Inner HRESULT: {hr}");
352 }
353 if let Ok(desc) = info.GetDescription()
354 && !desc.is_empty()
355 {
356 let _ = write!(out, "\n Description: {desc}");
357 }
358 let mut inner = None;
359 if info.GetInnerErrorInfo(&mut inner).is_ok()
360 && let Some(inner) = inner
361 {
362 let _ = write!(out, "\n Caused by:");
363 append_control_chain(out, &inner);
364 }
365 }
366 }
367 append_control_chain(&mut out, source);
368
369 if let Ok(info) = source.cast::<windows::Win32::System::Com::IErrorInfo>() {
370 unsafe {
371 if let Ok(desc) = info.GetDescription()
372 && !desc.is_empty()
373 {
374 let _ = write!(out, "\n IErrorInfo: {desc}");
375 }
376 }
377 }
378
379 out
380}
381
382trait WprContext<T> {
383 fn wpr_context(self, source: &impl Interface) -> Result<T>;
384}
385
386impl<T> WprContext<T> for windows_core::Result<T> {
387 fn wpr_context(self, source: &impl Interface) -> Result<T> {
388 self.map_err(|e| {
389 let unknown: windows_core::IUnknown = source.cast().expect("cast to IUnknown");
390 let context = wpr_error_context(e.code(), &unknown);
391 anyhow::anyhow!("{context}")
392 })
393 }
394}
395
396fn create_wpr<T: windows_core::Interface>(clsid: &windows_core::GUID) -> Result<T> {
397 unsafe {
398 WPRCCreateInstanceUnderInstanceName::<_, T>(
399 &BSTR::from(INSTANCE_NAME),
400 clsid,
401 None,
402 CLSCTX_INPROC_SERVER.0,
403 )
404 .context("WPRCCreateInstance failed")
405 }
406}
407
408fn build_profile_collection(heap_pid: Option<u32>) -> Result<IProfileCollection> {
409 let collection: IProfileCollection = create_wpr(&CProfileCollection)?;
410
411 for profile_name in BUILTIN_PROFILES {
412 let profile: IProfile = create_wpr(&CProfile)?;
413 unsafe {
414 profile
415 .LoadFromFile(&BSTR::from(*profile_name), &BSTR::new())
416 .wpr_context(&profile)
417 .with_context(|| format!("Load built-in profile '{profile_name}'"))?;
418 collection
419 .Add(&profile, VARIANT_FALSE)
420 .wpr_context(&collection)
421 .with_context(|| format!("Add profile '{profile_name}' to collection"))?;
422 }
423 }
424
425 let heap_xml = heap_tracing_profile(heap_pid);
426 let heap_profile: IProfile = create_wpr(&CProfile)?;
427 unsafe {
428 heap_profile
429 .LoadFromString(&BSTR::from(heap_xml))
430 .wpr_context(&heap_profile)
431 .context("Load profile from XML string")?;
432 collection
433 .Add(&heap_profile, VARIANT_BOOL(0))
434 .wpr_context(&collection)
435 .context("Add ZedHeap profile to collection")?;
436 }
437
438 Ok(collection)
439}
440
441pub fn record_etw_trace(
442 heap_pid: Option<u32>,
443 output_path: &Path,
444 socket_path: &str,
445) -> Result<()> {
446 unsafe {
447 CoInitializeEx(None, COINIT_MULTITHREADED)
448 .ok()
449 .context("COM initialization failed")?;
450 }
451
452 let socket_path = Path::new(socket_path);
453 let mut stream = net::UnixStream::connect(socket_path).context("Connect to parent socket")?;
454
455 match record_etw_trace_inner(heap_pid, output_path, &mut stream) {
456 Ok(()) => Ok(()),
457 Err(e) => {
458 send_json(
459 &mut stream,
460 &StatusMessage::Error {
461 message: format!("{e:#}"),
462 },
463 )
464 .log_err();
465 Err(e)
466 }
467 }
468}
469
470fn record_etw_trace_inner(
471 heap_pid: Option<u32>,
472 output_path: &Path,
473 stream: &mut net::UnixStream,
474) -> Result<()> {
475 let collection = build_profile_collection(heap_pid)?;
476 let control_manager: IControlManager = create_wpr(&CControlManager)?;
477
478 // Cancel any leftover sessions with the same name that might exist
479 unsafe {
480 _ = control_manager.Cancel(None);
481 }
482
483 unsafe {
484 control_manager
485 .Start(&collection)
486 .wpr_context(&control_manager)
487 .context("Start WPR recording")?;
488 }
489
490 // We must call Save or Cancel before returning or we'll leak the kernel buffers used to record the ETW session.
491 let cancel_guard = defer({
492 let control_manager = control_manager.clone();
493 move || unsafe {
494 let _ = control_manager.Cancel(None);
495 }
496 });
497
498 send_json(stream, &StatusMessage::Started)?;
499
500 let (command, timed_out) = receive_command(stream)?;
501
502 match command {
503 Command::Cancel => {
504 unsafe {
505 control_manager
506 .Cancel(None)
507 .wpr_context(&control_manager)
508 .context("Cancel WPR recording")?;
509 }
510 cancel_guard.abort();
511
512 send_json(stream, &StatusMessage::Cancelled).log_err();
513 }
514 Command::Save => {
515 unsafe {
516 control_manager
517 .Save(
518 &BSTR::from(output_path.to_string_lossy().as_ref()),
519 &collection,
520 None,
521 )
522 .wpr_context(&control_manager)
523 .context("Stop WPR recording")?;
524 }
525 cancel_guard.abort();
526
527 if timed_out {
528 send_json(stream, &StatusMessage::TimedOut).log_err();
529 } else {
530 send_json(stream, &StatusMessage::Stopped).log_err();
531 }
532 }
533 }
534
535 Ok(())
536}
537
538fn receive_command(stream: &mut net::UnixStream) -> Result<(Command, bool)> {
539 use std::os::windows::io::{AsRawSocket, AsSocket};
540 use windows::Win32::Networking::WinSock::{SO_RCVTIMEO, SOL_SOCKET, setsockopt};
541
542 // Set a receive timeout so read_line returns an error after `timeout`.
543 let millis = RECORDING_TIMEOUT.as_millis() as u32;
544 let socket = stream.as_socket();
545 let ret = unsafe {
546 setsockopt(
547 windows::Win32::Networking::WinSock::SOCKET(socket.as_raw_socket() as _),
548 SOL_SOCKET,
549 SO_RCVTIMEO,
550 Some(&millis.to_ne_bytes()),
551 )
552 };
553 if ret != 0 {
554 bail!("Failed to set socket receive timeout: setsockopt returned {ret}");
555 }
556
557 let mut reader = BufReader::new(&mut *stream);
558 match recv_json::<Command>(&mut reader) {
559 Ok(command) => Ok((command, false)),
560 Err(error) => {
561 log::warn!("Failed to receive ETW command, treating as timed-out Save: {error:#}");
562 Ok((Command::Save, true))
563 }
564 }
565}
566
567pub struct EtwSession {
568 output_path: PathBuf,
569 stream: BufReader<net::UnixStream>,
570 listener: net::UnixListener,
571 socket_path: PathBuf,
572}
573
574pub fn launch_etw_recording(heap_pid: Option<u32>, output_path: &Path) -> Result<EtwSession> {
575 let sock_path = std::env::temp_dir().join(format!("zed-etw-{}.sock", std::process::id()));
576
577 _ = std::fs::remove_file(&sock_path);
578 let listener = net::UnixListener::bind(&sock_path).context("Bind Unix socket for ETW IPC")?;
579
580 let exe_path = std::env::current_exe().context("Failed to get current exe path")?;
581 let pid_arg = heap_pid.map_or(-1i64, |pid| pid as i64);
582 let args = format!(
583 "--record-etw-trace --etw-zed-pid {} --etw-output \"{}\" --etw-socket \"{}\"",
584 pid_arg,
585 output_path.display(),
586 sock_path.display(),
587 );
588
589 use windows::Win32::UI::Shell::ShellExecuteW;
590 use windows_core::PCWSTR;
591
592 let operation: Vec<u16> = "runas\0".encode_utf16().collect();
593 let file: Vec<u16> = format!("{}\0", exe_path.to_string_lossy())
594 .encode_utf16()
595 .collect();
596 let parameters: Vec<u16> = format!("{args}\0").encode_utf16().collect();
597
598 let result = unsafe {
599 ShellExecuteW(
600 None,
601 PCWSTR(operation.as_ptr()),
602 PCWSTR(file.as_ptr()),
603 PCWSTR(parameters.as_ptr()),
604 PCWSTR::null(),
605 windows::Win32::UI::WindowsAndMessaging::SW_HIDE,
606 )
607 };
608
609 let result_code = result.0 as usize;
610 if result_code <= 32 {
611 bail!("ShellExecuteW failed to launch elevated process (code: {result_code})");
612 }
613
614 let (stream, _) = listener.accept().context("Accept subprocess connection")?;
615
616 let mut session = EtwSession {
617 output_path: output_path.to_path_buf(),
618 stream: BufReader::new(stream),
619 listener,
620 socket_path: sock_path,
621 };
622
623 let status: StatusMessage =
624 recv_json(&mut session.stream).context("Wait for Started status")?;
625
626 match status {
627 StatusMessage::Started => {}
628 StatusMessage::Error { message } => {
629 bail!("Subprocess reported error during start: {message}");
630 }
631 other => {
632 bail!("Unexpected status from subprocess: {other:?}");
633 }
634 }
635
636 Ok(session)
637}
638
639#[derive(Debug, serde::Serialize, serde::Deserialize)]
640#[serde(tag = "type")]
641pub enum StatusMessage {
642 Started,
643 Stopped,
644 TimedOut,
645 Cancelled,
646 Error { message: String },
647}
648
649#[derive(Debug, serde::Serialize, serde::Deserialize)]
650#[serde(tag = "type")]
651pub enum Command {
652 Save,
653 Cancel,
654}
655
656fn send_json<T: serde::Serialize>(writer: &mut impl Write, value: &T) -> Result<()> {
657 let json = serde_json::to_string(value).context("Serialize message")?;
658 writeln!(writer, "{json}").context("Write to socket")?;
659 writer.flush().context("Flush socket")?;
660 Ok(())
661}
662
663fn recv_json<T: serde::de::DeserializeOwned>(reader: &mut impl BufRead) -> Result<T> {
664 let mut line = String::new();
665 reader.read_line(&mut line).context("Read from socket")?;
666 if line.is_empty() {
667 bail!("Socket closed before a message was received");
668 }
669 serde_json::from_str(line.trim()).context("Parse message")
670}