Show inline assistant errors

Antonio Scandurra created

Change summary

crates/ai/src/assistant.rs | 158 +++++++++++++++++++++++++++++++--------
1 file changed, 124 insertions(+), 34 deletions(-)

Detailed changes

crates/ai/src/assistant.rs 🔗

@@ -40,7 +40,7 @@ use std::{
     cell::{Cell, RefCell},
     cmp, env,
     fmt::Write,
-    iter,
+    future, iter,
     ops::Range,
     path::{Path, PathBuf},
     rc::Rc,
@@ -55,7 +55,7 @@ use util::{paths::CONVERSATIONS_DIR, post_inc, ResultExt, TryFutureExt};
 use workspace::{
     dock::{DockPosition, Panel},
     searchable::Direction,
-    Save, ToggleZoom, Toolbar, Workspace,
+    Save, Toast, ToggleZoom, Toolbar, Workspace,
 };
 
 actions!(
@@ -290,6 +290,7 @@ impl AssistantPanel {
                 has_focus: false,
                 include_conversation: self.include_conversation_in_next_inline_assist,
                 measurements: measurements.clone(),
+                error: None,
             };
             cx.focus_self();
             assistant
@@ -331,7 +332,7 @@ impl AssistantPanel {
                 editor: editor.downgrade(),
                 range,
                 highlighted_ranges: Default::default(),
-                inline_assistant_block_id: Some(block_id),
+                inline_assistant: Some((block_id, inline_assistant.clone())),
                 code_generation: Task::ready(None),
                 transaction_id: None,
                 _subscriptions: vec![
@@ -477,7 +478,7 @@ impl AssistantPanel {
     fn hide_inline_assist(&mut self, assist_id: usize, cx: &mut ViewContext<Self>) {
         if let Some(pending_assist) = self.pending_inline_assists.get_mut(&assist_id) {
             if let Some(editor) = pending_assist.editor.upgrade(cx) {
-                if let Some(block_id) = pending_assist.inline_assistant_block_id.take() {
+                if let Some((block_id, _)) = pending_assist.inline_assistant.take() {
                     editor.update(cx, |editor, cx| {
                         editor.remove_blocks(HashSet::from_iter([block_id]), None, cx);
                     });
@@ -699,22 +700,17 @@ impl AssistantPanel {
 
         pending_assist.code_generation = cx.spawn(|this, mut cx| {
             async move {
-                let _cleanup = util::defer({
-                    let mut cx = cx.clone();
-                    let this = this.clone();
-                    move || {
-                        let _ = this.update(&mut cx, |this, cx| {
-                            this.close_inline_assist(inline_assist_id, false, cx)
-                        });
-                    }
-                });
-
                 let mut edit_start = range.start.to_offset(&snapshot);
 
                 let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
                 let diff = cx.background().spawn(async move {
                     let chunks = strip_markdown_codeblock(response.await?.filter_map(
-                        |message| async move { message.ok()?.choices.pop()?.delta.content },
+                        |message| async move {
+                            match message {
+                                Ok(mut message) => Some(Ok(message.choices.pop()?.delta.content?)),
+                                Err(error) => Some(Err(error)),
+                            }
+                        },
                     ));
                     futures::pin_mut!(chunks);
                     let mut diff = StreamingDiff::new(selected_text.to_string());
@@ -737,6 +733,7 @@ impl AssistantPanel {
                     let mut new_text = String::new();
 
                     while let Some(chunk) = chunks.next().await {
+                        let chunk = chunk?;
                         if first_chunk && (chunk.starts_with(' ') || chunk.starts_with('\t')) {
                             autoindent = false;
                         }
@@ -771,9 +768,17 @@ impl AssistantPanel {
                 });
 
                 while let Some(hunks) = hunks_rx.next().await {
-                    let editor = editor
-                        .upgrade(&cx)
-                        .ok_or_else(|| anyhow!("editor was dropped"))?;
+                    let editor = if let Some(editor) = editor.upgrade(&cx) {
+                        editor
+                    } else {
+                        break;
+                    };
+
+                    let this = if let Some(this) = this.upgrade(&cx) {
+                        this
+                    } else {
+                        break;
+                    };
 
                     this.update(&mut cx, |this, cx| {
                         let pending_assist = if let Some(pending_assist) =
@@ -840,9 +845,42 @@ impl AssistantPanel {
                         });
 
                         this.update_highlights_for_editor(&editor, cx);
+                    });
+                }
+
+                if let Err(error) = diff.await {
+                    this.update(&mut cx, |this, cx| {
+                        let pending_assist = if let Some(pending_assist) =
+                            this.pending_inline_assists.get_mut(&inline_assist_id)
+                        {
+                            pending_assist
+                        } else {
+                            return;
+                        };
+
+                        if let Some((_, inline_assistant)) =
+                            pending_assist.inline_assistant.as_ref()
+                        {
+                            inline_assistant.update(cx, |inline_assistant, cx| {
+                                inline_assistant.set_error(error, cx);
+                            });
+                        } else if let Some(workspace) = this.workspace.upgrade(cx) {
+                            workspace.update(cx, |workspace, cx| {
+                                workspace.show_toast(
+                                    Toast::new(
+                                        inline_assist_id,
+                                        format!("Inline assistant error: {}", error),
+                                    ),
+                                    cx,
+                                );
+                            })
+                        }
                     })?;
+                } else {
+                    let _ = this.update(&mut cx, |this, cx| {
+                        this.close_inline_assist(inline_assist_id, false, cx)
+                    });
                 }
-                diff.await?;
 
                 anyhow::Ok(())
             }
@@ -2856,6 +2894,7 @@ struct InlineAssistant {
     has_focus: bool,
     include_conversation: bool,
     measurements: Rc<Cell<BlockMeasurements>>,
+    error: Option<anyhow::Error>,
 }
 
 impl Entity for InlineAssistant {
@@ -2868,17 +2907,42 @@ impl View for InlineAssistant {
     }
 
     fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
+        enum ErrorIcon {}
         let theme = theme::current(cx);
 
         Flex::row()
             .with_child(
-                Button::action(ToggleIncludeConversation)
-                    .with_tooltip("Include Conversation", theme.tooltip.clone())
-                    .with_id(self.id)
-                    .with_contents(theme::components::svg::Svg::new("icons/ai.svg"))
-                    .toggleable(self.include_conversation)
-                    .with_style(theme.assistant.inline.include_conversation.clone())
-                    .element()
+                Flex::row()
+                    .with_child(
+                        Button::action(ToggleIncludeConversation)
+                            .with_tooltip("Include Conversation", theme.tooltip.clone())
+                            .with_id(self.id)
+                            .with_contents(theme::components::svg::Svg::new("icons/ai.svg"))
+                            .toggleable(self.include_conversation)
+                            .with_style(theme.assistant.inline.include_conversation.clone())
+                            .element()
+                            .aligned(),
+                    )
+                    .with_children(if let Some(error) = self.error.as_ref() {
+                        Some(
+                            Svg::new("icons/circle_x_mark_12.svg")
+                                .with_color(theme.assistant.error_icon.color)
+                                .constrained()
+                                .with_width(theme.assistant.error_icon.width)
+                                .contained()
+                                .with_style(theme.assistant.error_icon.container)
+                                .with_tooltip::<ErrorIcon>(
+                                    self.id,
+                                    error.to_string(),
+                                    None,
+                                    theme.tooltip.clone(),
+                                    cx,
+                                )
+                                .aligned(),
+                        )
+                    } else {
+                        None
+                    })
                     .aligned()
                     .constrained()
                     .dynamically({
@@ -2954,6 +3018,8 @@ impl InlineAssistant {
                 include_conversation: self.include_conversation,
             });
             self.confirmed = true;
+            self.error = None;
+            cx.notify();
         }
     }
 
@@ -2968,6 +3034,19 @@ impl InlineAssistant {
         });
         cx.notify();
     }
+
+    fn set_error(&mut self, error: anyhow::Error, cx: &mut ViewContext<Self>) {
+        self.error = Some(error);
+        self.confirmed = false;
+        self.prompt_editor.update(cx, |editor, cx| {
+            editor.set_read_only(false);
+            editor.set_field_editor_style(
+                Some(Arc::new(|theme| theme.assistant.inline.editor.clone())),
+                cx,
+            );
+        });
+        cx.notify();
+    }
 }
 
 // This wouldn't need to exist if we could pass parameters when rendering child views.
@@ -2982,7 +3061,7 @@ struct PendingInlineAssist {
     editor: WeakViewHandle<Editor>,
     range: Range<Anchor>,
     highlighted_ranges: Vec<Range<Anchor>>,
-    inline_assistant_block_id: Option<BlockId>,
+    inline_assistant: Option<(BlockId, ViewHandle<InlineAssistant>)>,
     code_generation: Task<Option<()>>,
     transaction_id: Option<TransactionId>,
     _subscriptions: Vec<Subscription>,
@@ -3010,23 +3089,29 @@ fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
     }
 }
 
-fn strip_markdown_codeblock(stream: impl Stream<Item = String>) -> impl Stream<Item = String> {
+fn strip_markdown_codeblock(
+    stream: impl Stream<Item = Result<String>>,
+) -> impl Stream<Item = Result<String>> {
     let mut first_line = true;
     let mut buffer = String::new();
     let mut starts_with_fenced_code_block = false;
     stream.filter_map(move |chunk| {
+        let chunk = match chunk {
+            Ok(chunk) => chunk,
+            Err(err) => return future::ready(Some(Err(err))),
+        };
         buffer.push_str(&chunk);
 
         if first_line {
             if buffer == "" || buffer == "`" || buffer == "``" {
-                return futures::future::ready(None);
+                return future::ready(None);
             } else if buffer.starts_with("```") {
                 starts_with_fenced_code_block = true;
                 if let Some(newline_ix) = buffer.find('\n') {
                     buffer.replace_range(..newline_ix + 1, "");
                     first_line = false;
                 } else {
-                    return futures::future::ready(None);
+                    return future::ready(None);
                 }
             }
         }
@@ -3050,10 +3135,10 @@ fn strip_markdown_codeblock(stream: impl Stream<Item = String>) -> impl Stream<I
         let result = if buffer.is_empty() {
             None
         } else {
-            Some(buffer.clone())
+            Some(Ok(buffer.clone()))
         };
         buffer = remainder;
-        futures::future::ready(result)
+        future::ready(result)
     })
 }
 
@@ -3434,41 +3519,46 @@ mod tests {
     async fn test_strip_markdown_codeblock() {
         assert_eq!(
             strip_markdown_codeblock(chunks("Lorem ipsum dolor", 2))
+                .map(|chunk| chunk.unwrap())
                 .collect::<String>()
                 .await,
             "Lorem ipsum dolor"
         );
         assert_eq!(
             strip_markdown_codeblock(chunks("```\nLorem ipsum dolor", 2))
+                .map(|chunk| chunk.unwrap())
                 .collect::<String>()
                 .await,
             "Lorem ipsum dolor"
         );
         assert_eq!(
             strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```", 2))
+                .map(|chunk| chunk.unwrap())
                 .collect::<String>()
                 .await,
             "Lorem ipsum dolor"
         );
         assert_eq!(
             strip_markdown_codeblock(chunks("```html\n```js\nLorem ipsum dolor\n```\n```", 2))
+                .map(|chunk| chunk.unwrap())
                 .collect::<String>()
                 .await,
             "```js\nLorem ipsum dolor\n```"
         );
         assert_eq!(
             strip_markdown_codeblock(chunks("``\nLorem ipsum dolor\n```", 2))
+                .map(|chunk| chunk.unwrap())
                 .collect::<String>()
                 .await,
             "``\nLorem ipsum dolor\n```"
         );
 
-        fn chunks(text: &str, size: usize) -> impl Stream<Item = String> {
+        fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
             stream::iter(
                 text.chars()
                     .collect::<Vec<_>>()
                     .chunks(size)
-                    .map(|chunk| chunk.iter().collect::<String>())
+                    .map(|chunk| Ok(chunk.iter().collect::<String>()))
                     .collect::<Vec<_>>(),
             )
         }