keyboard.rs

  1use anyhow::Result;
  2use collections::HashMap;
  3use windows::Win32::UI::{
  4    Input::KeyboardAndMouse::{
  5        GetKeyboardLayoutNameW, MAPVK_VK_TO_CHAR, MAPVK_VK_TO_VSC, MapVirtualKeyW, ToUnicode,
  6        VIRTUAL_KEY, VK_0, VK_1, VK_2, VK_3, VK_4, VK_5, VK_6, VK_7, VK_8, VK_9, VK_ABNT_C1,
  7        VK_CONTROL, VK_MENU, VK_OEM_1, VK_OEM_2, VK_OEM_3, VK_OEM_4, VK_OEM_5, VK_OEM_6, VK_OEM_7,
  8        VK_OEM_8, VK_OEM_102, VK_OEM_COMMA, VK_OEM_MINUS, VK_OEM_PERIOD, VK_OEM_PLUS, VK_SHIFT,
  9    },
 10    WindowsAndMessaging::KL_NAMELENGTH,
 11};
 12
 13use crate::{
 14    KeybindingKeystroke, Keystroke, Modifiers, PlatformKeyboardLayout, PlatformKeyboardMapper,
 15};
 16
 17pub(crate) struct WindowsKeyboardLayout {
 18    id: String,
 19    name: String,
 20}
 21
 22pub(crate) struct WindowsKeyboardMapper {
 23    key_to_vkey: HashMap<String, (u16, bool)>,
 24    vkey_to_key: HashMap<u16, String>,
 25    vkey_to_shifted: HashMap<u16, String>,
 26}
 27
 28impl PlatformKeyboardLayout for WindowsKeyboardLayout {
 29    fn id(&self) -> &str {
 30        &self.id
 31    }
 32
 33    fn name(&self) -> &str {
 34        &self.name
 35    }
 36}
 37
 38impl PlatformKeyboardMapper for WindowsKeyboardMapper {
 39    fn map_key_equivalent(
 40        &self,
 41        mut keystroke: Keystroke,
 42        use_key_equivalents: bool,
 43    ) -> KeybindingKeystroke {
 44        let Some((vkey, shifted_key)) = self.get_vkey_from_key(&keystroke.key, use_key_equivalents)
 45        else {
 46            return KeybindingKeystroke::from_keystroke(keystroke);
 47        };
 48        if shifted_key && keystroke.modifiers.shift {
 49            log::warn!(
 50                "Keystroke '{}' has both shift and a shifted key, this is likely a bug",
 51                keystroke.key
 52            );
 53        }
 54
 55        let shift = shifted_key || keystroke.modifiers.shift;
 56        keystroke.modifiers.shift = false;
 57
 58        let Some(key) = self.vkey_to_key.get(&vkey).cloned() else {
 59            log::error!(
 60                "Failed to map key equivalent '{:?}' to a valid key",
 61                keystroke
 62            );
 63            return KeybindingKeystroke::from_keystroke(keystroke);
 64        };
 65
 66        keystroke.key = if shift {
 67            let Some(shifted_key) = self.vkey_to_shifted.get(&vkey).cloned() else {
 68                log::error!(
 69                    "Failed to map keystroke {:?} with virtual key '{:?}' to a shifted key",
 70                    keystroke,
 71                    vkey
 72                );
 73                return KeybindingKeystroke::from_keystroke(keystroke);
 74            };
 75            shifted_key
 76        } else {
 77            key.clone()
 78        };
 79
 80        let modifiers = Modifiers {
 81            shift,
 82            ..keystroke.modifiers
 83        };
 84
 85        KeybindingKeystroke::new(keystroke, modifiers, key)
 86    }
 87
 88    fn get_key_equivalents(&self) -> Option<&HashMap<char, char>> {
 89        None
 90    }
 91}
 92
 93impl WindowsKeyboardLayout {
 94    pub(crate) fn new() -> Result<Self> {
 95        let mut buffer = [0u16; KL_NAMELENGTH as usize]; // KL_NAMELENGTH includes the null terminator
 96        unsafe { GetKeyboardLayoutNameW(&mut buffer)? };
 97        let id = String::from_utf16_lossy(&buffer[..buffer.len() - 1]); // Remove the null terminator
 98        let entry = windows_registry::LOCAL_MACHINE.open(format!(
 99            "System\\CurrentControlSet\\Control\\Keyboard Layouts\\{id}"
100        ))?;
101        let name = entry.get_string("Layout Text")?;
102        Ok(Self { id, name })
103    }
104
105    pub(crate) fn unknown() -> Self {
106        Self {
107            id: "unknown".to_string(),
108            name: "unknown".to_string(),
109        }
110    }
111
112    pub(crate) fn uses_altgr(&self) -> bool {
113        // Check if this is a known AltGr layout by examining the layout ID
114        // The layout ID is a hex string like "00000409" (US) or "00000407" (German)
115        // Extract the language ID (last 4 bytes)
116        let id_bytes = self.id.as_bytes();
117        if id_bytes.len() >= 4 {
118            let lang_id = &id_bytes[id_bytes.len() - 4..];
119            // List of keyboard layouts that use AltGr (non-exhaustive)
120            matches!(
121                lang_id,
122                b"0407" | // German
123                b"040C" | // French
124                b"040A" | // Spanish
125                b"0415" | // Polish
126                b"0413" | // Dutch
127                b"0816" | // Portuguese
128                b"041D" | // Swedish
129                b"0414" | // Norwegian
130                b"040B" | // Finnish
131                b"041F" | // Turkish
132                b"0419" | // Russian
133                b"0405" | // Czech
134                b"040E" | // Hungarian
135                b"0424" | // Slovenian
136                b"041A" | // Croatian
137                b"041B" | // Slovak
138                b"0418" // Romanian
139            )
140        } else {
141            false
142        }
143    }
144}
145
146impl WindowsKeyboardMapper {
147    pub(crate) fn new() -> Self {
148        let mut key_to_vkey = HashMap::default();
149        let mut vkey_to_key = HashMap::default();
150        let mut vkey_to_shifted = HashMap::default();
151        for vkey in CANDIDATE_VKEYS {
152            if let Some(key) = get_key_from_vkey(*vkey) {
153                key_to_vkey.insert(key.clone(), (vkey.0, false));
154                vkey_to_key.insert(vkey.0, key);
155            }
156            let scan_code = unsafe { MapVirtualKeyW(vkey.0 as u32, MAPVK_VK_TO_VSC) };
157            if scan_code == 0 {
158                continue;
159            }
160            if let Some(shifted_key) = get_shifted_key(*vkey, scan_code) {
161                key_to_vkey.insert(shifted_key.clone(), (vkey.0, true));
162                vkey_to_shifted.insert(vkey.0, shifted_key);
163            }
164        }
165        Self {
166            key_to_vkey,
167            vkey_to_key,
168            vkey_to_shifted,
169        }
170    }
171
172    fn get_vkey_from_key(&self, key: &str, use_key_equivalents: bool) -> Option<(u16, bool)> {
173        if use_key_equivalents {
174            get_vkey_from_key_with_us_layout(key)
175        } else {
176            self.key_to_vkey.get(key).cloned()
177        }
178    }
179}
180
181pub(crate) fn get_keystroke_key(
182    vkey: VIRTUAL_KEY,
183    scan_code: u32,
184    modifiers: &mut Modifiers,
185) -> Option<String> {
186    if modifiers.shift && need_to_convert_to_shifted_key(vkey) {
187        get_shifted_key(vkey, scan_code).inspect(|_| {
188            modifiers.shift = false;
189        })
190    } else {
191        get_key_from_vkey(vkey)
192    }
193}
194
195fn get_key_from_vkey(vkey: VIRTUAL_KEY) -> Option<String> {
196    let key_data = unsafe { MapVirtualKeyW(vkey.0 as u32, MAPVK_VK_TO_CHAR) };
197    if key_data == 0 {
198        return None;
199    }
200
201    // The high word contains dead key flag, the low word contains the character
202    let key = char::from_u32(key_data & 0xFFFF)?;
203
204    Some(key.to_ascii_lowercase().to_string())
205}
206
207#[inline]
208fn need_to_convert_to_shifted_key(vkey: VIRTUAL_KEY) -> bool {
209    matches!(
210        vkey,
211        VK_OEM_3
212            | VK_OEM_MINUS
213            | VK_OEM_PLUS
214            | VK_OEM_4
215            | VK_OEM_5
216            | VK_OEM_6
217            | VK_OEM_1
218            | VK_OEM_7
219            | VK_OEM_COMMA
220            | VK_OEM_PERIOD
221            | VK_OEM_2
222            | VK_OEM_102
223            | VK_OEM_8
224            | VK_ABNT_C1
225            | VK_0
226            | VK_1
227            | VK_2
228            | VK_3
229            | VK_4
230            | VK_5
231            | VK_6
232            | VK_7
233            | VK_8
234            | VK_9
235    )
236}
237
238fn get_shifted_key(vkey: VIRTUAL_KEY, scan_code: u32) -> Option<String> {
239    generate_key_char(vkey, scan_code, false, true, false)
240}
241
242pub(crate) fn generate_key_char(
243    vkey: VIRTUAL_KEY,
244    scan_code: u32,
245    control: bool,
246    shift: bool,
247    alt: bool,
248) -> Option<String> {
249    let mut state = [0; 256];
250    if control {
251        state[VK_CONTROL.0 as usize] = 0x80;
252    }
253    if shift {
254        state[VK_SHIFT.0 as usize] = 0x80;
255    }
256    if alt {
257        state[VK_MENU.0 as usize] = 0x80;
258    }
259
260    let mut buffer = [0; 8];
261    let len = unsafe { ToUnicode(vkey.0 as u32, scan_code, Some(&state), &mut buffer, 1 << 2) };
262
263    match len {
264        len if len > 0 => String::from_utf16(&buffer[..len as usize])
265            .ok()
266            .filter(|candidate| {
267                !candidate.is_empty() && !candidate.chars().next().unwrap().is_control()
268            }),
269        len if len < 0 => String::from_utf16(&buffer[..(-len as usize)]).ok(),
270        _ => None,
271    }
272}
273
274fn get_vkey_from_key_with_us_layout(key: &str) -> Option<(u16, bool)> {
275    match key {
276        // ` => VK_OEM_3
277        "`" => Some((VK_OEM_3.0, false)),
278        "~" => Some((VK_OEM_3.0, true)),
279        "1" => Some((VK_1.0, false)),
280        "!" => Some((VK_1.0, true)),
281        "2" => Some((VK_2.0, false)),
282        "@" => Some((VK_2.0, true)),
283        "3" => Some((VK_3.0, false)),
284        "#" => Some((VK_3.0, true)),
285        "4" => Some((VK_4.0, false)),
286        "$" => Some((VK_4.0, true)),
287        "5" => Some((VK_5.0, false)),
288        "%" => Some((VK_5.0, true)),
289        "6" => Some((VK_6.0, false)),
290        "^" => Some((VK_6.0, true)),
291        "7" => Some((VK_7.0, false)),
292        "&" => Some((VK_7.0, true)),
293        "8" => Some((VK_8.0, false)),
294        "*" => Some((VK_8.0, true)),
295        "9" => Some((VK_9.0, false)),
296        "(" => Some((VK_9.0, true)),
297        "0" => Some((VK_0.0, false)),
298        ")" => Some((VK_0.0, true)),
299        "-" => Some((VK_OEM_MINUS.0, false)),
300        "_" => Some((VK_OEM_MINUS.0, true)),
301        "=" => Some((VK_OEM_PLUS.0, false)),
302        "+" => Some((VK_OEM_PLUS.0, true)),
303        "[" => Some((VK_OEM_4.0, false)),
304        "{" => Some((VK_OEM_4.0, true)),
305        "]" => Some((VK_OEM_6.0, false)),
306        "}" => Some((VK_OEM_6.0, true)),
307        "\\" => Some((VK_OEM_5.0, false)),
308        "|" => Some((VK_OEM_5.0, true)),
309        ";" => Some((VK_OEM_1.0, false)),
310        ":" => Some((VK_OEM_1.0, true)),
311        "'" => Some((VK_OEM_7.0, false)),
312        "\"" => Some((VK_OEM_7.0, true)),
313        "," => Some((VK_OEM_COMMA.0, false)),
314        "<" => Some((VK_OEM_COMMA.0, true)),
315        "." => Some((VK_OEM_PERIOD.0, false)),
316        ">" => Some((VK_OEM_PERIOD.0, true)),
317        "/" => Some((VK_OEM_2.0, false)),
318        "?" => Some((VK_OEM_2.0, true)),
319        _ => None,
320    }
321}
322
323const CANDIDATE_VKEYS: &[VIRTUAL_KEY] = &[
324    VK_OEM_3,
325    VK_OEM_MINUS,
326    VK_OEM_PLUS,
327    VK_OEM_4,
328    VK_OEM_5,
329    VK_OEM_6,
330    VK_OEM_1,
331    VK_OEM_7,
332    VK_OEM_COMMA,
333    VK_OEM_PERIOD,
334    VK_OEM_2,
335    VK_OEM_102,
336    VK_OEM_8,
337    VK_ABNT_C1,
338    VK_0,
339    VK_1,
340    VK_2,
341    VK_3,
342    VK_4,
343    VK_5,
344    VK_6,
345    VK_7,
346    VK_8,
347    VK_9,
348];
349
350#[cfg(test)]
351mod tests {
352    use crate::{Keystroke, Modifiers, PlatformKeyboardMapper, WindowsKeyboardMapper};
353
354    #[test]
355    fn test_keyboard_mapper() {
356        let mapper = WindowsKeyboardMapper::new();
357
358        // Normal case
359        let keystroke = Keystroke {
360            modifiers: Modifiers::control(),
361            key: "a".to_string(),
362            key_char: None,
363        };
364        let mapped = mapper.map_key_equivalent(keystroke.clone(), true);
365        assert_eq!(*mapped.inner(), keystroke);
366        assert_eq!(mapped.key(), "a");
367        assert_eq!(*mapped.modifiers(), Modifiers::control());
368
369        // Shifted case, ctrl-$
370        let keystroke = Keystroke {
371            modifiers: Modifiers::control(),
372            key: "$".to_string(),
373            key_char: None,
374        };
375        let mapped = mapper.map_key_equivalent(keystroke.clone(), true);
376        assert_eq!(*mapped.inner(), keystroke);
377        assert_eq!(mapped.key(), "4");
378        assert_eq!(*mapped.modifiers(), Modifiers::control_shift());
379
380        // Shifted case, but shift is true
381        let keystroke = Keystroke {
382            modifiers: Modifiers::control_shift(),
383            key: "$".to_string(),
384            key_char: None,
385        };
386        let mapped = mapper.map_key_equivalent(keystroke, true);
387        assert_eq!(mapped.inner().modifiers, Modifiers::control());
388        assert_eq!(mapped.key(), "4");
389        assert_eq!(*mapped.modifiers(), Modifiers::control_shift());
390
391        // Windows style
392        let keystroke = Keystroke {
393            modifiers: Modifiers::control_shift(),
394            key: "4".to_string(),
395            key_char: None,
396        };
397        let mapped = mapper.map_key_equivalent(keystroke, true);
398        assert_eq!(mapped.inner().modifiers, Modifiers::control());
399        assert_eq!(mapped.inner().key, "$");
400        assert_eq!(mapped.key(), "4");
401        assert_eq!(*mapped.modifiers(), Modifiers::control_shift());
402    }
403}