keyboard.rs

  1use anyhow::Result;
  2use collections::HashMap;
  3use windows::Win32::UI::{
  4    Input::KeyboardAndMouse::{
  5        GetKeyboardLayoutNameW, MAPVK_VK_TO_CHAR, MAPVK_VK_TO_VSC, MapVirtualKeyW, ToUnicode,
  6        VIRTUAL_KEY, VK_0, VK_1, VK_2, VK_3, VK_4, VK_5, VK_6, VK_7, VK_8, VK_9, VK_ABNT_C1,
  7        VK_CONTROL, VK_MENU, VK_OEM_1, VK_OEM_2, VK_OEM_3, VK_OEM_4, VK_OEM_5, VK_OEM_6, VK_OEM_7,
  8        VK_OEM_8, VK_OEM_102, VK_OEM_COMMA, VK_OEM_MINUS, VK_OEM_PERIOD, VK_OEM_PLUS, VK_SHIFT,
  9    },
 10    WindowsAndMessaging::KL_NAMELENGTH,
 11};
 12
 13use crate::{
 14    KeybindingKeystroke, Keystroke, Modifiers, PlatformKeyboardLayout, PlatformKeyboardMapper,
 15};
 16
 17pub(crate) struct WindowsKeyboardLayout {
 18    id: String,
 19    name: String,
 20}
 21
 22pub(crate) struct WindowsKeyboardMapper {
 23    key_to_vkey: HashMap<String, (u16, bool)>,
 24    vkey_to_key: HashMap<u16, String>,
 25    vkey_to_shifted: HashMap<u16, String>,
 26}
 27
 28impl PlatformKeyboardLayout for WindowsKeyboardLayout {
 29    fn id(&self) -> &str {
 30        &self.id
 31    }
 32
 33    fn name(&self) -> &str {
 34        &self.name
 35    }
 36}
 37
 38impl PlatformKeyboardMapper for WindowsKeyboardMapper {
 39    fn map_key_equivalent(
 40        &self,
 41        mut keystroke: Keystroke,
 42        use_key_equivalents: bool,
 43    ) -> KeybindingKeystroke {
 44        let Some((vkey, shifted_key)) = self.get_vkey_from_key(&keystroke.key, use_key_equivalents)
 45        else {
 46            return KeybindingKeystroke::from_keystroke(keystroke);
 47        };
 48        if shifted_key && keystroke.modifiers.shift {
 49            log::warn!(
 50                "Keystroke '{}' has both shift and a shifted key, this is likely a bug",
 51                keystroke.key
 52            );
 53        }
 54
 55        let shift = shifted_key || keystroke.modifiers.shift;
 56        keystroke.modifiers.shift = false;
 57
 58        let Some(key) = self.vkey_to_key.get(&vkey).cloned() else {
 59            log::error!(
 60                "Failed to map key equivalent '{:?}' to a valid key",
 61                keystroke
 62            );
 63            return KeybindingKeystroke::from_keystroke(keystroke);
 64        };
 65
 66        keystroke.key = if shift {
 67            let Some(shifted_key) = self.vkey_to_shifted.get(&vkey).cloned() else {
 68                log::error!(
 69                    "Failed to map keystroke {:?} with virtual key '{:?}' to a shifted key",
 70                    keystroke,
 71                    vkey
 72                );
 73                return KeybindingKeystroke::from_keystroke(keystroke);
 74            };
 75            shifted_key
 76        } else {
 77            key.clone()
 78        };
 79
 80        let modifiers = Modifiers {
 81            shift,
 82            ..keystroke.modifiers
 83        };
 84
 85        KeybindingKeystroke::new(keystroke, modifiers, key)
 86    }
 87
 88    fn get_key_equivalents(&self) -> Option<&HashMap<char, char>> {
 89        None
 90    }
 91}
 92
 93impl WindowsKeyboardLayout {
 94    pub(crate) fn new() -> Result<Self> {
 95        let mut buffer = [0u16; KL_NAMELENGTH as usize]; // KL_NAMELENGTH includes the null terminator
 96        unsafe { GetKeyboardLayoutNameW(&mut buffer)? };
 97        let id = String::from_utf16_lossy(&buffer[..buffer.len() - 1]); // Remove the null terminator
 98        let entry = windows_registry::LOCAL_MACHINE.open(format!(
 99            "System\\CurrentControlSet\\Control\\Keyboard Layouts\\{id}"
100        ))?;
101        let name = entry.get_string("Layout Text")?;
102        Ok(Self { id, name })
103    }
104
105    pub(crate) fn unknown() -> Self {
106        Self {
107            id: "unknown".to_string(),
108            name: "unknown".to_string(),
109        }
110    }
111}
112
113impl WindowsKeyboardMapper {
114    pub(crate) fn new() -> Self {
115        let mut key_to_vkey = HashMap::default();
116        let mut vkey_to_key = HashMap::default();
117        let mut vkey_to_shifted = HashMap::default();
118        for vkey in CANDIDATE_VKEYS {
119            if let Some(key) = get_key_from_vkey(*vkey) {
120                key_to_vkey.insert(key.clone(), (vkey.0, false));
121                vkey_to_key.insert(vkey.0, key);
122            }
123            let scan_code = unsafe { MapVirtualKeyW(vkey.0 as u32, MAPVK_VK_TO_VSC) };
124            if scan_code == 0 {
125                continue;
126            }
127            if let Some(shifted_key) = get_shifted_key(*vkey, scan_code) {
128                key_to_vkey.insert(shifted_key.clone(), (vkey.0, true));
129                vkey_to_shifted.insert(vkey.0, shifted_key);
130            }
131        }
132        Self {
133            key_to_vkey,
134            vkey_to_key,
135            vkey_to_shifted,
136        }
137    }
138
139    fn get_vkey_from_key(&self, key: &str, use_key_equivalents: bool) -> Option<(u16, bool)> {
140        if use_key_equivalents {
141            get_vkey_from_key_with_us_layout(key)
142        } else {
143            self.key_to_vkey.get(key).cloned()
144        }
145    }
146}
147
148pub(crate) fn get_keystroke_key(
149    vkey: VIRTUAL_KEY,
150    scan_code: u32,
151    modifiers: &mut Modifiers,
152) -> Option<String> {
153    if modifiers.shift && need_to_convert_to_shifted_key(vkey) {
154        get_shifted_key(vkey, scan_code).inspect(|_| {
155            modifiers.shift = false;
156        })
157    } else {
158        get_key_from_vkey(vkey)
159    }
160}
161
162fn get_key_from_vkey(vkey: VIRTUAL_KEY) -> Option<String> {
163    let key_data = unsafe { MapVirtualKeyW(vkey.0 as u32, MAPVK_VK_TO_CHAR) };
164    if key_data == 0 {
165        return None;
166    }
167
168    // The high word contains dead key flag, the low word contains the character
169    let key = char::from_u32(key_data & 0xFFFF)?;
170
171    Some(key.to_ascii_lowercase().to_string())
172}
173
174#[inline]
175fn need_to_convert_to_shifted_key(vkey: VIRTUAL_KEY) -> bool {
176    matches!(
177        vkey,
178        VK_OEM_3
179            | VK_OEM_MINUS
180            | VK_OEM_PLUS
181            | VK_OEM_4
182            | VK_OEM_5
183            | VK_OEM_6
184            | VK_OEM_1
185            | VK_OEM_7
186            | VK_OEM_COMMA
187            | VK_OEM_PERIOD
188            | VK_OEM_2
189            | VK_OEM_102
190            | VK_OEM_8
191            | VK_ABNT_C1
192            | VK_0
193            | VK_1
194            | VK_2
195            | VK_3
196            | VK_4
197            | VK_5
198            | VK_6
199            | VK_7
200            | VK_8
201            | VK_9
202    )
203}
204
205fn get_shifted_key(vkey: VIRTUAL_KEY, scan_code: u32) -> Option<String> {
206    generate_key_char(vkey, scan_code, false, true, false)
207}
208
209pub(crate) fn generate_key_char(
210    vkey: VIRTUAL_KEY,
211    scan_code: u32,
212    control: bool,
213    shift: bool,
214    alt: bool,
215) -> Option<String> {
216    let mut state = [0; 256];
217    if control {
218        state[VK_CONTROL.0 as usize] = 0x80;
219    }
220    if shift {
221        state[VK_SHIFT.0 as usize] = 0x80;
222    }
223    if alt {
224        state[VK_MENU.0 as usize] = 0x80;
225    }
226
227    let mut buffer = [0; 8];
228    let len = unsafe { ToUnicode(vkey.0 as u32, scan_code, Some(&state), &mut buffer, 0x5) };
229
230    match len {
231        len if len > 0 => String::from_utf16(&buffer[..len as usize])
232            .ok()
233            .filter(|candidate| {
234                !candidate.is_empty() && !candidate.chars().next().unwrap().is_control()
235            }),
236        len if len < 0 => String::from_utf16(&buffer[..(-len as usize)]).ok(),
237        _ => None,
238    }
239}
240
241fn get_vkey_from_key_with_us_layout(key: &str) -> Option<(u16, bool)> {
242    match key {
243        // ` => VK_OEM_3
244        "`" => Some((VK_OEM_3.0, false)),
245        "~" => Some((VK_OEM_3.0, true)),
246        "1" => Some((VK_1.0, false)),
247        "!" => Some((VK_1.0, true)),
248        "2" => Some((VK_2.0, false)),
249        "@" => Some((VK_2.0, true)),
250        "3" => Some((VK_3.0, false)),
251        "#" => Some((VK_3.0, true)),
252        "4" => Some((VK_4.0, false)),
253        "$" => Some((VK_4.0, true)),
254        "5" => Some((VK_5.0, false)),
255        "%" => Some((VK_5.0, true)),
256        "6" => Some((VK_6.0, false)),
257        "^" => Some((VK_6.0, true)),
258        "7" => Some((VK_7.0, false)),
259        "&" => Some((VK_7.0, true)),
260        "8" => Some((VK_8.0, false)),
261        "*" => Some((VK_8.0, true)),
262        "9" => Some((VK_9.0, false)),
263        "(" => Some((VK_9.0, true)),
264        "0" => Some((VK_0.0, false)),
265        ")" => Some((VK_0.0, true)),
266        "-" => Some((VK_OEM_MINUS.0, false)),
267        "_" => Some((VK_OEM_MINUS.0, true)),
268        "=" => Some((VK_OEM_PLUS.0, false)),
269        "+" => Some((VK_OEM_PLUS.0, true)),
270        "[" => Some((VK_OEM_4.0, false)),
271        "{" => Some((VK_OEM_4.0, true)),
272        "]" => Some((VK_OEM_6.0, false)),
273        "}" => Some((VK_OEM_6.0, true)),
274        "\\" => Some((VK_OEM_5.0, false)),
275        "|" => Some((VK_OEM_5.0, true)),
276        ";" => Some((VK_OEM_1.0, false)),
277        ":" => Some((VK_OEM_1.0, true)),
278        "'" => Some((VK_OEM_7.0, false)),
279        "\"" => Some((VK_OEM_7.0, true)),
280        "," => Some((VK_OEM_COMMA.0, false)),
281        "<" => Some((VK_OEM_COMMA.0, true)),
282        "." => Some((VK_OEM_PERIOD.0, false)),
283        ">" => Some((VK_OEM_PERIOD.0, true)),
284        "/" => Some((VK_OEM_2.0, false)),
285        "?" => Some((VK_OEM_2.0, true)),
286        _ => None,
287    }
288}
289
290const CANDIDATE_VKEYS: &[VIRTUAL_KEY] = &[
291    VK_OEM_3,
292    VK_OEM_MINUS,
293    VK_OEM_PLUS,
294    VK_OEM_4,
295    VK_OEM_5,
296    VK_OEM_6,
297    VK_OEM_1,
298    VK_OEM_7,
299    VK_OEM_COMMA,
300    VK_OEM_PERIOD,
301    VK_OEM_2,
302    VK_OEM_102,
303    VK_OEM_8,
304    VK_ABNT_C1,
305    VK_0,
306    VK_1,
307    VK_2,
308    VK_3,
309    VK_4,
310    VK_5,
311    VK_6,
312    VK_7,
313    VK_8,
314    VK_9,
315];
316
317#[cfg(test)]
318mod tests {
319    use crate::{Keystroke, Modifiers, PlatformKeyboardMapper, WindowsKeyboardMapper};
320
321    #[test]
322    fn test_keyboard_mapper() {
323        let mapper = WindowsKeyboardMapper::new();
324
325        // Normal case
326        let keystroke = Keystroke {
327            modifiers: Modifiers::control(),
328            key: "a".to_string(),
329            key_char: None,
330        };
331        let mapped = mapper.map_key_equivalent(keystroke.clone(), true);
332        assert_eq!(*mapped.inner(), keystroke);
333        assert_eq!(mapped.key(), "a");
334        assert_eq!(*mapped.modifiers(), Modifiers::control());
335
336        // Shifted case, ctrl-$
337        let keystroke = Keystroke {
338            modifiers: Modifiers::control(),
339            key: "$".to_string(),
340            key_char: None,
341        };
342        let mapped = mapper.map_key_equivalent(keystroke.clone(), true);
343        assert_eq!(*mapped.inner(), keystroke);
344        assert_eq!(mapped.key(), "4");
345        assert_eq!(*mapped.modifiers(), Modifiers::control_shift());
346
347        // Shifted case, but shift is true
348        let keystroke = Keystroke {
349            modifiers: Modifiers::control_shift(),
350            key: "$".to_string(),
351            key_char: None,
352        };
353        let mapped = mapper.map_key_equivalent(keystroke, true);
354        assert_eq!(mapped.inner().modifiers, Modifiers::control());
355        assert_eq!(mapped.key(), "4");
356        assert_eq!(*mapped.modifiers(), Modifiers::control_shift());
357
358        // Windows style
359        let keystroke = Keystroke {
360            modifiers: Modifiers::control_shift(),
361            key: "4".to_string(),
362            key_char: None,
363        };
364        let mapped = mapper.map_key_equivalent(keystroke, true);
365        assert_eq!(mapped.inner().modifiers, Modifiers::control());
366        assert_eq!(mapped.inner().key, "$");
367        assert_eq!(mapped.key(), "4");
368        assert_eq!(*mapped.modifiers(), Modifiers::control_shift());
369    }
370}