get magic palette working

Mikayla Maki created

Change summary

crates/magic_palette/src/magic_palette.rs | 309 +++++++++++++-----------
1 file changed, 167 insertions(+), 142 deletions(-)

Detailed changes

crates/magic_palette/src/magic_palette.rs 🔗

@@ -4,8 +4,8 @@ use cloud_llm_client::CompletionIntent;
 use command_palette::humanize_action_name;
 use futures::StreamExt as _;
 use gpui::{
-    Action, AppContext as _, DismissEvent, Entity, EventEmitter, Focusable, IntoElement, Task,
-    WeakEntity,
+    Action, AppContext as _, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable,
+    IntoElement, Task, WeakEntity,
 };
 use language_model::{
     ConfiguredModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
@@ -13,8 +13,8 @@ use language_model::{
 use picker::{Picker, PickerDelegate};
 use settings::Settings as _;
 use ui::{
-    App, Context, InteractiveElement, ListItem, ParentElement as _, Render, Styled as _, Window,
-    div, rems,
+    App, Context, InteractiveElement, KeyBinding, Label, ListItem, ListItemSpacing,
+    ParentElement as _, Render, Styled as _, Toggleable as _, Window, div, h_flex, rems,
 };
 use util::ResultExt;
 use workspace::{ModalView, Workspace};
@@ -25,9 +25,16 @@ pub fn init(cx: &mut App) {
 
 gpui::actions!(magic_palette, [Toggle]);
 
+fn format_prompt(query: &str, actions: &str) -> String {
+    format!(
+        "Match the query: \"{query}\" to relevant actions. Return 5-10 action names, most relevant first, one per line.
+        Actions:
+        {actions}"
+    )
+}
+
 struct MagicPalette {
     picker: Entity<Picker<MagicPaletteDelegate>>,
-    matches: Vec<Command>,
 }
 
 impl ModalView for MagicPalette {}
@@ -35,7 +42,7 @@ impl ModalView for MagicPalette {}
 impl EventEmitter<DismissEvent> for MagicPalette {}
 
 impl Focusable for MagicPalette {
-    fn focus_handle(&self, cx: &App) -> gpui::FocusHandle {
+    fn focus_handle(&self, cx: &App) -> FocusHandle {
         self.picker.focus_handle(cx)
     }
 }
@@ -52,22 +59,29 @@ impl MagicPalette {
     }
 
     fn toggle(workspace: &mut Workspace, window: &mut Window, cx: &mut Context<Workspace>) {
+        let Some(previous_focus_handle) = window.focused(cx) else {
+            return;
+        };
+
         if agent_settings::AgentSettings::get_global(cx).enabled(cx) {
-            workspace.toggle_modal(window, cx, |window, cx| MagicPalette::new(window, cx));
+            workspace.toggle_modal(window, cx, |window, cx| {
+                MagicPalette::new(previous_focus_handle, window, cx)
+            });
         }
     }
 
-    fn new(window: &mut Window, cx: &mut Context<Self>) -> Self {
+    fn new(
+        previous_focus_handle: FocusHandle,
+        window: &mut Window,
+        cx: &mut Context<Self>,
+    ) -> Self {
         let this = cx.weak_entity();
-        let delegate = MagicPaletteDelegate::new(this);
+        let delegate = MagicPaletteDelegate::new(this, previous_focus_handle);
         let picker = cx.new(|cx| {
             let picker = Picker::uniform_list(delegate, window, cx);
             picker
         });
-        Self {
-            picker,
-            matches: vec![],
-        }
+        Self { picker }
     }
 }
 
@@ -86,27 +100,24 @@ struct Command {
     action: Box<dyn Action>,
 }
 
-enum MagicPaletteMode {
-    WriteQuery,
-    SelectResult(Vec<Command>),
-}
-
 struct MagicPaletteDelegate {
     query: String,
-    llm_generation_task: Task<Result<()>>,
+    llm_generation_task: Option<Task<Result<()>>>,
     magic_palette: WeakEntity<MagicPalette>,
-    mode: MagicPaletteMode,
+    matches: Vec<Command>,
     selected_index: usize,
+    previous_focus_handle: FocusHandle,
 }
 
 impl MagicPaletteDelegate {
-    fn new(magic_palette: WeakEntity<MagicPalette>) -> Self {
+    fn new(magic_palette: WeakEntity<MagicPalette>, previous_focus_handle: FocusHandle) -> Self {
         Self {
             query: String::new(),
-            llm_generation_task: Task::ready(Ok(())),
+            llm_generation_task: None,
             magic_palette,
-            mode: MagicPaletteMode::WriteQuery,
+            matches: vec![],
             selected_index: 0,
+            previous_focus_handle,
         }
     }
 }
@@ -115,10 +126,7 @@ impl PickerDelegate for MagicPaletteDelegate {
     type ListItem = ListItem;
 
     fn match_count(&self) -> usize {
-        match &self.mode {
-            MagicPaletteMode::WriteQuery => 0,
-            MagicPaletteMode::SelectResult(commands) => commands.len(),
-        }
+        self.matches.len()
     }
 
     fn selected_index(&self) -> usize {
@@ -138,6 +146,14 @@ impl PickerDelegate for MagicPaletteDelegate {
         "Ask Zed AI what actions you want to perform...".into()
     }
 
+    fn no_matches_text(&self, _window: &mut Window, _cx: &mut App) -> Option<ui::SharedString> {
+        if self.llm_generation_task.is_some() {
+            Some("Generating...".into())
+        } else {
+            None
+        }
+    }
+
     fn update_matches(
         &mut self,
         query: String,
@@ -154,118 +170,116 @@ impl PickerDelegate for MagicPaletteDelegate {
         window: &mut Window,
         cx: &mut Context<picker::Picker<Self>>,
     ) {
-        match &self.mode {
-            MagicPaletteMode::WriteQuery => {
-                let Some(ConfiguredModel { provider, model }) =
-                    LanguageModelRegistry::read_global(cx).commit_message_model()
-                else {
-                    return;
-                };
-                let temperature = AgentSettings::temperature_for_model(&model, cx);
-                let query = self.query.clone();
-                self.llm_generation_task = cx.spawn_in(window, async move |this, cx| {
-                    let actions = cx.update(|_, cx| cx.action_documentation().clone())?;
-
-                    if let Some(task) = cx.update(|_, cx| {
-                        if !provider.is_authenticated(cx) {
-                            Some(provider.authenticate(cx))
-                        } else {
-                            None
-                        }
-                    })? {
-                        task.await.log_err();
-                    };
-
-                    let actions = actions
-                        .into_iter()
-                        .map(|(name, descriptiopn)| format!("{} – {}", name, descriptiopn))
-                        .collect::<Vec<String>>();
-                    let actions = actions.join("\n");
-                    let prompt = format!(
-                        "You are helping a user find the most relevant actions in Zed editor based on their natural language query.
-
-User query: \"{query}\"
-
-Available actions in Zed:
-{actions}
-
-Instructions:
-1. Analyze the user's query to understand their intent
-2. Match the query against the available actions, considering:
-   - Exact keyword matches
-   - Semantic similarity (e.g., \"open file\" matches \"workspace::Open\")
-   - Common synonyms and alternative phrasings
-   - Partial matches where relevant
-3. Return the top 5-10 most relevant actions in order of relevance
-4. Return each action name exactly as shown in the list above
-5. If no good matches exist, return the closest alternatives
-
-Format your response as a simple list of action names, one per line, with no additional text or explanation."
-                    );
-                    dbg!(&prompt);
-
-                    let request = LanguageModelRequest {
-                        thread_id: None,
-                        prompt_id: None,
-                        intent: Some(CompletionIntent::GenerateGitCommitMessage),
-                        mode: None,
-                        messages: vec![LanguageModelRequestMessage {
-                            role: Role::User,
-                            content: vec![prompt.into()],
-                            cache: false,
-                        }],
-                        tools: Vec::new(),
-                        tool_choice: None,
-                        stop: Vec::new(),
-                        temperature,
-                        thinking_allowed: false,
-                    };
-
-                    let stream = model.stream_completion_text(request, cx);
-                    dbg!("pinging stream");
-                    let mut messages = stream.await?;
-                    let mut buffer = String::new();
-                    while let Some(Ok(message)) = messages.stream.next().await {
-                        buffer.push_str(&message);
+        if self.matches.is_empty() {
+            let Some(ConfiguredModel { provider, model }) =
+                LanguageModelRegistry::read_global(cx).commit_message_model()
+            else {
+                return;
+            };
+            let temperature = AgentSettings::temperature_for_model(&model, cx);
+            let query = self.query.clone();
+            cx.notify();
+            self.llm_generation_task = Some(cx.spawn_in(window, async move |this, cx| {
+                let actions = cx.update(|_, cx| cx.action_documentation().clone())?;
+
+                if let Some(task) = cx.update(|_, cx| {
+                    if !provider.is_authenticated(cx) {
+                        Some(provider.authenticate(cx))
+                    } else {
+                        None
                     }
+                })? {
+                    task.await.log_err();
+                };
 
-                    dbg!(&buffer);
-
-                    // Split result by `\n` and for each string, call `cx.build_action`.
-                    let commands = cx.update(move |_window, cx| {
-                        let mut commands: Vec<Command> = vec![];
-
-                        for name in buffer.lines() {
-                            dbg!(name);
+                let actions = actions
+                    .into_iter()
+                    .filter(|(action, _)| !action.starts_with("vim") && !action.starts_with("dev"))
+                    .map(|(name, description)| {
+                        let short = description
+                            .split_whitespace()
+                            .take(5)
+                            .collect::<Vec<_>>()
+                            .join(" ");
+
+                        format!("{} | {}", name, short)
+                    })
+                    .collect::<Vec<String>>();
+                let actions = actions.join("\n");
+                let prompt = format_prompt(&query, &actions);
+                println!("{}", prompt);
+
+                let request = LanguageModelRequest {
+                    thread_id: None,
+                    prompt_id: None,
+                    intent: Some(CompletionIntent::GenerateGitCommitMessage),
+                    mode: None,
+                    messages: vec![LanguageModelRequestMessage {
+                        role: Role::User,
+                        content: vec![prompt.into()],
+                        cache: false,
+                    }],
+                    tools: Vec::new(),
+                    tool_choice: None,
+                    stop: Vec::new(),
+                    temperature,
+                    thinking_allowed: false,
+                };
 
-                            let action = cx.build_action(name, None);
-                            match action {
-                                Ok(action) => {
-                                    commands.push(Command { action: action, name: humanize_action_name(name) })
-                                    },
-                                Err(err) => {
-                                    log::error!("Failed to build action: {}", err);
-                                }
+                let stream = model.stream_completion_text(request, cx);
+                dbg!("pinging stream");
+                let mut messages = stream.await?;
+                let mut buffer = String::new();
+                while let Some(Ok(message)) = messages.stream.next().await {
+                    buffer.push_str(&message);
+                }
+
+                // Split result by `\n` and for each string, call `cx.build_action`.
+                let commands = cx.update(move |_window, cx| {
+                    let mut commands: Vec<Command> = vec![];
+
+                    for name in buffer.lines() {
+                        dbg!(name);
+
+                        let action = cx.build_action(name, None);
+                        match action {
+                            Ok(action) => commands.push(Command {
+                                action: action,
+                                name: humanize_action_name(name),
+                            }),
+                            Err(err) => {
+                                log::error!("Failed to build action: {}", err);
                             }
                         }
-
-                        commands
-                    });
-
-                    dbg!(&commands);
-                    if let Ok(commands) = commands {
-                        let _ = this.update(cx, |this, cx| {
-                            let _ = this.delegate.magic_palette.update(cx, |magic_palette, _| {
-                                magic_palette.matches = commands;
-                            });
-                        });
                     }
 
-                    //
-                    Ok(())
-                });
-            }
-            MagicPaletteMode::SelectResult(commands) => todo!(),
+                    commands
+                })?;
+
+                this.update(cx, |this, cx| {
+                    this.delegate.matches = commands;
+                    this.delegate.llm_generation_task = None;
+                    this.delegate.selected_index = 0;
+                    cx.notify();
+                })?;
+
+                Ok(())
+            }));
+        } else {
+            let command = self.matches.swap_remove(self.selected_index);
+            telemetry::event!(
+                "Action Invoked",
+                source = "magic palette",
+                action = command.name
+            );
+            self.matches.clear();
+            self.query.clear();
+            self.llm_generation_task.take();
+
+            let action = command.action;
+            window.focus(&self.previous_focus_handle);
+            self.dismissed(window, cx);
+            window.dispatch_action(action, cx);
         }
     }
 
@@ -281,17 +295,28 @@ Format your response as a simple list of action names, one per line, with no add
         &self,
         ix: usize,
         selected: bool,
-        window: &mut Window,
+        _window: &mut Window,
         cx: &mut Context<picker::Picker<Self>>,
     ) -> Option<Self::ListItem> {
-        None
-    }
-
-    fn confirm_input(
-        &mut self,
-        _secondary: bool,
-        _window: &mut Window,
-        _: &mut Context<picker::Picker<Self>>,
-    ) {
+        let command = self.matches.get(ix)?;
+
+        Some(
+            ListItem::new(ix)
+                .inset(true)
+                .spacing(ListItemSpacing::Sparse)
+                .toggle_state(selected)
+                .child(
+                    h_flex()
+                        .w_full()
+                        .py_px()
+                        .justify_between()
+                        .child(Label::new(command.name.clone()))
+                        .child(KeyBinding::for_action_in(
+                            &*command.action,
+                            &self.previous_focus_handle,
+                            cx,
+                        )),
+                ),
+        )
     }
 }