diff --git a/Cargo.lock b/Cargo.lock index 0594b5c9b5add5dfa1295187f82e8248281085e4..7504b8491b24372d13f867683c171ca2e3485fd7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -456,6 +456,7 @@ version = "0.1.0" dependencies = [ "anyhow", "assistant_tool", + "client", "collections", "command_palette_hooks", "context_server", @@ -465,6 +466,7 @@ dependencies = [ "gpui", "language_model", "language_model_selector", + "language_models", "log", "project", "proto", diff --git a/crates/assistant2/Cargo.toml b/crates/assistant2/Cargo.toml index ff49801c46af4f8f186f81b53bbbf19e7f658f47..20e8dfbc9a9891d6778f8ceec600001636b1d5ff 100644 --- a/crates/assistant2/Cargo.toml +++ b/crates/assistant2/Cargo.toml @@ -15,6 +15,7 @@ doctest = false [dependencies] anyhow.workspace = true assistant_tool.workspace = true +client.workspace = true collections.workspace = true command_palette_hooks.workspace = true context_server.workspace = true @@ -24,6 +25,7 @@ futures.workspace = true gpui.workspace = true language_model.workspace = true language_model_selector.workspace = true +language_models.workspace = true log.workspace = true project.workspace = true proto.workspace = true diff --git a/crates/assistant2/src/assistant_panel.rs b/crates/assistant2/src/assistant_panel.rs index 7d8405dc78722c4e4d48d040407c23f148340080..4e6b6ef227c3285f9847f185f90261e689ae4965 100644 --- a/crates/assistant2/src/assistant_panel.rs +++ b/crates/assistant2/src/assistant_panel.rs @@ -2,9 +2,11 @@ use std::sync::Arc; use anyhow::Result; use assistant_tool::ToolWorkingSet; +use client::zed_urls; use gpui::{ - prelude::*, px, Action, AppContext, AsyncWindowContext, EventEmitter, FocusHandle, - FocusableView, Model, Pixels, Subscription, Task, View, ViewContext, WeakView, WindowContext, + prelude::*, px, Action, AnyElement, AppContext, AsyncWindowContext, EventEmitter, FocusHandle, + FocusableView, FontWeight, Model, Pixels, Subscription, Task, View, ViewContext, WeakView, + WindowContext, }; use language_model::{LanguageModelRegistry, Role}; use language_model_selector::LanguageModelSelector; @@ -13,7 +15,7 @@ use workspace::dock::{DockPosition, Panel, PanelEvent}; use workspace::Workspace; use crate::message_editor::MessageEditor; -use crate::thread::{Message, Thread, ThreadEvent}; +use crate::thread::{Message, Thread, ThreadError, ThreadEvent}; use crate::thread_store::ThreadStore; use crate::{NewThread, ToggleFocus, ToggleModelSelector}; @@ -35,6 +37,7 @@ pub struct AssistantPanel { thread: Model, message_editor: View, tools: Arc, + last_error: Option, _subscriptions: Vec, } @@ -76,6 +79,7 @@ impl AssistantPanel { thread: thread.clone(), message_editor: cx.new_view(|cx| MessageEditor::new(thread, cx)), tools, + last_error: None, _subscriptions: subscriptions, } } @@ -102,6 +106,9 @@ impl AssistantPanel { cx: &mut ViewContext, ) { match event { + ThreadEvent::ShowError(error) => { + self.last_error = Some(error.clone()); + } ThreadEvent::StreamedCompletion => {} ThreadEvent::UsePendingTools => { let pending_tool_uses = self @@ -320,6 +327,152 @@ impl AssistantPanel { ) .child(v_flex().p_1p5().child(Label::new(message.text.clone()))) } + + fn render_last_error(&self, cx: &mut ViewContext) -> Option { + let last_error = self.last_error.as_ref()?; + + Some( + div() + .absolute() + .right_3() + .bottom_12() + .max_w_96() + .py_2() + .px_3() + .elevation_2(cx) + .occlude() + .child(match last_error { + ThreadError::PaymentRequired => self.render_payment_required_error(cx), + ThreadError::MaxMonthlySpendReached => { + self.render_max_monthly_spend_reached_error(cx) + } + ThreadError::Message(error_message) => { + self.render_error_message(error_message, cx) + } + }) + .into_any(), + ) + } + + fn render_payment_required_error(&self, cx: &mut ViewContext) -> AnyElement { + const ERROR_MESSAGE: &str = "Free tier exceeded. Subscribe and add payment to continue using Zed LLMs. You'll be billed at cost for tokens used."; + + v_flex() + .gap_0p5() + .child( + h_flex() + .gap_1p5() + .items_center() + .child(Icon::new(IconName::XCircle).color(Color::Error)) + .child(Label::new("Free Usage Exceeded").weight(FontWeight::MEDIUM)), + ) + .child( + div() + .id("error-message") + .max_h_24() + .overflow_y_scroll() + .child(Label::new(ERROR_MESSAGE)), + ) + .child( + h_flex() + .justify_end() + .mt_1() + .child(Button::new("subscribe", "Subscribe").on_click(cx.listener( + |this, _, cx| { + this.last_error = None; + cx.open_url(&zed_urls::account_url(cx)); + cx.notify(); + }, + ))) + .child(Button::new("dismiss", "Dismiss").on_click(cx.listener( + |this, _, cx| { + this.last_error = None; + cx.notify(); + }, + ))), + ) + .into_any() + } + + fn render_max_monthly_spend_reached_error(&self, cx: &mut ViewContext) -> AnyElement { + const ERROR_MESSAGE: &str = "You have reached your maximum monthly spend. Increase your spend limit to continue using Zed LLMs."; + + v_flex() + .gap_0p5() + .child( + h_flex() + .gap_1p5() + .items_center() + .child(Icon::new(IconName::XCircle).color(Color::Error)) + .child(Label::new("Max Monthly Spend Reached").weight(FontWeight::MEDIUM)), + ) + .child( + div() + .id("error-message") + .max_h_24() + .overflow_y_scroll() + .child(Label::new(ERROR_MESSAGE)), + ) + .child( + h_flex() + .justify_end() + .mt_1() + .child( + Button::new("subscribe", "Update Monthly Spend Limit").on_click( + cx.listener(|this, _, cx| { + this.last_error = None; + cx.open_url(&zed_urls::account_url(cx)); + cx.notify(); + }), + ), + ) + .child(Button::new("dismiss", "Dismiss").on_click(cx.listener( + |this, _, cx| { + this.last_error = None; + cx.notify(); + }, + ))), + ) + .into_any() + } + + fn render_error_message( + &self, + error_message: &SharedString, + cx: &mut ViewContext, + ) -> AnyElement { + v_flex() + .gap_0p5() + .child( + h_flex() + .gap_1p5() + .items_center() + .child(Icon::new(IconName::XCircle).color(Color::Error)) + .child( + Label::new("Error interacting with language model") + .weight(FontWeight::MEDIUM), + ), + ) + .child( + div() + .id("error-message") + .max_h_32() + .overflow_y_scroll() + .child(Label::new(error_message.clone())), + ) + .child( + h_flex() + .justify_end() + .mt_1() + .child(Button::new("dismiss", "Dismiss").on_click(cx.listener( + |this, _, cx| { + this.last_error = None; + cx.notify(); + }, + ))), + ) + .into_any() + } } impl Render for AssistantPanel { @@ -354,5 +507,6 @@ impl Render for AssistantPanel { .border_color(cx.theme().colors().border_variant) .child(self.message_editor.clone()), ) + .children(self.render_last_error(cx)) } } diff --git a/crates/assistant2/src/thread.rs b/crates/assistant2/src/thread.rs index 0d2aab6905f62dbe5d9f5643843a64703ebfec6f..a5ab415a4d7e107408f2b128445c8ff62294b1e5 100644 --- a/crates/assistant2/src/thread.rs +++ b/crates/assistant2/src/thread.rs @@ -5,12 +5,13 @@ use assistant_tool::ToolWorkingSet; use collections::HashMap; use futures::future::Shared; use futures::{FutureExt as _, StreamExt as _}; -use gpui::{AppContext, EventEmitter, ModelContext, Task}; +use gpui::{AppContext, EventEmitter, ModelContext, SharedString, Task}; use language_model::{ LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role, StopReason, }; +use language_models::provider::cloud::{MaxMonthlySpendReachedError, PaymentRequiredError}; use serde::{Deserialize, Serialize}; use util::post_inc; @@ -210,29 +211,28 @@ impl Thread { let result = stream_completion.await; thread - .update(&mut cx, |_thread, cx| { - let error_message = if let Some(error) = result.as_ref().err() { - let error_message = error - .chain() - .map(|err| err.to_string()) - .collect::>() - .join("\n"); - Some(error_message) - } else { - None - }; - - if let Some(error_message) = error_message { - eprintln!("Completion failed: {error_message:?}"); - } - - if let Ok(stop_reason) = result { - match stop_reason { - StopReason::ToolUse => { - cx.emit(ThreadEvent::UsePendingTools); - } - StopReason::EndTurn => {} - StopReason::MaxTokens => {} + .update(&mut cx, |_thread, cx| match result.as_ref() { + Ok(stop_reason) => match stop_reason { + StopReason::ToolUse => { + cx.emit(ThreadEvent::UsePendingTools); + } + StopReason::EndTurn => {} + StopReason::MaxTokens => {} + }, + Err(error) => { + if error.is::() { + cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired)); + } else if error.is::() { + cx.emit(ThreadEvent::ShowError(ThreadError::MaxMonthlySpendReached)); + } else { + let error_message = error + .chain() + .map(|err| err.to_string()) + .collect::>() + .join("\n"); + cx.emit(ThreadEvent::ShowError(ThreadError::Message( + SharedString::from(error_message.clone()), + ))); } } }) @@ -305,8 +305,16 @@ impl Thread { } } +#[derive(Debug, Clone)] +pub enum ThreadError { + PaymentRequired, + MaxMonthlySpendReached, + Message(SharedString), +} + #[derive(Debug, Clone)] pub enum ThreadEvent { + ShowError(ThreadError), StreamedCompletion, UsePendingTools, ToolFinished { diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index 3c5a00bd85e682fa7f53b747e786d25ded013249..83f0b50321c4c069e592deffbc7a2a9816e5794c 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -55,7 +55,7 @@ pub enum LanguageModelCompletionEvent { StartMessage { message_id: String }, } -#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] +#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum StopReason { EndTurn,