Detailed changes
@@ -629,13 +629,17 @@ version = "0.1.0"
dependencies = [
"anyhow",
"chrono",
+ "collections",
"futures 0.3.32",
"http_client",
+ "language_model_core",
+ "log",
"schemars",
"serde",
"serde_json",
"strum 0.27.2",
"thiserror 2.0.17",
+ "tiktoken-rs",
]
[[package]]
@@ -2903,7 +2907,6 @@ dependencies = [
"http_client",
"http_client_tls",
"httparse",
- "language_model",
"log",
"objc2-foundation",
"parking_lot",
@@ -2959,6 +2962,7 @@ dependencies = [
"http_client",
"parking_lot",
"serde_json",
+ "smol",
"thiserror 2.0.17",
"yawc",
]
@@ -5162,6 +5166,7 @@ dependencies = [
"buffer_diff",
"client",
"clock",
+ "cloud_api_client",
"cloud_api_types",
"cloud_llm_client",
"collections",
@@ -5641,7 +5646,7 @@ dependencies = [
name = "env_var"
version = "0.1.0"
dependencies = [
- "gpui",
+ "gpui_shared_string",
]
[[package]]
@@ -7468,11 +7473,13 @@ dependencies = [
"anyhow",
"futures 0.3.32",
"http_client",
+ "language_model_core",
+ "log",
"schemars",
"serde",
"serde_json",
- "settings",
"strum 0.27.2",
+ "tiktoken-rs",
]
[[package]]
@@ -7541,6 +7548,7 @@ dependencies = [
"getrandom 0.3.4",
"gpui_macros",
"gpui_platform",
+ "gpui_shared_string",
"gpui_util",
"gpui_web",
"http_client",
@@ -7710,6 +7718,16 @@ dependencies = [
"gpui_windows",
]
+[[package]]
+name = "gpui_shared_string"
+version = "0.1.0"
+dependencies = [
+ "derive_more",
+ "gpui_util",
+ "schemars",
+ "serde",
+]
+
[[package]]
name = "gpui_tokio"
version = "0.1.0"
@@ -9358,7 +9376,7 @@ version = "0.1.0"
dependencies = [
"anyhow",
"collections",
- "gpui",
+ "gpui_shared_string",
"log",
"lsp",
"parking_lot",
@@ -9397,12 +9415,8 @@ dependencies = [
name = "language_model"
version = "0.1.0"
dependencies = [
- "anthropic",
"anyhow",
"base64 0.22.1",
- "cloud_api_client",
- "cloud_api_types",
- "cloud_llm_client",
"collections",
"credentials_provider",
"env_var",
@@ -9411,16 +9425,31 @@ dependencies = [
"http_client",
"icons",
"image",
+ "language_model_core",
"log",
- "open_ai",
- "open_router",
"parking_lot",
+ "serde",
+ "serde_json",
+ "thiserror 2.0.17",
+ "util",
+]
+
+[[package]]
+name = "language_model_core"
+version = "0.1.0"
+dependencies = [
+ "anyhow",
+ "cloud_llm_client",
+ "futures 0.3.32",
+ "gpui_shared_string",
+ "http_client",
+ "partial-json-fixer",
"schemars",
"serde",
"serde_json",
"smol",
+ "strum 0.27.2",
"thiserror 2.0.17",
- "util",
]
[[package]]
@@ -9436,8 +9465,8 @@ dependencies = [
"base64 0.22.1",
"bedrock",
"client",
+ "cloud_api_client",
"cloud_api_types",
- "cloud_llm_client",
"collections",
"component",
"convert_case 0.8.0",
@@ -9456,6 +9485,7 @@ dependencies = [
"http_client",
"language",
"language_model",
+ "language_models_cloud",
"lmstudio",
"log",
"menu",
@@ -9464,17 +9494,14 @@ dependencies = [
"open_ai",
"open_router",
"opencode",
- "partial-json-fixer",
"pretty_assertions",
"release_channel",
"schemars",
- "semver",
"serde",
"serde_json",
"settings",
"smol",
"strum 0.27.2",
- "thiserror 2.0.17",
"tiktoken-rs",
"tokio",
"ui",
@@ -9484,6 +9511,28 @@ dependencies = [
"x_ai",
]
+[[package]]
+name = "language_models_cloud"
+version = "0.1.0"
+dependencies = [
+ "anthropic",
+ "anyhow",
+ "cloud_llm_client",
+ "futures 0.3.32",
+ "google_ai",
+ "gpui",
+ "http_client",
+ "language_model",
+ "open_ai",
+ "schemars",
+ "semver",
+ "serde",
+ "serde_json",
+ "smol",
+ "thiserror 2.0.17",
+ "x_ai",
+]
+
[[package]]
name = "language_onboarding"
version = "0.1.0"
@@ -11631,16 +11680,19 @@ name = "open_ai"
version = "0.1.0"
dependencies = [
"anyhow",
+ "collections",
"futures 0.3.32",
"http_client",
+ "language_model_core",
"log",
+ "pretty_assertions",
"rand 0.9.2",
"schemars",
"serde",
"serde_json",
- "settings",
"strum 0.27.2",
"thiserror 2.0.17",
+ "tiktoken-rs",
]
[[package]]
@@ -11672,6 +11724,7 @@ dependencies = [
"anyhow",
"futures 0.3.32",
"http_client",
+ "language_model_core",
"schemars",
"serde",
"serde_json",
@@ -15801,6 +15854,7 @@ dependencies = [
"collections",
"derive_more",
"gpui",
+ "language_model_core",
"log",
"schemars",
"serde",
@@ -20180,6 +20234,7 @@ version = "0.1.0"
dependencies = [
"anyhow",
"client",
+ "cloud_api_client",
"cloud_api_types",
"cloud_llm_client",
"futures 0.3.32",
@@ -21783,9 +21838,11 @@ name = "x_ai"
version = "0.1.0"
dependencies = [
"anyhow",
+ "language_model_core",
"schemars",
"serde",
"strum 0.27.2",
+ "tiktoken-rs",
]
[[package]]
@@ -87,6 +87,7 @@ members = [
"crates/google_ai",
"crates/grammars",
"crates/gpui",
+ "crates/gpui_shared_string",
"crates/gpui_linux",
"crates/gpui_macos",
"crates/gpui_macros",
@@ -110,7 +111,9 @@ members = [
"crates/language_core",
"crates/language_extension",
"crates/language_model",
+ "crates/language_model_core",
"crates/language_models",
+ "crates/language_models_cloud",
"crates/language_onboarding",
"crates/language_selector",
"crates/language_tools",
@@ -335,6 +338,7 @@ go_to_line = { path = "crates/go_to_line" }
google_ai = { path = "crates/google_ai" }
grammars = { path = "crates/grammars" }
gpui = { path = "crates/gpui", default-features = false }
+gpui_shared_string = { path = "crates/gpui_shared_string" }
gpui_linux = { path = "crates/gpui_linux", default-features = false }
gpui_macos = { path = "crates/gpui_macos", default-features = false }
gpui_macros = { path = "crates/gpui_macros" }
@@ -361,7 +365,9 @@ language = { path = "crates/language" }
language_core = { path = "crates/language_core" }
language_extension = { path = "crates/language_extension" }
language_model = { path = "crates/language_model" }
+language_model_core = { path = "crates/language_model_core" }
language_models = { path = "crates/language_models" }
+language_models_cloud = { path = "crates/language_models_cloud" }
language_onboarding = { path = "crates/language_onboarding" }
language_selector = { path = "crates/language_selector" }
language_tools = { path = "crates/language_tools" }
@@ -5,7 +5,7 @@ use futures::FutureExt as _;
use gpui::{App, Entity, SharedString, Task};
use indoc::formatdoc;
use language::Point;
-use language_model::{LanguageModelImage, LanguageModelToolResultContent};
+use language_model::{LanguageModelImage, LanguageModelImageExt, LanguageModelToolResultContent};
use project::{AgentLocation, ImageItem, Project, WorktreeSettings, image_store};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
@@ -325,7 +325,7 @@ impl AcpConnection {
// Use the one the agent provides if we have one
.map(|info| info.name.into())
// Otherwise, just use the name
- .unwrap_or_else(|| agent_id.0.to_string().into());
+ .unwrap_or_else(|| agent_id.0.clone());
let session_list = if response
.agent_capabilities
@@ -382,7 +382,7 @@ impl AgentRegistryPage {
self.install_button(agent, install_status, supports_current_platform, cx);
let repository_button = agent.repository().map(|repository| {
- let repository_for_tooltip: SharedString = repository.to_string().into();
+ let repository_for_tooltip = repository.clone();
let repository_for_click = repository.to_string();
IconButton::new(
@@ -18,7 +18,7 @@ use gpui::{
use http_client::{AsyncBody, HttpClientWithUrl};
use itertools::Either;
use language::Buffer;
-use language_model::LanguageModelImage;
+use language_model::{LanguageModelImage, LanguageModelImageExt};
use multi_buffer::MultiBufferRow;
use postage::stream::Stream as _;
use project::{Project, ProjectItem, ProjectPath, Worktree};
@@ -18,12 +18,16 @@ path = "src/anthropic.rs"
[dependencies]
anyhow.workspace = true
chrono.workspace = true
+collections.workspace = true
futures.workspace = true
http_client.workspace = true
+language_model_core.workspace = true
+log.workspace = true
schemars = { workspace = true, optional = true }
serde.workspace = true
serde_json.workspace = true
strum.workspace = true
thiserror.workspace = true
+tiktoken-rs.workspace = true
@@ -12,6 +12,7 @@ use strum::{EnumIter, EnumString};
use thiserror::Error;
pub mod batches;
+pub mod completion;
pub const ANTHROPIC_API_URL: &str = "https://api.anthropic.com";
@@ -1026,6 +1027,89 @@ pub async fn count_tokens(
}
}
+// -- Conversions from/to `language_model_core` types --
+
+impl From<language_model_core::Speed> for Speed {
+ fn from(speed: language_model_core::Speed) -> Self {
+ match speed {
+ language_model_core::Speed::Standard => Speed::Standard,
+ language_model_core::Speed::Fast => Speed::Fast,
+ }
+ }
+}
+
+impl From<AnthropicError> for language_model_core::LanguageModelCompletionError {
+ fn from(error: AnthropicError) -> Self {
+ let provider = language_model_core::ANTHROPIC_PROVIDER_NAME;
+ match error {
+ AnthropicError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
+ AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
+ AnthropicError::HttpSend(error) => Self::HttpSend { provider, error },
+ AnthropicError::DeserializeResponse(error) => {
+ Self::DeserializeResponse { provider, error }
+ }
+ AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
+ AnthropicError::HttpResponseError {
+ status_code,
+ message,
+ } => Self::HttpResponseError {
+ provider,
+ status_code,
+ message,
+ },
+ AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded {
+ provider,
+ retry_after: Some(retry_after),
+ },
+ AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
+ provider,
+ retry_after,
+ },
+ AnthropicError::ApiError(api_error) => api_error.into(),
+ }
+ }
+}
+
+impl From<ApiError> for language_model_core::LanguageModelCompletionError {
+ fn from(error: ApiError) -> Self {
+ use ApiErrorCode::*;
+ let provider = language_model_core::ANTHROPIC_PROVIDER_NAME;
+ match error.code() {
+ Some(code) => match code {
+ InvalidRequestError => Self::BadRequestFormat {
+ provider,
+ message: error.message,
+ },
+ AuthenticationError => Self::AuthenticationError {
+ provider,
+ message: error.message,
+ },
+ PermissionError => Self::PermissionError {
+ provider,
+ message: error.message,
+ },
+ NotFoundError => Self::ApiEndpointNotFound { provider },
+ RequestTooLarge => Self::PromptTooLarge {
+ tokens: language_model_core::parse_prompt_too_long(&error.message),
+ },
+ RateLimitError => Self::RateLimitExceeded {
+ provider,
+ retry_after: None,
+ },
+ ApiError => Self::ApiInternalServerError {
+ provider,
+ message: error.message,
+ },
+ OverloadedError => Self::ServerOverloaded {
+ provider,
+ retry_after: None,
+ },
+ },
+ None => Self::Other(error.into()),
+ }
+ }
+}
+
#[test]
fn test_match_window_exceeded() {
let error = ApiError {
@@ -0,0 +1,765 @@
+use anyhow::Result;
+use collections::HashMap;
+use futures::{Stream, StreamExt};
+use language_model_core::{
+ LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelRequest,
+ LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
+ Role, StopReason, TokenUsage,
+ util::{fix_streamed_json, parse_tool_arguments},
+};
+use std::pin::Pin;
+use std::str::FromStr;
+
+use crate::{
+ AnthropicError, AnthropicModelMode, CacheControl, CacheControlType, ContentDelta,
+ CountTokensRequest, Event, ImageSource, Message, RequestContent, ResponseContent,
+ StringOrContents, Thinking, Tool, ToolChoice, ToolResultContent, ToolResultPart, Usage,
+};
+
+fn to_anthropic_content(content: MessageContent) -> Option<RequestContent> {
+ match content {
+ MessageContent::Text(text) => {
+ let text = if text.chars().last().is_some_and(|c| c.is_whitespace()) {
+ text.trim_end().to_string()
+ } else {
+ text
+ };
+ if !text.is_empty() {
+ Some(RequestContent::Text {
+ text,
+ cache_control: None,
+ })
+ } else {
+ None
+ }
+ }
+ MessageContent::Thinking {
+ text: thinking,
+ signature,
+ } => {
+ if let Some(signature) = signature
+ && !thinking.is_empty()
+ {
+ Some(RequestContent::Thinking {
+ thinking,
+ signature,
+ cache_control: None,
+ })
+ } else {
+ None
+ }
+ }
+ MessageContent::RedactedThinking(data) => {
+ if !data.is_empty() {
+ Some(RequestContent::RedactedThinking { data })
+ } else {
+ None
+ }
+ }
+ MessageContent::Image(image) => Some(RequestContent::Image {
+ source: ImageSource {
+ source_type: "base64".to_string(),
+ media_type: "image/png".to_string(),
+ data: image.source.to_string(),
+ },
+ cache_control: None,
+ }),
+ MessageContent::ToolUse(tool_use) => Some(RequestContent::ToolUse {
+ id: tool_use.id.to_string(),
+ name: tool_use.name.to_string(),
+ input: tool_use.input,
+ cache_control: None,
+ }),
+ MessageContent::ToolResult(tool_result) => Some(RequestContent::ToolResult {
+ tool_use_id: tool_result.tool_use_id.to_string(),
+ is_error: tool_result.is_error,
+ content: match tool_result.content {
+ LanguageModelToolResultContent::Text(text) => {
+ ToolResultContent::Plain(text.to_string())
+ }
+ LanguageModelToolResultContent::Image(image) => {
+ ToolResultContent::Multipart(vec![ToolResultPart::Image {
+ source: ImageSource {
+ source_type: "base64".to_string(),
+ media_type: "image/png".to_string(),
+ data: image.source.to_string(),
+ },
+ }])
+ }
+ },
+ cache_control: None,
+ }),
+ }
+}
+
+/// Convert a LanguageModelRequest to an Anthropic CountTokensRequest.
+pub fn into_anthropic_count_tokens_request(
+ request: LanguageModelRequest,
+ model: String,
+ mode: AnthropicModelMode,
+) -> CountTokensRequest {
+ let mut new_messages: Vec<Message> = Vec::new();
+ let mut system_message = String::new();
+
+ for message in request.messages {
+ if message.contents_empty() {
+ continue;
+ }
+
+ match message.role {
+ Role::User | Role::Assistant => {
+ let anthropic_message_content: Vec<RequestContent> = message
+ .content
+ .into_iter()
+ .filter_map(to_anthropic_content)
+ .collect();
+ let anthropic_role = match message.role {
+ Role::User => crate::Role::User,
+ Role::Assistant => crate::Role::Assistant,
+ Role::System => unreachable!("System role should never occur here"),
+ };
+ if anthropic_message_content.is_empty() {
+ continue;
+ }
+
+ if let Some(last_message) = new_messages.last_mut()
+ && last_message.role == anthropic_role
+ {
+ last_message.content.extend(anthropic_message_content);
+ continue;
+ }
+
+ new_messages.push(Message {
+ role: anthropic_role,
+ content: anthropic_message_content,
+ });
+ }
+ Role::System => {
+ if !system_message.is_empty() {
+ system_message.push_str("\n\n");
+ }
+ system_message.push_str(&message.string_contents());
+ }
+ }
+ }
+
+ CountTokensRequest {
+ model,
+ messages: new_messages,
+ system: if system_message.is_empty() {
+ None
+ } else {
+ Some(StringOrContents::String(system_message))
+ },
+ thinking: if request.thinking_allowed {
+ match mode {
+ AnthropicModelMode::Thinking { budget_tokens } => {
+ Some(Thinking::Enabled { budget_tokens })
+ }
+ AnthropicModelMode::AdaptiveThinking => Some(Thinking::Adaptive),
+ AnthropicModelMode::Default => None,
+ }
+ } else {
+ None
+ },
+ tools: request
+ .tools
+ .into_iter()
+ .map(|tool| Tool {
+ name: tool.name,
+ description: tool.description,
+ input_schema: tool.input_schema,
+ eager_input_streaming: tool.use_input_streaming,
+ })
+ .collect(),
+ tool_choice: request.tool_choice.map(|choice| match choice {
+ LanguageModelToolChoice::Auto => ToolChoice::Auto,
+ LanguageModelToolChoice::Any => ToolChoice::Any,
+ LanguageModelToolChoice::None => ToolChoice::None,
+ }),
+ }
+}
+
+/// Estimate tokens using tiktoken. Used as a fallback when the API is unavailable,
+/// or by providers (like Zed Cloud) that don't have direct Anthropic API access.
+pub fn count_anthropic_tokens_with_tiktoken(request: LanguageModelRequest) -> Result<u64> {
+ let messages = request.messages;
+ let mut tokens_from_images = 0;
+ let mut string_messages = Vec::with_capacity(messages.len());
+
+ for message in messages {
+ let mut string_contents = String::new();
+
+ for content in message.content {
+ match content {
+ MessageContent::Text(text) => {
+ string_contents.push_str(&text);
+ }
+ MessageContent::Thinking { .. } => {
+ // Thinking blocks are not included in the input token count.
+ }
+ MessageContent::RedactedThinking(_) => {
+ // Thinking blocks are not included in the input token count.
+ }
+ MessageContent::Image(image) => {
+ tokens_from_images += image.estimate_tokens();
+ }
+ MessageContent::ToolUse(_tool_use) => {
+ // TODO: Estimate token usage from tool uses.
+ }
+ MessageContent::ToolResult(tool_result) => match &tool_result.content {
+ LanguageModelToolResultContent::Text(text) => {
+ string_contents.push_str(text);
+ }
+ LanguageModelToolResultContent::Image(image) => {
+ tokens_from_images += image.estimate_tokens();
+ }
+ },
+ }
+ }
+
+ if !string_contents.is_empty() {
+ string_messages.push(tiktoken_rs::ChatCompletionRequestMessage {
+ role: match message.role {
+ Role::User => "user".into(),
+ Role::Assistant => "assistant".into(),
+ Role::System => "system".into(),
+ },
+ content: Some(string_contents),
+ name: None,
+ function_call: None,
+ });
+ }
+ }
+
+ // Tiktoken doesn't yet support these models, so we manually use the
+ // same tokenizer as GPT-4.
+ tiktoken_rs::num_tokens_from_messages("gpt-4", &string_messages)
+ .map(|tokens| (tokens + tokens_from_images) as u64)
+}
+
+pub fn into_anthropic(
+ request: LanguageModelRequest,
+ model: String,
+ default_temperature: f32,
+ max_output_tokens: u64,
+ mode: AnthropicModelMode,
+) -> crate::Request {
+ let mut new_messages: Vec<Message> = Vec::new();
+ let mut system_message = String::new();
+
+ for message in request.messages {
+ if message.contents_empty() {
+ continue;
+ }
+
+ match message.role {
+ Role::User | Role::Assistant => {
+ let mut anthropic_message_content: Vec<RequestContent> = message
+ .content
+ .into_iter()
+ .filter_map(to_anthropic_content)
+ .collect();
+ let anthropic_role = match message.role {
+ Role::User => crate::Role::User,
+ Role::Assistant => crate::Role::Assistant,
+ Role::System => unreachable!("System role should never occur here"),
+ };
+ if anthropic_message_content.is_empty() {
+ continue;
+ }
+
+ if let Some(last_message) = new_messages.last_mut()
+ && last_message.role == anthropic_role
+ {
+ last_message.content.extend(anthropic_message_content);
+ continue;
+ }
+
+ // Mark the last segment of the message as cached
+ if message.cache {
+ let cache_control_value = Some(CacheControl {
+ cache_type: CacheControlType::Ephemeral,
+ });
+ for message_content in anthropic_message_content.iter_mut().rev() {
+ match message_content {
+ RequestContent::RedactedThinking { .. } => {
+ // Caching is not possible, fallback to next message
+ }
+ RequestContent::Text { cache_control, .. }
+ | RequestContent::Thinking { cache_control, .. }
+ | RequestContent::Image { cache_control, .. }
+ | RequestContent::ToolUse { cache_control, .. }
+ | RequestContent::ToolResult { cache_control, .. } => {
+ *cache_control = cache_control_value;
+ break;
+ }
+ }
+ }
+ }
+
+ new_messages.push(Message {
+ role: anthropic_role,
+ content: anthropic_message_content,
+ });
+ }
+ Role::System => {
+ if !system_message.is_empty() {
+ system_message.push_str("\n\n");
+ }
+ system_message.push_str(&message.string_contents());
+ }
+ }
+ }
+
+ crate::Request {
+ model,
+ messages: new_messages,
+ max_tokens: max_output_tokens,
+ system: if system_message.is_empty() {
+ None
+ } else {
+ Some(StringOrContents::String(system_message))
+ },
+ thinking: if request.thinking_allowed {
+ match mode {
+ AnthropicModelMode::Thinking { budget_tokens } => {
+ Some(Thinking::Enabled { budget_tokens })
+ }
+ AnthropicModelMode::AdaptiveThinking => Some(Thinking::Adaptive),
+ AnthropicModelMode::Default => None,
+ }
+ } else {
+ None
+ },
+ tools: request
+ .tools
+ .into_iter()
+ .map(|tool| Tool {
+ name: tool.name,
+ description: tool.description,
+ input_schema: tool.input_schema,
+ eager_input_streaming: tool.use_input_streaming,
+ })
+ .collect(),
+ tool_choice: request.tool_choice.map(|choice| match choice {
+ LanguageModelToolChoice::Auto => ToolChoice::Auto,
+ LanguageModelToolChoice::Any => ToolChoice::Any,
+ LanguageModelToolChoice::None => ToolChoice::None,
+ }),
+ metadata: None,
+ output_config: if request.thinking_allowed
+ && matches!(mode, AnthropicModelMode::AdaptiveThinking)
+ {
+ request.thinking_effort.as_deref().and_then(|effort| {
+ let effort = match effort {
+ "low" => Some(crate::Effort::Low),
+ "medium" => Some(crate::Effort::Medium),
+ "high" => Some(crate::Effort::High),
+ "max" => Some(crate::Effort::Max),
+ _ => None,
+ };
+ effort.map(|effort| crate::OutputConfig {
+ effort: Some(effort),
+ })
+ })
+ } else {
+ None
+ },
+ stop_sequences: Vec::new(),
+ speed: request.speed.map(Into::into),
+ temperature: request.temperature.or(Some(default_temperature)),
+ top_k: None,
+ top_p: None,
+ }
+}
+
+pub struct AnthropicEventMapper {
+ tool_uses_by_index: HashMap<usize, RawToolUse>,
+ usage: Usage,
+ stop_reason: StopReason,
+}
+
+impl AnthropicEventMapper {
+ pub fn new() -> Self {
+ Self {
+ tool_uses_by_index: HashMap::default(),
+ usage: Usage::default(),
+ stop_reason: StopReason::EndTurn,
+ }
+ }
+
+ pub fn map_stream(
+ mut self,
+ events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
+ ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
+ {
+ events.flat_map(move |event| {
+ futures::stream::iter(match event {
+ Ok(event) => self.map_event(event),
+ Err(error) => vec![Err(error.into())],
+ })
+ })
+ }
+
+ pub fn map_event(
+ &mut self,
+ event: Event,
+ ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
+ match event {
+ Event::ContentBlockStart {
+ index,
+ content_block,
+ } => match content_block {
+ ResponseContent::Text { text } => {
+ vec![Ok(LanguageModelCompletionEvent::Text(text))]
+ }
+ ResponseContent::Thinking { thinking } => {
+ vec![Ok(LanguageModelCompletionEvent::Thinking {
+ text: thinking,
+ signature: None,
+ })]
+ }
+ ResponseContent::RedactedThinking { data } => {
+ vec![Ok(LanguageModelCompletionEvent::RedactedThinking { data })]
+ }
+ ResponseContent::ToolUse { id, name, .. } => {
+ self.tool_uses_by_index.insert(
+ index,
+ RawToolUse {
+ id,
+ name,
+ input_json: String::new(),
+ },
+ );
+ Vec::new()
+ }
+ },
+ Event::ContentBlockDelta { index, delta } => match delta {
+ ContentDelta::TextDelta { text } => {
+ vec![Ok(LanguageModelCompletionEvent::Text(text))]
+ }
+ ContentDelta::ThinkingDelta { thinking } => {
+ vec![Ok(LanguageModelCompletionEvent::Thinking {
+ text: thinking,
+ signature: None,
+ })]
+ }
+ ContentDelta::SignatureDelta { signature } => {
+ vec![Ok(LanguageModelCompletionEvent::Thinking {
+ text: "".to_string(),
+ signature: Some(signature),
+ })]
+ }
+ ContentDelta::InputJsonDelta { partial_json } => {
+ if let Some(tool_use) = self.tool_uses_by_index.get_mut(&index) {
+ tool_use.input_json.push_str(&partial_json);
+
+ // Try to convert invalid (incomplete) JSON into
+ // valid JSON that serde can accept, e.g. by closing
+ // unclosed delimiters. This way, we can update the
+ // UI with whatever has been streamed back so far.
+ if let Ok(input) =
+ serde_json::Value::from_str(&fix_streamed_json(&tool_use.input_json))
+ {
+ return vec![Ok(LanguageModelCompletionEvent::ToolUse(
+ LanguageModelToolUse {
+ id: tool_use.id.clone().into(),
+ name: tool_use.name.clone().into(),
+ is_input_complete: false,
+ raw_input: tool_use.input_json.clone(),
+ input,
+ thought_signature: None,
+ },
+ ))];
+ }
+ }
+ vec![]
+ }
+ },
+ Event::ContentBlockStop { index } => {
+ if let Some(tool_use) = self.tool_uses_by_index.remove(&index) {
+ let input_json = tool_use.input_json.trim();
+ let event_result = match parse_tool_arguments(input_json) {
+ Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
+ LanguageModelToolUse {
+ id: tool_use.id.into(),
+ name: tool_use.name.into(),
+ is_input_complete: true,
+ input,
+ raw_input: tool_use.input_json.clone(),
+ thought_signature: None,
+ },
+ )),
+ Err(json_parse_err) => {
+ Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
+ id: tool_use.id.into(),
+ tool_name: tool_use.name.into(),
+ raw_input: input_json.into(),
+ json_parse_error: json_parse_err.to_string(),
+ })
+ }
+ };
+
+ vec![event_result]
+ } else {
+ Vec::new()
+ }
+ }
+ Event::MessageStart { message } => {
+ update_usage(&mut self.usage, &message.usage);
+ vec![
+ Ok(LanguageModelCompletionEvent::UsageUpdate(convert_usage(
+ &self.usage,
+ ))),
+ Ok(LanguageModelCompletionEvent::StartMessage {
+ message_id: message.id,
+ }),
+ ]
+ }
+ Event::MessageDelta { delta, usage } => {
+ update_usage(&mut self.usage, &usage);
+ if let Some(stop_reason) = delta.stop_reason.as_deref() {
+ self.stop_reason = match stop_reason {
+ "end_turn" => StopReason::EndTurn,
+ "max_tokens" => StopReason::MaxTokens,
+ "tool_use" => StopReason::ToolUse,
+ "refusal" => StopReason::Refusal,
+ _ => {
+ log::error!("Unexpected anthropic stop_reason: {stop_reason}");
+ StopReason::EndTurn
+ }
+ };
+ }
+ vec![Ok(LanguageModelCompletionEvent::UsageUpdate(
+ convert_usage(&self.usage),
+ ))]
+ }
+ Event::MessageStop => {
+ vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))]
+ }
+ Event::Error { error } => {
+ vec![Err(error.into())]
+ }
+ _ => Vec::new(),
+ }
+ }
+}
+
+struct RawToolUse {
+ id: String,
+ name: String,
+ input_json: String,
+}
+
+/// Updates usage data by preferring counts from `new`.
+fn update_usage(usage: &mut Usage, new: &Usage) {
+ if let Some(input_tokens) = new.input_tokens {
+ usage.input_tokens = Some(input_tokens);
+ }
+ if let Some(output_tokens) = new.output_tokens {
+ usage.output_tokens = Some(output_tokens);
+ }
+ if let Some(cache_creation_input_tokens) = new.cache_creation_input_tokens {
+ usage.cache_creation_input_tokens = Some(cache_creation_input_tokens);
+ }
+ if let Some(cache_read_input_tokens) = new.cache_read_input_tokens {
+ usage.cache_read_input_tokens = Some(cache_read_input_tokens);
+ }
+}
+
+fn convert_usage(usage: &Usage) -> TokenUsage {
+ TokenUsage {
+ input_tokens: usage.input_tokens.unwrap_or(0),
+ output_tokens: usage.output_tokens.unwrap_or(0),
+ cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0),
+ cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0),
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::AnthropicModelMode;
+ use language_model_core::{LanguageModelImage, LanguageModelRequestMessage, MessageContent};
+
+ #[test]
+ fn test_cache_control_only_on_last_segment() {
+ let request = LanguageModelRequest {
+ messages: vec![LanguageModelRequestMessage {
+ role: Role::User,
+ content: vec![
+ MessageContent::Text("Some prompt".to_string()),
+ MessageContent::Image(LanguageModelImage::empty()),
+ MessageContent::Image(LanguageModelImage::empty()),
+ MessageContent::Image(LanguageModelImage::empty()),
+ MessageContent::Image(LanguageModelImage::empty()),
+ ],
+ cache: true,
+ reasoning_details: None,
+ }],
+ thread_id: None,
+ prompt_id: None,
+ intent: None,
+ stop: vec![],
+ temperature: None,
+ tools: vec![],
+ tool_choice: None,
+ thinking_allowed: true,
+ thinking_effort: None,
+ speed: None,
+ };
+
+ let anthropic_request = into_anthropic(
+ request,
+ "claude-3-5-sonnet".to_string(),
+ 0.7,
+ 4096,
+ AnthropicModelMode::Default,
+ );
+
+ assert_eq!(anthropic_request.messages.len(), 1);
+
+ let message = &anthropic_request.messages[0];
+ assert_eq!(message.content.len(), 5);
+
+ assert!(matches!(
+ message.content[0],
+ RequestContent::Text {
+ cache_control: None,
+ ..
+ }
+ ));
+ for i in 1..3 {
+ assert!(matches!(
+ message.content[i],
+ RequestContent::Image {
+ cache_control: None,
+ ..
+ }
+ ));
+ }
+
+ assert!(matches!(
+ message.content[4],
+ RequestContent::Image {
+ cache_control: Some(CacheControl {
+ cache_type: CacheControlType::Ephemeral,
+ }),
+ ..
+ }
+ ));
+ }
+
+ fn request_with_assistant_content(assistant_content: Vec<MessageContent>) -> crate::Request {
+ let mut request = LanguageModelRequest {
+ messages: vec![LanguageModelRequestMessage {
+ role: Role::User,
+ content: vec![MessageContent::Text("Hello".to_string())],
+ cache: false,
+ reasoning_details: None,
+ }],
+ thinking_effort: None,
+ thread_id: None,
+ prompt_id: None,
+ intent: None,
+ stop: vec![],
+ temperature: None,
+ tools: vec![],
+ tool_choice: None,
+ thinking_allowed: true,
+ speed: None,
+ };
+ request.messages.push(LanguageModelRequestMessage {
+ role: Role::Assistant,
+ content: assistant_content,
+ cache: false,
+ reasoning_details: None,
+ });
+ into_anthropic(
+ request,
+ "claude-sonnet-4-5".to_string(),
+ 1.0,
+ 16000,
+ AnthropicModelMode::Thinking {
+ budget_tokens: Some(10000),
+ },
+ )
+ }
+
+ #[test]
+ fn test_unsigned_thinking_blocks_stripped() {
+ let result = request_with_assistant_content(vec![
+ MessageContent::Thinking {
+ text: "Cancelled mid-think, no signature".to_string(),
+ signature: None,
+ },
+ MessageContent::Text("Some response text".to_string()),
+ ]);
+
+ let assistant_message = result
+ .messages
+ .iter()
+ .find(|m| m.role == crate::Role::Assistant)
+ .expect("assistant message should still exist");
+
+ assert_eq!(
+ assistant_message.content.len(),
+ 1,
+ "Only the text content should remain; unsigned thinking block should be stripped"
+ );
+ assert!(matches!(
+ &assistant_message.content[0],
+ RequestContent::Text { text, .. } if text == "Some response text"
+ ));
+ }
+
+ #[test]
+ fn test_signed_thinking_blocks_preserved() {
+ let result = request_with_assistant_content(vec![
+ MessageContent::Thinking {
+ text: "Completed thinking".to_string(),
+ signature: Some("valid-signature".to_string()),
+ },
+ MessageContent::Text("Response".to_string()),
+ ]);
+
+ let assistant_message = result
+ .messages
+ .iter()
+ .find(|m| m.role == crate::Role::Assistant)
+ .expect("assistant message should exist");
+
+ assert_eq!(
+ assistant_message.content.len(),
+ 2,
+ "Both the signed thinking block and text should be preserved"
+ );
+ assert!(matches!(
+ &assistant_message.content[0],
+ RequestContent::Thinking { thinking, signature, .. }
+ if thinking == "Completed thinking" && signature == "valid-signature"
+ ));
+ }
+
+ #[test]
+ fn test_only_unsigned_thinking_block_omits_entire_message() {
+ let result = request_with_assistant_content(vec![MessageContent::Thinking {
+ text: "Cancelled before any text or signature".to_string(),
+ signature: None,
+ }]);
+
+ let assistant_messages: Vec<_> = result
+ .messages
+ .iter()
+ .filter(|m| m.role == crate::Role::Assistant)
+ .collect();
+
+ assert_eq!(
+ assistant_messages.len(),
+ 0,
+ "An assistant message whose only content was an unsigned thinking block \
+ should be omitted entirely"
+ );
+ }
+}
@@ -36,7 +36,6 @@ gpui_tokio.workspace = true
http_client.workspace = true
http_client_tls.workspace = true
httparse = "1.10"
-language_model.workspace = true
log.workspace = true
parking_lot.workspace = true
paths.workspace = true
@@ -14,6 +14,7 @@ use async_tungstenite::tungstenite::{
http::{HeaderValue, Request, StatusCode},
};
use clock::SystemClock;
+use cloud_api_client::LlmApiToken;
use cloud_api_client::websocket_protocol::MessageToClient;
use cloud_api_client::{ClientApiError, CloudApiClient};
use cloud_api_types::OrganizationId;
@@ -26,7 +27,6 @@ use futures::{
};
use gpui::{App, AsyncApp, Entity, Global, Task, WeakEntity, actions};
use http_client::{HttpClient, HttpClientWithUrl, http, read_proxy_from_env};
-use language_model::LlmApiToken;
use parking_lot::{Mutex, RwLock};
use postage::watch;
use proxy::connect_proxy_stream;
@@ -1,10 +1,10 @@
use super::{Client, UserStore};
+use cloud_api_client::LlmApiToken;
use cloud_api_types::websocket_protocol::MessageToClient;
use cloud_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, OUTDATED_LLM_TOKEN_HEADER_NAME};
use gpui::{
App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _, Subscription,
};
-use language_model::LlmApiToken;
use std::sync::Arc;
pub trait NeedsLlmTokenRefresh {
@@ -20,5 +20,6 @@ gpui_tokio.workspace = true
http_client.workspace = true
parking_lot.workspace = true
serde_json.workspace = true
+smol.workspace = true
thiserror.workspace = true
yawc.workspace = true
@@ -1,3 +1,4 @@
+mod llm_token;
mod websocket;
use std::sync::Arc;
@@ -18,6 +19,8 @@ use yawc::WebSocket;
use crate::websocket::Connection;
+pub use llm_token::LlmApiToken;
+
struct Credentials {
user_id: u32,
access_token: String,
@@ -0,0 +1,74 @@
+use std::sync::Arc;
+
+use cloud_api_types::OrganizationId;
+use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard};
+
+use crate::{ClientApiError, CloudApiClient};
+
+#[derive(Clone, Default)]
+pub struct LlmApiToken(Arc<RwLock<Option<String>>>);
+
+impl LlmApiToken {
+ pub async fn acquire(
+ &self,
+ client: &CloudApiClient,
+ system_id: Option<String>,
+ organization_id: Option<OrganizationId>,
+ ) -> Result<String, ClientApiError> {
+ let lock = self.0.upgradable_read().await;
+ if let Some(token) = lock.as_ref() {
+ Ok(token.to_string())
+ } else {
+ Self::fetch(
+ RwLockUpgradableReadGuard::upgrade(lock).await,
+ client,
+ system_id,
+ organization_id,
+ )
+ .await
+ }
+ }
+
+ pub async fn refresh(
+ &self,
+ client: &CloudApiClient,
+ system_id: Option<String>,
+ organization_id: Option<OrganizationId>,
+ ) -> Result<String, ClientApiError> {
+ Self::fetch(self.0.write().await, client, system_id, organization_id).await
+ }
+
+ /// Clears the existing token before attempting to fetch a new one.
+ ///
+ /// Used when switching organizations so that a failed refresh doesn't
+ /// leave a token for the wrong organization.
+ pub async fn clear_and_refresh(
+ &self,
+ client: &CloudApiClient,
+ system_id: Option<String>,
+ organization_id: Option<OrganizationId>,
+ ) -> Result<String, ClientApiError> {
+ let mut lock = self.0.write().await;
+ *lock = None;
+ Self::fetch(lock, client, system_id, organization_id).await
+ }
+
+ async fn fetch(
+ mut lock: RwLockWriteGuard<'_, Option<String>>,
+ client: &CloudApiClient,
+ system_id: Option<String>,
+ organization_id: Option<OrganizationId>,
+ ) -> Result<String, ClientApiError> {
+ let result = client.create_llm_token(system_id, organization_id).await;
+ match result {
+ Ok(response) => {
+ *lock = Some(response.token.0.clone());
+ Ok(response.token.0)
+ }
+ Err(err) => {
+ *lock = None;
+ Err(err)
+ }
+ }
+ }
+}
@@ -7,6 +7,7 @@ license = "Apache-2.0"
[features]
test-support = []
+predict-edits = ["dep:zeta_prompt"]
[lints]
workspace = true
@@ -20,6 +21,6 @@ serde = { workspace = true, features = ["derive", "rc"] }
serde_json.workspace = true
strum = { workspace = true, features = ["derive"] }
uuid = { workspace = true, features = ["serde"] }
-zeta_prompt.workspace = true
+zeta_prompt = { workspace = true, optional = true }
@@ -1,3 +1,4 @@
+#[cfg(feature = "predict-edits")]
pub mod predict_edits_v3;
use std::str::FromStr;
@@ -2846,11 +2846,11 @@ impl CollabPanel {
}
};
- Some(channel.name.as_ref())
+ Some(channel.name.clone())
});
if let Some(name) = channel_name {
- SharedString::from(name.to_string())
+ name
} else {
SharedString::from("Current Call")
}
@@ -21,8 +21,9 @@ heapless.workspace = true
buffer_diff.workspace = true
client.workspace = true
clock.workspace = true
+cloud_api_client.workspace = true
cloud_api_types.workspace = true
-cloud_llm_client.workspace = true
+cloud_llm_client = { workspace = true, features = ["predict-edits"] }
collections.workspace = true
copilot.workspace = true
copilot_ui.workspace = true
@@ -1,5 +1,6 @@
use anyhow::Result;
use client::{Client, EditPredictionUsage, NeedsLlmTokenRefresh, UserStore, global_llm_token};
+use cloud_api_client::LlmApiToken;
use cloud_api_types::{OrganizationId, SubmitEditPredictionFeedbackBody};
use cloud_llm_client::predict_edits_v3::{
PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
@@ -31,7 +32,6 @@ use heapless::Vec as ArrayVec;
use language::language_settings::all_language_settings;
use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
use language::{BufferSnapshot, OffsetRangeExt};
-use language_model::LlmApiToken;
use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
use release_channel::AppVersion;
use semver::Version;
@@ -57,7 +57,7 @@ pub fn fetch_models(cx: &mut App) -> Vec<SharedString> {
let mut models: Vec<SharedString> = provider
.provided_models(cx)
.into_iter()
- .map(|model| SharedString::from(model.id().0.to_string()))
+ .map(|model| model.id().0)
.collect();
models.sort();
models
@@ -177,7 +177,7 @@ impl EditPredictionDelegate for ZedEditPredictionDelegate {
BufferEditPrediction::Local { prediction } => prediction,
BufferEditPrediction::Jump { prediction } => {
return Some(edit_prediction_types::EditPrediction::Jump {
- id: Some(prediction.id.to_string().into()),
+ id: Some(prediction.id.0.clone()),
snapshot: prediction.snapshot.clone(),
target: prediction.edits.first().unwrap().0.start,
});
@@ -228,7 +228,7 @@ impl EditPredictionDelegate for ZedEditPredictionDelegate {
}
Some(edit_prediction_types::EditPrediction::Local {
- id: Some(prediction.id.to_string().into()),
+ id: Some(prediction.id.0.clone()),
edits: edits[edit_start_ix..edit_end_ix].to_vec(),
cursor_position: prediction.cursor_position,
edit_preview: Some(prediction.edit_preview.clone()),
@@ -22,7 +22,7 @@ http_client.workspace = true
chrono.workspace = true
clap = "4"
client.workspace = true
-cloud_llm_client.workspace= true
+cloud_llm_client = { workspace = true, features = ["predict-edits"] }
collections.workspace = true
db.workspace = true
debug_adapter_extension.workspace = true
@@ -12,4 +12,4 @@ workspace = true
path = "src/env_var.rs"
[dependencies]
-gpui.workspace = true
+gpui_shared_string.workspace = true
@@ -1,4 +1,4 @@
-use gpui::SharedString;
+use gpui_shared_string::SharedString;
#[derive(Clone)]
pub struct EnvVar {
@@ -1906,7 +1906,7 @@ mod tests {
assert_eq!(
remotes,
vec![Remote {
- name: SharedString::from("my_new_remote".to_string())
+ name: SharedString::from("my_new_remote")
}]
);
}
@@ -18,8 +18,10 @@ schemars = ["dep:schemars"]
anyhow.workspace = true
futures.workspace = true
http_client.workspace = true
+language_model_core.workspace = true
+log.workspace = true
schemars = { workspace = true, optional = true }
serde.workspace = true
serde_json.workspace = true
-settings.workspace = true
strum.workspace = true
+tiktoken-rs.workspace = true
@@ -0,0 +1,492 @@
+use anyhow::Result;
+use futures::{Stream, StreamExt};
+use language_model_core::{
+ LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelRequest,
+ LanguageModelToolChoice, LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role,
+ StopReason, TokenUsage,
+};
+use std::pin::Pin;
+use std::sync::Arc;
+use std::sync::atomic::{self, AtomicU64};
+
+use crate::{
+ Content, FunctionCallingConfig, FunctionCallingMode, FunctionDeclaration,
+ GenerateContentResponse, GenerationConfig, GenerativeContentBlob, GoogleModelMode,
+ InlineDataPart, ModelName, Part, SystemInstruction, TextPart, ThinkingConfig, ToolConfig,
+ UsageMetadata,
+};
+
+pub fn into_google(
+ mut request: LanguageModelRequest,
+ model_id: String,
+ mode: GoogleModelMode,
+) -> crate::GenerateContentRequest {
+ fn map_content(content: Vec<MessageContent>) -> Vec<Part> {
+ content
+ .into_iter()
+ .flat_map(|content| match content {
+ MessageContent::Text(text) => {
+ if !text.is_empty() {
+ vec![Part::TextPart(TextPart { text })]
+ } else {
+ vec![]
+ }
+ }
+ MessageContent::Thinking {
+ text: _,
+ signature: Some(signature),
+ } => {
+ if !signature.is_empty() {
+ vec![Part::ThoughtPart(crate::ThoughtPart {
+ thought: true,
+ thought_signature: signature,
+ })]
+ } else {
+ vec![]
+ }
+ }
+ MessageContent::Thinking { .. } => {
+ vec![]
+ }
+ MessageContent::RedactedThinking(_) => vec![],
+ MessageContent::Image(image) => {
+ vec![Part::InlineDataPart(InlineDataPart {
+ inline_data: GenerativeContentBlob {
+ mime_type: "image/png".to_string(),
+ data: image.source.to_string(),
+ },
+ })]
+ }
+ MessageContent::ToolUse(tool_use) => {
+ // Normalize empty string signatures to None
+ let thought_signature = tool_use.thought_signature.filter(|s| !s.is_empty());
+
+ vec![Part::FunctionCallPart(crate::FunctionCallPart {
+ function_call: crate::FunctionCall {
+ name: tool_use.name.to_string(),
+ args: tool_use.input,
+ },
+ thought_signature,
+ })]
+ }
+ MessageContent::ToolResult(tool_result) => {
+ match tool_result.content {
+ language_model_core::LanguageModelToolResultContent::Text(text) => {
+ vec![Part::FunctionResponsePart(crate::FunctionResponsePart {
+ function_response: crate::FunctionResponse {
+ name: tool_result.tool_name.to_string(),
+ // The API expects a valid JSON object
+ response: serde_json::json!({
+ "output": text
+ }),
+ },
+ })]
+ }
+ language_model_core::LanguageModelToolResultContent::Image(image) => {
+ vec![
+ Part::FunctionResponsePart(crate::FunctionResponsePart {
+ function_response: crate::FunctionResponse {
+ name: tool_result.tool_name.to_string(),
+ // The API expects a valid JSON object
+ response: serde_json::json!({
+ "output": "Tool responded with an image"
+ }),
+ },
+ }),
+ Part::InlineDataPart(InlineDataPart {
+ inline_data: GenerativeContentBlob {
+ mime_type: "image/png".to_string(),
+ data: image.source.to_string(),
+ },
+ }),
+ ]
+ }
+ }
+ }
+ })
+ .collect()
+ }
+
+ let system_instructions = if request
+ .messages
+ .first()
+ .is_some_and(|msg| matches!(msg.role, Role::System))
+ {
+ let message = request.messages.remove(0);
+ Some(SystemInstruction {
+ parts: map_content(message.content),
+ })
+ } else {
+ None
+ };
+
+ crate::GenerateContentRequest {
+ model: ModelName { model_id },
+ system_instruction: system_instructions,
+ contents: request
+ .messages
+ .into_iter()
+ .filter_map(|message| {
+ let parts = map_content(message.content);
+ if parts.is_empty() {
+ None
+ } else {
+ Some(Content {
+ parts,
+ role: match message.role {
+ Role::User => crate::Role::User,
+ Role::Assistant => crate::Role::Model,
+ Role::System => crate::Role::User, // Google AI doesn't have a system role
+ },
+ })
+ }
+ })
+ .collect(),
+ generation_config: Some(GenerationConfig {
+ candidate_count: Some(1),
+ stop_sequences: Some(request.stop),
+ max_output_tokens: None,
+ temperature: request.temperature.map(|t| t as f64).or(Some(1.0)),
+ thinking_config: match (request.thinking_allowed, mode) {
+ (true, GoogleModelMode::Thinking { budget_tokens }) => {
+ budget_tokens.map(|thinking_budget| ThinkingConfig { thinking_budget })
+ }
+ _ => None,
+ },
+ top_p: None,
+ top_k: None,
+ }),
+ safety_settings: None,
+ tools: (!request.tools.is_empty()).then(|| {
+ vec![crate::Tool {
+ function_declarations: request
+ .tools
+ .into_iter()
+ .map(|tool| FunctionDeclaration {
+ name: tool.name,
+ description: tool.description,
+ parameters: tool.input_schema,
+ })
+ .collect(),
+ }]
+ }),
+ tool_config: request.tool_choice.map(|choice| ToolConfig {
+ function_calling_config: FunctionCallingConfig {
+ mode: match choice {
+ LanguageModelToolChoice::Auto => FunctionCallingMode::Auto,
+ LanguageModelToolChoice::Any => FunctionCallingMode::Any,
+ LanguageModelToolChoice::None => FunctionCallingMode::None,
+ },
+ allowed_function_names: None,
+ },
+ }),
+ }
+}
+
+pub struct GoogleEventMapper {
+ usage: UsageMetadata,
+ stop_reason: StopReason,
+}
+
+impl GoogleEventMapper {
+ pub fn new() -> Self {
+ Self {
+ usage: UsageMetadata::default(),
+ stop_reason: StopReason::EndTurn,
+ }
+ }
+
+ pub fn map_stream(
+ mut self,
+ events: Pin<Box<dyn Send + Stream<Item = Result<GenerateContentResponse>>>>,
+ ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
+ {
+ events
+ .map(Some)
+ .chain(futures::stream::once(async { None }))
+ .flat_map(move |event| {
+ futures::stream::iter(match event {
+ Some(Ok(event)) => self.map_event(event),
+ Some(Err(error)) => {
+ vec![Err(LanguageModelCompletionError::from(error))]
+ }
+ None => vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))],
+ })
+ })
+ }
+
+ pub fn map_event(
+ &mut self,
+ event: GenerateContentResponse,
+ ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
+ static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0);
+
+ let mut events: Vec<_> = Vec::new();
+ let mut wants_to_use_tool = false;
+ if let Some(usage_metadata) = event.usage_metadata {
+ update_usage(&mut self.usage, &usage_metadata);
+ events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(
+ convert_usage(&self.usage),
+ )))
+ }
+
+ if let Some(prompt_feedback) = event.prompt_feedback
+ && let Some(block_reason) = prompt_feedback.block_reason.as_deref()
+ {
+ self.stop_reason = match block_reason {
+ "SAFETY" | "OTHER" | "BLOCKLIST" | "PROHIBITED_CONTENT" | "IMAGE_SAFETY" => {
+ StopReason::Refusal
+ }
+ _ => {
+ log::error!("Unexpected Google block_reason: {block_reason}");
+ StopReason::Refusal
+ }
+ };
+ events.push(Ok(LanguageModelCompletionEvent::Stop(self.stop_reason)));
+
+ return events;
+ }
+
+ if let Some(candidates) = event.candidates {
+ for candidate in candidates {
+ if let Some(finish_reason) = candidate.finish_reason.as_deref() {
+ self.stop_reason = match finish_reason {
+ "STOP" => StopReason::EndTurn,
+ "MAX_TOKENS" => StopReason::MaxTokens,
+ _ => {
+ log::error!("Unexpected google finish_reason: {finish_reason}");
+ StopReason::EndTurn
+ }
+ };
+ }
+ candidate
+ .content
+ .parts
+ .into_iter()
+ .for_each(|part| match part {
+ Part::TextPart(text_part) => {
+ events.push(Ok(LanguageModelCompletionEvent::Text(text_part.text)))
+ }
+ Part::InlineDataPart(_) => {}
+ Part::FunctionCallPart(function_call_part) => {
+ wants_to_use_tool = true;
+ let name: Arc<str> = function_call_part.function_call.name.into();
+ let next_tool_id =
+ TOOL_CALL_COUNTER.fetch_add(1, atomic::Ordering::SeqCst);
+ let id: LanguageModelToolUseId =
+ format!("{}-{}", name, next_tool_id).into();
+
+ // Normalize empty string signatures to None
+ let thought_signature = function_call_part
+ .thought_signature
+ .filter(|s| !s.is_empty());
+
+ events.push(Ok(LanguageModelCompletionEvent::ToolUse(
+ LanguageModelToolUse {
+ id,
+ name,
+ is_input_complete: true,
+ raw_input: function_call_part.function_call.args.to_string(),
+ input: function_call_part.function_call.args,
+ thought_signature,
+ },
+ )));
+ }
+ Part::FunctionResponsePart(_) => {}
+ Part::ThoughtPart(part) => {
+ events.push(Ok(LanguageModelCompletionEvent::Thinking {
+ text: "(Encrypted thought)".to_string(), // TODO: Can we populate this from thought summaries?
+ signature: Some(part.thought_signature),
+ }));
+ }
+ });
+ }
+ }
+
+ // Even when Gemini wants to use a Tool, the API
+ // responds with `finish_reason: STOP`
+ if wants_to_use_tool {
+ self.stop_reason = StopReason::ToolUse;
+ events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
+ }
+ events
+ }
+}
+
+/// Count tokens for a Google AI model using tiktoken. This is synchronous;
+/// callers should spawn it on a background thread if needed.
+pub fn count_google_tokens(request: LanguageModelRequest) -> Result<u64> {
+ let messages = request
+ .messages
+ .into_iter()
+ .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
+ role: match message.role {
+ Role::User => "user".into(),
+ Role::Assistant => "assistant".into(),
+ Role::System => "system".into(),
+ },
+ content: Some(message.string_contents()),
+ name: None,
+ function_call: None,
+ })
+ .collect::<Vec<_>>();
+
+ // Tiktoken doesn't yet support these models, so we manually use the
+ // same tokenizer as GPT-4.
+ tiktoken_rs::num_tokens_from_messages("gpt-4", &messages).map(|tokens| tokens as u64)
+}
+
+fn update_usage(usage: &mut UsageMetadata, new: &UsageMetadata) {
+ if let Some(prompt_token_count) = new.prompt_token_count {
+ usage.prompt_token_count = Some(prompt_token_count);
+ }
+ if let Some(cached_content_token_count) = new.cached_content_token_count {
+ usage.cached_content_token_count = Some(cached_content_token_count);
+ }
+ if let Some(candidates_token_count) = new.candidates_token_count {
+ usage.candidates_token_count = Some(candidates_token_count);
+ }
+ if let Some(tool_use_prompt_token_count) = new.tool_use_prompt_token_count {
+ usage.tool_use_prompt_token_count = Some(tool_use_prompt_token_count);
+ }
+ if let Some(thoughts_token_count) = new.thoughts_token_count {
+ usage.thoughts_token_count = Some(thoughts_token_count);
+ }
+ if let Some(total_token_count) = new.total_token_count {
+ usage.total_token_count = Some(total_token_count);
+ }
+}
+
+fn convert_usage(usage: &UsageMetadata) -> TokenUsage {
+ let prompt_tokens = usage.prompt_token_count.unwrap_or(0);
+ let cached_tokens = usage.cached_content_token_count.unwrap_or(0);
+ let input_tokens = prompt_tokens - cached_tokens;
+ let output_tokens = usage.candidates_token_count.unwrap_or(0);
+
+ TokenUsage {
+ input_tokens,
+ output_tokens,
+ cache_read_input_tokens: cached_tokens,
+ cache_creation_input_tokens: 0,
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::{
+ Content, FunctionCall, FunctionCallPart, GenerateContentCandidate, GenerateContentResponse,
+ Part, Role as GoogleRole,
+ };
+ use serde_json::json;
+
+ #[test]
+ fn test_function_call_with_signature_creates_tool_use_with_signature() {
+ let mut mapper = GoogleEventMapper::new();
+
+ let response = GenerateContentResponse {
+ candidates: Some(vec![GenerateContentCandidate {
+ index: Some(0),
+ content: Content {
+ parts: vec![Part::FunctionCallPart(FunctionCallPart {
+ function_call: FunctionCall {
+ name: "test_function".to_string(),
+ args: json!({"arg": "value"}),
+ },
+ thought_signature: Some("test_signature_123".to_string()),
+ })],
+ role: GoogleRole::Model,
+ },
+ finish_reason: None,
+ finish_message: None,
+ safety_ratings: None,
+ citation_metadata: None,
+ }]),
+ prompt_feedback: None,
+ usage_metadata: None,
+ };
+
+ let events = mapper.map_event(response);
+ assert_eq!(events.len(), 2);
+
+ if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
+ assert_eq!(tool_use.name.as_ref(), "test_function");
+ assert_eq!(
+ tool_use.thought_signature.as_deref(),
+ Some("test_signature_123")
+ );
+ } else {
+ panic!("Expected ToolUse event");
+ }
+ }
+
+ #[test]
+ fn test_function_call_without_signature_has_none() {
+ let mut mapper = GoogleEventMapper::new();
+
+ let response = GenerateContentResponse {
+ candidates: Some(vec![GenerateContentCandidate {
+ index: Some(0),
+ content: Content {
+ parts: vec![Part::FunctionCallPart(FunctionCallPart {
+ function_call: FunctionCall {
+ name: "test_function".to_string(),
+ args: json!({"arg": "value"}),
+ },
+ thought_signature: None,
+ })],
+ role: GoogleRole::Model,
+ },
+ finish_reason: None,
+ finish_message: None,
+ safety_ratings: None,
+ citation_metadata: None,
+ }]),
+ prompt_feedback: None,
+ usage_metadata: None,
+ };
+
+ let events = mapper.map_event(response);
+ assert_eq!(events.len(), 2);
+
+ if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
+ assert!(tool_use.thought_signature.is_none());
+ } else {
+ panic!("Expected ToolUse event");
+ }
+ }
+
+ #[test]
+ fn test_empty_string_signature_normalized_to_none() {
+ let mut mapper = GoogleEventMapper::new();
+
+ let response = GenerateContentResponse {
+ candidates: Some(vec![GenerateContentCandidate {
+ index: Some(0),
+ content: Content {
+ parts: vec![Part::FunctionCallPart(FunctionCallPart {
+ function_call: FunctionCall {
+ name: "test_function".to_string(),
+ args: json!({"arg": "value"}),
+ },
+ thought_signature: Some("".to_string()),
+ })],
+ role: GoogleRole::Model,
+ },
+ finish_reason: None,
+ finish_message: None,
+ safety_ratings: None,
+ citation_metadata: None,
+ }]),
+ prompt_feedback: None,
+ usage_metadata: None,
+ };
+
+ let events = mapper.map_event(response);
+ if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
+ assert!(tool_use.thought_signature.is_none());
+ } else {
+ panic!("Expected ToolUse event");
+ }
+ }
+}
@@ -3,8 +3,9 @@ use std::mem;
use anyhow::{Result, anyhow, bail};
use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
+pub use language_model_core::ModelMode as GoogleModelMode;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
-pub use settings::ModelMode as GoogleModelMode;
+pub mod completion;
pub const API_URL: &str = "https://generativelanguage.googleapis.com";
@@ -56,6 +56,7 @@ etagere = "0.2"
futures.workspace = true
futures-concurrency.workspace = true
gpui_macros.workspace = true
+gpui_shared_string.workspace = true
http_client.workspace = true
image.workspace = true
inventory.workspace = true
@@ -39,7 +39,6 @@ pub mod profiler;
#[expect(missing_docs)]
pub mod queue;
mod scene;
-mod shared_string;
mod shared_uri;
mod style;
mod styled;
@@ -92,6 +91,7 @@ pub use global::*;
pub use gpui_macros::{
AppContext, IntoElement, Render, VisualContext, property_test, register_action, test,
};
+pub use gpui_shared_string::*;
pub use gpui_util::arc_cow::ArcCow;
pub use http_client;
pub use input::*;
@@ -106,7 +106,6 @@ pub use profiler::*;
pub use queue::{PriorityQueueReceiver, PriorityQueueSender};
pub use refineable::*;
pub use scene::*;
-pub use shared_string::*;
pub use shared_uri::*;
use std::{any::Any, future::Future};
pub use style::*;
@@ -882,7 +882,7 @@ mod tests {
],
len: 6,
}),
- text: SharedString::new("abcdef".to_string()),
+ text: "abcdef".into(),
decoration_runs: SmallVec::new(),
};
@@ -0,0 +1,17 @@
+[package]
+name = "gpui_shared_string"
+version = "0.1.0"
+publish.workspace = true
+edition.workspace = true
+
+[lib]
+path = "gpui_shared_string.rs"
+
+[dependencies]
+derive_more.workspace = true
+gpui_util.workspace = true
+schemars.workspace = true
+serde.workspace = true
+
+[lints]
+workspace = true
@@ -0,0 +1 @@
+../../LICENSE-APACHE
@@ -10,7 +10,7 @@ path = "src/language_core.rs"
[dependencies]
anyhow.workspace = true
collections.workspace = true
-gpui.workspace = true
+gpui_shared_string.workspace = true
log.workspace = true
lsp.workspace = true
parking_lot.workspace = true
@@ -22,8 +22,6 @@ toml.workspace = true
tree-sitter.workspace = true
util.workspace = true
-[dev-dependencies]
-gpui = { workspace = true, features = ["test-support"] }
[features]
test-support = []
@@ -1,4 +1,4 @@
-use gpui::SharedString;
+use gpui_shared_string::SharedString;
use lsp::{DiagnosticSeverity, NumberOrString};
use serde::{Deserialize, Serialize};
use serde_json::Value;
@@ -4,7 +4,7 @@ use crate::{
};
use anyhow::{Context as _, Result};
use collections::HashMap;
-use gpui::SharedString;
+use gpui_shared_string::SharedString;
use lsp::LanguageServerName;
use parking_lot::Mutex;
use std::sync::atomic::{AtomicUsize, Ordering::SeqCst};
@@ -1,6 +1,6 @@
use crate::LanguageName;
use collections::{HashMap, HashSet, IndexSet};
-use gpui::SharedString;
+use gpui_shared_string::SharedString;
use lsp::LanguageServerName;
use regex::Regex;
use schemars::{JsonSchema, SchemaGenerator, json_schema};
@@ -1,4 +1,4 @@
-use gpui::SharedString;
+use gpui_shared_string::SharedString;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::{
@@ -1,4 +1,4 @@
-use gpui::SharedString;
+use gpui_shared_string::SharedString;
use serde::{Deserialize, Serialize};
/// Converts a value into an LSP position.
@@ -1,6 +1,6 @@
use std::borrow::Borrow;
-use gpui::SharedString;
+use gpui_shared_string::SharedString;
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct ManifestName(SharedString);
@@ -6,7 +6,7 @@
use std::{path::Path, sync::Arc};
-use gpui::SharedString;
+use gpui_shared_string::SharedString;
use util::rel_path::RelPath;
use crate::{LanguageName, ManifestName};
@@ -16,13 +16,9 @@ doctest = false
test-support = []
[dependencies]
-anthropic = { workspace = true, features = ["schemars"] }
anyhow.workspace = true
credentials_provider.workspace = true
base64.workspace = true
-cloud_api_client.workspace = true
-cloud_api_types.workspace = true
-cloud_llm_client.workspace = true
collections.workspace = true
env_var.workspace = true
futures.workspace = true
@@ -30,14 +26,11 @@ gpui.workspace = true
http_client.workspace = true
icons.workspace = true
image.workspace = true
+language_model_core.workspace = true
log.workspace = true
-open_ai = { workspace = true, features = ["schemars"] }
-open_router.workspace = true
parking_lot.workspace = true
-schemars.workspace = true
serde.workspace = true
serde_json.workspace = true
-smol.workspace = true
thiserror.workspace = true
util.workspace = true
@@ -5,11 +5,10 @@ use crate::{
LanguageModelRequest, LanguageModelToolChoice,
};
use anyhow::anyhow;
-use futures::{FutureExt, channel::mpsc, future::BoxFuture, stream::BoxStream};
+use futures::{FutureExt, channel::mpsc, future::BoxFuture, stream::BoxStream, stream::StreamExt};
use gpui::{AnyView, App, AsyncApp, Entity, Task, Window};
use http_client::Result;
use parking_lot::Mutex;
-use smol::stream::StreamExt;
use std::sync::{
Arc,
atomic::{AtomicBool, Ordering::SeqCst},
@@ -1,380 +1,31 @@
mod api_key;
mod model;
-mod provider;
-mod rate_limiter;
mod registry;
mod request;
-mod role;
-pub mod tool_schema;
#[cfg(any(test, feature = "test-support"))]
pub mod fake_provider;
-use anyhow::{Result, anyhow};
-use cloud_llm_client::CompletionRequestStatus;
+pub use language_model_core::*;
+
+use anyhow::Result;
use futures::FutureExt;
use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
-use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window};
-use http_client::{StatusCode, http};
+use gpui::{AnyView, App, AsyncApp, Task, Window};
use icons::IconName;
use parking_lot::Mutex;
-use serde::{Deserialize, Serialize};
-use std::ops::{Add, Sub};
-use std::str::FromStr;
use std::sync::Arc;
-use std::time::Duration;
-use std::{fmt, io};
-use thiserror::Error;
-use util::serde::is_default;
pub use crate::api_key::{ApiKey, ApiKeyState};
pub use crate::model::*;
-pub use crate::rate_limiter::*;
pub use crate::registry::*;
-pub use crate::request::*;
-pub use crate::role::*;
-pub use crate::tool_schema::LanguageModelToolSchemaFormat;
+pub use crate::request::{LanguageModelImageExt, gpui_size_to_image_size, image_size_to_gpui};
pub use env_var::{EnvVar, env_var};
-pub use provider::*;
pub fn init(cx: &mut App) {
registry::init(cx);
}
-#[derive(Clone, Debug)]
-pub struct LanguageModelCacheConfiguration {
- pub max_cache_anchors: usize,
- pub should_speculate: bool,
- pub min_total_token: u64,
-}
-
-/// A completion event from a language model.
-#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
-pub enum LanguageModelCompletionEvent {
- Queued {
- position: usize,
- },
- Started,
- Stop(StopReason),
- Text(String),
- Thinking {
- text: String,
- signature: Option<String>,
- },
- RedactedThinking {
- data: String,
- },
- ToolUse(LanguageModelToolUse),
- ToolUseJsonParseError {
- id: LanguageModelToolUseId,
- tool_name: Arc<str>,
- raw_input: Arc<str>,
- json_parse_error: String,
- },
- StartMessage {
- message_id: String,
- },
- ReasoningDetails(serde_json::Value),
- UsageUpdate(TokenUsage),
-}
-
-impl LanguageModelCompletionEvent {
- pub fn from_completion_request_status(
- status: CompletionRequestStatus,
- upstream_provider: LanguageModelProviderName,
- ) -> Result<Option<Self>, LanguageModelCompletionError> {
- match status {
- CompletionRequestStatus::Queued { position } => {
- Ok(Some(LanguageModelCompletionEvent::Queued { position }))
- }
- CompletionRequestStatus::Started => Ok(Some(LanguageModelCompletionEvent::Started)),
- CompletionRequestStatus::Unknown | CompletionRequestStatus::StreamEnded => Ok(None),
- CompletionRequestStatus::Failed {
- code,
- message,
- request_id: _,
- retry_after,
- } => Err(LanguageModelCompletionError::from_cloud_failure(
- upstream_provider,
- code,
- message,
- retry_after.map(Duration::from_secs_f64),
- )),
- }
- }
-}
-
-#[derive(Error, Debug)]
-pub enum LanguageModelCompletionError {
- #[error("prompt too large for context window")]
- PromptTooLarge { tokens: Option<u64> },
- #[error("missing {provider} API key")]
- NoApiKey { provider: LanguageModelProviderName },
- #[error("{provider}'s API rate limit exceeded")]
- RateLimitExceeded {
- provider: LanguageModelProviderName,
- retry_after: Option<Duration>,
- },
- #[error("{provider}'s API servers are overloaded right now")]
- ServerOverloaded {
- provider: LanguageModelProviderName,
- retry_after: Option<Duration>,
- },
- #[error("{provider}'s API server reported an internal server error: {message}")]
- ApiInternalServerError {
- provider: LanguageModelProviderName,
- message: String,
- },
- #[error("{message}")]
- UpstreamProviderError {
- message: String,
- status: StatusCode,
- retry_after: Option<Duration>,
- },
- #[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")]
- HttpResponseError {
- provider: LanguageModelProviderName,
- status_code: StatusCode,
- message: String,
- },
-
- // Client errors
- #[error("invalid request format to {provider}'s API: {message}")]
- BadRequestFormat {
- provider: LanguageModelProviderName,
- message: String,
- },
- #[error("authentication error with {provider}'s API: {message}")]
- AuthenticationError {
- provider: LanguageModelProviderName,
- message: String,
- },
- #[error("Permission error with {provider}'s API: {message}")]
- PermissionError {
- provider: LanguageModelProviderName,
- message: String,
- },
- #[error("language model provider API endpoint not found")]
- ApiEndpointNotFound { provider: LanguageModelProviderName },
- #[error("I/O error reading response from {provider}'s API")]
- ApiReadResponseError {
- provider: LanguageModelProviderName,
- #[source]
- error: io::Error,
- },
- #[error("error serializing request to {provider} API")]
- SerializeRequest {
- provider: LanguageModelProviderName,
- #[source]
- error: serde_json::Error,
- },
- #[error("error building request body to {provider} API")]
- BuildRequestBody {
- provider: LanguageModelProviderName,
- #[source]
- error: http::Error,
- },
- #[error("error sending HTTP request to {provider} API")]
- HttpSend {
- provider: LanguageModelProviderName,
- #[source]
- error: anyhow::Error,
- },
- #[error("error deserializing {provider} API response")]
- DeserializeResponse {
- provider: LanguageModelProviderName,
- #[source]
- error: serde_json::Error,
- },
-
- #[error("stream from {provider} ended unexpectedly")]
- StreamEndedUnexpectedly { provider: LanguageModelProviderName },
-
- // TODO: Ideally this would be removed in favor of having a comprehensive list of errors.
- #[error(transparent)]
- Other(#[from] anyhow::Error),
-}
-
-impl LanguageModelCompletionError {
- fn parse_upstream_error_json(message: &str) -> Option<(StatusCode, String)> {
- let error_json = serde_json::from_str::<serde_json::Value>(message).ok()?;
- let upstream_status = error_json
- .get("upstream_status")
- .and_then(|v| v.as_u64())
- .and_then(|status| u16::try_from(status).ok())
- .and_then(|status| StatusCode::from_u16(status).ok())?;
- let inner_message = error_json
- .get("message")
- .and_then(|v| v.as_str())
- .unwrap_or(message)
- .to_string();
- Some((upstream_status, inner_message))
- }
-
- pub fn from_cloud_failure(
- upstream_provider: LanguageModelProviderName,
- code: String,
- message: String,
- retry_after: Option<Duration>,
- ) -> Self {
- if let Some(tokens) = parse_prompt_too_long(&message) {
- // TODO: currently Anthropic PAYLOAD_TOO_LARGE response may cause INTERNAL_SERVER_ERROR
- // to be reported. This is a temporary workaround to handle this in the case where the
- // token limit has been exceeded.
- Self::PromptTooLarge {
- tokens: Some(tokens),
- }
- } else if code == "upstream_http_error" {
- if let Some((upstream_status, inner_message)) =
- Self::parse_upstream_error_json(&message)
- {
- return Self::from_http_status(
- upstream_provider,
- upstream_status,
- inner_message,
- retry_after,
- );
- }
- anyhow!("completion request failed, code: {code}, message: {message}").into()
- } else if let Some(status_code) = code
- .strip_prefix("upstream_http_")
- .and_then(|code| StatusCode::from_str(code).ok())
- {
- Self::from_http_status(upstream_provider, status_code, message, retry_after)
- } else if let Some(status_code) = code
- .strip_prefix("http_")
- .and_then(|code| StatusCode::from_str(code).ok())
- {
- Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after)
- } else {
- anyhow!("completion request failed, code: {code}, message: {message}").into()
- }
- }
-
- pub fn from_http_status(
- provider: LanguageModelProviderName,
- status_code: StatusCode,
- message: String,
- retry_after: Option<Duration>,
- ) -> Self {
- match status_code {
- StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message },
- StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message },
- StatusCode::FORBIDDEN => Self::PermissionError { provider, message },
- StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider },
- StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge {
- tokens: parse_prompt_too_long(&message),
- },
- StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded {
- provider,
- retry_after,
- },
- StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message },
- StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded {
- provider,
- retry_after,
- },
- _ if status_code.as_u16() == 529 => Self::ServerOverloaded {
- provider,
- retry_after,
- },
- _ => Self::HttpResponseError {
- provider,
- status_code,
- message,
- },
- }
- }
-}
-
-#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
-#[serde(rename_all = "snake_case")]
-pub enum StopReason {
- EndTurn,
- MaxTokens,
- ToolUse,
- Refusal,
-}
-
-#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
-pub struct TokenUsage {
- #[serde(default, skip_serializing_if = "is_default")]
- pub input_tokens: u64,
- #[serde(default, skip_serializing_if = "is_default")]
- pub output_tokens: u64,
- #[serde(default, skip_serializing_if = "is_default")]
- pub cache_creation_input_tokens: u64,
- #[serde(default, skip_serializing_if = "is_default")]
- pub cache_read_input_tokens: u64,
-}
-
-impl TokenUsage {
- pub fn total_tokens(&self) -> u64 {
- self.input_tokens
- + self.output_tokens
- + self.cache_read_input_tokens
- + self.cache_creation_input_tokens
- }
-}
-
-impl Add<TokenUsage> for TokenUsage {
- type Output = Self;
-
- fn add(self, other: Self) -> Self {
- Self {
- input_tokens: self.input_tokens + other.input_tokens,
- output_tokens: self.output_tokens + other.output_tokens,
- cache_creation_input_tokens: self.cache_creation_input_tokens
- + other.cache_creation_input_tokens,
- cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens,
- }
- }
-}
-
-impl Sub<TokenUsage> for TokenUsage {
- type Output = Self;
-
- fn sub(self, other: Self) -> Self {
- Self {
- input_tokens: self.input_tokens - other.input_tokens,
- output_tokens: self.output_tokens - other.output_tokens,
- cache_creation_input_tokens: self.cache_creation_input_tokens
- - other.cache_creation_input_tokens,
- cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens,
- }
- }
-}
-
-#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
-pub struct LanguageModelToolUseId(Arc<str>);
-
-impl fmt::Display for LanguageModelToolUseId {
- fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
- write!(f, "{}", self.0)
- }
-}
-
-impl<T> From<T> for LanguageModelToolUseId
-where
- T: Into<Arc<str>>,
-{
- fn from(value: T) -> Self {
- Self(value.into())
- }
-}
-
-#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
-pub struct LanguageModelToolUse {
- pub id: LanguageModelToolUseId,
- pub name: Arc<str>,
- pub raw_input: String,
- pub input: serde_json::Value,
- pub is_input_complete: bool,
- /// Thought signature the model sent us. Some models require that this
- /// signature be preserved and sent back in conversation history for validation.
- pub thought_signature: Option<String>,
-}
-
pub struct LanguageModelTextStream {
pub message_id: Option<String>,
pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
@@ -392,13 +43,6 @@ impl Default for LanguageModelTextStream {
}
}
-#[derive(Debug, Clone)]
-pub struct LanguageModelEffortLevel {
- pub name: SharedString,
- pub value: SharedString,
- pub is_default: bool,
-}
-
pub trait LanguageModel: Send + Sync {
fn id(&self) -> LanguageModelId;
fn name(&self) -> LanguageModelName;
@@ -605,7 +249,7 @@ pub trait LanguageModel: Send + Sync {
}
impl std::fmt::Debug for dyn LanguageModel {
- fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("<dyn LanguageModel>")
.field("id", &self.id())
.field("name", &self.name())
@@ -619,17 +263,6 @@ impl std::fmt::Debug for dyn LanguageModel {
}
}
-/// An error that occurred when trying to authenticate the language model provider.
-#[derive(Debug, Error)]
-pub enum AuthenticateError {
- #[error("connection refused")]
- ConnectionRefused,
- #[error("credentials not found")]
- CredentialsNotFound,
- #[error(transparent)]
- Other(#[from] anyhow::Error),
-}
-
/// Either a built-in icon name or a path to an external SVG.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum IconOrSvg {
@@ -692,18 +325,6 @@ pub trait LanguageModelProviderState: 'static {
}
}
-#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)]
-pub struct LanguageModelId(pub SharedString);
-
-#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
-pub struct LanguageModelName(pub SharedString);
-
-#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
-pub struct LanguageModelProviderId(pub SharedString);
-
-#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
-pub struct LanguageModelProviderName(pub SharedString);
-
#[derive(Clone, Debug, PartialEq)]
pub enum LanguageModelCostInfo {
/// Cost per 1,000 input and output tokens
@@ -741,245 +362,3 @@ impl LanguageModelCostInfo {
}
}
}
-
-impl LanguageModelProviderId {
- pub const fn new(id: &'static str) -> Self {
- Self(SharedString::new_static(id))
- }
-}
-
-impl LanguageModelProviderName {
- pub const fn new(id: &'static str) -> Self {
- Self(SharedString::new_static(id))
- }
-}
-
-impl fmt::Display for LanguageModelProviderId {
- fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
- write!(f, "{}", self.0)
- }
-}
-
-impl fmt::Display for LanguageModelProviderName {
- fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
- write!(f, "{}", self.0)
- }
-}
-
-impl From<String> for LanguageModelId {
- fn from(value: String) -> Self {
- Self(SharedString::from(value))
- }
-}
-
-impl From<String> for LanguageModelName {
- fn from(value: String) -> Self {
- Self(SharedString::from(value))
- }
-}
-
-impl From<String> for LanguageModelProviderId {
- fn from(value: String) -> Self {
- Self(SharedString::from(value))
- }
-}
-
-impl From<String> for LanguageModelProviderName {
- fn from(value: String) -> Self {
- Self(SharedString::from(value))
- }
-}
-
-impl From<Arc<str>> for LanguageModelProviderId {
- fn from(value: Arc<str>) -> Self {
- Self(SharedString::from(value))
- }
-}
-
-impl From<Arc<str>> for LanguageModelProviderName {
- fn from(value: Arc<str>) -> Self {
- Self(SharedString::from(value))
- }
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
-
- #[test]
- fn test_from_cloud_failure_with_upstream_http_error() {
- let error = LanguageModelCompletionError::from_cloud_failure(
- String::from("anthropic").into(),
- "upstream_http_error".to_string(),
- r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
- None,
- );
-
- match error {
- LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
- assert_eq!(provider.0, "anthropic");
- }
- _ => panic!(
- "Expected ServerOverloaded error for 503 status, got: {:?}",
- error
- ),
- }
-
- let error = LanguageModelCompletionError::from_cloud_failure(
- String::from("anthropic").into(),
- "upstream_http_error".to_string(),
- r#"{"code":"upstream_http_error","message":"Internal server error","upstream_status":500}"#.to_string(),
- None,
- );
-
- match error {
- LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
- assert_eq!(provider.0, "anthropic");
- assert_eq!(message, "Internal server error");
- }
- _ => panic!(
- "Expected ApiInternalServerError for 500 status, got: {:?}",
- error
- ),
- }
- }
-
- #[test]
- fn test_from_cloud_failure_with_standard_format() {
- let error = LanguageModelCompletionError::from_cloud_failure(
- String::from("anthropic").into(),
- "upstream_http_503".to_string(),
- "Service unavailable".to_string(),
- None,
- );
-
- match error {
- LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
- assert_eq!(provider.0, "anthropic");
- }
- _ => panic!("Expected ServerOverloaded error for upstream_http_503"),
- }
- }
-
- #[test]
- fn test_upstream_http_error_connection_timeout() {
- let error = LanguageModelCompletionError::from_cloud_failure(
- String::from("anthropic").into(),
- "upstream_http_error".to_string(),
- r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
- None,
- );
-
- match error {
- LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
- assert_eq!(provider.0, "anthropic");
- }
- _ => panic!(
- "Expected ServerOverloaded error for connection timeout with 503 status, got: {:?}",
- error
- ),
- }
-
- let error = LanguageModelCompletionError::from_cloud_failure(
- String::from("anthropic").into(),
- "upstream_http_error".to_string(),
- r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":500}"#.to_string(),
- None,
- );
-
- match error {
- LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
- assert_eq!(provider.0, "anthropic");
- assert_eq!(
- message,
- "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout"
- );
- }
- _ => panic!(
- "Expected ApiInternalServerError for connection timeout with 500 status, got: {:?}",
- error
- ),
- }
- }
-
- #[test]
- fn test_language_model_tool_use_serializes_with_signature() {
- use serde_json::json;
-
- let tool_use = LanguageModelToolUse {
- id: LanguageModelToolUseId::from("test_id"),
- name: "test_tool".into(),
- raw_input: json!({"arg": "value"}).to_string(),
- input: json!({"arg": "value"}),
- is_input_complete: true,
- thought_signature: Some("test_signature".to_string()),
- };
-
- let serialized = serde_json::to_value(&tool_use).unwrap();
-
- assert_eq!(serialized["id"], "test_id");
- assert_eq!(serialized["name"], "test_tool");
- assert_eq!(serialized["thought_signature"], "test_signature");
- }
-
- #[test]
- fn test_language_model_tool_use_deserializes_with_missing_signature() {
- use serde_json::json;
-
- let json = json!({
- "id": "test_id",
- "name": "test_tool",
- "raw_input": "{\"arg\":\"value\"}",
- "input": {"arg": "value"},
- "is_input_complete": true
- });
-
- let tool_use: LanguageModelToolUse = serde_json::from_value(json).unwrap();
-
- assert_eq!(tool_use.id, LanguageModelToolUseId::from("test_id"));
- assert_eq!(tool_use.name.as_ref(), "test_tool");
- assert_eq!(tool_use.thought_signature, None);
- }
-
- #[test]
- fn test_language_model_tool_use_round_trip_with_signature() {
- use serde_json::json;
-
- let original = LanguageModelToolUse {
- id: LanguageModelToolUseId::from("round_trip_id"),
- name: "round_trip_tool".into(),
- raw_input: json!({"key": "value"}).to_string(),
- input: json!({"key": "value"}),
- is_input_complete: true,
- thought_signature: Some("round_trip_sig".to_string()),
- };
-
- let serialized = serde_json::to_value(&original).unwrap();
- let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
-
- assert_eq!(deserialized.id, original.id);
- assert_eq!(deserialized.name, original.name);
- assert_eq!(deserialized.thought_signature, original.thought_signature);
- }
-
- #[test]
- fn test_language_model_tool_use_round_trip_without_signature() {
- use serde_json::json;
-
- let original = LanguageModelToolUse {
- id: LanguageModelToolUseId::from("no_sig_id"),
- name: "no_sig_tool".into(),
- raw_input: json!({"arg": "value"}).to_string(),
- input: json!({"arg": "value"}),
- is_input_complete: true,
- thought_signature: None,
- };
-
- let serialized = serde_json::to_value(&original).unwrap();
- let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
-
- assert_eq!(deserialized.id, original.id);
- assert_eq!(deserialized.name, original.name);
- assert_eq!(deserialized.thought_signature, None);
- }
-}
@@ -1,10 +1,5 @@
use std::fmt;
-use std::sync::Arc;
-use cloud_api_client::ClientApiError;
-use cloud_api_client::CloudApiClient;
-use cloud_api_types::OrganizationId;
-use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard};
use thiserror::Error;
#[derive(Error, Debug)]
@@ -18,71 +13,3 @@ impl fmt::Display for PaymentRequiredError {
)
}
}
-
-#[derive(Clone, Default)]
-pub struct LlmApiToken(Arc<RwLock<Option<String>>>);
-
-impl LlmApiToken {
- pub async fn acquire(
- &self,
- client: &CloudApiClient,
- system_id: Option<String>,
- organization_id: Option<OrganizationId>,
- ) -> Result<String, ClientApiError> {
- let lock = self.0.upgradable_read().await;
- if let Some(token) = lock.as_ref() {
- Ok(token.to_string())
- } else {
- Self::fetch(
- RwLockUpgradableReadGuard::upgrade(lock).await,
- client,
- system_id,
- organization_id,
- )
- .await
- }
- }
-
- pub async fn refresh(
- &self,
- client: &CloudApiClient,
- system_id: Option<String>,
- organization_id: Option<OrganizationId>,
- ) -> Result<String, ClientApiError> {
- Self::fetch(self.0.write().await, client, system_id, organization_id).await
- }
-
- /// Clears the existing token before attempting to fetch a new one.
- ///
- /// Used when switching organizations so that a failed refresh doesn't
- /// leave a token for the wrong organization.
- pub async fn clear_and_refresh(
- &self,
- client: &CloudApiClient,
- system_id: Option<String>,
- organization_id: Option<OrganizationId>,
- ) -> Result<String, ClientApiError> {
- let mut lock = self.0.write().await;
- *lock = None;
- Self::fetch(lock, client, system_id, organization_id).await
- }
-
- async fn fetch(
- mut lock: RwLockWriteGuard<'_, Option<String>>,
- client: &CloudApiClient,
- system_id: Option<String>,
- organization_id: Option<OrganizationId>,
- ) -> Result<String, ClientApiError> {
- let result = client.create_llm_token(system_id, organization_id).await;
- match result {
- Ok(response) => {
- *lock = Some(response.token.0.clone());
- Ok(response.token.0)
- }
- Err(err) => {
- *lock = None;
- Err(err)
- }
- }
- }
-}
@@ -1,12 +0,0 @@
-pub mod anthropic;
-pub mod google;
-pub mod open_ai;
-pub mod open_router;
-pub mod x_ai;
-pub mod zed;
-
-pub use anthropic::*;
-pub use google::*;
-pub use open_ai::*;
-pub use x_ai::*;
-pub use zed::*;
@@ -1,80 +0,0 @@
-use crate::{LanguageModelCompletionError, LanguageModelProviderId, LanguageModelProviderName};
-use anthropic::AnthropicError;
-pub use anthropic::parse_prompt_too_long;
-
-pub const ANTHROPIC_PROVIDER_ID: LanguageModelProviderId =
- LanguageModelProviderId::new("anthropic");
-pub const ANTHROPIC_PROVIDER_NAME: LanguageModelProviderName =
- LanguageModelProviderName::new("Anthropic");
-
-impl From<AnthropicError> for LanguageModelCompletionError {
- fn from(error: AnthropicError) -> Self {
- let provider = ANTHROPIC_PROVIDER_NAME;
- match error {
- AnthropicError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
- AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
- AnthropicError::HttpSend(error) => Self::HttpSend { provider, error },
- AnthropicError::DeserializeResponse(error) => {
- Self::DeserializeResponse { provider, error }
- }
- AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
- AnthropicError::HttpResponseError {
- status_code,
- message,
- } => Self::HttpResponseError {
- provider,
- status_code,
- message,
- },
- AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded {
- provider,
- retry_after: Some(retry_after),
- },
- AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
- provider,
- retry_after,
- },
- AnthropicError::ApiError(api_error) => api_error.into(),
- }
- }
-}
-
-impl From<anthropic::ApiError> for LanguageModelCompletionError {
- fn from(error: anthropic::ApiError) -> Self {
- use anthropic::ApiErrorCode::*;
- let provider = ANTHROPIC_PROVIDER_NAME;
- match error.code() {
- Some(code) => match code {
- InvalidRequestError => Self::BadRequestFormat {
- provider,
- message: error.message,
- },
- AuthenticationError => Self::AuthenticationError {
- provider,
- message: error.message,
- },
- PermissionError => Self::PermissionError {
- provider,
- message: error.message,
- },
- NotFoundError => Self::ApiEndpointNotFound { provider },
- RequestTooLarge => Self::PromptTooLarge {
- tokens: parse_prompt_too_long(&error.message),
- },
- RateLimitError => Self::RateLimitExceeded {
- provider,
- retry_after: None,
- },
- ApiError => Self::ApiInternalServerError {
- provider,
- message: error.message,
- },
- OverloadedError => Self::ServerOverloaded {
- provider,
- retry_after: None,
- },
- },
- None => Self::Other(error.into()),
- }
- }
-}
@@ -1,5 +0,0 @@
-use crate::{LanguageModelProviderId, LanguageModelProviderName};
-
-pub const GOOGLE_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("google");
-pub const GOOGLE_PROVIDER_NAME: LanguageModelProviderName =
- LanguageModelProviderName::new("Google AI");
@@ -1,28 +0,0 @@
-use crate::{LanguageModelCompletionError, LanguageModelProviderId, LanguageModelProviderName};
-use http_client::http;
-use std::time::Duration;
-
-pub const OPEN_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openai");
-pub const OPEN_AI_PROVIDER_NAME: LanguageModelProviderName =
- LanguageModelProviderName::new("OpenAI");
-
-impl From<open_ai::RequestError> for LanguageModelCompletionError {
- fn from(error: open_ai::RequestError) -> Self {
- match error {
- open_ai::RequestError::HttpResponseError {
- provider,
- status_code,
- body,
- headers,
- } => {
- let retry_after = headers
- .get(http::header::RETRY_AFTER)
- .and_then(|val| val.to_str().ok()?.parse::<u64>().ok())
- .map(Duration::from_secs);
-
- Self::from_http_status(provider.into(), status_code, body, retry_after)
- }
- open_ai::RequestError::Other(e) => Self::Other(e),
- }
- }
-}
@@ -1,69 +0,0 @@
-use crate::{LanguageModelCompletionError, LanguageModelProviderName};
-use http_client::StatusCode;
-use open_router::OpenRouterError;
-
-impl From<OpenRouterError> for LanguageModelCompletionError {
- fn from(error: OpenRouterError) -> Self {
- let provider = LanguageModelProviderName::new("OpenRouter");
- match error {
- OpenRouterError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
- OpenRouterError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
- OpenRouterError::HttpSend(error) => Self::HttpSend { provider, error },
- OpenRouterError::DeserializeResponse(error) => {
- Self::DeserializeResponse { provider, error }
- }
- OpenRouterError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
- OpenRouterError::RateLimit { retry_after } => Self::RateLimitExceeded {
- provider,
- retry_after: Some(retry_after),
- },
- OpenRouterError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
- provider,
- retry_after,
- },
- OpenRouterError::ApiError(api_error) => api_error.into(),
- }
- }
-}
-
-impl From<open_router::ApiError> for LanguageModelCompletionError {
- fn from(error: open_router::ApiError) -> Self {
- use open_router::ApiErrorCode::*;
- let provider = LanguageModelProviderName::new("OpenRouter");
- match error.code {
- InvalidRequestError => Self::BadRequestFormat {
- provider,
- message: error.message,
- },
- AuthenticationError => Self::AuthenticationError {
- provider,
- message: error.message,
- },
- PaymentRequiredError => Self::AuthenticationError {
- provider,
- message: format!("Payment required: {}", error.message),
- },
- PermissionError => Self::PermissionError {
- provider,
- message: error.message,
- },
- RequestTimedOut => Self::HttpResponseError {
- provider,
- status_code: StatusCode::REQUEST_TIMEOUT,
- message: error.message,
- },
- RateLimitError => Self::RateLimitExceeded {
- provider,
- retry_after: None,
- },
- ApiError => Self::ApiInternalServerError {
- provider,
- message: error.message,
- },
- OverloadedError => Self::ServerOverloaded {
- provider,
- retry_after: None,
- },
- }
- }
-}
@@ -1,4 +0,0 @@
-use crate::{LanguageModelProviderId, LanguageModelProviderName};
-
-pub const X_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("x_ai");
-pub const X_AI_PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("xAI");
@@ -1,5 +0,0 @@
-use crate::{LanguageModelProviderId, LanguageModelProviderName};
-
-pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("zed.dev");
-pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName =
- LanguageModelProviderName::new("Zed");
@@ -1,6 +1,6 @@
use crate::{
LanguageModel, LanguageModelId, LanguageModelProvider, LanguageModelProviderId,
- LanguageModelProviderState,
+ LanguageModelProviderState, ZED_CLOUD_PROVIDER_ID,
};
use collections::{BTreeMap, HashSet};
use gpui::{App, Context, Entity, EventEmitter, Global, prelude::*};
@@ -101,7 +101,7 @@ impl ConfiguredModel {
}
pub fn is_provided_by_zed(&self) -> bool {
- self.provider.id() == crate::provider::ZED_CLOUD_PROVIDER_ID
+ self.provider.id() == ZED_CLOUD_PROVIDER_ID
}
}
@@ -4,78 +4,13 @@ use std::sync::Arc;
use anyhow::Result;
use base64::write::EncoderWriter;
use gpui::{
- App, AppContext as _, DevicePixels, Image, ImageFormat, ObjectFit, SharedString, Size, Task,
- point, px, size,
+ App, AppContext as _, DevicePixels, Image, ImageFormat, ObjectFit, Size, Task, point, px, size,
};
use image::GenericImageView as _;
use image::codecs::png::PngEncoder;
-use serde::{Deserialize, Serialize};
use util::ResultExt;
-use crate::role::Role;
-use crate::{LanguageModelToolUse, LanguageModelToolUseId};
-
-#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
-pub struct LanguageModelImage {
- /// A base64-encoded PNG image.
- pub source: SharedString,
- #[serde(default, skip_serializing_if = "Option::is_none")]
- pub size: Option<Size<DevicePixels>>,
-}
-
-impl LanguageModelImage {
- pub fn len(&self) -> usize {
- self.source.len()
- }
-
- pub fn is_empty(&self) -> bool {
- self.source.is_empty()
- }
-
- // Parse Self from a JSON object with case-insensitive field names
- pub fn from_json(obj: &serde_json::Map<String, serde_json::Value>) -> Option<Self> {
- let mut source = None;
- let mut size_obj = None;
-
- // Find source and size fields (case-insensitive)
- for (k, v) in obj.iter() {
- match k.to_lowercase().as_str() {
- "source" => source = v.as_str(),
- "size" => size_obj = v.as_object(),
- _ => {}
- }
- }
-
- let source = source?;
- let size_obj = size_obj?;
-
- let mut width = None;
- let mut height = None;
-
- // Find width and height in size object (case-insensitive)
- for (k, v) in size_obj.iter() {
- match k.to_lowercase().as_str() {
- "width" => width = v.as_i64().map(|w| w as i32),
- "height" => height = v.as_i64().map(|h| h as i32),
- _ => {}
- }
- }
-
- Some(Self {
- size: Some(size(DevicePixels(width?), DevicePixels(height?))),
- source: SharedString::from(source.to_string()),
- })
- }
-}
-
-impl std::fmt::Debug for LanguageModelImage {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- f.debug_struct("LanguageModelImage")
- .field("source", &format!("<{} bytes>", self.source.len()))
- .field("size", &self.size)
- .finish()
- }
-}
+use language_model_core::{ImageSize, LanguageModelImage};
/// Anthropic wants uploaded images to be smaller than this in both dimensions.
const ANTHROPIC_SIZE_LIMIT: f32 = 1568.;
@@ -90,18 +25,16 @@ const DEFAULT_IMAGE_MAX_BYTES: usize = 5 * 1024 * 1024;
/// `DEFAULT_IMAGE_MAX_BYTES`.
const MAX_IMAGE_DOWNSCALE_PASSES: usize = 8;
-impl LanguageModelImage {
- // All language model images are encoded as PNGs.
- pub const FORMAT: ImageFormat = ImageFormat::Png;
+/// Extension trait for `LanguageModelImage` that provides GPUI-dependent functionality.
+pub trait LanguageModelImageExt {
+ const FORMAT: ImageFormat;
+ fn from_image(data: Arc<Image>, cx: &mut App) -> Task<Option<LanguageModelImage>>;
+}
- pub fn empty() -> Self {
- Self {
- source: "".into(),
- size: None,
- }
- }
+impl LanguageModelImageExt for LanguageModelImage {
+ const FORMAT: ImageFormat = ImageFormat::Png;
- pub fn from_image(data: Arc<Image>, cx: &mut App) -> Task<Option<Self>> {
+ fn from_image(data: Arc<Image>, cx: &mut App) -> Task<Option<LanguageModelImage>> {
cx.background_spawn(async move {
let image_bytes = Cursor::new(data.bytes());
let dynamic_image = match data.format() {
@@ -186,28 +119,14 @@ impl LanguageModelImage {
let source = unsafe { String::from_utf8_unchecked(base64_image) };
Some(LanguageModelImage {
- size: Some(image_size),
+ size: Some(ImageSize {
+ width: width as i32,
+ height: height as i32,
+ }),
source: source.into(),
})
})
}
-
- pub fn estimate_tokens(&self) -> usize {
- let Some(size) = self.size.as_ref() else {
- return 0;
- };
- let width = size.width.0.unsigned_abs() as usize;
- let height = size.height.0.unsigned_abs() as usize;
-
- // From: https://docs.anthropic.com/en/docs/build-with-claude/vision#calculate-image-costs
- // Note that are a lot of conditions on Anthropic's API, and OpenAI doesn't use this,
- // so this method is more of a rough guess.
- (width * height) / 750
- }
-
- pub fn to_base64_url(&self) -> String {
- format!("data:image/png;base64,{}", self.source)
- }
}
fn encode_png_bytes(image: &image::DynamicImage) -> Result<Vec<u8>> {
@@ -228,512 +147,85 @@ fn encode_bytes_as_base64(bytes: &[u8]) -> Result<Vec<u8>> {
Ok(base64_image)
}
-#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
-pub struct LanguageModelToolResult {
- pub tool_use_id: LanguageModelToolUseId,
- pub tool_name: Arc<str>,
- pub is_error: bool,
- /// The tool output formatted for presenting to the model
- pub content: LanguageModelToolResultContent,
- /// The raw tool output, if available, often for debugging or extra state for replay
- pub output: Option<serde_json::Value>,
-}
-
-#[derive(Debug, Clone, Serialize, Eq, PartialEq, Hash)]
-pub enum LanguageModelToolResultContent {
- Text(Arc<str>),
- Image(LanguageModelImage),
-}
-
-impl<'de> Deserialize<'de> for LanguageModelToolResultContent {
- fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
- where
- D: serde::Deserializer<'de>,
- {
- use serde::de::Error;
-
- let value = serde_json::Value::deserialize(deserializer)?;
-
- // Models can provide these responses in several styles. Try each in order.
-
- // 1. Try as plain string
- if let Ok(text) = serde_json::from_value::<String>(value.clone()) {
- return Ok(Self::Text(Arc::from(text)));
- }
-
- // 2. Try as object
- if let Some(obj) = value.as_object() {
- // get a JSON field case-insensitively
- fn get_field<'a>(
- obj: &'a serde_json::Map<String, serde_json::Value>,
- field: &str,
- ) -> Option<&'a serde_json::Value> {
- obj.iter()
- .find(|(k, _)| k.to_lowercase() == field.to_lowercase())
- .map(|(_, v)| v)
- }
-
- // Accept wrapped text format: { "type": "text", "text": "..." }
- if let (Some(type_value), Some(text_value)) =
- (get_field(obj, "type"), get_field(obj, "text"))
- && let Some(type_str) = type_value.as_str()
- && type_str.to_lowercase() == "text"
- && let Some(text) = text_value.as_str()
- {
- return Ok(Self::Text(Arc::from(text)));
- }
-
- // Check for wrapped Text variant: { "text": "..." }
- if let Some((_key, value)) = obj.iter().find(|(k, _)| k.to_lowercase() == "text")
- && obj.len() == 1
- {
- // Only one field, and it's "text" (case-insensitive)
- if let Some(text) = value.as_str() {
- return Ok(Self::Text(Arc::from(text)));
- }
- }
-
- // Check for wrapped Image variant: { "image": { "source": "...", "size": ... } }
- if let Some((_key, value)) = obj.iter().find(|(k, _)| k.to_lowercase() == "image")
- && obj.len() == 1
- {
- // Only one field, and it's "image" (case-insensitive)
- // Try to parse the nested image object
- if let Some(image_obj) = value.as_object()
- && let Some(image) = LanguageModelImage::from_json(image_obj)
- {
- return Ok(Self::Image(image));
- }
- }
-
- // Try as direct Image (object with "source" and "size" fields)
- if let Some(image) = LanguageModelImage::from_json(obj) {
- return Ok(Self::Image(image));
- }
- }
-
- // If none of the variants match, return an error with the problematic JSON
- Err(D::Error::custom(format!(
- "data did not match any variant of LanguageModelToolResultContent. Expected either a string, \
- an object with 'type': 'text', a wrapped variant like {{\"Text\": \"...\"}}, or an image object. Got: {}",
- serde_json::to_string_pretty(&value).unwrap_or_else(|_| value.to_string())
- )))
- }
-}
-
-impl LanguageModelToolResultContent {
- pub fn to_str(&self) -> Option<&str> {
- match self {
- Self::Text(text) => Some(text),
- Self::Image(_) => None,
- }
- }
-
- pub fn is_empty(&self) -> bool {
- match self {
- Self::Text(text) => text.chars().all(|c| c.is_whitespace()),
- Self::Image(_) => false,
- }
- }
-}
-
-impl From<&str> for LanguageModelToolResultContent {
- fn from(value: &str) -> Self {
- Self::Text(Arc::from(value))
- }
-}
-
-impl From<String> for LanguageModelToolResultContent {
- fn from(value: String) -> Self {
- Self::Text(Arc::from(value))
- }
-}
-
-impl From<LanguageModelImage> for LanguageModelToolResultContent {
- fn from(image: LanguageModelImage) -> Self {
- Self::Image(image)
- }
-}
-
-#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
-pub enum MessageContent {
- Text(String),
- Thinking {
- text: String,
- signature: Option<String>,
- },
- RedactedThinking(String),
- Image(LanguageModelImage),
- ToolUse(LanguageModelToolUse),
- ToolResult(LanguageModelToolResult),
-}
-
-impl MessageContent {
- pub fn to_str(&self) -> Option<&str> {
- match self {
- MessageContent::Text(text) => Some(text.as_str()),
- MessageContent::Thinking { text, .. } => Some(text.as_str()),
- MessageContent::RedactedThinking(_) => None,
- MessageContent::ToolResult(tool_result) => tool_result.content.to_str(),
- MessageContent::ToolUse(_) | MessageContent::Image(_) => None,
- }
- }
-
- pub fn is_empty(&self) -> bool {
- match self {
- MessageContent::Text(text) => text.chars().all(|c| c.is_whitespace()),
- MessageContent::Thinking { text, .. } => text.chars().all(|c| c.is_whitespace()),
- MessageContent::ToolResult(tool_result) => tool_result.content.is_empty(),
- MessageContent::RedactedThinking(_)
- | MessageContent::ToolUse(_)
- | MessageContent::Image(_) => false,
- }
- }
-}
-
-impl From<String> for MessageContent {
- fn from(value: String) -> Self {
- MessageContent::Text(value)
- }
-}
-
-impl From<&str> for MessageContent {
- fn from(value: &str) -> Self {
- MessageContent::Text(value.to_string())
- }
-}
-
-#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Hash)]
-pub struct LanguageModelRequestMessage {
- pub role: Role,
- pub content: Vec<MessageContent>,
- pub cache: bool,
- #[serde(default, skip_serializing_if = "Option::is_none")]
- pub reasoning_details: Option<serde_json::Value>,
-}
-
-impl LanguageModelRequestMessage {
- pub fn string_contents(&self) -> String {
- let mut buffer = String::new();
- for string in self.content.iter().filter_map(|content| content.to_str()) {
- buffer.push_str(string);
- }
-
- buffer
- }
-
- pub fn contents_empty(&self) -> bool {
- self.content.iter().all(|content| content.is_empty())
- }
-}
-
-#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)]
-pub struct LanguageModelRequestTool {
- pub name: String,
- pub description: String,
- pub input_schema: serde_json::Value,
- pub use_input_streaming: bool,
-}
-
-#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)]
-pub enum LanguageModelToolChoice {
- Auto,
- Any,
- None,
-}
-
-#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
-#[serde(rename_all = "snake_case")]
-pub enum CompletionIntent {
- UserPrompt,
- Subagent,
- ToolResults,
- ThreadSummarization,
- ThreadContextSummarization,
- CreateFile,
- EditFile,
- InlineAssist,
- TerminalInlineAssist,
- GenerateGitCommitMessage,
-}
-
-#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
-pub struct LanguageModelRequest {
- pub thread_id: Option<String>,
- pub prompt_id: Option<String>,
- pub intent: Option<CompletionIntent>,
- pub messages: Vec<LanguageModelRequestMessage>,
- pub tools: Vec<LanguageModelRequestTool>,
- pub tool_choice: Option<LanguageModelToolChoice>,
- pub stop: Vec<String>,
- pub temperature: Option<f32>,
- pub thinking_allowed: bool,
- pub thinking_effort: Option<String>,
- pub speed: Option<Speed>,
-}
-
-#[derive(Clone, Copy, Default, Debug, Serialize, Deserialize, PartialEq, Eq)]
-#[serde(rename_all = "snake_case")]
-pub enum Speed {
- #[default]
- Standard,
- Fast,
-}
-
-impl Speed {
- pub fn toggle(self) -> Self {
- match self {
- Speed::Standard => Speed::Fast,
- Speed::Fast => Speed::Standard,
- }
+/// Convert a core `ImageSize` to a gpui `Size<DevicePixels>`.
+pub fn image_size_to_gpui(size: ImageSize) -> Size<DevicePixels> {
+ Size {
+ width: DevicePixels(size.width),
+ height: DevicePixels(size.height),
}
}
-impl From<Speed> for anthropic::Speed {
- fn from(speed: Speed) -> Self {
- match speed {
- Speed::Standard => anthropic::Speed::Standard,
- Speed::Fast => anthropic::Speed::Fast,
- }
+/// Convert a gpui `Size<DevicePixels>` to a core `ImageSize`.
+pub fn gpui_size_to_image_size(size: Size<DevicePixels>) -> ImageSize {
+ ImageSize {
+ width: size.width.0,
+ height: size.height.0,
}
}
-#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
-pub struct LanguageModelResponseMessage {
- pub role: Option<Role>,
- pub content: Option<String>,
-}
-
#[cfg(test)]
mod tests {
use super::*;
use base64::Engine as _;
use gpui::TestAppContext;
- use image::ImageDecoder as _;
- fn base64_to_png_bytes(base64_png: &str) -> Vec<u8> {
+ fn base64_to_png_bytes(base64: &str) -> Vec<u8> {
base64::engine::general_purpose::STANDARD
- .decode(base64_png.as_bytes())
- .expect("base64 should decode")
+ .decode(base64)
+ .expect("valid base64")
}
fn png_dimensions(png_bytes: &[u8]) -> (u32, u32) {
- let decoder =
- image::codecs::png::PngDecoder::new(Cursor::new(png_bytes)).expect("png should decode");
- decoder.dimensions()
+ let img = image::load_from_memory(png_bytes).expect("valid png");
+ (img.width(), img.height())
}
fn make_noisy_png_bytes(width: u32, height: u32) -> Vec<u8> {
- // Create an RGBA image with per-pixel variance to avoid PNG compressing too well.
- let mut img = image::RgbaImage::new(width, height);
- for y in 0..height {
- for x in 0..width {
- let r = ((x ^ y) & 0xFF) as u8;
- let g = ((x.wrapping_mul(31) ^ y.wrapping_mul(17)) & 0xFF) as u8;
- let b = ((x.wrapping_mul(131) ^ y.wrapping_mul(7)) & 0xFF) as u8;
- img.put_pixel(x, y, image::Rgba([r, g, b, 0xFF]));
- }
- }
+ use image::{ImageBuffer, Rgba};
+ use std::hash::{Hash, Hasher};
+
+ let img = ImageBuffer::from_fn(width, height, |x, y| {
+ let mut hasher = std::hash::DefaultHasher::new();
+ (x, y, width, height).hash(&mut hasher);
+ let h = hasher.finish();
+ Rgba([h as u8, (h >> 8) as u8, (h >> 16) as u8, 255])
+ });
- let mut out = Vec::new();
- image::DynamicImage::ImageRgba8(img)
- .write_with_encoder(PngEncoder::new(&mut out))
- .expect("png encoding should succeed");
- out
+ let mut buf = Cursor::new(Vec::new());
+ img.write_with_encoder(PngEncoder::new(&mut buf))
+ .expect("encode");
+ buf.into_inner()
}
#[gpui::test]
async fn test_from_image_downscales_to_default_5mb_limit(cx: &mut TestAppContext) {
- // Pick a size that reliably produces a PNG > 5MB when filled with noise.
- // If this fails (image is too small), bump dimensions.
- let original_png = make_noisy_png_bytes(4096, 4096);
+ let raw_png = make_noisy_png_bytes(4096, 4096);
assert!(
- original_png.len() > DEFAULT_IMAGE_MAX_BYTES,
- "precondition failed: noisy PNG must exceed DEFAULT_IMAGE_MAX_BYTES"
+ raw_png.len() > DEFAULT_IMAGE_MAX_BYTES,
+ "Test image should exceed the 5 MB limit (actual: {} bytes)",
+ raw_png.len()
);
- let image = gpui::Image::from_bytes(ImageFormat::Png, original_png);
+ let image = Arc::new(gpui::Image::from_bytes(ImageFormat::Png, raw_png));
let lm_image = cx
- .update(|cx| LanguageModelImage::from_image(Arc::new(image), cx))
+ .update(|cx| LanguageModelImage::from_image(Arc::clone(&image), cx))
.await
- .expect("image conversion should succeed");
+ .expect("from_image should succeed");
- let encoded_png = base64_to_png_bytes(lm_image.source.as_ref());
+ let decoded_png = base64_to_png_bytes(lm_image.source.as_ref());
assert!(
- encoded_png.len() <= DEFAULT_IMAGE_MAX_BYTES,
- "expected encoded PNG <= DEFAULT_IMAGE_MAX_BYTES, got {} bytes",
- encoded_png.len()
+ decoded_png.len() <= DEFAULT_IMAGE_MAX_BYTES,
+ "Encoded PNG should be ≤ {} bytes after downscale, but was {} bytes",
+ DEFAULT_IMAGE_MAX_BYTES,
+ decoded_png.len()
);
- // Ensure we actually downscaled in pixels (not just re-encoded).
- let (w, h) = png_dimensions(&encoded_png);
+ let (w, h) = png_dimensions(&decoded_png);
assert!(
- w < 4096 || h < 4096,
- "expected image to be downscaled in at least one dimension; got {w}x{h}"
- );
- }
-
- #[test]
- fn test_language_model_tool_result_content_deserialization() {
- let json = r#""This is plain text""#;
- let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
- assert_eq!(
- result,
- LanguageModelToolResultContent::Text("This is plain text".into())
- );
-
- let json = r#"{"type": "text", "text": "This is wrapped text"}"#;
- let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
- assert_eq!(
- result,
- LanguageModelToolResultContent::Text("This is wrapped text".into())
- );
-
- let json = r#"{"Type": "TEXT", "TEXT": "Case insensitive"}"#;
- let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
- assert_eq!(
- result,
- LanguageModelToolResultContent::Text("Case insensitive".into())
- );
-
- let json = r#"{"Text": "Wrapped variant"}"#;
- let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
- assert_eq!(
- result,
- LanguageModelToolResultContent::Text("Wrapped variant".into())
- );
-
- let json = r#"{"text": "Lowercase wrapped"}"#;
- let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
- assert_eq!(
- result,
- LanguageModelToolResultContent::Text("Lowercase wrapped".into())
+ w < 4096 && h < 4096,
+ "Dimensions should have shrunk: got {}×{}",
+ w,
+ h
);
-
- // Test image deserialization
- let json = r#"{
- "source": "base64encodedimagedata",
- "size": {
- "width": 100,
- "height": 200
- }
- }"#;
- let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
- match result {
- LanguageModelToolResultContent::Image(image) => {
- assert_eq!(image.source.as_ref(), "base64encodedimagedata");
- let size = image.size.expect("size");
- assert_eq!(size.width.0, 100);
- assert_eq!(size.height.0, 200);
- }
- _ => panic!("Expected Image variant"),
- }
-
- // Test wrapped Image variant
- let json = r#"{
- "Image": {
- "source": "wrappedimagedata",
- "size": {
- "width": 50,
- "height": 75
- }
- }
- }"#;
- let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
- match result {
- LanguageModelToolResultContent::Image(image) => {
- assert_eq!(image.source.as_ref(), "wrappedimagedata");
- let size = image.size.expect("size");
- assert_eq!(size.width.0, 50);
- assert_eq!(size.height.0, 75);
- }
- _ => panic!("Expected Image variant"),
- }
-
- // Test wrapped Image variant with case insensitive
- let json = r#"{
- "image": {
- "Source": "caseinsensitive",
- "SIZE": {
- "width": 30,
- "height": 40
- }
- }
- }"#;
- let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
- match result {
- LanguageModelToolResultContent::Image(image) => {
- assert_eq!(image.source.as_ref(), "caseinsensitive");
- let size = image.size.expect("size");
- assert_eq!(size.width.0, 30);
- assert_eq!(size.height.0, 40);
- }
- _ => panic!("Expected Image variant"),
- }
-
- // Test that wrapped text with wrong type fails
- let json = r#"{"type": "blahblah", "text": "This should fail"}"#;
- let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
- assert!(result.is_err());
-
- // Test that malformed JSON fails
- let json = r#"{"invalid": "structure"}"#;
- let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
- assert!(result.is_err());
-
- // Test edge cases
- let json = r#""""#; // Empty string
- let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
- assert_eq!(result, LanguageModelToolResultContent::Text("".into()));
-
- // Test with extra fields in wrapped text (should be ignored)
- let json = r#"{"type": "text", "text": "Hello", "extra": "field"}"#;
- let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
- assert_eq!(result, LanguageModelToolResultContent::Text("Hello".into()));
-
- // Test direct image with case-insensitive fields
- let json = r#"{
- "SOURCE": "directimage",
- "Size": {
- "width": 200,
- "height": 300
- }
- }"#;
- let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
- match result {
- LanguageModelToolResultContent::Image(image) => {
- assert_eq!(image.source.as_ref(), "directimage");
- let size = image.size.expect("size");
- assert_eq!(size.width.0, 200);
- assert_eq!(size.height.0, 300);
- }
- _ => panic!("Expected Image variant"),
- }
-
- // Test that multiple fields prevent wrapped variant interpretation
- let json = r#"{"Text": "not wrapped", "extra": "field"}"#;
- let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
- assert!(result.is_err());
-
- // Test wrapped text with uppercase TEXT variant
- let json = r#"{"TEXT": "Uppercase variant"}"#;
- let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
- assert_eq!(
- result,
- LanguageModelToolResultContent::Text("Uppercase variant".into())
- );
-
- // Test that numbers and other JSON values fail gracefully
- let json = r#"123"#;
- let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
- assert!(result.is_err());
-
- let json = r#"null"#;
- let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
- assert!(result.is_err());
-
- let json = r#"[1, 2, 3]"#;
- let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
- assert!(result.is_err());
}
}
@@ -0,0 +1,27 @@
+[package]
+name = "language_model_core"
+version = "0.1.0"
+edition.workspace = true
+publish.workspace = true
+license = "GPL-3.0-or-later"
+
+[lints]
+workspace = true
+
+[lib]
+path = "src/language_model_core.rs"
+doctest = false
+
+[dependencies]
+anyhow.workspace = true
+cloud_llm_client.workspace = true
+futures.workspace = true
+gpui_shared_string.workspace = true
+http_client.workspace = true
+partial-json-fixer.workspace = true
+schemars.workspace = true
+serde.workspace = true
+serde_json.workspace = true
+smol.workspace = true
+strum.workspace = true
+thiserror.workspace = true
@@ -0,0 +1 @@
+../../LICENSE-GPL
@@ -0,0 +1,658 @@
+mod provider;
+mod rate_limiter;
+mod request;
+mod role;
+pub mod tool_schema;
+pub mod util;
+
+use anyhow::{Result, anyhow};
+use cloud_llm_client::CompletionRequestStatus;
+use http_client::{StatusCode, http};
+use schemars::JsonSchema;
+use serde::{Deserialize, Serialize};
+use std::ops::{Add, Sub};
+use std::str::FromStr;
+use std::sync::Arc;
+use std::time::Duration;
+use std::{fmt, io};
+use thiserror::Error;
+fn is_default<T: Default + PartialEq>(value: &T) -> bool {
+ *value == T::default()
+}
+
+pub use crate::provider::*;
+pub use crate::rate_limiter::*;
+pub use crate::request::*;
+pub use crate::role::*;
+pub use crate::tool_schema::LanguageModelToolSchemaFormat;
+pub use crate::util::{fix_streamed_json, parse_prompt_too_long, parse_tool_arguments};
+pub use gpui_shared_string::SharedString;
+
+#[derive(Clone, Debug)]
+pub struct LanguageModelCacheConfiguration {
+ pub max_cache_anchors: usize,
+ pub should_speculate: bool,
+ pub min_total_token: u64,
+}
+
+/// A completion event from a language model.
+#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
+pub enum LanguageModelCompletionEvent {
+ Queued {
+ position: usize,
+ },
+ Started,
+ Stop(StopReason),
+ Text(String),
+ Thinking {
+ text: String,
+ signature: Option<String>,
+ },
+ RedactedThinking {
+ data: String,
+ },
+ ToolUse(LanguageModelToolUse),
+ ToolUseJsonParseError {
+ id: LanguageModelToolUseId,
+ tool_name: Arc<str>,
+ raw_input: Arc<str>,
+ json_parse_error: String,
+ },
+ StartMessage {
+ message_id: String,
+ },
+ ReasoningDetails(serde_json::Value),
+ UsageUpdate(TokenUsage),
+}
+
+impl LanguageModelCompletionEvent {
+ pub fn from_completion_request_status(
+ status: CompletionRequestStatus,
+ upstream_provider: LanguageModelProviderName,
+ ) -> Result<Option<Self>, LanguageModelCompletionError> {
+ match status {
+ CompletionRequestStatus::Queued { position } => {
+ Ok(Some(LanguageModelCompletionEvent::Queued { position }))
+ }
+ CompletionRequestStatus::Started => Ok(Some(LanguageModelCompletionEvent::Started)),
+ CompletionRequestStatus::Unknown | CompletionRequestStatus::StreamEnded => Ok(None),
+ CompletionRequestStatus::Failed {
+ code,
+ message,
+ request_id: _,
+ retry_after,
+ } => Err(LanguageModelCompletionError::from_cloud_failure(
+ upstream_provider,
+ code,
+ message,
+ retry_after.map(Duration::from_secs_f64),
+ )),
+ }
+ }
+}
+
+#[derive(Error, Debug)]
+pub enum LanguageModelCompletionError {
+ #[error("prompt too large for context window")]
+ PromptTooLarge { tokens: Option<u64> },
+ #[error("missing {provider} API key")]
+ NoApiKey { provider: LanguageModelProviderName },
+ #[error("{provider}'s API rate limit exceeded")]
+ RateLimitExceeded {
+ provider: LanguageModelProviderName,
+ retry_after: Option<Duration>,
+ },
+ #[error("{provider}'s API servers are overloaded right now")]
+ ServerOverloaded {
+ provider: LanguageModelProviderName,
+ retry_after: Option<Duration>,
+ },
+ #[error("{provider}'s API server reported an internal server error: {message}")]
+ ApiInternalServerError {
+ provider: LanguageModelProviderName,
+ message: String,
+ },
+ #[error("{message}")]
+ UpstreamProviderError {
+ message: String,
+ status: StatusCode,
+ retry_after: Option<Duration>,
+ },
+ #[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")]
+ HttpResponseError {
+ provider: LanguageModelProviderName,
+ status_code: StatusCode,
+ message: String,
+ },
+ #[error("invalid request format to {provider}'s API: {message}")]
+ BadRequestFormat {
+ provider: LanguageModelProviderName,
+ message: String,
+ },
+ #[error("authentication error with {provider}'s API: {message}")]
+ AuthenticationError {
+ provider: LanguageModelProviderName,
+ message: String,
+ },
+ #[error("Permission error with {provider}'s API: {message}")]
+ PermissionError {
+ provider: LanguageModelProviderName,
+ message: String,
+ },
+ #[error("language model provider API endpoint not found")]
+ ApiEndpointNotFound { provider: LanguageModelProviderName },
+ #[error("I/O error reading response from {provider}'s API")]
+ ApiReadResponseError {
+ provider: LanguageModelProviderName,
+ #[source]
+ error: io::Error,
+ },
+ #[error("error serializing request to {provider} API")]
+ SerializeRequest {
+ provider: LanguageModelProviderName,
+ #[source]
+ error: serde_json::Error,
+ },
+ #[error("error building request body to {provider} API")]
+ BuildRequestBody {
+ provider: LanguageModelProviderName,
+ #[source]
+ error: http::Error,
+ },
+ #[error("error sending HTTP request to {provider} API")]
+ HttpSend {
+ provider: LanguageModelProviderName,
+ #[source]
+ error: anyhow::Error,
+ },
+ #[error("error deserializing {provider} API response")]
+ DeserializeResponse {
+ provider: LanguageModelProviderName,
+ #[source]
+ error: serde_json::Error,
+ },
+ #[error("stream from {provider} ended unexpectedly")]
+ StreamEndedUnexpectedly { provider: LanguageModelProviderName },
+ #[error(transparent)]
+ Other(#[from] anyhow::Error),
+}
+
+impl LanguageModelCompletionError {
+ fn parse_upstream_error_json(message: &str) -> Option<(StatusCode, String)> {
+ let error_json = serde_json::from_str::<serde_json::Value>(message).ok()?;
+ let upstream_status = error_json
+ .get("upstream_status")
+ .and_then(|v| v.as_u64())
+ .and_then(|status| u16::try_from(status).ok())
+ .and_then(|status| StatusCode::from_u16(status).ok())?;
+ let inner_message = error_json
+ .get("message")
+ .and_then(|v| v.as_str())
+ .unwrap_or(message)
+ .to_string();
+ Some((upstream_status, inner_message))
+ }
+
+ pub fn from_cloud_failure(
+ upstream_provider: LanguageModelProviderName,
+ code: String,
+ message: String,
+ retry_after: Option<Duration>,
+ ) -> Self {
+ if let Some(tokens) = parse_prompt_too_long(&message) {
+ Self::PromptTooLarge {
+ tokens: Some(tokens),
+ }
+ } else if code == "upstream_http_error" {
+ if let Some((upstream_status, inner_message)) =
+ Self::parse_upstream_error_json(&message)
+ {
+ return Self::from_http_status(
+ upstream_provider,
+ upstream_status,
+ inner_message,
+ retry_after,
+ );
+ }
+ anyhow!("completion request failed, code: {code}, message: {message}").into()
+ } else if let Some(status_code) = code
+ .strip_prefix("upstream_http_")
+ .and_then(|code| StatusCode::from_str(code).ok())
+ {
+ Self::from_http_status(upstream_provider, status_code, message, retry_after)
+ } else if let Some(status_code) = code
+ .strip_prefix("http_")
+ .and_then(|code| StatusCode::from_str(code).ok())
+ {
+ Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after)
+ } else {
+ anyhow!("completion request failed, code: {code}, message: {message}").into()
+ }
+ }
+
+ pub fn from_http_status(
+ provider: LanguageModelProviderName,
+ status_code: StatusCode,
+ message: String,
+ retry_after: Option<Duration>,
+ ) -> Self {
+ match status_code {
+ StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message },
+ StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message },
+ StatusCode::FORBIDDEN => Self::PermissionError { provider, message },
+ StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider },
+ StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge {
+ tokens: parse_prompt_too_long(&message),
+ },
+ StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded {
+ provider,
+ retry_after,
+ },
+ StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message },
+ StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded {
+ provider,
+ retry_after,
+ },
+ _ if status_code.as_u16() == 529 => Self::ServerOverloaded {
+ provider,
+ retry_after,
+ },
+ _ => Self::HttpResponseError {
+ provider,
+ status_code,
+ message,
+ },
+ }
+ }
+}
+
+#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
+#[serde(rename_all = "snake_case")]
+pub enum StopReason {
+ EndTurn,
+ MaxTokens,
+ ToolUse,
+ Refusal,
+}
+
+#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
+pub struct TokenUsage {
+ #[serde(default, skip_serializing_if = "is_default")]
+ pub input_tokens: u64,
+ #[serde(default, skip_serializing_if = "is_default")]
+ pub output_tokens: u64,
+ #[serde(default, skip_serializing_if = "is_default")]
+ pub cache_creation_input_tokens: u64,
+ #[serde(default, skip_serializing_if = "is_default")]
+ pub cache_read_input_tokens: u64,
+}
+
+impl TokenUsage {
+ pub fn total_tokens(&self) -> u64 {
+ self.input_tokens
+ + self.output_tokens
+ + self.cache_read_input_tokens
+ + self.cache_creation_input_tokens
+ }
+}
+
+impl Add<TokenUsage> for TokenUsage {
+ type Output = Self;
+
+ fn add(self, other: Self) -> Self {
+ Self {
+ input_tokens: self.input_tokens + other.input_tokens,
+ output_tokens: self.output_tokens + other.output_tokens,
+ cache_creation_input_tokens: self.cache_creation_input_tokens
+ + other.cache_creation_input_tokens,
+ cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens,
+ }
+ }
+}
+
+impl Sub<TokenUsage> for TokenUsage {
+ type Output = Self;
+
+ fn sub(self, other: Self) -> Self {
+ Self {
+ input_tokens: self.input_tokens - other.input_tokens,
+ output_tokens: self.output_tokens - other.output_tokens,
+ cache_creation_input_tokens: self.cache_creation_input_tokens
+ - other.cache_creation_input_tokens,
+ cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens,
+ }
+ }
+}
+
+#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
+pub struct LanguageModelToolUseId(Arc<str>);
+
+impl fmt::Display for LanguageModelToolUseId {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ write!(f, "{}", self.0)
+ }
+}
+
+impl<T> From<T> for LanguageModelToolUseId
+where
+ T: Into<Arc<str>>,
+{
+ fn from(value: T) -> Self {
+ Self(value.into())
+ }
+}
+
+#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
+pub struct LanguageModelToolUse {
+ pub id: LanguageModelToolUseId,
+ pub name: Arc<str>,
+ pub raw_input: String,
+ pub input: serde_json::Value,
+ pub is_input_complete: bool,
+ /// Thought signature the model sent us. Some models require that this
+ /// signature be preserved and sent back in conversation history for validation.
+ pub thought_signature: Option<String>,
+}
+
+#[derive(Debug, Clone)]
+pub struct LanguageModelEffortLevel {
+ pub name: SharedString,
+ pub value: SharedString,
+ pub is_default: bool,
+}
+
+/// An error that occurred when trying to authenticate the language model provider.
+#[derive(Debug, Error)]
+pub enum AuthenticateError {
+ #[error("connection refused")]
+ ConnectionRefused,
+ #[error("credentials not found")]
+ CredentialsNotFound,
+ #[error(transparent)]
+ Other(#[from] anyhow::Error),
+}
+
+#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)]
+pub struct LanguageModelId(pub SharedString);
+
+#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
+pub struct LanguageModelName(pub SharedString);
+
+#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
+pub struct LanguageModelProviderId(pub SharedString);
+
+#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
+pub struct LanguageModelProviderName(pub SharedString);
+
+impl LanguageModelProviderId {
+ pub const fn new(id: &'static str) -> Self {
+ Self(SharedString::new_static(id))
+ }
+}
+
+impl LanguageModelProviderName {
+ pub const fn new(id: &'static str) -> Self {
+ Self(SharedString::new_static(id))
+ }
+}
+
+impl fmt::Display for LanguageModelProviderId {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ write!(f, "{}", self.0)
+ }
+}
+
+impl fmt::Display for LanguageModelProviderName {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ write!(f, "{}", self.0)
+ }
+}
+
+impl From<String> for LanguageModelId {
+ fn from(value: String) -> Self {
+ Self(SharedString::from(value))
+ }
+}
+
+impl From<String> for LanguageModelName {
+ fn from(value: String) -> Self {
+ Self(SharedString::from(value))
+ }
+}
+
+impl From<String> for LanguageModelProviderId {
+ fn from(value: String) -> Self {
+ Self(SharedString::from(value))
+ }
+}
+
+impl From<String> for LanguageModelProviderName {
+ fn from(value: String) -> Self {
+ Self(SharedString::from(value))
+ }
+}
+
+impl From<Arc<str>> for LanguageModelProviderId {
+ fn from(value: Arc<str>) -> Self {
+ Self(SharedString::from(value))
+ }
+}
+
+impl From<Arc<str>> for LanguageModelProviderName {
+ fn from(value: Arc<str>) -> Self {
+ Self(SharedString::from(value))
+ }
+}
+
+/// Settings-layer–free model mode enum.
+///
+/// Mirrors the shape of `settings_content::ModelMode` but lives here so that
+/// crates below the settings layer can reference it.
+#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
+#[serde(tag = "type", rename_all = "lowercase")]
+pub enum ModelMode {
+ #[default]
+ Default,
+ Thinking {
+ budget_tokens: Option<u32>,
+ },
+}
+
+/// Settings-layer–free reasoning-effort enum.
+///
+/// Mirrors the shape of `settings_content::OpenAiReasoningEffort` but lives
+/// here so that crates below the settings layer can reference it.
+#[derive(
+ Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema, strum::EnumString,
+)]
+#[serde(rename_all = "lowercase")]
+#[strum(serialize_all = "lowercase")]
+pub enum ReasoningEffort {
+ Minimal,
+ Low,
+ Medium,
+ High,
+ XHigh,
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_from_cloud_failure_with_upstream_http_error() {
+ let error = LanguageModelCompletionError::from_cloud_failure(
+ String::from("anthropic").into(),
+ "upstream_http_error".to_string(),
+ r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
+ None,
+ );
+
+ match error {
+ LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
+ assert_eq!(provider.0, "anthropic");
+ }
+ _ => panic!(
+ "Expected ServerOverloaded error for 503 status, got: {:?}",
+ error
+ ),
+ }
+
+ let error = LanguageModelCompletionError::from_cloud_failure(
+ String::from("anthropic").into(),
+ "upstream_http_error".to_string(),
+ r#"{"code":"upstream_http_error","message":"Internal server error","upstream_status":500}"#.to_string(),
+ None,
+ );
+
+ match error {
+ LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
+ assert_eq!(provider.0, "anthropic");
+ assert_eq!(message, "Internal server error");
+ }
+ _ => panic!(
+ "Expected ApiInternalServerError for 500 status, got: {:?}",
+ error
+ ),
+ }
+ }
+
+ #[test]
+ fn test_from_cloud_failure_with_standard_format() {
+ let error = LanguageModelCompletionError::from_cloud_failure(
+ String::from("anthropic").into(),
+ "upstream_http_503".to_string(),
+ "Service unavailable".to_string(),
+ None,
+ );
+
+ match error {
+ LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
+ assert_eq!(provider.0, "anthropic");
+ }
+ _ => panic!("Expected ServerOverloaded error for upstream_http_503"),
+ }
+ }
+
+ #[test]
+ fn test_upstream_http_error_connection_timeout() {
+ let error = LanguageModelCompletionError::from_cloud_failure(
+ String::from("anthropic").into(),
+ "upstream_http_error".to_string(),
+ r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
+ None,
+ );
+
+ match error {
+ LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
+ assert_eq!(provider.0, "anthropic");
+ }
+ _ => panic!(
+ "Expected ServerOverloaded error for connection timeout with 503 status, got: {:?}",
+ error
+ ),
+ }
+
+ let error = LanguageModelCompletionError::from_cloud_failure(
+ String::from("anthropic").into(),
+ "upstream_http_error".to_string(),
+ r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":500}"#.to_string(),
+ None,
+ );
+
+ match error {
+ LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
+ assert_eq!(provider.0, "anthropic");
+ assert_eq!(
+ message,
+ "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout"
+ );
+ }
+ _ => panic!(
+ "Expected ApiInternalServerError for connection timeout with 500 status, got: {:?}",
+ error
+ ),
+ }
+ }
+
+ #[test]
+ fn test_language_model_tool_use_serializes_with_signature() {
+ use serde_json::json;
+
+ let tool_use = LanguageModelToolUse {
+ id: LanguageModelToolUseId::from("test_id"),
+ name: "test_tool".into(),
+ raw_input: json!({"arg": "value"}).to_string(),
+ input: json!({"arg": "value"}),
+ is_input_complete: true,
+ thought_signature: Some("test_signature".to_string()),
+ };
+
+ let serialized = serde_json::to_value(&tool_use).unwrap();
+
+ assert_eq!(serialized["id"], "test_id");
+ assert_eq!(serialized["name"], "test_tool");
+ assert_eq!(serialized["thought_signature"], "test_signature");
+ }
+
+ #[test]
+ fn test_language_model_tool_use_deserializes_with_missing_signature() {
+ use serde_json::json;
+
+ let json = json!({
+ "id": "test_id",
+ "name": "test_tool",
+ "raw_input": "{\"arg\":\"value\"}",
+ "input": {"arg": "value"},
+ "is_input_complete": true
+ });
+
+ let tool_use: LanguageModelToolUse = serde_json::from_value(json).unwrap();
+
+ assert_eq!(tool_use.id, LanguageModelToolUseId::from("test_id"));
+ assert_eq!(tool_use.name.as_ref(), "test_tool");
+ assert_eq!(tool_use.thought_signature, None);
+ }
+
+ #[test]
+ fn test_language_model_tool_use_round_trip_with_signature() {
+ use serde_json::json;
+
+ let original = LanguageModelToolUse {
+ id: LanguageModelToolUseId::from("round_trip_id"),
+ name: "round_trip_tool".into(),
+ raw_input: json!({"key": "value"}).to_string(),
+ input: json!({"key": "value"}),
+ is_input_complete: true,
+ thought_signature: Some("round_trip_sig".to_string()),
+ };
+
+ let serialized = serde_json::to_value(&original).unwrap();
+ let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
+
+ assert_eq!(deserialized.id, original.id);
+ assert_eq!(deserialized.name, original.name);
+ assert_eq!(deserialized.thought_signature, original.thought_signature);
+ }
+
+ #[test]
+ fn test_language_model_tool_use_round_trip_without_signature() {
+ use serde_json::json;
+
+ let original = LanguageModelToolUse {
+ id: LanguageModelToolUseId::from("no_sig_id"),
+ name: "no_sig_tool".into(),
+ raw_input: json!({"arg": "value"}).to_string(),
+ input: json!({"arg": "value"}),
+ is_input_complete: true,
+ thought_signature: None,
+ };
+
+ let serialized = serde_json::to_value(&original).unwrap();
+ let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
+
+ assert_eq!(deserialized.id, original.id);
+ assert_eq!(deserialized.name, original.name);
+ assert_eq!(deserialized.thought_signature, None);
+ }
+}
@@ -0,0 +1,21 @@
+use crate::{LanguageModelProviderId, LanguageModelProviderName};
+
+pub const ANTHROPIC_PROVIDER_ID: LanguageModelProviderId =
+ LanguageModelProviderId::new("anthropic");
+pub const ANTHROPIC_PROVIDER_NAME: LanguageModelProviderName =
+ LanguageModelProviderName::new("Anthropic");
+
+pub const OPEN_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openai");
+pub const OPEN_AI_PROVIDER_NAME: LanguageModelProviderName =
+ LanguageModelProviderName::new("OpenAI");
+
+pub const GOOGLE_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("google");
+pub const GOOGLE_PROVIDER_NAME: LanguageModelProviderName =
+ LanguageModelProviderName::new("Google AI");
+
+pub const X_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("x_ai");
+pub const X_AI_PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("xAI");
+
+pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("zed.dev");
+pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName =
+ LanguageModelProviderName::new("Zed");
@@ -0,0 +1,463 @@
+use std::sync::Arc;
+
+use serde::{Deserialize, Serialize};
+
+use crate::role::Role;
+use crate::{LanguageModelToolUse, LanguageModelToolUseId, SharedString};
+
+/// Dimensions of a `LanguageModelImage`
+#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
+pub struct ImageSize {
+ pub width: i32,
+ pub height: i32,
+}
+
+#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
+pub struct LanguageModelImage {
+ /// A base64-encoded PNG image.
+ pub source: SharedString,
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ pub size: Option<ImageSize>,
+}
+
+impl LanguageModelImage {
+ pub fn len(&self) -> usize {
+ self.source.len()
+ }
+
+ pub fn is_empty(&self) -> bool {
+ self.source.is_empty()
+ }
+
+ pub fn empty() -> Self {
+ Self {
+ source: "".into(),
+ size: None,
+ }
+ }
+
+ /// Parse Self from a JSON object with case-insensitive field names
+ pub fn from_json(obj: &serde_json::Map<String, serde_json::Value>) -> Option<Self> {
+ let mut source = None;
+ let mut size_obj = None;
+
+ for (k, v) in obj.iter() {
+ match k.to_lowercase().as_str() {
+ "source" => source = v.as_str(),
+ "size" => size_obj = v.as_object(),
+ _ => {}
+ }
+ }
+
+ let source = source?;
+ let size_obj = size_obj?;
+
+ let mut width = None;
+ let mut height = None;
+
+ for (k, v) in size_obj.iter() {
+ match k.to_lowercase().as_str() {
+ "width" => width = v.as_i64().map(|w| w as i32),
+ "height" => height = v.as_i64().map(|h| h as i32),
+ _ => {}
+ }
+ }
+
+ Some(Self {
+ size: Some(ImageSize {
+ width: width?,
+ height: height?,
+ }),
+ source: SharedString::from(source.to_string()),
+ })
+ }
+
+ pub fn estimate_tokens(&self) -> usize {
+ let Some(size) = self.size.as_ref() else {
+ return 0;
+ };
+ let width = size.width.unsigned_abs() as usize;
+ let height = size.height.unsigned_abs() as usize;
+
+ // From: https://docs.anthropic.com/en/docs/build-with-claude/vision#calculate-image-costs
+ (width * height) / 750
+ }
+
+ pub fn to_base64_url(&self) -> String {
+ format!("data:image/png;base64,{}", self.source)
+ }
+}
+
+impl std::fmt::Debug for LanguageModelImage {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ f.debug_struct("LanguageModelImage")
+ .field("source", &format!("<{} bytes>", self.source.len()))
+ .field("size", &self.size)
+ .finish()
+ }
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
+pub struct LanguageModelToolResult {
+ pub tool_use_id: LanguageModelToolUseId,
+ pub tool_name: Arc<str>,
+ pub is_error: bool,
+ /// The tool output formatted for presenting to the model
+ pub content: LanguageModelToolResultContent,
+ /// The raw tool output, if available, often for debugging or extra state for replay
+ pub output: Option<serde_json::Value>,
+}
+
+#[derive(Debug, Clone, Serialize, Eq, PartialEq, Hash)]
+pub enum LanguageModelToolResultContent {
+ Text(Arc<str>),
+ Image(LanguageModelImage),
+}
+
+impl<'de> Deserialize<'de> for LanguageModelToolResultContent {
+ fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
+ where
+ D: serde::Deserializer<'de>,
+ {
+ use serde::de::Error;
+
+ let value = serde_json::Value::deserialize(deserializer)?;
+
+ // 1. Try as plain string
+ if let Ok(text) = serde_json::from_value::<String>(value.clone()) {
+ return Ok(Self::Text(Arc::from(text)));
+ }
+
+ // 2. Try as object
+ if let Some(obj) = value.as_object() {
+ fn get_field<'a>(
+ obj: &'a serde_json::Map<String, serde_json::Value>,
+ field: &str,
+ ) -> Option<&'a serde_json::Value> {
+ obj.iter()
+ .find(|(k, _)| k.to_lowercase() == field.to_lowercase())
+ .map(|(_, v)| v)
+ }
+
+ // Accept wrapped text format: { "type": "text", "text": "..." }
+ if let (Some(type_value), Some(text_value)) =
+ (get_field(obj, "type"), get_field(obj, "text"))
+ && let Some(type_str) = type_value.as_str()
+ && type_str.to_lowercase() == "text"
+ && let Some(text) = text_value.as_str()
+ {
+ return Ok(Self::Text(Arc::from(text)));
+ }
+
+ // Check for wrapped Text variant: { "text": "..." }
+ if let Some((_key, value)) = obj.iter().find(|(k, _)| k.to_lowercase() == "text")
+ && obj.len() == 1
+ {
+ if let Some(text) = value.as_str() {
+ return Ok(Self::Text(Arc::from(text)));
+ }
+ }
+
+ // Check for wrapped Image variant: { "image": { "source": "...", "size": ... } }
+ if let Some((_key, value)) = obj.iter().find(|(k, _)| k.to_lowercase() == "image")
+ && obj.len() == 1
+ {
+ if let Some(image_obj) = value.as_object()
+ && let Some(image) = LanguageModelImage::from_json(image_obj)
+ {
+ return Ok(Self::Image(image));
+ }
+ }
+
+ // Try as direct Image
+ if let Some(image) = LanguageModelImage::from_json(obj) {
+ return Ok(Self::Image(image));
+ }
+ }
+
+ Err(D::Error::custom(format!(
+ "data did not match any variant of LanguageModelToolResultContent. Expected either a string, \
+ an object with 'type': 'text', a wrapped variant like {{\"Text\": \"...\"}}, or an image object. Got: {}",
+ serde_json::to_string_pretty(&value).unwrap_or_else(|_| value.to_string())
+ )))
+ }
+}
+
+impl LanguageModelToolResultContent {
+ pub fn to_str(&self) -> Option<&str> {
+ match self {
+ Self::Text(text) => Some(text),
+ Self::Image(_) => None,
+ }
+ }
+
+ pub fn is_empty(&self) -> bool {
+ match self {
+ Self::Text(text) => text.chars().all(|c| c.is_whitespace()),
+ Self::Image(_) => false,
+ }
+ }
+}
+
+impl From<&str> for LanguageModelToolResultContent {
+ fn from(value: &str) -> Self {
+ Self::Text(Arc::from(value))
+ }
+}
+
+impl From<String> for LanguageModelToolResultContent {
+ fn from(value: String) -> Self {
+ Self::Text(Arc::from(value))
+ }
+}
+
+impl From<LanguageModelImage> for LanguageModelToolResultContent {
+ fn from(image: LanguageModelImage) -> Self {
+ Self::Image(image)
+ }
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
+pub enum MessageContent {
+ Text(String),
+ Thinking {
+ text: String,
+ signature: Option<String>,
+ },
+ RedactedThinking(String),
+ Image(LanguageModelImage),
+ ToolUse(LanguageModelToolUse),
+ ToolResult(LanguageModelToolResult),
+}
+
+impl MessageContent {
+ pub fn to_str(&self) -> Option<&str> {
+ match self {
+ MessageContent::Text(text) => Some(text.as_str()),
+ MessageContent::Thinking { text, .. } => Some(text.as_str()),
+ MessageContent::RedactedThinking(_) => None,
+ MessageContent::ToolResult(tool_result) => tool_result.content.to_str(),
+ MessageContent::ToolUse(_) | MessageContent::Image(_) => None,
+ }
+ }
+
+ pub fn is_empty(&self) -> bool {
+ match self {
+ MessageContent::Text(text) => text.chars().all(|c| c.is_whitespace()),
+ MessageContent::Thinking { text, .. } => text.chars().all(|c| c.is_whitespace()),
+ MessageContent::ToolResult(tool_result) => tool_result.content.is_empty(),
+ MessageContent::RedactedThinking(_)
+ | MessageContent::ToolUse(_)
+ | MessageContent::Image(_) => false,
+ }
+ }
+}
+
+impl From<String> for MessageContent {
+ fn from(value: String) -> Self {
+ MessageContent::Text(value)
+ }
+}
+
+impl From<&str> for MessageContent {
+ fn from(value: &str) -> Self {
+ MessageContent::Text(value.to_string())
+ }
+}
+
+#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Hash)]
+pub struct LanguageModelRequestMessage {
+ pub role: Role,
+ pub content: Vec<MessageContent>,
+ pub cache: bool,
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ pub reasoning_details: Option<serde_json::Value>,
+}
+
+impl LanguageModelRequestMessage {
+ pub fn string_contents(&self) -> String {
+ let mut buffer = String::new();
+ for string in self.content.iter().filter_map(|content| content.to_str()) {
+ buffer.push_str(string);
+ }
+ buffer
+ }
+
+ pub fn contents_empty(&self) -> bool {
+ self.content.iter().all(|content| content.is_empty())
+ }
+}
+
+#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)]
+pub struct LanguageModelRequestTool {
+ pub name: String,
+ pub description: String,
+ pub input_schema: serde_json::Value,
+ pub use_input_streaming: bool,
+}
+
+#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)]
+pub enum LanguageModelToolChoice {
+ Auto,
+ Any,
+ None,
+}
+
+#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
+#[serde(rename_all = "snake_case")]
+pub enum CompletionIntent {
+ UserPrompt,
+ Subagent,
+ ToolResults,
+ ThreadSummarization,
+ ThreadContextSummarization,
+ CreateFile,
+ EditFile,
+ InlineAssist,
+ TerminalInlineAssist,
+ GenerateGitCommitMessage,
+}
+
+#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
+pub struct LanguageModelRequest {
+ pub thread_id: Option<String>,
+ pub prompt_id: Option<String>,
+ pub intent: Option<CompletionIntent>,
+ pub messages: Vec<LanguageModelRequestMessage>,
+ pub tools: Vec<LanguageModelRequestTool>,
+ pub tool_choice: Option<LanguageModelToolChoice>,
+ pub stop: Vec<String>,
+ pub temperature: Option<f32>,
+ pub thinking_allowed: bool,
+ pub thinking_effort: Option<String>,
+ pub speed: Option<Speed>,
+}
+
+#[derive(Clone, Copy, Default, Debug, Serialize, Deserialize, PartialEq, Eq)]
+#[serde(rename_all = "snake_case")]
+pub enum Speed {
+ #[default]
+ Standard,
+ Fast,
+}
+
+impl Speed {
+ pub fn toggle(self) -> Self {
+ match self {
+ Speed::Standard => Speed::Fast,
+ Speed::Fast => Speed::Standard,
+ }
+ }
+}
+
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+pub struct LanguageModelResponseMessage {
+ pub role: Option<Role>,
+ pub content: Option<String>,
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_language_model_tool_result_content_deserialization() {
+ // Test plain string
+ let json = serde_json::json!("hello world");
+ let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap();
+ assert_eq!(
+ content,
+ LanguageModelToolResultContent::Text(Arc::from("hello world"))
+ );
+
+ // Test wrapped text format: { "type": "text", "text": "..." }
+ let json = serde_json::json!({"type": "text", "text": "hello"});
+ let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap();
+ assert_eq!(
+ content,
+ LanguageModelToolResultContent::Text(Arc::from("hello"))
+ );
+
+ // Test single-field text object: { "text": "..." }
+ let json = serde_json::json!({"text": "hello"});
+ let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap();
+ assert_eq!(
+ content,
+ LanguageModelToolResultContent::Text(Arc::from("hello"))
+ );
+
+ // Test case-insensitive type field
+ let json = serde_json::json!({"Type": "Text", "Text": "hello"});
+ let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap();
+ assert_eq!(
+ content,
+ LanguageModelToolResultContent::Text(Arc::from("hello"))
+ );
+
+ // Test image object
+ let json = serde_json::json!({
+ "source": "base64encodedimagedata",
+ "size": {"width": 100, "height": 200}
+ });
+ let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap();
+ match content {
+ LanguageModelToolResultContent::Image(image) => {
+ assert_eq!(image.source.as_ref(), "base64encodedimagedata");
+ let size = image.size.expect("size");
+ assert_eq!(size.width, 100);
+ assert_eq!(size.height, 200);
+ }
+ _ => panic!("Expected Image variant"),
+ }
+
+ // Test wrapped image: { "image": { "source": "...", "size": ... } }
+ let json = serde_json::json!({
+ "image": {
+ "source": "wrappedimagedata",
+ "size": {"width": 50, "height": 75}
+ }
+ });
+ let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap();
+ match content {
+ LanguageModelToolResultContent::Image(image) => {
+ assert_eq!(image.source.as_ref(), "wrappedimagedata");
+ let size = image.size.expect("size");
+ assert_eq!(size.width, 50);
+ assert_eq!(size.height, 75);
+ }
+ _ => panic!("Expected Image variant"),
+ }
+
+ // Test case insensitive
+ let json = serde_json::json!({
+ "Source": "caseinsensitive",
+ "Size": {"Width": 30, "Height": 40}
+ });
+ let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap();
+ match content {
+ LanguageModelToolResultContent::Image(image) => {
+ assert_eq!(image.source.as_ref(), "caseinsensitive");
+ let size = image.size.expect("size");
+ assert_eq!(size.width, 30);
+ assert_eq!(size.height, 40);
+ }
+ _ => panic!("Expected Image variant"),
+ }
+
+ // Test direct image object
+ let json = serde_json::json!({
+ "source": "directimage",
+ "size": {"width": 200, "height": 300}
+ });
+ let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap();
+ match content {
+ LanguageModelToolResultContent::Image(image) => {
+ assert_eq!(image.source.as_ref(), "directimage");
+ let size = image.size.expect("size");
+ assert_eq!(size.width, 200);
+ assert_eq!(size.height, 300);
+ }
+ _ => panic!("Expected Image variant"),
+ }
+ }
+}
@@ -77,8 +77,6 @@ pub fn adapt_schema_to_format(
}
fn preprocess_json_schema(json: &mut Value) -> Result<()> {
- // `additionalProperties` defaults to `false` unless explicitly specified.
- // This prevents models from hallucinating tool parameters.
if let Value::Object(obj) = json
&& matches!(obj.get("type"), Some(Value::String(s)) if s == "object")
{
@@ -86,7 +84,6 @@ fn preprocess_json_schema(json: &mut Value) -> Result<()> {
obj.insert("additionalProperties".to_string(), Value::Bool(false));
}
- // OpenAI API requires non-missing `properties`
if !obj.contains_key("properties") {
obj.insert("properties".to_string(), Value::Object(Default::default()));
}
@@ -94,7 +91,6 @@ fn preprocess_json_schema(json: &mut Value) -> Result<()> {
Ok(())
}
-/// Tries to adapt the json schema so that it is compatible with https://ai.google.dev/api/caching#Schema
fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> {
if let Value::Object(obj) = json {
const UNSUPPORTED_KEYS: [&str; 4] = ["if", "then", "else", "$ref"];
@@ -108,9 +104,7 @@ fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> {
const KEYS_TO_REMOVE: [(&str, fn(&Value) -> bool); 6] = [
("format", |value| value.is_string()),
- // Gemini doesn't support `additionalProperties` in any form (boolean or schema object)
("additionalProperties", |_| true),
- // Gemini doesn't support `propertyNames`
("propertyNames", |_| true),
("exclusiveMinimum", |value| value.is_number()),
("exclusiveMaximum", |value| value.is_number()),
@@ -124,7 +118,6 @@ fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> {
}
}
- // If a type is not specified for an input parameter, add a default type
if matches!(obj.get("description"), Some(Value::String(_)))
&& !obj.contains_key("type")
&& !(obj.contains_key("anyOf")
@@ -134,7 +127,6 @@ fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> {
obj.insert("type".to_string(), Value::String("string".to_string()));
}
- // Handle oneOf -> anyOf conversion
if let Some(subschemas) = obj.get_mut("oneOf")
&& subschemas.is_array()
{
@@ -143,7 +135,6 @@ fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> {
obj.insert("anyOf".to_string(), subschemas_clone);
}
- // Recursively process all nested objects and arrays
for (_, value) in obj.iter_mut() {
if let Value::Object(_) | Value::Array(_) = value {
adapt_to_json_schema_subset(value)?;
@@ -178,7 +169,6 @@ mod tests {
})
);
- // Ensure that we do not add a type if it is an object
let mut json = json!({
"description": {
"value": "abc",
@@ -221,7 +211,6 @@ mod tests {
})
);
- // Ensure that we do not remove keys that are actually supported (e.g. "format" can just be used as another property)
let mut json = json!({
"description": "A test field",
"type": "integer",
@@ -239,7 +228,6 @@ mod tests {
})
);
- // additionalProperties as an object schema is also unsupported by Gemini
let mut json = json!({
"type": "object",
"properties": {
@@ -38,13 +38,22 @@ fn strip_trailing_incomplete_escape(json: &str) -> &str {
}
}
+/// Parses a "prompt is too long: N tokens ..." message and extracts the token count.
+pub fn parse_prompt_too_long(message: &str) -> Option<u64> {
+ message
+ .strip_prefix("prompt is too long: ")?
+ .split_once(" tokens")?
+ .0
+ .parse()
+ .ok()
+}
+
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_fix_streamed_json_strips_incomplete_escape() {
- // Trailing `\` inside a string — incomplete escape sequence
let fixed = fix_streamed_json(r#"{"text": "hello\"#);
let parsed: serde_json::Value = serde_json::from_str(&fixed).expect("valid json");
assert_eq!(parsed["text"], "hello");
@@ -52,7 +61,6 @@ mod tests {
#[test]
fn test_fix_streamed_json_preserves_complete_escape() {
- // `\\` is a complete escape (literal backslash)
let fixed = fix_streamed_json(r#"{"text": "hello\\"#);
let parsed: serde_json::Value = serde_json::from_str(&fixed).expect("valid json");
assert_eq!(parsed["text"], "hello\\");
@@ -60,7 +68,6 @@ mod tests {
#[test]
fn test_fix_streamed_json_strips_escape_after_complete_escape() {
- // `\\\` = complete `\\` (literal backslash) + incomplete `\`
let fixed = fix_streamed_json(r#"{"text": "hello\\\"#);
let parsed: serde_json::Value = serde_json::from_str(&fixed).expect("valid json");
assert_eq!(parsed["text"], "hello\\");
@@ -75,12 +82,10 @@ mod tests {
#[test]
fn test_fix_streamed_json_newline_escape_boundary() {
- // Simulates a stream boundary landing between `\` and `n`
let fixed = fix_streamed_json(r#"{"text": "line1\"#);
let parsed: serde_json::Value = serde_json::from_str(&fixed).expect("valid json");
assert_eq!(parsed["text"], "line1");
- // Next chunk completes the escape
let fixed = fix_streamed_json(r#"{"text": "line1\nline2"#);
let parsed: serde_json::Value = serde_json::from_str(&fixed).expect("valid json");
assert_eq!(parsed["text"], "line1\nline2");
@@ -88,8 +93,6 @@ mod tests {
#[test]
fn test_fix_streamed_json_incremental_delta_correctness() {
- // This is the actual scenario that causes the bug:
- // chunk 1 ends mid-escape, chunk 2 completes it.
let chunk1 = r#"{"replacement_text": "fn foo() {\"#;
let fixed1 = fix_streamed_json(chunk1);
let parsed1: serde_json::Value = serde_json::from_str(&fixed1).expect("valid json");
@@ -102,7 +105,6 @@ mod tests {
let text2 = parsed2["replacement_text"].as_str().expect("string");
assert_eq!(text2, "fn foo() {\n return bar;\n}");
- // The delta should be the newline + rest, with no spurious backslash
let delta = &text2[text1.len()..];
assert_eq!(delta, "\n return bar;\n}");
}
@@ -21,8 +21,8 @@ aws_http_client.workspace = true
base64.workspace = true
bedrock = { workspace = true, features = ["schemars"] }
client.workspace = true
+cloud_api_client.workspace = true
cloud_api_types.workspace = true
-cloud_llm_client.workspace = true
collections.workspace = true
component.workspace = true
convert_case.workspace = true
@@ -41,6 +41,7 @@ gpui_tokio.workspace = true
http_client.workspace = true
language.workspace = true
language_model.workspace = true
+language_models_cloud.workspace = true
lmstudio = { workspace = true, features = ["schemars"] }
log.workspace = true
menu.workspace = true
@@ -49,16 +50,13 @@ ollama = { workspace = true, features = ["schemars"] }
open_ai = { workspace = true, features = ["schemars"] }
opencode = { workspace = true, features = ["schemars"] }
open_router = { workspace = true, features = ["schemars"] }
-partial-json-fixer.workspace = true
release_channel.workspace = true
schemars.workspace = true
-semver.workspace = true
serde.workspace = true
serde_json.workspace = true
settings.workspace = true
smol.workspace = true
strum.workspace = true
-thiserror.workspace = true
tiktoken-rs.workspace = true
tokio = { workspace = true, features = ["rt", "rt-multi-thread"] }
ui.workspace = true
@@ -70,4 +68,3 @@ x_ai = { workspace = true, features = ["schemars"] }
[dev-dependencies]
language_model = { workspace = true, features = ["test-support"] }
pretty_assertions.workspace = true
-
@@ -11,7 +11,7 @@ pub mod open_ai;
pub mod open_ai_compatible;
pub mod open_router;
pub mod opencode;
-mod util;
+
pub mod vercel;
pub mod vercel_ai_gateway;
pub mod x_ai;
@@ -1,13 +1,10 @@
pub mod telemetry;
-use anthropic::{
- ANTHROPIC_API_URL, AnthropicError, AnthropicModelMode, ContentDelta, CountTokensRequest, Event,
- ResponseContent, ToolResultContent, ToolResultPart, Usage,
-};
+use anthropic::{ANTHROPIC_API_URL, AnthropicError, AnthropicModelMode};
use anyhow::Result;
-use collections::{BTreeMap, HashMap};
+use collections::BTreeMap;
use credentials_provider::CredentialsProvider;
-use futures::{FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream};
+use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
use gpui::{AnyView, App, AsyncApp, Context, Entity, Task};
use http_client::HttpClient;
use language_model::{
@@ -16,20 +13,19 @@ use language_model::{
LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
- LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
- RateLimiter, Role, StopReason, env_var,
+ LanguageModelToolChoice, RateLimiter, env_var,
};
use settings::{Settings, SettingsStore};
-use std::pin::Pin;
-use std::str::FromStr;
use std::sync::{Arc, LazyLock};
use strum::IntoEnumIterator;
use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*};
use ui_input::InputField;
use util::ResultExt;
-use crate::provider::util::{fix_streamed_json, parse_tool_arguments};
-
+pub use anthropic::completion::{
+ AnthropicEventMapper, count_anthropic_tokens_with_tiktoken, into_anthropic,
+ into_anthropic_count_tokens_request,
+};
pub use settings::AnthropicAvailableModel as AvailableModel;
const PROVIDER_ID: LanguageModelProviderId = ANTHROPIC_PROVIDER_ID;
@@ -249,228 +245,6 @@ pub struct AnthropicModel {
request_limiter: RateLimiter,
}
-fn to_anthropic_content(content: MessageContent) -> Option<anthropic::RequestContent> {
- match content {
- MessageContent::Text(text) => {
- let text = if text.chars().last().is_some_and(|c| c.is_whitespace()) {
- text.trim_end().to_string()
- } else {
- text
- };
- if !text.is_empty() {
- Some(anthropic::RequestContent::Text {
- text,
- cache_control: None,
- })
- } else {
- None
- }
- }
- MessageContent::Thinking {
- text: thinking,
- signature,
- } => {
- if let Some(signature) = signature
- && !thinking.is_empty()
- {
- Some(anthropic::RequestContent::Thinking {
- thinking,
- signature,
- cache_control: None,
- })
- } else {
- None
- }
- }
- MessageContent::RedactedThinking(data) => {
- if !data.is_empty() {
- Some(anthropic::RequestContent::RedactedThinking { data })
- } else {
- None
- }
- }
- MessageContent::Image(image) => Some(anthropic::RequestContent::Image {
- source: anthropic::ImageSource {
- source_type: "base64".to_string(),
- media_type: "image/png".to_string(),
- data: image.source.to_string(),
- },
- cache_control: None,
- }),
- MessageContent::ToolUse(tool_use) => Some(anthropic::RequestContent::ToolUse {
- id: tool_use.id.to_string(),
- name: tool_use.name.to_string(),
- input: tool_use.input,
- cache_control: None,
- }),
- MessageContent::ToolResult(tool_result) => Some(anthropic::RequestContent::ToolResult {
- tool_use_id: tool_result.tool_use_id.to_string(),
- is_error: tool_result.is_error,
- content: match tool_result.content {
- LanguageModelToolResultContent::Text(text) => {
- ToolResultContent::Plain(text.to_string())
- }
- LanguageModelToolResultContent::Image(image) => {
- ToolResultContent::Multipart(vec![ToolResultPart::Image {
- source: anthropic::ImageSource {
- source_type: "base64".to_string(),
- media_type: "image/png".to_string(),
- data: image.source.to_string(),
- },
- }])
- }
- },
- cache_control: None,
- }),
- }
-}
-
-/// Convert a LanguageModelRequest to an Anthropic CountTokensRequest.
-pub fn into_anthropic_count_tokens_request(
- request: LanguageModelRequest,
- model: String,
- mode: AnthropicModelMode,
-) -> CountTokensRequest {
- let mut new_messages: Vec<anthropic::Message> = Vec::new();
- let mut system_message = String::new();
-
- for message in request.messages {
- if message.contents_empty() {
- continue;
- }
-
- match message.role {
- Role::User | Role::Assistant => {
- let anthropic_message_content: Vec<anthropic::RequestContent> = message
- .content
- .into_iter()
- .filter_map(to_anthropic_content)
- .collect();
- let anthropic_role = match message.role {
- Role::User => anthropic::Role::User,
- Role::Assistant => anthropic::Role::Assistant,
- Role::System => unreachable!("System role should never occur here"),
- };
- if anthropic_message_content.is_empty() {
- continue;
- }
-
- if let Some(last_message) = new_messages.last_mut()
- && last_message.role == anthropic_role
- {
- last_message.content.extend(anthropic_message_content);
- continue;
- }
-
- new_messages.push(anthropic::Message {
- role: anthropic_role,
- content: anthropic_message_content,
- });
- }
- Role::System => {
- if !system_message.is_empty() {
- system_message.push_str("\n\n");
- }
- system_message.push_str(&message.string_contents());
- }
- }
- }
-
- CountTokensRequest {
- model,
- messages: new_messages,
- system: if system_message.is_empty() {
- None
- } else {
- Some(anthropic::StringOrContents::String(system_message))
- },
- thinking: if request.thinking_allowed {
- match mode {
- AnthropicModelMode::Thinking { budget_tokens } => {
- Some(anthropic::Thinking::Enabled { budget_tokens })
- }
- AnthropicModelMode::AdaptiveThinking => Some(anthropic::Thinking::Adaptive),
- AnthropicModelMode::Default => None,
- }
- } else {
- None
- },
- tools: request
- .tools
- .into_iter()
- .map(|tool| anthropic::Tool {
- name: tool.name,
- description: tool.description,
- input_schema: tool.input_schema,
- eager_input_streaming: tool.use_input_streaming,
- })
- .collect(),
- tool_choice: request.tool_choice.map(|choice| match choice {
- LanguageModelToolChoice::Auto => anthropic::ToolChoice::Auto,
- LanguageModelToolChoice::Any => anthropic::ToolChoice::Any,
- LanguageModelToolChoice::None => anthropic::ToolChoice::None,
- }),
- }
-}
-
-/// Estimate tokens using tiktoken. Used as a fallback when the API is unavailable,
-/// or by providers (like Zed Cloud) that don't have direct Anthropic API access.
-pub fn count_anthropic_tokens_with_tiktoken(request: LanguageModelRequest) -> Result<u64> {
- let messages = request.messages;
- let mut tokens_from_images = 0;
- let mut string_messages = Vec::with_capacity(messages.len());
-
- for message in messages {
- let mut string_contents = String::new();
-
- for content in message.content {
- match content {
- MessageContent::Text(text) => {
- string_contents.push_str(&text);
- }
- MessageContent::Thinking { .. } => {
- // Thinking blocks are not included in the input token count.
- }
- MessageContent::RedactedThinking(_) => {
- // Thinking blocks are not included in the input token count.
- }
- MessageContent::Image(image) => {
- tokens_from_images += image.estimate_tokens();
- }
- MessageContent::ToolUse(_tool_use) => {
- // TODO: Estimate token usage from tool uses.
- }
- MessageContent::ToolResult(tool_result) => match &tool_result.content {
- LanguageModelToolResultContent::Text(text) => {
- string_contents.push_str(text);
- }
- LanguageModelToolResultContent::Image(image) => {
- tokens_from_images += image.estimate_tokens();
- }
- },
- }
- }
-
- if !string_contents.is_empty() {
- string_messages.push(tiktoken_rs::ChatCompletionRequestMessage {
- role: match message.role {
- Role::User => "user".into(),
- Role::Assistant => "assistant".into(),
- Role::System => "system".into(),
- },
- content: Some(string_contents),
- name: None,
- function_call: None,
- });
- }
- }
-
- // Tiktoken doesn't yet support these models, so we manually use the
- // same tokenizer as GPT-4.
- tiktoken_rs::num_tokens_from_messages("gpt-4", &string_messages)
- .map(|tokens| (tokens + tokens_from_images) as u64)
-}
-
impl AnthropicModel {
fn stream_completion(
&self,
@@ -617,10 +391,13 @@ impl LanguageModel for AnthropicModel {
)
});
+ let background = cx.background_executor().clone();
async move {
// If no API key, fall back to tiktoken estimation
let Some(api_key) = api_key else {
- return count_anthropic_tokens_with_tiktoken(request);
+ return background
+ .spawn(async move { count_anthropic_tokens_with_tiktoken(request) })
+ .await;
};
let count_request =
@@ -634,7 +411,9 @@ impl LanguageModel for AnthropicModel {
log::error!(
"Anthropic count_tokens API failed, falling back to tiktoken: {err:?}"
);
- count_anthropic_tokens_with_tiktoken(request)
+ background
+ .spawn(async move { count_anthropic_tokens_with_tiktoken(request) })
+ .await
}
}
}
@@ -678,345 +457,6 @@ impl LanguageModel for AnthropicModel {
}
}
-pub fn into_anthropic(
- request: LanguageModelRequest,
- model: String,
- default_temperature: f32,
- max_output_tokens: u64,
- mode: AnthropicModelMode,
-) -> anthropic::Request {
- let mut new_messages: Vec<anthropic::Message> = Vec::new();
- let mut system_message = String::new();
-
- for message in request.messages {
- if message.contents_empty() {
- continue;
- }
-
- match message.role {
- Role::User | Role::Assistant => {
- let mut anthropic_message_content: Vec<anthropic::RequestContent> = message
- .content
- .into_iter()
- .filter_map(to_anthropic_content)
- .collect();
- let anthropic_role = match message.role {
- Role::User => anthropic::Role::User,
- Role::Assistant => anthropic::Role::Assistant,
- Role::System => unreachable!("System role should never occur here"),
- };
- if anthropic_message_content.is_empty() {
- continue;
- }
-
- if let Some(last_message) = new_messages.last_mut()
- && last_message.role == anthropic_role
- {
- last_message.content.extend(anthropic_message_content);
- continue;
- }
-
- // Mark the last segment of the message as cached
- if message.cache {
- let cache_control_value = Some(anthropic::CacheControl {
- cache_type: anthropic::CacheControlType::Ephemeral,
- });
- for message_content in anthropic_message_content.iter_mut().rev() {
- match message_content {
- anthropic::RequestContent::RedactedThinking { .. } => {
- // Caching is not possible, fallback to next message
- }
- anthropic::RequestContent::Text { cache_control, .. }
- | anthropic::RequestContent::Thinking { cache_control, .. }
- | anthropic::RequestContent::Image { cache_control, .. }
- | anthropic::RequestContent::ToolUse { cache_control, .. }
- | anthropic::RequestContent::ToolResult { cache_control, .. } => {
- *cache_control = cache_control_value;
- break;
- }
- }
- }
- }
-
- new_messages.push(anthropic::Message {
- role: anthropic_role,
- content: anthropic_message_content,
- });
- }
- Role::System => {
- if !system_message.is_empty() {
- system_message.push_str("\n\n");
- }
- system_message.push_str(&message.string_contents());
- }
- }
- }
-
- anthropic::Request {
- model,
- messages: new_messages,
- max_tokens: max_output_tokens,
- system: if system_message.is_empty() {
- None
- } else {
- Some(anthropic::StringOrContents::String(system_message))
- },
- thinking: if request.thinking_allowed {
- match mode {
- AnthropicModelMode::Thinking { budget_tokens } => {
- Some(anthropic::Thinking::Enabled { budget_tokens })
- }
- AnthropicModelMode::AdaptiveThinking => Some(anthropic::Thinking::Adaptive),
- AnthropicModelMode::Default => None,
- }
- } else {
- None
- },
- tools: request
- .tools
- .into_iter()
- .map(|tool| anthropic::Tool {
- name: tool.name,
- description: tool.description,
- input_schema: tool.input_schema,
- eager_input_streaming: tool.use_input_streaming,
- })
- .collect(),
- tool_choice: request.tool_choice.map(|choice| match choice {
- LanguageModelToolChoice::Auto => anthropic::ToolChoice::Auto,
- LanguageModelToolChoice::Any => anthropic::ToolChoice::Any,
- LanguageModelToolChoice::None => anthropic::ToolChoice::None,
- }),
- metadata: None,
- output_config: if request.thinking_allowed
- && matches!(mode, AnthropicModelMode::AdaptiveThinking)
- {
- request.thinking_effort.as_deref().and_then(|effort| {
- let effort = match effort {
- "low" => Some(anthropic::Effort::Low),
- "medium" => Some(anthropic::Effort::Medium),
- "high" => Some(anthropic::Effort::High),
- "max" => Some(anthropic::Effort::Max),
- _ => None,
- };
- effort.map(|effort| anthropic::OutputConfig {
- effort: Some(effort),
- })
- })
- } else {
- None
- },
- stop_sequences: Vec::new(),
- speed: request.speed.map(From::from),
- temperature: request.temperature.or(Some(default_temperature)),
- top_k: None,
- top_p: None,
- }
-}
-
-pub struct AnthropicEventMapper {
- tool_uses_by_index: HashMap<usize, RawToolUse>,
- usage: Usage,
- stop_reason: StopReason,
-}
-
-impl AnthropicEventMapper {
- pub fn new() -> Self {
- Self {
- tool_uses_by_index: HashMap::default(),
- usage: Usage::default(),
- stop_reason: StopReason::EndTurn,
- }
- }
-
- pub fn map_stream(
- mut self,
- events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
- ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
- {
- events.flat_map(move |event| {
- futures::stream::iter(match event {
- Ok(event) => self.map_event(event),
- Err(error) => vec![Err(error.into())],
- })
- })
- }
-
- pub fn map_event(
- &mut self,
- event: Event,
- ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
- match event {
- Event::ContentBlockStart {
- index,
- content_block,
- } => match content_block {
- ResponseContent::Text { text } => {
- vec![Ok(LanguageModelCompletionEvent::Text(text))]
- }
- ResponseContent::Thinking { thinking } => {
- vec![Ok(LanguageModelCompletionEvent::Thinking {
- text: thinking,
- signature: None,
- })]
- }
- ResponseContent::RedactedThinking { data } => {
- vec![Ok(LanguageModelCompletionEvent::RedactedThinking { data })]
- }
- ResponseContent::ToolUse { id, name, .. } => {
- self.tool_uses_by_index.insert(
- index,
- RawToolUse {
- id,
- name,
- input_json: String::new(),
- },
- );
- Vec::new()
- }
- },
- Event::ContentBlockDelta { index, delta } => match delta {
- ContentDelta::TextDelta { text } => {
- vec![Ok(LanguageModelCompletionEvent::Text(text))]
- }
- ContentDelta::ThinkingDelta { thinking } => {
- vec![Ok(LanguageModelCompletionEvent::Thinking {
- text: thinking,
- signature: None,
- })]
- }
- ContentDelta::SignatureDelta { signature } => {
- vec![Ok(LanguageModelCompletionEvent::Thinking {
- text: "".to_string(),
- signature: Some(signature),
- })]
- }
- ContentDelta::InputJsonDelta { partial_json } => {
- if let Some(tool_use) = self.tool_uses_by_index.get_mut(&index) {
- tool_use.input_json.push_str(&partial_json);
-
- // Try to convert invalid (incomplete) JSON into
- // valid JSON that serde can accept, e.g. by closing
- // unclosed delimiters. This way, we can update the
- // UI with whatever has been streamed back so far.
- if let Ok(input) =
- serde_json::Value::from_str(&fix_streamed_json(&tool_use.input_json))
- {
- return vec![Ok(LanguageModelCompletionEvent::ToolUse(
- LanguageModelToolUse {
- id: tool_use.id.clone().into(),
- name: tool_use.name.clone().into(),
- is_input_complete: false,
- raw_input: tool_use.input_json.clone(),
- input,
- thought_signature: None,
- },
- ))];
- }
- }
- vec![]
- }
- },
- Event::ContentBlockStop { index } => {
- if let Some(tool_use) = self.tool_uses_by_index.remove(&index) {
- let input_json = tool_use.input_json.trim();
- let event_result = match parse_tool_arguments(input_json) {
- Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
- LanguageModelToolUse {
- id: tool_use.id.into(),
- name: tool_use.name.into(),
- is_input_complete: true,
- input,
- raw_input: tool_use.input_json.clone(),
- thought_signature: None,
- },
- )),
- Err(json_parse_err) => {
- Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
- id: tool_use.id.into(),
- tool_name: tool_use.name.into(),
- raw_input: input_json.into(),
- json_parse_error: json_parse_err.to_string(),
- })
- }
- };
-
- vec![event_result]
- } else {
- Vec::new()
- }
- }
- Event::MessageStart { message } => {
- update_usage(&mut self.usage, &message.usage);
- vec![
- Ok(LanguageModelCompletionEvent::UsageUpdate(convert_usage(
- &self.usage,
- ))),
- Ok(LanguageModelCompletionEvent::StartMessage {
- message_id: message.id,
- }),
- ]
- }
- Event::MessageDelta { delta, usage } => {
- update_usage(&mut self.usage, &usage);
- if let Some(stop_reason) = delta.stop_reason.as_deref() {
- self.stop_reason = match stop_reason {
- "end_turn" => StopReason::EndTurn,
- "max_tokens" => StopReason::MaxTokens,
- "tool_use" => StopReason::ToolUse,
- "refusal" => StopReason::Refusal,
- _ => {
- log::error!("Unexpected anthropic stop_reason: {stop_reason}");
- StopReason::EndTurn
- }
- };
- }
- vec![Ok(LanguageModelCompletionEvent::UsageUpdate(
- convert_usage(&self.usage),
- ))]
- }
- Event::MessageStop => {
- vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))]
- }
- Event::Error { error } => {
- vec![Err(error.into())]
- }
- _ => Vec::new(),
- }
- }
-}
-
-struct RawToolUse {
- id: String,
- name: String,
- input_json: String,
-}
-
-/// Updates usage data by preferring counts from `new`.
-fn update_usage(usage: &mut Usage, new: &Usage) {
- if let Some(input_tokens) = new.input_tokens {
- usage.input_tokens = Some(input_tokens);
- }
- if let Some(output_tokens) = new.output_tokens {
- usage.output_tokens = Some(output_tokens);
- }
- if let Some(cache_creation_input_tokens) = new.cache_creation_input_tokens {
- usage.cache_creation_input_tokens = Some(cache_creation_input_tokens);
- }
- if let Some(cache_read_input_tokens) = new.cache_read_input_tokens {
- usage.cache_read_input_tokens = Some(cache_read_input_tokens);
- }
-}
-
-fn convert_usage(usage: &Usage) -> language_model::TokenUsage {
- language_model::TokenUsage {
- input_tokens: usage.input_tokens.unwrap_or(0),
- output_tokens: usage.output_tokens.unwrap_or(0),
- cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0),
- cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0),
- }
-}
-
struct ConfigurationView {
api_key_editor: Entity<InputField>,
state: Entity<State>,
@@ -1157,192 +597,3 @@ impl Render for ConfigurationView {
}
}
}
-
-#[cfg(test)]
-mod tests {
- use super::*;
- use anthropic::AnthropicModelMode;
- use language_model::{LanguageModelRequestMessage, MessageContent};
-
- #[test]
- fn test_cache_control_only_on_last_segment() {
- let request = LanguageModelRequest {
- messages: vec![LanguageModelRequestMessage {
- role: Role::User,
- content: vec![
- MessageContent::Text("Some prompt".to_string()),
- MessageContent::Image(language_model::LanguageModelImage::empty()),
- MessageContent::Image(language_model::LanguageModelImage::empty()),
- MessageContent::Image(language_model::LanguageModelImage::empty()),
- MessageContent::Image(language_model::LanguageModelImage::empty()),
- ],
- cache: true,
- reasoning_details: None,
- }],
- thread_id: None,
- prompt_id: None,
- intent: None,
- stop: vec![],
- temperature: None,
- tools: vec![],
- tool_choice: None,
- thinking_allowed: true,
- thinking_effort: None,
- speed: None,
- };
-
- let anthropic_request = into_anthropic(
- request,
- "claude-3-5-sonnet".to_string(),
- 0.7,
- 4096,
- AnthropicModelMode::Default,
- );
-
- assert_eq!(anthropic_request.messages.len(), 1);
-
- let message = &anthropic_request.messages[0];
- assert_eq!(message.content.len(), 5);
-
- assert!(matches!(
- message.content[0],
- anthropic::RequestContent::Text {
- cache_control: None,
- ..
- }
- ));
- for i in 1..3 {
- assert!(matches!(
- message.content[i],
- anthropic::RequestContent::Image {
- cache_control: None,
- ..
- }
- ));
- }
-
- assert!(matches!(
- message.content[4],
- anthropic::RequestContent::Image {
- cache_control: Some(anthropic::CacheControl {
- cache_type: anthropic::CacheControlType::Ephemeral,
- }),
- ..
- }
- ));
- }
-
- fn request_with_assistant_content(
- assistant_content: Vec<MessageContent>,
- ) -> anthropic::Request {
- let mut request = LanguageModelRequest {
- messages: vec![LanguageModelRequestMessage {
- role: Role::User,
- content: vec![MessageContent::Text("Hello".to_string())],
- cache: false,
- reasoning_details: None,
- }],
- thinking_effort: None,
- thread_id: None,
- prompt_id: None,
- intent: None,
- stop: vec![],
- temperature: None,
- tools: vec![],
- tool_choice: None,
- thinking_allowed: true,
- speed: None,
- };
- request.messages.push(LanguageModelRequestMessage {
- role: Role::Assistant,
- content: assistant_content,
- cache: false,
- reasoning_details: None,
- });
- into_anthropic(
- request,
- "claude-sonnet-4-5".to_string(),
- 1.0,
- 16000,
- AnthropicModelMode::Thinking {
- budget_tokens: Some(10000),
- },
- )
- }
-
- #[test]
- fn test_unsigned_thinking_blocks_stripped() {
- let result = request_with_assistant_content(vec![
- MessageContent::Thinking {
- text: "Cancelled mid-think, no signature".to_string(),
- signature: None,
- },
- MessageContent::Text("Some response text".to_string()),
- ]);
-
- let assistant_message = result
- .messages
- .iter()
- .find(|m| m.role == anthropic::Role::Assistant)
- .expect("assistant message should still exist");
-
- assert_eq!(
- assistant_message.content.len(),
- 1,
- "Only the text content should remain; unsigned thinking block should be stripped"
- );
- assert!(matches!(
- &assistant_message.content[0],
- anthropic::RequestContent::Text { text, .. } if text == "Some response text"
- ));
- }
-
- #[test]
- fn test_signed_thinking_blocks_preserved() {
- let result = request_with_assistant_content(vec![
- MessageContent::Thinking {
- text: "Completed thinking".to_string(),
- signature: Some("valid-signature".to_string()),
- },
- MessageContent::Text("Response".to_string()),
- ]);
-
- let assistant_message = result
- .messages
- .iter()
- .find(|m| m.role == anthropic::Role::Assistant)
- .expect("assistant message should exist");
-
- assert_eq!(
- assistant_message.content.len(),
- 2,
- "Both the signed thinking block and text should be preserved"
- );
- assert!(matches!(
- &assistant_message.content[0],
- anthropic::RequestContent::Thinking { thinking, signature, .. }
- if thinking == "Completed thinking" && signature == "valid-signature"
- ));
- }
-
- #[test]
- fn test_only_unsigned_thinking_block_omits_entire_message() {
- let result = request_with_assistant_content(vec![MessageContent::Thinking {
- text: "Cancelled before any text or signature".to_string(),
- signature: None,
- }]);
-
- let assistant_messages: Vec<_> = result
- .messages
- .iter()
- .filter(|m| m.role == anthropic::Role::Assistant)
- .collect();
-
- assert_eq!(
- assistant_messages.len(),
- 0,
- "An assistant message whose only content was an unsigned thinking block \
- should be omitted entirely"
- );
- }
-}
@@ -48,7 +48,7 @@ use ui_input::InputField;
use util::ResultExt;
use crate::AllLanguageModelSettings;
-use crate::provider::util::{fix_streamed_json, parse_tool_arguments};
+use language_model::util::{fix_streamed_json, parse_tool_arguments};
actions!(bedrock, [Tab, TabPrev]);
@@ -1,107 +1,93 @@
use ai_onboarding::YoungAccountBanner;
-use anthropic::AnthropicModelMode;
-use anyhow::{Context as _, Result, anyhow};
-use client::{
- Client, NeedsLlmTokenRefresh, RefreshLlmTokenListener, UserStore, global_llm_token, zed_urls,
-};
-use cloud_api_types::{OrganizationId, Plan};
-use cloud_llm_client::{
- CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME,
- CLIENT_SUPPORTS_X_AI_HEADER_NAME, CompletionBody, CompletionEvent, CompletionRequestStatus,
- CountTokensBody, CountTokensResponse, ListModelsResponse,
- SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, ZED_VERSION_HEADER_NAME,
-};
-use futures::{
- AsyncBufReadExt, FutureExt, Stream, StreamExt,
- future::BoxFuture,
- stream::{self, BoxStream},
-};
-use google_ai::GoogleModelMode;
-use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Task};
-use http_client::http::{HeaderMap, HeaderValue};
-use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Response, StatusCode};
+use anyhow::Result;
+use client::{Client, RefreshLlmTokenListener, UserStore, global_llm_token, zed_urls};
+use cloud_api_client::LlmApiToken;
+use cloud_api_types::OrganizationId;
+use cloud_api_types::Plan;
+use futures::StreamExt;
+use futures::future::BoxFuture;
+use gpui::AsyncApp;
+use gpui::{AnyElement, AnyView, App, Context, Entity, Subscription, Task};
use language_model::{
- ANTHROPIC_PROVIDER_ID, ANTHROPIC_PROVIDER_NAME, AuthenticateError, GOOGLE_PROVIDER_ID,
- GOOGLE_PROVIDER_NAME, IconOrSvg, LanguageModel, LanguageModelCacheConfiguration,
- LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelEffortLevel,
- LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
- LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
- LanguageModelToolChoice, LanguageModelToolSchemaFormat, LlmApiToken, OPEN_AI_PROVIDER_ID,
- OPEN_AI_PROVIDER_NAME, PaymentRequiredError, RateLimiter, X_AI_PROVIDER_ID, X_AI_PROVIDER_NAME,
- ZED_CLOUD_PROVIDER_ID, ZED_CLOUD_PROVIDER_NAME,
+ AuthenticateError, IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelProviderId,
+ LanguageModelProviderName, LanguageModelProviderState, ZED_CLOUD_PROVIDER_ID,
+ ZED_CLOUD_PROVIDER_NAME,
};
+use language_models_cloud::{CloudLlmTokenProvider, CloudModelProvider};
use release_channel::AppVersion;
-use schemars::JsonSchema;
-use semver::Version;
-use serde::{Deserialize, Serialize, de::DeserializeOwned};
+
use settings::SettingsStore;
pub use settings::ZedDotDevAvailableModel as AvailableModel;
pub use settings::ZedDotDevAvailableProvider as AvailableProvider;
-use smol::io::{AsyncReadExt, BufReader};
-use std::collections::VecDeque;
-use std::pin::Pin;
-use std::str::FromStr;
use std::sync::Arc;
-use std::task::Poll;
-use std::time::Duration;
-use thiserror::Error;
use ui::{TintColor, prelude::*};
-use crate::provider::anthropic::{
- AnthropicEventMapper, count_anthropic_tokens_with_tiktoken, into_anthropic,
-};
-use crate::provider::google::{GoogleEventMapper, into_google};
-use crate::provider::open_ai::{
- OpenAiEventMapper, OpenAiResponseEventMapper, count_open_ai_tokens, into_open_ai,
- into_open_ai_response,
-};
-use crate::provider::x_ai::count_xai_tokens;
-
const PROVIDER_ID: LanguageModelProviderId = ZED_CLOUD_PROVIDER_ID;
const PROVIDER_NAME: LanguageModelProviderName = ZED_CLOUD_PROVIDER_NAME;
-#[derive(Default, Clone, Debug, PartialEq)]
-pub struct ZedDotDevSettings {
- pub available_models: Vec<AvailableModel>,
-}
-#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
-#[serde(tag = "type", rename_all = "lowercase")]
-pub enum ModelMode {
- #[default]
- Default,
- Thinking {
- /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
- budget_tokens: Option<u32>,
- },
+struct ClientTokenProvider {
+ client: Arc<Client>,
+ llm_api_token: LlmApiToken,
+ user_store: Entity<UserStore>,
}
-impl From<ModelMode> for AnthropicModelMode {
- fn from(value: ModelMode) -> Self {
- match value {
- ModelMode::Default => AnthropicModelMode::Default,
- ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens },
- }
+impl CloudLlmTokenProvider for ClientTokenProvider {
+ type AuthContext = Option<OrganizationId>;
+
+ fn auth_context(&self, cx: &AsyncApp) -> Self::AuthContext {
+ self.user_store.read_with(cx, |user_store, _| {
+ user_store
+ .current_organization()
+ .map(|organization| organization.id.clone())
+ })
}
+
+ fn acquire_token(
+ &self,
+ organization_id: Self::AuthContext,
+ ) -> BoxFuture<'static, Result<String>> {
+ let client = self.client.clone();
+ let llm_api_token = self.llm_api_token.clone();
+ Box::pin(async move {
+ client
+ .acquire_llm_token(&llm_api_token, organization_id)
+ .await
+ })
+ }
+
+ fn refresh_token(
+ &self,
+ organization_id: Self::AuthContext,
+ ) -> BoxFuture<'static, Result<String>> {
+ let client = self.client.clone();
+ let llm_api_token = self.llm_api_token.clone();
+ Box::pin(async move {
+ client
+ .refresh_llm_token(&llm_api_token, organization_id)
+ .await
+ })
+ }
+}
+
+#[derive(Default, Clone, Debug, PartialEq)]
+pub struct ZedDotDevSettings {
+ pub available_models: Vec<AvailableModel>,
}
pub struct CloudLanguageModelProvider {
- client: Arc<Client>,
state: Entity<State>,
_maintain_client_status: Task<()>,
}
pub struct State {
client: Arc<Client>,
- llm_api_token: LlmApiToken,
user_store: Entity<UserStore>,
status: client::Status,
- models: Vec<Arc<cloud_llm_client::LanguageModel>>,
- default_model: Option<Arc<cloud_llm_client::LanguageModel>>,
- default_fast_model: Option<Arc<cloud_llm_client::LanguageModel>>,
- recommended_models: Vec<Arc<cloud_llm_client::LanguageModel>>,
+ provider: Entity<CloudModelProvider<ClientTokenProvider>>,
_user_store_subscription: Subscription,
_settings_subscription: Subscription,
_llm_token_subscription: Subscription,
+ _provider_subscription: Subscription,
}
impl State {
@@ -112,16 +98,26 @@ impl State {
cx: &mut Context<Self>,
) -> Self {
let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
- let llm_api_token = global_llm_token(cx);
+ let token_provider = Arc::new(ClientTokenProvider {
+ client: client.clone(),
+ llm_api_token: global_llm_token(cx),
+ user_store: user_store.clone(),
+ });
+
+ let provider = cx.new(|cx| {
+ CloudModelProvider::new(
+ token_provider.clone(),
+ client.http_client(),
+ Some(AppVersion::global(cx)),
+ )
+ });
+
Self {
client: client.clone(),
- llm_api_token,
user_store: user_store.clone(),
status,
- models: Vec::new(),
- default_model: None,
- default_fast_model: None,
- recommended_models: Vec::new(),
+ _provider_subscription: cx.observe(&provider, |_, _, cx| cx.notify()),
+ provider,
_user_store_subscription: cx.subscribe(
&user_store,
move |this, _user_store, event, cx| match event {
@@ -131,19 +127,7 @@ impl State {
return;
}
- let client = this.client.clone();
- let llm_api_token = this.llm_api_token.clone();
- let organization_id = this
- .user_store
- .read(cx)
- .current_organization()
- .map(|organization| organization.id.clone());
- cx.spawn(async move |this, cx| {
- let response =
- Self::fetch_models(client, llm_api_token, organization_id).await?;
- this.update(cx, |this, cx| this.update_models(response, cx))
- })
- .detach_and_log_err(cx);
+ this.refresh_models(cx);
}
_ => {}
},
@@ -154,21 +138,7 @@ impl State {
_llm_token_subscription: cx.subscribe(
&refresh_llm_token_listener,
move |this, _listener, _event, cx| {
- let client = this.client.clone();
- let llm_api_token = this.llm_api_token.clone();
- let organization_id = this
- .user_store
- .read(cx)
- .current_organization()
- .map(|organization| organization.id.clone());
- cx.spawn(async move |this, cx| {
- let response =
- Self::fetch_models(client, llm_api_token, organization_id).await?;
- this.update(cx, |this, cx| {
- this.update_models(response, cx);
- })
- })
- .detach_and_log_err(cx);
+ this.refresh_models(cx);
},
),
}
@@ -186,74 +156,10 @@ impl State {
})
}
- fn update_models(&mut self, response: ListModelsResponse, cx: &mut Context<Self>) {
- let mut models = Vec::new();
-
- for model in response.models {
- models.push(Arc::new(model.clone()));
- }
-
- self.default_model = models
- .iter()
- .find(|model| {
- response
- .default_model
- .as_ref()
- .is_some_and(|default_model_id| &model.id == default_model_id)
- })
- .cloned();
- self.default_fast_model = models
- .iter()
- .find(|model| {
- response
- .default_fast_model
- .as_ref()
- .is_some_and(|default_fast_model_id| &model.id == default_fast_model_id)
- })
- .cloned();
- self.recommended_models = response
- .recommended_models
- .iter()
- .filter_map(|id| models.iter().find(|model| &model.id == id))
- .cloned()
- .collect();
- self.models = models;
- cx.notify();
- }
-
- async fn fetch_models(
- client: Arc<Client>,
- llm_api_token: LlmApiToken,
- organization_id: Option<OrganizationId>,
- ) -> Result<ListModelsResponse> {
- let http_client = &client.http_client();
- let token = client
- .acquire_llm_token(&llm_api_token, organization_id)
- .await?;
-
- let request = http_client::Request::builder()
- .method(Method::GET)
- .header(CLIENT_SUPPORTS_X_AI_HEADER_NAME, "true")
- .uri(http_client.build_zed_llm_url("/models", &[])?.as_ref())
- .header("Authorization", format!("Bearer {token}"))
- .body(AsyncBody::empty())?;
- let mut response = http_client
- .send(request)
- .await
- .context("failed to send list models request")?;
-
- if response.status().is_success() {
- let mut body = String::new();
- response.body_mut().read_to_string(&mut body).await?;
- Ok(serde_json::from_str(&body)?)
- } else {
- let mut body = String::new();
- response.body_mut().read_to_string(&mut body).await?;
- anyhow::bail!(
- "error listing models.\nStatus: {:?}\nBody: {body}",
- response.status(),
- );
- }
+ fn refresh_models(&mut self, cx: &mut Context<Self>) {
+ self.provider.update(cx, |provider, cx| {
+ provider.refresh_models(cx).detach_and_log_err(cx);
+ });
}
}
@@ -281,27 +187,10 @@ impl CloudLanguageModelProvider {
});
Self {
- client,
state,
_maintain_client_status: maintain_client_status,
}
}
-
- fn create_language_model(
- &self,
- model: Arc<cloud_llm_client::LanguageModel>,
- llm_api_token: LlmApiToken,
- user_store: Entity<UserStore>,
- ) -> Arc<dyn LanguageModel> {
- Arc::new(CloudLanguageModel {
- id: LanguageModelId(SharedString::from(model.id.0.clone())),
- model,
- llm_api_token,
- user_store,
- client: self.client.clone(),
- request_limiter: RateLimiter::new(4),
- })
- }
}
impl LanguageModelProviderState for CloudLanguageModelProvider {
@@ -327,45 +216,35 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
let state = self.state.read(cx);
- let default_model = state.default_model.clone()?;
- let llm_api_token = state.llm_api_token.clone();
- let user_store = state.user_store.clone();
- Some(self.create_language_model(default_model, llm_api_token, user_store))
+ let provider = state.provider.read(cx);
+ let model = provider.default_model()?;
+ Some(provider.create_model(model))
}
fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
let state = self.state.read(cx);
- let default_fast_model = state.default_fast_model.clone()?;
- let llm_api_token = state.llm_api_token.clone();
- let user_store = state.user_store.clone();
- Some(self.create_language_model(default_fast_model, llm_api_token, user_store))
+ let provider = state.provider.read(cx);
+ let model = provider.default_fast_model()?;
+ Some(provider.create_model(model))
}
fn recommended_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
let state = self.state.read(cx);
- let llm_api_token = state.llm_api_token.clone();
- let user_store = state.user_store.clone();
- state
- .recommended_models
+ let provider = state.provider.read(cx);
+ provider
+ .recommended_models()
.iter()
- .cloned()
- .map(|model| {
- self.create_language_model(model, llm_api_token.clone(), user_store.clone())
- })
+ .map(|model| provider.create_model(model))
.collect()
}
fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
let state = self.state.read(cx);
- let llm_api_token = state.llm_api_token.clone();
- let user_store = state.user_store.clone();
- state
- .models
+ let provider = state.provider.read(cx);
+ provider
+ .models()
.iter()
- .cloned()
- .map(|model| {
- self.create_language_model(model, llm_api_token.clone(), user_store.clone())
- })
+ .map(|model| provider.create_model(model))
.collect()
}
@@ -393,700 +272,6 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
}
}
-pub struct CloudLanguageModel {
- id: LanguageModelId,
- model: Arc<cloud_llm_client::LanguageModel>,
- llm_api_token: LlmApiToken,
- user_store: Entity<UserStore>,
- client: Arc<Client>,
- request_limiter: RateLimiter,
-}
-
-struct PerformLlmCompletionResponse {
- response: Response<AsyncBody>,
- includes_status_messages: bool,
-}
-
-impl CloudLanguageModel {
- async fn perform_llm_completion(
- client: Arc<Client>,
- llm_api_token: LlmApiToken,
- organization_id: Option<OrganizationId>,
- app_version: Option<Version>,
- body: CompletionBody,
- ) -> Result<PerformLlmCompletionResponse> {
- let http_client = &client.http_client();
-
- let mut token = client
- .acquire_llm_token(&llm_api_token, organization_id.clone())
- .await?;
- let mut refreshed_token = false;
-
- loop {
- let request = http_client::Request::builder()
- .method(Method::POST)
- .uri(http_client.build_zed_llm_url("/completions", &[])?.as_ref())
- .when_some(app_version.as_ref(), |builder, app_version| {
- builder.header(ZED_VERSION_HEADER_NAME, app_version.to_string())
- })
- .header("Content-Type", "application/json")
- .header("Authorization", format!("Bearer {token}"))
- .header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true")
- .header(CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME, "true")
- .body(serde_json::to_string(&body)?.into())?;
-
- let mut response = http_client.send(request).await?;
- let status = response.status();
- if status.is_success() {
- let includes_status_messages = response
- .headers()
- .get(SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME)
- .is_some();
-
- return Ok(PerformLlmCompletionResponse {
- response,
- includes_status_messages,
- });
- }
-
- if !refreshed_token && response.needs_llm_token_refresh() {
- token = client
- .refresh_llm_token(&llm_api_token, organization_id.clone())
- .await?;
- refreshed_token = true;
- continue;
- }
-
- if status == StatusCode::PAYMENT_REQUIRED {
- return Err(anyhow!(PaymentRequiredError));
- }
-
- let mut body = String::new();
- let headers = response.headers().clone();
- response.body_mut().read_to_string(&mut body).await?;
- return Err(anyhow!(ApiError {
- status,
- body,
- headers
- }));
- }
- }
-}
-
-#[derive(Debug, Error)]
-#[error("cloud language model request failed with status {status}: {body}")]
-struct ApiError {
- status: StatusCode,
- body: String,
- headers: HeaderMap<HeaderValue>,
-}
-
-/// Represents error responses from Zed's cloud API.
-///
-/// Example JSON for an upstream HTTP error:
-/// ```json
-/// {
-/// "code": "upstream_http_error",
-/// "message": "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout",
-/// "upstream_status": 503
-/// }
-/// ```
-#[derive(Debug, serde::Deserialize)]
-struct CloudApiError {
- code: String,
- message: String,
- #[serde(default)]
- #[serde(deserialize_with = "deserialize_optional_status_code")]
- upstream_status: Option<StatusCode>,
- #[serde(default)]
- retry_after: Option<f64>,
-}
-
-fn deserialize_optional_status_code<'de, D>(deserializer: D) -> Result<Option<StatusCode>, D::Error>
-where
- D: serde::Deserializer<'de>,
-{
- let opt: Option<u16> = Option::deserialize(deserializer)?;
- Ok(opt.and_then(|code| StatusCode::from_u16(code).ok()))
-}
-
-impl From<ApiError> for LanguageModelCompletionError {
- fn from(error: ApiError) -> Self {
- if let Ok(cloud_error) = serde_json::from_str::<CloudApiError>(&error.body) {
- if cloud_error.code.starts_with("upstream_http_") {
- let status = if let Some(status) = cloud_error.upstream_status {
- status
- } else if cloud_error.code.ends_with("_error") {
- error.status
- } else {
- // If there's a status code in the code string (e.g. "upstream_http_429")
- // then use that; otherwise, see if the JSON contains a status code.
- cloud_error
- .code
- .strip_prefix("upstream_http_")
- .and_then(|code_str| code_str.parse::<u16>().ok())
- .and_then(|code| StatusCode::from_u16(code).ok())
- .unwrap_or(error.status)
- };
-
- return LanguageModelCompletionError::UpstreamProviderError {
- message: cloud_error.message,
- status,
- retry_after: cloud_error.retry_after.map(Duration::from_secs_f64),
- };
- }
-
- return LanguageModelCompletionError::from_http_status(
- PROVIDER_NAME,
- error.status,
- cloud_error.message,
- None,
- );
- }
-
- let retry_after = None;
- LanguageModelCompletionError::from_http_status(
- PROVIDER_NAME,
- error.status,
- error.body,
- retry_after,
- )
- }
-}
-
-impl LanguageModel for CloudLanguageModel {
- fn id(&self) -> LanguageModelId {
- self.id.clone()
- }
-
- fn name(&self) -> LanguageModelName {
- LanguageModelName::from(self.model.display_name.clone())
- }
-
- fn provider_id(&self) -> LanguageModelProviderId {
- PROVIDER_ID
- }
-
- fn provider_name(&self) -> LanguageModelProviderName {
- PROVIDER_NAME
- }
-
- fn upstream_provider_id(&self) -> LanguageModelProviderId {
- use cloud_llm_client::LanguageModelProvider::*;
- match self.model.provider {
- Anthropic => ANTHROPIC_PROVIDER_ID,
- OpenAi => OPEN_AI_PROVIDER_ID,
- Google => GOOGLE_PROVIDER_ID,
- XAi => X_AI_PROVIDER_ID,
- }
- }
-
- fn upstream_provider_name(&self) -> LanguageModelProviderName {
- use cloud_llm_client::LanguageModelProvider::*;
- match self.model.provider {
- Anthropic => ANTHROPIC_PROVIDER_NAME,
- OpenAi => OPEN_AI_PROVIDER_NAME,
- Google => GOOGLE_PROVIDER_NAME,
- XAi => X_AI_PROVIDER_NAME,
- }
- }
-
- fn is_latest(&self) -> bool {
- self.model.is_latest
- }
-
- fn supports_tools(&self) -> bool {
- self.model.supports_tools
- }
-
- fn supports_images(&self) -> bool {
- self.model.supports_images
- }
-
- fn supports_thinking(&self) -> bool {
- self.model.supports_thinking
- }
-
- fn supports_fast_mode(&self) -> bool {
- self.model.supports_fast_mode
- }
-
- fn supported_effort_levels(&self) -> Vec<LanguageModelEffortLevel> {
- self.model
- .supported_effort_levels
- .iter()
- .map(|effort_level| LanguageModelEffortLevel {
- name: effort_level.name.clone().into(),
- value: effort_level.value.clone().into(),
- is_default: effort_level.is_default.unwrap_or(false),
- })
- .collect()
- }
-
- fn supports_streaming_tools(&self) -> bool {
- self.model.supports_streaming_tools
- }
-
- fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
- match choice {
- LanguageModelToolChoice::Auto
- | LanguageModelToolChoice::Any
- | LanguageModelToolChoice::None => true,
- }
- }
-
- fn supports_split_token_display(&self) -> bool {
- use cloud_llm_client::LanguageModelProvider::*;
- matches!(self.model.provider, OpenAi | XAi)
- }
-
- fn telemetry_id(&self) -> String {
- format!("zed.dev/{}", self.model.id)
- }
-
- fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
- match self.model.provider {
- cloud_llm_client::LanguageModelProvider::Anthropic
- | cloud_llm_client::LanguageModelProvider::OpenAi => {
- LanguageModelToolSchemaFormat::JsonSchema
- }
- cloud_llm_client::LanguageModelProvider::Google
- | cloud_llm_client::LanguageModelProvider::XAi => {
- LanguageModelToolSchemaFormat::JsonSchemaSubset
- }
- }
- }
-
- fn max_token_count(&self) -> u64 {
- self.model.max_token_count as u64
- }
-
- fn max_output_tokens(&self) -> Option<u64> {
- Some(self.model.max_output_tokens as u64)
- }
-
- fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
- match &self.model.provider {
- cloud_llm_client::LanguageModelProvider::Anthropic => {
- Some(LanguageModelCacheConfiguration {
- min_total_token: 2_048,
- should_speculate: true,
- max_cache_anchors: 4,
- })
- }
- cloud_llm_client::LanguageModelProvider::OpenAi
- | cloud_llm_client::LanguageModelProvider::XAi
- | cloud_llm_client::LanguageModelProvider::Google => None,
- }
- }
-
- fn count_tokens(
- &self,
- request: LanguageModelRequest,
- cx: &App,
- ) -> BoxFuture<'static, Result<u64>> {
- match self.model.provider {
- cloud_llm_client::LanguageModelProvider::Anthropic => cx
- .background_spawn(async move { count_anthropic_tokens_with_tiktoken(request) })
- .boxed(),
- cloud_llm_client::LanguageModelProvider::OpenAi => {
- let model = match open_ai::Model::from_id(&self.model.id.0) {
- Ok(model) => model,
- Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
- };
- count_open_ai_tokens(request, model, cx)
- }
- cloud_llm_client::LanguageModelProvider::XAi => {
- let model = match x_ai::Model::from_id(&self.model.id.0) {
- Ok(model) => model,
- Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
- };
- count_xai_tokens(request, model, cx)
- }
- cloud_llm_client::LanguageModelProvider::Google => {
- let client = self.client.clone();
- let llm_api_token = self.llm_api_token.clone();
- let organization_id = self
- .user_store
- .read(cx)
- .current_organization()
- .map(|organization| organization.id.clone());
- let model_id = self.model.id.to_string();
- let generate_content_request =
- into_google(request, model_id.clone(), GoogleModelMode::Default);
- async move {
- let http_client = &client.http_client();
- let token = client
- .acquire_llm_token(&llm_api_token, organization_id)
- .await?;
-
- let request_body = CountTokensBody {
- provider: cloud_llm_client::LanguageModelProvider::Google,
- model: model_id,
- provider_request: serde_json::to_value(&google_ai::CountTokensRequest {
- generate_content_request,
- })?,
- };
- let request = http_client::Request::builder()
- .method(Method::POST)
- .uri(
- http_client
- .build_zed_llm_url("/count_tokens", &[])?
- .as_ref(),
- )
- .header("Content-Type", "application/json")
- .header("Authorization", format!("Bearer {token}"))
- .body(serde_json::to_string(&request_body)?.into())?;
- let mut response = http_client.send(request).await?;
- let status = response.status();
- let headers = response.headers().clone();
- let mut response_body = String::new();
- response
- .body_mut()
- .read_to_string(&mut response_body)
- .await?;
-
- if status.is_success() {
- let response_body: CountTokensResponse =
- serde_json::from_str(&response_body)?;
-
- Ok(response_body.tokens as u64)
- } else {
- Err(anyhow!(ApiError {
- status,
- body: response_body,
- headers
- }))
- }
- }
- .boxed()
- }
- }
- }
-
- fn stream_completion(
- &self,
- request: LanguageModelRequest,
- cx: &AsyncApp,
- ) -> BoxFuture<
- 'static,
- Result<
- BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
- LanguageModelCompletionError,
- >,
- > {
- let thread_id = request.thread_id.clone();
- let prompt_id = request.prompt_id.clone();
- let app_version = Some(cx.update(|cx| AppVersion::global(cx)));
- let user_store = self.user_store.clone();
- let organization_id = cx.update(|cx| {
- user_store
- .read(cx)
- .current_organization()
- .map(|organization| organization.id.clone())
- });
- let thinking_allowed = request.thinking_allowed;
- let enable_thinking = thinking_allowed && self.model.supports_thinking;
- let provider_name = provider_name(&self.model.provider);
- match self.model.provider {
- cloud_llm_client::LanguageModelProvider::Anthropic => {
- let effort = request
- .thinking_effort
- .as_ref()
- .and_then(|effort| anthropic::Effort::from_str(effort).ok());
-
- let mut request = into_anthropic(
- request,
- self.model.id.to_string(),
- 1.0,
- self.model.max_output_tokens as u64,
- if enable_thinking {
- AnthropicModelMode::Thinking {
- budget_tokens: Some(4_096),
- }
- } else {
- AnthropicModelMode::Default
- },
- );
-
- if enable_thinking && effort.is_some() {
- request.thinking = Some(anthropic::Thinking::Adaptive);
- request.output_config = Some(anthropic::OutputConfig { effort });
- }
-
- let client = self.client.clone();
- let llm_api_token = self.llm_api_token.clone();
- let organization_id = organization_id.clone();
- let future = self.request_limiter.stream(async move {
- let PerformLlmCompletionResponse {
- response,
- includes_status_messages,
- } = Self::perform_llm_completion(
- client.clone(),
- llm_api_token,
- organization_id,
- app_version,
- CompletionBody {
- thread_id,
- prompt_id,
- provider: cloud_llm_client::LanguageModelProvider::Anthropic,
- model: request.model.clone(),
- provider_request: serde_json::to_value(&request)
- .map_err(|e| anyhow!(e))?,
- },
- )
- .await
- .map_err(|err| match err.downcast::<ApiError>() {
- Ok(api_err) => anyhow!(LanguageModelCompletionError::from(api_err)),
- Err(err) => anyhow!(err),
- })?;
-
- let mut mapper = AnthropicEventMapper::new();
- Ok(map_cloud_completion_events(
- Box::pin(response_lines(response, includes_status_messages)),
- &provider_name,
- move |event| mapper.map_event(event),
- ))
- });
- async move { Ok(future.await?.boxed()) }.boxed()
- }
- cloud_llm_client::LanguageModelProvider::OpenAi => {
- let client = self.client.clone();
- let llm_api_token = self.llm_api_token.clone();
- let organization_id = organization_id.clone();
- let effort = request
- .thinking_effort
- .as_ref()
- .and_then(|effort| open_ai::ReasoningEffort::from_str(effort).ok());
-
- let mut request = into_open_ai_response(
- request,
- &self.model.id.0,
- self.model.supports_parallel_tool_calls,
- true,
- None,
- None,
- );
-
- if enable_thinking && let Some(effort) = effort {
- request.reasoning = Some(open_ai::responses::ReasoningConfig {
- effort,
- summary: Some(open_ai::responses::ReasoningSummaryMode::Auto),
- });
- }
-
- let future = self.request_limiter.stream(async move {
- let PerformLlmCompletionResponse {
- response,
- includes_status_messages,
- } = Self::perform_llm_completion(
- client.clone(),
- llm_api_token,
- organization_id,
- app_version,
- CompletionBody {
- thread_id,
- prompt_id,
- provider: cloud_llm_client::LanguageModelProvider::OpenAi,
- model: request.model.clone(),
- provider_request: serde_json::to_value(&request)
- .map_err(|e| anyhow!(e))?,
- },
- )
- .await?;
-
- let mut mapper = OpenAiResponseEventMapper::new();
- Ok(map_cloud_completion_events(
- Box::pin(response_lines(response, includes_status_messages)),
- &provider_name,
- move |event| mapper.map_event(event),
- ))
- });
- async move { Ok(future.await?.boxed()) }.boxed()
- }
- cloud_llm_client::LanguageModelProvider::XAi => {
- let client = self.client.clone();
- let request = into_open_ai(
- request,
- &self.model.id.0,
- self.model.supports_parallel_tool_calls,
- false,
- None,
- None,
- );
- let llm_api_token = self.llm_api_token.clone();
- let organization_id = organization_id.clone();
- let future = self.request_limiter.stream(async move {
- let PerformLlmCompletionResponse {
- response,
- includes_status_messages,
- } = Self::perform_llm_completion(
- client.clone(),
- llm_api_token,
- organization_id,
- app_version,
- CompletionBody {
- thread_id,
- prompt_id,
- provider: cloud_llm_client::LanguageModelProvider::XAi,
- model: request.model.clone(),
- provider_request: serde_json::to_value(&request)
- .map_err(|e| anyhow!(e))?,
- },
- )
- .await?;
-
- let mut mapper = OpenAiEventMapper::new();
- Ok(map_cloud_completion_events(
- Box::pin(response_lines(response, includes_status_messages)),
- &provider_name,
- move |event| mapper.map_event(event),
- ))
- });
- async move { Ok(future.await?.boxed()) }.boxed()
- }
- cloud_llm_client::LanguageModelProvider::Google => {
- let client = self.client.clone();
- let request =
- into_google(request, self.model.id.to_string(), GoogleModelMode::Default);
- let llm_api_token = self.llm_api_token.clone();
- let future = self.request_limiter.stream(async move {
- let PerformLlmCompletionResponse {
- response,
- includes_status_messages,
- } = Self::perform_llm_completion(
- client.clone(),
- llm_api_token,
- organization_id,
- app_version,
- CompletionBody {
- thread_id,
- prompt_id,
- provider: cloud_llm_client::LanguageModelProvider::Google,
- model: request.model.model_id.clone(),
- provider_request: serde_json::to_value(&request)
- .map_err(|e| anyhow!(e))?,
- },
- )
- .await?;
-
- let mut mapper = GoogleEventMapper::new();
- Ok(map_cloud_completion_events(
- Box::pin(response_lines(response, includes_status_messages)),
- &provider_name,
- move |event| mapper.map_event(event),
- ))
- });
- async move { Ok(future.await?.boxed()) }.boxed()
- }
- }
- }
-}
-
-fn map_cloud_completion_events<T, F>(
- stream: Pin<Box<dyn Stream<Item = Result<CompletionEvent<T>>> + Send>>,
- provider: &LanguageModelProviderName,
- mut map_callback: F,
-) -> BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
-where
- T: DeserializeOwned + 'static,
- F: FnMut(T) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
- + Send
- + 'static,
-{
- let provider = provider.clone();
- let mut stream = stream.fuse();
-
- let mut saw_stream_ended = false;
-
- let mut done = false;
- let mut pending = VecDeque::new();
-
- stream::poll_fn(move |cx| {
- loop {
- if let Some(item) = pending.pop_front() {
- return Poll::Ready(Some(item));
- }
-
- if done {
- return Poll::Ready(None);
- }
-
- match stream.poll_next_unpin(cx) {
- Poll::Ready(Some(event)) => {
- let items = match event {
- Err(error) => {
- vec![Err(LanguageModelCompletionError::from(error))]
- }
- Ok(CompletionEvent::Status(CompletionRequestStatus::StreamEnded)) => {
- saw_stream_ended = true;
- vec![]
- }
- Ok(CompletionEvent::Status(status)) => {
- LanguageModelCompletionEvent::from_completion_request_status(
- status,
- provider.clone(),
- )
- .transpose()
- .map(|event| vec![event])
- .unwrap_or_default()
- }
- Ok(CompletionEvent::Event(event)) => map_callback(event),
- };
- pending.extend(items);
- }
- Poll::Ready(None) => {
- done = true;
-
- if !saw_stream_ended {
- return Poll::Ready(Some(Err(
- LanguageModelCompletionError::StreamEndedUnexpectedly {
- provider: provider.clone(),
- },
- )));
- }
- }
- Poll::Pending => return Poll::Pending,
- }
- }
- })
- .boxed()
-}
-
-fn provider_name(provider: &cloud_llm_client::LanguageModelProvider) -> LanguageModelProviderName {
- match provider {
- cloud_llm_client::LanguageModelProvider::Anthropic => ANTHROPIC_PROVIDER_NAME,
- cloud_llm_client::LanguageModelProvider::OpenAi => OPEN_AI_PROVIDER_NAME,
- cloud_llm_client::LanguageModelProvider::Google => GOOGLE_PROVIDER_NAME,
- cloud_llm_client::LanguageModelProvider::XAi => X_AI_PROVIDER_NAME,
- }
-}
-
-fn response_lines<T: DeserializeOwned>(
- response: Response<AsyncBody>,
- includes_status_messages: bool,
-) -> impl Stream<Item = Result<CompletionEvent<T>>> {
- futures::stream::try_unfold(
- (String::new(), BufReader::new(response.into_body())),
- move |(mut line, mut body)| async move {
- match body.read_line(&mut line).await {
- Ok(0) => Ok(None),
- Ok(_) => {
- let event = if includes_status_messages {
- serde_json::from_str::<CompletionEvent<T>>(&line)?
- } else {
- CompletionEvent::Event(serde_json::from_str::<T>(&line)?)
- };
-
- line.clear();
- Ok(Some((event, (line, body))))
- }
- Err(e) => Err(e.into()),
- }
- },
- )
-}
-
#[derive(IntoElement, RegisterComponent)]
struct ZedAiConfiguration {
is_connected: bool,
@@ -32,7 +32,7 @@ use ui::prelude::*;
use util::debug_panic;
use crate::provider::anthropic::{AnthropicEventMapper, into_anthropic};
-use crate::provider::util::{fix_streamed_json, parse_tool_arguments};
+use language_model::util::{fix_streamed_json, parse_tool_arguments};
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("copilot_chat");
const PROVIDER_NAME: LanguageModelProviderName =
@@ -268,15 +268,15 @@ impl LanguageModel for CopilotChatLanguageModel {
levels
.iter()
.map(|level| {
- let name: SharedString = match level.as_str() {
+ let name = match level.as_str() {
"low" => "Low".into(),
"medium" => "Medium".into(),
"high" => "High".into(),
- _ => SharedString::from(level.clone()),
+ _ => language_model::SharedString::from(level.clone()),
};
LanguageModelEffortLevel {
name,
- value: SharedString::from(level.clone()),
+ value: language_model::SharedString::from(level.clone()),
is_default: level == "high",
}
})
@@ -23,7 +23,7 @@ use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*};
use ui_input::InputField;
use util::ResultExt;
-use crate::provider::util::{fix_streamed_json, parse_tool_arguments};
+use language_model::util::{fix_streamed_json, parse_tool_arguments};
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("deepseek");
const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("DeepSeek");
@@ -1,32 +1,25 @@
use anyhow::{Context as _, Result};
use collections::BTreeMap;
use credentials_provider::CredentialsProvider;
-use futures::{FutureExt, Stream, StreamExt, future::BoxFuture};
-use google_ai::{
- FunctionDeclaration, GenerateContentResponse, GoogleModelMode, Part, SystemInstruction,
- ThinkingConfig, UsageMetadata,
-};
+use futures::{FutureExt, StreamExt, future::BoxFuture};
+pub use google_ai::completion::{GoogleEventMapper, count_google_tokens, into_google};
+use google_ai::{GenerateContentResponse, GoogleModelMode};
use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
use http_client::HttpClient;
use language_model::{
AuthenticateError, ConfigurationViewTargetAgent, EnvVar, LanguageModelCompletionError,
LanguageModelCompletionEvent, LanguageModelToolChoice, LanguageModelToolSchemaFormat,
- LanguageModelToolUse, LanguageModelToolUseId, MessageContent, StopReason,
};
use language_model::{
GOOGLE_PROVIDER_ID, GOOGLE_PROVIDER_NAME, IconOrSvg, LanguageModel, LanguageModelId,
LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
- LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
+ LanguageModelProviderState, LanguageModelRequest, RateLimiter,
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
pub use settings::GoogleAvailableModel as AvailableModel;
use settings::{Settings, SettingsStore};
-use std::pin::Pin;
-use std::sync::{
- Arc, LazyLock,
- atomic::{self, AtomicU64},
-};
+use std::sync::{Arc, LazyLock};
use strum::IntoEnumIterator;
use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*};
use ui_input::InputField;
@@ -394,369 +387,6 @@ impl LanguageModel for GoogleLanguageModel {
}
}
-pub fn into_google(
- mut request: LanguageModelRequest,
- model_id: String,
- mode: GoogleModelMode,
-) -> google_ai::GenerateContentRequest {
- fn map_content(content: Vec<MessageContent>) -> Vec<Part> {
- content
- .into_iter()
- .flat_map(|content| match content {
- language_model::MessageContent::Text(text) => {
- if !text.is_empty() {
- vec![Part::TextPart(google_ai::TextPart { text })]
- } else {
- vec![]
- }
- }
- language_model::MessageContent::Thinking {
- text: _,
- signature: Some(signature),
- } => {
- if !signature.is_empty() {
- vec![Part::ThoughtPart(google_ai::ThoughtPart {
- thought: true,
- thought_signature: signature,
- })]
- } else {
- vec![]
- }
- }
- language_model::MessageContent::Thinking { .. } => {
- vec![]
- }
- language_model::MessageContent::RedactedThinking(_) => vec![],
- language_model::MessageContent::Image(image) => {
- vec![Part::InlineDataPart(google_ai::InlineDataPart {
- inline_data: google_ai::GenerativeContentBlob {
- mime_type: "image/png".to_string(),
- data: image.source.to_string(),
- },
- })]
- }
- language_model::MessageContent::ToolUse(tool_use) => {
- // Normalize empty string signatures to None
- let thought_signature = tool_use.thought_signature.filter(|s| !s.is_empty());
-
- vec![Part::FunctionCallPart(google_ai::FunctionCallPart {
- function_call: google_ai::FunctionCall {
- name: tool_use.name.to_string(),
- args: tool_use.input,
- },
- thought_signature,
- })]
- }
- language_model::MessageContent::ToolResult(tool_result) => {
- match tool_result.content {
- language_model::LanguageModelToolResultContent::Text(text) => {
- vec![Part::FunctionResponsePart(
- google_ai::FunctionResponsePart {
- function_response: google_ai::FunctionResponse {
- name: tool_result.tool_name.to_string(),
- // The API expects a valid JSON object
- response: serde_json::json!({
- "output": text
- }),
- },
- },
- )]
- }
- language_model::LanguageModelToolResultContent::Image(image) => {
- vec![
- Part::FunctionResponsePart(google_ai::FunctionResponsePart {
- function_response: google_ai::FunctionResponse {
- name: tool_result.tool_name.to_string(),
- // The API expects a valid JSON object
- response: serde_json::json!({
- "output": "Tool responded with an image"
- }),
- },
- }),
- Part::InlineDataPart(google_ai::InlineDataPart {
- inline_data: google_ai::GenerativeContentBlob {
- mime_type: "image/png".to_string(),
- data: image.source.to_string(),
- },
- }),
- ]
- }
- }
- }
- })
- .collect()
- }
-
- let system_instructions = if request
- .messages
- .first()
- .is_some_and(|msg| matches!(msg.role, Role::System))
- {
- let message = request.messages.remove(0);
- Some(SystemInstruction {
- parts: map_content(message.content),
- })
- } else {
- None
- };
-
- google_ai::GenerateContentRequest {
- model: google_ai::ModelName { model_id },
- system_instruction: system_instructions,
- contents: request
- .messages
- .into_iter()
- .filter_map(|message| {
- let parts = map_content(message.content);
- if parts.is_empty() {
- None
- } else {
- Some(google_ai::Content {
- parts,
- role: match message.role {
- Role::User => google_ai::Role::User,
- Role::Assistant => google_ai::Role::Model,
- Role::System => google_ai::Role::User, // Google AI doesn't have a system role
- },
- })
- }
- })
- .collect(),
- generation_config: Some(google_ai::GenerationConfig {
- candidate_count: Some(1),
- stop_sequences: Some(request.stop),
- max_output_tokens: None,
- temperature: request.temperature.map(|t| t as f64).or(Some(1.0)),
- thinking_config: match (request.thinking_allowed, mode) {
- (true, GoogleModelMode::Thinking { budget_tokens }) => {
- budget_tokens.map(|thinking_budget| ThinkingConfig { thinking_budget })
- }
- _ => None,
- },
- top_p: None,
- top_k: None,
- }),
- safety_settings: None,
- tools: (!request.tools.is_empty()).then(|| {
- vec![google_ai::Tool {
- function_declarations: request
- .tools
- .into_iter()
- .map(|tool| FunctionDeclaration {
- name: tool.name,
- description: tool.description,
- parameters: tool.input_schema,
- })
- .collect(),
- }]
- }),
- tool_config: request.tool_choice.map(|choice| google_ai::ToolConfig {
- function_calling_config: google_ai::FunctionCallingConfig {
- mode: match choice {
- LanguageModelToolChoice::Auto => google_ai::FunctionCallingMode::Auto,
- LanguageModelToolChoice::Any => google_ai::FunctionCallingMode::Any,
- LanguageModelToolChoice::None => google_ai::FunctionCallingMode::None,
- },
- allowed_function_names: None,
- },
- }),
- }
-}
-
-pub struct GoogleEventMapper {
- usage: UsageMetadata,
- stop_reason: StopReason,
-}
-
-impl GoogleEventMapper {
- pub fn new() -> Self {
- Self {
- usage: UsageMetadata::default(),
- stop_reason: StopReason::EndTurn,
- }
- }
-
- pub fn map_stream(
- mut self,
- events: Pin<Box<dyn Send + Stream<Item = Result<GenerateContentResponse>>>>,
- ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
- {
- events
- .map(Some)
- .chain(futures::stream::once(async { None }))
- .flat_map(move |event| {
- futures::stream::iter(match event {
- Some(Ok(event)) => self.map_event(event),
- Some(Err(error)) => {
- vec![Err(LanguageModelCompletionError::from(error))]
- }
- None => vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))],
- })
- })
- }
-
- pub fn map_event(
- &mut self,
- event: GenerateContentResponse,
- ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
- static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0);
-
- let mut events: Vec<_> = Vec::new();
- let mut wants_to_use_tool = false;
- if let Some(usage_metadata) = event.usage_metadata {
- update_usage(&mut self.usage, &usage_metadata);
- events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(
- convert_usage(&self.usage),
- )))
- }
-
- if let Some(prompt_feedback) = event.prompt_feedback
- && let Some(block_reason) = prompt_feedback.block_reason.as_deref()
- {
- self.stop_reason = match block_reason {
- "SAFETY" | "OTHER" | "BLOCKLIST" | "PROHIBITED_CONTENT" | "IMAGE_SAFETY" => {
- StopReason::Refusal
- }
- _ => {
- log::error!("Unexpected Google block_reason: {block_reason}");
- StopReason::Refusal
- }
- };
- events.push(Ok(LanguageModelCompletionEvent::Stop(self.stop_reason)));
-
- return events;
- }
-
- if let Some(candidates) = event.candidates {
- for candidate in candidates {
- if let Some(finish_reason) = candidate.finish_reason.as_deref() {
- self.stop_reason = match finish_reason {
- "STOP" => StopReason::EndTurn,
- "MAX_TOKENS" => StopReason::MaxTokens,
- _ => {
- log::error!("Unexpected google finish_reason: {finish_reason}");
- StopReason::EndTurn
- }
- };
- }
- candidate
- .content
- .parts
- .into_iter()
- .for_each(|part| match part {
- Part::TextPart(text_part) => {
- events.push(Ok(LanguageModelCompletionEvent::Text(text_part.text)))
- }
- Part::InlineDataPart(_) => {}
- Part::FunctionCallPart(function_call_part) => {
- wants_to_use_tool = true;
- let name: Arc<str> = function_call_part.function_call.name.into();
- let next_tool_id =
- TOOL_CALL_COUNTER.fetch_add(1, atomic::Ordering::SeqCst);
- let id: LanguageModelToolUseId =
- format!("{}-{}", name, next_tool_id).into();
-
- // Normalize empty string signatures to None
- let thought_signature = function_call_part
- .thought_signature
- .filter(|s| !s.is_empty());
-
- events.push(Ok(LanguageModelCompletionEvent::ToolUse(
- LanguageModelToolUse {
- id,
- name,
- is_input_complete: true,
- raw_input: function_call_part.function_call.args.to_string(),
- input: function_call_part.function_call.args,
- thought_signature,
- },
- )));
- }
- Part::FunctionResponsePart(_) => {}
- Part::ThoughtPart(part) => {
- events.push(Ok(LanguageModelCompletionEvent::Thinking {
- text: "(Encrypted thought)".to_string(), // TODO: Can we populate this from thought summaries?
- signature: Some(part.thought_signature),
- }));
- }
- });
- }
- }
-
- // Even when Gemini wants to use a Tool, the API
- // responds with `finish_reason: STOP`
- if wants_to_use_tool {
- self.stop_reason = StopReason::ToolUse;
- events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
- }
- events
- }
-}
-
-pub fn count_google_tokens(
- request: LanguageModelRequest,
- cx: &App,
-) -> BoxFuture<'static, Result<u64>> {
- // We couldn't use the GoogleLanguageModelProvider to count tokens because the github copilot doesn't have the access to google_ai directly.
- // So we have to use tokenizer from tiktoken_rs to count tokens.
- cx.background_spawn(async move {
- let messages = request
- .messages
- .into_iter()
- .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
- role: match message.role {
- Role::User => "user".into(),
- Role::Assistant => "assistant".into(),
- Role::System => "system".into(),
- },
- content: Some(message.string_contents()),
- name: None,
- function_call: None,
- })
- .collect::<Vec<_>>();
-
- // Tiktoken doesn't yet support these models, so we manually use the
- // same tokenizer as GPT-4.
- tiktoken_rs::num_tokens_from_messages("gpt-4", &messages).map(|tokens| tokens as u64)
- })
- .boxed()
-}
-
-fn update_usage(usage: &mut UsageMetadata, new: &UsageMetadata) {
- if let Some(prompt_token_count) = new.prompt_token_count {
- usage.prompt_token_count = Some(prompt_token_count);
- }
- if let Some(cached_content_token_count) = new.cached_content_token_count {
- usage.cached_content_token_count = Some(cached_content_token_count);
- }
- if let Some(candidates_token_count) = new.candidates_token_count {
- usage.candidates_token_count = Some(candidates_token_count);
- }
- if let Some(tool_use_prompt_token_count) = new.tool_use_prompt_token_count {
- usage.tool_use_prompt_token_count = Some(tool_use_prompt_token_count);
- }
- if let Some(thoughts_token_count) = new.thoughts_token_count {
- usage.thoughts_token_count = Some(thoughts_token_count);
- }
- if let Some(total_token_count) = new.total_token_count {
- usage.total_token_count = Some(total_token_count);
- }
-}
-
-fn convert_usage(usage: &UsageMetadata) -> language_model::TokenUsage {
- let prompt_tokens = usage.prompt_token_count.unwrap_or(0);
- let cached_tokens = usage.cached_content_token_count.unwrap_or(0);
- let input_tokens = prompt_tokens - cached_tokens;
- let output_tokens = usage.candidates_token_count.unwrap_or(0);
-
- language_model::TokenUsage {
- input_tokens,
- output_tokens,
- cache_read_input_tokens: cached_tokens,
- cache_creation_input_tokens: 0,
- }
-}
-
struct ConfigurationView {
api_key_editor: Entity<InputField>,
state: Entity<State>,
@@ -895,428 +525,3 @@ impl Render for ConfigurationView {
}
}
}
-
-#[cfg(test)]
-mod tests {
- use super::*;
- use google_ai::{
- Content, FunctionCall, FunctionCallPart, GenerateContentCandidate, GenerateContentResponse,
- Part, Role as GoogleRole, TextPart,
- };
- use language_model::{LanguageModelToolUseId, MessageContent, Role};
- use serde_json::json;
-
- #[test]
- fn test_function_call_with_signature_creates_tool_use_with_signature() {
- let mut mapper = GoogleEventMapper::new();
-
- let response = GenerateContentResponse {
- candidates: Some(vec![GenerateContentCandidate {
- index: Some(0),
- content: Content {
- parts: vec![Part::FunctionCallPart(FunctionCallPart {
- function_call: FunctionCall {
- name: "test_function".to_string(),
- args: json!({"arg": "value"}),
- },
- thought_signature: Some("test_signature_123".to_string()),
- })],
- role: GoogleRole::Model,
- },
- finish_reason: None,
- finish_message: None,
- safety_ratings: None,
- citation_metadata: None,
- }]),
- prompt_feedback: None,
- usage_metadata: None,
- };
-
- let events = mapper.map_event(response);
-
- assert_eq!(events.len(), 2); // ToolUse event + Stop event
-
- if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
- assert_eq!(tool_use.name.as_ref(), "test_function");
- assert_eq!(
- tool_use.thought_signature.as_deref(),
- Some("test_signature_123")
- );
- } else {
- panic!("Expected ToolUse event");
- }
- }
-
- #[test]
- fn test_function_call_without_signature_has_none() {
- let mut mapper = GoogleEventMapper::new();
-
- let response = GenerateContentResponse {
- candidates: Some(vec![GenerateContentCandidate {
- index: Some(0),
- content: Content {
- parts: vec![Part::FunctionCallPart(FunctionCallPart {
- function_call: FunctionCall {
- name: "test_function".to_string(),
- args: json!({"arg": "value"}),
- },
- thought_signature: None,
- })],
- role: GoogleRole::Model,
- },
- finish_reason: None,
- finish_message: None,
- safety_ratings: None,
- citation_metadata: None,
- }]),
- prompt_feedback: None,
- usage_metadata: None,
- };
-
- let events = mapper.map_event(response);
-
- if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
- assert_eq!(tool_use.thought_signature, None);
- } else {
- panic!("Expected ToolUse event");
- }
- }
-
- #[test]
- fn test_empty_string_signature_normalized_to_none() {
- let mut mapper = GoogleEventMapper::new();
-
- let response = GenerateContentResponse {
- candidates: Some(vec![GenerateContentCandidate {
- index: Some(0),
- content: Content {
- parts: vec![Part::FunctionCallPart(FunctionCallPart {
- function_call: FunctionCall {
- name: "test_function".to_string(),
- args: json!({"arg": "value"}),
- },
- thought_signature: Some("".to_string()),
- })],
- role: GoogleRole::Model,
- },
- finish_reason: None,
- finish_message: None,
- safety_ratings: None,
- citation_metadata: None,
- }]),
- prompt_feedback: None,
- usage_metadata: None,
- };
-
- let events = mapper.map_event(response);
-
- if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
- assert_eq!(tool_use.thought_signature, None);
- } else {
- panic!("Expected ToolUse event");
- }
- }
-
- #[test]
- fn test_parallel_function_calls_preserve_signatures() {
- let mut mapper = GoogleEventMapper::new();
-
- let response = GenerateContentResponse {
- candidates: Some(vec![GenerateContentCandidate {
- index: Some(0),
- content: Content {
- parts: vec![
- Part::FunctionCallPart(FunctionCallPart {
- function_call: FunctionCall {
- name: "function_1".to_string(),
- args: json!({"arg": "value1"}),
- },
- thought_signature: Some("signature_1".to_string()),
- }),
- Part::FunctionCallPart(FunctionCallPart {
- function_call: FunctionCall {
- name: "function_2".to_string(),
- args: json!({"arg": "value2"}),
- },
- thought_signature: None,
- }),
- ],
- role: GoogleRole::Model,
- },
- finish_reason: None,
- finish_message: None,
- safety_ratings: None,
- citation_metadata: None,
- }]),
- prompt_feedback: None,
- usage_metadata: None,
- };
-
- let events = mapper.map_event(response);
-
- assert_eq!(events.len(), 3); // 2 ToolUse events + Stop event
-
- if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
- assert_eq!(tool_use.name.as_ref(), "function_1");
- assert_eq!(tool_use.thought_signature.as_deref(), Some("signature_1"));
- } else {
- panic!("Expected ToolUse event for function_1");
- }
-
- if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[1] {
- assert_eq!(tool_use.name.as_ref(), "function_2");
- assert_eq!(tool_use.thought_signature, None);
- } else {
- panic!("Expected ToolUse event for function_2");
- }
- }
-
- #[test]
- fn test_tool_use_with_signature_converts_to_function_call_part() {
- let tool_use = language_model::LanguageModelToolUse {
- id: LanguageModelToolUseId::from("test_id"),
- name: "test_function".into(),
- raw_input: json!({"arg": "value"}).to_string(),
- input: json!({"arg": "value"}),
- is_input_complete: true,
- thought_signature: Some("test_signature_456".to_string()),
- };
-
- let request = super::into_google(
- LanguageModelRequest {
- messages: vec![language_model::LanguageModelRequestMessage {
- role: Role::Assistant,
- content: vec![MessageContent::ToolUse(tool_use)],
- cache: false,
- reasoning_details: None,
- }],
- ..Default::default()
- },
- "gemini-2.5-flash".to_string(),
- GoogleModelMode::Default,
- );
-
- assert_eq!(request.contents[0].parts.len(), 1);
- if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] {
- assert_eq!(fc_part.function_call.name, "test_function");
- assert_eq!(
- fc_part.thought_signature.as_deref(),
- Some("test_signature_456")
- );
- } else {
- panic!("Expected FunctionCallPart");
- }
- }
-
- #[test]
- fn test_tool_use_without_signature_omits_field() {
- let tool_use = language_model::LanguageModelToolUse {
- id: LanguageModelToolUseId::from("test_id"),
- name: "test_function".into(),
- raw_input: json!({"arg": "value"}).to_string(),
- input: json!({"arg": "value"}),
- is_input_complete: true,
- thought_signature: None,
- };
-
- let request = super::into_google(
- LanguageModelRequest {
- messages: vec![language_model::LanguageModelRequestMessage {
- role: Role::Assistant,
- content: vec![MessageContent::ToolUse(tool_use)],
- cache: false,
- reasoning_details: None,
- }],
- ..Default::default()
- },
- "gemini-2.5-flash".to_string(),
- GoogleModelMode::Default,
- );
-
- assert_eq!(request.contents[0].parts.len(), 1);
- if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] {
- assert_eq!(fc_part.thought_signature, None);
- } else {
- panic!("Expected FunctionCallPart");
- }
- }
-
- #[test]
- fn test_empty_signature_in_tool_use_normalized_to_none() {
- let tool_use = language_model::LanguageModelToolUse {
- id: LanguageModelToolUseId::from("test_id"),
- name: "test_function".into(),
- raw_input: json!({"arg": "value"}).to_string(),
- input: json!({"arg": "value"}),
- is_input_complete: true,
- thought_signature: Some("".to_string()),
- };
-
- let request = super::into_google(
- LanguageModelRequest {
- messages: vec![language_model::LanguageModelRequestMessage {
- role: Role::Assistant,
- content: vec![MessageContent::ToolUse(tool_use)],
- cache: false,
- reasoning_details: None,
- }],
- ..Default::default()
- },
- "gemini-2.5-flash".to_string(),
- GoogleModelMode::Default,
- );
-
- if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] {
- assert_eq!(fc_part.thought_signature, None);
- } else {
- panic!("Expected FunctionCallPart");
- }
- }
-
- #[test]
- fn test_round_trip_preserves_signature() {
- let mut mapper = GoogleEventMapper::new();
-
- // Simulate receiving a response from Google with a signature
- let response = GenerateContentResponse {
- candidates: Some(vec![GenerateContentCandidate {
- index: Some(0),
- content: Content {
- parts: vec![Part::FunctionCallPart(FunctionCallPart {
- function_call: FunctionCall {
- name: "test_function".to_string(),
- args: json!({"arg": "value"}),
- },
- thought_signature: Some("round_trip_sig".to_string()),
- })],
- role: GoogleRole::Model,
- },
- finish_reason: None,
- finish_message: None,
- safety_ratings: None,
- citation_metadata: None,
- }]),
- prompt_feedback: None,
- usage_metadata: None,
- };
-
- let events = mapper.map_event(response);
-
- let tool_use = if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
- tool_use.clone()
- } else {
- panic!("Expected ToolUse event");
- };
-
- // Convert back to Google format
- let request = super::into_google(
- LanguageModelRequest {
- messages: vec![language_model::LanguageModelRequestMessage {
- role: Role::Assistant,
- content: vec![MessageContent::ToolUse(tool_use)],
- cache: false,
- reasoning_details: None,
- }],
- ..Default::default()
- },
- "gemini-2.5-flash".to_string(),
- GoogleModelMode::Default,
- );
-
- // Verify signature is preserved
- if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] {
- assert_eq!(fc_part.thought_signature.as_deref(), Some("round_trip_sig"));
- } else {
- panic!("Expected FunctionCallPart");
- }
- }
-
- #[test]
- fn test_mixed_text_and_function_call_with_signature() {
- let mut mapper = GoogleEventMapper::new();
-
- let response = GenerateContentResponse {
- candidates: Some(vec![GenerateContentCandidate {
- index: Some(0),
- content: Content {
- parts: vec![
- Part::TextPart(TextPart {
- text: "I'll help with that.".to_string(),
- }),
- Part::FunctionCallPart(FunctionCallPart {
- function_call: FunctionCall {
- name: "helper_function".to_string(),
- args: json!({"query": "help"}),
- },
- thought_signature: Some("mixed_sig".to_string()),
- }),
- ],
- role: GoogleRole::Model,
- },
- finish_reason: None,
- finish_message: None,
- safety_ratings: None,
- citation_metadata: None,
- }]),
- prompt_feedback: None,
- usage_metadata: None,
- };
-
- let events = mapper.map_event(response);
-
- assert_eq!(events.len(), 3); // Text event + ToolUse event + Stop event
-
- if let Ok(LanguageModelCompletionEvent::Text(text)) = &events[0] {
- assert_eq!(text, "I'll help with that.");
- } else {
- panic!("Expected Text event");
- }
-
- if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[1] {
- assert_eq!(tool_use.name.as_ref(), "helper_function");
- assert_eq!(tool_use.thought_signature.as_deref(), Some("mixed_sig"));
- } else {
- panic!("Expected ToolUse event");
- }
- }
-
- #[test]
- fn test_special_characters_in_signature_preserved() {
- let mut mapper = GoogleEventMapper::new();
-
- let signature_with_special_chars = "sig<>\"'&%$#@!{}[]".to_string();
-
- let response = GenerateContentResponse {
- candidates: Some(vec![GenerateContentCandidate {
- index: Some(0),
- content: Content {
- parts: vec![Part::FunctionCallPart(FunctionCallPart {
- function_call: FunctionCall {
- name: "test_function".to_string(),
- args: json!({"arg": "value"}),
- },
- thought_signature: Some(signature_with_special_chars.clone()),
- })],
- role: GoogleRole::Model,
- },
- finish_reason: None,
- finish_message: None,
- safety_ratings: None,
- citation_metadata: None,
- }]),
- prompt_feedback: None,
- usage_metadata: None,
- };
-
- let events = mapper.map_event(response);
-
- if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
- assert_eq!(
- tool_use.thought_signature.as_deref(),
- Some(signature_with_special_chars.as_str())
- );
- } else {
- panic!("Expected ToolUse event");
- }
- }
-}
@@ -28,7 +28,7 @@ use ui::{
use ui_input::InputField;
use crate::AllLanguageModelSettings;
-use crate::provider::util::parse_tool_arguments;
+use language_model::util::parse_tool_arguments;
const LMSTUDIO_DOWNLOAD_URL: &str = "https://lmstudio.ai/download";
const LMSTUDIO_CATALOG_URL: &str = "https://lmstudio.ai/models";
@@ -23,7 +23,7 @@ use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*};
use ui_input::InputField;
use util::ResultExt;
-use crate::provider::util::{fix_streamed_json, parse_tool_arguments};
+use language_model::util::{fix_streamed_json, parse_tool_arguments};
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("mistral");
const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Mistral");
@@ -1,41 +1,33 @@
-use anyhow::{Result, anyhow};
-use collections::{BTreeMap, HashMap};
+use anyhow::Result;
+use collections::BTreeMap;
use credentials_provider::CredentialsProvider;
-use futures::Stream;
use futures::{FutureExt, StreamExt, future::BoxFuture};
use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
use http_client::HttpClient;
use language_model::{
ApiKeyState, AuthenticateError, EnvVar, IconOrSvg, LanguageModel, LanguageModelCompletionError,
- LanguageModelCompletionEvent, LanguageModelId, LanguageModelImage, LanguageModelName,
- LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
- LanguageModelProviderState, LanguageModelRequest, LanguageModelRequestMessage,
- LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse,
- LanguageModelToolUseId, MessageContent, OPEN_AI_PROVIDER_ID, OPEN_AI_PROVIDER_NAME,
- RateLimiter, Role, StopReason, TokenUsage, env_var,
+ LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
+ LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
+ LanguageModelRequest, LanguageModelToolChoice, OPEN_AI_PROVIDER_ID, OPEN_AI_PROVIDER_NAME,
+ RateLimiter, env_var,
};
use menu;
-use open_ai::responses::{
- ResponseFunctionCallItem, ResponseFunctionCallOutputContent, ResponseFunctionCallOutputItem,
- ResponseInputContent, ResponseInputItem, ResponseMessageItem,
-};
use open_ai::{
- ImageUrl, Model, OPEN_AI_API_URL, ReasoningEffort, ResponseStreamEvent,
- responses::{
- Request as ResponseRequest, ResponseOutputItem, ResponseSummary as ResponsesSummary,
- ResponseUsage as ResponsesUsage, StreamEvent as ResponsesStreamEvent, stream_response,
- },
+ OPEN_AI_API_URL, ResponseStreamEvent,
+ responses::{Request as ResponseRequest, StreamEvent as ResponsesStreamEvent, stream_response},
stream_completion,
};
use settings::{OpenAiAvailableModel as AvailableModel, Settings, SettingsStore};
-use std::pin::Pin;
use std::sync::{Arc, LazyLock};
use strum::IntoEnumIterator;
use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*};
use ui_input::InputField;
use util::ResultExt;
-use crate::provider::util::{fix_streamed_json, parse_tool_arguments};
+pub use open_ai::completion::{
+ OpenAiEventMapper, OpenAiResponseEventMapper, collect_tiktoken_messages, count_open_ai_tokens,
+ into_open_ai, into_open_ai_response,
+};
const PROVIDER_ID: LanguageModelProviderId = OPEN_AI_PROVIDER_ID;
const PROVIDER_NAME: LanguageModelProviderName = OPEN_AI_PROVIDER_NAME;
@@ -189,7 +181,7 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider {
max_tokens: model.max_tokens,
max_output_tokens: model.max_output_tokens,
max_completion_tokens: model.max_completion_tokens,
- reasoning_effort: model.reasoning_effort.clone(),
+ reasoning_effort: model.reasoning_effort,
supports_chat_completions: model.capabilities.chat_completions,
},
);
@@ -382,7 +374,9 @@ impl LanguageModel for OpenAiLanguageModel {
request: LanguageModelRequest,
cx: &App,
) -> BoxFuture<'static, Result<u64>> {
- count_open_ai_tokens(request, self.model.clone(), cx)
+ let model = self.model.clone();
+ cx.background_spawn(async move { count_open_ai_tokens(request, model) })
+ .boxed()
}
fn stream_completion(
@@ -433,853 +427,6 @@ impl LanguageModel for OpenAiLanguageModel {
}
}
-pub fn into_open_ai(
- request: LanguageModelRequest,
- model_id: &str,
- supports_parallel_tool_calls: bool,
- supports_prompt_cache_key: bool,
- max_output_tokens: Option<u64>,
- reasoning_effort: Option<ReasoningEffort>,
-) -> open_ai::Request {
- let stream = !model_id.starts_with("o1-");
-
- let mut messages = Vec::new();
- for message in request.messages {
- for content in message.content {
- match content {
- MessageContent::Text(text) | MessageContent::Thinking { text, .. } => {
- let should_add = if message.role == Role::User {
- // Including whitespace-only user messages can cause error with OpenAI compatible APIs
- // See https://github.com/zed-industries/zed/issues/40097
- !text.trim().is_empty()
- } else {
- !text.is_empty()
- };
- if should_add {
- add_message_content_part(
- open_ai::MessagePart::Text { text },
- message.role,
- &mut messages,
- );
- }
- }
- MessageContent::RedactedThinking(_) => {}
- MessageContent::Image(image) => {
- add_message_content_part(
- open_ai::MessagePart::Image {
- image_url: ImageUrl {
- url: image.to_base64_url(),
- detail: None,
- },
- },
- message.role,
- &mut messages,
- );
- }
- MessageContent::ToolUse(tool_use) => {
- let tool_call = open_ai::ToolCall {
- id: tool_use.id.to_string(),
- content: open_ai::ToolCallContent::Function {
- function: open_ai::FunctionContent {
- name: tool_use.name.to_string(),
- arguments: serde_json::to_string(&tool_use.input)
- .unwrap_or_default(),
- },
- },
- };
-
- if let Some(open_ai::RequestMessage::Assistant { tool_calls, .. }) =
- messages.last_mut()
- {
- tool_calls.push(tool_call);
- } else {
- messages.push(open_ai::RequestMessage::Assistant {
- content: None,
- tool_calls: vec![tool_call],
- });
- }
- }
- MessageContent::ToolResult(tool_result) => {
- let content = match &tool_result.content {
- LanguageModelToolResultContent::Text(text) => {
- vec![open_ai::MessagePart::Text {
- text: text.to_string(),
- }]
- }
- LanguageModelToolResultContent::Image(image) => {
- vec![open_ai::MessagePart::Image {
- image_url: ImageUrl {
- url: image.to_base64_url(),
- detail: None,
- },
- }]
- }
- };
-
- messages.push(open_ai::RequestMessage::Tool {
- content: content.into(),
- tool_call_id: tool_result.tool_use_id.to_string(),
- });
- }
- }
- }
- }
-
- open_ai::Request {
- model: model_id.into(),
- messages,
- stream,
- stream_options: if stream {
- Some(open_ai::StreamOptions::default())
- } else {
- None
- },
- stop: request.stop,
- temperature: request.temperature.or(Some(1.0)),
- max_completion_tokens: max_output_tokens,
- parallel_tool_calls: if supports_parallel_tool_calls && !request.tools.is_empty() {
- Some(supports_parallel_tool_calls)
- } else {
- None
- },
- prompt_cache_key: if supports_prompt_cache_key {
- request.thread_id
- } else {
- None
- },
- tools: request
- .tools
- .into_iter()
- .map(|tool| open_ai::ToolDefinition::Function {
- function: open_ai::FunctionDefinition {
- name: tool.name,
- description: Some(tool.description),
- parameters: Some(tool.input_schema),
- },
- })
- .collect(),
- tool_choice: request.tool_choice.map(|choice| match choice {
- LanguageModelToolChoice::Auto => open_ai::ToolChoice::Auto,
- LanguageModelToolChoice::Any => open_ai::ToolChoice::Required,
- LanguageModelToolChoice::None => open_ai::ToolChoice::None,
- }),
- reasoning_effort,
- }
-}
-
-pub fn into_open_ai_response(
- request: LanguageModelRequest,
- model_id: &str,
- supports_parallel_tool_calls: bool,
- supports_prompt_cache_key: bool,
- max_output_tokens: Option<u64>,
- reasoning_effort: Option<ReasoningEffort>,
-) -> ResponseRequest {
- let stream = !model_id.starts_with("o1-");
-
- let LanguageModelRequest {
- thread_id,
- prompt_id: _,
- intent: _,
- messages,
- tools,
- tool_choice,
- stop: _,
- temperature,
- thinking_allowed: _,
- thinking_effort: _,
- speed: _,
- } = request;
-
- let mut input_items = Vec::new();
- for (index, message) in messages.into_iter().enumerate() {
- append_message_to_response_items(message, index, &mut input_items);
- }
-
- let tools: Vec<_> = tools
- .into_iter()
- .map(|tool| open_ai::responses::ToolDefinition::Function {
- name: tool.name,
- description: Some(tool.description),
- parameters: Some(tool.input_schema),
- strict: None,
- })
- .collect();
-
- ResponseRequest {
- model: model_id.into(),
- input: input_items,
- stream,
- temperature,
- top_p: None,
- max_output_tokens,
- parallel_tool_calls: if tools.is_empty() {
- None
- } else {
- Some(supports_parallel_tool_calls)
- },
- tool_choice: tool_choice.map(|choice| match choice {
- LanguageModelToolChoice::Auto => open_ai::ToolChoice::Auto,
- LanguageModelToolChoice::Any => open_ai::ToolChoice::Required,
- LanguageModelToolChoice::None => open_ai::ToolChoice::None,
- }),
- tools,
- prompt_cache_key: if supports_prompt_cache_key {
- thread_id
- } else {
- None
- },
- reasoning: reasoning_effort.map(|effort| open_ai::responses::ReasoningConfig {
- effort,
- summary: Some(open_ai::responses::ReasoningSummaryMode::Auto),
- }),
- }
-}
-
-fn append_message_to_response_items(
- message: LanguageModelRequestMessage,
- index: usize,
- input_items: &mut Vec<ResponseInputItem>,
-) {
- let mut content_parts: Vec<ResponseInputContent> = Vec::new();
-
- for content in message.content {
- match content {
- MessageContent::Text(text) => {
- push_response_text_part(&message.role, text, &mut content_parts);
- }
- MessageContent::Thinking { text, .. } => {
- push_response_text_part(&message.role, text, &mut content_parts);
- }
- MessageContent::RedactedThinking(_) => {}
- MessageContent::Image(image) => {
- push_response_image_part(&message.role, image, &mut content_parts);
- }
- MessageContent::ToolUse(tool_use) => {
- flush_response_parts(&message.role, index, &mut content_parts, input_items);
- let call_id = tool_use.id.to_string();
- input_items.push(ResponseInputItem::FunctionCall(ResponseFunctionCallItem {
- call_id,
- name: tool_use.name.to_string(),
- arguments: tool_use.raw_input,
- }));
- }
- MessageContent::ToolResult(tool_result) => {
- flush_response_parts(&message.role, index, &mut content_parts, input_items);
- input_items.push(ResponseInputItem::FunctionCallOutput(
- ResponseFunctionCallOutputItem {
- call_id: tool_result.tool_use_id.to_string(),
- output: match tool_result.content {
- LanguageModelToolResultContent::Text(text) => {
- ResponseFunctionCallOutputContent::Text(text.to_string())
- }
- LanguageModelToolResultContent::Image(image) => {
- ResponseFunctionCallOutputContent::List(vec![
- ResponseInputContent::Image {
- image_url: image.to_base64_url(),
- },
- ])
- }
- },
- },
- ));
- }
- }
- }
-
- flush_response_parts(&message.role, index, &mut content_parts, input_items);
-}
-
-fn push_response_text_part(
- role: &Role,
- text: impl Into<String>,
- parts: &mut Vec<ResponseInputContent>,
-) {
- let text = text.into();
- if text.trim().is_empty() {
- return;
- }
-
- match role {
- Role::Assistant => parts.push(ResponseInputContent::OutputText {
- text,
- annotations: Vec::new(),
- }),
- _ => parts.push(ResponseInputContent::Text { text }),
- }
-}
-
-fn push_response_image_part(
- role: &Role,
- image: LanguageModelImage,
- parts: &mut Vec<ResponseInputContent>,
-) {
- match role {
- Role::Assistant => parts.push(ResponseInputContent::OutputText {
- text: "[image omitted]".to_string(),
- annotations: Vec::new(),
- }),
- _ => parts.push(ResponseInputContent::Image {
- image_url: image.to_base64_url(),
- }),
- }
-}
-
-fn flush_response_parts(
- role: &Role,
- _index: usize,
- parts: &mut Vec<ResponseInputContent>,
- input_items: &mut Vec<ResponseInputItem>,
-) {
- if parts.is_empty() {
- return;
- }
-
- let item = ResponseInputItem::Message(ResponseMessageItem {
- role: match role {
- Role::User => open_ai::Role::User,
- Role::Assistant => open_ai::Role::Assistant,
- Role::System => open_ai::Role::System,
- },
- content: parts.clone(),
- });
-
- input_items.push(item);
- parts.clear();
-}
-
-fn add_message_content_part(
- new_part: open_ai::MessagePart,
- role: Role,
- messages: &mut Vec<open_ai::RequestMessage>,
-) {
- match (role, messages.last_mut()) {
- (Role::User, Some(open_ai::RequestMessage::User { content }))
- | (
- Role::Assistant,
- Some(open_ai::RequestMessage::Assistant {
- content: Some(content),
- ..
- }),
- )
- | (Role::System, Some(open_ai::RequestMessage::System { content, .. })) => {
- content.push_part(new_part);
- }
- _ => {
- messages.push(match role {
- Role::User => open_ai::RequestMessage::User {
- content: open_ai::MessageContent::from(vec![new_part]),
- },
- Role::Assistant => open_ai::RequestMessage::Assistant {
- content: Some(open_ai::MessageContent::from(vec![new_part])),
- tool_calls: Vec::new(),
- },
- Role::System => open_ai::RequestMessage::System {
- content: open_ai::MessageContent::from(vec![new_part]),
- },
- });
- }
- }
-}
-
-pub struct OpenAiEventMapper {
- tool_calls_by_index: HashMap<usize, RawToolCall>,
-}
-
-impl OpenAiEventMapper {
- pub fn new() -> Self {
- Self {
- tool_calls_by_index: HashMap::default(),
- }
- }
-
- pub fn map_stream(
- mut self,
- events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
- ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
- {
- events.flat_map(move |event| {
- futures::stream::iter(match event {
- Ok(event) => self.map_event(event),
- Err(error) => vec![Err(LanguageModelCompletionError::from(anyhow!(error)))],
- })
- })
- }
-
- pub fn map_event(
- &mut self,
- event: ResponseStreamEvent,
- ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
- let mut events = Vec::new();
- if let Some(usage) = event.usage {
- events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
- input_tokens: usage.prompt_tokens,
- output_tokens: usage.completion_tokens,
- cache_creation_input_tokens: 0,
- cache_read_input_tokens: 0,
- })));
- }
-
- let Some(choice) = event.choices.first() else {
- return events;
- };
-
- if let Some(delta) = choice.delta.as_ref() {
- if let Some(reasoning_content) = delta.reasoning_content.clone() {
- if !reasoning_content.is_empty() {
- events.push(Ok(LanguageModelCompletionEvent::Thinking {
- text: reasoning_content,
- signature: None,
- }));
- }
- }
- if let Some(content) = delta.content.clone() {
- if !content.is_empty() {
- events.push(Ok(LanguageModelCompletionEvent::Text(content)));
- }
- }
-
- if let Some(tool_calls) = delta.tool_calls.as_ref() {
- for tool_call in tool_calls {
- let entry = self.tool_calls_by_index.entry(tool_call.index).or_default();
-
- if let Some(tool_id) = tool_call.id.clone() {
- entry.id = tool_id;
- }
-
- if let Some(function) = tool_call.function.as_ref() {
- if let Some(name) = function.name.clone() {
- entry.name = name;
- }
-
- if let Some(arguments) = function.arguments.clone() {
- entry.arguments.push_str(&arguments);
- }
- }
-
- if !entry.id.is_empty() && !entry.name.is_empty() {
- if let Ok(input) = serde_json::from_str::<serde_json::Value>(
- &fix_streamed_json(&entry.arguments),
- ) {
- events.push(Ok(LanguageModelCompletionEvent::ToolUse(
- LanguageModelToolUse {
- id: entry.id.clone().into(),
- name: entry.name.as_str().into(),
- is_input_complete: false,
- input,
- raw_input: entry.arguments.clone(),
- thought_signature: None,
- },
- )));
- }
- }
- }
- }
- }
-
- match choice.finish_reason.as_deref() {
- Some("stop") => {
- events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
- }
- Some("tool_calls") => {
- events.extend(self.tool_calls_by_index.drain().map(|(_, tool_call)| {
- match parse_tool_arguments(&tool_call.arguments) {
- Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
- LanguageModelToolUse {
- id: tool_call.id.clone().into(),
- name: tool_call.name.as_str().into(),
- is_input_complete: true,
- input,
- raw_input: tool_call.arguments.clone(),
- thought_signature: None,
- },
- )),
- Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
- id: tool_call.id.into(),
- tool_name: tool_call.name.into(),
- raw_input: tool_call.arguments.clone().into(),
- json_parse_error: error.to_string(),
- }),
- }
- }));
-
- events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
- }
- Some(stop_reason) => {
- log::error!("Unexpected OpenAI stop_reason: {stop_reason:?}",);
- events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
- }
- None => {}
- }
-
- events
- }
-}
-
-#[derive(Default)]
-struct RawToolCall {
- id: String,
- name: String,
- arguments: String,
-}
-
-pub struct OpenAiResponseEventMapper {
- function_calls_by_item: HashMap<String, PendingResponseFunctionCall>,
- pending_stop_reason: Option<StopReason>,
-}
-
-#[derive(Default)]
-struct PendingResponseFunctionCall {
- call_id: String,
- name: Arc<str>,
- arguments: String,
-}
-
-impl OpenAiResponseEventMapper {
- pub fn new() -> Self {
- Self {
- function_calls_by_item: HashMap::default(),
- pending_stop_reason: None,
- }
- }
-
- pub fn map_stream(
- mut self,
- events: Pin<Box<dyn Send + Stream<Item = Result<ResponsesStreamEvent>>>>,
- ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
- {
- events.flat_map(move |event| {
- futures::stream::iter(match event {
- Ok(event) => self.map_event(event),
- Err(error) => vec![Err(LanguageModelCompletionError::from(anyhow!(error)))],
- })
- })
- }
-
- pub fn map_event(
- &mut self,
- event: ResponsesStreamEvent,
- ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
- match event {
- ResponsesStreamEvent::OutputItemAdded { item, .. } => {
- let mut events = Vec::new();
-
- match &item {
- ResponseOutputItem::Message(message) => {
- if let Some(id) = &message.id {
- events.push(Ok(LanguageModelCompletionEvent::StartMessage {
- message_id: id.clone(),
- }));
- }
- }
- ResponseOutputItem::FunctionCall(function_call) => {
- if let Some(item_id) = function_call.id.clone() {
- let call_id = function_call
- .call_id
- .clone()
- .or_else(|| function_call.id.clone())
- .unwrap_or_else(|| item_id.clone());
- let entry = PendingResponseFunctionCall {
- call_id,
- name: Arc::<str>::from(
- function_call.name.clone().unwrap_or_default(),
- ),
- arguments: function_call.arguments.clone(),
- };
- self.function_calls_by_item.insert(item_id, entry);
- }
- }
- ResponseOutputItem::Reasoning(_) | ResponseOutputItem::Unknown => {}
- }
- events
- }
- ResponsesStreamEvent::ReasoningSummaryTextDelta { delta, .. } => {
- if delta.is_empty() {
- Vec::new()
- } else {
- vec![Ok(LanguageModelCompletionEvent::Thinking {
- text: delta,
- signature: None,
- })]
- }
- }
- ResponsesStreamEvent::OutputTextDelta { delta, .. } => {
- if delta.is_empty() {
- Vec::new()
- } else {
- vec![Ok(LanguageModelCompletionEvent::Text(delta))]
- }
- }
- ResponsesStreamEvent::FunctionCallArgumentsDelta { item_id, delta, .. } => {
- if let Some(entry) = self.function_calls_by_item.get_mut(&item_id) {
- entry.arguments.push_str(&delta);
- if let Ok(input) = serde_json::from_str::<serde_json::Value>(
- &fix_streamed_json(&entry.arguments),
- ) {
- return vec![Ok(LanguageModelCompletionEvent::ToolUse(
- LanguageModelToolUse {
- id: LanguageModelToolUseId::from(entry.call_id.clone()),
- name: entry.name.clone(),
- is_input_complete: false,
- input,
- raw_input: entry.arguments.clone(),
- thought_signature: None,
- },
- ))];
- }
- }
- Vec::new()
- }
- ResponsesStreamEvent::FunctionCallArgumentsDone {
- item_id, arguments, ..
- } => {
- if let Some(mut entry) = self.function_calls_by_item.remove(&item_id) {
- if !arguments.is_empty() {
- entry.arguments = arguments;
- }
- let raw_input = entry.arguments.clone();
- self.pending_stop_reason = Some(StopReason::ToolUse);
- match parse_tool_arguments(&entry.arguments) {
- Ok(input) => {
- vec![Ok(LanguageModelCompletionEvent::ToolUse(
- LanguageModelToolUse {
- id: LanguageModelToolUseId::from(entry.call_id.clone()),
- name: entry.name.clone(),
- is_input_complete: true,
- input,
- raw_input,
- thought_signature: None,
- },
- ))]
- }
- Err(error) => {
- vec![Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
- id: LanguageModelToolUseId::from(entry.call_id.clone()),
- tool_name: entry.name.clone(),
- raw_input: Arc::<str>::from(raw_input),
- json_parse_error: error.to_string(),
- })]
- }
- }
- } else {
- Vec::new()
- }
- }
- ResponsesStreamEvent::Completed { response } => {
- self.handle_completion(response, StopReason::EndTurn)
- }
- ResponsesStreamEvent::Incomplete { response } => {
- let reason = response
- .status_details
- .as_ref()
- .and_then(|details| details.reason.as_deref());
- let stop_reason = match reason {
- Some("max_output_tokens") => StopReason::MaxTokens,
- Some("content_filter") => {
- self.pending_stop_reason = Some(StopReason::Refusal);
- StopReason::Refusal
- }
- _ => self
- .pending_stop_reason
- .take()
- .unwrap_or(StopReason::EndTurn),
- };
-
- let mut events = Vec::new();
- if self.pending_stop_reason.is_none() {
- events.extend(self.emit_tool_calls_from_output(&response.output));
- }
- if let Some(usage) = response.usage.as_ref() {
- events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(
- token_usage_from_response_usage(usage),
- )));
- }
- events.push(Ok(LanguageModelCompletionEvent::Stop(stop_reason)));
- events
- }
- ResponsesStreamEvent::Failed { response } => {
- let message = response
- .status_details
- .and_then(|details| details.error)
- .map(|error| error.to_string())
- .unwrap_or_else(|| "response failed".to_string());
- vec![Err(LanguageModelCompletionError::Other(anyhow!(message)))]
- }
- ResponsesStreamEvent::Error { error }
- | ResponsesStreamEvent::GenericError { error } => {
- vec![Err(LanguageModelCompletionError::Other(anyhow!(
- error.message
- )))]
- }
- ResponsesStreamEvent::ReasoningSummaryPartAdded { summary_index, .. } => {
- if summary_index > 0 {
- vec![Ok(LanguageModelCompletionEvent::Thinking {
- text: "\n\n".to_string(),
- signature: None,
- })]
- } else {
- Vec::new()
- }
- }
- ResponsesStreamEvent::OutputTextDone { .. }
- | ResponsesStreamEvent::OutputItemDone { .. }
- | ResponsesStreamEvent::ContentPartAdded { .. }
- | ResponsesStreamEvent::ContentPartDone { .. }
- | ResponsesStreamEvent::ReasoningSummaryTextDone { .. }
- | ResponsesStreamEvent::ReasoningSummaryPartDone { .. }
- | ResponsesStreamEvent::Created { .. }
- | ResponsesStreamEvent::InProgress { .. }
- | ResponsesStreamEvent::Unknown => Vec::new(),
- }
- }
-
- fn handle_completion(
- &mut self,
- response: ResponsesSummary,
- default_reason: StopReason,
- ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
- let mut events = Vec::new();
-
- if self.pending_stop_reason.is_none() {
- events.extend(self.emit_tool_calls_from_output(&response.output));
- }
-
- if let Some(usage) = response.usage.as_ref() {
- events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(
- token_usage_from_response_usage(usage),
- )));
- }
-
- let stop_reason = self.pending_stop_reason.take().unwrap_or(default_reason);
- events.push(Ok(LanguageModelCompletionEvent::Stop(stop_reason)));
- events
- }
-
- fn emit_tool_calls_from_output(
- &mut self,
- output: &[ResponseOutputItem],
- ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
- let mut events = Vec::new();
- for item in output {
- if let ResponseOutputItem::FunctionCall(function_call) = item {
- let Some(call_id) = function_call
- .call_id
- .clone()
- .or_else(|| function_call.id.clone())
- else {
- log::error!(
- "Function call item missing both call_id and id: {:?}",
- function_call
- );
- continue;
- };
- let name: Arc<str> = Arc::from(function_call.name.clone().unwrap_or_default());
- let arguments = &function_call.arguments;
- self.pending_stop_reason = Some(StopReason::ToolUse);
- match parse_tool_arguments(arguments) {
- Ok(input) => {
- events.push(Ok(LanguageModelCompletionEvent::ToolUse(
- LanguageModelToolUse {
- id: LanguageModelToolUseId::from(call_id.clone()),
- name: name.clone(),
- is_input_complete: true,
- input,
- raw_input: arguments.clone(),
- thought_signature: None,
- },
- )));
- }
- Err(error) => {
- events.push(Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
- id: LanguageModelToolUseId::from(call_id.clone()),
- tool_name: name.clone(),
- raw_input: Arc::<str>::from(arguments.clone()),
- json_parse_error: error.to_string(),
- }));
- }
- }
- }
- }
- events
- }
-}
-
-fn token_usage_from_response_usage(usage: &ResponsesUsage) -> TokenUsage {
- TokenUsage {
- input_tokens: usage.input_tokens.unwrap_or_default(),
- output_tokens: usage.output_tokens.unwrap_or_default(),
- cache_creation_input_tokens: 0,
- cache_read_input_tokens: 0,
- }
-}
-
-pub(crate) fn collect_tiktoken_messages(
- request: LanguageModelRequest,
-) -> Vec<tiktoken_rs::ChatCompletionRequestMessage> {
- request
- .messages
- .into_iter()
- .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
- role: match message.role {
- Role::User => "user".into(),
- Role::Assistant => "assistant".into(),
- Role::System => "system".into(),
- },
- content: Some(message.string_contents()),
- name: None,
- function_call: None,
- })
- .collect::<Vec<_>>()
-}
-
-pub fn count_open_ai_tokens(
- request: LanguageModelRequest,
- model: Model,
- cx: &App,
-) -> BoxFuture<'static, Result<u64>> {
- cx.background_spawn(async move {
- let messages = collect_tiktoken_messages(request);
- match model {
- Model::Custom { max_tokens, .. } => {
- let model = if max_tokens >= 100_000 {
- // If the max tokens is 100k or more, it likely uses the o200k_base tokenizer
- "gpt-4o"
- } else {
- // Otherwise fallback to gpt-4, since only cl100k_base and o200k_base are
- // supported with this tiktoken method
- "gpt-4"
- };
- tiktoken_rs::num_tokens_from_messages(model, &messages)
- }
- // Currently supported by tiktoken_rs
- // Sometimes tiktoken-rs is behind on model support. If that is the case, make a new branch
- // arm with an override. We enumerate all supported models here so that we can check if new
- // models are supported yet or not.
- Model::ThreePointFiveTurbo
- | Model::Four
- | Model::FourTurbo
- | Model::FourOmniMini
- | Model::FourPointOneNano
- | Model::O1
- | Model::O3
- | Model::O3Mini
- | Model::Five
- | Model::FiveCodex
- | Model::FiveMini
- | Model::FiveNano => tiktoken_rs::num_tokens_from_messages(model.id(), &messages),
- // GPT-5.1, 5.2, 5.2-codex, 5.3-codex, 5.4, and 5.4-pro don't have dedicated tiktoken support; use gpt-5 tokenizer
- Model::FivePointOne
- | Model::FivePointTwo
- | Model::FivePointTwoCodex
- | Model::FivePointThreeCodex
- | Model::FivePointFour
- | Model::FivePointFourPro => tiktoken_rs::num_tokens_from_messages("gpt-5", &messages),
- }
- .map(|tokens| tokens as u64)
- })
- .boxed()
-}
-
struct ConfigurationView {
api_key_editor: Entity<InputField>,
state: Entity<State>,
@@ -1459,874 +606,3 @@ impl Render for ConfigurationView {
}
}
}
-
-#[cfg(test)]
-mod tests {
- use futures::{StreamExt, executor::block_on};
- use gpui::TestAppContext;
- use language_model::{
- LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
- };
- use open_ai::responses::{
- ReasoningSummaryPart, ResponseFunctionToolCall, ResponseOutputItem, ResponseOutputMessage,
- ResponseReasoningItem, ResponseStatusDetails, ResponseSummary, ResponseUsage,
- StreamEvent as ResponsesStreamEvent,
- };
- use pretty_assertions::assert_eq;
- use serde_json::json;
-
- use super::*;
-
- fn map_response_events(events: Vec<ResponsesStreamEvent>) -> Vec<LanguageModelCompletionEvent> {
- block_on(async {
- OpenAiResponseEventMapper::new()
- .map_stream(Box::pin(futures::stream::iter(events.into_iter().map(Ok))))
- .collect::<Vec<_>>()
- .await
- .into_iter()
- .map(Result::unwrap)
- .collect()
- })
- }
-
- fn response_item_message(id: &str) -> ResponseOutputItem {
- ResponseOutputItem::Message(ResponseOutputMessage {
- id: Some(id.to_string()),
- role: Some("assistant".to_string()),
- status: Some("in_progress".to_string()),
- content: vec![],
- })
- }
-
- fn response_item_function_call(id: &str, args: Option<&str>) -> ResponseOutputItem {
- ResponseOutputItem::FunctionCall(ResponseFunctionToolCall {
- id: Some(id.to_string()),
- status: Some("in_progress".to_string()),
- name: Some("get_weather".to_string()),
- call_id: Some("call_123".to_string()),
- arguments: args.map(|s| s.to_string()).unwrap_or_default(),
- })
- }
-
- #[gpui::test]
- fn tiktoken_rs_support(cx: &TestAppContext) {
- let request = LanguageModelRequest {
- thread_id: None,
- prompt_id: None,
- intent: None,
- messages: vec![LanguageModelRequestMessage {
- role: Role::User,
- content: vec![MessageContent::Text("message".into())],
- cache: false,
- reasoning_details: None,
- }],
- tools: vec![],
- tool_choice: None,
- stop: vec![],
- temperature: None,
- thinking_allowed: true,
- thinking_effort: None,
- speed: None,
- };
-
- // Validate that all models are supported by tiktoken-rs
- for model in Model::iter() {
- let count = cx
- .foreground_executor()
- .block_on(count_open_ai_tokens(
- request.clone(),
- model,
- &cx.app.borrow(),
- ))
- .unwrap();
- assert!(count > 0);
- }
- }
-
- #[test]
- fn responses_stream_maps_text_and_usage() {
- let events = vec![
- ResponsesStreamEvent::OutputItemAdded {
- output_index: 0,
- sequence_number: None,
- item: response_item_message("msg_123"),
- },
- ResponsesStreamEvent::OutputTextDelta {
- item_id: "msg_123".into(),
- output_index: 0,
- content_index: Some(0),
- delta: "Hello".into(),
- },
- ResponsesStreamEvent::Completed {
- response: ResponseSummary {
- usage: Some(ResponseUsage {
- input_tokens: Some(5),
- output_tokens: Some(3),
- total_tokens: Some(8),
- }),
- ..Default::default()
- },
- },
- ];
-
- let mapped = map_response_events(events);
- assert!(matches!(
- mapped[0],
- LanguageModelCompletionEvent::StartMessage { ref message_id } if message_id == "msg_123"
- ));
- assert!(matches!(
- mapped[1],
- LanguageModelCompletionEvent::Text(ref text) if text == "Hello"
- ));
- assert!(matches!(
- mapped[2],
- LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
- input_tokens: 5,
- output_tokens: 3,
- ..
- })
- ));
- assert!(matches!(
- mapped[3],
- LanguageModelCompletionEvent::Stop(StopReason::EndTurn)
- ));
- }
-
- #[test]
- fn into_open_ai_response_builds_complete_payload() {
- let tool_call_id = LanguageModelToolUseId::from("call-42");
- let tool_input = json!({ "city": "Boston" });
- let tool_arguments = serde_json::to_string(&tool_input).unwrap();
- let tool_use = LanguageModelToolUse {
- id: tool_call_id.clone(),
- name: Arc::from("get_weather"),
- raw_input: tool_arguments.clone(),
- input: tool_input,
- is_input_complete: true,
- thought_signature: None,
- };
- let tool_result = LanguageModelToolResult {
- tool_use_id: tool_call_id,
- tool_name: Arc::from("get_weather"),
- is_error: false,
- content: LanguageModelToolResultContent::Text(Arc::from("Sunny")),
- output: Some(json!({ "forecast": "Sunny" })),
- };
- let user_image = LanguageModelImage {
- source: SharedString::from("aGVsbG8="),
- size: None,
- };
- let expected_image_url = user_image.to_base64_url();
-
- let request = LanguageModelRequest {
- thread_id: Some("thread-123".into()),
- prompt_id: None,
- intent: None,
- messages: vec![
- LanguageModelRequestMessage {
- role: Role::System,
- content: vec![MessageContent::Text("System context".into())],
- cache: false,
- reasoning_details: None,
- },
- LanguageModelRequestMessage {
- role: Role::User,
- content: vec![
- MessageContent::Text("Please check the weather.".into()),
- MessageContent::Image(user_image),
- ],
- cache: false,
- reasoning_details: None,
- },
- LanguageModelRequestMessage {
- role: Role::Assistant,
- content: vec![
- MessageContent::Text("Looking that up.".into()),
- MessageContent::ToolUse(tool_use),
- ],
- cache: false,
- reasoning_details: None,
- },
- LanguageModelRequestMessage {
- role: Role::Assistant,
- content: vec![MessageContent::ToolResult(tool_result)],
- cache: false,
- reasoning_details: None,
- },
- ],
- tools: vec![LanguageModelRequestTool {
- name: "get_weather".into(),
- description: "Fetches the weather".into(),
- input_schema: json!({ "type": "object" }),
- use_input_streaming: false,
- }],
- tool_choice: Some(LanguageModelToolChoice::Any),
- stop: vec!["<STOP>".into()],
- temperature: None,
- thinking_allowed: false,
- thinking_effort: None,
- speed: None,
- };
-
- let response = into_open_ai_response(
- request,
- "custom-model",
- true,
- true,
- Some(2048),
- Some(ReasoningEffort::Low),
- );
-
- let serialized = serde_json::to_value(&response).unwrap();
- let expected = json!({
- "model": "custom-model",
- "input": [
- {
- "type": "message",
- "role": "system",
- "content": [
- { "type": "input_text", "text": "System context" }
- ]
- },
- {
- "type": "message",
- "role": "user",
- "content": [
- { "type": "input_text", "text": "Please check the weather." },
- { "type": "input_image", "image_url": expected_image_url }
- ]
- },
- {
- "type": "message",
- "role": "assistant",
- "content": [
- { "type": "output_text", "text": "Looking that up.", "annotations": [] }
- ]
- },
- {
- "type": "function_call",
- "call_id": "call-42",
- "name": "get_weather",
- "arguments": tool_arguments
- },
- {
- "type": "function_call_output",
- "call_id": "call-42",
- "output": "Sunny"
- }
- ],
- "stream": true,
- "max_output_tokens": 2048,
- "parallel_tool_calls": true,
- "tool_choice": "required",
- "tools": [
- {
- "type": "function",
- "name": "get_weather",
- "description": "Fetches the weather",
- "parameters": { "type": "object" }
- }
- ],
- "prompt_cache_key": "thread-123",
- "reasoning": { "effort": "low", "summary": "auto" }
- });
-
- assert_eq!(serialized, expected);
- }
-
- #[test]
- fn responses_stream_maps_tool_calls() {
- let events = vec![
- ResponsesStreamEvent::OutputItemAdded {
- output_index: 0,
- sequence_number: None,
- item: response_item_function_call("item_fn", Some("{\"city\":\"Bos")),
- },
- ResponsesStreamEvent::FunctionCallArgumentsDelta {
- item_id: "item_fn".into(),
- output_index: 0,
- delta: "ton\"}".into(),
- sequence_number: None,
- },
- ResponsesStreamEvent::FunctionCallArgumentsDone {
- item_id: "item_fn".into(),
- output_index: 0,
- arguments: "{\"city\":\"Boston\"}".into(),
- sequence_number: None,
- },
- ResponsesStreamEvent::Completed {
- response: ResponseSummary::default(),
- },
- ];
-
- let mapped = map_response_events(events);
- assert_eq!(mapped.len(), 3);
- // First event is the partial tool use (from FunctionCallArgumentsDelta)
- assert!(matches!(
- mapped[0],
- LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse {
- is_input_complete: false,
- ..
- })
- ));
- // Second event is the complete tool use (from FunctionCallArgumentsDone)
- assert!(matches!(
- mapped[1],
- LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse {
- ref id,
- ref name,
- ref raw_input,
- is_input_complete: true,
- ..
- }) if id.to_string() == "call_123"
- && name.as_ref() == "get_weather"
- && raw_input == "{\"city\":\"Boston\"}"
- ));
- assert!(matches!(
- mapped[2],
- LanguageModelCompletionEvent::Stop(StopReason::ToolUse)
- ));
- }
-
- #[test]
- fn responses_stream_uses_max_tokens_stop_reason() {
- let events = vec![ResponsesStreamEvent::Incomplete {
- response: ResponseSummary {
- status_details: Some(ResponseStatusDetails {
- reason: Some("max_output_tokens".into()),
- r#type: Some("incomplete".into()),
- error: None,
- }),
- usage: Some(ResponseUsage {
- input_tokens: Some(10),
- output_tokens: Some(20),
- total_tokens: Some(30),
- }),
- ..Default::default()
- },
- }];
-
- let mapped = map_response_events(events);
- assert!(matches!(
- mapped[0],
- LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
- input_tokens: 10,
- output_tokens: 20,
- ..
- })
- ));
- assert!(matches!(
- mapped[1],
- LanguageModelCompletionEvent::Stop(StopReason::MaxTokens)
- ));
- }
-
- #[test]
- fn responses_stream_handles_multiple_tool_calls() {
- let events = vec![
- ResponsesStreamEvent::OutputItemAdded {
- output_index: 0,
- sequence_number: None,
- item: response_item_function_call("item_fn1", Some("{\"city\":\"NYC\"}")),
- },
- ResponsesStreamEvent::FunctionCallArgumentsDone {
- item_id: "item_fn1".into(),
- output_index: 0,
- arguments: "{\"city\":\"NYC\"}".into(),
- sequence_number: None,
- },
- ResponsesStreamEvent::OutputItemAdded {
- output_index: 1,
- sequence_number: None,
- item: response_item_function_call("item_fn2", Some("{\"city\":\"LA\"}")),
- },
- ResponsesStreamEvent::FunctionCallArgumentsDone {
- item_id: "item_fn2".into(),
- output_index: 1,
- arguments: "{\"city\":\"LA\"}".into(),
- sequence_number: None,
- },
- ResponsesStreamEvent::Completed {
- response: ResponseSummary::default(),
- },
- ];
-
- let mapped = map_response_events(events);
- assert_eq!(mapped.len(), 3);
- assert!(matches!(
- mapped[0],
- LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { ref raw_input, .. })
- if raw_input == "{\"city\":\"NYC\"}"
- ));
- assert!(matches!(
- mapped[1],
- LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { ref raw_input, .. })
- if raw_input == "{\"city\":\"LA\"}"
- ));
- assert!(matches!(
- mapped[2],
- LanguageModelCompletionEvent::Stop(StopReason::ToolUse)
- ));
- }
-
- #[test]
- fn responses_stream_handles_mixed_text_and_tool_calls() {
- let events = vec![
- ResponsesStreamEvent::OutputItemAdded {
- output_index: 0,
- sequence_number: None,
- item: response_item_message("msg_123"),
- },
- ResponsesStreamEvent::OutputTextDelta {
- item_id: "msg_123".into(),
- output_index: 0,
- content_index: Some(0),
- delta: "Let me check that".into(),
- },
- ResponsesStreamEvent::OutputItemAdded {
- output_index: 1,
- sequence_number: None,
- item: response_item_function_call("item_fn", Some("{\"query\":\"test\"}")),
- },
- ResponsesStreamEvent::FunctionCallArgumentsDone {
- item_id: "item_fn".into(),
- output_index: 1,
- arguments: "{\"query\":\"test\"}".into(),
- sequence_number: None,
- },
- ResponsesStreamEvent::Completed {
- response: ResponseSummary::default(),
- },
- ];
-
- let mapped = map_response_events(events);
- assert!(matches!(
- mapped[0],
- LanguageModelCompletionEvent::StartMessage { .. }
- ));
- assert!(matches!(
- mapped[1],
- LanguageModelCompletionEvent::Text(ref text) if text == "Let me check that"
- ));
- assert!(matches!(
- mapped[2],
- LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { ref raw_input, .. })
- if raw_input == "{\"query\":\"test\"}"
- ));
- assert!(matches!(
- mapped[3],
- LanguageModelCompletionEvent::Stop(StopReason::ToolUse)
- ));
- }
-
- #[test]
- fn responses_stream_handles_json_parse_error() {
- let events = vec![
- ResponsesStreamEvent::OutputItemAdded {
- output_index: 0,
- sequence_number: None,
- item: response_item_function_call("item_fn", Some("{invalid json")),
- },
- ResponsesStreamEvent::FunctionCallArgumentsDone {
- item_id: "item_fn".into(),
- output_index: 0,
- arguments: "{invalid json".into(),
- sequence_number: None,
- },
- ResponsesStreamEvent::Completed {
- response: ResponseSummary::default(),
- },
- ];
-
- let mapped = map_response_events(events);
- assert!(matches!(
- mapped[0],
- LanguageModelCompletionEvent::ToolUseJsonParseError {
- ref raw_input,
- ..
- } if raw_input.as_ref() == "{invalid json"
- ));
- }
-
- #[test]
- fn responses_stream_handles_incomplete_function_call() {
- let events = vec![
- ResponsesStreamEvent::OutputItemAdded {
- output_index: 0,
- sequence_number: None,
- item: response_item_function_call("item_fn", Some("{\"city\":")),
- },
- ResponsesStreamEvent::FunctionCallArgumentsDelta {
- item_id: "item_fn".into(),
- output_index: 0,
- delta: "\"Boston\"".into(),
- sequence_number: None,
- },
- ResponsesStreamEvent::Incomplete {
- response: ResponseSummary {
- status_details: Some(ResponseStatusDetails {
- reason: Some("max_output_tokens".into()),
- r#type: Some("incomplete".into()),
- error: None,
- }),
- output: vec![response_item_function_call(
- "item_fn",
- Some("{\"city\":\"Boston\"}"),
- )],
- ..Default::default()
- },
- },
- ];
-
- let mapped = map_response_events(events);
- assert_eq!(mapped.len(), 3);
- // First event is the partial tool use (from FunctionCallArgumentsDelta)
- assert!(matches!(
- mapped[0],
- LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse {
- is_input_complete: false,
- ..
- })
- ));
- // Second event is the complete tool use (from the Incomplete response output)
- assert!(matches!(
- mapped[1],
- LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse {
- ref raw_input,
- is_input_complete: true,
- ..
- })
- if raw_input == "{\"city\":\"Boston\"}"
- ));
- assert!(matches!(
- mapped[2],
- LanguageModelCompletionEvent::Stop(StopReason::MaxTokens)
- ));
- }
-
- #[test]
- fn responses_stream_incomplete_does_not_duplicate_tool_calls() {
- let events = vec![
- ResponsesStreamEvent::OutputItemAdded {
- output_index: 0,
- sequence_number: None,
- item: response_item_function_call("item_fn", Some("{\"city\":\"Boston\"}")),
- },
- ResponsesStreamEvent::FunctionCallArgumentsDone {
- item_id: "item_fn".into(),
- output_index: 0,
- arguments: "{\"city\":\"Boston\"}".into(),
- sequence_number: None,
- },
- ResponsesStreamEvent::Incomplete {
- response: ResponseSummary {
- status_details: Some(ResponseStatusDetails {
- reason: Some("max_output_tokens".into()),
- r#type: Some("incomplete".into()),
- error: None,
- }),
- output: vec![response_item_function_call(
- "item_fn",
- Some("{\"city\":\"Boston\"}"),
- )],
- ..Default::default()
- },
- },
- ];
-
- let mapped = map_response_events(events);
- assert_eq!(mapped.len(), 2);
- assert!(matches!(
- mapped[0],
- LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { ref raw_input, .. })
- if raw_input == "{\"city\":\"Boston\"}"
- ));
- assert!(matches!(
- mapped[1],
- LanguageModelCompletionEvent::Stop(StopReason::MaxTokens)
- ));
- }
-
- #[test]
- fn responses_stream_handles_empty_tool_arguments() {
- // Test that tools with no arguments (empty string) are handled correctly
- let events = vec![
- ResponsesStreamEvent::OutputItemAdded {
- output_index: 0,
- sequence_number: None,
- item: response_item_function_call("item_fn", Some("")),
- },
- ResponsesStreamEvent::FunctionCallArgumentsDone {
- item_id: "item_fn".into(),
- output_index: 0,
- arguments: "".into(),
- sequence_number: None,
- },
- ResponsesStreamEvent::Completed {
- response: ResponseSummary::default(),
- },
- ];
-
- let mapped = map_response_events(events);
- assert_eq!(mapped.len(), 2);
-
- // Should produce a ToolUse event with an empty object
- assert!(matches!(
- &mapped[0],
- LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse {
- id,
- name,
- raw_input,
- input,
- ..
- }) if id.to_string() == "call_123"
- && name.as_ref() == "get_weather"
- && raw_input == ""
- && input.is_object()
- && input.as_object().unwrap().is_empty()
- ));
-
- assert!(matches!(
- mapped[1],
- LanguageModelCompletionEvent::Stop(StopReason::ToolUse)
- ));
- }
-
- #[test]
- fn responses_stream_emits_partial_tool_use_events() {
- let events = vec![
- ResponsesStreamEvent::OutputItemAdded {
- output_index: 0,
- sequence_number: None,
- item: ResponseOutputItem::FunctionCall(ResponseFunctionToolCall {
- id: Some("item_fn".to_string()),
- status: Some("in_progress".to_string()),
- name: Some("get_weather".to_string()),
- call_id: Some("call_abc".to_string()),
- arguments: String::new(),
- }),
- },
- ResponsesStreamEvent::FunctionCallArgumentsDelta {
- item_id: "item_fn".into(),
- output_index: 0,
- delta: "{\"city\":\"Bos".into(),
- sequence_number: None,
- },
- ResponsesStreamEvent::FunctionCallArgumentsDelta {
- item_id: "item_fn".into(),
- output_index: 0,
- delta: "ton\"}".into(),
- sequence_number: None,
- },
- ResponsesStreamEvent::FunctionCallArgumentsDone {
- item_id: "item_fn".into(),
- output_index: 0,
- arguments: "{\"city\":\"Boston\"}".into(),
- sequence_number: None,
- },
- ResponsesStreamEvent::Completed {
- response: ResponseSummary::default(),
- },
- ];
-
- let mapped = map_response_events(events);
- // Two partial events + one complete event + Stop
- assert!(mapped.len() >= 3);
-
- // The last complete ToolUse event should have is_input_complete: true
- let complete_tool_use = mapped.iter().find(|e| {
- matches!(
- e,
- LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse {
- is_input_complete: true,
- ..
- })
- )
- });
- assert!(
- complete_tool_use.is_some(),
- "should have a complete tool use event"
- );
-
- // All ToolUse events before the final one should have is_input_complete: false
- let tool_uses: Vec<_> = mapped
- .iter()
- .filter(|e| matches!(e, LanguageModelCompletionEvent::ToolUse(_)))
- .collect();
- assert!(
- tool_uses.len() >= 2,
- "should have at least one partial and one complete event"
- );
-
- let last = tool_uses.last().unwrap();
- assert!(matches!(
- last,
- LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse {
- is_input_complete: true,
- ..
- })
- ));
- }
-
- #[test]
- fn responses_stream_maps_reasoning_summary_deltas() {
- let events = vec![
- ResponsesStreamEvent::OutputItemAdded {
- output_index: 0,
- sequence_number: None,
- item: ResponseOutputItem::Reasoning(ResponseReasoningItem {
- id: Some("rs_123".into()),
- summary: vec![],
- }),
- },
- ResponsesStreamEvent::ReasoningSummaryPartAdded {
- item_id: "rs_123".into(),
- output_index: 0,
- summary_index: 0,
- },
- ResponsesStreamEvent::ReasoningSummaryTextDelta {
- item_id: "rs_123".into(),
- output_index: 0,
- delta: "Thinking about".into(),
- },
- ResponsesStreamEvent::ReasoningSummaryTextDelta {
- item_id: "rs_123".into(),
- output_index: 0,
- delta: " the answer".into(),
- },
- ResponsesStreamEvent::ReasoningSummaryTextDone {
- item_id: "rs_123".into(),
- output_index: 0,
- text: "Thinking about the answer".into(),
- },
- ResponsesStreamEvent::ReasoningSummaryPartDone {
- item_id: "rs_123".into(),
- output_index: 0,
- summary_index: 0,
- },
- ResponsesStreamEvent::ReasoningSummaryPartAdded {
- item_id: "rs_123".into(),
- output_index: 0,
- summary_index: 1,
- },
- ResponsesStreamEvent::ReasoningSummaryTextDelta {
- item_id: "rs_123".into(),
- output_index: 0,
- delta: "Second part".into(),
- },
- ResponsesStreamEvent::ReasoningSummaryTextDone {
- item_id: "rs_123".into(),
- output_index: 0,
- text: "Second part".into(),
- },
- ResponsesStreamEvent::ReasoningSummaryPartDone {
- item_id: "rs_123".into(),
- output_index: 0,
- summary_index: 1,
- },
- ResponsesStreamEvent::OutputItemDone {
- output_index: 0,
- sequence_number: None,
- item: ResponseOutputItem::Reasoning(ResponseReasoningItem {
- id: Some("rs_123".into()),
- summary: vec![
- ReasoningSummaryPart::SummaryText {
- text: "Thinking about the answer".into(),
- },
- ReasoningSummaryPart::SummaryText {
- text: "Second part".into(),
- },
- ],
- }),
- },
- ResponsesStreamEvent::OutputItemAdded {
- output_index: 1,
- sequence_number: None,
- item: response_item_message("msg_456"),
- },
- ResponsesStreamEvent::OutputTextDelta {
- item_id: "msg_456".into(),
- output_index: 1,
- content_index: Some(0),
- delta: "The answer is 42".into(),
- },
- ResponsesStreamEvent::Completed {
- response: ResponseSummary::default(),
- },
- ];
-
- let mapped = map_response_events(events);
-
- let thinking_events: Vec<_> = mapped
- .iter()
- .filter(|e| matches!(e, LanguageModelCompletionEvent::Thinking { .. }))
- .collect();
- assert_eq!(
- thinking_events.len(),
- 4,
- "expected 4 thinking events (2 deltas + separator + second delta), got {:?}",
- thinking_events,
- );
-
- assert!(matches!(
- &thinking_events[0],
- LanguageModelCompletionEvent::Thinking { text, .. } if text == "Thinking about"
- ));
- assert!(matches!(
- &thinking_events[1],
- LanguageModelCompletionEvent::Thinking { text, .. } if text == " the answer"
- ));
- assert!(
- matches!(
- &thinking_events[2],
- LanguageModelCompletionEvent::Thinking { text, .. } if text == "\n\n"
- ),
- "expected separator between summary parts"
- );
- assert!(matches!(
- &thinking_events[3],
- LanguageModelCompletionEvent::Thinking { text, .. } if text == "Second part"
- ));
-
- assert!(mapped.iter().any(|e| matches!(
- e,
- LanguageModelCompletionEvent::Text(t) if t == "The answer is 42"
- )));
- }
-
- #[test]
- fn responses_stream_maps_reasoning_from_done_only() {
- let events = vec![
- ResponsesStreamEvent::OutputItemAdded {
- output_index: 0,
- sequence_number: None,
- item: ResponseOutputItem::Reasoning(ResponseReasoningItem {
- id: Some("rs_789".into()),
- summary: vec![],
- }),
- },
- ResponsesStreamEvent::OutputItemDone {
- output_index: 0,
- sequence_number: None,
- item: ResponseOutputItem::Reasoning(ResponseReasoningItem {
- id: Some("rs_789".into()),
- summary: vec![ReasoningSummaryPart::SummaryText {
- text: "Summary without deltas".into(),
- }],
- }),
- },
- ResponsesStreamEvent::Completed {
- response: ResponseSummary::default(),
- },
- ];
-
- let mapped = map_response_events(events);
-
- assert!(
- !mapped
- .iter()
- .any(|e| matches!(e, LanguageModelCompletionEvent::Thinking { .. })),
- "OutputItemDone reasoning should not produce Thinking events (no delta/done text events)"
- );
- }
-}
@@ -402,7 +402,7 @@ impl LanguageModel for OpenAiCompatibleLanguageModel {
self.model.capabilities.parallel_tool_calls,
self.model.capabilities.prompt_cache_key,
self.max_output_tokens(),
- self.model.reasoning_effort.clone(),
+ self.model.reasoning_effort,
);
let completions = self.stream_completion(request, cx);
async move {
@@ -417,7 +417,7 @@ impl LanguageModel for OpenAiCompatibleLanguageModel {
self.model.capabilities.parallel_tool_calls,
self.model.capabilities.prompt_cache_key,
self.max_output_tokens(),
- self.model.reasoning_effort.clone(),
+ self.model.reasoning_effort,
);
let completions = self.stream_response(request, cx);
async move {
@@ -22,7 +22,7 @@ use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*};
use ui_input::InputField;
use util::ResultExt;
-use crate::provider::util::{fix_streamed_json, parse_tool_arguments};
+use language_model::util::{fix_streamed_json, parse_tool_arguments};
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openrouter");
const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("OpenRouter");
@@ -9,7 +9,7 @@ use language_model::{
LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolSchemaFormat, RateLimiter,
- Role, env_var,
+ env_var,
};
use open_ai::ResponseStreamEvent;
pub use settings::XaiAvailableModel as AvailableModel;
@@ -19,7 +19,8 @@ use strum::IntoEnumIterator;
use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*};
use ui_input::InputField;
use util::ResultExt;
-use x_ai::{Model, XAI_API_URL};
+use x_ai::XAI_API_URL;
+pub use x_ai::completion::count_xai_tokens;
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("x_ai");
const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("xAI");
@@ -320,7 +321,9 @@ impl LanguageModel for XAiLanguageModel {
request: LanguageModelRequest,
cx: &App,
) -> BoxFuture<'static, Result<u64>> {
- count_xai_tokens(request, self.model.clone(), cx)
+ let model = self.model.clone();
+ cx.background_spawn(async move { count_xai_tokens(request, model) })
+ .boxed()
}
fn stream_completion(
@@ -354,37 +357,6 @@ impl LanguageModel for XAiLanguageModel {
}
}
-pub fn count_xai_tokens(
- request: LanguageModelRequest,
- model: Model,
- cx: &App,
-) -> BoxFuture<'static, Result<u64>> {
- cx.background_spawn(async move {
- let messages = request
- .messages
- .into_iter()
- .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
- role: match message.role {
- Role::User => "user".into(),
- Role::Assistant => "assistant".into(),
- Role::System => "system".into(),
- },
- content: Some(message.string_contents()),
- name: None,
- function_call: None,
- })
- .collect::<Vec<_>>();
-
- let model_name = if model.max_token_count() >= 100_000 {
- "gpt-4o"
- } else {
- "gpt-4"
- };
- tiktoken_rs::num_tokens_from_messages(model_name, &messages).map(|tokens| tokens as u64)
- })
- .boxed()
-}
-
struct ConfigurationView {
api_key_editor: Entity<InputField>,
state: Entity<State>,
@@ -0,0 +1,33 @@
+[package]
+name = "language_models_cloud"
+version = "0.1.0"
+edition.workspace = true
+publish.workspace = true
+license = "GPL-3.0-or-later"
+
+[lints]
+workspace = true
+
+[lib]
+path = "src/language_models_cloud.rs"
+
+[dependencies]
+anthropic = { workspace = true, features = ["schemars"] }
+anyhow.workspace = true
+cloud_llm_client.workspace = true
+futures.workspace = true
+google_ai = { workspace = true, features = ["schemars"] }
+gpui.workspace = true
+http_client.workspace = true
+language_model.workspace = true
+open_ai = { workspace = true, features = ["schemars"] }
+schemars.workspace = true
+semver.workspace = true
+serde.workspace = true
+serde_json.workspace = true
+smol.workspace = true
+thiserror.workspace = true
+x_ai = { workspace = true, features = ["schemars"] }
+
+[dev-dependencies]
+language_model = { workspace = true, features = ["test-support"] }
@@ -0,0 +1 @@
+../../LICENSE-GPL
@@ -0,0 +1,1059 @@
+use anthropic::AnthropicModelMode;
+use anyhow::{Context as _, Result, anyhow};
+use cloud_llm_client::{
+ CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME,
+ CLIENT_SUPPORTS_X_AI_HEADER_NAME, CompletionBody, CompletionEvent, CompletionRequestStatus,
+ CountTokensBody, CountTokensResponse, EXPIRED_LLM_TOKEN_HEADER_NAME, ListModelsResponse,
+ OUTDATED_LLM_TOKEN_HEADER_NAME, SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME,
+ ZED_VERSION_HEADER_NAME,
+};
+use futures::{
+ AsyncBufReadExt, FutureExt, Stream, StreamExt,
+ future::BoxFuture,
+ stream::{self, BoxStream},
+};
+use google_ai::GoogleModelMode;
+use gpui::{App, AppContext, AsyncApp, Context, Task};
+use http_client::http::{HeaderMap, HeaderValue};
+use http_client::{
+ AsyncBody, HttpClient, HttpClientWithUrl, HttpRequestExt, Method, Response, StatusCode,
+};
+use language_model::{
+ ANTHROPIC_PROVIDER_ID, ANTHROPIC_PROVIDER_NAME, GOOGLE_PROVIDER_ID, GOOGLE_PROVIDER_NAME,
+ LanguageModel, LanguageModelCacheConfiguration, LanguageModelCompletionError,
+ LanguageModelCompletionEvent, LanguageModelEffortLevel, LanguageModelId, LanguageModelName,
+ LanguageModelProviderId, LanguageModelProviderName, LanguageModelRequest,
+ LanguageModelToolChoice, LanguageModelToolSchemaFormat, OPEN_AI_PROVIDER_ID,
+ OPEN_AI_PROVIDER_NAME, PaymentRequiredError, RateLimiter, X_AI_PROVIDER_ID, X_AI_PROVIDER_NAME,
+ ZED_CLOUD_PROVIDER_ID, ZED_CLOUD_PROVIDER_NAME,
+};
+
+use schemars::JsonSchema;
+use semver::Version;
+use serde::{Deserialize, Serialize, de::DeserializeOwned};
+use smol::io::{AsyncReadExt, BufReader};
+use std::collections::VecDeque;
+use std::pin::Pin;
+use std::str::FromStr;
+use std::sync::Arc;
+use std::task::Poll;
+use std::time::Duration;
+use thiserror::Error;
+
+use anthropic::completion::{
+ AnthropicEventMapper, count_anthropic_tokens_with_tiktoken, into_anthropic,
+};
+use google_ai::completion::{GoogleEventMapper, into_google};
+use open_ai::completion::{
+ OpenAiEventMapper, OpenAiResponseEventMapper, count_open_ai_tokens, into_open_ai,
+ into_open_ai_response,
+};
+use x_ai::completion::count_xai_tokens;
+
+const PROVIDER_ID: LanguageModelProviderId = ZED_CLOUD_PROVIDER_ID;
+const PROVIDER_NAME: LanguageModelProviderName = ZED_CLOUD_PROVIDER_NAME;
+
+/// Trait for acquiring and refreshing LLM authentication tokens.
+pub trait CloudLlmTokenProvider: Send + Sync {
+ type AuthContext: Clone + Send + 'static;
+
+ fn auth_context(&self, cx: &AsyncApp) -> Self::AuthContext;
+ fn acquire_token(&self, auth_context: Self::AuthContext) -> BoxFuture<'static, Result<String>>;
+ fn refresh_token(&self, auth_context: Self::AuthContext) -> BoxFuture<'static, Result<String>>;
+}
+
+#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
+#[serde(tag = "type", rename_all = "lowercase")]
+pub enum ModelMode {
+ #[default]
+ Default,
+ Thinking {
+ /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
+ budget_tokens: Option<u32>,
+ },
+}
+
+impl From<ModelMode> for AnthropicModelMode {
+ fn from(value: ModelMode) -> Self {
+ match value {
+ ModelMode::Default => AnthropicModelMode::Default,
+ ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens },
+ }
+ }
+}
+
+pub struct CloudLanguageModel<TP: CloudLlmTokenProvider> {
+ pub id: LanguageModelId,
+ pub model: Arc<cloud_llm_client::LanguageModel>,
+ pub token_provider: Arc<TP>,
+ pub http_client: Arc<HttpClientWithUrl>,
+ pub app_version: Option<Version>,
+ pub request_limiter: RateLimiter,
+}
+
+pub struct PerformLlmCompletionResponse {
+ pub response: Response<AsyncBody>,
+ pub includes_status_messages: bool,
+}
+
+impl<TP: CloudLlmTokenProvider> CloudLanguageModel<TP> {
+ pub async fn perform_llm_completion(
+ http_client: &HttpClientWithUrl,
+ token_provider: &TP,
+ auth_context: TP::AuthContext,
+ app_version: Option<Version>,
+ body: CompletionBody,
+ ) -> Result<PerformLlmCompletionResponse> {
+ let mut token = token_provider.acquire_token(auth_context.clone()).await?;
+ let mut refreshed_token = false;
+
+ loop {
+ let request = http_client::Request::builder()
+ .method(Method::POST)
+ .uri(http_client.build_zed_llm_url("/completions", &[])?.as_ref())
+ .when_some(app_version.as_ref(), |builder, app_version| {
+ builder.header(ZED_VERSION_HEADER_NAME, app_version.to_string())
+ })
+ .header("Content-Type", "application/json")
+ .header("Authorization", format!("Bearer {token}"))
+ .header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true")
+ .header(CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME, "true")
+ .body(serde_json::to_string(&body)?.into())?;
+
+ let mut response = http_client.send(request).await?;
+ let status = response.status();
+ if status.is_success() {
+ let includes_status_messages = response
+ .headers()
+ .get(SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME)
+ .is_some();
+
+ return Ok(PerformLlmCompletionResponse {
+ response,
+ includes_status_messages,
+ });
+ }
+
+ if !refreshed_token && needs_llm_token_refresh(&response) {
+ token = token_provider.refresh_token(auth_context.clone()).await?;
+ refreshed_token = true;
+ continue;
+ }
+
+ if status == StatusCode::PAYMENT_REQUIRED {
+ return Err(anyhow!(PaymentRequiredError));
+ }
+
+ let mut body = String::new();
+ let headers = response.headers().clone();
+ response.body_mut().read_to_string(&mut body).await?;
+ return Err(anyhow!(ApiError {
+ status,
+ body,
+ headers
+ }));
+ }
+ }
+}
+
+fn needs_llm_token_refresh(response: &Response<AsyncBody>) -> bool {
+ response
+ .headers()
+ .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
+ .is_some()
+ || response
+ .headers()
+ .get(OUTDATED_LLM_TOKEN_HEADER_NAME)
+ .is_some()
+}
+
+#[derive(Debug, Error)]
+#[error("cloud language model request failed with status {status}: {body}")]
+struct ApiError {
+ status: StatusCode,
+ body: String,
+ headers: HeaderMap<HeaderValue>,
+}
+
+/// Represents error responses from Zed's cloud API.
+///
+/// Example JSON for an upstream HTTP error:
+/// ```json
+/// {
+/// "code": "upstream_http_error",
+/// "message": "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout",
+/// "upstream_status": 503
+/// }
+/// ```
+#[derive(Debug, serde::Deserialize)]
+struct CloudApiError {
+ code: String,
+ message: String,
+ #[serde(default)]
+ #[serde(deserialize_with = "deserialize_optional_status_code")]
+ upstream_status: Option<StatusCode>,
+ #[serde(default)]
+ retry_after: Option<f64>,
+}
+
+fn deserialize_optional_status_code<'de, D>(deserializer: D) -> Result<Option<StatusCode>, D::Error>
+where
+ D: serde::Deserializer<'de>,
+{
+ let opt: Option<u16> = Option::deserialize(deserializer)?;
+ Ok(opt.and_then(|code| StatusCode::from_u16(code).ok()))
+}
+
+impl From<ApiError> for LanguageModelCompletionError {
+ fn from(error: ApiError) -> Self {
+ if let Ok(cloud_error) = serde_json::from_str::<CloudApiError>(&error.body) {
+ if cloud_error.code.starts_with("upstream_http_") {
+ let status = if let Some(status) = cloud_error.upstream_status {
+ status
+ } else if cloud_error.code.ends_with("_error") {
+ error.status
+ } else {
+ // If there's a status code in the code string (e.g. "upstream_http_429")
+ // then use that; otherwise, see if the JSON contains a status code.
+ cloud_error
+ .code
+ .strip_prefix("upstream_http_")
+ .and_then(|code_str| code_str.parse::<u16>().ok())
+ .and_then(|code| StatusCode::from_u16(code).ok())
+ .unwrap_or(error.status)
+ };
+
+ return LanguageModelCompletionError::UpstreamProviderError {
+ message: cloud_error.message,
+ status,
+ retry_after: cloud_error.retry_after.map(Duration::from_secs_f64),
+ };
+ }
+
+ return LanguageModelCompletionError::from_http_status(
+ PROVIDER_NAME,
+ error.status,
+ cloud_error.message,
+ None,
+ );
+ }
+
+ let retry_after = None;
+ LanguageModelCompletionError::from_http_status(
+ PROVIDER_NAME,
+ error.status,
+ error.body,
+ retry_after,
+ )
+ }
+}
+
+impl<TP: CloudLlmTokenProvider + 'static> LanguageModel for CloudLanguageModel<TP> {
+ fn id(&self) -> LanguageModelId {
+ self.id.clone()
+ }
+
+ fn name(&self) -> LanguageModelName {
+ LanguageModelName::from(self.model.display_name.clone())
+ }
+
+ fn provider_id(&self) -> LanguageModelProviderId {
+ PROVIDER_ID
+ }
+
+ fn provider_name(&self) -> LanguageModelProviderName {
+ PROVIDER_NAME
+ }
+
+ fn upstream_provider_id(&self) -> LanguageModelProviderId {
+ use cloud_llm_client::LanguageModelProvider::*;
+ match self.model.provider {
+ Anthropic => ANTHROPIC_PROVIDER_ID,
+ OpenAi => OPEN_AI_PROVIDER_ID,
+ Google => GOOGLE_PROVIDER_ID,
+ XAi => X_AI_PROVIDER_ID,
+ }
+ }
+
+ fn upstream_provider_name(&self) -> LanguageModelProviderName {
+ use cloud_llm_client::LanguageModelProvider::*;
+ match self.model.provider {
+ Anthropic => ANTHROPIC_PROVIDER_NAME,
+ OpenAi => OPEN_AI_PROVIDER_NAME,
+ Google => GOOGLE_PROVIDER_NAME,
+ XAi => X_AI_PROVIDER_NAME,
+ }
+ }
+
+ fn is_latest(&self) -> bool {
+ self.model.is_latest
+ }
+
+ fn supports_tools(&self) -> bool {
+ self.model.supports_tools
+ }
+
+ fn supports_images(&self) -> bool {
+ self.model.supports_images
+ }
+
+ fn supports_thinking(&self) -> bool {
+ self.model.supports_thinking
+ }
+
+ fn supports_fast_mode(&self) -> bool {
+ self.model.supports_fast_mode
+ }
+
+ fn supported_effort_levels(&self) -> Vec<LanguageModelEffortLevel> {
+ self.model
+ .supported_effort_levels
+ .iter()
+ .map(|effort_level| LanguageModelEffortLevel {
+ name: effort_level.name.clone().into(),
+ value: effort_level.value.clone().into(),
+ is_default: effort_level.is_default.unwrap_or(false),
+ })
+ .collect()
+ }
+
+ fn supports_streaming_tools(&self) -> bool {
+ self.model.supports_streaming_tools
+ }
+
+ fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
+ match choice {
+ LanguageModelToolChoice::Auto
+ | LanguageModelToolChoice::Any
+ | LanguageModelToolChoice::None => true,
+ }
+ }
+
+ fn supports_split_token_display(&self) -> bool {
+ use cloud_llm_client::LanguageModelProvider::*;
+ matches!(self.model.provider, OpenAi | XAi)
+ }
+
+ fn telemetry_id(&self) -> String {
+ format!("zed.dev/{}", self.model.id)
+ }
+
+ fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
+ match self.model.provider {
+ cloud_llm_client::LanguageModelProvider::Anthropic
+ | cloud_llm_client::LanguageModelProvider::OpenAi => {
+ LanguageModelToolSchemaFormat::JsonSchema
+ }
+ cloud_llm_client::LanguageModelProvider::Google
+ | cloud_llm_client::LanguageModelProvider::XAi => {
+ LanguageModelToolSchemaFormat::JsonSchemaSubset
+ }
+ }
+ }
+
+ fn max_token_count(&self) -> u64 {
+ self.model.max_token_count as u64
+ }
+
+ fn max_output_tokens(&self) -> Option<u64> {
+ Some(self.model.max_output_tokens as u64)
+ }
+
+ fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
+ match &self.model.provider {
+ cloud_llm_client::LanguageModelProvider::Anthropic => {
+ Some(LanguageModelCacheConfiguration {
+ min_total_token: 2_048,
+ should_speculate: true,
+ max_cache_anchors: 4,
+ })
+ }
+ cloud_llm_client::LanguageModelProvider::OpenAi
+ | cloud_llm_client::LanguageModelProvider::XAi
+ | cloud_llm_client::LanguageModelProvider::Google => None,
+ }
+ }
+
+ fn count_tokens(
+ &self,
+ request: LanguageModelRequest,
+ cx: &App,
+ ) -> BoxFuture<'static, Result<u64>> {
+ match self.model.provider {
+ cloud_llm_client::LanguageModelProvider::Anthropic => cx
+ .background_spawn(async move { count_anthropic_tokens_with_tiktoken(request) })
+ .boxed(),
+ cloud_llm_client::LanguageModelProvider::OpenAi => {
+ let model = match open_ai::Model::from_id(&self.model.id.0) {
+ Ok(model) => model,
+ Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
+ };
+ cx.background_spawn(async move { count_open_ai_tokens(request, model) })
+ .boxed()
+ }
+ cloud_llm_client::LanguageModelProvider::XAi => {
+ let model = match x_ai::Model::from_id(&self.model.id.0) {
+ Ok(model) => model,
+ Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
+ };
+ cx.background_spawn(async move { count_xai_tokens(request, model) })
+ .boxed()
+ }
+ cloud_llm_client::LanguageModelProvider::Google => {
+ let http_client = self.http_client.clone();
+ let token_provider = self.token_provider.clone();
+ let model_id = self.model.id.to_string();
+ let generate_content_request =
+ into_google(request, model_id.clone(), GoogleModelMode::Default);
+ let auth_context = token_provider.auth_context(&cx.to_async());
+ async move {
+ let token = token_provider.acquire_token(auth_context).await?;
+
+ let request_body = CountTokensBody {
+ provider: cloud_llm_client::LanguageModelProvider::Google,
+ model: model_id,
+ provider_request: serde_json::to_value(&google_ai::CountTokensRequest {
+ generate_content_request,
+ })?,
+ };
+ let request = http_client::Request::builder()
+ .method(Method::POST)
+ .uri(
+ http_client
+ .build_zed_llm_url("/count_tokens", &[])?
+ .as_ref(),
+ )
+ .header("Content-Type", "application/json")
+ .header("Authorization", format!("Bearer {token}"))
+ .body(serde_json::to_string(&request_body)?.into())?;
+ let mut response = http_client.send(request).await?;
+ let status = response.status();
+ let headers = response.headers().clone();
+ let mut response_body = String::new();
+ response
+ .body_mut()
+ .read_to_string(&mut response_body)
+ .await?;
+
+ if status.is_success() {
+ let response_body: CountTokensResponse =
+ serde_json::from_str(&response_body)?;
+
+ Ok(response_body.tokens as u64)
+ } else {
+ Err(anyhow!(ApiError {
+ status,
+ body: response_body,
+ headers
+ }))
+ }
+ }
+ .boxed()
+ }
+ }
+ }
+
+ fn stream_completion(
+ &self,
+ request: LanguageModelRequest,
+ cx: &AsyncApp,
+ ) -> BoxFuture<
+ 'static,
+ Result<
+ BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
+ LanguageModelCompletionError,
+ >,
+ > {
+ let thread_id = request.thread_id.clone();
+ let prompt_id = request.prompt_id.clone();
+ let app_version = self.app_version.clone();
+ let thinking_allowed = request.thinking_allowed;
+ let enable_thinking = thinking_allowed && self.model.supports_thinking;
+ let provider_name = provider_name(&self.model.provider);
+ match self.model.provider {
+ cloud_llm_client::LanguageModelProvider::Anthropic => {
+ let effort = request
+ .thinking_effort
+ .as_ref()
+ .and_then(|effort| anthropic::Effort::from_str(effort).ok());
+
+ let mut request = into_anthropic(
+ request,
+ self.model.id.to_string(),
+ 1.0,
+ self.model.max_output_tokens as u64,
+ if enable_thinking {
+ AnthropicModelMode::Thinking {
+ budget_tokens: Some(4_096),
+ }
+ } else {
+ AnthropicModelMode::Default
+ },
+ );
+
+ if enable_thinking && effort.is_some() {
+ request.thinking = Some(anthropic::Thinking::Adaptive);
+ request.output_config = Some(anthropic::OutputConfig { effort });
+ }
+
+ let http_client = self.http_client.clone();
+ let token_provider = self.token_provider.clone();
+ let auth_context = token_provider.auth_context(cx);
+ let future = self.request_limiter.stream(async move {
+ let PerformLlmCompletionResponse {
+ response,
+ includes_status_messages,
+ } = Self::perform_llm_completion(
+ &http_client,
+ &*token_provider,
+ auth_context,
+ app_version,
+ CompletionBody {
+ thread_id,
+ prompt_id,
+ provider: cloud_llm_client::LanguageModelProvider::Anthropic,
+ model: request.model.clone(),
+ provider_request: serde_json::to_value(&request)
+ .map_err(|e| anyhow!(e))?,
+ },
+ )
+ .await
+ .map_err(|err| match err.downcast::<ApiError>() {
+ Ok(api_err) => anyhow!(LanguageModelCompletionError::from(api_err)),
+ Err(err) => anyhow!(err),
+ })?;
+
+ let mut mapper = AnthropicEventMapper::new();
+ Ok(map_cloud_completion_events(
+ Box::pin(response_lines(response, includes_status_messages)),
+ &provider_name,
+ move |event| mapper.map_event(event),
+ ))
+ });
+ async move { Ok(future.await?.boxed()) }.boxed()
+ }
+ cloud_llm_client::LanguageModelProvider::OpenAi => {
+ let http_client = self.http_client.clone();
+ let token_provider = self.token_provider.clone();
+ let effort = request
+ .thinking_effort
+ .as_ref()
+ .and_then(|effort| open_ai::ReasoningEffort::from_str(effort).ok());
+
+ let mut request = into_open_ai_response(
+ request,
+ &self.model.id.0,
+ self.model.supports_parallel_tool_calls,
+ true,
+ None,
+ None,
+ );
+
+ if enable_thinking && let Some(effort) = effort {
+ request.reasoning = Some(open_ai::responses::ReasoningConfig {
+ effort,
+ summary: Some(open_ai::responses::ReasoningSummaryMode::Auto),
+ });
+ }
+
+ let auth_context = token_provider.auth_context(cx);
+ let future = self.request_limiter.stream(async move {
+ let PerformLlmCompletionResponse {
+ response,
+ includes_status_messages,
+ } = Self::perform_llm_completion(
+ &http_client,
+ &*token_provider,
+ auth_context,
+ app_version,
+ CompletionBody {
+ thread_id,
+ prompt_id,
+ provider: cloud_llm_client::LanguageModelProvider::OpenAi,
+ model: request.model.clone(),
+ provider_request: serde_json::to_value(&request)
+ .map_err(|e| anyhow!(e))?,
+ },
+ )
+ .await?;
+
+ let mut mapper = OpenAiResponseEventMapper::new();
+ Ok(map_cloud_completion_events(
+ Box::pin(response_lines(response, includes_status_messages)),
+ &provider_name,
+ move |event| mapper.map_event(event),
+ ))
+ });
+ async move { Ok(future.await?.boxed()) }.boxed()
+ }
+ cloud_llm_client::LanguageModelProvider::XAi => {
+ let http_client = self.http_client.clone();
+ let token_provider = self.token_provider.clone();
+ let request = into_open_ai(
+ request,
+ &self.model.id.0,
+ self.model.supports_parallel_tool_calls,
+ false,
+ None,
+ None,
+ );
+ let auth_context = token_provider.auth_context(cx);
+ let future = self.request_limiter.stream(async move {
+ let PerformLlmCompletionResponse {
+ response,
+ includes_status_messages,
+ } = Self::perform_llm_completion(
+ &http_client,
+ &*token_provider,
+ auth_context,
+ app_version,
+ CompletionBody {
+ thread_id,
+ prompt_id,
+ provider: cloud_llm_client::LanguageModelProvider::XAi,
+ model: request.model.clone(),
+ provider_request: serde_json::to_value(&request)
+ .map_err(|e| anyhow!(e))?,
+ },
+ )
+ .await?;
+
+ let mut mapper = OpenAiEventMapper::new();
+ Ok(map_cloud_completion_events(
+ Box::pin(response_lines(response, includes_status_messages)),
+ &provider_name,
+ move |event| mapper.map_event(event),
+ ))
+ });
+ async move { Ok(future.await?.boxed()) }.boxed()
+ }
+ cloud_llm_client::LanguageModelProvider::Google => {
+ let http_client = self.http_client.clone();
+ let token_provider = self.token_provider.clone();
+ let request =
+ into_google(request, self.model.id.to_string(), GoogleModelMode::Default);
+ let auth_context = token_provider.auth_context(cx);
+ let future = self.request_limiter.stream(async move {
+ let PerformLlmCompletionResponse {
+ response,
+ includes_status_messages,
+ } = Self::perform_llm_completion(
+ &http_client,
+ &*token_provider,
+ auth_context,
+ app_version,
+ CompletionBody {
+ thread_id,
+ prompt_id,
+ provider: cloud_llm_client::LanguageModelProvider::Google,
+ model: request.model.model_id.clone(),
+ provider_request: serde_json::to_value(&request)
+ .map_err(|e| anyhow!(e))?,
+ },
+ )
+ .await?;
+
+ let mut mapper = GoogleEventMapper::new();
+ Ok(map_cloud_completion_events(
+ Box::pin(response_lines(response, includes_status_messages)),
+ &provider_name,
+ move |event| mapper.map_event(event),
+ ))
+ });
+ async move { Ok(future.await?.boxed()) }.boxed()
+ }
+ }
+ }
+}
+
+pub struct CloudModelProvider<TP: CloudLlmTokenProvider> {
+ token_provider: Arc<TP>,
+ http_client: Arc<HttpClientWithUrl>,
+ app_version: Option<Version>,
+ models: Vec<Arc<cloud_llm_client::LanguageModel>>,
+ default_model: Option<Arc<cloud_llm_client::LanguageModel>>,
+ default_fast_model: Option<Arc<cloud_llm_client::LanguageModel>>,
+ recommended_models: Vec<Arc<cloud_llm_client::LanguageModel>>,
+}
+
+impl<TP: CloudLlmTokenProvider + 'static> CloudModelProvider<TP> {
+ pub fn new(
+ token_provider: Arc<TP>,
+ http_client: Arc<HttpClientWithUrl>,
+ app_version: Option<Version>,
+ ) -> Self {
+ Self {
+ token_provider,
+ http_client,
+ app_version,
+ models: Vec::new(),
+ default_model: None,
+ default_fast_model: None,
+ recommended_models: Vec::new(),
+ }
+ }
+
+ pub fn refresh_models(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
+ let http_client = self.http_client.clone();
+ let token_provider = self.token_provider.clone();
+ cx.spawn(async move |this, cx| {
+ let auth_context = token_provider.auth_context(cx);
+ let response =
+ Self::fetch_models_request(&http_client, &*token_provider, auth_context).await?;
+ this.update(cx, |this, cx| {
+ this.update_models(response);
+ cx.notify();
+ })
+ })
+ }
+
+ async fn fetch_models_request(
+ http_client: &HttpClientWithUrl,
+ token_provider: &TP,
+ auth_context: TP::AuthContext,
+ ) -> Result<ListModelsResponse> {
+ let token = token_provider.acquire_token(auth_context).await?;
+
+ let request = http_client::Request::builder()
+ .method(Method::GET)
+ .header(CLIENT_SUPPORTS_X_AI_HEADER_NAME, "true")
+ .uri(http_client.build_zed_llm_url("/models", &[])?.as_ref())
+ .header("Authorization", format!("Bearer {token}"))
+ .body(AsyncBody::empty())?;
+ let mut response = http_client
+ .send(request)
+ .await
+ .context("failed to send list models request")?;
+
+ if response.status().is_success() {
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
+ Ok(serde_json::from_str(&body)?)
+ } else {
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
+ anyhow::bail!(
+ "error listing models.\nStatus: {:?}\nBody: {body}",
+ response.status(),
+ );
+ }
+ }
+
+ pub fn update_models(&mut self, response: ListModelsResponse) {
+ let models: Vec<_> = response.models.into_iter().map(Arc::new).collect();
+
+ self.default_model = models
+ .iter()
+ .find(|model| {
+ response
+ .default_model
+ .as_ref()
+ .is_some_and(|default_model_id| &model.id == default_model_id)
+ })
+ .cloned();
+ self.default_fast_model = models
+ .iter()
+ .find(|model| {
+ response
+ .default_fast_model
+ .as_ref()
+ .is_some_and(|default_fast_model_id| &model.id == default_fast_model_id)
+ })
+ .cloned();
+ self.recommended_models = response
+ .recommended_models
+ .iter()
+ .filter_map(|id| models.iter().find(|model| &model.id == id))
+ .cloned()
+ .collect();
+ self.models = models;
+ }
+
+ pub fn create_model(
+ &self,
+ model: &Arc<cloud_llm_client::LanguageModel>,
+ ) -> Arc<dyn LanguageModel> {
+ Arc::new(CloudLanguageModel::<TP> {
+ id: LanguageModelId::from(model.id.0.to_string()),
+ model: model.clone(),
+ token_provider: self.token_provider.clone(),
+ http_client: self.http_client.clone(),
+ app_version: self.app_version.clone(),
+ request_limiter: RateLimiter::new(4),
+ })
+ }
+
+ pub fn models(&self) -> &[Arc<cloud_llm_client::LanguageModel>] {
+ &self.models
+ }
+
+ pub fn default_model(&self) -> Option<&Arc<cloud_llm_client::LanguageModel>> {
+ self.default_model.as_ref()
+ }
+
+ pub fn default_fast_model(&self) -> Option<&Arc<cloud_llm_client::LanguageModel>> {
+ self.default_fast_model.as_ref()
+ }
+
+ pub fn recommended_models(&self) -> &[Arc<cloud_llm_client::LanguageModel>] {
+ &self.recommended_models
+ }
+}
+
+pub fn map_cloud_completion_events<T, F>(
+ stream: Pin<Box<dyn Stream<Item = Result<CompletionEvent<T>>> + Send>>,
+ provider: &LanguageModelProviderName,
+ mut map_callback: F,
+) -> BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
+where
+ T: DeserializeOwned + 'static,
+ F: FnMut(T) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
+ + Send
+ + 'static,
+{
+ let provider = provider.clone();
+ let mut stream = stream.fuse();
+
+ let mut saw_stream_ended = false;
+
+ let mut done = false;
+ let mut pending = VecDeque::new();
+
+ stream::poll_fn(move |cx| {
+ loop {
+ if let Some(item) = pending.pop_front() {
+ return Poll::Ready(Some(item));
+ }
+
+ if done {
+ return Poll::Ready(None);
+ }
+
+ match stream.poll_next_unpin(cx) {
+ Poll::Ready(Some(event)) => {
+ let items = match event {
+ Err(error) => {
+ vec![Err(LanguageModelCompletionError::from(error))]
+ }
+ Ok(CompletionEvent::Status(CompletionRequestStatus::StreamEnded)) => {
+ saw_stream_ended = true;
+ vec![]
+ }
+ Ok(CompletionEvent::Status(status)) => {
+ LanguageModelCompletionEvent::from_completion_request_status(
+ status,
+ provider.clone(),
+ )
+ .transpose()
+ .map(|event| vec![event])
+ .unwrap_or_default()
+ }
+ Ok(CompletionEvent::Event(event)) => map_callback(event),
+ };
+ pending.extend(items);
+ }
+ Poll::Ready(None) => {
+ done = true;
+
+ if !saw_stream_ended {
+ return Poll::Ready(Some(Err(
+ LanguageModelCompletionError::StreamEndedUnexpectedly {
+ provider: provider.clone(),
+ },
+ )));
+ }
+ }
+ Poll::Pending => return Poll::Pending,
+ }
+ }
+ })
+ .boxed()
+}
+
+pub fn provider_name(
+ provider: &cloud_llm_client::LanguageModelProvider,
+) -> LanguageModelProviderName {
+ match provider {
+ cloud_llm_client::LanguageModelProvider::Anthropic => ANTHROPIC_PROVIDER_NAME,
+ cloud_llm_client::LanguageModelProvider::OpenAi => OPEN_AI_PROVIDER_NAME,
+ cloud_llm_client::LanguageModelProvider::Google => GOOGLE_PROVIDER_NAME,
+ cloud_llm_client::LanguageModelProvider::XAi => X_AI_PROVIDER_NAME,
+ }
+}
+
+pub fn response_lines<T: DeserializeOwned>(
+ response: Response<AsyncBody>,
+ includes_status_messages: bool,
+) -> impl Stream<Item = Result<CompletionEvent<T>>> {
+ futures::stream::try_unfold(
+ (String::new(), BufReader::new(response.into_body())),
+ move |(mut line, mut body)| async move {
+ match body.read_line(&mut line).await {
+ Ok(0) => Ok(None),
+ Ok(_) => {
+ let event = if includes_status_messages {
+ serde_json::from_str::<CompletionEvent<T>>(&line)?
+ } else {
+ CompletionEvent::Event(serde_json::from_str::<T>(&line)?)
+ };
+
+ line.clear();
+ Ok(Some((event, (line, body))))
+ }
+ Err(e) => Err(e.into()),
+ }
+ },
+ )
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use http_client::http::{HeaderMap, StatusCode};
+ use language_model::LanguageModelCompletionError;
+
+ #[test]
+ fn test_api_error_conversion_with_upstream_http_error() {
+ // upstream_http_error with 503 status should become ServerOverloaded
+ let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout","upstream_status":503}"#;
+
+ let api_error = ApiError {
+ status: StatusCode::INTERNAL_SERVER_ERROR,
+ body: error_body.to_string(),
+ headers: HeaderMap::new(),
+ };
+
+ let completion_error: LanguageModelCompletionError = api_error.into();
+
+ match completion_error {
+ LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
+ assert_eq!(
+ message,
+ "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout"
+ );
+ }
+ _ => panic!(
+ "Expected UpstreamProviderError for upstream 503, got: {:?}",
+ completion_error
+ ),
+ }
+
+ // upstream_http_error with 500 status should become ApiInternalServerError
+ let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the OpenAI API: internal server error","upstream_status":500}"#;
+
+ let api_error = ApiError {
+ status: StatusCode::INTERNAL_SERVER_ERROR,
+ body: error_body.to_string(),
+ headers: HeaderMap::new(),
+ };
+
+ let completion_error: LanguageModelCompletionError = api_error.into();
+
+ match completion_error {
+ LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
+ assert_eq!(
+ message,
+ "Received an error from the OpenAI API: internal server error"
+ );
+ }
+ _ => panic!(
+ "Expected UpstreamProviderError for upstream 500, got: {:?}",
+ completion_error
+ ),
+ }
+
+ // upstream_http_error with 429 status should become RateLimitExceeded
+ let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Google API: rate limit exceeded","upstream_status":429}"#;
+
+ let api_error = ApiError {
+ status: StatusCode::INTERNAL_SERVER_ERROR,
+ body: error_body.to_string(),
+ headers: HeaderMap::new(),
+ };
+
+ let completion_error: LanguageModelCompletionError = api_error.into();
+
+ match completion_error {
+ LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
+ assert_eq!(
+ message,
+ "Received an error from the Google API: rate limit exceeded"
+ );
+ }
+ _ => panic!(
+ "Expected UpstreamProviderError for upstream 429, got: {:?}",
+ completion_error
+ ),
+ }
+
+ // Regular 500 error without upstream_http_error should remain ApiInternalServerError for Zed
+ let error_body = "Regular internal server error";
+
+ let api_error = ApiError {
+ status: StatusCode::INTERNAL_SERVER_ERROR,
+ body: error_body.to_string(),
+ headers: HeaderMap::new(),
+ };
+
+ let completion_error: LanguageModelCompletionError = api_error.into();
+
+ match completion_error {
+ LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
+ assert_eq!(provider, PROVIDER_NAME);
+ assert_eq!(message, "Regular internal server error");
+ }
+ _ => panic!(
+ "Expected ApiInternalServerError for regular 500, got: {:?}",
+ completion_error
+ ),
+ }
+
+ // upstream_http_429 format should be converted to UpstreamProviderError
+ let error_body = r#"{"code":"upstream_http_429","message":"Upstream Anthropic rate limit exceeded.","retry_after":30.5}"#;
+
+ let api_error = ApiError {
+ status: StatusCode::INTERNAL_SERVER_ERROR,
+ body: error_body.to_string(),
+ headers: HeaderMap::new(),
+ };
+
+ let completion_error: LanguageModelCompletionError = api_error.into();
+
+ match completion_error {
+ LanguageModelCompletionError::UpstreamProviderError {
+ message,
+ status,
+ retry_after,
+ } => {
+ assert_eq!(message, "Upstream Anthropic rate limit exceeded.");
+ assert_eq!(status, StatusCode::TOO_MANY_REQUESTS);
+ assert_eq!(retry_after, Some(Duration::from_secs_f64(30.5)));
+ }
+ _ => panic!(
+ "Expected UpstreamProviderError for upstream_http_429, got: {:?}",
+ completion_error
+ ),
+ }
+
+ // Invalid JSON in error body should fall back to regular error handling
+ let error_body = "Not JSON at all";
+
+ let api_error = ApiError {
+ status: StatusCode::INTERNAL_SERVER_ERROR,
+ body: error_body.to_string(),
+ headers: HeaderMap::new(),
+ };
+
+ let completion_error: LanguageModelCompletionError = api_error.into();
+
+ match completion_error {
+ LanguageModelCompletionError::ApiInternalServerError { provider, .. } => {
+ assert_eq!(provider, PROVIDER_NAME);
+ }
+ _ => panic!(
+ "Expected ApiInternalServerError for invalid JSON, got: {:?}",
+ completion_error
+ ),
+ }
+ }
+}
@@ -17,13 +17,18 @@ schemars = ["dep:schemars"]
[dependencies]
anyhow.workspace = true
+collections.workspace = true
futures.workspace = true
http_client.workspace = true
+language_model_core.workspace = true
rand.workspace = true
schemars = { workspace = true, optional = true }
log.workspace = true
serde.workspace = true
serde_json.workspace = true
-settings.workspace = true
strum.workspace = true
thiserror.workspace = true
+tiktoken-rs.workspace = true
+
+[dev-dependencies]
+pretty_assertions.workspace = true
@@ -0,0 +1,1693 @@
+use anyhow::{Result, anyhow};
+use collections::HashMap;
+use futures::{Stream, StreamExt};
+use language_model_core::{
+ LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelImage,
+ LanguageModelRequest, LanguageModelRequestMessage, LanguageModelToolChoice,
+ LanguageModelToolResultContent, LanguageModelToolUse, LanguageModelToolUseId, MessageContent,
+ Role, StopReason, TokenUsage,
+ util::{fix_streamed_json, parse_tool_arguments},
+};
+use std::pin::Pin;
+use std::sync::Arc;
+
+use crate::responses::{
+ Request as ResponseRequest, ResponseFunctionCallItem, ResponseFunctionCallOutputContent,
+ ResponseFunctionCallOutputItem, ResponseInputContent, ResponseInputItem, ResponseMessageItem,
+ ResponseOutputItem, ResponseSummary as ResponsesSummary, ResponseUsage as ResponsesUsage,
+ StreamEvent as ResponsesStreamEvent,
+};
+use crate::{
+ FunctionContent, FunctionDefinition, ImageUrl, MessagePart, Model, ReasoningEffort,
+ ResponseStreamEvent, ToolCall, ToolCallContent,
+};
+
+pub fn into_open_ai(
+ request: LanguageModelRequest,
+ model_id: &str,
+ supports_parallel_tool_calls: bool,
+ supports_prompt_cache_key: bool,
+ max_output_tokens: Option<u64>,
+ reasoning_effort: Option<ReasoningEffort>,
+) -> crate::Request {
+ let stream = !model_id.starts_with("o1-");
+
+ let mut messages = Vec::new();
+ for message in request.messages {
+ for content in message.content {
+ match content {
+ MessageContent::Text(text) | MessageContent::Thinking { text, .. } => {
+ let should_add = if message.role == Role::User {
+ // Including whitespace-only user messages can cause error with OpenAI compatible APIs
+ // See https://github.com/zed-industries/zed/issues/40097
+ !text.trim().is_empty()
+ } else {
+ !text.is_empty()
+ };
+ if should_add {
+ add_message_content_part(
+ MessagePart::Text { text },
+ message.role,
+ &mut messages,
+ );
+ }
+ }
+ MessageContent::RedactedThinking(_) => {}
+ MessageContent::Image(image) => {
+ add_message_content_part(
+ MessagePart::Image {
+ image_url: ImageUrl {
+ url: image.to_base64_url(),
+ detail: None,
+ },
+ },
+ message.role,
+ &mut messages,
+ );
+ }
+ MessageContent::ToolUse(tool_use) => {
+ let tool_call = ToolCall {
+ id: tool_use.id.to_string(),
+ content: ToolCallContent::Function {
+ function: FunctionContent {
+ name: tool_use.name.to_string(),
+ arguments: serde_json::to_string(&tool_use.input)
+ .unwrap_or_default(),
+ },
+ },
+ };
+
+ if let Some(crate::RequestMessage::Assistant { tool_calls, .. }) =
+ messages.last_mut()
+ {
+ tool_calls.push(tool_call);
+ } else {
+ messages.push(crate::RequestMessage::Assistant {
+ content: None,
+ tool_calls: vec![tool_call],
+ });
+ }
+ }
+ MessageContent::ToolResult(tool_result) => {
+ let content = match &tool_result.content {
+ LanguageModelToolResultContent::Text(text) => {
+ vec![MessagePart::Text {
+ text: text.to_string(),
+ }]
+ }
+ LanguageModelToolResultContent::Image(image) => {
+ vec![MessagePart::Image {
+ image_url: ImageUrl {
+ url: image.to_base64_url(),
+ detail: None,
+ },
+ }]
+ }
+ };
+
+ messages.push(crate::RequestMessage::Tool {
+ content: content.into(),
+ tool_call_id: tool_result.tool_use_id.to_string(),
+ });
+ }
+ }
+ }
+ }
+
+ crate::Request {
+ model: model_id.into(),
+ messages,
+ stream,
+ stream_options: if stream {
+ Some(crate::StreamOptions::default())
+ } else {
+ None
+ },
+ stop: request.stop,
+ temperature: request.temperature.or(Some(1.0)),
+ max_completion_tokens: max_output_tokens,
+ parallel_tool_calls: if supports_parallel_tool_calls && !request.tools.is_empty() {
+ Some(supports_parallel_tool_calls)
+ } else {
+ None
+ },
+ prompt_cache_key: if supports_prompt_cache_key {
+ request.thread_id
+ } else {
+ None
+ },
+ tools: request
+ .tools
+ .into_iter()
+ .map(|tool| crate::ToolDefinition::Function {
+ function: FunctionDefinition {
+ name: tool.name,
+ description: Some(tool.description),
+ parameters: Some(tool.input_schema),
+ },
+ })
+ .collect(),
+ tool_choice: request.tool_choice.map(|choice| match choice {
+ LanguageModelToolChoice::Auto => crate::ToolChoice::Auto,
+ LanguageModelToolChoice::Any => crate::ToolChoice::Required,
+ LanguageModelToolChoice::None => crate::ToolChoice::None,
+ }),
+ reasoning_effort,
+ }
+}
+
+pub fn into_open_ai_response(
+ request: LanguageModelRequest,
+ model_id: &str,
+ supports_parallel_tool_calls: bool,
+ supports_prompt_cache_key: bool,
+ max_output_tokens: Option<u64>,
+ reasoning_effort: Option<ReasoningEffort>,
+) -> ResponseRequest {
+ let stream = !model_id.starts_with("o1-");
+
+ let LanguageModelRequest {
+ thread_id,
+ prompt_id: _,
+ intent: _,
+ messages,
+ tools,
+ tool_choice,
+ stop: _,
+ temperature,
+ thinking_allowed: _,
+ thinking_effort: _,
+ speed: _,
+ } = request;
+
+ let mut input_items = Vec::new();
+ for (index, message) in messages.into_iter().enumerate() {
+ append_message_to_response_items(message, index, &mut input_items);
+ }
+
+ let tools: Vec<_> = tools
+ .into_iter()
+ .map(|tool| crate::responses::ToolDefinition::Function {
+ name: tool.name,
+ description: Some(tool.description),
+ parameters: Some(tool.input_schema),
+ strict: None,
+ })
+ .collect();
+
+ ResponseRequest {
+ model: model_id.into(),
+ input: input_items,
+ stream,
+ temperature,
+ top_p: None,
+ max_output_tokens,
+ parallel_tool_calls: if tools.is_empty() {
+ None
+ } else {
+ Some(supports_parallel_tool_calls)
+ },
+ tool_choice: tool_choice.map(|choice| match choice {
+ LanguageModelToolChoice::Auto => crate::ToolChoice::Auto,
+ LanguageModelToolChoice::Any => crate::ToolChoice::Required,
+ LanguageModelToolChoice::None => crate::ToolChoice::None,
+ }),
+ tools,
+ prompt_cache_key: if supports_prompt_cache_key {
+ thread_id
+ } else {
+ None
+ },
+ reasoning: reasoning_effort.map(|effort| crate::responses::ReasoningConfig {
+ effort,
+ summary: Some(crate::responses::ReasoningSummaryMode::Auto),
+ }),
+ }
+}
+
+fn append_message_to_response_items(
+ message: LanguageModelRequestMessage,
+ index: usize,
+ input_items: &mut Vec<ResponseInputItem>,
+) {
+ let mut content_parts: Vec<ResponseInputContent> = Vec::new();
+
+ for content in message.content {
+ match content {
+ MessageContent::Text(text) => {
+ push_response_text_part(&message.role, text, &mut content_parts);
+ }
+ MessageContent::Thinking { text, .. } => {
+ push_response_text_part(&message.role, text, &mut content_parts);
+ }
+ MessageContent::RedactedThinking(_) => {}
+ MessageContent::Image(image) => {
+ push_response_image_part(&message.role, image, &mut content_parts);
+ }
+ MessageContent::ToolUse(tool_use) => {
+ flush_response_parts(&message.role, index, &mut content_parts, input_items);
+ let call_id = tool_use.id.to_string();
+ input_items.push(ResponseInputItem::FunctionCall(ResponseFunctionCallItem {
+ call_id,
+ name: tool_use.name.to_string(),
+ arguments: tool_use.raw_input,
+ }));
+ }
+ MessageContent::ToolResult(tool_result) => {
+ flush_response_parts(&message.role, index, &mut content_parts, input_items);
+ input_items.push(ResponseInputItem::FunctionCallOutput(
+ ResponseFunctionCallOutputItem {
+ call_id: tool_result.tool_use_id.to_string(),
+ output: match tool_result.content {
+ LanguageModelToolResultContent::Text(text) => {
+ ResponseFunctionCallOutputContent::Text(text.to_string())
+ }
+ LanguageModelToolResultContent::Image(image) => {
+ ResponseFunctionCallOutputContent::List(vec![
+ ResponseInputContent::Image {
+ image_url: image.to_base64_url(),
+ },
+ ])
+ }
+ },
+ },
+ ));
+ }
+ }
+ }
+
+ flush_response_parts(&message.role, index, &mut content_parts, input_items);
+}
+
+fn push_response_text_part(
+ role: &Role,
+ text: impl Into<String>,
+ parts: &mut Vec<ResponseInputContent>,
+) {
+ let text = text.into();
+ if text.trim().is_empty() {
+ return;
+ }
+
+ match role {
+ Role::Assistant => parts.push(ResponseInputContent::OutputText {
+ text,
+ annotations: Vec::new(),
+ }),
+ _ => parts.push(ResponseInputContent::Text { text }),
+ }
+}
+
+fn push_response_image_part(
+ role: &Role,
+ image: LanguageModelImage,
+ parts: &mut Vec<ResponseInputContent>,
+) {
+ match role {
+ Role::Assistant => parts.push(ResponseInputContent::OutputText {
+ text: "[image omitted]".to_string(),
+ annotations: Vec::new(),
+ }),
+ _ => parts.push(ResponseInputContent::Image {
+ image_url: image.to_base64_url(),
+ }),
+ }
+}
+
+fn flush_response_parts(
+ role: &Role,
+ _index: usize,
+ parts: &mut Vec<ResponseInputContent>,
+ input_items: &mut Vec<ResponseInputItem>,
+) {
+ if parts.is_empty() {
+ return;
+ }
+
+ let item = ResponseInputItem::Message(ResponseMessageItem {
+ role: match role {
+ Role::User => crate::Role::User,
+ Role::Assistant => crate::Role::Assistant,
+ Role::System => crate::Role::System,
+ },
+ content: parts.clone(),
+ });
+
+ input_items.push(item);
+ parts.clear();
+}
+
+fn add_message_content_part(
+ new_part: MessagePart,
+ role: Role,
+ messages: &mut Vec<crate::RequestMessage>,
+) {
+ match (role, messages.last_mut()) {
+ (Role::User, Some(crate::RequestMessage::User { content }))
+ | (
+ Role::Assistant,
+ Some(crate::RequestMessage::Assistant {
+ content: Some(content),
+ ..
+ }),
+ )
+ | (Role::System, Some(crate::RequestMessage::System { content, .. })) => {
+ content.push_part(new_part);
+ }
+ _ => {
+ messages.push(match role {
+ Role::User => crate::RequestMessage::User {
+ content: crate::MessageContent::from(vec![new_part]),
+ },
+ Role::Assistant => crate::RequestMessage::Assistant {
+ content: Some(crate::MessageContent::from(vec![new_part])),
+ tool_calls: Vec::new(),
+ },
+ Role::System => crate::RequestMessage::System {
+ content: crate::MessageContent::from(vec![new_part]),
+ },
+ });
+ }
+ }
+}
+
+pub struct OpenAiEventMapper {
+ tool_calls_by_index: HashMap<usize, RawToolCall>,
+}
+
+impl OpenAiEventMapper {
+ pub fn new() -> Self {
+ Self {
+ tool_calls_by_index: HashMap::default(),
+ }
+ }
+
+ pub fn map_stream(
+ mut self,
+ events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
+ ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
+ {
+ events.flat_map(move |event| {
+ futures::stream::iter(match event {
+ Ok(event) => self.map_event(event),
+ Err(error) => vec![Err(LanguageModelCompletionError::from(anyhow!(error)))],
+ })
+ })
+ }
+
+ pub fn map_event(
+ &mut self,
+ event: ResponseStreamEvent,
+ ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
+ let mut events = Vec::new();
+ if let Some(usage) = event.usage {
+ events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
+ input_tokens: usage.prompt_tokens,
+ output_tokens: usage.completion_tokens,
+ cache_creation_input_tokens: 0,
+ cache_read_input_tokens: 0,
+ })));
+ }
+
+ let Some(choice) = event.choices.first() else {
+ return events;
+ };
+
+ if let Some(delta) = choice.delta.as_ref() {
+ if let Some(reasoning_content) = delta.reasoning_content.clone() {
+ if !reasoning_content.is_empty() {
+ events.push(Ok(LanguageModelCompletionEvent::Thinking {
+ text: reasoning_content,
+ signature: None,
+ }));
+ }
+ }
+ if let Some(content) = delta.content.clone() {
+ if !content.is_empty() {
+ events.push(Ok(LanguageModelCompletionEvent::Text(content)));
+ }
+ }
+
+ if let Some(tool_calls) = delta.tool_calls.as_ref() {
+ for tool_call in tool_calls {
+ let entry = self.tool_calls_by_index.entry(tool_call.index).or_default();
+
+ if let Some(tool_id) = tool_call.id.clone() {
+ entry.id = tool_id;
+ }
+
+ if let Some(function) = tool_call.function.as_ref() {
+ if let Some(name) = function.name.clone() {
+ entry.name = name;
+ }
+
+ if let Some(arguments) = function.arguments.clone() {
+ entry.arguments.push_str(&arguments);
+ }
+ }
+
+ if !entry.id.is_empty() && !entry.name.is_empty() {
+ if let Ok(input) = serde_json::from_str::<serde_json::Value>(
+ &fix_streamed_json(&entry.arguments),
+ ) {
+ events.push(Ok(LanguageModelCompletionEvent::ToolUse(
+ LanguageModelToolUse {
+ id: entry.id.clone().into(),
+ name: entry.name.as_str().into(),
+ is_input_complete: false,
+ input,
+ raw_input: entry.arguments.clone(),
+ thought_signature: None,
+ },
+ )));
+ }
+ }
+ }
+ }
+ }
+
+ match choice.finish_reason.as_deref() {
+ Some("stop") => {
+ events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
+ }
+ Some("tool_calls") => {
+ events.extend(self.tool_calls_by_index.drain().map(|(_, tool_call)| {
+ match parse_tool_arguments(&tool_call.arguments) {
+ Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
+ LanguageModelToolUse {
+ id: tool_call.id.clone().into(),
+ name: tool_call.name.as_str().into(),
+ is_input_complete: true,
+ input,
+ raw_input: tool_call.arguments.clone(),
+ thought_signature: None,
+ },
+ )),
+ Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
+ id: tool_call.id.into(),
+ tool_name: tool_call.name.into(),
+ raw_input: tool_call.arguments.clone().into(),
+ json_parse_error: error.to_string(),
+ }),
+ }
+ }));
+
+ events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
+ }
+ Some(stop_reason) => {
+ log::error!("Unexpected OpenAI stop_reason: {stop_reason:?}",);
+ events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
+ }
+ None => {}
+ }
+
+ events
+ }
+}
+
+#[derive(Default)]
+struct RawToolCall {
+ id: String,
+ name: String,
+ arguments: String,
+}
+
+pub struct OpenAiResponseEventMapper {
+ function_calls_by_item: HashMap<String, PendingResponseFunctionCall>,
+ pending_stop_reason: Option<StopReason>,
+}
+
+#[derive(Default)]
+struct PendingResponseFunctionCall {
+ call_id: String,
+ name: Arc<str>,
+ arguments: String,
+}
+
+impl OpenAiResponseEventMapper {
+ pub fn new() -> Self {
+ Self {
+ function_calls_by_item: HashMap::default(),
+ pending_stop_reason: None,
+ }
+ }
+
+ pub fn map_stream(
+ mut self,
+ events: Pin<Box<dyn Send + Stream<Item = Result<ResponsesStreamEvent>>>>,
+ ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
+ {
+ events.flat_map(move |event| {
+ futures::stream::iter(match event {
+ Ok(event) => self.map_event(event),
+ Err(error) => vec![Err(LanguageModelCompletionError::from(anyhow!(error)))],
+ })
+ })
+ }
+
+ pub fn map_event(
+ &mut self,
+ event: ResponsesStreamEvent,
+ ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
+ match event {
+ ResponsesStreamEvent::OutputItemAdded { item, .. } => {
+ let mut events = Vec::new();
+
+ match &item {
+ ResponseOutputItem::Message(message) => {
+ if let Some(id) = &message.id {
+ events.push(Ok(LanguageModelCompletionEvent::StartMessage {
+ message_id: id.clone(),
+ }));
+ }
+ }
+ ResponseOutputItem::FunctionCall(function_call) => {
+ if let Some(item_id) = function_call.id.clone() {
+ let call_id = function_call
+ .call_id
+ .clone()
+ .or_else(|| function_call.id.clone())
+ .unwrap_or_else(|| item_id.clone());
+ let entry = PendingResponseFunctionCall {
+ call_id,
+ name: Arc::<str>::from(
+ function_call.name.clone().unwrap_or_default(),
+ ),
+ arguments: function_call.arguments.clone(),
+ };
+ self.function_calls_by_item.insert(item_id, entry);
+ }
+ }
+ ResponseOutputItem::Reasoning(_) | ResponseOutputItem::Unknown => {}
+ }
+ events
+ }
+ ResponsesStreamEvent::ReasoningSummaryTextDelta { delta, .. } => {
+ if delta.is_empty() {
+ Vec::new()
+ } else {
+ vec![Ok(LanguageModelCompletionEvent::Thinking {
+ text: delta,
+ signature: None,
+ })]
+ }
+ }
+ ResponsesStreamEvent::OutputTextDelta { delta, .. } => {
+ if delta.is_empty() {
+ Vec::new()
+ } else {
+ vec![Ok(LanguageModelCompletionEvent::Text(delta))]
+ }
+ }
+ ResponsesStreamEvent::FunctionCallArgumentsDelta { item_id, delta, .. } => {
+ if let Some(entry) = self.function_calls_by_item.get_mut(&item_id) {
+ entry.arguments.push_str(&delta);
+ if let Ok(input) = serde_json::from_str::<serde_json::Value>(
+ &fix_streamed_json(&entry.arguments),
+ ) {
+ return vec![Ok(LanguageModelCompletionEvent::ToolUse(
+ LanguageModelToolUse {
+ id: LanguageModelToolUseId::from(entry.call_id.clone()),
+ name: entry.name.clone(),
+ is_input_complete: false,
+ input,
+ raw_input: entry.arguments.clone(),
+ thought_signature: None,
+ },
+ ))];
+ }
+ }
+ Vec::new()
+ }
+ ResponsesStreamEvent::FunctionCallArgumentsDone {
+ item_id, arguments, ..
+ } => {
+ if let Some(mut entry) = self.function_calls_by_item.remove(&item_id) {
+ if !arguments.is_empty() {
+ entry.arguments = arguments;
+ }
+ let raw_input = entry.arguments.clone();
+ self.pending_stop_reason = Some(StopReason::ToolUse);
+ match parse_tool_arguments(&entry.arguments) {
+ Ok(input) => {
+ vec![Ok(LanguageModelCompletionEvent::ToolUse(
+ LanguageModelToolUse {
+ id: LanguageModelToolUseId::from(entry.call_id.clone()),
+ name: entry.name.clone(),
+ is_input_complete: true,
+ input,
+ raw_input,
+ thought_signature: None,
+ },
+ ))]
+ }
+ Err(error) => {
+ vec![Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
+ id: LanguageModelToolUseId::from(entry.call_id.clone()),
+ tool_name: entry.name.clone(),
+ raw_input: Arc::<str>::from(raw_input),
+ json_parse_error: error.to_string(),
+ })]
+ }
+ }
+ } else {
+ Vec::new()
+ }
+ }
+ ResponsesStreamEvent::Completed { response } => {
+ self.handle_completion(response, StopReason::EndTurn)
+ }
+ ResponsesStreamEvent::Incomplete { response } => {
+ let reason = response
+ .status_details
+ .as_ref()
+ .and_then(|details| details.reason.as_deref());
+ let stop_reason = match reason {
+ Some("max_output_tokens") => StopReason::MaxTokens,
+ Some("content_filter") => {
+ self.pending_stop_reason = Some(StopReason::Refusal);
+ StopReason::Refusal
+ }
+ _ => self
+ .pending_stop_reason
+ .take()
+ .unwrap_or(StopReason::EndTurn),
+ };
+
+ let mut events = Vec::new();
+ if self.pending_stop_reason.is_none() {
+ events.extend(self.emit_tool_calls_from_output(&response.output));
+ }
+ if let Some(usage) = response.usage.as_ref() {
+ events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(
+ token_usage_from_response_usage(usage),
+ )));
+ }
+ events.push(Ok(LanguageModelCompletionEvent::Stop(stop_reason)));
+ events
+ }
+ ResponsesStreamEvent::Failed { response } => {
+ let message = response
+ .status_details
+ .and_then(|details| details.error)
+ .map(|error| error.to_string())
+ .unwrap_or_else(|| "response failed".to_string());
+ vec![Err(LanguageModelCompletionError::Other(anyhow!(message)))]
+ }
+ ResponsesStreamEvent::Error { error }
+ | ResponsesStreamEvent::GenericError { error } => {
+ vec![Err(LanguageModelCompletionError::Other(anyhow!(
+ error.message
+ )))]
+ }
+ ResponsesStreamEvent::ReasoningSummaryPartAdded { summary_index, .. } => {
+ if summary_index > 0 {
+ vec![Ok(LanguageModelCompletionEvent::Thinking {
+ text: "\n\n".to_string(),
+ signature: None,
+ })]
+ } else {
+ Vec::new()
+ }
+ }
+ ResponsesStreamEvent::OutputTextDone { .. }
+ | ResponsesStreamEvent::OutputItemDone { .. }
+ | ResponsesStreamEvent::ContentPartAdded { .. }
+ | ResponsesStreamEvent::ContentPartDone { .. }
+ | ResponsesStreamEvent::ReasoningSummaryTextDone { .. }
+ | ResponsesStreamEvent::ReasoningSummaryPartDone { .. }
+ | ResponsesStreamEvent::Created { .. }
+ | ResponsesStreamEvent::InProgress { .. }
+ | ResponsesStreamEvent::Unknown => Vec::new(),
+ }
+ }
+
+ fn handle_completion(
+ &mut self,
+ response: ResponsesSummary,
+ default_reason: StopReason,
+ ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
+ let mut events = Vec::new();
+
+ if self.pending_stop_reason.is_none() {
+ events.extend(self.emit_tool_calls_from_output(&response.output));
+ }
+
+ if let Some(usage) = response.usage.as_ref() {
+ events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(
+ token_usage_from_response_usage(usage),
+ )));
+ }
+
+ let stop_reason = self.pending_stop_reason.take().unwrap_or(default_reason);
+ events.push(Ok(LanguageModelCompletionEvent::Stop(stop_reason)));
+ events
+ }
+
+ fn emit_tool_calls_from_output(
+ &mut self,
+ output: &[ResponseOutputItem],
+ ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
+ let mut events = Vec::new();
+ for item in output {
+ if let ResponseOutputItem::FunctionCall(function_call) = item {
+ let Some(call_id) = function_call
+ .call_id
+ .clone()
+ .or_else(|| function_call.id.clone())
+ else {
+ log::error!(
+ "Function call item missing both call_id and id: {:?}",
+ function_call
+ );
+ continue;
+ };
+ let name: Arc<str> = Arc::from(function_call.name.clone().unwrap_or_default());
+ let arguments = &function_call.arguments;
+ self.pending_stop_reason = Some(StopReason::ToolUse);
+ match parse_tool_arguments(arguments) {
+ Ok(input) => {
+ events.push(Ok(LanguageModelCompletionEvent::ToolUse(
+ LanguageModelToolUse {
+ id: LanguageModelToolUseId::from(call_id.clone()),
+ name: name.clone(),
+ is_input_complete: true,
+ input,
+ raw_input: arguments.clone(),
+ thought_signature: None,
+ },
+ )));
+ }
+ Err(error) => {
+ events.push(Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
+ id: LanguageModelToolUseId::from(call_id.clone()),
+ tool_name: name.clone(),
+ raw_input: Arc::<str>::from(arguments.clone()),
+ json_parse_error: error.to_string(),
+ }));
+ }
+ }
+ }
+ }
+ events
+ }
+}
+
+fn token_usage_from_response_usage(usage: &ResponsesUsage) -> TokenUsage {
+ TokenUsage {
+ input_tokens: usage.input_tokens.unwrap_or_default(),
+ output_tokens: usage.output_tokens.unwrap_or_default(),
+ cache_creation_input_tokens: 0,
+ cache_read_input_tokens: 0,
+ }
+}
+
+pub fn collect_tiktoken_messages(
+ request: LanguageModelRequest,
+) -> Vec<tiktoken_rs::ChatCompletionRequestMessage> {
+ request
+ .messages
+ .into_iter()
+ .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
+ role: match message.role {
+ Role::User => "user".into(),
+ Role::Assistant => "assistant".into(),
+ Role::System => "system".into(),
+ },
+ content: Some(message.string_contents()),
+ name: None,
+ function_call: None,
+ })
+ .collect::<Vec<_>>()
+}
+
+/// Count tokens for an OpenAI model. This is synchronous; callers should spawn
+/// it on a background thread if needed.
+pub fn count_open_ai_tokens(request: LanguageModelRequest, model: Model) -> Result<u64> {
+ let messages = collect_tiktoken_messages(request);
+ match model {
+ Model::Custom { max_tokens, .. } => {
+ let model = if max_tokens >= 100_000 {
+ // If the max tokens is 100k or more, it likely uses the o200k_base tokenizer
+ "gpt-4o"
+ } else {
+ // Otherwise fallback to gpt-4, since only cl100k_base and o200k_base are
+ // supported with this tiktoken method
+ "gpt-4"
+ };
+ tiktoken_rs::num_tokens_from_messages(model, &messages)
+ }
+ // Currently supported by tiktoken_rs
+ // Sometimes tiktoken-rs is behind on model support. If that is the case, make a new branch
+ // arm with an override. We enumerate all supported models here so that we can check if new
+ // models are supported yet or not.
+ Model::ThreePointFiveTurbo
+ | Model::Four
+ | Model::FourTurbo
+ | Model::FourOmniMini
+ | Model::FourPointOneNano
+ | Model::O1
+ | Model::O3
+ | Model::O3Mini
+ | Model::Five
+ | Model::FiveCodex
+ | Model::FiveMini
+ | Model::FiveNano => tiktoken_rs::num_tokens_from_messages(model.id(), &messages),
+ // GPT-5.1, 5.2, 5.2-codex, 5.3-codex, 5.4, and 5.4-pro don't have dedicated tiktoken support; use gpt-5 tokenizer
+ Model::FivePointOne
+ | Model::FivePointTwo
+ | Model::FivePointTwoCodex
+ | Model::FivePointThreeCodex
+ | Model::FivePointFour
+ | Model::FivePointFourPro => tiktoken_rs::num_tokens_from_messages("gpt-5", &messages),
+ }
+ .map(|tokens| tokens as u64)
+}
+
+#[cfg(test)]
+mod tests {
+ use crate::responses::{
+ ReasoningSummaryPart, ResponseFunctionToolCall, ResponseOutputItem, ResponseOutputMessage,
+ ResponseReasoningItem, ResponseStatusDetails, ResponseSummary, ResponseUsage,
+ StreamEvent as ResponsesStreamEvent,
+ };
+ use futures::{StreamExt, executor::block_on};
+ use language_model_core::{
+ LanguageModelImage, LanguageModelRequestMessage, LanguageModelRequestTool,
+ LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolUse,
+ LanguageModelToolUseId, SharedString,
+ };
+ use pretty_assertions::assert_eq;
+ use serde_json::json;
+
+ use super::*;
+
+ fn map_response_events(events: Vec<ResponsesStreamEvent>) -> Vec<LanguageModelCompletionEvent> {
+ block_on(async {
+ OpenAiResponseEventMapper::new()
+ .map_stream(Box::pin(futures::stream::iter(events.into_iter().map(Ok))))
+ .collect::<Vec<_>>()
+ .await
+ .into_iter()
+ .map(Result::unwrap)
+ .collect()
+ })
+ }
+
+ fn response_item_message(id: &str) -> ResponseOutputItem {
+ ResponseOutputItem::Message(ResponseOutputMessage {
+ id: Some(id.to_string()),
+ role: Some("assistant".to_string()),
+ status: Some("in_progress".to_string()),
+ content: vec![],
+ })
+ }
+
+ fn response_item_function_call(id: &str, args: Option<&str>) -> ResponseOutputItem {
+ ResponseOutputItem::FunctionCall(ResponseFunctionToolCall {
+ id: Some(id.to_string()),
+ status: Some("in_progress".to_string()),
+ name: Some("get_weather".to_string()),
+ call_id: Some("call_123".to_string()),
+ arguments: args.map(|s| s.to_string()).unwrap_or_default(),
+ })
+ }
+
+ #[test]
+ fn tiktoken_rs_support() {
+ let request = LanguageModelRequest {
+ thread_id: None,
+ prompt_id: None,
+ intent: None,
+ messages: vec![LanguageModelRequestMessage {
+ role: Role::User,
+ content: vec![MessageContent::Text("message".into())],
+ cache: false,
+ reasoning_details: None,
+ }],
+ tools: vec![],
+ tool_choice: None,
+ stop: vec![],
+ temperature: None,
+ thinking_allowed: true,
+ thinking_effort: None,
+ speed: None,
+ };
+
+ // Validate that all models are supported by tiktoken-rs
+ for model in <Model as strum::IntoEnumIterator>::iter() {
+ let count = count_open_ai_tokens(request.clone(), model).unwrap();
+ assert!(count > 0);
+ }
+ }
+
+ #[test]
+ fn responses_stream_maps_text_and_usage() {
+ let events = vec![
+ ResponsesStreamEvent::OutputItemAdded {
+ output_index: 0,
+ sequence_number: None,
+ item: response_item_message("msg_123"),
+ },
+ ResponsesStreamEvent::OutputTextDelta {
+ item_id: "msg_123".into(),
+ output_index: 0,
+ content_index: Some(0),
+ delta: "Hello".into(),
+ },
+ ResponsesStreamEvent::Completed {
+ response: ResponseSummary {
+ usage: Some(ResponseUsage {
+ input_tokens: Some(5),
+ output_tokens: Some(3),
+ total_tokens: Some(8),
+ }),
+ ..Default::default()
+ },
+ },
+ ];
+
+ let mapped = map_response_events(events);
+ assert!(matches!(
+ mapped[0],
+ LanguageModelCompletionEvent::StartMessage { ref message_id } if message_id == "msg_123"
+ ));
+ assert!(matches!(
+ mapped[1],
+ LanguageModelCompletionEvent::Text(ref text) if text == "Hello"
+ ));
+ assert!(matches!(
+ mapped[2],
+ LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
+ input_tokens: 5,
+ output_tokens: 3,
+ ..
+ })
+ ));
+ assert!(matches!(
+ mapped[3],
+ LanguageModelCompletionEvent::Stop(StopReason::EndTurn)
+ ));
+ }
+
+ #[test]
+ fn into_open_ai_response_builds_complete_payload() {
+ let tool_call_id = LanguageModelToolUseId::from("call-42");
+ let tool_input = json!({ "city": "Boston" });
+ let tool_arguments = serde_json::to_string(&tool_input).unwrap();
+ let tool_use = LanguageModelToolUse {
+ id: tool_call_id.clone(),
+ name: Arc::from("get_weather"),
+ raw_input: tool_arguments.clone(),
+ input: tool_input,
+ is_input_complete: true,
+ thought_signature: None,
+ };
+ let tool_result = LanguageModelToolResult {
+ tool_use_id: tool_call_id,
+ tool_name: Arc::from("get_weather"),
+ is_error: false,
+ content: LanguageModelToolResultContent::Text(Arc::from("Sunny")),
+ output: Some(json!({ "forecast": "Sunny" })),
+ };
+ let user_image = LanguageModelImage {
+ source: SharedString::from("aGVsbG8="),
+ size: None,
+ };
+ let expected_image_url = user_image.to_base64_url();
+
+ let request = LanguageModelRequest {
+ thread_id: Some("thread-123".into()),
+ prompt_id: None,
+ intent: None,
+ messages: vec![
+ LanguageModelRequestMessage {
+ role: Role::System,
+ content: vec![MessageContent::Text("System context".into())],
+ cache: false,
+ reasoning_details: None,
+ },
+ LanguageModelRequestMessage {
+ role: Role::User,
+ content: vec![
+ MessageContent::Text("Please check the weather.".into()),
+ MessageContent::Image(user_image),
+ ],
+ cache: false,
+ reasoning_details: None,
+ },
+ LanguageModelRequestMessage {
+ role: Role::Assistant,
+ content: vec![
+ MessageContent::Text("Looking that up.".into()),
+ MessageContent::ToolUse(tool_use),
+ ],
+ cache: false,
+ reasoning_details: None,
+ },
+ LanguageModelRequestMessage {
+ role: Role::Assistant,
+ content: vec![MessageContent::ToolResult(tool_result)],
+ cache: false,
+ reasoning_details: None,
+ },
+ ],
+ tools: vec![LanguageModelRequestTool {
+ name: "get_weather".into(),
+ description: "Fetches the weather".into(),
+ input_schema: json!({ "type": "object" }),
+ use_input_streaming: false,
+ }],
+ tool_choice: Some(LanguageModelToolChoice::Any),
+ stop: vec!["<STOP>".into()],
+ temperature: None,
+ thinking_allowed: false,
+ thinking_effort: None,
+ speed: None,
+ };
+
+ let response = into_open_ai_response(
+ request,
+ "custom-model",
+ true,
+ true,
+ Some(2048),
+ Some(ReasoningEffort::Low),
+ );
+
+ let serialized = serde_json::to_value(&response).unwrap();
+ let expected = json!({
+ "model": "custom-model",
+ "input": [
+ {
+ "type": "message",
+ "role": "system",
+ "content": [
+ { "type": "input_text", "text": "System context" }
+ ]
+ },
+ {
+ "type": "message",
+ "role": "user",
+ "content": [
+ { "type": "input_text", "text": "Please check the weather." },
+ { "type": "input_image", "image_url": expected_image_url }
+ ]
+ },
+ {
+ "type": "message",
+ "role": "assistant",
+ "content": [
+ { "type": "output_text", "text": "Looking that up.", "annotations": [] }
+ ]
+ },
+ {
+ "type": "function_call",
+ "call_id": "call-42",
+ "name": "get_weather",
+ "arguments": tool_arguments
+ },
+ {
+ "type": "function_call_output",
+ "call_id": "call-42",
+ "output": "Sunny"
+ }
+ ],
+ "stream": true,
+ "max_output_tokens": 2048,
+ "parallel_tool_calls": true,
+ "tool_choice": "required",
+ "tools": [
+ {
+ "type": "function",
+ "name": "get_weather",
+ "description": "Fetches the weather",
+ "parameters": { "type": "object" }
+ }
+ ],
+ "prompt_cache_key": "thread-123",
+ "reasoning": { "effort": "low", "summary": "auto" }
+ });
+
+ assert_eq!(serialized, expected);
+ }
+
+ #[test]
+ fn responses_stream_maps_tool_calls() {
+ let events = vec![
+ ResponsesStreamEvent::OutputItemAdded {
+ output_index: 0,
+ sequence_number: None,
+ item: response_item_function_call("item_fn", Some("{\"city\":\"Bos")),
+ },
+ ResponsesStreamEvent::FunctionCallArgumentsDelta {
+ item_id: "item_fn".into(),
+ output_index: 0,
+ delta: "ton\"}".into(),
+ sequence_number: None,
+ },
+ ResponsesStreamEvent::FunctionCallArgumentsDone {
+ item_id: "item_fn".into(),
+ output_index: 0,
+ arguments: "{\"city\":\"Boston\"}".into(),
+ sequence_number: None,
+ },
+ ResponsesStreamEvent::Completed {
+ response: ResponseSummary::default(),
+ },
+ ];
+
+ let mapped = map_response_events(events);
+ assert_eq!(mapped.len(), 3);
+ assert!(matches!(
+ mapped[0],
+ LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse {
+ is_input_complete: false,
+ ..
+ })
+ ));
+ assert!(matches!(
+ mapped[1],
+ LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse {
+ ref id,
+ ref name,
+ ref raw_input,
+ is_input_complete: true,
+ ..
+ }) if id.to_string() == "call_123"
+ && name.as_ref() == "get_weather"
+ && raw_input == "{\"city\":\"Boston\"}"
+ ));
+ assert!(matches!(
+ mapped[2],
+ LanguageModelCompletionEvent::Stop(StopReason::ToolUse)
+ ));
+ }
+
+ #[test]
+ fn responses_stream_uses_max_tokens_stop_reason() {
+ let events = vec![ResponsesStreamEvent::Incomplete {
+ response: ResponseSummary {
+ status_details: Some(ResponseStatusDetails {
+ reason: Some("max_output_tokens".into()),
+ r#type: Some("incomplete".into()),
+ error: None,
+ }),
+ usage: Some(ResponseUsage {
+ input_tokens: Some(10),
+ output_tokens: Some(20),
+ total_tokens: Some(30),
+ }),
+ ..Default::default()
+ },
+ }];
+
+ let mapped = map_response_events(events);
+ assert!(matches!(
+ mapped[0],
+ LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
+ input_tokens: 10,
+ output_tokens: 20,
+ ..
+ })
+ ));
+ assert!(matches!(
+ mapped[1],
+ LanguageModelCompletionEvent::Stop(StopReason::MaxTokens)
+ ));
+ }
+
+ #[test]
+ fn responses_stream_handles_multiple_tool_calls() {
+ let events = vec![
+ ResponsesStreamEvent::OutputItemAdded {
+ output_index: 0,
+ sequence_number: None,
+ item: response_item_function_call("item_fn1", Some("{\"city\":\"NYC\"}")),
+ },
+ ResponsesStreamEvent::FunctionCallArgumentsDone {
+ item_id: "item_fn1".into(),
+ output_index: 0,
+ arguments: "{\"city\":\"NYC\"}".into(),
+ sequence_number: None,
+ },
+ ResponsesStreamEvent::OutputItemAdded {
+ output_index: 1,
+ sequence_number: None,
+ item: response_item_function_call("item_fn2", Some("{\"city\":\"LA\"}")),
+ },
+ ResponsesStreamEvent::FunctionCallArgumentsDone {
+ item_id: "item_fn2".into(),
+ output_index: 1,
+ arguments: "{\"city\":\"LA\"}".into(),
+ sequence_number: None,
+ },
+ ResponsesStreamEvent::Completed {
+ response: ResponseSummary::default(),
+ },
+ ];
+
+ let mapped = map_response_events(events);
+ assert_eq!(mapped.len(), 3);
+ assert!(matches!(
+ mapped[0],
+ LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { ref raw_input, .. })
+ if raw_input == "{\"city\":\"NYC\"}"
+ ));
+ assert!(matches!(
+ mapped[1],
+ LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { ref raw_input, .. })
+ if raw_input == "{\"city\":\"LA\"}"
+ ));
+ assert!(matches!(
+ mapped[2],
+ LanguageModelCompletionEvent::Stop(StopReason::ToolUse)
+ ));
+ }
+
+ #[test]
+ fn responses_stream_handles_mixed_text_and_tool_calls() {
+ let events = vec![
+ ResponsesStreamEvent::OutputItemAdded {
+ output_index: 0,
+ sequence_number: None,
+ item: response_item_message("msg_123"),
+ },
+ ResponsesStreamEvent::OutputTextDelta {
+ item_id: "msg_123".into(),
+ output_index: 0,
+ content_index: Some(0),
+ delta: "Let me check that".into(),
+ },
+ ResponsesStreamEvent::OutputItemAdded {
+ output_index: 1,
+ sequence_number: None,
+ item: response_item_function_call("item_fn", Some("{\"query\":\"test\"}")),
+ },
+ ResponsesStreamEvent::FunctionCallArgumentsDone {
+ item_id: "item_fn".into(),
+ output_index: 1,
+ arguments: "{\"query\":\"test\"}".into(),
+ sequence_number: None,
+ },
+ ResponsesStreamEvent::Completed {
+ response: ResponseSummary::default(),
+ },
+ ];
+
+ let mapped = map_response_events(events);
+ assert!(matches!(
+ mapped[0],
+ LanguageModelCompletionEvent::StartMessage { .. }
+ ));
+ assert!(
+ matches!(mapped[1], LanguageModelCompletionEvent::Text(ref text) if text == "Let me check that")
+ );
+ assert!(
+ matches!(mapped[2], LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { ref raw_input, .. }) if raw_input == "{\"query\":\"test\"}")
+ );
+ assert!(matches!(
+ mapped[3],
+ LanguageModelCompletionEvent::Stop(StopReason::ToolUse)
+ ));
+ }
+
+ #[test]
+ fn responses_stream_handles_json_parse_error() {
+ let events = vec![
+ ResponsesStreamEvent::OutputItemAdded {
+ output_index: 0,
+ sequence_number: None,
+ item: response_item_function_call("item_fn", Some("{invalid json")),
+ },
+ ResponsesStreamEvent::FunctionCallArgumentsDone {
+ item_id: "item_fn".into(),
+ output_index: 0,
+ arguments: "{invalid json".into(),
+ sequence_number: None,
+ },
+ ResponsesStreamEvent::Completed {
+ response: ResponseSummary::default(),
+ },
+ ];
+
+ let mapped = map_response_events(events);
+ assert!(matches!(
+ mapped[0],
+ LanguageModelCompletionEvent::ToolUseJsonParseError { ref raw_input, .. }
+ if raw_input.as_ref() == "{invalid json"
+ ));
+ }
+
+ #[test]
+ fn responses_stream_handles_incomplete_function_call() {
+ let events = vec![
+ ResponsesStreamEvent::OutputItemAdded {
+ output_index: 0,
+ sequence_number: None,
+ item: response_item_function_call("item_fn", Some("{\"city\":")),
+ },
+ ResponsesStreamEvent::FunctionCallArgumentsDelta {
+ item_id: "item_fn".into(),
+ output_index: 0,
+ delta: "\"Boston\"".into(),
+ sequence_number: None,
+ },
+ ResponsesStreamEvent::Incomplete {
+ response: ResponseSummary {
+ status_details: Some(ResponseStatusDetails {
+ reason: Some("max_output_tokens".into()),
+ r#type: Some("incomplete".into()),
+ error: None,
+ }),
+ output: vec![response_item_function_call(
+ "item_fn",
+ Some("{\"city\":\"Boston\"}"),
+ )],
+ ..Default::default()
+ },
+ },
+ ];
+
+ let mapped = map_response_events(events);
+ assert_eq!(mapped.len(), 3);
+ assert!(matches!(
+ mapped[0],
+ LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse {
+ is_input_complete: false,
+ ..
+ })
+ ));
+ assert!(
+ matches!(mapped[1], LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { ref raw_input, is_input_complete: true, .. }) if raw_input == "{\"city\":\"Boston\"}")
+ );
+ assert!(matches!(
+ mapped[2],
+ LanguageModelCompletionEvent::Stop(StopReason::MaxTokens)
+ ));
+ }
+
+ #[test]
+ fn responses_stream_incomplete_does_not_duplicate_tool_calls() {
+ let events = vec![
+ ResponsesStreamEvent::OutputItemAdded {
+ output_index: 0,
+ sequence_number: None,
+ item: response_item_function_call("item_fn", Some("{\"city\":\"Boston\"}")),
+ },
+ ResponsesStreamEvent::FunctionCallArgumentsDone {
+ item_id: "item_fn".into(),
+ output_index: 0,
+ arguments: "{\"city\":\"Boston\"}".into(),
+ sequence_number: None,
+ },
+ ResponsesStreamEvent::Incomplete {
+ response: ResponseSummary {
+ status_details: Some(ResponseStatusDetails {
+ reason: Some("max_output_tokens".into()),
+ r#type: Some("incomplete".into()),
+ error: None,
+ }),
+ output: vec![response_item_function_call(
+ "item_fn",
+ Some("{\"city\":\"Boston\"}"),
+ )],
+ ..Default::default()
+ },
+ },
+ ];
+
+ let mapped = map_response_events(events);
+ assert_eq!(mapped.len(), 2);
+ assert!(
+ matches!(mapped[0], LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { ref raw_input, .. }) if raw_input == "{\"city\":\"Boston\"}")
+ );
+ assert!(matches!(
+ mapped[1],
+ LanguageModelCompletionEvent::Stop(StopReason::MaxTokens)
+ ));
+ }
+
+ #[test]
+ fn responses_stream_handles_empty_tool_arguments() {
+ let events = vec![
+ ResponsesStreamEvent::OutputItemAdded {
+ output_index: 0,
+ sequence_number: None,
+ item: response_item_function_call("item_fn", Some("")),
+ },
+ ResponsesStreamEvent::FunctionCallArgumentsDone {
+ item_id: "item_fn".into(),
+ output_index: 0,
+ arguments: "".into(),
+ sequence_number: None,
+ },
+ ResponsesStreamEvent::Completed {
+ response: ResponseSummary::default(),
+ },
+ ];
+
+ let mapped = map_response_events(events);
+ assert_eq!(mapped.len(), 2);
+ assert!(matches!(
+ &mapped[0],
+ LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse {
+ id, name, raw_input, input, ..
+ }) if id.to_string() == "call_123"
+ && name.as_ref() == "get_weather"
+ && raw_input == ""
+ && input.is_object()
+ && input.as_object().unwrap().is_empty()
+ ));
+ assert!(matches!(
+ mapped[1],
+ LanguageModelCompletionEvent::Stop(StopReason::ToolUse)
+ ));
+ }
+
+ #[test]
+ fn responses_stream_emits_partial_tool_use_events() {
+ let events = vec![
+ ResponsesStreamEvent::OutputItemAdded {
+ output_index: 0,
+ sequence_number: None,
+ item: ResponseOutputItem::FunctionCall(
+ crate::responses::ResponseFunctionToolCall {
+ id: Some("item_fn".to_string()),
+ status: Some("in_progress".to_string()),
+ name: Some("get_weather".to_string()),
+ call_id: Some("call_abc".to_string()),
+ arguments: String::new(),
+ },
+ ),
+ },
+ ResponsesStreamEvent::FunctionCallArgumentsDelta {
+ item_id: "item_fn".into(),
+ output_index: 0,
+ delta: "{\"city\":\"Bos".into(),
+ sequence_number: None,
+ },
+ ResponsesStreamEvent::FunctionCallArgumentsDelta {
+ item_id: "item_fn".into(),
+ output_index: 0,
+ delta: "ton\"}".into(),
+ sequence_number: None,
+ },
+ ResponsesStreamEvent::FunctionCallArgumentsDone {
+ item_id: "item_fn".into(),
+ output_index: 0,
+ arguments: "{\"city\":\"Boston\"}".into(),
+ sequence_number: None,
+ },
+ ResponsesStreamEvent::Completed {
+ response: ResponseSummary::default(),
+ },
+ ];
+
+ let mapped = map_response_events(events);
+ assert!(mapped.len() >= 3);
+
+ let complete_tool_use = mapped.iter().find(|e| {
+ matches!(
+ e,
+ LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse {
+ is_input_complete: true,
+ ..
+ })
+ )
+ });
+ assert!(
+ complete_tool_use.is_some(),
+ "should have a complete tool use event"
+ );
+
+ let tool_uses: Vec<_> = mapped
+ .iter()
+ .filter(|e| matches!(e, LanguageModelCompletionEvent::ToolUse(_)))
+ .collect();
+ assert!(
+ tool_uses.len() >= 2,
+ "should have at least one partial and one complete event"
+ );
+ assert!(matches!(
+ tool_uses.last().unwrap(),
+ LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse {
+ is_input_complete: true,
+ ..
+ })
+ ));
+ }
+
+ #[test]
+ fn responses_stream_maps_reasoning_summary_deltas() {
+ let events = vec![
+ ResponsesStreamEvent::OutputItemAdded {
+ output_index: 0,
+ sequence_number: None,
+ item: ResponseOutputItem::Reasoning(ResponseReasoningItem {
+ id: Some("rs_123".into()),
+ summary: vec![],
+ }),
+ },
+ ResponsesStreamEvent::ReasoningSummaryPartAdded {
+ item_id: "rs_123".into(),
+ output_index: 0,
+ summary_index: 0,
+ },
+ ResponsesStreamEvent::ReasoningSummaryTextDelta {
+ item_id: "rs_123".into(),
+ output_index: 0,
+ delta: "Thinking about".into(),
+ },
+ ResponsesStreamEvent::ReasoningSummaryTextDelta {
+ item_id: "rs_123".into(),
+ output_index: 0,
+ delta: " the answer".into(),
+ },
+ ResponsesStreamEvent::ReasoningSummaryTextDone {
+ item_id: "rs_123".into(),
+ output_index: 0,
+ text: "Thinking about the answer".into(),
+ },
+ ResponsesStreamEvent::ReasoningSummaryPartDone {
+ item_id: "rs_123".into(),
+ output_index: 0,
+ summary_index: 0,
+ },
+ ResponsesStreamEvent::ReasoningSummaryPartAdded {
+ item_id: "rs_123".into(),
+ output_index: 0,
+ summary_index: 1,
+ },
+ ResponsesStreamEvent::ReasoningSummaryTextDelta {
+ item_id: "rs_123".into(),
+ output_index: 0,
+ delta: "Second part".into(),
+ },
+ ResponsesStreamEvent::ReasoningSummaryTextDone {
+ item_id: "rs_123".into(),
+ output_index: 0,
+ text: "Second part".into(),
+ },
+ ResponsesStreamEvent::ReasoningSummaryPartDone {
+ item_id: "rs_123".into(),
+ output_index: 0,
+ summary_index: 1,
+ },
+ ResponsesStreamEvent::OutputItemDone {
+ output_index: 0,
+ sequence_number: None,
+ item: ResponseOutputItem::Reasoning(ResponseReasoningItem {
+ id: Some("rs_123".into()),
+ summary: vec![
+ ReasoningSummaryPart::SummaryText {
+ text: "Thinking about the answer".into(),
+ },
+ ReasoningSummaryPart::SummaryText {
+ text: "Second part".into(),
+ },
+ ],
+ }),
+ },
+ ResponsesStreamEvent::OutputItemAdded {
+ output_index: 1,
+ sequence_number: None,
+ item: response_item_message("msg_456"),
+ },
+ ResponsesStreamEvent::OutputTextDelta {
+ item_id: "msg_456".into(),
+ output_index: 1,
+ content_index: Some(0),
+ delta: "The answer is 42".into(),
+ },
+ ResponsesStreamEvent::Completed {
+ response: ResponseSummary::default(),
+ },
+ ];
+
+ let mapped = map_response_events(events);
+
+ let thinking_events: Vec<_> = mapped
+ .iter()
+ .filter(|e| matches!(e, LanguageModelCompletionEvent::Thinking { .. }))
+ .collect();
+ assert_eq!(
+ thinking_events.len(),
+ 4,
+ "expected 4 thinking events, got {:?}",
+ thinking_events
+ );
+ assert!(
+ matches!(&thinking_events[0], LanguageModelCompletionEvent::Thinking { text, .. } if text == "Thinking about")
+ );
+ assert!(
+ matches!(&thinking_events[1], LanguageModelCompletionEvent::Thinking { text, .. } if text == " the answer")
+ );
+ assert!(
+ matches!(&thinking_events[2], LanguageModelCompletionEvent::Thinking { text, .. } if text == "\n\n"),
+ "expected separator between summary parts"
+ );
+ assert!(
+ matches!(&thinking_events[3], LanguageModelCompletionEvent::Thinking { text, .. } if text == "Second part")
+ );
+
+ assert!(mapped.iter().any(
+ |e| matches!(e, LanguageModelCompletionEvent::Text(t) if t == "The answer is 42")
+ ));
+ }
+
+ #[test]
+ fn responses_stream_maps_reasoning_from_done_only() {
+ let events = vec![
+ ResponsesStreamEvent::OutputItemAdded {
+ output_index: 0,
+ sequence_number: None,
+ item: ResponseOutputItem::Reasoning(ResponseReasoningItem {
+ id: Some("rs_789".into()),
+ summary: vec![],
+ }),
+ },
+ ResponsesStreamEvent::OutputItemDone {
+ output_index: 0,
+ sequence_number: None,
+ item: ResponseOutputItem::Reasoning(ResponseReasoningItem {
+ id: Some("rs_789".into()),
+ summary: vec![ReasoningSummaryPart::SummaryText {
+ text: "Summary without deltas".into(),
+ }],
+ }),
+ },
+ ResponsesStreamEvent::Completed {
+ response: ResponseSummary::default(),
+ },
+ ];
+
+ let mapped = map_response_events(events);
+ assert!(
+ !mapped
+ .iter()
+ .any(|e| matches!(e, LanguageModelCompletionEvent::Thinking { .. })),
+ "OutputItemDone reasoning should not produce Thinking events"
+ );
+ }
+}
@@ -1,4 +1,5 @@
pub mod batches;
+pub mod completion;
pub mod responses;
use anyhow::{Context as _, Result, anyhow};
@@ -7,9 +8,9 @@ use http_client::{
AsyncBody, HttpClient, Method, Request as HttpRequest, StatusCode,
http::{HeaderMap, HeaderValue},
};
+pub use language_model_core::ReasoningEffort;
use serde::{Deserialize, Serialize};
use serde_json::Value;
-pub use settings::OpenAiReasoningEffort as ReasoningEffort;
use std::{convert::TryFrom, future::Future};
use strum::EnumIter;
use thiserror::Error;
@@ -717,3 +718,26 @@ pub fn embed<'a>(
Ok(response)
}
}
+
+// -- Conversions to `language_model_core` types --
+
+impl From<RequestError> for language_model_core::LanguageModelCompletionError {
+ fn from(error: RequestError) -> Self {
+ match error {
+ RequestError::HttpResponseError {
+ provider,
+ status_code,
+ body,
+ headers,
+ } => {
+ let retry_after = headers
+ .get(http_client::http::header::RETRY_AFTER)
+ .and_then(|val| val.to_str().ok()?.parse::<u64>().ok())
+ .map(std::time::Duration::from_secs);
+
+ Self::from_http_status(provider.into(), status_code, body, retry_after)
+ }
+ RequestError::Other(e) => Self::Other(e),
+ }
+ }
+}
@@ -19,6 +19,7 @@ schemars = ["dep:schemars"]
anyhow.workspace = true
futures.workspace = true
http_client.workspace = true
+language_model_core.workspace = true
schemars = { workspace = true, optional = true }
serde.workspace = true
serde_json.workspace = true
@@ -744,3 +744,71 @@ impl ApiErrorCode {
}
}
}
+
+// -- Conversions to `language_model_core` types --
+
+impl From<OpenRouterError> for language_model_core::LanguageModelCompletionError {
+ fn from(error: OpenRouterError) -> Self {
+ let provider = language_model_core::LanguageModelProviderName::new("OpenRouter");
+ match error {
+ OpenRouterError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
+ OpenRouterError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
+ OpenRouterError::HttpSend(error) => Self::HttpSend { provider, error },
+ OpenRouterError::DeserializeResponse(error) => {
+ Self::DeserializeResponse { provider, error }
+ }
+ OpenRouterError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
+ OpenRouterError::RateLimit { retry_after } => Self::RateLimitExceeded {
+ provider,
+ retry_after: Some(retry_after),
+ },
+ OpenRouterError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
+ provider,
+ retry_after,
+ },
+ OpenRouterError::ApiError(api_error) => api_error.into(),
+ }
+ }
+}
+
+impl From<ApiError> for language_model_core::LanguageModelCompletionError {
+ fn from(error: ApiError) -> Self {
+ use ApiErrorCode::*;
+ let provider = language_model_core::LanguageModelProviderName::new("OpenRouter");
+ match error.code {
+ InvalidRequestError => Self::BadRequestFormat {
+ provider,
+ message: error.message,
+ },
+ AuthenticationError => Self::AuthenticationError {
+ provider,
+ message: error.message,
+ },
+ PaymentRequiredError => Self::AuthenticationError {
+ provider,
+ message: format!("Payment required: {}", error.message),
+ },
+ PermissionError => Self::PermissionError {
+ provider,
+ message: error.message,
+ },
+ RequestTimedOut => Self::HttpResponseError {
+ provider,
+ status_code: http_client::StatusCode::REQUEST_TIMEOUT,
+ message: error.message,
+ },
+ RateLimitError => Self::RateLimitExceeded {
+ provider,
+ retry_after: None,
+ },
+ ApiError => Self::ApiInternalServerError {
+ provider,
+ message: error.message,
+ },
+ OverloadedError => Self::ServerOverloaded {
+ provider,
+ retry_after: None,
+ },
+ }
+ }
+}
@@ -412,7 +412,7 @@ impl PrettierStore {
prettier_store
.update(cx, |prettier_store, cx| {
let name = if is_default {
- LanguageServerName("prettier (default)".to_string().into())
+ LanguageServerName("prettier (default)".into())
} else {
let worktree_path = worktree_id
.and_then(|id| {
@@ -19,6 +19,7 @@ anyhow.workspace = true
collections.workspace = true
derive_more.workspace = true
gpui.workspace = true
+language_model_core.workspace = true
log.workspace = true
schemars.workspace = true
serde.workspace = true
@@ -1,8 +1,8 @@
+use crate::merge_from::MergeFrom;
use collections::HashMap;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings_macros::{MergeFrom, with_fallible_options};
-use strum::EnumString;
use std::sync::Arc;
@@ -237,15 +237,12 @@ pub struct OpenAiAvailableModel {
pub capabilities: OpenAiModelCapabilities,
}
-#[derive(Debug, Serialize, Deserialize, PartialEq, Clone, EnumString, JsonSchema, MergeFrom)]
-#[serde(rename_all = "lowercase")]
-#[strum(serialize_all = "lowercase")]
-pub enum OpenAiReasoningEffort {
- Minimal,
- Low,
- Medium,
- High,
- XHigh,
+pub use language_model_core::ReasoningEffort as OpenAiReasoningEffort;
+
+impl MergeFrom for OpenAiReasoningEffort {
+ fn merge_from(&mut self, other: &Self) {
+ *self = *other;
+ }
}
#[with_fallible_options]
@@ -479,15 +476,10 @@ pub struct LanguageModelCacheConfiguration {
pub min_total_token: u64,
}
-#[derive(
- Copy, Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize, JsonSchema, MergeFrom,
-)]
-#[serde(tag = "type", rename_all = "lowercase")]
-pub enum ModelMode {
- #[default]
- Default,
- Thinking {
- /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
- budget_tokens: Option<u32>,
- },
+pub use language_model_core::ModelMode;
+
+impl MergeFrom for ModelMode {
+ fn merge_from(&mut self, other: &Self) {
+ *self = *other;
+ }
}
@@ -14,6 +14,7 @@ path = "src/web_search_providers.rs"
[dependencies]
anyhow.workspace = true
client.workspace = true
+cloud_api_client.workspace = true
cloud_api_types.workspace = true
cloud_llm_client.workspace = true
futures.workspace = true
@@ -2,12 +2,12 @@ use std::sync::Arc;
use anyhow::{Context as _, Result};
use client::{Client, NeedsLlmTokenRefresh, UserStore, global_llm_token};
+use cloud_api_client::LlmApiToken;
use cloud_api_types::OrganizationId;
use cloud_llm_client::{WebSearchBody, WebSearchResponse};
use futures::AsyncReadExt as _;
use gpui::{App, AppContext, Context, Entity, Task};
use http_client::{HttpClient, Method};
-use language_model::LlmApiToken;
use web_search::{WebSearchProvider, WebSearchProviderId};
pub struct CloudWebSearchProvider {
@@ -17,6 +17,8 @@ schemars = ["dep:schemars"]
[dependencies]
anyhow.workspace = true
+language_model_core.workspace = true
schemars = { workspace = true, optional = true }
serde.workspace = true
strum.workspace = true
+tiktoken-rs.workspace = true
@@ -0,0 +1,30 @@
+use anyhow::Result;
+use language_model_core::{LanguageModelRequest, Role};
+
+use crate::Model;
+
+/// Count tokens for an xAI model using tiktoken. This is synchronous;
+/// callers should spawn it on a background thread if needed.
+pub fn count_xai_tokens(request: LanguageModelRequest, model: Model) -> Result<u64> {
+ let messages = request
+ .messages
+ .into_iter()
+ .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
+ role: match message.role {
+ Role::User => "user".into(),
+ Role::Assistant => "assistant".into(),
+ Role::System => "system".into(),
+ },
+ content: Some(message.string_contents()),
+ name: None,
+ function_call: None,
+ })
+ .collect::<Vec<_>>();
+
+ let model_name = if model.max_token_count() >= 100_000 {
+ "gpt-4o"
+ } else {
+ "gpt-4"
+ };
+ tiktoken_rs::num_tokens_from_messages(model_name, &messages).map(|tokens| tokens as u64)
+}
@@ -1,3 +1,5 @@
+pub mod completion;
+
use anyhow::Result;
use serde::{Deserialize, Serialize};
use strum::EnumIter;