Detailed changes
@@ -43,6 +43,7 @@ use ui::{
Disclosure, IconButton, KeyBinding, Scrollbar, ScrollbarState, TextSize, Tooltip, prelude::*,
};
use util::ResultExt as _;
+use util::markdown::MarkdownString;
use workspace::{OpenOptions, Workspace};
use zed_actions::assistant::OpenRulesLibrary;
@@ -769,7 +770,7 @@ impl ActiveThread {
this.render_tool_use_markdown(
tool_use.id.clone(),
tool_use.ui_text.clone(),
- &tool_use.input,
+ &serde_json::to_string_pretty(&tool_use.input).unwrap_or_default(),
tool_use.status.text(),
cx,
);
@@ -870,7 +871,7 @@ impl ActiveThread {
&mut self,
tool_use_id: LanguageModelToolUseId,
tool_label: impl Into<SharedString>,
- tool_input: &serde_json::Value,
+ tool_input: &str,
tool_output: SharedString,
cx: &mut Context<Self>,
) {
@@ -893,11 +894,10 @@ impl ActiveThread {
this.replace(tool_label, cx);
});
rendered.input.update(cx, |this, cx| {
- let input = format!(
- "```json\n{}\n```",
- serde_json::to_string_pretty(tool_input).unwrap_or_default()
+ this.replace(
+ MarkdownString::code_block("json", tool_input).to_string(),
+ cx,
);
- this.replace(input, cx);
});
rendered.output.update(cx, |this, cx| {
this.replace(tool_output, cx);
@@ -988,7 +988,7 @@ impl ActiveThread {
self.render_tool_use_markdown(
tool_use.id.clone(),
tool_use.ui_text.clone(),
- &tool_use.input,
+ &serde_json::to_string_pretty(&tool_use.input).unwrap_or_default(),
"".into(),
cx,
);
@@ -1002,7 +1002,7 @@ impl ActiveThread {
self.render_tool_use_markdown(
tool_use_id.clone(),
ui_text.clone(),
- input,
+ &serde_json::to_string_pretty(&input).unwrap_or_default(),
"".into(),
cx,
);
@@ -1014,7 +1014,7 @@ impl ActiveThread {
self.render_tool_use_markdown(
tool_use.id.clone(),
tool_use.ui_text.clone(),
- &tool_use.input,
+ &serde_json::to_string_pretty(&tool_use.input).unwrap_or_default(),
self.thread
.read(cx)
.output_for_tool(&tool_use.id)
@@ -1026,6 +1026,23 @@ impl ActiveThread {
}
ThreadEvent::CheckpointChanged => cx.notify(),
ThreadEvent::ReceivedTextChunk => {}
+ ThreadEvent::InvalidToolInput {
+ tool_use_id,
+ ui_text,
+ invalid_input_json,
+ } => {
+ self.render_tool_use_markdown(
+ tool_use_id.clone(),
+ ui_text,
+ invalid_input_json,
+ self.thread
+ .read(cx)
+ .output_for_tool(tool_use_id)
+ .map(|output| output.clone().into())
+ .unwrap_or("".into()),
+ cx,
+ );
+ }
}
}
@@ -5,7 +5,9 @@ use anyhow::Result;
use client::telemetry::Telemetry;
use collections::HashSet;
use editor::{Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset as _, ToPoint};
-use futures::{SinkExt, Stream, StreamExt, channel::mpsc, future::LocalBoxFuture, join};
+use futures::{
+ SinkExt, Stream, StreamExt, TryStreamExt as _, channel::mpsc, future::LocalBoxFuture, join,
+};
use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Subscription, Task};
use language::{Buffer, IndentKind, Point, TransactionId, line_diff};
use language_model::{
@@ -508,7 +510,9 @@ impl CodegenAlternative {
let mut response_latency = None;
let request_start = Instant::now();
let diff = async {
- let chunks = StripInvalidSpans::new(stream?.stream);
+ let chunks = StripInvalidSpans::new(
+ stream?.stream.map_err(|error| error.into()),
+ );
futures::pin_mut!(chunks);
let mut diff = StreamingDiff::new(selected_text.to_string());
let mut line_diff = LineDiff::default();
@@ -17,10 +17,10 @@ use gpui::{
AnyWindowHandle, App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity,
};
use language_model::{
- ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
- LanguageModelImage, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
- LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
- LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
+ ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
+ LanguageModelId, LanguageModelImage, LanguageModelKnownError, LanguageModelRegistry,
+ LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
+ LanguageModelToolResult, LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, StopReason,
TokenUsage,
};
@@ -1275,9 +1275,30 @@ impl Thread {
.push(event.as_ref().map_err(|error| error.to_string()).cloned());
}
- let event = event?;
-
thread.update(cx, |thread, cx| {
+ let event = match event {
+ Ok(event) => event,
+ Err(LanguageModelCompletionError::BadInputJson {
+ id,
+ tool_name,
+ raw_input: invalid_input_json,
+ json_parse_error,
+ }) => {
+ thread.receive_invalid_tool_json(
+ id,
+ tool_name,
+ invalid_input_json,
+ json_parse_error,
+ window,
+ cx,
+ );
+ return Ok(());
+ }
+ Err(LanguageModelCompletionError::Other(error)) => {
+ return Err(error);
+ }
+ };
+
match event {
LanguageModelCompletionEvent::StartMessage { .. } => {
request_assistant_message_id = Some(thread.insert_message(
@@ -1390,7 +1411,8 @@ impl Thread {
cx.notify();
thread.auto_capture_telemetry(cx);
- })?;
+ Ok(())
+ })??;
smol::future::yield_now().await;
}
@@ -1681,6 +1703,41 @@ impl Thread {
pending_tool_uses
}
+ pub fn receive_invalid_tool_json(
+ &mut self,
+ tool_use_id: LanguageModelToolUseId,
+ tool_name: Arc<str>,
+ invalid_json: Arc<str>,
+ error: String,
+ window: Option<AnyWindowHandle>,
+ cx: &mut Context<Thread>,
+ ) {
+ log::error!("The model returned invalid input JSON: {invalid_json}");
+
+ let pending_tool_use = self.tool_use.insert_tool_output(
+ tool_use_id.clone(),
+ tool_name,
+ Err(anyhow!("Error parsing input JSON: {error}")),
+ cx,
+ );
+ let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
+ pending_tool_use.ui_text.clone()
+ } else {
+ log::error!(
+ "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
+ );
+ format!("Unknown tool {}", tool_use_id).into()
+ };
+
+ cx.emit(ThreadEvent::InvalidToolInput {
+ tool_use_id: tool_use_id.clone(),
+ ui_text,
+ invalid_input_json: invalid_json,
+ });
+
+ self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
+ }
+
pub fn run_tool(
&mut self,
tool_use_id: LanguageModelToolUseId,
@@ -2282,6 +2339,11 @@ pub enum ThreadEvent {
ui_text: Arc<str>,
input: serde_json::Value,
},
+ InvalidToolInput {
+ tool_use_id: LanguageModelToolUseId,
+ ui_text: Arc<str>,
+ invalid_input_json: Arc<str>,
+ },
Stopped(Result<StopReason, Arc<anyhow::Error>>),
MessageAdded(MessageId),
MessageEdited(MessageId),
@@ -22,7 +22,7 @@ use feature_flags::{
};
use fs::Fs;
use futures::{
- SinkExt, Stream, StreamExt,
+ SinkExt, Stream, StreamExt, TryStreamExt as _,
channel::mpsc,
future::{BoxFuture, LocalBoxFuture},
join,
@@ -3056,7 +3056,8 @@ impl CodegenAlternative {
let mut response_latency = None;
let request_start = Instant::now();
let diff = async {
- let chunks = StripInvalidSpans::new(stream?.stream);
+ let chunks =
+ StripInvalidSpans::new(stream?.stream.map_err(|e| e.into()));
futures::pin_mut!(chunks);
let mut diff = StreamingDiff::new(selected_text.to_string());
let mut line_diff = LineDiff::default();
@@ -253,6 +253,9 @@ impl ExampleContext {
}
});
}
+ ThreadEvent::InvalidToolInput { .. } => {
+ println!("{log_prefix} invalid tool input");
+ }
ThreadEvent::ToolConfirmationNeeded => {
panic!(
"{}Bug: Tool confirmation should not be required in eval",
@@ -1,7 +1,7 @@
use crate::{
- AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
- LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
- LanguageModelProviderState, LanguageModelRequest,
+ AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
+ LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
+ LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
};
use futures::{FutureExt, StreamExt, channel::mpsc, future::BoxFuture, stream::BoxStream};
use gpui::{AnyView, App, AsyncApp, Entity, Task, Window};
@@ -168,7 +168,12 @@ impl LanguageModel for FakeLanguageModel {
&self,
request: LanguageModelRequest,
_: &AsyncApp,
- ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
+ ) -> BoxFuture<
+ 'static,
+ Result<
+ BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
+ >,
+ > {
let (tx, rx) = mpsc::unbounded();
self.current_completion_txs.lock().push((request, tx));
async move {
@@ -76,6 +76,19 @@ pub enum LanguageModelCompletionEvent {
UsageUpdate(TokenUsage),
}
+#[derive(Error, Debug)]
+pub enum LanguageModelCompletionError {
+ #[error("received bad input JSON")]
+ BadInputJson {
+ id: LanguageModelToolUseId,
+ tool_name: Arc<str>,
+ raw_input: Arc<str>,
+ json_parse_error: String,
+ },
+ #[error(transparent)]
+ Other(#[from] anyhow::Error),
+}
+
/// Indicates the format used to define the input schema for a language model tool.
#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
pub enum LanguageModelToolSchemaFormat {
@@ -193,7 +206,7 @@ pub struct LanguageModelToolUse {
pub struct LanguageModelTextStream {
pub message_id: Option<String>,
- pub stream: BoxStream<'static, Result<String>>,
+ pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
// Has complete token usage after the stream has finished
pub last_token_usage: Arc<Mutex<TokenUsage>>,
}
@@ -246,7 +259,12 @@ pub trait LanguageModel: Send + Sync {
&self,
request: LanguageModelRequest,
cx: &AsyncApp,
- ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>>;
+ ) -> BoxFuture<
+ 'static,
+ Result<
+ BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
+ >,
+ >;
fn stream_completion_with_usage(
&self,
@@ -255,7 +273,7 @@ pub trait LanguageModel: Send + Sync {
) -> BoxFuture<
'static,
Result<(
- BoxStream<'static, Result<LanguageModelCompletionEvent>>,
+ BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
Option<RequestUsage>,
)>,
> {
@@ -12,10 +12,10 @@ use gpui::{
};
use http_client::HttpClient;
use language_model::{
- AuthenticateError, LanguageModel, LanguageModelCacheConfiguration, LanguageModelId,
- LanguageModelKnownError, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
- LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, MessageContent,
- RateLimiter, Role,
+ AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
+ LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName,
+ LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
+ LanguageModelProviderState, LanguageModelRequest, MessageContent, RateLimiter, Role,
};
use language_model::{LanguageModelCompletionEvent, LanguageModelToolUse, StopReason};
use schemars::JsonSchema;
@@ -27,7 +27,7 @@ use std::sync::Arc;
use strum::IntoEnumIterator;
use theme::ThemeSettings;
use ui::{Icon, IconName, List, Tooltip, prelude::*};
-use util::{ResultExt, maybe};
+use util::ResultExt;
const PROVIDER_ID: &str = language_model::ANTHROPIC_PROVIDER_ID;
const PROVIDER_NAME: &str = "Anthropic";
@@ -448,7 +448,12 @@ impl LanguageModel for AnthropicModel {
&self,
request: LanguageModelRequest,
cx: &AsyncApp,
- ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
+ ) -> BoxFuture<
+ 'static,
+ Result<
+ BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
+ >,
+ > {
let request = into_anthropic(
request,
self.model.request_id().into(),
@@ -626,7 +631,7 @@ pub fn into_anthropic(
pub fn map_to_language_model_completion_events(
events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
-) -> impl Stream<Item = Result<LanguageModelCompletionEvent>> {
+) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
struct RawToolUse {
id: String,
name: String,
@@ -740,30 +745,32 @@ pub fn map_to_language_model_completion_events(
Event::ContentBlockStop { index } => {
if let Some(tool_use) = state.tool_uses_by_index.remove(&index) {
let input_json = tool_use.input_json.trim();
+ let input_value = if input_json.is_empty() {
+ Ok(serde_json::Value::Object(serde_json::Map::default()))
+ } else {
+ serde_json::Value::from_str(input_json)
+ };
+ let event_result = match input_value {
+ Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
+ LanguageModelToolUse {
+ id: tool_use.id.into(),
+ name: tool_use.name.into(),
+ is_input_complete: true,
+ input,
+ raw_input: tool_use.input_json.clone(),
+ },
+ )),
+ Err(json_parse_err) => {
+ Err(LanguageModelCompletionError::BadInputJson {
+ id: tool_use.id.into(),
+ tool_name: tool_use.name.into(),
+ raw_input: input_json.into(),
+ json_parse_error: json_parse_err.to_string(),
+ })
+ }
+ };
- return Some((
- vec![maybe!({
- Ok(LanguageModelCompletionEvent::ToolUse(
- LanguageModelToolUse {
- id: tool_use.id.into(),
- name: tool_use.name.into(),
- is_input_complete: true,
- input: if input_json.is_empty() {
- serde_json::Value::Object(
- serde_json::Map::default(),
- )
- } else {
- serde_json::Value::from_str(
- input_json
- )
- .map_err(|err| anyhow!("Error parsing tool call input JSON: {err:?} - JSON string was: {input_json:?}"))?
- },
- raw_input: tool_use.input_json.clone(),
- },
- ))
- })],
- state,
- ));
+ return Some((vec![event_result], state));
}
}
Event::MessageStart { message } => {
@@ -810,14 +817,19 @@ pub fn map_to_language_model_completion_events(
}
Event::Error { error } => {
return Some((
- vec![Err(anyhow!(AnthropicError::ApiError(error)))],
+ vec![Err(LanguageModelCompletionError::Other(anyhow!(
+ AnthropicError::ApiError(error)
+ )))],
state,
));
}
_ => {}
},
Err(err) => {
- return Some((vec![Err(anthropic_err_to_anyhow(err))], state));
+ return Some((
+ vec![Err(LanguageModelCompletionError::Other(anyhow!(err)))],
+ state,
+ ));
}
}
}
@@ -32,9 +32,10 @@ use gpui_tokio::Tokio;
use http_client::HttpClient;
use language_model::{
AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
- LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
- LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
- LanguageModelRequest, LanguageModelToolUse, MessageContent, RateLimiter, Role, TokenUsage,
+ LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName,
+ LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
+ LanguageModelProviderState, LanguageModelRequest, LanguageModelToolUse, MessageContent,
+ RateLimiter, Role, TokenUsage,
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
@@ -542,7 +543,12 @@ impl LanguageModel for BedrockModel {
&self,
request: LanguageModelRequest,
cx: &AsyncApp,
- ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
+ ) -> BoxFuture<
+ 'static,
+ Result<
+ BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
+ >,
+ > {
let Ok(region) = cx.read_entity(&self.state, |state, _cx| {
// Get region - from credentials or directly from settings
let region = state
@@ -780,7 +786,7 @@ pub fn get_bedrock_tokens(
pub fn map_to_language_model_completion_events(
events: Pin<Box<dyn Send + Stream<Item = Result<BedrockStreamingResponse, BedrockError>>>>,
handle: Handle,
-) -> impl Stream<Item = Result<LanguageModelCompletionEvent>> {
+) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
struct RawToolUse {
id: String,
name: String,
@@ -971,7 +977,7 @@ pub fn map_to_language_model_completion_events(
_ => {}
},
- Err(err) => return Some((Some(Err(anyhow!(err))), state)),
+ Err(err) => return Some((Some(Err(anyhow!(err).into())), state)),
}
}
None
@@ -10,11 +10,11 @@ use futures::{
use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Task};
use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
use language_model::{
- AuthenticateError, CloudModel, LanguageModel, LanguageModelCacheConfiguration, LanguageModelId,
- LanguageModelKnownError, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
- LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest,
- LanguageModelToolSchemaFormat, ModelRequestLimitReachedError, RateLimiter, RequestUsage,
- ZED_CLOUD_PROVIDER_ID,
+ AuthenticateError, CloudModel, LanguageModel, LanguageModelCacheConfiguration,
+ LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName,
+ LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
+ LanguageModelProviderTosView, LanguageModelRequest, LanguageModelToolSchemaFormat,
+ ModelRequestLimitReachedError, RateLimiter, RequestUsage, ZED_CLOUD_PROVIDER_ID,
};
use language_model::{
LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider, LlmApiToken,
@@ -745,7 +745,12 @@ impl LanguageModel for CloudLanguageModel {
&self,
request: LanguageModelRequest,
cx: &AsyncApp,
- ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
+ ) -> BoxFuture<
+ 'static,
+ Result<
+ BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
+ >,
+ > {
self.stream_completion_with_usage(request, cx)
.map(|result| result.map(|(stream, _)| stream))
.boxed()
@@ -758,7 +763,7 @@ impl LanguageModel for CloudLanguageModel {
) -> BoxFuture<
'static,
Result<(
- BoxStream<'static, Result<LanguageModelCompletionEvent>>,
+ BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
Option<RequestUsage>,
)>,
> {
@@ -17,16 +17,16 @@ use gpui::{
Transformation, percentage, svg,
};
use language_model::{
- AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
- LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
- LanguageModelProviderState, LanguageModelRequest, LanguageModelRequestMessage,
- LanguageModelToolUse, MessageContent, RateLimiter, Role, StopReason,
+ AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
+ LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
+ LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
+ LanguageModelRequestMessage, LanguageModelToolUse, MessageContent, RateLimiter, Role,
+ StopReason,
};
use settings::SettingsStore;
use std::time::Duration;
use strum::IntoEnumIterator;
use ui::prelude::*;
-use util::maybe;
use super::anthropic::count_anthropic_tokens;
use super::google::count_google_tokens;
@@ -242,7 +242,12 @@ impl LanguageModel for CopilotChatLanguageModel {
&self,
request: LanguageModelRequest,
cx: &AsyncApp,
- ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
+ ) -> BoxFuture<
+ 'static,
+ Result<
+ BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
+ >,
+ > {
if let Some(message) = request.messages.last() {
if message.contents_empty() {
const EMPTY_PROMPT_MSG: &str =
@@ -285,7 +290,7 @@ impl LanguageModel for CopilotChatLanguageModel {
pub fn map_to_language_model_completion_events(
events: Pin<Box<dyn Send + Stream<Item = Result<ResponseEvent>>>>,
is_streaming: bool,
-) -> impl Stream<Item = Result<LanguageModelCompletionEvent>> {
+) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
#[derive(Default)]
struct RawToolCall {
id: String,
@@ -309,7 +314,7 @@ pub fn map_to_language_model_completion_events(
Ok(event) => {
let Some(choice) = event.choices.first() else {
return Some((
- vec![Err(anyhow!("Response contained no choices"))],
+ vec![Err(anyhow!("Response contained no choices").into())],
state,
));
};
@@ -322,7 +327,7 @@ pub fn map_to_language_model_completion_events(
let Some(delta) = delta else {
return Some((
- vec![Err(anyhow!("Response contained no delta"))],
+ vec![Err(anyhow!("Response contained no delta").into())],
state,
));
};
@@ -361,20 +366,26 @@ pub fn map_to_language_model_completion_events(
}
Some("tool_calls") => {
events.extend(state.tool_calls_by_index.drain().map(
- |(_, tool_call)| {
- maybe!({
- Ok(LanguageModelCompletionEvent::ToolUse(
- LanguageModelToolUse {
- id: tool_call.id.into(),
- name: tool_call.name.as_str().into(),
- is_input_complete: true,
- raw_input: tool_call.arguments.clone(),
- input: serde_json::Value::from_str(
- &tool_call.arguments,
- )?,
- },
- ))
- })
+ |(_, tool_call)| match serde_json::Value::from_str(
+ &tool_call.arguments,
+ ) {
+ Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
+ LanguageModelToolUse {
+ id: tool_call.id.clone().into(),
+ name: tool_call.name.as_str().into(),
+ is_input_complete: true,
+ input,
+ raw_input: tool_call.arguments.clone(),
+ },
+ )),
+ Err(error) => {
+ Err(LanguageModelCompletionError::BadInputJson {
+ id: tool_call.id.into(),
+ tool_name: tool_call.name.as_str().into(),
+ raw_input: tool_call.arguments.into(),
+ json_parse_error: error.to_string(),
+ })
+ }
},
));
@@ -393,7 +404,7 @@ pub fn map_to_language_model_completion_events(
return Some((events, state));
}
- Err(err) => return Some((vec![Err(err)], state)),
+ Err(err) => return Some((vec![Err(anyhow!(err).into())], state)),
}
}
@@ -9,9 +9,9 @@ use gpui::{
};
use http_client::HttpClient;
use language_model::{
- AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
- LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
- LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
+ AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
+ LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
+ LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
@@ -324,7 +324,12 @@ impl LanguageModel for DeepSeekLanguageModel {
&self,
request: LanguageModelRequest,
cx: &AsyncApp,
- ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
+ ) -> BoxFuture<
+ 'static,
+ Result<
+ BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
+ >,
+ > {
let request = into_deepseek(
request,
self.model.id().to_string(),
@@ -336,20 +341,22 @@ impl LanguageModel for DeepSeekLanguageModel {
let stream = stream.await?;
Ok(stream
.map(|result| {
- result.and_then(|response| {
- response
- .choices
- .first()
- .ok_or_else(|| anyhow!("Empty response"))
- .map(|choice| {
- choice
- .delta
- .content
- .clone()
- .unwrap_or_default()
- .map(LanguageModelCompletionEvent::Text)
- })
- })
+ result
+ .and_then(|response| {
+ response
+ .choices
+ .first()
+ .ok_or_else(|| anyhow!("Empty response"))
+ .map(|choice| {
+ choice
+ .delta
+ .content
+ .clone()
+ .unwrap_or_default()
+ .map(LanguageModelCompletionEvent::Text)
+ })
+ })
+ .map_err(LanguageModelCompletionError::Other)
})
.boxed())
}
@@ -11,8 +11,9 @@ use gpui::{
};
use http_client::HttpClient;
use language_model::{
- AuthenticateError, LanguageModelCompletionEvent, LanguageModelToolSchemaFormat,
- LanguageModelToolUse, LanguageModelToolUseId, MessageContent, StopReason,
+ AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent,
+ LanguageModelToolSchemaFormat, LanguageModelToolUse, LanguageModelToolUseId, MessageContent,
+ StopReason,
};
use language_model::{
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
@@ -355,12 +356,19 @@ impl LanguageModel for GoogleLanguageModel {
cx: &AsyncApp,
) -> BoxFuture<
'static,
- Result<futures::stream::BoxStream<'static, Result<LanguageModelCompletionEvent>>>,
+ Result<
+ futures::stream::BoxStream<
+ 'static,
+ Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
+ >,
+ >,
> {
let request = into_google(request, self.model.id().to_string());
let request = self.stream_completion(request, cx);
let future = self.request_limiter.stream(async move {
- let response = request.await.map_err(|err| anyhow!(err))?;
+ let response = request
+ .await
+ .map_err(|err| LanguageModelCompletionError::Other(anyhow!(err)))?;
Ok(map_to_language_model_completion_events(response))
});
async move { Ok(future.await?.boxed()) }.boxed()
@@ -471,7 +479,7 @@ pub fn into_google(
pub fn map_to_language_model_completion_events(
events: Pin<Box<dyn Send + Stream<Item = Result<GenerateContentResponse>>>>,
-) -> impl Stream<Item = Result<LanguageModelCompletionEvent>> {
+) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
use std::sync::atomic::{AtomicU64, Ordering};
static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0);
@@ -492,7 +500,7 @@ pub fn map_to_language_model_completion_events(
if let Some(event) = state.events.next().await {
match event {
Ok(event) => {
- let mut events: Vec<Result<LanguageModelCompletionEvent>> = Vec::new();
+ let mut events: Vec<_> = Vec::new();
let mut wants_to_use_tool = false;
if let Some(usage_metadata) = event.usage_metadata {
update_usage(&mut state.usage, &usage_metadata);
@@ -559,7 +567,10 @@ pub fn map_to_language_model_completion_events(
return Some((events, state));
}
Err(err) => {
- return Some((vec![Err(anyhow!(err))], state));
+ return Some((
+ vec![Err(LanguageModelCompletionError::Other(anyhow!(err)))],
+ state,
+ ));
}
}
}
@@ -2,7 +2,9 @@ use anyhow::{Result, anyhow};
use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task};
use http_client::HttpClient;
-use language_model::{AuthenticateError, LanguageModelCompletionEvent};
+use language_model::{
+ AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent,
+};
use language_model::{
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
@@ -310,7 +312,12 @@ impl LanguageModel for LmStudioLanguageModel {
&self,
request: LanguageModelRequest,
cx: &AsyncApp,
- ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
+ ) -> BoxFuture<
+ 'static,
+ Result<
+ BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
+ >,
+ > {
let request = self.to_lmstudio_request(request);
let http_client = self.http_client.clone();
@@ -364,7 +371,11 @@ impl LanguageModel for LmStudioLanguageModel {
async move {
Ok(future
.await?
- .map(|result| result.map(LanguageModelCompletionEvent::Text))
+ .map(|result| {
+ result
+ .map(LanguageModelCompletionEvent::Text)
+ .map_err(LanguageModelCompletionError::Other)
+ })
.boxed())
}
.boxed()
@@ -8,9 +8,9 @@ use gpui::{
};
use http_client::HttpClient;
use language_model::{
- AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
- LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
- LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
+ AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
+ LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
+ LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
};
use futures::stream::BoxStream;
@@ -344,7 +344,12 @@ impl LanguageModel for MistralLanguageModel {
&self,
request: LanguageModelRequest,
cx: &AsyncApp,
- ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
+ ) -> BoxFuture<
+ 'static,
+ Result<
+ BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
+ >,
+ > {
let request = into_mistral(
request,
self.model.id().to_string(),
@@ -356,20 +361,22 @@ impl LanguageModel for MistralLanguageModel {
let stream = stream.await?;
Ok(stream
.map(|result| {
- result.and_then(|response| {
- response
- .choices
- .first()
- .ok_or_else(|| anyhow!("Empty response"))
- .map(|choice| {
- choice
- .delta
- .content
- .clone()
- .unwrap_or_default()
- .map(LanguageModelCompletionEvent::Text)
- })
- })
+ result
+ .and_then(|response| {
+ response
+ .choices
+ .first()
+ .ok_or_else(|| anyhow!("Empty response"))
+ .map(|choice| {
+ choice
+ .delta
+ .content
+ .clone()
+ .unwrap_or_default()
+ .map(LanguageModelCompletionEvent::Text)
+ })
+ })
+ .map_err(LanguageModelCompletionError::Other)
})
.boxed())
}
@@ -2,7 +2,9 @@ use anyhow::{Result, anyhow};
use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task};
use http_client::HttpClient;
-use language_model::{AuthenticateError, LanguageModelCompletionEvent};
+use language_model::{
+ AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent,
+};
use language_model::{
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
@@ -322,7 +324,12 @@ impl LanguageModel for OllamaLanguageModel {
&self,
request: LanguageModelRequest,
cx: &AsyncApp,
- ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
+ ) -> BoxFuture<
+ 'static,
+ Result<
+ BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
+ >,
+ > {
let request = self.to_ollama_request(request);
let http_client = self.http_client.clone();
@@ -356,7 +363,11 @@ impl LanguageModel for OllamaLanguageModel {
async move {
Ok(future
.await?
- .map(|result| result.map(LanguageModelCompletionEvent::Text))
+ .map(|result| {
+ result
+ .map(LanguageModelCompletionEvent::Text)
+ .map_err(LanguageModelCompletionError::Other)
+ })
.boxed())
}
.boxed()
@@ -9,10 +9,10 @@ use gpui::{
};
use http_client::HttpClient;
use language_model::{
- AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
- LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
- LanguageModelProviderState, LanguageModelRequest, LanguageModelToolUse, MessageContent,
- RateLimiter, Role, StopReason,
+ AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
+ LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
+ LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
+ LanguageModelToolUse, MessageContent, RateLimiter, Role, StopReason,
};
use open_ai::{Model, ResponseStreamEvent, stream_completion};
use schemars::JsonSchema;
@@ -24,7 +24,7 @@ use std::sync::Arc;
use strum::IntoEnumIterator;
use theme::ThemeSettings;
use ui::{Icon, IconName, List, Tooltip, prelude::*};
-use util::{ResultExt, maybe};
+use util::ResultExt;
use crate::{AllLanguageModelSettings, ui::InstructionListItem};
@@ -321,7 +321,12 @@ impl LanguageModel for OpenAiLanguageModel {
cx: &AsyncApp,
) -> BoxFuture<
'static,
- Result<futures::stream::BoxStream<'static, Result<LanguageModelCompletionEvent>>>,
+ Result<
+ futures::stream::BoxStream<
+ 'static,
+ Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
+ >,
+ >,
> {
let request = into_open_ai(request, &self.model, self.max_output_tokens());
let completions = self.stream_completion(request, cx);
@@ -419,7 +424,7 @@ pub fn into_open_ai(
pub fn map_to_language_model_completion_events(
events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
-) -> impl Stream<Item = Result<LanguageModelCompletionEvent>> {
+) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
#[derive(Default)]
struct RawToolCall {
id: String,
@@ -443,7 +448,9 @@ pub fn map_to_language_model_completion_events(
Ok(event) => {
let Some(choice) = event.choices.first() else {
return Some((
- vec![Err(anyhow!("Response contained no choices"))],
+ vec![Err(LanguageModelCompletionError::Other(anyhow!(
+ "Response contained no choices"
+ )))],
state,
));
};
@@ -484,20 +491,26 @@ pub fn map_to_language_model_completion_events(
}
Some("tool_calls") => {
events.extend(state.tool_calls_by_index.drain().map(
- |(_, tool_call)| {
- maybe!({
- Ok(LanguageModelCompletionEvent::ToolUse(
- LanguageModelToolUse {
- id: tool_call.id.into(),
- name: tool_call.name.as_str().into(),
- is_input_complete: true,
- raw_input: tool_call.arguments.clone(),
- input: serde_json::Value::from_str(
- &tool_call.arguments,
- )?,
- },
- ))
- })
+ |(_, tool_call)| match serde_json::Value::from_str(
+ &tool_call.arguments,
+ ) {
+ Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
+ LanguageModelToolUse {
+ id: tool_call.id.clone().into(),
+ name: tool_call.name.as_str().into(),
+ is_input_complete: true,
+ input,
+ raw_input: tool_call.arguments.clone(),
+ },
+ )),
+ Err(error) => {
+ Err(LanguageModelCompletionError::BadInputJson {
+ id: tool_call.id.into(),
+ tool_name: tool_call.name.as_str().into(),
+ raw_input: tool_call.arguments.into(),
+ json_parse_error: error.to_string(),
+ })
+ }
},
));
@@ -516,7 +529,9 @@ pub fn map_to_language_model_completion_events(
return Some((events, state));
}
- Err(err) => return Some((vec![Err(err)], state)),
+ Err(err) => {
+ return Some((vec![Err(LanguageModelCompletionError::Other(err))], state));
+ }
}
}