diff --git a/Cargo.lock b/Cargo.lock index cbc494f9dc0fc1858a846fabe168b3538de4dbe5..3fccd850ae697925330d15ed6b72804c39f4795e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -629,13 +629,17 @@ version = "0.1.0" dependencies = [ "anyhow", "chrono", + "collections", "futures 0.3.32", "http_client", + "language_model_core", + "log", "schemars", "serde", "serde_json", "strum 0.27.2", "thiserror 2.0.17", + "tiktoken-rs", ] [[package]] @@ -2903,7 +2907,6 @@ dependencies = [ "http_client", "http_client_tls", "httparse", - "language_model", "log", "objc2-foundation", "parking_lot", @@ -2959,6 +2962,7 @@ dependencies = [ "http_client", "parking_lot", "serde_json", + "smol", "thiserror 2.0.17", "yawc", ] @@ -5162,6 +5166,7 @@ dependencies = [ "buffer_diff", "client", "clock", + "cloud_api_client", "cloud_api_types", "cloud_llm_client", "collections", @@ -5641,7 +5646,7 @@ dependencies = [ name = "env_var" version = "0.1.0" dependencies = [ - "gpui", + "gpui_shared_string", ] [[package]] @@ -7468,11 +7473,13 @@ dependencies = [ "anyhow", "futures 0.3.32", "http_client", + "language_model_core", + "log", "schemars", "serde", "serde_json", - "settings", "strum 0.27.2", + "tiktoken-rs", ] [[package]] @@ -7541,6 +7548,7 @@ dependencies = [ "getrandom 0.3.4", "gpui_macros", "gpui_platform", + "gpui_shared_string", "gpui_util", "gpui_web", "http_client", @@ -7710,6 +7718,16 @@ dependencies = [ "gpui_windows", ] +[[package]] +name = "gpui_shared_string" +version = "0.1.0" +dependencies = [ + "derive_more", + "gpui_util", + "schemars", + "serde", +] + [[package]] name = "gpui_tokio" version = "0.1.0" @@ -9358,7 +9376,7 @@ version = "0.1.0" dependencies = [ "anyhow", "collections", - "gpui", + "gpui_shared_string", "log", "lsp", "parking_lot", @@ -9397,12 +9415,8 @@ dependencies = [ name = "language_model" version = "0.1.0" dependencies = [ - "anthropic", "anyhow", "base64 0.22.1", - "cloud_api_client", - "cloud_api_types", - "cloud_llm_client", "collections", "credentials_provider", "env_var", @@ -9411,16 +9425,31 @@ dependencies = [ "http_client", "icons", "image", + "language_model_core", "log", - "open_ai", - "open_router", "parking_lot", + "serde", + "serde_json", + "thiserror 2.0.17", + "util", +] + +[[package]] +name = "language_model_core" +version = "0.1.0" +dependencies = [ + "anyhow", + "cloud_llm_client", + "futures 0.3.32", + "gpui_shared_string", + "http_client", + "partial-json-fixer", "schemars", "serde", "serde_json", "smol", + "strum 0.27.2", "thiserror 2.0.17", - "util", ] [[package]] @@ -9436,8 +9465,8 @@ dependencies = [ "base64 0.22.1", "bedrock", "client", + "cloud_api_client", "cloud_api_types", - "cloud_llm_client", "collections", "component", "convert_case 0.8.0", @@ -9456,6 +9485,7 @@ dependencies = [ "http_client", "language", "language_model", + "language_models_cloud", "lmstudio", "log", "menu", @@ -9464,17 +9494,14 @@ dependencies = [ "open_ai", "open_router", "opencode", - "partial-json-fixer", "pretty_assertions", "release_channel", "schemars", - "semver", "serde", "serde_json", "settings", "smol", "strum 0.27.2", - "thiserror 2.0.17", "tiktoken-rs", "tokio", "ui", @@ -9484,6 +9511,28 @@ dependencies = [ "x_ai", ] +[[package]] +name = "language_models_cloud" +version = "0.1.0" +dependencies = [ + "anthropic", + "anyhow", + "cloud_llm_client", + "futures 0.3.32", + "google_ai", + "gpui", + "http_client", + "language_model", + "open_ai", + "schemars", + "semver", + "serde", + "serde_json", + "smol", + "thiserror 2.0.17", + "x_ai", +] + [[package]] name = "language_onboarding" version = "0.1.0" @@ -11631,16 +11680,19 @@ name = "open_ai" version = "0.1.0" dependencies = [ "anyhow", + "collections", "futures 0.3.32", "http_client", + "language_model_core", "log", + "pretty_assertions", "rand 0.9.2", "schemars", "serde", "serde_json", - "settings", "strum 0.27.2", "thiserror 2.0.17", + "tiktoken-rs", ] [[package]] @@ -11672,6 +11724,7 @@ dependencies = [ "anyhow", "futures 0.3.32", "http_client", + "language_model_core", "schemars", "serde", "serde_json", @@ -15801,6 +15854,7 @@ dependencies = [ "collections", "derive_more", "gpui", + "language_model_core", "log", "schemars", "serde", @@ -20180,6 +20234,7 @@ version = "0.1.0" dependencies = [ "anyhow", "client", + "cloud_api_client", "cloud_api_types", "cloud_llm_client", "futures 0.3.32", @@ -21783,9 +21838,11 @@ name = "x_ai" version = "0.1.0" dependencies = [ "anyhow", + "language_model_core", "schemars", "serde", "strum 0.27.2", + "tiktoken-rs", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 4c75dafae5df4d63815e0da5cabb95ccdad25e9d..5a7fc9caaf982953168855671bebbcf4f010df03 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -87,6 +87,7 @@ members = [ "crates/google_ai", "crates/grammars", "crates/gpui", + "crates/gpui_shared_string", "crates/gpui_linux", "crates/gpui_macos", "crates/gpui_macros", @@ -110,7 +111,9 @@ members = [ "crates/language_core", "crates/language_extension", "crates/language_model", + "crates/language_model_core", "crates/language_models", + "crates/language_models_cloud", "crates/language_onboarding", "crates/language_selector", "crates/language_tools", @@ -335,6 +338,7 @@ go_to_line = { path = "crates/go_to_line" } google_ai = { path = "crates/google_ai" } grammars = { path = "crates/grammars" } gpui = { path = "crates/gpui", default-features = false } +gpui_shared_string = { path = "crates/gpui_shared_string" } gpui_linux = { path = "crates/gpui_linux", default-features = false } gpui_macos = { path = "crates/gpui_macos", default-features = false } gpui_macros = { path = "crates/gpui_macros" } @@ -361,7 +365,9 @@ language = { path = "crates/language" } language_core = { path = "crates/language_core" } language_extension = { path = "crates/language_extension" } language_model = { path = "crates/language_model" } +language_model_core = { path = "crates/language_model_core" } language_models = { path = "crates/language_models" } +language_models_cloud = { path = "crates/language_models_cloud" } language_onboarding = { path = "crates/language_onboarding" } language_selector = { path = "crates/language_selector" } language_tools = { path = "crates/language_tools" } diff --git a/crates/agent/src/tools/read_file_tool.rs b/crates/agent/src/tools/read_file_tool.rs index 0086a82f4e79c9924502202873ceb2b25d2e66fb..9b013f111e7eaa981652d8868dfcf3c098d9dc7e 100644 --- a/crates/agent/src/tools/read_file_tool.rs +++ b/crates/agent/src/tools/read_file_tool.rs @@ -5,7 +5,7 @@ use futures::FutureExt as _; use gpui::{App, Entity, SharedString, Task}; use indoc::formatdoc; use language::Point; -use language_model::{LanguageModelImage, LanguageModelToolResultContent}; +use language_model::{LanguageModelImage, LanguageModelImageExt, LanguageModelToolResultContent}; use project::{AgentLocation, ImageItem, Project, WorktreeSettings, image_store}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; diff --git a/crates/agent_servers/src/acp.rs b/crates/agent_servers/src/acp.rs index 5f452bc9c0e2e9c2322042583295894a5866b053..e56db9df927ab3cdf838587f1cb4f9514eb5a758 100644 --- a/crates/agent_servers/src/acp.rs +++ b/crates/agent_servers/src/acp.rs @@ -325,7 +325,7 @@ impl AcpConnection { // Use the one the agent provides if we have one .map(|info| info.name.into()) // Otherwise, just use the name - .unwrap_or_else(|| agent_id.0.to_string().into()); + .unwrap_or_else(|| agent_id.0.clone()); let session_list = if response .agent_capabilities diff --git a/crates/agent_ui/src/agent_registry_ui.rs b/crates/agent_ui/src/agent_registry_ui.rs index 78b4e3a5a3965c72b96d4ec201139b1d8e510fb2..e19afdecc390268cefbd7be4e5d0759aa2a29c19 100644 --- a/crates/agent_ui/src/agent_registry_ui.rs +++ b/crates/agent_ui/src/agent_registry_ui.rs @@ -382,7 +382,7 @@ impl AgentRegistryPage { self.install_button(agent, install_status, supports_current_platform, cx); let repository_button = agent.repository().map(|repository| { - let repository_for_tooltip: SharedString = repository.to_string().into(); + let repository_for_tooltip = repository.clone(); let repository_for_click = repository.to_string(); IconButton::new( diff --git a/crates/agent_ui/src/mention_set.rs b/crates/agent_ui/src/mention_set.rs index 1b2ec0ad2fd460b4eec5a8b757bdd3058d4a3704..880257e3f942bf71d1d51b1e661d911474aa786b 100644 --- a/crates/agent_ui/src/mention_set.rs +++ b/crates/agent_ui/src/mention_set.rs @@ -18,7 +18,7 @@ use gpui::{ use http_client::{AsyncBody, HttpClientWithUrl}; use itertools::Either; use language::Buffer; -use language_model::LanguageModelImage; +use language_model::{LanguageModelImage, LanguageModelImageExt}; use multi_buffer::MultiBufferRow; use postage::stream::Stream as _; use project::{Project, ProjectItem, ProjectPath, Worktree}; diff --git a/crates/anthropic/Cargo.toml b/crates/anthropic/Cargo.toml index 1e2587435489dea6952c697b0e0a4cf627226728..458f9bfae7da4736c4e54e42f08b5e3a926ed30a 100644 --- a/crates/anthropic/Cargo.toml +++ b/crates/anthropic/Cargo.toml @@ -18,12 +18,16 @@ path = "src/anthropic.rs" [dependencies] anyhow.workspace = true chrono.workspace = true +collections.workspace = true futures.workspace = true http_client.workspace = true +language_model_core.workspace = true +log.workspace = true schemars = { workspace = true, optional = true } serde.workspace = true serde_json.workspace = true strum.workspace = true thiserror.workspace = true +tiktoken-rs.workspace = true diff --git a/crates/anthropic/src/anthropic.rs b/crates/anthropic/src/anthropic.rs index 5d7790b86b09853e22436252fcde1bebf5feff9b..48fa318d7c1d87e63725cef836baf9c945966206 100644 --- a/crates/anthropic/src/anthropic.rs +++ b/crates/anthropic/src/anthropic.rs @@ -12,6 +12,7 @@ use strum::{EnumIter, EnumString}; use thiserror::Error; pub mod batches; +pub mod completion; pub const ANTHROPIC_API_URL: &str = "https://api.anthropic.com"; @@ -1026,6 +1027,89 @@ pub async fn count_tokens( } } +// -- Conversions from/to `language_model_core` types -- + +impl From for Speed { + fn from(speed: language_model_core::Speed) -> Self { + match speed { + language_model_core::Speed::Standard => Speed::Standard, + language_model_core::Speed::Fast => Speed::Fast, + } + } +} + +impl From for language_model_core::LanguageModelCompletionError { + fn from(error: AnthropicError) -> Self { + let provider = language_model_core::ANTHROPIC_PROVIDER_NAME; + match error { + AnthropicError::SerializeRequest(error) => Self::SerializeRequest { provider, error }, + AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error }, + AnthropicError::HttpSend(error) => Self::HttpSend { provider, error }, + AnthropicError::DeserializeResponse(error) => { + Self::DeserializeResponse { provider, error } + } + AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { provider, error }, + AnthropicError::HttpResponseError { + status_code, + message, + } => Self::HttpResponseError { + provider, + status_code, + message, + }, + AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded { + provider, + retry_after: Some(retry_after), + }, + AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded { + provider, + retry_after, + }, + AnthropicError::ApiError(api_error) => api_error.into(), + } + } +} + +impl From for language_model_core::LanguageModelCompletionError { + fn from(error: ApiError) -> Self { + use ApiErrorCode::*; + let provider = language_model_core::ANTHROPIC_PROVIDER_NAME; + match error.code() { + Some(code) => match code { + InvalidRequestError => Self::BadRequestFormat { + provider, + message: error.message, + }, + AuthenticationError => Self::AuthenticationError { + provider, + message: error.message, + }, + PermissionError => Self::PermissionError { + provider, + message: error.message, + }, + NotFoundError => Self::ApiEndpointNotFound { provider }, + RequestTooLarge => Self::PromptTooLarge { + tokens: language_model_core::parse_prompt_too_long(&error.message), + }, + RateLimitError => Self::RateLimitExceeded { + provider, + retry_after: None, + }, + ApiError => Self::ApiInternalServerError { + provider, + message: error.message, + }, + OverloadedError => Self::ServerOverloaded { + provider, + retry_after: None, + }, + }, + None => Self::Other(error.into()), + } + } +} + #[test] fn test_match_window_exceeded() { let error = ApiError { diff --git a/crates/anthropic/src/completion.rs b/crates/anthropic/src/completion.rs new file mode 100644 index 0000000000000000000000000000000000000000..a6175a4f7c24b3b724734b2edef48ef8acfaa159 --- /dev/null +++ b/crates/anthropic/src/completion.rs @@ -0,0 +1,765 @@ +use anyhow::Result; +use collections::HashMap; +use futures::{Stream, StreamExt}; +use language_model_core::{ + LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelRequest, + LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent, + Role, StopReason, TokenUsage, + util::{fix_streamed_json, parse_tool_arguments}, +}; +use std::pin::Pin; +use std::str::FromStr; + +use crate::{ + AnthropicError, AnthropicModelMode, CacheControl, CacheControlType, ContentDelta, + CountTokensRequest, Event, ImageSource, Message, RequestContent, ResponseContent, + StringOrContents, Thinking, Tool, ToolChoice, ToolResultContent, ToolResultPart, Usage, +}; + +fn to_anthropic_content(content: MessageContent) -> Option { + match content { + MessageContent::Text(text) => { + let text = if text.chars().last().is_some_and(|c| c.is_whitespace()) { + text.trim_end().to_string() + } else { + text + }; + if !text.is_empty() { + Some(RequestContent::Text { + text, + cache_control: None, + }) + } else { + None + } + } + MessageContent::Thinking { + text: thinking, + signature, + } => { + if let Some(signature) = signature + && !thinking.is_empty() + { + Some(RequestContent::Thinking { + thinking, + signature, + cache_control: None, + }) + } else { + None + } + } + MessageContent::RedactedThinking(data) => { + if !data.is_empty() { + Some(RequestContent::RedactedThinking { data }) + } else { + None + } + } + MessageContent::Image(image) => Some(RequestContent::Image { + source: ImageSource { + source_type: "base64".to_string(), + media_type: "image/png".to_string(), + data: image.source.to_string(), + }, + cache_control: None, + }), + MessageContent::ToolUse(tool_use) => Some(RequestContent::ToolUse { + id: tool_use.id.to_string(), + name: tool_use.name.to_string(), + input: tool_use.input, + cache_control: None, + }), + MessageContent::ToolResult(tool_result) => Some(RequestContent::ToolResult { + tool_use_id: tool_result.tool_use_id.to_string(), + is_error: tool_result.is_error, + content: match tool_result.content { + LanguageModelToolResultContent::Text(text) => { + ToolResultContent::Plain(text.to_string()) + } + LanguageModelToolResultContent::Image(image) => { + ToolResultContent::Multipart(vec![ToolResultPart::Image { + source: ImageSource { + source_type: "base64".to_string(), + media_type: "image/png".to_string(), + data: image.source.to_string(), + }, + }]) + } + }, + cache_control: None, + }), + } +} + +/// Convert a LanguageModelRequest to an Anthropic CountTokensRequest. +pub fn into_anthropic_count_tokens_request( + request: LanguageModelRequest, + model: String, + mode: AnthropicModelMode, +) -> CountTokensRequest { + let mut new_messages: Vec = Vec::new(); + let mut system_message = String::new(); + + for message in request.messages { + if message.contents_empty() { + continue; + } + + match message.role { + Role::User | Role::Assistant => { + let anthropic_message_content: Vec = message + .content + .into_iter() + .filter_map(to_anthropic_content) + .collect(); + let anthropic_role = match message.role { + Role::User => crate::Role::User, + Role::Assistant => crate::Role::Assistant, + Role::System => unreachable!("System role should never occur here"), + }; + if anthropic_message_content.is_empty() { + continue; + } + + if let Some(last_message) = new_messages.last_mut() + && last_message.role == anthropic_role + { + last_message.content.extend(anthropic_message_content); + continue; + } + + new_messages.push(Message { + role: anthropic_role, + content: anthropic_message_content, + }); + } + Role::System => { + if !system_message.is_empty() { + system_message.push_str("\n\n"); + } + system_message.push_str(&message.string_contents()); + } + } + } + + CountTokensRequest { + model, + messages: new_messages, + system: if system_message.is_empty() { + None + } else { + Some(StringOrContents::String(system_message)) + }, + thinking: if request.thinking_allowed { + match mode { + AnthropicModelMode::Thinking { budget_tokens } => { + Some(Thinking::Enabled { budget_tokens }) + } + AnthropicModelMode::AdaptiveThinking => Some(Thinking::Adaptive), + AnthropicModelMode::Default => None, + } + } else { + None + }, + tools: request + .tools + .into_iter() + .map(|tool| Tool { + name: tool.name, + description: tool.description, + input_schema: tool.input_schema, + eager_input_streaming: tool.use_input_streaming, + }) + .collect(), + tool_choice: request.tool_choice.map(|choice| match choice { + LanguageModelToolChoice::Auto => ToolChoice::Auto, + LanguageModelToolChoice::Any => ToolChoice::Any, + LanguageModelToolChoice::None => ToolChoice::None, + }), + } +} + +/// Estimate tokens using tiktoken. Used as a fallback when the API is unavailable, +/// or by providers (like Zed Cloud) that don't have direct Anthropic API access. +pub fn count_anthropic_tokens_with_tiktoken(request: LanguageModelRequest) -> Result { + let messages = request.messages; + let mut tokens_from_images = 0; + let mut string_messages = Vec::with_capacity(messages.len()); + + for message in messages { + let mut string_contents = String::new(); + + for content in message.content { + match content { + MessageContent::Text(text) => { + string_contents.push_str(&text); + } + MessageContent::Thinking { .. } => { + // Thinking blocks are not included in the input token count. + } + MessageContent::RedactedThinking(_) => { + // Thinking blocks are not included in the input token count. + } + MessageContent::Image(image) => { + tokens_from_images += image.estimate_tokens(); + } + MessageContent::ToolUse(_tool_use) => { + // TODO: Estimate token usage from tool uses. + } + MessageContent::ToolResult(tool_result) => match &tool_result.content { + LanguageModelToolResultContent::Text(text) => { + string_contents.push_str(text); + } + LanguageModelToolResultContent::Image(image) => { + tokens_from_images += image.estimate_tokens(); + } + }, + } + } + + if !string_contents.is_empty() { + string_messages.push(tiktoken_rs::ChatCompletionRequestMessage { + role: match message.role { + Role::User => "user".into(), + Role::Assistant => "assistant".into(), + Role::System => "system".into(), + }, + content: Some(string_contents), + name: None, + function_call: None, + }); + } + } + + // Tiktoken doesn't yet support these models, so we manually use the + // same tokenizer as GPT-4. + tiktoken_rs::num_tokens_from_messages("gpt-4", &string_messages) + .map(|tokens| (tokens + tokens_from_images) as u64) +} + +pub fn into_anthropic( + request: LanguageModelRequest, + model: String, + default_temperature: f32, + max_output_tokens: u64, + mode: AnthropicModelMode, +) -> crate::Request { + let mut new_messages: Vec = Vec::new(); + let mut system_message = String::new(); + + for message in request.messages { + if message.contents_empty() { + continue; + } + + match message.role { + Role::User | Role::Assistant => { + let mut anthropic_message_content: Vec = message + .content + .into_iter() + .filter_map(to_anthropic_content) + .collect(); + let anthropic_role = match message.role { + Role::User => crate::Role::User, + Role::Assistant => crate::Role::Assistant, + Role::System => unreachable!("System role should never occur here"), + }; + if anthropic_message_content.is_empty() { + continue; + } + + if let Some(last_message) = new_messages.last_mut() + && last_message.role == anthropic_role + { + last_message.content.extend(anthropic_message_content); + continue; + } + + // Mark the last segment of the message as cached + if message.cache { + let cache_control_value = Some(CacheControl { + cache_type: CacheControlType::Ephemeral, + }); + for message_content in anthropic_message_content.iter_mut().rev() { + match message_content { + RequestContent::RedactedThinking { .. } => { + // Caching is not possible, fallback to next message + } + RequestContent::Text { cache_control, .. } + | RequestContent::Thinking { cache_control, .. } + | RequestContent::Image { cache_control, .. } + | RequestContent::ToolUse { cache_control, .. } + | RequestContent::ToolResult { cache_control, .. } => { + *cache_control = cache_control_value; + break; + } + } + } + } + + new_messages.push(Message { + role: anthropic_role, + content: anthropic_message_content, + }); + } + Role::System => { + if !system_message.is_empty() { + system_message.push_str("\n\n"); + } + system_message.push_str(&message.string_contents()); + } + } + } + + crate::Request { + model, + messages: new_messages, + max_tokens: max_output_tokens, + system: if system_message.is_empty() { + None + } else { + Some(StringOrContents::String(system_message)) + }, + thinking: if request.thinking_allowed { + match mode { + AnthropicModelMode::Thinking { budget_tokens } => { + Some(Thinking::Enabled { budget_tokens }) + } + AnthropicModelMode::AdaptiveThinking => Some(Thinking::Adaptive), + AnthropicModelMode::Default => None, + } + } else { + None + }, + tools: request + .tools + .into_iter() + .map(|tool| Tool { + name: tool.name, + description: tool.description, + input_schema: tool.input_schema, + eager_input_streaming: tool.use_input_streaming, + }) + .collect(), + tool_choice: request.tool_choice.map(|choice| match choice { + LanguageModelToolChoice::Auto => ToolChoice::Auto, + LanguageModelToolChoice::Any => ToolChoice::Any, + LanguageModelToolChoice::None => ToolChoice::None, + }), + metadata: None, + output_config: if request.thinking_allowed + && matches!(mode, AnthropicModelMode::AdaptiveThinking) + { + request.thinking_effort.as_deref().and_then(|effort| { + let effort = match effort { + "low" => Some(crate::Effort::Low), + "medium" => Some(crate::Effort::Medium), + "high" => Some(crate::Effort::High), + "max" => Some(crate::Effort::Max), + _ => None, + }; + effort.map(|effort| crate::OutputConfig { + effort: Some(effort), + }) + }) + } else { + None + }, + stop_sequences: Vec::new(), + speed: request.speed.map(Into::into), + temperature: request.temperature.or(Some(default_temperature)), + top_k: None, + top_p: None, + } +} + +pub struct AnthropicEventMapper { + tool_uses_by_index: HashMap, + usage: Usage, + stop_reason: StopReason, +} + +impl AnthropicEventMapper { + pub fn new() -> Self { + Self { + tool_uses_by_index: HashMap::default(), + usage: Usage::default(), + stop_reason: StopReason::EndTurn, + } + } + + pub fn map_stream( + mut self, + events: Pin>>>, + ) -> impl Stream> + { + events.flat_map(move |event| { + futures::stream::iter(match event { + Ok(event) => self.map_event(event), + Err(error) => vec![Err(error.into())], + }) + }) + } + + pub fn map_event( + &mut self, + event: Event, + ) -> Vec> { + match event { + Event::ContentBlockStart { + index, + content_block, + } => match content_block { + ResponseContent::Text { text } => { + vec![Ok(LanguageModelCompletionEvent::Text(text))] + } + ResponseContent::Thinking { thinking } => { + vec![Ok(LanguageModelCompletionEvent::Thinking { + text: thinking, + signature: None, + })] + } + ResponseContent::RedactedThinking { data } => { + vec![Ok(LanguageModelCompletionEvent::RedactedThinking { data })] + } + ResponseContent::ToolUse { id, name, .. } => { + self.tool_uses_by_index.insert( + index, + RawToolUse { + id, + name, + input_json: String::new(), + }, + ); + Vec::new() + } + }, + Event::ContentBlockDelta { index, delta } => match delta { + ContentDelta::TextDelta { text } => { + vec![Ok(LanguageModelCompletionEvent::Text(text))] + } + ContentDelta::ThinkingDelta { thinking } => { + vec![Ok(LanguageModelCompletionEvent::Thinking { + text: thinking, + signature: None, + })] + } + ContentDelta::SignatureDelta { signature } => { + vec![Ok(LanguageModelCompletionEvent::Thinking { + text: "".to_string(), + signature: Some(signature), + })] + } + ContentDelta::InputJsonDelta { partial_json } => { + if let Some(tool_use) = self.tool_uses_by_index.get_mut(&index) { + tool_use.input_json.push_str(&partial_json); + + // Try to convert invalid (incomplete) JSON into + // valid JSON that serde can accept, e.g. by closing + // unclosed delimiters. This way, we can update the + // UI with whatever has been streamed back so far. + if let Ok(input) = + serde_json::Value::from_str(&fix_streamed_json(&tool_use.input_json)) + { + return vec![Ok(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: tool_use.id.clone().into(), + name: tool_use.name.clone().into(), + is_input_complete: false, + raw_input: tool_use.input_json.clone(), + input, + thought_signature: None, + }, + ))]; + } + } + vec![] + } + }, + Event::ContentBlockStop { index } => { + if let Some(tool_use) = self.tool_uses_by_index.remove(&index) { + let input_json = tool_use.input_json.trim(); + let event_result = match parse_tool_arguments(input_json) { + Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: tool_use.id.into(), + name: tool_use.name.into(), + is_input_complete: true, + input, + raw_input: tool_use.input_json.clone(), + thought_signature: None, + }, + )), + Err(json_parse_err) => { + Ok(LanguageModelCompletionEvent::ToolUseJsonParseError { + id: tool_use.id.into(), + tool_name: tool_use.name.into(), + raw_input: input_json.into(), + json_parse_error: json_parse_err.to_string(), + }) + } + }; + + vec![event_result] + } else { + Vec::new() + } + } + Event::MessageStart { message } => { + update_usage(&mut self.usage, &message.usage); + vec![ + Ok(LanguageModelCompletionEvent::UsageUpdate(convert_usage( + &self.usage, + ))), + Ok(LanguageModelCompletionEvent::StartMessage { + message_id: message.id, + }), + ] + } + Event::MessageDelta { delta, usage } => { + update_usage(&mut self.usage, &usage); + if let Some(stop_reason) = delta.stop_reason.as_deref() { + self.stop_reason = match stop_reason { + "end_turn" => StopReason::EndTurn, + "max_tokens" => StopReason::MaxTokens, + "tool_use" => StopReason::ToolUse, + "refusal" => StopReason::Refusal, + _ => { + log::error!("Unexpected anthropic stop_reason: {stop_reason}"); + StopReason::EndTurn + } + }; + } + vec![Ok(LanguageModelCompletionEvent::UsageUpdate( + convert_usage(&self.usage), + ))] + } + Event::MessageStop => { + vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))] + } + Event::Error { error } => { + vec![Err(error.into())] + } + _ => Vec::new(), + } + } +} + +struct RawToolUse { + id: String, + name: String, + input_json: String, +} + +/// Updates usage data by preferring counts from `new`. +fn update_usage(usage: &mut Usage, new: &Usage) { + if let Some(input_tokens) = new.input_tokens { + usage.input_tokens = Some(input_tokens); + } + if let Some(output_tokens) = new.output_tokens { + usage.output_tokens = Some(output_tokens); + } + if let Some(cache_creation_input_tokens) = new.cache_creation_input_tokens { + usage.cache_creation_input_tokens = Some(cache_creation_input_tokens); + } + if let Some(cache_read_input_tokens) = new.cache_read_input_tokens { + usage.cache_read_input_tokens = Some(cache_read_input_tokens); + } +} + +fn convert_usage(usage: &Usage) -> TokenUsage { + TokenUsage { + input_tokens: usage.input_tokens.unwrap_or(0), + output_tokens: usage.output_tokens.unwrap_or(0), + cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0), + cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::AnthropicModelMode; + use language_model_core::{LanguageModelImage, LanguageModelRequestMessage, MessageContent}; + + #[test] + fn test_cache_control_only_on_last_segment() { + let request = LanguageModelRequest { + messages: vec![LanguageModelRequestMessage { + role: Role::User, + content: vec![ + MessageContent::Text("Some prompt".to_string()), + MessageContent::Image(LanguageModelImage::empty()), + MessageContent::Image(LanguageModelImage::empty()), + MessageContent::Image(LanguageModelImage::empty()), + MessageContent::Image(LanguageModelImage::empty()), + ], + cache: true, + reasoning_details: None, + }], + thread_id: None, + prompt_id: None, + intent: None, + stop: vec![], + temperature: None, + tools: vec![], + tool_choice: None, + thinking_allowed: true, + thinking_effort: None, + speed: None, + }; + + let anthropic_request = into_anthropic( + request, + "claude-3-5-sonnet".to_string(), + 0.7, + 4096, + AnthropicModelMode::Default, + ); + + assert_eq!(anthropic_request.messages.len(), 1); + + let message = &anthropic_request.messages[0]; + assert_eq!(message.content.len(), 5); + + assert!(matches!( + message.content[0], + RequestContent::Text { + cache_control: None, + .. + } + )); + for i in 1..3 { + assert!(matches!( + message.content[i], + RequestContent::Image { + cache_control: None, + .. + } + )); + } + + assert!(matches!( + message.content[4], + RequestContent::Image { + cache_control: Some(CacheControl { + cache_type: CacheControlType::Ephemeral, + }), + .. + } + )); + } + + fn request_with_assistant_content(assistant_content: Vec) -> crate::Request { + let mut request = LanguageModelRequest { + messages: vec![LanguageModelRequestMessage { + role: Role::User, + content: vec![MessageContent::Text("Hello".to_string())], + cache: false, + reasoning_details: None, + }], + thinking_effort: None, + thread_id: None, + prompt_id: None, + intent: None, + stop: vec![], + temperature: None, + tools: vec![], + tool_choice: None, + thinking_allowed: true, + speed: None, + }; + request.messages.push(LanguageModelRequestMessage { + role: Role::Assistant, + content: assistant_content, + cache: false, + reasoning_details: None, + }); + into_anthropic( + request, + "claude-sonnet-4-5".to_string(), + 1.0, + 16000, + AnthropicModelMode::Thinking { + budget_tokens: Some(10000), + }, + ) + } + + #[test] + fn test_unsigned_thinking_blocks_stripped() { + let result = request_with_assistant_content(vec![ + MessageContent::Thinking { + text: "Cancelled mid-think, no signature".to_string(), + signature: None, + }, + MessageContent::Text("Some response text".to_string()), + ]); + + let assistant_message = result + .messages + .iter() + .find(|m| m.role == crate::Role::Assistant) + .expect("assistant message should still exist"); + + assert_eq!( + assistant_message.content.len(), + 1, + "Only the text content should remain; unsigned thinking block should be stripped" + ); + assert!(matches!( + &assistant_message.content[0], + RequestContent::Text { text, .. } if text == "Some response text" + )); + } + + #[test] + fn test_signed_thinking_blocks_preserved() { + let result = request_with_assistant_content(vec![ + MessageContent::Thinking { + text: "Completed thinking".to_string(), + signature: Some("valid-signature".to_string()), + }, + MessageContent::Text("Response".to_string()), + ]); + + let assistant_message = result + .messages + .iter() + .find(|m| m.role == crate::Role::Assistant) + .expect("assistant message should exist"); + + assert_eq!( + assistant_message.content.len(), + 2, + "Both the signed thinking block and text should be preserved" + ); + assert!(matches!( + &assistant_message.content[0], + RequestContent::Thinking { thinking, signature, .. } + if thinking == "Completed thinking" && signature == "valid-signature" + )); + } + + #[test] + fn test_only_unsigned_thinking_block_omits_entire_message() { + let result = request_with_assistant_content(vec![MessageContent::Thinking { + text: "Cancelled before any text or signature".to_string(), + signature: None, + }]); + + let assistant_messages: Vec<_> = result + .messages + .iter() + .filter(|m| m.role == crate::Role::Assistant) + .collect(); + + assert_eq!( + assistant_messages.len(), + 0, + "An assistant message whose only content was an unsigned thinking block \ + should be omitted entirely" + ); + } +} diff --git a/crates/client/Cargo.toml b/crates/client/Cargo.toml index 7bbaccb22e0e6c7508240186103e216f83be2f0c..532fe38f7df1f686730ed862a81806e9a531e156 100644 --- a/crates/client/Cargo.toml +++ b/crates/client/Cargo.toml @@ -36,7 +36,6 @@ gpui_tokio.workspace = true http_client.workspace = true http_client_tls.workspace = true httparse = "1.10" -language_model.workspace = true log.workspace = true parking_lot.workspace = true paths.workspace = true diff --git a/crates/client/src/client.rs b/crates/client/src/client.rs index dfd9963a0ee52d167f8d4edb0b850f4debed7fd4..05ca974f80438542b232262dd375e0e38ab4327c 100644 --- a/crates/client/src/client.rs +++ b/crates/client/src/client.rs @@ -14,6 +14,7 @@ use async_tungstenite::tungstenite::{ http::{HeaderValue, Request, StatusCode}, }; use clock::SystemClock; +use cloud_api_client::LlmApiToken; use cloud_api_client::websocket_protocol::MessageToClient; use cloud_api_client::{ClientApiError, CloudApiClient}; use cloud_api_types::OrganizationId; @@ -26,7 +27,6 @@ use futures::{ }; use gpui::{App, AsyncApp, Entity, Global, Task, WeakEntity, actions}; use http_client::{HttpClient, HttpClientWithUrl, http, read_proxy_from_env}; -use language_model::LlmApiToken; use parking_lot::{Mutex, RwLock}; use postage::watch; use proxy::connect_proxy_stream; diff --git a/crates/client/src/llm_token.rs b/crates/client/src/llm_token.rs index f62aa6dd4dc3462bc3a0f6f46c35f0e4e5499816..70457679e4b965e3251ae4861d3052bfa41fd65a 100644 --- a/crates/client/src/llm_token.rs +++ b/crates/client/src/llm_token.rs @@ -1,10 +1,10 @@ use super::{Client, UserStore}; +use cloud_api_client::LlmApiToken; use cloud_api_types::websocket_protocol::MessageToClient; use cloud_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, OUTDATED_LLM_TOKEN_HEADER_NAME}; use gpui::{ App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _, Subscription, }; -use language_model::LlmApiToken; use std::sync::Arc; pub trait NeedsLlmTokenRefresh { diff --git a/crates/cloud_api_client/Cargo.toml b/crates/cloud_api_client/Cargo.toml index 78c684e3e54ee29a5f3f3ae5620d4a52b445f92e..cf293d83f848e1266dec977c0925af7f66608ce6 100644 --- a/crates/cloud_api_client/Cargo.toml +++ b/crates/cloud_api_client/Cargo.toml @@ -20,5 +20,6 @@ gpui_tokio.workspace = true http_client.workspace = true parking_lot.workspace = true serde_json.workspace = true +smol.workspace = true thiserror.workspace = true yawc.workspace = true diff --git a/crates/cloud_api_client/src/cloud_api_client.rs b/crates/cloud_api_client/src/cloud_api_client.rs index 13d67838b216f4990f15ec22c1701aa7aef9dbf2..8c605bb3490ef5c7aea6e96045680338e8344a83 100644 --- a/crates/cloud_api_client/src/cloud_api_client.rs +++ b/crates/cloud_api_client/src/cloud_api_client.rs @@ -1,3 +1,4 @@ +mod llm_token; mod websocket; use std::sync::Arc; @@ -18,6 +19,8 @@ use yawc::WebSocket; use crate::websocket::Connection; +pub use llm_token::LlmApiToken; + struct Credentials { user_id: u32, access_token: String, diff --git a/crates/cloud_api_client/src/llm_token.rs b/crates/cloud_api_client/src/llm_token.rs new file mode 100644 index 0000000000000000000000000000000000000000..711e0d51b89bf34db255d7cb1e58483c9de340fc --- /dev/null +++ b/crates/cloud_api_client/src/llm_token.rs @@ -0,0 +1,74 @@ +use std::sync::Arc; + +use cloud_api_types::OrganizationId; +use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard}; + +use crate::{ClientApiError, CloudApiClient}; + +#[derive(Clone, Default)] +pub struct LlmApiToken(Arc>>); + +impl LlmApiToken { + pub async fn acquire( + &self, + client: &CloudApiClient, + system_id: Option, + organization_id: Option, + ) -> Result { + let lock = self.0.upgradable_read().await; + if let Some(token) = lock.as_ref() { + Ok(token.to_string()) + } else { + Self::fetch( + RwLockUpgradableReadGuard::upgrade(lock).await, + client, + system_id, + organization_id, + ) + .await + } + } + + pub async fn refresh( + &self, + client: &CloudApiClient, + system_id: Option, + organization_id: Option, + ) -> Result { + Self::fetch(self.0.write().await, client, system_id, organization_id).await + } + + /// Clears the existing token before attempting to fetch a new one. + /// + /// Used when switching organizations so that a failed refresh doesn't + /// leave a token for the wrong organization. + pub async fn clear_and_refresh( + &self, + client: &CloudApiClient, + system_id: Option, + organization_id: Option, + ) -> Result { + let mut lock = self.0.write().await; + *lock = None; + Self::fetch(lock, client, system_id, organization_id).await + } + + async fn fetch( + mut lock: RwLockWriteGuard<'_, Option>, + client: &CloudApiClient, + system_id: Option, + organization_id: Option, + ) -> Result { + let result = client.create_llm_token(system_id, organization_id).await; + match result { + Ok(response) => { + *lock = Some(response.token.0.clone()); + Ok(response.token.0) + } + Err(err) => { + *lock = None; + Err(err) + } + } + } +} diff --git a/crates/cloud_llm_client/Cargo.toml b/crates/cloud_llm_client/Cargo.toml index a7b4f925a9302296e8fe25a14177a583e5f44b33..7cc59f255abeb27c6e35a2064654d8eca1a581fe 100644 --- a/crates/cloud_llm_client/Cargo.toml +++ b/crates/cloud_llm_client/Cargo.toml @@ -7,6 +7,7 @@ license = "Apache-2.0" [features] test-support = [] +predict-edits = ["dep:zeta_prompt"] [lints] workspace = true @@ -20,6 +21,6 @@ serde = { workspace = true, features = ["derive", "rc"] } serde_json.workspace = true strum = { workspace = true, features = ["derive"] } uuid = { workspace = true, features = ["serde"] } -zeta_prompt.workspace = true +zeta_prompt = { workspace = true, optional = true } diff --git a/crates/cloud_llm_client/src/cloud_llm_client.rs b/crates/cloud_llm_client/src/cloud_llm_client.rs index 35eb3f2b80dd400558b1f027781f5b8cf63bb6cb..ac8bdd462a9c4754ef42a6afa41f1bef8b5bbe6a 100644 --- a/crates/cloud_llm_client/src/cloud_llm_client.rs +++ b/crates/cloud_llm_client/src/cloud_llm_client.rs @@ -1,3 +1,4 @@ +#[cfg(feature = "predict-edits")] pub mod predict_edits_v3; use std::str::FromStr; diff --git a/crates/collab_ui/src/collab_panel.rs b/crates/collab_ui/src/collab_panel.rs index 7dc807998760a8e65d373164eec5c7663171e5d0..327ef1cf6003eb959bd0926d67d2b0ed3b4ab0ba 100644 --- a/crates/collab_ui/src/collab_panel.rs +++ b/crates/collab_ui/src/collab_panel.rs @@ -2846,11 +2846,11 @@ impl CollabPanel { } }; - Some(channel.name.as_ref()) + Some(channel.name.clone()) }); if let Some(name) = channel_name { - SharedString::from(name.to_string()) + name } else { SharedString::from("Current Call") } diff --git a/crates/edit_prediction/Cargo.toml b/crates/edit_prediction/Cargo.toml index eabb1641fd4fbec7b2f8ef0ba399a8fe9600dfa3..87ad4e42e7826cdda4fc6a8c31a27afe888830f0 100644 --- a/crates/edit_prediction/Cargo.toml +++ b/crates/edit_prediction/Cargo.toml @@ -21,8 +21,9 @@ heapless.workspace = true buffer_diff.workspace = true client.workspace = true clock.workspace = true +cloud_api_client.workspace = true cloud_api_types.workspace = true -cloud_llm_client.workspace = true +cloud_llm_client = { workspace = true, features = ["predict-edits"] } collections.workspace = true copilot.workspace = true copilot_ui.workspace = true diff --git a/crates/edit_prediction/src/edit_prediction.rs b/crates/edit_prediction/src/edit_prediction.rs index 280427df006b510e1854ffb40cd7f995fcd9fdc6..2d90e13fb9b45aedd354f753502cd4e616ae3bcd 100644 --- a/crates/edit_prediction/src/edit_prediction.rs +++ b/crates/edit_prediction/src/edit_prediction.rs @@ -1,5 +1,6 @@ use anyhow::Result; use client::{Client, EditPredictionUsage, NeedsLlmTokenRefresh, UserStore, global_llm_token}; +use cloud_api_client::LlmApiToken; use cloud_api_types::{OrganizationId, SubmitEditPredictionFeedbackBody}; use cloud_llm_client::predict_edits_v3::{ PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse, @@ -31,7 +32,6 @@ use heapless::Vec as ArrayVec; use language::language_settings::all_language_settings; use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint}; use language::{BufferSnapshot, OffsetRangeExt}; -use language_model::LlmApiToken; use project::{DisableAiSettings, Project, ProjectPath, WorktreeId}; use release_channel::AppVersion; use semver::Version; diff --git a/crates/edit_prediction/src/ollama.rs b/crates/edit_prediction/src/ollama.rs index 0250ec44a46cf081c6badc6fa11a9c34ebb65c4a..0ae90dd9f6eca4bfe9f87950a5a66916d8894df4 100644 --- a/crates/edit_prediction/src/ollama.rs +++ b/crates/edit_prediction/src/ollama.rs @@ -57,7 +57,7 @@ pub fn fetch_models(cx: &mut App) -> Vec { let mut models: Vec = provider .provided_models(cx) .into_iter() - .map(|model| SharedString::from(model.id().0.to_string())) + .map(|model| model.id().0) .collect(); models.sort(); models diff --git a/crates/edit_prediction/src/zed_edit_prediction_delegate.rs b/crates/edit_prediction/src/zed_edit_prediction_delegate.rs index c5e97fd87eaad9b98aeb9b946a9a69b1c1071db2..1a574e9389715ce888f8b8c5ec8be921ceab4a38 100644 --- a/crates/edit_prediction/src/zed_edit_prediction_delegate.rs +++ b/crates/edit_prediction/src/zed_edit_prediction_delegate.rs @@ -177,7 +177,7 @@ impl EditPredictionDelegate for ZedEditPredictionDelegate { BufferEditPrediction::Local { prediction } => prediction, BufferEditPrediction::Jump { prediction } => { return Some(edit_prediction_types::EditPrediction::Jump { - id: Some(prediction.id.to_string().into()), + id: Some(prediction.id.0.clone()), snapshot: prediction.snapshot.clone(), target: prediction.edits.first().unwrap().0.start, }); @@ -228,7 +228,7 @@ impl EditPredictionDelegate for ZedEditPredictionDelegate { } Some(edit_prediction_types::EditPrediction::Local { - id: Some(prediction.id.to_string().into()), + id: Some(prediction.id.0.clone()), edits: edits[edit_start_ix..edit_end_ix].to_vec(), cursor_position: prediction.cursor_position, edit_preview: Some(prediction.edit_preview.clone()), diff --git a/crates/edit_prediction_cli/Cargo.toml b/crates/edit_prediction_cli/Cargo.toml index 323ee3de41902b2140f95da22b0e37fb98d31fd5..a999fed2baf990273f0801bac15573b3aed0cc78 100644 --- a/crates/edit_prediction_cli/Cargo.toml +++ b/crates/edit_prediction_cli/Cargo.toml @@ -22,7 +22,7 @@ http_client.workspace = true chrono.workspace = true clap = "4" client.workspace = true -cloud_llm_client.workspace= true +cloud_llm_client = { workspace = true, features = ["predict-edits"] } collections.workspace = true db.workspace = true debug_adapter_extension.workspace = true diff --git a/crates/env_var/Cargo.toml b/crates/env_var/Cargo.toml index 2cbbd08c7833d3e57a09766d42ffffe35c620a93..3c879a2f49184e19a131046320d767931e1ca8ec 100644 --- a/crates/env_var/Cargo.toml +++ b/crates/env_var/Cargo.toml @@ -12,4 +12,4 @@ workspace = true path = "src/env_var.rs" [dependencies] -gpui.workspace = true +gpui_shared_string.workspace = true diff --git a/crates/env_var/src/env_var.rs b/crates/env_var/src/env_var.rs index 79f671e0147ebfaad4ab76a123cc477dc7e55cb7..cb436e95e0e734e4b7d8d271199246e1558a074d 100644 --- a/crates/env_var/src/env_var.rs +++ b/crates/env_var/src/env_var.rs @@ -1,4 +1,4 @@ -use gpui::SharedString; +use gpui_shared_string::SharedString; #[derive(Clone)] pub struct EnvVar { diff --git a/crates/git_ui/src/branch_picker.rs b/crates/git_ui/src/branch_picker.rs index 83c8119a077ac1c024dbb3b3df948f762b072ec1..2bf4a1991f7a302ed73fe098e8914fedd0f9eb2a 100644 --- a/crates/git_ui/src/branch_picker.rs +++ b/crates/git_ui/src/branch_picker.rs @@ -1906,7 +1906,7 @@ mod tests { assert_eq!( remotes, vec![Remote { - name: SharedString::from("my_new_remote".to_string()) + name: SharedString::from("my_new_remote") }] ); } diff --git a/crates/google_ai/Cargo.toml b/crates/google_ai/Cargo.toml index 81e05e4836529e9b73b58b72683a7e72a4d5c984..d91d28851997723835ba85be343a453918301c71 100644 --- a/crates/google_ai/Cargo.toml +++ b/crates/google_ai/Cargo.toml @@ -18,8 +18,10 @@ schemars = ["dep:schemars"] anyhow.workspace = true futures.workspace = true http_client.workspace = true +language_model_core.workspace = true +log.workspace = true schemars = { workspace = true, optional = true } serde.workspace = true serde_json.workspace = true -settings.workspace = true strum.workspace = true +tiktoken-rs.workspace = true diff --git a/crates/google_ai/src/completion.rs b/crates/google_ai/src/completion.rs new file mode 100644 index 0000000000000000000000000000000000000000..3a15fdaa0187e52cb82dc8c71b5b861eb797f1a8 --- /dev/null +++ b/crates/google_ai/src/completion.rs @@ -0,0 +1,492 @@ +use anyhow::Result; +use futures::{Stream, StreamExt}; +use language_model_core::{ + LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelRequest, + LanguageModelToolChoice, LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role, + StopReason, TokenUsage, +}; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::atomic::{self, AtomicU64}; + +use crate::{ + Content, FunctionCallingConfig, FunctionCallingMode, FunctionDeclaration, + GenerateContentResponse, GenerationConfig, GenerativeContentBlob, GoogleModelMode, + InlineDataPart, ModelName, Part, SystemInstruction, TextPart, ThinkingConfig, ToolConfig, + UsageMetadata, +}; + +pub fn into_google( + mut request: LanguageModelRequest, + model_id: String, + mode: GoogleModelMode, +) -> crate::GenerateContentRequest { + fn map_content(content: Vec) -> Vec { + content + .into_iter() + .flat_map(|content| match content { + MessageContent::Text(text) => { + if !text.is_empty() { + vec![Part::TextPart(TextPart { text })] + } else { + vec![] + } + } + MessageContent::Thinking { + text: _, + signature: Some(signature), + } => { + if !signature.is_empty() { + vec![Part::ThoughtPart(crate::ThoughtPart { + thought: true, + thought_signature: signature, + })] + } else { + vec![] + } + } + MessageContent::Thinking { .. } => { + vec![] + } + MessageContent::RedactedThinking(_) => vec![], + MessageContent::Image(image) => { + vec![Part::InlineDataPart(InlineDataPart { + inline_data: GenerativeContentBlob { + mime_type: "image/png".to_string(), + data: image.source.to_string(), + }, + })] + } + MessageContent::ToolUse(tool_use) => { + // Normalize empty string signatures to None + let thought_signature = tool_use.thought_signature.filter(|s| !s.is_empty()); + + vec![Part::FunctionCallPart(crate::FunctionCallPart { + function_call: crate::FunctionCall { + name: tool_use.name.to_string(), + args: tool_use.input, + }, + thought_signature, + })] + } + MessageContent::ToolResult(tool_result) => { + match tool_result.content { + language_model_core::LanguageModelToolResultContent::Text(text) => { + vec![Part::FunctionResponsePart(crate::FunctionResponsePart { + function_response: crate::FunctionResponse { + name: tool_result.tool_name.to_string(), + // The API expects a valid JSON object + response: serde_json::json!({ + "output": text + }), + }, + })] + } + language_model_core::LanguageModelToolResultContent::Image(image) => { + vec![ + Part::FunctionResponsePart(crate::FunctionResponsePart { + function_response: crate::FunctionResponse { + name: tool_result.tool_name.to_string(), + // The API expects a valid JSON object + response: serde_json::json!({ + "output": "Tool responded with an image" + }), + }, + }), + Part::InlineDataPart(InlineDataPart { + inline_data: GenerativeContentBlob { + mime_type: "image/png".to_string(), + data: image.source.to_string(), + }, + }), + ] + } + } + } + }) + .collect() + } + + let system_instructions = if request + .messages + .first() + .is_some_and(|msg| matches!(msg.role, Role::System)) + { + let message = request.messages.remove(0); + Some(SystemInstruction { + parts: map_content(message.content), + }) + } else { + None + }; + + crate::GenerateContentRequest { + model: ModelName { model_id }, + system_instruction: system_instructions, + contents: request + .messages + .into_iter() + .filter_map(|message| { + let parts = map_content(message.content); + if parts.is_empty() { + None + } else { + Some(Content { + parts, + role: match message.role { + Role::User => crate::Role::User, + Role::Assistant => crate::Role::Model, + Role::System => crate::Role::User, // Google AI doesn't have a system role + }, + }) + } + }) + .collect(), + generation_config: Some(GenerationConfig { + candidate_count: Some(1), + stop_sequences: Some(request.stop), + max_output_tokens: None, + temperature: request.temperature.map(|t| t as f64).or(Some(1.0)), + thinking_config: match (request.thinking_allowed, mode) { + (true, GoogleModelMode::Thinking { budget_tokens }) => { + budget_tokens.map(|thinking_budget| ThinkingConfig { thinking_budget }) + } + _ => None, + }, + top_p: None, + top_k: None, + }), + safety_settings: None, + tools: (!request.tools.is_empty()).then(|| { + vec![crate::Tool { + function_declarations: request + .tools + .into_iter() + .map(|tool| FunctionDeclaration { + name: tool.name, + description: tool.description, + parameters: tool.input_schema, + }) + .collect(), + }] + }), + tool_config: request.tool_choice.map(|choice| ToolConfig { + function_calling_config: FunctionCallingConfig { + mode: match choice { + LanguageModelToolChoice::Auto => FunctionCallingMode::Auto, + LanguageModelToolChoice::Any => FunctionCallingMode::Any, + LanguageModelToolChoice::None => FunctionCallingMode::None, + }, + allowed_function_names: None, + }, + }), + } +} + +pub struct GoogleEventMapper { + usage: UsageMetadata, + stop_reason: StopReason, +} + +impl GoogleEventMapper { + pub fn new() -> Self { + Self { + usage: UsageMetadata::default(), + stop_reason: StopReason::EndTurn, + } + } + + pub fn map_stream( + mut self, + events: Pin>>>, + ) -> impl Stream> + { + events + .map(Some) + .chain(futures::stream::once(async { None })) + .flat_map(move |event| { + futures::stream::iter(match event { + Some(Ok(event)) => self.map_event(event), + Some(Err(error)) => { + vec![Err(LanguageModelCompletionError::from(error))] + } + None => vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))], + }) + }) + } + + pub fn map_event( + &mut self, + event: GenerateContentResponse, + ) -> Vec> { + static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0); + + let mut events: Vec<_> = Vec::new(); + let mut wants_to_use_tool = false; + if let Some(usage_metadata) = event.usage_metadata { + update_usage(&mut self.usage, &usage_metadata); + events.push(Ok(LanguageModelCompletionEvent::UsageUpdate( + convert_usage(&self.usage), + ))) + } + + if let Some(prompt_feedback) = event.prompt_feedback + && let Some(block_reason) = prompt_feedback.block_reason.as_deref() + { + self.stop_reason = match block_reason { + "SAFETY" | "OTHER" | "BLOCKLIST" | "PROHIBITED_CONTENT" | "IMAGE_SAFETY" => { + StopReason::Refusal + } + _ => { + log::error!("Unexpected Google block_reason: {block_reason}"); + StopReason::Refusal + } + }; + events.push(Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))); + + return events; + } + + if let Some(candidates) = event.candidates { + for candidate in candidates { + if let Some(finish_reason) = candidate.finish_reason.as_deref() { + self.stop_reason = match finish_reason { + "STOP" => StopReason::EndTurn, + "MAX_TOKENS" => StopReason::MaxTokens, + _ => { + log::error!("Unexpected google finish_reason: {finish_reason}"); + StopReason::EndTurn + } + }; + } + candidate + .content + .parts + .into_iter() + .for_each(|part| match part { + Part::TextPart(text_part) => { + events.push(Ok(LanguageModelCompletionEvent::Text(text_part.text))) + } + Part::InlineDataPart(_) => {} + Part::FunctionCallPart(function_call_part) => { + wants_to_use_tool = true; + let name: Arc = function_call_part.function_call.name.into(); + let next_tool_id = + TOOL_CALL_COUNTER.fetch_add(1, atomic::Ordering::SeqCst); + let id: LanguageModelToolUseId = + format!("{}-{}", name, next_tool_id).into(); + + // Normalize empty string signatures to None + let thought_signature = function_call_part + .thought_signature + .filter(|s| !s.is_empty()); + + events.push(Ok(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id, + name, + is_input_complete: true, + raw_input: function_call_part.function_call.args.to_string(), + input: function_call_part.function_call.args, + thought_signature, + }, + ))); + } + Part::FunctionResponsePart(_) => {} + Part::ThoughtPart(part) => { + events.push(Ok(LanguageModelCompletionEvent::Thinking { + text: "(Encrypted thought)".to_string(), // TODO: Can we populate this from thought summaries? + signature: Some(part.thought_signature), + })); + } + }); + } + } + + // Even when Gemini wants to use a Tool, the API + // responds with `finish_reason: STOP` + if wants_to_use_tool { + self.stop_reason = StopReason::ToolUse; + events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse))); + } + events + } +} + +/// Count tokens for a Google AI model using tiktoken. This is synchronous; +/// callers should spawn it on a background thread if needed. +pub fn count_google_tokens(request: LanguageModelRequest) -> Result { + let messages = request + .messages + .into_iter() + .map(|message| tiktoken_rs::ChatCompletionRequestMessage { + role: match message.role { + Role::User => "user".into(), + Role::Assistant => "assistant".into(), + Role::System => "system".into(), + }, + content: Some(message.string_contents()), + name: None, + function_call: None, + }) + .collect::>(); + + // Tiktoken doesn't yet support these models, so we manually use the + // same tokenizer as GPT-4. + tiktoken_rs::num_tokens_from_messages("gpt-4", &messages).map(|tokens| tokens as u64) +} + +fn update_usage(usage: &mut UsageMetadata, new: &UsageMetadata) { + if let Some(prompt_token_count) = new.prompt_token_count { + usage.prompt_token_count = Some(prompt_token_count); + } + if let Some(cached_content_token_count) = new.cached_content_token_count { + usage.cached_content_token_count = Some(cached_content_token_count); + } + if let Some(candidates_token_count) = new.candidates_token_count { + usage.candidates_token_count = Some(candidates_token_count); + } + if let Some(tool_use_prompt_token_count) = new.tool_use_prompt_token_count { + usage.tool_use_prompt_token_count = Some(tool_use_prompt_token_count); + } + if let Some(thoughts_token_count) = new.thoughts_token_count { + usage.thoughts_token_count = Some(thoughts_token_count); + } + if let Some(total_token_count) = new.total_token_count { + usage.total_token_count = Some(total_token_count); + } +} + +fn convert_usage(usage: &UsageMetadata) -> TokenUsage { + let prompt_tokens = usage.prompt_token_count.unwrap_or(0); + let cached_tokens = usage.cached_content_token_count.unwrap_or(0); + let input_tokens = prompt_tokens - cached_tokens; + let output_tokens = usage.candidates_token_count.unwrap_or(0); + + TokenUsage { + input_tokens, + output_tokens, + cache_read_input_tokens: cached_tokens, + cache_creation_input_tokens: 0, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + Content, FunctionCall, FunctionCallPart, GenerateContentCandidate, GenerateContentResponse, + Part, Role as GoogleRole, + }; + use serde_json::json; + + #[test] + fn test_function_call_with_signature_creates_tool_use_with_signature() { + let mut mapper = GoogleEventMapper::new(); + + let response = GenerateContentResponse { + candidates: Some(vec![GenerateContentCandidate { + index: Some(0), + content: Content { + parts: vec![Part::FunctionCallPart(FunctionCallPart { + function_call: FunctionCall { + name: "test_function".to_string(), + args: json!({"arg": "value"}), + }, + thought_signature: Some("test_signature_123".to_string()), + })], + role: GoogleRole::Model, + }, + finish_reason: None, + finish_message: None, + safety_ratings: None, + citation_metadata: None, + }]), + prompt_feedback: None, + usage_metadata: None, + }; + + let events = mapper.map_event(response); + assert_eq!(events.len(), 2); + + if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] { + assert_eq!(tool_use.name.as_ref(), "test_function"); + assert_eq!( + tool_use.thought_signature.as_deref(), + Some("test_signature_123") + ); + } else { + panic!("Expected ToolUse event"); + } + } + + #[test] + fn test_function_call_without_signature_has_none() { + let mut mapper = GoogleEventMapper::new(); + + let response = GenerateContentResponse { + candidates: Some(vec![GenerateContentCandidate { + index: Some(0), + content: Content { + parts: vec![Part::FunctionCallPart(FunctionCallPart { + function_call: FunctionCall { + name: "test_function".to_string(), + args: json!({"arg": "value"}), + }, + thought_signature: None, + })], + role: GoogleRole::Model, + }, + finish_reason: None, + finish_message: None, + safety_ratings: None, + citation_metadata: None, + }]), + prompt_feedback: None, + usage_metadata: None, + }; + + let events = mapper.map_event(response); + assert_eq!(events.len(), 2); + + if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] { + assert!(tool_use.thought_signature.is_none()); + } else { + panic!("Expected ToolUse event"); + } + } + + #[test] + fn test_empty_string_signature_normalized_to_none() { + let mut mapper = GoogleEventMapper::new(); + + let response = GenerateContentResponse { + candidates: Some(vec![GenerateContentCandidate { + index: Some(0), + content: Content { + parts: vec![Part::FunctionCallPart(FunctionCallPart { + function_call: FunctionCall { + name: "test_function".to_string(), + args: json!({"arg": "value"}), + }, + thought_signature: Some("".to_string()), + })], + role: GoogleRole::Model, + }, + finish_reason: None, + finish_message: None, + safety_ratings: None, + citation_metadata: None, + }]), + prompt_feedback: None, + usage_metadata: None, + }; + + let events = mapper.map_event(response); + if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] { + assert!(tool_use.thought_signature.is_none()); + } else { + panic!("Expected ToolUse event"); + } + } +} diff --git a/crates/google_ai/src/google_ai.rs b/crates/google_ai/src/google_ai.rs index 7659be8ab44da35efd16389c4abd0bf99d8cf3a4..5770c9a020b04bf280908993911b67ec3a5b980f 100644 --- a/crates/google_ai/src/google_ai.rs +++ b/crates/google_ai/src/google_ai.rs @@ -3,8 +3,9 @@ use std::mem; use anyhow::{Result, anyhow, bail}; use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream}; use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; +pub use language_model_core::ModelMode as GoogleModelMode; use serde::{Deserialize, Deserializer, Serialize, Serializer}; -pub use settings::ModelMode as GoogleModelMode; +pub mod completion; pub const API_URL: &str = "https://generativelanguage.googleapis.com"; diff --git a/crates/gpui/Cargo.toml b/crates/gpui/Cargo.toml index 915f0fc03e2cc5beaf40c810654724295c41cde8..efb4817ef0e0c037bc08d0c5a8ad702705cb996d 100644 --- a/crates/gpui/Cargo.toml +++ b/crates/gpui/Cargo.toml @@ -56,6 +56,7 @@ etagere = "0.2" futures.workspace = true futures-concurrency.workspace = true gpui_macros.workspace = true +gpui_shared_string.workspace = true http_client.workspace = true image.workspace = true inventory.workspace = true diff --git a/crates/gpui/src/gpui.rs b/crates/gpui/src/gpui.rs index 6d7d801cd42c3639d7892295a660319d21b05dfa..dbb57f46efc37678c07dfd4f02bb3faebc60c9a3 100644 --- a/crates/gpui/src/gpui.rs +++ b/crates/gpui/src/gpui.rs @@ -39,7 +39,6 @@ pub mod profiler; #[expect(missing_docs)] pub mod queue; mod scene; -mod shared_string; mod shared_uri; mod style; mod styled; @@ -92,6 +91,7 @@ pub use global::*; pub use gpui_macros::{ AppContext, IntoElement, Render, VisualContext, property_test, register_action, test, }; +pub use gpui_shared_string::*; pub use gpui_util::arc_cow::ArcCow; pub use http_client; pub use input::*; @@ -106,7 +106,6 @@ pub use profiler::*; pub use queue::{PriorityQueueReceiver, PriorityQueueSender}; pub use refineable::*; pub use scene::*; -pub use shared_string::*; pub use shared_uri::*; use std::{any::Any, future::Future}; pub use style::*; diff --git a/crates/gpui/src/text_system/line.rs b/crates/gpui/src/text_system/line.rs index 7b5714188ff97d0169806ac5da9f039f9be2c16a..611c979bc29f488fa18386c7b319a7310b6ce1c6 100644 --- a/crates/gpui/src/text_system/line.rs +++ b/crates/gpui/src/text_system/line.rs @@ -882,7 +882,7 @@ mod tests { ], len: 6, }), - text: SharedString::new("abcdef".to_string()), + text: "abcdef".into(), decoration_runs: SmallVec::new(), }; diff --git a/crates/gpui_shared_string/Cargo.toml b/crates/gpui_shared_string/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..4f7735b4f88253de7cd62d30445153d2a6284751 --- /dev/null +++ b/crates/gpui_shared_string/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "gpui_shared_string" +version = "0.1.0" +publish.workspace = true +edition.workspace = true + +[lib] +path = "gpui_shared_string.rs" + +[dependencies] +derive_more.workspace = true +gpui_util.workspace = true +schemars.workspace = true +serde.workspace = true + +[lints] +workspace = true diff --git a/crates/gpui_shared_string/LICENSE-APACHE b/crates/gpui_shared_string/LICENSE-APACHE new file mode 120000 index 0000000000000000000000000000000000000000..1cd601d0a3affae83854be02a0afdec3b7a9ec4d --- /dev/null +++ b/crates/gpui_shared_string/LICENSE-APACHE @@ -0,0 +1 @@ +../../LICENSE-APACHE \ No newline at end of file diff --git a/crates/gpui/src/shared_string.rs b/crates/gpui_shared_string/gpui_shared_string.rs similarity index 100% rename from crates/gpui/src/shared_string.rs rename to crates/gpui_shared_string/gpui_shared_string.rs diff --git a/crates/language_core/Cargo.toml b/crates/language_core/Cargo.toml index 4861632b4663c860706525c65cd8607133b3ec71..cd1143f61d3af1d3b72bb5bd3a23e53b27aa9aba 100644 --- a/crates/language_core/Cargo.toml +++ b/crates/language_core/Cargo.toml @@ -10,7 +10,7 @@ path = "src/language_core.rs" [dependencies] anyhow.workspace = true collections.workspace = true -gpui.workspace = true +gpui_shared_string.workspace = true log.workspace = true lsp.workspace = true parking_lot.workspace = true @@ -22,8 +22,6 @@ toml.workspace = true tree-sitter.workspace = true util.workspace = true -[dev-dependencies] -gpui = { workspace = true, features = ["test-support"] } [features] test-support = [] diff --git a/crates/language_core/src/diagnostic.rs b/crates/language_core/src/diagnostic.rs index 9a468a14b863a94ef23e00c3e15edd9fa2d8b09a..00abcb61d1b1290dd96c69b31296eebfd3900348 100644 --- a/crates/language_core/src/diagnostic.rs +++ b/crates/language_core/src/diagnostic.rs @@ -1,4 +1,4 @@ -use gpui::SharedString; +use gpui_shared_string::SharedString; use lsp::{DiagnosticSeverity, NumberOrString}; use serde::{Deserialize, Serialize}; use serde_json::Value; diff --git a/crates/language_core/src/grammar.rs b/crates/language_core/src/grammar.rs index 54e9a3f1b3309718436b206874802779925a9d04..44f73ac6dea235a522393b5b0bd10729999b45bf 100644 --- a/crates/language_core/src/grammar.rs +++ b/crates/language_core/src/grammar.rs @@ -4,7 +4,7 @@ use crate::{ }; use anyhow::{Context as _, Result}; use collections::HashMap; -use gpui::SharedString; +use gpui_shared_string::SharedString; use lsp::LanguageServerName; use parking_lot::Mutex; use std::sync::atomic::{AtomicUsize, Ordering::SeqCst}; diff --git a/crates/language_core/src/language_config.rs b/crates/language_core/src/language_config.rs index f412af418b7948b40e3bdac5a3a649d12d008e8a..89474dbad9171d37cfb1b7f55f70a137eeb535d5 100644 --- a/crates/language_core/src/language_config.rs +++ b/crates/language_core/src/language_config.rs @@ -1,6 +1,6 @@ use crate::LanguageName; use collections::{HashMap, HashSet, IndexSet}; -use gpui::SharedString; +use gpui_shared_string::SharedString; use lsp::LanguageServerName; use regex::Regex; use schemars::{JsonSchema, SchemaGenerator, json_schema}; diff --git a/crates/language_core/src/language_name.rs b/crates/language_core/src/language_name.rs index 764b54a48a566ad98212de3e22bce6aca9a1e393..14528435d9103b4faad3e055ea69bbdaf372113c 100644 --- a/crates/language_core/src/language_name.rs +++ b/crates/language_core/src/language_name.rs @@ -1,4 +1,4 @@ -use gpui::SharedString; +use gpui_shared_string::SharedString; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use std::{ diff --git a/crates/language_core/src/lsp_adapter.rs b/crates/language_core/src/lsp_adapter.rs index 03012f71143428b49ea9d75a03b0118b50e413b4..8f449637b306c2a33a76cb5b356d0280903f4187 100644 --- a/crates/language_core/src/lsp_adapter.rs +++ b/crates/language_core/src/lsp_adapter.rs @@ -1,4 +1,4 @@ -use gpui::SharedString; +use gpui_shared_string::SharedString; use serde::{Deserialize, Serialize}; /// Converts a value into an LSP position. diff --git a/crates/language_core/src/manifest.rs b/crates/language_core/src/manifest.rs index 1e762ff6e7c364eef02eea16ce9e1ecaaa198554..864f89e6cee65b0dff7c4462c99940c32ba0830f 100644 --- a/crates/language_core/src/manifest.rs +++ b/crates/language_core/src/manifest.rs @@ -1,6 +1,6 @@ use std::borrow::Borrow; -use gpui::SharedString; +use gpui_shared_string::SharedString; #[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] pub struct ManifestName(SharedString); diff --git a/crates/language_core/src/toolchain.rs b/crates/language_core/src/toolchain.rs index a021cb86bd36295a065b16281209c5fc3b63cffc..78bd69917fbc0f66af454ba262c1eb3b7c357290 100644 --- a/crates/language_core/src/toolchain.rs +++ b/crates/language_core/src/toolchain.rs @@ -6,7 +6,7 @@ use std::{path::Path, sync::Arc}; -use gpui::SharedString; +use gpui_shared_string::SharedString; use util::rel_path::RelPath; use crate::{LanguageName, ManifestName}; diff --git a/crates/language_model/Cargo.toml b/crates/language_model/Cargo.toml index 4712d86dff6c44f9cdd8576a08349ccfa7d0ecca..d679588138ccec0f8d9fd830d26d13f2f65d44a3 100644 --- a/crates/language_model/Cargo.toml +++ b/crates/language_model/Cargo.toml @@ -16,13 +16,9 @@ doctest = false test-support = [] [dependencies] -anthropic = { workspace = true, features = ["schemars"] } anyhow.workspace = true credentials_provider.workspace = true base64.workspace = true -cloud_api_client.workspace = true -cloud_api_types.workspace = true -cloud_llm_client.workspace = true collections.workspace = true env_var.workspace = true futures.workspace = true @@ -30,14 +26,11 @@ gpui.workspace = true http_client.workspace = true icons.workspace = true image.workspace = true +language_model_core.workspace = true log.workspace = true -open_ai = { workspace = true, features = ["schemars"] } -open_router.workspace = true parking_lot.workspace = true -schemars.workspace = true serde.workspace = true serde_json.workspace = true -smol.workspace = true thiserror.workspace = true util.workspace = true diff --git a/crates/language_model/src/fake_provider.rs b/crates/language_model/src/fake_provider.rs index 50037f31facbac446de7ecf38536d1e4a24c7867..cee65c21e575e7c96579c271805386527a29d4da 100644 --- a/crates/language_model/src/fake_provider.rs +++ b/crates/language_model/src/fake_provider.rs @@ -5,11 +5,10 @@ use crate::{ LanguageModelRequest, LanguageModelToolChoice, }; use anyhow::anyhow; -use futures::{FutureExt, channel::mpsc, future::BoxFuture, stream::BoxStream}; +use futures::{FutureExt, channel::mpsc, future::BoxFuture, stream::BoxStream, stream::StreamExt}; use gpui::{AnyView, App, AsyncApp, Entity, Task, Window}; use http_client::Result; use parking_lot::Mutex; -use smol::stream::StreamExt; use std::sync::{ Arc, atomic::{AtomicBool, Ordering::SeqCst}, diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index 3f309b7b1d4152c54324efaaf0ad3bdb7035eea4..60e8228fec52ffee763e19541f042ce47246dad2 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -1,380 +1,31 @@ mod api_key; mod model; -mod provider; -mod rate_limiter; mod registry; mod request; -mod role; -pub mod tool_schema; #[cfg(any(test, feature = "test-support"))] pub mod fake_provider; -use anyhow::{Result, anyhow}; -use cloud_llm_client::CompletionRequestStatus; +pub use language_model_core::*; + +use anyhow::Result; use futures::FutureExt; use futures::{StreamExt, future::BoxFuture, stream::BoxStream}; -use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window}; -use http_client::{StatusCode, http}; +use gpui::{AnyView, App, AsyncApp, Task, Window}; use icons::IconName; use parking_lot::Mutex; -use serde::{Deserialize, Serialize}; -use std::ops::{Add, Sub}; -use std::str::FromStr; use std::sync::Arc; -use std::time::Duration; -use std::{fmt, io}; -use thiserror::Error; -use util::serde::is_default; pub use crate::api_key::{ApiKey, ApiKeyState}; pub use crate::model::*; -pub use crate::rate_limiter::*; pub use crate::registry::*; -pub use crate::request::*; -pub use crate::role::*; -pub use crate::tool_schema::LanguageModelToolSchemaFormat; +pub use crate::request::{LanguageModelImageExt, gpui_size_to_image_size, image_size_to_gpui}; pub use env_var::{EnvVar, env_var}; -pub use provider::*; pub fn init(cx: &mut App) { registry::init(cx); } -#[derive(Clone, Debug)] -pub struct LanguageModelCacheConfiguration { - pub max_cache_anchors: usize, - pub should_speculate: bool, - pub min_total_token: u64, -} - -/// A completion event from a language model. -#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] -pub enum LanguageModelCompletionEvent { - Queued { - position: usize, - }, - Started, - Stop(StopReason), - Text(String), - Thinking { - text: String, - signature: Option, - }, - RedactedThinking { - data: String, - }, - ToolUse(LanguageModelToolUse), - ToolUseJsonParseError { - id: LanguageModelToolUseId, - tool_name: Arc, - raw_input: Arc, - json_parse_error: String, - }, - StartMessage { - message_id: String, - }, - ReasoningDetails(serde_json::Value), - UsageUpdate(TokenUsage), -} - -impl LanguageModelCompletionEvent { - pub fn from_completion_request_status( - status: CompletionRequestStatus, - upstream_provider: LanguageModelProviderName, - ) -> Result, LanguageModelCompletionError> { - match status { - CompletionRequestStatus::Queued { position } => { - Ok(Some(LanguageModelCompletionEvent::Queued { position })) - } - CompletionRequestStatus::Started => Ok(Some(LanguageModelCompletionEvent::Started)), - CompletionRequestStatus::Unknown | CompletionRequestStatus::StreamEnded => Ok(None), - CompletionRequestStatus::Failed { - code, - message, - request_id: _, - retry_after, - } => Err(LanguageModelCompletionError::from_cloud_failure( - upstream_provider, - code, - message, - retry_after.map(Duration::from_secs_f64), - )), - } - } -} - -#[derive(Error, Debug)] -pub enum LanguageModelCompletionError { - #[error("prompt too large for context window")] - PromptTooLarge { tokens: Option }, - #[error("missing {provider} API key")] - NoApiKey { provider: LanguageModelProviderName }, - #[error("{provider}'s API rate limit exceeded")] - RateLimitExceeded { - provider: LanguageModelProviderName, - retry_after: Option, - }, - #[error("{provider}'s API servers are overloaded right now")] - ServerOverloaded { - provider: LanguageModelProviderName, - retry_after: Option, - }, - #[error("{provider}'s API server reported an internal server error: {message}")] - ApiInternalServerError { - provider: LanguageModelProviderName, - message: String, - }, - #[error("{message}")] - UpstreamProviderError { - message: String, - status: StatusCode, - retry_after: Option, - }, - #[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")] - HttpResponseError { - provider: LanguageModelProviderName, - status_code: StatusCode, - message: String, - }, - - // Client errors - #[error("invalid request format to {provider}'s API: {message}")] - BadRequestFormat { - provider: LanguageModelProviderName, - message: String, - }, - #[error("authentication error with {provider}'s API: {message}")] - AuthenticationError { - provider: LanguageModelProviderName, - message: String, - }, - #[error("Permission error with {provider}'s API: {message}")] - PermissionError { - provider: LanguageModelProviderName, - message: String, - }, - #[error("language model provider API endpoint not found")] - ApiEndpointNotFound { provider: LanguageModelProviderName }, - #[error("I/O error reading response from {provider}'s API")] - ApiReadResponseError { - provider: LanguageModelProviderName, - #[source] - error: io::Error, - }, - #[error("error serializing request to {provider} API")] - SerializeRequest { - provider: LanguageModelProviderName, - #[source] - error: serde_json::Error, - }, - #[error("error building request body to {provider} API")] - BuildRequestBody { - provider: LanguageModelProviderName, - #[source] - error: http::Error, - }, - #[error("error sending HTTP request to {provider} API")] - HttpSend { - provider: LanguageModelProviderName, - #[source] - error: anyhow::Error, - }, - #[error("error deserializing {provider} API response")] - DeserializeResponse { - provider: LanguageModelProviderName, - #[source] - error: serde_json::Error, - }, - - #[error("stream from {provider} ended unexpectedly")] - StreamEndedUnexpectedly { provider: LanguageModelProviderName }, - - // TODO: Ideally this would be removed in favor of having a comprehensive list of errors. - #[error(transparent)] - Other(#[from] anyhow::Error), -} - -impl LanguageModelCompletionError { - fn parse_upstream_error_json(message: &str) -> Option<(StatusCode, String)> { - let error_json = serde_json::from_str::(message).ok()?; - let upstream_status = error_json - .get("upstream_status") - .and_then(|v| v.as_u64()) - .and_then(|status| u16::try_from(status).ok()) - .and_then(|status| StatusCode::from_u16(status).ok())?; - let inner_message = error_json - .get("message") - .and_then(|v| v.as_str()) - .unwrap_or(message) - .to_string(); - Some((upstream_status, inner_message)) - } - - pub fn from_cloud_failure( - upstream_provider: LanguageModelProviderName, - code: String, - message: String, - retry_after: Option, - ) -> Self { - if let Some(tokens) = parse_prompt_too_long(&message) { - // TODO: currently Anthropic PAYLOAD_TOO_LARGE response may cause INTERNAL_SERVER_ERROR - // to be reported. This is a temporary workaround to handle this in the case where the - // token limit has been exceeded. - Self::PromptTooLarge { - tokens: Some(tokens), - } - } else if code == "upstream_http_error" { - if let Some((upstream_status, inner_message)) = - Self::parse_upstream_error_json(&message) - { - return Self::from_http_status( - upstream_provider, - upstream_status, - inner_message, - retry_after, - ); - } - anyhow!("completion request failed, code: {code}, message: {message}").into() - } else if let Some(status_code) = code - .strip_prefix("upstream_http_") - .and_then(|code| StatusCode::from_str(code).ok()) - { - Self::from_http_status(upstream_provider, status_code, message, retry_after) - } else if let Some(status_code) = code - .strip_prefix("http_") - .and_then(|code| StatusCode::from_str(code).ok()) - { - Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after) - } else { - anyhow!("completion request failed, code: {code}, message: {message}").into() - } - } - - pub fn from_http_status( - provider: LanguageModelProviderName, - status_code: StatusCode, - message: String, - retry_after: Option, - ) -> Self { - match status_code { - StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message }, - StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message }, - StatusCode::FORBIDDEN => Self::PermissionError { provider, message }, - StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider }, - StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge { - tokens: parse_prompt_too_long(&message), - }, - StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded { - provider, - retry_after, - }, - StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message }, - StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded { - provider, - retry_after, - }, - _ if status_code.as_u16() == 529 => Self::ServerOverloaded { - provider, - retry_after, - }, - _ => Self::HttpResponseError { - provider, - status_code, - message, - }, - } - } -} - -#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)] -#[serde(rename_all = "snake_case")] -pub enum StopReason { - EndTurn, - MaxTokens, - ToolUse, - Refusal, -} - -#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)] -pub struct TokenUsage { - #[serde(default, skip_serializing_if = "is_default")] - pub input_tokens: u64, - #[serde(default, skip_serializing_if = "is_default")] - pub output_tokens: u64, - #[serde(default, skip_serializing_if = "is_default")] - pub cache_creation_input_tokens: u64, - #[serde(default, skip_serializing_if = "is_default")] - pub cache_read_input_tokens: u64, -} - -impl TokenUsage { - pub fn total_tokens(&self) -> u64 { - self.input_tokens - + self.output_tokens - + self.cache_read_input_tokens - + self.cache_creation_input_tokens - } -} - -impl Add for TokenUsage { - type Output = Self; - - fn add(self, other: Self) -> Self { - Self { - input_tokens: self.input_tokens + other.input_tokens, - output_tokens: self.output_tokens + other.output_tokens, - cache_creation_input_tokens: self.cache_creation_input_tokens - + other.cache_creation_input_tokens, - cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens, - } - } -} - -impl Sub for TokenUsage { - type Output = Self; - - fn sub(self, other: Self) -> Self { - Self { - input_tokens: self.input_tokens - other.input_tokens, - output_tokens: self.output_tokens - other.output_tokens, - cache_creation_input_tokens: self.cache_creation_input_tokens - - other.cache_creation_input_tokens, - cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens, - } - } -} - -#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)] -pub struct LanguageModelToolUseId(Arc); - -impl fmt::Display for LanguageModelToolUseId { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.0) - } -} - -impl From for LanguageModelToolUseId -where - T: Into>, -{ - fn from(value: T) -> Self { - Self(value.into()) - } -} - -#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)] -pub struct LanguageModelToolUse { - pub id: LanguageModelToolUseId, - pub name: Arc, - pub raw_input: String, - pub input: serde_json::Value, - pub is_input_complete: bool, - /// Thought signature the model sent us. Some models require that this - /// signature be preserved and sent back in conversation history for validation. - pub thought_signature: Option, -} - pub struct LanguageModelTextStream { pub message_id: Option, pub stream: BoxStream<'static, Result>, @@ -392,13 +43,6 @@ impl Default for LanguageModelTextStream { } } -#[derive(Debug, Clone)] -pub struct LanguageModelEffortLevel { - pub name: SharedString, - pub value: SharedString, - pub is_default: bool, -} - pub trait LanguageModel: Send + Sync { fn id(&self) -> LanguageModelId; fn name(&self) -> LanguageModelName; @@ -605,7 +249,7 @@ pub trait LanguageModel: Send + Sync { } impl std::fmt::Debug for dyn LanguageModel { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("") .field("id", &self.id()) .field("name", &self.name()) @@ -619,17 +263,6 @@ impl std::fmt::Debug for dyn LanguageModel { } } -/// An error that occurred when trying to authenticate the language model provider. -#[derive(Debug, Error)] -pub enum AuthenticateError { - #[error("connection refused")] - ConnectionRefused, - #[error("credentials not found")] - CredentialsNotFound, - #[error(transparent)] - Other(#[from] anyhow::Error), -} - /// Either a built-in icon name or a path to an external SVG. #[derive(Debug, Clone, PartialEq, Eq)] pub enum IconOrSvg { @@ -692,18 +325,6 @@ pub trait LanguageModelProviderState: 'static { } } -#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)] -pub struct LanguageModelId(pub SharedString); - -#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)] -pub struct LanguageModelName(pub SharedString); - -#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)] -pub struct LanguageModelProviderId(pub SharedString); - -#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)] -pub struct LanguageModelProviderName(pub SharedString); - #[derive(Clone, Debug, PartialEq)] pub enum LanguageModelCostInfo { /// Cost per 1,000 input and output tokens @@ -741,245 +362,3 @@ impl LanguageModelCostInfo { } } } - -impl LanguageModelProviderId { - pub const fn new(id: &'static str) -> Self { - Self(SharedString::new_static(id)) - } -} - -impl LanguageModelProviderName { - pub const fn new(id: &'static str) -> Self { - Self(SharedString::new_static(id)) - } -} - -impl fmt::Display for LanguageModelProviderId { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.0) - } -} - -impl fmt::Display for LanguageModelProviderName { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.0) - } -} - -impl From for LanguageModelId { - fn from(value: String) -> Self { - Self(SharedString::from(value)) - } -} - -impl From for LanguageModelName { - fn from(value: String) -> Self { - Self(SharedString::from(value)) - } -} - -impl From for LanguageModelProviderId { - fn from(value: String) -> Self { - Self(SharedString::from(value)) - } -} - -impl From for LanguageModelProviderName { - fn from(value: String) -> Self { - Self(SharedString::from(value)) - } -} - -impl From> for LanguageModelProviderId { - fn from(value: Arc) -> Self { - Self(SharedString::from(value)) - } -} - -impl From> for LanguageModelProviderName { - fn from(value: Arc) -> Self { - Self(SharedString::from(value)) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_from_cloud_failure_with_upstream_http_error() { - let error = LanguageModelCompletionError::from_cloud_failure( - String::from("anthropic").into(), - "upstream_http_error".to_string(), - r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(), - None, - ); - - match error { - LanguageModelCompletionError::ServerOverloaded { provider, .. } => { - assert_eq!(provider.0, "anthropic"); - } - _ => panic!( - "Expected ServerOverloaded error for 503 status, got: {:?}", - error - ), - } - - let error = LanguageModelCompletionError::from_cloud_failure( - String::from("anthropic").into(), - "upstream_http_error".to_string(), - r#"{"code":"upstream_http_error","message":"Internal server error","upstream_status":500}"#.to_string(), - None, - ); - - match error { - LanguageModelCompletionError::ApiInternalServerError { provider, message } => { - assert_eq!(provider.0, "anthropic"); - assert_eq!(message, "Internal server error"); - } - _ => panic!( - "Expected ApiInternalServerError for 500 status, got: {:?}", - error - ), - } - } - - #[test] - fn test_from_cloud_failure_with_standard_format() { - let error = LanguageModelCompletionError::from_cloud_failure( - String::from("anthropic").into(), - "upstream_http_503".to_string(), - "Service unavailable".to_string(), - None, - ); - - match error { - LanguageModelCompletionError::ServerOverloaded { provider, .. } => { - assert_eq!(provider.0, "anthropic"); - } - _ => panic!("Expected ServerOverloaded error for upstream_http_503"), - } - } - - #[test] - fn test_upstream_http_error_connection_timeout() { - let error = LanguageModelCompletionError::from_cloud_failure( - String::from("anthropic").into(), - "upstream_http_error".to_string(), - r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(), - None, - ); - - match error { - LanguageModelCompletionError::ServerOverloaded { provider, .. } => { - assert_eq!(provider.0, "anthropic"); - } - _ => panic!( - "Expected ServerOverloaded error for connection timeout with 503 status, got: {:?}", - error - ), - } - - let error = LanguageModelCompletionError::from_cloud_failure( - String::from("anthropic").into(), - "upstream_http_error".to_string(), - r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":500}"#.to_string(), - None, - ); - - match error { - LanguageModelCompletionError::ApiInternalServerError { provider, message } => { - assert_eq!(provider.0, "anthropic"); - assert_eq!( - message, - "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout" - ); - } - _ => panic!( - "Expected ApiInternalServerError for connection timeout with 500 status, got: {:?}", - error - ), - } - } - - #[test] - fn test_language_model_tool_use_serializes_with_signature() { - use serde_json::json; - - let tool_use = LanguageModelToolUse { - id: LanguageModelToolUseId::from("test_id"), - name: "test_tool".into(), - raw_input: json!({"arg": "value"}).to_string(), - input: json!({"arg": "value"}), - is_input_complete: true, - thought_signature: Some("test_signature".to_string()), - }; - - let serialized = serde_json::to_value(&tool_use).unwrap(); - - assert_eq!(serialized["id"], "test_id"); - assert_eq!(serialized["name"], "test_tool"); - assert_eq!(serialized["thought_signature"], "test_signature"); - } - - #[test] - fn test_language_model_tool_use_deserializes_with_missing_signature() { - use serde_json::json; - - let json = json!({ - "id": "test_id", - "name": "test_tool", - "raw_input": "{\"arg\":\"value\"}", - "input": {"arg": "value"}, - "is_input_complete": true - }); - - let tool_use: LanguageModelToolUse = serde_json::from_value(json).unwrap(); - - assert_eq!(tool_use.id, LanguageModelToolUseId::from("test_id")); - assert_eq!(tool_use.name.as_ref(), "test_tool"); - assert_eq!(tool_use.thought_signature, None); - } - - #[test] - fn test_language_model_tool_use_round_trip_with_signature() { - use serde_json::json; - - let original = LanguageModelToolUse { - id: LanguageModelToolUseId::from("round_trip_id"), - name: "round_trip_tool".into(), - raw_input: json!({"key": "value"}).to_string(), - input: json!({"key": "value"}), - is_input_complete: true, - thought_signature: Some("round_trip_sig".to_string()), - }; - - let serialized = serde_json::to_value(&original).unwrap(); - let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap(); - - assert_eq!(deserialized.id, original.id); - assert_eq!(deserialized.name, original.name); - assert_eq!(deserialized.thought_signature, original.thought_signature); - } - - #[test] - fn test_language_model_tool_use_round_trip_without_signature() { - use serde_json::json; - - let original = LanguageModelToolUse { - id: LanguageModelToolUseId::from("no_sig_id"), - name: "no_sig_tool".into(), - raw_input: json!({"arg": "value"}).to_string(), - input: json!({"arg": "value"}), - is_input_complete: true, - thought_signature: None, - }; - - let serialized = serde_json::to_value(&original).unwrap(); - let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap(); - - assert_eq!(deserialized.id, original.id); - assert_eq!(deserialized.name, original.name); - assert_eq!(deserialized.thought_signature, None); - } -} diff --git a/crates/language_model/src/model/cloud_model.rs b/crates/language_model/src/model/cloud_model.rs index db926aab1f70a46a4e70b1b67c2c9e4c4f465c2c..8cd71928b10fb1e86f3df40ca118305c198c094f 100644 --- a/crates/language_model/src/model/cloud_model.rs +++ b/crates/language_model/src/model/cloud_model.rs @@ -1,10 +1,5 @@ use std::fmt; -use std::sync::Arc; -use cloud_api_client::ClientApiError; -use cloud_api_client::CloudApiClient; -use cloud_api_types::OrganizationId; -use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard}; use thiserror::Error; #[derive(Error, Debug)] @@ -18,71 +13,3 @@ impl fmt::Display for PaymentRequiredError { ) } } - -#[derive(Clone, Default)] -pub struct LlmApiToken(Arc>>); - -impl LlmApiToken { - pub async fn acquire( - &self, - client: &CloudApiClient, - system_id: Option, - organization_id: Option, - ) -> Result { - let lock = self.0.upgradable_read().await; - if let Some(token) = lock.as_ref() { - Ok(token.to_string()) - } else { - Self::fetch( - RwLockUpgradableReadGuard::upgrade(lock).await, - client, - system_id, - organization_id, - ) - .await - } - } - - pub async fn refresh( - &self, - client: &CloudApiClient, - system_id: Option, - organization_id: Option, - ) -> Result { - Self::fetch(self.0.write().await, client, system_id, organization_id).await - } - - /// Clears the existing token before attempting to fetch a new one. - /// - /// Used when switching organizations so that a failed refresh doesn't - /// leave a token for the wrong organization. - pub async fn clear_and_refresh( - &self, - client: &CloudApiClient, - system_id: Option, - organization_id: Option, - ) -> Result { - let mut lock = self.0.write().await; - *lock = None; - Self::fetch(lock, client, system_id, organization_id).await - } - - async fn fetch( - mut lock: RwLockWriteGuard<'_, Option>, - client: &CloudApiClient, - system_id: Option, - organization_id: Option, - ) -> Result { - let result = client.create_llm_token(system_id, organization_id).await; - match result { - Ok(response) => { - *lock = Some(response.token.0.clone()); - Ok(response.token.0) - } - Err(err) => { - *lock = None; - Err(err) - } - } - } -} diff --git a/crates/language_model/src/provider.rs b/crates/language_model/src/provider.rs deleted file mode 100644 index 707d8e2d618894e2898e253450dbfbb5e9483bba..0000000000000000000000000000000000000000 --- a/crates/language_model/src/provider.rs +++ /dev/null @@ -1,12 +0,0 @@ -pub mod anthropic; -pub mod google; -pub mod open_ai; -pub mod open_router; -pub mod x_ai; -pub mod zed; - -pub use anthropic::*; -pub use google::*; -pub use open_ai::*; -pub use x_ai::*; -pub use zed::*; diff --git a/crates/language_model/src/provider/anthropic.rs b/crates/language_model/src/provider/anthropic.rs deleted file mode 100644 index 0878be2070fdbb9e57145684f59c962a32bb9fd2..0000000000000000000000000000000000000000 --- a/crates/language_model/src/provider/anthropic.rs +++ /dev/null @@ -1,80 +0,0 @@ -use crate::{LanguageModelCompletionError, LanguageModelProviderId, LanguageModelProviderName}; -use anthropic::AnthropicError; -pub use anthropic::parse_prompt_too_long; - -pub const ANTHROPIC_PROVIDER_ID: LanguageModelProviderId = - LanguageModelProviderId::new("anthropic"); -pub const ANTHROPIC_PROVIDER_NAME: LanguageModelProviderName = - LanguageModelProviderName::new("Anthropic"); - -impl From for LanguageModelCompletionError { - fn from(error: AnthropicError) -> Self { - let provider = ANTHROPIC_PROVIDER_NAME; - match error { - AnthropicError::SerializeRequest(error) => Self::SerializeRequest { provider, error }, - AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error }, - AnthropicError::HttpSend(error) => Self::HttpSend { provider, error }, - AnthropicError::DeserializeResponse(error) => { - Self::DeserializeResponse { provider, error } - } - AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { provider, error }, - AnthropicError::HttpResponseError { - status_code, - message, - } => Self::HttpResponseError { - provider, - status_code, - message, - }, - AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded { - provider, - retry_after: Some(retry_after), - }, - AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded { - provider, - retry_after, - }, - AnthropicError::ApiError(api_error) => api_error.into(), - } - } -} - -impl From for LanguageModelCompletionError { - fn from(error: anthropic::ApiError) -> Self { - use anthropic::ApiErrorCode::*; - let provider = ANTHROPIC_PROVIDER_NAME; - match error.code() { - Some(code) => match code { - InvalidRequestError => Self::BadRequestFormat { - provider, - message: error.message, - }, - AuthenticationError => Self::AuthenticationError { - provider, - message: error.message, - }, - PermissionError => Self::PermissionError { - provider, - message: error.message, - }, - NotFoundError => Self::ApiEndpointNotFound { provider }, - RequestTooLarge => Self::PromptTooLarge { - tokens: parse_prompt_too_long(&error.message), - }, - RateLimitError => Self::RateLimitExceeded { - provider, - retry_after: None, - }, - ApiError => Self::ApiInternalServerError { - provider, - message: error.message, - }, - OverloadedError => Self::ServerOverloaded { - provider, - retry_after: None, - }, - }, - None => Self::Other(error.into()), - } - } -} diff --git a/crates/language_model/src/provider/google.rs b/crates/language_model/src/provider/google.rs deleted file mode 100644 index 1caee496b519f395dd10744b127bc29ee893849f..0000000000000000000000000000000000000000 --- a/crates/language_model/src/provider/google.rs +++ /dev/null @@ -1,5 +0,0 @@ -use crate::{LanguageModelProviderId, LanguageModelProviderName}; - -pub const GOOGLE_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("google"); -pub const GOOGLE_PROVIDER_NAME: LanguageModelProviderName = - LanguageModelProviderName::new("Google AI"); diff --git a/crates/language_model/src/provider/open_ai.rs b/crates/language_model/src/provider/open_ai.rs deleted file mode 100644 index 3796eb9a3aef78628c52d92e92fabb3812249e04..0000000000000000000000000000000000000000 --- a/crates/language_model/src/provider/open_ai.rs +++ /dev/null @@ -1,28 +0,0 @@ -use crate::{LanguageModelCompletionError, LanguageModelProviderId, LanguageModelProviderName}; -use http_client::http; -use std::time::Duration; - -pub const OPEN_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openai"); -pub const OPEN_AI_PROVIDER_NAME: LanguageModelProviderName = - LanguageModelProviderName::new("OpenAI"); - -impl From for LanguageModelCompletionError { - fn from(error: open_ai::RequestError) -> Self { - match error { - open_ai::RequestError::HttpResponseError { - provider, - status_code, - body, - headers, - } => { - let retry_after = headers - .get(http::header::RETRY_AFTER) - .and_then(|val| val.to_str().ok()?.parse::().ok()) - .map(Duration::from_secs); - - Self::from_http_status(provider.into(), status_code, body, retry_after) - } - open_ai::RequestError::Other(e) => Self::Other(e), - } - } -} diff --git a/crates/language_model/src/provider/open_router.rs b/crates/language_model/src/provider/open_router.rs deleted file mode 100644 index 809e22f1fec0f2d205caa3ebbcb0baaf129b062c..0000000000000000000000000000000000000000 --- a/crates/language_model/src/provider/open_router.rs +++ /dev/null @@ -1,69 +0,0 @@ -use crate::{LanguageModelCompletionError, LanguageModelProviderName}; -use http_client::StatusCode; -use open_router::OpenRouterError; - -impl From for LanguageModelCompletionError { - fn from(error: OpenRouterError) -> Self { - let provider = LanguageModelProviderName::new("OpenRouter"); - match error { - OpenRouterError::SerializeRequest(error) => Self::SerializeRequest { provider, error }, - OpenRouterError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error }, - OpenRouterError::HttpSend(error) => Self::HttpSend { provider, error }, - OpenRouterError::DeserializeResponse(error) => { - Self::DeserializeResponse { provider, error } - } - OpenRouterError::ReadResponse(error) => Self::ApiReadResponseError { provider, error }, - OpenRouterError::RateLimit { retry_after } => Self::RateLimitExceeded { - provider, - retry_after: Some(retry_after), - }, - OpenRouterError::ServerOverloaded { retry_after } => Self::ServerOverloaded { - provider, - retry_after, - }, - OpenRouterError::ApiError(api_error) => api_error.into(), - } - } -} - -impl From for LanguageModelCompletionError { - fn from(error: open_router::ApiError) -> Self { - use open_router::ApiErrorCode::*; - let provider = LanguageModelProviderName::new("OpenRouter"); - match error.code { - InvalidRequestError => Self::BadRequestFormat { - provider, - message: error.message, - }, - AuthenticationError => Self::AuthenticationError { - provider, - message: error.message, - }, - PaymentRequiredError => Self::AuthenticationError { - provider, - message: format!("Payment required: {}", error.message), - }, - PermissionError => Self::PermissionError { - provider, - message: error.message, - }, - RequestTimedOut => Self::HttpResponseError { - provider, - status_code: StatusCode::REQUEST_TIMEOUT, - message: error.message, - }, - RateLimitError => Self::RateLimitExceeded { - provider, - retry_after: None, - }, - ApiError => Self::ApiInternalServerError { - provider, - message: error.message, - }, - OverloadedError => Self::ServerOverloaded { - provider, - retry_after: None, - }, - } - } -} diff --git a/crates/language_model/src/provider/x_ai.rs b/crates/language_model/src/provider/x_ai.rs deleted file mode 100644 index 3d0f794fa4087a4beeb4a9b6253d016a9b592f0e..0000000000000000000000000000000000000000 --- a/crates/language_model/src/provider/x_ai.rs +++ /dev/null @@ -1,4 +0,0 @@ -use crate::{LanguageModelProviderId, LanguageModelProviderName}; - -pub const X_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("x_ai"); -pub const X_AI_PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("xAI"); diff --git a/crates/language_model/src/provider/zed.rs b/crates/language_model/src/provider/zed.rs deleted file mode 100644 index 0ba793e99aad1caa25f049a96faf02c16e8970fa..0000000000000000000000000000000000000000 --- a/crates/language_model/src/provider/zed.rs +++ /dev/null @@ -1,5 +0,0 @@ -use crate::{LanguageModelProviderId, LanguageModelProviderName}; - -pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("zed.dev"); -pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName = - LanguageModelProviderName::new("Zed"); diff --git a/crates/language_model/src/registry.rs b/crates/language_model/src/registry.rs index bf14fbb0b5804505b33074e6e4cbcc36ddf21fab..680078808ab33cc2a90caead8b304326beccf11b 100644 --- a/crates/language_model/src/registry.rs +++ b/crates/language_model/src/registry.rs @@ -1,6 +1,6 @@ use crate::{ LanguageModel, LanguageModelId, LanguageModelProvider, LanguageModelProviderId, - LanguageModelProviderState, + LanguageModelProviderState, ZED_CLOUD_PROVIDER_ID, }; use collections::{BTreeMap, HashSet}; use gpui::{App, Context, Entity, EventEmitter, Global, prelude::*}; @@ -101,7 +101,7 @@ impl ConfiguredModel { } pub fn is_provided_by_zed(&self) -> bool { - self.provider.id() == crate::provider::ZED_CLOUD_PROVIDER_ID + self.provider.id() == ZED_CLOUD_PROVIDER_ID } } diff --git a/crates/language_model/src/request.rs b/crates/language_model/src/request.rs index 9a5e96078cd4d952185261c79032c5c5fdf30060..ef73864fe3e2f5b58e73dec848c686123a61fcde 100644 --- a/crates/language_model/src/request.rs +++ b/crates/language_model/src/request.rs @@ -4,78 +4,13 @@ use std::sync::Arc; use anyhow::Result; use base64::write::EncoderWriter; use gpui::{ - App, AppContext as _, DevicePixels, Image, ImageFormat, ObjectFit, SharedString, Size, Task, - point, px, size, + App, AppContext as _, DevicePixels, Image, ImageFormat, ObjectFit, Size, Task, point, px, size, }; use image::GenericImageView as _; use image::codecs::png::PngEncoder; -use serde::{Deserialize, Serialize}; use util::ResultExt; -use crate::role::Role; -use crate::{LanguageModelToolUse, LanguageModelToolUseId}; - -#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Hash)] -pub struct LanguageModelImage { - /// A base64-encoded PNG image. - pub source: SharedString, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub size: Option>, -} - -impl LanguageModelImage { - pub fn len(&self) -> usize { - self.source.len() - } - - pub fn is_empty(&self) -> bool { - self.source.is_empty() - } - - // Parse Self from a JSON object with case-insensitive field names - pub fn from_json(obj: &serde_json::Map) -> Option { - let mut source = None; - let mut size_obj = None; - - // Find source and size fields (case-insensitive) - for (k, v) in obj.iter() { - match k.to_lowercase().as_str() { - "source" => source = v.as_str(), - "size" => size_obj = v.as_object(), - _ => {} - } - } - - let source = source?; - let size_obj = size_obj?; - - let mut width = None; - let mut height = None; - - // Find width and height in size object (case-insensitive) - for (k, v) in size_obj.iter() { - match k.to_lowercase().as_str() { - "width" => width = v.as_i64().map(|w| w as i32), - "height" => height = v.as_i64().map(|h| h as i32), - _ => {} - } - } - - Some(Self { - size: Some(size(DevicePixels(width?), DevicePixels(height?))), - source: SharedString::from(source.to_string()), - }) - } -} - -impl std::fmt::Debug for LanguageModelImage { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("LanguageModelImage") - .field("source", &format!("<{} bytes>", self.source.len())) - .field("size", &self.size) - .finish() - } -} +use language_model_core::{ImageSize, LanguageModelImage}; /// Anthropic wants uploaded images to be smaller than this in both dimensions. const ANTHROPIC_SIZE_LIMIT: f32 = 1568.; @@ -90,18 +25,16 @@ const DEFAULT_IMAGE_MAX_BYTES: usize = 5 * 1024 * 1024; /// `DEFAULT_IMAGE_MAX_BYTES`. const MAX_IMAGE_DOWNSCALE_PASSES: usize = 8; -impl LanguageModelImage { - // All language model images are encoded as PNGs. - pub const FORMAT: ImageFormat = ImageFormat::Png; +/// Extension trait for `LanguageModelImage` that provides GPUI-dependent functionality. +pub trait LanguageModelImageExt { + const FORMAT: ImageFormat; + fn from_image(data: Arc, cx: &mut App) -> Task>; +} - pub fn empty() -> Self { - Self { - source: "".into(), - size: None, - } - } +impl LanguageModelImageExt for LanguageModelImage { + const FORMAT: ImageFormat = ImageFormat::Png; - pub fn from_image(data: Arc, cx: &mut App) -> Task> { + fn from_image(data: Arc, cx: &mut App) -> Task> { cx.background_spawn(async move { let image_bytes = Cursor::new(data.bytes()); let dynamic_image = match data.format() { @@ -186,28 +119,14 @@ impl LanguageModelImage { let source = unsafe { String::from_utf8_unchecked(base64_image) }; Some(LanguageModelImage { - size: Some(image_size), + size: Some(ImageSize { + width: width as i32, + height: height as i32, + }), source: source.into(), }) }) } - - pub fn estimate_tokens(&self) -> usize { - let Some(size) = self.size.as_ref() else { - return 0; - }; - let width = size.width.0.unsigned_abs() as usize; - let height = size.height.0.unsigned_abs() as usize; - - // From: https://docs.anthropic.com/en/docs/build-with-claude/vision#calculate-image-costs - // Note that are a lot of conditions on Anthropic's API, and OpenAI doesn't use this, - // so this method is more of a rough guess. - (width * height) / 750 - } - - pub fn to_base64_url(&self) -> String { - format!("data:image/png;base64,{}", self.source) - } } fn encode_png_bytes(image: &image::DynamicImage) -> Result> { @@ -228,512 +147,85 @@ fn encode_bytes_as_base64(bytes: &[u8]) -> Result> { Ok(base64_image) } -#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)] -pub struct LanguageModelToolResult { - pub tool_use_id: LanguageModelToolUseId, - pub tool_name: Arc, - pub is_error: bool, - /// The tool output formatted for presenting to the model - pub content: LanguageModelToolResultContent, - /// The raw tool output, if available, often for debugging or extra state for replay - pub output: Option, -} - -#[derive(Debug, Clone, Serialize, Eq, PartialEq, Hash)] -pub enum LanguageModelToolResultContent { - Text(Arc), - Image(LanguageModelImage), -} - -impl<'de> Deserialize<'de> for LanguageModelToolResultContent { - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - use serde::de::Error; - - let value = serde_json::Value::deserialize(deserializer)?; - - // Models can provide these responses in several styles. Try each in order. - - // 1. Try as plain string - if let Ok(text) = serde_json::from_value::(value.clone()) { - return Ok(Self::Text(Arc::from(text))); - } - - // 2. Try as object - if let Some(obj) = value.as_object() { - // get a JSON field case-insensitively - fn get_field<'a>( - obj: &'a serde_json::Map, - field: &str, - ) -> Option<&'a serde_json::Value> { - obj.iter() - .find(|(k, _)| k.to_lowercase() == field.to_lowercase()) - .map(|(_, v)| v) - } - - // Accept wrapped text format: { "type": "text", "text": "..." } - if let (Some(type_value), Some(text_value)) = - (get_field(obj, "type"), get_field(obj, "text")) - && let Some(type_str) = type_value.as_str() - && type_str.to_lowercase() == "text" - && let Some(text) = text_value.as_str() - { - return Ok(Self::Text(Arc::from(text))); - } - - // Check for wrapped Text variant: { "text": "..." } - if let Some((_key, value)) = obj.iter().find(|(k, _)| k.to_lowercase() == "text") - && obj.len() == 1 - { - // Only one field, and it's "text" (case-insensitive) - if let Some(text) = value.as_str() { - return Ok(Self::Text(Arc::from(text))); - } - } - - // Check for wrapped Image variant: { "image": { "source": "...", "size": ... } } - if let Some((_key, value)) = obj.iter().find(|(k, _)| k.to_lowercase() == "image") - && obj.len() == 1 - { - // Only one field, and it's "image" (case-insensitive) - // Try to parse the nested image object - if let Some(image_obj) = value.as_object() - && let Some(image) = LanguageModelImage::from_json(image_obj) - { - return Ok(Self::Image(image)); - } - } - - // Try as direct Image (object with "source" and "size" fields) - if let Some(image) = LanguageModelImage::from_json(obj) { - return Ok(Self::Image(image)); - } - } - - // If none of the variants match, return an error with the problematic JSON - Err(D::Error::custom(format!( - "data did not match any variant of LanguageModelToolResultContent. Expected either a string, \ - an object with 'type': 'text', a wrapped variant like {{\"Text\": \"...\"}}, or an image object. Got: {}", - serde_json::to_string_pretty(&value).unwrap_or_else(|_| value.to_string()) - ))) - } -} - -impl LanguageModelToolResultContent { - pub fn to_str(&self) -> Option<&str> { - match self { - Self::Text(text) => Some(text), - Self::Image(_) => None, - } - } - - pub fn is_empty(&self) -> bool { - match self { - Self::Text(text) => text.chars().all(|c| c.is_whitespace()), - Self::Image(_) => false, - } - } -} - -impl From<&str> for LanguageModelToolResultContent { - fn from(value: &str) -> Self { - Self::Text(Arc::from(value)) - } -} - -impl From for LanguageModelToolResultContent { - fn from(value: String) -> Self { - Self::Text(Arc::from(value)) - } -} - -impl From for LanguageModelToolResultContent { - fn from(image: LanguageModelImage) -> Self { - Self::Image(image) - } -} - -#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)] -pub enum MessageContent { - Text(String), - Thinking { - text: String, - signature: Option, - }, - RedactedThinking(String), - Image(LanguageModelImage), - ToolUse(LanguageModelToolUse), - ToolResult(LanguageModelToolResult), -} - -impl MessageContent { - pub fn to_str(&self) -> Option<&str> { - match self { - MessageContent::Text(text) => Some(text.as_str()), - MessageContent::Thinking { text, .. } => Some(text.as_str()), - MessageContent::RedactedThinking(_) => None, - MessageContent::ToolResult(tool_result) => tool_result.content.to_str(), - MessageContent::ToolUse(_) | MessageContent::Image(_) => None, - } - } - - pub fn is_empty(&self) -> bool { - match self { - MessageContent::Text(text) => text.chars().all(|c| c.is_whitespace()), - MessageContent::Thinking { text, .. } => text.chars().all(|c| c.is_whitespace()), - MessageContent::ToolResult(tool_result) => tool_result.content.is_empty(), - MessageContent::RedactedThinking(_) - | MessageContent::ToolUse(_) - | MessageContent::Image(_) => false, - } - } -} - -impl From for MessageContent { - fn from(value: String) -> Self { - MessageContent::Text(value) - } -} - -impl From<&str> for MessageContent { - fn from(value: &str) -> Self { - MessageContent::Text(value.to_string()) - } -} - -#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Hash)] -pub struct LanguageModelRequestMessage { - pub role: Role, - pub content: Vec, - pub cache: bool, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub reasoning_details: Option, -} - -impl LanguageModelRequestMessage { - pub fn string_contents(&self) -> String { - let mut buffer = String::new(); - for string in self.content.iter().filter_map(|content| content.to_str()) { - buffer.push_str(string); - } - - buffer - } - - pub fn contents_empty(&self) -> bool { - self.content.iter().all(|content| content.is_empty()) - } -} - -#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)] -pub struct LanguageModelRequestTool { - pub name: String, - pub description: String, - pub input_schema: serde_json::Value, - pub use_input_streaming: bool, -} - -#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)] -pub enum LanguageModelToolChoice { - Auto, - Any, - None, -} - -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)] -#[serde(rename_all = "snake_case")] -pub enum CompletionIntent { - UserPrompt, - Subagent, - ToolResults, - ThreadSummarization, - ThreadContextSummarization, - CreateFile, - EditFile, - InlineAssist, - TerminalInlineAssist, - GenerateGitCommitMessage, -} - -#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)] -pub struct LanguageModelRequest { - pub thread_id: Option, - pub prompt_id: Option, - pub intent: Option, - pub messages: Vec, - pub tools: Vec, - pub tool_choice: Option, - pub stop: Vec, - pub temperature: Option, - pub thinking_allowed: bool, - pub thinking_effort: Option, - pub speed: Option, -} - -#[derive(Clone, Copy, Default, Debug, Serialize, Deserialize, PartialEq, Eq)] -#[serde(rename_all = "snake_case")] -pub enum Speed { - #[default] - Standard, - Fast, -} - -impl Speed { - pub fn toggle(self) -> Self { - match self { - Speed::Standard => Speed::Fast, - Speed::Fast => Speed::Standard, - } +/// Convert a core `ImageSize` to a gpui `Size`. +pub fn image_size_to_gpui(size: ImageSize) -> Size { + Size { + width: DevicePixels(size.width), + height: DevicePixels(size.height), } } -impl From for anthropic::Speed { - fn from(speed: Speed) -> Self { - match speed { - Speed::Standard => anthropic::Speed::Standard, - Speed::Fast => anthropic::Speed::Fast, - } +/// Convert a gpui `Size` to a core `ImageSize`. +pub fn gpui_size_to_image_size(size: Size) -> ImageSize { + ImageSize { + width: size.width.0, + height: size.height.0, } } -#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] -pub struct LanguageModelResponseMessage { - pub role: Option, - pub content: Option, -} - #[cfg(test)] mod tests { use super::*; use base64::Engine as _; use gpui::TestAppContext; - use image::ImageDecoder as _; - fn base64_to_png_bytes(base64_png: &str) -> Vec { + fn base64_to_png_bytes(base64: &str) -> Vec { base64::engine::general_purpose::STANDARD - .decode(base64_png.as_bytes()) - .expect("base64 should decode") + .decode(base64) + .expect("valid base64") } fn png_dimensions(png_bytes: &[u8]) -> (u32, u32) { - let decoder = - image::codecs::png::PngDecoder::new(Cursor::new(png_bytes)).expect("png should decode"); - decoder.dimensions() + let img = image::load_from_memory(png_bytes).expect("valid png"); + (img.width(), img.height()) } fn make_noisy_png_bytes(width: u32, height: u32) -> Vec { - // Create an RGBA image with per-pixel variance to avoid PNG compressing too well. - let mut img = image::RgbaImage::new(width, height); - for y in 0..height { - for x in 0..width { - let r = ((x ^ y) & 0xFF) as u8; - let g = ((x.wrapping_mul(31) ^ y.wrapping_mul(17)) & 0xFF) as u8; - let b = ((x.wrapping_mul(131) ^ y.wrapping_mul(7)) & 0xFF) as u8; - img.put_pixel(x, y, image::Rgba([r, g, b, 0xFF])); - } - } + use image::{ImageBuffer, Rgba}; + use std::hash::{Hash, Hasher}; + + let img = ImageBuffer::from_fn(width, height, |x, y| { + let mut hasher = std::hash::DefaultHasher::new(); + (x, y, width, height).hash(&mut hasher); + let h = hasher.finish(); + Rgba([h as u8, (h >> 8) as u8, (h >> 16) as u8, 255]) + }); - let mut out = Vec::new(); - image::DynamicImage::ImageRgba8(img) - .write_with_encoder(PngEncoder::new(&mut out)) - .expect("png encoding should succeed"); - out + let mut buf = Cursor::new(Vec::new()); + img.write_with_encoder(PngEncoder::new(&mut buf)) + .expect("encode"); + buf.into_inner() } #[gpui::test] async fn test_from_image_downscales_to_default_5mb_limit(cx: &mut TestAppContext) { - // Pick a size that reliably produces a PNG > 5MB when filled with noise. - // If this fails (image is too small), bump dimensions. - let original_png = make_noisy_png_bytes(4096, 4096); + let raw_png = make_noisy_png_bytes(4096, 4096); assert!( - original_png.len() > DEFAULT_IMAGE_MAX_BYTES, - "precondition failed: noisy PNG must exceed DEFAULT_IMAGE_MAX_BYTES" + raw_png.len() > DEFAULT_IMAGE_MAX_BYTES, + "Test image should exceed the 5 MB limit (actual: {} bytes)", + raw_png.len() ); - let image = gpui::Image::from_bytes(ImageFormat::Png, original_png); + let image = Arc::new(gpui::Image::from_bytes(ImageFormat::Png, raw_png)); let lm_image = cx - .update(|cx| LanguageModelImage::from_image(Arc::new(image), cx)) + .update(|cx| LanguageModelImage::from_image(Arc::clone(&image), cx)) .await - .expect("image conversion should succeed"); + .expect("from_image should succeed"); - let encoded_png = base64_to_png_bytes(lm_image.source.as_ref()); + let decoded_png = base64_to_png_bytes(lm_image.source.as_ref()); assert!( - encoded_png.len() <= DEFAULT_IMAGE_MAX_BYTES, - "expected encoded PNG <= DEFAULT_IMAGE_MAX_BYTES, got {} bytes", - encoded_png.len() + decoded_png.len() <= DEFAULT_IMAGE_MAX_BYTES, + "Encoded PNG should be ≤ {} bytes after downscale, but was {} bytes", + DEFAULT_IMAGE_MAX_BYTES, + decoded_png.len() ); - // Ensure we actually downscaled in pixels (not just re-encoded). - let (w, h) = png_dimensions(&encoded_png); + let (w, h) = png_dimensions(&decoded_png); assert!( - w < 4096 || h < 4096, - "expected image to be downscaled in at least one dimension; got {w}x{h}" - ); - } - - #[test] - fn test_language_model_tool_result_content_deserialization() { - let json = r#""This is plain text""#; - let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); - assert_eq!( - result, - LanguageModelToolResultContent::Text("This is plain text".into()) - ); - - let json = r#"{"type": "text", "text": "This is wrapped text"}"#; - let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); - assert_eq!( - result, - LanguageModelToolResultContent::Text("This is wrapped text".into()) - ); - - let json = r#"{"Type": "TEXT", "TEXT": "Case insensitive"}"#; - let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); - assert_eq!( - result, - LanguageModelToolResultContent::Text("Case insensitive".into()) - ); - - let json = r#"{"Text": "Wrapped variant"}"#; - let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); - assert_eq!( - result, - LanguageModelToolResultContent::Text("Wrapped variant".into()) - ); - - let json = r#"{"text": "Lowercase wrapped"}"#; - let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); - assert_eq!( - result, - LanguageModelToolResultContent::Text("Lowercase wrapped".into()) + w < 4096 && h < 4096, + "Dimensions should have shrunk: got {}×{}", + w, + h ); - - // Test image deserialization - let json = r#"{ - "source": "base64encodedimagedata", - "size": { - "width": 100, - "height": 200 - } - }"#; - let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); - match result { - LanguageModelToolResultContent::Image(image) => { - assert_eq!(image.source.as_ref(), "base64encodedimagedata"); - let size = image.size.expect("size"); - assert_eq!(size.width.0, 100); - assert_eq!(size.height.0, 200); - } - _ => panic!("Expected Image variant"), - } - - // Test wrapped Image variant - let json = r#"{ - "Image": { - "source": "wrappedimagedata", - "size": { - "width": 50, - "height": 75 - } - } - }"#; - let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); - match result { - LanguageModelToolResultContent::Image(image) => { - assert_eq!(image.source.as_ref(), "wrappedimagedata"); - let size = image.size.expect("size"); - assert_eq!(size.width.0, 50); - assert_eq!(size.height.0, 75); - } - _ => panic!("Expected Image variant"), - } - - // Test wrapped Image variant with case insensitive - let json = r#"{ - "image": { - "Source": "caseinsensitive", - "SIZE": { - "width": 30, - "height": 40 - } - } - }"#; - let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); - match result { - LanguageModelToolResultContent::Image(image) => { - assert_eq!(image.source.as_ref(), "caseinsensitive"); - let size = image.size.expect("size"); - assert_eq!(size.width.0, 30); - assert_eq!(size.height.0, 40); - } - _ => panic!("Expected Image variant"), - } - - // Test that wrapped text with wrong type fails - let json = r#"{"type": "blahblah", "text": "This should fail"}"#; - let result: Result = serde_json::from_str(json); - assert!(result.is_err()); - - // Test that malformed JSON fails - let json = r#"{"invalid": "structure"}"#; - let result: Result = serde_json::from_str(json); - assert!(result.is_err()); - - // Test edge cases - let json = r#""""#; // Empty string - let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); - assert_eq!(result, LanguageModelToolResultContent::Text("".into())); - - // Test with extra fields in wrapped text (should be ignored) - let json = r#"{"type": "text", "text": "Hello", "extra": "field"}"#; - let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); - assert_eq!(result, LanguageModelToolResultContent::Text("Hello".into())); - - // Test direct image with case-insensitive fields - let json = r#"{ - "SOURCE": "directimage", - "Size": { - "width": 200, - "height": 300 - } - }"#; - let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); - match result { - LanguageModelToolResultContent::Image(image) => { - assert_eq!(image.source.as_ref(), "directimage"); - let size = image.size.expect("size"); - assert_eq!(size.width.0, 200); - assert_eq!(size.height.0, 300); - } - _ => panic!("Expected Image variant"), - } - - // Test that multiple fields prevent wrapped variant interpretation - let json = r#"{"Text": "not wrapped", "extra": "field"}"#; - let result: Result = serde_json::from_str(json); - assert!(result.is_err()); - - // Test wrapped text with uppercase TEXT variant - let json = r#"{"TEXT": "Uppercase variant"}"#; - let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); - assert_eq!( - result, - LanguageModelToolResultContent::Text("Uppercase variant".into()) - ); - - // Test that numbers and other JSON values fail gracefully - let json = r#"123"#; - let result: Result = serde_json::from_str(json); - assert!(result.is_err()); - - let json = r#"null"#; - let result: Result = serde_json::from_str(json); - assert!(result.is_err()); - - let json = r#"[1, 2, 3]"#; - let result: Result = serde_json::from_str(json); - assert!(result.is_err()); } } diff --git a/crates/language_model_core/Cargo.toml b/crates/language_model_core/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..7a6de00f3e4a774537d93e2f77ea9107845a7c50 --- /dev/null +++ b/crates/language_model_core/Cargo.toml @@ -0,0 +1,27 @@ +[package] +name = "language_model_core" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[lib] +path = "src/language_model_core.rs" +doctest = false + +[dependencies] +anyhow.workspace = true +cloud_llm_client.workspace = true +futures.workspace = true +gpui_shared_string.workspace = true +http_client.workspace = true +partial-json-fixer.workspace = true +schemars.workspace = true +serde.workspace = true +serde_json.workspace = true +smol.workspace = true +strum.workspace = true +thiserror.workspace = true diff --git a/crates/language_model_core/LICENSE-GPL b/crates/language_model_core/LICENSE-GPL new file mode 120000 index 0000000000000000000000000000000000000000..89e542f750cd3860a0598eff0dc34b56d7336dc4 --- /dev/null +++ b/crates/language_model_core/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/language_model_core/src/language_model_core.rs b/crates/language_model_core/src/language_model_core.rs new file mode 100644 index 0000000000000000000000000000000000000000..5f932690869a2c17ec1c89cbe9401bcdef6e1e73 --- /dev/null +++ b/crates/language_model_core/src/language_model_core.rs @@ -0,0 +1,658 @@ +mod provider; +mod rate_limiter; +mod request; +mod role; +pub mod tool_schema; +pub mod util; + +use anyhow::{Result, anyhow}; +use cloud_llm_client::CompletionRequestStatus; +use http_client::{StatusCode, http}; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use std::ops::{Add, Sub}; +use std::str::FromStr; +use std::sync::Arc; +use std::time::Duration; +use std::{fmt, io}; +use thiserror::Error; +fn is_default(value: &T) -> bool { + *value == T::default() +} + +pub use crate::provider::*; +pub use crate::rate_limiter::*; +pub use crate::request::*; +pub use crate::role::*; +pub use crate::tool_schema::LanguageModelToolSchemaFormat; +pub use crate::util::{fix_streamed_json, parse_prompt_too_long, parse_tool_arguments}; +pub use gpui_shared_string::SharedString; + +#[derive(Clone, Debug)] +pub struct LanguageModelCacheConfiguration { + pub max_cache_anchors: usize, + pub should_speculate: bool, + pub min_total_token: u64, +} + +/// A completion event from a language model. +#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] +pub enum LanguageModelCompletionEvent { + Queued { + position: usize, + }, + Started, + Stop(StopReason), + Text(String), + Thinking { + text: String, + signature: Option, + }, + RedactedThinking { + data: String, + }, + ToolUse(LanguageModelToolUse), + ToolUseJsonParseError { + id: LanguageModelToolUseId, + tool_name: Arc, + raw_input: Arc, + json_parse_error: String, + }, + StartMessage { + message_id: String, + }, + ReasoningDetails(serde_json::Value), + UsageUpdate(TokenUsage), +} + +impl LanguageModelCompletionEvent { + pub fn from_completion_request_status( + status: CompletionRequestStatus, + upstream_provider: LanguageModelProviderName, + ) -> Result, LanguageModelCompletionError> { + match status { + CompletionRequestStatus::Queued { position } => { + Ok(Some(LanguageModelCompletionEvent::Queued { position })) + } + CompletionRequestStatus::Started => Ok(Some(LanguageModelCompletionEvent::Started)), + CompletionRequestStatus::Unknown | CompletionRequestStatus::StreamEnded => Ok(None), + CompletionRequestStatus::Failed { + code, + message, + request_id: _, + retry_after, + } => Err(LanguageModelCompletionError::from_cloud_failure( + upstream_provider, + code, + message, + retry_after.map(Duration::from_secs_f64), + )), + } + } +} + +#[derive(Error, Debug)] +pub enum LanguageModelCompletionError { + #[error("prompt too large for context window")] + PromptTooLarge { tokens: Option }, + #[error("missing {provider} API key")] + NoApiKey { provider: LanguageModelProviderName }, + #[error("{provider}'s API rate limit exceeded")] + RateLimitExceeded { + provider: LanguageModelProviderName, + retry_after: Option, + }, + #[error("{provider}'s API servers are overloaded right now")] + ServerOverloaded { + provider: LanguageModelProviderName, + retry_after: Option, + }, + #[error("{provider}'s API server reported an internal server error: {message}")] + ApiInternalServerError { + provider: LanguageModelProviderName, + message: String, + }, + #[error("{message}")] + UpstreamProviderError { + message: String, + status: StatusCode, + retry_after: Option, + }, + #[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")] + HttpResponseError { + provider: LanguageModelProviderName, + status_code: StatusCode, + message: String, + }, + #[error("invalid request format to {provider}'s API: {message}")] + BadRequestFormat { + provider: LanguageModelProviderName, + message: String, + }, + #[error("authentication error with {provider}'s API: {message}")] + AuthenticationError { + provider: LanguageModelProviderName, + message: String, + }, + #[error("Permission error with {provider}'s API: {message}")] + PermissionError { + provider: LanguageModelProviderName, + message: String, + }, + #[error("language model provider API endpoint not found")] + ApiEndpointNotFound { provider: LanguageModelProviderName }, + #[error("I/O error reading response from {provider}'s API")] + ApiReadResponseError { + provider: LanguageModelProviderName, + #[source] + error: io::Error, + }, + #[error("error serializing request to {provider} API")] + SerializeRequest { + provider: LanguageModelProviderName, + #[source] + error: serde_json::Error, + }, + #[error("error building request body to {provider} API")] + BuildRequestBody { + provider: LanguageModelProviderName, + #[source] + error: http::Error, + }, + #[error("error sending HTTP request to {provider} API")] + HttpSend { + provider: LanguageModelProviderName, + #[source] + error: anyhow::Error, + }, + #[error("error deserializing {provider} API response")] + DeserializeResponse { + provider: LanguageModelProviderName, + #[source] + error: serde_json::Error, + }, + #[error("stream from {provider} ended unexpectedly")] + StreamEndedUnexpectedly { provider: LanguageModelProviderName }, + #[error(transparent)] + Other(#[from] anyhow::Error), +} + +impl LanguageModelCompletionError { + fn parse_upstream_error_json(message: &str) -> Option<(StatusCode, String)> { + let error_json = serde_json::from_str::(message).ok()?; + let upstream_status = error_json + .get("upstream_status") + .and_then(|v| v.as_u64()) + .and_then(|status| u16::try_from(status).ok()) + .and_then(|status| StatusCode::from_u16(status).ok())?; + let inner_message = error_json + .get("message") + .and_then(|v| v.as_str()) + .unwrap_or(message) + .to_string(); + Some((upstream_status, inner_message)) + } + + pub fn from_cloud_failure( + upstream_provider: LanguageModelProviderName, + code: String, + message: String, + retry_after: Option, + ) -> Self { + if let Some(tokens) = parse_prompt_too_long(&message) { + Self::PromptTooLarge { + tokens: Some(tokens), + } + } else if code == "upstream_http_error" { + if let Some((upstream_status, inner_message)) = + Self::parse_upstream_error_json(&message) + { + return Self::from_http_status( + upstream_provider, + upstream_status, + inner_message, + retry_after, + ); + } + anyhow!("completion request failed, code: {code}, message: {message}").into() + } else if let Some(status_code) = code + .strip_prefix("upstream_http_") + .and_then(|code| StatusCode::from_str(code).ok()) + { + Self::from_http_status(upstream_provider, status_code, message, retry_after) + } else if let Some(status_code) = code + .strip_prefix("http_") + .and_then(|code| StatusCode::from_str(code).ok()) + { + Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after) + } else { + anyhow!("completion request failed, code: {code}, message: {message}").into() + } + } + + pub fn from_http_status( + provider: LanguageModelProviderName, + status_code: StatusCode, + message: String, + retry_after: Option, + ) -> Self { + match status_code { + StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message }, + StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message }, + StatusCode::FORBIDDEN => Self::PermissionError { provider, message }, + StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider }, + StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge { + tokens: parse_prompt_too_long(&message), + }, + StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded { + provider, + retry_after, + }, + StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message }, + StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded { + provider, + retry_after, + }, + _ if status_code.as_u16() == 529 => Self::ServerOverloaded { + provider, + retry_after, + }, + _ => Self::HttpResponseError { + provider, + status_code, + message, + }, + } + } +} + +#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum StopReason { + EndTurn, + MaxTokens, + ToolUse, + Refusal, +} + +#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)] +pub struct TokenUsage { + #[serde(default, skip_serializing_if = "is_default")] + pub input_tokens: u64, + #[serde(default, skip_serializing_if = "is_default")] + pub output_tokens: u64, + #[serde(default, skip_serializing_if = "is_default")] + pub cache_creation_input_tokens: u64, + #[serde(default, skip_serializing_if = "is_default")] + pub cache_read_input_tokens: u64, +} + +impl TokenUsage { + pub fn total_tokens(&self) -> u64 { + self.input_tokens + + self.output_tokens + + self.cache_read_input_tokens + + self.cache_creation_input_tokens + } +} + +impl Add for TokenUsage { + type Output = Self; + + fn add(self, other: Self) -> Self { + Self { + input_tokens: self.input_tokens + other.input_tokens, + output_tokens: self.output_tokens + other.output_tokens, + cache_creation_input_tokens: self.cache_creation_input_tokens + + other.cache_creation_input_tokens, + cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens, + } + } +} + +impl Sub for TokenUsage { + type Output = Self; + + fn sub(self, other: Self) -> Self { + Self { + input_tokens: self.input_tokens - other.input_tokens, + output_tokens: self.output_tokens - other.output_tokens, + cache_creation_input_tokens: self.cache_creation_input_tokens + - other.cache_creation_input_tokens, + cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens, + } + } +} + +#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)] +pub struct LanguageModelToolUseId(Arc); + +impl fmt::Display for LanguageModelToolUseId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +impl From for LanguageModelToolUseId +where + T: Into>, +{ + fn from(value: T) -> Self { + Self(value.into()) + } +} + +#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)] +pub struct LanguageModelToolUse { + pub id: LanguageModelToolUseId, + pub name: Arc, + pub raw_input: String, + pub input: serde_json::Value, + pub is_input_complete: bool, + /// Thought signature the model sent us. Some models require that this + /// signature be preserved and sent back in conversation history for validation. + pub thought_signature: Option, +} + +#[derive(Debug, Clone)] +pub struct LanguageModelEffortLevel { + pub name: SharedString, + pub value: SharedString, + pub is_default: bool, +} + +/// An error that occurred when trying to authenticate the language model provider. +#[derive(Debug, Error)] +pub enum AuthenticateError { + #[error("connection refused")] + ConnectionRefused, + #[error("credentials not found")] + CredentialsNotFound, + #[error(transparent)] + Other(#[from] anyhow::Error), +} + +#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)] +pub struct LanguageModelId(pub SharedString); + +#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)] +pub struct LanguageModelName(pub SharedString); + +#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)] +pub struct LanguageModelProviderId(pub SharedString); + +#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)] +pub struct LanguageModelProviderName(pub SharedString); + +impl LanguageModelProviderId { + pub const fn new(id: &'static str) -> Self { + Self(SharedString::new_static(id)) + } +} + +impl LanguageModelProviderName { + pub const fn new(id: &'static str) -> Self { + Self(SharedString::new_static(id)) + } +} + +impl fmt::Display for LanguageModelProviderId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +impl fmt::Display for LanguageModelProviderName { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +impl From for LanguageModelId { + fn from(value: String) -> Self { + Self(SharedString::from(value)) + } +} + +impl From for LanguageModelName { + fn from(value: String) -> Self { + Self(SharedString::from(value)) + } +} + +impl From for LanguageModelProviderId { + fn from(value: String) -> Self { + Self(SharedString::from(value)) + } +} + +impl From for LanguageModelProviderName { + fn from(value: String) -> Self { + Self(SharedString::from(value)) + } +} + +impl From> for LanguageModelProviderId { + fn from(value: Arc) -> Self { + Self(SharedString::from(value)) + } +} + +impl From> for LanguageModelProviderName { + fn from(value: Arc) -> Self { + Self(SharedString::from(value)) + } +} + +/// Settings-layer–free model mode enum. +/// +/// Mirrors the shape of `settings_content::ModelMode` but lives here so that +/// crates below the settings layer can reference it. +#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize, JsonSchema)] +#[serde(tag = "type", rename_all = "lowercase")] +pub enum ModelMode { + #[default] + Default, + Thinking { + budget_tokens: Option, + }, +} + +/// Settings-layer–free reasoning-effort enum. +/// +/// Mirrors the shape of `settings_content::OpenAiReasoningEffort` but lives +/// here so that crates below the settings layer can reference it. +#[derive( + Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema, strum::EnumString, +)] +#[serde(rename_all = "lowercase")] +#[strum(serialize_all = "lowercase")] +pub enum ReasoningEffort { + Minimal, + Low, + Medium, + High, + XHigh, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_from_cloud_failure_with_upstream_http_error() { + let error = LanguageModelCompletionError::from_cloud_failure( + String::from("anthropic").into(), + "upstream_http_error".to_string(), + r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(), + None, + ); + + match error { + LanguageModelCompletionError::ServerOverloaded { provider, .. } => { + assert_eq!(provider.0, "anthropic"); + } + _ => panic!( + "Expected ServerOverloaded error for 503 status, got: {:?}", + error + ), + } + + let error = LanguageModelCompletionError::from_cloud_failure( + String::from("anthropic").into(), + "upstream_http_error".to_string(), + r#"{"code":"upstream_http_error","message":"Internal server error","upstream_status":500}"#.to_string(), + None, + ); + + match error { + LanguageModelCompletionError::ApiInternalServerError { provider, message } => { + assert_eq!(provider.0, "anthropic"); + assert_eq!(message, "Internal server error"); + } + _ => panic!( + "Expected ApiInternalServerError for 500 status, got: {:?}", + error + ), + } + } + + #[test] + fn test_from_cloud_failure_with_standard_format() { + let error = LanguageModelCompletionError::from_cloud_failure( + String::from("anthropic").into(), + "upstream_http_503".to_string(), + "Service unavailable".to_string(), + None, + ); + + match error { + LanguageModelCompletionError::ServerOverloaded { provider, .. } => { + assert_eq!(provider.0, "anthropic"); + } + _ => panic!("Expected ServerOverloaded error for upstream_http_503"), + } + } + + #[test] + fn test_upstream_http_error_connection_timeout() { + let error = LanguageModelCompletionError::from_cloud_failure( + String::from("anthropic").into(), + "upstream_http_error".to_string(), + r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(), + None, + ); + + match error { + LanguageModelCompletionError::ServerOverloaded { provider, .. } => { + assert_eq!(provider.0, "anthropic"); + } + _ => panic!( + "Expected ServerOverloaded error for connection timeout with 503 status, got: {:?}", + error + ), + } + + let error = LanguageModelCompletionError::from_cloud_failure( + String::from("anthropic").into(), + "upstream_http_error".to_string(), + r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":500}"#.to_string(), + None, + ); + + match error { + LanguageModelCompletionError::ApiInternalServerError { provider, message } => { + assert_eq!(provider.0, "anthropic"); + assert_eq!( + message, + "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout" + ); + } + _ => panic!( + "Expected ApiInternalServerError for connection timeout with 500 status, got: {:?}", + error + ), + } + } + + #[test] + fn test_language_model_tool_use_serializes_with_signature() { + use serde_json::json; + + let tool_use = LanguageModelToolUse { + id: LanguageModelToolUseId::from("test_id"), + name: "test_tool".into(), + raw_input: json!({"arg": "value"}).to_string(), + input: json!({"arg": "value"}), + is_input_complete: true, + thought_signature: Some("test_signature".to_string()), + }; + + let serialized = serde_json::to_value(&tool_use).unwrap(); + + assert_eq!(serialized["id"], "test_id"); + assert_eq!(serialized["name"], "test_tool"); + assert_eq!(serialized["thought_signature"], "test_signature"); + } + + #[test] + fn test_language_model_tool_use_deserializes_with_missing_signature() { + use serde_json::json; + + let json = json!({ + "id": "test_id", + "name": "test_tool", + "raw_input": "{\"arg\":\"value\"}", + "input": {"arg": "value"}, + "is_input_complete": true + }); + + let tool_use: LanguageModelToolUse = serde_json::from_value(json).unwrap(); + + assert_eq!(tool_use.id, LanguageModelToolUseId::from("test_id")); + assert_eq!(tool_use.name.as_ref(), "test_tool"); + assert_eq!(tool_use.thought_signature, None); + } + + #[test] + fn test_language_model_tool_use_round_trip_with_signature() { + use serde_json::json; + + let original = LanguageModelToolUse { + id: LanguageModelToolUseId::from("round_trip_id"), + name: "round_trip_tool".into(), + raw_input: json!({"key": "value"}).to_string(), + input: json!({"key": "value"}), + is_input_complete: true, + thought_signature: Some("round_trip_sig".to_string()), + }; + + let serialized = serde_json::to_value(&original).unwrap(); + let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap(); + + assert_eq!(deserialized.id, original.id); + assert_eq!(deserialized.name, original.name); + assert_eq!(deserialized.thought_signature, original.thought_signature); + } + + #[test] + fn test_language_model_tool_use_round_trip_without_signature() { + use serde_json::json; + + let original = LanguageModelToolUse { + id: LanguageModelToolUseId::from("no_sig_id"), + name: "no_sig_tool".into(), + raw_input: json!({"arg": "value"}).to_string(), + input: json!({"arg": "value"}), + is_input_complete: true, + thought_signature: None, + }; + + let serialized = serde_json::to_value(&original).unwrap(); + let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap(); + + assert_eq!(deserialized.id, original.id); + assert_eq!(deserialized.name, original.name); + assert_eq!(deserialized.thought_signature, None); + } +} diff --git a/crates/language_model_core/src/provider.rs b/crates/language_model_core/src/provider.rs new file mode 100644 index 0000000000000000000000000000000000000000..da8b208147ad1d5b58a35888dfd07c821965097c --- /dev/null +++ b/crates/language_model_core/src/provider.rs @@ -0,0 +1,21 @@ +use crate::{LanguageModelProviderId, LanguageModelProviderName}; + +pub const ANTHROPIC_PROVIDER_ID: LanguageModelProviderId = + LanguageModelProviderId::new("anthropic"); +pub const ANTHROPIC_PROVIDER_NAME: LanguageModelProviderName = + LanguageModelProviderName::new("Anthropic"); + +pub const OPEN_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openai"); +pub const OPEN_AI_PROVIDER_NAME: LanguageModelProviderName = + LanguageModelProviderName::new("OpenAI"); + +pub const GOOGLE_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("google"); +pub const GOOGLE_PROVIDER_NAME: LanguageModelProviderName = + LanguageModelProviderName::new("Google AI"); + +pub const X_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("x_ai"); +pub const X_AI_PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("xAI"); + +pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("zed.dev"); +pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName = + LanguageModelProviderName::new("Zed"); diff --git a/crates/language_model/src/rate_limiter.rs b/crates/language_model_core/src/rate_limiter.rs similarity index 100% rename from crates/language_model/src/rate_limiter.rs rename to crates/language_model_core/src/rate_limiter.rs diff --git a/crates/language_model_core/src/request.rs b/crates/language_model_core/src/request.rs new file mode 100644 index 0000000000000000000000000000000000000000..48f7f00522bc3dd5c06747d662761efb003886c0 --- /dev/null +++ b/crates/language_model_core/src/request.rs @@ -0,0 +1,463 @@ +use std::sync::Arc; + +use serde::{Deserialize, Serialize}; + +use crate::role::Role; +use crate::{LanguageModelToolUse, LanguageModelToolUseId, SharedString}; + +/// Dimensions of a `LanguageModelImage` +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub struct ImageSize { + pub width: i32, + pub height: i32, +} + +#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Hash)] +pub struct LanguageModelImage { + /// A base64-encoded PNG image. + pub source: SharedString, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub size: Option, +} + +impl LanguageModelImage { + pub fn len(&self) -> usize { + self.source.len() + } + + pub fn is_empty(&self) -> bool { + self.source.is_empty() + } + + pub fn empty() -> Self { + Self { + source: "".into(), + size: None, + } + } + + /// Parse Self from a JSON object with case-insensitive field names + pub fn from_json(obj: &serde_json::Map) -> Option { + let mut source = None; + let mut size_obj = None; + + for (k, v) in obj.iter() { + match k.to_lowercase().as_str() { + "source" => source = v.as_str(), + "size" => size_obj = v.as_object(), + _ => {} + } + } + + let source = source?; + let size_obj = size_obj?; + + let mut width = None; + let mut height = None; + + for (k, v) in size_obj.iter() { + match k.to_lowercase().as_str() { + "width" => width = v.as_i64().map(|w| w as i32), + "height" => height = v.as_i64().map(|h| h as i32), + _ => {} + } + } + + Some(Self { + size: Some(ImageSize { + width: width?, + height: height?, + }), + source: SharedString::from(source.to_string()), + }) + } + + pub fn estimate_tokens(&self) -> usize { + let Some(size) = self.size.as_ref() else { + return 0; + }; + let width = size.width.unsigned_abs() as usize; + let height = size.height.unsigned_abs() as usize; + + // From: https://docs.anthropic.com/en/docs/build-with-claude/vision#calculate-image-costs + (width * height) / 750 + } + + pub fn to_base64_url(&self) -> String { + format!("data:image/png;base64,{}", self.source) + } +} + +impl std::fmt::Debug for LanguageModelImage { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("LanguageModelImage") + .field("source", &format!("<{} bytes>", self.source.len())) + .field("size", &self.size) + .finish() + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)] +pub struct LanguageModelToolResult { + pub tool_use_id: LanguageModelToolUseId, + pub tool_name: Arc, + pub is_error: bool, + /// The tool output formatted for presenting to the model + pub content: LanguageModelToolResultContent, + /// The raw tool output, if available, often for debugging or extra state for replay + pub output: Option, +} + +#[derive(Debug, Clone, Serialize, Eq, PartialEq, Hash)] +pub enum LanguageModelToolResultContent { + Text(Arc), + Image(LanguageModelImage), +} + +impl<'de> Deserialize<'de> for LanguageModelToolResultContent { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + use serde::de::Error; + + let value = serde_json::Value::deserialize(deserializer)?; + + // 1. Try as plain string + if let Ok(text) = serde_json::from_value::(value.clone()) { + return Ok(Self::Text(Arc::from(text))); + } + + // 2. Try as object + if let Some(obj) = value.as_object() { + fn get_field<'a>( + obj: &'a serde_json::Map, + field: &str, + ) -> Option<&'a serde_json::Value> { + obj.iter() + .find(|(k, _)| k.to_lowercase() == field.to_lowercase()) + .map(|(_, v)| v) + } + + // Accept wrapped text format: { "type": "text", "text": "..." } + if let (Some(type_value), Some(text_value)) = + (get_field(obj, "type"), get_field(obj, "text")) + && let Some(type_str) = type_value.as_str() + && type_str.to_lowercase() == "text" + && let Some(text) = text_value.as_str() + { + return Ok(Self::Text(Arc::from(text))); + } + + // Check for wrapped Text variant: { "text": "..." } + if let Some((_key, value)) = obj.iter().find(|(k, _)| k.to_lowercase() == "text") + && obj.len() == 1 + { + if let Some(text) = value.as_str() { + return Ok(Self::Text(Arc::from(text))); + } + } + + // Check for wrapped Image variant: { "image": { "source": "...", "size": ... } } + if let Some((_key, value)) = obj.iter().find(|(k, _)| k.to_lowercase() == "image") + && obj.len() == 1 + { + if let Some(image_obj) = value.as_object() + && let Some(image) = LanguageModelImage::from_json(image_obj) + { + return Ok(Self::Image(image)); + } + } + + // Try as direct Image + if let Some(image) = LanguageModelImage::from_json(obj) { + return Ok(Self::Image(image)); + } + } + + Err(D::Error::custom(format!( + "data did not match any variant of LanguageModelToolResultContent. Expected either a string, \ + an object with 'type': 'text', a wrapped variant like {{\"Text\": \"...\"}}, or an image object. Got: {}", + serde_json::to_string_pretty(&value).unwrap_or_else(|_| value.to_string()) + ))) + } +} + +impl LanguageModelToolResultContent { + pub fn to_str(&self) -> Option<&str> { + match self { + Self::Text(text) => Some(text), + Self::Image(_) => None, + } + } + + pub fn is_empty(&self) -> bool { + match self { + Self::Text(text) => text.chars().all(|c| c.is_whitespace()), + Self::Image(_) => false, + } + } +} + +impl From<&str> for LanguageModelToolResultContent { + fn from(value: &str) -> Self { + Self::Text(Arc::from(value)) + } +} + +impl From for LanguageModelToolResultContent { + fn from(value: String) -> Self { + Self::Text(Arc::from(value)) + } +} + +impl From for LanguageModelToolResultContent { + fn from(image: LanguageModelImage) -> Self { + Self::Image(image) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)] +pub enum MessageContent { + Text(String), + Thinking { + text: String, + signature: Option, + }, + RedactedThinking(String), + Image(LanguageModelImage), + ToolUse(LanguageModelToolUse), + ToolResult(LanguageModelToolResult), +} + +impl MessageContent { + pub fn to_str(&self) -> Option<&str> { + match self { + MessageContent::Text(text) => Some(text.as_str()), + MessageContent::Thinking { text, .. } => Some(text.as_str()), + MessageContent::RedactedThinking(_) => None, + MessageContent::ToolResult(tool_result) => tool_result.content.to_str(), + MessageContent::ToolUse(_) | MessageContent::Image(_) => None, + } + } + + pub fn is_empty(&self) -> bool { + match self { + MessageContent::Text(text) => text.chars().all(|c| c.is_whitespace()), + MessageContent::Thinking { text, .. } => text.chars().all(|c| c.is_whitespace()), + MessageContent::ToolResult(tool_result) => tool_result.content.is_empty(), + MessageContent::RedactedThinking(_) + | MessageContent::ToolUse(_) + | MessageContent::Image(_) => false, + } + } +} + +impl From for MessageContent { + fn from(value: String) -> Self { + MessageContent::Text(value) + } +} + +impl From<&str> for MessageContent { + fn from(value: &str) -> Self { + MessageContent::Text(value.to_string()) + } +} + +#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Hash)] +pub struct LanguageModelRequestMessage { + pub role: Role, + pub content: Vec, + pub cache: bool, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub reasoning_details: Option, +} + +impl LanguageModelRequestMessage { + pub fn string_contents(&self) -> String { + let mut buffer = String::new(); + for string in self.content.iter().filter_map(|content| content.to_str()) { + buffer.push_str(string); + } + buffer + } + + pub fn contents_empty(&self) -> bool { + self.content.iter().all(|content| content.is_empty()) + } +} + +#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)] +pub struct LanguageModelRequestTool { + pub name: String, + pub description: String, + pub input_schema: serde_json::Value, + pub use_input_streaming: bool, +} + +#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)] +pub enum LanguageModelToolChoice { + Auto, + Any, + None, +} + +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum CompletionIntent { + UserPrompt, + Subagent, + ToolResults, + ThreadSummarization, + ThreadContextSummarization, + CreateFile, + EditFile, + InlineAssist, + TerminalInlineAssist, + GenerateGitCommitMessage, +} + +#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)] +pub struct LanguageModelRequest { + pub thread_id: Option, + pub prompt_id: Option, + pub intent: Option, + pub messages: Vec, + pub tools: Vec, + pub tool_choice: Option, + pub stop: Vec, + pub temperature: Option, + pub thinking_allowed: bool, + pub thinking_effort: Option, + pub speed: Option, +} + +#[derive(Clone, Copy, Default, Debug, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum Speed { + #[default] + Standard, + Fast, +} + +impl Speed { + pub fn toggle(self) -> Self { + match self { + Speed::Standard => Speed::Fast, + Speed::Fast => Speed::Standard, + } + } +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +pub struct LanguageModelResponseMessage { + pub role: Option, + pub content: Option, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_language_model_tool_result_content_deserialization() { + // Test plain string + let json = serde_json::json!("hello world"); + let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap(); + assert_eq!( + content, + LanguageModelToolResultContent::Text(Arc::from("hello world")) + ); + + // Test wrapped text format: { "type": "text", "text": "..." } + let json = serde_json::json!({"type": "text", "text": "hello"}); + let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap(); + assert_eq!( + content, + LanguageModelToolResultContent::Text(Arc::from("hello")) + ); + + // Test single-field text object: { "text": "..." } + let json = serde_json::json!({"text": "hello"}); + let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap(); + assert_eq!( + content, + LanguageModelToolResultContent::Text(Arc::from("hello")) + ); + + // Test case-insensitive type field + let json = serde_json::json!({"Type": "Text", "Text": "hello"}); + let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap(); + assert_eq!( + content, + LanguageModelToolResultContent::Text(Arc::from("hello")) + ); + + // Test image object + let json = serde_json::json!({ + "source": "base64encodedimagedata", + "size": {"width": 100, "height": 200} + }); + let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap(); + match content { + LanguageModelToolResultContent::Image(image) => { + assert_eq!(image.source.as_ref(), "base64encodedimagedata"); + let size = image.size.expect("size"); + assert_eq!(size.width, 100); + assert_eq!(size.height, 200); + } + _ => panic!("Expected Image variant"), + } + + // Test wrapped image: { "image": { "source": "...", "size": ... } } + let json = serde_json::json!({ + "image": { + "source": "wrappedimagedata", + "size": {"width": 50, "height": 75} + } + }); + let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap(); + match content { + LanguageModelToolResultContent::Image(image) => { + assert_eq!(image.source.as_ref(), "wrappedimagedata"); + let size = image.size.expect("size"); + assert_eq!(size.width, 50); + assert_eq!(size.height, 75); + } + _ => panic!("Expected Image variant"), + } + + // Test case insensitive + let json = serde_json::json!({ + "Source": "caseinsensitive", + "Size": {"Width": 30, "Height": 40} + }); + let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap(); + match content { + LanguageModelToolResultContent::Image(image) => { + assert_eq!(image.source.as_ref(), "caseinsensitive"); + let size = image.size.expect("size"); + assert_eq!(size.width, 30); + assert_eq!(size.height, 40); + } + _ => panic!("Expected Image variant"), + } + + // Test direct image object + let json = serde_json::json!({ + "source": "directimage", + "size": {"width": 200, "height": 300} + }); + let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap(); + match content { + LanguageModelToolResultContent::Image(image) => { + assert_eq!(image.source.as_ref(), "directimage"); + let size = image.size.expect("size"); + assert_eq!(size.width, 200); + assert_eq!(size.height, 300); + } + _ => panic!("Expected Image variant"), + } + } +} diff --git a/crates/language_model/src/role.rs b/crates/language_model_core/src/role.rs similarity index 100% rename from crates/language_model/src/role.rs rename to crates/language_model_core/src/role.rs diff --git a/crates/language_model/src/tool_schema.rs b/crates/language_model_core/src/tool_schema.rs similarity index 92% rename from crates/language_model/src/tool_schema.rs rename to crates/language_model_core/src/tool_schema.rs index 878870482a7527bf815797d16e03ad8edc79642e..0e82b2f41081469c6c04d16765e8336eb903fd94 100644 --- a/crates/language_model/src/tool_schema.rs +++ b/crates/language_model_core/src/tool_schema.rs @@ -77,8 +77,6 @@ pub fn adapt_schema_to_format( } fn preprocess_json_schema(json: &mut Value) -> Result<()> { - // `additionalProperties` defaults to `false` unless explicitly specified. - // This prevents models from hallucinating tool parameters. if let Value::Object(obj) = json && matches!(obj.get("type"), Some(Value::String(s)) if s == "object") { @@ -86,7 +84,6 @@ fn preprocess_json_schema(json: &mut Value) -> Result<()> { obj.insert("additionalProperties".to_string(), Value::Bool(false)); } - // OpenAI API requires non-missing `properties` if !obj.contains_key("properties") { obj.insert("properties".to_string(), Value::Object(Default::default())); } @@ -94,7 +91,6 @@ fn preprocess_json_schema(json: &mut Value) -> Result<()> { Ok(()) } -/// Tries to adapt the json schema so that it is compatible with https://ai.google.dev/api/caching#Schema fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> { if let Value::Object(obj) = json { const UNSUPPORTED_KEYS: [&str; 4] = ["if", "then", "else", "$ref"]; @@ -108,9 +104,7 @@ fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> { const KEYS_TO_REMOVE: [(&str, fn(&Value) -> bool); 6] = [ ("format", |value| value.is_string()), - // Gemini doesn't support `additionalProperties` in any form (boolean or schema object) ("additionalProperties", |_| true), - // Gemini doesn't support `propertyNames` ("propertyNames", |_| true), ("exclusiveMinimum", |value| value.is_number()), ("exclusiveMaximum", |value| value.is_number()), @@ -124,7 +118,6 @@ fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> { } } - // If a type is not specified for an input parameter, add a default type if matches!(obj.get("description"), Some(Value::String(_))) && !obj.contains_key("type") && !(obj.contains_key("anyOf") @@ -134,7 +127,6 @@ fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> { obj.insert("type".to_string(), Value::String("string".to_string())); } - // Handle oneOf -> anyOf conversion if let Some(subschemas) = obj.get_mut("oneOf") && subschemas.is_array() { @@ -143,7 +135,6 @@ fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> { obj.insert("anyOf".to_string(), subschemas_clone); } - // Recursively process all nested objects and arrays for (_, value) in obj.iter_mut() { if let Value::Object(_) | Value::Array(_) = value { adapt_to_json_schema_subset(value)?; @@ -178,7 +169,6 @@ mod tests { }) ); - // Ensure that we do not add a type if it is an object let mut json = json!({ "description": { "value": "abc", @@ -221,7 +211,6 @@ mod tests { }) ); - // Ensure that we do not remove keys that are actually supported (e.g. "format" can just be used as another property) let mut json = json!({ "description": "A test field", "type": "integer", @@ -239,7 +228,6 @@ mod tests { }) ); - // additionalProperties as an object schema is also unsupported by Gemini let mut json = json!({ "type": "object", "properties": { diff --git a/crates/language_models/src/provider/util.rs b/crates/language_model_core/src/util.rs similarity index 88% rename from crates/language_models/src/provider/util.rs rename to crates/language_model_core/src/util.rs index 76a02b6de40a3e36c7c506f11a6f6d34d2aaca3e..3db2e0b76fd76070aa4d30e97c525fa8f3460c9d 100644 --- a/crates/language_models/src/provider/util.rs +++ b/crates/language_model_core/src/util.rs @@ -38,13 +38,22 @@ fn strip_trailing_incomplete_escape(json: &str) -> &str { } } +/// Parses a "prompt is too long: N tokens ..." message and extracts the token count. +pub fn parse_prompt_too_long(message: &str) -> Option { + message + .strip_prefix("prompt is too long: ")? + .split_once(" tokens")? + .0 + .parse() + .ok() +} + #[cfg(test)] mod tests { use super::*; #[test] fn test_fix_streamed_json_strips_incomplete_escape() { - // Trailing `\` inside a string — incomplete escape sequence let fixed = fix_streamed_json(r#"{"text": "hello\"#); let parsed: serde_json::Value = serde_json::from_str(&fixed).expect("valid json"); assert_eq!(parsed["text"], "hello"); @@ -52,7 +61,6 @@ mod tests { #[test] fn test_fix_streamed_json_preserves_complete_escape() { - // `\\` is a complete escape (literal backslash) let fixed = fix_streamed_json(r#"{"text": "hello\\"#); let parsed: serde_json::Value = serde_json::from_str(&fixed).expect("valid json"); assert_eq!(parsed["text"], "hello\\"); @@ -60,7 +68,6 @@ mod tests { #[test] fn test_fix_streamed_json_strips_escape_after_complete_escape() { - // `\\\` = complete `\\` (literal backslash) + incomplete `\` let fixed = fix_streamed_json(r#"{"text": "hello\\\"#); let parsed: serde_json::Value = serde_json::from_str(&fixed).expect("valid json"); assert_eq!(parsed["text"], "hello\\"); @@ -75,12 +82,10 @@ mod tests { #[test] fn test_fix_streamed_json_newline_escape_boundary() { - // Simulates a stream boundary landing between `\` and `n` let fixed = fix_streamed_json(r#"{"text": "line1\"#); let parsed: serde_json::Value = serde_json::from_str(&fixed).expect("valid json"); assert_eq!(parsed["text"], "line1"); - // Next chunk completes the escape let fixed = fix_streamed_json(r#"{"text": "line1\nline2"#); let parsed: serde_json::Value = serde_json::from_str(&fixed).expect("valid json"); assert_eq!(parsed["text"], "line1\nline2"); @@ -88,8 +93,6 @@ mod tests { #[test] fn test_fix_streamed_json_incremental_delta_correctness() { - // This is the actual scenario that causes the bug: - // chunk 1 ends mid-escape, chunk 2 completes it. let chunk1 = r#"{"replacement_text": "fn foo() {\"#; let fixed1 = fix_streamed_json(chunk1); let parsed1: serde_json::Value = serde_json::from_str(&fixed1).expect("valid json"); @@ -102,7 +105,6 @@ mod tests { let text2 = parsed2["replacement_text"].as_str().expect("string"); assert_eq!(text2, "fn foo() {\n return bar;\n}"); - // The delta should be the newline + rest, with no spurious backslash let delta = &text2[text1.len()..]; assert_eq!(delta, "\n return bar;\n}"); } diff --git a/crates/language_models/Cargo.toml b/crates/language_models/Cargo.toml index 4ebfce695e587265ea39077c67c84ce9b01e5352..60670114529b07dca78202cc438ff5e243acaeee 100644 --- a/crates/language_models/Cargo.toml +++ b/crates/language_models/Cargo.toml @@ -21,8 +21,8 @@ aws_http_client.workspace = true base64.workspace = true bedrock = { workspace = true, features = ["schemars"] } client.workspace = true +cloud_api_client.workspace = true cloud_api_types.workspace = true -cloud_llm_client.workspace = true collections.workspace = true component.workspace = true convert_case.workspace = true @@ -41,6 +41,7 @@ gpui_tokio.workspace = true http_client.workspace = true language.workspace = true language_model.workspace = true +language_models_cloud.workspace = true lmstudio = { workspace = true, features = ["schemars"] } log.workspace = true menu.workspace = true @@ -49,16 +50,13 @@ ollama = { workspace = true, features = ["schemars"] } open_ai = { workspace = true, features = ["schemars"] } opencode = { workspace = true, features = ["schemars"] } open_router = { workspace = true, features = ["schemars"] } -partial-json-fixer.workspace = true release_channel.workspace = true schemars.workspace = true -semver.workspace = true serde.workspace = true serde_json.workspace = true settings.workspace = true smol.workspace = true strum.workspace = true -thiserror.workspace = true tiktoken-rs.workspace = true tokio = { workspace = true, features = ["rt", "rt-multi-thread"] } ui.workspace = true @@ -70,4 +68,3 @@ x_ai = { workspace = true, features = ["schemars"] } [dev-dependencies] language_model = { workspace = true, features = ["test-support"] } pretty_assertions.workspace = true - diff --git a/crates/language_models/src/provider.rs b/crates/language_models/src/provider.rs index d3c433974599399160e602b8f201b9fd0af874cb..35a1e90e4483ba03e1ded8ce8c7519fc0fa7a746 100644 --- a/crates/language_models/src/provider.rs +++ b/crates/language_models/src/provider.rs @@ -11,7 +11,7 @@ pub mod open_ai; pub mod open_ai_compatible; pub mod open_router; pub mod opencode; -mod util; + pub mod vercel; pub mod vercel_ai_gateway; pub mod x_ai; diff --git a/crates/language_models/src/provider/anthropic.rs b/crates/language_models/src/provider/anthropic.rs index c1b8bc1a3bb1b602b67ae5563d8acc3b05a94d47..58de77d573293345ec2120695866c824f10c6108 100644 --- a/crates/language_models/src/provider/anthropic.rs +++ b/crates/language_models/src/provider/anthropic.rs @@ -1,13 +1,10 @@ pub mod telemetry; -use anthropic::{ - ANTHROPIC_API_URL, AnthropicError, AnthropicModelMode, ContentDelta, CountTokensRequest, Event, - ResponseContent, ToolResultContent, ToolResultPart, Usage, -}; +use anthropic::{ANTHROPIC_API_URL, AnthropicError, AnthropicModelMode}; use anyhow::Result; -use collections::{BTreeMap, HashMap}; +use collections::BTreeMap; use credentials_provider::CredentialsProvider; -use futures::{FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream}; +use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream}; use gpui::{AnyView, App, AsyncApp, Context, Entity, Task}; use http_client::HttpClient; use language_model::{ @@ -16,20 +13,19 @@ use language_model::{ LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, - LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent, - RateLimiter, Role, StopReason, env_var, + LanguageModelToolChoice, RateLimiter, env_var, }; use settings::{Settings, SettingsStore}; -use std::pin::Pin; -use std::str::FromStr; use std::sync::{Arc, LazyLock}; use strum::IntoEnumIterator; use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*}; use ui_input::InputField; use util::ResultExt; -use crate::provider::util::{fix_streamed_json, parse_tool_arguments}; - +pub use anthropic::completion::{ + AnthropicEventMapper, count_anthropic_tokens_with_tiktoken, into_anthropic, + into_anthropic_count_tokens_request, +}; pub use settings::AnthropicAvailableModel as AvailableModel; const PROVIDER_ID: LanguageModelProviderId = ANTHROPIC_PROVIDER_ID; @@ -249,228 +245,6 @@ pub struct AnthropicModel { request_limiter: RateLimiter, } -fn to_anthropic_content(content: MessageContent) -> Option { - match content { - MessageContent::Text(text) => { - let text = if text.chars().last().is_some_and(|c| c.is_whitespace()) { - text.trim_end().to_string() - } else { - text - }; - if !text.is_empty() { - Some(anthropic::RequestContent::Text { - text, - cache_control: None, - }) - } else { - None - } - } - MessageContent::Thinking { - text: thinking, - signature, - } => { - if let Some(signature) = signature - && !thinking.is_empty() - { - Some(anthropic::RequestContent::Thinking { - thinking, - signature, - cache_control: None, - }) - } else { - None - } - } - MessageContent::RedactedThinking(data) => { - if !data.is_empty() { - Some(anthropic::RequestContent::RedactedThinking { data }) - } else { - None - } - } - MessageContent::Image(image) => Some(anthropic::RequestContent::Image { - source: anthropic::ImageSource { - source_type: "base64".to_string(), - media_type: "image/png".to_string(), - data: image.source.to_string(), - }, - cache_control: None, - }), - MessageContent::ToolUse(tool_use) => Some(anthropic::RequestContent::ToolUse { - id: tool_use.id.to_string(), - name: tool_use.name.to_string(), - input: tool_use.input, - cache_control: None, - }), - MessageContent::ToolResult(tool_result) => Some(anthropic::RequestContent::ToolResult { - tool_use_id: tool_result.tool_use_id.to_string(), - is_error: tool_result.is_error, - content: match tool_result.content { - LanguageModelToolResultContent::Text(text) => { - ToolResultContent::Plain(text.to_string()) - } - LanguageModelToolResultContent::Image(image) => { - ToolResultContent::Multipart(vec![ToolResultPart::Image { - source: anthropic::ImageSource { - source_type: "base64".to_string(), - media_type: "image/png".to_string(), - data: image.source.to_string(), - }, - }]) - } - }, - cache_control: None, - }), - } -} - -/// Convert a LanguageModelRequest to an Anthropic CountTokensRequest. -pub fn into_anthropic_count_tokens_request( - request: LanguageModelRequest, - model: String, - mode: AnthropicModelMode, -) -> CountTokensRequest { - let mut new_messages: Vec = Vec::new(); - let mut system_message = String::new(); - - for message in request.messages { - if message.contents_empty() { - continue; - } - - match message.role { - Role::User | Role::Assistant => { - let anthropic_message_content: Vec = message - .content - .into_iter() - .filter_map(to_anthropic_content) - .collect(); - let anthropic_role = match message.role { - Role::User => anthropic::Role::User, - Role::Assistant => anthropic::Role::Assistant, - Role::System => unreachable!("System role should never occur here"), - }; - if anthropic_message_content.is_empty() { - continue; - } - - if let Some(last_message) = new_messages.last_mut() - && last_message.role == anthropic_role - { - last_message.content.extend(anthropic_message_content); - continue; - } - - new_messages.push(anthropic::Message { - role: anthropic_role, - content: anthropic_message_content, - }); - } - Role::System => { - if !system_message.is_empty() { - system_message.push_str("\n\n"); - } - system_message.push_str(&message.string_contents()); - } - } - } - - CountTokensRequest { - model, - messages: new_messages, - system: if system_message.is_empty() { - None - } else { - Some(anthropic::StringOrContents::String(system_message)) - }, - thinking: if request.thinking_allowed { - match mode { - AnthropicModelMode::Thinking { budget_tokens } => { - Some(anthropic::Thinking::Enabled { budget_tokens }) - } - AnthropicModelMode::AdaptiveThinking => Some(anthropic::Thinking::Adaptive), - AnthropicModelMode::Default => None, - } - } else { - None - }, - tools: request - .tools - .into_iter() - .map(|tool| anthropic::Tool { - name: tool.name, - description: tool.description, - input_schema: tool.input_schema, - eager_input_streaming: tool.use_input_streaming, - }) - .collect(), - tool_choice: request.tool_choice.map(|choice| match choice { - LanguageModelToolChoice::Auto => anthropic::ToolChoice::Auto, - LanguageModelToolChoice::Any => anthropic::ToolChoice::Any, - LanguageModelToolChoice::None => anthropic::ToolChoice::None, - }), - } -} - -/// Estimate tokens using tiktoken. Used as a fallback when the API is unavailable, -/// or by providers (like Zed Cloud) that don't have direct Anthropic API access. -pub fn count_anthropic_tokens_with_tiktoken(request: LanguageModelRequest) -> Result { - let messages = request.messages; - let mut tokens_from_images = 0; - let mut string_messages = Vec::with_capacity(messages.len()); - - for message in messages { - let mut string_contents = String::new(); - - for content in message.content { - match content { - MessageContent::Text(text) => { - string_contents.push_str(&text); - } - MessageContent::Thinking { .. } => { - // Thinking blocks are not included in the input token count. - } - MessageContent::RedactedThinking(_) => { - // Thinking blocks are not included in the input token count. - } - MessageContent::Image(image) => { - tokens_from_images += image.estimate_tokens(); - } - MessageContent::ToolUse(_tool_use) => { - // TODO: Estimate token usage from tool uses. - } - MessageContent::ToolResult(tool_result) => match &tool_result.content { - LanguageModelToolResultContent::Text(text) => { - string_contents.push_str(text); - } - LanguageModelToolResultContent::Image(image) => { - tokens_from_images += image.estimate_tokens(); - } - }, - } - } - - if !string_contents.is_empty() { - string_messages.push(tiktoken_rs::ChatCompletionRequestMessage { - role: match message.role { - Role::User => "user".into(), - Role::Assistant => "assistant".into(), - Role::System => "system".into(), - }, - content: Some(string_contents), - name: None, - function_call: None, - }); - } - } - - // Tiktoken doesn't yet support these models, so we manually use the - // same tokenizer as GPT-4. - tiktoken_rs::num_tokens_from_messages("gpt-4", &string_messages) - .map(|tokens| (tokens + tokens_from_images) as u64) -} - impl AnthropicModel { fn stream_completion( &self, @@ -617,10 +391,13 @@ impl LanguageModel for AnthropicModel { ) }); + let background = cx.background_executor().clone(); async move { // If no API key, fall back to tiktoken estimation let Some(api_key) = api_key else { - return count_anthropic_tokens_with_tiktoken(request); + return background + .spawn(async move { count_anthropic_tokens_with_tiktoken(request) }) + .await; }; let count_request = @@ -634,7 +411,9 @@ impl LanguageModel for AnthropicModel { log::error!( "Anthropic count_tokens API failed, falling back to tiktoken: {err:?}" ); - count_anthropic_tokens_with_tiktoken(request) + background + .spawn(async move { count_anthropic_tokens_with_tiktoken(request) }) + .await } } } @@ -678,345 +457,6 @@ impl LanguageModel for AnthropicModel { } } -pub fn into_anthropic( - request: LanguageModelRequest, - model: String, - default_temperature: f32, - max_output_tokens: u64, - mode: AnthropicModelMode, -) -> anthropic::Request { - let mut new_messages: Vec = Vec::new(); - let mut system_message = String::new(); - - for message in request.messages { - if message.contents_empty() { - continue; - } - - match message.role { - Role::User | Role::Assistant => { - let mut anthropic_message_content: Vec = message - .content - .into_iter() - .filter_map(to_anthropic_content) - .collect(); - let anthropic_role = match message.role { - Role::User => anthropic::Role::User, - Role::Assistant => anthropic::Role::Assistant, - Role::System => unreachable!("System role should never occur here"), - }; - if anthropic_message_content.is_empty() { - continue; - } - - if let Some(last_message) = new_messages.last_mut() - && last_message.role == anthropic_role - { - last_message.content.extend(anthropic_message_content); - continue; - } - - // Mark the last segment of the message as cached - if message.cache { - let cache_control_value = Some(anthropic::CacheControl { - cache_type: anthropic::CacheControlType::Ephemeral, - }); - for message_content in anthropic_message_content.iter_mut().rev() { - match message_content { - anthropic::RequestContent::RedactedThinking { .. } => { - // Caching is not possible, fallback to next message - } - anthropic::RequestContent::Text { cache_control, .. } - | anthropic::RequestContent::Thinking { cache_control, .. } - | anthropic::RequestContent::Image { cache_control, .. } - | anthropic::RequestContent::ToolUse { cache_control, .. } - | anthropic::RequestContent::ToolResult { cache_control, .. } => { - *cache_control = cache_control_value; - break; - } - } - } - } - - new_messages.push(anthropic::Message { - role: anthropic_role, - content: anthropic_message_content, - }); - } - Role::System => { - if !system_message.is_empty() { - system_message.push_str("\n\n"); - } - system_message.push_str(&message.string_contents()); - } - } - } - - anthropic::Request { - model, - messages: new_messages, - max_tokens: max_output_tokens, - system: if system_message.is_empty() { - None - } else { - Some(anthropic::StringOrContents::String(system_message)) - }, - thinking: if request.thinking_allowed { - match mode { - AnthropicModelMode::Thinking { budget_tokens } => { - Some(anthropic::Thinking::Enabled { budget_tokens }) - } - AnthropicModelMode::AdaptiveThinking => Some(anthropic::Thinking::Adaptive), - AnthropicModelMode::Default => None, - } - } else { - None - }, - tools: request - .tools - .into_iter() - .map(|tool| anthropic::Tool { - name: tool.name, - description: tool.description, - input_schema: tool.input_schema, - eager_input_streaming: tool.use_input_streaming, - }) - .collect(), - tool_choice: request.tool_choice.map(|choice| match choice { - LanguageModelToolChoice::Auto => anthropic::ToolChoice::Auto, - LanguageModelToolChoice::Any => anthropic::ToolChoice::Any, - LanguageModelToolChoice::None => anthropic::ToolChoice::None, - }), - metadata: None, - output_config: if request.thinking_allowed - && matches!(mode, AnthropicModelMode::AdaptiveThinking) - { - request.thinking_effort.as_deref().and_then(|effort| { - let effort = match effort { - "low" => Some(anthropic::Effort::Low), - "medium" => Some(anthropic::Effort::Medium), - "high" => Some(anthropic::Effort::High), - "max" => Some(anthropic::Effort::Max), - _ => None, - }; - effort.map(|effort| anthropic::OutputConfig { - effort: Some(effort), - }) - }) - } else { - None - }, - stop_sequences: Vec::new(), - speed: request.speed.map(From::from), - temperature: request.temperature.or(Some(default_temperature)), - top_k: None, - top_p: None, - } -} - -pub struct AnthropicEventMapper { - tool_uses_by_index: HashMap, - usage: Usage, - stop_reason: StopReason, -} - -impl AnthropicEventMapper { - pub fn new() -> Self { - Self { - tool_uses_by_index: HashMap::default(), - usage: Usage::default(), - stop_reason: StopReason::EndTurn, - } - } - - pub fn map_stream( - mut self, - events: Pin>>>, - ) -> impl Stream> - { - events.flat_map(move |event| { - futures::stream::iter(match event { - Ok(event) => self.map_event(event), - Err(error) => vec![Err(error.into())], - }) - }) - } - - pub fn map_event( - &mut self, - event: Event, - ) -> Vec> { - match event { - Event::ContentBlockStart { - index, - content_block, - } => match content_block { - ResponseContent::Text { text } => { - vec![Ok(LanguageModelCompletionEvent::Text(text))] - } - ResponseContent::Thinking { thinking } => { - vec![Ok(LanguageModelCompletionEvent::Thinking { - text: thinking, - signature: None, - })] - } - ResponseContent::RedactedThinking { data } => { - vec![Ok(LanguageModelCompletionEvent::RedactedThinking { data })] - } - ResponseContent::ToolUse { id, name, .. } => { - self.tool_uses_by_index.insert( - index, - RawToolUse { - id, - name, - input_json: String::new(), - }, - ); - Vec::new() - } - }, - Event::ContentBlockDelta { index, delta } => match delta { - ContentDelta::TextDelta { text } => { - vec![Ok(LanguageModelCompletionEvent::Text(text))] - } - ContentDelta::ThinkingDelta { thinking } => { - vec![Ok(LanguageModelCompletionEvent::Thinking { - text: thinking, - signature: None, - })] - } - ContentDelta::SignatureDelta { signature } => { - vec![Ok(LanguageModelCompletionEvent::Thinking { - text: "".to_string(), - signature: Some(signature), - })] - } - ContentDelta::InputJsonDelta { partial_json } => { - if let Some(tool_use) = self.tool_uses_by_index.get_mut(&index) { - tool_use.input_json.push_str(&partial_json); - - // Try to convert invalid (incomplete) JSON into - // valid JSON that serde can accept, e.g. by closing - // unclosed delimiters. This way, we can update the - // UI with whatever has been streamed back so far. - if let Ok(input) = - serde_json::Value::from_str(&fix_streamed_json(&tool_use.input_json)) - { - return vec![Ok(LanguageModelCompletionEvent::ToolUse( - LanguageModelToolUse { - id: tool_use.id.clone().into(), - name: tool_use.name.clone().into(), - is_input_complete: false, - raw_input: tool_use.input_json.clone(), - input, - thought_signature: None, - }, - ))]; - } - } - vec![] - } - }, - Event::ContentBlockStop { index } => { - if let Some(tool_use) = self.tool_uses_by_index.remove(&index) { - let input_json = tool_use.input_json.trim(); - let event_result = match parse_tool_arguments(input_json) { - Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse( - LanguageModelToolUse { - id: tool_use.id.into(), - name: tool_use.name.into(), - is_input_complete: true, - input, - raw_input: tool_use.input_json.clone(), - thought_signature: None, - }, - )), - Err(json_parse_err) => { - Ok(LanguageModelCompletionEvent::ToolUseJsonParseError { - id: tool_use.id.into(), - tool_name: tool_use.name.into(), - raw_input: input_json.into(), - json_parse_error: json_parse_err.to_string(), - }) - } - }; - - vec![event_result] - } else { - Vec::new() - } - } - Event::MessageStart { message } => { - update_usage(&mut self.usage, &message.usage); - vec![ - Ok(LanguageModelCompletionEvent::UsageUpdate(convert_usage( - &self.usage, - ))), - Ok(LanguageModelCompletionEvent::StartMessage { - message_id: message.id, - }), - ] - } - Event::MessageDelta { delta, usage } => { - update_usage(&mut self.usage, &usage); - if let Some(stop_reason) = delta.stop_reason.as_deref() { - self.stop_reason = match stop_reason { - "end_turn" => StopReason::EndTurn, - "max_tokens" => StopReason::MaxTokens, - "tool_use" => StopReason::ToolUse, - "refusal" => StopReason::Refusal, - _ => { - log::error!("Unexpected anthropic stop_reason: {stop_reason}"); - StopReason::EndTurn - } - }; - } - vec![Ok(LanguageModelCompletionEvent::UsageUpdate( - convert_usage(&self.usage), - ))] - } - Event::MessageStop => { - vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))] - } - Event::Error { error } => { - vec![Err(error.into())] - } - _ => Vec::new(), - } - } -} - -struct RawToolUse { - id: String, - name: String, - input_json: String, -} - -/// Updates usage data by preferring counts from `new`. -fn update_usage(usage: &mut Usage, new: &Usage) { - if let Some(input_tokens) = new.input_tokens { - usage.input_tokens = Some(input_tokens); - } - if let Some(output_tokens) = new.output_tokens { - usage.output_tokens = Some(output_tokens); - } - if let Some(cache_creation_input_tokens) = new.cache_creation_input_tokens { - usage.cache_creation_input_tokens = Some(cache_creation_input_tokens); - } - if let Some(cache_read_input_tokens) = new.cache_read_input_tokens { - usage.cache_read_input_tokens = Some(cache_read_input_tokens); - } -} - -fn convert_usage(usage: &Usage) -> language_model::TokenUsage { - language_model::TokenUsage { - input_tokens: usage.input_tokens.unwrap_or(0), - output_tokens: usage.output_tokens.unwrap_or(0), - cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0), - cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0), - } -} - struct ConfigurationView { api_key_editor: Entity, state: Entity, @@ -1157,192 +597,3 @@ impl Render for ConfigurationView { } } } - -#[cfg(test)] -mod tests { - use super::*; - use anthropic::AnthropicModelMode; - use language_model::{LanguageModelRequestMessage, MessageContent}; - - #[test] - fn test_cache_control_only_on_last_segment() { - let request = LanguageModelRequest { - messages: vec![LanguageModelRequestMessage { - role: Role::User, - content: vec![ - MessageContent::Text("Some prompt".to_string()), - MessageContent::Image(language_model::LanguageModelImage::empty()), - MessageContent::Image(language_model::LanguageModelImage::empty()), - MessageContent::Image(language_model::LanguageModelImage::empty()), - MessageContent::Image(language_model::LanguageModelImage::empty()), - ], - cache: true, - reasoning_details: None, - }], - thread_id: None, - prompt_id: None, - intent: None, - stop: vec![], - temperature: None, - tools: vec![], - tool_choice: None, - thinking_allowed: true, - thinking_effort: None, - speed: None, - }; - - let anthropic_request = into_anthropic( - request, - "claude-3-5-sonnet".to_string(), - 0.7, - 4096, - AnthropicModelMode::Default, - ); - - assert_eq!(anthropic_request.messages.len(), 1); - - let message = &anthropic_request.messages[0]; - assert_eq!(message.content.len(), 5); - - assert!(matches!( - message.content[0], - anthropic::RequestContent::Text { - cache_control: None, - .. - } - )); - for i in 1..3 { - assert!(matches!( - message.content[i], - anthropic::RequestContent::Image { - cache_control: None, - .. - } - )); - } - - assert!(matches!( - message.content[4], - anthropic::RequestContent::Image { - cache_control: Some(anthropic::CacheControl { - cache_type: anthropic::CacheControlType::Ephemeral, - }), - .. - } - )); - } - - fn request_with_assistant_content( - assistant_content: Vec, - ) -> anthropic::Request { - let mut request = LanguageModelRequest { - messages: vec![LanguageModelRequestMessage { - role: Role::User, - content: vec![MessageContent::Text("Hello".to_string())], - cache: false, - reasoning_details: None, - }], - thinking_effort: None, - thread_id: None, - prompt_id: None, - intent: None, - stop: vec![], - temperature: None, - tools: vec![], - tool_choice: None, - thinking_allowed: true, - speed: None, - }; - request.messages.push(LanguageModelRequestMessage { - role: Role::Assistant, - content: assistant_content, - cache: false, - reasoning_details: None, - }); - into_anthropic( - request, - "claude-sonnet-4-5".to_string(), - 1.0, - 16000, - AnthropicModelMode::Thinking { - budget_tokens: Some(10000), - }, - ) - } - - #[test] - fn test_unsigned_thinking_blocks_stripped() { - let result = request_with_assistant_content(vec![ - MessageContent::Thinking { - text: "Cancelled mid-think, no signature".to_string(), - signature: None, - }, - MessageContent::Text("Some response text".to_string()), - ]); - - let assistant_message = result - .messages - .iter() - .find(|m| m.role == anthropic::Role::Assistant) - .expect("assistant message should still exist"); - - assert_eq!( - assistant_message.content.len(), - 1, - "Only the text content should remain; unsigned thinking block should be stripped" - ); - assert!(matches!( - &assistant_message.content[0], - anthropic::RequestContent::Text { text, .. } if text == "Some response text" - )); - } - - #[test] - fn test_signed_thinking_blocks_preserved() { - let result = request_with_assistant_content(vec![ - MessageContent::Thinking { - text: "Completed thinking".to_string(), - signature: Some("valid-signature".to_string()), - }, - MessageContent::Text("Response".to_string()), - ]); - - let assistant_message = result - .messages - .iter() - .find(|m| m.role == anthropic::Role::Assistant) - .expect("assistant message should exist"); - - assert_eq!( - assistant_message.content.len(), - 2, - "Both the signed thinking block and text should be preserved" - ); - assert!(matches!( - &assistant_message.content[0], - anthropic::RequestContent::Thinking { thinking, signature, .. } - if thinking == "Completed thinking" && signature == "valid-signature" - )); - } - - #[test] - fn test_only_unsigned_thinking_block_omits_entire_message() { - let result = request_with_assistant_content(vec![MessageContent::Thinking { - text: "Cancelled before any text or signature".to_string(), - signature: None, - }]); - - let assistant_messages: Vec<_> = result - .messages - .iter() - .filter(|m| m.role == anthropic::Role::Assistant) - .collect(); - - assert_eq!( - assistant_messages.len(), - 0, - "An assistant message whose only content was an unsigned thinking block \ - should be omitted entirely" - ); - } -} diff --git a/crates/language_models/src/provider/bedrock.rs b/crates/language_models/src/provider/bedrock.rs index 4320763e2c5c6de7f3fe9238d7a4991565c3bfcd..80c758769cd990c00f5942433143bf6fb2216b7c 100644 --- a/crates/language_models/src/provider/bedrock.rs +++ b/crates/language_models/src/provider/bedrock.rs @@ -48,7 +48,7 @@ use ui_input::InputField; use util::ResultExt; use crate::AllLanguageModelSettings; -use crate::provider::util::{fix_streamed_json, parse_tool_arguments}; +use language_model::util::{fix_streamed_json, parse_tool_arguments}; actions!(bedrock, [Tab, TabPrev]); diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index 29623cc998ad0fe933e9a29c45c651f7be010b07..294b44ecae9941481e26c2341018ce584d68b3ec 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -1,107 +1,93 @@ use ai_onboarding::YoungAccountBanner; -use anthropic::AnthropicModelMode; -use anyhow::{Context as _, Result, anyhow}; -use client::{ - Client, NeedsLlmTokenRefresh, RefreshLlmTokenListener, UserStore, global_llm_token, zed_urls, -}; -use cloud_api_types::{OrganizationId, Plan}; -use cloud_llm_client::{ - CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME, - CLIENT_SUPPORTS_X_AI_HEADER_NAME, CompletionBody, CompletionEvent, CompletionRequestStatus, - CountTokensBody, CountTokensResponse, ListModelsResponse, - SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, ZED_VERSION_HEADER_NAME, -}; -use futures::{ - AsyncBufReadExt, FutureExt, Stream, StreamExt, - future::BoxFuture, - stream::{self, BoxStream}, -}; -use google_ai::GoogleModelMode; -use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Task}; -use http_client::http::{HeaderMap, HeaderValue}; -use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Response, StatusCode}; +use anyhow::Result; +use client::{Client, RefreshLlmTokenListener, UserStore, global_llm_token, zed_urls}; +use cloud_api_client::LlmApiToken; +use cloud_api_types::OrganizationId; +use cloud_api_types::Plan; +use futures::StreamExt; +use futures::future::BoxFuture; +use gpui::AsyncApp; +use gpui::{AnyElement, AnyView, App, Context, Entity, Subscription, Task}; use language_model::{ - ANTHROPIC_PROVIDER_ID, ANTHROPIC_PROVIDER_NAME, AuthenticateError, GOOGLE_PROVIDER_ID, - GOOGLE_PROVIDER_NAME, IconOrSvg, LanguageModel, LanguageModelCacheConfiguration, - LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelEffortLevel, - LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, - LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, - LanguageModelToolChoice, LanguageModelToolSchemaFormat, LlmApiToken, OPEN_AI_PROVIDER_ID, - OPEN_AI_PROVIDER_NAME, PaymentRequiredError, RateLimiter, X_AI_PROVIDER_ID, X_AI_PROVIDER_NAME, - ZED_CLOUD_PROVIDER_ID, ZED_CLOUD_PROVIDER_NAME, + AuthenticateError, IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelProviderId, + LanguageModelProviderName, LanguageModelProviderState, ZED_CLOUD_PROVIDER_ID, + ZED_CLOUD_PROVIDER_NAME, }; +use language_models_cloud::{CloudLlmTokenProvider, CloudModelProvider}; use release_channel::AppVersion; -use schemars::JsonSchema; -use semver::Version; -use serde::{Deserialize, Serialize, de::DeserializeOwned}; + use settings::SettingsStore; pub use settings::ZedDotDevAvailableModel as AvailableModel; pub use settings::ZedDotDevAvailableProvider as AvailableProvider; -use smol::io::{AsyncReadExt, BufReader}; -use std::collections::VecDeque; -use std::pin::Pin; -use std::str::FromStr; use std::sync::Arc; -use std::task::Poll; -use std::time::Duration; -use thiserror::Error; use ui::{TintColor, prelude::*}; -use crate::provider::anthropic::{ - AnthropicEventMapper, count_anthropic_tokens_with_tiktoken, into_anthropic, -}; -use crate::provider::google::{GoogleEventMapper, into_google}; -use crate::provider::open_ai::{ - OpenAiEventMapper, OpenAiResponseEventMapper, count_open_ai_tokens, into_open_ai, - into_open_ai_response, -}; -use crate::provider::x_ai::count_xai_tokens; - const PROVIDER_ID: LanguageModelProviderId = ZED_CLOUD_PROVIDER_ID; const PROVIDER_NAME: LanguageModelProviderName = ZED_CLOUD_PROVIDER_NAME; -#[derive(Default, Clone, Debug, PartialEq)] -pub struct ZedDotDevSettings { - pub available_models: Vec, -} -#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] -#[serde(tag = "type", rename_all = "lowercase")] -pub enum ModelMode { - #[default] - Default, - Thinking { - /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`. - budget_tokens: Option, - }, +struct ClientTokenProvider { + client: Arc, + llm_api_token: LlmApiToken, + user_store: Entity, } -impl From for AnthropicModelMode { - fn from(value: ModelMode) -> Self { - match value { - ModelMode::Default => AnthropicModelMode::Default, - ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens }, - } +impl CloudLlmTokenProvider for ClientTokenProvider { + type AuthContext = Option; + + fn auth_context(&self, cx: &AsyncApp) -> Self::AuthContext { + self.user_store.read_with(cx, |user_store, _| { + user_store + .current_organization() + .map(|organization| organization.id.clone()) + }) } + + fn acquire_token( + &self, + organization_id: Self::AuthContext, + ) -> BoxFuture<'static, Result> { + let client = self.client.clone(); + let llm_api_token = self.llm_api_token.clone(); + Box::pin(async move { + client + .acquire_llm_token(&llm_api_token, organization_id) + .await + }) + } + + fn refresh_token( + &self, + organization_id: Self::AuthContext, + ) -> BoxFuture<'static, Result> { + let client = self.client.clone(); + let llm_api_token = self.llm_api_token.clone(); + Box::pin(async move { + client + .refresh_llm_token(&llm_api_token, organization_id) + .await + }) + } +} + +#[derive(Default, Clone, Debug, PartialEq)] +pub struct ZedDotDevSettings { + pub available_models: Vec, } pub struct CloudLanguageModelProvider { - client: Arc, state: Entity, _maintain_client_status: Task<()>, } pub struct State { client: Arc, - llm_api_token: LlmApiToken, user_store: Entity, status: client::Status, - models: Vec>, - default_model: Option>, - default_fast_model: Option>, - recommended_models: Vec>, + provider: Entity>, _user_store_subscription: Subscription, _settings_subscription: Subscription, _llm_token_subscription: Subscription, + _provider_subscription: Subscription, } impl State { @@ -112,16 +98,26 @@ impl State { cx: &mut Context, ) -> Self { let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx); - let llm_api_token = global_llm_token(cx); + let token_provider = Arc::new(ClientTokenProvider { + client: client.clone(), + llm_api_token: global_llm_token(cx), + user_store: user_store.clone(), + }); + + let provider = cx.new(|cx| { + CloudModelProvider::new( + token_provider.clone(), + client.http_client(), + Some(AppVersion::global(cx)), + ) + }); + Self { client: client.clone(), - llm_api_token, user_store: user_store.clone(), status, - models: Vec::new(), - default_model: None, - default_fast_model: None, - recommended_models: Vec::new(), + _provider_subscription: cx.observe(&provider, |_, _, cx| cx.notify()), + provider, _user_store_subscription: cx.subscribe( &user_store, move |this, _user_store, event, cx| match event { @@ -131,19 +127,7 @@ impl State { return; } - let client = this.client.clone(); - let llm_api_token = this.llm_api_token.clone(); - let organization_id = this - .user_store - .read(cx) - .current_organization() - .map(|organization| organization.id.clone()); - cx.spawn(async move |this, cx| { - let response = - Self::fetch_models(client, llm_api_token, organization_id).await?; - this.update(cx, |this, cx| this.update_models(response, cx)) - }) - .detach_and_log_err(cx); + this.refresh_models(cx); } _ => {} }, @@ -154,21 +138,7 @@ impl State { _llm_token_subscription: cx.subscribe( &refresh_llm_token_listener, move |this, _listener, _event, cx| { - let client = this.client.clone(); - let llm_api_token = this.llm_api_token.clone(); - let organization_id = this - .user_store - .read(cx) - .current_organization() - .map(|organization| organization.id.clone()); - cx.spawn(async move |this, cx| { - let response = - Self::fetch_models(client, llm_api_token, organization_id).await?; - this.update(cx, |this, cx| { - this.update_models(response, cx); - }) - }) - .detach_and_log_err(cx); + this.refresh_models(cx); }, ), } @@ -186,74 +156,10 @@ impl State { }) } - fn update_models(&mut self, response: ListModelsResponse, cx: &mut Context) { - let mut models = Vec::new(); - - for model in response.models { - models.push(Arc::new(model.clone())); - } - - self.default_model = models - .iter() - .find(|model| { - response - .default_model - .as_ref() - .is_some_and(|default_model_id| &model.id == default_model_id) - }) - .cloned(); - self.default_fast_model = models - .iter() - .find(|model| { - response - .default_fast_model - .as_ref() - .is_some_and(|default_fast_model_id| &model.id == default_fast_model_id) - }) - .cloned(); - self.recommended_models = response - .recommended_models - .iter() - .filter_map(|id| models.iter().find(|model| &model.id == id)) - .cloned() - .collect(); - self.models = models; - cx.notify(); - } - - async fn fetch_models( - client: Arc, - llm_api_token: LlmApiToken, - organization_id: Option, - ) -> Result { - let http_client = &client.http_client(); - let token = client - .acquire_llm_token(&llm_api_token, organization_id) - .await?; - - let request = http_client::Request::builder() - .method(Method::GET) - .header(CLIENT_SUPPORTS_X_AI_HEADER_NAME, "true") - .uri(http_client.build_zed_llm_url("/models", &[])?.as_ref()) - .header("Authorization", format!("Bearer {token}")) - .body(AsyncBody::empty())?; - let mut response = http_client - .send(request) - .await - .context("failed to send list models request")?; - - if response.status().is_success() { - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - Ok(serde_json::from_str(&body)?) - } else { - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - anyhow::bail!( - "error listing models.\nStatus: {:?}\nBody: {body}", - response.status(), - ); - } + fn refresh_models(&mut self, cx: &mut Context) { + self.provider.update(cx, |provider, cx| { + provider.refresh_models(cx).detach_and_log_err(cx); + }); } } @@ -281,27 +187,10 @@ impl CloudLanguageModelProvider { }); Self { - client, state, _maintain_client_status: maintain_client_status, } } - - fn create_language_model( - &self, - model: Arc, - llm_api_token: LlmApiToken, - user_store: Entity, - ) -> Arc { - Arc::new(CloudLanguageModel { - id: LanguageModelId(SharedString::from(model.id.0.clone())), - model, - llm_api_token, - user_store, - client: self.client.clone(), - request_limiter: RateLimiter::new(4), - }) - } } impl LanguageModelProviderState for CloudLanguageModelProvider { @@ -327,45 +216,35 @@ impl LanguageModelProvider for CloudLanguageModelProvider { fn default_model(&self, cx: &App) -> Option> { let state = self.state.read(cx); - let default_model = state.default_model.clone()?; - let llm_api_token = state.llm_api_token.clone(); - let user_store = state.user_store.clone(); - Some(self.create_language_model(default_model, llm_api_token, user_store)) + let provider = state.provider.read(cx); + let model = provider.default_model()?; + Some(provider.create_model(model)) } fn default_fast_model(&self, cx: &App) -> Option> { let state = self.state.read(cx); - let default_fast_model = state.default_fast_model.clone()?; - let llm_api_token = state.llm_api_token.clone(); - let user_store = state.user_store.clone(); - Some(self.create_language_model(default_fast_model, llm_api_token, user_store)) + let provider = state.provider.read(cx); + let model = provider.default_fast_model()?; + Some(provider.create_model(model)) } fn recommended_models(&self, cx: &App) -> Vec> { let state = self.state.read(cx); - let llm_api_token = state.llm_api_token.clone(); - let user_store = state.user_store.clone(); - state - .recommended_models + let provider = state.provider.read(cx); + provider + .recommended_models() .iter() - .cloned() - .map(|model| { - self.create_language_model(model, llm_api_token.clone(), user_store.clone()) - }) + .map(|model| provider.create_model(model)) .collect() } fn provided_models(&self, cx: &App) -> Vec> { let state = self.state.read(cx); - let llm_api_token = state.llm_api_token.clone(); - let user_store = state.user_store.clone(); - state - .models + let provider = state.provider.read(cx); + provider + .models() .iter() - .cloned() - .map(|model| { - self.create_language_model(model, llm_api_token.clone(), user_store.clone()) - }) + .map(|model| provider.create_model(model)) .collect() } @@ -393,700 +272,6 @@ impl LanguageModelProvider for CloudLanguageModelProvider { } } -pub struct CloudLanguageModel { - id: LanguageModelId, - model: Arc, - llm_api_token: LlmApiToken, - user_store: Entity, - client: Arc, - request_limiter: RateLimiter, -} - -struct PerformLlmCompletionResponse { - response: Response, - includes_status_messages: bool, -} - -impl CloudLanguageModel { - async fn perform_llm_completion( - client: Arc, - llm_api_token: LlmApiToken, - organization_id: Option, - app_version: Option, - body: CompletionBody, - ) -> Result { - let http_client = &client.http_client(); - - let mut token = client - .acquire_llm_token(&llm_api_token, organization_id.clone()) - .await?; - let mut refreshed_token = false; - - loop { - let request = http_client::Request::builder() - .method(Method::POST) - .uri(http_client.build_zed_llm_url("/completions", &[])?.as_ref()) - .when_some(app_version.as_ref(), |builder, app_version| { - builder.header(ZED_VERSION_HEADER_NAME, app_version.to_string()) - }) - .header("Content-Type", "application/json") - .header("Authorization", format!("Bearer {token}")) - .header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true") - .header(CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME, "true") - .body(serde_json::to_string(&body)?.into())?; - - let mut response = http_client.send(request).await?; - let status = response.status(); - if status.is_success() { - let includes_status_messages = response - .headers() - .get(SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME) - .is_some(); - - return Ok(PerformLlmCompletionResponse { - response, - includes_status_messages, - }); - } - - if !refreshed_token && response.needs_llm_token_refresh() { - token = client - .refresh_llm_token(&llm_api_token, organization_id.clone()) - .await?; - refreshed_token = true; - continue; - } - - if status == StatusCode::PAYMENT_REQUIRED { - return Err(anyhow!(PaymentRequiredError)); - } - - let mut body = String::new(); - let headers = response.headers().clone(); - response.body_mut().read_to_string(&mut body).await?; - return Err(anyhow!(ApiError { - status, - body, - headers - })); - } - } -} - -#[derive(Debug, Error)] -#[error("cloud language model request failed with status {status}: {body}")] -struct ApiError { - status: StatusCode, - body: String, - headers: HeaderMap, -} - -/// Represents error responses from Zed's cloud API. -/// -/// Example JSON for an upstream HTTP error: -/// ```json -/// { -/// "code": "upstream_http_error", -/// "message": "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout", -/// "upstream_status": 503 -/// } -/// ``` -#[derive(Debug, serde::Deserialize)] -struct CloudApiError { - code: String, - message: String, - #[serde(default)] - #[serde(deserialize_with = "deserialize_optional_status_code")] - upstream_status: Option, - #[serde(default)] - retry_after: Option, -} - -fn deserialize_optional_status_code<'de, D>(deserializer: D) -> Result, D::Error> -where - D: serde::Deserializer<'de>, -{ - let opt: Option = Option::deserialize(deserializer)?; - Ok(opt.and_then(|code| StatusCode::from_u16(code).ok())) -} - -impl From for LanguageModelCompletionError { - fn from(error: ApiError) -> Self { - if let Ok(cloud_error) = serde_json::from_str::(&error.body) { - if cloud_error.code.starts_with("upstream_http_") { - let status = if let Some(status) = cloud_error.upstream_status { - status - } else if cloud_error.code.ends_with("_error") { - error.status - } else { - // If there's a status code in the code string (e.g. "upstream_http_429") - // then use that; otherwise, see if the JSON contains a status code. - cloud_error - .code - .strip_prefix("upstream_http_") - .and_then(|code_str| code_str.parse::().ok()) - .and_then(|code| StatusCode::from_u16(code).ok()) - .unwrap_or(error.status) - }; - - return LanguageModelCompletionError::UpstreamProviderError { - message: cloud_error.message, - status, - retry_after: cloud_error.retry_after.map(Duration::from_secs_f64), - }; - } - - return LanguageModelCompletionError::from_http_status( - PROVIDER_NAME, - error.status, - cloud_error.message, - None, - ); - } - - let retry_after = None; - LanguageModelCompletionError::from_http_status( - PROVIDER_NAME, - error.status, - error.body, - retry_after, - ) - } -} - -impl LanguageModel for CloudLanguageModel { - fn id(&self) -> LanguageModelId { - self.id.clone() - } - - fn name(&self) -> LanguageModelName { - LanguageModelName::from(self.model.display_name.clone()) - } - - fn provider_id(&self) -> LanguageModelProviderId { - PROVIDER_ID - } - - fn provider_name(&self) -> LanguageModelProviderName { - PROVIDER_NAME - } - - fn upstream_provider_id(&self) -> LanguageModelProviderId { - use cloud_llm_client::LanguageModelProvider::*; - match self.model.provider { - Anthropic => ANTHROPIC_PROVIDER_ID, - OpenAi => OPEN_AI_PROVIDER_ID, - Google => GOOGLE_PROVIDER_ID, - XAi => X_AI_PROVIDER_ID, - } - } - - fn upstream_provider_name(&self) -> LanguageModelProviderName { - use cloud_llm_client::LanguageModelProvider::*; - match self.model.provider { - Anthropic => ANTHROPIC_PROVIDER_NAME, - OpenAi => OPEN_AI_PROVIDER_NAME, - Google => GOOGLE_PROVIDER_NAME, - XAi => X_AI_PROVIDER_NAME, - } - } - - fn is_latest(&self) -> bool { - self.model.is_latest - } - - fn supports_tools(&self) -> bool { - self.model.supports_tools - } - - fn supports_images(&self) -> bool { - self.model.supports_images - } - - fn supports_thinking(&self) -> bool { - self.model.supports_thinking - } - - fn supports_fast_mode(&self) -> bool { - self.model.supports_fast_mode - } - - fn supported_effort_levels(&self) -> Vec { - self.model - .supported_effort_levels - .iter() - .map(|effort_level| LanguageModelEffortLevel { - name: effort_level.name.clone().into(), - value: effort_level.value.clone().into(), - is_default: effort_level.is_default.unwrap_or(false), - }) - .collect() - } - - fn supports_streaming_tools(&self) -> bool { - self.model.supports_streaming_tools - } - - fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { - match choice { - LanguageModelToolChoice::Auto - | LanguageModelToolChoice::Any - | LanguageModelToolChoice::None => true, - } - } - - fn supports_split_token_display(&self) -> bool { - use cloud_llm_client::LanguageModelProvider::*; - matches!(self.model.provider, OpenAi | XAi) - } - - fn telemetry_id(&self) -> String { - format!("zed.dev/{}", self.model.id) - } - - fn tool_input_format(&self) -> LanguageModelToolSchemaFormat { - match self.model.provider { - cloud_llm_client::LanguageModelProvider::Anthropic - | cloud_llm_client::LanguageModelProvider::OpenAi => { - LanguageModelToolSchemaFormat::JsonSchema - } - cloud_llm_client::LanguageModelProvider::Google - | cloud_llm_client::LanguageModelProvider::XAi => { - LanguageModelToolSchemaFormat::JsonSchemaSubset - } - } - } - - fn max_token_count(&self) -> u64 { - self.model.max_token_count as u64 - } - - fn max_output_tokens(&self) -> Option { - Some(self.model.max_output_tokens as u64) - } - - fn cache_configuration(&self) -> Option { - match &self.model.provider { - cloud_llm_client::LanguageModelProvider::Anthropic => { - Some(LanguageModelCacheConfiguration { - min_total_token: 2_048, - should_speculate: true, - max_cache_anchors: 4, - }) - } - cloud_llm_client::LanguageModelProvider::OpenAi - | cloud_llm_client::LanguageModelProvider::XAi - | cloud_llm_client::LanguageModelProvider::Google => None, - } - } - - fn count_tokens( - &self, - request: LanguageModelRequest, - cx: &App, - ) -> BoxFuture<'static, Result> { - match self.model.provider { - cloud_llm_client::LanguageModelProvider::Anthropic => cx - .background_spawn(async move { count_anthropic_tokens_with_tiktoken(request) }) - .boxed(), - cloud_llm_client::LanguageModelProvider::OpenAi => { - let model = match open_ai::Model::from_id(&self.model.id.0) { - Ok(model) => model, - Err(err) => return async move { Err(anyhow!(err)) }.boxed(), - }; - count_open_ai_tokens(request, model, cx) - } - cloud_llm_client::LanguageModelProvider::XAi => { - let model = match x_ai::Model::from_id(&self.model.id.0) { - Ok(model) => model, - Err(err) => return async move { Err(anyhow!(err)) }.boxed(), - }; - count_xai_tokens(request, model, cx) - } - cloud_llm_client::LanguageModelProvider::Google => { - let client = self.client.clone(); - let llm_api_token = self.llm_api_token.clone(); - let organization_id = self - .user_store - .read(cx) - .current_organization() - .map(|organization| organization.id.clone()); - let model_id = self.model.id.to_string(); - let generate_content_request = - into_google(request, model_id.clone(), GoogleModelMode::Default); - async move { - let http_client = &client.http_client(); - let token = client - .acquire_llm_token(&llm_api_token, organization_id) - .await?; - - let request_body = CountTokensBody { - provider: cloud_llm_client::LanguageModelProvider::Google, - model: model_id, - provider_request: serde_json::to_value(&google_ai::CountTokensRequest { - generate_content_request, - })?, - }; - let request = http_client::Request::builder() - .method(Method::POST) - .uri( - http_client - .build_zed_llm_url("/count_tokens", &[])? - .as_ref(), - ) - .header("Content-Type", "application/json") - .header("Authorization", format!("Bearer {token}")) - .body(serde_json::to_string(&request_body)?.into())?; - let mut response = http_client.send(request).await?; - let status = response.status(); - let headers = response.headers().clone(); - let mut response_body = String::new(); - response - .body_mut() - .read_to_string(&mut response_body) - .await?; - - if status.is_success() { - let response_body: CountTokensResponse = - serde_json::from_str(&response_body)?; - - Ok(response_body.tokens as u64) - } else { - Err(anyhow!(ApiError { - status, - body: response_body, - headers - })) - } - } - .boxed() - } - } - } - - fn stream_completion( - &self, - request: LanguageModelRequest, - cx: &AsyncApp, - ) -> BoxFuture< - 'static, - Result< - BoxStream<'static, Result>, - LanguageModelCompletionError, - >, - > { - let thread_id = request.thread_id.clone(); - let prompt_id = request.prompt_id.clone(); - let app_version = Some(cx.update(|cx| AppVersion::global(cx))); - let user_store = self.user_store.clone(); - let organization_id = cx.update(|cx| { - user_store - .read(cx) - .current_organization() - .map(|organization| organization.id.clone()) - }); - let thinking_allowed = request.thinking_allowed; - let enable_thinking = thinking_allowed && self.model.supports_thinking; - let provider_name = provider_name(&self.model.provider); - match self.model.provider { - cloud_llm_client::LanguageModelProvider::Anthropic => { - let effort = request - .thinking_effort - .as_ref() - .and_then(|effort| anthropic::Effort::from_str(effort).ok()); - - let mut request = into_anthropic( - request, - self.model.id.to_string(), - 1.0, - self.model.max_output_tokens as u64, - if enable_thinking { - AnthropicModelMode::Thinking { - budget_tokens: Some(4_096), - } - } else { - AnthropicModelMode::Default - }, - ); - - if enable_thinking && effort.is_some() { - request.thinking = Some(anthropic::Thinking::Adaptive); - request.output_config = Some(anthropic::OutputConfig { effort }); - } - - let client = self.client.clone(); - let llm_api_token = self.llm_api_token.clone(); - let organization_id = organization_id.clone(); - let future = self.request_limiter.stream(async move { - let PerformLlmCompletionResponse { - response, - includes_status_messages, - } = Self::perform_llm_completion( - client.clone(), - llm_api_token, - organization_id, - app_version, - CompletionBody { - thread_id, - prompt_id, - provider: cloud_llm_client::LanguageModelProvider::Anthropic, - model: request.model.clone(), - provider_request: serde_json::to_value(&request) - .map_err(|e| anyhow!(e))?, - }, - ) - .await - .map_err(|err| match err.downcast::() { - Ok(api_err) => anyhow!(LanguageModelCompletionError::from(api_err)), - Err(err) => anyhow!(err), - })?; - - let mut mapper = AnthropicEventMapper::new(); - Ok(map_cloud_completion_events( - Box::pin(response_lines(response, includes_status_messages)), - &provider_name, - move |event| mapper.map_event(event), - )) - }); - async move { Ok(future.await?.boxed()) }.boxed() - } - cloud_llm_client::LanguageModelProvider::OpenAi => { - let client = self.client.clone(); - let llm_api_token = self.llm_api_token.clone(); - let organization_id = organization_id.clone(); - let effort = request - .thinking_effort - .as_ref() - .and_then(|effort| open_ai::ReasoningEffort::from_str(effort).ok()); - - let mut request = into_open_ai_response( - request, - &self.model.id.0, - self.model.supports_parallel_tool_calls, - true, - None, - None, - ); - - if enable_thinking && let Some(effort) = effort { - request.reasoning = Some(open_ai::responses::ReasoningConfig { - effort, - summary: Some(open_ai::responses::ReasoningSummaryMode::Auto), - }); - } - - let future = self.request_limiter.stream(async move { - let PerformLlmCompletionResponse { - response, - includes_status_messages, - } = Self::perform_llm_completion( - client.clone(), - llm_api_token, - organization_id, - app_version, - CompletionBody { - thread_id, - prompt_id, - provider: cloud_llm_client::LanguageModelProvider::OpenAi, - model: request.model.clone(), - provider_request: serde_json::to_value(&request) - .map_err(|e| anyhow!(e))?, - }, - ) - .await?; - - let mut mapper = OpenAiResponseEventMapper::new(); - Ok(map_cloud_completion_events( - Box::pin(response_lines(response, includes_status_messages)), - &provider_name, - move |event| mapper.map_event(event), - )) - }); - async move { Ok(future.await?.boxed()) }.boxed() - } - cloud_llm_client::LanguageModelProvider::XAi => { - let client = self.client.clone(); - let request = into_open_ai( - request, - &self.model.id.0, - self.model.supports_parallel_tool_calls, - false, - None, - None, - ); - let llm_api_token = self.llm_api_token.clone(); - let organization_id = organization_id.clone(); - let future = self.request_limiter.stream(async move { - let PerformLlmCompletionResponse { - response, - includes_status_messages, - } = Self::perform_llm_completion( - client.clone(), - llm_api_token, - organization_id, - app_version, - CompletionBody { - thread_id, - prompt_id, - provider: cloud_llm_client::LanguageModelProvider::XAi, - model: request.model.clone(), - provider_request: serde_json::to_value(&request) - .map_err(|e| anyhow!(e))?, - }, - ) - .await?; - - let mut mapper = OpenAiEventMapper::new(); - Ok(map_cloud_completion_events( - Box::pin(response_lines(response, includes_status_messages)), - &provider_name, - move |event| mapper.map_event(event), - )) - }); - async move { Ok(future.await?.boxed()) }.boxed() - } - cloud_llm_client::LanguageModelProvider::Google => { - let client = self.client.clone(); - let request = - into_google(request, self.model.id.to_string(), GoogleModelMode::Default); - let llm_api_token = self.llm_api_token.clone(); - let future = self.request_limiter.stream(async move { - let PerformLlmCompletionResponse { - response, - includes_status_messages, - } = Self::perform_llm_completion( - client.clone(), - llm_api_token, - organization_id, - app_version, - CompletionBody { - thread_id, - prompt_id, - provider: cloud_llm_client::LanguageModelProvider::Google, - model: request.model.model_id.clone(), - provider_request: serde_json::to_value(&request) - .map_err(|e| anyhow!(e))?, - }, - ) - .await?; - - let mut mapper = GoogleEventMapper::new(); - Ok(map_cloud_completion_events( - Box::pin(response_lines(response, includes_status_messages)), - &provider_name, - move |event| mapper.map_event(event), - )) - }); - async move { Ok(future.await?.boxed()) }.boxed() - } - } - } -} - -fn map_cloud_completion_events( - stream: Pin>> + Send>>, - provider: &LanguageModelProviderName, - mut map_callback: F, -) -> BoxStream<'static, Result> -where - T: DeserializeOwned + 'static, - F: FnMut(T) -> Vec> - + Send - + 'static, -{ - let provider = provider.clone(); - let mut stream = stream.fuse(); - - let mut saw_stream_ended = false; - - let mut done = false; - let mut pending = VecDeque::new(); - - stream::poll_fn(move |cx| { - loop { - if let Some(item) = pending.pop_front() { - return Poll::Ready(Some(item)); - } - - if done { - return Poll::Ready(None); - } - - match stream.poll_next_unpin(cx) { - Poll::Ready(Some(event)) => { - let items = match event { - Err(error) => { - vec![Err(LanguageModelCompletionError::from(error))] - } - Ok(CompletionEvent::Status(CompletionRequestStatus::StreamEnded)) => { - saw_stream_ended = true; - vec![] - } - Ok(CompletionEvent::Status(status)) => { - LanguageModelCompletionEvent::from_completion_request_status( - status, - provider.clone(), - ) - .transpose() - .map(|event| vec![event]) - .unwrap_or_default() - } - Ok(CompletionEvent::Event(event)) => map_callback(event), - }; - pending.extend(items); - } - Poll::Ready(None) => { - done = true; - - if !saw_stream_ended { - return Poll::Ready(Some(Err( - LanguageModelCompletionError::StreamEndedUnexpectedly { - provider: provider.clone(), - }, - ))); - } - } - Poll::Pending => return Poll::Pending, - } - } - }) - .boxed() -} - -fn provider_name(provider: &cloud_llm_client::LanguageModelProvider) -> LanguageModelProviderName { - match provider { - cloud_llm_client::LanguageModelProvider::Anthropic => ANTHROPIC_PROVIDER_NAME, - cloud_llm_client::LanguageModelProvider::OpenAi => OPEN_AI_PROVIDER_NAME, - cloud_llm_client::LanguageModelProvider::Google => GOOGLE_PROVIDER_NAME, - cloud_llm_client::LanguageModelProvider::XAi => X_AI_PROVIDER_NAME, - } -} - -fn response_lines( - response: Response, - includes_status_messages: bool, -) -> impl Stream>> { - futures::stream::try_unfold( - (String::new(), BufReader::new(response.into_body())), - move |(mut line, mut body)| async move { - match body.read_line(&mut line).await { - Ok(0) => Ok(None), - Ok(_) => { - let event = if includes_status_messages { - serde_json::from_str::>(&line)? - } else { - CompletionEvent::Event(serde_json::from_str::(&line)?) - }; - - line.clear(); - Ok(Some((event, (line, body)))) - } - Err(e) => Err(e.into()), - } - }, - ) -} - #[derive(IntoElement, RegisterComponent)] struct ZedAiConfiguration { is_connected: bool, @@ -1281,155 +466,3 @@ impl Component for ZedAiConfiguration { ) } } - -#[cfg(test)] -mod tests { - use super::*; - use http_client::http::{HeaderMap, StatusCode}; - use language_model::LanguageModelCompletionError; - - #[test] - fn test_api_error_conversion_with_upstream_http_error() { - // upstream_http_error with 503 status should become ServerOverloaded - let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout","upstream_status":503}"#; - - let api_error = ApiError { - status: StatusCode::INTERNAL_SERVER_ERROR, - body: error_body.to_string(), - headers: HeaderMap::new(), - }; - - let completion_error: LanguageModelCompletionError = api_error.into(); - - match completion_error { - LanguageModelCompletionError::UpstreamProviderError { message, .. } => { - assert_eq!( - message, - "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout" - ); - } - _ => panic!( - "Expected UpstreamProviderError for upstream 503, got: {:?}", - completion_error - ), - } - - // upstream_http_error with 500 status should become ApiInternalServerError - let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the OpenAI API: internal server error","upstream_status":500}"#; - - let api_error = ApiError { - status: StatusCode::INTERNAL_SERVER_ERROR, - body: error_body.to_string(), - headers: HeaderMap::new(), - }; - - let completion_error: LanguageModelCompletionError = api_error.into(); - - match completion_error { - LanguageModelCompletionError::UpstreamProviderError { message, .. } => { - assert_eq!( - message, - "Received an error from the OpenAI API: internal server error" - ); - } - _ => panic!( - "Expected UpstreamProviderError for upstream 500, got: {:?}", - completion_error - ), - } - - // upstream_http_error with 429 status should become RateLimitExceeded - let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Google API: rate limit exceeded","upstream_status":429}"#; - - let api_error = ApiError { - status: StatusCode::INTERNAL_SERVER_ERROR, - body: error_body.to_string(), - headers: HeaderMap::new(), - }; - - let completion_error: LanguageModelCompletionError = api_error.into(); - - match completion_error { - LanguageModelCompletionError::UpstreamProviderError { message, .. } => { - assert_eq!( - message, - "Received an error from the Google API: rate limit exceeded" - ); - } - _ => panic!( - "Expected UpstreamProviderError for upstream 429, got: {:?}", - completion_error - ), - } - - // Regular 500 error without upstream_http_error should remain ApiInternalServerError for Zed - let error_body = "Regular internal server error"; - - let api_error = ApiError { - status: StatusCode::INTERNAL_SERVER_ERROR, - body: error_body.to_string(), - headers: HeaderMap::new(), - }; - - let completion_error: LanguageModelCompletionError = api_error.into(); - - match completion_error { - LanguageModelCompletionError::ApiInternalServerError { provider, message } => { - assert_eq!(provider, PROVIDER_NAME); - assert_eq!(message, "Regular internal server error"); - } - _ => panic!( - "Expected ApiInternalServerError for regular 500, got: {:?}", - completion_error - ), - } - - // upstream_http_429 format should be converted to UpstreamProviderError - let error_body = r#"{"code":"upstream_http_429","message":"Upstream Anthropic rate limit exceeded.","retry_after":30.5}"#; - - let api_error = ApiError { - status: StatusCode::INTERNAL_SERVER_ERROR, - body: error_body.to_string(), - headers: HeaderMap::new(), - }; - - let completion_error: LanguageModelCompletionError = api_error.into(); - - match completion_error { - LanguageModelCompletionError::UpstreamProviderError { - message, - status, - retry_after, - } => { - assert_eq!(message, "Upstream Anthropic rate limit exceeded."); - assert_eq!(status, StatusCode::TOO_MANY_REQUESTS); - assert_eq!(retry_after, Some(Duration::from_secs_f64(30.5))); - } - _ => panic!( - "Expected UpstreamProviderError for upstream_http_429, got: {:?}", - completion_error - ), - } - - // Invalid JSON in error body should fall back to regular error handling - let error_body = "Not JSON at all"; - - let api_error = ApiError { - status: StatusCode::INTERNAL_SERVER_ERROR, - body: error_body.to_string(), - headers: HeaderMap::new(), - }; - - let completion_error: LanguageModelCompletionError = api_error.into(); - - match completion_error { - LanguageModelCompletionError::ApiInternalServerError { provider, .. } => { - assert_eq!(provider, PROVIDER_NAME); - } - _ => panic!( - "Expected ApiInternalServerError for invalid JSON, got: {:?}", - completion_error - ), - } - } -} diff --git a/crates/language_models/src/provider/copilot_chat.rs b/crates/language_models/src/provider/copilot_chat.rs index a2d39e1945e2791d9d5c998cc717a07498ebc157..a77e3f880be18d8f9f0e97ec8717c32bc780e267 100644 --- a/crates/language_models/src/provider/copilot_chat.rs +++ b/crates/language_models/src/provider/copilot_chat.rs @@ -32,7 +32,7 @@ use ui::prelude::*; use util::debug_panic; use crate::provider::anthropic::{AnthropicEventMapper, into_anthropic}; -use crate::provider::util::{fix_streamed_json, parse_tool_arguments}; +use language_model::util::{fix_streamed_json, parse_tool_arguments}; const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("copilot_chat"); const PROVIDER_NAME: LanguageModelProviderName = @@ -268,15 +268,15 @@ impl LanguageModel for CopilotChatLanguageModel { levels .iter() .map(|level| { - let name: SharedString = match level.as_str() { + let name = match level.as_str() { "low" => "Low".into(), "medium" => "Medium".into(), "high" => "High".into(), - _ => SharedString::from(level.clone()), + _ => language_model::SharedString::from(level.clone()), }; LanguageModelEffortLevel { name, - value: SharedString::from(level.clone()), + value: language_model::SharedString::from(level.clone()), is_default: level == "high", } }) diff --git a/crates/language_models/src/provider/deepseek.rs b/crates/language_models/src/provider/deepseek.rs index 0cfb1af425c7cb0279d98fa124a589437f1bb1a1..f3dccd5cc1a2e1a5ddfe2bc6b43901f2b549e532 100644 --- a/crates/language_models/src/provider/deepseek.rs +++ b/crates/language_models/src/provider/deepseek.rs @@ -23,7 +23,7 @@ use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*}; use ui_input::InputField; use util::ResultExt; -use crate::provider::util::{fix_streamed_json, parse_tool_arguments}; +use language_model::util::{fix_streamed_json, parse_tool_arguments}; const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("deepseek"); const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("DeepSeek"); diff --git a/crates/language_models/src/provider/google.rs b/crates/language_models/src/provider/google.rs index 244f7835a85ff67f0c4826321910ea13516371cb..92278839c6ff5119849f8881409928686f055331 100644 --- a/crates/language_models/src/provider/google.rs +++ b/crates/language_models/src/provider/google.rs @@ -1,32 +1,25 @@ use anyhow::{Context as _, Result}; use collections::BTreeMap; use credentials_provider::CredentialsProvider; -use futures::{FutureExt, Stream, StreamExt, future::BoxFuture}; -use google_ai::{ - FunctionDeclaration, GenerateContentResponse, GoogleModelMode, Part, SystemInstruction, - ThinkingConfig, UsageMetadata, -}; +use futures::{FutureExt, StreamExt, future::BoxFuture}; +pub use google_ai::completion::{GoogleEventMapper, count_google_tokens, into_google}; +use google_ai::{GenerateContentResponse, GoogleModelMode}; use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window}; use http_client::HttpClient; use language_model::{ AuthenticateError, ConfigurationViewTargetAgent, EnvVar, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelToolChoice, LanguageModelToolSchemaFormat, - LanguageModelToolUse, LanguageModelToolUseId, MessageContent, StopReason, }; use language_model::{ GOOGLE_PROVIDER_ID, GOOGLE_PROVIDER_NAME, IconOrSvg, LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, - LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role, + LanguageModelProviderState, LanguageModelRequest, RateLimiter, }; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; pub use settings::GoogleAvailableModel as AvailableModel; use settings::{Settings, SettingsStore}; -use std::pin::Pin; -use std::sync::{ - Arc, LazyLock, - atomic::{self, AtomicU64}, -}; +use std::sync::{Arc, LazyLock}; use strum::IntoEnumIterator; use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*}; use ui_input::InputField; @@ -394,369 +387,6 @@ impl LanguageModel for GoogleLanguageModel { } } -pub fn into_google( - mut request: LanguageModelRequest, - model_id: String, - mode: GoogleModelMode, -) -> google_ai::GenerateContentRequest { - fn map_content(content: Vec) -> Vec { - content - .into_iter() - .flat_map(|content| match content { - language_model::MessageContent::Text(text) => { - if !text.is_empty() { - vec![Part::TextPart(google_ai::TextPart { text })] - } else { - vec![] - } - } - language_model::MessageContent::Thinking { - text: _, - signature: Some(signature), - } => { - if !signature.is_empty() { - vec![Part::ThoughtPart(google_ai::ThoughtPart { - thought: true, - thought_signature: signature, - })] - } else { - vec![] - } - } - language_model::MessageContent::Thinking { .. } => { - vec![] - } - language_model::MessageContent::RedactedThinking(_) => vec![], - language_model::MessageContent::Image(image) => { - vec![Part::InlineDataPart(google_ai::InlineDataPart { - inline_data: google_ai::GenerativeContentBlob { - mime_type: "image/png".to_string(), - data: image.source.to_string(), - }, - })] - } - language_model::MessageContent::ToolUse(tool_use) => { - // Normalize empty string signatures to None - let thought_signature = tool_use.thought_signature.filter(|s| !s.is_empty()); - - vec![Part::FunctionCallPart(google_ai::FunctionCallPart { - function_call: google_ai::FunctionCall { - name: tool_use.name.to_string(), - args: tool_use.input, - }, - thought_signature, - })] - } - language_model::MessageContent::ToolResult(tool_result) => { - match tool_result.content { - language_model::LanguageModelToolResultContent::Text(text) => { - vec![Part::FunctionResponsePart( - google_ai::FunctionResponsePart { - function_response: google_ai::FunctionResponse { - name: tool_result.tool_name.to_string(), - // The API expects a valid JSON object - response: serde_json::json!({ - "output": text - }), - }, - }, - )] - } - language_model::LanguageModelToolResultContent::Image(image) => { - vec![ - Part::FunctionResponsePart(google_ai::FunctionResponsePart { - function_response: google_ai::FunctionResponse { - name: tool_result.tool_name.to_string(), - // The API expects a valid JSON object - response: serde_json::json!({ - "output": "Tool responded with an image" - }), - }, - }), - Part::InlineDataPart(google_ai::InlineDataPart { - inline_data: google_ai::GenerativeContentBlob { - mime_type: "image/png".to_string(), - data: image.source.to_string(), - }, - }), - ] - } - } - } - }) - .collect() - } - - let system_instructions = if request - .messages - .first() - .is_some_and(|msg| matches!(msg.role, Role::System)) - { - let message = request.messages.remove(0); - Some(SystemInstruction { - parts: map_content(message.content), - }) - } else { - None - }; - - google_ai::GenerateContentRequest { - model: google_ai::ModelName { model_id }, - system_instruction: system_instructions, - contents: request - .messages - .into_iter() - .filter_map(|message| { - let parts = map_content(message.content); - if parts.is_empty() { - None - } else { - Some(google_ai::Content { - parts, - role: match message.role { - Role::User => google_ai::Role::User, - Role::Assistant => google_ai::Role::Model, - Role::System => google_ai::Role::User, // Google AI doesn't have a system role - }, - }) - } - }) - .collect(), - generation_config: Some(google_ai::GenerationConfig { - candidate_count: Some(1), - stop_sequences: Some(request.stop), - max_output_tokens: None, - temperature: request.temperature.map(|t| t as f64).or(Some(1.0)), - thinking_config: match (request.thinking_allowed, mode) { - (true, GoogleModelMode::Thinking { budget_tokens }) => { - budget_tokens.map(|thinking_budget| ThinkingConfig { thinking_budget }) - } - _ => None, - }, - top_p: None, - top_k: None, - }), - safety_settings: None, - tools: (!request.tools.is_empty()).then(|| { - vec![google_ai::Tool { - function_declarations: request - .tools - .into_iter() - .map(|tool| FunctionDeclaration { - name: tool.name, - description: tool.description, - parameters: tool.input_schema, - }) - .collect(), - }] - }), - tool_config: request.tool_choice.map(|choice| google_ai::ToolConfig { - function_calling_config: google_ai::FunctionCallingConfig { - mode: match choice { - LanguageModelToolChoice::Auto => google_ai::FunctionCallingMode::Auto, - LanguageModelToolChoice::Any => google_ai::FunctionCallingMode::Any, - LanguageModelToolChoice::None => google_ai::FunctionCallingMode::None, - }, - allowed_function_names: None, - }, - }), - } -} - -pub struct GoogleEventMapper { - usage: UsageMetadata, - stop_reason: StopReason, -} - -impl GoogleEventMapper { - pub fn new() -> Self { - Self { - usage: UsageMetadata::default(), - stop_reason: StopReason::EndTurn, - } - } - - pub fn map_stream( - mut self, - events: Pin>>>, - ) -> impl Stream> - { - events - .map(Some) - .chain(futures::stream::once(async { None })) - .flat_map(move |event| { - futures::stream::iter(match event { - Some(Ok(event)) => self.map_event(event), - Some(Err(error)) => { - vec![Err(LanguageModelCompletionError::from(error))] - } - None => vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))], - }) - }) - } - - pub fn map_event( - &mut self, - event: GenerateContentResponse, - ) -> Vec> { - static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0); - - let mut events: Vec<_> = Vec::new(); - let mut wants_to_use_tool = false; - if let Some(usage_metadata) = event.usage_metadata { - update_usage(&mut self.usage, &usage_metadata); - events.push(Ok(LanguageModelCompletionEvent::UsageUpdate( - convert_usage(&self.usage), - ))) - } - - if let Some(prompt_feedback) = event.prompt_feedback - && let Some(block_reason) = prompt_feedback.block_reason.as_deref() - { - self.stop_reason = match block_reason { - "SAFETY" | "OTHER" | "BLOCKLIST" | "PROHIBITED_CONTENT" | "IMAGE_SAFETY" => { - StopReason::Refusal - } - _ => { - log::error!("Unexpected Google block_reason: {block_reason}"); - StopReason::Refusal - } - }; - events.push(Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))); - - return events; - } - - if let Some(candidates) = event.candidates { - for candidate in candidates { - if let Some(finish_reason) = candidate.finish_reason.as_deref() { - self.stop_reason = match finish_reason { - "STOP" => StopReason::EndTurn, - "MAX_TOKENS" => StopReason::MaxTokens, - _ => { - log::error!("Unexpected google finish_reason: {finish_reason}"); - StopReason::EndTurn - } - }; - } - candidate - .content - .parts - .into_iter() - .for_each(|part| match part { - Part::TextPart(text_part) => { - events.push(Ok(LanguageModelCompletionEvent::Text(text_part.text))) - } - Part::InlineDataPart(_) => {} - Part::FunctionCallPart(function_call_part) => { - wants_to_use_tool = true; - let name: Arc = function_call_part.function_call.name.into(); - let next_tool_id = - TOOL_CALL_COUNTER.fetch_add(1, atomic::Ordering::SeqCst); - let id: LanguageModelToolUseId = - format!("{}-{}", name, next_tool_id).into(); - - // Normalize empty string signatures to None - let thought_signature = function_call_part - .thought_signature - .filter(|s| !s.is_empty()); - - events.push(Ok(LanguageModelCompletionEvent::ToolUse( - LanguageModelToolUse { - id, - name, - is_input_complete: true, - raw_input: function_call_part.function_call.args.to_string(), - input: function_call_part.function_call.args, - thought_signature, - }, - ))); - } - Part::FunctionResponsePart(_) => {} - Part::ThoughtPart(part) => { - events.push(Ok(LanguageModelCompletionEvent::Thinking { - text: "(Encrypted thought)".to_string(), // TODO: Can we populate this from thought summaries? - signature: Some(part.thought_signature), - })); - } - }); - } - } - - // Even when Gemini wants to use a Tool, the API - // responds with `finish_reason: STOP` - if wants_to_use_tool { - self.stop_reason = StopReason::ToolUse; - events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse))); - } - events - } -} - -pub fn count_google_tokens( - request: LanguageModelRequest, - cx: &App, -) -> BoxFuture<'static, Result> { - // We couldn't use the GoogleLanguageModelProvider to count tokens because the github copilot doesn't have the access to google_ai directly. - // So we have to use tokenizer from tiktoken_rs to count tokens. - cx.background_spawn(async move { - let messages = request - .messages - .into_iter() - .map(|message| tiktoken_rs::ChatCompletionRequestMessage { - role: match message.role { - Role::User => "user".into(), - Role::Assistant => "assistant".into(), - Role::System => "system".into(), - }, - content: Some(message.string_contents()), - name: None, - function_call: None, - }) - .collect::>(); - - // Tiktoken doesn't yet support these models, so we manually use the - // same tokenizer as GPT-4. - tiktoken_rs::num_tokens_from_messages("gpt-4", &messages).map(|tokens| tokens as u64) - }) - .boxed() -} - -fn update_usage(usage: &mut UsageMetadata, new: &UsageMetadata) { - if let Some(prompt_token_count) = new.prompt_token_count { - usage.prompt_token_count = Some(prompt_token_count); - } - if let Some(cached_content_token_count) = new.cached_content_token_count { - usage.cached_content_token_count = Some(cached_content_token_count); - } - if let Some(candidates_token_count) = new.candidates_token_count { - usage.candidates_token_count = Some(candidates_token_count); - } - if let Some(tool_use_prompt_token_count) = new.tool_use_prompt_token_count { - usage.tool_use_prompt_token_count = Some(tool_use_prompt_token_count); - } - if let Some(thoughts_token_count) = new.thoughts_token_count { - usage.thoughts_token_count = Some(thoughts_token_count); - } - if let Some(total_token_count) = new.total_token_count { - usage.total_token_count = Some(total_token_count); - } -} - -fn convert_usage(usage: &UsageMetadata) -> language_model::TokenUsage { - let prompt_tokens = usage.prompt_token_count.unwrap_or(0); - let cached_tokens = usage.cached_content_token_count.unwrap_or(0); - let input_tokens = prompt_tokens - cached_tokens; - let output_tokens = usage.candidates_token_count.unwrap_or(0); - - language_model::TokenUsage { - input_tokens, - output_tokens, - cache_read_input_tokens: cached_tokens, - cache_creation_input_tokens: 0, - } -} - struct ConfigurationView { api_key_editor: Entity, state: Entity, @@ -895,428 +525,3 @@ impl Render for ConfigurationView { } } } - -#[cfg(test)] -mod tests { - use super::*; - use google_ai::{ - Content, FunctionCall, FunctionCallPart, GenerateContentCandidate, GenerateContentResponse, - Part, Role as GoogleRole, TextPart, - }; - use language_model::{LanguageModelToolUseId, MessageContent, Role}; - use serde_json::json; - - #[test] - fn test_function_call_with_signature_creates_tool_use_with_signature() { - let mut mapper = GoogleEventMapper::new(); - - let response = GenerateContentResponse { - candidates: Some(vec![GenerateContentCandidate { - index: Some(0), - content: Content { - parts: vec![Part::FunctionCallPart(FunctionCallPart { - function_call: FunctionCall { - name: "test_function".to_string(), - args: json!({"arg": "value"}), - }, - thought_signature: Some("test_signature_123".to_string()), - })], - role: GoogleRole::Model, - }, - finish_reason: None, - finish_message: None, - safety_ratings: None, - citation_metadata: None, - }]), - prompt_feedback: None, - usage_metadata: None, - }; - - let events = mapper.map_event(response); - - assert_eq!(events.len(), 2); // ToolUse event + Stop event - - if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] { - assert_eq!(tool_use.name.as_ref(), "test_function"); - assert_eq!( - tool_use.thought_signature.as_deref(), - Some("test_signature_123") - ); - } else { - panic!("Expected ToolUse event"); - } - } - - #[test] - fn test_function_call_without_signature_has_none() { - let mut mapper = GoogleEventMapper::new(); - - let response = GenerateContentResponse { - candidates: Some(vec![GenerateContentCandidate { - index: Some(0), - content: Content { - parts: vec![Part::FunctionCallPart(FunctionCallPart { - function_call: FunctionCall { - name: "test_function".to_string(), - args: json!({"arg": "value"}), - }, - thought_signature: None, - })], - role: GoogleRole::Model, - }, - finish_reason: None, - finish_message: None, - safety_ratings: None, - citation_metadata: None, - }]), - prompt_feedback: None, - usage_metadata: None, - }; - - let events = mapper.map_event(response); - - if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] { - assert_eq!(tool_use.thought_signature, None); - } else { - panic!("Expected ToolUse event"); - } - } - - #[test] - fn test_empty_string_signature_normalized_to_none() { - let mut mapper = GoogleEventMapper::new(); - - let response = GenerateContentResponse { - candidates: Some(vec![GenerateContentCandidate { - index: Some(0), - content: Content { - parts: vec![Part::FunctionCallPart(FunctionCallPart { - function_call: FunctionCall { - name: "test_function".to_string(), - args: json!({"arg": "value"}), - }, - thought_signature: Some("".to_string()), - })], - role: GoogleRole::Model, - }, - finish_reason: None, - finish_message: None, - safety_ratings: None, - citation_metadata: None, - }]), - prompt_feedback: None, - usage_metadata: None, - }; - - let events = mapper.map_event(response); - - if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] { - assert_eq!(tool_use.thought_signature, None); - } else { - panic!("Expected ToolUse event"); - } - } - - #[test] - fn test_parallel_function_calls_preserve_signatures() { - let mut mapper = GoogleEventMapper::new(); - - let response = GenerateContentResponse { - candidates: Some(vec![GenerateContentCandidate { - index: Some(0), - content: Content { - parts: vec![ - Part::FunctionCallPart(FunctionCallPart { - function_call: FunctionCall { - name: "function_1".to_string(), - args: json!({"arg": "value1"}), - }, - thought_signature: Some("signature_1".to_string()), - }), - Part::FunctionCallPart(FunctionCallPart { - function_call: FunctionCall { - name: "function_2".to_string(), - args: json!({"arg": "value2"}), - }, - thought_signature: None, - }), - ], - role: GoogleRole::Model, - }, - finish_reason: None, - finish_message: None, - safety_ratings: None, - citation_metadata: None, - }]), - prompt_feedback: None, - usage_metadata: None, - }; - - let events = mapper.map_event(response); - - assert_eq!(events.len(), 3); // 2 ToolUse events + Stop event - - if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] { - assert_eq!(tool_use.name.as_ref(), "function_1"); - assert_eq!(tool_use.thought_signature.as_deref(), Some("signature_1")); - } else { - panic!("Expected ToolUse event for function_1"); - } - - if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[1] { - assert_eq!(tool_use.name.as_ref(), "function_2"); - assert_eq!(tool_use.thought_signature, None); - } else { - panic!("Expected ToolUse event for function_2"); - } - } - - #[test] - fn test_tool_use_with_signature_converts_to_function_call_part() { - let tool_use = language_model::LanguageModelToolUse { - id: LanguageModelToolUseId::from("test_id"), - name: "test_function".into(), - raw_input: json!({"arg": "value"}).to_string(), - input: json!({"arg": "value"}), - is_input_complete: true, - thought_signature: Some("test_signature_456".to_string()), - }; - - let request = super::into_google( - LanguageModelRequest { - messages: vec![language_model::LanguageModelRequestMessage { - role: Role::Assistant, - content: vec![MessageContent::ToolUse(tool_use)], - cache: false, - reasoning_details: None, - }], - ..Default::default() - }, - "gemini-2.5-flash".to_string(), - GoogleModelMode::Default, - ); - - assert_eq!(request.contents[0].parts.len(), 1); - if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] { - assert_eq!(fc_part.function_call.name, "test_function"); - assert_eq!( - fc_part.thought_signature.as_deref(), - Some("test_signature_456") - ); - } else { - panic!("Expected FunctionCallPart"); - } - } - - #[test] - fn test_tool_use_without_signature_omits_field() { - let tool_use = language_model::LanguageModelToolUse { - id: LanguageModelToolUseId::from("test_id"), - name: "test_function".into(), - raw_input: json!({"arg": "value"}).to_string(), - input: json!({"arg": "value"}), - is_input_complete: true, - thought_signature: None, - }; - - let request = super::into_google( - LanguageModelRequest { - messages: vec![language_model::LanguageModelRequestMessage { - role: Role::Assistant, - content: vec![MessageContent::ToolUse(tool_use)], - cache: false, - reasoning_details: None, - }], - ..Default::default() - }, - "gemini-2.5-flash".to_string(), - GoogleModelMode::Default, - ); - - assert_eq!(request.contents[0].parts.len(), 1); - if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] { - assert_eq!(fc_part.thought_signature, None); - } else { - panic!("Expected FunctionCallPart"); - } - } - - #[test] - fn test_empty_signature_in_tool_use_normalized_to_none() { - let tool_use = language_model::LanguageModelToolUse { - id: LanguageModelToolUseId::from("test_id"), - name: "test_function".into(), - raw_input: json!({"arg": "value"}).to_string(), - input: json!({"arg": "value"}), - is_input_complete: true, - thought_signature: Some("".to_string()), - }; - - let request = super::into_google( - LanguageModelRequest { - messages: vec![language_model::LanguageModelRequestMessage { - role: Role::Assistant, - content: vec![MessageContent::ToolUse(tool_use)], - cache: false, - reasoning_details: None, - }], - ..Default::default() - }, - "gemini-2.5-flash".to_string(), - GoogleModelMode::Default, - ); - - if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] { - assert_eq!(fc_part.thought_signature, None); - } else { - panic!("Expected FunctionCallPart"); - } - } - - #[test] - fn test_round_trip_preserves_signature() { - let mut mapper = GoogleEventMapper::new(); - - // Simulate receiving a response from Google with a signature - let response = GenerateContentResponse { - candidates: Some(vec![GenerateContentCandidate { - index: Some(0), - content: Content { - parts: vec![Part::FunctionCallPart(FunctionCallPart { - function_call: FunctionCall { - name: "test_function".to_string(), - args: json!({"arg": "value"}), - }, - thought_signature: Some("round_trip_sig".to_string()), - })], - role: GoogleRole::Model, - }, - finish_reason: None, - finish_message: None, - safety_ratings: None, - citation_metadata: None, - }]), - prompt_feedback: None, - usage_metadata: None, - }; - - let events = mapper.map_event(response); - - let tool_use = if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] { - tool_use.clone() - } else { - panic!("Expected ToolUse event"); - }; - - // Convert back to Google format - let request = super::into_google( - LanguageModelRequest { - messages: vec![language_model::LanguageModelRequestMessage { - role: Role::Assistant, - content: vec![MessageContent::ToolUse(tool_use)], - cache: false, - reasoning_details: None, - }], - ..Default::default() - }, - "gemini-2.5-flash".to_string(), - GoogleModelMode::Default, - ); - - // Verify signature is preserved - if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] { - assert_eq!(fc_part.thought_signature.as_deref(), Some("round_trip_sig")); - } else { - panic!("Expected FunctionCallPart"); - } - } - - #[test] - fn test_mixed_text_and_function_call_with_signature() { - let mut mapper = GoogleEventMapper::new(); - - let response = GenerateContentResponse { - candidates: Some(vec![GenerateContentCandidate { - index: Some(0), - content: Content { - parts: vec![ - Part::TextPart(TextPart { - text: "I'll help with that.".to_string(), - }), - Part::FunctionCallPart(FunctionCallPart { - function_call: FunctionCall { - name: "helper_function".to_string(), - args: json!({"query": "help"}), - }, - thought_signature: Some("mixed_sig".to_string()), - }), - ], - role: GoogleRole::Model, - }, - finish_reason: None, - finish_message: None, - safety_ratings: None, - citation_metadata: None, - }]), - prompt_feedback: None, - usage_metadata: None, - }; - - let events = mapper.map_event(response); - - assert_eq!(events.len(), 3); // Text event + ToolUse event + Stop event - - if let Ok(LanguageModelCompletionEvent::Text(text)) = &events[0] { - assert_eq!(text, "I'll help with that."); - } else { - panic!("Expected Text event"); - } - - if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[1] { - assert_eq!(tool_use.name.as_ref(), "helper_function"); - assert_eq!(tool_use.thought_signature.as_deref(), Some("mixed_sig")); - } else { - panic!("Expected ToolUse event"); - } - } - - #[test] - fn test_special_characters_in_signature_preserved() { - let mut mapper = GoogleEventMapper::new(); - - let signature_with_special_chars = "sig<>\"'&%$#@!{}[]".to_string(); - - let response = GenerateContentResponse { - candidates: Some(vec![GenerateContentCandidate { - index: Some(0), - content: Content { - parts: vec![Part::FunctionCallPart(FunctionCallPart { - function_call: FunctionCall { - name: "test_function".to_string(), - args: json!({"arg": "value"}), - }, - thought_signature: Some(signature_with_special_chars.clone()), - })], - role: GoogleRole::Model, - }, - finish_reason: None, - finish_message: None, - safety_ratings: None, - citation_metadata: None, - }]), - prompt_feedback: None, - usage_metadata: None, - }; - - let events = mapper.map_event(response); - - if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] { - assert_eq!( - tool_use.thought_signature.as_deref(), - Some(signature_with_special_chars.as_str()) - ); - } else { - panic!("Expected ToolUse event"); - } - } -} diff --git a/crates/language_models/src/provider/lmstudio.rs b/crates/language_models/src/provider/lmstudio.rs index 0d60fef16791087e35bac7d846b2ec99821d5470..a541da8cd8092d5d0fa43af1217c31833f10cdeb 100644 --- a/crates/language_models/src/provider/lmstudio.rs +++ b/crates/language_models/src/provider/lmstudio.rs @@ -28,7 +28,7 @@ use ui::{ use ui_input::InputField; use crate::AllLanguageModelSettings; -use crate::provider::util::parse_tool_arguments; +use language_model::util::parse_tool_arguments; const LMSTUDIO_DOWNLOAD_URL: &str = "https://lmstudio.ai/download"; const LMSTUDIO_CATALOG_URL: &str = "https://lmstudio.ai/models"; diff --git a/crates/language_models/src/provider/mistral.rs b/crates/language_models/src/provider/mistral.rs index 4cd1375fe50cd792a3a7bc8c85ba7b5b5af9520a..5fef40b2b1badbc77133ebe67fbe0f1fe5521259 100644 --- a/crates/language_models/src/provider/mistral.rs +++ b/crates/language_models/src/provider/mistral.rs @@ -23,7 +23,7 @@ use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*}; use ui_input::InputField; use util::ResultExt; -use crate::provider::util::{fix_streamed_json, parse_tool_arguments}; +use language_model::util::{fix_streamed_json, parse_tool_arguments}; const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("mistral"); const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Mistral"); diff --git a/crates/language_models/src/provider/open_ai.rs b/crates/language_models/src/provider/open_ai.rs index 6a2313487f4a1922cdc2aa20d23ede01c4b7d158..358a0ec5a6d517064be93d973f08eceb894ab665 100644 --- a/crates/language_models/src/provider/open_ai.rs +++ b/crates/language_models/src/provider/open_ai.rs @@ -1,41 +1,33 @@ -use anyhow::{Result, anyhow}; -use collections::{BTreeMap, HashMap}; +use anyhow::Result; +use collections::BTreeMap; use credentials_provider::CredentialsProvider; -use futures::Stream; use futures::{FutureExt, StreamExt, future::BoxFuture}; use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window}; use http_client::HttpClient; use language_model::{ ApiKeyState, AuthenticateError, EnvVar, IconOrSvg, LanguageModel, LanguageModelCompletionError, - LanguageModelCompletionEvent, LanguageModelId, LanguageModelImage, LanguageModelName, - LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, - LanguageModelProviderState, LanguageModelRequest, LanguageModelRequestMessage, - LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, - LanguageModelToolUseId, MessageContent, OPEN_AI_PROVIDER_ID, OPEN_AI_PROVIDER_NAME, - RateLimiter, Role, StopReason, TokenUsage, env_var, + LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider, + LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, + LanguageModelRequest, LanguageModelToolChoice, OPEN_AI_PROVIDER_ID, OPEN_AI_PROVIDER_NAME, + RateLimiter, env_var, }; use menu; -use open_ai::responses::{ - ResponseFunctionCallItem, ResponseFunctionCallOutputContent, ResponseFunctionCallOutputItem, - ResponseInputContent, ResponseInputItem, ResponseMessageItem, -}; use open_ai::{ - ImageUrl, Model, OPEN_AI_API_URL, ReasoningEffort, ResponseStreamEvent, - responses::{ - Request as ResponseRequest, ResponseOutputItem, ResponseSummary as ResponsesSummary, - ResponseUsage as ResponsesUsage, StreamEvent as ResponsesStreamEvent, stream_response, - }, + OPEN_AI_API_URL, ResponseStreamEvent, + responses::{Request as ResponseRequest, StreamEvent as ResponsesStreamEvent, stream_response}, stream_completion, }; use settings::{OpenAiAvailableModel as AvailableModel, Settings, SettingsStore}; -use std::pin::Pin; use std::sync::{Arc, LazyLock}; use strum::IntoEnumIterator; use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*}; use ui_input::InputField; use util::ResultExt; -use crate::provider::util::{fix_streamed_json, parse_tool_arguments}; +pub use open_ai::completion::{ + OpenAiEventMapper, OpenAiResponseEventMapper, collect_tiktoken_messages, count_open_ai_tokens, + into_open_ai, into_open_ai_response, +}; const PROVIDER_ID: LanguageModelProviderId = OPEN_AI_PROVIDER_ID; const PROVIDER_NAME: LanguageModelProviderName = OPEN_AI_PROVIDER_NAME; @@ -189,7 +181,7 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider { max_tokens: model.max_tokens, max_output_tokens: model.max_output_tokens, max_completion_tokens: model.max_completion_tokens, - reasoning_effort: model.reasoning_effort.clone(), + reasoning_effort: model.reasoning_effort, supports_chat_completions: model.capabilities.chat_completions, }, ); @@ -382,7 +374,9 @@ impl LanguageModel for OpenAiLanguageModel { request: LanguageModelRequest, cx: &App, ) -> BoxFuture<'static, Result> { - count_open_ai_tokens(request, self.model.clone(), cx) + let model = self.model.clone(); + cx.background_spawn(async move { count_open_ai_tokens(request, model) }) + .boxed() } fn stream_completion( @@ -433,853 +427,6 @@ impl LanguageModel for OpenAiLanguageModel { } } -pub fn into_open_ai( - request: LanguageModelRequest, - model_id: &str, - supports_parallel_tool_calls: bool, - supports_prompt_cache_key: bool, - max_output_tokens: Option, - reasoning_effort: Option, -) -> open_ai::Request { - let stream = !model_id.starts_with("o1-"); - - let mut messages = Vec::new(); - for message in request.messages { - for content in message.content { - match content { - MessageContent::Text(text) | MessageContent::Thinking { text, .. } => { - let should_add = if message.role == Role::User { - // Including whitespace-only user messages can cause error with OpenAI compatible APIs - // See https://github.com/zed-industries/zed/issues/40097 - !text.trim().is_empty() - } else { - !text.is_empty() - }; - if should_add { - add_message_content_part( - open_ai::MessagePart::Text { text }, - message.role, - &mut messages, - ); - } - } - MessageContent::RedactedThinking(_) => {} - MessageContent::Image(image) => { - add_message_content_part( - open_ai::MessagePart::Image { - image_url: ImageUrl { - url: image.to_base64_url(), - detail: None, - }, - }, - message.role, - &mut messages, - ); - } - MessageContent::ToolUse(tool_use) => { - let tool_call = open_ai::ToolCall { - id: tool_use.id.to_string(), - content: open_ai::ToolCallContent::Function { - function: open_ai::FunctionContent { - name: tool_use.name.to_string(), - arguments: serde_json::to_string(&tool_use.input) - .unwrap_or_default(), - }, - }, - }; - - if let Some(open_ai::RequestMessage::Assistant { tool_calls, .. }) = - messages.last_mut() - { - tool_calls.push(tool_call); - } else { - messages.push(open_ai::RequestMessage::Assistant { - content: None, - tool_calls: vec![tool_call], - }); - } - } - MessageContent::ToolResult(tool_result) => { - let content = match &tool_result.content { - LanguageModelToolResultContent::Text(text) => { - vec![open_ai::MessagePart::Text { - text: text.to_string(), - }] - } - LanguageModelToolResultContent::Image(image) => { - vec![open_ai::MessagePart::Image { - image_url: ImageUrl { - url: image.to_base64_url(), - detail: None, - }, - }] - } - }; - - messages.push(open_ai::RequestMessage::Tool { - content: content.into(), - tool_call_id: tool_result.tool_use_id.to_string(), - }); - } - } - } - } - - open_ai::Request { - model: model_id.into(), - messages, - stream, - stream_options: if stream { - Some(open_ai::StreamOptions::default()) - } else { - None - }, - stop: request.stop, - temperature: request.temperature.or(Some(1.0)), - max_completion_tokens: max_output_tokens, - parallel_tool_calls: if supports_parallel_tool_calls && !request.tools.is_empty() { - Some(supports_parallel_tool_calls) - } else { - None - }, - prompt_cache_key: if supports_prompt_cache_key { - request.thread_id - } else { - None - }, - tools: request - .tools - .into_iter() - .map(|tool| open_ai::ToolDefinition::Function { - function: open_ai::FunctionDefinition { - name: tool.name, - description: Some(tool.description), - parameters: Some(tool.input_schema), - }, - }) - .collect(), - tool_choice: request.tool_choice.map(|choice| match choice { - LanguageModelToolChoice::Auto => open_ai::ToolChoice::Auto, - LanguageModelToolChoice::Any => open_ai::ToolChoice::Required, - LanguageModelToolChoice::None => open_ai::ToolChoice::None, - }), - reasoning_effort, - } -} - -pub fn into_open_ai_response( - request: LanguageModelRequest, - model_id: &str, - supports_parallel_tool_calls: bool, - supports_prompt_cache_key: bool, - max_output_tokens: Option, - reasoning_effort: Option, -) -> ResponseRequest { - let stream = !model_id.starts_with("o1-"); - - let LanguageModelRequest { - thread_id, - prompt_id: _, - intent: _, - messages, - tools, - tool_choice, - stop: _, - temperature, - thinking_allowed: _, - thinking_effort: _, - speed: _, - } = request; - - let mut input_items = Vec::new(); - for (index, message) in messages.into_iter().enumerate() { - append_message_to_response_items(message, index, &mut input_items); - } - - let tools: Vec<_> = tools - .into_iter() - .map(|tool| open_ai::responses::ToolDefinition::Function { - name: tool.name, - description: Some(tool.description), - parameters: Some(tool.input_schema), - strict: None, - }) - .collect(); - - ResponseRequest { - model: model_id.into(), - input: input_items, - stream, - temperature, - top_p: None, - max_output_tokens, - parallel_tool_calls: if tools.is_empty() { - None - } else { - Some(supports_parallel_tool_calls) - }, - tool_choice: tool_choice.map(|choice| match choice { - LanguageModelToolChoice::Auto => open_ai::ToolChoice::Auto, - LanguageModelToolChoice::Any => open_ai::ToolChoice::Required, - LanguageModelToolChoice::None => open_ai::ToolChoice::None, - }), - tools, - prompt_cache_key: if supports_prompt_cache_key { - thread_id - } else { - None - }, - reasoning: reasoning_effort.map(|effort| open_ai::responses::ReasoningConfig { - effort, - summary: Some(open_ai::responses::ReasoningSummaryMode::Auto), - }), - } -} - -fn append_message_to_response_items( - message: LanguageModelRequestMessage, - index: usize, - input_items: &mut Vec, -) { - let mut content_parts: Vec = Vec::new(); - - for content in message.content { - match content { - MessageContent::Text(text) => { - push_response_text_part(&message.role, text, &mut content_parts); - } - MessageContent::Thinking { text, .. } => { - push_response_text_part(&message.role, text, &mut content_parts); - } - MessageContent::RedactedThinking(_) => {} - MessageContent::Image(image) => { - push_response_image_part(&message.role, image, &mut content_parts); - } - MessageContent::ToolUse(tool_use) => { - flush_response_parts(&message.role, index, &mut content_parts, input_items); - let call_id = tool_use.id.to_string(); - input_items.push(ResponseInputItem::FunctionCall(ResponseFunctionCallItem { - call_id, - name: tool_use.name.to_string(), - arguments: tool_use.raw_input, - })); - } - MessageContent::ToolResult(tool_result) => { - flush_response_parts(&message.role, index, &mut content_parts, input_items); - input_items.push(ResponseInputItem::FunctionCallOutput( - ResponseFunctionCallOutputItem { - call_id: tool_result.tool_use_id.to_string(), - output: match tool_result.content { - LanguageModelToolResultContent::Text(text) => { - ResponseFunctionCallOutputContent::Text(text.to_string()) - } - LanguageModelToolResultContent::Image(image) => { - ResponseFunctionCallOutputContent::List(vec![ - ResponseInputContent::Image { - image_url: image.to_base64_url(), - }, - ]) - } - }, - }, - )); - } - } - } - - flush_response_parts(&message.role, index, &mut content_parts, input_items); -} - -fn push_response_text_part( - role: &Role, - text: impl Into, - parts: &mut Vec, -) { - let text = text.into(); - if text.trim().is_empty() { - return; - } - - match role { - Role::Assistant => parts.push(ResponseInputContent::OutputText { - text, - annotations: Vec::new(), - }), - _ => parts.push(ResponseInputContent::Text { text }), - } -} - -fn push_response_image_part( - role: &Role, - image: LanguageModelImage, - parts: &mut Vec, -) { - match role { - Role::Assistant => parts.push(ResponseInputContent::OutputText { - text: "[image omitted]".to_string(), - annotations: Vec::new(), - }), - _ => parts.push(ResponseInputContent::Image { - image_url: image.to_base64_url(), - }), - } -} - -fn flush_response_parts( - role: &Role, - _index: usize, - parts: &mut Vec, - input_items: &mut Vec, -) { - if parts.is_empty() { - return; - } - - let item = ResponseInputItem::Message(ResponseMessageItem { - role: match role { - Role::User => open_ai::Role::User, - Role::Assistant => open_ai::Role::Assistant, - Role::System => open_ai::Role::System, - }, - content: parts.clone(), - }); - - input_items.push(item); - parts.clear(); -} - -fn add_message_content_part( - new_part: open_ai::MessagePart, - role: Role, - messages: &mut Vec, -) { - match (role, messages.last_mut()) { - (Role::User, Some(open_ai::RequestMessage::User { content })) - | ( - Role::Assistant, - Some(open_ai::RequestMessage::Assistant { - content: Some(content), - .. - }), - ) - | (Role::System, Some(open_ai::RequestMessage::System { content, .. })) => { - content.push_part(new_part); - } - _ => { - messages.push(match role { - Role::User => open_ai::RequestMessage::User { - content: open_ai::MessageContent::from(vec![new_part]), - }, - Role::Assistant => open_ai::RequestMessage::Assistant { - content: Some(open_ai::MessageContent::from(vec![new_part])), - tool_calls: Vec::new(), - }, - Role::System => open_ai::RequestMessage::System { - content: open_ai::MessageContent::from(vec![new_part]), - }, - }); - } - } -} - -pub struct OpenAiEventMapper { - tool_calls_by_index: HashMap, -} - -impl OpenAiEventMapper { - pub fn new() -> Self { - Self { - tool_calls_by_index: HashMap::default(), - } - } - - pub fn map_stream( - mut self, - events: Pin>>>, - ) -> impl Stream> - { - events.flat_map(move |event| { - futures::stream::iter(match event { - Ok(event) => self.map_event(event), - Err(error) => vec![Err(LanguageModelCompletionError::from(anyhow!(error)))], - }) - }) - } - - pub fn map_event( - &mut self, - event: ResponseStreamEvent, - ) -> Vec> { - let mut events = Vec::new(); - if let Some(usage) = event.usage { - events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage { - input_tokens: usage.prompt_tokens, - output_tokens: usage.completion_tokens, - cache_creation_input_tokens: 0, - cache_read_input_tokens: 0, - }))); - } - - let Some(choice) = event.choices.first() else { - return events; - }; - - if let Some(delta) = choice.delta.as_ref() { - if let Some(reasoning_content) = delta.reasoning_content.clone() { - if !reasoning_content.is_empty() { - events.push(Ok(LanguageModelCompletionEvent::Thinking { - text: reasoning_content, - signature: None, - })); - } - } - if let Some(content) = delta.content.clone() { - if !content.is_empty() { - events.push(Ok(LanguageModelCompletionEvent::Text(content))); - } - } - - if let Some(tool_calls) = delta.tool_calls.as_ref() { - for tool_call in tool_calls { - let entry = self.tool_calls_by_index.entry(tool_call.index).or_default(); - - if let Some(tool_id) = tool_call.id.clone() { - entry.id = tool_id; - } - - if let Some(function) = tool_call.function.as_ref() { - if let Some(name) = function.name.clone() { - entry.name = name; - } - - if let Some(arguments) = function.arguments.clone() { - entry.arguments.push_str(&arguments); - } - } - - if !entry.id.is_empty() && !entry.name.is_empty() { - if let Ok(input) = serde_json::from_str::( - &fix_streamed_json(&entry.arguments), - ) { - events.push(Ok(LanguageModelCompletionEvent::ToolUse( - LanguageModelToolUse { - id: entry.id.clone().into(), - name: entry.name.as_str().into(), - is_input_complete: false, - input, - raw_input: entry.arguments.clone(), - thought_signature: None, - }, - ))); - } - } - } - } - } - - match choice.finish_reason.as_deref() { - Some("stop") => { - events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn))); - } - Some("tool_calls") => { - events.extend(self.tool_calls_by_index.drain().map(|(_, tool_call)| { - match parse_tool_arguments(&tool_call.arguments) { - Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse( - LanguageModelToolUse { - id: tool_call.id.clone().into(), - name: tool_call.name.as_str().into(), - is_input_complete: true, - input, - raw_input: tool_call.arguments.clone(), - thought_signature: None, - }, - )), - Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError { - id: tool_call.id.into(), - tool_name: tool_call.name.into(), - raw_input: tool_call.arguments.clone().into(), - json_parse_error: error.to_string(), - }), - } - })); - - events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse))); - } - Some(stop_reason) => { - log::error!("Unexpected OpenAI stop_reason: {stop_reason:?}",); - events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn))); - } - None => {} - } - - events - } -} - -#[derive(Default)] -struct RawToolCall { - id: String, - name: String, - arguments: String, -} - -pub struct OpenAiResponseEventMapper { - function_calls_by_item: HashMap, - pending_stop_reason: Option, -} - -#[derive(Default)] -struct PendingResponseFunctionCall { - call_id: String, - name: Arc, - arguments: String, -} - -impl OpenAiResponseEventMapper { - pub fn new() -> Self { - Self { - function_calls_by_item: HashMap::default(), - pending_stop_reason: None, - } - } - - pub fn map_stream( - mut self, - events: Pin>>>, - ) -> impl Stream> - { - events.flat_map(move |event| { - futures::stream::iter(match event { - Ok(event) => self.map_event(event), - Err(error) => vec![Err(LanguageModelCompletionError::from(anyhow!(error)))], - }) - }) - } - - pub fn map_event( - &mut self, - event: ResponsesStreamEvent, - ) -> Vec> { - match event { - ResponsesStreamEvent::OutputItemAdded { item, .. } => { - let mut events = Vec::new(); - - match &item { - ResponseOutputItem::Message(message) => { - if let Some(id) = &message.id { - events.push(Ok(LanguageModelCompletionEvent::StartMessage { - message_id: id.clone(), - })); - } - } - ResponseOutputItem::FunctionCall(function_call) => { - if let Some(item_id) = function_call.id.clone() { - let call_id = function_call - .call_id - .clone() - .or_else(|| function_call.id.clone()) - .unwrap_or_else(|| item_id.clone()); - let entry = PendingResponseFunctionCall { - call_id, - name: Arc::::from( - function_call.name.clone().unwrap_or_default(), - ), - arguments: function_call.arguments.clone(), - }; - self.function_calls_by_item.insert(item_id, entry); - } - } - ResponseOutputItem::Reasoning(_) | ResponseOutputItem::Unknown => {} - } - events - } - ResponsesStreamEvent::ReasoningSummaryTextDelta { delta, .. } => { - if delta.is_empty() { - Vec::new() - } else { - vec![Ok(LanguageModelCompletionEvent::Thinking { - text: delta, - signature: None, - })] - } - } - ResponsesStreamEvent::OutputTextDelta { delta, .. } => { - if delta.is_empty() { - Vec::new() - } else { - vec![Ok(LanguageModelCompletionEvent::Text(delta))] - } - } - ResponsesStreamEvent::FunctionCallArgumentsDelta { item_id, delta, .. } => { - if let Some(entry) = self.function_calls_by_item.get_mut(&item_id) { - entry.arguments.push_str(&delta); - if let Ok(input) = serde_json::from_str::( - &fix_streamed_json(&entry.arguments), - ) { - return vec![Ok(LanguageModelCompletionEvent::ToolUse( - LanguageModelToolUse { - id: LanguageModelToolUseId::from(entry.call_id.clone()), - name: entry.name.clone(), - is_input_complete: false, - input, - raw_input: entry.arguments.clone(), - thought_signature: None, - }, - ))]; - } - } - Vec::new() - } - ResponsesStreamEvent::FunctionCallArgumentsDone { - item_id, arguments, .. - } => { - if let Some(mut entry) = self.function_calls_by_item.remove(&item_id) { - if !arguments.is_empty() { - entry.arguments = arguments; - } - let raw_input = entry.arguments.clone(); - self.pending_stop_reason = Some(StopReason::ToolUse); - match parse_tool_arguments(&entry.arguments) { - Ok(input) => { - vec![Ok(LanguageModelCompletionEvent::ToolUse( - LanguageModelToolUse { - id: LanguageModelToolUseId::from(entry.call_id.clone()), - name: entry.name.clone(), - is_input_complete: true, - input, - raw_input, - thought_signature: None, - }, - ))] - } - Err(error) => { - vec![Ok(LanguageModelCompletionEvent::ToolUseJsonParseError { - id: LanguageModelToolUseId::from(entry.call_id.clone()), - tool_name: entry.name.clone(), - raw_input: Arc::::from(raw_input), - json_parse_error: error.to_string(), - })] - } - } - } else { - Vec::new() - } - } - ResponsesStreamEvent::Completed { response } => { - self.handle_completion(response, StopReason::EndTurn) - } - ResponsesStreamEvent::Incomplete { response } => { - let reason = response - .status_details - .as_ref() - .and_then(|details| details.reason.as_deref()); - let stop_reason = match reason { - Some("max_output_tokens") => StopReason::MaxTokens, - Some("content_filter") => { - self.pending_stop_reason = Some(StopReason::Refusal); - StopReason::Refusal - } - _ => self - .pending_stop_reason - .take() - .unwrap_or(StopReason::EndTurn), - }; - - let mut events = Vec::new(); - if self.pending_stop_reason.is_none() { - events.extend(self.emit_tool_calls_from_output(&response.output)); - } - if let Some(usage) = response.usage.as_ref() { - events.push(Ok(LanguageModelCompletionEvent::UsageUpdate( - token_usage_from_response_usage(usage), - ))); - } - events.push(Ok(LanguageModelCompletionEvent::Stop(stop_reason))); - events - } - ResponsesStreamEvent::Failed { response } => { - let message = response - .status_details - .and_then(|details| details.error) - .map(|error| error.to_string()) - .unwrap_or_else(|| "response failed".to_string()); - vec![Err(LanguageModelCompletionError::Other(anyhow!(message)))] - } - ResponsesStreamEvent::Error { error } - | ResponsesStreamEvent::GenericError { error } => { - vec![Err(LanguageModelCompletionError::Other(anyhow!( - error.message - )))] - } - ResponsesStreamEvent::ReasoningSummaryPartAdded { summary_index, .. } => { - if summary_index > 0 { - vec![Ok(LanguageModelCompletionEvent::Thinking { - text: "\n\n".to_string(), - signature: None, - })] - } else { - Vec::new() - } - } - ResponsesStreamEvent::OutputTextDone { .. } - | ResponsesStreamEvent::OutputItemDone { .. } - | ResponsesStreamEvent::ContentPartAdded { .. } - | ResponsesStreamEvent::ContentPartDone { .. } - | ResponsesStreamEvent::ReasoningSummaryTextDone { .. } - | ResponsesStreamEvent::ReasoningSummaryPartDone { .. } - | ResponsesStreamEvent::Created { .. } - | ResponsesStreamEvent::InProgress { .. } - | ResponsesStreamEvent::Unknown => Vec::new(), - } - } - - fn handle_completion( - &mut self, - response: ResponsesSummary, - default_reason: StopReason, - ) -> Vec> { - let mut events = Vec::new(); - - if self.pending_stop_reason.is_none() { - events.extend(self.emit_tool_calls_from_output(&response.output)); - } - - if let Some(usage) = response.usage.as_ref() { - events.push(Ok(LanguageModelCompletionEvent::UsageUpdate( - token_usage_from_response_usage(usage), - ))); - } - - let stop_reason = self.pending_stop_reason.take().unwrap_or(default_reason); - events.push(Ok(LanguageModelCompletionEvent::Stop(stop_reason))); - events - } - - fn emit_tool_calls_from_output( - &mut self, - output: &[ResponseOutputItem], - ) -> Vec> { - let mut events = Vec::new(); - for item in output { - if let ResponseOutputItem::FunctionCall(function_call) = item { - let Some(call_id) = function_call - .call_id - .clone() - .or_else(|| function_call.id.clone()) - else { - log::error!( - "Function call item missing both call_id and id: {:?}", - function_call - ); - continue; - }; - let name: Arc = Arc::from(function_call.name.clone().unwrap_or_default()); - let arguments = &function_call.arguments; - self.pending_stop_reason = Some(StopReason::ToolUse); - match parse_tool_arguments(arguments) { - Ok(input) => { - events.push(Ok(LanguageModelCompletionEvent::ToolUse( - LanguageModelToolUse { - id: LanguageModelToolUseId::from(call_id.clone()), - name: name.clone(), - is_input_complete: true, - input, - raw_input: arguments.clone(), - thought_signature: None, - }, - ))); - } - Err(error) => { - events.push(Ok(LanguageModelCompletionEvent::ToolUseJsonParseError { - id: LanguageModelToolUseId::from(call_id.clone()), - tool_name: name.clone(), - raw_input: Arc::::from(arguments.clone()), - json_parse_error: error.to_string(), - })); - } - } - } - } - events - } -} - -fn token_usage_from_response_usage(usage: &ResponsesUsage) -> TokenUsage { - TokenUsage { - input_tokens: usage.input_tokens.unwrap_or_default(), - output_tokens: usage.output_tokens.unwrap_or_default(), - cache_creation_input_tokens: 0, - cache_read_input_tokens: 0, - } -} - -pub(crate) fn collect_tiktoken_messages( - request: LanguageModelRequest, -) -> Vec { - request - .messages - .into_iter() - .map(|message| tiktoken_rs::ChatCompletionRequestMessage { - role: match message.role { - Role::User => "user".into(), - Role::Assistant => "assistant".into(), - Role::System => "system".into(), - }, - content: Some(message.string_contents()), - name: None, - function_call: None, - }) - .collect::>() -} - -pub fn count_open_ai_tokens( - request: LanguageModelRequest, - model: Model, - cx: &App, -) -> BoxFuture<'static, Result> { - cx.background_spawn(async move { - let messages = collect_tiktoken_messages(request); - match model { - Model::Custom { max_tokens, .. } => { - let model = if max_tokens >= 100_000 { - // If the max tokens is 100k or more, it likely uses the o200k_base tokenizer - "gpt-4o" - } else { - // Otherwise fallback to gpt-4, since only cl100k_base and o200k_base are - // supported with this tiktoken method - "gpt-4" - }; - tiktoken_rs::num_tokens_from_messages(model, &messages) - } - // Currently supported by tiktoken_rs - // Sometimes tiktoken-rs is behind on model support. If that is the case, make a new branch - // arm with an override. We enumerate all supported models here so that we can check if new - // models are supported yet or not. - Model::ThreePointFiveTurbo - | Model::Four - | Model::FourTurbo - | Model::FourOmniMini - | Model::FourPointOneNano - | Model::O1 - | Model::O3 - | Model::O3Mini - | Model::Five - | Model::FiveCodex - | Model::FiveMini - | Model::FiveNano => tiktoken_rs::num_tokens_from_messages(model.id(), &messages), - // GPT-5.1, 5.2, 5.2-codex, 5.3-codex, 5.4, and 5.4-pro don't have dedicated tiktoken support; use gpt-5 tokenizer - Model::FivePointOne - | Model::FivePointTwo - | Model::FivePointTwoCodex - | Model::FivePointThreeCodex - | Model::FivePointFour - | Model::FivePointFourPro => tiktoken_rs::num_tokens_from_messages("gpt-5", &messages), - } - .map(|tokens| tokens as u64) - }) - .boxed() -} - struct ConfigurationView { api_key_editor: Entity, state: Entity, @@ -1459,874 +606,3 @@ impl Render for ConfigurationView { } } } - -#[cfg(test)] -mod tests { - use futures::{StreamExt, executor::block_on}; - use gpui::TestAppContext; - use language_model::{ - LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult, - }; - use open_ai::responses::{ - ReasoningSummaryPart, ResponseFunctionToolCall, ResponseOutputItem, ResponseOutputMessage, - ResponseReasoningItem, ResponseStatusDetails, ResponseSummary, ResponseUsage, - StreamEvent as ResponsesStreamEvent, - }; - use pretty_assertions::assert_eq; - use serde_json::json; - - use super::*; - - fn map_response_events(events: Vec) -> Vec { - block_on(async { - OpenAiResponseEventMapper::new() - .map_stream(Box::pin(futures::stream::iter(events.into_iter().map(Ok)))) - .collect::>() - .await - .into_iter() - .map(Result::unwrap) - .collect() - }) - } - - fn response_item_message(id: &str) -> ResponseOutputItem { - ResponseOutputItem::Message(ResponseOutputMessage { - id: Some(id.to_string()), - role: Some("assistant".to_string()), - status: Some("in_progress".to_string()), - content: vec![], - }) - } - - fn response_item_function_call(id: &str, args: Option<&str>) -> ResponseOutputItem { - ResponseOutputItem::FunctionCall(ResponseFunctionToolCall { - id: Some(id.to_string()), - status: Some("in_progress".to_string()), - name: Some("get_weather".to_string()), - call_id: Some("call_123".to_string()), - arguments: args.map(|s| s.to_string()).unwrap_or_default(), - }) - } - - #[gpui::test] - fn tiktoken_rs_support(cx: &TestAppContext) { - let request = LanguageModelRequest { - thread_id: None, - prompt_id: None, - intent: None, - messages: vec![LanguageModelRequestMessage { - role: Role::User, - content: vec![MessageContent::Text("message".into())], - cache: false, - reasoning_details: None, - }], - tools: vec![], - tool_choice: None, - stop: vec![], - temperature: None, - thinking_allowed: true, - thinking_effort: None, - speed: None, - }; - - // Validate that all models are supported by tiktoken-rs - for model in Model::iter() { - let count = cx - .foreground_executor() - .block_on(count_open_ai_tokens( - request.clone(), - model, - &cx.app.borrow(), - )) - .unwrap(); - assert!(count > 0); - } - } - - #[test] - fn responses_stream_maps_text_and_usage() { - let events = vec![ - ResponsesStreamEvent::OutputItemAdded { - output_index: 0, - sequence_number: None, - item: response_item_message("msg_123"), - }, - ResponsesStreamEvent::OutputTextDelta { - item_id: "msg_123".into(), - output_index: 0, - content_index: Some(0), - delta: "Hello".into(), - }, - ResponsesStreamEvent::Completed { - response: ResponseSummary { - usage: Some(ResponseUsage { - input_tokens: Some(5), - output_tokens: Some(3), - total_tokens: Some(8), - }), - ..Default::default() - }, - }, - ]; - - let mapped = map_response_events(events); - assert!(matches!( - mapped[0], - LanguageModelCompletionEvent::StartMessage { ref message_id } if message_id == "msg_123" - )); - assert!(matches!( - mapped[1], - LanguageModelCompletionEvent::Text(ref text) if text == "Hello" - )); - assert!(matches!( - mapped[2], - LanguageModelCompletionEvent::UsageUpdate(TokenUsage { - input_tokens: 5, - output_tokens: 3, - .. - }) - )); - assert!(matches!( - mapped[3], - LanguageModelCompletionEvent::Stop(StopReason::EndTurn) - )); - } - - #[test] - fn into_open_ai_response_builds_complete_payload() { - let tool_call_id = LanguageModelToolUseId::from("call-42"); - let tool_input = json!({ "city": "Boston" }); - let tool_arguments = serde_json::to_string(&tool_input).unwrap(); - let tool_use = LanguageModelToolUse { - id: tool_call_id.clone(), - name: Arc::from("get_weather"), - raw_input: tool_arguments.clone(), - input: tool_input, - is_input_complete: true, - thought_signature: None, - }; - let tool_result = LanguageModelToolResult { - tool_use_id: tool_call_id, - tool_name: Arc::from("get_weather"), - is_error: false, - content: LanguageModelToolResultContent::Text(Arc::from("Sunny")), - output: Some(json!({ "forecast": "Sunny" })), - }; - let user_image = LanguageModelImage { - source: SharedString::from("aGVsbG8="), - size: None, - }; - let expected_image_url = user_image.to_base64_url(); - - let request = LanguageModelRequest { - thread_id: Some("thread-123".into()), - prompt_id: None, - intent: None, - messages: vec![ - LanguageModelRequestMessage { - role: Role::System, - content: vec![MessageContent::Text("System context".into())], - cache: false, - reasoning_details: None, - }, - LanguageModelRequestMessage { - role: Role::User, - content: vec![ - MessageContent::Text("Please check the weather.".into()), - MessageContent::Image(user_image), - ], - cache: false, - reasoning_details: None, - }, - LanguageModelRequestMessage { - role: Role::Assistant, - content: vec![ - MessageContent::Text("Looking that up.".into()), - MessageContent::ToolUse(tool_use), - ], - cache: false, - reasoning_details: None, - }, - LanguageModelRequestMessage { - role: Role::Assistant, - content: vec![MessageContent::ToolResult(tool_result)], - cache: false, - reasoning_details: None, - }, - ], - tools: vec![LanguageModelRequestTool { - name: "get_weather".into(), - description: "Fetches the weather".into(), - input_schema: json!({ "type": "object" }), - use_input_streaming: false, - }], - tool_choice: Some(LanguageModelToolChoice::Any), - stop: vec!["".into()], - temperature: None, - thinking_allowed: false, - thinking_effort: None, - speed: None, - }; - - let response = into_open_ai_response( - request, - "custom-model", - true, - true, - Some(2048), - Some(ReasoningEffort::Low), - ); - - let serialized = serde_json::to_value(&response).unwrap(); - let expected = json!({ - "model": "custom-model", - "input": [ - { - "type": "message", - "role": "system", - "content": [ - { "type": "input_text", "text": "System context" } - ] - }, - { - "type": "message", - "role": "user", - "content": [ - { "type": "input_text", "text": "Please check the weather." }, - { "type": "input_image", "image_url": expected_image_url } - ] - }, - { - "type": "message", - "role": "assistant", - "content": [ - { "type": "output_text", "text": "Looking that up.", "annotations": [] } - ] - }, - { - "type": "function_call", - "call_id": "call-42", - "name": "get_weather", - "arguments": tool_arguments - }, - { - "type": "function_call_output", - "call_id": "call-42", - "output": "Sunny" - } - ], - "stream": true, - "max_output_tokens": 2048, - "parallel_tool_calls": true, - "tool_choice": "required", - "tools": [ - { - "type": "function", - "name": "get_weather", - "description": "Fetches the weather", - "parameters": { "type": "object" } - } - ], - "prompt_cache_key": "thread-123", - "reasoning": { "effort": "low", "summary": "auto" } - }); - - assert_eq!(serialized, expected); - } - - #[test] - fn responses_stream_maps_tool_calls() { - let events = vec![ - ResponsesStreamEvent::OutputItemAdded { - output_index: 0, - sequence_number: None, - item: response_item_function_call("item_fn", Some("{\"city\":\"Bos")), - }, - ResponsesStreamEvent::FunctionCallArgumentsDelta { - item_id: "item_fn".into(), - output_index: 0, - delta: "ton\"}".into(), - sequence_number: None, - }, - ResponsesStreamEvent::FunctionCallArgumentsDone { - item_id: "item_fn".into(), - output_index: 0, - arguments: "{\"city\":\"Boston\"}".into(), - sequence_number: None, - }, - ResponsesStreamEvent::Completed { - response: ResponseSummary::default(), - }, - ]; - - let mapped = map_response_events(events); - assert_eq!(mapped.len(), 3); - // First event is the partial tool use (from FunctionCallArgumentsDelta) - assert!(matches!( - mapped[0], - LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { - is_input_complete: false, - .. - }) - )); - // Second event is the complete tool use (from FunctionCallArgumentsDone) - assert!(matches!( - mapped[1], - LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { - ref id, - ref name, - ref raw_input, - is_input_complete: true, - .. - }) if id.to_string() == "call_123" - && name.as_ref() == "get_weather" - && raw_input == "{\"city\":\"Boston\"}" - )); - assert!(matches!( - mapped[2], - LanguageModelCompletionEvent::Stop(StopReason::ToolUse) - )); - } - - #[test] - fn responses_stream_uses_max_tokens_stop_reason() { - let events = vec![ResponsesStreamEvent::Incomplete { - response: ResponseSummary { - status_details: Some(ResponseStatusDetails { - reason: Some("max_output_tokens".into()), - r#type: Some("incomplete".into()), - error: None, - }), - usage: Some(ResponseUsage { - input_tokens: Some(10), - output_tokens: Some(20), - total_tokens: Some(30), - }), - ..Default::default() - }, - }]; - - let mapped = map_response_events(events); - assert!(matches!( - mapped[0], - LanguageModelCompletionEvent::UsageUpdate(TokenUsage { - input_tokens: 10, - output_tokens: 20, - .. - }) - )); - assert!(matches!( - mapped[1], - LanguageModelCompletionEvent::Stop(StopReason::MaxTokens) - )); - } - - #[test] - fn responses_stream_handles_multiple_tool_calls() { - let events = vec![ - ResponsesStreamEvent::OutputItemAdded { - output_index: 0, - sequence_number: None, - item: response_item_function_call("item_fn1", Some("{\"city\":\"NYC\"}")), - }, - ResponsesStreamEvent::FunctionCallArgumentsDone { - item_id: "item_fn1".into(), - output_index: 0, - arguments: "{\"city\":\"NYC\"}".into(), - sequence_number: None, - }, - ResponsesStreamEvent::OutputItemAdded { - output_index: 1, - sequence_number: None, - item: response_item_function_call("item_fn2", Some("{\"city\":\"LA\"}")), - }, - ResponsesStreamEvent::FunctionCallArgumentsDone { - item_id: "item_fn2".into(), - output_index: 1, - arguments: "{\"city\":\"LA\"}".into(), - sequence_number: None, - }, - ResponsesStreamEvent::Completed { - response: ResponseSummary::default(), - }, - ]; - - let mapped = map_response_events(events); - assert_eq!(mapped.len(), 3); - assert!(matches!( - mapped[0], - LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { ref raw_input, .. }) - if raw_input == "{\"city\":\"NYC\"}" - )); - assert!(matches!( - mapped[1], - LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { ref raw_input, .. }) - if raw_input == "{\"city\":\"LA\"}" - )); - assert!(matches!( - mapped[2], - LanguageModelCompletionEvent::Stop(StopReason::ToolUse) - )); - } - - #[test] - fn responses_stream_handles_mixed_text_and_tool_calls() { - let events = vec![ - ResponsesStreamEvent::OutputItemAdded { - output_index: 0, - sequence_number: None, - item: response_item_message("msg_123"), - }, - ResponsesStreamEvent::OutputTextDelta { - item_id: "msg_123".into(), - output_index: 0, - content_index: Some(0), - delta: "Let me check that".into(), - }, - ResponsesStreamEvent::OutputItemAdded { - output_index: 1, - sequence_number: None, - item: response_item_function_call("item_fn", Some("{\"query\":\"test\"}")), - }, - ResponsesStreamEvent::FunctionCallArgumentsDone { - item_id: "item_fn".into(), - output_index: 1, - arguments: "{\"query\":\"test\"}".into(), - sequence_number: None, - }, - ResponsesStreamEvent::Completed { - response: ResponseSummary::default(), - }, - ]; - - let mapped = map_response_events(events); - assert!(matches!( - mapped[0], - LanguageModelCompletionEvent::StartMessage { .. } - )); - assert!(matches!( - mapped[1], - LanguageModelCompletionEvent::Text(ref text) if text == "Let me check that" - )); - assert!(matches!( - mapped[2], - LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { ref raw_input, .. }) - if raw_input == "{\"query\":\"test\"}" - )); - assert!(matches!( - mapped[3], - LanguageModelCompletionEvent::Stop(StopReason::ToolUse) - )); - } - - #[test] - fn responses_stream_handles_json_parse_error() { - let events = vec![ - ResponsesStreamEvent::OutputItemAdded { - output_index: 0, - sequence_number: None, - item: response_item_function_call("item_fn", Some("{invalid json")), - }, - ResponsesStreamEvent::FunctionCallArgumentsDone { - item_id: "item_fn".into(), - output_index: 0, - arguments: "{invalid json".into(), - sequence_number: None, - }, - ResponsesStreamEvent::Completed { - response: ResponseSummary::default(), - }, - ]; - - let mapped = map_response_events(events); - assert!(matches!( - mapped[0], - LanguageModelCompletionEvent::ToolUseJsonParseError { - ref raw_input, - .. - } if raw_input.as_ref() == "{invalid json" - )); - } - - #[test] - fn responses_stream_handles_incomplete_function_call() { - let events = vec![ - ResponsesStreamEvent::OutputItemAdded { - output_index: 0, - sequence_number: None, - item: response_item_function_call("item_fn", Some("{\"city\":")), - }, - ResponsesStreamEvent::FunctionCallArgumentsDelta { - item_id: "item_fn".into(), - output_index: 0, - delta: "\"Boston\"".into(), - sequence_number: None, - }, - ResponsesStreamEvent::Incomplete { - response: ResponseSummary { - status_details: Some(ResponseStatusDetails { - reason: Some("max_output_tokens".into()), - r#type: Some("incomplete".into()), - error: None, - }), - output: vec![response_item_function_call( - "item_fn", - Some("{\"city\":\"Boston\"}"), - )], - ..Default::default() - }, - }, - ]; - - let mapped = map_response_events(events); - assert_eq!(mapped.len(), 3); - // First event is the partial tool use (from FunctionCallArgumentsDelta) - assert!(matches!( - mapped[0], - LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { - is_input_complete: false, - .. - }) - )); - // Second event is the complete tool use (from the Incomplete response output) - assert!(matches!( - mapped[1], - LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { - ref raw_input, - is_input_complete: true, - .. - }) - if raw_input == "{\"city\":\"Boston\"}" - )); - assert!(matches!( - mapped[2], - LanguageModelCompletionEvent::Stop(StopReason::MaxTokens) - )); - } - - #[test] - fn responses_stream_incomplete_does_not_duplicate_tool_calls() { - let events = vec![ - ResponsesStreamEvent::OutputItemAdded { - output_index: 0, - sequence_number: None, - item: response_item_function_call("item_fn", Some("{\"city\":\"Boston\"}")), - }, - ResponsesStreamEvent::FunctionCallArgumentsDone { - item_id: "item_fn".into(), - output_index: 0, - arguments: "{\"city\":\"Boston\"}".into(), - sequence_number: None, - }, - ResponsesStreamEvent::Incomplete { - response: ResponseSummary { - status_details: Some(ResponseStatusDetails { - reason: Some("max_output_tokens".into()), - r#type: Some("incomplete".into()), - error: None, - }), - output: vec![response_item_function_call( - "item_fn", - Some("{\"city\":\"Boston\"}"), - )], - ..Default::default() - }, - }, - ]; - - let mapped = map_response_events(events); - assert_eq!(mapped.len(), 2); - assert!(matches!( - mapped[0], - LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { ref raw_input, .. }) - if raw_input == "{\"city\":\"Boston\"}" - )); - assert!(matches!( - mapped[1], - LanguageModelCompletionEvent::Stop(StopReason::MaxTokens) - )); - } - - #[test] - fn responses_stream_handles_empty_tool_arguments() { - // Test that tools with no arguments (empty string) are handled correctly - let events = vec![ - ResponsesStreamEvent::OutputItemAdded { - output_index: 0, - sequence_number: None, - item: response_item_function_call("item_fn", Some("")), - }, - ResponsesStreamEvent::FunctionCallArgumentsDone { - item_id: "item_fn".into(), - output_index: 0, - arguments: "".into(), - sequence_number: None, - }, - ResponsesStreamEvent::Completed { - response: ResponseSummary::default(), - }, - ]; - - let mapped = map_response_events(events); - assert_eq!(mapped.len(), 2); - - // Should produce a ToolUse event with an empty object - assert!(matches!( - &mapped[0], - LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { - id, - name, - raw_input, - input, - .. - }) if id.to_string() == "call_123" - && name.as_ref() == "get_weather" - && raw_input == "" - && input.is_object() - && input.as_object().unwrap().is_empty() - )); - - assert!(matches!( - mapped[1], - LanguageModelCompletionEvent::Stop(StopReason::ToolUse) - )); - } - - #[test] - fn responses_stream_emits_partial_tool_use_events() { - let events = vec![ - ResponsesStreamEvent::OutputItemAdded { - output_index: 0, - sequence_number: None, - item: ResponseOutputItem::FunctionCall(ResponseFunctionToolCall { - id: Some("item_fn".to_string()), - status: Some("in_progress".to_string()), - name: Some("get_weather".to_string()), - call_id: Some("call_abc".to_string()), - arguments: String::new(), - }), - }, - ResponsesStreamEvent::FunctionCallArgumentsDelta { - item_id: "item_fn".into(), - output_index: 0, - delta: "{\"city\":\"Bos".into(), - sequence_number: None, - }, - ResponsesStreamEvent::FunctionCallArgumentsDelta { - item_id: "item_fn".into(), - output_index: 0, - delta: "ton\"}".into(), - sequence_number: None, - }, - ResponsesStreamEvent::FunctionCallArgumentsDone { - item_id: "item_fn".into(), - output_index: 0, - arguments: "{\"city\":\"Boston\"}".into(), - sequence_number: None, - }, - ResponsesStreamEvent::Completed { - response: ResponseSummary::default(), - }, - ]; - - let mapped = map_response_events(events); - // Two partial events + one complete event + Stop - assert!(mapped.len() >= 3); - - // The last complete ToolUse event should have is_input_complete: true - let complete_tool_use = mapped.iter().find(|e| { - matches!( - e, - LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { - is_input_complete: true, - .. - }) - ) - }); - assert!( - complete_tool_use.is_some(), - "should have a complete tool use event" - ); - - // All ToolUse events before the final one should have is_input_complete: false - let tool_uses: Vec<_> = mapped - .iter() - .filter(|e| matches!(e, LanguageModelCompletionEvent::ToolUse(_))) - .collect(); - assert!( - tool_uses.len() >= 2, - "should have at least one partial and one complete event" - ); - - let last = tool_uses.last().unwrap(); - assert!(matches!( - last, - LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { - is_input_complete: true, - .. - }) - )); - } - - #[test] - fn responses_stream_maps_reasoning_summary_deltas() { - let events = vec![ - ResponsesStreamEvent::OutputItemAdded { - output_index: 0, - sequence_number: None, - item: ResponseOutputItem::Reasoning(ResponseReasoningItem { - id: Some("rs_123".into()), - summary: vec![], - }), - }, - ResponsesStreamEvent::ReasoningSummaryPartAdded { - item_id: "rs_123".into(), - output_index: 0, - summary_index: 0, - }, - ResponsesStreamEvent::ReasoningSummaryTextDelta { - item_id: "rs_123".into(), - output_index: 0, - delta: "Thinking about".into(), - }, - ResponsesStreamEvent::ReasoningSummaryTextDelta { - item_id: "rs_123".into(), - output_index: 0, - delta: " the answer".into(), - }, - ResponsesStreamEvent::ReasoningSummaryTextDone { - item_id: "rs_123".into(), - output_index: 0, - text: "Thinking about the answer".into(), - }, - ResponsesStreamEvent::ReasoningSummaryPartDone { - item_id: "rs_123".into(), - output_index: 0, - summary_index: 0, - }, - ResponsesStreamEvent::ReasoningSummaryPartAdded { - item_id: "rs_123".into(), - output_index: 0, - summary_index: 1, - }, - ResponsesStreamEvent::ReasoningSummaryTextDelta { - item_id: "rs_123".into(), - output_index: 0, - delta: "Second part".into(), - }, - ResponsesStreamEvent::ReasoningSummaryTextDone { - item_id: "rs_123".into(), - output_index: 0, - text: "Second part".into(), - }, - ResponsesStreamEvent::ReasoningSummaryPartDone { - item_id: "rs_123".into(), - output_index: 0, - summary_index: 1, - }, - ResponsesStreamEvent::OutputItemDone { - output_index: 0, - sequence_number: None, - item: ResponseOutputItem::Reasoning(ResponseReasoningItem { - id: Some("rs_123".into()), - summary: vec![ - ReasoningSummaryPart::SummaryText { - text: "Thinking about the answer".into(), - }, - ReasoningSummaryPart::SummaryText { - text: "Second part".into(), - }, - ], - }), - }, - ResponsesStreamEvent::OutputItemAdded { - output_index: 1, - sequence_number: None, - item: response_item_message("msg_456"), - }, - ResponsesStreamEvent::OutputTextDelta { - item_id: "msg_456".into(), - output_index: 1, - content_index: Some(0), - delta: "The answer is 42".into(), - }, - ResponsesStreamEvent::Completed { - response: ResponseSummary::default(), - }, - ]; - - let mapped = map_response_events(events); - - let thinking_events: Vec<_> = mapped - .iter() - .filter(|e| matches!(e, LanguageModelCompletionEvent::Thinking { .. })) - .collect(); - assert_eq!( - thinking_events.len(), - 4, - "expected 4 thinking events (2 deltas + separator + second delta), got {:?}", - thinking_events, - ); - - assert!(matches!( - &thinking_events[0], - LanguageModelCompletionEvent::Thinking { text, .. } if text == "Thinking about" - )); - assert!(matches!( - &thinking_events[1], - LanguageModelCompletionEvent::Thinking { text, .. } if text == " the answer" - )); - assert!( - matches!( - &thinking_events[2], - LanguageModelCompletionEvent::Thinking { text, .. } if text == "\n\n" - ), - "expected separator between summary parts" - ); - assert!(matches!( - &thinking_events[3], - LanguageModelCompletionEvent::Thinking { text, .. } if text == "Second part" - )); - - assert!(mapped.iter().any(|e| matches!( - e, - LanguageModelCompletionEvent::Text(t) if t == "The answer is 42" - ))); - } - - #[test] - fn responses_stream_maps_reasoning_from_done_only() { - let events = vec![ - ResponsesStreamEvent::OutputItemAdded { - output_index: 0, - sequence_number: None, - item: ResponseOutputItem::Reasoning(ResponseReasoningItem { - id: Some("rs_789".into()), - summary: vec![], - }), - }, - ResponsesStreamEvent::OutputItemDone { - output_index: 0, - sequence_number: None, - item: ResponseOutputItem::Reasoning(ResponseReasoningItem { - id: Some("rs_789".into()), - summary: vec![ReasoningSummaryPart::SummaryText { - text: "Summary without deltas".into(), - }], - }), - }, - ResponsesStreamEvent::Completed { - response: ResponseSummary::default(), - }, - ]; - - let mapped = map_response_events(events); - - assert!( - !mapped - .iter() - .any(|e| matches!(e, LanguageModelCompletionEvent::Thinking { .. })), - "OutputItemDone reasoning should not produce Thinking events (no delta/done text events)" - ); - } -} diff --git a/crates/language_models/src/provider/open_ai_compatible.rs b/crates/language_models/src/provider/open_ai_compatible.rs index 1c3268749c3340826cd2f50d29e80eecfa1826d4..7a3126f8f33beb7851ea914cfe063b76f8b4443f 100644 --- a/crates/language_models/src/provider/open_ai_compatible.rs +++ b/crates/language_models/src/provider/open_ai_compatible.rs @@ -402,7 +402,7 @@ impl LanguageModel for OpenAiCompatibleLanguageModel { self.model.capabilities.parallel_tool_calls, self.model.capabilities.prompt_cache_key, self.max_output_tokens(), - self.model.reasoning_effort.clone(), + self.model.reasoning_effort, ); let completions = self.stream_completion(request, cx); async move { @@ -417,7 +417,7 @@ impl LanguageModel for OpenAiCompatibleLanguageModel { self.model.capabilities.parallel_tool_calls, self.model.capabilities.prompt_cache_key, self.max_output_tokens(), - self.model.reasoning_effort.clone(), + self.model.reasoning_effort, ); let completions = self.stream_response(request, cx); async move { diff --git a/crates/language_models/src/provider/open_router.rs b/crates/language_models/src/provider/open_router.rs index 09c8eb768d12c61ed1dc86a1251ad52114be6162..fba3a6938aecf1db80680e014e408e4d59c42ff7 100644 --- a/crates/language_models/src/provider/open_router.rs +++ b/crates/language_models/src/provider/open_router.rs @@ -22,7 +22,7 @@ use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*}; use ui_input::InputField; use util::ResultExt; -use crate::provider::util::{fix_streamed_json, parse_tool_arguments}; +use language_model::util::{fix_streamed_json, parse_tool_arguments}; const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openrouter"); const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("OpenRouter"); diff --git a/crates/language_models/src/provider/x_ai.rs b/crates/language_models/src/provider/x_ai.rs index 88189864c7b4b650a24afb2b872c1d6105cf9782..e95bc1ba72fabcf9632b2ed2efd94254fb1313cd 100644 --- a/crates/language_models/src/provider/x_ai.rs +++ b/crates/language_models/src/provider/x_ai.rs @@ -9,7 +9,7 @@ use language_model::{ LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolSchemaFormat, RateLimiter, - Role, env_var, + env_var, }; use open_ai::ResponseStreamEvent; pub use settings::XaiAvailableModel as AvailableModel; @@ -19,7 +19,8 @@ use strum::IntoEnumIterator; use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*}; use ui_input::InputField; use util::ResultExt; -use x_ai::{Model, XAI_API_URL}; +use x_ai::XAI_API_URL; +pub use x_ai::completion::count_xai_tokens; const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("x_ai"); const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("xAI"); @@ -320,7 +321,9 @@ impl LanguageModel for XAiLanguageModel { request: LanguageModelRequest, cx: &App, ) -> BoxFuture<'static, Result> { - count_xai_tokens(request, self.model.clone(), cx) + let model = self.model.clone(); + cx.background_spawn(async move { count_xai_tokens(request, model) }) + .boxed() } fn stream_completion( @@ -354,37 +357,6 @@ impl LanguageModel for XAiLanguageModel { } } -pub fn count_xai_tokens( - request: LanguageModelRequest, - model: Model, - cx: &App, -) -> BoxFuture<'static, Result> { - cx.background_spawn(async move { - let messages = request - .messages - .into_iter() - .map(|message| tiktoken_rs::ChatCompletionRequestMessage { - role: match message.role { - Role::User => "user".into(), - Role::Assistant => "assistant".into(), - Role::System => "system".into(), - }, - content: Some(message.string_contents()), - name: None, - function_call: None, - }) - .collect::>(); - - let model_name = if model.max_token_count() >= 100_000 { - "gpt-4o" - } else { - "gpt-4" - }; - tiktoken_rs::num_tokens_from_messages(model_name, &messages).map(|tokens| tokens as u64) - }) - .boxed() -} - struct ConfigurationView { api_key_editor: Entity, state: Entity, diff --git a/crates/language_models_cloud/Cargo.toml b/crates/language_models_cloud/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..b08acc5ecd5c2a718e936378c2dbfbc3d1c32df0 --- /dev/null +++ b/crates/language_models_cloud/Cargo.toml @@ -0,0 +1,33 @@ +[package] +name = "language_models_cloud" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[lib] +path = "src/language_models_cloud.rs" + +[dependencies] +anthropic = { workspace = true, features = ["schemars"] } +anyhow.workspace = true +cloud_llm_client.workspace = true +futures.workspace = true +google_ai = { workspace = true, features = ["schemars"] } +gpui.workspace = true +http_client.workspace = true +language_model.workspace = true +open_ai = { workspace = true, features = ["schemars"] } +schemars.workspace = true +semver.workspace = true +serde.workspace = true +serde_json.workspace = true +smol.workspace = true +thiserror.workspace = true +x_ai = { workspace = true, features = ["schemars"] } + +[dev-dependencies] +language_model = { workspace = true, features = ["test-support"] } diff --git a/crates/language_models_cloud/LICENSE-GPL b/crates/language_models_cloud/LICENSE-GPL new file mode 120000 index 0000000000000000000000000000000000000000..89e542f750cd3860a0598eff0dc34b56d7336dc4 --- /dev/null +++ b/crates/language_models_cloud/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/language_models_cloud/src/language_models_cloud.rs b/crates/language_models_cloud/src/language_models_cloud.rs new file mode 100644 index 0000000000000000000000000000000000000000..24c8ec87d5c672dbc18b20164f2fe28c9b46b2e1 --- /dev/null +++ b/crates/language_models_cloud/src/language_models_cloud.rs @@ -0,0 +1,1059 @@ +use anthropic::AnthropicModelMode; +use anyhow::{Context as _, Result, anyhow}; +use cloud_llm_client::{ + CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME, + CLIENT_SUPPORTS_X_AI_HEADER_NAME, CompletionBody, CompletionEvent, CompletionRequestStatus, + CountTokensBody, CountTokensResponse, EXPIRED_LLM_TOKEN_HEADER_NAME, ListModelsResponse, + OUTDATED_LLM_TOKEN_HEADER_NAME, SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, + ZED_VERSION_HEADER_NAME, +}; +use futures::{ + AsyncBufReadExt, FutureExt, Stream, StreamExt, + future::BoxFuture, + stream::{self, BoxStream}, +}; +use google_ai::GoogleModelMode; +use gpui::{App, AppContext, AsyncApp, Context, Task}; +use http_client::http::{HeaderMap, HeaderValue}; +use http_client::{ + AsyncBody, HttpClient, HttpClientWithUrl, HttpRequestExt, Method, Response, StatusCode, +}; +use language_model::{ + ANTHROPIC_PROVIDER_ID, ANTHROPIC_PROVIDER_NAME, GOOGLE_PROVIDER_ID, GOOGLE_PROVIDER_NAME, + LanguageModel, LanguageModelCacheConfiguration, LanguageModelCompletionError, + LanguageModelCompletionEvent, LanguageModelEffortLevel, LanguageModelId, LanguageModelName, + LanguageModelProviderId, LanguageModelProviderName, LanguageModelRequest, + LanguageModelToolChoice, LanguageModelToolSchemaFormat, OPEN_AI_PROVIDER_ID, + OPEN_AI_PROVIDER_NAME, PaymentRequiredError, RateLimiter, X_AI_PROVIDER_ID, X_AI_PROVIDER_NAME, + ZED_CLOUD_PROVIDER_ID, ZED_CLOUD_PROVIDER_NAME, +}; + +use schemars::JsonSchema; +use semver::Version; +use serde::{Deserialize, Serialize, de::DeserializeOwned}; +use smol::io::{AsyncReadExt, BufReader}; +use std::collections::VecDeque; +use std::pin::Pin; +use std::str::FromStr; +use std::sync::Arc; +use std::task::Poll; +use std::time::Duration; +use thiserror::Error; + +use anthropic::completion::{ + AnthropicEventMapper, count_anthropic_tokens_with_tiktoken, into_anthropic, +}; +use google_ai::completion::{GoogleEventMapper, into_google}; +use open_ai::completion::{ + OpenAiEventMapper, OpenAiResponseEventMapper, count_open_ai_tokens, into_open_ai, + into_open_ai_response, +}; +use x_ai::completion::count_xai_tokens; + +const PROVIDER_ID: LanguageModelProviderId = ZED_CLOUD_PROVIDER_ID; +const PROVIDER_NAME: LanguageModelProviderName = ZED_CLOUD_PROVIDER_NAME; + +/// Trait for acquiring and refreshing LLM authentication tokens. +pub trait CloudLlmTokenProvider: Send + Sync { + type AuthContext: Clone + Send + 'static; + + fn auth_context(&self, cx: &AsyncApp) -> Self::AuthContext; + fn acquire_token(&self, auth_context: Self::AuthContext) -> BoxFuture<'static, Result>; + fn refresh_token(&self, auth_context: Self::AuthContext) -> BoxFuture<'static, Result>; +} + +#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] +#[serde(tag = "type", rename_all = "lowercase")] +pub enum ModelMode { + #[default] + Default, + Thinking { + /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`. + budget_tokens: Option, + }, +} + +impl From for AnthropicModelMode { + fn from(value: ModelMode) -> Self { + match value { + ModelMode::Default => AnthropicModelMode::Default, + ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens }, + } + } +} + +pub struct CloudLanguageModel { + pub id: LanguageModelId, + pub model: Arc, + pub token_provider: Arc, + pub http_client: Arc, + pub app_version: Option, + pub request_limiter: RateLimiter, +} + +pub struct PerformLlmCompletionResponse { + pub response: Response, + pub includes_status_messages: bool, +} + +impl CloudLanguageModel { + pub async fn perform_llm_completion( + http_client: &HttpClientWithUrl, + token_provider: &TP, + auth_context: TP::AuthContext, + app_version: Option, + body: CompletionBody, + ) -> Result { + let mut token = token_provider.acquire_token(auth_context.clone()).await?; + let mut refreshed_token = false; + + loop { + let request = http_client::Request::builder() + .method(Method::POST) + .uri(http_client.build_zed_llm_url("/completions", &[])?.as_ref()) + .when_some(app_version.as_ref(), |builder, app_version| { + builder.header(ZED_VERSION_HEADER_NAME, app_version.to_string()) + }) + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {token}")) + .header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true") + .header(CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME, "true") + .body(serde_json::to_string(&body)?.into())?; + + let mut response = http_client.send(request).await?; + let status = response.status(); + if status.is_success() { + let includes_status_messages = response + .headers() + .get(SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME) + .is_some(); + + return Ok(PerformLlmCompletionResponse { + response, + includes_status_messages, + }); + } + + if !refreshed_token && needs_llm_token_refresh(&response) { + token = token_provider.refresh_token(auth_context.clone()).await?; + refreshed_token = true; + continue; + } + + if status == StatusCode::PAYMENT_REQUIRED { + return Err(anyhow!(PaymentRequiredError)); + } + + let mut body = String::new(); + let headers = response.headers().clone(); + response.body_mut().read_to_string(&mut body).await?; + return Err(anyhow!(ApiError { + status, + body, + headers + })); + } + } +} + +fn needs_llm_token_refresh(response: &Response) -> bool { + response + .headers() + .get(EXPIRED_LLM_TOKEN_HEADER_NAME) + .is_some() + || response + .headers() + .get(OUTDATED_LLM_TOKEN_HEADER_NAME) + .is_some() +} + +#[derive(Debug, Error)] +#[error("cloud language model request failed with status {status}: {body}")] +struct ApiError { + status: StatusCode, + body: String, + headers: HeaderMap, +} + +/// Represents error responses from Zed's cloud API. +/// +/// Example JSON for an upstream HTTP error: +/// ```json +/// { +/// "code": "upstream_http_error", +/// "message": "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout", +/// "upstream_status": 503 +/// } +/// ``` +#[derive(Debug, serde::Deserialize)] +struct CloudApiError { + code: String, + message: String, + #[serde(default)] + #[serde(deserialize_with = "deserialize_optional_status_code")] + upstream_status: Option, + #[serde(default)] + retry_after: Option, +} + +fn deserialize_optional_status_code<'de, D>(deserializer: D) -> Result, D::Error> +where + D: serde::Deserializer<'de>, +{ + let opt: Option = Option::deserialize(deserializer)?; + Ok(opt.and_then(|code| StatusCode::from_u16(code).ok())) +} + +impl From for LanguageModelCompletionError { + fn from(error: ApiError) -> Self { + if let Ok(cloud_error) = serde_json::from_str::(&error.body) { + if cloud_error.code.starts_with("upstream_http_") { + let status = if let Some(status) = cloud_error.upstream_status { + status + } else if cloud_error.code.ends_with("_error") { + error.status + } else { + // If there's a status code in the code string (e.g. "upstream_http_429") + // then use that; otherwise, see if the JSON contains a status code. + cloud_error + .code + .strip_prefix("upstream_http_") + .and_then(|code_str| code_str.parse::().ok()) + .and_then(|code| StatusCode::from_u16(code).ok()) + .unwrap_or(error.status) + }; + + return LanguageModelCompletionError::UpstreamProviderError { + message: cloud_error.message, + status, + retry_after: cloud_error.retry_after.map(Duration::from_secs_f64), + }; + } + + return LanguageModelCompletionError::from_http_status( + PROVIDER_NAME, + error.status, + cloud_error.message, + None, + ); + } + + let retry_after = None; + LanguageModelCompletionError::from_http_status( + PROVIDER_NAME, + error.status, + error.body, + retry_after, + ) + } +} + +impl LanguageModel for CloudLanguageModel { + fn id(&self) -> LanguageModelId { + self.id.clone() + } + + fn name(&self) -> LanguageModelName { + LanguageModelName::from(self.model.display_name.clone()) + } + + fn provider_id(&self) -> LanguageModelProviderId { + PROVIDER_ID + } + + fn provider_name(&self) -> LanguageModelProviderName { + PROVIDER_NAME + } + + fn upstream_provider_id(&self) -> LanguageModelProviderId { + use cloud_llm_client::LanguageModelProvider::*; + match self.model.provider { + Anthropic => ANTHROPIC_PROVIDER_ID, + OpenAi => OPEN_AI_PROVIDER_ID, + Google => GOOGLE_PROVIDER_ID, + XAi => X_AI_PROVIDER_ID, + } + } + + fn upstream_provider_name(&self) -> LanguageModelProviderName { + use cloud_llm_client::LanguageModelProvider::*; + match self.model.provider { + Anthropic => ANTHROPIC_PROVIDER_NAME, + OpenAi => OPEN_AI_PROVIDER_NAME, + Google => GOOGLE_PROVIDER_NAME, + XAi => X_AI_PROVIDER_NAME, + } + } + + fn is_latest(&self) -> bool { + self.model.is_latest + } + + fn supports_tools(&self) -> bool { + self.model.supports_tools + } + + fn supports_images(&self) -> bool { + self.model.supports_images + } + + fn supports_thinking(&self) -> bool { + self.model.supports_thinking + } + + fn supports_fast_mode(&self) -> bool { + self.model.supports_fast_mode + } + + fn supported_effort_levels(&self) -> Vec { + self.model + .supported_effort_levels + .iter() + .map(|effort_level| LanguageModelEffortLevel { + name: effort_level.name.clone().into(), + value: effort_level.value.clone().into(), + is_default: effort_level.is_default.unwrap_or(false), + }) + .collect() + } + + fn supports_streaming_tools(&self) -> bool { + self.model.supports_streaming_tools + } + + fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { + match choice { + LanguageModelToolChoice::Auto + | LanguageModelToolChoice::Any + | LanguageModelToolChoice::None => true, + } + } + + fn supports_split_token_display(&self) -> bool { + use cloud_llm_client::LanguageModelProvider::*; + matches!(self.model.provider, OpenAi | XAi) + } + + fn telemetry_id(&self) -> String { + format!("zed.dev/{}", self.model.id) + } + + fn tool_input_format(&self) -> LanguageModelToolSchemaFormat { + match self.model.provider { + cloud_llm_client::LanguageModelProvider::Anthropic + | cloud_llm_client::LanguageModelProvider::OpenAi => { + LanguageModelToolSchemaFormat::JsonSchema + } + cloud_llm_client::LanguageModelProvider::Google + | cloud_llm_client::LanguageModelProvider::XAi => { + LanguageModelToolSchemaFormat::JsonSchemaSubset + } + } + } + + fn max_token_count(&self) -> u64 { + self.model.max_token_count as u64 + } + + fn max_output_tokens(&self) -> Option { + Some(self.model.max_output_tokens as u64) + } + + fn cache_configuration(&self) -> Option { + match &self.model.provider { + cloud_llm_client::LanguageModelProvider::Anthropic => { + Some(LanguageModelCacheConfiguration { + min_total_token: 2_048, + should_speculate: true, + max_cache_anchors: 4, + }) + } + cloud_llm_client::LanguageModelProvider::OpenAi + | cloud_llm_client::LanguageModelProvider::XAi + | cloud_llm_client::LanguageModelProvider::Google => None, + } + } + + fn count_tokens( + &self, + request: LanguageModelRequest, + cx: &App, + ) -> BoxFuture<'static, Result> { + match self.model.provider { + cloud_llm_client::LanguageModelProvider::Anthropic => cx + .background_spawn(async move { count_anthropic_tokens_with_tiktoken(request) }) + .boxed(), + cloud_llm_client::LanguageModelProvider::OpenAi => { + let model = match open_ai::Model::from_id(&self.model.id.0) { + Ok(model) => model, + Err(err) => return async move { Err(anyhow!(err)) }.boxed(), + }; + cx.background_spawn(async move { count_open_ai_tokens(request, model) }) + .boxed() + } + cloud_llm_client::LanguageModelProvider::XAi => { + let model = match x_ai::Model::from_id(&self.model.id.0) { + Ok(model) => model, + Err(err) => return async move { Err(anyhow!(err)) }.boxed(), + }; + cx.background_spawn(async move { count_xai_tokens(request, model) }) + .boxed() + } + cloud_llm_client::LanguageModelProvider::Google => { + let http_client = self.http_client.clone(); + let token_provider = self.token_provider.clone(); + let model_id = self.model.id.to_string(); + let generate_content_request = + into_google(request, model_id.clone(), GoogleModelMode::Default); + let auth_context = token_provider.auth_context(&cx.to_async()); + async move { + let token = token_provider.acquire_token(auth_context).await?; + + let request_body = CountTokensBody { + provider: cloud_llm_client::LanguageModelProvider::Google, + model: model_id, + provider_request: serde_json::to_value(&google_ai::CountTokensRequest { + generate_content_request, + })?, + }; + let request = http_client::Request::builder() + .method(Method::POST) + .uri( + http_client + .build_zed_llm_url("/count_tokens", &[])? + .as_ref(), + ) + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {token}")) + .body(serde_json::to_string(&request_body)?.into())?; + let mut response = http_client.send(request).await?; + let status = response.status(); + let headers = response.headers().clone(); + let mut response_body = String::new(); + response + .body_mut() + .read_to_string(&mut response_body) + .await?; + + if status.is_success() { + let response_body: CountTokensResponse = + serde_json::from_str(&response_body)?; + + Ok(response_body.tokens as u64) + } else { + Err(anyhow!(ApiError { + status, + body: response_body, + headers + })) + } + } + .boxed() + } + } + } + + fn stream_completion( + &self, + request: LanguageModelRequest, + cx: &AsyncApp, + ) -> BoxFuture< + 'static, + Result< + BoxStream<'static, Result>, + LanguageModelCompletionError, + >, + > { + let thread_id = request.thread_id.clone(); + let prompt_id = request.prompt_id.clone(); + let app_version = self.app_version.clone(); + let thinking_allowed = request.thinking_allowed; + let enable_thinking = thinking_allowed && self.model.supports_thinking; + let provider_name = provider_name(&self.model.provider); + match self.model.provider { + cloud_llm_client::LanguageModelProvider::Anthropic => { + let effort = request + .thinking_effort + .as_ref() + .and_then(|effort| anthropic::Effort::from_str(effort).ok()); + + let mut request = into_anthropic( + request, + self.model.id.to_string(), + 1.0, + self.model.max_output_tokens as u64, + if enable_thinking { + AnthropicModelMode::Thinking { + budget_tokens: Some(4_096), + } + } else { + AnthropicModelMode::Default + }, + ); + + if enable_thinking && effort.is_some() { + request.thinking = Some(anthropic::Thinking::Adaptive); + request.output_config = Some(anthropic::OutputConfig { effort }); + } + + let http_client = self.http_client.clone(); + let token_provider = self.token_provider.clone(); + let auth_context = token_provider.auth_context(cx); + let future = self.request_limiter.stream(async move { + let PerformLlmCompletionResponse { + response, + includes_status_messages, + } = Self::perform_llm_completion( + &http_client, + &*token_provider, + auth_context, + app_version, + CompletionBody { + thread_id, + prompt_id, + provider: cloud_llm_client::LanguageModelProvider::Anthropic, + model: request.model.clone(), + provider_request: serde_json::to_value(&request) + .map_err(|e| anyhow!(e))?, + }, + ) + .await + .map_err(|err| match err.downcast::() { + Ok(api_err) => anyhow!(LanguageModelCompletionError::from(api_err)), + Err(err) => anyhow!(err), + })?; + + let mut mapper = AnthropicEventMapper::new(); + Ok(map_cloud_completion_events( + Box::pin(response_lines(response, includes_status_messages)), + &provider_name, + move |event| mapper.map_event(event), + )) + }); + async move { Ok(future.await?.boxed()) }.boxed() + } + cloud_llm_client::LanguageModelProvider::OpenAi => { + let http_client = self.http_client.clone(); + let token_provider = self.token_provider.clone(); + let effort = request + .thinking_effort + .as_ref() + .and_then(|effort| open_ai::ReasoningEffort::from_str(effort).ok()); + + let mut request = into_open_ai_response( + request, + &self.model.id.0, + self.model.supports_parallel_tool_calls, + true, + None, + None, + ); + + if enable_thinking && let Some(effort) = effort { + request.reasoning = Some(open_ai::responses::ReasoningConfig { + effort, + summary: Some(open_ai::responses::ReasoningSummaryMode::Auto), + }); + } + + let auth_context = token_provider.auth_context(cx); + let future = self.request_limiter.stream(async move { + let PerformLlmCompletionResponse { + response, + includes_status_messages, + } = Self::perform_llm_completion( + &http_client, + &*token_provider, + auth_context, + app_version, + CompletionBody { + thread_id, + prompt_id, + provider: cloud_llm_client::LanguageModelProvider::OpenAi, + model: request.model.clone(), + provider_request: serde_json::to_value(&request) + .map_err(|e| anyhow!(e))?, + }, + ) + .await?; + + let mut mapper = OpenAiResponseEventMapper::new(); + Ok(map_cloud_completion_events( + Box::pin(response_lines(response, includes_status_messages)), + &provider_name, + move |event| mapper.map_event(event), + )) + }); + async move { Ok(future.await?.boxed()) }.boxed() + } + cloud_llm_client::LanguageModelProvider::XAi => { + let http_client = self.http_client.clone(); + let token_provider = self.token_provider.clone(); + let request = into_open_ai( + request, + &self.model.id.0, + self.model.supports_parallel_tool_calls, + false, + None, + None, + ); + let auth_context = token_provider.auth_context(cx); + let future = self.request_limiter.stream(async move { + let PerformLlmCompletionResponse { + response, + includes_status_messages, + } = Self::perform_llm_completion( + &http_client, + &*token_provider, + auth_context, + app_version, + CompletionBody { + thread_id, + prompt_id, + provider: cloud_llm_client::LanguageModelProvider::XAi, + model: request.model.clone(), + provider_request: serde_json::to_value(&request) + .map_err(|e| anyhow!(e))?, + }, + ) + .await?; + + let mut mapper = OpenAiEventMapper::new(); + Ok(map_cloud_completion_events( + Box::pin(response_lines(response, includes_status_messages)), + &provider_name, + move |event| mapper.map_event(event), + )) + }); + async move { Ok(future.await?.boxed()) }.boxed() + } + cloud_llm_client::LanguageModelProvider::Google => { + let http_client = self.http_client.clone(); + let token_provider = self.token_provider.clone(); + let request = + into_google(request, self.model.id.to_string(), GoogleModelMode::Default); + let auth_context = token_provider.auth_context(cx); + let future = self.request_limiter.stream(async move { + let PerformLlmCompletionResponse { + response, + includes_status_messages, + } = Self::perform_llm_completion( + &http_client, + &*token_provider, + auth_context, + app_version, + CompletionBody { + thread_id, + prompt_id, + provider: cloud_llm_client::LanguageModelProvider::Google, + model: request.model.model_id.clone(), + provider_request: serde_json::to_value(&request) + .map_err(|e| anyhow!(e))?, + }, + ) + .await?; + + let mut mapper = GoogleEventMapper::new(); + Ok(map_cloud_completion_events( + Box::pin(response_lines(response, includes_status_messages)), + &provider_name, + move |event| mapper.map_event(event), + )) + }); + async move { Ok(future.await?.boxed()) }.boxed() + } + } + } +} + +pub struct CloudModelProvider { + token_provider: Arc, + http_client: Arc, + app_version: Option, + models: Vec>, + default_model: Option>, + default_fast_model: Option>, + recommended_models: Vec>, +} + +impl CloudModelProvider { + pub fn new( + token_provider: Arc, + http_client: Arc, + app_version: Option, + ) -> Self { + Self { + token_provider, + http_client, + app_version, + models: Vec::new(), + default_model: None, + default_fast_model: None, + recommended_models: Vec::new(), + } + } + + pub fn refresh_models(&self, cx: &mut Context) -> Task> { + let http_client = self.http_client.clone(); + let token_provider = self.token_provider.clone(); + cx.spawn(async move |this, cx| { + let auth_context = token_provider.auth_context(cx); + let response = + Self::fetch_models_request(&http_client, &*token_provider, auth_context).await?; + this.update(cx, |this, cx| { + this.update_models(response); + cx.notify(); + }) + }) + } + + async fn fetch_models_request( + http_client: &HttpClientWithUrl, + token_provider: &TP, + auth_context: TP::AuthContext, + ) -> Result { + let token = token_provider.acquire_token(auth_context).await?; + + let request = http_client::Request::builder() + .method(Method::GET) + .header(CLIENT_SUPPORTS_X_AI_HEADER_NAME, "true") + .uri(http_client.build_zed_llm_url("/models", &[])?.as_ref()) + .header("Authorization", format!("Bearer {token}")) + .body(AsyncBody::empty())?; + let mut response = http_client + .send(request) + .await + .context("failed to send list models request")?; + + if response.status().is_success() { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + Ok(serde_json::from_str(&body)?) + } else { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + anyhow::bail!( + "error listing models.\nStatus: {:?}\nBody: {body}", + response.status(), + ); + } + } + + pub fn update_models(&mut self, response: ListModelsResponse) { + let models: Vec<_> = response.models.into_iter().map(Arc::new).collect(); + + self.default_model = models + .iter() + .find(|model| { + response + .default_model + .as_ref() + .is_some_and(|default_model_id| &model.id == default_model_id) + }) + .cloned(); + self.default_fast_model = models + .iter() + .find(|model| { + response + .default_fast_model + .as_ref() + .is_some_and(|default_fast_model_id| &model.id == default_fast_model_id) + }) + .cloned(); + self.recommended_models = response + .recommended_models + .iter() + .filter_map(|id| models.iter().find(|model| &model.id == id)) + .cloned() + .collect(); + self.models = models; + } + + pub fn create_model( + &self, + model: &Arc, + ) -> Arc { + Arc::new(CloudLanguageModel:: { + id: LanguageModelId::from(model.id.0.to_string()), + model: model.clone(), + token_provider: self.token_provider.clone(), + http_client: self.http_client.clone(), + app_version: self.app_version.clone(), + request_limiter: RateLimiter::new(4), + }) + } + + pub fn models(&self) -> &[Arc] { + &self.models + } + + pub fn default_model(&self) -> Option<&Arc> { + self.default_model.as_ref() + } + + pub fn default_fast_model(&self) -> Option<&Arc> { + self.default_fast_model.as_ref() + } + + pub fn recommended_models(&self) -> &[Arc] { + &self.recommended_models + } +} + +pub fn map_cloud_completion_events( + stream: Pin>> + Send>>, + provider: &LanguageModelProviderName, + mut map_callback: F, +) -> BoxStream<'static, Result> +where + T: DeserializeOwned + 'static, + F: FnMut(T) -> Vec> + + Send + + 'static, +{ + let provider = provider.clone(); + let mut stream = stream.fuse(); + + let mut saw_stream_ended = false; + + let mut done = false; + let mut pending = VecDeque::new(); + + stream::poll_fn(move |cx| { + loop { + if let Some(item) = pending.pop_front() { + return Poll::Ready(Some(item)); + } + + if done { + return Poll::Ready(None); + } + + match stream.poll_next_unpin(cx) { + Poll::Ready(Some(event)) => { + let items = match event { + Err(error) => { + vec![Err(LanguageModelCompletionError::from(error))] + } + Ok(CompletionEvent::Status(CompletionRequestStatus::StreamEnded)) => { + saw_stream_ended = true; + vec![] + } + Ok(CompletionEvent::Status(status)) => { + LanguageModelCompletionEvent::from_completion_request_status( + status, + provider.clone(), + ) + .transpose() + .map(|event| vec![event]) + .unwrap_or_default() + } + Ok(CompletionEvent::Event(event)) => map_callback(event), + }; + pending.extend(items); + } + Poll::Ready(None) => { + done = true; + + if !saw_stream_ended { + return Poll::Ready(Some(Err( + LanguageModelCompletionError::StreamEndedUnexpectedly { + provider: provider.clone(), + }, + ))); + } + } + Poll::Pending => return Poll::Pending, + } + } + }) + .boxed() +} + +pub fn provider_name( + provider: &cloud_llm_client::LanguageModelProvider, +) -> LanguageModelProviderName { + match provider { + cloud_llm_client::LanguageModelProvider::Anthropic => ANTHROPIC_PROVIDER_NAME, + cloud_llm_client::LanguageModelProvider::OpenAi => OPEN_AI_PROVIDER_NAME, + cloud_llm_client::LanguageModelProvider::Google => GOOGLE_PROVIDER_NAME, + cloud_llm_client::LanguageModelProvider::XAi => X_AI_PROVIDER_NAME, + } +} + +pub fn response_lines( + response: Response, + includes_status_messages: bool, +) -> impl Stream>> { + futures::stream::try_unfold( + (String::new(), BufReader::new(response.into_body())), + move |(mut line, mut body)| async move { + match body.read_line(&mut line).await { + Ok(0) => Ok(None), + Ok(_) => { + let event = if includes_status_messages { + serde_json::from_str::>(&line)? + } else { + CompletionEvent::Event(serde_json::from_str::(&line)?) + }; + + line.clear(); + Ok(Some((event, (line, body)))) + } + Err(e) => Err(e.into()), + } + }, + ) +} + +#[cfg(test)] +mod tests { + use super::*; + use http_client::http::{HeaderMap, StatusCode}; + use language_model::LanguageModelCompletionError; + + #[test] + fn test_api_error_conversion_with_upstream_http_error() { + // upstream_http_error with 503 status should become ServerOverloaded + let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout","upstream_status":503}"#; + + let api_error = ApiError { + status: StatusCode::INTERNAL_SERVER_ERROR, + body: error_body.to_string(), + headers: HeaderMap::new(), + }; + + let completion_error: LanguageModelCompletionError = api_error.into(); + + match completion_error { + LanguageModelCompletionError::UpstreamProviderError { message, .. } => { + assert_eq!( + message, + "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout" + ); + } + _ => panic!( + "Expected UpstreamProviderError for upstream 503, got: {:?}", + completion_error + ), + } + + // upstream_http_error with 500 status should become ApiInternalServerError + let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the OpenAI API: internal server error","upstream_status":500}"#; + + let api_error = ApiError { + status: StatusCode::INTERNAL_SERVER_ERROR, + body: error_body.to_string(), + headers: HeaderMap::new(), + }; + + let completion_error: LanguageModelCompletionError = api_error.into(); + + match completion_error { + LanguageModelCompletionError::UpstreamProviderError { message, .. } => { + assert_eq!( + message, + "Received an error from the OpenAI API: internal server error" + ); + } + _ => panic!( + "Expected UpstreamProviderError for upstream 500, got: {:?}", + completion_error + ), + } + + // upstream_http_error with 429 status should become RateLimitExceeded + let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Google API: rate limit exceeded","upstream_status":429}"#; + + let api_error = ApiError { + status: StatusCode::INTERNAL_SERVER_ERROR, + body: error_body.to_string(), + headers: HeaderMap::new(), + }; + + let completion_error: LanguageModelCompletionError = api_error.into(); + + match completion_error { + LanguageModelCompletionError::UpstreamProviderError { message, .. } => { + assert_eq!( + message, + "Received an error from the Google API: rate limit exceeded" + ); + } + _ => panic!( + "Expected UpstreamProviderError for upstream 429, got: {:?}", + completion_error + ), + } + + // Regular 500 error without upstream_http_error should remain ApiInternalServerError for Zed + let error_body = "Regular internal server error"; + + let api_error = ApiError { + status: StatusCode::INTERNAL_SERVER_ERROR, + body: error_body.to_string(), + headers: HeaderMap::new(), + }; + + let completion_error: LanguageModelCompletionError = api_error.into(); + + match completion_error { + LanguageModelCompletionError::ApiInternalServerError { provider, message } => { + assert_eq!(provider, PROVIDER_NAME); + assert_eq!(message, "Regular internal server error"); + } + _ => panic!( + "Expected ApiInternalServerError for regular 500, got: {:?}", + completion_error + ), + } + + // upstream_http_429 format should be converted to UpstreamProviderError + let error_body = r#"{"code":"upstream_http_429","message":"Upstream Anthropic rate limit exceeded.","retry_after":30.5}"#; + + let api_error = ApiError { + status: StatusCode::INTERNAL_SERVER_ERROR, + body: error_body.to_string(), + headers: HeaderMap::new(), + }; + + let completion_error: LanguageModelCompletionError = api_error.into(); + + match completion_error { + LanguageModelCompletionError::UpstreamProviderError { + message, + status, + retry_after, + } => { + assert_eq!(message, "Upstream Anthropic rate limit exceeded."); + assert_eq!(status, StatusCode::TOO_MANY_REQUESTS); + assert_eq!(retry_after, Some(Duration::from_secs_f64(30.5))); + } + _ => panic!( + "Expected UpstreamProviderError for upstream_http_429, got: {:?}", + completion_error + ), + } + + // Invalid JSON in error body should fall back to regular error handling + let error_body = "Not JSON at all"; + + let api_error = ApiError { + status: StatusCode::INTERNAL_SERVER_ERROR, + body: error_body.to_string(), + headers: HeaderMap::new(), + }; + + let completion_error: LanguageModelCompletionError = api_error.into(); + + match completion_error { + LanguageModelCompletionError::ApiInternalServerError { provider, .. } => { + assert_eq!(provider, PROVIDER_NAME); + } + _ => panic!( + "Expected ApiInternalServerError for invalid JSON, got: {:?}", + completion_error + ), + } + } +} diff --git a/crates/open_ai/Cargo.toml b/crates/open_ai/Cargo.toml index 3de3a4dc3fcb8c9519f4c67be7cead75401f6281..9a73e73196fa225691fa68e2ca839a19783bc3ca 100644 --- a/crates/open_ai/Cargo.toml +++ b/crates/open_ai/Cargo.toml @@ -17,13 +17,18 @@ schemars = ["dep:schemars"] [dependencies] anyhow.workspace = true +collections.workspace = true futures.workspace = true http_client.workspace = true +language_model_core.workspace = true rand.workspace = true schemars = { workspace = true, optional = true } log.workspace = true serde.workspace = true serde_json.workspace = true -settings.workspace = true strum.workspace = true thiserror.workspace = true +tiktoken-rs.workspace = true + +[dev-dependencies] +pretty_assertions.workspace = true diff --git a/crates/open_ai/src/completion.rs b/crates/open_ai/src/completion.rs new file mode 100644 index 0000000000000000000000000000000000000000..81fa79d35ee134ef4fee7618aec17d34e9382cec --- /dev/null +++ b/crates/open_ai/src/completion.rs @@ -0,0 +1,1693 @@ +use anyhow::{Result, anyhow}; +use collections::HashMap; +use futures::{Stream, StreamExt}; +use language_model_core::{ + LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelImage, + LanguageModelRequest, LanguageModelRequestMessage, LanguageModelToolChoice, + LanguageModelToolResultContent, LanguageModelToolUse, LanguageModelToolUseId, MessageContent, + Role, StopReason, TokenUsage, + util::{fix_streamed_json, parse_tool_arguments}, +}; +use std::pin::Pin; +use std::sync::Arc; + +use crate::responses::{ + Request as ResponseRequest, ResponseFunctionCallItem, ResponseFunctionCallOutputContent, + ResponseFunctionCallOutputItem, ResponseInputContent, ResponseInputItem, ResponseMessageItem, + ResponseOutputItem, ResponseSummary as ResponsesSummary, ResponseUsage as ResponsesUsage, + StreamEvent as ResponsesStreamEvent, +}; +use crate::{ + FunctionContent, FunctionDefinition, ImageUrl, MessagePart, Model, ReasoningEffort, + ResponseStreamEvent, ToolCall, ToolCallContent, +}; + +pub fn into_open_ai( + request: LanguageModelRequest, + model_id: &str, + supports_parallel_tool_calls: bool, + supports_prompt_cache_key: bool, + max_output_tokens: Option, + reasoning_effort: Option, +) -> crate::Request { + let stream = !model_id.starts_with("o1-"); + + let mut messages = Vec::new(); + for message in request.messages { + for content in message.content { + match content { + MessageContent::Text(text) | MessageContent::Thinking { text, .. } => { + let should_add = if message.role == Role::User { + // Including whitespace-only user messages can cause error with OpenAI compatible APIs + // See https://github.com/zed-industries/zed/issues/40097 + !text.trim().is_empty() + } else { + !text.is_empty() + }; + if should_add { + add_message_content_part( + MessagePart::Text { text }, + message.role, + &mut messages, + ); + } + } + MessageContent::RedactedThinking(_) => {} + MessageContent::Image(image) => { + add_message_content_part( + MessagePart::Image { + image_url: ImageUrl { + url: image.to_base64_url(), + detail: None, + }, + }, + message.role, + &mut messages, + ); + } + MessageContent::ToolUse(tool_use) => { + let tool_call = ToolCall { + id: tool_use.id.to_string(), + content: ToolCallContent::Function { + function: FunctionContent { + name: tool_use.name.to_string(), + arguments: serde_json::to_string(&tool_use.input) + .unwrap_or_default(), + }, + }, + }; + + if let Some(crate::RequestMessage::Assistant { tool_calls, .. }) = + messages.last_mut() + { + tool_calls.push(tool_call); + } else { + messages.push(crate::RequestMessage::Assistant { + content: None, + tool_calls: vec![tool_call], + }); + } + } + MessageContent::ToolResult(tool_result) => { + let content = match &tool_result.content { + LanguageModelToolResultContent::Text(text) => { + vec![MessagePart::Text { + text: text.to_string(), + }] + } + LanguageModelToolResultContent::Image(image) => { + vec![MessagePart::Image { + image_url: ImageUrl { + url: image.to_base64_url(), + detail: None, + }, + }] + } + }; + + messages.push(crate::RequestMessage::Tool { + content: content.into(), + tool_call_id: tool_result.tool_use_id.to_string(), + }); + } + } + } + } + + crate::Request { + model: model_id.into(), + messages, + stream, + stream_options: if stream { + Some(crate::StreamOptions::default()) + } else { + None + }, + stop: request.stop, + temperature: request.temperature.or(Some(1.0)), + max_completion_tokens: max_output_tokens, + parallel_tool_calls: if supports_parallel_tool_calls && !request.tools.is_empty() { + Some(supports_parallel_tool_calls) + } else { + None + }, + prompt_cache_key: if supports_prompt_cache_key { + request.thread_id + } else { + None + }, + tools: request + .tools + .into_iter() + .map(|tool| crate::ToolDefinition::Function { + function: FunctionDefinition { + name: tool.name, + description: Some(tool.description), + parameters: Some(tool.input_schema), + }, + }) + .collect(), + tool_choice: request.tool_choice.map(|choice| match choice { + LanguageModelToolChoice::Auto => crate::ToolChoice::Auto, + LanguageModelToolChoice::Any => crate::ToolChoice::Required, + LanguageModelToolChoice::None => crate::ToolChoice::None, + }), + reasoning_effort, + } +} + +pub fn into_open_ai_response( + request: LanguageModelRequest, + model_id: &str, + supports_parallel_tool_calls: bool, + supports_prompt_cache_key: bool, + max_output_tokens: Option, + reasoning_effort: Option, +) -> ResponseRequest { + let stream = !model_id.starts_with("o1-"); + + let LanguageModelRequest { + thread_id, + prompt_id: _, + intent: _, + messages, + tools, + tool_choice, + stop: _, + temperature, + thinking_allowed: _, + thinking_effort: _, + speed: _, + } = request; + + let mut input_items = Vec::new(); + for (index, message) in messages.into_iter().enumerate() { + append_message_to_response_items(message, index, &mut input_items); + } + + let tools: Vec<_> = tools + .into_iter() + .map(|tool| crate::responses::ToolDefinition::Function { + name: tool.name, + description: Some(tool.description), + parameters: Some(tool.input_schema), + strict: None, + }) + .collect(); + + ResponseRequest { + model: model_id.into(), + input: input_items, + stream, + temperature, + top_p: None, + max_output_tokens, + parallel_tool_calls: if tools.is_empty() { + None + } else { + Some(supports_parallel_tool_calls) + }, + tool_choice: tool_choice.map(|choice| match choice { + LanguageModelToolChoice::Auto => crate::ToolChoice::Auto, + LanguageModelToolChoice::Any => crate::ToolChoice::Required, + LanguageModelToolChoice::None => crate::ToolChoice::None, + }), + tools, + prompt_cache_key: if supports_prompt_cache_key { + thread_id + } else { + None + }, + reasoning: reasoning_effort.map(|effort| crate::responses::ReasoningConfig { + effort, + summary: Some(crate::responses::ReasoningSummaryMode::Auto), + }), + } +} + +fn append_message_to_response_items( + message: LanguageModelRequestMessage, + index: usize, + input_items: &mut Vec, +) { + let mut content_parts: Vec = Vec::new(); + + for content in message.content { + match content { + MessageContent::Text(text) => { + push_response_text_part(&message.role, text, &mut content_parts); + } + MessageContent::Thinking { text, .. } => { + push_response_text_part(&message.role, text, &mut content_parts); + } + MessageContent::RedactedThinking(_) => {} + MessageContent::Image(image) => { + push_response_image_part(&message.role, image, &mut content_parts); + } + MessageContent::ToolUse(tool_use) => { + flush_response_parts(&message.role, index, &mut content_parts, input_items); + let call_id = tool_use.id.to_string(); + input_items.push(ResponseInputItem::FunctionCall(ResponseFunctionCallItem { + call_id, + name: tool_use.name.to_string(), + arguments: tool_use.raw_input, + })); + } + MessageContent::ToolResult(tool_result) => { + flush_response_parts(&message.role, index, &mut content_parts, input_items); + input_items.push(ResponseInputItem::FunctionCallOutput( + ResponseFunctionCallOutputItem { + call_id: tool_result.tool_use_id.to_string(), + output: match tool_result.content { + LanguageModelToolResultContent::Text(text) => { + ResponseFunctionCallOutputContent::Text(text.to_string()) + } + LanguageModelToolResultContent::Image(image) => { + ResponseFunctionCallOutputContent::List(vec![ + ResponseInputContent::Image { + image_url: image.to_base64_url(), + }, + ]) + } + }, + }, + )); + } + } + } + + flush_response_parts(&message.role, index, &mut content_parts, input_items); +} + +fn push_response_text_part( + role: &Role, + text: impl Into, + parts: &mut Vec, +) { + let text = text.into(); + if text.trim().is_empty() { + return; + } + + match role { + Role::Assistant => parts.push(ResponseInputContent::OutputText { + text, + annotations: Vec::new(), + }), + _ => parts.push(ResponseInputContent::Text { text }), + } +} + +fn push_response_image_part( + role: &Role, + image: LanguageModelImage, + parts: &mut Vec, +) { + match role { + Role::Assistant => parts.push(ResponseInputContent::OutputText { + text: "[image omitted]".to_string(), + annotations: Vec::new(), + }), + _ => parts.push(ResponseInputContent::Image { + image_url: image.to_base64_url(), + }), + } +} + +fn flush_response_parts( + role: &Role, + _index: usize, + parts: &mut Vec, + input_items: &mut Vec, +) { + if parts.is_empty() { + return; + } + + let item = ResponseInputItem::Message(ResponseMessageItem { + role: match role { + Role::User => crate::Role::User, + Role::Assistant => crate::Role::Assistant, + Role::System => crate::Role::System, + }, + content: parts.clone(), + }); + + input_items.push(item); + parts.clear(); +} + +fn add_message_content_part( + new_part: MessagePart, + role: Role, + messages: &mut Vec, +) { + match (role, messages.last_mut()) { + (Role::User, Some(crate::RequestMessage::User { content })) + | ( + Role::Assistant, + Some(crate::RequestMessage::Assistant { + content: Some(content), + .. + }), + ) + | (Role::System, Some(crate::RequestMessage::System { content, .. })) => { + content.push_part(new_part); + } + _ => { + messages.push(match role { + Role::User => crate::RequestMessage::User { + content: crate::MessageContent::from(vec![new_part]), + }, + Role::Assistant => crate::RequestMessage::Assistant { + content: Some(crate::MessageContent::from(vec![new_part])), + tool_calls: Vec::new(), + }, + Role::System => crate::RequestMessage::System { + content: crate::MessageContent::from(vec![new_part]), + }, + }); + } + } +} + +pub struct OpenAiEventMapper { + tool_calls_by_index: HashMap, +} + +impl OpenAiEventMapper { + pub fn new() -> Self { + Self { + tool_calls_by_index: HashMap::default(), + } + } + + pub fn map_stream( + mut self, + events: Pin>>>, + ) -> impl Stream> + { + events.flat_map(move |event| { + futures::stream::iter(match event { + Ok(event) => self.map_event(event), + Err(error) => vec![Err(LanguageModelCompletionError::from(anyhow!(error)))], + }) + }) + } + + pub fn map_event( + &mut self, + event: ResponseStreamEvent, + ) -> Vec> { + let mut events = Vec::new(); + if let Some(usage) = event.usage { + events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage { + input_tokens: usage.prompt_tokens, + output_tokens: usage.completion_tokens, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + }))); + } + + let Some(choice) = event.choices.first() else { + return events; + }; + + if let Some(delta) = choice.delta.as_ref() { + if let Some(reasoning_content) = delta.reasoning_content.clone() { + if !reasoning_content.is_empty() { + events.push(Ok(LanguageModelCompletionEvent::Thinking { + text: reasoning_content, + signature: None, + })); + } + } + if let Some(content) = delta.content.clone() { + if !content.is_empty() { + events.push(Ok(LanguageModelCompletionEvent::Text(content))); + } + } + + if let Some(tool_calls) = delta.tool_calls.as_ref() { + for tool_call in tool_calls { + let entry = self.tool_calls_by_index.entry(tool_call.index).or_default(); + + if let Some(tool_id) = tool_call.id.clone() { + entry.id = tool_id; + } + + if let Some(function) = tool_call.function.as_ref() { + if let Some(name) = function.name.clone() { + entry.name = name; + } + + if let Some(arguments) = function.arguments.clone() { + entry.arguments.push_str(&arguments); + } + } + + if !entry.id.is_empty() && !entry.name.is_empty() { + if let Ok(input) = serde_json::from_str::( + &fix_streamed_json(&entry.arguments), + ) { + events.push(Ok(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: entry.id.clone().into(), + name: entry.name.as_str().into(), + is_input_complete: false, + input, + raw_input: entry.arguments.clone(), + thought_signature: None, + }, + ))); + } + } + } + } + } + + match choice.finish_reason.as_deref() { + Some("stop") => { + events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn))); + } + Some("tool_calls") => { + events.extend(self.tool_calls_by_index.drain().map(|(_, tool_call)| { + match parse_tool_arguments(&tool_call.arguments) { + Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: tool_call.id.clone().into(), + name: tool_call.name.as_str().into(), + is_input_complete: true, + input, + raw_input: tool_call.arguments.clone(), + thought_signature: None, + }, + )), + Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError { + id: tool_call.id.into(), + tool_name: tool_call.name.into(), + raw_input: tool_call.arguments.clone().into(), + json_parse_error: error.to_string(), + }), + } + })); + + events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse))); + } + Some(stop_reason) => { + log::error!("Unexpected OpenAI stop_reason: {stop_reason:?}",); + events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn))); + } + None => {} + } + + events + } +} + +#[derive(Default)] +struct RawToolCall { + id: String, + name: String, + arguments: String, +} + +pub struct OpenAiResponseEventMapper { + function_calls_by_item: HashMap, + pending_stop_reason: Option, +} + +#[derive(Default)] +struct PendingResponseFunctionCall { + call_id: String, + name: Arc, + arguments: String, +} + +impl OpenAiResponseEventMapper { + pub fn new() -> Self { + Self { + function_calls_by_item: HashMap::default(), + pending_stop_reason: None, + } + } + + pub fn map_stream( + mut self, + events: Pin>>>, + ) -> impl Stream> + { + events.flat_map(move |event| { + futures::stream::iter(match event { + Ok(event) => self.map_event(event), + Err(error) => vec![Err(LanguageModelCompletionError::from(anyhow!(error)))], + }) + }) + } + + pub fn map_event( + &mut self, + event: ResponsesStreamEvent, + ) -> Vec> { + match event { + ResponsesStreamEvent::OutputItemAdded { item, .. } => { + let mut events = Vec::new(); + + match &item { + ResponseOutputItem::Message(message) => { + if let Some(id) = &message.id { + events.push(Ok(LanguageModelCompletionEvent::StartMessage { + message_id: id.clone(), + })); + } + } + ResponseOutputItem::FunctionCall(function_call) => { + if let Some(item_id) = function_call.id.clone() { + let call_id = function_call + .call_id + .clone() + .or_else(|| function_call.id.clone()) + .unwrap_or_else(|| item_id.clone()); + let entry = PendingResponseFunctionCall { + call_id, + name: Arc::::from( + function_call.name.clone().unwrap_or_default(), + ), + arguments: function_call.arguments.clone(), + }; + self.function_calls_by_item.insert(item_id, entry); + } + } + ResponseOutputItem::Reasoning(_) | ResponseOutputItem::Unknown => {} + } + events + } + ResponsesStreamEvent::ReasoningSummaryTextDelta { delta, .. } => { + if delta.is_empty() { + Vec::new() + } else { + vec![Ok(LanguageModelCompletionEvent::Thinking { + text: delta, + signature: None, + })] + } + } + ResponsesStreamEvent::OutputTextDelta { delta, .. } => { + if delta.is_empty() { + Vec::new() + } else { + vec![Ok(LanguageModelCompletionEvent::Text(delta))] + } + } + ResponsesStreamEvent::FunctionCallArgumentsDelta { item_id, delta, .. } => { + if let Some(entry) = self.function_calls_by_item.get_mut(&item_id) { + entry.arguments.push_str(&delta); + if let Ok(input) = serde_json::from_str::( + &fix_streamed_json(&entry.arguments), + ) { + return vec![Ok(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: LanguageModelToolUseId::from(entry.call_id.clone()), + name: entry.name.clone(), + is_input_complete: false, + input, + raw_input: entry.arguments.clone(), + thought_signature: None, + }, + ))]; + } + } + Vec::new() + } + ResponsesStreamEvent::FunctionCallArgumentsDone { + item_id, arguments, .. + } => { + if let Some(mut entry) = self.function_calls_by_item.remove(&item_id) { + if !arguments.is_empty() { + entry.arguments = arguments; + } + let raw_input = entry.arguments.clone(); + self.pending_stop_reason = Some(StopReason::ToolUse); + match parse_tool_arguments(&entry.arguments) { + Ok(input) => { + vec![Ok(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: LanguageModelToolUseId::from(entry.call_id.clone()), + name: entry.name.clone(), + is_input_complete: true, + input, + raw_input, + thought_signature: None, + }, + ))] + } + Err(error) => { + vec![Ok(LanguageModelCompletionEvent::ToolUseJsonParseError { + id: LanguageModelToolUseId::from(entry.call_id.clone()), + tool_name: entry.name.clone(), + raw_input: Arc::::from(raw_input), + json_parse_error: error.to_string(), + })] + } + } + } else { + Vec::new() + } + } + ResponsesStreamEvent::Completed { response } => { + self.handle_completion(response, StopReason::EndTurn) + } + ResponsesStreamEvent::Incomplete { response } => { + let reason = response + .status_details + .as_ref() + .and_then(|details| details.reason.as_deref()); + let stop_reason = match reason { + Some("max_output_tokens") => StopReason::MaxTokens, + Some("content_filter") => { + self.pending_stop_reason = Some(StopReason::Refusal); + StopReason::Refusal + } + _ => self + .pending_stop_reason + .take() + .unwrap_or(StopReason::EndTurn), + }; + + let mut events = Vec::new(); + if self.pending_stop_reason.is_none() { + events.extend(self.emit_tool_calls_from_output(&response.output)); + } + if let Some(usage) = response.usage.as_ref() { + events.push(Ok(LanguageModelCompletionEvent::UsageUpdate( + token_usage_from_response_usage(usage), + ))); + } + events.push(Ok(LanguageModelCompletionEvent::Stop(stop_reason))); + events + } + ResponsesStreamEvent::Failed { response } => { + let message = response + .status_details + .and_then(|details| details.error) + .map(|error| error.to_string()) + .unwrap_or_else(|| "response failed".to_string()); + vec![Err(LanguageModelCompletionError::Other(anyhow!(message)))] + } + ResponsesStreamEvent::Error { error } + | ResponsesStreamEvent::GenericError { error } => { + vec![Err(LanguageModelCompletionError::Other(anyhow!( + error.message + )))] + } + ResponsesStreamEvent::ReasoningSummaryPartAdded { summary_index, .. } => { + if summary_index > 0 { + vec![Ok(LanguageModelCompletionEvent::Thinking { + text: "\n\n".to_string(), + signature: None, + })] + } else { + Vec::new() + } + } + ResponsesStreamEvent::OutputTextDone { .. } + | ResponsesStreamEvent::OutputItemDone { .. } + | ResponsesStreamEvent::ContentPartAdded { .. } + | ResponsesStreamEvent::ContentPartDone { .. } + | ResponsesStreamEvent::ReasoningSummaryTextDone { .. } + | ResponsesStreamEvent::ReasoningSummaryPartDone { .. } + | ResponsesStreamEvent::Created { .. } + | ResponsesStreamEvent::InProgress { .. } + | ResponsesStreamEvent::Unknown => Vec::new(), + } + } + + fn handle_completion( + &mut self, + response: ResponsesSummary, + default_reason: StopReason, + ) -> Vec> { + let mut events = Vec::new(); + + if self.pending_stop_reason.is_none() { + events.extend(self.emit_tool_calls_from_output(&response.output)); + } + + if let Some(usage) = response.usage.as_ref() { + events.push(Ok(LanguageModelCompletionEvent::UsageUpdate( + token_usage_from_response_usage(usage), + ))); + } + + let stop_reason = self.pending_stop_reason.take().unwrap_or(default_reason); + events.push(Ok(LanguageModelCompletionEvent::Stop(stop_reason))); + events + } + + fn emit_tool_calls_from_output( + &mut self, + output: &[ResponseOutputItem], + ) -> Vec> { + let mut events = Vec::new(); + for item in output { + if let ResponseOutputItem::FunctionCall(function_call) = item { + let Some(call_id) = function_call + .call_id + .clone() + .or_else(|| function_call.id.clone()) + else { + log::error!( + "Function call item missing both call_id and id: {:?}", + function_call + ); + continue; + }; + let name: Arc = Arc::from(function_call.name.clone().unwrap_or_default()); + let arguments = &function_call.arguments; + self.pending_stop_reason = Some(StopReason::ToolUse); + match parse_tool_arguments(arguments) { + Ok(input) => { + events.push(Ok(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: LanguageModelToolUseId::from(call_id.clone()), + name: name.clone(), + is_input_complete: true, + input, + raw_input: arguments.clone(), + thought_signature: None, + }, + ))); + } + Err(error) => { + events.push(Ok(LanguageModelCompletionEvent::ToolUseJsonParseError { + id: LanguageModelToolUseId::from(call_id.clone()), + tool_name: name.clone(), + raw_input: Arc::::from(arguments.clone()), + json_parse_error: error.to_string(), + })); + } + } + } + } + events + } +} + +fn token_usage_from_response_usage(usage: &ResponsesUsage) -> TokenUsage { + TokenUsage { + input_tokens: usage.input_tokens.unwrap_or_default(), + output_tokens: usage.output_tokens.unwrap_or_default(), + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + } +} + +pub fn collect_tiktoken_messages( + request: LanguageModelRequest, +) -> Vec { + request + .messages + .into_iter() + .map(|message| tiktoken_rs::ChatCompletionRequestMessage { + role: match message.role { + Role::User => "user".into(), + Role::Assistant => "assistant".into(), + Role::System => "system".into(), + }, + content: Some(message.string_contents()), + name: None, + function_call: None, + }) + .collect::>() +} + +/// Count tokens for an OpenAI model. This is synchronous; callers should spawn +/// it on a background thread if needed. +pub fn count_open_ai_tokens(request: LanguageModelRequest, model: Model) -> Result { + let messages = collect_tiktoken_messages(request); + match model { + Model::Custom { max_tokens, .. } => { + let model = if max_tokens >= 100_000 { + // If the max tokens is 100k or more, it likely uses the o200k_base tokenizer + "gpt-4o" + } else { + // Otherwise fallback to gpt-4, since only cl100k_base and o200k_base are + // supported with this tiktoken method + "gpt-4" + }; + tiktoken_rs::num_tokens_from_messages(model, &messages) + } + // Currently supported by tiktoken_rs + // Sometimes tiktoken-rs is behind on model support. If that is the case, make a new branch + // arm with an override. We enumerate all supported models here so that we can check if new + // models are supported yet or not. + Model::ThreePointFiveTurbo + | Model::Four + | Model::FourTurbo + | Model::FourOmniMini + | Model::FourPointOneNano + | Model::O1 + | Model::O3 + | Model::O3Mini + | Model::Five + | Model::FiveCodex + | Model::FiveMini + | Model::FiveNano => tiktoken_rs::num_tokens_from_messages(model.id(), &messages), + // GPT-5.1, 5.2, 5.2-codex, 5.3-codex, 5.4, and 5.4-pro don't have dedicated tiktoken support; use gpt-5 tokenizer + Model::FivePointOne + | Model::FivePointTwo + | Model::FivePointTwoCodex + | Model::FivePointThreeCodex + | Model::FivePointFour + | Model::FivePointFourPro => tiktoken_rs::num_tokens_from_messages("gpt-5", &messages), + } + .map(|tokens| tokens as u64) +} + +#[cfg(test)] +mod tests { + use crate::responses::{ + ReasoningSummaryPart, ResponseFunctionToolCall, ResponseOutputItem, ResponseOutputMessage, + ResponseReasoningItem, ResponseStatusDetails, ResponseSummary, ResponseUsage, + StreamEvent as ResponsesStreamEvent, + }; + use futures::{StreamExt, executor::block_on}; + use language_model_core::{ + LanguageModelImage, LanguageModelRequestMessage, LanguageModelRequestTool, + LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolUse, + LanguageModelToolUseId, SharedString, + }; + use pretty_assertions::assert_eq; + use serde_json::json; + + use super::*; + + fn map_response_events(events: Vec) -> Vec { + block_on(async { + OpenAiResponseEventMapper::new() + .map_stream(Box::pin(futures::stream::iter(events.into_iter().map(Ok)))) + .collect::>() + .await + .into_iter() + .map(Result::unwrap) + .collect() + }) + } + + fn response_item_message(id: &str) -> ResponseOutputItem { + ResponseOutputItem::Message(ResponseOutputMessage { + id: Some(id.to_string()), + role: Some("assistant".to_string()), + status: Some("in_progress".to_string()), + content: vec![], + }) + } + + fn response_item_function_call(id: &str, args: Option<&str>) -> ResponseOutputItem { + ResponseOutputItem::FunctionCall(ResponseFunctionToolCall { + id: Some(id.to_string()), + status: Some("in_progress".to_string()), + name: Some("get_weather".to_string()), + call_id: Some("call_123".to_string()), + arguments: args.map(|s| s.to_string()).unwrap_or_default(), + }) + } + + #[test] + fn tiktoken_rs_support() { + let request = LanguageModelRequest { + thread_id: None, + prompt_id: None, + intent: None, + messages: vec![LanguageModelRequestMessage { + role: Role::User, + content: vec![MessageContent::Text("message".into())], + cache: false, + reasoning_details: None, + }], + tools: vec![], + tool_choice: None, + stop: vec![], + temperature: None, + thinking_allowed: true, + thinking_effort: None, + speed: None, + }; + + // Validate that all models are supported by tiktoken-rs + for model in ::iter() { + let count = count_open_ai_tokens(request.clone(), model).unwrap(); + assert!(count > 0); + } + } + + #[test] + fn responses_stream_maps_text_and_usage() { + let events = vec![ + ResponsesStreamEvent::OutputItemAdded { + output_index: 0, + sequence_number: None, + item: response_item_message("msg_123"), + }, + ResponsesStreamEvent::OutputTextDelta { + item_id: "msg_123".into(), + output_index: 0, + content_index: Some(0), + delta: "Hello".into(), + }, + ResponsesStreamEvent::Completed { + response: ResponseSummary { + usage: Some(ResponseUsage { + input_tokens: Some(5), + output_tokens: Some(3), + total_tokens: Some(8), + }), + ..Default::default() + }, + }, + ]; + + let mapped = map_response_events(events); + assert!(matches!( + mapped[0], + LanguageModelCompletionEvent::StartMessage { ref message_id } if message_id == "msg_123" + )); + assert!(matches!( + mapped[1], + LanguageModelCompletionEvent::Text(ref text) if text == "Hello" + )); + assert!(matches!( + mapped[2], + LanguageModelCompletionEvent::UsageUpdate(TokenUsage { + input_tokens: 5, + output_tokens: 3, + .. + }) + )); + assert!(matches!( + mapped[3], + LanguageModelCompletionEvent::Stop(StopReason::EndTurn) + )); + } + + #[test] + fn into_open_ai_response_builds_complete_payload() { + let tool_call_id = LanguageModelToolUseId::from("call-42"); + let tool_input = json!({ "city": "Boston" }); + let tool_arguments = serde_json::to_string(&tool_input).unwrap(); + let tool_use = LanguageModelToolUse { + id: tool_call_id.clone(), + name: Arc::from("get_weather"), + raw_input: tool_arguments.clone(), + input: tool_input, + is_input_complete: true, + thought_signature: None, + }; + let tool_result = LanguageModelToolResult { + tool_use_id: tool_call_id, + tool_name: Arc::from("get_weather"), + is_error: false, + content: LanguageModelToolResultContent::Text(Arc::from("Sunny")), + output: Some(json!({ "forecast": "Sunny" })), + }; + let user_image = LanguageModelImage { + source: SharedString::from("aGVsbG8="), + size: None, + }; + let expected_image_url = user_image.to_base64_url(); + + let request = LanguageModelRequest { + thread_id: Some("thread-123".into()), + prompt_id: None, + intent: None, + messages: vec![ + LanguageModelRequestMessage { + role: Role::System, + content: vec![MessageContent::Text("System context".into())], + cache: false, + reasoning_details: None, + }, + LanguageModelRequestMessage { + role: Role::User, + content: vec![ + MessageContent::Text("Please check the weather.".into()), + MessageContent::Image(user_image), + ], + cache: false, + reasoning_details: None, + }, + LanguageModelRequestMessage { + role: Role::Assistant, + content: vec![ + MessageContent::Text("Looking that up.".into()), + MessageContent::ToolUse(tool_use), + ], + cache: false, + reasoning_details: None, + }, + LanguageModelRequestMessage { + role: Role::Assistant, + content: vec![MessageContent::ToolResult(tool_result)], + cache: false, + reasoning_details: None, + }, + ], + tools: vec![LanguageModelRequestTool { + name: "get_weather".into(), + description: "Fetches the weather".into(), + input_schema: json!({ "type": "object" }), + use_input_streaming: false, + }], + tool_choice: Some(LanguageModelToolChoice::Any), + stop: vec!["".into()], + temperature: None, + thinking_allowed: false, + thinking_effort: None, + speed: None, + }; + + let response = into_open_ai_response( + request, + "custom-model", + true, + true, + Some(2048), + Some(ReasoningEffort::Low), + ); + + let serialized = serde_json::to_value(&response).unwrap(); + let expected = json!({ + "model": "custom-model", + "input": [ + { + "type": "message", + "role": "system", + "content": [ + { "type": "input_text", "text": "System context" } + ] + }, + { + "type": "message", + "role": "user", + "content": [ + { "type": "input_text", "text": "Please check the weather." }, + { "type": "input_image", "image_url": expected_image_url } + ] + }, + { + "type": "message", + "role": "assistant", + "content": [ + { "type": "output_text", "text": "Looking that up.", "annotations": [] } + ] + }, + { + "type": "function_call", + "call_id": "call-42", + "name": "get_weather", + "arguments": tool_arguments + }, + { + "type": "function_call_output", + "call_id": "call-42", + "output": "Sunny" + } + ], + "stream": true, + "max_output_tokens": 2048, + "parallel_tool_calls": true, + "tool_choice": "required", + "tools": [ + { + "type": "function", + "name": "get_weather", + "description": "Fetches the weather", + "parameters": { "type": "object" } + } + ], + "prompt_cache_key": "thread-123", + "reasoning": { "effort": "low", "summary": "auto" } + }); + + assert_eq!(serialized, expected); + } + + #[test] + fn responses_stream_maps_tool_calls() { + let events = vec![ + ResponsesStreamEvent::OutputItemAdded { + output_index: 0, + sequence_number: None, + item: response_item_function_call("item_fn", Some("{\"city\":\"Bos")), + }, + ResponsesStreamEvent::FunctionCallArgumentsDelta { + item_id: "item_fn".into(), + output_index: 0, + delta: "ton\"}".into(), + sequence_number: None, + }, + ResponsesStreamEvent::FunctionCallArgumentsDone { + item_id: "item_fn".into(), + output_index: 0, + arguments: "{\"city\":\"Boston\"}".into(), + sequence_number: None, + }, + ResponsesStreamEvent::Completed { + response: ResponseSummary::default(), + }, + ]; + + let mapped = map_response_events(events); + assert_eq!(mapped.len(), 3); + assert!(matches!( + mapped[0], + LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { + is_input_complete: false, + .. + }) + )); + assert!(matches!( + mapped[1], + LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { + ref id, + ref name, + ref raw_input, + is_input_complete: true, + .. + }) if id.to_string() == "call_123" + && name.as_ref() == "get_weather" + && raw_input == "{\"city\":\"Boston\"}" + )); + assert!(matches!( + mapped[2], + LanguageModelCompletionEvent::Stop(StopReason::ToolUse) + )); + } + + #[test] + fn responses_stream_uses_max_tokens_stop_reason() { + let events = vec![ResponsesStreamEvent::Incomplete { + response: ResponseSummary { + status_details: Some(ResponseStatusDetails { + reason: Some("max_output_tokens".into()), + r#type: Some("incomplete".into()), + error: None, + }), + usage: Some(ResponseUsage { + input_tokens: Some(10), + output_tokens: Some(20), + total_tokens: Some(30), + }), + ..Default::default() + }, + }]; + + let mapped = map_response_events(events); + assert!(matches!( + mapped[0], + LanguageModelCompletionEvent::UsageUpdate(TokenUsage { + input_tokens: 10, + output_tokens: 20, + .. + }) + )); + assert!(matches!( + mapped[1], + LanguageModelCompletionEvent::Stop(StopReason::MaxTokens) + )); + } + + #[test] + fn responses_stream_handles_multiple_tool_calls() { + let events = vec![ + ResponsesStreamEvent::OutputItemAdded { + output_index: 0, + sequence_number: None, + item: response_item_function_call("item_fn1", Some("{\"city\":\"NYC\"}")), + }, + ResponsesStreamEvent::FunctionCallArgumentsDone { + item_id: "item_fn1".into(), + output_index: 0, + arguments: "{\"city\":\"NYC\"}".into(), + sequence_number: None, + }, + ResponsesStreamEvent::OutputItemAdded { + output_index: 1, + sequence_number: None, + item: response_item_function_call("item_fn2", Some("{\"city\":\"LA\"}")), + }, + ResponsesStreamEvent::FunctionCallArgumentsDone { + item_id: "item_fn2".into(), + output_index: 1, + arguments: "{\"city\":\"LA\"}".into(), + sequence_number: None, + }, + ResponsesStreamEvent::Completed { + response: ResponseSummary::default(), + }, + ]; + + let mapped = map_response_events(events); + assert_eq!(mapped.len(), 3); + assert!(matches!( + mapped[0], + LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { ref raw_input, .. }) + if raw_input == "{\"city\":\"NYC\"}" + )); + assert!(matches!( + mapped[1], + LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { ref raw_input, .. }) + if raw_input == "{\"city\":\"LA\"}" + )); + assert!(matches!( + mapped[2], + LanguageModelCompletionEvent::Stop(StopReason::ToolUse) + )); + } + + #[test] + fn responses_stream_handles_mixed_text_and_tool_calls() { + let events = vec![ + ResponsesStreamEvent::OutputItemAdded { + output_index: 0, + sequence_number: None, + item: response_item_message("msg_123"), + }, + ResponsesStreamEvent::OutputTextDelta { + item_id: "msg_123".into(), + output_index: 0, + content_index: Some(0), + delta: "Let me check that".into(), + }, + ResponsesStreamEvent::OutputItemAdded { + output_index: 1, + sequence_number: None, + item: response_item_function_call("item_fn", Some("{\"query\":\"test\"}")), + }, + ResponsesStreamEvent::FunctionCallArgumentsDone { + item_id: "item_fn".into(), + output_index: 1, + arguments: "{\"query\":\"test\"}".into(), + sequence_number: None, + }, + ResponsesStreamEvent::Completed { + response: ResponseSummary::default(), + }, + ]; + + let mapped = map_response_events(events); + assert!(matches!( + mapped[0], + LanguageModelCompletionEvent::StartMessage { .. } + )); + assert!( + matches!(mapped[1], LanguageModelCompletionEvent::Text(ref text) if text == "Let me check that") + ); + assert!( + matches!(mapped[2], LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { ref raw_input, .. }) if raw_input == "{\"query\":\"test\"}") + ); + assert!(matches!( + mapped[3], + LanguageModelCompletionEvent::Stop(StopReason::ToolUse) + )); + } + + #[test] + fn responses_stream_handles_json_parse_error() { + let events = vec![ + ResponsesStreamEvent::OutputItemAdded { + output_index: 0, + sequence_number: None, + item: response_item_function_call("item_fn", Some("{invalid json")), + }, + ResponsesStreamEvent::FunctionCallArgumentsDone { + item_id: "item_fn".into(), + output_index: 0, + arguments: "{invalid json".into(), + sequence_number: None, + }, + ResponsesStreamEvent::Completed { + response: ResponseSummary::default(), + }, + ]; + + let mapped = map_response_events(events); + assert!(matches!( + mapped[0], + LanguageModelCompletionEvent::ToolUseJsonParseError { ref raw_input, .. } + if raw_input.as_ref() == "{invalid json" + )); + } + + #[test] + fn responses_stream_handles_incomplete_function_call() { + let events = vec![ + ResponsesStreamEvent::OutputItemAdded { + output_index: 0, + sequence_number: None, + item: response_item_function_call("item_fn", Some("{\"city\":")), + }, + ResponsesStreamEvent::FunctionCallArgumentsDelta { + item_id: "item_fn".into(), + output_index: 0, + delta: "\"Boston\"".into(), + sequence_number: None, + }, + ResponsesStreamEvent::Incomplete { + response: ResponseSummary { + status_details: Some(ResponseStatusDetails { + reason: Some("max_output_tokens".into()), + r#type: Some("incomplete".into()), + error: None, + }), + output: vec![response_item_function_call( + "item_fn", + Some("{\"city\":\"Boston\"}"), + )], + ..Default::default() + }, + }, + ]; + + let mapped = map_response_events(events); + assert_eq!(mapped.len(), 3); + assert!(matches!( + mapped[0], + LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { + is_input_complete: false, + .. + }) + )); + assert!( + matches!(mapped[1], LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { ref raw_input, is_input_complete: true, .. }) if raw_input == "{\"city\":\"Boston\"}") + ); + assert!(matches!( + mapped[2], + LanguageModelCompletionEvent::Stop(StopReason::MaxTokens) + )); + } + + #[test] + fn responses_stream_incomplete_does_not_duplicate_tool_calls() { + let events = vec![ + ResponsesStreamEvent::OutputItemAdded { + output_index: 0, + sequence_number: None, + item: response_item_function_call("item_fn", Some("{\"city\":\"Boston\"}")), + }, + ResponsesStreamEvent::FunctionCallArgumentsDone { + item_id: "item_fn".into(), + output_index: 0, + arguments: "{\"city\":\"Boston\"}".into(), + sequence_number: None, + }, + ResponsesStreamEvent::Incomplete { + response: ResponseSummary { + status_details: Some(ResponseStatusDetails { + reason: Some("max_output_tokens".into()), + r#type: Some("incomplete".into()), + error: None, + }), + output: vec![response_item_function_call( + "item_fn", + Some("{\"city\":\"Boston\"}"), + )], + ..Default::default() + }, + }, + ]; + + let mapped = map_response_events(events); + assert_eq!(mapped.len(), 2); + assert!( + matches!(mapped[0], LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { ref raw_input, .. }) if raw_input == "{\"city\":\"Boston\"}") + ); + assert!(matches!( + mapped[1], + LanguageModelCompletionEvent::Stop(StopReason::MaxTokens) + )); + } + + #[test] + fn responses_stream_handles_empty_tool_arguments() { + let events = vec![ + ResponsesStreamEvent::OutputItemAdded { + output_index: 0, + sequence_number: None, + item: response_item_function_call("item_fn", Some("")), + }, + ResponsesStreamEvent::FunctionCallArgumentsDone { + item_id: "item_fn".into(), + output_index: 0, + arguments: "".into(), + sequence_number: None, + }, + ResponsesStreamEvent::Completed { + response: ResponseSummary::default(), + }, + ]; + + let mapped = map_response_events(events); + assert_eq!(mapped.len(), 2); + assert!(matches!( + &mapped[0], + LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { + id, name, raw_input, input, .. + }) if id.to_string() == "call_123" + && name.as_ref() == "get_weather" + && raw_input == "" + && input.is_object() + && input.as_object().unwrap().is_empty() + )); + assert!(matches!( + mapped[1], + LanguageModelCompletionEvent::Stop(StopReason::ToolUse) + )); + } + + #[test] + fn responses_stream_emits_partial_tool_use_events() { + let events = vec![ + ResponsesStreamEvent::OutputItemAdded { + output_index: 0, + sequence_number: None, + item: ResponseOutputItem::FunctionCall( + crate::responses::ResponseFunctionToolCall { + id: Some("item_fn".to_string()), + status: Some("in_progress".to_string()), + name: Some("get_weather".to_string()), + call_id: Some("call_abc".to_string()), + arguments: String::new(), + }, + ), + }, + ResponsesStreamEvent::FunctionCallArgumentsDelta { + item_id: "item_fn".into(), + output_index: 0, + delta: "{\"city\":\"Bos".into(), + sequence_number: None, + }, + ResponsesStreamEvent::FunctionCallArgumentsDelta { + item_id: "item_fn".into(), + output_index: 0, + delta: "ton\"}".into(), + sequence_number: None, + }, + ResponsesStreamEvent::FunctionCallArgumentsDone { + item_id: "item_fn".into(), + output_index: 0, + arguments: "{\"city\":\"Boston\"}".into(), + sequence_number: None, + }, + ResponsesStreamEvent::Completed { + response: ResponseSummary::default(), + }, + ]; + + let mapped = map_response_events(events); + assert!(mapped.len() >= 3); + + let complete_tool_use = mapped.iter().find(|e| { + matches!( + e, + LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { + is_input_complete: true, + .. + }) + ) + }); + assert!( + complete_tool_use.is_some(), + "should have a complete tool use event" + ); + + let tool_uses: Vec<_> = mapped + .iter() + .filter(|e| matches!(e, LanguageModelCompletionEvent::ToolUse(_))) + .collect(); + assert!( + tool_uses.len() >= 2, + "should have at least one partial and one complete event" + ); + assert!(matches!( + tool_uses.last().unwrap(), + LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { + is_input_complete: true, + .. + }) + )); + } + + #[test] + fn responses_stream_maps_reasoning_summary_deltas() { + let events = vec![ + ResponsesStreamEvent::OutputItemAdded { + output_index: 0, + sequence_number: None, + item: ResponseOutputItem::Reasoning(ResponseReasoningItem { + id: Some("rs_123".into()), + summary: vec![], + }), + }, + ResponsesStreamEvent::ReasoningSummaryPartAdded { + item_id: "rs_123".into(), + output_index: 0, + summary_index: 0, + }, + ResponsesStreamEvent::ReasoningSummaryTextDelta { + item_id: "rs_123".into(), + output_index: 0, + delta: "Thinking about".into(), + }, + ResponsesStreamEvent::ReasoningSummaryTextDelta { + item_id: "rs_123".into(), + output_index: 0, + delta: " the answer".into(), + }, + ResponsesStreamEvent::ReasoningSummaryTextDone { + item_id: "rs_123".into(), + output_index: 0, + text: "Thinking about the answer".into(), + }, + ResponsesStreamEvent::ReasoningSummaryPartDone { + item_id: "rs_123".into(), + output_index: 0, + summary_index: 0, + }, + ResponsesStreamEvent::ReasoningSummaryPartAdded { + item_id: "rs_123".into(), + output_index: 0, + summary_index: 1, + }, + ResponsesStreamEvent::ReasoningSummaryTextDelta { + item_id: "rs_123".into(), + output_index: 0, + delta: "Second part".into(), + }, + ResponsesStreamEvent::ReasoningSummaryTextDone { + item_id: "rs_123".into(), + output_index: 0, + text: "Second part".into(), + }, + ResponsesStreamEvent::ReasoningSummaryPartDone { + item_id: "rs_123".into(), + output_index: 0, + summary_index: 1, + }, + ResponsesStreamEvent::OutputItemDone { + output_index: 0, + sequence_number: None, + item: ResponseOutputItem::Reasoning(ResponseReasoningItem { + id: Some("rs_123".into()), + summary: vec![ + ReasoningSummaryPart::SummaryText { + text: "Thinking about the answer".into(), + }, + ReasoningSummaryPart::SummaryText { + text: "Second part".into(), + }, + ], + }), + }, + ResponsesStreamEvent::OutputItemAdded { + output_index: 1, + sequence_number: None, + item: response_item_message("msg_456"), + }, + ResponsesStreamEvent::OutputTextDelta { + item_id: "msg_456".into(), + output_index: 1, + content_index: Some(0), + delta: "The answer is 42".into(), + }, + ResponsesStreamEvent::Completed { + response: ResponseSummary::default(), + }, + ]; + + let mapped = map_response_events(events); + + let thinking_events: Vec<_> = mapped + .iter() + .filter(|e| matches!(e, LanguageModelCompletionEvent::Thinking { .. })) + .collect(); + assert_eq!( + thinking_events.len(), + 4, + "expected 4 thinking events, got {:?}", + thinking_events + ); + assert!( + matches!(&thinking_events[0], LanguageModelCompletionEvent::Thinking { text, .. } if text == "Thinking about") + ); + assert!( + matches!(&thinking_events[1], LanguageModelCompletionEvent::Thinking { text, .. } if text == " the answer") + ); + assert!( + matches!(&thinking_events[2], LanguageModelCompletionEvent::Thinking { text, .. } if text == "\n\n"), + "expected separator between summary parts" + ); + assert!( + matches!(&thinking_events[3], LanguageModelCompletionEvent::Thinking { text, .. } if text == "Second part") + ); + + assert!(mapped.iter().any( + |e| matches!(e, LanguageModelCompletionEvent::Text(t) if t == "The answer is 42") + )); + } + + #[test] + fn responses_stream_maps_reasoning_from_done_only() { + let events = vec![ + ResponsesStreamEvent::OutputItemAdded { + output_index: 0, + sequence_number: None, + item: ResponseOutputItem::Reasoning(ResponseReasoningItem { + id: Some("rs_789".into()), + summary: vec![], + }), + }, + ResponsesStreamEvent::OutputItemDone { + output_index: 0, + sequence_number: None, + item: ResponseOutputItem::Reasoning(ResponseReasoningItem { + id: Some("rs_789".into()), + summary: vec![ReasoningSummaryPart::SummaryText { + text: "Summary without deltas".into(), + }], + }), + }, + ResponsesStreamEvent::Completed { + response: ResponseSummary::default(), + }, + ]; + + let mapped = map_response_events(events); + assert!( + !mapped + .iter() + .any(|e| matches!(e, LanguageModelCompletionEvent::Thinking { .. })), + "OutputItemDone reasoning should not produce Thinking events" + ); + } +} diff --git a/crates/open_ai/src/open_ai.rs b/crates/open_ai/src/open_ai.rs index c4a3e078d76eb028b90e5b80fe95b1281b795f34..5423d9c5dcaa13589a8a7d658548b42fd467f67f 100644 --- a/crates/open_ai/src/open_ai.rs +++ b/crates/open_ai/src/open_ai.rs @@ -1,4 +1,5 @@ pub mod batches; +pub mod completion; pub mod responses; use anyhow::{Context as _, Result, anyhow}; @@ -7,9 +8,9 @@ use http_client::{ AsyncBody, HttpClient, Method, Request as HttpRequest, StatusCode, http::{HeaderMap, HeaderValue}, }; +pub use language_model_core::ReasoningEffort; use serde::{Deserialize, Serialize}; use serde_json::Value; -pub use settings::OpenAiReasoningEffort as ReasoningEffort; use std::{convert::TryFrom, future::Future}; use strum::EnumIter; use thiserror::Error; @@ -717,3 +718,26 @@ pub fn embed<'a>( Ok(response) } } + +// -- Conversions to `language_model_core` types -- + +impl From for language_model_core::LanguageModelCompletionError { + fn from(error: RequestError) -> Self { + match error { + RequestError::HttpResponseError { + provider, + status_code, + body, + headers, + } => { + let retry_after = headers + .get(http_client::http::header::RETRY_AFTER) + .and_then(|val| val.to_str().ok()?.parse::().ok()) + .map(std::time::Duration::from_secs); + + Self::from_http_status(provider.into(), status_code, body, retry_after) + } + RequestError::Other(e) => Self::Other(e), + } + } +} diff --git a/crates/open_router/Cargo.toml b/crates/open_router/Cargo.toml index cccb92c33b05b8fff0e5e78277c9f7fa29844ace..2cc5d3d00e2eb5d755cef971be51a315bcdf254f 100644 --- a/crates/open_router/Cargo.toml +++ b/crates/open_router/Cargo.toml @@ -19,6 +19,7 @@ schemars = ["dep:schemars"] anyhow.workspace = true futures.workspace = true http_client.workspace = true +language_model_core.workspace = true schemars = { workspace = true, optional = true } serde.workspace = true serde_json.workspace = true diff --git a/crates/open_router/src/open_router.rs b/crates/open_router/src/open_router.rs index 9841c7b1ae19a57878fd8e84625bc4058b809613..b94631f9a0e6764ab5cfe487e7851a820fa80b1d 100644 --- a/crates/open_router/src/open_router.rs +++ b/crates/open_router/src/open_router.rs @@ -744,3 +744,71 @@ impl ApiErrorCode { } } } + +// -- Conversions to `language_model_core` types -- + +impl From for language_model_core::LanguageModelCompletionError { + fn from(error: OpenRouterError) -> Self { + let provider = language_model_core::LanguageModelProviderName::new("OpenRouter"); + match error { + OpenRouterError::SerializeRequest(error) => Self::SerializeRequest { provider, error }, + OpenRouterError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error }, + OpenRouterError::HttpSend(error) => Self::HttpSend { provider, error }, + OpenRouterError::DeserializeResponse(error) => { + Self::DeserializeResponse { provider, error } + } + OpenRouterError::ReadResponse(error) => Self::ApiReadResponseError { provider, error }, + OpenRouterError::RateLimit { retry_after } => Self::RateLimitExceeded { + provider, + retry_after: Some(retry_after), + }, + OpenRouterError::ServerOverloaded { retry_after } => Self::ServerOverloaded { + provider, + retry_after, + }, + OpenRouterError::ApiError(api_error) => api_error.into(), + } + } +} + +impl From for language_model_core::LanguageModelCompletionError { + fn from(error: ApiError) -> Self { + use ApiErrorCode::*; + let provider = language_model_core::LanguageModelProviderName::new("OpenRouter"); + match error.code { + InvalidRequestError => Self::BadRequestFormat { + provider, + message: error.message, + }, + AuthenticationError => Self::AuthenticationError { + provider, + message: error.message, + }, + PaymentRequiredError => Self::AuthenticationError { + provider, + message: format!("Payment required: {}", error.message), + }, + PermissionError => Self::PermissionError { + provider, + message: error.message, + }, + RequestTimedOut => Self::HttpResponseError { + provider, + status_code: http_client::StatusCode::REQUEST_TIMEOUT, + message: error.message, + }, + RateLimitError => Self::RateLimitExceeded { + provider, + retry_after: None, + }, + ApiError => Self::ApiInternalServerError { + provider, + message: error.message, + }, + OverloadedError => Self::ServerOverloaded { + provider, + retry_after: None, + }, + } + } +} diff --git a/crates/project/src/prettier_store.rs b/crates/project/src/prettier_store.rs index b66f2d5e0c041e104cf109a48b6bad249b492b88..faa2cca79866f31682a497eebab819b75e778ffb 100644 --- a/crates/project/src/prettier_store.rs +++ b/crates/project/src/prettier_store.rs @@ -412,7 +412,7 @@ impl PrettierStore { prettier_store .update(cx, |prettier_store, cx| { let name = if is_default { - LanguageServerName("prettier (default)".to_string().into()) + LanguageServerName("prettier (default)".into()) } else { let worktree_path = worktree_id .and_then(|id| { diff --git a/crates/settings_content/Cargo.toml b/crates/settings_content/Cargo.toml index b3599e9eef3b7ac5680f441369a7cbdc98a5d043..59cccb4167ed64a2ece8ae5a73ac570ca7dabd97 100644 --- a/crates/settings_content/Cargo.toml +++ b/crates/settings_content/Cargo.toml @@ -19,6 +19,7 @@ anyhow.workspace = true collections.workspace = true derive_more.workspace = true gpui.workspace = true +language_model_core.workspace = true log.workspace = true schemars.workspace = true serde.workspace = true diff --git a/crates/settings_content/src/language_model.rs b/crates/settings_content/src/language_model.rs index 4b72c2ad3f47d834dfa38555d80a8646e3940f51..00ecf42537459496102495c51628b54405968214 100644 --- a/crates/settings_content/src/language_model.rs +++ b/crates/settings_content/src/language_model.rs @@ -1,8 +1,8 @@ +use crate::merge_from::MergeFrom; use collections::HashMap; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings_macros::{MergeFrom, with_fallible_options}; -use strum::EnumString; use std::sync::Arc; @@ -237,15 +237,12 @@ pub struct OpenAiAvailableModel { pub capabilities: OpenAiModelCapabilities, } -#[derive(Debug, Serialize, Deserialize, PartialEq, Clone, EnumString, JsonSchema, MergeFrom)] -#[serde(rename_all = "lowercase")] -#[strum(serialize_all = "lowercase")] -pub enum OpenAiReasoningEffort { - Minimal, - Low, - Medium, - High, - XHigh, +pub use language_model_core::ReasoningEffort as OpenAiReasoningEffort; + +impl MergeFrom for OpenAiReasoningEffort { + fn merge_from(&mut self, other: &Self) { + *self = *other; + } } #[with_fallible_options] @@ -479,15 +476,10 @@ pub struct LanguageModelCacheConfiguration { pub min_total_token: u64, } -#[derive( - Copy, Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize, JsonSchema, MergeFrom, -)] -#[serde(tag = "type", rename_all = "lowercase")] -pub enum ModelMode { - #[default] - Default, - Thinking { - /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`. - budget_tokens: Option, - }, +pub use language_model_core::ModelMode; + +impl MergeFrom for ModelMode { + fn merge_from(&mut self, other: &Self) { + *self = *other; + } } diff --git a/crates/web_search_providers/Cargo.toml b/crates/web_search_providers/Cargo.toml index ff264edcb150063237c633de746b2f6b9f6f250c..e2bbc1aeb2dd5718596b905788b4a88826357401 100644 --- a/crates/web_search_providers/Cargo.toml +++ b/crates/web_search_providers/Cargo.toml @@ -14,6 +14,7 @@ path = "src/web_search_providers.rs" [dependencies] anyhow.workspace = true client.workspace = true +cloud_api_client.workspace = true cloud_api_types.workspace = true cloud_llm_client.workspace = true futures.workspace = true diff --git a/crates/web_search_providers/src/cloud.rs b/crates/web_search_providers/src/cloud.rs index 11227d8fb5c7152dc5b7e03b95fadea6cb714717..16707003c49921bce6244b69d0e7387f935ed8e1 100644 --- a/crates/web_search_providers/src/cloud.rs +++ b/crates/web_search_providers/src/cloud.rs @@ -2,12 +2,12 @@ use std::sync::Arc; use anyhow::{Context as _, Result}; use client::{Client, NeedsLlmTokenRefresh, UserStore, global_llm_token}; +use cloud_api_client::LlmApiToken; use cloud_api_types::OrganizationId; use cloud_llm_client::{WebSearchBody, WebSearchResponse}; use futures::AsyncReadExt as _; use gpui::{App, AppContext, Context, Entity, Task}; use http_client::{HttpClient, Method}; -use language_model::LlmApiToken; use web_search::{WebSearchProvider, WebSearchProviderId}; pub struct CloudWebSearchProvider { diff --git a/crates/x_ai/Cargo.toml b/crates/x_ai/Cargo.toml index 8ff020df8c1ccaf284157d8b46ddaa0e678b3cd7..2d1c9d0ecebeb8a1e0965b0ac914603b41383f00 100644 --- a/crates/x_ai/Cargo.toml +++ b/crates/x_ai/Cargo.toml @@ -17,6 +17,8 @@ schemars = ["dep:schemars"] [dependencies] anyhow.workspace = true +language_model_core.workspace = true schemars = { workspace = true, optional = true } serde.workspace = true strum.workspace = true +tiktoken-rs.workspace = true diff --git a/crates/x_ai/src/completion.rs b/crates/x_ai/src/completion.rs new file mode 100644 index 0000000000000000000000000000000000000000..aad03d227eb82768c972283f7e1617ea7486f22f --- /dev/null +++ b/crates/x_ai/src/completion.rs @@ -0,0 +1,30 @@ +use anyhow::Result; +use language_model_core::{LanguageModelRequest, Role}; + +use crate::Model; + +/// Count tokens for an xAI model using tiktoken. This is synchronous; +/// callers should spawn it on a background thread if needed. +pub fn count_xai_tokens(request: LanguageModelRequest, model: Model) -> Result { + let messages = request + .messages + .into_iter() + .map(|message| tiktoken_rs::ChatCompletionRequestMessage { + role: match message.role { + Role::User => "user".into(), + Role::Assistant => "assistant".into(), + Role::System => "system".into(), + }, + content: Some(message.string_contents()), + name: None, + function_call: None, + }) + .collect::>(); + + let model_name = if model.max_token_count() >= 100_000 { + "gpt-4o" + } else { + "gpt-4" + }; + tiktoken_rs::num_tokens_from_messages(model_name, &messages).map(|tokens| tokens as u64) +} diff --git a/crates/x_ai/src/x_ai.rs b/crates/x_ai/src/x_ai.rs index 1abb2b53771fa1e29e2979560e9f394744b26158..fd141a1723a28d235311d5d875bf4cc0388cab61 100644 --- a/crates/x_ai/src/x_ai.rs +++ b/crates/x_ai/src/x_ai.rs @@ -1,3 +1,5 @@ +pub mod completion; + use anyhow::Result; use serde::{Deserialize, Serialize}; use strum::EnumIter;