Show error message when requests to OpenAI fail

Antonio Scandurra and Julia Risley created

Co-Authored-By: Julia Risley <julia@zed.dev>

Change summary

crates/ai/src/assistant.rs        | 124 ++++++++++++++++++++++++--------
crates/theme/src/theme.rs         |   1 
styles/src/styleTree/assistant.ts |  19 +++-
3 files changed, 106 insertions(+), 38 deletions(-)

Detailed changes

crates/ai/src/assistant.rs 🔗

@@ -18,6 +18,7 @@ use gpui::{
 };
 use isahc::{http::StatusCode, Request, RequestExt};
 use language::{language_settings::SoftWrap, Buffer, LanguageRegistry};
+use serde::Deserialize;
 use settings::SettingsStore;
 use std::{borrow::Cow, cell::RefCell, io, rc::Rc, sync::Arc, time::Duration};
 use util::{post_inc, truncate_and_trailoff, ResultExt, TryFutureExt};
@@ -415,7 +416,7 @@ enum AssistantEvent {
 struct Assistant {
     buffer: ModelHandle<MultiBuffer>,
     messages: Vec<Message>,
-    messages_by_id: HashMap<ExcerptId, Message>,
+    messages_metadata: HashMap<ExcerptId, MessageMetadata>,
     summary: Option<String>,
     pending_summary: Task<Option<()>>,
     completion_count: usize,
@@ -443,7 +444,7 @@ impl Assistant {
         let buffer = cx.add_model(|_| MultiBuffer::new(0));
         let mut this = Self {
             messages: Default::default(),
-            messages_by_id: Default::default(),
+            messages_metadata: Default::default(),
             summary: None,
             pending_summary: Task::ready(None),
             completion_count: Default::default(),
@@ -541,16 +542,16 @@ impl Assistant {
         let api_key = self.api_key.borrow().clone();
         if let Some(api_key) = api_key {
             let stream = stream_completion(api_key, cx.background().clone(), request);
-            let response = self.push_message(Role::Assistant, cx);
+            let (excerpt_id, content) = self.push_message(Role::Assistant, cx);
             self.push_message(Role::User, cx);
-            let task = cx.spawn(|this, mut cx| {
-                async move {
+            let task = cx.spawn(|this, mut cx| async move {
+                let stream_completion = async {
                     let mut messages = stream.await?;
 
                     while let Some(message) = messages.next().await {
                         let mut message = message?;
                         if let Some(choice) = message.choices.pop() {
-                            response.content.update(&mut cx, |content, cx| {
+                            content.update(&mut cx, |content, cx| {
                                 let text: Arc<str> = choice.delta.content?.into();
                                 content.edit([(content.len()..content.len(), text)], None, cx);
                                 Some(())
@@ -565,8 +566,16 @@ impl Assistant {
                     });
 
                     anyhow::Ok(())
+                };
+
+                if let Err(error) = stream_completion.await {
+                    this.update(&mut cx, |this, cx| {
+                        if let Some(metadata) = this.messages_metadata.get_mut(&excerpt_id) {
+                            metadata.error = Some(error.to_string().trim().into());
+                            cx.notify();
+                        }
+                    })
                 }
-                .log_err()
             });
 
             self.pending_completions.push(PendingCompletion {
@@ -596,7 +605,7 @@ impl Assistant {
                 && excerpts.contains(&message.excerpt_id)
             {
                 excerpts_to_remove.push(message.excerpt_id);
-                self.messages_by_id.remove(&message.excerpt_id);
+                self.messages_metadata.remove(&message.excerpt_id);
                 false
             } else {
                 true
@@ -611,7 +620,11 @@ impl Assistant {
         }
     }
 
-    fn push_message(&mut self, role: Role, cx: &mut ModelContext<Self>) -> Message {
+    fn push_message(
+        &mut self,
+        role: Role,
+        cx: &mut ModelContext<Self>,
+    ) -> (ExcerptId, ModelHandle<Buffer>) {
         let content = cx.add_model(|cx| {
             let mut buffer = Buffer::new(0, "", cx);
             let markdown = self.languages.language_for_name("Markdown");
@@ -643,15 +656,20 @@ impl Assistant {
                 .unwrap()
         });
 
-        let message = Message {
+        self.messages.push(Message {
             excerpt_id,
             role,
             content: content.clone(),
-            sent_at: Local::now(),
-        };
-        self.messages.push(message.clone());
-        self.messages_by_id.insert(excerpt_id, message.clone());
-        message
+        });
+        self.messages_metadata.insert(
+            excerpt_id,
+            MessageMetadata {
+                role,
+                sent_at: Local::now(),
+                error: None,
+            },
+        );
+        (excerpt_id, content)
     }
 
     fn summarize(&mut self, cx: &mut ModelContext<Self>) {
@@ -705,7 +723,7 @@ impl Assistant {
 
 struct PendingCompletion {
     id: usize,
-    _task: Task<Option<()>>,
+    _task: Task<()>,
 }
 
 enum AssistantEditorEvent {
@@ -733,9 +751,13 @@ impl AssistantEditor {
                 {
                     let assistant = assistant.clone();
                     move |_editor, params: editor::RenderExcerptHeaderParams, cx| {
-                        let style = &theme::current(cx).assistant;
-                        if let Some(message) = assistant.read(cx).messages_by_id.get(&params.id) {
-                            let sender = match message.role {
+                        enum ErrorTooltip {}
+
+                        let theme = theme::current(cx);
+                        let style = &theme.assistant;
+                        if let Some(metadata) = assistant.read(cx).messages_metadata.get(&params.id)
+                        {
+                            let sender = match metadata.role {
                                 Role::User => Label::new("You", style.user_sender.text.clone())
                                     .contained()
                                     .with_style(style.user_sender.container),
@@ -755,13 +777,29 @@ impl AssistantEditor {
                                 .with_child(sender.aligned())
                                 .with_child(
                                     Label::new(
-                                        message.sent_at.format("%I:%M%P").to_string(),
+                                        metadata.sent_at.format("%I:%M%P").to_string(),
                                         style.sent_at.text.clone(),
                                     )
                                     .contained()
                                     .with_style(style.sent_at.container)
                                     .aligned(),
                                 )
+                                .with_children(metadata.error.clone().map(|error| {
+                                    Svg::new("icons/circle_x_mark_12.svg")
+                                        .with_color(style.error_icon.color)
+                                        .constrained()
+                                        .with_width(style.error_icon.width)
+                                        .contained()
+                                        .with_style(style.error_icon.container)
+                                        .with_tooltip::<ErrorTooltip>(
+                                            params.id.into(),
+                                            error,
+                                            None,
+                                            theme.tooltip.clone(),
+                                            cx,
+                                        )
+                                        .aligned()
+                                }))
                                 .aligned()
                                 .left()
                                 .contained()
@@ -793,17 +831,18 @@ impl AssistantEditor {
         self.assistant.update(cx, |assistant, cx| {
             let editor = self.editor.read(cx);
             let newest_selection = editor.selections.newest_anchor();
-            let message = if newest_selection.head() == Anchor::min() {
-                assistant.messages.first()
+            let role = if newest_selection.head() == Anchor::min() {
+                assistant.messages.first().map(|message| message.role)
             } else if newest_selection.head() == Anchor::max() {
-                assistant.messages.last()
+                assistant.messages.last().map(|message| message.role)
             } else {
                 assistant
-                    .messages_by_id
+                    .messages_metadata
                     .get(&newest_selection.head().excerpt_id())
+                    .map(|message| message.role)
             };
 
-            if message.map_or(false, |message| message.role == Role::Assistant) {
+            if role.map_or(false, |role| role == Role::Assistant) {
                 assistant.push_message(Role::User, cx);
             } else {
                 assistant.assist(cx);
@@ -1007,12 +1046,18 @@ impl Item for AssistantEditor {
     }
 }
 
-#[derive(Clone, Debug)]
+#[derive(Debug)]
 struct Message {
     excerpt_id: ExcerptId,
     role: Role,
     content: ModelHandle<Buffer>,
+}
+
+#[derive(Debug)]
+struct MessageMetadata {
+    role: Role,
     sent_at: DateTime<Local>,
+    error: Option<String>,
 }
 
 async fn stream_completion(
@@ -1076,10 +1121,27 @@ async fn stream_completion(
         let mut body = String::new();
         response.body_mut().read_to_string(&mut body).await?;
 
-        Err(anyhow!(
-            "Failed to connect to OpenAI API: {} {}",
-            response.status(),
-            body,
-        ))
+        #[derive(Deserialize)]
+        struct OpenAIResponse {
+            error: OpenAIError,
+        }
+
+        #[derive(Deserialize)]
+        struct OpenAIError {
+            message: String,
+        }
+
+        match serde_json::from_str::<OpenAIResponse>(&body) {
+            Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
+                "Failed to connect to OpenAI API: {}",
+                response.error.message,
+            )),
+
+            _ => Err(anyhow!(
+                "Failed to connect to OpenAI API: {} {}",
+                response.status(),
+                body,
+            )),
+        }
     }
 }

crates/theme/src/theme.rs 🔗

@@ -980,6 +980,7 @@ pub struct AssistantStyle {
     pub model: Interactive<ContainedText>,
     pub remaining_tokens: ContainedText,
     pub no_remaining_tokens: ContainedText,
+    pub error_icon: Icon,
     pub api_key_editor: FieldEditor,
     pub api_key_prompt: ContainedText,
 }

styles/src/styleTree/assistant.ts 🔗

@@ -1,5 +1,5 @@
 import { ColorScheme } from "../themes/common/colorScheme"
-import { text, border, background } from "./components"
+import { text, border, background, foreground } from "./components"
 import editor from "./editor"
 
 export default function assistant(colorScheme: ColorScheme) {
@@ -14,17 +14,17 @@ export default function assistant(colorScheme: ColorScheme) {
         margin: { bottom: 6, top: 6 },
         background: editor(colorScheme).background
       },
-      user_sender: {
+      userSender: {
         ...text(layer, "sans", "default", { size: "sm", weight: "bold" }),
       },
-      assistant_sender: {
+      assistantSender: {
         ...text(layer, "sans", "accent", { size: "sm", weight: "bold" }),
       },
-      sent_at: {
+      sentAt: {
         margin: { top: 2, left: 8 },
         ...text(layer, "sans", "default", { size: "2xs" }),
       },
-      model_info_container: {
+      modelInfoContainer: {
         margin: { right: 16, top: 4 },
       },
       model: {
@@ -37,7 +37,7 @@ export default function assistant(colorScheme: ColorScheme) {
           background: background(layer, "on", "hovered"),
         }
       },
-      remaining_tokens: {
+      remainingTokens: {
         background: background(layer, "on"),
         border: border(layer, "on", { overlay: true }),
         padding: 4,
@@ -45,7 +45,7 @@ export default function assistant(colorScheme: ColorScheme) {
         cornerRadius: 4,
         ...text(layer, "sans", "positive", { size: "xs" }),
       },
-      no_remaining_tokens: {
+      noRemainingTokens: {
         background: background(layer, "on"),
         border: border(layer, "on", { overlay: true }),
         padding: 4,
@@ -53,6 +53,11 @@ export default function assistant(colorScheme: ColorScheme) {
         cornerRadius: 4,
         ...text(layer, "sans", "negative", { size: "xs" }),
       },
+      errorIcon: {
+        margin: { left: 8 },
+        color: foreground(layer, "negative"),
+        width: 12,
+      },
       apiKeyEditor: {
           background: background(layer, "on"),
           cornerRadius: 6,