diff --git a/Cargo.lock b/Cargo.lock index 94aca307210e19bc97c14002ba5f136edfd76778..78a209af2af33ab8d4e99e2fb3b5f200c6194b34 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -629,13 +629,17 @@ version = "0.1.0" dependencies = [ "anyhow", "chrono", + "collections", "futures 0.3.32", "http_client", + "language_model_core", + "log", "schemars", "serde", "serde_json", "strum 0.27.2", "thiserror 2.0.17", + "tiktoken-rs", ] [[package]] @@ -2903,7 +2907,6 @@ dependencies = [ "http_client", "http_client_tls", "httparse", - "language_model", "log", "objc2-foundation", "parking_lot", @@ -2959,6 +2962,7 @@ dependencies = [ "http_client", "parking_lot", "serde_json", + "smol", "thiserror 2.0.17", "yawc", ] @@ -3204,7 +3208,6 @@ dependencies = [ "anyhow", "call", "channel", - "chrono", "client", "collections", "db", @@ -3213,7 +3216,6 @@ dependencies = [ "fuzzy", "gpui", "livekit_client", - "log", "menu", "notifications", "picker", @@ -3228,7 +3230,6 @@ dependencies = [ "theme", "theme_settings", "time", - "time_format", "title_bar", "ui", "util", @@ -5161,6 +5162,7 @@ dependencies = [ "buffer_diff", "client", "clock", + "cloud_api_client", "cloud_api_types", "cloud_llm_client", "collections", @@ -5640,7 +5642,7 @@ dependencies = [ name = "env_var" version = "0.1.0" dependencies = [ - "gpui", + "gpui_shared_string", ] [[package]] @@ -6182,6 +6184,7 @@ dependencies = [ "file_icons", "futures 0.3.32", "fuzzy", + "fuzzy_nucleo", "gpui", "menu", "open_path_prompt", @@ -6739,6 +6742,15 @@ dependencies = [ "thread_local", ] +[[package]] +name = "fuzzy_nucleo" +version = "0.1.0" +dependencies = [ + "gpui", + "nucleo", + "util", +] + [[package]] name = "gaoya" version = "0.2.0" @@ -7457,11 +7469,13 @@ dependencies = [ "anyhow", "futures 0.3.32", "http_client", + "language_model_core", + "log", "schemars", "serde", "serde_json", - "settings", "strum 0.27.2", + "tiktoken-rs", ] [[package]] @@ -7530,6 +7544,7 @@ dependencies = [ "getrandom 0.3.4", "gpui_macros", "gpui_platform", + "gpui_shared_string", "gpui_util", "gpui_web", "http_client", @@ -7699,6 +7714,16 @@ dependencies = [ "gpui_windows", ] +[[package]] +name = "gpui_shared_string" +version = "0.1.0" +dependencies = [ + "derive_more", + "gpui_util", + "schemars", + "serde", +] + [[package]] name = "gpui_tokio" version = "0.1.0" @@ -9348,7 +9373,7 @@ version = "0.1.0" dependencies = [ "anyhow", "collections", - "gpui", + "gpui_shared_string", "log", "lsp", "parking_lot", @@ -9387,12 +9412,8 @@ dependencies = [ name = "language_model" version = "0.1.0" dependencies = [ - "anthropic", "anyhow", "base64 0.22.1", - "cloud_api_client", - "cloud_api_types", - "cloud_llm_client", "collections", "credentials_provider", "env_var", @@ -9401,16 +9422,31 @@ dependencies = [ "http_client", "icons", "image", + "language_model_core", "log", - "open_ai", - "open_router", "parking_lot", + "serde", + "serde_json", + "thiserror 2.0.17", + "util", +] + +[[package]] +name = "language_model_core" +version = "0.1.0" +dependencies = [ + "anyhow", + "cloud_llm_client", + "futures 0.3.32", + "gpui_shared_string", + "http_client", + "partial-json-fixer", "schemars", "serde", "serde_json", "smol", + "strum 0.27.2", "thiserror 2.0.17", - "util", ] [[package]] @@ -9426,8 +9462,8 @@ dependencies = [ "base64 0.22.1", "bedrock", "client", + "cloud_api_client", "cloud_api_types", - "cloud_llm_client", "collections", "component", "convert_case 0.8.0", @@ -9446,6 +9482,7 @@ dependencies = [ "http_client", "language", "language_model", + "language_models_cloud", "lmstudio", "log", "menu", @@ -9455,19 +9492,16 @@ dependencies = [ "open_router", "opencode", "parking_lot", - "partial-json-fixer", "pretty_assertions", "rand 0.9.2", "release_channel", "schemars", - "semver", "serde", "serde_json", "settings", "sha2", "smol", "strum 0.27.2", - "thiserror 2.0.17", "tiktoken-rs", "tokio", "ui", @@ -9478,6 +9512,28 @@ dependencies = [ "x_ai", ] +[[package]] +name = "language_models_cloud" +version = "0.1.0" +dependencies = [ + "anthropic", + "anyhow", + "cloud_llm_client", + "futures 0.3.32", + "google_ai", + "gpui", + "http_client", + "language_model", + "open_ai", + "schemars", + "semver", + "serde", + "serde_json", + "smol", + "thiserror 2.0.17", + "x_ai", +] + [[package]] name = "language_onboarding" version = "0.1.0" @@ -11067,6 +11123,27 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "nucleo" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5262af4c94921c2646c5ac6ff7900c2af9cbb08dc26a797e18130a7019c039d4" +dependencies = [ + "nucleo-matcher", + "parking_lot", + "rayon", +] + +[[package]] +name = "nucleo-matcher" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf33f538733d1a5a3494b836ba913207f14d9d4a1d3cd67030c5061bdd2cac85" +dependencies = [ + "memchr", + "unicode-segmentation", +] + [[package]] name = "num" version = "0.4.3" @@ -11604,16 +11681,19 @@ name = "open_ai" version = "0.1.0" dependencies = [ "anyhow", + "collections", "futures 0.3.32", "http_client", + "language_model_core", "log", + "pretty_assertions", "rand 0.9.2", "schemars", "serde", "serde_json", - "settings", "strum 0.27.2", "thiserror 2.0.17", + "tiktoken-rs", ] [[package]] @@ -11645,6 +11725,7 @@ dependencies = [ "anyhow", "futures 0.3.32", "http_client", + "language_model_core", "schemars", "serde", "serde_json", @@ -13207,6 +13288,7 @@ dependencies = [ "fs", "futures 0.3.32", "fuzzy", + "fuzzy_nucleo", "git", "git2", "git_hosting_providers", @@ -15773,6 +15855,7 @@ dependencies = [ "collections", "derive_more", "gpui", + "language_model_core", "log", "schemars", "serde", @@ -20152,6 +20235,7 @@ version = "0.1.0" dependencies = [ "anyhow", "client", + "cloud_api_client", "cloud_api_types", "cloud_llm_client", "futures 0.3.32", @@ -21755,9 +21839,11 @@ name = "x_ai" version = "0.1.0" dependencies = [ "anyhow", + "language_model_core", "schemars", "serde", "strum 0.27.2", + "tiktoken-rs", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 5cb5b991b645ec1b78b16f48493c7c8dc1426344..5a7fc9caaf982953168855671bebbcf4f010df03 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -78,6 +78,7 @@ members = [ "crates/fs", "crates/fs_benchmarks", "crates/fuzzy", + "crates/fuzzy_nucleo", "crates/git", "crates/git_graph", "crates/git_hosting_providers", @@ -86,6 +87,7 @@ members = [ "crates/google_ai", "crates/grammars", "crates/gpui", + "crates/gpui_shared_string", "crates/gpui_linux", "crates/gpui_macos", "crates/gpui_macros", @@ -109,7 +111,9 @@ members = [ "crates/language_core", "crates/language_extension", "crates/language_model", + "crates/language_model_core", "crates/language_models", + "crates/language_models_cloud", "crates/language_onboarding", "crates/language_selector", "crates/language_tools", @@ -325,6 +329,7 @@ file_finder = { path = "crates/file_finder" } file_icons = { path = "crates/file_icons" } fs = { path = "crates/fs" } fuzzy = { path = "crates/fuzzy" } +fuzzy_nucleo = { path = "crates/fuzzy_nucleo" } git = { path = "crates/git" } git_graph = { path = "crates/git_graph" } git_hosting_providers = { path = "crates/git_hosting_providers" } @@ -333,6 +338,7 @@ go_to_line = { path = "crates/go_to_line" } google_ai = { path = "crates/google_ai" } grammars = { path = "crates/grammars" } gpui = { path = "crates/gpui", default-features = false } +gpui_shared_string = { path = "crates/gpui_shared_string" } gpui_linux = { path = "crates/gpui_linux", default-features = false } gpui_macos = { path = "crates/gpui_macos", default-features = false } gpui_macros = { path = "crates/gpui_macros" } @@ -359,7 +365,9 @@ language = { path = "crates/language" } language_core = { path = "crates/language_core" } language_extension = { path = "crates/language_extension" } language_model = { path = "crates/language_model" } +language_model_core = { path = "crates/language_model_core" } language_models = { path = "crates/language_models" } +language_models_cloud = { path = "crates/language_models_cloud" } language_onboarding = { path = "crates/language_onboarding" } language_selector = { path = "crates/language_selector" } language_tools = { path = "crates/language_tools" } @@ -609,6 +617,7 @@ naga = { version = "29.0", features = ["wgsl-in"] } nanoid = "0.4" nbformat = "1.2.0" nix = "0.29" +nucleo = "0.5" num-format = "0.4.4" objc = "0.2" objc2-app-kit = { version = "0.3", default-features = false, features = [ "NSGraphics" ] } diff --git a/assets/icons/folder_open_add.svg b/assets/icons/folder_open_add.svg new file mode 100644 index 0000000000000000000000000000000000000000..d5ebbdaa8b080037a2faee0ee0fc3606eec9c6ca --- /dev/null +++ b/assets/icons/folder_open_add.svg @@ -0,0 +1,5 @@ + + + + + diff --git a/assets/icons/folder_plus.svg b/assets/icons/folder_plus.svg deleted file mode 100644 index a543448ed6197043291369bee640e23b6ad729b9..0000000000000000000000000000000000000000 --- a/assets/icons/folder_plus.svg +++ /dev/null @@ -1,5 +0,0 @@ - - - - - diff --git a/assets/icons/open_new_window.svg b/assets/icons/open_new_window.svg new file mode 100644 index 0000000000000000000000000000000000000000..c81d49f9ff9edfbc965055568efc72e0214efb41 --- /dev/null +++ b/assets/icons/open_new_window.svg @@ -0,0 +1,7 @@ + + + + + + + diff --git a/assets/settings/default.json b/assets/settings/default.json index 63e906e3b11206fc458f8d7353f3ecba0abeb825..97fbcd546e09beefa9ff7a67e33806f3faf561d1 100644 --- a/assets/settings/default.json +++ b/assets/settings/default.json @@ -936,16 +936,6 @@ // For example: typing `:wave:` gets replaced with `đź‘‹`. "auto_replace_emoji_shortcode": true, }, - "notification_panel": { - // Whether to show the notification panel button in the status bar. - "button": true, - // Where to dock the notification panel. Can be 'left' or 'right'. - "dock": "right", - // Default width of the notification panel. - "default_width": 380, - // Whether to show a badge on the notification panel icon with the count of unread notifications. - "show_count_badge": false, - }, "agent": { // Whether the inline assistant should use streaming tools, when available "inline_assistant_use_streaming_tools": true, @@ -965,6 +955,9 @@ "default_width": 640, // Default height when the agent panel is docked to the bottom. "default_height": 320, + // Maximum content width when the agent panel is wider than this value. + // Content will be centered within the panel. + "max_content_width": 850, // The default model to use when creating new threads. "default_model": { // The provider to use. diff --git a/crates/agent/src/tool_permissions.rs b/crates/agent/src/tool_permissions.rs index 58e779da59aef176464839ed6f2d6a5c16e4bc12..ff9e735b6c4181588ed5cddbd6dada7fbae5f18f 100644 --- a/crates/agent/src/tool_permissions.rs +++ b/crates/agent/src/tool_permissions.rs @@ -574,6 +574,7 @@ mod tests { flexible: true, default_width: px(300.), default_height: px(600.), + max_content_width: px(850.), default_model: None, inline_assistant_model: None, inline_assistant_use_streaming_tools: false, diff --git a/crates/agent/src/tools/read_file_tool.rs b/crates/agent/src/tools/read_file_tool.rs index 0086a82f4e79c9924502202873ceb2b25d2e66fb..9b013f111e7eaa981652d8868dfcf3c098d9dc7e 100644 --- a/crates/agent/src/tools/read_file_tool.rs +++ b/crates/agent/src/tools/read_file_tool.rs @@ -5,7 +5,7 @@ use futures::FutureExt as _; use gpui::{App, Entity, SharedString, Task}; use indoc::formatdoc; use language::Point; -use language_model::{LanguageModelImage, LanguageModelToolResultContent}; +use language_model::{LanguageModelImage, LanguageModelImageExt, LanguageModelToolResultContent}; use project::{AgentLocation, ImageItem, Project, WorktreeSettings, image_store}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; diff --git a/crates/agent_servers/src/acp.rs b/crates/agent_servers/src/acp.rs index 5f452bc9c0e2e9c2322042583295894a5866b053..e56db9df927ab3cdf838587f1cb4f9514eb5a758 100644 --- a/crates/agent_servers/src/acp.rs +++ b/crates/agent_servers/src/acp.rs @@ -325,7 +325,7 @@ impl AcpConnection { // Use the one the agent provides if we have one .map(|info| info.name.into()) // Otherwise, just use the name - .unwrap_or_else(|| agent_id.0.to_string().into()); + .unwrap_or_else(|| agent_id.0.clone()); let session_list = if response .agent_capabilities diff --git a/crates/agent_settings/src/agent_settings.rs b/crates/agent_settings/src/agent_settings.rs index 0c68d2f25d54f966d1cc0a93476457bbba79c959..a04de2ed3be69d3f5791419a32e427fa0c26791e 100644 --- a/crates/agent_settings/src/agent_settings.rs +++ b/crates/agent_settings/src/agent_settings.rs @@ -31,7 +31,6 @@ pub struct PanelLayout { pub(crate) outline_panel_dock: Option, pub(crate) collaboration_panel_dock: Option, pub(crate) git_panel_dock: Option, - pub(crate) notification_panel_button: Option, } impl PanelLayout { @@ -41,7 +40,6 @@ impl PanelLayout { outline_panel_dock: Some(DockSide::Right), collaboration_panel_dock: Some(DockPosition::Right), git_panel_dock: Some(DockPosition::Right), - notification_panel_button: Some(false), }; const EDITOR: Self = Self { @@ -50,7 +48,6 @@ impl PanelLayout { outline_panel_dock: Some(DockSide::Left), collaboration_panel_dock: Some(DockPosition::Left), git_panel_dock: Some(DockPosition::Left), - notification_panel_button: Some(true), }; pub fn is_agent_layout(&self) -> bool { @@ -68,7 +65,6 @@ impl PanelLayout { outline_panel_dock: content.outline_panel.as_ref().and_then(|p| p.dock), collaboration_panel_dock: content.collaboration_panel.as_ref().and_then(|p| p.dock), git_panel_dock: content.git_panel.as_ref().and_then(|p| p.dock), - notification_panel_button: content.notification_panel.as_ref().and_then(|p| p.button), } } @@ -78,7 +74,6 @@ impl PanelLayout { settings.outline_panel.get_or_insert_default().dock = self.outline_panel_dock; settings.collaboration_panel.get_or_insert_default().dock = self.collaboration_panel_dock; settings.git_panel.get_or_insert_default().dock = self.git_panel_dock; - settings.notification_panel.get_or_insert_default().button = self.notification_panel_button; } fn write_diff_to(&self, current_merged: &PanelLayout, settings: &mut SettingsContent) { @@ -98,10 +93,6 @@ impl PanelLayout { if self.git_panel_dock != current_merged.git_panel_dock { settings.git_panel.get_or_insert_default().dock = self.git_panel_dock; } - if self.notification_panel_button != current_merged.notification_panel_button { - settings.notification_panel.get_or_insert_default().button = - self.notification_panel_button; - } } fn backfill_to(&self, user_layout: &PanelLayout, settings: &mut SettingsContent) { @@ -121,10 +112,6 @@ impl PanelLayout { if user_layout.git_panel_dock.is_none() { settings.git_panel.get_or_insert_default().dock = self.git_panel_dock; } - if user_layout.notification_panel_button.is_none() { - settings.notification_panel.get_or_insert_default().button = - self.notification_panel_button; - } } } @@ -154,6 +141,7 @@ pub struct AgentSettings { pub sidebar_side: SidebarDockPosition, pub default_width: Pixels, pub default_height: Pixels, + pub max_content_width: Pixels, pub default_model: Option, pub inline_assistant_model: Option, pub inline_assistant_use_streaming_tools: bool, @@ -600,6 +588,7 @@ impl Settings for AgentSettings { sidebar_side: agent.sidebar_side.unwrap(), default_width: px(agent.default_width.unwrap()), default_height: px(agent.default_height.unwrap()), + max_content_width: px(agent.max_content_width.unwrap()), flexible: agent.flexible.unwrap(), default_model: Some(agent.default_model.unwrap()), inline_assistant_model: agent.inline_assistant_model, @@ -1255,7 +1244,6 @@ mod tests { assert_eq!(user_layout.outline_panel_dock, None); assert_eq!(user_layout.collaboration_panel_dock, None); assert_eq!(user_layout.git_panel_dock, None); - assert_eq!(user_layout.notification_panel_button, None); // User sets a combination that doesn't match either preset: // agent on the left but project panel also on the left. @@ -1478,7 +1466,6 @@ mod tests { Some(DockPosition::Left) ); assert_eq!(user_layout.git_panel_dock, Some(DockPosition::Left)); - assert_eq!(user_layout.notification_panel_button, Some(true)); // Now switch defaults to agent V2. set_agent_v2_defaults(cx); diff --git a/crates/agent_ui/src/agent_panel.rs b/crates/agent_ui/src/agent_panel.rs index 41900e71e5d3ad7e5327ee7e04f73cb05eed5a5b..01b897fc63da76247b5624f8316ea06b2c1f85e5 100644 --- a/crates/agent_ui/src/agent_panel.rs +++ b/crates/agent_ui/src/agent_panel.rs @@ -28,21 +28,20 @@ use zed_actions::agent::{ use crate::thread_metadata_store::ThreadMetadataStore; use crate::{ AddContextServer, AgentDiffPane, ConversationView, CopyThreadToClipboard, CycleStartThreadIn, - Follow, InlineAssistant, LoadThreadFromClipboard, NewThread, OpenActiveThreadAsMarkdown, - OpenAgentDiff, OpenHistory, ResetTrialEndUpsell, ResetTrialUpsell, StartThreadIn, - ToggleNavigationMenu, ToggleNewThreadMenu, ToggleOptionsMenu, + Follow, InlineAssistant, LoadThreadFromClipboard, NewThread, NewWorktreeBranchTarget, + OpenActiveThreadAsMarkdown, OpenAgentDiff, OpenHistory, ResetTrialEndUpsell, ResetTrialUpsell, + StartThreadIn, ToggleNavigationMenu, ToggleNewThreadMenu, ToggleOptionsMenu, agent_configuration::{AgentConfiguration, AssistantConfigurationEvent}, conversation_view::{AcpThreadViewEvent, ThreadView}, + thread_branch_picker::ThreadBranchPicker, + thread_worktree_picker::ThreadWorktreePicker, ui::EndTrialUpsell, }; use crate::{ Agent, AgentInitialContent, ExternalSourcePrompt, NewExternalAgentThread, NewNativeAgentThreadFromSummary, }; -use crate::{ - DEFAULT_THREAD_TITLE, - ui::{AcpOnboardingModal, HoldForDefault}, -}; +use crate::{DEFAULT_THREAD_TITLE, ui::AcpOnboardingModal}; use crate::{ExpandMessageEditor, ThreadHistoryView}; use crate::{ManageProfiles, ThreadHistoryViewEvent}; use crate::{ThreadHistory, agent_connection_store::AgentConnectionStore}; @@ -73,8 +72,8 @@ use terminal::terminal_settings::TerminalSettings; use terminal_view::{TerminalView, terminal_panel::TerminalPanel}; use theme_settings::ThemeSettings; use ui::{ - Button, Callout, CommonAnimationExt, ContextMenu, ContextMenuEntry, DocumentationSide, - PopoverMenu, PopoverMenuHandle, Tab, Tooltip, prelude::*, utils::WithRemSize, + Button, Callout, CommonAnimationExt, ContextMenu, ContextMenuEntry, PopoverMenu, + PopoverMenuHandle, Tab, Tooltip, prelude::*, utils::WithRemSize, }; use util::{ResultExt as _, debug_panic}; use workspace::{ @@ -620,7 +619,31 @@ impl StartThreadIn { fn label(&self) -> SharedString { match self { Self::LocalProject => "Current Worktree".into(), - Self::NewWorktree => "New Git Worktree".into(), + Self::NewWorktree { + worktree_name: Some(worktree_name), + .. + } => format!("New: {worktree_name}").into(), + Self::NewWorktree { .. } => "New Git Worktree".into(), + Self::LinkedWorktree { display_name, .. } => format!("From: {}", &display_name).into(), + } + } + + fn worktree_branch_label(&self, default_branch_label: SharedString) -> Option { + match self { + Self::NewWorktree { branch_target, .. } => match branch_target { + NewWorktreeBranchTarget::CurrentBranch => Some(default_branch_label), + NewWorktreeBranchTarget::ExistingBranch { name } => { + Some(format!("From: {name}").into()) + } + NewWorktreeBranchTarget::CreateBranch { name, from_ref } => { + if let Some(from_ref) = from_ref { + Some(format!("From: {from_ref}").into()) + } else { + Some(format!("From: {name}").into()) + } + } + }, + _ => None, } } } @@ -632,6 +655,17 @@ pub enum WorktreeCreationStatus { Error(SharedString), } +#[derive(Clone, Debug)] +enum WorktreeCreationArgs { + New { + worktree_name: Option, + branch_target: NewWorktreeBranchTarget, + }, + Linked { + worktree_path: PathBuf, + }, +} + impl ActiveView { pub fn which_font_size_used(&self) -> WhichFontSize { match self { @@ -662,7 +696,8 @@ pub struct AgentPanel { previous_view: Option, background_threads: HashMap>, new_thread_menu_handle: PopoverMenuHandle, - start_thread_in_menu_handle: PopoverMenuHandle, + start_thread_in_menu_handle: PopoverMenuHandle, + thread_branch_menu_handle: PopoverMenuHandle, agent_panel_menu_handle: PopoverMenuHandle, agent_navigation_menu_handle: PopoverMenuHandle, agent_navigation_menu: Option>, @@ -689,7 +724,7 @@ impl AgentPanel { }; let selected_agent = self.selected_agent.clone(); - let start_thread_in = Some(self.start_thread_in); + let start_thread_in = Some(self.start_thread_in.clone()); let last_active_thread = self.active_agent_thread(cx).map(|thread| { let thread = thread.read(cx); @@ -794,18 +829,21 @@ impl AgentPanel { } else if let Some(agent) = global_fallback { panel.selected_agent = agent; } - if let Some(start_thread_in) = serialized_panel.start_thread_in { + if let Some(ref start_thread_in) = serialized_panel.start_thread_in { let is_worktree_flag_enabled = cx.has_flag::(); let is_valid = match &start_thread_in { StartThreadIn::LocalProject => true, - StartThreadIn::NewWorktree => { + StartThreadIn::NewWorktree { .. } => { let project = panel.project.read(cx); is_worktree_flag_enabled && !project.is_via_collab() } + StartThreadIn::LinkedWorktree { path, .. } => { + is_worktree_flag_enabled && path.exists() + } }; if is_valid { - panel.start_thread_in = start_thread_in; + panel.start_thread_in = start_thread_in.clone(); } else { log::info!( "deserialized start_thread_in {:?} is no longer valid, falling back to LocalProject", @@ -979,6 +1017,7 @@ impl AgentPanel { background_threads: HashMap::default(), new_thread_menu_handle: PopoverMenuHandle::default(), start_thread_in_menu_handle: PopoverMenuHandle::default(), + thread_branch_menu_handle: PopoverMenuHandle::default(), agent_panel_menu_handle: PopoverMenuHandle::default(), agent_navigation_menu_handle: PopoverMenuHandle::default(), agent_navigation_menu: None, @@ -1948,24 +1987,43 @@ impl AgentPanel { window: &mut Window, cx: &mut Context, ) { - if matches!(action, StartThreadIn::NewWorktree) && !cx.has_flag::() { - return; - } - - let new_target = match *action { + let new_target = match action { StartThreadIn::LocalProject => StartThreadIn::LocalProject, - StartThreadIn::NewWorktree => { + StartThreadIn::NewWorktree { .. } => { + if !cx.has_flag::() { + return; + } if !self.project_has_git_repository(cx) { log::error!( - "set_start_thread_in: cannot use NewWorktree without a git repository" + "set_start_thread_in: cannot use worktree mode without a git repository" ); return; } if self.project.read(cx).is_via_collab() { - log::error!("set_start_thread_in: cannot use NewWorktree in a collab project"); + log::error!( + "set_start_thread_in: cannot use worktree mode in a collab project" + ); return; } - StartThreadIn::NewWorktree + action.clone() + } + StartThreadIn::LinkedWorktree { .. } => { + if !cx.has_flag::() { + return; + } + if !self.project_has_git_repository(cx) { + log::error!( + "set_start_thread_in: cannot use LinkedWorktree without a git repository" + ); + return; + } + if self.project.read(cx).is_via_collab() { + log::error!( + "set_start_thread_in: cannot use LinkedWorktree in a collab project" + ); + return; + } + action.clone() } }; self.start_thread_in = new_target; @@ -1977,9 +2035,14 @@ impl AgentPanel { } fn cycle_start_thread_in(&mut self, window: &mut Window, cx: &mut Context) { - let next = match self.start_thread_in { - StartThreadIn::LocalProject => StartThreadIn::NewWorktree, - StartThreadIn::NewWorktree => StartThreadIn::LocalProject, + let next = match &self.start_thread_in { + StartThreadIn::LocalProject => StartThreadIn::NewWorktree { + worktree_name: None, + branch_target: NewWorktreeBranchTarget::default(), + }, + StartThreadIn::NewWorktree { .. } | StartThreadIn::LinkedWorktree { .. } => { + StartThreadIn::LocalProject + } }; self.set_start_thread_in(&next, window, cx); } @@ -1991,7 +2054,10 @@ impl AgentPanel { NewThreadLocation::LocalProject => StartThreadIn::LocalProject, NewThreadLocation::NewWorktree => { if self.project_has_git_repository(cx) { - StartThreadIn::NewWorktree + StartThreadIn::NewWorktree { + worktree_name: None, + branch_target: NewWorktreeBranchTarget::default(), + } } else { StartThreadIn::LocalProject } @@ -2219,15 +2285,39 @@ impl AgentPanel { window: &mut Window, cx: &mut Context, ) { - if self.start_thread_in == StartThreadIn::NewWorktree { - self.handle_worktree_creation_requested(content, window, cx); - } else { - cx.defer_in(window, move |_this, window, cx| { - thread_view.update(cx, |thread_view, cx| { - let editor = thread_view.message_editor.clone(); - thread_view.send_impl(editor, window, cx); + match &self.start_thread_in { + StartThreadIn::NewWorktree { + worktree_name, + branch_target, + } => { + self.handle_worktree_requested( + content, + WorktreeCreationArgs::New { + worktree_name: worktree_name.clone(), + branch_target: branch_target.clone(), + }, + window, + cx, + ); + } + StartThreadIn::LinkedWorktree { path, .. } => { + self.handle_worktree_requested( + content, + WorktreeCreationArgs::Linked { + worktree_path: path.clone(), + }, + window, + cx, + ); + } + StartThreadIn::LocalProject => { + cx.defer_in(window, move |_this, window, cx| { + thread_view.update(cx, |thread_view, cx| { + let editor = thread_view.message_editor.clone(); + thread_view.send_impl(editor, window, cx); + }); }); - }); + } } } @@ -2289,6 +2379,33 @@ impl AgentPanel { (git_repos, non_git_paths) } + fn resolve_worktree_branch_target( + branch_target: &NewWorktreeBranchTarget, + existing_branches: &HashSet, + occupied_branches: &HashSet, + ) -> Result<(String, bool, Option)> { + let generate_branch_name = || -> Result { + let refs: Vec<&str> = existing_branches.iter().map(|s| s.as_str()).collect(); + let mut rng = rand::rng(); + crate::branch_names::generate_branch_name(&refs, &mut rng) + .ok_or_else(|| anyhow!("Failed to generate a unique branch name")) + }; + + match branch_target { + NewWorktreeBranchTarget::CreateBranch { name, from_ref } => { + Ok((name.clone(), false, from_ref.clone())) + } + NewWorktreeBranchTarget::ExistingBranch { name } => { + if occupied_branches.contains(name) { + Ok((generate_branch_name()?, false, Some(name.clone()))) + } else { + Ok((name.clone(), true, None)) + } + } + NewWorktreeBranchTarget::CurrentBranch => Ok((generate_branch_name()?, false, None)), + } + } + /// Kicks off an async git-worktree creation for each repository. Returns: /// /// - `creation_infos`: a vec of `(repo, new_path, receiver)` tuples—the @@ -2297,7 +2414,10 @@ impl AgentPanel { /// later to remap open editor tabs into the new workspace. fn start_worktree_creations( git_repos: &[Entity], + worktree_name: Option, branch_name: &str, + use_existing_branch: bool, + start_point: Option, worktree_directory_setting: &str, cx: &mut Context, ) -> Result<( @@ -2311,12 +2431,27 @@ impl AgentPanel { let mut creation_infos = Vec::new(); let mut path_remapping = Vec::new(); + let worktree_name = worktree_name.unwrap_or_else(|| branch_name.to_string()); + for repo in git_repos { let (work_dir, new_path, receiver) = repo.update(cx, |repo, _cx| { let new_path = - repo.path_for_new_linked_worktree(branch_name, worktree_directory_setting)?; - let receiver = - repo.create_worktree(branch_name.to_string(), new_path.clone(), None); + repo.path_for_new_linked_worktree(&worktree_name, worktree_directory_setting)?; + let target = if use_existing_branch { + debug_assert!( + git_repos.len() == 1, + "use_existing_branch should only be true for a single repo" + ); + git::repository::CreateWorktreeTarget::ExistingBranch { + branch_name: branch_name.to_string(), + } + } else { + git::repository::CreateWorktreeTarget::NewBranch { + branch_name: branch_name.to_string(), + base_sha: start_point.clone(), + } + }; + let receiver = repo.create_worktree(target, new_path.clone()); let work_dir = repo.work_directory_abs_path.clone(); anyhow::Ok((work_dir, new_path, receiver)) })?; @@ -2419,9 +2554,10 @@ impl AgentPanel { cx.notify(); } - fn handle_worktree_creation_requested( + fn handle_worktree_requested( &mut self, content: Vec, + args: WorktreeCreationArgs, window: &mut Window, cx: &mut Context, ) { @@ -2437,7 +2573,7 @@ impl AgentPanel { let (git_repos, non_git_paths) = self.classify_worktrees(cx); - if git_repos.is_empty() { + if matches!(args, WorktreeCreationArgs::New { .. }) && git_repos.is_empty() { self.set_worktree_creation_error( "No git repositories found in the project".into(), window, @@ -2446,17 +2582,31 @@ impl AgentPanel { return; } - // Kick off branch listing as early as possible so it can run - // concurrently with the remaining synchronous setup work. - let branch_receivers: Vec<_> = git_repos - .iter() - .map(|repo| repo.update(cx, |repo, _cx| repo.branches())) - .collect(); - - let worktree_directory_setting = ProjectSettings::get_global(cx) - .git - .worktree_directory - .clone(); + let (branch_receivers, worktree_receivers, worktree_directory_setting) = + if matches!(args, WorktreeCreationArgs::New { .. }) { + ( + Some( + git_repos + .iter() + .map(|repo| repo.update(cx, |repo, _cx| repo.branches())) + .collect::>(), + ), + Some( + git_repos + .iter() + .map(|repo| repo.update(cx, |repo, _cx| repo.worktrees())) + .collect::>(), + ), + Some( + ProjectSettings::get_global(cx) + .git + .worktree_directory + .clone(), + ), + ) + } else { + (None, None, None) + }; let active_file_path = self.workspace.upgrade().and_then(|workspace| { let workspace = workspace.read(cx); @@ -2476,77 +2626,124 @@ impl AgentPanel { let selected_agent = self.selected_agent(); let task = cx.spawn_in(window, async move |this, cx| { - // Await the branch listings we kicked off earlier. - let mut existing_branches = Vec::new(); - for result in futures::future::join_all(branch_receivers).await { - match result { - Ok(Ok(branches)) => { - for branch in branches { - existing_branches.push(branch.name().to_string()); + let (all_paths, path_remapping, has_non_git) = match args { + WorktreeCreationArgs::New { + worktree_name, + branch_target, + } => { + let branch_receivers = branch_receivers + .expect("branch receivers must be prepared for new worktree creation"); + let worktree_receivers = worktree_receivers + .expect("worktree receivers must be prepared for new worktree creation"); + let worktree_directory_setting = worktree_directory_setting + .expect("worktree directory must be prepared for new worktree creation"); + + let mut existing_branches = HashSet::default(); + for result in futures::future::join_all(branch_receivers).await { + match result { + Ok(Ok(branches)) => { + for branch in branches { + existing_branches.insert(branch.name().to_string()); + } + } + Ok(Err(err)) => { + Err::<(), _>(err).log_err(); + } + Err(_) => {} } } - Ok(Err(err)) => { - Err::<(), _>(err).log_err(); + + let mut occupied_branches = HashSet::default(); + for result in futures::future::join_all(worktree_receivers).await { + match result { + Ok(Ok(worktrees)) => { + for worktree in worktrees { + if let Some(branch_name) = worktree.branch_name() { + occupied_branches.insert(branch_name.to_string()); + } + } + } + Ok(Err(err)) => { + Err::<(), _>(err).log_err(); + } + Err(_) => {} + } } - Err(_) => {} - } - } - let existing_branch_refs: Vec<&str> = - existing_branches.iter().map(|s| s.as_str()).collect(); - let mut rng = rand::rng(); - let branch_name = - match crate::branch_names::generate_branch_name(&existing_branch_refs, &mut rng) { - Some(name) => name, - None => { - this.update_in(cx, |this, window, cx| { - this.set_worktree_creation_error( - "Failed to generate a unique branch name".into(), - window, + let (branch_name, use_existing_branch, start_point) = + match Self::resolve_worktree_branch_target( + &branch_target, + &existing_branches, + &occupied_branches, + ) { + Ok(target) => target, + Err(err) => { + this.update_in(cx, |this, window, cx| { + this.set_worktree_creation_error( + err.to_string().into(), + window, + cx, + ); + })?; + return anyhow::Ok(()); + } + }; + + let (creation_infos, path_remapping) = + match this.update_in(cx, |_this, _window, cx| { + Self::start_worktree_creations( + &git_repos, + worktree_name, + &branch_name, + use_existing_branch, + start_point, + &worktree_directory_setting, cx, - ); - })?; - return anyhow::Ok(()); - } - }; + ) + }) { + Ok(Ok(result)) => result, + Ok(Err(err)) | Err(err) => { + this.update_in(cx, |this, window, cx| { + this.set_worktree_creation_error( + format!("Failed to validate worktree directory: {err}") + .into(), + window, + cx, + ); + }) + .log_err(); + return anyhow::Ok(()); + } + }; - let (creation_infos, path_remapping) = match this.update_in(cx, |_this, _window, cx| { - Self::start_worktree_creations( - &git_repos, - &branch_name, - &worktree_directory_setting, - cx, - ) - }) { - Ok(Ok(result)) => result, - Ok(Err(err)) | Err(err) => { - this.update_in(cx, |this, window, cx| { - this.set_worktree_creation_error( - format!("Failed to validate worktree directory: {err}").into(), - window, - cx, - ); - }) - .log_err(); - return anyhow::Ok(()); - } - }; + let created_paths = + match Self::await_and_rollback_on_failure(creation_infos, cx).await { + Ok(paths) => paths, + Err(err) => { + this.update_in(cx, |this, window, cx| { + this.set_worktree_creation_error( + format!("{err}").into(), + window, + cx, + ); + })?; + return anyhow::Ok(()); + } + }; - let created_paths = match Self::await_and_rollback_on_failure(creation_infos, cx).await - { - Ok(paths) => paths, - Err(err) => { - this.update_in(cx, |this, window, cx| { - this.set_worktree_creation_error(format!("{err}").into(), window, cx); - })?; - return anyhow::Ok(()); + let mut all_paths = created_paths; + let has_non_git = !non_git_paths.is_empty(); + all_paths.extend(non_git_paths.iter().cloned()); + (all_paths, path_remapping, has_non_git) + } + WorktreeCreationArgs::Linked { worktree_path } => { + let mut all_paths = vec![worktree_path]; + let has_non_git = !non_git_paths.is_empty(); + all_paths.extend(non_git_paths.iter().cloned()); + (all_paths, Vec::new(), has_non_git) } }; - let mut all_paths = created_paths; - let has_non_git = !non_git_paths.is_empty(); - all_paths.extend(non_git_paths.iter().cloned()); - let app_state = match workspace.upgrade() { Some(workspace) => cx.update(|_, cx| workspace.read(cx).app_state().clone())?, None => { @@ -2562,7 +2759,7 @@ impl AgentPanel { }; let this_for_error = this.clone(); - if let Err(err) = Self::setup_new_workspace( + if let Err(err) = Self::open_worktree_workspace_and_start_thread( this, all_paths, app_state, @@ -2595,7 +2792,7 @@ impl AgentPanel { })); } - async fn setup_new_workspace( + async fn open_worktree_workspace_and_start_thread( this: WeakEntity, all_paths: Vec, app_state: Arc, @@ -2989,17 +3186,11 @@ impl AgentPanel { fn render_panel_options_menu( &self, - window: &mut Window, + _window: &mut Window, cx: &mut Context, ) -> impl IntoElement { let focus_handle = self.focus_handle(cx); - let full_screen_label = if self.is_zoomed(window, cx) { - "Disable Full Screen" - } else { - "Enable Full Screen" - }; - let conversation_view = match &self.active_view { ActiveView::AgentThread { conversation_view } => Some(conversation_view.clone()), _ => None, @@ -3075,8 +3266,7 @@ impl AgentPanel { .action("Profiles", Box::new(ManageProfiles::default())) .action("Settings", Box::new(OpenSettings)) .separator() - .action("Toggle Threads Sidebar", Box::new(ToggleWorkspaceSidebar)) - .action(full_screen_label, Box::new(ToggleZoom)); + .action("Toggle Threads Sidebar", Box::new(ToggleWorkspaceSidebar)); if has_auth_methods { menu = menu.action("Reauthenticate", Box::new(ReauthenticateAgent)) @@ -3149,25 +3339,15 @@ impl AgentPanel { } fn render_start_thread_in_selector(&self, cx: &mut Context) -> impl IntoElement { - use settings::{NewThreadLocation, Settings}; - let focus_handle = self.focus_handle(cx); - let has_git_repo = self.project_has_git_repository(cx); - let is_via_collab = self.project.read(cx).is_via_collab(); - let fs = self.fs.clone(); let is_creating = matches!( self.worktree_creation_status, Some(WorktreeCreationStatus::Creating) ); - let current_target = self.start_thread_in; let trigger_label = self.start_thread_in.label(); - let new_thread_location = AgentSettings::get_global(cx).new_thread_location; - let is_local_default = new_thread_location == NewThreadLocation::LocalProject; - let is_new_worktree_default = new_thread_location == NewThreadLocation::NewWorktree; - let icon = if self.start_thread_in_menu_handle.is_deployed() { IconName::ChevronUp } else { @@ -3178,13 +3358,9 @@ impl AgentPanel { .end_icon(Icon::new(icon).size(IconSize::XSmall).color(Color::Muted)) .disabled(is_creating); - let dock_position = AgentSettings::get_global(cx).dock; - let documentation_side = match dock_position { - settings::DockPosition::Left => DocumentationSide::Right, - settings::DockPosition::Bottom | settings::DockPosition::Right => { - DocumentationSide::Left - } - }; + let project = self.project.clone(); + let current_target = self.start_thread_in.clone(); + let fs = self.fs.clone(); PopoverMenu::new("thread-target-selector") .trigger_with_tooltip(trigger_button, { @@ -3198,89 +3374,66 @@ impl AgentPanel { } }) .menu(move |window, cx| { - let is_local_selected = current_target == StartThreadIn::LocalProject; - let is_new_worktree_selected = current_target == StartThreadIn::NewWorktree; let fs = fs.clone(); + Some(cx.new(|cx| { + ThreadWorktreePicker::new(project.clone(), ¤t_target, fs, window, cx) + })) + }) + .with_handle(self.start_thread_in_menu_handle.clone()) + .anchor(Corner::TopLeft) + .offset(gpui::Point { + x: px(1.0), + y: px(1.0), + }) + } - Some(ContextMenu::build(window, cx, move |menu, _window, _cx| { - let new_worktree_disabled = !has_git_repo || is_via_collab; + fn render_new_worktree_branch_selector(&self, cx: &mut Context) -> impl IntoElement { + let is_creating = matches!( + self.worktree_creation_status, + Some(WorktreeCreationStatus::Creating) + ); + let default_branch_label = if self.project.read(cx).repositories(cx).len() > 1 { + SharedString::from("From: current branches") + } else { + self.project + .read(cx) + .active_repository(cx) + .and_then(|repo| { + repo.read(cx) + .branch + .as_ref() + .map(|branch| SharedString::from(format!("From: {}", branch.name()))) + }) + .unwrap_or_else(|| SharedString::from("From: HEAD")) + }; + let trigger_label = self + .start_thread_in + .worktree_branch_label(default_branch_label) + .unwrap_or_else(|| SharedString::from("From: HEAD")); + let icon = if self.thread_branch_menu_handle.is_deployed() { + IconName::ChevronUp + } else { + IconName::ChevronDown + }; + let trigger_button = Button::new("thread-branch-trigger", trigger_label) + .start_icon( + Icon::new(IconName::GitBranch) + .size(IconSize::Small) + .color(Color::Muted), + ) + .end_icon(Icon::new(icon).size(IconSize::XSmall).color(Color::Muted)) + .disabled(is_creating); + let project = self.project.clone(); + let current_target = self.start_thread_in.clone(); - menu.header("Start Thread In…") - .item( - ContextMenuEntry::new("Current Worktree") - .toggleable(IconPosition::End, is_local_selected) - .documentation_aside(documentation_side, move |_| { - HoldForDefault::new(is_local_default) - .more_content(false) - .into_any_element() - }) - .handler({ - let fs = fs.clone(); - move |window, cx| { - if window.modifiers().secondary() { - update_settings_file(fs.clone(), cx, |settings, _| { - settings - .agent - .get_or_insert_default() - .set_new_thread_location( - NewThreadLocation::LocalProject, - ); - }); - } - window.dispatch_action( - Box::new(StartThreadIn::LocalProject), - cx, - ); - } - }), - ) - .item({ - let entry = ContextMenuEntry::new("New Git Worktree") - .toggleable(IconPosition::End, is_new_worktree_selected) - .disabled(new_worktree_disabled) - .handler({ - let fs = fs.clone(); - move |window, cx| { - if window.modifiers().secondary() { - update_settings_file(fs.clone(), cx, |settings, _| { - settings - .agent - .get_or_insert_default() - .set_new_thread_location( - NewThreadLocation::NewWorktree, - ); - }); - } - window.dispatch_action( - Box::new(StartThreadIn::NewWorktree), - cx, - ); - } - }); - - if new_worktree_disabled { - entry.documentation_aside(documentation_side, move |_| { - let reason = if !has_git_repo { - "No git repository found in this project." - } else { - "Not available for remote/collab projects yet." - }; - Label::new(reason) - .color(Color::Muted) - .size(LabelSize::Small) - .into_any_element() - }) - } else { - entry.documentation_aside(documentation_side, move |_| { - HoldForDefault::new(is_new_worktree_default) - .more_content(false) - .into_any_element() - }) - } - }) + PopoverMenu::new("thread-branch-selector") + .trigger_with_tooltip(trigger_button, Tooltip::text("Choose Worktree Branch…")) + .menu(move |window, cx| { + Some(cx.new(|cx| { + ThreadBranchPicker::new(project.clone(), ¤t_target, window, cx) })) }) - .with_handle(self.start_thread_in_menu_handle.clone()) + .with_handle(self.thread_branch_menu_handle.clone()) .anchor(Corner::TopLeft) .offset(gpui::Point { x: px(1.0), @@ -3549,21 +3702,37 @@ impl AgentPanel { ); let is_full_screen = self.is_zoomed(window, cx); + let full_screen_button = if is_full_screen { + IconButton::new("disable-full-screen", IconName::Minimize) + .icon_size(IconSize::Small) + .tooltip(move |_, cx| Tooltip::for_action("Disable Full Screen", &ToggleZoom, cx)) + .on_click(cx.listener(move |this, _, window, cx| { + this.toggle_zoom(&ToggleZoom, window, cx); + })) + } else { + IconButton::new("enable-full-screen", IconName::Maximize) + .icon_size(IconSize::Small) + .tooltip(move |_, cx| Tooltip::for_action("Enable Full Screen", &ToggleZoom, cx)) + .on_click(cx.listener(move |this, _, window, cx| { + this.toggle_zoom(&ToggleZoom, window, cx); + })) + }; let use_v2_empty_toolbar = has_v2_flag && is_empty_state && !is_in_history_or_config; + let max_content_width = AgentSettings::get_global(cx).max_content_width; + let base_container = h_flex() - .id("agent-panel-toolbar") - .h(Tab::container_height(cx)) - .max_w_full() + .size_full() + // TODO: This is only until we remove Agent settings from the panel. + .when(!is_in_history_or_config, |this| { + this.max_w(max_content_width).mx_auto() + }) .flex_none() .justify_between() - .gap_2() - .bg(cx.theme().colors().tab_bar_background) - .border_b_1() - .border_color(cx.theme().colors().border); + .gap_2(); - if use_v2_empty_toolbar { + let toolbar_content = if use_v2_empty_toolbar { let (chevron_icon, icon_color, label_color) = if self.new_thread_menu_handle.is_deployed() { (IconName::ChevronUp, Color::Accent, Color::Accent) @@ -3621,6 +3790,14 @@ impl AgentPanel { .when( has_visible_worktrees && self.project_has_git_repository(cx), |this| this.child(self.render_start_thread_in_selector(cx)), + ) + .when( + has_v2_flag + && matches!( + self.start_thread_in, + StartThreadIn::NewWorktree { .. } + ), + |this| this.child(self.render_new_worktree_branch_selector(cx)), ), ) .child( @@ -3637,20 +3814,7 @@ impl AgentPanel { cx, )) }) - .when(is_full_screen, |this| { - this.child( - IconButton::new("disable-full-screen", IconName::Minimize) - .icon_size(IconSize::Small) - .tooltip(move |_, cx| { - Tooltip::for_action("Disable Full Screen", &ToggleZoom, cx) - }) - .on_click({ - cx.listener(move |_, _, window, cx| { - window.dispatch_action(ToggleZoom.boxed_clone(), cx); - }) - }), - ) - }) + .child(full_screen_button) .child(self.render_panel_options_menu(window, cx)), ) .into_any_element() @@ -3703,24 +3867,21 @@ impl AgentPanel { cx, )) }) - .when(is_full_screen, |this| { - this.child( - IconButton::new("disable-full-screen", IconName::Minimize) - .icon_size(IconSize::Small) - .tooltip(move |_, cx| { - Tooltip::for_action("Disable Full Screen", &ToggleZoom, cx) - }) - .on_click({ - cx.listener(move |_, _, window, cx| { - window.dispatch_action(ToggleZoom.boxed_clone(), cx); - }) - }), - ) - }) + .child(full_screen_button) .child(self.render_panel_options_menu(window, cx)), ) .into_any_element() - } + }; + + h_flex() + .id("agent-panel-toolbar") + .h(Tab::container_height(cx)) + .flex_shrink_0() + .max_w_full() + .bg(cx.theme().colors().tab_bar_background) + .border_b_1() + .border_color(cx.theme().colors().border) + .child(toolbar_content) } fn render_worktree_creation_status(&self, cx: &mut Context) -> Option { @@ -5265,13 +5426,23 @@ mod tests { // Change thread target to NewWorktree. panel.update_in(cx, |panel, window, cx| { - panel.set_start_thread_in(&StartThreadIn::NewWorktree, window, cx); + panel.set_start_thread_in( + &StartThreadIn::NewWorktree { + worktree_name: None, + branch_target: NewWorktreeBranchTarget::default(), + }, + window, + cx, + ); }); panel.read_with(cx, |panel, _cx| { assert_eq!( *panel.start_thread_in(), - StartThreadIn::NewWorktree, + StartThreadIn::NewWorktree { + worktree_name: None, + branch_target: NewWorktreeBranchTarget::default(), + }, "thread target should be NewWorktree after set_thread_target" ); }); @@ -5289,7 +5460,10 @@ mod tests { loaded_panel.read_with(cx, |panel, _cx| { assert_eq!( *panel.start_thread_in(), - StartThreadIn::NewWorktree, + StartThreadIn::NewWorktree { + worktree_name: None, + branch_target: NewWorktreeBranchTarget::default(), + }, "thread target should survive serialization round-trip" ); }); @@ -5420,6 +5594,53 @@ mod tests { ); } + #[test] + fn test_resolve_worktree_branch_target() { + let existing_branches = HashSet::from_iter([ + "main".to_string(), + "feature".to_string(), + "origin/main".to_string(), + ]); + + let resolved = AgentPanel::resolve_worktree_branch_target( + &NewWorktreeBranchTarget::CreateBranch { + name: "new-branch".to_string(), + from_ref: Some("main".to_string()), + }, + &existing_branches, + &HashSet::from_iter(["main".to_string()]), + ) + .unwrap(); + assert_eq!( + resolved, + ("new-branch".to_string(), false, Some("main".to_string())) + ); + + let resolved = AgentPanel::resolve_worktree_branch_target( + &NewWorktreeBranchTarget::ExistingBranch { + name: "feature".to_string(), + }, + &existing_branches, + &HashSet::default(), + ) + .unwrap(); + assert_eq!(resolved, ("feature".to_string(), true, None)); + + let resolved = AgentPanel::resolve_worktree_branch_target( + &NewWorktreeBranchTarget::ExistingBranch { + name: "main".to_string(), + }, + &existing_branches, + &HashSet::from_iter(["main".to_string()]), + ) + .unwrap(); + assert_eq!(resolved.1, false); + assert_eq!(resolved.2, Some("main".to_string())); + assert_ne!(resolved.0, "main"); + assert!(existing_branches.contains("main")); + assert!(!existing_branches.contains(&resolved.0)); + } + #[gpui::test] async fn test_worktree_creation_preserves_selected_agent(cx: &mut TestAppContext) { init_test(cx); @@ -5513,7 +5734,14 @@ mod tests { panel.selected_agent = Agent::Custom { id: CODEX_ID.into(), }; - panel.set_start_thread_in(&StartThreadIn::NewWorktree, window, cx); + panel.set_start_thread_in( + &StartThreadIn::NewWorktree { + worktree_name: None, + branch_target: NewWorktreeBranchTarget::default(), + }, + window, + cx, + ); }); // Verify the panel has the Codex agent selected. @@ -5532,7 +5760,15 @@ mod tests { "Hello from test", ))]; panel.update_in(cx, |panel, window, cx| { - panel.handle_worktree_creation_requested(content, window, cx); + panel.handle_worktree_requested( + content, + WorktreeCreationArgs::New { + worktree_name: None, + branch_target: NewWorktreeBranchTarget::default(), + }, + window, + cx, + ); }); // Let the async worktree creation + workspace setup complete. diff --git a/crates/agent_ui/src/agent_registry_ui.rs b/crates/agent_ui/src/agent_registry_ui.rs index 78b4e3a5a3965c72b96d4ec201139b1d8e510fb2..e19afdecc390268cefbd7be4e5d0759aa2a29c19 100644 --- a/crates/agent_ui/src/agent_registry_ui.rs +++ b/crates/agent_ui/src/agent_registry_ui.rs @@ -382,7 +382,7 @@ impl AgentRegistryPage { self.install_button(agent, install_status, supports_current_platform, cx); let repository_button = agent.repository().map(|repository| { - let repository_for_tooltip: SharedString = repository.to_string().into(); + let repository_for_tooltip = repository.clone(); let repository_for_click = repository.to_string(); IconButton::new( diff --git a/crates/agent_ui/src/agent_ui.rs b/crates/agent_ui/src/agent_ui.rs index 5cff5bfc38d4512d659d919c6e7c4ff02fcc0caf..429bc184f5d889990599c196910ae8d0feb28da1 100644 --- a/crates/agent_ui/src/agent_ui.rs +++ b/crates/agent_ui/src/agent_ui.rs @@ -28,13 +28,16 @@ mod terminal_codegen; mod terminal_inline_assistant; #[cfg(any(test, feature = "test-support"))] pub mod test_support; +mod thread_branch_picker; mod thread_history; mod thread_history_view; mod thread_import; pub mod thread_metadata_store; +mod thread_worktree_picker; pub mod threads_archive_view; mod ui; +use std::path::PathBuf; use std::rc::Rc; use std::sync::Arc; @@ -314,16 +317,42 @@ impl Agent { } } +/// Describes which branch to use when creating a new git worktree. +#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize, JsonSchema)] +#[serde(rename_all = "snake_case", tag = "kind")] +pub enum NewWorktreeBranchTarget { + /// Create a new randomly named branch from the current HEAD. + /// Will match worktree name if the newly created worktree was also randomly named. + #[default] + CurrentBranch, + /// Check out an existing branch, or create a new branch from it if it's + /// already occupied by another worktree. + ExistingBranch { name: String }, + /// Create a new branch with an explicit name, optionally from a specific ref. + CreateBranch { + name: String, + #[serde(default)] + from_ref: Option, + }, +} + /// Sets where new threads will run. -#[derive( - Clone, Copy, Debug, Default, PartialEq, Eq, Serialize, Deserialize, JsonSchema, Action, -)] +#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize, JsonSchema, Action)] #[action(namespace = agent)] #[serde(rename_all = "snake_case", tag = "kind")] pub enum StartThreadIn { #[default] LocalProject, - NewWorktree, + NewWorktree { + /// When this is None, Zed will randomly generate a worktree name + /// otherwise, the provided name will be used. + #[serde(default)] + worktree_name: Option, + #[serde(default)] + branch_target: NewWorktreeBranchTarget, + }, + /// A linked worktree that already exists on disk. + LinkedWorktree { path: PathBuf, display_name: String }, } /// Content to initialize new external agent with. @@ -495,7 +524,6 @@ pub fn init( defaults.collaboration_panel.get_or_insert_default().dock = Some(DockPosition::Right); defaults.git_panel.get_or_insert_default().dock = Some(DockPosition::Right); - defaults.notification_panel.get_or_insert_default().button = Some(false); } else { defaults.agent.get_or_insert_default().dock = Some(DockPosition::Right); defaults.project_panel.get_or_insert_default().dock = Some(DockSide::Left); @@ -503,7 +531,6 @@ pub fn init( defaults.collaboration_panel.get_or_insert_default().dock = Some(DockPosition::Left); defaults.git_panel.get_or_insert_default().dock = Some(DockPosition::Left); - defaults.notification_panel.get_or_insert_default().button = Some(true); } }); }); @@ -713,6 +740,7 @@ mod tests { flexible: true, default_width: px(300.), default_height: px(600.), + max_content_width: px(850.), default_model: None, inline_assistant_model: None, inline_assistant_use_streaming_tools: false, diff --git a/crates/agent_ui/src/conversation_view/thread_view.rs b/crates/agent_ui/src/conversation_view/thread_view.rs index 685621eb3c93632f1e7410bbbad22b623d5e18c7..27ebadade8047db5f2b4de63c5c3731708d9af59 100644 --- a/crates/agent_ui/src/conversation_view/thread_view.rs +++ b/crates/agent_ui/src/conversation_view/thread_view.rs @@ -869,7 +869,10 @@ impl ThreadView { .upgrade() .and_then(|workspace| workspace.read(cx).panel::(cx)) .is_some_and(|panel| { - panel.read(cx).start_thread_in() == &StartThreadIn::NewWorktree + !matches!( + panel.read(cx).start_thread_in(), + StartThreadIn::LocalProject + ) }); if intercept_first_send { @@ -3011,14 +3014,12 @@ impl ThreadView { let is_done = thread.read(cx).status() == ThreadStatus::Idle; let is_canceled_or_failed = self.is_subagent_canceled_or_failed(cx); + let max_content_width = AgentSettings::get_global(cx).max_content_width; + Some( h_flex() - .h(Tab::container_height(cx)) - .pl_2() - .pr_1p5() .w_full() - .justify_between() - .gap_1() + .h(Tab::container_height(cx)) .border_b_1() .when(is_done && is_canceled_or_failed, |this| { this.border_dashed() @@ -3027,50 +3028,61 @@ impl ThreadView { .bg(cx.theme().colors().editor_background.opacity(0.2)) .child( h_flex() - .flex_1() - .gap_2() + .size_full() + .max_w(max_content_width) + .mx_auto() + .pl_2() + .pr_1() + .flex_shrink_0() + .justify_between() + .gap_1() .child( - Icon::new(IconName::ForwardArrowUp) - .size(IconSize::Small) - .color(Color::Muted), + h_flex() + .flex_1() + .gap_2() + .child( + Icon::new(IconName::ForwardArrowUp) + .size(IconSize::Small) + .color(Color::Muted), + ) + .child(self.title_editor.clone()) + .when(is_done && is_canceled_or_failed, |this| { + this.child(Icon::new(IconName::Close).color(Color::Error)) + }) + .when(is_done && !is_canceled_or_failed, |this| { + this.child(Icon::new(IconName::Check).color(Color::Success)) + }), ) - .child(self.title_editor.clone()) - .when(is_done && is_canceled_or_failed, |this| { - this.child(Icon::new(IconName::Close).color(Color::Error)) - }) - .when(is_done && !is_canceled_or_failed, |this| { - this.child(Icon::new(IconName::Check).color(Color::Success)) - }), - ) - .child( - h_flex() - .gap_0p5() - .when(!is_done, |this| { - this.child( - IconButton::new("stop_subagent", IconName::Stop) - .icon_size(IconSize::Small) - .icon_color(Color::Error) - .tooltip(Tooltip::text("Stop Subagent")) - .on_click(move |_, _, cx| { - thread.update(cx, |thread, cx| { - thread.cancel(cx).detach(); - }); - }), - ) - }) .child( - IconButton::new("minimize_subagent", IconName::Minimize) - .icon_size(IconSize::Small) - .tooltip(Tooltip::text("Minimize Subagent")) - .on_click(move |_, window, cx| { - let _ = server_view.update(cx, |server_view, cx| { - server_view.navigate_to_session( - parent_session_id.clone(), - window, - cx, - ); - }); - }), + h_flex() + .gap_0p5() + .when(!is_done, |this| { + this.child( + IconButton::new("stop_subagent", IconName::Stop) + .icon_size(IconSize::Small) + .icon_color(Color::Error) + .tooltip(Tooltip::text("Stop Subagent")) + .on_click(move |_, _, cx| { + thread.update(cx, |thread, cx| { + thread.cancel(cx).detach(); + }); + }), + ) + }) + .child( + IconButton::new("minimize_subagent", IconName::Dash) + .icon_size(IconSize::Small) + .tooltip(Tooltip::text("Minimize Subagent")) + .on_click(move |_, window, cx| { + let _ = server_view.update(cx, |server_view, cx| { + server_view.navigate_to_session( + parent_session_id.clone(), + window, + cx, + ); + }); + }), + ), ), ), ) @@ -3096,6 +3108,8 @@ impl ThreadView { (IconName::Maximize, "Expand Message Editor") }; + let max_content_width = AgentSettings::get_global(cx).max_content_width; + v_flex() .on_action(cx.listener(Self::expand_message_editor)) .p_2() @@ -3110,73 +3124,80 @@ impl ThreadView { }) .child( v_flex() - .relative() - .size_full() - .when(v2_empty_state, |this| this.flex_1()) - .pt_1() - .pr_2p5() - .child(self.message_editor.clone()) - .when(!v2_empty_state, |this| { - this.child( - h_flex() - .absolute() - .top_0() - .right_0() - .opacity(0.5) - .hover(|this| this.opacity(1.0)) - .child( - IconButton::new("toggle-height", expand_icon) - .icon_size(IconSize::Small) - .icon_color(Color::Muted) - .tooltip({ - move |_window, cx| { - Tooltip::for_action_in( - expand_tooltip, - &ExpandMessageEditor, - &focus_handle, - cx, - ) - } - }) - .on_click(cx.listener(|this, _, window, cx| { - this.expand_message_editor( - &ExpandMessageEditor, - window, - cx, - ); - })), - ), - ) - }), - ) - .child( - h_flex() - .flex_none() - .flex_wrap() - .justify_between() + .flex_1() + .w_full() + .max_w(max_content_width) + .mx_auto() .child( - h_flex() - .gap_0p5() - .child(self.render_add_context_button(cx)) - .child(self.render_follow_toggle(cx)) - .children(self.render_fast_mode_control(cx)) - .children(self.render_thinking_control(cx)), + v_flex() + .relative() + .size_full() + .when(v2_empty_state, |this| this.flex_1()) + .pt_1() + .pr_2p5() + .child(self.message_editor.clone()) + .when(!v2_empty_state, |this| { + this.child( + h_flex() + .absolute() + .top_0() + .right_0() + .opacity(0.5) + .hover(|this| this.opacity(1.0)) + .child( + IconButton::new("toggle-height", expand_icon) + .icon_size(IconSize::Small) + .icon_color(Color::Muted) + .tooltip({ + move |_window, cx| { + Tooltip::for_action_in( + expand_tooltip, + &ExpandMessageEditor, + &focus_handle, + cx, + ) + } + }) + .on_click(cx.listener(|this, _, window, cx| { + this.expand_message_editor( + &ExpandMessageEditor, + window, + cx, + ); + })), + ), + ) + }), ) .child( h_flex() - .gap_1() - .children(self.render_token_usage(cx)) - .children(self.profile_selector.clone()) - .map(|this| { - // Either config_options_view OR (mode_selector + model_selector) - match self.config_options_view.clone() { - Some(config_view) => this.child(config_view), - None => this - .children(self.mode_selector.clone()) - .children(self.model_selector.clone()), - } - }) - .child(self.render_send_button(cx)), + .flex_none() + .flex_wrap() + .justify_between() + .child( + h_flex() + .gap_0p5() + .child(self.render_add_context_button(cx)) + .child(self.render_follow_toggle(cx)) + .children(self.render_fast_mode_control(cx)) + .children(self.render_thinking_control(cx)), + ) + .child( + h_flex() + .gap_1() + .children(self.render_token_usage(cx)) + .children(self.profile_selector.clone()) + .map(|this| { + // Either config_options_view OR (mode_selector + model_selector) + match self.config_options_view.clone() { + Some(config_view) => this.child(config_view), + None => this + .children(self.mode_selector.clone()) + .children(self.model_selector.clone()), + } + }) + .child(self.render_send_button(cx)), + ), ), ) .into_any() @@ -8556,8 +8577,12 @@ impl Render for ThreadView { let has_messages = self.list_state.item_count() > 0; let v2_empty_state = cx.has_flag::() && !has_messages; + let max_content_width = AgentSettings::get_global(cx).max_content_width; + let conversation = v_flex() - .when(!v2_empty_state, |this| this.flex_1()) + .mx_auto() + .max_w(max_content_width) + .when(!v2_empty_state, |this| this.flex_1().size_full()) .map(|this| { let this = this.when(self.resumed_without_history, |this| { this.child(Self::render_resume_notice(cx)) diff --git a/crates/agent_ui/src/mention_set.rs b/crates/agent_ui/src/mention_set.rs index 1b2ec0ad2fd460b4eec5a8b757bdd3058d4a3704..880257e3f942bf71d1d51b1e661d911474aa786b 100644 --- a/crates/agent_ui/src/mention_set.rs +++ b/crates/agent_ui/src/mention_set.rs @@ -18,7 +18,7 @@ use gpui::{ use http_client::{AsyncBody, HttpClientWithUrl}; use itertools::Either; use language::Buffer; -use language_model::LanguageModelImage; +use language_model::{LanguageModelImage, LanguageModelImageExt}; use multi_buffer::MultiBufferRow; use postage::stream::Stream as _; use project::{Project, ProjectItem, ProjectPath, Worktree}; diff --git a/crates/agent_ui/src/thread_branch_picker.rs b/crates/agent_ui/src/thread_branch_picker.rs new file mode 100644 index 0000000000000000000000000000000000000000..d69cbb4a60054ad83d767928c880f3a43caef4f1 --- /dev/null +++ b/crates/agent_ui/src/thread_branch_picker.rs @@ -0,0 +1,695 @@ +use std::collections::{HashMap, HashSet}; + +use collections::HashSet as CollectionsHashSet; +use std::path::PathBuf; +use std::sync::Arc; + +use fuzzy::StringMatchCandidate; +use git::repository::Branch as GitBranch; +use gpui::{ + App, Context, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, IntoElement, + ParentElement, Render, SharedString, Styled, Task, Window, rems, +}; +use picker::{Picker, PickerDelegate, PickerEditorPosition}; +use project::Project; +use ui::{ + HighlightedLabel, Icon, IconName, Label, LabelCommon, ListItem, ListItemSpacing, Tooltip, + prelude::*, +}; +use util::ResultExt as _; + +use crate::{NewWorktreeBranchTarget, StartThreadIn}; + +pub(crate) struct ThreadBranchPicker { + picker: Entity>, + focus_handle: FocusHandle, + _subscription: gpui::Subscription, +} + +impl ThreadBranchPicker { + pub fn new( + project: Entity, + current_target: &StartThreadIn, + window: &mut Window, + cx: &mut Context, + ) -> Self { + let project_worktree_paths: HashSet = project + .read(cx) + .visible_worktrees(cx) + .map(|worktree| worktree.read(cx).abs_path().to_path_buf()) + .collect(); + + let has_multiple_repositories = project.read(cx).repositories(cx).len() > 1; + let current_branch_name = project + .read(cx) + .active_repository(cx) + .and_then(|repo| { + repo.read(cx) + .branch + .as_ref() + .map(|branch| branch.name().to_string()) + }) + .unwrap_or_else(|| "HEAD".to_string()); + + let repository = if has_multiple_repositories { + None + } else { + project.read(cx).active_repository(cx) + }; + let branches_request = repository + .clone() + .map(|repo| repo.update(cx, |repo, _| repo.branches())); + let default_branch_request = repository + .clone() + .map(|repo| repo.update(cx, |repo, _| repo.default_branch(false))); + let worktrees_request = repository.map(|repo| repo.update(cx, |repo, _| repo.worktrees())); + + let (worktree_name, branch_target) = match current_target { + StartThreadIn::NewWorktree { + worktree_name, + branch_target, + } => (worktree_name.clone(), branch_target.clone()), + _ => (None, NewWorktreeBranchTarget::default()), + }; + + let delegate = ThreadBranchPickerDelegate { + matches: vec![ThreadBranchEntry::CurrentBranch], + all_branches: None, + occupied_branches: None, + selected_index: 0, + worktree_name, + branch_target, + project_worktree_paths, + current_branch_name, + default_branch_name: None, + has_multiple_repositories, + }; + + let picker = cx.new(|cx| { + Picker::list(delegate, window, cx) + .list_measure_all() + .modal(false) + .max_height(Some(rems(20.).into())) + }); + + let focus_handle = picker.focus_handle(cx); + + if let (Some(branches_request), Some(default_branch_request), Some(worktrees_request)) = + (branches_request, default_branch_request, worktrees_request) + { + let picker_handle = picker.downgrade(); + cx.spawn_in(window, async move |_this, cx| { + let branches = branches_request.await??; + let default_branch = default_branch_request.await.ok().and_then(Result::ok).flatten(); + let worktrees = worktrees_request.await??; + + let remote_upstreams: CollectionsHashSet<_> = branches + .iter() + .filter_map(|branch| { + branch + .upstream + .as_ref() + .filter(|upstream| upstream.is_remote()) + .map(|upstream| upstream.ref_name.clone()) + }) + .collect(); + + let mut occupied_branches = HashMap::new(); + for worktree in worktrees { + let Some(branch_name) = worktree.branch_name().map(ToOwned::to_owned) else { + continue; + }; + + let reason = if picker_handle + .read_with(cx, |picker, _| { + picker + .delegate + .project_worktree_paths + .contains(&worktree.path) + }) + .unwrap_or(false) + { + format!( + "This branch is already checked out in the current project worktree at {}.", + worktree.path.display() + ) + } else { + format!( + "This branch is already checked out in a linked worktree at {}.", + worktree.path.display() + ) + }; + + occupied_branches.insert(branch_name, reason); + } + + let mut all_branches: Vec<_> = branches + .into_iter() + .filter(|branch| !remote_upstreams.contains(&branch.ref_name)) + .collect(); + all_branches.sort_by_key(|branch| { + ( + branch.is_remote(), + !branch.is_head, + branch + .most_recent_commit + .as_ref() + .map(|commit| 0 - commit.commit_timestamp), + ) + }); + + picker_handle.update_in(cx, |picker, window, cx| { + picker.delegate.all_branches = Some(all_branches); + picker.delegate.occupied_branches = Some(occupied_branches); + picker.delegate.default_branch_name = default_branch.map(|branch| branch.to_string()); + picker.refresh(window, cx); + })?; + + anyhow::Ok(()) + }) + .detach_and_log_err(cx); + } + + let subscription = cx.subscribe(&picker, |_, _, _, cx| { + cx.emit(DismissEvent); + }); + + Self { + picker, + focus_handle, + _subscription: subscription, + } + } +} + +impl Focusable for ThreadBranchPicker { + fn focus_handle(&self, _cx: &App) -> FocusHandle { + self.focus_handle.clone() + } +} + +impl EventEmitter for ThreadBranchPicker {} + +impl Render for ThreadBranchPicker { + fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { + v_flex() + .w(rems(22.)) + .elevation_3(cx) + .child(self.picker.clone()) + .on_mouse_down_out(cx.listener(|_, _, _, cx| { + cx.emit(DismissEvent); + })) + } +} + +#[derive(Clone)] +enum ThreadBranchEntry { + CurrentBranch, + DefaultBranch, + ExistingBranch { + branch: GitBranch, + positions: Vec, + occupied_reason: Option, + }, + CreateNamed { + name: String, + }, +} + +pub(crate) struct ThreadBranchPickerDelegate { + matches: Vec, + all_branches: Option>, + occupied_branches: Option>, + selected_index: usize, + worktree_name: Option, + branch_target: NewWorktreeBranchTarget, + project_worktree_paths: HashSet, + current_branch_name: String, + default_branch_name: Option, + has_multiple_repositories: bool, +} + +impl ThreadBranchPickerDelegate { + fn new_worktree_action(&self, branch_target: NewWorktreeBranchTarget) -> StartThreadIn { + StartThreadIn::NewWorktree { + worktree_name: self.worktree_name.clone(), + branch_target, + } + } + + fn selected_entry_name(&self) -> Option<&str> { + match &self.branch_target { + NewWorktreeBranchTarget::CurrentBranch => None, + NewWorktreeBranchTarget::ExistingBranch { name } => Some(name), + NewWorktreeBranchTarget::CreateBranch { + from_ref: Some(from_ref), + .. + } => Some(from_ref), + NewWorktreeBranchTarget::CreateBranch { name, .. } => Some(name), + } + } + + fn prefer_create_entry(&self) -> bool { + matches!( + &self.branch_target, + NewWorktreeBranchTarget::CreateBranch { from_ref: None, .. } + ) + } + + fn fixed_matches(&self) -> Vec { + let mut matches = vec![ThreadBranchEntry::CurrentBranch]; + if !self.has_multiple_repositories + && self + .default_branch_name + .as_ref() + .is_some_and(|default_branch_name| default_branch_name != &self.current_branch_name) + { + matches.push(ThreadBranchEntry::DefaultBranch); + } + matches + } + + fn current_branch_label(&self) -> SharedString { + if self.has_multiple_repositories { + SharedString::from("New branch from: current branches") + } else { + SharedString::from(format!("New branch from: {}", self.current_branch_name)) + } + } + + fn default_branch_label(&self) -> Option { + let default_branch_name = self + .default_branch_name + .as_ref() + .filter(|name| *name != &self.current_branch_name)?; + let is_occupied = self + .occupied_branches + .as_ref() + .is_some_and(|occupied| occupied.contains_key(default_branch_name)); + let prefix = if is_occupied { + "New branch from" + } else { + "From" + }; + Some(SharedString::from(format!( + "{prefix}: {default_branch_name}" + ))) + } + + fn branch_label_prefix(&self, branch_name: &str) -> &'static str { + let is_occupied = self + .occupied_branches + .as_ref() + .is_some_and(|occupied| occupied.contains_key(branch_name)); + if is_occupied { + "New branch from: " + } else { + "From: " + } + } + + fn sync_selected_index(&mut self) { + let selected_entry_name = self.selected_entry_name().map(ToOwned::to_owned); + let prefer_create = self.prefer_create_entry(); + + if prefer_create { + if let Some(ref selected_entry_name) = selected_entry_name { + if let Some(index) = self.matches.iter().position(|entry| { + matches!( + entry, + ThreadBranchEntry::CreateNamed { name } if name == selected_entry_name + ) + }) { + self.selected_index = index; + return; + } + } + } else if let Some(ref selected_entry_name) = selected_entry_name { + if selected_entry_name == &self.current_branch_name { + if let Some(index) = self + .matches + .iter() + .position(|entry| matches!(entry, ThreadBranchEntry::CurrentBranch)) + { + self.selected_index = index; + return; + } + } + + if self + .default_branch_name + .as_ref() + .is_some_and(|default_branch_name| default_branch_name == selected_entry_name) + { + if let Some(index) = self + .matches + .iter() + .position(|entry| matches!(entry, ThreadBranchEntry::DefaultBranch)) + { + self.selected_index = index; + return; + } + } + + if let Some(index) = self.matches.iter().position(|entry| { + matches!( + entry, + ThreadBranchEntry::ExistingBranch { branch, .. } + if branch.name() == selected_entry_name.as_str() + ) + }) { + self.selected_index = index; + return; + } + } + + if self.matches.len() > 1 + && self + .matches + .iter() + .skip(1) + .all(|entry| matches!(entry, ThreadBranchEntry::CreateNamed { .. })) + { + self.selected_index = 1; + return; + } + + self.selected_index = 0; + } +} + +impl PickerDelegate for ThreadBranchPickerDelegate { + type ListItem = ListItem; + + fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc { + "Search branches…".into() + } + + fn editor_position(&self) -> PickerEditorPosition { + PickerEditorPosition::Start + } + + fn match_count(&self) -> usize { + self.matches.len() + } + + fn selected_index(&self) -> usize { + self.selected_index + } + + fn set_selected_index( + &mut self, + ix: usize, + _window: &mut Window, + _cx: &mut Context>, + ) { + self.selected_index = ix; + } + + fn update_matches( + &mut self, + query: String, + window: &mut Window, + cx: &mut Context>, + ) -> Task<()> { + if self.has_multiple_repositories { + let mut matches = self.fixed_matches(); + + if query.is_empty() { + if let Some(name) = self.selected_entry_name().map(ToOwned::to_owned) { + if self.prefer_create_entry() { + matches.push(ThreadBranchEntry::CreateNamed { name }); + } + } + } else { + matches.push(ThreadBranchEntry::CreateNamed { + name: query.replace(' ', "-"), + }); + } + + self.matches = matches; + self.sync_selected_index(); + return Task::ready(()); + } + + let Some(all_branches) = self.all_branches.clone() else { + self.matches = self.fixed_matches(); + self.selected_index = 0; + return Task::ready(()); + }; + let occupied_branches = self.occupied_branches.clone().unwrap_or_default(); + + if query.is_empty() { + let mut matches = self.fixed_matches(); + for branch in all_branches.into_iter().filter(|branch| { + branch.name() != self.current_branch_name + && self + .default_branch_name + .as_ref() + .is_none_or(|default_branch_name| branch.name() != default_branch_name) + }) { + matches.push(ThreadBranchEntry::ExistingBranch { + occupied_reason: occupied_branches.get(branch.name()).cloned(), + branch, + positions: Vec::new(), + }); + } + + if let Some(selected_entry_name) = self.selected_entry_name().map(ToOwned::to_owned) { + let has_existing = matches.iter().any(|entry| { + matches!( + entry, + ThreadBranchEntry::ExistingBranch { branch, .. } + if branch.name() == selected_entry_name + ) + }); + if self.prefer_create_entry() && !has_existing { + matches.push(ThreadBranchEntry::CreateNamed { + name: selected_entry_name, + }); + } + } + + self.matches = matches; + self.sync_selected_index(); + return Task::ready(()); + } + + let candidates: Vec<_> = all_branches + .iter() + .enumerate() + .map(|(ix, branch)| StringMatchCandidate::new(ix, branch.name())) + .collect(); + let executor = cx.background_executor().clone(); + let query_clone = query.clone(); + let normalized_query = query.replace(' ', "-"); + + let task = cx.background_executor().spawn(async move { + fuzzy::match_strings( + &candidates, + &query_clone, + true, + true, + 10000, + &Default::default(), + executor, + ) + .await + }); + + let all_branches_clone = all_branches; + cx.spawn_in(window, async move |picker, cx| { + let fuzzy_matches = task.await; + + picker + .update_in(cx, |picker, _window, cx| { + let mut matches = picker.delegate.fixed_matches(); + + for candidate in &fuzzy_matches { + let branch = all_branches_clone[candidate.candidate_id].clone(); + if branch.name() == picker.delegate.current_branch_name + || picker.delegate.default_branch_name.as_ref().is_some_and( + |default_branch_name| branch.name() == default_branch_name, + ) + { + continue; + } + let occupied_reason = occupied_branches.get(branch.name()).cloned(); + matches.push(ThreadBranchEntry::ExistingBranch { + branch, + positions: candidate.positions.clone(), + occupied_reason, + }); + } + + if fuzzy_matches.is_empty() { + matches.push(ThreadBranchEntry::CreateNamed { + name: normalized_query.clone(), + }); + } + + picker.delegate.matches = matches; + if let Some(index) = + picker.delegate.matches.iter().position(|entry| { + matches!(entry, ThreadBranchEntry::ExistingBranch { .. }) + }) + { + picker.delegate.selected_index = index; + } else if !fuzzy_matches.is_empty() { + picker.delegate.selected_index = 0; + } else if let Some(index) = + picker.delegate.matches.iter().position(|entry| { + matches!(entry, ThreadBranchEntry::CreateNamed { .. }) + }) + { + picker.delegate.selected_index = index; + } else { + picker.delegate.sync_selected_index(); + } + cx.notify(); + }) + .log_err(); + }) + } + + fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context>) { + let Some(entry) = self.matches.get(self.selected_index) else { + return; + }; + + match entry { + ThreadBranchEntry::CurrentBranch => { + window.dispatch_action( + Box::new(self.new_worktree_action(NewWorktreeBranchTarget::CurrentBranch)), + cx, + ); + } + ThreadBranchEntry::DefaultBranch => { + let Some(default_branch_name) = self.default_branch_name.clone() else { + return; + }; + window.dispatch_action( + Box::new( + self.new_worktree_action(NewWorktreeBranchTarget::ExistingBranch { + name: default_branch_name, + }), + ), + cx, + ); + } + ThreadBranchEntry::ExistingBranch { branch, .. } => { + let branch_target = if branch.is_remote() { + let branch_name = branch + .ref_name + .as_ref() + .strip_prefix("refs/remotes/") + .and_then(|stripped| stripped.split_once('/').map(|(_, name)| name)) + .unwrap_or(branch.name()) + .to_string(); + NewWorktreeBranchTarget::CreateBranch { + name: branch_name, + from_ref: Some(branch.name().to_string()), + } + } else { + NewWorktreeBranchTarget::ExistingBranch { + name: branch.name().to_string(), + } + }; + window.dispatch_action(Box::new(self.new_worktree_action(branch_target)), cx); + } + ThreadBranchEntry::CreateNamed { name } => { + window.dispatch_action( + Box::new( + self.new_worktree_action(NewWorktreeBranchTarget::CreateBranch { + name: name.clone(), + from_ref: None, + }), + ), + cx, + ); + } + } + + cx.emit(DismissEvent); + } + + fn dismissed(&mut self, _window: &mut Window, _cx: &mut Context>) {} + + fn separators_after_indices(&self) -> Vec { + let fixed_count = self.fixed_matches().len(); + if self.matches.len() > fixed_count { + vec![fixed_count - 1] + } else { + Vec::new() + } + } + + fn render_match( + &self, + ix: usize, + selected: bool, + _window: &mut Window, + _cx: &mut Context>, + ) -> Option { + let entry = self.matches.get(ix)?; + + match entry { + ThreadBranchEntry::CurrentBranch => Some( + ListItem::new("current-branch") + .inset(true) + .spacing(ListItemSpacing::Sparse) + .toggle_state(selected) + .start_slot(Icon::new(IconName::GitBranch).color(Color::Muted)) + .child(Label::new(self.current_branch_label())), + ), + ThreadBranchEntry::DefaultBranch => Some( + ListItem::new("default-branch") + .inset(true) + .spacing(ListItemSpacing::Sparse) + .toggle_state(selected) + .start_slot(Icon::new(IconName::GitBranch).color(Color::Muted)) + .child(Label::new(self.default_branch_label()?)), + ), + ThreadBranchEntry::ExistingBranch { + branch, + positions, + occupied_reason, + } => { + let prefix = self.branch_label_prefix(branch.name()); + let branch_name = branch.name().to_string(); + let full_label = format!("{prefix}{branch_name}"); + let adjusted_positions: Vec = + positions.iter().map(|&p| p + prefix.len()).collect(); + + let item = ListItem::new(SharedString::from(format!("branch-{ix}"))) + .inset(true) + .spacing(ListItemSpacing::Sparse) + .toggle_state(selected) + .start_slot(Icon::new(IconName::GitBranch).color(Color::Muted)) + .child(HighlightedLabel::new(full_label, adjusted_positions).truncate()); + + Some(if let Some(reason) = occupied_reason.clone() { + item.tooltip(Tooltip::text(reason)) + } else if branch.is_remote() { + item.tooltip(Tooltip::text( + "Create a new local branch from this remote branch", + )) + } else { + item + }) + } + ThreadBranchEntry::CreateNamed { name } => Some( + ListItem::new("create-named-branch") + .inset(true) + .spacing(ListItemSpacing::Sparse) + .toggle_state(selected) + .start_slot(Icon::new(IconName::Plus).color(Color::Accent)) + .child(Label::new(format!("Create Branch: \"{name}\"…"))), + ), + } + } + + fn no_matches_text(&self, _window: &mut Window, _cx: &mut App) -> Option { + None + } +} diff --git a/crates/agent_ui/src/thread_worktree_picker.rs b/crates/agent_ui/src/thread_worktree_picker.rs new file mode 100644 index 0000000000000000000000000000000000000000..47a6a12d71822e13ab3523a3a6b0bb1ee57c7b4b --- /dev/null +++ b/crates/agent_ui/src/thread_worktree_picker.rs @@ -0,0 +1,485 @@ +use std::path::PathBuf; +use std::sync::Arc; + +use agent_settings::AgentSettings; +use fs::Fs; +use fuzzy::StringMatchCandidate; +use git::repository::Worktree as GitWorktree; +use gpui::{ + App, Context, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, IntoElement, + ParentElement, Render, SharedString, Styled, Task, Window, rems, +}; +use picker::{Picker, PickerDelegate, PickerEditorPosition}; +use project::{Project, git_store::RepositoryId}; +use settings::{NewThreadLocation, Settings, update_settings_file}; +use ui::{ + HighlightedLabel, Icon, IconName, Label, LabelCommon, ListItem, ListItemSpacing, Tooltip, + prelude::*, +}; +use util::ResultExt as _; + +use crate::ui::HoldForDefault; +use crate::{NewWorktreeBranchTarget, StartThreadIn}; + +pub(crate) struct ThreadWorktreePicker { + picker: Entity>, + focus_handle: FocusHandle, + _subscription: gpui::Subscription, +} + +impl ThreadWorktreePicker { + pub fn new( + project: Entity, + current_target: &StartThreadIn, + fs: Arc, + window: &mut Window, + cx: &mut Context, + ) -> Self { + let project_worktree_paths: Vec = project + .read(cx) + .visible_worktrees(cx) + .map(|wt| wt.read(cx).abs_path().to_path_buf()) + .collect(); + + let preserved_branch_target = match current_target { + StartThreadIn::NewWorktree { branch_target, .. } => branch_target.clone(), + _ => NewWorktreeBranchTarget::default(), + }; + + let delegate = ThreadWorktreePickerDelegate { + matches: vec![ + ThreadWorktreeEntry::CurrentWorktree, + ThreadWorktreeEntry::NewWorktree, + ], + all_worktrees: project + .read(cx) + .repositories(cx) + .iter() + .map(|(repo_id, repo)| (*repo_id, repo.read(cx).linked_worktrees.clone())) + .collect(), + project_worktree_paths, + selected_index: match current_target { + StartThreadIn::LocalProject => 0, + StartThreadIn::NewWorktree { .. } => 1, + _ => 0, + }, + project: project.clone(), + preserved_branch_target, + fs, + }; + + let picker = cx.new(|cx| { + Picker::list(delegate, window, cx) + .list_measure_all() + .modal(false) + .max_height(Some(rems(20.).into())) + }); + + let subscription = cx.subscribe(&picker, |_, _, _, cx| { + cx.emit(DismissEvent); + }); + + Self { + focus_handle: picker.focus_handle(cx), + picker, + _subscription: subscription, + } + } +} + +impl Focusable for ThreadWorktreePicker { + fn focus_handle(&self, _cx: &App) -> FocusHandle { + self.focus_handle.clone() + } +} + +impl EventEmitter for ThreadWorktreePicker {} + +impl Render for ThreadWorktreePicker { + fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { + v_flex() + .w(rems(20.)) + .elevation_3(cx) + .child(self.picker.clone()) + .on_mouse_down_out(cx.listener(|_, _, _, cx| { + cx.emit(DismissEvent); + })) + } +} + +#[derive(Clone)] +enum ThreadWorktreeEntry { + CurrentWorktree, + NewWorktree, + LinkedWorktree { + worktree: GitWorktree, + positions: Vec, + }, + CreateNamed { + name: String, + disabled_reason: Option, + }, +} + +pub(crate) struct ThreadWorktreePickerDelegate { + matches: Vec, + all_worktrees: Vec<(RepositoryId, Arc<[GitWorktree]>)>, + project_worktree_paths: Vec, + selected_index: usize, + preserved_branch_target: NewWorktreeBranchTarget, + project: Entity, + fs: Arc, +} + +impl ThreadWorktreePickerDelegate { + fn new_worktree_action(&self, worktree_name: Option) -> StartThreadIn { + StartThreadIn::NewWorktree { + worktree_name, + branch_target: self.preserved_branch_target.clone(), + } + } + + fn sync_selected_index(&mut self) { + if let Some(index) = self + .matches + .iter() + .position(|entry| matches!(entry, ThreadWorktreeEntry::LinkedWorktree { .. })) + { + self.selected_index = index; + } else if let Some(index) = self + .matches + .iter() + .position(|entry| matches!(entry, ThreadWorktreeEntry::CreateNamed { .. })) + { + self.selected_index = index; + } else { + self.selected_index = 0; + } + } +} + +impl PickerDelegate for ThreadWorktreePickerDelegate { + type ListItem = ListItem; + + fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc { + "Search or create worktrees…".into() + } + + fn editor_position(&self) -> PickerEditorPosition { + PickerEditorPosition::Start + } + + fn match_count(&self) -> usize { + self.matches.len() + } + + fn selected_index(&self) -> usize { + self.selected_index + } + + fn set_selected_index( + &mut self, + ix: usize, + _window: &mut Window, + _cx: &mut Context>, + ) { + self.selected_index = ix; + } + + fn separators_after_indices(&self) -> Vec { + if self.matches.len() > 2 { + vec![1] + } else { + Vec::new() + } + } + + fn update_matches( + &mut self, + query: String, + window: &mut Window, + cx: &mut Context>, + ) -> Task<()> { + let has_multiple_repositories = self.all_worktrees.len() > 1; + + let linked_worktrees: Vec<_> = if has_multiple_repositories { + Vec::new() + } else { + self.all_worktrees + .iter() + .flat_map(|(_, worktrees)| worktrees.iter()) + .filter(|worktree| { + !self + .project_worktree_paths + .iter() + .any(|project_path| project_path == &worktree.path) + }) + .cloned() + .collect() + }; + + let normalized_query = query.replace(' ', "-"); + let has_named_worktree = self.all_worktrees.iter().any(|(_, worktrees)| { + worktrees + .iter() + .any(|worktree| worktree.display_name() == normalized_query) + }); + let create_named_disabled_reason = if has_multiple_repositories { + Some("Cannot create a named worktree in a project with multiple repositories".into()) + } else if has_named_worktree { + Some("A worktree with this name already exists".into()) + } else { + None + }; + + let mut matches = vec![ + ThreadWorktreeEntry::CurrentWorktree, + ThreadWorktreeEntry::NewWorktree, + ]; + + if query.is_empty() { + for worktree in &linked_worktrees { + matches.push(ThreadWorktreeEntry::LinkedWorktree { + worktree: worktree.clone(), + positions: Vec::new(), + }); + } + } else if linked_worktrees.is_empty() { + matches.push(ThreadWorktreeEntry::CreateNamed { + name: normalized_query, + disabled_reason: create_named_disabled_reason, + }); + } else { + let candidates: Vec<_> = linked_worktrees + .iter() + .enumerate() + .map(|(ix, worktree)| StringMatchCandidate::new(ix, worktree.display_name())) + .collect(); + + let executor = cx.background_executor().clone(); + let query_clone = query.clone(); + + let task = cx.background_executor().spawn(async move { + fuzzy::match_strings( + &candidates, + &query_clone, + true, + true, + 10000, + &Default::default(), + executor, + ) + .await + }); + + let linked_worktrees_clone = linked_worktrees; + return cx.spawn_in(window, async move |picker, cx| { + let fuzzy_matches = task.await; + + picker + .update_in(cx, |picker, _window, cx| { + let mut new_matches = vec![ + ThreadWorktreeEntry::CurrentWorktree, + ThreadWorktreeEntry::NewWorktree, + ]; + + for candidate in &fuzzy_matches { + new_matches.push(ThreadWorktreeEntry::LinkedWorktree { + worktree: linked_worktrees_clone[candidate.candidate_id].clone(), + positions: candidate.positions.clone(), + }); + } + + let has_exact_match = linked_worktrees_clone + .iter() + .any(|worktree| worktree.display_name() == query); + + if !has_exact_match { + new_matches.push(ThreadWorktreeEntry::CreateNamed { + name: normalized_query.clone(), + disabled_reason: create_named_disabled_reason.clone(), + }); + } + + picker.delegate.matches = new_matches; + picker.delegate.sync_selected_index(); + + cx.notify(); + }) + .log_err(); + }); + } + + self.matches = matches; + self.sync_selected_index(); + + Task::ready(()) + } + + fn confirm(&mut self, secondary: bool, window: &mut Window, cx: &mut Context>) { + let Some(entry) = self.matches.get(self.selected_index) else { + return; + }; + + match entry { + ThreadWorktreeEntry::CurrentWorktree => { + if secondary { + update_settings_file(self.fs.clone(), cx, |settings, _| { + settings + .agent + .get_or_insert_default() + .set_new_thread_location(NewThreadLocation::LocalProject); + }); + } + window.dispatch_action(Box::new(StartThreadIn::LocalProject), cx); + } + ThreadWorktreeEntry::NewWorktree => { + if secondary { + update_settings_file(self.fs.clone(), cx, |settings, _| { + settings + .agent + .get_or_insert_default() + .set_new_thread_location(NewThreadLocation::NewWorktree); + }); + } + window.dispatch_action(Box::new(self.new_worktree_action(None)), cx); + } + ThreadWorktreeEntry::LinkedWorktree { worktree, .. } => { + window.dispatch_action( + Box::new(StartThreadIn::LinkedWorktree { + path: worktree.path.clone(), + display_name: worktree.display_name().to_string(), + }), + cx, + ); + } + ThreadWorktreeEntry::CreateNamed { + name, + disabled_reason: None, + } => { + window.dispatch_action(Box::new(self.new_worktree_action(Some(name.clone()))), cx); + } + ThreadWorktreeEntry::CreateNamed { + disabled_reason: Some(_), + .. + } => { + return; + } + } + + cx.emit(DismissEvent); + } + + fn dismissed(&mut self, _window: &mut Window, _cx: &mut Context>) {} + + fn render_match( + &self, + ix: usize, + selected: bool, + _window: &mut Window, + cx: &mut Context>, + ) -> Option { + let entry = self.matches.get(ix)?; + let project = self.project.read(cx); + let is_new_worktree_disabled = + project.repositories(cx).is_empty() || project.is_via_collab(); + let new_thread_location = AgentSettings::get_global(cx).new_thread_location; + let is_local_default = new_thread_location == NewThreadLocation::LocalProject; + let is_new_worktree_default = new_thread_location == NewThreadLocation::NewWorktree; + + match entry { + ThreadWorktreeEntry::CurrentWorktree => Some( + ListItem::new("current-worktree") + .inset(true) + .spacing(ListItemSpacing::Sparse) + .toggle_state(selected) + .start_slot(Icon::new(IconName::Folder).color(Color::Muted)) + .child(Label::new("Current Worktree")) + .end_slot(HoldForDefault::new(is_local_default).more_content(false)) + .tooltip(Tooltip::text("Use the current project worktree")), + ), + ThreadWorktreeEntry::NewWorktree => { + let item = ListItem::new("new-worktree") + .inset(true) + .spacing(ListItemSpacing::Sparse) + .toggle_state(selected) + .disabled(is_new_worktree_disabled) + .start_slot( + Icon::new(IconName::Plus).color(if is_new_worktree_disabled { + Color::Disabled + } else { + Color::Muted + }), + ) + .child( + Label::new("New Git Worktree").color(if is_new_worktree_disabled { + Color::Disabled + } else { + Color::Default + }), + ); + + Some(if is_new_worktree_disabled { + item.tooltip(Tooltip::text("Requires a Git repository in the project")) + } else { + item.end_slot(HoldForDefault::new(is_new_worktree_default).more_content(false)) + .tooltip(Tooltip::text("Start a thread in a new Git worktree")) + }) + } + ThreadWorktreeEntry::LinkedWorktree { + worktree, + positions, + } => { + let display_name = worktree.display_name(); + let first_line = display_name.lines().next().unwrap_or(display_name); + let positions: Vec<_> = positions + .iter() + .copied() + .filter(|&pos| pos < first_line.len()) + .collect(); + + Some( + ListItem::new(SharedString::from(format!("linked-worktree-{ix}"))) + .inset(true) + .spacing(ListItemSpacing::Sparse) + .toggle_state(selected) + .start_slot(Icon::new(IconName::GitWorktree).color(Color::Muted)) + .child(HighlightedLabel::new(first_line.to_owned(), positions).truncate()), + ) + } + ThreadWorktreeEntry::CreateNamed { + name, + disabled_reason, + } => { + let is_disabled = disabled_reason.is_some(); + let item = ListItem::new("create-named-worktree") + .inset(true) + .spacing(ListItemSpacing::Sparse) + .toggle_state(selected) + .disabled(is_disabled) + .start_slot(Icon::new(IconName::Plus).color(if is_disabled { + Color::Disabled + } else { + Color::Accent + })) + .child(Label::new(format!("Create Worktree: \"{name}\"…")).color( + if is_disabled { + Color::Disabled + } else { + Color::Default + }, + )); + + Some(if let Some(reason) = disabled_reason.clone() { + item.tooltip(Tooltip::text(reason)) + } else { + item + }) + } + } + } + + fn no_matches_text(&self, _window: &mut Window, _cx: &mut App) -> Option { + None + } +} diff --git a/crates/agent_ui/src/threads_archive_view.rs b/crates/agent_ui/src/threads_archive_view.rs index 13b2aa1a37cd506c338d13db78bce751882e426a..7cb8410e5017438b0e8adde673887c13397d9abf 100644 --- a/crates/agent_ui/src/threads_archive_view.rs +++ b/crates/agent_ui/src/threads_archive_view.rs @@ -1236,6 +1236,7 @@ impl PickerDelegate for ProjectPickerDelegate { }, match_label: HighlightedMatch::join(match_labels.into_iter().flatten(), ", "), paths: Vec::new(), + active: false, }; Some( diff --git a/crates/anthropic/Cargo.toml b/crates/anthropic/Cargo.toml index 1e2587435489dea6952c697b0e0a4cf627226728..458f9bfae7da4736c4e54e42f08b5e3a926ed30a 100644 --- a/crates/anthropic/Cargo.toml +++ b/crates/anthropic/Cargo.toml @@ -18,12 +18,16 @@ path = "src/anthropic.rs" [dependencies] anyhow.workspace = true chrono.workspace = true +collections.workspace = true futures.workspace = true http_client.workspace = true +language_model_core.workspace = true +log.workspace = true schemars = { workspace = true, optional = true } serde.workspace = true serde_json.workspace = true strum.workspace = true thiserror.workspace = true +tiktoken-rs.workspace = true diff --git a/crates/anthropic/src/anthropic.rs b/crates/anthropic/src/anthropic.rs index 5d7790b86b09853e22436252fcde1bebf5feff9b..48fa318d7c1d87e63725cef836baf9c945966206 100644 --- a/crates/anthropic/src/anthropic.rs +++ b/crates/anthropic/src/anthropic.rs @@ -12,6 +12,7 @@ use strum::{EnumIter, EnumString}; use thiserror::Error; pub mod batches; +pub mod completion; pub const ANTHROPIC_API_URL: &str = "https://api.anthropic.com"; @@ -1026,6 +1027,89 @@ pub async fn count_tokens( } } +// -- Conversions from/to `language_model_core` types -- + +impl From for Speed { + fn from(speed: language_model_core::Speed) -> Self { + match speed { + language_model_core::Speed::Standard => Speed::Standard, + language_model_core::Speed::Fast => Speed::Fast, + } + } +} + +impl From for language_model_core::LanguageModelCompletionError { + fn from(error: AnthropicError) -> Self { + let provider = language_model_core::ANTHROPIC_PROVIDER_NAME; + match error { + AnthropicError::SerializeRequest(error) => Self::SerializeRequest { provider, error }, + AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error }, + AnthropicError::HttpSend(error) => Self::HttpSend { provider, error }, + AnthropicError::DeserializeResponse(error) => { + Self::DeserializeResponse { provider, error } + } + AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { provider, error }, + AnthropicError::HttpResponseError { + status_code, + message, + } => Self::HttpResponseError { + provider, + status_code, + message, + }, + AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded { + provider, + retry_after: Some(retry_after), + }, + AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded { + provider, + retry_after, + }, + AnthropicError::ApiError(api_error) => api_error.into(), + } + } +} + +impl From for language_model_core::LanguageModelCompletionError { + fn from(error: ApiError) -> Self { + use ApiErrorCode::*; + let provider = language_model_core::ANTHROPIC_PROVIDER_NAME; + match error.code() { + Some(code) => match code { + InvalidRequestError => Self::BadRequestFormat { + provider, + message: error.message, + }, + AuthenticationError => Self::AuthenticationError { + provider, + message: error.message, + }, + PermissionError => Self::PermissionError { + provider, + message: error.message, + }, + NotFoundError => Self::ApiEndpointNotFound { provider }, + RequestTooLarge => Self::PromptTooLarge { + tokens: language_model_core::parse_prompt_too_long(&error.message), + }, + RateLimitError => Self::RateLimitExceeded { + provider, + retry_after: None, + }, + ApiError => Self::ApiInternalServerError { + provider, + message: error.message, + }, + OverloadedError => Self::ServerOverloaded { + provider, + retry_after: None, + }, + }, + None => Self::Other(error.into()), + } + } +} + #[test] fn test_match_window_exceeded() { let error = ApiError { diff --git a/crates/anthropic/src/completion.rs b/crates/anthropic/src/completion.rs new file mode 100644 index 0000000000000000000000000000000000000000..a6175a4f7c24b3b724734b2edef48ef8acfaa159 --- /dev/null +++ b/crates/anthropic/src/completion.rs @@ -0,0 +1,765 @@ +use anyhow::Result; +use collections::HashMap; +use futures::{Stream, StreamExt}; +use language_model_core::{ + LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelRequest, + LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent, + Role, StopReason, TokenUsage, + util::{fix_streamed_json, parse_tool_arguments}, +}; +use std::pin::Pin; +use std::str::FromStr; + +use crate::{ + AnthropicError, AnthropicModelMode, CacheControl, CacheControlType, ContentDelta, + CountTokensRequest, Event, ImageSource, Message, RequestContent, ResponseContent, + StringOrContents, Thinking, Tool, ToolChoice, ToolResultContent, ToolResultPart, Usage, +}; + +fn to_anthropic_content(content: MessageContent) -> Option { + match content { + MessageContent::Text(text) => { + let text = if text.chars().last().is_some_and(|c| c.is_whitespace()) { + text.trim_end().to_string() + } else { + text + }; + if !text.is_empty() { + Some(RequestContent::Text { + text, + cache_control: None, + }) + } else { + None + } + } + MessageContent::Thinking { + text: thinking, + signature, + } => { + if let Some(signature) = signature + && !thinking.is_empty() + { + Some(RequestContent::Thinking { + thinking, + signature, + cache_control: None, + }) + } else { + None + } + } + MessageContent::RedactedThinking(data) => { + if !data.is_empty() { + Some(RequestContent::RedactedThinking { data }) + } else { + None + } + } + MessageContent::Image(image) => Some(RequestContent::Image { + source: ImageSource { + source_type: "base64".to_string(), + media_type: "image/png".to_string(), + data: image.source.to_string(), + }, + cache_control: None, + }), + MessageContent::ToolUse(tool_use) => Some(RequestContent::ToolUse { + id: tool_use.id.to_string(), + name: tool_use.name.to_string(), + input: tool_use.input, + cache_control: None, + }), + MessageContent::ToolResult(tool_result) => Some(RequestContent::ToolResult { + tool_use_id: tool_result.tool_use_id.to_string(), + is_error: tool_result.is_error, + content: match tool_result.content { + LanguageModelToolResultContent::Text(text) => { + ToolResultContent::Plain(text.to_string()) + } + LanguageModelToolResultContent::Image(image) => { + ToolResultContent::Multipart(vec![ToolResultPart::Image { + source: ImageSource { + source_type: "base64".to_string(), + media_type: "image/png".to_string(), + data: image.source.to_string(), + }, + }]) + } + }, + cache_control: None, + }), + } +} + +/// Convert a LanguageModelRequest to an Anthropic CountTokensRequest. +pub fn into_anthropic_count_tokens_request( + request: LanguageModelRequest, + model: String, + mode: AnthropicModelMode, +) -> CountTokensRequest { + let mut new_messages: Vec = Vec::new(); + let mut system_message = String::new(); + + for message in request.messages { + if message.contents_empty() { + continue; + } + + match message.role { + Role::User | Role::Assistant => { + let anthropic_message_content: Vec = message + .content + .into_iter() + .filter_map(to_anthropic_content) + .collect(); + let anthropic_role = match message.role { + Role::User => crate::Role::User, + Role::Assistant => crate::Role::Assistant, + Role::System => unreachable!("System role should never occur here"), + }; + if anthropic_message_content.is_empty() { + continue; + } + + if let Some(last_message) = new_messages.last_mut() + && last_message.role == anthropic_role + { + last_message.content.extend(anthropic_message_content); + continue; + } + + new_messages.push(Message { + role: anthropic_role, + content: anthropic_message_content, + }); + } + Role::System => { + if !system_message.is_empty() { + system_message.push_str("\n\n"); + } + system_message.push_str(&message.string_contents()); + } + } + } + + CountTokensRequest { + model, + messages: new_messages, + system: if system_message.is_empty() { + None + } else { + Some(StringOrContents::String(system_message)) + }, + thinking: if request.thinking_allowed { + match mode { + AnthropicModelMode::Thinking { budget_tokens } => { + Some(Thinking::Enabled { budget_tokens }) + } + AnthropicModelMode::AdaptiveThinking => Some(Thinking::Adaptive), + AnthropicModelMode::Default => None, + } + } else { + None + }, + tools: request + .tools + .into_iter() + .map(|tool| Tool { + name: tool.name, + description: tool.description, + input_schema: tool.input_schema, + eager_input_streaming: tool.use_input_streaming, + }) + .collect(), + tool_choice: request.tool_choice.map(|choice| match choice { + LanguageModelToolChoice::Auto => ToolChoice::Auto, + LanguageModelToolChoice::Any => ToolChoice::Any, + LanguageModelToolChoice::None => ToolChoice::None, + }), + } +} + +/// Estimate tokens using tiktoken. Used as a fallback when the API is unavailable, +/// or by providers (like Zed Cloud) that don't have direct Anthropic API access. +pub fn count_anthropic_tokens_with_tiktoken(request: LanguageModelRequest) -> Result { + let messages = request.messages; + let mut tokens_from_images = 0; + let mut string_messages = Vec::with_capacity(messages.len()); + + for message in messages { + let mut string_contents = String::new(); + + for content in message.content { + match content { + MessageContent::Text(text) => { + string_contents.push_str(&text); + } + MessageContent::Thinking { .. } => { + // Thinking blocks are not included in the input token count. + } + MessageContent::RedactedThinking(_) => { + // Thinking blocks are not included in the input token count. + } + MessageContent::Image(image) => { + tokens_from_images += image.estimate_tokens(); + } + MessageContent::ToolUse(_tool_use) => { + // TODO: Estimate token usage from tool uses. + } + MessageContent::ToolResult(tool_result) => match &tool_result.content { + LanguageModelToolResultContent::Text(text) => { + string_contents.push_str(text); + } + LanguageModelToolResultContent::Image(image) => { + tokens_from_images += image.estimate_tokens(); + } + }, + } + } + + if !string_contents.is_empty() { + string_messages.push(tiktoken_rs::ChatCompletionRequestMessage { + role: match message.role { + Role::User => "user".into(), + Role::Assistant => "assistant".into(), + Role::System => "system".into(), + }, + content: Some(string_contents), + name: None, + function_call: None, + }); + } + } + + // Tiktoken doesn't yet support these models, so we manually use the + // same tokenizer as GPT-4. + tiktoken_rs::num_tokens_from_messages("gpt-4", &string_messages) + .map(|tokens| (tokens + tokens_from_images) as u64) +} + +pub fn into_anthropic( + request: LanguageModelRequest, + model: String, + default_temperature: f32, + max_output_tokens: u64, + mode: AnthropicModelMode, +) -> crate::Request { + let mut new_messages: Vec = Vec::new(); + let mut system_message = String::new(); + + for message in request.messages { + if message.contents_empty() { + continue; + } + + match message.role { + Role::User | Role::Assistant => { + let mut anthropic_message_content: Vec = message + .content + .into_iter() + .filter_map(to_anthropic_content) + .collect(); + let anthropic_role = match message.role { + Role::User => crate::Role::User, + Role::Assistant => crate::Role::Assistant, + Role::System => unreachable!("System role should never occur here"), + }; + if anthropic_message_content.is_empty() { + continue; + } + + if let Some(last_message) = new_messages.last_mut() + && last_message.role == anthropic_role + { + last_message.content.extend(anthropic_message_content); + continue; + } + + // Mark the last segment of the message as cached + if message.cache { + let cache_control_value = Some(CacheControl { + cache_type: CacheControlType::Ephemeral, + }); + for message_content in anthropic_message_content.iter_mut().rev() { + match message_content { + RequestContent::RedactedThinking { .. } => { + // Caching is not possible, fallback to next message + } + RequestContent::Text { cache_control, .. } + | RequestContent::Thinking { cache_control, .. } + | RequestContent::Image { cache_control, .. } + | RequestContent::ToolUse { cache_control, .. } + | RequestContent::ToolResult { cache_control, .. } => { + *cache_control = cache_control_value; + break; + } + } + } + } + + new_messages.push(Message { + role: anthropic_role, + content: anthropic_message_content, + }); + } + Role::System => { + if !system_message.is_empty() { + system_message.push_str("\n\n"); + } + system_message.push_str(&message.string_contents()); + } + } + } + + crate::Request { + model, + messages: new_messages, + max_tokens: max_output_tokens, + system: if system_message.is_empty() { + None + } else { + Some(StringOrContents::String(system_message)) + }, + thinking: if request.thinking_allowed { + match mode { + AnthropicModelMode::Thinking { budget_tokens } => { + Some(Thinking::Enabled { budget_tokens }) + } + AnthropicModelMode::AdaptiveThinking => Some(Thinking::Adaptive), + AnthropicModelMode::Default => None, + } + } else { + None + }, + tools: request + .tools + .into_iter() + .map(|tool| Tool { + name: tool.name, + description: tool.description, + input_schema: tool.input_schema, + eager_input_streaming: tool.use_input_streaming, + }) + .collect(), + tool_choice: request.tool_choice.map(|choice| match choice { + LanguageModelToolChoice::Auto => ToolChoice::Auto, + LanguageModelToolChoice::Any => ToolChoice::Any, + LanguageModelToolChoice::None => ToolChoice::None, + }), + metadata: None, + output_config: if request.thinking_allowed + && matches!(mode, AnthropicModelMode::AdaptiveThinking) + { + request.thinking_effort.as_deref().and_then(|effort| { + let effort = match effort { + "low" => Some(crate::Effort::Low), + "medium" => Some(crate::Effort::Medium), + "high" => Some(crate::Effort::High), + "max" => Some(crate::Effort::Max), + _ => None, + }; + effort.map(|effort| crate::OutputConfig { + effort: Some(effort), + }) + }) + } else { + None + }, + stop_sequences: Vec::new(), + speed: request.speed.map(Into::into), + temperature: request.temperature.or(Some(default_temperature)), + top_k: None, + top_p: None, + } +} + +pub struct AnthropicEventMapper { + tool_uses_by_index: HashMap, + usage: Usage, + stop_reason: StopReason, +} + +impl AnthropicEventMapper { + pub fn new() -> Self { + Self { + tool_uses_by_index: HashMap::default(), + usage: Usage::default(), + stop_reason: StopReason::EndTurn, + } + } + + pub fn map_stream( + mut self, + events: Pin>>>, + ) -> impl Stream> + { + events.flat_map(move |event| { + futures::stream::iter(match event { + Ok(event) => self.map_event(event), + Err(error) => vec![Err(error.into())], + }) + }) + } + + pub fn map_event( + &mut self, + event: Event, + ) -> Vec> { + match event { + Event::ContentBlockStart { + index, + content_block, + } => match content_block { + ResponseContent::Text { text } => { + vec![Ok(LanguageModelCompletionEvent::Text(text))] + } + ResponseContent::Thinking { thinking } => { + vec![Ok(LanguageModelCompletionEvent::Thinking { + text: thinking, + signature: None, + })] + } + ResponseContent::RedactedThinking { data } => { + vec![Ok(LanguageModelCompletionEvent::RedactedThinking { data })] + } + ResponseContent::ToolUse { id, name, .. } => { + self.tool_uses_by_index.insert( + index, + RawToolUse { + id, + name, + input_json: String::new(), + }, + ); + Vec::new() + } + }, + Event::ContentBlockDelta { index, delta } => match delta { + ContentDelta::TextDelta { text } => { + vec![Ok(LanguageModelCompletionEvent::Text(text))] + } + ContentDelta::ThinkingDelta { thinking } => { + vec![Ok(LanguageModelCompletionEvent::Thinking { + text: thinking, + signature: None, + })] + } + ContentDelta::SignatureDelta { signature } => { + vec![Ok(LanguageModelCompletionEvent::Thinking { + text: "".to_string(), + signature: Some(signature), + })] + } + ContentDelta::InputJsonDelta { partial_json } => { + if let Some(tool_use) = self.tool_uses_by_index.get_mut(&index) { + tool_use.input_json.push_str(&partial_json); + + // Try to convert invalid (incomplete) JSON into + // valid JSON that serde can accept, e.g. by closing + // unclosed delimiters. This way, we can update the + // UI with whatever has been streamed back so far. + if let Ok(input) = + serde_json::Value::from_str(&fix_streamed_json(&tool_use.input_json)) + { + return vec![Ok(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: tool_use.id.clone().into(), + name: tool_use.name.clone().into(), + is_input_complete: false, + raw_input: tool_use.input_json.clone(), + input, + thought_signature: None, + }, + ))]; + } + } + vec![] + } + }, + Event::ContentBlockStop { index } => { + if let Some(tool_use) = self.tool_uses_by_index.remove(&index) { + let input_json = tool_use.input_json.trim(); + let event_result = match parse_tool_arguments(input_json) { + Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: tool_use.id.into(), + name: tool_use.name.into(), + is_input_complete: true, + input, + raw_input: tool_use.input_json.clone(), + thought_signature: None, + }, + )), + Err(json_parse_err) => { + Ok(LanguageModelCompletionEvent::ToolUseJsonParseError { + id: tool_use.id.into(), + tool_name: tool_use.name.into(), + raw_input: input_json.into(), + json_parse_error: json_parse_err.to_string(), + }) + } + }; + + vec![event_result] + } else { + Vec::new() + } + } + Event::MessageStart { message } => { + update_usage(&mut self.usage, &message.usage); + vec![ + Ok(LanguageModelCompletionEvent::UsageUpdate(convert_usage( + &self.usage, + ))), + Ok(LanguageModelCompletionEvent::StartMessage { + message_id: message.id, + }), + ] + } + Event::MessageDelta { delta, usage } => { + update_usage(&mut self.usage, &usage); + if let Some(stop_reason) = delta.stop_reason.as_deref() { + self.stop_reason = match stop_reason { + "end_turn" => StopReason::EndTurn, + "max_tokens" => StopReason::MaxTokens, + "tool_use" => StopReason::ToolUse, + "refusal" => StopReason::Refusal, + _ => { + log::error!("Unexpected anthropic stop_reason: {stop_reason}"); + StopReason::EndTurn + } + }; + } + vec![Ok(LanguageModelCompletionEvent::UsageUpdate( + convert_usage(&self.usage), + ))] + } + Event::MessageStop => { + vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))] + } + Event::Error { error } => { + vec![Err(error.into())] + } + _ => Vec::new(), + } + } +} + +struct RawToolUse { + id: String, + name: String, + input_json: String, +} + +/// Updates usage data by preferring counts from `new`. +fn update_usage(usage: &mut Usage, new: &Usage) { + if let Some(input_tokens) = new.input_tokens { + usage.input_tokens = Some(input_tokens); + } + if let Some(output_tokens) = new.output_tokens { + usage.output_tokens = Some(output_tokens); + } + if let Some(cache_creation_input_tokens) = new.cache_creation_input_tokens { + usage.cache_creation_input_tokens = Some(cache_creation_input_tokens); + } + if let Some(cache_read_input_tokens) = new.cache_read_input_tokens { + usage.cache_read_input_tokens = Some(cache_read_input_tokens); + } +} + +fn convert_usage(usage: &Usage) -> TokenUsage { + TokenUsage { + input_tokens: usage.input_tokens.unwrap_or(0), + output_tokens: usage.output_tokens.unwrap_or(0), + cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0), + cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::AnthropicModelMode; + use language_model_core::{LanguageModelImage, LanguageModelRequestMessage, MessageContent}; + + #[test] + fn test_cache_control_only_on_last_segment() { + let request = LanguageModelRequest { + messages: vec![LanguageModelRequestMessage { + role: Role::User, + content: vec![ + MessageContent::Text("Some prompt".to_string()), + MessageContent::Image(LanguageModelImage::empty()), + MessageContent::Image(LanguageModelImage::empty()), + MessageContent::Image(LanguageModelImage::empty()), + MessageContent::Image(LanguageModelImage::empty()), + ], + cache: true, + reasoning_details: None, + }], + thread_id: None, + prompt_id: None, + intent: None, + stop: vec![], + temperature: None, + tools: vec![], + tool_choice: None, + thinking_allowed: true, + thinking_effort: None, + speed: None, + }; + + let anthropic_request = into_anthropic( + request, + "claude-3-5-sonnet".to_string(), + 0.7, + 4096, + AnthropicModelMode::Default, + ); + + assert_eq!(anthropic_request.messages.len(), 1); + + let message = &anthropic_request.messages[0]; + assert_eq!(message.content.len(), 5); + + assert!(matches!( + message.content[0], + RequestContent::Text { + cache_control: None, + .. + } + )); + for i in 1..3 { + assert!(matches!( + message.content[i], + RequestContent::Image { + cache_control: None, + .. + } + )); + } + + assert!(matches!( + message.content[4], + RequestContent::Image { + cache_control: Some(CacheControl { + cache_type: CacheControlType::Ephemeral, + }), + .. + } + )); + } + + fn request_with_assistant_content(assistant_content: Vec) -> crate::Request { + let mut request = LanguageModelRequest { + messages: vec![LanguageModelRequestMessage { + role: Role::User, + content: vec![MessageContent::Text("Hello".to_string())], + cache: false, + reasoning_details: None, + }], + thinking_effort: None, + thread_id: None, + prompt_id: None, + intent: None, + stop: vec![], + temperature: None, + tools: vec![], + tool_choice: None, + thinking_allowed: true, + speed: None, + }; + request.messages.push(LanguageModelRequestMessage { + role: Role::Assistant, + content: assistant_content, + cache: false, + reasoning_details: None, + }); + into_anthropic( + request, + "claude-sonnet-4-5".to_string(), + 1.0, + 16000, + AnthropicModelMode::Thinking { + budget_tokens: Some(10000), + }, + ) + } + + #[test] + fn test_unsigned_thinking_blocks_stripped() { + let result = request_with_assistant_content(vec![ + MessageContent::Thinking { + text: "Cancelled mid-think, no signature".to_string(), + signature: None, + }, + MessageContent::Text("Some response text".to_string()), + ]); + + let assistant_message = result + .messages + .iter() + .find(|m| m.role == crate::Role::Assistant) + .expect("assistant message should still exist"); + + assert_eq!( + assistant_message.content.len(), + 1, + "Only the text content should remain; unsigned thinking block should be stripped" + ); + assert!(matches!( + &assistant_message.content[0], + RequestContent::Text { text, .. } if text == "Some response text" + )); + } + + #[test] + fn test_signed_thinking_blocks_preserved() { + let result = request_with_assistant_content(vec![ + MessageContent::Thinking { + text: "Completed thinking".to_string(), + signature: Some("valid-signature".to_string()), + }, + MessageContent::Text("Response".to_string()), + ]); + + let assistant_message = result + .messages + .iter() + .find(|m| m.role == crate::Role::Assistant) + .expect("assistant message should exist"); + + assert_eq!( + assistant_message.content.len(), + 2, + "Both the signed thinking block and text should be preserved" + ); + assert!(matches!( + &assistant_message.content[0], + RequestContent::Thinking { thinking, signature, .. } + if thinking == "Completed thinking" && signature == "valid-signature" + )); + } + + #[test] + fn test_only_unsigned_thinking_block_omits_entire_message() { + let result = request_with_assistant_content(vec![MessageContent::Thinking { + text: "Cancelled before any text or signature".to_string(), + signature: None, + }]); + + let assistant_messages: Vec<_> = result + .messages + .iter() + .filter(|m| m.role == crate::Role::Assistant) + .collect(); + + assert_eq!( + assistant_messages.len(), + 0, + "An assistant message whose only content was an unsigned thinking block \ + should be omitted entirely" + ); + } +} diff --git a/crates/bedrock/src/models.rs b/crates/bedrock/src/models.rs index 8b6113e4d5521fb3c7e27a7f2f6547c7a9db86ce..7c1e6e0e4e6ef873345c30c0af4c9e8842699c77 100644 --- a/crates/bedrock/src/models.rs +++ b/crates/bedrock/src/models.rs @@ -113,6 +113,10 @@ pub enum Model { MistralLarge3, #[serde(rename = "pixtral-large")] PixtralLarge, + #[serde(rename = "devstral-2-123b")] + Devstral2_123B, + #[serde(rename = "ministral-14b")] + Ministral14B, // Qwen models #[serde(rename = "qwen3-32b")] @@ -146,9 +150,27 @@ pub enum Model { #[serde(rename = "gpt-oss-120b")] GptOss120B, + // NVIDIA Nemotron models + #[serde(rename = "nemotron-super-3-120b")] + NemotronSuper3_120B, + #[serde(rename = "nemotron-nano-3-30b")] + NemotronNano3_30B, + // MiniMax models #[serde(rename = "minimax-m2")] MiniMaxM2, + #[serde(rename = "minimax-m2-1")] + MiniMaxM2_1, + #[serde(rename = "minimax-m2-5")] + MiniMaxM2_5, + + // Z.AI GLM models + #[serde(rename = "glm-5")] + GLM5, + #[serde(rename = "glm-4-7")] + GLM4_7, + #[serde(rename = "glm-4-7-flash")] + GLM4_7Flash, // Moonshot models #[serde(rename = "kimi-k2-thinking")] @@ -217,6 +239,8 @@ impl Model { Self::MagistralSmall => "magistral-small", Self::MistralLarge3 => "mistral-large-3", Self::PixtralLarge => "pixtral-large", + Self::Devstral2_123B => "devstral-2-123b", + Self::Ministral14B => "ministral-14b", Self::Qwen3_32B => "qwen3-32b", Self::Qwen3VL235B => "qwen3-vl-235b", Self::Qwen3_235B => "qwen3-235b", @@ -230,7 +254,14 @@ impl Model { Self::Nova2Lite => "nova-2-lite", Self::GptOss20B => "gpt-oss-20b", Self::GptOss120B => "gpt-oss-120b", + Self::NemotronSuper3_120B => "nemotron-super-3-120b", + Self::NemotronNano3_30B => "nemotron-nano-3-30b", Self::MiniMaxM2 => "minimax-m2", + Self::MiniMaxM2_1 => "minimax-m2-1", + Self::MiniMaxM2_5 => "minimax-m2-5", + Self::GLM5 => "glm-5", + Self::GLM4_7 => "glm-4-7", + Self::GLM4_7Flash => "glm-4-7-flash", Self::KimiK2Thinking => "kimi-k2-thinking", Self::KimiK2_5 => "kimi-k2-5", Self::DeepSeekR1 => "deepseek-r1", @@ -257,6 +288,8 @@ impl Model { Self::MagistralSmall => "mistral.magistral-small-2509", Self::MistralLarge3 => "mistral.mistral-large-3-675b-instruct", Self::PixtralLarge => "mistral.pixtral-large-2502-v1:0", + Self::Devstral2_123B => "mistral.devstral-2-123b", + Self::Ministral14B => "mistral.ministral-3-14b-instruct", Self::Qwen3VL235B => "qwen.qwen3-vl-235b-a22b", Self::Qwen3_32B => "qwen.qwen3-32b-v1:0", Self::Qwen3_235B => "qwen.qwen3-235b-a22b-2507-v1:0", @@ -270,7 +303,14 @@ impl Model { Self::Nova2Lite => "amazon.nova-2-lite-v1:0", Self::GptOss20B => "openai.gpt-oss-20b-1:0", Self::GptOss120B => "openai.gpt-oss-120b-1:0", + Self::NemotronSuper3_120B => "nvidia.nemotron-super-3-120b", + Self::NemotronNano3_30B => "nvidia.nemotron-nano-3-30b", Self::MiniMaxM2 => "minimax.minimax-m2", + Self::MiniMaxM2_1 => "minimax.minimax-m2.1", + Self::MiniMaxM2_5 => "minimax.minimax-m2.5", + Self::GLM5 => "zai.glm-5", + Self::GLM4_7 => "zai.glm-4.7", + Self::GLM4_7Flash => "zai.glm-4.7-flash", Self::KimiK2Thinking => "moonshot.kimi-k2-thinking", Self::KimiK2_5 => "moonshotai.kimi-k2.5", Self::DeepSeekR1 => "deepseek.r1-v1:0", @@ -297,6 +337,8 @@ impl Model { Self::MagistralSmall => "Magistral Small", Self::MistralLarge3 => "Mistral Large 3", Self::PixtralLarge => "Pixtral Large", + Self::Devstral2_123B => "Devstral 2 123B", + Self::Ministral14B => "Ministral 14B", Self::Qwen3VL235B => "Qwen3 VL 235B", Self::Qwen3_32B => "Qwen3 32B", Self::Qwen3_235B => "Qwen3 235B", @@ -310,7 +352,14 @@ impl Model { Self::Nova2Lite => "Amazon Nova 2 Lite", Self::GptOss20B => "GPT OSS 20B", Self::GptOss120B => "GPT OSS 120B", + Self::NemotronSuper3_120B => "Nemotron Super 3 120B", + Self::NemotronNano3_30B => "Nemotron Nano 3 30B", Self::MiniMaxM2 => "MiniMax M2", + Self::MiniMaxM2_1 => "MiniMax M2.1", + Self::MiniMaxM2_5 => "MiniMax M2.5", + Self::GLM5 => "GLM 5", + Self::GLM4_7 => "GLM 4.7", + Self::GLM4_7Flash => "GLM 4.7 Flash", Self::KimiK2Thinking => "Kimi K2 Thinking", Self::KimiK2_5 => "Kimi K2.5", Self::DeepSeekR1 => "DeepSeek R1", @@ -338,6 +387,7 @@ impl Model { Self::Llama4Scout17B | Self::Llama4Maverick17B => 128_000, Self::Gemma3_4B | Self::Gemma3_12B | Self::Gemma3_27B => 128_000, Self::MagistralSmall | Self::MistralLarge3 | Self::PixtralLarge => 128_000, + Self::Devstral2_123B | Self::Ministral14B => 256_000, Self::Qwen3_32B | Self::Qwen3VL235B | Self::Qwen3_235B @@ -349,7 +399,9 @@ impl Model { Self::NovaPremier => 1_000_000, Self::Nova2Lite => 300_000, Self::GptOss20B | Self::GptOss120B => 128_000, - Self::MiniMaxM2 => 128_000, + Self::NemotronSuper3_120B | Self::NemotronNano3_30B => 262_000, + Self::MiniMaxM2 | Self::MiniMaxM2_1 | Self::MiniMaxM2_5 => 196_000, + Self::GLM5 | Self::GLM4_7 | Self::GLM4_7Flash => 203_000, Self::KimiK2Thinking | Self::KimiK2_5 => 128_000, Self::DeepSeekR1 | Self::DeepSeekV3_1 | Self::DeepSeekV3_2 => 128_000, Self::Custom { max_tokens, .. } => *max_tokens, @@ -373,6 +425,7 @@ impl Model { | Self::MagistralSmall | Self::MistralLarge3 | Self::PixtralLarge => 8_192, + Self::Devstral2_123B | Self::Ministral14B => 131_000, Self::Qwen3_32B | Self::Qwen3VL235B | Self::Qwen3_235B @@ -382,7 +435,9 @@ impl Model { | Self::Qwen3Coder480B => 8_192, Self::NovaLite | Self::NovaPro | Self::NovaPremier | Self::Nova2Lite => 5_000, Self::GptOss20B | Self::GptOss120B => 16_000, - Self::MiniMaxM2 => 16_000, + Self::NemotronSuper3_120B | Self::NemotronNano3_30B => 131_000, + Self::MiniMaxM2 | Self::MiniMaxM2_1 | Self::MiniMaxM2_5 => 98_000, + Self::GLM5 | Self::GLM4_7 | Self::GLM4_7Flash => 101_000, Self::KimiK2Thinking | Self::KimiK2_5 => 16_000, Self::DeepSeekR1 | Self::DeepSeekV3_1 | Self::DeepSeekV3_2 => 16_000, Self::Custom { @@ -419,6 +474,7 @@ impl Model { | Self::ClaudeSonnet4_6 => true, Self::NovaLite | Self::NovaPro | Self::NovaPremier | Self::Nova2Lite => true, Self::MistralLarge3 | Self::PixtralLarge | Self::MagistralSmall => true, + Self::Devstral2_123B | Self::Ministral14B => true, // Gemma accepts toolConfig without error but produces unreliable tool // calls -- malformed JSON args, hallucinated tool names, dropped calls. Self::Qwen3_32B @@ -428,7 +484,9 @@ impl Model { | Self::Qwen3Coder30B | Self::Qwen3CoderNext | Self::Qwen3Coder480B => true, - Self::MiniMaxM2 => true, + Self::MiniMaxM2 | Self::MiniMaxM2_1 | Self::MiniMaxM2_5 => true, + Self::NemotronSuper3_120B | Self::NemotronNano3_30B => true, + Self::GLM5 | Self::GLM4_7 | Self::GLM4_7Flash => true, Self::KimiK2Thinking | Self::KimiK2_5 => true, Self::DeepSeekR1 | Self::DeepSeekV3_1 | Self::DeepSeekV3_2 => true, _ => false, diff --git a/crates/client/Cargo.toml b/crates/client/Cargo.toml index 7bbaccb22e0e6c7508240186103e216f83be2f0c..532fe38f7df1f686730ed862a81806e9a531e156 100644 --- a/crates/client/Cargo.toml +++ b/crates/client/Cargo.toml @@ -36,7 +36,6 @@ gpui_tokio.workspace = true http_client.workspace = true http_client_tls.workspace = true httparse = "1.10" -language_model.workspace = true log.workspace = true parking_lot.workspace = true paths.workspace = true diff --git a/crates/client/src/client.rs b/crates/client/src/client.rs index dfd9963a0ee52d167f8d4edb0b850f4debed7fd4..05ca974f80438542b232262dd375e0e38ab4327c 100644 --- a/crates/client/src/client.rs +++ b/crates/client/src/client.rs @@ -14,6 +14,7 @@ use async_tungstenite::tungstenite::{ http::{HeaderValue, Request, StatusCode}, }; use clock::SystemClock; +use cloud_api_client::LlmApiToken; use cloud_api_client::websocket_protocol::MessageToClient; use cloud_api_client::{ClientApiError, CloudApiClient}; use cloud_api_types::OrganizationId; @@ -26,7 +27,6 @@ use futures::{ }; use gpui::{App, AsyncApp, Entity, Global, Task, WeakEntity, actions}; use http_client::{HttpClient, HttpClientWithUrl, http, read_proxy_from_env}; -use language_model::LlmApiToken; use parking_lot::{Mutex, RwLock}; use postage::watch; use proxy::connect_proxy_stream; diff --git a/crates/client/src/llm_token.rs b/crates/client/src/llm_token.rs index f62aa6dd4dc3462bc3a0f6f46c35f0e4e5499816..70457679e4b965e3251ae4861d3052bfa41fd65a 100644 --- a/crates/client/src/llm_token.rs +++ b/crates/client/src/llm_token.rs @@ -1,10 +1,10 @@ use super::{Client, UserStore}; +use cloud_api_client::LlmApiToken; use cloud_api_types::websocket_protocol::MessageToClient; use cloud_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, OUTDATED_LLM_TOKEN_HEADER_NAME}; use gpui::{ App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _, Subscription, }; -use language_model::LlmApiToken; use std::sync::Arc; pub trait NeedsLlmTokenRefresh { diff --git a/crates/cloud_api_client/Cargo.toml b/crates/cloud_api_client/Cargo.toml index 78c684e3e54ee29a5f3f3ae5620d4a52b445f92e..cf293d83f848e1266dec977c0925af7f66608ce6 100644 --- a/crates/cloud_api_client/Cargo.toml +++ b/crates/cloud_api_client/Cargo.toml @@ -20,5 +20,6 @@ gpui_tokio.workspace = true http_client.workspace = true parking_lot.workspace = true serde_json.workspace = true +smol.workspace = true thiserror.workspace = true yawc.workspace = true diff --git a/crates/cloud_api_client/src/cloud_api_client.rs b/crates/cloud_api_client/src/cloud_api_client.rs index 13d67838b216f4990f15ec22c1701aa7aef9dbf2..8c605bb3490ef5c7aea6e96045680338e8344a83 100644 --- a/crates/cloud_api_client/src/cloud_api_client.rs +++ b/crates/cloud_api_client/src/cloud_api_client.rs @@ -1,3 +1,4 @@ +mod llm_token; mod websocket; use std::sync::Arc; @@ -18,6 +19,8 @@ use yawc::WebSocket; use crate::websocket::Connection; +pub use llm_token::LlmApiToken; + struct Credentials { user_id: u32, access_token: String, diff --git a/crates/cloud_api_client/src/llm_token.rs b/crates/cloud_api_client/src/llm_token.rs new file mode 100644 index 0000000000000000000000000000000000000000..711e0d51b89bf34db255d7cb1e58483c9de340fc --- /dev/null +++ b/crates/cloud_api_client/src/llm_token.rs @@ -0,0 +1,74 @@ +use std::sync::Arc; + +use cloud_api_types::OrganizationId; +use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard}; + +use crate::{ClientApiError, CloudApiClient}; + +#[derive(Clone, Default)] +pub struct LlmApiToken(Arc>>); + +impl LlmApiToken { + pub async fn acquire( + &self, + client: &CloudApiClient, + system_id: Option, + organization_id: Option, + ) -> Result { + let lock = self.0.upgradable_read().await; + if let Some(token) = lock.as_ref() { + Ok(token.to_string()) + } else { + Self::fetch( + RwLockUpgradableReadGuard::upgrade(lock).await, + client, + system_id, + organization_id, + ) + .await + } + } + + pub async fn refresh( + &self, + client: &CloudApiClient, + system_id: Option, + organization_id: Option, + ) -> Result { + Self::fetch(self.0.write().await, client, system_id, organization_id).await + } + + /// Clears the existing token before attempting to fetch a new one. + /// + /// Used when switching organizations so that a failed refresh doesn't + /// leave a token for the wrong organization. + pub async fn clear_and_refresh( + &self, + client: &CloudApiClient, + system_id: Option, + organization_id: Option, + ) -> Result { + let mut lock = self.0.write().await; + *lock = None; + Self::fetch(lock, client, system_id, organization_id).await + } + + async fn fetch( + mut lock: RwLockWriteGuard<'_, Option>, + client: &CloudApiClient, + system_id: Option, + organization_id: Option, + ) -> Result { + let result = client.create_llm_token(system_id, organization_id).await; + match result { + Ok(response) => { + *lock = Some(response.token.0.clone()); + Ok(response.token.0) + } + Err(err) => { + *lock = None; + Err(err) + } + } + } +} diff --git a/crates/cloud_llm_client/Cargo.toml b/crates/cloud_llm_client/Cargo.toml index a7b4f925a9302296e8fe25a14177a583e5f44b33..7cc59f255abeb27c6e35a2064654d8eca1a581fe 100644 --- a/crates/cloud_llm_client/Cargo.toml +++ b/crates/cloud_llm_client/Cargo.toml @@ -7,6 +7,7 @@ license = "Apache-2.0" [features] test-support = [] +predict-edits = ["dep:zeta_prompt"] [lints] workspace = true @@ -20,6 +21,6 @@ serde = { workspace = true, features = ["derive", "rc"] } serde_json.workspace = true strum = { workspace = true, features = ["derive"] } uuid = { workspace = true, features = ["serde"] } -zeta_prompt.workspace = true +zeta_prompt = { workspace = true, optional = true } diff --git a/crates/cloud_llm_client/src/cloud_llm_client.rs b/crates/cloud_llm_client/src/cloud_llm_client.rs index 35eb3f2b80dd400558b1f027781f5b8cf63bb6cb..ac8bdd462a9c4754ef42a6afa41f1bef8b5bbe6a 100644 --- a/crates/cloud_llm_client/src/cloud_llm_client.rs +++ b/crates/cloud_llm_client/src/cloud_llm_client.rs @@ -1,3 +1,4 @@ +#[cfg(feature = "predict-edits")] pub mod predict_edits_v3; use std::str::FromStr; diff --git a/crates/collab/tests/integration/git_tests.rs b/crates/collab/tests/integration/git_tests.rs index 2fa67b072f1c3d49ef5ca1b90056fd08d57df1ba..c273005264d0a53b6a083a4013f7597a56919016 100644 --- a/crates/collab/tests/integration/git_tests.rs +++ b/crates/collab/tests/integration/git_tests.rs @@ -269,9 +269,11 @@ async fn test_remote_git_worktrees( cx_b.update(|cx| { repo_b.update(cx, |repository, _| { repository.create_worktree( - "feature-branch".to_string(), + git::repository::CreateWorktreeTarget::NewBranch { + branch_name: "feature-branch".to_string(), + base_sha: Some("abc123".to_string()), + }, worktree_directory.join("feature-branch"), - Some("abc123".to_string()), ) }) }) @@ -323,9 +325,11 @@ async fn test_remote_git_worktrees( cx_b.update(|cx| { repo_b.update(cx, |repository, _| { repository.create_worktree( - "bugfix-branch".to_string(), + git::repository::CreateWorktreeTarget::NewBranch { + branch_name: "bugfix-branch".to_string(), + base_sha: None, + }, worktree_directory.join("bugfix-branch"), - None, ) }) }) diff --git a/crates/collab/tests/integration/remote_editing_collaboration_tests.rs b/crates/collab/tests/integration/remote_editing_collaboration_tests.rs index 0796323fc5b3d8f6b1cbcb0e108a7d573240f446..d478402a9d66ca9fba4e8f9517cb62898754e677 100644 --- a/crates/collab/tests/integration/remote_editing_collaboration_tests.rs +++ b/crates/collab/tests/integration/remote_editing_collaboration_tests.rs @@ -473,9 +473,11 @@ async fn test_ssh_collaboration_git_worktrees( cx_b.update(|cx| { repo_b.update(cx, |repo, _| { repo.create_worktree( - "feature-branch".to_string(), + git::repository::CreateWorktreeTarget::NewBranch { + branch_name: "feature-branch".to_string(), + base_sha: Some("abc123".to_string()), + }, worktree_directory.join("feature-branch"), - Some("abc123".to_string()), ) }) }) diff --git a/crates/collab_ui/Cargo.toml b/crates/collab_ui/Cargo.toml index efcba05456955e308e5a00e938bf3092d894efeb..920f620e0ea2d48f514c5e0af598add193f80d98 100644 --- a/crates/collab_ui/Cargo.toml +++ b/crates/collab_ui/Cargo.toml @@ -32,7 +32,6 @@ test-support = [ anyhow.workspace = true call.workspace = true channel.workspace = true -chrono.workspace = true client.workspace = true collections.workspace = true db.workspace = true @@ -41,7 +40,6 @@ futures.workspace = true fuzzy.workspace = true gpui.workspace = true livekit_client.workspace = true -log.workspace = true menu.workspace = true notifications.workspace = true picker.workspace = true @@ -56,7 +54,6 @@ telemetry.workspace = true theme.workspace = true theme_settings.workspace = true time.workspace = true -time_format.workspace = true title_bar.workspace = true ui.workspace = true util.workspace = true diff --git a/crates/collab_ui/src/collab_panel.rs b/crates/collab_ui/src/collab_panel.rs index 8d0cdf351163dadf0ac8cbf6a8dc04886f30f583..1cff27ac6b2f3c61f7a90c4a9ca6749d4b1e48b7 100644 --- a/crates/collab_ui/src/collab_panel.rs +++ b/crates/collab_ui/src/collab_panel.rs @@ -6,7 +6,7 @@ use crate::{CollaborationPanelSettings, channel_view::ChannelView}; use anyhow::Context as _; use call::ActiveCall; use channel::{Channel, ChannelEvent, ChannelStore}; -use client::{ChannelId, Client, Contact, User, UserStore}; +use client::{ChannelId, Client, Contact, Notification, User, UserStore}; use collections::{HashMap, HashSet}; use contact_finder::ContactFinder; use db::kvp::KeyValueStore; @@ -21,6 +21,7 @@ use gpui::{ }; use menu::{Cancel, Confirm, SecondaryConfirm, SelectNext, SelectPrevious}; +use notifications::{NotificationEntry, NotificationEvent, NotificationStore}; use project::{Fs, Project}; use rpc::{ ErrorCode, ErrorExt, @@ -29,19 +30,23 @@ use rpc::{ use serde::{Deserialize, Serialize}; use settings::Settings; use smallvec::SmallVec; -use std::{mem, sync::Arc}; +use std::{mem, sync::Arc, time::Duration}; use theme::ActiveTheme; use theme_settings::ThemeSettings; use ui::{ - Avatar, AvatarAvailabilityIndicator, ContextMenu, CopyButton, Facepile, HighlightedLabel, - IconButtonShape, Indicator, ListHeader, ListItem, Tab, Tooltip, prelude::*, tooltip_container, + Avatar, AvatarAvailabilityIndicator, CollabNotification, ContextMenu, CopyButton, Facepile, + HighlightedLabel, IconButtonShape, Indicator, ListHeader, ListItem, Tab, Tooltip, prelude::*, + tooltip_container, }; use util::{ResultExt, TryFutureExt, maybe}; use workspace::{ CopyRoomId, Deafen, LeaveCall, MultiWorkspace, Mute, OpenChannelNotes, OpenChannelNotesById, ScreenShare, ShareProject, Workspace, dock::{DockPosition, Panel, PanelEvent}, - notifications::{DetachAndPromptErr, NotifyResultExt}, + notifications::{ + DetachAndPromptErr, Notification as WorkspaceNotification, NotificationId, NotifyResultExt, + SuppressEvent, + }, }; const FILTER_OCCUPIED_CHANNELS_KEY: &str = "filter_occupied_channels"; @@ -87,6 +92,7 @@ struct ChannelMoveClipboard { } const COLLABORATION_PANEL_KEY: &str = "CollaborationPanel"; +const TOAST_DURATION: Duration = Duration::from_secs(5); pub fn init(cx: &mut App) { cx.observe_new(|workspace: &mut Workspace, _, _| { @@ -267,6 +273,9 @@ pub struct CollabPanel { collapsed_channels: Vec, filter_occupied_channels: bool, workspace: WeakEntity, + notification_store: Entity, + current_notification_toast: Option<(u64, Task<()>)>, + mark_as_read_tasks: HashMap>>, } #[derive(Serialize, Deserialize)] @@ -394,6 +403,9 @@ impl CollabPanel { channel_editing_state: None, selection: None, channel_store: ChannelStore::global(cx), + notification_store: NotificationStore::global(cx), + current_notification_toast: None, + mark_as_read_tasks: HashMap::default(), user_store: workspace.user_store().clone(), project: workspace.project().clone(), subscriptions: Vec::default(), @@ -437,6 +449,11 @@ impl CollabPanel { } }, )); + this.subscriptions.push(cx.subscribe_in( + &this.notification_store, + window, + Self::on_notification_event, + )); this }) @@ -1181,7 +1198,7 @@ impl CollabPanel { .into(); ListItem::new(project_id as usize) - .height(px(24.)) + .height(rems_from_px(24.)) .toggle_state(is_selected) .on_click(cx.listener(move |this, _, window, cx| { this.workspace @@ -1222,7 +1239,7 @@ impl CollabPanel { let id = peer_id.map_or(usize::MAX, |id| id.as_u64() as usize); ListItem::new(("screen", id)) - .height(px(24.)) + .height(rems_from_px(24.)) .toggle_state(is_selected) .start_slot( h_flex() @@ -1269,7 +1286,7 @@ impl CollabPanel { let has_channel_buffer_changed = channel_store.has_channel_buffer_changed(channel_id); ListItem::new("channel-notes") - .height(px(24.)) + .height(rems_from_px(24.)) .toggle_state(is_selected) .on_click(cx.listener(move |this, _, window, cx| { this.open_channel_notes(channel_id, window, cx); @@ -2665,26 +2682,28 @@ impl CollabPanel { window: &mut Window, cx: &mut Context, ) -> AnyElement { - let entry = &self.entries[ix]; + let entry = self.entries[ix].clone(); let is_selected = self.selection == Some(ix); match entry { ListEntry::Header(section) => { - let is_collapsed = self.collapsed_sections.contains(section); - self.render_header(*section, is_selected, is_collapsed, cx) + let is_collapsed = self.collapsed_sections.contains(§ion); + self.render_header(section, is_selected, is_collapsed, cx) + .into_any_element() + } + ListEntry::Contact { contact, calling } => { + self.mark_contact_request_accepted_notifications_read(contact.user.id, cx); + self.render_contact(&contact, calling, is_selected, cx) .into_any_element() } - ListEntry::Contact { contact, calling } => self - .render_contact(contact, *calling, is_selected, cx) - .into_any_element(), ListEntry::ContactPlaceholder => self .render_contact_placeholder(is_selected, cx) .into_any_element(), ListEntry::IncomingRequest(user) => self - .render_contact_request(user, true, is_selected, cx) + .render_contact_request(&user, true, is_selected, cx) .into_any_element(), ListEntry::OutgoingRequest(user) => self - .render_contact_request(user, false, is_selected, cx) + .render_contact_request(&user, false, is_selected, cx) .into_any_element(), ListEntry::Channel { channel, @@ -2694,9 +2713,9 @@ impl CollabPanel { .. } => self .render_channel( - channel, - *depth, - *has_children, + &channel, + depth, + has_children, is_selected, ix, string_match.as_ref(), @@ -2704,10 +2723,10 @@ impl CollabPanel { ) .into_any_element(), ListEntry::ChannelEditor { depth } => self - .render_channel_editor(*depth, window, cx) + .render_channel_editor(depth, window, cx) .into_any_element(), ListEntry::ChannelInvite(channel) => self - .render_channel_invite(channel, is_selected, cx) + .render_channel_invite(&channel, is_selected, cx) .into_any_element(), ListEntry::CallParticipant { user, @@ -2715,7 +2734,7 @@ impl CollabPanel { is_pending, role, } => self - .render_call_participant(user, *peer_id, *is_pending, *role, is_selected, cx) + .render_call_participant(&user, peer_id, is_pending, role, is_selected, cx) .into_any_element(), ListEntry::ParticipantProject { project_id, @@ -2724,20 +2743,20 @@ impl CollabPanel { is_last, } => self .render_participant_project( - *project_id, - worktree_root_names, - *host_user_id, - *is_last, + project_id, + &worktree_root_names, + host_user_id, + is_last, is_selected, window, cx, ) .into_any_element(), ListEntry::ParticipantScreen { peer_id, is_last } => self - .render_participant_screen(*peer_id, *is_last, is_selected, window, cx) + .render_participant_screen(peer_id, is_last, is_selected, window, cx) .into_any_element(), ListEntry::ChannelNotes { channel_id } => self - .render_channel_notes(*channel_id, is_selected, window, cx) + .render_channel_notes(channel_id, is_selected, window, cx) .into_any_element(), } } @@ -2846,11 +2865,11 @@ impl CollabPanel { } }; - Some(channel.name.as_ref()) + Some(channel.name.clone()) }); if let Some(name) = channel_name { - SharedString::from(name.to_string()) + name } else { SharedString::from("Current Call") } @@ -3210,7 +3229,7 @@ impl CollabPanel { (IconName::Star, Color::Default, "Add to Favorites") }; - let height = px(24.); + let height = rems_from_px(24.); h_flex() .id(ix) @@ -3397,6 +3416,178 @@ impl CollabPanel { item.child(self.channel_name_editor.clone()) } } + + fn on_notification_event( + &mut self, + _: &Entity, + event: &NotificationEvent, + _window: &mut Window, + cx: &mut Context, + ) { + match event { + NotificationEvent::NewNotification { entry } => { + self.add_toast(entry, cx); + cx.notify(); + } + NotificationEvent::NotificationRemoved { entry } + | NotificationEvent::NotificationRead { entry } => { + self.remove_toast(entry.id, cx); + cx.notify(); + } + NotificationEvent::NotificationsUpdated { .. } => { + cx.notify(); + } + } + } + + fn present_notification( + &self, + entry: &NotificationEntry, + cx: &App, + ) -> Option<(Option>, String)> { + let user_store = self.user_store.read(cx); + match &entry.notification { + Notification::ContactRequest { sender_id } => { + let requester = user_store.get_cached_user(*sender_id)?; + Some(( + Some(requester.clone()), + format!("{} wants to add you as a contact", requester.github_login), + )) + } + Notification::ContactRequestAccepted { responder_id } => { + let responder = user_store.get_cached_user(*responder_id)?; + Some(( + Some(responder.clone()), + format!("{} accepted your contact request", responder.github_login), + )) + } + Notification::ChannelInvitation { + channel_name, + inviter_id, + .. + } => { + let inviter = user_store.get_cached_user(*inviter_id)?; + Some(( + Some(inviter.clone()), + format!( + "{} invited you to join the #{channel_name} channel", + inviter.github_login + ), + )) + } + } + } + + fn add_toast(&mut self, entry: &NotificationEntry, cx: &mut Context) { + let Some((actor, text)) = self.present_notification(entry, cx) else { + return; + }; + + let notification = entry.notification.clone(); + let needs_response = matches!( + notification, + Notification::ContactRequest { .. } | Notification::ChannelInvitation { .. } + ); + + let notification_id = entry.id; + + self.current_notification_toast = Some(( + notification_id, + cx.spawn(async move |this, cx| { + cx.background_executor().timer(TOAST_DURATION).await; + this.update(cx, |this, cx| this.remove_toast(notification_id, cx)) + .ok(); + }), + )); + + let collab_panel = cx.entity().downgrade(); + self.workspace + .update(cx, |workspace, cx| { + let id = NotificationId::unique::(); + + workspace.dismiss_notification(&id, cx); + workspace.show_notification(id, cx, |cx| { + let workspace = cx.entity().downgrade(); + cx.new(|cx| CollabNotificationToast { + actor, + text, + notification: needs_response.then(|| notification), + workspace, + collab_panel: collab_panel.clone(), + focus_handle: cx.focus_handle(), + }) + }) + }) + .ok(); + } + + fn mark_notification_read(&mut self, notification_id: u64, cx: &mut Context) { + let client = self.client.clone(); + self.mark_as_read_tasks + .entry(notification_id) + .or_insert_with(|| { + cx.spawn(async move |this, cx| { + let request_result = client + .request(proto::MarkNotificationRead { notification_id }) + .await; + + this.update(cx, |this, _| { + this.mark_as_read_tasks.remove(¬ification_id); + })?; + + request_result?; + Ok(()) + }) + }); + } + + fn mark_contact_request_accepted_notifications_read( + &mut self, + contact_user_id: u64, + cx: &mut Context, + ) { + let notification_ids = self.notification_store.read_with(cx, |store, _| { + (0..store.notification_count()) + .filter_map(|index| { + let entry = store.notification_at(index)?; + if entry.is_read { + return None; + } + + match &entry.notification { + Notification::ContactRequestAccepted { responder_id } + if *responder_id == contact_user_id => + { + Some(entry.id) + } + _ => None, + } + }) + .collect::>() + }); + + for notification_id in notification_ids { + self.mark_notification_read(notification_id, cx); + } + } + + fn remove_toast(&mut self, notification_id: u64, cx: &mut Context) { + if let Some((current_id, _)) = &self.current_notification_toast { + if *current_id == notification_id { + self.dismiss_toast(cx); + } + } + } + + fn dismiss_toast(&mut self, cx: &mut Context) { + self.current_notification_toast.take(); + self.workspace + .update(cx, |workspace, cx| { + let id = NotificationId::unique::(); + workspace.dismiss_notification(&id, cx) + }) + .ok(); + } } fn render_tree_branch( @@ -3516,12 +3707,38 @@ impl Panel for CollabPanel { CollaborationPanelSettings::get_global(cx).default_width } + fn set_active(&mut self, active: bool, _window: &mut Window, cx: &mut Context) { + if active && self.current_notification_toast.is_some() { + self.current_notification_toast.take(); + let workspace = self.workspace.clone(); + cx.defer(move |cx| { + workspace + .update(cx, |workspace, cx| { + let id = NotificationId::unique::(); + workspace.dismiss_notification(&id, cx) + }) + .ok(); + }); + } + } + fn icon(&self, _window: &Window, cx: &App) -> Option { CollaborationPanelSettings::get_global(cx) .button .then_some(ui::IconName::UserGroup) } + fn icon_label(&self, _window: &Window, cx: &App) -> Option { + let user_store = self.user_store.read(cx); + let count = user_store.incoming_contact_requests().len() + + self.channel_store.read(cx).channel_invitations().len(); + if count == 0 { + None + } else { + Some(count.to_string()) + } + } + fn icon_tooltip(&self, _window: &Window, _cx: &App) -> Option<&'static str> { Some("Collab Panel") } @@ -3702,6 +3919,101 @@ impl Render for JoinChannelTooltip { } } +pub struct CollabNotificationToast { + actor: Option>, + text: String, + notification: Option, + workspace: WeakEntity, + collab_panel: WeakEntity, + focus_handle: FocusHandle, +} + +impl Focusable for CollabNotificationToast { + fn focus_handle(&self, _cx: &App) -> FocusHandle { + self.focus_handle.clone() + } +} + +impl WorkspaceNotification for CollabNotificationToast {} + +impl CollabNotificationToast { + fn focus_collab_panel(&self, window: &mut Window, cx: &mut Context) { + let workspace = self.workspace.clone(); + window.defer(cx, move |window, cx| { + workspace + .update(cx, |workspace, cx| { + workspace.focus_panel::(window, cx) + }) + .ok(); + }) + } + + fn respond(&mut self, accept: bool, window: &mut Window, cx: &mut Context) { + if let Some(notification) = self.notification.take() { + self.collab_panel + .update(cx, |collab_panel, cx| match notification { + Notification::ContactRequest { sender_id } => { + collab_panel.respond_to_contact_request(sender_id, accept, window, cx); + } + Notification::ChannelInvitation { channel_id, .. } => { + collab_panel.respond_to_channel_invite(ChannelId(channel_id), accept, cx); + } + Notification::ContactRequestAccepted { .. } => {} + }) + .ok(); + } + cx.emit(DismissEvent); + } +} + +impl Render for CollabNotificationToast { + fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { + let needs_response = self.notification.is_some(); + + let accept_button = if needs_response { + Button::new("accept", "Accept").on_click(cx.listener(|this, _, window, cx| { + this.respond(true, window, cx); + cx.stop_propagation(); + })) + } else { + Button::new("dismiss", "Dismiss").on_click(cx.listener(|_, _, _, cx| { + cx.emit(DismissEvent); + })) + }; + + let decline_button = if needs_response { + Button::new("decline", "Decline").on_click(cx.listener(|this, _, window, cx| { + this.respond(false, window, cx); + cx.stop_propagation(); + })) + } else { + Button::new("close", "Close").on_click(cx.listener(|_, _, _, cx| { + cx.emit(DismissEvent); + })) + }; + + let avatar_uri = self + .actor + .as_ref() + .map(|user| user.avatar_uri.clone()) + .unwrap_or_default(); + + div() + .id("collab_notification_toast") + .on_click(cx.listener(|this, _, window, cx| { + this.focus_collab_panel(window, cx); + cx.emit(DismissEvent); + })) + .child( + CollabNotification::new(avatar_uri, accept_button, decline_button) + .child(Label::new(self.text.clone())), + ) + } +} + +impl EventEmitter for CollabNotificationToast {} +impl EventEmitter for CollabNotificationToast {} + #[cfg(any(test, feature = "test-support"))] impl CollabPanel { pub fn entries_as_strings(&self) -> Vec { diff --git a/crates/collab_ui/src/collab_ui.rs b/crates/collab_ui/src/collab_ui.rs index 107b2ffa7f625d98dd9c54bb6bbf75df8b72d020..f9c463c0690343a3b4b1b9a048134265326a9f50 100644 --- a/crates/collab_ui/src/collab_ui.rs +++ b/crates/collab_ui/src/collab_ui.rs @@ -1,7 +1,6 @@ mod call_stats_modal; pub mod channel_view; pub mod collab_panel; -pub mod notification_panel; pub mod notifications; mod panel_settings; @@ -12,7 +11,7 @@ use gpui::{ App, Pixels, PlatformDisplay, Size, WindowBackgroundAppearance, WindowBounds, WindowDecorations, WindowKind, WindowOptions, point, }; -pub use panel_settings::{CollaborationPanelSettings, NotificationPanelSettings}; +pub use panel_settings::CollaborationPanelSettings; use release_channel::ReleaseChannel; use ui::px; use workspace::AppState; @@ -22,7 +21,6 @@ pub fn init(app_state: &Arc, cx: &mut App) { call_stats_modal::init(cx); channel_view::init(cx); collab_panel::init(cx); - notification_panel::init(cx); notifications::init(app_state, cx); title_bar::init(cx); } diff --git a/crates/collab_ui/src/notification_panel.rs b/crates/collab_ui/src/notification_panel.rs deleted file mode 100644 index d7fef4873c687ab23a25b3144ba902cf4c42c137..0000000000000000000000000000000000000000 --- a/crates/collab_ui/src/notification_panel.rs +++ /dev/null @@ -1,727 +0,0 @@ -use crate::NotificationPanelSettings; -use anyhow::Result; -use channel::ChannelStore; -use client::{ChannelId, Client, Notification, User, UserStore}; -use collections::HashMap; -use futures::StreamExt; -use gpui::{ - AnyElement, App, AsyncWindowContext, ClickEvent, Context, DismissEvent, Element, Entity, - EventEmitter, FocusHandle, Focusable, InteractiveElement, IntoElement, ListAlignment, - ListScrollEvent, ListState, ParentElement, Render, StatefulInteractiveElement, Styled, Task, - WeakEntity, Window, actions, div, img, list, px, -}; -use notifications::{NotificationEntry, NotificationEvent, NotificationStore}; -use project::Fs; -use rpc::proto; - -use settings::{Settings, SettingsStore}; -use std::{sync::Arc, time::Duration}; -use time::{OffsetDateTime, UtcOffset}; -use ui::{ - Avatar, Button, Icon, IconButton, IconName, Label, Tab, Tooltip, h_flex, prelude::*, v_flex, -}; -use util::ResultExt; -use workspace::notifications::{ - Notification as WorkspaceNotification, NotificationId, SuppressEvent, -}; -use workspace::{ - Workspace, - dock::{DockPosition, Panel, PanelEvent}, -}; - -const LOADING_THRESHOLD: usize = 30; -const MARK_AS_READ_DELAY: Duration = Duration::from_secs(1); -const TOAST_DURATION: Duration = Duration::from_secs(5); -const NOTIFICATION_PANEL_KEY: &str = "NotificationPanel"; - -pub struct NotificationPanel { - client: Arc, - user_store: Entity, - channel_store: Entity, - notification_store: Entity, - fs: Arc, - active: bool, - notification_list: ListState, - subscriptions: Vec, - workspace: WeakEntity, - current_notification_toast: Option<(u64, Task<()>)>, - local_timezone: UtcOffset, - focus_handle: FocusHandle, - mark_as_read_tasks: HashMap>>, - unseen_notifications: Vec, -} - -#[derive(Debug)] -pub enum Event { - DockPositionChanged, - Focus, - Dismissed, -} - -pub struct NotificationPresenter { - pub actor: Option>, - pub text: String, - pub icon: &'static str, - pub needs_response: bool, -} - -actions!( - notification_panel, - [ - /// Toggles the notification panel. - Toggle, - /// Toggles focus on the notification panel. - ToggleFocus - ] -); - -pub fn init(cx: &mut App) { - cx.observe_new(|workspace: &mut Workspace, _, _| { - workspace.register_action(|workspace, _: &ToggleFocus, window, cx| { - workspace.toggle_panel_focus::(window, cx); - }); - workspace.register_action(|workspace, _: &Toggle, window, cx| { - if !workspace.toggle_panel_focus::(window, cx) { - workspace.close_panel::(window, cx); - } - }); - }) - .detach(); -} - -impl NotificationPanel { - pub fn new( - workspace: &mut Workspace, - window: &mut Window, - cx: &mut Context, - ) -> Entity { - let fs = workspace.app_state().fs.clone(); - let client = workspace.app_state().client.clone(); - let user_store = workspace.app_state().user_store.clone(); - let workspace_handle = workspace.weak_handle(); - - cx.new(|cx| { - let mut status = client.status(); - cx.spawn_in(window, async move |this, cx| { - while (status.next().await).is_some() { - if this - .update(cx, |_: &mut Self, cx| { - cx.notify(); - }) - .is_err() - { - break; - } - } - }) - .detach(); - - let notification_list = ListState::new(0, ListAlignment::Top, px(1000.)); - notification_list.set_scroll_handler(cx.listener( - |this, event: &ListScrollEvent, _, cx| { - if event.count.saturating_sub(event.visible_range.end) < LOADING_THRESHOLD - && let Some(task) = this - .notification_store - .update(cx, |store, cx| store.load_more_notifications(false, cx)) - { - task.detach(); - } - }, - )); - - let local_offset = chrono::Local::now().offset().local_minus_utc(); - let mut this = Self { - fs, - client, - user_store, - local_timezone: UtcOffset::from_whole_seconds(local_offset).unwrap(), - channel_store: ChannelStore::global(cx), - notification_store: NotificationStore::global(cx), - notification_list, - workspace: workspace_handle, - focus_handle: cx.focus_handle(), - subscriptions: Default::default(), - current_notification_toast: None, - active: false, - mark_as_read_tasks: Default::default(), - unseen_notifications: Default::default(), - }; - - let mut old_dock_position = this.position(window, cx); - this.subscriptions.extend([ - cx.observe(&this.notification_store, |_, _, cx| cx.notify()), - cx.subscribe_in( - &this.notification_store, - window, - Self::on_notification_event, - ), - cx.observe_global_in::( - window, - move |this: &mut Self, window, cx| { - let new_dock_position = this.position(window, cx); - if new_dock_position != old_dock_position { - old_dock_position = new_dock_position; - cx.emit(Event::DockPositionChanged); - } - cx.notify(); - }, - ), - ]); - this - }) - } - - pub fn load( - workspace: WeakEntity, - cx: AsyncWindowContext, - ) -> Task>> { - cx.spawn(async move |cx| { - workspace.update_in(cx, |workspace, window, cx| Self::new(workspace, window, cx)) - }) - } - - fn render_notification( - &mut self, - ix: usize, - window: &mut Window, - cx: &mut Context, - ) -> Option { - let entry = self.notification_store.read(cx).notification_at(ix)?; - let notification_id = entry.id; - let now = OffsetDateTime::now_utc(); - let timestamp = entry.timestamp; - let NotificationPresenter { - actor, - text, - needs_response, - .. - } = self.present_notification(entry, cx)?; - - let response = entry.response; - let notification = entry.notification.clone(); - - if self.active && !entry.is_read { - self.did_render_notification(notification_id, ¬ification, window, cx); - } - - let relative_timestamp = time_format::format_localized_timestamp( - timestamp, - now, - self.local_timezone, - time_format::TimestampFormat::Relative, - ); - - let absolute_timestamp = time_format::format_localized_timestamp( - timestamp, - now, - self.local_timezone, - time_format::TimestampFormat::Absolute, - ); - - Some( - div() - .id(ix) - .flex() - .flex_row() - .size_full() - .px_2() - .py_1() - .gap_2() - .hover(|style| style.bg(cx.theme().colors().element_hover)) - .children(actor.map(|actor| { - img(actor.avatar_uri.clone()) - .flex_none() - .w_8() - .h_8() - .rounded_full() - })) - .child( - v_flex() - .gap_1() - .size_full() - .overflow_hidden() - .child(Label::new(text)) - .child( - h_flex() - .child( - div() - .id("notification_timestamp") - .hover(|style| { - style - .bg(cx.theme().colors().element_selected) - .rounded_sm() - }) - .child(Label::new(relative_timestamp).color(Color::Muted)) - .tooltip(move |_, cx| { - Tooltip::simple(absolute_timestamp.clone(), cx) - }), - ) - .children(if let Some(is_accepted) = response { - Some(div().flex().flex_grow().justify_end().child(Label::new( - if is_accepted { - "You accepted" - } else { - "You declined" - }, - ))) - } else if needs_response { - Some( - h_flex() - .flex_grow() - .justify_end() - .child(Button::new("decline", "Decline").on_click({ - let notification = notification.clone(); - let entity = cx.entity(); - move |_, _, cx| { - entity.update(cx, |this, cx| { - this.respond_to_notification( - notification.clone(), - false, - cx, - ) - }); - } - })) - .child(Button::new("accept", "Accept").on_click({ - let notification = notification.clone(); - let entity = cx.entity(); - move |_, _, cx| { - entity.update(cx, |this, cx| { - this.respond_to_notification( - notification.clone(), - true, - cx, - ) - }); - } - })), - ) - } else { - None - }), - ), - ) - .into_any(), - ) - } - - fn present_notification( - &self, - entry: &NotificationEntry, - cx: &App, - ) -> Option { - let user_store = self.user_store.read(cx); - let channel_store = self.channel_store.read(cx); - match entry.notification { - Notification::ContactRequest { sender_id } => { - let requester = user_store.get_cached_user(sender_id)?; - Some(NotificationPresenter { - icon: "icons/plus.svg", - text: format!("{} wants to add you as a contact", requester.github_login), - needs_response: user_store.has_incoming_contact_request(requester.id), - actor: Some(requester), - }) - } - Notification::ContactRequestAccepted { responder_id } => { - let responder = user_store.get_cached_user(responder_id)?; - Some(NotificationPresenter { - icon: "icons/plus.svg", - text: format!("{} accepted your contact invite", responder.github_login), - needs_response: false, - actor: Some(responder), - }) - } - Notification::ChannelInvitation { - ref channel_name, - channel_id, - inviter_id, - } => { - let inviter = user_store.get_cached_user(inviter_id)?; - Some(NotificationPresenter { - icon: "icons/hash.svg", - text: format!( - "{} invited you to join the #{channel_name} channel", - inviter.github_login - ), - needs_response: channel_store.has_channel_invitation(ChannelId(channel_id)), - actor: Some(inviter), - }) - } - } - } - - fn did_render_notification( - &mut self, - notification_id: u64, - notification: &Notification, - window: &mut Window, - cx: &mut Context, - ) { - let should_mark_as_read = match notification { - Notification::ContactRequestAccepted { .. } => true, - Notification::ContactRequest { .. } | Notification::ChannelInvitation { .. } => false, - }; - - if should_mark_as_read { - self.mark_as_read_tasks - .entry(notification_id) - .or_insert_with(|| { - let client = self.client.clone(); - cx.spawn_in(window, async move |this, cx| { - cx.background_executor().timer(MARK_AS_READ_DELAY).await; - client - .request(proto::MarkNotificationRead { notification_id }) - .await?; - this.update(cx, |this, _| { - this.mark_as_read_tasks.remove(¬ification_id); - })?; - Ok(()) - }) - }); - } - } - - fn on_notification_event( - &mut self, - _: &Entity, - event: &NotificationEvent, - window: &mut Window, - cx: &mut Context, - ) { - match event { - NotificationEvent::NewNotification { entry } => { - self.unseen_notifications.push(entry.clone()); - self.add_toast(entry, window, cx); - } - NotificationEvent::NotificationRemoved { entry } - | NotificationEvent::NotificationRead { entry } => { - self.unseen_notifications.retain(|n| n.id != entry.id); - self.remove_toast(entry.id, cx); - } - NotificationEvent::NotificationsUpdated { - old_range, - new_count, - } => { - self.notification_list.splice(old_range.clone(), *new_count); - cx.notify(); - } - } - } - - fn add_toast( - &mut self, - entry: &NotificationEntry, - window: &mut Window, - cx: &mut Context, - ) { - let Some(NotificationPresenter { actor, text, .. }) = self.present_notification(entry, cx) - else { - return; - }; - - let notification_id = entry.id; - self.current_notification_toast = Some(( - notification_id, - cx.spawn_in(window, async move |this, cx| { - cx.background_executor().timer(TOAST_DURATION).await; - this.update(cx, |this, cx| this.remove_toast(notification_id, cx)) - .ok(); - }), - )); - - self.workspace - .update(cx, |workspace, cx| { - let id = NotificationId::unique::(); - - workspace.dismiss_notification(&id, cx); - workspace.show_notification(id, cx, |cx| { - let workspace = cx.entity().downgrade(); - cx.new(|cx| NotificationToast { - actor, - text, - workspace, - focus_handle: cx.focus_handle(), - }) - }) - }) - .ok(); - } - - fn remove_toast(&mut self, notification_id: u64, cx: &mut Context) { - if let Some((current_id, _)) = &self.current_notification_toast - && *current_id == notification_id - { - self.current_notification_toast.take(); - self.workspace - .update(cx, |workspace, cx| { - let id = NotificationId::unique::(); - workspace.dismiss_notification(&id, cx) - }) - .ok(); - } - } - - fn respond_to_notification( - &mut self, - notification: Notification, - response: bool, - - cx: &mut Context, - ) { - self.notification_store.update(cx, |store, cx| { - store.respond_to_notification(notification, response, cx); - }); - } -} - -impl Render for NotificationPanel { - fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement { - v_flex() - .size_full() - .child( - h_flex() - .justify_between() - .px_2() - .py_1() - // Match the height of the tab bar so they line up. - .h(Tab::container_height(cx)) - .border_b_1() - .border_color(cx.theme().colors().border) - .child(Label::new("Notifications")) - .child(Icon::new(IconName::Envelope)), - ) - .map(|this| { - if !self.client.status().borrow().is_connected() { - this.child( - v_flex() - .gap_2() - .p_4() - .child( - Button::new("connect_prompt_button", "Connect") - .start_icon(Icon::new(IconName::Github).color(Color::Muted)) - .style(ButtonStyle::Filled) - .full_width() - .on_click({ - let client = self.client.clone(); - move |_, window, cx| { - let client = client.clone(); - window - .spawn(cx, async move |cx| { - match client.connect(true, cx).await { - util::ConnectionResult::Timeout => { - log::error!("Connection timeout"); - } - util::ConnectionResult::ConnectionReset => { - log::error!("Connection reset"); - } - util::ConnectionResult::Result(r) => { - r.log_err(); - } - } - }) - .detach() - } - }), - ) - .child( - div().flex().w_full().items_center().child( - Label::new("Connect to view notifications.") - .color(Color::Muted) - .size(LabelSize::Small), - ), - ), - ) - } else if self.notification_list.item_count() == 0 { - this.child( - v_flex().p_4().child( - div().flex().w_full().items_center().child( - Label::new("You have no notifications.") - .color(Color::Muted) - .size(LabelSize::Small), - ), - ), - ) - } else { - this.child( - list( - self.notification_list.clone(), - cx.processor(|this, ix, window, cx| { - this.render_notification(ix, window, cx) - .unwrap_or_else(|| div().into_any()) - }), - ) - .size_full(), - ) - } - }) - } -} - -impl Focusable for NotificationPanel { - fn focus_handle(&self, _: &App) -> FocusHandle { - self.focus_handle.clone() - } -} - -impl EventEmitter for NotificationPanel {} -impl EventEmitter for NotificationPanel {} - -impl Panel for NotificationPanel { - fn persistent_name() -> &'static str { - "NotificationPanel" - } - - fn panel_key() -> &'static str { - NOTIFICATION_PANEL_KEY - } - - fn position(&self, _: &Window, cx: &App) -> DockPosition { - NotificationPanelSettings::get_global(cx).dock - } - - fn position_is_valid(&self, position: DockPosition) -> bool { - matches!(position, DockPosition::Left | DockPosition::Right) - } - - fn set_position(&mut self, position: DockPosition, _: &mut Window, cx: &mut Context) { - settings::update_settings_file(self.fs.clone(), cx, move |settings, _| { - settings.notification_panel.get_or_insert_default().dock = Some(position.into()) - }); - } - - fn default_size(&self, _: &Window, cx: &App) -> Pixels { - NotificationPanelSettings::get_global(cx).default_width - } - - fn set_active(&mut self, active: bool, _: &mut Window, cx: &mut Context) { - self.active = active; - - if self.active { - self.unseen_notifications = Vec::new(); - cx.notify(); - } - - if self.notification_store.read(cx).notification_count() == 0 { - cx.emit(Event::Dismissed); - } - } - - fn icon(&self, _: &Window, cx: &App) -> Option { - let show_button = NotificationPanelSettings::get_global(cx).button; - if !show_button { - return None; - } - - if self.unseen_notifications.is_empty() { - return Some(IconName::Bell); - } - - Some(IconName::BellDot) - } - - fn icon_tooltip(&self, _window: &Window, _cx: &App) -> Option<&'static str> { - Some("Notification Panel") - } - - fn icon_label(&self, _window: &Window, cx: &App) -> Option { - if !NotificationPanelSettings::get_global(cx).show_count_badge { - return None; - } - let count = self.notification_store.read(cx).unread_notification_count(); - if count == 0 { - None - } else { - Some(count.to_string()) - } - } - - fn toggle_action(&self) -> Box { - Box::new(ToggleFocus) - } - - fn activation_priority(&self) -> u32 { - 4 - } -} - -pub struct NotificationToast { - actor: Option>, - text: String, - workspace: WeakEntity, - focus_handle: FocusHandle, -} - -impl Focusable for NotificationToast { - fn focus_handle(&self, _cx: &App) -> FocusHandle { - self.focus_handle.clone() - } -} - -impl WorkspaceNotification for NotificationToast {} - -impl NotificationToast { - fn focus_notification_panel(&self, window: &mut Window, cx: &mut Context) { - let workspace = self.workspace.clone(); - window.defer(cx, move |window, cx| { - workspace - .update(cx, |workspace, cx| { - workspace.focus_panel::(window, cx) - }) - .ok(); - }) - } -} - -impl Render for NotificationToast { - fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { - let user = self.actor.clone(); - - let suppress = window.modifiers().shift; - let (close_id, close_icon) = if suppress { - ("suppress", IconName::Minimize) - } else { - ("close", IconName::Close) - }; - - h_flex() - .id("notification_panel_toast") - .elevation_3(cx) - .p_2() - .justify_between() - .children(user.map(|user| Avatar::new(user.avatar_uri.clone()))) - .child(Label::new(self.text.clone())) - .on_modifiers_changed(cx.listener(|_, _, _, cx| cx.notify())) - .child( - IconButton::new(close_id, close_icon) - .tooltip(move |_window, cx| { - if suppress { - Tooltip::for_action( - "Suppress.\nClose with click.", - &workspace::SuppressNotification, - cx, - ) - } else { - Tooltip::for_action( - "Close.\nSuppress with shift-click", - &menu::Cancel, - cx, - ) - } - }) - .on_click(cx.listener(move |_, _: &ClickEvent, _, cx| { - if suppress { - cx.emit(SuppressEvent); - } else { - cx.emit(DismissEvent); - } - })), - ) - .on_click(cx.listener(|this, _, window, cx| { - this.focus_notification_panel(window, cx); - cx.emit(DismissEvent); - })) - } -} - -impl EventEmitter for NotificationToast {} -impl EventEmitter for NotificationToast {} diff --git a/crates/collab_ui/src/panel_settings.rs b/crates/collab_ui/src/panel_settings.rs index 938d33159e9adb7a9e63ceb73219b70724efee17..3d6de1015a3751751c13c8ccb6d4c5639755be20 100644 --- a/crates/collab_ui/src/panel_settings.rs +++ b/crates/collab_ui/src/panel_settings.rs @@ -10,14 +10,6 @@ pub struct CollaborationPanelSettings { pub default_width: Pixels, } -#[derive(Debug, RegisterSetting)] -pub struct NotificationPanelSettings { - pub button: bool, - pub dock: DockPosition, - pub default_width: Pixels, - pub show_count_badge: bool, -} - impl Settings for CollaborationPanelSettings { fn from_settings(content: &settings::SettingsContent) -> Self { let panel = content.collaboration_panel.as_ref().unwrap(); @@ -29,15 +21,3 @@ impl Settings for CollaborationPanelSettings { } } } - -impl Settings for NotificationPanelSettings { - fn from_settings(content: &settings::SettingsContent) -> Self { - let panel = content.notification_panel.as_ref().unwrap(); - return Self { - button: panel.button.unwrap(), - dock: panel.dock.unwrap().into(), - default_width: panel.default_width.map(px).unwrap(), - show_count_badge: panel.show_count_badge.unwrap(), - }; - } -} diff --git a/crates/edit_prediction/Cargo.toml b/crates/edit_prediction/Cargo.toml index eabb1641fd4fbec7b2f8ef0ba399a8fe9600dfa3..87ad4e42e7826cdda4fc6a8c31a27afe888830f0 100644 --- a/crates/edit_prediction/Cargo.toml +++ b/crates/edit_prediction/Cargo.toml @@ -21,8 +21,9 @@ heapless.workspace = true buffer_diff.workspace = true client.workspace = true clock.workspace = true +cloud_api_client.workspace = true cloud_api_types.workspace = true -cloud_llm_client.workspace = true +cloud_llm_client = { workspace = true, features = ["predict-edits"] } collections.workspace = true copilot.workspace = true copilot_ui.workspace = true diff --git a/crates/edit_prediction/src/edit_prediction.rs b/crates/edit_prediction/src/edit_prediction.rs index 280427df006b510e1854ffb40cd7f995fcd9fdc6..2d90e13fb9b45aedd354f753502cd4e616ae3bcd 100644 --- a/crates/edit_prediction/src/edit_prediction.rs +++ b/crates/edit_prediction/src/edit_prediction.rs @@ -1,5 +1,6 @@ use anyhow::Result; use client::{Client, EditPredictionUsage, NeedsLlmTokenRefresh, UserStore, global_llm_token}; +use cloud_api_client::LlmApiToken; use cloud_api_types::{OrganizationId, SubmitEditPredictionFeedbackBody}; use cloud_llm_client::predict_edits_v3::{ PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse, @@ -31,7 +32,6 @@ use heapless::Vec as ArrayVec; use language::language_settings::all_language_settings; use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint}; use language::{BufferSnapshot, OffsetRangeExt}; -use language_model::LlmApiToken; use project::{DisableAiSettings, Project, ProjectPath, WorktreeId}; use release_channel::AppVersion; use semver::Version; diff --git a/crates/edit_prediction/src/ollama.rs b/crates/edit_prediction/src/ollama.rs index 0250ec44a46cf081c6badc6fa11a9c34ebb65c4a..0ae90dd9f6eca4bfe9f87950a5a66916d8894df4 100644 --- a/crates/edit_prediction/src/ollama.rs +++ b/crates/edit_prediction/src/ollama.rs @@ -57,7 +57,7 @@ pub fn fetch_models(cx: &mut App) -> Vec { let mut models: Vec = provider .provided_models(cx) .into_iter() - .map(|model| SharedString::from(model.id().0.to_string())) + .map(|model| model.id().0) .collect(); models.sort(); models diff --git a/crates/edit_prediction/src/zed_edit_prediction_delegate.rs b/crates/edit_prediction/src/zed_edit_prediction_delegate.rs index c5e97fd87eaad9b98aeb9b946a9a69b1c1071db2..1a574e9389715ce888f8b8c5ec8be921ceab4a38 100644 --- a/crates/edit_prediction/src/zed_edit_prediction_delegate.rs +++ b/crates/edit_prediction/src/zed_edit_prediction_delegate.rs @@ -177,7 +177,7 @@ impl EditPredictionDelegate for ZedEditPredictionDelegate { BufferEditPrediction::Local { prediction } => prediction, BufferEditPrediction::Jump { prediction } => { return Some(edit_prediction_types::EditPrediction::Jump { - id: Some(prediction.id.to_string().into()), + id: Some(prediction.id.0.clone()), snapshot: prediction.snapshot.clone(), target: prediction.edits.first().unwrap().0.start, }); @@ -228,7 +228,7 @@ impl EditPredictionDelegate for ZedEditPredictionDelegate { } Some(edit_prediction_types::EditPrediction::Local { - id: Some(prediction.id.to_string().into()), + id: Some(prediction.id.0.clone()), edits: edits[edit_start_ix..edit_end_ix].to_vec(), cursor_position: prediction.cursor_position, edit_preview: Some(prediction.edit_preview.clone()), diff --git a/crates/edit_prediction_cli/Cargo.toml b/crates/edit_prediction_cli/Cargo.toml index 323ee3de41902b2140f95da22b0e37fb98d31fd5..a999fed2baf990273f0801bac15573b3aed0cc78 100644 --- a/crates/edit_prediction_cli/Cargo.toml +++ b/crates/edit_prediction_cli/Cargo.toml @@ -22,7 +22,7 @@ http_client.workspace = true chrono.workspace = true clap = "4" client.workspace = true -cloud_llm_client.workspace= true +cloud_llm_client = { workspace = true, features = ["predict-edits"] } collections.workspace = true db.workspace = true debug_adapter_extension.workspace = true diff --git a/crates/editor/src/display_map.rs b/crates/editor/src/display_map.rs index f95f1030276015af4825119fc98ac68b876d0e5f..7cb8040e282a47d27cf5d7b33e5453295b4f645f 100644 --- a/crates/editor/src/display_map.rs +++ b/crates/editor/src/display_map.rs @@ -98,7 +98,7 @@ use gpui::{ WeakEntity, }; use language::{ - Point, Subscription as BufferSubscription, + LanguageAwareStyling, Point, Subscription as BufferSubscription, language_settings::{AllLanguageSettings, LanguageSettings}, }; @@ -1769,7 +1769,10 @@ impl DisplaySnapshot { self.block_snapshot .chunks( BlockRow(display_row.0)..BlockRow(self.max_point().row().next_row().0), - false, + LanguageAwareStyling { + tree_sitter: false, + diagnostics: false, + }, self.masked, Highlights::default(), ) @@ -1783,7 +1786,10 @@ impl DisplaySnapshot { self.block_snapshot .chunks( BlockRow(row)..BlockRow(row + 1), - false, + LanguageAwareStyling { + tree_sitter: false, + diagnostics: false, + }, self.masked, Highlights::default(), ) @@ -1798,7 +1804,7 @@ impl DisplaySnapshot { pub fn chunks( &self, display_rows: Range, - language_aware: bool, + language_aware: LanguageAwareStyling, highlight_styles: HighlightStyles, ) -> DisplayChunks<'_> { self.block_snapshot.chunks( @@ -1818,7 +1824,7 @@ impl DisplaySnapshot { pub fn highlighted_chunks<'a>( &'a self, display_rows: Range, - language_aware: bool, + language_aware: LanguageAwareStyling, editor_style: &'a EditorStyle, ) -> impl Iterator> { self.chunks( @@ -1910,7 +1916,10 @@ impl DisplaySnapshot { let chunks = custom_highlights::CustomHighlightsChunks::new( multibuffer_range, - true, + LanguageAwareStyling { + tree_sitter: true, + diagnostics: true, + }, None, Some(&self.semantic_token_highlights), multibuffer, @@ -1961,7 +1970,14 @@ impl DisplaySnapshot { let mut line = String::new(); let range = display_row..display_row.next_row(); - for chunk in self.highlighted_chunks(range, false, editor_style) { + for chunk in self.highlighted_chunks( + range, + LanguageAwareStyling { + tree_sitter: false, + diagnostics: false, + }, + editor_style, + ) { line.push_str(chunk.text); let text_style = if let Some(style) = chunk.style { @@ -3388,7 +3404,14 @@ pub mod tests { let snapshot = map.update(cx, |map, cx| map.snapshot(cx)); let mut chunks = Vec::<(String, Option, Rgba)>::new(); - for chunk in snapshot.chunks(DisplayRow(0)..DisplayRow(5), true, Default::default()) { + for chunk in snapshot.chunks( + DisplayRow(0)..DisplayRow(5), + LanguageAwareStyling { + tree_sitter: true, + diagnostics: true, + }, + Default::default(), + ) { let color = chunk .highlight_style .and_then(|style| style.color) @@ -3940,7 +3963,14 @@ pub mod tests { ) -> Vec<(String, Option, Option)> { let snapshot = map.update(cx, |map, cx| map.snapshot(cx)); let mut chunks: Vec<(String, Option, Option)> = Vec::new(); - for chunk in snapshot.chunks(rows, true, HighlightStyles::default()) { + for chunk in snapshot.chunks( + rows, + LanguageAwareStyling { + tree_sitter: true, + diagnostics: true, + }, + HighlightStyles::default(), + ) { let syntax_color = chunk .syntax_highlight_id .and_then(|id| theme.get(id)?.color); diff --git a/crates/editor/src/display_map/block_map.rs b/crates/editor/src/display_map/block_map.rs index 67318e3300e73085fe40c2e22edfcd06778902c8..17fa7e3de4a361f6728664e76368583788053cfd 100644 --- a/crates/editor/src/display_map/block_map.rs +++ b/crates/editor/src/display_map/block_map.rs @@ -9,7 +9,7 @@ use crate::{ }; use collections::{Bound, HashMap, HashSet}; use gpui::{AnyElement, App, EntityId, Pixels, Window}; -use language::{Patch, Point}; +use language::{LanguageAwareStyling, Patch, Point}; use multi_buffer::{ Anchor, ExcerptBoundaryInfo, MultiBuffer, MultiBufferOffset, MultiBufferPoint, MultiBufferRow, MultiBufferSnapshot, RowInfo, ToOffset, ToPoint as _, @@ -2140,7 +2140,10 @@ impl BlockSnapshot { pub fn text(&self) -> String { self.chunks( BlockRow(0)..self.transforms.summary().output_rows, - false, + LanguageAwareStyling { + tree_sitter: false, + diagnostics: false, + }, false, Highlights::default(), ) @@ -2152,7 +2155,7 @@ impl BlockSnapshot { pub(crate) fn chunks<'a>( &'a self, rows: Range, - language_aware: bool, + language_aware: LanguageAwareStyling, masked: bool, highlights: Highlights<'a>, ) -> BlockChunks<'a> { @@ -4300,7 +4303,10 @@ mod tests { let actual_text = blocks_snapshot .chunks( BlockRow(start_row as u32)..BlockRow(end_row as u32), - false, + LanguageAwareStyling { + tree_sitter: false, + diagnostics: false, + }, false, Highlights::default(), ) diff --git a/crates/editor/src/display_map/custom_highlights.rs b/crates/editor/src/display_map/custom_highlights.rs index 39eabef2f9627b8088dc826ec64379bf76a6c9fa..6e93e562172decb0843da35c7f55fafd92ed21cc 100644 --- a/crates/editor/src/display_map/custom_highlights.rs +++ b/crates/editor/src/display_map/custom_highlights.rs @@ -1,6 +1,6 @@ use collections::BTreeMap; use gpui::HighlightStyle; -use language::Chunk; +use language::{Chunk, LanguageAwareStyling}; use multi_buffer::{MultiBufferChunks, MultiBufferOffset, MultiBufferSnapshot, ToOffset as _}; use std::{ cmp, @@ -34,7 +34,7 @@ impl<'a> CustomHighlightsChunks<'a> { #[ztracing::instrument(skip_all)] pub fn new( range: Range, - language_aware: bool, + language_aware: LanguageAwareStyling, text_highlights: Option<&'a TextHighlights>, semantic_token_highlights: Option<&'a SemanticTokensHighlights>, multibuffer_snapshot: &'a MultiBufferSnapshot, @@ -308,7 +308,10 @@ mod tests { // Get all chunks and verify their bitmaps let chunks = CustomHighlightsChunks::new( MultiBufferOffset(0)..buffer_snapshot.len(), - false, + LanguageAwareStyling { + tree_sitter: false, + diagnostics: false, + }, None, None, &buffer_snapshot, diff --git a/crates/editor/src/display_map/fold_map.rs b/crates/editor/src/display_map/fold_map.rs index 1554bb96dab0e2f76a17df1396bd945f332af208..4c6c04b86cc3e2fb9ef10be58c14faae623dc65f 100644 --- a/crates/editor/src/display_map/fold_map.rs +++ b/crates/editor/src/display_map/fold_map.rs @@ -5,7 +5,7 @@ use super::{ inlay_map::{InlayBufferRows, InlayChunks, InlayEdit, InlayOffset, InlayPoint, InlaySnapshot}, }; use gpui::{AnyElement, App, ElementId, HighlightStyle, Pixels, SharedString, Stateful, Window}; -use language::{Edit, HighlightId, Point}; +use language::{Edit, HighlightId, LanguageAwareStyling, Point}; use multi_buffer::{ Anchor, AnchorRangeExt, MBTextSummary, MultiBufferOffset, MultiBufferRow, MultiBufferSnapshot, RowInfo, ToOffset, @@ -707,7 +707,10 @@ impl FoldSnapshot { pub fn text(&self) -> String { self.chunks( FoldOffset(MultiBufferOffset(0))..self.len(), - false, + LanguageAwareStyling { + tree_sitter: false, + diagnostics: false, + }, Highlights::default(), ) .map(|c| c.text) @@ -909,7 +912,7 @@ impl FoldSnapshot { pub(crate) fn chunks<'a>( &'a self, range: Range, - language_aware: bool, + language_aware: LanguageAwareStyling, highlights: Highlights<'a>, ) -> FoldChunks<'a> { let mut transform_cursor = self @@ -954,7 +957,10 @@ impl FoldSnapshot { pub fn chars_at(&self, start: FoldPoint) -> impl '_ + Iterator { self.chunks( start.to_offset(self)..self.len(), - false, + LanguageAwareStyling { + tree_sitter: false, + diagnostics: false, + }, Highlights::default(), ) .flat_map(|chunk| chunk.text.chars()) @@ -964,7 +970,10 @@ impl FoldSnapshot { pub fn chunks_at(&self, start: FoldPoint) -> FoldChunks<'_> { self.chunks( start.to_offset(self)..self.len(), - false, + LanguageAwareStyling { + tree_sitter: false, + diagnostics: false, + }, Highlights::default(), ) } @@ -2131,7 +2140,14 @@ mod tests { let text = &expected_text[start.0.0..end.0.0]; assert_eq!( snapshot - .chunks(start..end, false, Highlights::default()) + .chunks( + start..end, + LanguageAwareStyling { + tree_sitter: false, + diagnostics: false, + }, + Highlights::default() + ) .map(|c| c.text) .collect::(), text, @@ -2303,7 +2319,10 @@ mod tests { // Get all chunks and verify their bitmaps let chunks = snapshot.chunks( FoldOffset(MultiBufferOffset(0))..FoldOffset(snapshot.len().0), - false, + LanguageAwareStyling { + tree_sitter: false, + diagnostics: false, + }, Highlights::default(), ); diff --git a/crates/editor/src/display_map/inlay_map.rs b/crates/editor/src/display_map/inlay_map.rs index 47ca295ccb1a08768ce129b92d10506294a9cf78..698b58682d7ef7682094e7728f419348fd5d32d9 100644 --- a/crates/editor/src/display_map/inlay_map.rs +++ b/crates/editor/src/display_map/inlay_map.rs @@ -10,7 +10,7 @@ use crate::{ inlays::{Inlay, InlayContent}, }; use collections::BTreeSet; -use language::{Chunk, Edit, Point, TextSummary}; +use language::{Chunk, Edit, LanguageAwareStyling, Point, TextSummary}; use multi_buffer::{ MBTextSummary, MultiBufferOffset, MultiBufferRow, MultiBufferRows, MultiBufferSnapshot, RowInfo, ToOffset, @@ -1200,7 +1200,7 @@ impl InlaySnapshot { pub(crate) fn chunks<'a>( &'a self, range: Range, - language_aware: bool, + language_aware: LanguageAwareStyling, highlights: Highlights<'a>, ) -> InlayChunks<'a> { let mut cursor = self @@ -1234,9 +1234,16 @@ impl InlaySnapshot { #[cfg(test)] #[ztracing::instrument(skip_all)] pub fn text(&self) -> String { - self.chunks(Default::default()..self.len(), false, Highlights::default()) - .map(|chunk| chunk.chunk.text) - .collect() + self.chunks( + Default::default()..self.len(), + LanguageAwareStyling { + tree_sitter: false, + diagnostics: false, + }, + Highlights::default(), + ) + .map(|chunk| chunk.chunk.text) + .collect() } #[ztracing::instrument(skip_all)] @@ -1979,7 +1986,10 @@ mod tests { let actual_text = inlay_snapshot .chunks( range, - false, + LanguageAwareStyling { + tree_sitter: false, + diagnostics: false, + }, Highlights { text_highlights: Some(&text_highlights), inlay_highlights: Some(&inlay_highlights), @@ -2158,7 +2168,10 @@ mod tests { // Get all chunks and verify their bitmaps let chunks = snapshot.chunks( InlayOffset(MultiBufferOffset(0))..snapshot.len(), - false, + LanguageAwareStyling { + tree_sitter: false, + diagnostics: false, + }, Highlights::default(), ); @@ -2293,7 +2306,10 @@ mod tests { let chunks: Vec<_> = inlay_snapshot .chunks( InlayOffset(MultiBufferOffset(0))..inlay_snapshot.len(), - false, + LanguageAwareStyling { + tree_sitter: false, + diagnostics: false, + }, highlights, ) .collect(); @@ -2408,7 +2424,10 @@ mod tests { let chunks: Vec<_> = inlay_snapshot .chunks( InlayOffset(MultiBufferOffset(0))..inlay_snapshot.len(), - false, + LanguageAwareStyling { + tree_sitter: false, + diagnostics: false, + }, highlights, ) .collect(); diff --git a/crates/editor/src/display_map/tab_map.rs b/crates/editor/src/display_map/tab_map.rs index 187ed8614e01ddb8dcdae930fd484de9594cf63f..bb0e642df380e04fcfa9b9533f027be7171b4975 100644 --- a/crates/editor/src/display_map/tab_map.rs +++ b/crates/editor/src/display_map/tab_map.rs @@ -3,7 +3,7 @@ use super::{ fold_map::{self, Chunk, FoldChunks, FoldEdit, FoldPoint, FoldSnapshot}, }; -use language::Point; +use language::{LanguageAwareStyling, Point}; use multi_buffer::MultiBufferSnapshot; use std::{cmp, num::NonZeroU32, ops::Range}; use sum_tree::Bias; @@ -101,7 +101,10 @@ impl TabMap { let mut last_tab_with_changed_expansion_offset = None; 'outer: for chunk in old_snapshot.fold_snapshot.chunks( fold_edit.old.end..old_end_row_successor_offset, - false, + LanguageAwareStyling { + tree_sitter: false, + diagnostics: false, + }, Highlights::default(), ) { let mut remaining_tabs = chunk.tabs; @@ -244,7 +247,14 @@ impl TabSnapshot { self.max_point() }; let first_line_chars = self - .chunks(range.start..line_end, false, Highlights::default()) + .chunks( + range.start..line_end, + LanguageAwareStyling { + tree_sitter: false, + diagnostics: false, + }, + Highlights::default(), + ) .flat_map(|chunk| chunk.text.chars()) .take_while(|&c| c != '\n') .count() as u32; @@ -254,7 +264,10 @@ impl TabSnapshot { } else { self.chunks( TabPoint::new(range.end.row(), 0)..range.end, - false, + LanguageAwareStyling { + tree_sitter: false, + diagnostics: false, + }, Highlights::default(), ) .flat_map(|chunk| chunk.text.chars()) @@ -274,7 +287,7 @@ impl TabSnapshot { pub(crate) fn chunks<'a>( &'a self, range: Range, - language_aware: bool, + language_aware: LanguageAwareStyling, highlights: Highlights<'a>, ) -> TabChunks<'a> { let (input_start, expanded_char_column, to_next_stop) = @@ -324,7 +337,10 @@ impl TabSnapshot { pub fn text(&self) -> String { self.chunks( TabPoint::zero()..self.max_point(), - false, + LanguageAwareStyling { + tree_sitter: false, + diagnostics: false, + }, Highlights::default(), ) .map(|chunk| chunk.text) @@ -1170,7 +1186,10 @@ mod tests { tab_snapshot .chunks( TabPoint::new(0, ix as u32)..tab_snapshot.max_point(), - false, + LanguageAwareStyling { + tree_sitter: false, + diagnostics: false, + }, Highlights::default(), ) .map(|c| c.text) @@ -1246,8 +1265,14 @@ mod tests { let mut chunks = Vec::new(); let mut was_tab = false; let mut text = String::new(); - for chunk in snapshot.chunks(start..snapshot.max_point(), false, Highlights::default()) - { + for chunk in snapshot.chunks( + start..snapshot.max_point(), + LanguageAwareStyling { + tree_sitter: false, + diagnostics: false, + }, + Highlights::default(), + ) { if chunk.is_tab != was_tab { if !text.is_empty() { chunks.push((mem::take(&mut text), was_tab)); @@ -1296,7 +1321,14 @@ mod tests { // This should not panic. let result: String = tab_snapshot - .chunks(start..end, false, Highlights::default()) + .chunks( + start..end, + LanguageAwareStyling { + tree_sitter: false, + diagnostics: false, + }, + Highlights::default(), + ) .map(|c| c.text) .collect(); assert!(!result.is_empty()); @@ -1354,7 +1386,14 @@ mod tests { let expected_summary = TextSummary::from(expected_text.as_str()); assert_eq!( tabs_snapshot - .chunks(start..end, false, Highlights::default()) + .chunks( + start..end, + LanguageAwareStyling { + tree_sitter: false, + diagnostics: false, + }, + Highlights::default() + ) .map(|c| c.text) .collect::(), expected_text, @@ -1436,7 +1475,10 @@ mod tests { let (_, fold_snapshot) = FoldMap::new(inlay_snapshot); let chunks = fold_snapshot.chunks( FoldOffset(MultiBufferOffset(0))..fold_snapshot.len(), - false, + LanguageAwareStyling { + tree_sitter: false, + diagnostics: false, + }, Default::default(), ); let mut cursor = TabStopCursor::new(chunks); @@ -1598,7 +1640,10 @@ mod tests { let (_, fold_snapshot) = FoldMap::new(inlay_snapshot); let chunks = fold_snapshot.chunks( FoldOffset(MultiBufferOffset(0))..fold_snapshot.len(), - false, + LanguageAwareStyling { + tree_sitter: false, + diagnostics: false, + }, Default::default(), ); let mut cursor = TabStopCursor::new(chunks); diff --git a/crates/editor/src/display_map/wrap_map.rs b/crates/editor/src/display_map/wrap_map.rs index d21642977ed923e15a583dfe767fd566e78c5de9..4ff11b1ef67971c5159a81278a5afaaaea171a28 100644 --- a/crates/editor/src/display_map/wrap_map.rs +++ b/crates/editor/src/display_map/wrap_map.rs @@ -5,7 +5,7 @@ use super::{ tab_map::{self, TabEdit, TabPoint, TabSnapshot}, }; use gpui::{App, AppContext as _, Context, Entity, Font, LineWrapper, Pixels, Task}; -use language::Point; +use language::{LanguageAwareStyling, Point}; use multi_buffer::{MultiBufferSnapshot, RowInfo}; use smol::future::yield_now; use std::{cmp, collections::VecDeque, mem, ops::Range, sync::LazyLock, time::Duration}; @@ -513,7 +513,10 @@ impl WrapSnapshot { let mut remaining = None; let mut chunks = new_tab_snapshot.chunks( TabPoint::new(edit.new_rows.start, 0)..new_tab_snapshot.max_point(), - false, + LanguageAwareStyling { + tree_sitter: false, + diagnostics: false, + }, Highlights::default(), ); let mut edit_transforms = Vec::::new(); @@ -656,7 +659,7 @@ impl WrapSnapshot { pub(crate) fn chunks<'a>( &'a self, rows: Range, - language_aware: bool, + language_aware: LanguageAwareStyling, highlights: Highlights<'a>, ) -> WrapChunks<'a> { let output_start = WrapPoint::new(rows.start, 0); @@ -960,7 +963,10 @@ impl WrapSnapshot { pub fn text_chunks(&self, wrap_row: WrapRow) -> impl Iterator { self.chunks( wrap_row..self.max_point().row() + WrapRow(1), - false, + LanguageAwareStyling { + tree_sitter: false, + diagnostics: false, + }, Highlights::default(), ) .map(|h| h.text) @@ -1719,7 +1725,10 @@ mod tests { let actual_text = self .chunks( WrapRow(start_row)..WrapRow(end_row), - true, + LanguageAwareStyling { + tree_sitter: true, + diagnostics: true, + }, Highlights::default(), ) .map(|c| c.text) diff --git a/crates/editor/src/editor.rs b/crates/editor/src/editor.rs index 6550d79c9f73799d37ccf6433db38f2719636ee6..e6f597de7ff9138b226cd2474353ef8c2ce16ebb 100644 --- a/crates/editor/src/editor.rs +++ b/crates/editor/src/editor.rs @@ -132,9 +132,9 @@ use language::{ AutoindentMode, BlockCommentConfig, BracketMatch, BracketPair, Buffer, BufferRow, BufferSnapshot, Capability, CharClassifier, CharKind, CharScopeContext, CodeLabel, CursorShape, DiagnosticEntryRef, DiffOptions, EditPredictionsMode, EditPreview, HighlightedText, IndentKind, - IndentSize, Language, LanguageName, LanguageRegistry, LanguageScope, LocalFile, OffsetRangeExt, - OutlineItem, Point, Selection, SelectionGoal, TextObject, TransactionId, TreeSitterOptions, - WordsQuery, + IndentSize, Language, LanguageAwareStyling, LanguageName, LanguageRegistry, LanguageScope, + LocalFile, OffsetRangeExt, OutlineItem, Point, Selection, SelectionGoal, TextObject, + TransactionId, TreeSitterOptions, WordsQuery, language_settings::{ self, AllLanguageSettings, LanguageSettings, LspInsertMode, RewrapBehavior, WordsCompletionMode, all_language_settings, @@ -1265,6 +1265,7 @@ pub struct Editor { >, use_autoclose: bool, use_auto_surround: bool, + use_selection_highlight: bool, auto_replace_emoji_shortcode: bool, jsx_tag_auto_close_enabled_in_any_buffer: bool, show_git_blame_gutter: bool, @@ -2468,6 +2469,7 @@ impl Editor { read_only: is_minimap, use_autoclose: true, use_auto_surround: true, + use_selection_highlight: true, auto_replace_emoji_shortcode: false, jsx_tag_auto_close_enabled_in_any_buffer: false, leader_id: None, @@ -3547,6 +3549,10 @@ impl Editor { self.use_autoclose = autoclose; } + pub fn set_use_selection_highlight(&mut self, highlight: bool) { + self.use_selection_highlight = highlight; + } + pub fn set_use_auto_surround(&mut self, auto_surround: bool) { self.use_auto_surround = auto_surround; } @@ -7699,7 +7705,7 @@ impl Editor { if matches!(self.mode, EditorMode::SingleLine) { return None; } - if !EditorSettings::get_global(cx).selection_highlight { + if !self.use_selection_highlight || !EditorSettings::get_global(cx).selection_highlight { return None; } if self.selections.count() != 1 || self.selections.line_mode() { @@ -19147,7 +19153,13 @@ impl Editor { let range = buffer.anchor_before(rename_start)..buffer.anchor_after(rename_end); let mut old_highlight_id = None; let old_name: Arc = buffer - .chunks(rename_start..rename_end, true) + .chunks( + rename_start..rename_end, + LanguageAwareStyling { + tree_sitter: true, + diagnostics: true, + }, + ) .map(|chunk| { if old_highlight_id.is_none() { old_highlight_id = chunk.syntax_highlight_id; @@ -25005,7 +25017,13 @@ impl Editor { selection.range() }; - let chunks = snapshot.chunks(range, true); + let chunks = snapshot.chunks( + range, + LanguageAwareStyling { + tree_sitter: true, + diagnostics: true, + }, + ); let mut lines = Vec::new(); let mut line: VecDeque = VecDeque::new(); diff --git a/crates/editor/src/element.rs b/crates/editor/src/element.rs index 7a532dc7a75ea3583456be6611ef072cd7692bc7..512fbb8855aa11d8c540065a55eb296919012821 100644 --- a/crates/editor/src/element.rs +++ b/crates/editor/src/element.rs @@ -51,7 +51,10 @@ use gpui::{ pattern_slash, point, px, quad, relative, size, solid_background, transparent_black, }; use itertools::Itertools; -use language::{HighlightedText, IndentGuideSettings, language_settings::ShowWhitespaceSetting}; +use language::{ + HighlightedText, IndentGuideSettings, LanguageAwareStyling, + language_settings::ShowWhitespaceSetting, +}; use markdown::Markdown; use multi_buffer::{ Anchor, ExcerptBoundaryInfo, ExpandExcerptDirection, ExpandInfo, MultiBufferPoint, @@ -3819,7 +3822,11 @@ impl EditorElement { } else { let use_tree_sitter = !snapshot.semantic_tokens_enabled || snapshot.use_tree_sitter_for_syntax(rows.start, cx); - let chunks = snapshot.highlighted_chunks(rows.clone(), use_tree_sitter, style); + let language_aware = LanguageAwareStyling { + tree_sitter: use_tree_sitter, + diagnostics: true, + }; + let chunks = snapshot.highlighted_chunks(rows.clone(), language_aware, style); LineWithInvisibles::from_chunks( chunks, style, @@ -11999,7 +12006,11 @@ pub fn layout_line( ) -> LineWithInvisibles { let use_tree_sitter = !snapshot.semantic_tokens_enabled || snapshot.use_tree_sitter_for_syntax(row, cx); - let chunks = snapshot.highlighted_chunks(row..row + DisplayRow(1), use_tree_sitter, style); + let language_aware = LanguageAwareStyling { + tree_sitter: use_tree_sitter, + diagnostics: true, + }; + let chunks = snapshot.highlighted_chunks(row..row + DisplayRow(1), language_aware, style); LineWithInvisibles::from_chunks( chunks, style, diff --git a/crates/editor/src/semantic_tokens.rs b/crates/editor/src/semantic_tokens.rs index 5e78be70d5627bd4f484a3efd44b13519b31b400..d485cfa70237fed542a240f202a8dc47b07467c4 100644 --- a/crates/editor/src/semantic_tokens.rs +++ b/crates/editor/src/semantic_tokens.rs @@ -475,13 +475,17 @@ mod tests { use gpui::{ AppContext as _, Entity, Focusable as _, HighlightStyle, TestAppContext, UpdateGlobal as _, }; - use language::{Language, LanguageConfig, LanguageMatcher}; + use language::{ + Diagnostic, DiagnosticEntry, DiagnosticSet, Language, LanguageAwareStyling, LanguageConfig, + LanguageMatcher, + }; use languages::FakeLspAdapter; + use lsp::LanguageServerId; use multi_buffer::{ AnchorRangeExt, ExpandExcerptDirection, MultiBuffer, MultiBufferOffset, PathKey, }; use project::Project; - use rope::Point; + use rope::{Point, PointUtf16}; use serde_json::json; use settings::{ GlobalLspSettingsContent, LanguageSettingsContent, SemanticTokenRule, SemanticTokenRules, @@ -2088,6 +2092,130 @@ mod tests { ); } + #[gpui::test] + async fn test_diagnostics_visible_when_semantic_token_set_to_full(cx: &mut TestAppContext) { + init_test(cx, |_| {}); + + update_test_language_settings(cx, &|language_settings| { + language_settings.languages.0.insert( + "Rust".into(), + LanguageSettingsContent { + semantic_tokens: Some(SemanticTokens::Full), + ..LanguageSettingsContent::default() + }, + ); + }); + + let mut cx = EditorLspTestContext::new_rust( + lsp::ServerCapabilities { + semantic_tokens_provider: Some( + lsp::SemanticTokensServerCapabilities::SemanticTokensOptions( + lsp::SemanticTokensOptions { + legend: lsp::SemanticTokensLegend { + token_types: vec!["function".into()], + token_modifiers: Vec::new(), + }, + full: Some(lsp::SemanticTokensFullOptions::Delta { delta: None }), + ..lsp::SemanticTokensOptions::default() + }, + ), + ), + ..lsp::ServerCapabilities::default() + }, + cx, + ) + .await; + + let mut full_request = cx + .set_request_handler::( + move |_, _, _| { + async move { + Ok(Some(lsp::SemanticTokensResult::Tokens( + lsp::SemanticTokens { + data: vec![ + 0, // delta_line + 3, // delta_start + 4, // length + 0, // token_type + 0, // token_modifiers_bitset + ], + result_id: Some("a".into()), + }, + ))) + } + }, + ); + + cx.set_state("ˇfn main() {}"); + assert!(full_request.next().await.is_some()); + + let task = cx.update_editor(|e, _, _| e.semantic_token_state.take_update_task()); + task.await; + + cx.update_buffer(|buffer, cx| { + buffer.update_diagnostics( + LanguageServerId(0), + DiagnosticSet::new( + [DiagnosticEntry { + range: PointUtf16::new(0, 3)..PointUtf16::new(0, 7), + diagnostic: Diagnostic { + severity: lsp::DiagnosticSeverity::ERROR, + group_id: 1, + message: "unused function".into(), + ..Default::default() + }, + }], + buffer, + ), + cx, + ) + }); + + cx.run_until_parked(); + let chunks = cx.update_editor(|editor, window, cx| { + editor + .snapshot(window, cx) + .display_snapshot + .chunks( + crate::display_map::DisplayRow(0)..crate::display_map::DisplayRow(1), + LanguageAwareStyling { + tree_sitter: false, + diagnostics: true, + }, + crate::HighlightStyles::default(), + ) + .map(|chunk| { + ( + chunk.text.to_string(), + chunk.diagnostic_severity, + chunk.highlight_style, + ) + }) + .collect::>() + }); + + assert_eq!( + extract_semantic_highlights(&cx.editor, &cx), + vec![MultiBufferOffset(3)..MultiBufferOffset(7)] + ); + + assert!( + chunks.iter().any( + |(text, severity, style): &( + String, + Option, + Option + )| { + text == "main" + && *severity == Some(lsp::DiagnosticSeverity::ERROR) + && style.is_some() + } + ), + "expected 'main' chunk to have both diagnostic and semantic styling: {:?}", + chunks + ); + } + fn extract_semantic_highlight_styles( editor: &Entity, cx: &TestAppContext, diff --git a/crates/env_var/Cargo.toml b/crates/env_var/Cargo.toml index 2cbbd08c7833d3e57a09766d42ffffe35c620a93..3c879a2f49184e19a131046320d767931e1ca8ec 100644 --- a/crates/env_var/Cargo.toml +++ b/crates/env_var/Cargo.toml @@ -12,4 +12,4 @@ workspace = true path = "src/env_var.rs" [dependencies] -gpui.workspace = true +gpui_shared_string.workspace = true diff --git a/crates/env_var/src/env_var.rs b/crates/env_var/src/env_var.rs index 79f671e0147ebfaad4ab76a123cc477dc7e55cb7..cb436e95e0e734e4b7d8d271199246e1558a074d 100644 --- a/crates/env_var/src/env_var.rs +++ b/crates/env_var/src/env_var.rs @@ -1,4 +1,4 @@ -use gpui::SharedString; +use gpui_shared_string::SharedString; #[derive(Clone)] pub struct EnvVar { diff --git a/crates/file_finder/Cargo.toml b/crates/file_finder/Cargo.toml index 5eb36f0f5150263629b407dbe07dc73b6eff31cf..67ebab62295e8db90a12f99cbc05e9b9e56c2c6b 100644 --- a/crates/file_finder/Cargo.toml +++ b/crates/file_finder/Cargo.toml @@ -21,6 +21,7 @@ editor.workspace = true file_icons.workspace = true futures.workspace = true fuzzy.workspace = true +fuzzy_nucleo.workspace = true gpui.workspace = true menu.workspace = true open_path_prompt.workspace = true diff --git a/crates/file_finder/src/file_finder.rs b/crates/file_finder/src/file_finder.rs index 4302669ddc11c94f7df128534217d00c27ef083a..a4d9ea042dea898b9dd9db7d40354cf960d210d5 100644 --- a/crates/file_finder/src/file_finder.rs +++ b/crates/file_finder/src/file_finder.rs @@ -9,7 +9,8 @@ use client::ChannelId; use collections::HashMap; use editor::Editor; use file_icons::FileIcons; -use fuzzy::{CharBag, PathMatch, PathMatchCandidate, StringMatch, StringMatchCandidate}; +use fuzzy::{StringMatch, StringMatchCandidate}; +use fuzzy_nucleo::{PathMatch, PathMatchCandidate}; use gpui::{ Action, AnyElement, App, Context, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, KeyContext, Modifiers, ModifiersChangedEvent, ParentElement, Render, Styled, Task, WeakEntity, @@ -663,15 +664,6 @@ impl Matches { // For file-vs-file matches, use the existing detailed comparison. if let (Some(a_panel), Some(b_panel)) = (a.panel_match(), b.panel_match()) { - let a_in_filename = Self::is_filename_match(a_panel); - let b_in_filename = Self::is_filename_match(b_panel); - - match (a_in_filename, b_in_filename) { - (true, false) => return cmp::Ordering::Greater, - (false, true) => return cmp::Ordering::Less, - _ => {} - } - return a_panel.cmp(b_panel); } @@ -691,32 +683,6 @@ impl Matches { Match::CreateNew(_) => 0.0, } } - - /// Determines if the match occurred within the filename rather than in the path - fn is_filename_match(panel_match: &ProjectPanelOrdMatch) -> bool { - if panel_match.0.positions.is_empty() { - return false; - } - - if let Some(filename) = panel_match.0.path.file_name() { - let path_str = panel_match.0.path.as_unix_str(); - - if let Some(filename_pos) = path_str.rfind(filename) - && panel_match.0.positions[0] >= filename_pos - { - let mut prev_position = panel_match.0.positions[0]; - for p in &panel_match.0.positions[1..] { - if *p != prev_position + 1 { - return false; - } - prev_position = *p; - } - return true; - } - } - - false - } } fn matching_history_items<'a>( @@ -731,25 +697,16 @@ fn matching_history_items<'a>( let history_items_by_worktrees = history_items .into_iter() .chain(currently_opened) - .filter_map(|found_path| { + .map(|found_path| { let candidate = PathMatchCandidate { is_dir: false, // You can't open directories as project items path: &found_path.project.path, // Only match history items names, otherwise their paths may match too many queries, producing false positives. // E.g. `foo` would match both `something/foo/bar.rs` and `something/foo/foo.rs` and if the former is a history item, // it would be shown first always, despite the latter being a better match. - char_bag: CharBag::from_iter( - found_path - .project - .path - .file_name()? - .to_string() - .to_lowercase() - .chars(), - ), }; candidates_paths.insert(&found_path.project, found_path); - Some((found_path.project.worktree_id, candidate)) + (found_path.project.worktree_id, candidate) }) .fold( HashMap::default(), @@ -767,8 +724,9 @@ fn matching_history_items<'a>( let worktree_root_name = worktree_name_by_id .as_ref() .and_then(|w| w.get(&worktree).cloned()); + matching_history_paths.extend( - fuzzy::match_fixed_path_set( + fuzzy_nucleo::match_fixed_path_set( candidates, worktree.to_usize(), worktree_root_name, @@ -778,6 +736,18 @@ fn matching_history_items<'a>( path_style, ) .into_iter() + // filter matches where at least one matched position is in filename portion, to prevent directory matches, nucleo scores them higher as history items are matched against their full path + .filter(|path_match| { + if let Some(filename) = path_match.path.file_name() { + let filename_start = path_match.path.as_unix_str().len() - filename.len(); + path_match + .positions + .iter() + .any(|&pos| pos >= filename_start) + } else { + true + } + }) .filter_map(|path_match| { candidates_paths .remove_entry(&ProjectPath { @@ -940,7 +910,7 @@ impl FileFinderDelegate { self.cancel_flag = Arc::new(AtomicBool::new(false)); let cancel_flag = self.cancel_flag.clone(); cx.spawn_in(window, async move |picker, cx| { - let matches = fuzzy::match_path_sets( + let matches = fuzzy_nucleo::match_path_sets( candidate_sets.as_slice(), query.path_query(), &relative_to, @@ -1452,7 +1422,6 @@ impl PickerDelegate for FileFinderDelegate { window: &mut Window, cx: &mut Context>, ) -> Task<()> { - let raw_query = raw_query.replace(' ', ""); let raw_query = raw_query.trim(); let raw_query = match &raw_query.get(0..2) { diff --git a/crates/file_finder/src/file_finder_tests.rs b/crates/file_finder/src/file_finder_tests.rs index cd9cdeee1ff266717d380aeaecf7cbeb66ec8309..7a17202a5e4ba96b001ea46ed310518d02baf1ff 100644 --- a/crates/file_finder/src/file_finder_tests.rs +++ b/crates/file_finder/src/file_finder_tests.rs @@ -4161,3 +4161,233 @@ async fn test_clear_navigation_history(cx: &mut TestAppContext) { "Should have no history items after clearing" ); } + +#[gpui::test] +async fn test_order_independent_search(cx: &mut TestAppContext) { + let app_state = init_test(cx); + app_state + .fs + .as_fake() + .insert_tree( + "/src", + json!({ + "internal": { + "auth": { + "login.rs": "", + } + } + }), + ) + .await; + let project = Project::test(app_state.fs.clone(), ["/src".as_ref()], cx).await; + let (picker, _, cx) = build_find_picker(project, cx); + + // forward order + picker + .update_in(cx, |picker, window, cx| { + picker + .delegate + .spawn_search(test_path_position("auth internal"), window, cx) + }) + .await; + picker.update(cx, |picker, _| { + let matches = collect_search_matches(picker).search_matches_only(); + assert_eq!(matches.len(), 1); + assert_eq!(matches[0].path.as_unix_str(), "internal/auth/login.rs"); + }); + + // reverse order should give same result + picker + .update_in(cx, |picker, window, cx| { + picker + .delegate + .spawn_search(test_path_position("internal auth"), window, cx) + }) + .await; + picker.update(cx, |picker, _| { + let matches = collect_search_matches(picker).search_matches_only(); + assert_eq!(matches.len(), 1); + assert_eq!(matches[0].path.as_unix_str(), "internal/auth/login.rs"); + }); +} + +#[gpui::test] +async fn test_filename_preferred_over_directory_match(cx: &mut TestAppContext) { + let app_state = init_test(cx); + app_state + .fs + .as_fake() + .insert_tree( + "/src", + json!({ + "crates": { + "settings_ui": { + "src": { + "pages": { + "audio_test_window.rs": "", + "audio_input_output_setup.rs": "", + } + } + }, + "audio": { + "src": { + "audio_settings.rs": "", + } + } + } + }), + ) + .await; + let project = Project::test(app_state.fs.clone(), ["/src".as_ref()], cx).await; + let (picker, _, cx) = build_find_picker(project, cx); + + picker + .update_in(cx, |picker, window, cx| { + picker + .delegate + .spawn_search(test_path_position("settings audio"), window, cx) + }) + .await; + picker.update(cx, |picker, _| { + let matches = collect_search_matches(picker).search_matches_only(); + assert!(!matches.is_empty(),); + assert_eq!( + matches[0].path.as_unix_str(), + "crates/audio/src/audio_settings.rs" + ); + }); +} + +#[gpui::test] +async fn test_start_of_word_preferred_over_scattered_match(cx: &mut TestAppContext) { + let app_state = init_test(cx); + app_state + .fs + .as_fake() + .insert_tree( + "/src", + json!({ + "crates": { + "livekit_client": { + "src": { + "livekit_client": { + "playback.rs": "", + } + } + }, + "vim": { + "test_data": { + "test_record_replay_interleaved.json": "", + } + } + } + }), + ) + .await; + let project = Project::test(app_state.fs.clone(), ["/src".as_ref()], cx).await; + let (picker, _, cx) = build_find_picker(project, cx); + + picker + .update_in(cx, |picker, window, cx| { + picker + .delegate + .spawn_search(test_path_position("live pla"), window, cx) + }) + .await; + picker.update(cx, |picker, _| { + let matches = collect_search_matches(picker).search_matches_only(); + assert!(!matches.is_empty(),); + assert_eq!( + matches[0].path.as_unix_str(), + "crates/livekit_client/src/livekit_client/playback.rs", + ); + }); +} + +#[gpui::test] +async fn test_exact_filename_stem_preferred(cx: &mut TestAppContext) { + let app_state = init_test(cx); + app_state + .fs + .as_fake() + .insert_tree( + "/src", + json!({ + "assets": { + "icons": { + "file_icons": { + "nix.svg": "", + } + } + }, + "crates": { + "zed": { + "resources": { + "app-icon-nightly@2x.png": "", + "app-icon-preview@2x.png": "", + } + } + } + }), + ) + .await; + let project = Project::test(app_state.fs.clone(), ["/src".as_ref()], cx).await; + let (picker, _, cx) = build_find_picker(project, cx); + + picker + .update_in(cx, |picker, window, cx| { + picker + .delegate + .spawn_search(test_path_position("nix icon"), window, cx) + }) + .await; + picker.update(cx, |picker, _| { + let matches = collect_search_matches(picker).search_matches_only(); + assert!(!matches.is_empty(),); + assert_eq!( + matches[0].path.as_unix_str(), + "assets/icons/file_icons/nix.svg", + ); + }); +} + +#[gpui::test] +async fn test_exact_filename_with_directory_token(cx: &mut TestAppContext) { + let app_state = init_test(cx); + app_state + .fs + .as_fake() + .insert_tree( + "/src", + json!({ + "crates": { + "agent_servers": { + "src": { + "acp.rs": "", + "agent_server.rs": "", + "custom.rs": "", + } + } + } + }), + ) + .await; + let project = Project::test(app_state.fs.clone(), ["/src".as_ref()], cx).await; + let (picker, _, cx) = build_find_picker(project, cx); + + picker + .update_in(cx, |picker, window, cx| { + picker + .delegate + .spawn_search(test_path_position("acp server"), window, cx) + }) + .await; + picker.update(cx, |picker, _| { + let matches = collect_search_matches(picker).search_matches_only(); + assert!(!matches.is_empty(),); + assert_eq!( + matches[0].path.as_unix_str(), + "crates/agent_servers/src/acp.rs", + ); + }); +} diff --git a/crates/fs/src/fake_git_repo.rs b/crates/fs/src/fake_git_repo.rs index 751796fb83164b78dc5d6789f0ae7870eff16ce1..7b89a0751f17ef8c2bba837882f2a31c7d5451e5 100644 --- a/crates/fs/src/fake_git_repo.rs +++ b/crates/fs/src/fake_git_repo.rs @@ -6,9 +6,10 @@ use git::{ Oid, RunHook, blame::Blame, repository::{ - AskPassDelegate, Branch, CommitDataReader, CommitDetails, CommitOptions, FetchOptions, - GRAPH_CHUNK_SIZE, GitRepository, GitRepositoryCheckpoint, InitialGraphCommitData, LogOrder, - LogSource, PushOptions, Remote, RepoPath, ResetMode, SearchCommitArgs, Worktree, + AskPassDelegate, Branch, CommitDataReader, CommitDetails, CommitOptions, + CreateWorktreeTarget, FetchOptions, GRAPH_CHUNK_SIZE, GitRepository, + GitRepositoryCheckpoint, InitialGraphCommitData, LogOrder, LogSource, PushOptions, Remote, + RepoPath, ResetMode, SearchCommitArgs, Worktree, }, stash::GitStash, status::{ @@ -60,6 +61,7 @@ pub struct FakeGitRepositoryState { pub remotes: HashMap, pub simulated_index_write_error_message: Option, pub simulated_create_worktree_error: Option, + pub simulated_graph_error: Option, pub refs: HashMap, pub graph_commits: Vec>, pub stash_entries: GitStash, @@ -77,6 +79,7 @@ impl FakeGitRepositoryState { branches: Default::default(), simulated_index_write_error_message: Default::default(), simulated_create_worktree_error: Default::default(), + simulated_graph_error: None, refs: HashMap::from_iter([("HEAD".into(), "abc".into())]), merge_base_contents: Default::default(), oids: Default::default(), @@ -540,9 +543,8 @@ impl GitRepository for FakeGitRepository { fn create_worktree( &self, - branch_name: Option, + target: CreateWorktreeTarget, path: PathBuf, - from_commit: Option, ) -> BoxFuture<'_, Result<()>> { let fs = self.fs.clone(); let executor = self.executor.clone(); @@ -550,30 +552,82 @@ impl GitRepository for FakeGitRepository { let common_dir_path = self.common_dir_path.clone(); async move { executor.simulate_random_delay().await; - // Check for simulated error and duplicate branch before any side effects. - fs.with_git_state(&dot_git_path, false, |state| { - if let Some(message) = &state.simulated_create_worktree_error { - anyhow::bail!("{message}"); - } - if let Some(ref name) = branch_name { - if state.branches.contains(name) { - bail!("a branch named '{}' already exists", name); + + let branch_name = target.branch_name().map(ToOwned::to_owned); + let create_branch_ref = matches!(target, CreateWorktreeTarget::NewBranch { .. }); + + // Check for simulated error and validate branch state before any side effects. + fs.with_git_state(&dot_git_path, false, { + let branch_name = branch_name.clone(); + move |state| { + if let Some(message) = &state.simulated_create_worktree_error { + anyhow::bail!("{message}"); + } + + match (create_branch_ref, branch_name.as_ref()) { + (true, Some(branch_name)) => { + if state.branches.contains(branch_name) { + bail!("a branch named '{}' already exists", branch_name); + } + } + (false, Some(branch_name)) => { + if !state.branches.contains(branch_name) { + bail!("no branch named '{}' exists", branch_name); + } + } + (false, None) => {} + (true, None) => bail!("branch name is required to create a branch"), } + + Ok(()) } - Ok(()) })??; + let (branch_name, sha, create_branch_ref) = match target { + CreateWorktreeTarget::ExistingBranch { branch_name } => { + let ref_name = format!("refs/heads/{branch_name}"); + let sha = fs.with_git_state(&dot_git_path, false, { + move |state| { + Ok::<_, anyhow::Error>( + state + .refs + .get(&ref_name) + .cloned() + .unwrap_or_else(|| "fake-sha".to_string()), + ) + } + })??; + (Some(branch_name), sha, false) + } + CreateWorktreeTarget::NewBranch { + branch_name, + base_sha: start_point, + } => ( + Some(branch_name), + start_point.unwrap_or_else(|| "fake-sha".to_string()), + true, + ), + CreateWorktreeTarget::Detached { + base_sha: start_point, + } => ( + None, + start_point.unwrap_or_else(|| "fake-sha".to_string()), + false, + ), + }; + // Create the worktree checkout directory. fs.create_dir(&path).await?; // Create .git/worktrees// directory with HEAD, commondir, gitdir. - let worktree_entry_name = branch_name - .as_deref() - .unwrap_or_else(|| path.file_name().unwrap().to_str().unwrap()); + let worktree_entry_name = branch_name.as_deref().unwrap_or_else(|| { + path.file_name() + .and_then(|name| name.to_str()) + .unwrap_or("detached") + }); let worktrees_entry_dir = common_dir_path.join("worktrees").join(worktree_entry_name); fs.create_dir(&worktrees_entry_dir).await?; - let sha = from_commit.unwrap_or_else(|| "fake-sha".to_string()); let head_content = if let Some(ref branch_name) = branch_name { let ref_name = format!("refs/heads/{branch_name}"); format!("ref: {ref_name}") @@ -604,15 +658,22 @@ impl GitRepository for FakeGitRepository { false, )?; - // Update git state: add ref and branch. - fs.with_git_state(&dot_git_path, true, move |state| { - if let Some(branch_name) = branch_name { - let ref_name = format!("refs/heads/{branch_name}"); - state.refs.insert(ref_name, sha); - state.branches.insert(branch_name); - } - Ok::<(), anyhow::Error>(()) - })??; + // Update git state for newly created branches. + if create_branch_ref { + fs.with_git_state(&dot_git_path, true, { + let branch_name = branch_name.clone(); + let sha = sha.clone(); + move |state| { + if let Some(branch_name) = branch_name { + let ref_name = format!("refs/heads/{branch_name}"); + state.refs.insert(ref_name, sha); + state.branches.insert(branch_name); + } + Ok::<(), anyhow::Error>(()) + } + })??; + } + Ok(()) } .boxed() @@ -1268,8 +1329,17 @@ impl GitRepository for FakeGitRepository { let fs = self.fs.clone(); let dot_git_path = self.dot_git_path.clone(); async move { - let graph_commits = - fs.with_git_state(&dot_git_path, false, |state| state.graph_commits.clone())?; + let (graph_commits, simulated_error) = + fs.with_git_state(&dot_git_path, false, |state| { + ( + state.graph_commits.clone(), + state.simulated_graph_error.clone(), + ) + })?; + + if let Some(error) = simulated_error { + anyhow::bail!("{}", error); + } for chunk in graph_commits.chunks(GRAPH_CHUNK_SIZE) { request_tx.send(chunk.to_vec()).await.ok(); diff --git a/crates/fs/src/fs.rs b/crates/fs/src/fs.rs index a26abb81255003e4059f9bcc8a68aa3c6212a73a..52cae537b6f00837b50123af0cae7c093699dedf 100644 --- a/crates/fs/src/fs.rs +++ b/crates/fs/src/fs.rs @@ -2168,6 +2168,13 @@ impl FakeFs { .unwrap(); } + pub fn set_graph_error(&self, dot_git: &Path, error: Option) { + self.with_git_state(dot_git, true, |state| { + state.simulated_graph_error = error; + }) + .unwrap(); + } + /// Put the given git repository into a state with the given status, /// by mutating the head, index, and unmerged state. pub fn set_status_for_repo(&self, dot_git: &Path, statuses: &[(&str, FileStatus)]) { diff --git a/crates/fs/tests/integration/fake_git_repo.rs b/crates/fs/tests/integration/fake_git_repo.rs index f4192a22bb42f88f8769ef59f817b2bf2a288fb9..3be81ad7301e6fc4ee6f4529ce8bb587de3b4565 100644 --- a/crates/fs/tests/integration/fake_git_repo.rs +++ b/crates/fs/tests/integration/fake_git_repo.rs @@ -24,9 +24,11 @@ async fn test_fake_worktree_lifecycle(cx: &mut TestAppContext) { // Create a worktree let worktree_1_dir = worktrees_dir.join("feature-branch"); repo.create_worktree( - Some("feature-branch".to_string()), + git::repository::CreateWorktreeTarget::NewBranch { + branch_name: "feature-branch".to_string(), + base_sha: Some("abc123".to_string()), + }, worktree_1_dir.clone(), - Some("abc123".to_string()), ) .await .unwrap(); @@ -48,9 +50,11 @@ async fn test_fake_worktree_lifecycle(cx: &mut TestAppContext) { // Create a second worktree (without explicit commit) let worktree_2_dir = worktrees_dir.join("bugfix-branch"); repo.create_worktree( - Some("bugfix-branch".to_string()), + git::repository::CreateWorktreeTarget::NewBranch { + branch_name: "bugfix-branch".to_string(), + base_sha: None, + }, worktree_2_dir.clone(), - None, ) .await .unwrap(); diff --git a/crates/fuzzy_nucleo/Cargo.toml b/crates/fuzzy_nucleo/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..59e8b642524777f449f79edba85093eef069ebff --- /dev/null +++ b/crates/fuzzy_nucleo/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "fuzzy_nucleo" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[lib] +path = "src/fuzzy_nucleo.rs" +doctest = false + +[dependencies] +nucleo.workspace = true +gpui.workspace = true +util.workspace = true + +[dev-dependencies] +util = {workspace = true, features = ["test-support"]} diff --git a/crates/fuzzy_nucleo/LICENSE-GPL b/crates/fuzzy_nucleo/LICENSE-GPL new file mode 120000 index 0000000000000000000000000000000000000000..89e542f750cd3860a0598eff0dc34b56d7336dc4 --- /dev/null +++ b/crates/fuzzy_nucleo/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/fuzzy_nucleo/src/fuzzy_nucleo.rs b/crates/fuzzy_nucleo/src/fuzzy_nucleo.rs new file mode 100644 index 0000000000000000000000000000000000000000..ddaa5c3489cf55d41d31440f037214b1dce0358c --- /dev/null +++ b/crates/fuzzy_nucleo/src/fuzzy_nucleo.rs @@ -0,0 +1,5 @@ +mod matcher; +mod paths; +pub use paths::{ + PathMatch, PathMatchCandidate, PathMatchCandidateSet, match_fixed_path_set, match_path_sets, +}; diff --git a/crates/fuzzy_nucleo/src/matcher.rs b/crates/fuzzy_nucleo/src/matcher.rs new file mode 100644 index 0000000000000000000000000000000000000000..b31da011106341420095bcffbfd012f40014ad6c --- /dev/null +++ b/crates/fuzzy_nucleo/src/matcher.rs @@ -0,0 +1,39 @@ +use std::sync::Mutex; + +static MATCHERS: Mutex> = Mutex::new(Vec::new()); + +pub const LENGTH_PENALTY: f64 = 0.01; + +pub fn get_matcher(config: nucleo::Config) -> nucleo::Matcher { + let mut matchers = MATCHERS.lock().unwrap(); + match matchers.pop() { + Some(mut matcher) => { + matcher.config = config; + matcher + } + None => nucleo::Matcher::new(config), + } +} + +pub fn return_matcher(matcher: nucleo::Matcher) { + MATCHERS.lock().unwrap().push(matcher); +} + +pub fn get_matchers(n: usize, config: nucleo::Config) -> Vec { + let mut matchers: Vec<_> = { + let mut pool = MATCHERS.lock().unwrap(); + let available = pool.len().min(n); + pool.drain(..available) + .map(|mut matcher| { + matcher.config = config.clone(); + matcher + }) + .collect() + }; + matchers.resize_with(n, || nucleo::Matcher::new(config.clone())); + matchers +} + +pub fn return_matchers(mut matchers: Vec) { + MATCHERS.lock().unwrap().append(&mut matchers); +} diff --git a/crates/fuzzy_nucleo/src/paths.rs b/crates/fuzzy_nucleo/src/paths.rs new file mode 100644 index 0000000000000000000000000000000000000000..ac766622c9d12c6e2a119fbcd7dd7fe7a3b5a90d --- /dev/null +++ b/crates/fuzzy_nucleo/src/paths.rs @@ -0,0 +1,352 @@ +use gpui::BackgroundExecutor; +use std::{ + cmp::Ordering, + sync::{ + Arc, + atomic::{self, AtomicBool}, + }, +}; +use util::{paths::PathStyle, rel_path::RelPath}; + +use nucleo::Utf32Str; +use nucleo::pattern::{Atom, AtomKind, CaseMatching, Normalization}; + +use crate::matcher::{self, LENGTH_PENALTY}; + +#[derive(Clone, Debug)] +pub struct PathMatchCandidate<'a> { + pub is_dir: bool, + pub path: &'a RelPath, +} + +#[derive(Clone, Debug)] +pub struct PathMatch { + pub score: f64, + pub positions: Vec, + pub worktree_id: usize, + pub path: Arc, + pub path_prefix: Arc, + pub is_dir: bool, + /// Number of steps removed from a shared parent with the relative path + /// Used to order closer paths first in the search list + pub distance_to_relative_ancestor: usize, +} + +pub trait PathMatchCandidateSet<'a>: Send + Sync { + type Candidates: Iterator>; + fn id(&self) -> usize; + fn len(&self) -> usize; + fn is_empty(&self) -> bool { + self.len() == 0 + } + fn root_is_file(&self) -> bool; + fn prefix(&self) -> Arc; + fn candidates(&'a self, start: usize) -> Self::Candidates; + fn path_style(&self) -> PathStyle; +} + +impl PartialEq for PathMatch { + fn eq(&self, other: &Self) -> bool { + self.cmp(other).is_eq() + } +} + +impl Eq for PathMatch {} + +impl PartialOrd for PathMatch { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for PathMatch { + fn cmp(&self, other: &Self) -> Ordering { + self.score + .partial_cmp(&other.score) + .unwrap_or(Ordering::Equal) + .then_with(|| self.worktree_id.cmp(&other.worktree_id)) + .then_with(|| { + other + .distance_to_relative_ancestor + .cmp(&self.distance_to_relative_ancestor) + }) + .then_with(|| self.path.cmp(&other.path)) + } +} + +fn make_atoms(query: &str, smart_case: bool) -> Vec { + let case = if smart_case { + CaseMatching::Smart + } else { + CaseMatching::Ignore + }; + query + .split_whitespace() + .map(|word| Atom::new(word, case, Normalization::Smart, AtomKind::Fuzzy, false)) + .collect() +} + +pub(crate) fn distance_between_paths(path: &RelPath, relative_to: &RelPath) -> usize { + let mut path_components = path.components(); + let mut relative_components = relative_to.components(); + + while path_components + .next() + .zip(relative_components.next()) + .map(|(path_component, relative_component)| path_component == relative_component) + .unwrap_or_default() + {} + path_components.count() + relative_components.count() + 1 +} + +fn get_filename_match_bonus( + candidate_buf: &str, + query_atoms: &[Atom], + matcher: &mut nucleo::Matcher, +) -> f64 { + let filename = match std::path::Path::new(candidate_buf).file_name() { + Some(f) => f.to_str().unwrap_or(""), + None => return 0.0, + }; + if filename.is_empty() || query_atoms.is_empty() { + return 0.0; + } + let mut buf = Vec::new(); + let haystack = Utf32Str::new(filename, &mut buf); + let mut total_score = 0u32; + for atom in query_atoms { + if let Some(score) = atom.score(haystack, matcher) { + total_score = total_score.saturating_add(score as u32); + } + } + total_score as f64 / filename.len().max(1) as f64 +} +struct Cancelled; + +fn path_match_helper<'a>( + matcher: &mut nucleo::Matcher, + atoms: &[Atom], + candidates: impl Iterator>, + results: &mut Vec, + worktree_id: usize, + path_prefix: &Arc, + root_is_file: bool, + relative_to: &Option>, + path_style: PathStyle, + cancel_flag: &AtomicBool, +) -> Result<(), Cancelled> { + let mut candidate_buf = if !path_prefix.is_empty() && !root_is_file { + let mut s = path_prefix.display(path_style).to_string(); + s.push_str(path_style.primary_separator()); + s + } else { + String::new() + }; + let path_prefix_len = candidate_buf.len(); + let mut buf = Vec::new(); + let mut matched_chars: Vec = Vec::new(); + let mut atom_matched_chars = Vec::new(); + for candidate in candidates { + buf.clear(); + matched_chars.clear(); + if cancel_flag.load(atomic::Ordering::Relaxed) { + return Err(Cancelled); + } + + candidate_buf.truncate(path_prefix_len); + if root_is_file { + candidate_buf.push_str(path_prefix.as_unix_str()); + } else { + candidate_buf.push_str(candidate.path.as_unix_str()); + } + + let haystack = Utf32Str::new(&candidate_buf, &mut buf); + + let mut total_score: u32 = 0; + let mut all_matched = true; + + for atom in atoms { + atom_matched_chars.clear(); + if let Some(score) = atom.indices(haystack, matcher, &mut atom_matched_chars) { + total_score = total_score.saturating_add(score as u32); + matched_chars.extend_from_slice(&atom_matched_chars); + } else { + all_matched = false; + break; + } + } + + if all_matched && !atoms.is_empty() { + matched_chars.sort_unstable(); + matched_chars.dedup(); + + let length_penalty = candidate_buf.len() as f64 * LENGTH_PENALTY; + let filename_bonus = get_filename_match_bonus(&candidate_buf, atoms, matcher); + let adjusted_score = total_score as f64 + filename_bonus - length_penalty; + let mut positions: Vec = candidate_buf + .char_indices() + .enumerate() + .filter_map(|(char_offset, (byte_offset, _))| { + matched_chars + .contains(&(char_offset as u32)) + .then_some(byte_offset) + }) + .collect(); + positions.sort_unstable(); + + results.push(PathMatch { + score: adjusted_score, + positions, + worktree_id, + path: if root_is_file { + Arc::clone(path_prefix) + } else { + candidate.path.into() + }, + path_prefix: if root_is_file { + RelPath::empty().into() + } else { + Arc::clone(path_prefix) + }, + is_dir: candidate.is_dir, + distance_to_relative_ancestor: relative_to + .as_ref() + .map_or(usize::MAX, |relative_to| { + distance_between_paths(candidate.path, relative_to.as_ref()) + }), + }); + } + } + Ok(()) +} + +pub fn match_fixed_path_set( + candidates: Vec, + worktree_id: usize, + worktree_root_name: Option>, + query: &str, + smart_case: bool, + max_results: usize, + path_style: PathStyle, +) -> Vec { + let mut config = nucleo::Config::DEFAULT; + config.set_match_paths(); + let mut matcher = matcher::get_matcher(config); + + let atoms = make_atoms(query, smart_case); + + let root_is_file = worktree_root_name.is_some() && candidates.iter().all(|c| c.path.is_empty()); + + let path_prefix = worktree_root_name.unwrap_or_else(|| RelPath::empty().into()); + + let mut results = Vec::new(); + + path_match_helper( + &mut matcher, + &atoms, + candidates.into_iter(), + &mut results, + worktree_id, + &path_prefix, + root_is_file, + &None, + path_style, + &AtomicBool::new(false), + ) + .ok(); + util::truncate_to_bottom_n_sorted_by(&mut results, max_results, &|a, b| b.cmp(a)); + matcher::return_matcher(matcher); + results +} + +pub async fn match_path_sets<'a, Set: PathMatchCandidateSet<'a>>( + candidate_sets: &'a [Set], + query: &str, + relative_to: &Option>, + smart_case: bool, + max_results: usize, + cancel_flag: &AtomicBool, + executor: BackgroundExecutor, +) -> Vec { + let path_count: usize = candidate_sets.iter().map(|s| s.len()).sum(); + if path_count == 0 { + return Vec::new(); + } + + let path_style = candidate_sets[0].path_style(); + + let query = if path_style.is_windows() { + query.replace('\\', "/") + } else { + query.to_owned() + }; + + let atoms = make_atoms(&query, smart_case); + + let num_cpus = executor.num_cpus().min(path_count); + let segment_size = path_count.div_ceil(num_cpus); + let mut segment_results = (0..num_cpus) + .map(|_| Vec::with_capacity(max_results)) + .collect::>(); + let mut config = nucleo::Config::DEFAULT; + config.set_match_paths(); + let mut matchers = matcher::get_matchers(num_cpus, config); + executor + .scoped(|scope| { + for (segment_idx, (results, matcher)) in segment_results + .iter_mut() + .zip(matchers.iter_mut()) + .enumerate() + { + let atoms = atoms.clone(); + let relative_to = relative_to.clone(); + scope.spawn(async move { + let segment_start = segment_idx * segment_size; + let segment_end = segment_start + segment_size; + + let mut tree_start = 0; + for candidate_set in candidate_sets { + let tree_end = tree_start + candidate_set.len(); + + if tree_start < segment_end && segment_start < tree_end { + let start = tree_start.max(segment_start) - tree_start; + let end = tree_end.min(segment_end) - tree_start; + let candidates = candidate_set.candidates(start).take(end - start); + + if path_match_helper( + matcher, + &atoms, + candidates, + results, + candidate_set.id(), + &candidate_set.prefix(), + candidate_set.root_is_file(), + &relative_to, + path_style, + cancel_flag, + ) + .is_err() + { + break; + } + } + + if tree_end >= segment_end { + break; + } + tree_start = tree_end; + } + }); + } + }) + .await; + + matcher::return_matchers(matchers); + if cancel_flag.load(atomic::Ordering::Acquire) { + return Vec::new(); + } + + let mut results = segment_results.concat(); + util::truncate_to_bottom_n_sorted_by(&mut results, max_results, &|a, b| b.cmp(a)); + results +} diff --git a/crates/git/src/repository.rs b/crates/git/src/repository.rs index c42d2e28cf041e40404c1b8276ddcf5d10ca5f01..d7049c0a50cb94c049556e395e818dbbddfb89bf 100644 --- a/crates/git/src/repository.rs +++ b/crates/git/src/repository.rs @@ -241,20 +241,57 @@ pub struct Worktree { pub is_main: bool, } +/// Describes how a new worktree should choose or create its checked-out HEAD. +#[derive(Clone, Debug, Hash, PartialEq, Eq)] +pub enum CreateWorktreeTarget { + /// Check out an existing local branch in the new worktree. + ExistingBranch { + /// The existing local branch to check out. + branch_name: String, + }, + /// Create a new local branch for the new worktree. + NewBranch { + /// The new local branch to create and check out. + branch_name: String, + /// The commit or ref to create the branch from. Uses `HEAD` when `None`. + base_sha: Option, + }, + /// Check out a commit or ref in detached HEAD state. + Detached { + /// The commit or ref to check out. Uses `HEAD` when `None`. + base_sha: Option, + }, +} + +impl CreateWorktreeTarget { + pub fn branch_name(&self) -> Option<&str> { + match self { + Self::ExistingBranch { branch_name } | Self::NewBranch { branch_name, .. } => { + Some(branch_name) + } + Self::Detached { .. } => None, + } + } +} + impl Worktree { + /// Returns the branch name if the worktree is attached to a branch. + pub fn branch_name(&self) -> Option<&str> { + self.ref_name.as_ref().map(|ref_name| { + ref_name + .strip_prefix("refs/heads/") + .or_else(|| ref_name.strip_prefix("refs/remotes/")) + .unwrap_or(ref_name) + }) + } + /// Returns a display name for the worktree, suitable for use in the UI. /// /// If the worktree is attached to a branch, returns the branch name. /// Otherwise, returns the short SHA of the worktree's HEAD commit. pub fn display_name(&self) -> &str { - match self.ref_name { - Some(ref ref_name) => ref_name - .strip_prefix("refs/heads/") - .or_else(|| ref_name.strip_prefix("refs/remotes/")) - .unwrap_or(ref_name), - // Detached HEAD — show the short SHA as a fallback. - None => &self.sha[..self.sha.len().min(SHORT_SHA_LENGTH)], - } + self.branch_name() + .unwrap_or(&self.sha[..self.sha.len().min(SHORT_SHA_LENGTH)]) } } @@ -716,9 +753,8 @@ pub trait GitRepository: Send + Sync { fn create_worktree( &self, - branch_name: Option, + target: CreateWorktreeTarget, path: PathBuf, - from_commit: Option, ) -> BoxFuture<'_, Result<()>>; fn remove_worktree(&self, path: PathBuf, force: bool) -> BoxFuture<'_, Result<()>>; @@ -1667,24 +1703,36 @@ impl GitRepository for RealGitRepository { fn create_worktree( &self, - branch_name: Option, + target: CreateWorktreeTarget, path: PathBuf, - from_commit: Option, ) -> BoxFuture<'_, Result<()>> { let git_binary = self.git_binary(); let mut args = vec![OsString::from("worktree"), OsString::from("add")]; - if let Some(branch_name) = &branch_name { - args.push(OsString::from("-b")); - args.push(OsString::from(branch_name.as_str())); - } else { - args.push(OsString::from("--detach")); - } - args.push(OsString::from("--")); - args.push(OsString::from(path.as_os_str())); - if let Some(from_commit) = from_commit { - args.push(OsString::from(from_commit)); - } else { - args.push(OsString::from("HEAD")); + + match &target { + CreateWorktreeTarget::ExistingBranch { branch_name } => { + args.push(OsString::from("--")); + args.push(OsString::from(path.as_os_str())); + args.push(OsString::from(branch_name)); + } + CreateWorktreeTarget::NewBranch { + branch_name, + base_sha: start_point, + } => { + args.push(OsString::from("-b")); + args.push(OsString::from(branch_name)); + args.push(OsString::from("--")); + args.push(OsString::from(path.as_os_str())); + args.push(OsString::from(start_point.as_deref().unwrap_or("HEAD"))); + } + CreateWorktreeTarget::Detached { + base_sha: start_point, + } => { + args.push(OsString::from("--detach")); + args.push(OsString::from("--")); + args.push(OsString::from(path.as_os_str())); + args.push(OsString::from(start_point.as_deref().unwrap_or("HEAD"))); + } } self.executor @@ -2736,10 +2784,11 @@ impl GitRepository for RealGitRepository { log_source.get_arg()?, ]); command.stdout(Stdio::piped()); - command.stderr(Stdio::null()); + command.stderr(Stdio::piped()); let mut child = command.spawn()?; let stdout = child.stdout.take().context("failed to get stdout")?; + let stderr = child.stderr.take().context("failed to get stderr")?; let mut reader = BufReader::new(stdout); let mut line_buffer = String::new(); @@ -2774,7 +2823,20 @@ impl GitRepository for RealGitRepository { } } - child.status().await?; + let status = child.status().await?; + if !status.success() { + let mut stderr_output = String::new(); + BufReader::new(stderr) + .read_to_string(&mut stderr_output) + .await + .log_err(); + + if stderr_output.is_empty() { + anyhow::bail!("git log command failed with {}", status); + } else { + anyhow::bail!("git log command failed with {}: {}", status, stderr_output); + } + } Ok(()) } .boxed() @@ -4054,9 +4116,11 @@ mod tests { // Create a new worktree repo.create_worktree( - Some("test-branch".to_string()), + CreateWorktreeTarget::NewBranch { + branch_name: "test-branch".to_string(), + base_sha: Some("HEAD".to_string()), + }, worktree_path.clone(), - Some("HEAD".to_string()), ) .await .unwrap(); @@ -4113,9 +4177,11 @@ mod tests { // Create a worktree let worktree_path = worktrees_dir.join("worktree-to-remove"); repo.create_worktree( - Some("to-remove".to_string()), + CreateWorktreeTarget::NewBranch { + branch_name: "to-remove".to_string(), + base_sha: Some("HEAD".to_string()), + }, worktree_path.clone(), - Some("HEAD".to_string()), ) .await .unwrap(); @@ -4137,9 +4203,11 @@ mod tests { // Create a worktree let worktree_path = worktrees_dir.join("dirty-wt"); repo.create_worktree( - Some("dirty-wt".to_string()), + CreateWorktreeTarget::NewBranch { + branch_name: "dirty-wt".to_string(), + base_sha: Some("HEAD".to_string()), + }, worktree_path.clone(), - Some("HEAD".to_string()), ) .await .unwrap(); @@ -4207,9 +4275,11 @@ mod tests { // Create a worktree let old_path = worktrees_dir.join("old-worktree-name"); repo.create_worktree( - Some("old-name".to_string()), + CreateWorktreeTarget::NewBranch { + branch_name: "old-name".to_string(), + base_sha: Some("HEAD".to_string()), + }, old_path.clone(), - Some("HEAD".to_string()), ) .await .unwrap(); diff --git a/crates/git_graph/src/git_graph.rs b/crates/git_graph/src/git_graph.rs index aa5f6bc6e1293cfd057baa0c5e9f77819da71086..7594a206f14705bf47a673dee9abefad5a3446de 100644 --- a/crates/git_graph/src/git_graph.rs +++ b/crates/git_graph/src/git_graph.rs @@ -2536,11 +2536,19 @@ impl Render for GitGraph { } }; + let error = self.get_repository(cx).and_then(|repo| { + repo.read(cx) + .get_graph_data(self.log_source.clone(), self.log_order) + .and_then(|data| data.error.clone()) + }); + let content = if commit_count == 0 { - let message = if is_loading { - "Loading" + let message = if let Some(error) = &error { + format!("Error loading: {}", error) + } else if is_loading { + "Loading".to_string() } else { - "No commits found" + "No commits found".to_string() }; let label = Label::new(message) .color(Color::Muted) @@ -2552,7 +2560,7 @@ impl Render for GitGraph { .items_center() .justify_center() .child(label) - .when(is_loading, |this| { + .when(is_loading && error.is_none(), |this| { this.child(self.render_loading_spinner(cx)) }) } else { @@ -3757,6 +3765,61 @@ mod tests { ); } + #[gpui::test] + async fn test_initial_graph_data_propagates_error(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + Path::new("/project"), + json!({ + ".git": {}, + "file.txt": "content", + }), + ) + .await; + + fs.set_graph_error( + Path::new("/project/.git"), + Some("fatal: bad default revision 'HEAD'".to_string()), + ); + + let project = Project::test(fs.clone(), [Path::new("/project")], cx).await; + + let repository = project.read_with(cx, |project, cx| { + project + .active_repository(cx) + .expect("should have a repository") + }); + + repository.update(cx, |repo, cx| { + repo.graph_data( + crate::LogSource::default(), + crate::LogOrder::default(), + 0..usize::MAX, + cx, + ); + }); + + cx.run_until_parked(); + + let error = repository.read_with(cx, |repo, _| { + repo.get_graph_data(crate::LogSource::default(), crate::LogOrder::default()) + .and_then(|data| data.error.clone()) + }); + + assert!( + error.is_some(), + "graph data should contain an error after initial_graph_data fails" + ); + let error_message = error.unwrap(); + assert!( + error_message.contains("bad default revision"), + "error should contain the git error message, got: {}", + error_message + ); + } + #[gpui::test] async fn test_graph_data_repopulated_from_cache_after_repo_switch(cx: &mut TestAppContext) { init_test(cx); diff --git a/crates/git_ui/src/branch_picker.rs b/crates/git_ui/src/branch_picker.rs index 83c8119a077ac1c024dbb3b3df948f762b072ec1..2bf4a1991f7a302ed73fe098e8914fedd0f9eb2a 100644 --- a/crates/git_ui/src/branch_picker.rs +++ b/crates/git_ui/src/branch_picker.rs @@ -1906,7 +1906,7 @@ mod tests { assert_eq!( remotes, vec![Remote { - name: SharedString::from("my_new_remote".to_string()) + name: SharedString::from("my_new_remote") }] ); } diff --git a/crates/git_ui/src/worktree_picker.rs b/crates/git_ui/src/worktree_picker.rs index 1b4497be1f4ea96bd4f0431c97bb538eda9faa57..bd1d694fa30bb914569fbb5e6e3c67de3e3d86a0 100644 --- a/crates/git_ui/src/worktree_picker.rs +++ b/crates/git_ui/src/worktree_picker.rs @@ -318,8 +318,13 @@ impl WorktreeListDelegate { .clone(); let new_worktree_path = repo.path_for_new_linked_worktree(&branch, &worktree_directory_setting)?; - let receiver = - repo.create_worktree(branch.clone(), new_worktree_path.clone(), commit); + let receiver = repo.create_worktree( + git::repository::CreateWorktreeTarget::NewBranch { + branch_name: branch.clone(), + base_sha: commit, + }, + new_worktree_path.clone(), + ); anyhow::Ok((receiver, new_worktree_path)) })?; receiver.await??; diff --git a/crates/google_ai/Cargo.toml b/crates/google_ai/Cargo.toml index 81e05e4836529e9b73b58b72683a7e72a4d5c984..d91d28851997723835ba85be343a453918301c71 100644 --- a/crates/google_ai/Cargo.toml +++ b/crates/google_ai/Cargo.toml @@ -18,8 +18,10 @@ schemars = ["dep:schemars"] anyhow.workspace = true futures.workspace = true http_client.workspace = true +language_model_core.workspace = true +log.workspace = true schemars = { workspace = true, optional = true } serde.workspace = true serde_json.workspace = true -settings.workspace = true strum.workspace = true +tiktoken-rs.workspace = true diff --git a/crates/google_ai/src/completion.rs b/crates/google_ai/src/completion.rs new file mode 100644 index 0000000000000000000000000000000000000000..3a15fdaa0187e52cb82dc8c71b5b861eb797f1a8 --- /dev/null +++ b/crates/google_ai/src/completion.rs @@ -0,0 +1,492 @@ +use anyhow::Result; +use futures::{Stream, StreamExt}; +use language_model_core::{ + LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelRequest, + LanguageModelToolChoice, LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role, + StopReason, TokenUsage, +}; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::atomic::{self, AtomicU64}; + +use crate::{ + Content, FunctionCallingConfig, FunctionCallingMode, FunctionDeclaration, + GenerateContentResponse, GenerationConfig, GenerativeContentBlob, GoogleModelMode, + InlineDataPart, ModelName, Part, SystemInstruction, TextPart, ThinkingConfig, ToolConfig, + UsageMetadata, +}; + +pub fn into_google( + mut request: LanguageModelRequest, + model_id: String, + mode: GoogleModelMode, +) -> crate::GenerateContentRequest { + fn map_content(content: Vec) -> Vec { + content + .into_iter() + .flat_map(|content| match content { + MessageContent::Text(text) => { + if !text.is_empty() { + vec![Part::TextPart(TextPart { text })] + } else { + vec![] + } + } + MessageContent::Thinking { + text: _, + signature: Some(signature), + } => { + if !signature.is_empty() { + vec![Part::ThoughtPart(crate::ThoughtPart { + thought: true, + thought_signature: signature, + })] + } else { + vec![] + } + } + MessageContent::Thinking { .. } => { + vec![] + } + MessageContent::RedactedThinking(_) => vec![], + MessageContent::Image(image) => { + vec![Part::InlineDataPart(InlineDataPart { + inline_data: GenerativeContentBlob { + mime_type: "image/png".to_string(), + data: image.source.to_string(), + }, + })] + } + MessageContent::ToolUse(tool_use) => { + // Normalize empty string signatures to None + let thought_signature = tool_use.thought_signature.filter(|s| !s.is_empty()); + + vec![Part::FunctionCallPart(crate::FunctionCallPart { + function_call: crate::FunctionCall { + name: tool_use.name.to_string(), + args: tool_use.input, + }, + thought_signature, + })] + } + MessageContent::ToolResult(tool_result) => { + match tool_result.content { + language_model_core::LanguageModelToolResultContent::Text(text) => { + vec![Part::FunctionResponsePart(crate::FunctionResponsePart { + function_response: crate::FunctionResponse { + name: tool_result.tool_name.to_string(), + // The API expects a valid JSON object + response: serde_json::json!({ + "output": text + }), + }, + })] + } + language_model_core::LanguageModelToolResultContent::Image(image) => { + vec![ + Part::FunctionResponsePart(crate::FunctionResponsePart { + function_response: crate::FunctionResponse { + name: tool_result.tool_name.to_string(), + // The API expects a valid JSON object + response: serde_json::json!({ + "output": "Tool responded with an image" + }), + }, + }), + Part::InlineDataPart(InlineDataPart { + inline_data: GenerativeContentBlob { + mime_type: "image/png".to_string(), + data: image.source.to_string(), + }, + }), + ] + } + } + } + }) + .collect() + } + + let system_instructions = if request + .messages + .first() + .is_some_and(|msg| matches!(msg.role, Role::System)) + { + let message = request.messages.remove(0); + Some(SystemInstruction { + parts: map_content(message.content), + }) + } else { + None + }; + + crate::GenerateContentRequest { + model: ModelName { model_id }, + system_instruction: system_instructions, + contents: request + .messages + .into_iter() + .filter_map(|message| { + let parts = map_content(message.content); + if parts.is_empty() { + None + } else { + Some(Content { + parts, + role: match message.role { + Role::User => crate::Role::User, + Role::Assistant => crate::Role::Model, + Role::System => crate::Role::User, // Google AI doesn't have a system role + }, + }) + } + }) + .collect(), + generation_config: Some(GenerationConfig { + candidate_count: Some(1), + stop_sequences: Some(request.stop), + max_output_tokens: None, + temperature: request.temperature.map(|t| t as f64).or(Some(1.0)), + thinking_config: match (request.thinking_allowed, mode) { + (true, GoogleModelMode::Thinking { budget_tokens }) => { + budget_tokens.map(|thinking_budget| ThinkingConfig { thinking_budget }) + } + _ => None, + }, + top_p: None, + top_k: None, + }), + safety_settings: None, + tools: (!request.tools.is_empty()).then(|| { + vec![crate::Tool { + function_declarations: request + .tools + .into_iter() + .map(|tool| FunctionDeclaration { + name: tool.name, + description: tool.description, + parameters: tool.input_schema, + }) + .collect(), + }] + }), + tool_config: request.tool_choice.map(|choice| ToolConfig { + function_calling_config: FunctionCallingConfig { + mode: match choice { + LanguageModelToolChoice::Auto => FunctionCallingMode::Auto, + LanguageModelToolChoice::Any => FunctionCallingMode::Any, + LanguageModelToolChoice::None => FunctionCallingMode::None, + }, + allowed_function_names: None, + }, + }), + } +} + +pub struct GoogleEventMapper { + usage: UsageMetadata, + stop_reason: StopReason, +} + +impl GoogleEventMapper { + pub fn new() -> Self { + Self { + usage: UsageMetadata::default(), + stop_reason: StopReason::EndTurn, + } + } + + pub fn map_stream( + mut self, + events: Pin>>>, + ) -> impl Stream> + { + events + .map(Some) + .chain(futures::stream::once(async { None })) + .flat_map(move |event| { + futures::stream::iter(match event { + Some(Ok(event)) => self.map_event(event), + Some(Err(error)) => { + vec![Err(LanguageModelCompletionError::from(error))] + } + None => vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))], + }) + }) + } + + pub fn map_event( + &mut self, + event: GenerateContentResponse, + ) -> Vec> { + static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0); + + let mut events: Vec<_> = Vec::new(); + let mut wants_to_use_tool = false; + if let Some(usage_metadata) = event.usage_metadata { + update_usage(&mut self.usage, &usage_metadata); + events.push(Ok(LanguageModelCompletionEvent::UsageUpdate( + convert_usage(&self.usage), + ))) + } + + if let Some(prompt_feedback) = event.prompt_feedback + && let Some(block_reason) = prompt_feedback.block_reason.as_deref() + { + self.stop_reason = match block_reason { + "SAFETY" | "OTHER" | "BLOCKLIST" | "PROHIBITED_CONTENT" | "IMAGE_SAFETY" => { + StopReason::Refusal + } + _ => { + log::error!("Unexpected Google block_reason: {block_reason}"); + StopReason::Refusal + } + }; + events.push(Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))); + + return events; + } + + if let Some(candidates) = event.candidates { + for candidate in candidates { + if let Some(finish_reason) = candidate.finish_reason.as_deref() { + self.stop_reason = match finish_reason { + "STOP" => StopReason::EndTurn, + "MAX_TOKENS" => StopReason::MaxTokens, + _ => { + log::error!("Unexpected google finish_reason: {finish_reason}"); + StopReason::EndTurn + } + }; + } + candidate + .content + .parts + .into_iter() + .for_each(|part| match part { + Part::TextPart(text_part) => { + events.push(Ok(LanguageModelCompletionEvent::Text(text_part.text))) + } + Part::InlineDataPart(_) => {} + Part::FunctionCallPart(function_call_part) => { + wants_to_use_tool = true; + let name: Arc = function_call_part.function_call.name.into(); + let next_tool_id = + TOOL_CALL_COUNTER.fetch_add(1, atomic::Ordering::SeqCst); + let id: LanguageModelToolUseId = + format!("{}-{}", name, next_tool_id).into(); + + // Normalize empty string signatures to None + let thought_signature = function_call_part + .thought_signature + .filter(|s| !s.is_empty()); + + events.push(Ok(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id, + name, + is_input_complete: true, + raw_input: function_call_part.function_call.args.to_string(), + input: function_call_part.function_call.args, + thought_signature, + }, + ))); + } + Part::FunctionResponsePart(_) => {} + Part::ThoughtPart(part) => { + events.push(Ok(LanguageModelCompletionEvent::Thinking { + text: "(Encrypted thought)".to_string(), // TODO: Can we populate this from thought summaries? + signature: Some(part.thought_signature), + })); + } + }); + } + } + + // Even when Gemini wants to use a Tool, the API + // responds with `finish_reason: STOP` + if wants_to_use_tool { + self.stop_reason = StopReason::ToolUse; + events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse))); + } + events + } +} + +/// Count tokens for a Google AI model using tiktoken. This is synchronous; +/// callers should spawn it on a background thread if needed. +pub fn count_google_tokens(request: LanguageModelRequest) -> Result { + let messages = request + .messages + .into_iter() + .map(|message| tiktoken_rs::ChatCompletionRequestMessage { + role: match message.role { + Role::User => "user".into(), + Role::Assistant => "assistant".into(), + Role::System => "system".into(), + }, + content: Some(message.string_contents()), + name: None, + function_call: None, + }) + .collect::>(); + + // Tiktoken doesn't yet support these models, so we manually use the + // same tokenizer as GPT-4. + tiktoken_rs::num_tokens_from_messages("gpt-4", &messages).map(|tokens| tokens as u64) +} + +fn update_usage(usage: &mut UsageMetadata, new: &UsageMetadata) { + if let Some(prompt_token_count) = new.prompt_token_count { + usage.prompt_token_count = Some(prompt_token_count); + } + if let Some(cached_content_token_count) = new.cached_content_token_count { + usage.cached_content_token_count = Some(cached_content_token_count); + } + if let Some(candidates_token_count) = new.candidates_token_count { + usage.candidates_token_count = Some(candidates_token_count); + } + if let Some(tool_use_prompt_token_count) = new.tool_use_prompt_token_count { + usage.tool_use_prompt_token_count = Some(tool_use_prompt_token_count); + } + if let Some(thoughts_token_count) = new.thoughts_token_count { + usage.thoughts_token_count = Some(thoughts_token_count); + } + if let Some(total_token_count) = new.total_token_count { + usage.total_token_count = Some(total_token_count); + } +} + +fn convert_usage(usage: &UsageMetadata) -> TokenUsage { + let prompt_tokens = usage.prompt_token_count.unwrap_or(0); + let cached_tokens = usage.cached_content_token_count.unwrap_or(0); + let input_tokens = prompt_tokens - cached_tokens; + let output_tokens = usage.candidates_token_count.unwrap_or(0); + + TokenUsage { + input_tokens, + output_tokens, + cache_read_input_tokens: cached_tokens, + cache_creation_input_tokens: 0, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + Content, FunctionCall, FunctionCallPart, GenerateContentCandidate, GenerateContentResponse, + Part, Role as GoogleRole, + }; + use serde_json::json; + + #[test] + fn test_function_call_with_signature_creates_tool_use_with_signature() { + let mut mapper = GoogleEventMapper::new(); + + let response = GenerateContentResponse { + candidates: Some(vec![GenerateContentCandidate { + index: Some(0), + content: Content { + parts: vec![Part::FunctionCallPart(FunctionCallPart { + function_call: FunctionCall { + name: "test_function".to_string(), + args: json!({"arg": "value"}), + }, + thought_signature: Some("test_signature_123".to_string()), + })], + role: GoogleRole::Model, + }, + finish_reason: None, + finish_message: None, + safety_ratings: None, + citation_metadata: None, + }]), + prompt_feedback: None, + usage_metadata: None, + }; + + let events = mapper.map_event(response); + assert_eq!(events.len(), 2); + + if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] { + assert_eq!(tool_use.name.as_ref(), "test_function"); + assert_eq!( + tool_use.thought_signature.as_deref(), + Some("test_signature_123") + ); + } else { + panic!("Expected ToolUse event"); + } + } + + #[test] + fn test_function_call_without_signature_has_none() { + let mut mapper = GoogleEventMapper::new(); + + let response = GenerateContentResponse { + candidates: Some(vec![GenerateContentCandidate { + index: Some(0), + content: Content { + parts: vec![Part::FunctionCallPart(FunctionCallPart { + function_call: FunctionCall { + name: "test_function".to_string(), + args: json!({"arg": "value"}), + }, + thought_signature: None, + })], + role: GoogleRole::Model, + }, + finish_reason: None, + finish_message: None, + safety_ratings: None, + citation_metadata: None, + }]), + prompt_feedback: None, + usage_metadata: None, + }; + + let events = mapper.map_event(response); + assert_eq!(events.len(), 2); + + if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] { + assert!(tool_use.thought_signature.is_none()); + } else { + panic!("Expected ToolUse event"); + } + } + + #[test] + fn test_empty_string_signature_normalized_to_none() { + let mut mapper = GoogleEventMapper::new(); + + let response = GenerateContentResponse { + candidates: Some(vec![GenerateContentCandidate { + index: Some(0), + content: Content { + parts: vec![Part::FunctionCallPart(FunctionCallPart { + function_call: FunctionCall { + name: "test_function".to_string(), + args: json!({"arg": "value"}), + }, + thought_signature: Some("".to_string()), + })], + role: GoogleRole::Model, + }, + finish_reason: None, + finish_message: None, + safety_ratings: None, + citation_metadata: None, + }]), + prompt_feedback: None, + usage_metadata: None, + }; + + let events = mapper.map_event(response); + if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] { + assert!(tool_use.thought_signature.is_none()); + } else { + panic!("Expected ToolUse event"); + } + } +} diff --git a/crates/google_ai/src/google_ai.rs b/crates/google_ai/src/google_ai.rs index 7659be8ab44da35efd16389c4abd0bf99d8cf3a4..5770c9a020b04bf280908993911b67ec3a5b980f 100644 --- a/crates/google_ai/src/google_ai.rs +++ b/crates/google_ai/src/google_ai.rs @@ -3,8 +3,9 @@ use std::mem; use anyhow::{Result, anyhow, bail}; use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream}; use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; +pub use language_model_core::ModelMode as GoogleModelMode; use serde::{Deserialize, Deserializer, Serialize, Serializer}; -pub use settings::ModelMode as GoogleModelMode; +pub mod completion; pub const API_URL: &str = "https://generativelanguage.googleapis.com"; diff --git a/crates/gpui/Cargo.toml b/crates/gpui/Cargo.toml index 915f0fc03e2cc5beaf40c810654724295c41cde8..efb4817ef0e0c037bc08d0c5a8ad702705cb996d 100644 --- a/crates/gpui/Cargo.toml +++ b/crates/gpui/Cargo.toml @@ -56,6 +56,7 @@ etagere = "0.2" futures.workspace = true futures-concurrency.workspace = true gpui_macros.workspace = true +gpui_shared_string.workspace = true http_client.workspace = true image.workspace = true inventory.workspace = true diff --git a/crates/gpui/src/gpui.rs b/crates/gpui/src/gpui.rs index 6d7d801cd42c3639d7892295a660319d21b05dfa..dbb57f46efc37678c07dfd4f02bb3faebc60c9a3 100644 --- a/crates/gpui/src/gpui.rs +++ b/crates/gpui/src/gpui.rs @@ -39,7 +39,6 @@ pub mod profiler; #[expect(missing_docs)] pub mod queue; mod scene; -mod shared_string; mod shared_uri; mod style; mod styled; @@ -92,6 +91,7 @@ pub use global::*; pub use gpui_macros::{ AppContext, IntoElement, Render, VisualContext, property_test, register_action, test, }; +pub use gpui_shared_string::*; pub use gpui_util::arc_cow::ArcCow; pub use http_client; pub use input::*; @@ -106,7 +106,6 @@ pub use profiler::*; pub use queue::{PriorityQueueReceiver, PriorityQueueSender}; pub use refineable::*; pub use scene::*; -pub use shared_string::*; pub use shared_uri::*; use std::{any::Any, future::Future}; pub use style::*; diff --git a/crates/gpui/src/svg_renderer.rs b/crates/gpui/src/svg_renderer.rs index 8653ab9b162031772ab29367b60ff988e33cd823..a766a25cc1ef66039f5b2a1d0aeaab51ace89578 100644 --- a/crates/gpui/src/svg_renderer.rs +++ b/crates/gpui/src/svg_renderer.rs @@ -105,18 +105,36 @@ pub enum SvgSize { impl SvgRenderer { /// Creates a new SVG renderer with the provided asset source. pub fn new(asset_source: Arc) -> Self { - static FONT_DB: LazyLock> = LazyLock::new(|| { + static SYSTEM_FONT_DB: LazyLock> = LazyLock::new(|| { let mut db = usvg::fontdb::Database::new(); db.load_system_fonts(); Arc::new(db) }); + + let fontdb = { + let mut db = (**SYSTEM_FONT_DB).clone(); + load_bundled_fonts(&*asset_source, &mut db); + fix_generic_font_families(&mut db); + Arc::new(db) + }; + let default_font_resolver = usvg::FontResolver::default_font_selector(); let font_resolver = Box::new( move |font: &usvg::Font, db: &mut Arc| { if db.is_empty() { - *db = FONT_DB.clone(); + *db = fontdb.clone(); + } + if let Some(id) = default_font_resolver(font, db) { + return Some(id); } - default_font_resolver(font, db) + // fontdb doesn't recognize CSS system font keywords like "system-ui" + // or "ui-sans-serif", so fall back to sans-serif before any face. + let sans_query = usvg::fontdb::Query { + families: &[usvg::fontdb::Family::SansSerif], + ..Default::default() + }; + db.query(&sans_query) + .or_else(|| db.faces().next().map(|f| f.id)) }, ); let default_fallback_selection = usvg::FontResolver::default_fallback_selector(); @@ -226,14 +244,69 @@ impl SvgRenderer { } } +fn load_bundled_fonts(asset_source: &dyn AssetSource, db: &mut usvg::fontdb::Database) { + let font_paths = [ + "fonts/ibm-plex-sans/IBMPlexSans-Regular.ttf", + "fonts/lilex/Lilex-Regular.ttf", + ]; + for path in font_paths { + match asset_source.load(path) { + Ok(Some(data)) => db.load_font_data(data.into_owned()), + Ok(None) => log::warn!("Bundled font not found: {path}"), + Err(error) => log::warn!("Failed to load bundled font {path}: {error}"), + } + } +} + +// fontdb defaults generic families to Microsoft fonts ("Arial", "Times New Roman") +// which aren't installed on most Linux systems. fontconfig normally overrides these, +// but when it fails the defaults remain and all generic family queries return None. +fn fix_generic_font_families(db: &mut usvg::fontdb::Database) { + use usvg::fontdb::{Family, Query}; + + let families_and_fallbacks: &[(Family<'_>, &str)] = &[ + (Family::SansSerif, "IBM Plex Sans"), + // No serif font bundled; use sans-serif as best available fallback. + (Family::Serif, "IBM Plex Sans"), + (Family::Monospace, "Lilex"), + (Family::Cursive, "IBM Plex Sans"), + (Family::Fantasy, "IBM Plex Sans"), + ]; + + for (family, fallback_name) in families_and_fallbacks { + let query = Query { + families: &[*family], + ..Default::default() + }; + if db.query(&query).is_none() { + match family { + Family::SansSerif => db.set_sans_serif_family(*fallback_name), + Family::Serif => db.set_serif_family(*fallback_name), + Family::Monospace => db.set_monospace_family(*fallback_name), + Family::Cursive => db.set_cursive_family(*fallback_name), + Family::Fantasy => db.set_fantasy_family(*fallback_name), + _ => {} + } + } + } +} + #[cfg(test)] mod tests { use super::*; + use usvg::fontdb::{Database, Family, Query}; const IBM_PLEX_REGULAR: &[u8] = include_bytes!("../../../assets/fonts/ibm-plex-sans/IBMPlexSans-Regular.ttf"); const LILEX_REGULAR: &[u8] = include_bytes!("../../../assets/fonts/lilex/Lilex-Regular.ttf"); + fn db_with_bundled_fonts() -> Database { + let mut db = Database::new(); + db.load_font_data(IBM_PLEX_REGULAR.to_vec()); + db.load_font_data(LILEX_REGULAR.to_vec()); + db + } + #[test] fn test_is_emoji_presentation() { let cases = [ @@ -266,11 +339,33 @@ mod tests { } #[test] - fn test_select_emoji_font_skips_family_without_glyph() { - let mut db = usvg::fontdb::Database::new(); + fn fix_generic_font_families_sets_all_families() { + let mut db = db_with_bundled_fonts(); + fix_generic_font_families(&mut db); + + let families = [ + Family::SansSerif, + Family::Serif, + Family::Monospace, + Family::Cursive, + Family::Fantasy, + ]; - db.load_font_data(IBM_PLEX_REGULAR.to_vec()); - db.load_font_data(LILEX_REGULAR.to_vec()); + for family in families { + let query = Query { + families: &[family], + ..Default::default() + }; + assert!( + db.query(&query).is_some(), + "Expected generic family {family:?} to resolve after fix_generic_font_families" + ); + } + } + + #[test] + fn test_select_emoji_font_skips_family_without_glyph() { + let mut db = db_with_bundled_fonts(); let ibm_plex_sans = db .query(&usvg::fontdb::Query { @@ -294,4 +389,22 @@ mod tests { assert!(!font_has_char(&db, ibm_plex_sans, '│')); assert!(font_has_char(&db, selected, '│')); } + + #[test] + fn fix_generic_font_families_monospace_resolves_to_lilex() { + let mut db = db_with_bundled_fonts(); + fix_generic_font_families(&mut db); + + let query = Query { + families: &[Family::Monospace], + ..Default::default() + }; + let id = db.query(&query).expect("Monospace should resolve"); + let face = db.face(id).expect("Face should exist"); + assert!( + face.families.iter().any(|(name, _)| name.contains("Lilex")), + "Monospace should map to Lilex, got {:?}", + face.families + ); + } } diff --git a/crates/gpui/src/text_system/line.rs b/crates/gpui/src/text_system/line.rs index 7b5714188ff97d0169806ac5da9f039f9be2c16a..611c979bc29f488fa18386c7b319a7310b6ce1c6 100644 --- a/crates/gpui/src/text_system/line.rs +++ b/crates/gpui/src/text_system/line.rs @@ -882,7 +882,7 @@ mod tests { ], len: 6, }), - text: SharedString::new("abcdef".to_string()), + text: "abcdef".into(), decoration_runs: SmallVec::new(), }; diff --git a/crates/gpui_shared_string/Cargo.toml b/crates/gpui_shared_string/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..4f7735b4f88253de7cd62d30445153d2a6284751 --- /dev/null +++ b/crates/gpui_shared_string/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "gpui_shared_string" +version = "0.1.0" +publish.workspace = true +edition.workspace = true + +[lib] +path = "gpui_shared_string.rs" + +[dependencies] +derive_more.workspace = true +gpui_util.workspace = true +schemars.workspace = true +serde.workspace = true + +[lints] +workspace = true diff --git a/crates/gpui_shared_string/LICENSE-APACHE b/crates/gpui_shared_string/LICENSE-APACHE new file mode 120000 index 0000000000000000000000000000000000000000..1cd601d0a3affae83854be02a0afdec3b7a9ec4d --- /dev/null +++ b/crates/gpui_shared_string/LICENSE-APACHE @@ -0,0 +1 @@ +../../LICENSE-APACHE \ No newline at end of file diff --git a/crates/gpui/src/shared_string.rs b/crates/gpui_shared_string/gpui_shared_string.rs similarity index 100% rename from crates/gpui/src/shared_string.rs rename to crates/gpui_shared_string/gpui_shared_string.rs diff --git a/crates/http_client/src/github_download.rs b/crates/http_client/src/github_download.rs index 47ae2c2b36b1ab37b56ab70735c2ce018bc5e275..5d11f3e11b7ea951c6bc9c143c266d8802f88cc3 100644 --- a/crates/http_client/src/github_download.rs +++ b/crates/http_client/src/github_download.rs @@ -207,11 +207,7 @@ async fn extract_tar_gz( from: impl AsyncRead + Unpin, ) -> Result<(), anyhow::Error> { let decompressed_bytes = GzipDecoder::new(BufReader::new(from)); - let archive = async_tar::Archive::new(decompressed_bytes); - archive - .unpack(&destination_path) - .await - .with_context(|| format!("extracting {url} to {destination_path:?}"))?; + unpack_tar_archive(destination_path, url, decompressed_bytes).await?; Ok(()) } @@ -221,7 +217,21 @@ async fn extract_tar_bz2( from: impl AsyncRead + Unpin, ) -> Result<(), anyhow::Error> { let decompressed_bytes = BzDecoder::new(BufReader::new(from)); - let archive = async_tar::Archive::new(decompressed_bytes); + unpack_tar_archive(destination_path, url, decompressed_bytes).await?; + Ok(()) +} + +async fn unpack_tar_archive( + destination_path: &Path, + url: &str, + archive_bytes: impl AsyncRead + Unpin, +) -> Result<(), anyhow::Error> { + // We don't need to set the modified time. It's irrelevant to downloaded + // archive verification, and some filesystems return errors when asked to + // apply it after extraction. + let archive = async_tar::ArchiveBuilder::new(archive_bytes) + .set_preserve_mtime(false) + .build(); archive .unpack(&destination_path) .await diff --git a/crates/icons/src/icons.rs b/crates/icons/src/icons.rs index e29b7d3593025556771d62dc0124786672c540de..bdc3890432414e0a78f69a226bb9174510453331 100644 --- a/crates/icons/src/icons.rs +++ b/crates/icons/src/icons.rs @@ -134,7 +134,7 @@ pub enum IconName { Flame, Folder, FolderOpen, - FolderPlus, + FolderOpenAdd, FolderSearch, Font, FontSize, @@ -184,6 +184,7 @@ pub enum IconName { NewThread, Notepad, OpenFolder, + OpenNewWindow, Option, PageDown, PageUp, diff --git a/crates/language/src/buffer.rs b/crates/language/src/buffer.rs index a467cd789555d39a32ad4e1d7b21da7b14df9c25..1e54134efcab4f0074a73b241f8e0d04cfbcbcdd 100644 --- a/crates/language/src/buffer.rs +++ b/crates/language/src/buffer.rs @@ -3733,16 +3733,24 @@ impl BufferSnapshot { /// returned in chunks where each chunk has a single syntax highlighting style and /// diagnostic status. #[ztracing::instrument(skip_all)] - pub fn chunks(&self, range: Range, language_aware: bool) -> BufferChunks<'_> { + pub fn chunks( + &self, + range: Range, + language_aware: LanguageAwareStyling, + ) -> BufferChunks<'_> { let range = range.start.to_offset(self)..range.end.to_offset(self); let mut syntax = None; - if language_aware { + if language_aware.tree_sitter { syntax = Some(self.get_highlights(range.clone())); } - // We want to look at diagnostic spans only when iterating over language-annotated chunks. - let diagnostics = language_aware; - BufferChunks::new(self.text.as_rope(), range, syntax, diagnostics, Some(self)) + BufferChunks::new( + self.text.as_rope(), + range, + syntax, + language_aware.diagnostics, + Some(self), + ) } pub fn highlighted_text_for_range( @@ -4477,7 +4485,13 @@ impl BufferSnapshot { let mut text = String::new(); let mut highlight_ranges = Vec::new(); let mut name_ranges = Vec::new(); - let mut chunks = self.chunks(source_range_for_text.clone(), true); + let mut chunks = self.chunks( + source_range_for_text.clone(), + LanguageAwareStyling { + tree_sitter: true, + diagnostics: true, + }, + ); let mut last_buffer_range_end = 0; for (buffer_range, is_name) in buffer_ranges { let space_added = !text.is_empty() && buffer_range.start > last_buffer_range_end; @@ -5402,7 +5416,13 @@ impl BufferSnapshot { let mut words = BTreeMap::default(); let mut current_word_start_ix = None; let mut chunk_ix = query.range.start; - for chunk in self.chunks(query.range, false) { + for chunk in self.chunks( + query.range, + LanguageAwareStyling { + tree_sitter: false, + diagnostics: false, + }, + ) { for (i, c) in chunk.text.char_indices() { let ix = chunk_ix + i; if classifier.is_word(c) { @@ -5441,6 +5461,15 @@ impl BufferSnapshot { } } +/// A configuration to use when producing styled text chunks. +#[derive(Clone, Copy)] +pub struct LanguageAwareStyling { + /// Whether to highlight text chunks using tree-sitter. + pub tree_sitter: bool, + /// Whether to highlight text chunks based on the diagnostics data. + pub diagnostics: bool, +} + pub struct WordsQuery<'a> { /// Only returns words with all chars from the fuzzy string in them. pub fuzzy_contents: Option<&'a str>, diff --git a/crates/language/src/buffer_tests.rs b/crates/language/src/buffer_tests.rs index 9308ee6f0a0ee207b30be9e6fafa73ba9452d94c..9f4562bf547f389c5ecc5ca29470ac4e49da0e04 100644 --- a/crates/language/src/buffer_tests.rs +++ b/crates/language/src/buffer_tests.rs @@ -4102,7 +4102,13 @@ fn test_random_chunk_bitmaps(cx: &mut App, mut rng: StdRng) { let snapshot = buffer.read(cx).snapshot(); // Get all chunks and verify their bitmaps - let chunks = snapshot.chunks(0..snapshot.len(), false); + let chunks = snapshot.chunks( + 0..snapshot.len(), + LanguageAwareStyling { + tree_sitter: false, + diagnostics: false, + }, + ); for chunk in chunks { let chunk_text = chunk.text; diff --git a/crates/language_core/Cargo.toml b/crates/language_core/Cargo.toml index 4861632b4663c860706525c65cd8607133b3ec71..cd1143f61d3af1d3b72bb5bd3a23e53b27aa9aba 100644 --- a/crates/language_core/Cargo.toml +++ b/crates/language_core/Cargo.toml @@ -10,7 +10,7 @@ path = "src/language_core.rs" [dependencies] anyhow.workspace = true collections.workspace = true -gpui.workspace = true +gpui_shared_string.workspace = true log.workspace = true lsp.workspace = true parking_lot.workspace = true @@ -22,8 +22,6 @@ toml.workspace = true tree-sitter.workspace = true util.workspace = true -[dev-dependencies] -gpui = { workspace = true, features = ["test-support"] } [features] test-support = [] diff --git a/crates/language_core/src/diagnostic.rs b/crates/language_core/src/diagnostic.rs index 9a468a14b863a94ef23e00c3e15edd9fa2d8b09a..00abcb61d1b1290dd96c69b31296eebfd3900348 100644 --- a/crates/language_core/src/diagnostic.rs +++ b/crates/language_core/src/diagnostic.rs @@ -1,4 +1,4 @@ -use gpui::SharedString; +use gpui_shared_string::SharedString; use lsp::{DiagnosticSeverity, NumberOrString}; use serde::{Deserialize, Serialize}; use serde_json::Value; diff --git a/crates/language_core/src/grammar.rs b/crates/language_core/src/grammar.rs index 54e9a3f1b3309718436b206874802779925a9d04..44f73ac6dea235a522393b5b0bd10729999b45bf 100644 --- a/crates/language_core/src/grammar.rs +++ b/crates/language_core/src/grammar.rs @@ -4,7 +4,7 @@ use crate::{ }; use anyhow::{Context as _, Result}; use collections::HashMap; -use gpui::SharedString; +use gpui_shared_string::SharedString; use lsp::LanguageServerName; use parking_lot::Mutex; use std::sync::atomic::{AtomicUsize, Ordering::SeqCst}; diff --git a/crates/language_core/src/language_config.rs b/crates/language_core/src/language_config.rs index f412af418b7948b40e3bdac5a3a649d12d008e8a..89474dbad9171d37cfb1b7f55f70a137eeb535d5 100644 --- a/crates/language_core/src/language_config.rs +++ b/crates/language_core/src/language_config.rs @@ -1,6 +1,6 @@ use crate::LanguageName; use collections::{HashMap, HashSet, IndexSet}; -use gpui::SharedString; +use gpui_shared_string::SharedString; use lsp::LanguageServerName; use regex::Regex; use schemars::{JsonSchema, SchemaGenerator, json_schema}; diff --git a/crates/language_core/src/language_name.rs b/crates/language_core/src/language_name.rs index 764b54a48a566ad98212de3e22bce6aca9a1e393..14528435d9103b4faad3e055ea69bbdaf372113c 100644 --- a/crates/language_core/src/language_name.rs +++ b/crates/language_core/src/language_name.rs @@ -1,4 +1,4 @@ -use gpui::SharedString; +use gpui_shared_string::SharedString; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use std::{ diff --git a/crates/language_core/src/lsp_adapter.rs b/crates/language_core/src/lsp_adapter.rs index 03012f71143428b49ea9d75a03b0118b50e413b4..8f449637b306c2a33a76cb5b356d0280903f4187 100644 --- a/crates/language_core/src/lsp_adapter.rs +++ b/crates/language_core/src/lsp_adapter.rs @@ -1,4 +1,4 @@ -use gpui::SharedString; +use gpui_shared_string::SharedString; use serde::{Deserialize, Serialize}; /// Converts a value into an LSP position. diff --git a/crates/language_core/src/manifest.rs b/crates/language_core/src/manifest.rs index 1e762ff6e7c364eef02eea16ce9e1ecaaa198554..864f89e6cee65b0dff7c4462c99940c32ba0830f 100644 --- a/crates/language_core/src/manifest.rs +++ b/crates/language_core/src/manifest.rs @@ -1,6 +1,6 @@ use std::borrow::Borrow; -use gpui::SharedString; +use gpui_shared_string::SharedString; #[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] pub struct ManifestName(SharedString); diff --git a/crates/language_core/src/toolchain.rs b/crates/language_core/src/toolchain.rs index a021cb86bd36295a065b16281209c5fc3b63cffc..78bd69917fbc0f66af454ba262c1eb3b7c357290 100644 --- a/crates/language_core/src/toolchain.rs +++ b/crates/language_core/src/toolchain.rs @@ -6,7 +6,7 @@ use std::{path::Path, sync::Arc}; -use gpui::SharedString; +use gpui_shared_string::SharedString; use util::rel_path::RelPath; use crate::{LanguageName, ManifestName}; diff --git a/crates/language_model/Cargo.toml b/crates/language_model/Cargo.toml index 4712d86dff6c44f9cdd8576a08349ccfa7d0ecca..d679588138ccec0f8d9fd830d26d13f2f65d44a3 100644 --- a/crates/language_model/Cargo.toml +++ b/crates/language_model/Cargo.toml @@ -16,13 +16,9 @@ doctest = false test-support = [] [dependencies] -anthropic = { workspace = true, features = ["schemars"] } anyhow.workspace = true credentials_provider.workspace = true base64.workspace = true -cloud_api_client.workspace = true -cloud_api_types.workspace = true -cloud_llm_client.workspace = true collections.workspace = true env_var.workspace = true futures.workspace = true @@ -30,14 +26,11 @@ gpui.workspace = true http_client.workspace = true icons.workspace = true image.workspace = true +language_model_core.workspace = true log.workspace = true -open_ai = { workspace = true, features = ["schemars"] } -open_router.workspace = true parking_lot.workspace = true -schemars.workspace = true serde.workspace = true serde_json.workspace = true -smol.workspace = true thiserror.workspace = true util.workspace = true diff --git a/crates/language_model/src/fake_provider.rs b/crates/language_model/src/fake_provider.rs index 50037f31facbac446de7ecf38536d1e4a24c7867..cee65c21e575e7c96579c271805386527a29d4da 100644 --- a/crates/language_model/src/fake_provider.rs +++ b/crates/language_model/src/fake_provider.rs @@ -5,11 +5,10 @@ use crate::{ LanguageModelRequest, LanguageModelToolChoice, }; use anyhow::anyhow; -use futures::{FutureExt, channel::mpsc, future::BoxFuture, stream::BoxStream}; +use futures::{FutureExt, channel::mpsc, future::BoxFuture, stream::BoxStream, stream::StreamExt}; use gpui::{AnyView, App, AsyncApp, Entity, Task, Window}; use http_client::Result; use parking_lot::Mutex; -use smol::stream::StreamExt; use std::sync::{ Arc, atomic::{AtomicBool, Ordering::SeqCst}, diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index 3f309b7b1d4152c54324efaaf0ad3bdb7035eea4..60e8228fec52ffee763e19541f042ce47246dad2 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -1,380 +1,31 @@ mod api_key; mod model; -mod provider; -mod rate_limiter; mod registry; mod request; -mod role; -pub mod tool_schema; #[cfg(any(test, feature = "test-support"))] pub mod fake_provider; -use anyhow::{Result, anyhow}; -use cloud_llm_client::CompletionRequestStatus; +pub use language_model_core::*; + +use anyhow::Result; use futures::FutureExt; use futures::{StreamExt, future::BoxFuture, stream::BoxStream}; -use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window}; -use http_client::{StatusCode, http}; +use gpui::{AnyView, App, AsyncApp, Task, Window}; use icons::IconName; use parking_lot::Mutex; -use serde::{Deserialize, Serialize}; -use std::ops::{Add, Sub}; -use std::str::FromStr; use std::sync::Arc; -use std::time::Duration; -use std::{fmt, io}; -use thiserror::Error; -use util::serde::is_default; pub use crate::api_key::{ApiKey, ApiKeyState}; pub use crate::model::*; -pub use crate::rate_limiter::*; pub use crate::registry::*; -pub use crate::request::*; -pub use crate::role::*; -pub use crate::tool_schema::LanguageModelToolSchemaFormat; +pub use crate::request::{LanguageModelImageExt, gpui_size_to_image_size, image_size_to_gpui}; pub use env_var::{EnvVar, env_var}; -pub use provider::*; pub fn init(cx: &mut App) { registry::init(cx); } -#[derive(Clone, Debug)] -pub struct LanguageModelCacheConfiguration { - pub max_cache_anchors: usize, - pub should_speculate: bool, - pub min_total_token: u64, -} - -/// A completion event from a language model. -#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] -pub enum LanguageModelCompletionEvent { - Queued { - position: usize, - }, - Started, - Stop(StopReason), - Text(String), - Thinking { - text: String, - signature: Option, - }, - RedactedThinking { - data: String, - }, - ToolUse(LanguageModelToolUse), - ToolUseJsonParseError { - id: LanguageModelToolUseId, - tool_name: Arc, - raw_input: Arc, - json_parse_error: String, - }, - StartMessage { - message_id: String, - }, - ReasoningDetails(serde_json::Value), - UsageUpdate(TokenUsage), -} - -impl LanguageModelCompletionEvent { - pub fn from_completion_request_status( - status: CompletionRequestStatus, - upstream_provider: LanguageModelProviderName, - ) -> Result, LanguageModelCompletionError> { - match status { - CompletionRequestStatus::Queued { position } => { - Ok(Some(LanguageModelCompletionEvent::Queued { position })) - } - CompletionRequestStatus::Started => Ok(Some(LanguageModelCompletionEvent::Started)), - CompletionRequestStatus::Unknown | CompletionRequestStatus::StreamEnded => Ok(None), - CompletionRequestStatus::Failed { - code, - message, - request_id: _, - retry_after, - } => Err(LanguageModelCompletionError::from_cloud_failure( - upstream_provider, - code, - message, - retry_after.map(Duration::from_secs_f64), - )), - } - } -} - -#[derive(Error, Debug)] -pub enum LanguageModelCompletionError { - #[error("prompt too large for context window")] - PromptTooLarge { tokens: Option }, - #[error("missing {provider} API key")] - NoApiKey { provider: LanguageModelProviderName }, - #[error("{provider}'s API rate limit exceeded")] - RateLimitExceeded { - provider: LanguageModelProviderName, - retry_after: Option, - }, - #[error("{provider}'s API servers are overloaded right now")] - ServerOverloaded { - provider: LanguageModelProviderName, - retry_after: Option, - }, - #[error("{provider}'s API server reported an internal server error: {message}")] - ApiInternalServerError { - provider: LanguageModelProviderName, - message: String, - }, - #[error("{message}")] - UpstreamProviderError { - message: String, - status: StatusCode, - retry_after: Option, - }, - #[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")] - HttpResponseError { - provider: LanguageModelProviderName, - status_code: StatusCode, - message: String, - }, - - // Client errors - #[error("invalid request format to {provider}'s API: {message}")] - BadRequestFormat { - provider: LanguageModelProviderName, - message: String, - }, - #[error("authentication error with {provider}'s API: {message}")] - AuthenticationError { - provider: LanguageModelProviderName, - message: String, - }, - #[error("Permission error with {provider}'s API: {message}")] - PermissionError { - provider: LanguageModelProviderName, - message: String, - }, - #[error("language model provider API endpoint not found")] - ApiEndpointNotFound { provider: LanguageModelProviderName }, - #[error("I/O error reading response from {provider}'s API")] - ApiReadResponseError { - provider: LanguageModelProviderName, - #[source] - error: io::Error, - }, - #[error("error serializing request to {provider} API")] - SerializeRequest { - provider: LanguageModelProviderName, - #[source] - error: serde_json::Error, - }, - #[error("error building request body to {provider} API")] - BuildRequestBody { - provider: LanguageModelProviderName, - #[source] - error: http::Error, - }, - #[error("error sending HTTP request to {provider} API")] - HttpSend { - provider: LanguageModelProviderName, - #[source] - error: anyhow::Error, - }, - #[error("error deserializing {provider} API response")] - DeserializeResponse { - provider: LanguageModelProviderName, - #[source] - error: serde_json::Error, - }, - - #[error("stream from {provider} ended unexpectedly")] - StreamEndedUnexpectedly { provider: LanguageModelProviderName }, - - // TODO: Ideally this would be removed in favor of having a comprehensive list of errors. - #[error(transparent)] - Other(#[from] anyhow::Error), -} - -impl LanguageModelCompletionError { - fn parse_upstream_error_json(message: &str) -> Option<(StatusCode, String)> { - let error_json = serde_json::from_str::(message).ok()?; - let upstream_status = error_json - .get("upstream_status") - .and_then(|v| v.as_u64()) - .and_then(|status| u16::try_from(status).ok()) - .and_then(|status| StatusCode::from_u16(status).ok())?; - let inner_message = error_json - .get("message") - .and_then(|v| v.as_str()) - .unwrap_or(message) - .to_string(); - Some((upstream_status, inner_message)) - } - - pub fn from_cloud_failure( - upstream_provider: LanguageModelProviderName, - code: String, - message: String, - retry_after: Option, - ) -> Self { - if let Some(tokens) = parse_prompt_too_long(&message) { - // TODO: currently Anthropic PAYLOAD_TOO_LARGE response may cause INTERNAL_SERVER_ERROR - // to be reported. This is a temporary workaround to handle this in the case where the - // token limit has been exceeded. - Self::PromptTooLarge { - tokens: Some(tokens), - } - } else if code == "upstream_http_error" { - if let Some((upstream_status, inner_message)) = - Self::parse_upstream_error_json(&message) - { - return Self::from_http_status( - upstream_provider, - upstream_status, - inner_message, - retry_after, - ); - } - anyhow!("completion request failed, code: {code}, message: {message}").into() - } else if let Some(status_code) = code - .strip_prefix("upstream_http_") - .and_then(|code| StatusCode::from_str(code).ok()) - { - Self::from_http_status(upstream_provider, status_code, message, retry_after) - } else if let Some(status_code) = code - .strip_prefix("http_") - .and_then(|code| StatusCode::from_str(code).ok()) - { - Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after) - } else { - anyhow!("completion request failed, code: {code}, message: {message}").into() - } - } - - pub fn from_http_status( - provider: LanguageModelProviderName, - status_code: StatusCode, - message: String, - retry_after: Option, - ) -> Self { - match status_code { - StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message }, - StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message }, - StatusCode::FORBIDDEN => Self::PermissionError { provider, message }, - StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider }, - StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge { - tokens: parse_prompt_too_long(&message), - }, - StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded { - provider, - retry_after, - }, - StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message }, - StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded { - provider, - retry_after, - }, - _ if status_code.as_u16() == 529 => Self::ServerOverloaded { - provider, - retry_after, - }, - _ => Self::HttpResponseError { - provider, - status_code, - message, - }, - } - } -} - -#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)] -#[serde(rename_all = "snake_case")] -pub enum StopReason { - EndTurn, - MaxTokens, - ToolUse, - Refusal, -} - -#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)] -pub struct TokenUsage { - #[serde(default, skip_serializing_if = "is_default")] - pub input_tokens: u64, - #[serde(default, skip_serializing_if = "is_default")] - pub output_tokens: u64, - #[serde(default, skip_serializing_if = "is_default")] - pub cache_creation_input_tokens: u64, - #[serde(default, skip_serializing_if = "is_default")] - pub cache_read_input_tokens: u64, -} - -impl TokenUsage { - pub fn total_tokens(&self) -> u64 { - self.input_tokens - + self.output_tokens - + self.cache_read_input_tokens - + self.cache_creation_input_tokens - } -} - -impl Add for TokenUsage { - type Output = Self; - - fn add(self, other: Self) -> Self { - Self { - input_tokens: self.input_tokens + other.input_tokens, - output_tokens: self.output_tokens + other.output_tokens, - cache_creation_input_tokens: self.cache_creation_input_tokens - + other.cache_creation_input_tokens, - cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens, - } - } -} - -impl Sub for TokenUsage { - type Output = Self; - - fn sub(self, other: Self) -> Self { - Self { - input_tokens: self.input_tokens - other.input_tokens, - output_tokens: self.output_tokens - other.output_tokens, - cache_creation_input_tokens: self.cache_creation_input_tokens - - other.cache_creation_input_tokens, - cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens, - } - } -} - -#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)] -pub struct LanguageModelToolUseId(Arc); - -impl fmt::Display for LanguageModelToolUseId { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.0) - } -} - -impl From for LanguageModelToolUseId -where - T: Into>, -{ - fn from(value: T) -> Self { - Self(value.into()) - } -} - -#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)] -pub struct LanguageModelToolUse { - pub id: LanguageModelToolUseId, - pub name: Arc, - pub raw_input: String, - pub input: serde_json::Value, - pub is_input_complete: bool, - /// Thought signature the model sent us. Some models require that this - /// signature be preserved and sent back in conversation history for validation. - pub thought_signature: Option, -} - pub struct LanguageModelTextStream { pub message_id: Option, pub stream: BoxStream<'static, Result>, @@ -392,13 +43,6 @@ impl Default for LanguageModelTextStream { } } -#[derive(Debug, Clone)] -pub struct LanguageModelEffortLevel { - pub name: SharedString, - pub value: SharedString, - pub is_default: bool, -} - pub trait LanguageModel: Send + Sync { fn id(&self) -> LanguageModelId; fn name(&self) -> LanguageModelName; @@ -605,7 +249,7 @@ pub trait LanguageModel: Send + Sync { } impl std::fmt::Debug for dyn LanguageModel { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("") .field("id", &self.id()) .field("name", &self.name()) @@ -619,17 +263,6 @@ impl std::fmt::Debug for dyn LanguageModel { } } -/// An error that occurred when trying to authenticate the language model provider. -#[derive(Debug, Error)] -pub enum AuthenticateError { - #[error("connection refused")] - ConnectionRefused, - #[error("credentials not found")] - CredentialsNotFound, - #[error(transparent)] - Other(#[from] anyhow::Error), -} - /// Either a built-in icon name or a path to an external SVG. #[derive(Debug, Clone, PartialEq, Eq)] pub enum IconOrSvg { @@ -692,18 +325,6 @@ pub trait LanguageModelProviderState: 'static { } } -#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)] -pub struct LanguageModelId(pub SharedString); - -#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)] -pub struct LanguageModelName(pub SharedString); - -#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)] -pub struct LanguageModelProviderId(pub SharedString); - -#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)] -pub struct LanguageModelProviderName(pub SharedString); - #[derive(Clone, Debug, PartialEq)] pub enum LanguageModelCostInfo { /// Cost per 1,000 input and output tokens @@ -741,245 +362,3 @@ impl LanguageModelCostInfo { } } } - -impl LanguageModelProviderId { - pub const fn new(id: &'static str) -> Self { - Self(SharedString::new_static(id)) - } -} - -impl LanguageModelProviderName { - pub const fn new(id: &'static str) -> Self { - Self(SharedString::new_static(id)) - } -} - -impl fmt::Display for LanguageModelProviderId { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.0) - } -} - -impl fmt::Display for LanguageModelProviderName { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.0) - } -} - -impl From for LanguageModelId { - fn from(value: String) -> Self { - Self(SharedString::from(value)) - } -} - -impl From for LanguageModelName { - fn from(value: String) -> Self { - Self(SharedString::from(value)) - } -} - -impl From for LanguageModelProviderId { - fn from(value: String) -> Self { - Self(SharedString::from(value)) - } -} - -impl From for LanguageModelProviderName { - fn from(value: String) -> Self { - Self(SharedString::from(value)) - } -} - -impl From> for LanguageModelProviderId { - fn from(value: Arc) -> Self { - Self(SharedString::from(value)) - } -} - -impl From> for LanguageModelProviderName { - fn from(value: Arc) -> Self { - Self(SharedString::from(value)) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_from_cloud_failure_with_upstream_http_error() { - let error = LanguageModelCompletionError::from_cloud_failure( - String::from("anthropic").into(), - "upstream_http_error".to_string(), - r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(), - None, - ); - - match error { - LanguageModelCompletionError::ServerOverloaded { provider, .. } => { - assert_eq!(provider.0, "anthropic"); - } - _ => panic!( - "Expected ServerOverloaded error for 503 status, got: {:?}", - error - ), - } - - let error = LanguageModelCompletionError::from_cloud_failure( - String::from("anthropic").into(), - "upstream_http_error".to_string(), - r#"{"code":"upstream_http_error","message":"Internal server error","upstream_status":500}"#.to_string(), - None, - ); - - match error { - LanguageModelCompletionError::ApiInternalServerError { provider, message } => { - assert_eq!(provider.0, "anthropic"); - assert_eq!(message, "Internal server error"); - } - _ => panic!( - "Expected ApiInternalServerError for 500 status, got: {:?}", - error - ), - } - } - - #[test] - fn test_from_cloud_failure_with_standard_format() { - let error = LanguageModelCompletionError::from_cloud_failure( - String::from("anthropic").into(), - "upstream_http_503".to_string(), - "Service unavailable".to_string(), - None, - ); - - match error { - LanguageModelCompletionError::ServerOverloaded { provider, .. } => { - assert_eq!(provider.0, "anthropic"); - } - _ => panic!("Expected ServerOverloaded error for upstream_http_503"), - } - } - - #[test] - fn test_upstream_http_error_connection_timeout() { - let error = LanguageModelCompletionError::from_cloud_failure( - String::from("anthropic").into(), - "upstream_http_error".to_string(), - r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(), - None, - ); - - match error { - LanguageModelCompletionError::ServerOverloaded { provider, .. } => { - assert_eq!(provider.0, "anthropic"); - } - _ => panic!( - "Expected ServerOverloaded error for connection timeout with 503 status, got: {:?}", - error - ), - } - - let error = LanguageModelCompletionError::from_cloud_failure( - String::from("anthropic").into(), - "upstream_http_error".to_string(), - r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":500}"#.to_string(), - None, - ); - - match error { - LanguageModelCompletionError::ApiInternalServerError { provider, message } => { - assert_eq!(provider.0, "anthropic"); - assert_eq!( - message, - "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout" - ); - } - _ => panic!( - "Expected ApiInternalServerError for connection timeout with 500 status, got: {:?}", - error - ), - } - } - - #[test] - fn test_language_model_tool_use_serializes_with_signature() { - use serde_json::json; - - let tool_use = LanguageModelToolUse { - id: LanguageModelToolUseId::from("test_id"), - name: "test_tool".into(), - raw_input: json!({"arg": "value"}).to_string(), - input: json!({"arg": "value"}), - is_input_complete: true, - thought_signature: Some("test_signature".to_string()), - }; - - let serialized = serde_json::to_value(&tool_use).unwrap(); - - assert_eq!(serialized["id"], "test_id"); - assert_eq!(serialized["name"], "test_tool"); - assert_eq!(serialized["thought_signature"], "test_signature"); - } - - #[test] - fn test_language_model_tool_use_deserializes_with_missing_signature() { - use serde_json::json; - - let json = json!({ - "id": "test_id", - "name": "test_tool", - "raw_input": "{\"arg\":\"value\"}", - "input": {"arg": "value"}, - "is_input_complete": true - }); - - let tool_use: LanguageModelToolUse = serde_json::from_value(json).unwrap(); - - assert_eq!(tool_use.id, LanguageModelToolUseId::from("test_id")); - assert_eq!(tool_use.name.as_ref(), "test_tool"); - assert_eq!(tool_use.thought_signature, None); - } - - #[test] - fn test_language_model_tool_use_round_trip_with_signature() { - use serde_json::json; - - let original = LanguageModelToolUse { - id: LanguageModelToolUseId::from("round_trip_id"), - name: "round_trip_tool".into(), - raw_input: json!({"key": "value"}).to_string(), - input: json!({"key": "value"}), - is_input_complete: true, - thought_signature: Some("round_trip_sig".to_string()), - }; - - let serialized = serde_json::to_value(&original).unwrap(); - let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap(); - - assert_eq!(deserialized.id, original.id); - assert_eq!(deserialized.name, original.name); - assert_eq!(deserialized.thought_signature, original.thought_signature); - } - - #[test] - fn test_language_model_tool_use_round_trip_without_signature() { - use serde_json::json; - - let original = LanguageModelToolUse { - id: LanguageModelToolUseId::from("no_sig_id"), - name: "no_sig_tool".into(), - raw_input: json!({"arg": "value"}).to_string(), - input: json!({"arg": "value"}), - is_input_complete: true, - thought_signature: None, - }; - - let serialized = serde_json::to_value(&original).unwrap(); - let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap(); - - assert_eq!(deserialized.id, original.id); - assert_eq!(deserialized.name, original.name); - assert_eq!(deserialized.thought_signature, None); - } -} diff --git a/crates/language_model/src/model/cloud_model.rs b/crates/language_model/src/model/cloud_model.rs index db926aab1f70a46a4e70b1b67c2c9e4c4f465c2c..8cd71928b10fb1e86f3df40ca118305c198c094f 100644 --- a/crates/language_model/src/model/cloud_model.rs +++ b/crates/language_model/src/model/cloud_model.rs @@ -1,10 +1,5 @@ use std::fmt; -use std::sync::Arc; -use cloud_api_client::ClientApiError; -use cloud_api_client::CloudApiClient; -use cloud_api_types::OrganizationId; -use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard}; use thiserror::Error; #[derive(Error, Debug)] @@ -18,71 +13,3 @@ impl fmt::Display for PaymentRequiredError { ) } } - -#[derive(Clone, Default)] -pub struct LlmApiToken(Arc>>); - -impl LlmApiToken { - pub async fn acquire( - &self, - client: &CloudApiClient, - system_id: Option, - organization_id: Option, - ) -> Result { - let lock = self.0.upgradable_read().await; - if let Some(token) = lock.as_ref() { - Ok(token.to_string()) - } else { - Self::fetch( - RwLockUpgradableReadGuard::upgrade(lock).await, - client, - system_id, - organization_id, - ) - .await - } - } - - pub async fn refresh( - &self, - client: &CloudApiClient, - system_id: Option, - organization_id: Option, - ) -> Result { - Self::fetch(self.0.write().await, client, system_id, organization_id).await - } - - /// Clears the existing token before attempting to fetch a new one. - /// - /// Used when switching organizations so that a failed refresh doesn't - /// leave a token for the wrong organization. - pub async fn clear_and_refresh( - &self, - client: &CloudApiClient, - system_id: Option, - organization_id: Option, - ) -> Result { - let mut lock = self.0.write().await; - *lock = None; - Self::fetch(lock, client, system_id, organization_id).await - } - - async fn fetch( - mut lock: RwLockWriteGuard<'_, Option>, - client: &CloudApiClient, - system_id: Option, - organization_id: Option, - ) -> Result { - let result = client.create_llm_token(system_id, organization_id).await; - match result { - Ok(response) => { - *lock = Some(response.token.0.clone()); - Ok(response.token.0) - } - Err(err) => { - *lock = None; - Err(err) - } - } - } -} diff --git a/crates/language_model/src/provider.rs b/crates/language_model/src/provider.rs deleted file mode 100644 index 707d8e2d618894e2898e253450dbfbb5e9483bba..0000000000000000000000000000000000000000 --- a/crates/language_model/src/provider.rs +++ /dev/null @@ -1,12 +0,0 @@ -pub mod anthropic; -pub mod google; -pub mod open_ai; -pub mod open_router; -pub mod x_ai; -pub mod zed; - -pub use anthropic::*; -pub use google::*; -pub use open_ai::*; -pub use x_ai::*; -pub use zed::*; diff --git a/crates/language_model/src/provider/anthropic.rs b/crates/language_model/src/provider/anthropic.rs deleted file mode 100644 index 0878be2070fdbb9e57145684f59c962a32bb9fd2..0000000000000000000000000000000000000000 --- a/crates/language_model/src/provider/anthropic.rs +++ /dev/null @@ -1,80 +0,0 @@ -use crate::{LanguageModelCompletionError, LanguageModelProviderId, LanguageModelProviderName}; -use anthropic::AnthropicError; -pub use anthropic::parse_prompt_too_long; - -pub const ANTHROPIC_PROVIDER_ID: LanguageModelProviderId = - LanguageModelProviderId::new("anthropic"); -pub const ANTHROPIC_PROVIDER_NAME: LanguageModelProviderName = - LanguageModelProviderName::new("Anthropic"); - -impl From for LanguageModelCompletionError { - fn from(error: AnthropicError) -> Self { - let provider = ANTHROPIC_PROVIDER_NAME; - match error { - AnthropicError::SerializeRequest(error) => Self::SerializeRequest { provider, error }, - AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error }, - AnthropicError::HttpSend(error) => Self::HttpSend { provider, error }, - AnthropicError::DeserializeResponse(error) => { - Self::DeserializeResponse { provider, error } - } - AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { provider, error }, - AnthropicError::HttpResponseError { - status_code, - message, - } => Self::HttpResponseError { - provider, - status_code, - message, - }, - AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded { - provider, - retry_after: Some(retry_after), - }, - AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded { - provider, - retry_after, - }, - AnthropicError::ApiError(api_error) => api_error.into(), - } - } -} - -impl From for LanguageModelCompletionError { - fn from(error: anthropic::ApiError) -> Self { - use anthropic::ApiErrorCode::*; - let provider = ANTHROPIC_PROVIDER_NAME; - match error.code() { - Some(code) => match code { - InvalidRequestError => Self::BadRequestFormat { - provider, - message: error.message, - }, - AuthenticationError => Self::AuthenticationError { - provider, - message: error.message, - }, - PermissionError => Self::PermissionError { - provider, - message: error.message, - }, - NotFoundError => Self::ApiEndpointNotFound { provider }, - RequestTooLarge => Self::PromptTooLarge { - tokens: parse_prompt_too_long(&error.message), - }, - RateLimitError => Self::RateLimitExceeded { - provider, - retry_after: None, - }, - ApiError => Self::ApiInternalServerError { - provider, - message: error.message, - }, - OverloadedError => Self::ServerOverloaded { - provider, - retry_after: None, - }, - }, - None => Self::Other(error.into()), - } - } -} diff --git a/crates/language_model/src/provider/google.rs b/crates/language_model/src/provider/google.rs deleted file mode 100644 index 1caee496b519f395dd10744b127bc29ee893849f..0000000000000000000000000000000000000000 --- a/crates/language_model/src/provider/google.rs +++ /dev/null @@ -1,5 +0,0 @@ -use crate::{LanguageModelProviderId, LanguageModelProviderName}; - -pub const GOOGLE_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("google"); -pub const GOOGLE_PROVIDER_NAME: LanguageModelProviderName = - LanguageModelProviderName::new("Google AI"); diff --git a/crates/language_model/src/provider/open_ai.rs b/crates/language_model/src/provider/open_ai.rs deleted file mode 100644 index 3796eb9a3aef78628c52d92e92fabb3812249e04..0000000000000000000000000000000000000000 --- a/crates/language_model/src/provider/open_ai.rs +++ /dev/null @@ -1,28 +0,0 @@ -use crate::{LanguageModelCompletionError, LanguageModelProviderId, LanguageModelProviderName}; -use http_client::http; -use std::time::Duration; - -pub const OPEN_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openai"); -pub const OPEN_AI_PROVIDER_NAME: LanguageModelProviderName = - LanguageModelProviderName::new("OpenAI"); - -impl From for LanguageModelCompletionError { - fn from(error: open_ai::RequestError) -> Self { - match error { - open_ai::RequestError::HttpResponseError { - provider, - status_code, - body, - headers, - } => { - let retry_after = headers - .get(http::header::RETRY_AFTER) - .and_then(|val| val.to_str().ok()?.parse::().ok()) - .map(Duration::from_secs); - - Self::from_http_status(provider.into(), status_code, body, retry_after) - } - open_ai::RequestError::Other(e) => Self::Other(e), - } - } -} diff --git a/crates/language_model/src/provider/open_router.rs b/crates/language_model/src/provider/open_router.rs deleted file mode 100644 index 809e22f1fec0f2d205caa3ebbcb0baaf129b062c..0000000000000000000000000000000000000000 --- a/crates/language_model/src/provider/open_router.rs +++ /dev/null @@ -1,69 +0,0 @@ -use crate::{LanguageModelCompletionError, LanguageModelProviderName}; -use http_client::StatusCode; -use open_router::OpenRouterError; - -impl From for LanguageModelCompletionError { - fn from(error: OpenRouterError) -> Self { - let provider = LanguageModelProviderName::new("OpenRouter"); - match error { - OpenRouterError::SerializeRequest(error) => Self::SerializeRequest { provider, error }, - OpenRouterError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error }, - OpenRouterError::HttpSend(error) => Self::HttpSend { provider, error }, - OpenRouterError::DeserializeResponse(error) => { - Self::DeserializeResponse { provider, error } - } - OpenRouterError::ReadResponse(error) => Self::ApiReadResponseError { provider, error }, - OpenRouterError::RateLimit { retry_after } => Self::RateLimitExceeded { - provider, - retry_after: Some(retry_after), - }, - OpenRouterError::ServerOverloaded { retry_after } => Self::ServerOverloaded { - provider, - retry_after, - }, - OpenRouterError::ApiError(api_error) => api_error.into(), - } - } -} - -impl From for LanguageModelCompletionError { - fn from(error: open_router::ApiError) -> Self { - use open_router::ApiErrorCode::*; - let provider = LanguageModelProviderName::new("OpenRouter"); - match error.code { - InvalidRequestError => Self::BadRequestFormat { - provider, - message: error.message, - }, - AuthenticationError => Self::AuthenticationError { - provider, - message: error.message, - }, - PaymentRequiredError => Self::AuthenticationError { - provider, - message: format!("Payment required: {}", error.message), - }, - PermissionError => Self::PermissionError { - provider, - message: error.message, - }, - RequestTimedOut => Self::HttpResponseError { - provider, - status_code: StatusCode::REQUEST_TIMEOUT, - message: error.message, - }, - RateLimitError => Self::RateLimitExceeded { - provider, - retry_after: None, - }, - ApiError => Self::ApiInternalServerError { - provider, - message: error.message, - }, - OverloadedError => Self::ServerOverloaded { - provider, - retry_after: None, - }, - } - } -} diff --git a/crates/language_model/src/provider/x_ai.rs b/crates/language_model/src/provider/x_ai.rs deleted file mode 100644 index 3d0f794fa4087a4beeb4a9b6253d016a9b592f0e..0000000000000000000000000000000000000000 --- a/crates/language_model/src/provider/x_ai.rs +++ /dev/null @@ -1,4 +0,0 @@ -use crate::{LanguageModelProviderId, LanguageModelProviderName}; - -pub const X_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("x_ai"); -pub const X_AI_PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("xAI"); diff --git a/crates/language_model/src/provider/zed.rs b/crates/language_model/src/provider/zed.rs deleted file mode 100644 index 0ba793e99aad1caa25f049a96faf02c16e8970fa..0000000000000000000000000000000000000000 --- a/crates/language_model/src/provider/zed.rs +++ /dev/null @@ -1,5 +0,0 @@ -use crate::{LanguageModelProviderId, LanguageModelProviderName}; - -pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("zed.dev"); -pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName = - LanguageModelProviderName::new("Zed"); diff --git a/crates/language_model/src/registry.rs b/crates/language_model/src/registry.rs index bf14fbb0b5804505b33074e6e4cbcc36ddf21fab..680078808ab33cc2a90caead8b304326beccf11b 100644 --- a/crates/language_model/src/registry.rs +++ b/crates/language_model/src/registry.rs @@ -1,6 +1,6 @@ use crate::{ LanguageModel, LanguageModelId, LanguageModelProvider, LanguageModelProviderId, - LanguageModelProviderState, + LanguageModelProviderState, ZED_CLOUD_PROVIDER_ID, }; use collections::{BTreeMap, HashSet}; use gpui::{App, Context, Entity, EventEmitter, Global, prelude::*}; @@ -101,7 +101,7 @@ impl ConfiguredModel { } pub fn is_provided_by_zed(&self) -> bool { - self.provider.id() == crate::provider::ZED_CLOUD_PROVIDER_ID + self.provider.id() == ZED_CLOUD_PROVIDER_ID } } diff --git a/crates/language_model/src/request.rs b/crates/language_model/src/request.rs index 9a5e96078cd4d952185261c79032c5c5fdf30060..ef73864fe3e2f5b58e73dec848c686123a61fcde 100644 --- a/crates/language_model/src/request.rs +++ b/crates/language_model/src/request.rs @@ -4,78 +4,13 @@ use std::sync::Arc; use anyhow::Result; use base64::write::EncoderWriter; use gpui::{ - App, AppContext as _, DevicePixels, Image, ImageFormat, ObjectFit, SharedString, Size, Task, - point, px, size, + App, AppContext as _, DevicePixels, Image, ImageFormat, ObjectFit, Size, Task, point, px, size, }; use image::GenericImageView as _; use image::codecs::png::PngEncoder; -use serde::{Deserialize, Serialize}; use util::ResultExt; -use crate::role::Role; -use crate::{LanguageModelToolUse, LanguageModelToolUseId}; - -#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Hash)] -pub struct LanguageModelImage { - /// A base64-encoded PNG image. - pub source: SharedString, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub size: Option>, -} - -impl LanguageModelImage { - pub fn len(&self) -> usize { - self.source.len() - } - - pub fn is_empty(&self) -> bool { - self.source.is_empty() - } - - // Parse Self from a JSON object with case-insensitive field names - pub fn from_json(obj: &serde_json::Map) -> Option { - let mut source = None; - let mut size_obj = None; - - // Find source and size fields (case-insensitive) - for (k, v) in obj.iter() { - match k.to_lowercase().as_str() { - "source" => source = v.as_str(), - "size" => size_obj = v.as_object(), - _ => {} - } - } - - let source = source?; - let size_obj = size_obj?; - - let mut width = None; - let mut height = None; - - // Find width and height in size object (case-insensitive) - for (k, v) in size_obj.iter() { - match k.to_lowercase().as_str() { - "width" => width = v.as_i64().map(|w| w as i32), - "height" => height = v.as_i64().map(|h| h as i32), - _ => {} - } - } - - Some(Self { - size: Some(size(DevicePixels(width?), DevicePixels(height?))), - source: SharedString::from(source.to_string()), - }) - } -} - -impl std::fmt::Debug for LanguageModelImage { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("LanguageModelImage") - .field("source", &format!("<{} bytes>", self.source.len())) - .field("size", &self.size) - .finish() - } -} +use language_model_core::{ImageSize, LanguageModelImage}; /// Anthropic wants uploaded images to be smaller than this in both dimensions. const ANTHROPIC_SIZE_LIMIT: f32 = 1568.; @@ -90,18 +25,16 @@ const DEFAULT_IMAGE_MAX_BYTES: usize = 5 * 1024 * 1024; /// `DEFAULT_IMAGE_MAX_BYTES`. const MAX_IMAGE_DOWNSCALE_PASSES: usize = 8; -impl LanguageModelImage { - // All language model images are encoded as PNGs. - pub const FORMAT: ImageFormat = ImageFormat::Png; +/// Extension trait for `LanguageModelImage` that provides GPUI-dependent functionality. +pub trait LanguageModelImageExt { + const FORMAT: ImageFormat; + fn from_image(data: Arc, cx: &mut App) -> Task>; +} - pub fn empty() -> Self { - Self { - source: "".into(), - size: None, - } - } +impl LanguageModelImageExt for LanguageModelImage { + const FORMAT: ImageFormat = ImageFormat::Png; - pub fn from_image(data: Arc, cx: &mut App) -> Task> { + fn from_image(data: Arc, cx: &mut App) -> Task> { cx.background_spawn(async move { let image_bytes = Cursor::new(data.bytes()); let dynamic_image = match data.format() { @@ -186,28 +119,14 @@ impl LanguageModelImage { let source = unsafe { String::from_utf8_unchecked(base64_image) }; Some(LanguageModelImage { - size: Some(image_size), + size: Some(ImageSize { + width: width as i32, + height: height as i32, + }), source: source.into(), }) }) } - - pub fn estimate_tokens(&self) -> usize { - let Some(size) = self.size.as_ref() else { - return 0; - }; - let width = size.width.0.unsigned_abs() as usize; - let height = size.height.0.unsigned_abs() as usize; - - // From: https://docs.anthropic.com/en/docs/build-with-claude/vision#calculate-image-costs - // Note that are a lot of conditions on Anthropic's API, and OpenAI doesn't use this, - // so this method is more of a rough guess. - (width * height) / 750 - } - - pub fn to_base64_url(&self) -> String { - format!("data:image/png;base64,{}", self.source) - } } fn encode_png_bytes(image: &image::DynamicImage) -> Result> { @@ -228,512 +147,85 @@ fn encode_bytes_as_base64(bytes: &[u8]) -> Result> { Ok(base64_image) } -#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)] -pub struct LanguageModelToolResult { - pub tool_use_id: LanguageModelToolUseId, - pub tool_name: Arc, - pub is_error: bool, - /// The tool output formatted for presenting to the model - pub content: LanguageModelToolResultContent, - /// The raw tool output, if available, often for debugging or extra state for replay - pub output: Option, -} - -#[derive(Debug, Clone, Serialize, Eq, PartialEq, Hash)] -pub enum LanguageModelToolResultContent { - Text(Arc), - Image(LanguageModelImage), -} - -impl<'de> Deserialize<'de> for LanguageModelToolResultContent { - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - use serde::de::Error; - - let value = serde_json::Value::deserialize(deserializer)?; - - // Models can provide these responses in several styles. Try each in order. - - // 1. Try as plain string - if let Ok(text) = serde_json::from_value::(value.clone()) { - return Ok(Self::Text(Arc::from(text))); - } - - // 2. Try as object - if let Some(obj) = value.as_object() { - // get a JSON field case-insensitively - fn get_field<'a>( - obj: &'a serde_json::Map, - field: &str, - ) -> Option<&'a serde_json::Value> { - obj.iter() - .find(|(k, _)| k.to_lowercase() == field.to_lowercase()) - .map(|(_, v)| v) - } - - // Accept wrapped text format: { "type": "text", "text": "..." } - if let (Some(type_value), Some(text_value)) = - (get_field(obj, "type"), get_field(obj, "text")) - && let Some(type_str) = type_value.as_str() - && type_str.to_lowercase() == "text" - && let Some(text) = text_value.as_str() - { - return Ok(Self::Text(Arc::from(text))); - } - - // Check for wrapped Text variant: { "text": "..." } - if let Some((_key, value)) = obj.iter().find(|(k, _)| k.to_lowercase() == "text") - && obj.len() == 1 - { - // Only one field, and it's "text" (case-insensitive) - if let Some(text) = value.as_str() { - return Ok(Self::Text(Arc::from(text))); - } - } - - // Check for wrapped Image variant: { "image": { "source": "...", "size": ... } } - if let Some((_key, value)) = obj.iter().find(|(k, _)| k.to_lowercase() == "image") - && obj.len() == 1 - { - // Only one field, and it's "image" (case-insensitive) - // Try to parse the nested image object - if let Some(image_obj) = value.as_object() - && let Some(image) = LanguageModelImage::from_json(image_obj) - { - return Ok(Self::Image(image)); - } - } - - // Try as direct Image (object with "source" and "size" fields) - if let Some(image) = LanguageModelImage::from_json(obj) { - return Ok(Self::Image(image)); - } - } - - // If none of the variants match, return an error with the problematic JSON - Err(D::Error::custom(format!( - "data did not match any variant of LanguageModelToolResultContent. Expected either a string, \ - an object with 'type': 'text', a wrapped variant like {{\"Text\": \"...\"}}, or an image object. Got: {}", - serde_json::to_string_pretty(&value).unwrap_or_else(|_| value.to_string()) - ))) - } -} - -impl LanguageModelToolResultContent { - pub fn to_str(&self) -> Option<&str> { - match self { - Self::Text(text) => Some(text), - Self::Image(_) => None, - } - } - - pub fn is_empty(&self) -> bool { - match self { - Self::Text(text) => text.chars().all(|c| c.is_whitespace()), - Self::Image(_) => false, - } - } -} - -impl From<&str> for LanguageModelToolResultContent { - fn from(value: &str) -> Self { - Self::Text(Arc::from(value)) - } -} - -impl From for LanguageModelToolResultContent { - fn from(value: String) -> Self { - Self::Text(Arc::from(value)) - } -} - -impl From for LanguageModelToolResultContent { - fn from(image: LanguageModelImage) -> Self { - Self::Image(image) - } -} - -#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)] -pub enum MessageContent { - Text(String), - Thinking { - text: String, - signature: Option, - }, - RedactedThinking(String), - Image(LanguageModelImage), - ToolUse(LanguageModelToolUse), - ToolResult(LanguageModelToolResult), -} - -impl MessageContent { - pub fn to_str(&self) -> Option<&str> { - match self { - MessageContent::Text(text) => Some(text.as_str()), - MessageContent::Thinking { text, .. } => Some(text.as_str()), - MessageContent::RedactedThinking(_) => None, - MessageContent::ToolResult(tool_result) => tool_result.content.to_str(), - MessageContent::ToolUse(_) | MessageContent::Image(_) => None, - } - } - - pub fn is_empty(&self) -> bool { - match self { - MessageContent::Text(text) => text.chars().all(|c| c.is_whitespace()), - MessageContent::Thinking { text, .. } => text.chars().all(|c| c.is_whitespace()), - MessageContent::ToolResult(tool_result) => tool_result.content.is_empty(), - MessageContent::RedactedThinking(_) - | MessageContent::ToolUse(_) - | MessageContent::Image(_) => false, - } - } -} - -impl From for MessageContent { - fn from(value: String) -> Self { - MessageContent::Text(value) - } -} - -impl From<&str> for MessageContent { - fn from(value: &str) -> Self { - MessageContent::Text(value.to_string()) - } -} - -#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Hash)] -pub struct LanguageModelRequestMessage { - pub role: Role, - pub content: Vec, - pub cache: bool, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub reasoning_details: Option, -} - -impl LanguageModelRequestMessage { - pub fn string_contents(&self) -> String { - let mut buffer = String::new(); - for string in self.content.iter().filter_map(|content| content.to_str()) { - buffer.push_str(string); - } - - buffer - } - - pub fn contents_empty(&self) -> bool { - self.content.iter().all(|content| content.is_empty()) - } -} - -#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)] -pub struct LanguageModelRequestTool { - pub name: String, - pub description: String, - pub input_schema: serde_json::Value, - pub use_input_streaming: bool, -} - -#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)] -pub enum LanguageModelToolChoice { - Auto, - Any, - None, -} - -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)] -#[serde(rename_all = "snake_case")] -pub enum CompletionIntent { - UserPrompt, - Subagent, - ToolResults, - ThreadSummarization, - ThreadContextSummarization, - CreateFile, - EditFile, - InlineAssist, - TerminalInlineAssist, - GenerateGitCommitMessage, -} - -#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)] -pub struct LanguageModelRequest { - pub thread_id: Option, - pub prompt_id: Option, - pub intent: Option, - pub messages: Vec, - pub tools: Vec, - pub tool_choice: Option, - pub stop: Vec, - pub temperature: Option, - pub thinking_allowed: bool, - pub thinking_effort: Option, - pub speed: Option, -} - -#[derive(Clone, Copy, Default, Debug, Serialize, Deserialize, PartialEq, Eq)] -#[serde(rename_all = "snake_case")] -pub enum Speed { - #[default] - Standard, - Fast, -} - -impl Speed { - pub fn toggle(self) -> Self { - match self { - Speed::Standard => Speed::Fast, - Speed::Fast => Speed::Standard, - } +/// Convert a core `ImageSize` to a gpui `Size`. +pub fn image_size_to_gpui(size: ImageSize) -> Size { + Size { + width: DevicePixels(size.width), + height: DevicePixels(size.height), } } -impl From for anthropic::Speed { - fn from(speed: Speed) -> Self { - match speed { - Speed::Standard => anthropic::Speed::Standard, - Speed::Fast => anthropic::Speed::Fast, - } +/// Convert a gpui `Size` to a core `ImageSize`. +pub fn gpui_size_to_image_size(size: Size) -> ImageSize { + ImageSize { + width: size.width.0, + height: size.height.0, } } -#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] -pub struct LanguageModelResponseMessage { - pub role: Option, - pub content: Option, -} - #[cfg(test)] mod tests { use super::*; use base64::Engine as _; use gpui::TestAppContext; - use image::ImageDecoder as _; - fn base64_to_png_bytes(base64_png: &str) -> Vec { + fn base64_to_png_bytes(base64: &str) -> Vec { base64::engine::general_purpose::STANDARD - .decode(base64_png.as_bytes()) - .expect("base64 should decode") + .decode(base64) + .expect("valid base64") } fn png_dimensions(png_bytes: &[u8]) -> (u32, u32) { - let decoder = - image::codecs::png::PngDecoder::new(Cursor::new(png_bytes)).expect("png should decode"); - decoder.dimensions() + let img = image::load_from_memory(png_bytes).expect("valid png"); + (img.width(), img.height()) } fn make_noisy_png_bytes(width: u32, height: u32) -> Vec { - // Create an RGBA image with per-pixel variance to avoid PNG compressing too well. - let mut img = image::RgbaImage::new(width, height); - for y in 0..height { - for x in 0..width { - let r = ((x ^ y) & 0xFF) as u8; - let g = ((x.wrapping_mul(31) ^ y.wrapping_mul(17)) & 0xFF) as u8; - let b = ((x.wrapping_mul(131) ^ y.wrapping_mul(7)) & 0xFF) as u8; - img.put_pixel(x, y, image::Rgba([r, g, b, 0xFF])); - } - } + use image::{ImageBuffer, Rgba}; + use std::hash::{Hash, Hasher}; + + let img = ImageBuffer::from_fn(width, height, |x, y| { + let mut hasher = std::hash::DefaultHasher::new(); + (x, y, width, height).hash(&mut hasher); + let h = hasher.finish(); + Rgba([h as u8, (h >> 8) as u8, (h >> 16) as u8, 255]) + }); - let mut out = Vec::new(); - image::DynamicImage::ImageRgba8(img) - .write_with_encoder(PngEncoder::new(&mut out)) - .expect("png encoding should succeed"); - out + let mut buf = Cursor::new(Vec::new()); + img.write_with_encoder(PngEncoder::new(&mut buf)) + .expect("encode"); + buf.into_inner() } #[gpui::test] async fn test_from_image_downscales_to_default_5mb_limit(cx: &mut TestAppContext) { - // Pick a size that reliably produces a PNG > 5MB when filled with noise. - // If this fails (image is too small), bump dimensions. - let original_png = make_noisy_png_bytes(4096, 4096); + let raw_png = make_noisy_png_bytes(4096, 4096); assert!( - original_png.len() > DEFAULT_IMAGE_MAX_BYTES, - "precondition failed: noisy PNG must exceed DEFAULT_IMAGE_MAX_BYTES" + raw_png.len() > DEFAULT_IMAGE_MAX_BYTES, + "Test image should exceed the 5 MB limit (actual: {} bytes)", + raw_png.len() ); - let image = gpui::Image::from_bytes(ImageFormat::Png, original_png); + let image = Arc::new(gpui::Image::from_bytes(ImageFormat::Png, raw_png)); let lm_image = cx - .update(|cx| LanguageModelImage::from_image(Arc::new(image), cx)) + .update(|cx| LanguageModelImage::from_image(Arc::clone(&image), cx)) .await - .expect("image conversion should succeed"); + .expect("from_image should succeed"); - let encoded_png = base64_to_png_bytes(lm_image.source.as_ref()); + let decoded_png = base64_to_png_bytes(lm_image.source.as_ref()); assert!( - encoded_png.len() <= DEFAULT_IMAGE_MAX_BYTES, - "expected encoded PNG <= DEFAULT_IMAGE_MAX_BYTES, got {} bytes", - encoded_png.len() + decoded_png.len() <= DEFAULT_IMAGE_MAX_BYTES, + "Encoded PNG should be ≤ {} bytes after downscale, but was {} bytes", + DEFAULT_IMAGE_MAX_BYTES, + decoded_png.len() ); - // Ensure we actually downscaled in pixels (not just re-encoded). - let (w, h) = png_dimensions(&encoded_png); + let (w, h) = png_dimensions(&decoded_png); assert!( - w < 4096 || h < 4096, - "expected image to be downscaled in at least one dimension; got {w}x{h}" - ); - } - - #[test] - fn test_language_model_tool_result_content_deserialization() { - let json = r#""This is plain text""#; - let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); - assert_eq!( - result, - LanguageModelToolResultContent::Text("This is plain text".into()) - ); - - let json = r#"{"type": "text", "text": "This is wrapped text"}"#; - let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); - assert_eq!( - result, - LanguageModelToolResultContent::Text("This is wrapped text".into()) - ); - - let json = r#"{"Type": "TEXT", "TEXT": "Case insensitive"}"#; - let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); - assert_eq!( - result, - LanguageModelToolResultContent::Text("Case insensitive".into()) - ); - - let json = r#"{"Text": "Wrapped variant"}"#; - let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); - assert_eq!( - result, - LanguageModelToolResultContent::Text("Wrapped variant".into()) - ); - - let json = r#"{"text": "Lowercase wrapped"}"#; - let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); - assert_eq!( - result, - LanguageModelToolResultContent::Text("Lowercase wrapped".into()) + w < 4096 && h < 4096, + "Dimensions should have shrunk: got {}Ă—{}", + w, + h ); - - // Test image deserialization - let json = r#"{ - "source": "base64encodedimagedata", - "size": { - "width": 100, - "height": 200 - } - }"#; - let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); - match result { - LanguageModelToolResultContent::Image(image) => { - assert_eq!(image.source.as_ref(), "base64encodedimagedata"); - let size = image.size.expect("size"); - assert_eq!(size.width.0, 100); - assert_eq!(size.height.0, 200); - } - _ => panic!("Expected Image variant"), - } - - // Test wrapped Image variant - let json = r#"{ - "Image": { - "source": "wrappedimagedata", - "size": { - "width": 50, - "height": 75 - } - } - }"#; - let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); - match result { - LanguageModelToolResultContent::Image(image) => { - assert_eq!(image.source.as_ref(), "wrappedimagedata"); - let size = image.size.expect("size"); - assert_eq!(size.width.0, 50); - assert_eq!(size.height.0, 75); - } - _ => panic!("Expected Image variant"), - } - - // Test wrapped Image variant with case insensitive - let json = r#"{ - "image": { - "Source": "caseinsensitive", - "SIZE": { - "width": 30, - "height": 40 - } - } - }"#; - let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); - match result { - LanguageModelToolResultContent::Image(image) => { - assert_eq!(image.source.as_ref(), "caseinsensitive"); - let size = image.size.expect("size"); - assert_eq!(size.width.0, 30); - assert_eq!(size.height.0, 40); - } - _ => panic!("Expected Image variant"), - } - - // Test that wrapped text with wrong type fails - let json = r#"{"type": "blahblah", "text": "This should fail"}"#; - let result: Result = serde_json::from_str(json); - assert!(result.is_err()); - - // Test that malformed JSON fails - let json = r#"{"invalid": "structure"}"#; - let result: Result = serde_json::from_str(json); - assert!(result.is_err()); - - // Test edge cases - let json = r#""""#; // Empty string - let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); - assert_eq!(result, LanguageModelToolResultContent::Text("".into())); - - // Test with extra fields in wrapped text (should be ignored) - let json = r#"{"type": "text", "text": "Hello", "extra": "field"}"#; - let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); - assert_eq!(result, LanguageModelToolResultContent::Text("Hello".into())); - - // Test direct image with case-insensitive fields - let json = r#"{ - "SOURCE": "directimage", - "Size": { - "width": 200, - "height": 300 - } - }"#; - let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); - match result { - LanguageModelToolResultContent::Image(image) => { - assert_eq!(image.source.as_ref(), "directimage"); - let size = image.size.expect("size"); - assert_eq!(size.width.0, 200); - assert_eq!(size.height.0, 300); - } - _ => panic!("Expected Image variant"), - } - - // Test that multiple fields prevent wrapped variant interpretation - let json = r#"{"Text": "not wrapped", "extra": "field"}"#; - let result: Result = serde_json::from_str(json); - assert!(result.is_err()); - - // Test wrapped text with uppercase TEXT variant - let json = r#"{"TEXT": "Uppercase variant"}"#; - let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap(); - assert_eq!( - result, - LanguageModelToolResultContent::Text("Uppercase variant".into()) - ); - - // Test that numbers and other JSON values fail gracefully - let json = r#"123"#; - let result: Result = serde_json::from_str(json); - assert!(result.is_err()); - - let json = r#"null"#; - let result: Result = serde_json::from_str(json); - assert!(result.is_err()); - - let json = r#"[1, 2, 3]"#; - let result: Result = serde_json::from_str(json); - assert!(result.is_err()); } } diff --git a/crates/language_model_core/Cargo.toml b/crates/language_model_core/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..7a6de00f3e4a774537d93e2f77ea9107845a7c50 --- /dev/null +++ b/crates/language_model_core/Cargo.toml @@ -0,0 +1,27 @@ +[package] +name = "language_model_core" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[lib] +path = "src/language_model_core.rs" +doctest = false + +[dependencies] +anyhow.workspace = true +cloud_llm_client.workspace = true +futures.workspace = true +gpui_shared_string.workspace = true +http_client.workspace = true +partial-json-fixer.workspace = true +schemars.workspace = true +serde.workspace = true +serde_json.workspace = true +smol.workspace = true +strum.workspace = true +thiserror.workspace = true diff --git a/crates/language_model_core/LICENSE-GPL b/crates/language_model_core/LICENSE-GPL new file mode 120000 index 0000000000000000000000000000000000000000..89e542f750cd3860a0598eff0dc34b56d7336dc4 --- /dev/null +++ b/crates/language_model_core/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/language_model_core/src/language_model_core.rs b/crates/language_model_core/src/language_model_core.rs new file mode 100644 index 0000000000000000000000000000000000000000..5f932690869a2c17ec1c89cbe9401bcdef6e1e73 --- /dev/null +++ b/crates/language_model_core/src/language_model_core.rs @@ -0,0 +1,658 @@ +mod provider; +mod rate_limiter; +mod request; +mod role; +pub mod tool_schema; +pub mod util; + +use anyhow::{Result, anyhow}; +use cloud_llm_client::CompletionRequestStatus; +use http_client::{StatusCode, http}; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use std::ops::{Add, Sub}; +use std::str::FromStr; +use std::sync::Arc; +use std::time::Duration; +use std::{fmt, io}; +use thiserror::Error; +fn is_default(value: &T) -> bool { + *value == T::default() +} + +pub use crate::provider::*; +pub use crate::rate_limiter::*; +pub use crate::request::*; +pub use crate::role::*; +pub use crate::tool_schema::LanguageModelToolSchemaFormat; +pub use crate::util::{fix_streamed_json, parse_prompt_too_long, parse_tool_arguments}; +pub use gpui_shared_string::SharedString; + +#[derive(Clone, Debug)] +pub struct LanguageModelCacheConfiguration { + pub max_cache_anchors: usize, + pub should_speculate: bool, + pub min_total_token: u64, +} + +/// A completion event from a language model. +#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] +pub enum LanguageModelCompletionEvent { + Queued { + position: usize, + }, + Started, + Stop(StopReason), + Text(String), + Thinking { + text: String, + signature: Option, + }, + RedactedThinking { + data: String, + }, + ToolUse(LanguageModelToolUse), + ToolUseJsonParseError { + id: LanguageModelToolUseId, + tool_name: Arc, + raw_input: Arc, + json_parse_error: String, + }, + StartMessage { + message_id: String, + }, + ReasoningDetails(serde_json::Value), + UsageUpdate(TokenUsage), +} + +impl LanguageModelCompletionEvent { + pub fn from_completion_request_status( + status: CompletionRequestStatus, + upstream_provider: LanguageModelProviderName, + ) -> Result, LanguageModelCompletionError> { + match status { + CompletionRequestStatus::Queued { position } => { + Ok(Some(LanguageModelCompletionEvent::Queued { position })) + } + CompletionRequestStatus::Started => Ok(Some(LanguageModelCompletionEvent::Started)), + CompletionRequestStatus::Unknown | CompletionRequestStatus::StreamEnded => Ok(None), + CompletionRequestStatus::Failed { + code, + message, + request_id: _, + retry_after, + } => Err(LanguageModelCompletionError::from_cloud_failure( + upstream_provider, + code, + message, + retry_after.map(Duration::from_secs_f64), + )), + } + } +} + +#[derive(Error, Debug)] +pub enum LanguageModelCompletionError { + #[error("prompt too large for context window")] + PromptTooLarge { tokens: Option }, + #[error("missing {provider} API key")] + NoApiKey { provider: LanguageModelProviderName }, + #[error("{provider}'s API rate limit exceeded")] + RateLimitExceeded { + provider: LanguageModelProviderName, + retry_after: Option, + }, + #[error("{provider}'s API servers are overloaded right now")] + ServerOverloaded { + provider: LanguageModelProviderName, + retry_after: Option, + }, + #[error("{provider}'s API server reported an internal server error: {message}")] + ApiInternalServerError { + provider: LanguageModelProviderName, + message: String, + }, + #[error("{message}")] + UpstreamProviderError { + message: String, + status: StatusCode, + retry_after: Option, + }, + #[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")] + HttpResponseError { + provider: LanguageModelProviderName, + status_code: StatusCode, + message: String, + }, + #[error("invalid request format to {provider}'s API: {message}")] + BadRequestFormat { + provider: LanguageModelProviderName, + message: String, + }, + #[error("authentication error with {provider}'s API: {message}")] + AuthenticationError { + provider: LanguageModelProviderName, + message: String, + }, + #[error("Permission error with {provider}'s API: {message}")] + PermissionError { + provider: LanguageModelProviderName, + message: String, + }, + #[error("language model provider API endpoint not found")] + ApiEndpointNotFound { provider: LanguageModelProviderName }, + #[error("I/O error reading response from {provider}'s API")] + ApiReadResponseError { + provider: LanguageModelProviderName, + #[source] + error: io::Error, + }, + #[error("error serializing request to {provider} API")] + SerializeRequest { + provider: LanguageModelProviderName, + #[source] + error: serde_json::Error, + }, + #[error("error building request body to {provider} API")] + BuildRequestBody { + provider: LanguageModelProviderName, + #[source] + error: http::Error, + }, + #[error("error sending HTTP request to {provider} API")] + HttpSend { + provider: LanguageModelProviderName, + #[source] + error: anyhow::Error, + }, + #[error("error deserializing {provider} API response")] + DeserializeResponse { + provider: LanguageModelProviderName, + #[source] + error: serde_json::Error, + }, + #[error("stream from {provider} ended unexpectedly")] + StreamEndedUnexpectedly { provider: LanguageModelProviderName }, + #[error(transparent)] + Other(#[from] anyhow::Error), +} + +impl LanguageModelCompletionError { + fn parse_upstream_error_json(message: &str) -> Option<(StatusCode, String)> { + let error_json = serde_json::from_str::(message).ok()?; + let upstream_status = error_json + .get("upstream_status") + .and_then(|v| v.as_u64()) + .and_then(|status| u16::try_from(status).ok()) + .and_then(|status| StatusCode::from_u16(status).ok())?; + let inner_message = error_json + .get("message") + .and_then(|v| v.as_str()) + .unwrap_or(message) + .to_string(); + Some((upstream_status, inner_message)) + } + + pub fn from_cloud_failure( + upstream_provider: LanguageModelProviderName, + code: String, + message: String, + retry_after: Option, + ) -> Self { + if let Some(tokens) = parse_prompt_too_long(&message) { + Self::PromptTooLarge { + tokens: Some(tokens), + } + } else if code == "upstream_http_error" { + if let Some((upstream_status, inner_message)) = + Self::parse_upstream_error_json(&message) + { + return Self::from_http_status( + upstream_provider, + upstream_status, + inner_message, + retry_after, + ); + } + anyhow!("completion request failed, code: {code}, message: {message}").into() + } else if let Some(status_code) = code + .strip_prefix("upstream_http_") + .and_then(|code| StatusCode::from_str(code).ok()) + { + Self::from_http_status(upstream_provider, status_code, message, retry_after) + } else if let Some(status_code) = code + .strip_prefix("http_") + .and_then(|code| StatusCode::from_str(code).ok()) + { + Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after) + } else { + anyhow!("completion request failed, code: {code}, message: {message}").into() + } + } + + pub fn from_http_status( + provider: LanguageModelProviderName, + status_code: StatusCode, + message: String, + retry_after: Option, + ) -> Self { + match status_code { + StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message }, + StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message }, + StatusCode::FORBIDDEN => Self::PermissionError { provider, message }, + StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider }, + StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge { + tokens: parse_prompt_too_long(&message), + }, + StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded { + provider, + retry_after, + }, + StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message }, + StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded { + provider, + retry_after, + }, + _ if status_code.as_u16() == 529 => Self::ServerOverloaded { + provider, + retry_after, + }, + _ => Self::HttpResponseError { + provider, + status_code, + message, + }, + } + } +} + +#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum StopReason { + EndTurn, + MaxTokens, + ToolUse, + Refusal, +} + +#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)] +pub struct TokenUsage { + #[serde(default, skip_serializing_if = "is_default")] + pub input_tokens: u64, + #[serde(default, skip_serializing_if = "is_default")] + pub output_tokens: u64, + #[serde(default, skip_serializing_if = "is_default")] + pub cache_creation_input_tokens: u64, + #[serde(default, skip_serializing_if = "is_default")] + pub cache_read_input_tokens: u64, +} + +impl TokenUsage { + pub fn total_tokens(&self) -> u64 { + self.input_tokens + + self.output_tokens + + self.cache_read_input_tokens + + self.cache_creation_input_tokens + } +} + +impl Add for TokenUsage { + type Output = Self; + + fn add(self, other: Self) -> Self { + Self { + input_tokens: self.input_tokens + other.input_tokens, + output_tokens: self.output_tokens + other.output_tokens, + cache_creation_input_tokens: self.cache_creation_input_tokens + + other.cache_creation_input_tokens, + cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens, + } + } +} + +impl Sub for TokenUsage { + type Output = Self; + + fn sub(self, other: Self) -> Self { + Self { + input_tokens: self.input_tokens - other.input_tokens, + output_tokens: self.output_tokens - other.output_tokens, + cache_creation_input_tokens: self.cache_creation_input_tokens + - other.cache_creation_input_tokens, + cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens, + } + } +} + +#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)] +pub struct LanguageModelToolUseId(Arc); + +impl fmt::Display for LanguageModelToolUseId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +impl From for LanguageModelToolUseId +where + T: Into>, +{ + fn from(value: T) -> Self { + Self(value.into()) + } +} + +#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)] +pub struct LanguageModelToolUse { + pub id: LanguageModelToolUseId, + pub name: Arc, + pub raw_input: String, + pub input: serde_json::Value, + pub is_input_complete: bool, + /// Thought signature the model sent us. Some models require that this + /// signature be preserved and sent back in conversation history for validation. + pub thought_signature: Option, +} + +#[derive(Debug, Clone)] +pub struct LanguageModelEffortLevel { + pub name: SharedString, + pub value: SharedString, + pub is_default: bool, +} + +/// An error that occurred when trying to authenticate the language model provider. +#[derive(Debug, Error)] +pub enum AuthenticateError { + #[error("connection refused")] + ConnectionRefused, + #[error("credentials not found")] + CredentialsNotFound, + #[error(transparent)] + Other(#[from] anyhow::Error), +} + +#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)] +pub struct LanguageModelId(pub SharedString); + +#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)] +pub struct LanguageModelName(pub SharedString); + +#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)] +pub struct LanguageModelProviderId(pub SharedString); + +#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)] +pub struct LanguageModelProviderName(pub SharedString); + +impl LanguageModelProviderId { + pub const fn new(id: &'static str) -> Self { + Self(SharedString::new_static(id)) + } +} + +impl LanguageModelProviderName { + pub const fn new(id: &'static str) -> Self { + Self(SharedString::new_static(id)) + } +} + +impl fmt::Display for LanguageModelProviderId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +impl fmt::Display for LanguageModelProviderName { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +impl From for LanguageModelId { + fn from(value: String) -> Self { + Self(SharedString::from(value)) + } +} + +impl From for LanguageModelName { + fn from(value: String) -> Self { + Self(SharedString::from(value)) + } +} + +impl From for LanguageModelProviderId { + fn from(value: String) -> Self { + Self(SharedString::from(value)) + } +} + +impl From for LanguageModelProviderName { + fn from(value: String) -> Self { + Self(SharedString::from(value)) + } +} + +impl From> for LanguageModelProviderId { + fn from(value: Arc) -> Self { + Self(SharedString::from(value)) + } +} + +impl From> for LanguageModelProviderName { + fn from(value: Arc) -> Self { + Self(SharedString::from(value)) + } +} + +/// Settings-layer–free model mode enum. +/// +/// Mirrors the shape of `settings_content::ModelMode` but lives here so that +/// crates below the settings layer can reference it. +#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize, JsonSchema)] +#[serde(tag = "type", rename_all = "lowercase")] +pub enum ModelMode { + #[default] + Default, + Thinking { + budget_tokens: Option, + }, +} + +/// Settings-layer–free reasoning-effort enum. +/// +/// Mirrors the shape of `settings_content::OpenAiReasoningEffort` but lives +/// here so that crates below the settings layer can reference it. +#[derive( + Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema, strum::EnumString, +)] +#[serde(rename_all = "lowercase")] +#[strum(serialize_all = "lowercase")] +pub enum ReasoningEffort { + Minimal, + Low, + Medium, + High, + XHigh, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_from_cloud_failure_with_upstream_http_error() { + let error = LanguageModelCompletionError::from_cloud_failure( + String::from("anthropic").into(), + "upstream_http_error".to_string(), + r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(), + None, + ); + + match error { + LanguageModelCompletionError::ServerOverloaded { provider, .. } => { + assert_eq!(provider.0, "anthropic"); + } + _ => panic!( + "Expected ServerOverloaded error for 503 status, got: {:?}", + error + ), + } + + let error = LanguageModelCompletionError::from_cloud_failure( + String::from("anthropic").into(), + "upstream_http_error".to_string(), + r#"{"code":"upstream_http_error","message":"Internal server error","upstream_status":500}"#.to_string(), + None, + ); + + match error { + LanguageModelCompletionError::ApiInternalServerError { provider, message } => { + assert_eq!(provider.0, "anthropic"); + assert_eq!(message, "Internal server error"); + } + _ => panic!( + "Expected ApiInternalServerError for 500 status, got: {:?}", + error + ), + } + } + + #[test] + fn test_from_cloud_failure_with_standard_format() { + let error = LanguageModelCompletionError::from_cloud_failure( + String::from("anthropic").into(), + "upstream_http_503".to_string(), + "Service unavailable".to_string(), + None, + ); + + match error { + LanguageModelCompletionError::ServerOverloaded { provider, .. } => { + assert_eq!(provider.0, "anthropic"); + } + _ => panic!("Expected ServerOverloaded error for upstream_http_503"), + } + } + + #[test] + fn test_upstream_http_error_connection_timeout() { + let error = LanguageModelCompletionError::from_cloud_failure( + String::from("anthropic").into(), + "upstream_http_error".to_string(), + r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(), + None, + ); + + match error { + LanguageModelCompletionError::ServerOverloaded { provider, .. } => { + assert_eq!(provider.0, "anthropic"); + } + _ => panic!( + "Expected ServerOverloaded error for connection timeout with 503 status, got: {:?}", + error + ), + } + + let error = LanguageModelCompletionError::from_cloud_failure( + String::from("anthropic").into(), + "upstream_http_error".to_string(), + r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":500}"#.to_string(), + None, + ); + + match error { + LanguageModelCompletionError::ApiInternalServerError { provider, message } => { + assert_eq!(provider.0, "anthropic"); + assert_eq!( + message, + "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout" + ); + } + _ => panic!( + "Expected ApiInternalServerError for connection timeout with 500 status, got: {:?}", + error + ), + } + } + + #[test] + fn test_language_model_tool_use_serializes_with_signature() { + use serde_json::json; + + let tool_use = LanguageModelToolUse { + id: LanguageModelToolUseId::from("test_id"), + name: "test_tool".into(), + raw_input: json!({"arg": "value"}).to_string(), + input: json!({"arg": "value"}), + is_input_complete: true, + thought_signature: Some("test_signature".to_string()), + }; + + let serialized = serde_json::to_value(&tool_use).unwrap(); + + assert_eq!(serialized["id"], "test_id"); + assert_eq!(serialized["name"], "test_tool"); + assert_eq!(serialized["thought_signature"], "test_signature"); + } + + #[test] + fn test_language_model_tool_use_deserializes_with_missing_signature() { + use serde_json::json; + + let json = json!({ + "id": "test_id", + "name": "test_tool", + "raw_input": "{\"arg\":\"value\"}", + "input": {"arg": "value"}, + "is_input_complete": true + }); + + let tool_use: LanguageModelToolUse = serde_json::from_value(json).unwrap(); + + assert_eq!(tool_use.id, LanguageModelToolUseId::from("test_id")); + assert_eq!(tool_use.name.as_ref(), "test_tool"); + assert_eq!(tool_use.thought_signature, None); + } + + #[test] + fn test_language_model_tool_use_round_trip_with_signature() { + use serde_json::json; + + let original = LanguageModelToolUse { + id: LanguageModelToolUseId::from("round_trip_id"), + name: "round_trip_tool".into(), + raw_input: json!({"key": "value"}).to_string(), + input: json!({"key": "value"}), + is_input_complete: true, + thought_signature: Some("round_trip_sig".to_string()), + }; + + let serialized = serde_json::to_value(&original).unwrap(); + let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap(); + + assert_eq!(deserialized.id, original.id); + assert_eq!(deserialized.name, original.name); + assert_eq!(deserialized.thought_signature, original.thought_signature); + } + + #[test] + fn test_language_model_tool_use_round_trip_without_signature() { + use serde_json::json; + + let original = LanguageModelToolUse { + id: LanguageModelToolUseId::from("no_sig_id"), + name: "no_sig_tool".into(), + raw_input: json!({"arg": "value"}).to_string(), + input: json!({"arg": "value"}), + is_input_complete: true, + thought_signature: None, + }; + + let serialized = serde_json::to_value(&original).unwrap(); + let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap(); + + assert_eq!(deserialized.id, original.id); + assert_eq!(deserialized.name, original.name); + assert_eq!(deserialized.thought_signature, None); + } +} diff --git a/crates/language_model_core/src/provider.rs b/crates/language_model_core/src/provider.rs new file mode 100644 index 0000000000000000000000000000000000000000..da8b208147ad1d5b58a35888dfd07c821965097c --- /dev/null +++ b/crates/language_model_core/src/provider.rs @@ -0,0 +1,21 @@ +use crate::{LanguageModelProviderId, LanguageModelProviderName}; + +pub const ANTHROPIC_PROVIDER_ID: LanguageModelProviderId = + LanguageModelProviderId::new("anthropic"); +pub const ANTHROPIC_PROVIDER_NAME: LanguageModelProviderName = + LanguageModelProviderName::new("Anthropic"); + +pub const OPEN_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openai"); +pub const OPEN_AI_PROVIDER_NAME: LanguageModelProviderName = + LanguageModelProviderName::new("OpenAI"); + +pub const GOOGLE_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("google"); +pub const GOOGLE_PROVIDER_NAME: LanguageModelProviderName = + LanguageModelProviderName::new("Google AI"); + +pub const X_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("x_ai"); +pub const X_AI_PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("xAI"); + +pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("zed.dev"); +pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName = + LanguageModelProviderName::new("Zed"); diff --git a/crates/language_model/src/rate_limiter.rs b/crates/language_model_core/src/rate_limiter.rs similarity index 100% rename from crates/language_model/src/rate_limiter.rs rename to crates/language_model_core/src/rate_limiter.rs diff --git a/crates/language_model_core/src/request.rs b/crates/language_model_core/src/request.rs new file mode 100644 index 0000000000000000000000000000000000000000..48f7f00522bc3dd5c06747d662761efb003886c0 --- /dev/null +++ b/crates/language_model_core/src/request.rs @@ -0,0 +1,463 @@ +use std::sync::Arc; + +use serde::{Deserialize, Serialize}; + +use crate::role::Role; +use crate::{LanguageModelToolUse, LanguageModelToolUseId, SharedString}; + +/// Dimensions of a `LanguageModelImage` +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub struct ImageSize { + pub width: i32, + pub height: i32, +} + +#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Hash)] +pub struct LanguageModelImage { + /// A base64-encoded PNG image. + pub source: SharedString, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub size: Option, +} + +impl LanguageModelImage { + pub fn len(&self) -> usize { + self.source.len() + } + + pub fn is_empty(&self) -> bool { + self.source.is_empty() + } + + pub fn empty() -> Self { + Self { + source: "".into(), + size: None, + } + } + + /// Parse Self from a JSON object with case-insensitive field names + pub fn from_json(obj: &serde_json::Map) -> Option { + let mut source = None; + let mut size_obj = None; + + for (k, v) in obj.iter() { + match k.to_lowercase().as_str() { + "source" => source = v.as_str(), + "size" => size_obj = v.as_object(), + _ => {} + } + } + + let source = source?; + let size_obj = size_obj?; + + let mut width = None; + let mut height = None; + + for (k, v) in size_obj.iter() { + match k.to_lowercase().as_str() { + "width" => width = v.as_i64().map(|w| w as i32), + "height" => height = v.as_i64().map(|h| h as i32), + _ => {} + } + } + + Some(Self { + size: Some(ImageSize { + width: width?, + height: height?, + }), + source: SharedString::from(source.to_string()), + }) + } + + pub fn estimate_tokens(&self) -> usize { + let Some(size) = self.size.as_ref() else { + return 0; + }; + let width = size.width.unsigned_abs() as usize; + let height = size.height.unsigned_abs() as usize; + + // From: https://docs.anthropic.com/en/docs/build-with-claude/vision#calculate-image-costs + (width * height) / 750 + } + + pub fn to_base64_url(&self) -> String { + format!("data:image/png;base64,{}", self.source) + } +} + +impl std::fmt::Debug for LanguageModelImage { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("LanguageModelImage") + .field("source", &format!("<{} bytes>", self.source.len())) + .field("size", &self.size) + .finish() + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)] +pub struct LanguageModelToolResult { + pub tool_use_id: LanguageModelToolUseId, + pub tool_name: Arc, + pub is_error: bool, + /// The tool output formatted for presenting to the model + pub content: LanguageModelToolResultContent, + /// The raw tool output, if available, often for debugging or extra state for replay + pub output: Option, +} + +#[derive(Debug, Clone, Serialize, Eq, PartialEq, Hash)] +pub enum LanguageModelToolResultContent { + Text(Arc), + Image(LanguageModelImage), +} + +impl<'de> Deserialize<'de> for LanguageModelToolResultContent { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + use serde::de::Error; + + let value = serde_json::Value::deserialize(deserializer)?; + + // 1. Try as plain string + if let Ok(text) = serde_json::from_value::(value.clone()) { + return Ok(Self::Text(Arc::from(text))); + } + + // 2. Try as object + if let Some(obj) = value.as_object() { + fn get_field<'a>( + obj: &'a serde_json::Map, + field: &str, + ) -> Option<&'a serde_json::Value> { + obj.iter() + .find(|(k, _)| k.to_lowercase() == field.to_lowercase()) + .map(|(_, v)| v) + } + + // Accept wrapped text format: { "type": "text", "text": "..." } + if let (Some(type_value), Some(text_value)) = + (get_field(obj, "type"), get_field(obj, "text")) + && let Some(type_str) = type_value.as_str() + && type_str.to_lowercase() == "text" + && let Some(text) = text_value.as_str() + { + return Ok(Self::Text(Arc::from(text))); + } + + // Check for wrapped Text variant: { "text": "..." } + if let Some((_key, value)) = obj.iter().find(|(k, _)| k.to_lowercase() == "text") + && obj.len() == 1 + { + if let Some(text) = value.as_str() { + return Ok(Self::Text(Arc::from(text))); + } + } + + // Check for wrapped Image variant: { "image": { "source": "...", "size": ... } } + if let Some((_key, value)) = obj.iter().find(|(k, _)| k.to_lowercase() == "image") + && obj.len() == 1 + { + if let Some(image_obj) = value.as_object() + && let Some(image) = LanguageModelImage::from_json(image_obj) + { + return Ok(Self::Image(image)); + } + } + + // Try as direct Image + if let Some(image) = LanguageModelImage::from_json(obj) { + return Ok(Self::Image(image)); + } + } + + Err(D::Error::custom(format!( + "data did not match any variant of LanguageModelToolResultContent. Expected either a string, \ + an object with 'type': 'text', a wrapped variant like {{\"Text\": \"...\"}}, or an image object. Got: {}", + serde_json::to_string_pretty(&value).unwrap_or_else(|_| value.to_string()) + ))) + } +} + +impl LanguageModelToolResultContent { + pub fn to_str(&self) -> Option<&str> { + match self { + Self::Text(text) => Some(text), + Self::Image(_) => None, + } + } + + pub fn is_empty(&self) -> bool { + match self { + Self::Text(text) => text.chars().all(|c| c.is_whitespace()), + Self::Image(_) => false, + } + } +} + +impl From<&str> for LanguageModelToolResultContent { + fn from(value: &str) -> Self { + Self::Text(Arc::from(value)) + } +} + +impl From for LanguageModelToolResultContent { + fn from(value: String) -> Self { + Self::Text(Arc::from(value)) + } +} + +impl From for LanguageModelToolResultContent { + fn from(image: LanguageModelImage) -> Self { + Self::Image(image) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)] +pub enum MessageContent { + Text(String), + Thinking { + text: String, + signature: Option, + }, + RedactedThinking(String), + Image(LanguageModelImage), + ToolUse(LanguageModelToolUse), + ToolResult(LanguageModelToolResult), +} + +impl MessageContent { + pub fn to_str(&self) -> Option<&str> { + match self { + MessageContent::Text(text) => Some(text.as_str()), + MessageContent::Thinking { text, .. } => Some(text.as_str()), + MessageContent::RedactedThinking(_) => None, + MessageContent::ToolResult(tool_result) => tool_result.content.to_str(), + MessageContent::ToolUse(_) | MessageContent::Image(_) => None, + } + } + + pub fn is_empty(&self) -> bool { + match self { + MessageContent::Text(text) => text.chars().all(|c| c.is_whitespace()), + MessageContent::Thinking { text, .. } => text.chars().all(|c| c.is_whitespace()), + MessageContent::ToolResult(tool_result) => tool_result.content.is_empty(), + MessageContent::RedactedThinking(_) + | MessageContent::ToolUse(_) + | MessageContent::Image(_) => false, + } + } +} + +impl From for MessageContent { + fn from(value: String) -> Self { + MessageContent::Text(value) + } +} + +impl From<&str> for MessageContent { + fn from(value: &str) -> Self { + MessageContent::Text(value.to_string()) + } +} + +#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Hash)] +pub struct LanguageModelRequestMessage { + pub role: Role, + pub content: Vec, + pub cache: bool, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub reasoning_details: Option, +} + +impl LanguageModelRequestMessage { + pub fn string_contents(&self) -> String { + let mut buffer = String::new(); + for string in self.content.iter().filter_map(|content| content.to_str()) { + buffer.push_str(string); + } + buffer + } + + pub fn contents_empty(&self) -> bool { + self.content.iter().all(|content| content.is_empty()) + } +} + +#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)] +pub struct LanguageModelRequestTool { + pub name: String, + pub description: String, + pub input_schema: serde_json::Value, + pub use_input_streaming: bool, +} + +#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)] +pub enum LanguageModelToolChoice { + Auto, + Any, + None, +} + +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum CompletionIntent { + UserPrompt, + Subagent, + ToolResults, + ThreadSummarization, + ThreadContextSummarization, + CreateFile, + EditFile, + InlineAssist, + TerminalInlineAssist, + GenerateGitCommitMessage, +} + +#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)] +pub struct LanguageModelRequest { + pub thread_id: Option, + pub prompt_id: Option, + pub intent: Option, + pub messages: Vec, + pub tools: Vec, + pub tool_choice: Option, + pub stop: Vec, + pub temperature: Option, + pub thinking_allowed: bool, + pub thinking_effort: Option, + pub speed: Option, +} + +#[derive(Clone, Copy, Default, Debug, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum Speed { + #[default] + Standard, + Fast, +} + +impl Speed { + pub fn toggle(self) -> Self { + match self { + Speed::Standard => Speed::Fast, + Speed::Fast => Speed::Standard, + } + } +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +pub struct LanguageModelResponseMessage { + pub role: Option, + pub content: Option, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_language_model_tool_result_content_deserialization() { + // Test plain string + let json = serde_json::json!("hello world"); + let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap(); + assert_eq!( + content, + LanguageModelToolResultContent::Text(Arc::from("hello world")) + ); + + // Test wrapped text format: { "type": "text", "text": "..." } + let json = serde_json::json!({"type": "text", "text": "hello"}); + let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap(); + assert_eq!( + content, + LanguageModelToolResultContent::Text(Arc::from("hello")) + ); + + // Test single-field text object: { "text": "..." } + let json = serde_json::json!({"text": "hello"}); + let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap(); + assert_eq!( + content, + LanguageModelToolResultContent::Text(Arc::from("hello")) + ); + + // Test case-insensitive type field + let json = serde_json::json!({"Type": "Text", "Text": "hello"}); + let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap(); + assert_eq!( + content, + LanguageModelToolResultContent::Text(Arc::from("hello")) + ); + + // Test image object + let json = serde_json::json!({ + "source": "base64encodedimagedata", + "size": {"width": 100, "height": 200} + }); + let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap(); + match content { + LanguageModelToolResultContent::Image(image) => { + assert_eq!(image.source.as_ref(), "base64encodedimagedata"); + let size = image.size.expect("size"); + assert_eq!(size.width, 100); + assert_eq!(size.height, 200); + } + _ => panic!("Expected Image variant"), + } + + // Test wrapped image: { "image": { "source": "...", "size": ... } } + let json = serde_json::json!({ + "image": { + "source": "wrappedimagedata", + "size": {"width": 50, "height": 75} + } + }); + let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap(); + match content { + LanguageModelToolResultContent::Image(image) => { + assert_eq!(image.source.as_ref(), "wrappedimagedata"); + let size = image.size.expect("size"); + assert_eq!(size.width, 50); + assert_eq!(size.height, 75); + } + _ => panic!("Expected Image variant"), + } + + // Test case insensitive + let json = serde_json::json!({ + "Source": "caseinsensitive", + "Size": {"Width": 30, "Height": 40} + }); + let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap(); + match content { + LanguageModelToolResultContent::Image(image) => { + assert_eq!(image.source.as_ref(), "caseinsensitive"); + let size = image.size.expect("size"); + assert_eq!(size.width, 30); + assert_eq!(size.height, 40); + } + _ => panic!("Expected Image variant"), + } + + // Test direct image object + let json = serde_json::json!({ + "source": "directimage", + "size": {"width": 200, "height": 300} + }); + let content: LanguageModelToolResultContent = serde_json::from_value(json).unwrap(); + match content { + LanguageModelToolResultContent::Image(image) => { + assert_eq!(image.source.as_ref(), "directimage"); + let size = image.size.expect("size"); + assert_eq!(size.width, 200); + assert_eq!(size.height, 300); + } + _ => panic!("Expected Image variant"), + } + } +} diff --git a/crates/language_model/src/role.rs b/crates/language_model_core/src/role.rs similarity index 100% rename from crates/language_model/src/role.rs rename to crates/language_model_core/src/role.rs diff --git a/crates/language_model/src/tool_schema.rs b/crates/language_model_core/src/tool_schema.rs similarity index 92% rename from crates/language_model/src/tool_schema.rs rename to crates/language_model_core/src/tool_schema.rs index 878870482a7527bf815797d16e03ad8edc79642e..0e82b2f41081469c6c04d16765e8336eb903fd94 100644 --- a/crates/language_model/src/tool_schema.rs +++ b/crates/language_model_core/src/tool_schema.rs @@ -77,8 +77,6 @@ pub fn adapt_schema_to_format( } fn preprocess_json_schema(json: &mut Value) -> Result<()> { - // `additionalProperties` defaults to `false` unless explicitly specified. - // This prevents models from hallucinating tool parameters. if let Value::Object(obj) = json && matches!(obj.get("type"), Some(Value::String(s)) if s == "object") { @@ -86,7 +84,6 @@ fn preprocess_json_schema(json: &mut Value) -> Result<()> { obj.insert("additionalProperties".to_string(), Value::Bool(false)); } - // OpenAI API requires non-missing `properties` if !obj.contains_key("properties") { obj.insert("properties".to_string(), Value::Object(Default::default())); } @@ -94,7 +91,6 @@ fn preprocess_json_schema(json: &mut Value) -> Result<()> { Ok(()) } -/// Tries to adapt the json schema so that it is compatible with https://ai.google.dev/api/caching#Schema fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> { if let Value::Object(obj) = json { const UNSUPPORTED_KEYS: [&str; 4] = ["if", "then", "else", "$ref"]; @@ -108,9 +104,7 @@ fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> { const KEYS_TO_REMOVE: [(&str, fn(&Value) -> bool); 6] = [ ("format", |value| value.is_string()), - // Gemini doesn't support `additionalProperties` in any form (boolean or schema object) ("additionalProperties", |_| true), - // Gemini doesn't support `propertyNames` ("propertyNames", |_| true), ("exclusiveMinimum", |value| value.is_number()), ("exclusiveMaximum", |value| value.is_number()), @@ -124,7 +118,6 @@ fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> { } } - // If a type is not specified for an input parameter, add a default type if matches!(obj.get("description"), Some(Value::String(_))) && !obj.contains_key("type") && !(obj.contains_key("anyOf") @@ -134,7 +127,6 @@ fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> { obj.insert("type".to_string(), Value::String("string".to_string())); } - // Handle oneOf -> anyOf conversion if let Some(subschemas) = obj.get_mut("oneOf") && subschemas.is_array() { @@ -143,7 +135,6 @@ fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> { obj.insert("anyOf".to_string(), subschemas_clone); } - // Recursively process all nested objects and arrays for (_, value) in obj.iter_mut() { if let Value::Object(_) | Value::Array(_) = value { adapt_to_json_schema_subset(value)?; @@ -178,7 +169,6 @@ mod tests { }) ); - // Ensure that we do not add a type if it is an object let mut json = json!({ "description": { "value": "abc", @@ -221,7 +211,6 @@ mod tests { }) ); - // Ensure that we do not remove keys that are actually supported (e.g. "format" can just be used as another property) let mut json = json!({ "description": "A test field", "type": "integer", @@ -239,7 +228,6 @@ mod tests { }) ); - // additionalProperties as an object schema is also unsupported by Gemini let mut json = json!({ "type": "object", "properties": { diff --git a/crates/language_models/src/provider/util.rs b/crates/language_model_core/src/util.rs similarity index 88% rename from crates/language_models/src/provider/util.rs rename to crates/language_model_core/src/util.rs index 76a02b6de40a3e36c7c506f11a6f6d34d2aaca3e..3db2e0b76fd76070aa4d30e97c525fa8f3460c9d 100644 --- a/crates/language_models/src/provider/util.rs +++ b/crates/language_model_core/src/util.rs @@ -38,13 +38,22 @@ fn strip_trailing_incomplete_escape(json: &str) -> &str { } } +/// Parses a "prompt is too long: N tokens ..." message and extracts the token count. +pub fn parse_prompt_too_long(message: &str) -> Option { + message + .strip_prefix("prompt is too long: ")? + .split_once(" tokens")? + .0 + .parse() + .ok() +} + #[cfg(test)] mod tests { use super::*; #[test] fn test_fix_streamed_json_strips_incomplete_escape() { - // Trailing `\` inside a string — incomplete escape sequence let fixed = fix_streamed_json(r#"{"text": "hello\"#); let parsed: serde_json::Value = serde_json::from_str(&fixed).expect("valid json"); assert_eq!(parsed["text"], "hello"); @@ -52,7 +61,6 @@ mod tests { #[test] fn test_fix_streamed_json_preserves_complete_escape() { - // `\\` is a complete escape (literal backslash) let fixed = fix_streamed_json(r#"{"text": "hello\\"#); let parsed: serde_json::Value = serde_json::from_str(&fixed).expect("valid json"); assert_eq!(parsed["text"], "hello\\"); @@ -60,7 +68,6 @@ mod tests { #[test] fn test_fix_streamed_json_strips_escape_after_complete_escape() { - // `\\\` = complete `\\` (literal backslash) + incomplete `\` let fixed = fix_streamed_json(r#"{"text": "hello\\\"#); let parsed: serde_json::Value = serde_json::from_str(&fixed).expect("valid json"); assert_eq!(parsed["text"], "hello\\"); @@ -75,12 +82,10 @@ mod tests { #[test] fn test_fix_streamed_json_newline_escape_boundary() { - // Simulates a stream boundary landing between `\` and `n` let fixed = fix_streamed_json(r#"{"text": "line1\"#); let parsed: serde_json::Value = serde_json::from_str(&fixed).expect("valid json"); assert_eq!(parsed["text"], "line1"); - // Next chunk completes the escape let fixed = fix_streamed_json(r#"{"text": "line1\nline2"#); let parsed: serde_json::Value = serde_json::from_str(&fixed).expect("valid json"); assert_eq!(parsed["text"], "line1\nline2"); @@ -88,8 +93,6 @@ mod tests { #[test] fn test_fix_streamed_json_incremental_delta_correctness() { - // This is the actual scenario that causes the bug: - // chunk 1 ends mid-escape, chunk 2 completes it. let chunk1 = r#"{"replacement_text": "fn foo() {\"#; let fixed1 = fix_streamed_json(chunk1); let parsed1: serde_json::Value = serde_json::from_str(&fixed1).expect("valid json"); @@ -102,7 +105,6 @@ mod tests { let text2 = parsed2["replacement_text"].as_str().expect("string"); assert_eq!(text2, "fn foo() {\n return bar;\n}"); - // The delta should be the newline + rest, with no spurious backslash let delta = &text2[text1.len()..]; assert_eq!(delta, "\n return bar;\n}"); } diff --git a/crates/language_models/Cargo.toml b/crates/language_models/Cargo.toml index 4b141205053efc2dbf2fce81087ec9ed8dc25e75..ffff0177d7b5172ed7f374b18de457b3f2a13a66 100644 --- a/crates/language_models/Cargo.toml +++ b/crates/language_models/Cargo.toml @@ -21,8 +21,8 @@ aws_http_client.workspace = true base64.workspace = true bedrock = { workspace = true, features = ["schemars"] } client.workspace = true +cloud_api_client.workspace = true cloud_api_types.workspace = true -cloud_llm_client.workspace = true collections.workspace = true component.workspace = true convert_case.workspace = true @@ -41,6 +41,7 @@ gpui_tokio.workspace = true http_client.workspace = true language.workspace = true language_model.workspace = true +language_models_cloud.workspace = true lmstudio = { workspace = true, features = ["schemars"] } log.workspace = true menu.workspace = true @@ -49,18 +50,15 @@ ollama = { workspace = true, features = ["schemars"] } open_ai = { workspace = true, features = ["schemars"] } opencode = { workspace = true, features = ["schemars"] } open_router = { workspace = true, features = ["schemars"] } -partial-json-fixer.workspace = true rand.workspace = true release_channel.workspace = true schemars.workspace = true sha2.workspace = true -semver.workspace = true serde.workspace = true serde_json.workspace = true settings.workspace = true smol.workspace = true strum.workspace = true -thiserror.workspace = true tiktoken-rs.workspace = true tokio = { workspace = true, features = ["rt", "rt-multi-thread"] } ui.workspace = true @@ -76,4 +74,3 @@ http_client = { workspace = true, features = ["test-support"] } language_model = { workspace = true, features = ["test-support"] } parking_lot.workspace = true pretty_assertions.workspace = true - diff --git a/crates/language_models/src/provider.rs b/crates/language_models/src/provider.rs index da4d2b808cf8a6c972351239efb97962a98b3b2c..8b37975f6c9fc83d9e13d3d35de260d59f1be85f 100644 --- a/crates/language_models/src/provider.rs +++ b/crates/language_models/src/provider.rs @@ -12,7 +12,7 @@ pub mod open_ai_compatible; pub mod open_router; pub mod openai_subscribed; pub mod opencode; -mod util; + pub mod vercel; pub mod vercel_ai_gateway; pub mod x_ai; diff --git a/crates/language_models/src/provider/anthropic.rs b/crates/language_models/src/provider/anthropic.rs index c1b8bc1a3bb1b602b67ae5563d8acc3b05a94d47..58de77d573293345ec2120695866c824f10c6108 100644 --- a/crates/language_models/src/provider/anthropic.rs +++ b/crates/language_models/src/provider/anthropic.rs @@ -1,13 +1,10 @@ pub mod telemetry; -use anthropic::{ - ANTHROPIC_API_URL, AnthropicError, AnthropicModelMode, ContentDelta, CountTokensRequest, Event, - ResponseContent, ToolResultContent, ToolResultPart, Usage, -}; +use anthropic::{ANTHROPIC_API_URL, AnthropicError, AnthropicModelMode}; use anyhow::Result; -use collections::{BTreeMap, HashMap}; +use collections::BTreeMap; use credentials_provider::CredentialsProvider; -use futures::{FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream}; +use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream}; use gpui::{AnyView, App, AsyncApp, Context, Entity, Task}; use http_client::HttpClient; use language_model::{ @@ -16,20 +13,19 @@ use language_model::{ LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, - LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent, - RateLimiter, Role, StopReason, env_var, + LanguageModelToolChoice, RateLimiter, env_var, }; use settings::{Settings, SettingsStore}; -use std::pin::Pin; -use std::str::FromStr; use std::sync::{Arc, LazyLock}; use strum::IntoEnumIterator; use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*}; use ui_input::InputField; use util::ResultExt; -use crate::provider::util::{fix_streamed_json, parse_tool_arguments}; - +pub use anthropic::completion::{ + AnthropicEventMapper, count_anthropic_tokens_with_tiktoken, into_anthropic, + into_anthropic_count_tokens_request, +}; pub use settings::AnthropicAvailableModel as AvailableModel; const PROVIDER_ID: LanguageModelProviderId = ANTHROPIC_PROVIDER_ID; @@ -249,228 +245,6 @@ pub struct AnthropicModel { request_limiter: RateLimiter, } -fn to_anthropic_content(content: MessageContent) -> Option { - match content { - MessageContent::Text(text) => { - let text = if text.chars().last().is_some_and(|c| c.is_whitespace()) { - text.trim_end().to_string() - } else { - text - }; - if !text.is_empty() { - Some(anthropic::RequestContent::Text { - text, - cache_control: None, - }) - } else { - None - } - } - MessageContent::Thinking { - text: thinking, - signature, - } => { - if let Some(signature) = signature - && !thinking.is_empty() - { - Some(anthropic::RequestContent::Thinking { - thinking, - signature, - cache_control: None, - }) - } else { - None - } - } - MessageContent::RedactedThinking(data) => { - if !data.is_empty() { - Some(anthropic::RequestContent::RedactedThinking { data }) - } else { - None - } - } - MessageContent::Image(image) => Some(anthropic::RequestContent::Image { - source: anthropic::ImageSource { - source_type: "base64".to_string(), - media_type: "image/png".to_string(), - data: image.source.to_string(), - }, - cache_control: None, - }), - MessageContent::ToolUse(tool_use) => Some(anthropic::RequestContent::ToolUse { - id: tool_use.id.to_string(), - name: tool_use.name.to_string(), - input: tool_use.input, - cache_control: None, - }), - MessageContent::ToolResult(tool_result) => Some(anthropic::RequestContent::ToolResult { - tool_use_id: tool_result.tool_use_id.to_string(), - is_error: tool_result.is_error, - content: match tool_result.content { - LanguageModelToolResultContent::Text(text) => { - ToolResultContent::Plain(text.to_string()) - } - LanguageModelToolResultContent::Image(image) => { - ToolResultContent::Multipart(vec![ToolResultPart::Image { - source: anthropic::ImageSource { - source_type: "base64".to_string(), - media_type: "image/png".to_string(), - data: image.source.to_string(), - }, - }]) - } - }, - cache_control: None, - }), - } -} - -/// Convert a LanguageModelRequest to an Anthropic CountTokensRequest. -pub fn into_anthropic_count_tokens_request( - request: LanguageModelRequest, - model: String, - mode: AnthropicModelMode, -) -> CountTokensRequest { - let mut new_messages: Vec = Vec::new(); - let mut system_message = String::new(); - - for message in request.messages { - if message.contents_empty() { - continue; - } - - match message.role { - Role::User | Role::Assistant => { - let anthropic_message_content: Vec = message - .content - .into_iter() - .filter_map(to_anthropic_content) - .collect(); - let anthropic_role = match message.role { - Role::User => anthropic::Role::User, - Role::Assistant => anthropic::Role::Assistant, - Role::System => unreachable!("System role should never occur here"), - }; - if anthropic_message_content.is_empty() { - continue; - } - - if let Some(last_message) = new_messages.last_mut() - && last_message.role == anthropic_role - { - last_message.content.extend(anthropic_message_content); - continue; - } - - new_messages.push(anthropic::Message { - role: anthropic_role, - content: anthropic_message_content, - }); - } - Role::System => { - if !system_message.is_empty() { - system_message.push_str("\n\n"); - } - system_message.push_str(&message.string_contents()); - } - } - } - - CountTokensRequest { - model, - messages: new_messages, - system: if system_message.is_empty() { - None - } else { - Some(anthropic::StringOrContents::String(system_message)) - }, - thinking: if request.thinking_allowed { - match mode { - AnthropicModelMode::Thinking { budget_tokens } => { - Some(anthropic::Thinking::Enabled { budget_tokens }) - } - AnthropicModelMode::AdaptiveThinking => Some(anthropic::Thinking::Adaptive), - AnthropicModelMode::Default => None, - } - } else { - None - }, - tools: request - .tools - .into_iter() - .map(|tool| anthropic::Tool { - name: tool.name, - description: tool.description, - input_schema: tool.input_schema, - eager_input_streaming: tool.use_input_streaming, - }) - .collect(), - tool_choice: request.tool_choice.map(|choice| match choice { - LanguageModelToolChoice::Auto => anthropic::ToolChoice::Auto, - LanguageModelToolChoice::Any => anthropic::ToolChoice::Any, - LanguageModelToolChoice::None => anthropic::ToolChoice::None, - }), - } -} - -/// Estimate tokens using tiktoken. Used as a fallback when the API is unavailable, -/// or by providers (like Zed Cloud) that don't have direct Anthropic API access. -pub fn count_anthropic_tokens_with_tiktoken(request: LanguageModelRequest) -> Result { - let messages = request.messages; - let mut tokens_from_images = 0; - let mut string_messages = Vec::with_capacity(messages.len()); - - for message in messages { - let mut string_contents = String::new(); - - for content in message.content { - match content { - MessageContent::Text(text) => { - string_contents.push_str(&text); - } - MessageContent::Thinking { .. } => { - // Thinking blocks are not included in the input token count. - } - MessageContent::RedactedThinking(_) => { - // Thinking blocks are not included in the input token count. - } - MessageContent::Image(image) => { - tokens_from_images += image.estimate_tokens(); - } - MessageContent::ToolUse(_tool_use) => { - // TODO: Estimate token usage from tool uses. - } - MessageContent::ToolResult(tool_result) => match &tool_result.content { - LanguageModelToolResultContent::Text(text) => { - string_contents.push_str(text); - } - LanguageModelToolResultContent::Image(image) => { - tokens_from_images += image.estimate_tokens(); - } - }, - } - } - - if !string_contents.is_empty() { - string_messages.push(tiktoken_rs::ChatCompletionRequestMessage { - role: match message.role { - Role::User => "user".into(), - Role::Assistant => "assistant".into(), - Role::System => "system".into(), - }, - content: Some(string_contents), - name: None, - function_call: None, - }); - } - } - - // Tiktoken doesn't yet support these models, so we manually use the - // same tokenizer as GPT-4. - tiktoken_rs::num_tokens_from_messages("gpt-4", &string_messages) - .map(|tokens| (tokens + tokens_from_images) as u64) -} - impl AnthropicModel { fn stream_completion( &self, @@ -617,10 +391,13 @@ impl LanguageModel for AnthropicModel { ) }); + let background = cx.background_executor().clone(); async move { // If no API key, fall back to tiktoken estimation let Some(api_key) = api_key else { - return count_anthropic_tokens_with_tiktoken(request); + return background + .spawn(async move { count_anthropic_tokens_with_tiktoken(request) }) + .await; }; let count_request = @@ -634,7 +411,9 @@ impl LanguageModel for AnthropicModel { log::error!( "Anthropic count_tokens API failed, falling back to tiktoken: {err:?}" ); - count_anthropic_tokens_with_tiktoken(request) + background + .spawn(async move { count_anthropic_tokens_with_tiktoken(request) }) + .await } } } @@ -678,345 +457,6 @@ impl LanguageModel for AnthropicModel { } } -pub fn into_anthropic( - request: LanguageModelRequest, - model: String, - default_temperature: f32, - max_output_tokens: u64, - mode: AnthropicModelMode, -) -> anthropic::Request { - let mut new_messages: Vec = Vec::new(); - let mut system_message = String::new(); - - for message in request.messages { - if message.contents_empty() { - continue; - } - - match message.role { - Role::User | Role::Assistant => { - let mut anthropic_message_content: Vec = message - .content - .into_iter() - .filter_map(to_anthropic_content) - .collect(); - let anthropic_role = match message.role { - Role::User => anthropic::Role::User, - Role::Assistant => anthropic::Role::Assistant, - Role::System => unreachable!("System role should never occur here"), - }; - if anthropic_message_content.is_empty() { - continue; - } - - if let Some(last_message) = new_messages.last_mut() - && last_message.role == anthropic_role - { - last_message.content.extend(anthropic_message_content); - continue; - } - - // Mark the last segment of the message as cached - if message.cache { - let cache_control_value = Some(anthropic::CacheControl { - cache_type: anthropic::CacheControlType::Ephemeral, - }); - for message_content in anthropic_message_content.iter_mut().rev() { - match message_content { - anthropic::RequestContent::RedactedThinking { .. } => { - // Caching is not possible, fallback to next message - } - anthropic::RequestContent::Text { cache_control, .. } - | anthropic::RequestContent::Thinking { cache_control, .. } - | anthropic::RequestContent::Image { cache_control, .. } - | anthropic::RequestContent::ToolUse { cache_control, .. } - | anthropic::RequestContent::ToolResult { cache_control, .. } => { - *cache_control = cache_control_value; - break; - } - } - } - } - - new_messages.push(anthropic::Message { - role: anthropic_role, - content: anthropic_message_content, - }); - } - Role::System => { - if !system_message.is_empty() { - system_message.push_str("\n\n"); - } - system_message.push_str(&message.string_contents()); - } - } - } - - anthropic::Request { - model, - messages: new_messages, - max_tokens: max_output_tokens, - system: if system_message.is_empty() { - None - } else { - Some(anthropic::StringOrContents::String(system_message)) - }, - thinking: if request.thinking_allowed { - match mode { - AnthropicModelMode::Thinking { budget_tokens } => { - Some(anthropic::Thinking::Enabled { budget_tokens }) - } - AnthropicModelMode::AdaptiveThinking => Some(anthropic::Thinking::Adaptive), - AnthropicModelMode::Default => None, - } - } else { - None - }, - tools: request - .tools - .into_iter() - .map(|tool| anthropic::Tool { - name: tool.name, - description: tool.description, - input_schema: tool.input_schema, - eager_input_streaming: tool.use_input_streaming, - }) - .collect(), - tool_choice: request.tool_choice.map(|choice| match choice { - LanguageModelToolChoice::Auto => anthropic::ToolChoice::Auto, - LanguageModelToolChoice::Any => anthropic::ToolChoice::Any, - LanguageModelToolChoice::None => anthropic::ToolChoice::None, - }), - metadata: None, - output_config: if request.thinking_allowed - && matches!(mode, AnthropicModelMode::AdaptiveThinking) - { - request.thinking_effort.as_deref().and_then(|effort| { - let effort = match effort { - "low" => Some(anthropic::Effort::Low), - "medium" => Some(anthropic::Effort::Medium), - "high" => Some(anthropic::Effort::High), - "max" => Some(anthropic::Effort::Max), - _ => None, - }; - effort.map(|effort| anthropic::OutputConfig { - effort: Some(effort), - }) - }) - } else { - None - }, - stop_sequences: Vec::new(), - speed: request.speed.map(From::from), - temperature: request.temperature.or(Some(default_temperature)), - top_k: None, - top_p: None, - } -} - -pub struct AnthropicEventMapper { - tool_uses_by_index: HashMap, - usage: Usage, - stop_reason: StopReason, -} - -impl AnthropicEventMapper { - pub fn new() -> Self { - Self { - tool_uses_by_index: HashMap::default(), - usage: Usage::default(), - stop_reason: StopReason::EndTurn, - } - } - - pub fn map_stream( - mut self, - events: Pin>>>, - ) -> impl Stream> - { - events.flat_map(move |event| { - futures::stream::iter(match event { - Ok(event) => self.map_event(event), - Err(error) => vec![Err(error.into())], - }) - }) - } - - pub fn map_event( - &mut self, - event: Event, - ) -> Vec> { - match event { - Event::ContentBlockStart { - index, - content_block, - } => match content_block { - ResponseContent::Text { text } => { - vec![Ok(LanguageModelCompletionEvent::Text(text))] - } - ResponseContent::Thinking { thinking } => { - vec![Ok(LanguageModelCompletionEvent::Thinking { - text: thinking, - signature: None, - })] - } - ResponseContent::RedactedThinking { data } => { - vec![Ok(LanguageModelCompletionEvent::RedactedThinking { data })] - } - ResponseContent::ToolUse { id, name, .. } => { - self.tool_uses_by_index.insert( - index, - RawToolUse { - id, - name, - input_json: String::new(), - }, - ); - Vec::new() - } - }, - Event::ContentBlockDelta { index, delta } => match delta { - ContentDelta::TextDelta { text } => { - vec![Ok(LanguageModelCompletionEvent::Text(text))] - } - ContentDelta::ThinkingDelta { thinking } => { - vec![Ok(LanguageModelCompletionEvent::Thinking { - text: thinking, - signature: None, - })] - } - ContentDelta::SignatureDelta { signature } => { - vec![Ok(LanguageModelCompletionEvent::Thinking { - text: "".to_string(), - signature: Some(signature), - })] - } - ContentDelta::InputJsonDelta { partial_json } => { - if let Some(tool_use) = self.tool_uses_by_index.get_mut(&index) { - tool_use.input_json.push_str(&partial_json); - - // Try to convert invalid (incomplete) JSON into - // valid JSON that serde can accept, e.g. by closing - // unclosed delimiters. This way, we can update the - // UI with whatever has been streamed back so far. - if let Ok(input) = - serde_json::Value::from_str(&fix_streamed_json(&tool_use.input_json)) - { - return vec![Ok(LanguageModelCompletionEvent::ToolUse( - LanguageModelToolUse { - id: tool_use.id.clone().into(), - name: tool_use.name.clone().into(), - is_input_complete: false, - raw_input: tool_use.input_json.clone(), - input, - thought_signature: None, - }, - ))]; - } - } - vec![] - } - }, - Event::ContentBlockStop { index } => { - if let Some(tool_use) = self.tool_uses_by_index.remove(&index) { - let input_json = tool_use.input_json.trim(); - let event_result = match parse_tool_arguments(input_json) { - Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse( - LanguageModelToolUse { - id: tool_use.id.into(), - name: tool_use.name.into(), - is_input_complete: true, - input, - raw_input: tool_use.input_json.clone(), - thought_signature: None, - }, - )), - Err(json_parse_err) => { - Ok(LanguageModelCompletionEvent::ToolUseJsonParseError { - id: tool_use.id.into(), - tool_name: tool_use.name.into(), - raw_input: input_json.into(), - json_parse_error: json_parse_err.to_string(), - }) - } - }; - - vec![event_result] - } else { - Vec::new() - } - } - Event::MessageStart { message } => { - update_usage(&mut self.usage, &message.usage); - vec![ - Ok(LanguageModelCompletionEvent::UsageUpdate(convert_usage( - &self.usage, - ))), - Ok(LanguageModelCompletionEvent::StartMessage { - message_id: message.id, - }), - ] - } - Event::MessageDelta { delta, usage } => { - update_usage(&mut self.usage, &usage); - if let Some(stop_reason) = delta.stop_reason.as_deref() { - self.stop_reason = match stop_reason { - "end_turn" => StopReason::EndTurn, - "max_tokens" => StopReason::MaxTokens, - "tool_use" => StopReason::ToolUse, - "refusal" => StopReason::Refusal, - _ => { - log::error!("Unexpected anthropic stop_reason: {stop_reason}"); - StopReason::EndTurn - } - }; - } - vec![Ok(LanguageModelCompletionEvent::UsageUpdate( - convert_usage(&self.usage), - ))] - } - Event::MessageStop => { - vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))] - } - Event::Error { error } => { - vec![Err(error.into())] - } - _ => Vec::new(), - } - } -} - -struct RawToolUse { - id: String, - name: String, - input_json: String, -} - -/// Updates usage data by preferring counts from `new`. -fn update_usage(usage: &mut Usage, new: &Usage) { - if let Some(input_tokens) = new.input_tokens { - usage.input_tokens = Some(input_tokens); - } - if let Some(output_tokens) = new.output_tokens { - usage.output_tokens = Some(output_tokens); - } - if let Some(cache_creation_input_tokens) = new.cache_creation_input_tokens { - usage.cache_creation_input_tokens = Some(cache_creation_input_tokens); - } - if let Some(cache_read_input_tokens) = new.cache_read_input_tokens { - usage.cache_read_input_tokens = Some(cache_read_input_tokens); - } -} - -fn convert_usage(usage: &Usage) -> language_model::TokenUsage { - language_model::TokenUsage { - input_tokens: usage.input_tokens.unwrap_or(0), - output_tokens: usage.output_tokens.unwrap_or(0), - cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0), - cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0), - } -} - struct ConfigurationView { api_key_editor: Entity, state: Entity, @@ -1157,192 +597,3 @@ impl Render for ConfigurationView { } } } - -#[cfg(test)] -mod tests { - use super::*; - use anthropic::AnthropicModelMode; - use language_model::{LanguageModelRequestMessage, MessageContent}; - - #[test] - fn test_cache_control_only_on_last_segment() { - let request = LanguageModelRequest { - messages: vec![LanguageModelRequestMessage { - role: Role::User, - content: vec![ - MessageContent::Text("Some prompt".to_string()), - MessageContent::Image(language_model::LanguageModelImage::empty()), - MessageContent::Image(language_model::LanguageModelImage::empty()), - MessageContent::Image(language_model::LanguageModelImage::empty()), - MessageContent::Image(language_model::LanguageModelImage::empty()), - ], - cache: true, - reasoning_details: None, - }], - thread_id: None, - prompt_id: None, - intent: None, - stop: vec![], - temperature: None, - tools: vec![], - tool_choice: None, - thinking_allowed: true, - thinking_effort: None, - speed: None, - }; - - let anthropic_request = into_anthropic( - request, - "claude-3-5-sonnet".to_string(), - 0.7, - 4096, - AnthropicModelMode::Default, - ); - - assert_eq!(anthropic_request.messages.len(), 1); - - let message = &anthropic_request.messages[0]; - assert_eq!(message.content.len(), 5); - - assert!(matches!( - message.content[0], - anthropic::RequestContent::Text { - cache_control: None, - .. - } - )); - for i in 1..3 { - assert!(matches!( - message.content[i], - anthropic::RequestContent::Image { - cache_control: None, - .. - } - )); - } - - assert!(matches!( - message.content[4], - anthropic::RequestContent::Image { - cache_control: Some(anthropic::CacheControl { - cache_type: anthropic::CacheControlType::Ephemeral, - }), - .. - } - )); - } - - fn request_with_assistant_content( - assistant_content: Vec, - ) -> anthropic::Request { - let mut request = LanguageModelRequest { - messages: vec![LanguageModelRequestMessage { - role: Role::User, - content: vec![MessageContent::Text("Hello".to_string())], - cache: false, - reasoning_details: None, - }], - thinking_effort: None, - thread_id: None, - prompt_id: None, - intent: None, - stop: vec![], - temperature: None, - tools: vec![], - tool_choice: None, - thinking_allowed: true, - speed: None, - }; - request.messages.push(LanguageModelRequestMessage { - role: Role::Assistant, - content: assistant_content, - cache: false, - reasoning_details: None, - }); - into_anthropic( - request, - "claude-sonnet-4-5".to_string(), - 1.0, - 16000, - AnthropicModelMode::Thinking { - budget_tokens: Some(10000), - }, - ) - } - - #[test] - fn test_unsigned_thinking_blocks_stripped() { - let result = request_with_assistant_content(vec![ - MessageContent::Thinking { - text: "Cancelled mid-think, no signature".to_string(), - signature: None, - }, - MessageContent::Text("Some response text".to_string()), - ]); - - let assistant_message = result - .messages - .iter() - .find(|m| m.role == anthropic::Role::Assistant) - .expect("assistant message should still exist"); - - assert_eq!( - assistant_message.content.len(), - 1, - "Only the text content should remain; unsigned thinking block should be stripped" - ); - assert!(matches!( - &assistant_message.content[0], - anthropic::RequestContent::Text { text, .. } if text == "Some response text" - )); - } - - #[test] - fn test_signed_thinking_blocks_preserved() { - let result = request_with_assistant_content(vec![ - MessageContent::Thinking { - text: "Completed thinking".to_string(), - signature: Some("valid-signature".to_string()), - }, - MessageContent::Text("Response".to_string()), - ]); - - let assistant_message = result - .messages - .iter() - .find(|m| m.role == anthropic::Role::Assistant) - .expect("assistant message should exist"); - - assert_eq!( - assistant_message.content.len(), - 2, - "Both the signed thinking block and text should be preserved" - ); - assert!(matches!( - &assistant_message.content[0], - anthropic::RequestContent::Thinking { thinking, signature, .. } - if thinking == "Completed thinking" && signature == "valid-signature" - )); - } - - #[test] - fn test_only_unsigned_thinking_block_omits_entire_message() { - let result = request_with_assistant_content(vec![MessageContent::Thinking { - text: "Cancelled before any text or signature".to_string(), - signature: None, - }]); - - let assistant_messages: Vec<_> = result - .messages - .iter() - .filter(|m| m.role == anthropic::Role::Assistant) - .collect(); - - assert_eq!( - assistant_messages.len(), - 0, - "An assistant message whose only content was an unsigned thinking block \ - should be omitted entirely" - ); - } -} diff --git a/crates/language_models/src/provider/bedrock.rs b/crates/language_models/src/provider/bedrock.rs index 4320763e2c5c6de7f3fe9238d7a4991565c3bfcd..80c758769cd990c00f5942433143bf6fb2216b7c 100644 --- a/crates/language_models/src/provider/bedrock.rs +++ b/crates/language_models/src/provider/bedrock.rs @@ -48,7 +48,7 @@ use ui_input::InputField; use util::ResultExt; use crate::AllLanguageModelSettings; -use crate::provider::util::{fix_streamed_json, parse_tool_arguments}; +use language_model::util::{fix_streamed_json, parse_tool_arguments}; actions!(bedrock, [Tab, TabPrev]); diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index 29623cc998ad0fe933e9a29c45c651f7be010b07..294b44ecae9941481e26c2341018ce584d68b3ec 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -1,107 +1,93 @@ use ai_onboarding::YoungAccountBanner; -use anthropic::AnthropicModelMode; -use anyhow::{Context as _, Result, anyhow}; -use client::{ - Client, NeedsLlmTokenRefresh, RefreshLlmTokenListener, UserStore, global_llm_token, zed_urls, -}; -use cloud_api_types::{OrganizationId, Plan}; -use cloud_llm_client::{ - CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME, - CLIENT_SUPPORTS_X_AI_HEADER_NAME, CompletionBody, CompletionEvent, CompletionRequestStatus, - CountTokensBody, CountTokensResponse, ListModelsResponse, - SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, ZED_VERSION_HEADER_NAME, -}; -use futures::{ - AsyncBufReadExt, FutureExt, Stream, StreamExt, - future::BoxFuture, - stream::{self, BoxStream}, -}; -use google_ai::GoogleModelMode; -use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Task}; -use http_client::http::{HeaderMap, HeaderValue}; -use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Response, StatusCode}; +use anyhow::Result; +use client::{Client, RefreshLlmTokenListener, UserStore, global_llm_token, zed_urls}; +use cloud_api_client::LlmApiToken; +use cloud_api_types::OrganizationId; +use cloud_api_types::Plan; +use futures::StreamExt; +use futures::future::BoxFuture; +use gpui::AsyncApp; +use gpui::{AnyElement, AnyView, App, Context, Entity, Subscription, Task}; use language_model::{ - ANTHROPIC_PROVIDER_ID, ANTHROPIC_PROVIDER_NAME, AuthenticateError, GOOGLE_PROVIDER_ID, - GOOGLE_PROVIDER_NAME, IconOrSvg, LanguageModel, LanguageModelCacheConfiguration, - LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelEffortLevel, - LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, - LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, - LanguageModelToolChoice, LanguageModelToolSchemaFormat, LlmApiToken, OPEN_AI_PROVIDER_ID, - OPEN_AI_PROVIDER_NAME, PaymentRequiredError, RateLimiter, X_AI_PROVIDER_ID, X_AI_PROVIDER_NAME, - ZED_CLOUD_PROVIDER_ID, ZED_CLOUD_PROVIDER_NAME, + AuthenticateError, IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelProviderId, + LanguageModelProviderName, LanguageModelProviderState, ZED_CLOUD_PROVIDER_ID, + ZED_CLOUD_PROVIDER_NAME, }; +use language_models_cloud::{CloudLlmTokenProvider, CloudModelProvider}; use release_channel::AppVersion; -use schemars::JsonSchema; -use semver::Version; -use serde::{Deserialize, Serialize, de::DeserializeOwned}; + use settings::SettingsStore; pub use settings::ZedDotDevAvailableModel as AvailableModel; pub use settings::ZedDotDevAvailableProvider as AvailableProvider; -use smol::io::{AsyncReadExt, BufReader}; -use std::collections::VecDeque; -use std::pin::Pin; -use std::str::FromStr; use std::sync::Arc; -use std::task::Poll; -use std::time::Duration; -use thiserror::Error; use ui::{TintColor, prelude::*}; -use crate::provider::anthropic::{ - AnthropicEventMapper, count_anthropic_tokens_with_tiktoken, into_anthropic, -}; -use crate::provider::google::{GoogleEventMapper, into_google}; -use crate::provider::open_ai::{ - OpenAiEventMapper, OpenAiResponseEventMapper, count_open_ai_tokens, into_open_ai, - into_open_ai_response, -}; -use crate::provider::x_ai::count_xai_tokens; - const PROVIDER_ID: LanguageModelProviderId = ZED_CLOUD_PROVIDER_ID; const PROVIDER_NAME: LanguageModelProviderName = ZED_CLOUD_PROVIDER_NAME; -#[derive(Default, Clone, Debug, PartialEq)] -pub struct ZedDotDevSettings { - pub available_models: Vec, -} -#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] -#[serde(tag = "type", rename_all = "lowercase")] -pub enum ModelMode { - #[default] - Default, - Thinking { - /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`. - budget_tokens: Option, - }, +struct ClientTokenProvider { + client: Arc, + llm_api_token: LlmApiToken, + user_store: Entity, } -impl From for AnthropicModelMode { - fn from(value: ModelMode) -> Self { - match value { - ModelMode::Default => AnthropicModelMode::Default, - ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens }, - } +impl CloudLlmTokenProvider for ClientTokenProvider { + type AuthContext = Option; + + fn auth_context(&self, cx: &AsyncApp) -> Self::AuthContext { + self.user_store.read_with(cx, |user_store, _| { + user_store + .current_organization() + .map(|organization| organization.id.clone()) + }) } + + fn acquire_token( + &self, + organization_id: Self::AuthContext, + ) -> BoxFuture<'static, Result> { + let client = self.client.clone(); + let llm_api_token = self.llm_api_token.clone(); + Box::pin(async move { + client + .acquire_llm_token(&llm_api_token, organization_id) + .await + }) + } + + fn refresh_token( + &self, + organization_id: Self::AuthContext, + ) -> BoxFuture<'static, Result> { + let client = self.client.clone(); + let llm_api_token = self.llm_api_token.clone(); + Box::pin(async move { + client + .refresh_llm_token(&llm_api_token, organization_id) + .await + }) + } +} + +#[derive(Default, Clone, Debug, PartialEq)] +pub struct ZedDotDevSettings { + pub available_models: Vec, } pub struct CloudLanguageModelProvider { - client: Arc, state: Entity, _maintain_client_status: Task<()>, } pub struct State { client: Arc, - llm_api_token: LlmApiToken, user_store: Entity, status: client::Status, - models: Vec>, - default_model: Option>, - default_fast_model: Option>, - recommended_models: Vec>, + provider: Entity>, _user_store_subscription: Subscription, _settings_subscription: Subscription, _llm_token_subscription: Subscription, + _provider_subscription: Subscription, } impl State { @@ -112,16 +98,26 @@ impl State { cx: &mut Context, ) -> Self { let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx); - let llm_api_token = global_llm_token(cx); + let token_provider = Arc::new(ClientTokenProvider { + client: client.clone(), + llm_api_token: global_llm_token(cx), + user_store: user_store.clone(), + }); + + let provider = cx.new(|cx| { + CloudModelProvider::new( + token_provider.clone(), + client.http_client(), + Some(AppVersion::global(cx)), + ) + }); + Self { client: client.clone(), - llm_api_token, user_store: user_store.clone(), status, - models: Vec::new(), - default_model: None, - default_fast_model: None, - recommended_models: Vec::new(), + _provider_subscription: cx.observe(&provider, |_, _, cx| cx.notify()), + provider, _user_store_subscription: cx.subscribe( &user_store, move |this, _user_store, event, cx| match event { @@ -131,19 +127,7 @@ impl State { return; } - let client = this.client.clone(); - let llm_api_token = this.llm_api_token.clone(); - let organization_id = this - .user_store - .read(cx) - .current_organization() - .map(|organization| organization.id.clone()); - cx.spawn(async move |this, cx| { - let response = - Self::fetch_models(client, llm_api_token, organization_id).await?; - this.update(cx, |this, cx| this.update_models(response, cx)) - }) - .detach_and_log_err(cx); + this.refresh_models(cx); } _ => {} }, @@ -154,21 +138,7 @@ impl State { _llm_token_subscription: cx.subscribe( &refresh_llm_token_listener, move |this, _listener, _event, cx| { - let client = this.client.clone(); - let llm_api_token = this.llm_api_token.clone(); - let organization_id = this - .user_store - .read(cx) - .current_organization() - .map(|organization| organization.id.clone()); - cx.spawn(async move |this, cx| { - let response = - Self::fetch_models(client, llm_api_token, organization_id).await?; - this.update(cx, |this, cx| { - this.update_models(response, cx); - }) - }) - .detach_and_log_err(cx); + this.refresh_models(cx); }, ), } @@ -186,74 +156,10 @@ impl State { }) } - fn update_models(&mut self, response: ListModelsResponse, cx: &mut Context) { - let mut models = Vec::new(); - - for model in response.models { - models.push(Arc::new(model.clone())); - } - - self.default_model = models - .iter() - .find(|model| { - response - .default_model - .as_ref() - .is_some_and(|default_model_id| &model.id == default_model_id) - }) - .cloned(); - self.default_fast_model = models - .iter() - .find(|model| { - response - .default_fast_model - .as_ref() - .is_some_and(|default_fast_model_id| &model.id == default_fast_model_id) - }) - .cloned(); - self.recommended_models = response - .recommended_models - .iter() - .filter_map(|id| models.iter().find(|model| &model.id == id)) - .cloned() - .collect(); - self.models = models; - cx.notify(); - } - - async fn fetch_models( - client: Arc, - llm_api_token: LlmApiToken, - organization_id: Option, - ) -> Result { - let http_client = &client.http_client(); - let token = client - .acquire_llm_token(&llm_api_token, organization_id) - .await?; - - let request = http_client::Request::builder() - .method(Method::GET) - .header(CLIENT_SUPPORTS_X_AI_HEADER_NAME, "true") - .uri(http_client.build_zed_llm_url("/models", &[])?.as_ref()) - .header("Authorization", format!("Bearer {token}")) - .body(AsyncBody::empty())?; - let mut response = http_client - .send(request) - .await - .context("failed to send list models request")?; - - if response.status().is_success() { - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - Ok(serde_json::from_str(&body)?) - } else { - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - anyhow::bail!( - "error listing models.\nStatus: {:?}\nBody: {body}", - response.status(), - ); - } + fn refresh_models(&mut self, cx: &mut Context) { + self.provider.update(cx, |provider, cx| { + provider.refresh_models(cx).detach_and_log_err(cx); + }); } } @@ -281,27 +187,10 @@ impl CloudLanguageModelProvider { }); Self { - client, state, _maintain_client_status: maintain_client_status, } } - - fn create_language_model( - &self, - model: Arc, - llm_api_token: LlmApiToken, - user_store: Entity, - ) -> Arc { - Arc::new(CloudLanguageModel { - id: LanguageModelId(SharedString::from(model.id.0.clone())), - model, - llm_api_token, - user_store, - client: self.client.clone(), - request_limiter: RateLimiter::new(4), - }) - } } impl LanguageModelProviderState for CloudLanguageModelProvider { @@ -327,45 +216,35 @@ impl LanguageModelProvider for CloudLanguageModelProvider { fn default_model(&self, cx: &App) -> Option> { let state = self.state.read(cx); - let default_model = state.default_model.clone()?; - let llm_api_token = state.llm_api_token.clone(); - let user_store = state.user_store.clone(); - Some(self.create_language_model(default_model, llm_api_token, user_store)) + let provider = state.provider.read(cx); + let model = provider.default_model()?; + Some(provider.create_model(model)) } fn default_fast_model(&self, cx: &App) -> Option> { let state = self.state.read(cx); - let default_fast_model = state.default_fast_model.clone()?; - let llm_api_token = state.llm_api_token.clone(); - let user_store = state.user_store.clone(); - Some(self.create_language_model(default_fast_model, llm_api_token, user_store)) + let provider = state.provider.read(cx); + let model = provider.default_fast_model()?; + Some(provider.create_model(model)) } fn recommended_models(&self, cx: &App) -> Vec> { let state = self.state.read(cx); - let llm_api_token = state.llm_api_token.clone(); - let user_store = state.user_store.clone(); - state - .recommended_models + let provider = state.provider.read(cx); + provider + .recommended_models() .iter() - .cloned() - .map(|model| { - self.create_language_model(model, llm_api_token.clone(), user_store.clone()) - }) + .map(|model| provider.create_model(model)) .collect() } fn provided_models(&self, cx: &App) -> Vec> { let state = self.state.read(cx); - let llm_api_token = state.llm_api_token.clone(); - let user_store = state.user_store.clone(); - state - .models + let provider = state.provider.read(cx); + provider + .models() .iter() - .cloned() - .map(|model| { - self.create_language_model(model, llm_api_token.clone(), user_store.clone()) - }) + .map(|model| provider.create_model(model)) .collect() } @@ -393,700 +272,6 @@ impl LanguageModelProvider for CloudLanguageModelProvider { } } -pub struct CloudLanguageModel { - id: LanguageModelId, - model: Arc, - llm_api_token: LlmApiToken, - user_store: Entity, - client: Arc, - request_limiter: RateLimiter, -} - -struct PerformLlmCompletionResponse { - response: Response, - includes_status_messages: bool, -} - -impl CloudLanguageModel { - async fn perform_llm_completion( - client: Arc, - llm_api_token: LlmApiToken, - organization_id: Option, - app_version: Option, - body: CompletionBody, - ) -> Result { - let http_client = &client.http_client(); - - let mut token = client - .acquire_llm_token(&llm_api_token, organization_id.clone()) - .await?; - let mut refreshed_token = false; - - loop { - let request = http_client::Request::builder() - .method(Method::POST) - .uri(http_client.build_zed_llm_url("/completions", &[])?.as_ref()) - .when_some(app_version.as_ref(), |builder, app_version| { - builder.header(ZED_VERSION_HEADER_NAME, app_version.to_string()) - }) - .header("Content-Type", "application/json") - .header("Authorization", format!("Bearer {token}")) - .header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true") - .header(CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME, "true") - .body(serde_json::to_string(&body)?.into())?; - - let mut response = http_client.send(request).await?; - let status = response.status(); - if status.is_success() { - let includes_status_messages = response - .headers() - .get(SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME) - .is_some(); - - return Ok(PerformLlmCompletionResponse { - response, - includes_status_messages, - }); - } - - if !refreshed_token && response.needs_llm_token_refresh() { - token = client - .refresh_llm_token(&llm_api_token, organization_id.clone()) - .await?; - refreshed_token = true; - continue; - } - - if status == StatusCode::PAYMENT_REQUIRED { - return Err(anyhow!(PaymentRequiredError)); - } - - let mut body = String::new(); - let headers = response.headers().clone(); - response.body_mut().read_to_string(&mut body).await?; - return Err(anyhow!(ApiError { - status, - body, - headers - })); - } - } -} - -#[derive(Debug, Error)] -#[error("cloud language model request failed with status {status}: {body}")] -struct ApiError { - status: StatusCode, - body: String, - headers: HeaderMap, -} - -/// Represents error responses from Zed's cloud API. -/// -/// Example JSON for an upstream HTTP error: -/// ```json -/// { -/// "code": "upstream_http_error", -/// "message": "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout", -/// "upstream_status": 503 -/// } -/// ``` -#[derive(Debug, serde::Deserialize)] -struct CloudApiError { - code: String, - message: String, - #[serde(default)] - #[serde(deserialize_with = "deserialize_optional_status_code")] - upstream_status: Option, - #[serde(default)] - retry_after: Option, -} - -fn deserialize_optional_status_code<'de, D>(deserializer: D) -> Result, D::Error> -where - D: serde::Deserializer<'de>, -{ - let opt: Option = Option::deserialize(deserializer)?; - Ok(opt.and_then(|code| StatusCode::from_u16(code).ok())) -} - -impl From for LanguageModelCompletionError { - fn from(error: ApiError) -> Self { - if let Ok(cloud_error) = serde_json::from_str::(&error.body) { - if cloud_error.code.starts_with("upstream_http_") { - let status = if let Some(status) = cloud_error.upstream_status { - status - } else if cloud_error.code.ends_with("_error") { - error.status - } else { - // If there's a status code in the code string (e.g. "upstream_http_429") - // then use that; otherwise, see if the JSON contains a status code. - cloud_error - .code - .strip_prefix("upstream_http_") - .and_then(|code_str| code_str.parse::().ok()) - .and_then(|code| StatusCode::from_u16(code).ok()) - .unwrap_or(error.status) - }; - - return LanguageModelCompletionError::UpstreamProviderError { - message: cloud_error.message, - status, - retry_after: cloud_error.retry_after.map(Duration::from_secs_f64), - }; - } - - return LanguageModelCompletionError::from_http_status( - PROVIDER_NAME, - error.status, - cloud_error.message, - None, - ); - } - - let retry_after = None; - LanguageModelCompletionError::from_http_status( - PROVIDER_NAME, - error.status, - error.body, - retry_after, - ) - } -} - -impl LanguageModel for CloudLanguageModel { - fn id(&self) -> LanguageModelId { - self.id.clone() - } - - fn name(&self) -> LanguageModelName { - LanguageModelName::from(self.model.display_name.clone()) - } - - fn provider_id(&self) -> LanguageModelProviderId { - PROVIDER_ID - } - - fn provider_name(&self) -> LanguageModelProviderName { - PROVIDER_NAME - } - - fn upstream_provider_id(&self) -> LanguageModelProviderId { - use cloud_llm_client::LanguageModelProvider::*; - match self.model.provider { - Anthropic => ANTHROPIC_PROVIDER_ID, - OpenAi => OPEN_AI_PROVIDER_ID, - Google => GOOGLE_PROVIDER_ID, - XAi => X_AI_PROVIDER_ID, - } - } - - fn upstream_provider_name(&self) -> LanguageModelProviderName { - use cloud_llm_client::LanguageModelProvider::*; - match self.model.provider { - Anthropic => ANTHROPIC_PROVIDER_NAME, - OpenAi => OPEN_AI_PROVIDER_NAME, - Google => GOOGLE_PROVIDER_NAME, - XAi => X_AI_PROVIDER_NAME, - } - } - - fn is_latest(&self) -> bool { - self.model.is_latest - } - - fn supports_tools(&self) -> bool { - self.model.supports_tools - } - - fn supports_images(&self) -> bool { - self.model.supports_images - } - - fn supports_thinking(&self) -> bool { - self.model.supports_thinking - } - - fn supports_fast_mode(&self) -> bool { - self.model.supports_fast_mode - } - - fn supported_effort_levels(&self) -> Vec { - self.model - .supported_effort_levels - .iter() - .map(|effort_level| LanguageModelEffortLevel { - name: effort_level.name.clone().into(), - value: effort_level.value.clone().into(), - is_default: effort_level.is_default.unwrap_or(false), - }) - .collect() - } - - fn supports_streaming_tools(&self) -> bool { - self.model.supports_streaming_tools - } - - fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { - match choice { - LanguageModelToolChoice::Auto - | LanguageModelToolChoice::Any - | LanguageModelToolChoice::None => true, - } - } - - fn supports_split_token_display(&self) -> bool { - use cloud_llm_client::LanguageModelProvider::*; - matches!(self.model.provider, OpenAi | XAi) - } - - fn telemetry_id(&self) -> String { - format!("zed.dev/{}", self.model.id) - } - - fn tool_input_format(&self) -> LanguageModelToolSchemaFormat { - match self.model.provider { - cloud_llm_client::LanguageModelProvider::Anthropic - | cloud_llm_client::LanguageModelProvider::OpenAi => { - LanguageModelToolSchemaFormat::JsonSchema - } - cloud_llm_client::LanguageModelProvider::Google - | cloud_llm_client::LanguageModelProvider::XAi => { - LanguageModelToolSchemaFormat::JsonSchemaSubset - } - } - } - - fn max_token_count(&self) -> u64 { - self.model.max_token_count as u64 - } - - fn max_output_tokens(&self) -> Option { - Some(self.model.max_output_tokens as u64) - } - - fn cache_configuration(&self) -> Option { - match &self.model.provider { - cloud_llm_client::LanguageModelProvider::Anthropic => { - Some(LanguageModelCacheConfiguration { - min_total_token: 2_048, - should_speculate: true, - max_cache_anchors: 4, - }) - } - cloud_llm_client::LanguageModelProvider::OpenAi - | cloud_llm_client::LanguageModelProvider::XAi - | cloud_llm_client::LanguageModelProvider::Google => None, - } - } - - fn count_tokens( - &self, - request: LanguageModelRequest, - cx: &App, - ) -> BoxFuture<'static, Result> { - match self.model.provider { - cloud_llm_client::LanguageModelProvider::Anthropic => cx - .background_spawn(async move { count_anthropic_tokens_with_tiktoken(request) }) - .boxed(), - cloud_llm_client::LanguageModelProvider::OpenAi => { - let model = match open_ai::Model::from_id(&self.model.id.0) { - Ok(model) => model, - Err(err) => return async move { Err(anyhow!(err)) }.boxed(), - }; - count_open_ai_tokens(request, model, cx) - } - cloud_llm_client::LanguageModelProvider::XAi => { - let model = match x_ai::Model::from_id(&self.model.id.0) { - Ok(model) => model, - Err(err) => return async move { Err(anyhow!(err)) }.boxed(), - }; - count_xai_tokens(request, model, cx) - } - cloud_llm_client::LanguageModelProvider::Google => { - let client = self.client.clone(); - let llm_api_token = self.llm_api_token.clone(); - let organization_id = self - .user_store - .read(cx) - .current_organization() - .map(|organization| organization.id.clone()); - let model_id = self.model.id.to_string(); - let generate_content_request = - into_google(request, model_id.clone(), GoogleModelMode::Default); - async move { - let http_client = &client.http_client(); - let token = client - .acquire_llm_token(&llm_api_token, organization_id) - .await?; - - let request_body = CountTokensBody { - provider: cloud_llm_client::LanguageModelProvider::Google, - model: model_id, - provider_request: serde_json::to_value(&google_ai::CountTokensRequest { - generate_content_request, - })?, - }; - let request = http_client::Request::builder() - .method(Method::POST) - .uri( - http_client - .build_zed_llm_url("/count_tokens", &[])? - .as_ref(), - ) - .header("Content-Type", "application/json") - .header("Authorization", format!("Bearer {token}")) - .body(serde_json::to_string(&request_body)?.into())?; - let mut response = http_client.send(request).await?; - let status = response.status(); - let headers = response.headers().clone(); - let mut response_body = String::new(); - response - .body_mut() - .read_to_string(&mut response_body) - .await?; - - if status.is_success() { - let response_body: CountTokensResponse = - serde_json::from_str(&response_body)?; - - Ok(response_body.tokens as u64) - } else { - Err(anyhow!(ApiError { - status, - body: response_body, - headers - })) - } - } - .boxed() - } - } - } - - fn stream_completion( - &self, - request: LanguageModelRequest, - cx: &AsyncApp, - ) -> BoxFuture< - 'static, - Result< - BoxStream<'static, Result>, - LanguageModelCompletionError, - >, - > { - let thread_id = request.thread_id.clone(); - let prompt_id = request.prompt_id.clone(); - let app_version = Some(cx.update(|cx| AppVersion::global(cx))); - let user_store = self.user_store.clone(); - let organization_id = cx.update(|cx| { - user_store - .read(cx) - .current_organization() - .map(|organization| organization.id.clone()) - }); - let thinking_allowed = request.thinking_allowed; - let enable_thinking = thinking_allowed && self.model.supports_thinking; - let provider_name = provider_name(&self.model.provider); - match self.model.provider { - cloud_llm_client::LanguageModelProvider::Anthropic => { - let effort = request - .thinking_effort - .as_ref() - .and_then(|effort| anthropic::Effort::from_str(effort).ok()); - - let mut request = into_anthropic( - request, - self.model.id.to_string(), - 1.0, - self.model.max_output_tokens as u64, - if enable_thinking { - AnthropicModelMode::Thinking { - budget_tokens: Some(4_096), - } - } else { - AnthropicModelMode::Default - }, - ); - - if enable_thinking && effort.is_some() { - request.thinking = Some(anthropic::Thinking::Adaptive); - request.output_config = Some(anthropic::OutputConfig { effort }); - } - - let client = self.client.clone(); - let llm_api_token = self.llm_api_token.clone(); - let organization_id = organization_id.clone(); - let future = self.request_limiter.stream(async move { - let PerformLlmCompletionResponse { - response, - includes_status_messages, - } = Self::perform_llm_completion( - client.clone(), - llm_api_token, - organization_id, - app_version, - CompletionBody { - thread_id, - prompt_id, - provider: cloud_llm_client::LanguageModelProvider::Anthropic, - model: request.model.clone(), - provider_request: serde_json::to_value(&request) - .map_err(|e| anyhow!(e))?, - }, - ) - .await - .map_err(|err| match err.downcast::() { - Ok(api_err) => anyhow!(LanguageModelCompletionError::from(api_err)), - Err(err) => anyhow!(err), - })?; - - let mut mapper = AnthropicEventMapper::new(); - Ok(map_cloud_completion_events( - Box::pin(response_lines(response, includes_status_messages)), - &provider_name, - move |event| mapper.map_event(event), - )) - }); - async move { Ok(future.await?.boxed()) }.boxed() - } - cloud_llm_client::LanguageModelProvider::OpenAi => { - let client = self.client.clone(); - let llm_api_token = self.llm_api_token.clone(); - let organization_id = organization_id.clone(); - let effort = request - .thinking_effort - .as_ref() - .and_then(|effort| open_ai::ReasoningEffort::from_str(effort).ok()); - - let mut request = into_open_ai_response( - request, - &self.model.id.0, - self.model.supports_parallel_tool_calls, - true, - None, - None, - ); - - if enable_thinking && let Some(effort) = effort { - request.reasoning = Some(open_ai::responses::ReasoningConfig { - effort, - summary: Some(open_ai::responses::ReasoningSummaryMode::Auto), - }); - } - - let future = self.request_limiter.stream(async move { - let PerformLlmCompletionResponse { - response, - includes_status_messages, - } = Self::perform_llm_completion( - client.clone(), - llm_api_token, - organization_id, - app_version, - CompletionBody { - thread_id, - prompt_id, - provider: cloud_llm_client::LanguageModelProvider::OpenAi, - model: request.model.clone(), - provider_request: serde_json::to_value(&request) - .map_err(|e| anyhow!(e))?, - }, - ) - .await?; - - let mut mapper = OpenAiResponseEventMapper::new(); - Ok(map_cloud_completion_events( - Box::pin(response_lines(response, includes_status_messages)), - &provider_name, - move |event| mapper.map_event(event), - )) - }); - async move { Ok(future.await?.boxed()) }.boxed() - } - cloud_llm_client::LanguageModelProvider::XAi => { - let client = self.client.clone(); - let request = into_open_ai( - request, - &self.model.id.0, - self.model.supports_parallel_tool_calls, - false, - None, - None, - ); - let llm_api_token = self.llm_api_token.clone(); - let organization_id = organization_id.clone(); - let future = self.request_limiter.stream(async move { - let PerformLlmCompletionResponse { - response, - includes_status_messages, - } = Self::perform_llm_completion( - client.clone(), - llm_api_token, - organization_id, - app_version, - CompletionBody { - thread_id, - prompt_id, - provider: cloud_llm_client::LanguageModelProvider::XAi, - model: request.model.clone(), - provider_request: serde_json::to_value(&request) - .map_err(|e| anyhow!(e))?, - }, - ) - .await?; - - let mut mapper = OpenAiEventMapper::new(); - Ok(map_cloud_completion_events( - Box::pin(response_lines(response, includes_status_messages)), - &provider_name, - move |event| mapper.map_event(event), - )) - }); - async move { Ok(future.await?.boxed()) }.boxed() - } - cloud_llm_client::LanguageModelProvider::Google => { - let client = self.client.clone(); - let request = - into_google(request, self.model.id.to_string(), GoogleModelMode::Default); - let llm_api_token = self.llm_api_token.clone(); - let future = self.request_limiter.stream(async move { - let PerformLlmCompletionResponse { - response, - includes_status_messages, - } = Self::perform_llm_completion( - client.clone(), - llm_api_token, - organization_id, - app_version, - CompletionBody { - thread_id, - prompt_id, - provider: cloud_llm_client::LanguageModelProvider::Google, - model: request.model.model_id.clone(), - provider_request: serde_json::to_value(&request) - .map_err(|e| anyhow!(e))?, - }, - ) - .await?; - - let mut mapper = GoogleEventMapper::new(); - Ok(map_cloud_completion_events( - Box::pin(response_lines(response, includes_status_messages)), - &provider_name, - move |event| mapper.map_event(event), - )) - }); - async move { Ok(future.await?.boxed()) }.boxed() - } - } - } -} - -fn map_cloud_completion_events( - stream: Pin>> + Send>>, - provider: &LanguageModelProviderName, - mut map_callback: F, -) -> BoxStream<'static, Result> -where - T: DeserializeOwned + 'static, - F: FnMut(T) -> Vec> - + Send - + 'static, -{ - let provider = provider.clone(); - let mut stream = stream.fuse(); - - let mut saw_stream_ended = false; - - let mut done = false; - let mut pending = VecDeque::new(); - - stream::poll_fn(move |cx| { - loop { - if let Some(item) = pending.pop_front() { - return Poll::Ready(Some(item)); - } - - if done { - return Poll::Ready(None); - } - - match stream.poll_next_unpin(cx) { - Poll::Ready(Some(event)) => { - let items = match event { - Err(error) => { - vec![Err(LanguageModelCompletionError::from(error))] - } - Ok(CompletionEvent::Status(CompletionRequestStatus::StreamEnded)) => { - saw_stream_ended = true; - vec![] - } - Ok(CompletionEvent::Status(status)) => { - LanguageModelCompletionEvent::from_completion_request_status( - status, - provider.clone(), - ) - .transpose() - .map(|event| vec![event]) - .unwrap_or_default() - } - Ok(CompletionEvent::Event(event)) => map_callback(event), - }; - pending.extend(items); - } - Poll::Ready(None) => { - done = true; - - if !saw_stream_ended { - return Poll::Ready(Some(Err( - LanguageModelCompletionError::StreamEndedUnexpectedly { - provider: provider.clone(), - }, - ))); - } - } - Poll::Pending => return Poll::Pending, - } - } - }) - .boxed() -} - -fn provider_name(provider: &cloud_llm_client::LanguageModelProvider) -> LanguageModelProviderName { - match provider { - cloud_llm_client::LanguageModelProvider::Anthropic => ANTHROPIC_PROVIDER_NAME, - cloud_llm_client::LanguageModelProvider::OpenAi => OPEN_AI_PROVIDER_NAME, - cloud_llm_client::LanguageModelProvider::Google => GOOGLE_PROVIDER_NAME, - cloud_llm_client::LanguageModelProvider::XAi => X_AI_PROVIDER_NAME, - } -} - -fn response_lines( - response: Response, - includes_status_messages: bool, -) -> impl Stream>> { - futures::stream::try_unfold( - (String::new(), BufReader::new(response.into_body())), - move |(mut line, mut body)| async move { - match body.read_line(&mut line).await { - Ok(0) => Ok(None), - Ok(_) => { - let event = if includes_status_messages { - serde_json::from_str::>(&line)? - } else { - CompletionEvent::Event(serde_json::from_str::(&line)?) - }; - - line.clear(); - Ok(Some((event, (line, body)))) - } - Err(e) => Err(e.into()), - } - }, - ) -} - #[derive(IntoElement, RegisterComponent)] struct ZedAiConfiguration { is_connected: bool, @@ -1281,155 +466,3 @@ impl Component for ZedAiConfiguration { ) } } - -#[cfg(test)] -mod tests { - use super::*; - use http_client::http::{HeaderMap, StatusCode}; - use language_model::LanguageModelCompletionError; - - #[test] - fn test_api_error_conversion_with_upstream_http_error() { - // upstream_http_error with 503 status should become ServerOverloaded - let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout","upstream_status":503}"#; - - let api_error = ApiError { - status: StatusCode::INTERNAL_SERVER_ERROR, - body: error_body.to_string(), - headers: HeaderMap::new(), - }; - - let completion_error: LanguageModelCompletionError = api_error.into(); - - match completion_error { - LanguageModelCompletionError::UpstreamProviderError { message, .. } => { - assert_eq!( - message, - "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout" - ); - } - _ => panic!( - "Expected UpstreamProviderError for upstream 503, got: {:?}", - completion_error - ), - } - - // upstream_http_error with 500 status should become ApiInternalServerError - let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the OpenAI API: internal server error","upstream_status":500}"#; - - let api_error = ApiError { - status: StatusCode::INTERNAL_SERVER_ERROR, - body: error_body.to_string(), - headers: HeaderMap::new(), - }; - - let completion_error: LanguageModelCompletionError = api_error.into(); - - match completion_error { - LanguageModelCompletionError::UpstreamProviderError { message, .. } => { - assert_eq!( - message, - "Received an error from the OpenAI API: internal server error" - ); - } - _ => panic!( - "Expected UpstreamProviderError for upstream 500, got: {:?}", - completion_error - ), - } - - // upstream_http_error with 429 status should become RateLimitExceeded - let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Google API: rate limit exceeded","upstream_status":429}"#; - - let api_error = ApiError { - status: StatusCode::INTERNAL_SERVER_ERROR, - body: error_body.to_string(), - headers: HeaderMap::new(), - }; - - let completion_error: LanguageModelCompletionError = api_error.into(); - - match completion_error { - LanguageModelCompletionError::UpstreamProviderError { message, .. } => { - assert_eq!( - message, - "Received an error from the Google API: rate limit exceeded" - ); - } - _ => panic!( - "Expected UpstreamProviderError for upstream 429, got: {:?}", - completion_error - ), - } - - // Regular 500 error without upstream_http_error should remain ApiInternalServerError for Zed - let error_body = "Regular internal server error"; - - let api_error = ApiError { - status: StatusCode::INTERNAL_SERVER_ERROR, - body: error_body.to_string(), - headers: HeaderMap::new(), - }; - - let completion_error: LanguageModelCompletionError = api_error.into(); - - match completion_error { - LanguageModelCompletionError::ApiInternalServerError { provider, message } => { - assert_eq!(provider, PROVIDER_NAME); - assert_eq!(message, "Regular internal server error"); - } - _ => panic!( - "Expected ApiInternalServerError for regular 500, got: {:?}", - completion_error - ), - } - - // upstream_http_429 format should be converted to UpstreamProviderError - let error_body = r#"{"code":"upstream_http_429","message":"Upstream Anthropic rate limit exceeded.","retry_after":30.5}"#; - - let api_error = ApiError { - status: StatusCode::INTERNAL_SERVER_ERROR, - body: error_body.to_string(), - headers: HeaderMap::new(), - }; - - let completion_error: LanguageModelCompletionError = api_error.into(); - - match completion_error { - LanguageModelCompletionError::UpstreamProviderError { - message, - status, - retry_after, - } => { - assert_eq!(message, "Upstream Anthropic rate limit exceeded."); - assert_eq!(status, StatusCode::TOO_MANY_REQUESTS); - assert_eq!(retry_after, Some(Duration::from_secs_f64(30.5))); - } - _ => panic!( - "Expected UpstreamProviderError for upstream_http_429, got: {:?}", - completion_error - ), - } - - // Invalid JSON in error body should fall back to regular error handling - let error_body = "Not JSON at all"; - - let api_error = ApiError { - status: StatusCode::INTERNAL_SERVER_ERROR, - body: error_body.to_string(), - headers: HeaderMap::new(), - }; - - let completion_error: LanguageModelCompletionError = api_error.into(); - - match completion_error { - LanguageModelCompletionError::ApiInternalServerError { provider, .. } => { - assert_eq!(provider, PROVIDER_NAME); - } - _ => panic!( - "Expected ApiInternalServerError for invalid JSON, got: {:?}", - completion_error - ), - } - } -} diff --git a/crates/language_models/src/provider/copilot_chat.rs b/crates/language_models/src/provider/copilot_chat.rs index a2d39e1945e2791d9d5c998cc717a07498ebc157..a77e3f880be18d8f9f0e97ec8717c32bc780e267 100644 --- a/crates/language_models/src/provider/copilot_chat.rs +++ b/crates/language_models/src/provider/copilot_chat.rs @@ -32,7 +32,7 @@ use ui::prelude::*; use util::debug_panic; use crate::provider::anthropic::{AnthropicEventMapper, into_anthropic}; -use crate::provider::util::{fix_streamed_json, parse_tool_arguments}; +use language_model::util::{fix_streamed_json, parse_tool_arguments}; const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("copilot_chat"); const PROVIDER_NAME: LanguageModelProviderName = @@ -268,15 +268,15 @@ impl LanguageModel for CopilotChatLanguageModel { levels .iter() .map(|level| { - let name: SharedString = match level.as_str() { + let name = match level.as_str() { "low" => "Low".into(), "medium" => "Medium".into(), "high" => "High".into(), - _ => SharedString::from(level.clone()), + _ => language_model::SharedString::from(level.clone()), }; LanguageModelEffortLevel { name, - value: SharedString::from(level.clone()), + value: language_model::SharedString::from(level.clone()), is_default: level == "high", } }) diff --git a/crates/language_models/src/provider/deepseek.rs b/crates/language_models/src/provider/deepseek.rs index 0cfb1af425c7cb0279d98fa124a589437f1bb1a1..f3dccd5cc1a2e1a5ddfe2bc6b43901f2b549e532 100644 --- a/crates/language_models/src/provider/deepseek.rs +++ b/crates/language_models/src/provider/deepseek.rs @@ -23,7 +23,7 @@ use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*}; use ui_input::InputField; use util::ResultExt; -use crate::provider::util::{fix_streamed_json, parse_tool_arguments}; +use language_model::util::{fix_streamed_json, parse_tool_arguments}; const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("deepseek"); const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("DeepSeek"); diff --git a/crates/language_models/src/provider/google.rs b/crates/language_models/src/provider/google.rs index 244f7835a85ff67f0c4826321910ea13516371cb..92278839c6ff5119849f8881409928686f055331 100644 --- a/crates/language_models/src/provider/google.rs +++ b/crates/language_models/src/provider/google.rs @@ -1,32 +1,25 @@ use anyhow::{Context as _, Result}; use collections::BTreeMap; use credentials_provider::CredentialsProvider; -use futures::{FutureExt, Stream, StreamExt, future::BoxFuture}; -use google_ai::{ - FunctionDeclaration, GenerateContentResponse, GoogleModelMode, Part, SystemInstruction, - ThinkingConfig, UsageMetadata, -}; +use futures::{FutureExt, StreamExt, future::BoxFuture}; +pub use google_ai::completion::{GoogleEventMapper, count_google_tokens, into_google}; +use google_ai::{GenerateContentResponse, GoogleModelMode}; use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window}; use http_client::HttpClient; use language_model::{ AuthenticateError, ConfigurationViewTargetAgent, EnvVar, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelToolChoice, LanguageModelToolSchemaFormat, - LanguageModelToolUse, LanguageModelToolUseId, MessageContent, StopReason, }; use language_model::{ GOOGLE_PROVIDER_ID, GOOGLE_PROVIDER_NAME, IconOrSvg, LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, - LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role, + LanguageModelProviderState, LanguageModelRequest, RateLimiter, }; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; pub use settings::GoogleAvailableModel as AvailableModel; use settings::{Settings, SettingsStore}; -use std::pin::Pin; -use std::sync::{ - Arc, LazyLock, - atomic::{self, AtomicU64}, -}; +use std::sync::{Arc, LazyLock}; use strum::IntoEnumIterator; use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*}; use ui_input::InputField; @@ -394,369 +387,6 @@ impl LanguageModel for GoogleLanguageModel { } } -pub fn into_google( - mut request: LanguageModelRequest, - model_id: String, - mode: GoogleModelMode, -) -> google_ai::GenerateContentRequest { - fn map_content(content: Vec) -> Vec { - content - .into_iter() - .flat_map(|content| match content { - language_model::MessageContent::Text(text) => { - if !text.is_empty() { - vec![Part::TextPart(google_ai::TextPart { text })] - } else { - vec![] - } - } - language_model::MessageContent::Thinking { - text: _, - signature: Some(signature), - } => { - if !signature.is_empty() { - vec![Part::ThoughtPart(google_ai::ThoughtPart { - thought: true, - thought_signature: signature, - })] - } else { - vec![] - } - } - language_model::MessageContent::Thinking { .. } => { - vec![] - } - language_model::MessageContent::RedactedThinking(_) => vec![], - language_model::MessageContent::Image(image) => { - vec![Part::InlineDataPart(google_ai::InlineDataPart { - inline_data: google_ai::GenerativeContentBlob { - mime_type: "image/png".to_string(), - data: image.source.to_string(), - }, - })] - } - language_model::MessageContent::ToolUse(tool_use) => { - // Normalize empty string signatures to None - let thought_signature = tool_use.thought_signature.filter(|s| !s.is_empty()); - - vec![Part::FunctionCallPart(google_ai::FunctionCallPart { - function_call: google_ai::FunctionCall { - name: tool_use.name.to_string(), - args: tool_use.input, - }, - thought_signature, - })] - } - language_model::MessageContent::ToolResult(tool_result) => { - match tool_result.content { - language_model::LanguageModelToolResultContent::Text(text) => { - vec![Part::FunctionResponsePart( - google_ai::FunctionResponsePart { - function_response: google_ai::FunctionResponse { - name: tool_result.tool_name.to_string(), - // The API expects a valid JSON object - response: serde_json::json!({ - "output": text - }), - }, - }, - )] - } - language_model::LanguageModelToolResultContent::Image(image) => { - vec![ - Part::FunctionResponsePart(google_ai::FunctionResponsePart { - function_response: google_ai::FunctionResponse { - name: tool_result.tool_name.to_string(), - // The API expects a valid JSON object - response: serde_json::json!({ - "output": "Tool responded with an image" - }), - }, - }), - Part::InlineDataPart(google_ai::InlineDataPart { - inline_data: google_ai::GenerativeContentBlob { - mime_type: "image/png".to_string(), - data: image.source.to_string(), - }, - }), - ] - } - } - } - }) - .collect() - } - - let system_instructions = if request - .messages - .first() - .is_some_and(|msg| matches!(msg.role, Role::System)) - { - let message = request.messages.remove(0); - Some(SystemInstruction { - parts: map_content(message.content), - }) - } else { - None - }; - - google_ai::GenerateContentRequest { - model: google_ai::ModelName { model_id }, - system_instruction: system_instructions, - contents: request - .messages - .into_iter() - .filter_map(|message| { - let parts = map_content(message.content); - if parts.is_empty() { - None - } else { - Some(google_ai::Content { - parts, - role: match message.role { - Role::User => google_ai::Role::User, - Role::Assistant => google_ai::Role::Model, - Role::System => google_ai::Role::User, // Google AI doesn't have a system role - }, - }) - } - }) - .collect(), - generation_config: Some(google_ai::GenerationConfig { - candidate_count: Some(1), - stop_sequences: Some(request.stop), - max_output_tokens: None, - temperature: request.temperature.map(|t| t as f64).or(Some(1.0)), - thinking_config: match (request.thinking_allowed, mode) { - (true, GoogleModelMode::Thinking { budget_tokens }) => { - budget_tokens.map(|thinking_budget| ThinkingConfig { thinking_budget }) - } - _ => None, - }, - top_p: None, - top_k: None, - }), - safety_settings: None, - tools: (!request.tools.is_empty()).then(|| { - vec![google_ai::Tool { - function_declarations: request - .tools - .into_iter() - .map(|tool| FunctionDeclaration { - name: tool.name, - description: tool.description, - parameters: tool.input_schema, - }) - .collect(), - }] - }), - tool_config: request.tool_choice.map(|choice| google_ai::ToolConfig { - function_calling_config: google_ai::FunctionCallingConfig { - mode: match choice { - LanguageModelToolChoice::Auto => google_ai::FunctionCallingMode::Auto, - LanguageModelToolChoice::Any => google_ai::FunctionCallingMode::Any, - LanguageModelToolChoice::None => google_ai::FunctionCallingMode::None, - }, - allowed_function_names: None, - }, - }), - } -} - -pub struct GoogleEventMapper { - usage: UsageMetadata, - stop_reason: StopReason, -} - -impl GoogleEventMapper { - pub fn new() -> Self { - Self { - usage: UsageMetadata::default(), - stop_reason: StopReason::EndTurn, - } - } - - pub fn map_stream( - mut self, - events: Pin>>>, - ) -> impl Stream> - { - events - .map(Some) - .chain(futures::stream::once(async { None })) - .flat_map(move |event| { - futures::stream::iter(match event { - Some(Ok(event)) => self.map_event(event), - Some(Err(error)) => { - vec![Err(LanguageModelCompletionError::from(error))] - } - None => vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))], - }) - }) - } - - pub fn map_event( - &mut self, - event: GenerateContentResponse, - ) -> Vec> { - static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0); - - let mut events: Vec<_> = Vec::new(); - let mut wants_to_use_tool = false; - if let Some(usage_metadata) = event.usage_metadata { - update_usage(&mut self.usage, &usage_metadata); - events.push(Ok(LanguageModelCompletionEvent::UsageUpdate( - convert_usage(&self.usage), - ))) - } - - if let Some(prompt_feedback) = event.prompt_feedback - && let Some(block_reason) = prompt_feedback.block_reason.as_deref() - { - self.stop_reason = match block_reason { - "SAFETY" | "OTHER" | "BLOCKLIST" | "PROHIBITED_CONTENT" | "IMAGE_SAFETY" => { - StopReason::Refusal - } - _ => { - log::error!("Unexpected Google block_reason: {block_reason}"); - StopReason::Refusal - } - }; - events.push(Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))); - - return events; - } - - if let Some(candidates) = event.candidates { - for candidate in candidates { - if let Some(finish_reason) = candidate.finish_reason.as_deref() { - self.stop_reason = match finish_reason { - "STOP" => StopReason::EndTurn, - "MAX_TOKENS" => StopReason::MaxTokens, - _ => { - log::error!("Unexpected google finish_reason: {finish_reason}"); - StopReason::EndTurn - } - }; - } - candidate - .content - .parts - .into_iter() - .for_each(|part| match part { - Part::TextPart(text_part) => { - events.push(Ok(LanguageModelCompletionEvent::Text(text_part.text))) - } - Part::InlineDataPart(_) => {} - Part::FunctionCallPart(function_call_part) => { - wants_to_use_tool = true; - let name: Arc = function_call_part.function_call.name.into(); - let next_tool_id = - TOOL_CALL_COUNTER.fetch_add(1, atomic::Ordering::SeqCst); - let id: LanguageModelToolUseId = - format!("{}-{}", name, next_tool_id).into(); - - // Normalize empty string signatures to None - let thought_signature = function_call_part - .thought_signature - .filter(|s| !s.is_empty()); - - events.push(Ok(LanguageModelCompletionEvent::ToolUse( - LanguageModelToolUse { - id, - name, - is_input_complete: true, - raw_input: function_call_part.function_call.args.to_string(), - input: function_call_part.function_call.args, - thought_signature, - }, - ))); - } - Part::FunctionResponsePart(_) => {} - Part::ThoughtPart(part) => { - events.push(Ok(LanguageModelCompletionEvent::Thinking { - text: "(Encrypted thought)".to_string(), // TODO: Can we populate this from thought summaries? - signature: Some(part.thought_signature), - })); - } - }); - } - } - - // Even when Gemini wants to use a Tool, the API - // responds with `finish_reason: STOP` - if wants_to_use_tool { - self.stop_reason = StopReason::ToolUse; - events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse))); - } - events - } -} - -pub fn count_google_tokens( - request: LanguageModelRequest, - cx: &App, -) -> BoxFuture<'static, Result> { - // We couldn't use the GoogleLanguageModelProvider to count tokens because the github copilot doesn't have the access to google_ai directly. - // So we have to use tokenizer from tiktoken_rs to count tokens. - cx.background_spawn(async move { - let messages = request - .messages - .into_iter() - .map(|message| tiktoken_rs::ChatCompletionRequestMessage { - role: match message.role { - Role::User => "user".into(), - Role::Assistant => "assistant".into(), - Role::System => "system".into(), - }, - content: Some(message.string_contents()), - name: None, - function_call: None, - }) - .collect::>(); - - // Tiktoken doesn't yet support these models, so we manually use the - // same tokenizer as GPT-4. - tiktoken_rs::num_tokens_from_messages("gpt-4", &messages).map(|tokens| tokens as u64) - }) - .boxed() -} - -fn update_usage(usage: &mut UsageMetadata, new: &UsageMetadata) { - if let Some(prompt_token_count) = new.prompt_token_count { - usage.prompt_token_count = Some(prompt_token_count); - } - if let Some(cached_content_token_count) = new.cached_content_token_count { - usage.cached_content_token_count = Some(cached_content_token_count); - } - if let Some(candidates_token_count) = new.candidates_token_count { - usage.candidates_token_count = Some(candidates_token_count); - } - if let Some(tool_use_prompt_token_count) = new.tool_use_prompt_token_count { - usage.tool_use_prompt_token_count = Some(tool_use_prompt_token_count); - } - if let Some(thoughts_token_count) = new.thoughts_token_count { - usage.thoughts_token_count = Some(thoughts_token_count); - } - if let Some(total_token_count) = new.total_token_count { - usage.total_token_count = Some(total_token_count); - } -} - -fn convert_usage(usage: &UsageMetadata) -> language_model::TokenUsage { - let prompt_tokens = usage.prompt_token_count.unwrap_or(0); - let cached_tokens = usage.cached_content_token_count.unwrap_or(0); - let input_tokens = prompt_tokens - cached_tokens; - let output_tokens = usage.candidates_token_count.unwrap_or(0); - - language_model::TokenUsage { - input_tokens, - output_tokens, - cache_read_input_tokens: cached_tokens, - cache_creation_input_tokens: 0, - } -} - struct ConfigurationView { api_key_editor: Entity, state: Entity, @@ -895,428 +525,3 @@ impl Render for ConfigurationView { } } } - -#[cfg(test)] -mod tests { - use super::*; - use google_ai::{ - Content, FunctionCall, FunctionCallPart, GenerateContentCandidate, GenerateContentResponse, - Part, Role as GoogleRole, TextPart, - }; - use language_model::{LanguageModelToolUseId, MessageContent, Role}; - use serde_json::json; - - #[test] - fn test_function_call_with_signature_creates_tool_use_with_signature() { - let mut mapper = GoogleEventMapper::new(); - - let response = GenerateContentResponse { - candidates: Some(vec![GenerateContentCandidate { - index: Some(0), - content: Content { - parts: vec![Part::FunctionCallPart(FunctionCallPart { - function_call: FunctionCall { - name: "test_function".to_string(), - args: json!({"arg": "value"}), - }, - thought_signature: Some("test_signature_123".to_string()), - })], - role: GoogleRole::Model, - }, - finish_reason: None, - finish_message: None, - safety_ratings: None, - citation_metadata: None, - }]), - prompt_feedback: None, - usage_metadata: None, - }; - - let events = mapper.map_event(response); - - assert_eq!(events.len(), 2); // ToolUse event + Stop event - - if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] { - assert_eq!(tool_use.name.as_ref(), "test_function"); - assert_eq!( - tool_use.thought_signature.as_deref(), - Some("test_signature_123") - ); - } else { - panic!("Expected ToolUse event"); - } - } - - #[test] - fn test_function_call_without_signature_has_none() { - let mut mapper = GoogleEventMapper::new(); - - let response = GenerateContentResponse { - candidates: Some(vec![GenerateContentCandidate { - index: Some(0), - content: Content { - parts: vec![Part::FunctionCallPart(FunctionCallPart { - function_call: FunctionCall { - name: "test_function".to_string(), - args: json!({"arg": "value"}), - }, - thought_signature: None, - })], - role: GoogleRole::Model, - }, - finish_reason: None, - finish_message: None, - safety_ratings: None, - citation_metadata: None, - }]), - prompt_feedback: None, - usage_metadata: None, - }; - - let events = mapper.map_event(response); - - if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] { - assert_eq!(tool_use.thought_signature, None); - } else { - panic!("Expected ToolUse event"); - } - } - - #[test] - fn test_empty_string_signature_normalized_to_none() { - let mut mapper = GoogleEventMapper::new(); - - let response = GenerateContentResponse { - candidates: Some(vec![GenerateContentCandidate { - index: Some(0), - content: Content { - parts: vec![Part::FunctionCallPart(FunctionCallPart { - function_call: FunctionCall { - name: "test_function".to_string(), - args: json!({"arg": "value"}), - }, - thought_signature: Some("".to_string()), - })], - role: GoogleRole::Model, - }, - finish_reason: None, - finish_message: None, - safety_ratings: None, - citation_metadata: None, - }]), - prompt_feedback: None, - usage_metadata: None, - }; - - let events = mapper.map_event(response); - - if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] { - assert_eq!(tool_use.thought_signature, None); - } else { - panic!("Expected ToolUse event"); - } - } - - #[test] - fn test_parallel_function_calls_preserve_signatures() { - let mut mapper = GoogleEventMapper::new(); - - let response = GenerateContentResponse { - candidates: Some(vec![GenerateContentCandidate { - index: Some(0), - content: Content { - parts: vec![ - Part::FunctionCallPart(FunctionCallPart { - function_call: FunctionCall { - name: "function_1".to_string(), - args: json!({"arg": "value1"}), - }, - thought_signature: Some("signature_1".to_string()), - }), - Part::FunctionCallPart(FunctionCallPart { - function_call: FunctionCall { - name: "function_2".to_string(), - args: json!({"arg": "value2"}), - }, - thought_signature: None, - }), - ], - role: GoogleRole::Model, - }, - finish_reason: None, - finish_message: None, - safety_ratings: None, - citation_metadata: None, - }]), - prompt_feedback: None, - usage_metadata: None, - }; - - let events = mapper.map_event(response); - - assert_eq!(events.len(), 3); // 2 ToolUse events + Stop event - - if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] { - assert_eq!(tool_use.name.as_ref(), "function_1"); - assert_eq!(tool_use.thought_signature.as_deref(), Some("signature_1")); - } else { - panic!("Expected ToolUse event for function_1"); - } - - if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[1] { - assert_eq!(tool_use.name.as_ref(), "function_2"); - assert_eq!(tool_use.thought_signature, None); - } else { - panic!("Expected ToolUse event for function_2"); - } - } - - #[test] - fn test_tool_use_with_signature_converts_to_function_call_part() { - let tool_use = language_model::LanguageModelToolUse { - id: LanguageModelToolUseId::from("test_id"), - name: "test_function".into(), - raw_input: json!({"arg": "value"}).to_string(), - input: json!({"arg": "value"}), - is_input_complete: true, - thought_signature: Some("test_signature_456".to_string()), - }; - - let request = super::into_google( - LanguageModelRequest { - messages: vec![language_model::LanguageModelRequestMessage { - role: Role::Assistant, - content: vec![MessageContent::ToolUse(tool_use)], - cache: false, - reasoning_details: None, - }], - ..Default::default() - }, - "gemini-2.5-flash".to_string(), - GoogleModelMode::Default, - ); - - assert_eq!(request.contents[0].parts.len(), 1); - if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] { - assert_eq!(fc_part.function_call.name, "test_function"); - assert_eq!( - fc_part.thought_signature.as_deref(), - Some("test_signature_456") - ); - } else { - panic!("Expected FunctionCallPart"); - } - } - - #[test] - fn test_tool_use_without_signature_omits_field() { - let tool_use = language_model::LanguageModelToolUse { - id: LanguageModelToolUseId::from("test_id"), - name: "test_function".into(), - raw_input: json!({"arg": "value"}).to_string(), - input: json!({"arg": "value"}), - is_input_complete: true, - thought_signature: None, - }; - - let request = super::into_google( - LanguageModelRequest { - messages: vec![language_model::LanguageModelRequestMessage { - role: Role::Assistant, - content: vec![MessageContent::ToolUse(tool_use)], - cache: false, - reasoning_details: None, - }], - ..Default::default() - }, - "gemini-2.5-flash".to_string(), - GoogleModelMode::Default, - ); - - assert_eq!(request.contents[0].parts.len(), 1); - if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] { - assert_eq!(fc_part.thought_signature, None); - } else { - panic!("Expected FunctionCallPart"); - } - } - - #[test] - fn test_empty_signature_in_tool_use_normalized_to_none() { - let tool_use = language_model::LanguageModelToolUse { - id: LanguageModelToolUseId::from("test_id"), - name: "test_function".into(), - raw_input: json!({"arg": "value"}).to_string(), - input: json!({"arg": "value"}), - is_input_complete: true, - thought_signature: Some("".to_string()), - }; - - let request = super::into_google( - LanguageModelRequest { - messages: vec![language_model::LanguageModelRequestMessage { - role: Role::Assistant, - content: vec![MessageContent::ToolUse(tool_use)], - cache: false, - reasoning_details: None, - }], - ..Default::default() - }, - "gemini-2.5-flash".to_string(), - GoogleModelMode::Default, - ); - - if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] { - assert_eq!(fc_part.thought_signature, None); - } else { - panic!("Expected FunctionCallPart"); - } - } - - #[test] - fn test_round_trip_preserves_signature() { - let mut mapper = GoogleEventMapper::new(); - - // Simulate receiving a response from Google with a signature - let response = GenerateContentResponse { - candidates: Some(vec![GenerateContentCandidate { - index: Some(0), - content: Content { - parts: vec![Part::FunctionCallPart(FunctionCallPart { - function_call: FunctionCall { - name: "test_function".to_string(), - args: json!({"arg": "value"}), - }, - thought_signature: Some("round_trip_sig".to_string()), - })], - role: GoogleRole::Model, - }, - finish_reason: None, - finish_message: None, - safety_ratings: None, - citation_metadata: None, - }]), - prompt_feedback: None, - usage_metadata: None, - }; - - let events = mapper.map_event(response); - - let tool_use = if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] { - tool_use.clone() - } else { - panic!("Expected ToolUse event"); - }; - - // Convert back to Google format - let request = super::into_google( - LanguageModelRequest { - messages: vec![language_model::LanguageModelRequestMessage { - role: Role::Assistant, - content: vec![MessageContent::ToolUse(tool_use)], - cache: false, - reasoning_details: None, - }], - ..Default::default() - }, - "gemini-2.5-flash".to_string(), - GoogleModelMode::Default, - ); - - // Verify signature is preserved - if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] { - assert_eq!(fc_part.thought_signature.as_deref(), Some("round_trip_sig")); - } else { - panic!("Expected FunctionCallPart"); - } - } - - #[test] - fn test_mixed_text_and_function_call_with_signature() { - let mut mapper = GoogleEventMapper::new(); - - let response = GenerateContentResponse { - candidates: Some(vec![GenerateContentCandidate { - index: Some(0), - content: Content { - parts: vec![ - Part::TextPart(TextPart { - text: "I'll help with that.".to_string(), - }), - Part::FunctionCallPart(FunctionCallPart { - function_call: FunctionCall { - name: "helper_function".to_string(), - args: json!({"query": "help"}), - }, - thought_signature: Some("mixed_sig".to_string()), - }), - ], - role: GoogleRole::Model, - }, - finish_reason: None, - finish_message: None, - safety_ratings: None, - citation_metadata: None, - }]), - prompt_feedback: None, - usage_metadata: None, - }; - - let events = mapper.map_event(response); - - assert_eq!(events.len(), 3); // Text event + ToolUse event + Stop event - - if let Ok(LanguageModelCompletionEvent::Text(text)) = &events[0] { - assert_eq!(text, "I'll help with that."); - } else { - panic!("Expected Text event"); - } - - if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[1] { - assert_eq!(tool_use.name.as_ref(), "helper_function"); - assert_eq!(tool_use.thought_signature.as_deref(), Some("mixed_sig")); - } else { - panic!("Expected ToolUse event"); - } - } - - #[test] - fn test_special_characters_in_signature_preserved() { - let mut mapper = GoogleEventMapper::new(); - - let signature_with_special_chars = "sig<>\"'&%$#@!{}[]".to_string(); - - let response = GenerateContentResponse { - candidates: Some(vec![GenerateContentCandidate { - index: Some(0), - content: Content { - parts: vec![Part::FunctionCallPart(FunctionCallPart { - function_call: FunctionCall { - name: "test_function".to_string(), - args: json!({"arg": "value"}), - }, - thought_signature: Some(signature_with_special_chars.clone()), - })], - role: GoogleRole::Model, - }, - finish_reason: None, - finish_message: None, - safety_ratings: None, - citation_metadata: None, - }]), - prompt_feedback: None, - usage_metadata: None, - }; - - let events = mapper.map_event(response); - - if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] { - assert_eq!( - tool_use.thought_signature.as_deref(), - Some(signature_with_special_chars.as_str()) - ); - } else { - panic!("Expected ToolUse event"); - } - } -} diff --git a/crates/language_models/src/provider/lmstudio.rs b/crates/language_models/src/provider/lmstudio.rs index 0d60fef16791087e35bac7d846b2ec99821d5470..a541da8cd8092d5d0fa43af1217c31833f10cdeb 100644 --- a/crates/language_models/src/provider/lmstudio.rs +++ b/crates/language_models/src/provider/lmstudio.rs @@ -28,7 +28,7 @@ use ui::{ use ui_input::InputField; use crate::AllLanguageModelSettings; -use crate::provider::util::parse_tool_arguments; +use language_model::util::parse_tool_arguments; const LMSTUDIO_DOWNLOAD_URL: &str = "https://lmstudio.ai/download"; const LMSTUDIO_CATALOG_URL: &str = "https://lmstudio.ai/models"; diff --git a/crates/language_models/src/provider/mistral.rs b/crates/language_models/src/provider/mistral.rs index 4cd1375fe50cd792a3a7bc8c85ba7b5b5af9520a..5fef40b2b1badbc77133ebe67fbe0f1fe5521259 100644 --- a/crates/language_models/src/provider/mistral.rs +++ b/crates/language_models/src/provider/mistral.rs @@ -23,7 +23,7 @@ use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*}; use ui_input::InputField; use util::ResultExt; -use crate::provider::util::{fix_streamed_json, parse_tool_arguments}; +use language_model::util::{fix_streamed_json, parse_tool_arguments}; const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("mistral"); const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Mistral"); diff --git a/crates/language_models/src/provider/open_ai.rs b/crates/language_models/src/provider/open_ai.rs index 6ed2c3e098cb4029f2b390dbee67402975f02cee..c3d93373a4fd27a307fa859e07b966eaa8616fb7 100644 --- a/crates/language_models/src/provider/open_ai.rs +++ b/crates/language_models/src/provider/open_ai.rs @@ -1,41 +1,33 @@ -use anyhow::{Result, anyhow}; -use collections::{BTreeMap, HashMap}; +use anyhow::Result; +use collections::BTreeMap; use credentials_provider::CredentialsProvider; -use futures::Stream; use futures::{FutureExt, StreamExt, future::BoxFuture}; use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window}; use http_client::HttpClient; use language_model::{ ApiKeyState, AuthenticateError, EnvVar, IconOrSvg, LanguageModel, LanguageModelCompletionError, - LanguageModelCompletionEvent, LanguageModelId, LanguageModelImage, LanguageModelName, - LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, - LanguageModelProviderState, LanguageModelRequest, LanguageModelRequestMessage, - LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, - LanguageModelToolUseId, MessageContent, OPEN_AI_PROVIDER_ID, OPEN_AI_PROVIDER_NAME, - RateLimiter, Role, StopReason, TokenUsage, env_var, + LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider, + LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, + LanguageModelRequest, LanguageModelToolChoice, OPEN_AI_PROVIDER_ID, OPEN_AI_PROVIDER_NAME, + RateLimiter, env_var, }; use menu; -use open_ai::responses::{ - ResponseFunctionCallItem, ResponseFunctionCallOutputContent, ResponseFunctionCallOutputItem, - ResponseInputContent, ResponseInputItem, ResponseMessageItem, -}; use open_ai::{ - ImageUrl, Model, OPEN_AI_API_URL, ReasoningEffort, ResponseStreamEvent, - responses::{ - Request as ResponseRequest, ResponseOutputItem, ResponseSummary as ResponsesSummary, - ResponseUsage as ResponsesUsage, StreamEvent as ResponsesStreamEvent, stream_response, - }, + OPEN_AI_API_URL, ResponseStreamEvent, + responses::{Request as ResponseRequest, StreamEvent as ResponsesStreamEvent, stream_response}, stream_completion, }; use settings::{OpenAiAvailableModel as AvailableModel, Settings, SettingsStore}; -use std::pin::Pin; use std::sync::{Arc, LazyLock}; use strum::IntoEnumIterator; use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*}; use ui_input::InputField; use util::ResultExt; -use crate::provider::util::{fix_streamed_json, parse_tool_arguments}; +pub use open_ai::completion::{ + OpenAiEventMapper, OpenAiResponseEventMapper, collect_tiktoken_messages, count_open_ai_tokens, + into_open_ai, into_open_ai_response, +}; const PROVIDER_ID: LanguageModelProviderId = OPEN_AI_PROVIDER_ID; const PROVIDER_NAME: LanguageModelProviderName = OPEN_AI_PROVIDER_NAME; @@ -189,7 +181,7 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider { max_tokens: model.max_tokens, max_output_tokens: model.max_output_tokens, max_completion_tokens: model.max_completion_tokens, - reasoning_effort: model.reasoning_effort.clone(), + reasoning_effort: model.reasoning_effort, supports_chat_completions: model.capabilities.chat_completions, }, ); @@ -384,7 +376,9 @@ impl LanguageModel for OpenAiLanguageModel { request: LanguageModelRequest, cx: &App, ) -> BoxFuture<'static, Result> { - count_open_ai_tokens(request, self.model.clone(), cx) + let model = self.model.clone(); + cx.background_spawn(async move { count_open_ai_tokens(request, model) }) + .boxed() } fn stream_completion( @@ -435,856 +429,6 @@ impl LanguageModel for OpenAiLanguageModel { } } -pub fn into_open_ai( - request: LanguageModelRequest, - model_id: &str, - supports_parallel_tool_calls: bool, - supports_prompt_cache_key: bool, - max_output_tokens: Option, - reasoning_effort: Option, -) -> open_ai::Request { - let stream = !model_id.starts_with("o1-"); - - let mut messages = Vec::new(); - for message in request.messages { - for content in message.content { - match content { - MessageContent::Text(text) | MessageContent::Thinking { text, .. } => { - let should_add = if message.role == Role::User { - // Including whitespace-only user messages can cause error with OpenAI compatible APIs - // See https://github.com/zed-industries/zed/issues/40097 - !text.trim().is_empty() - } else { - !text.is_empty() - }; - if should_add { - add_message_content_part( - open_ai::MessagePart::Text { text }, - message.role, - &mut messages, - ); - } - } - MessageContent::RedactedThinking(_) => {} - MessageContent::Image(image) => { - add_message_content_part( - open_ai::MessagePart::Image { - image_url: ImageUrl { - url: image.to_base64_url(), - detail: None, - }, - }, - message.role, - &mut messages, - ); - } - MessageContent::ToolUse(tool_use) => { - let tool_call = open_ai::ToolCall { - id: tool_use.id.to_string(), - content: open_ai::ToolCallContent::Function { - function: open_ai::FunctionContent { - name: tool_use.name.to_string(), - arguments: serde_json::to_string(&tool_use.input) - .unwrap_or_default(), - }, - }, - }; - - if let Some(open_ai::RequestMessage::Assistant { tool_calls, .. }) = - messages.last_mut() - { - tool_calls.push(tool_call); - } else { - messages.push(open_ai::RequestMessage::Assistant { - content: None, - tool_calls: vec![tool_call], - }); - } - } - MessageContent::ToolResult(tool_result) => { - let content = match &tool_result.content { - LanguageModelToolResultContent::Text(text) => { - vec![open_ai::MessagePart::Text { - text: text.to_string(), - }] - } - LanguageModelToolResultContent::Image(image) => { - vec![open_ai::MessagePart::Image { - image_url: ImageUrl { - url: image.to_base64_url(), - detail: None, - }, - }] - } - }; - - messages.push(open_ai::RequestMessage::Tool { - content: content.into(), - tool_call_id: tool_result.tool_use_id.to_string(), - }); - } - } - } - } - - open_ai::Request { - model: model_id.into(), - messages, - stream, - stream_options: if stream { - Some(open_ai::StreamOptions::default()) - } else { - None - }, - stop: request.stop, - temperature: request.temperature.or(Some(1.0)), - max_completion_tokens: max_output_tokens, - parallel_tool_calls: if supports_parallel_tool_calls && !request.tools.is_empty() { - Some(supports_parallel_tool_calls) - } else { - None - }, - prompt_cache_key: if supports_prompt_cache_key { - request.thread_id - } else { - None - }, - tools: request - .tools - .into_iter() - .map(|tool| open_ai::ToolDefinition::Function { - function: open_ai::FunctionDefinition { - name: tool.name, - description: Some(tool.description), - parameters: Some(tool.input_schema), - }, - }) - .collect(), - tool_choice: request.tool_choice.map(|choice| match choice { - LanguageModelToolChoice::Auto => open_ai::ToolChoice::Auto, - LanguageModelToolChoice::Any => open_ai::ToolChoice::Required, - LanguageModelToolChoice::None => open_ai::ToolChoice::None, - }), - reasoning_effort, - } -} - -pub fn into_open_ai_response( - request: LanguageModelRequest, - model_id: &str, - supports_parallel_tool_calls: bool, - supports_prompt_cache_key: bool, - max_output_tokens: Option, - reasoning_effort: Option, -) -> ResponseRequest { - let stream = !model_id.starts_with("o1-"); - - let LanguageModelRequest { - thread_id, - prompt_id: _, - intent: _, - messages, - tools, - tool_choice, - stop: _, - temperature, - thinking_allowed: _, - thinking_effort: _, - speed: _, - } = request; - - let mut input_items = Vec::new(); - for (index, message) in messages.into_iter().enumerate() { - append_message_to_response_items(message, index, &mut input_items); - } - - let tools: Vec<_> = tools - .into_iter() - .map(|tool| open_ai::responses::ToolDefinition::Function { - name: tool.name, - description: Some(tool.description), - parameters: Some(tool.input_schema), - strict: None, - }) - .collect(); - - ResponseRequest { - model: model_id.into(), - instructions: None, - input: input_items, - stream, - temperature, - top_p: None, - max_output_tokens, - parallel_tool_calls: if tools.is_empty() { - None - } else { - Some(supports_parallel_tool_calls) - }, - tool_choice: tool_choice.map(|choice| match choice { - LanguageModelToolChoice::Auto => open_ai::ToolChoice::Auto, - LanguageModelToolChoice::Any => open_ai::ToolChoice::Required, - LanguageModelToolChoice::None => open_ai::ToolChoice::None, - }), - tools, - prompt_cache_key: if supports_prompt_cache_key { - thread_id - } else { - None - }, - reasoning: reasoning_effort.map(|effort| open_ai::responses::ReasoningConfig { - effort, - summary: Some(open_ai::responses::ReasoningSummaryMode::Auto), - }), - store: None, - } -} - -fn append_message_to_response_items( - message: LanguageModelRequestMessage, - index: usize, - input_items: &mut Vec, -) { - let mut content_parts: Vec = Vec::new(); - - for content in message.content { - match content { - MessageContent::Text(text) => { - push_response_text_part(&message.role, text, &mut content_parts); - } - MessageContent::Thinking { text, .. } => { - push_response_text_part(&message.role, text, &mut content_parts); - } - MessageContent::RedactedThinking(_) => {} - MessageContent::Image(image) => { - push_response_image_part(&message.role, image, &mut content_parts); - } - MessageContent::ToolUse(tool_use) => { - flush_response_parts(&message.role, index, &mut content_parts, input_items); - let call_id = tool_use.id.to_string(); - input_items.push(ResponseInputItem::FunctionCall(ResponseFunctionCallItem { - call_id, - name: tool_use.name.to_string(), - arguments: tool_use.raw_input, - })); - } - MessageContent::ToolResult(tool_result) => { - flush_response_parts(&message.role, index, &mut content_parts, input_items); - input_items.push(ResponseInputItem::FunctionCallOutput( - ResponseFunctionCallOutputItem { - call_id: tool_result.tool_use_id.to_string(), - output: match tool_result.content { - LanguageModelToolResultContent::Text(text) => { - ResponseFunctionCallOutputContent::Text(text.to_string()) - } - LanguageModelToolResultContent::Image(image) => { - ResponseFunctionCallOutputContent::List(vec![ - ResponseInputContent::Image { - image_url: image.to_base64_url(), - }, - ]) - } - }, - }, - )); - } - } - } - - flush_response_parts(&message.role, index, &mut content_parts, input_items); -} - -fn push_response_text_part( - role: &Role, - text: impl Into, - parts: &mut Vec, -) { - let text = text.into(); - if text.trim().is_empty() { - return; - } - - match role { - Role::Assistant => parts.push(ResponseInputContent::OutputText { - text, - annotations: Vec::new(), - }), - _ => parts.push(ResponseInputContent::Text { text }), - } -} - -fn push_response_image_part( - role: &Role, - image: LanguageModelImage, - parts: &mut Vec, -) { - match role { - Role::Assistant => parts.push(ResponseInputContent::OutputText { - text: "[image omitted]".to_string(), - annotations: Vec::new(), - }), - _ => parts.push(ResponseInputContent::Image { - image_url: image.to_base64_url(), - }), - } -} - -fn flush_response_parts( - role: &Role, - _index: usize, - parts: &mut Vec, - input_items: &mut Vec, -) { - if parts.is_empty() { - return; - } - - let item = ResponseInputItem::Message(ResponseMessageItem { - role: match role { - Role::User => open_ai::Role::User, - Role::Assistant => open_ai::Role::Assistant, - Role::System => open_ai::Role::System, - }, - content: parts.clone(), - }); - - input_items.push(item); - parts.clear(); -} - -fn add_message_content_part( - new_part: open_ai::MessagePart, - role: Role, - messages: &mut Vec, -) { - match (role, messages.last_mut()) { - (Role::User, Some(open_ai::RequestMessage::User { content })) - | ( - Role::Assistant, - Some(open_ai::RequestMessage::Assistant { - content: Some(content), - .. - }), - ) - | (Role::System, Some(open_ai::RequestMessage::System { content, .. })) => { - content.push_part(new_part); - } - _ => { - messages.push(match role { - Role::User => open_ai::RequestMessage::User { - content: open_ai::MessageContent::from(vec![new_part]), - }, - Role::Assistant => open_ai::RequestMessage::Assistant { - content: Some(open_ai::MessageContent::from(vec![new_part])), - tool_calls: Vec::new(), - }, - Role::System => open_ai::RequestMessage::System { - content: open_ai::MessageContent::from(vec![new_part]), - }, - }); - } - } -} - -pub struct OpenAiEventMapper { - tool_calls_by_index: HashMap, -} - -impl OpenAiEventMapper { - pub fn new() -> Self { - Self { - tool_calls_by_index: HashMap::default(), - } - } - - pub fn map_stream( - mut self, - events: Pin>>>, - ) -> impl Stream> - { - events.flat_map(move |event| { - futures::stream::iter(match event { - Ok(event) => self.map_event(event), - Err(error) => vec![Err(LanguageModelCompletionError::from(anyhow!(error)))], - }) - }) - } - - pub fn map_event( - &mut self, - event: ResponseStreamEvent, - ) -> Vec> { - let mut events = Vec::new(); - if let Some(usage) = event.usage { - events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage { - input_tokens: usage.prompt_tokens, - output_tokens: usage.completion_tokens, - cache_creation_input_tokens: 0, - cache_read_input_tokens: 0, - }))); - } - - let Some(choice) = event.choices.first() else { - return events; - }; - - if let Some(delta) = choice.delta.as_ref() { - if let Some(reasoning_content) = delta.reasoning_content.clone() { - if !reasoning_content.is_empty() { - events.push(Ok(LanguageModelCompletionEvent::Thinking { - text: reasoning_content, - signature: None, - })); - } - } - if let Some(content) = delta.content.clone() { - if !content.is_empty() { - events.push(Ok(LanguageModelCompletionEvent::Text(content))); - } - } - - if let Some(tool_calls) = delta.tool_calls.as_ref() { - for tool_call in tool_calls { - let entry = self.tool_calls_by_index.entry(tool_call.index).or_default(); - - if let Some(tool_id) = tool_call.id.clone() { - entry.id = tool_id; - } - - if let Some(function) = tool_call.function.as_ref() { - if let Some(name) = function.name.clone() { - entry.name = name; - } - - if let Some(arguments) = function.arguments.clone() { - entry.arguments.push_str(&arguments); - } - } - - if !entry.id.is_empty() && !entry.name.is_empty() { - if let Ok(input) = serde_json::from_str::( - &fix_streamed_json(&entry.arguments), - ) { - events.push(Ok(LanguageModelCompletionEvent::ToolUse( - LanguageModelToolUse { - id: entry.id.clone().into(), - name: entry.name.as_str().into(), - is_input_complete: false, - input, - raw_input: entry.arguments.clone(), - thought_signature: None, - }, - ))); - } - } - } - } - } - - match choice.finish_reason.as_deref() { - Some("stop") => { - events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn))); - } - Some("tool_calls") => { - events.extend(self.tool_calls_by_index.drain().map(|(_, tool_call)| { - match parse_tool_arguments(&tool_call.arguments) { - Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse( - LanguageModelToolUse { - id: tool_call.id.clone().into(), - name: tool_call.name.as_str().into(), - is_input_complete: true, - input, - raw_input: tool_call.arguments.clone(), - thought_signature: None, - }, - )), - Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError { - id: tool_call.id.into(), - tool_name: tool_call.name.into(), - raw_input: tool_call.arguments.clone().into(), - json_parse_error: error.to_string(), - }), - } - })); - - events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse))); - } - Some(stop_reason) => { - log::error!("Unexpected OpenAI stop_reason: {stop_reason:?}",); - events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn))); - } - None => {} - } - - events - } -} - -#[derive(Default)] -struct RawToolCall { - id: String, - name: String, - arguments: String, -} - -pub struct OpenAiResponseEventMapper { - function_calls_by_item: HashMap, - pending_stop_reason: Option, -} - -#[derive(Default)] -struct PendingResponseFunctionCall { - call_id: String, - name: Arc, - arguments: String, -} - -impl OpenAiResponseEventMapper { - pub fn new() -> Self { - Self { - function_calls_by_item: HashMap::default(), - pending_stop_reason: None, - } - } - - pub fn map_stream( - mut self, - events: Pin>>>, - ) -> impl Stream> - { - events.flat_map(move |event| { - futures::stream::iter(match event { - Ok(event) => self.map_event(event), - Err(error) => vec![Err(LanguageModelCompletionError::from(anyhow!(error)))], - }) - }) - } - - pub fn map_event( - &mut self, - event: ResponsesStreamEvent, - ) -> Vec> { - match event { - ResponsesStreamEvent::OutputItemAdded { item, .. } => { - let mut events = Vec::new(); - - match &item { - ResponseOutputItem::Message(message) => { - if let Some(id) = &message.id { - events.push(Ok(LanguageModelCompletionEvent::StartMessage { - message_id: id.clone(), - })); - } - } - ResponseOutputItem::FunctionCall(function_call) => { - if let Some(item_id) = function_call.id.clone() { - let call_id = function_call - .call_id - .clone() - .or_else(|| function_call.id.clone()) - .unwrap_or_else(|| item_id.clone()); - let entry = PendingResponseFunctionCall { - call_id, - name: Arc::::from( - function_call.name.clone().unwrap_or_default(), - ), - arguments: function_call.arguments.clone(), - }; - self.function_calls_by_item.insert(item_id, entry); - } - } - ResponseOutputItem::Reasoning(_) | ResponseOutputItem::Unknown => {} - } - events - } - ResponsesStreamEvent::ReasoningSummaryTextDelta { delta, .. } => { - if delta.is_empty() { - Vec::new() - } else { - vec![Ok(LanguageModelCompletionEvent::Thinking { - text: delta, - signature: None, - })] - } - } - ResponsesStreamEvent::OutputTextDelta { delta, .. } => { - if delta.is_empty() { - Vec::new() - } else { - vec![Ok(LanguageModelCompletionEvent::Text(delta))] - } - } - ResponsesStreamEvent::FunctionCallArgumentsDelta { item_id, delta, .. } => { - if let Some(entry) = self.function_calls_by_item.get_mut(&item_id) { - entry.arguments.push_str(&delta); - if let Ok(input) = serde_json::from_str::( - &fix_streamed_json(&entry.arguments), - ) { - return vec![Ok(LanguageModelCompletionEvent::ToolUse( - LanguageModelToolUse { - id: LanguageModelToolUseId::from(entry.call_id.clone()), - name: entry.name.clone(), - is_input_complete: false, - input, - raw_input: entry.arguments.clone(), - thought_signature: None, - }, - ))]; - } - } - Vec::new() - } - ResponsesStreamEvent::FunctionCallArgumentsDone { - item_id, arguments, .. - } => { - if let Some(mut entry) = self.function_calls_by_item.remove(&item_id) { - if !arguments.is_empty() { - entry.arguments = arguments; - } - let raw_input = entry.arguments.clone(); - self.pending_stop_reason = Some(StopReason::ToolUse); - match parse_tool_arguments(&entry.arguments) { - Ok(input) => { - vec![Ok(LanguageModelCompletionEvent::ToolUse( - LanguageModelToolUse { - id: LanguageModelToolUseId::from(entry.call_id.clone()), - name: entry.name.clone(), - is_input_complete: true, - input, - raw_input, - thought_signature: None, - }, - ))] - } - Err(error) => { - vec![Ok(LanguageModelCompletionEvent::ToolUseJsonParseError { - id: LanguageModelToolUseId::from(entry.call_id.clone()), - tool_name: entry.name.clone(), - raw_input: Arc::::from(raw_input), - json_parse_error: error.to_string(), - })] - } - } - } else { - Vec::new() - } - } - ResponsesStreamEvent::Completed { response } => { - self.handle_completion(response, StopReason::EndTurn) - } - ResponsesStreamEvent::Incomplete { response } => { - let reason = response - .status_details - .as_ref() - .and_then(|details| details.reason.as_deref()); - let stop_reason = match reason { - Some("max_output_tokens") => StopReason::MaxTokens, - Some("content_filter") => { - self.pending_stop_reason = Some(StopReason::Refusal); - StopReason::Refusal - } - _ => self - .pending_stop_reason - .take() - .unwrap_or(StopReason::EndTurn), - }; - - let mut events = Vec::new(); - if self.pending_stop_reason.is_none() { - events.extend(self.emit_tool_calls_from_output(&response.output)); - } - if let Some(usage) = response.usage.as_ref() { - events.push(Ok(LanguageModelCompletionEvent::UsageUpdate( - token_usage_from_response_usage(usage), - ))); - } - events.push(Ok(LanguageModelCompletionEvent::Stop(stop_reason))); - events - } - ResponsesStreamEvent::Failed { response } => { - let message = response - .status_details - .and_then(|details| details.error) - .map(|error| error.to_string()) - .unwrap_or_else(|| "response failed".to_string()); - vec![Err(LanguageModelCompletionError::Other(anyhow!(message)))] - } - ResponsesStreamEvent::Error { error } - | ResponsesStreamEvent::GenericError { error } => { - vec![Err(LanguageModelCompletionError::Other(anyhow!( - error.message - )))] - } - ResponsesStreamEvent::ReasoningSummaryPartAdded { summary_index, .. } => { - if summary_index > 0 { - vec![Ok(LanguageModelCompletionEvent::Thinking { - text: "\n\n".to_string(), - signature: None, - })] - } else { - Vec::new() - } - } - ResponsesStreamEvent::OutputTextDone { .. } - | ResponsesStreamEvent::OutputItemDone { .. } - | ResponsesStreamEvent::ContentPartAdded { .. } - | ResponsesStreamEvent::ContentPartDone { .. } - | ResponsesStreamEvent::ReasoningSummaryTextDone { .. } - | ResponsesStreamEvent::ReasoningSummaryPartDone { .. } - | ResponsesStreamEvent::Created { .. } - | ResponsesStreamEvent::InProgress { .. } - | ResponsesStreamEvent::Unknown => Vec::new(), - } - } - - fn handle_completion( - &mut self, - response: ResponsesSummary, - default_reason: StopReason, - ) -> Vec> { - let mut events = Vec::new(); - - if self.pending_stop_reason.is_none() { - events.extend(self.emit_tool_calls_from_output(&response.output)); - } - - if let Some(usage) = response.usage.as_ref() { - events.push(Ok(LanguageModelCompletionEvent::UsageUpdate( - token_usage_from_response_usage(usage), - ))); - } - - let stop_reason = self.pending_stop_reason.take().unwrap_or(default_reason); - events.push(Ok(LanguageModelCompletionEvent::Stop(stop_reason))); - events - } - - fn emit_tool_calls_from_output( - &mut self, - output: &[ResponseOutputItem], - ) -> Vec> { - let mut events = Vec::new(); - for item in output { - if let ResponseOutputItem::FunctionCall(function_call) = item { - let Some(call_id) = function_call - .call_id - .clone() - .or_else(|| function_call.id.clone()) - else { - log::error!( - "Function call item missing both call_id and id: {:?}", - function_call - ); - continue; - }; - let name: Arc = Arc::from(function_call.name.clone().unwrap_or_default()); - let arguments = &function_call.arguments; - self.pending_stop_reason = Some(StopReason::ToolUse); - match parse_tool_arguments(arguments) { - Ok(input) => { - events.push(Ok(LanguageModelCompletionEvent::ToolUse( - LanguageModelToolUse { - id: LanguageModelToolUseId::from(call_id.clone()), - name: name.clone(), - is_input_complete: true, - input, - raw_input: arguments.clone(), - thought_signature: None, - }, - ))); - } - Err(error) => { - events.push(Ok(LanguageModelCompletionEvent::ToolUseJsonParseError { - id: LanguageModelToolUseId::from(call_id.clone()), - tool_name: name.clone(), - raw_input: Arc::::from(arguments.clone()), - json_parse_error: error.to_string(), - })); - } - } - } - } - events - } -} - -fn token_usage_from_response_usage(usage: &ResponsesUsage) -> TokenUsage { - TokenUsage { - input_tokens: usage.input_tokens.unwrap_or_default(), - output_tokens: usage.output_tokens.unwrap_or_default(), - cache_creation_input_tokens: 0, - cache_read_input_tokens: 0, - } -} - -pub(crate) fn collect_tiktoken_messages( - request: LanguageModelRequest, -) -> Vec { - request - .messages - .into_iter() - .map(|message| tiktoken_rs::ChatCompletionRequestMessage { - role: match message.role { - Role::User => "user".into(), - Role::Assistant => "assistant".into(), - Role::System => "system".into(), - }, - content: Some(message.string_contents()), - name: None, - function_call: None, - }) - .collect::>() -} - -pub fn count_open_ai_tokens( - request: LanguageModelRequest, - model: Model, - cx: &App, -) -> BoxFuture<'static, Result> { - cx.background_spawn(async move { - let messages = collect_tiktoken_messages(request); - match model { - Model::Custom { max_tokens, .. } => { - let model = if max_tokens >= 100_000 { - // If the max tokens is 100k or more, it likely uses the o200k_base tokenizer - "gpt-4o" - } else { - // Otherwise fallback to gpt-4, since only cl100k_base and o200k_base are - // supported with this tiktoken method - "gpt-4" - }; - tiktoken_rs::num_tokens_from_messages(model, &messages) - } - // Currently supported by tiktoken_rs - // Sometimes tiktoken-rs is behind on model support. If that is the case, make a new branch - // arm with an override. We enumerate all supported models here so that we can check if new - // models are supported yet or not. - Model::ThreePointFiveTurbo - | Model::Four - | Model::FourTurbo - | Model::FourOmniMini - | Model::FourPointOneNano - | Model::O1 - | Model::O3 - | Model::O3Mini - | Model::O4Mini - | Model::Five - | Model::FiveCodex - | Model::FiveMini - | Model::FiveNano => tiktoken_rs::num_tokens_from_messages(model.id(), &messages), - // GPT-5.1, 5.2, 5.2-codex, 5.3-codex, 5.4, and 5.4-pro don't have dedicated tiktoken support; use gpt-5 tokenizer - Model::FivePointOne - | Model::FivePointTwo - | Model::FivePointTwoCodex - | Model::FivePointThreeCodex - | Model::FivePointFour - | Model::FivePointFourPro => tiktoken_rs::num_tokens_from_messages("gpt-5", &messages), - } - .map(|tokens| tokens as u64) - }) - .boxed() -} - struct ConfigurationView { api_key_editor: Entity, state: Entity, @@ -1464,874 +608,3 @@ impl Render for ConfigurationView { } } } - -#[cfg(test)] -mod tests { - use futures::{StreamExt, executor::block_on}; - use gpui::TestAppContext; - use language_model::{ - LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult, - }; - use open_ai::responses::{ - ReasoningSummaryPart, ResponseFunctionToolCall, ResponseOutputItem, ResponseOutputMessage, - ResponseReasoningItem, ResponseStatusDetails, ResponseSummary, ResponseUsage, - StreamEvent as ResponsesStreamEvent, - }; - use pretty_assertions::assert_eq; - use serde_json::json; - - use super::*; - - fn map_response_events(events: Vec) -> Vec { - block_on(async { - OpenAiResponseEventMapper::new() - .map_stream(Box::pin(futures::stream::iter(events.into_iter().map(Ok)))) - .collect::>() - .await - .into_iter() - .map(Result::unwrap) - .collect() - }) - } - - fn response_item_message(id: &str) -> ResponseOutputItem { - ResponseOutputItem::Message(ResponseOutputMessage { - id: Some(id.to_string()), - role: Some("assistant".to_string()), - status: Some("in_progress".to_string()), - content: vec![], - }) - } - - fn response_item_function_call(id: &str, args: Option<&str>) -> ResponseOutputItem { - ResponseOutputItem::FunctionCall(ResponseFunctionToolCall { - id: Some(id.to_string()), - status: Some("in_progress".to_string()), - name: Some("get_weather".to_string()), - call_id: Some("call_123".to_string()), - arguments: args.map(|s| s.to_string()).unwrap_or_default(), - }) - } - - #[gpui::test] - fn tiktoken_rs_support(cx: &TestAppContext) { - let request = LanguageModelRequest { - thread_id: None, - prompt_id: None, - intent: None, - messages: vec![LanguageModelRequestMessage { - role: Role::User, - content: vec![MessageContent::Text("message".into())], - cache: false, - reasoning_details: None, - }], - tools: vec![], - tool_choice: None, - stop: vec![], - temperature: None, - thinking_allowed: true, - thinking_effort: None, - speed: None, - }; - - // Validate that all models are supported by tiktoken-rs - for model in Model::iter() { - let count = cx - .foreground_executor() - .block_on(count_open_ai_tokens( - request.clone(), - model, - &cx.app.borrow(), - )) - .unwrap(); - assert!(count > 0); - } - } - - #[test] - fn responses_stream_maps_text_and_usage() { - let events = vec![ - ResponsesStreamEvent::OutputItemAdded { - output_index: 0, - sequence_number: None, - item: response_item_message("msg_123"), - }, - ResponsesStreamEvent::OutputTextDelta { - item_id: "msg_123".into(), - output_index: 0, - content_index: Some(0), - delta: "Hello".into(), - }, - ResponsesStreamEvent::Completed { - response: ResponseSummary { - usage: Some(ResponseUsage { - input_tokens: Some(5), - output_tokens: Some(3), - total_tokens: Some(8), - }), - ..Default::default() - }, - }, - ]; - - let mapped = map_response_events(events); - assert!(matches!( - mapped[0], - LanguageModelCompletionEvent::StartMessage { ref message_id } if message_id == "msg_123" - )); - assert!(matches!( - mapped[1], - LanguageModelCompletionEvent::Text(ref text) if text == "Hello" - )); - assert!(matches!( - mapped[2], - LanguageModelCompletionEvent::UsageUpdate(TokenUsage { - input_tokens: 5, - output_tokens: 3, - .. - }) - )); - assert!(matches!( - mapped[3], - LanguageModelCompletionEvent::Stop(StopReason::EndTurn) - )); - } - - #[test] - fn into_open_ai_response_builds_complete_payload() { - let tool_call_id = LanguageModelToolUseId::from("call-42"); - let tool_input = json!({ "city": "Boston" }); - let tool_arguments = serde_json::to_string(&tool_input).unwrap(); - let tool_use = LanguageModelToolUse { - id: tool_call_id.clone(), - name: Arc::from("get_weather"), - raw_input: tool_arguments.clone(), - input: tool_input, - is_input_complete: true, - thought_signature: None, - }; - let tool_result = LanguageModelToolResult { - tool_use_id: tool_call_id, - tool_name: Arc::from("get_weather"), - is_error: false, - content: LanguageModelToolResultContent::Text(Arc::from("Sunny")), - output: Some(json!({ "forecast": "Sunny" })), - }; - let user_image = LanguageModelImage { - source: SharedString::from("aGVsbG8="), - size: None, - }; - let expected_image_url = user_image.to_base64_url(); - - let request = LanguageModelRequest { - thread_id: Some("thread-123".into()), - prompt_id: None, - intent: None, - messages: vec![ - LanguageModelRequestMessage { - role: Role::System, - content: vec![MessageContent::Text("System context".into())], - cache: false, - reasoning_details: None, - }, - LanguageModelRequestMessage { - role: Role::User, - content: vec![ - MessageContent::Text("Please check the weather.".into()), - MessageContent::Image(user_image), - ], - cache: false, - reasoning_details: None, - }, - LanguageModelRequestMessage { - role: Role::Assistant, - content: vec![ - MessageContent::Text("Looking that up.".into()), - MessageContent::ToolUse(tool_use), - ], - cache: false, - reasoning_details: None, - }, - LanguageModelRequestMessage { - role: Role::Assistant, - content: vec![MessageContent::ToolResult(tool_result)], - cache: false, - reasoning_details: None, - }, - ], - tools: vec![LanguageModelRequestTool { - name: "get_weather".into(), - description: "Fetches the weather".into(), - input_schema: json!({ "type": "object" }), - use_input_streaming: false, - }], - tool_choice: Some(LanguageModelToolChoice::Any), - stop: vec!["".into()], - temperature: None, - thinking_allowed: false, - thinking_effort: None, - speed: None, - }; - - let response = into_open_ai_response( - request, - "custom-model", - true, - true, - Some(2048), - Some(ReasoningEffort::Low), - ); - - let serialized = serde_json::to_value(&response).unwrap(); - let expected = json!({ - "model": "custom-model", - "input": [ - { - "type": "message", - "role": "system", - "content": [ - { "type": "input_text", "text": "System context" } - ] - }, - { - "type": "message", - "role": "user", - "content": [ - { "type": "input_text", "text": "Please check the weather." }, - { "type": "input_image", "image_url": expected_image_url } - ] - }, - { - "type": "message", - "role": "assistant", - "content": [ - { "type": "output_text", "text": "Looking that up.", "annotations": [] } - ] - }, - { - "type": "function_call", - "call_id": "call-42", - "name": "get_weather", - "arguments": tool_arguments - }, - { - "type": "function_call_output", - "call_id": "call-42", - "output": "Sunny" - } - ], - "stream": true, - "max_output_tokens": 2048, - "parallel_tool_calls": true, - "tool_choice": "required", - "tools": [ - { - "type": "function", - "name": "get_weather", - "description": "Fetches the weather", - "parameters": { "type": "object" } - } - ], - "prompt_cache_key": "thread-123", - "reasoning": { "effort": "low", "summary": "auto" } - }); - - assert_eq!(serialized, expected); - } - - #[test] - fn responses_stream_maps_tool_calls() { - let events = vec![ - ResponsesStreamEvent::OutputItemAdded { - output_index: 0, - sequence_number: None, - item: response_item_function_call("item_fn", Some("{\"city\":\"Bos")), - }, - ResponsesStreamEvent::FunctionCallArgumentsDelta { - item_id: "item_fn".into(), - output_index: 0, - delta: "ton\"}".into(), - sequence_number: None, - }, - ResponsesStreamEvent::FunctionCallArgumentsDone { - item_id: "item_fn".into(), - output_index: 0, - arguments: "{\"city\":\"Boston\"}".into(), - sequence_number: None, - }, - ResponsesStreamEvent::Completed { - response: ResponseSummary::default(), - }, - ]; - - let mapped = map_response_events(events); - assert_eq!(mapped.len(), 3); - // First event is the partial tool use (from FunctionCallArgumentsDelta) - assert!(matches!( - mapped[0], - LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { - is_input_complete: false, - .. - }) - )); - // Second event is the complete tool use (from FunctionCallArgumentsDone) - assert!(matches!( - mapped[1], - LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { - ref id, - ref name, - ref raw_input, - is_input_complete: true, - .. - }) if id.to_string() == "call_123" - && name.as_ref() == "get_weather" - && raw_input == "{\"city\":\"Boston\"}" - )); - assert!(matches!( - mapped[2], - LanguageModelCompletionEvent::Stop(StopReason::ToolUse) - )); - } - - #[test] - fn responses_stream_uses_max_tokens_stop_reason() { - let events = vec![ResponsesStreamEvent::Incomplete { - response: ResponseSummary { - status_details: Some(ResponseStatusDetails { - reason: Some("max_output_tokens".into()), - r#type: Some("incomplete".into()), - error: None, - }), - usage: Some(ResponseUsage { - input_tokens: Some(10), - output_tokens: Some(20), - total_tokens: Some(30), - }), - ..Default::default() - }, - }]; - - let mapped = map_response_events(events); - assert!(matches!( - mapped[0], - LanguageModelCompletionEvent::UsageUpdate(TokenUsage { - input_tokens: 10, - output_tokens: 20, - .. - }) - )); - assert!(matches!( - mapped[1], - LanguageModelCompletionEvent::Stop(StopReason::MaxTokens) - )); - } - - #[test] - fn responses_stream_handles_multiple_tool_calls() { - let events = vec![ - ResponsesStreamEvent::OutputItemAdded { - output_index: 0, - sequence_number: None, - item: response_item_function_call("item_fn1", Some("{\"city\":\"NYC\"}")), - }, - ResponsesStreamEvent::FunctionCallArgumentsDone { - item_id: "item_fn1".into(), - output_index: 0, - arguments: "{\"city\":\"NYC\"}".into(), - sequence_number: None, - }, - ResponsesStreamEvent::OutputItemAdded { - output_index: 1, - sequence_number: None, - item: response_item_function_call("item_fn2", Some("{\"city\":\"LA\"}")), - }, - ResponsesStreamEvent::FunctionCallArgumentsDone { - item_id: "item_fn2".into(), - output_index: 1, - arguments: "{\"city\":\"LA\"}".into(), - sequence_number: None, - }, - ResponsesStreamEvent::Completed { - response: ResponseSummary::default(), - }, - ]; - - let mapped = map_response_events(events); - assert_eq!(mapped.len(), 3); - assert!(matches!( - mapped[0], - LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { ref raw_input, .. }) - if raw_input == "{\"city\":\"NYC\"}" - )); - assert!(matches!( - mapped[1], - LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { ref raw_input, .. }) - if raw_input == "{\"city\":\"LA\"}" - )); - assert!(matches!( - mapped[2], - LanguageModelCompletionEvent::Stop(StopReason::ToolUse) - )); - } - - #[test] - fn responses_stream_handles_mixed_text_and_tool_calls() { - let events = vec![ - ResponsesStreamEvent::OutputItemAdded { - output_index: 0, - sequence_number: None, - item: response_item_message("msg_123"), - }, - ResponsesStreamEvent::OutputTextDelta { - item_id: "msg_123".into(), - output_index: 0, - content_index: Some(0), - delta: "Let me check that".into(), - }, - ResponsesStreamEvent::OutputItemAdded { - output_index: 1, - sequence_number: None, - item: response_item_function_call("item_fn", Some("{\"query\":\"test\"}")), - }, - ResponsesStreamEvent::FunctionCallArgumentsDone { - item_id: "item_fn".into(), - output_index: 1, - arguments: "{\"query\":\"test\"}".into(), - sequence_number: None, - }, - ResponsesStreamEvent::Completed { - response: ResponseSummary::default(), - }, - ]; - - let mapped = map_response_events(events); - assert!(matches!( - mapped[0], - LanguageModelCompletionEvent::StartMessage { .. } - )); - assert!(matches!( - mapped[1], - LanguageModelCompletionEvent::Text(ref text) if text == "Let me check that" - )); - assert!(matches!( - mapped[2], - LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { ref raw_input, .. }) - if raw_input == "{\"query\":\"test\"}" - )); - assert!(matches!( - mapped[3], - LanguageModelCompletionEvent::Stop(StopReason::ToolUse) - )); - } - - #[test] - fn responses_stream_handles_json_parse_error() { - let events = vec![ - ResponsesStreamEvent::OutputItemAdded { - output_index: 0, - sequence_number: None, - item: response_item_function_call("item_fn", Some("{invalid json")), - }, - ResponsesStreamEvent::FunctionCallArgumentsDone { - item_id: "item_fn".into(), - output_index: 0, - arguments: "{invalid json".into(), - sequence_number: None, - }, - ResponsesStreamEvent::Completed { - response: ResponseSummary::default(), - }, - ]; - - let mapped = map_response_events(events); - assert!(matches!( - mapped[0], - LanguageModelCompletionEvent::ToolUseJsonParseError { - ref raw_input, - .. - } if raw_input.as_ref() == "{invalid json" - )); - } - - #[test] - fn responses_stream_handles_incomplete_function_call() { - let events = vec![ - ResponsesStreamEvent::OutputItemAdded { - output_index: 0, - sequence_number: None, - item: response_item_function_call("item_fn", Some("{\"city\":")), - }, - ResponsesStreamEvent::FunctionCallArgumentsDelta { - item_id: "item_fn".into(), - output_index: 0, - delta: "\"Boston\"".into(), - sequence_number: None, - }, - ResponsesStreamEvent::Incomplete { - response: ResponseSummary { - status_details: Some(ResponseStatusDetails { - reason: Some("max_output_tokens".into()), - r#type: Some("incomplete".into()), - error: None, - }), - output: vec![response_item_function_call( - "item_fn", - Some("{\"city\":\"Boston\"}"), - )], - ..Default::default() - }, - }, - ]; - - let mapped = map_response_events(events); - assert_eq!(mapped.len(), 3); - // First event is the partial tool use (from FunctionCallArgumentsDelta) - assert!(matches!( - mapped[0], - LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { - is_input_complete: false, - .. - }) - )); - // Second event is the complete tool use (from the Incomplete response output) - assert!(matches!( - mapped[1], - LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { - ref raw_input, - is_input_complete: true, - .. - }) - if raw_input == "{\"city\":\"Boston\"}" - )); - assert!(matches!( - mapped[2], - LanguageModelCompletionEvent::Stop(StopReason::MaxTokens) - )); - } - - #[test] - fn responses_stream_incomplete_does_not_duplicate_tool_calls() { - let events = vec![ - ResponsesStreamEvent::OutputItemAdded { - output_index: 0, - sequence_number: None, - item: response_item_function_call("item_fn", Some("{\"city\":\"Boston\"}")), - }, - ResponsesStreamEvent::FunctionCallArgumentsDone { - item_id: "item_fn".into(), - output_index: 0, - arguments: "{\"city\":\"Boston\"}".into(), - sequence_number: None, - }, - ResponsesStreamEvent::Incomplete { - response: ResponseSummary { - status_details: Some(ResponseStatusDetails { - reason: Some("max_output_tokens".into()), - r#type: Some("incomplete".into()), - error: None, - }), - output: vec![response_item_function_call( - "item_fn", - Some("{\"city\":\"Boston\"}"), - )], - ..Default::default() - }, - }, - ]; - - let mapped = map_response_events(events); - assert_eq!(mapped.len(), 2); - assert!(matches!( - mapped[0], - LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { ref raw_input, .. }) - if raw_input == "{\"city\":\"Boston\"}" - )); - assert!(matches!( - mapped[1], - LanguageModelCompletionEvent::Stop(StopReason::MaxTokens) - )); - } - - #[test] - fn responses_stream_handles_empty_tool_arguments() { - // Test that tools with no arguments (empty string) are handled correctly - let events = vec![ - ResponsesStreamEvent::OutputItemAdded { - output_index: 0, - sequence_number: None, - item: response_item_function_call("item_fn", Some("")), - }, - ResponsesStreamEvent::FunctionCallArgumentsDone { - item_id: "item_fn".into(), - output_index: 0, - arguments: "".into(), - sequence_number: None, - }, - ResponsesStreamEvent::Completed { - response: ResponseSummary::default(), - }, - ]; - - let mapped = map_response_events(events); - assert_eq!(mapped.len(), 2); - - // Should produce a ToolUse event with an empty object - assert!(matches!( - &mapped[0], - LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { - id, - name, - raw_input, - input, - .. - }) if id.to_string() == "call_123" - && name.as_ref() == "get_weather" - && raw_input == "" - && input.is_object() - && input.as_object().unwrap().is_empty() - )); - - assert!(matches!( - mapped[1], - LanguageModelCompletionEvent::Stop(StopReason::ToolUse) - )); - } - - #[test] - fn responses_stream_emits_partial_tool_use_events() { - let events = vec![ - ResponsesStreamEvent::OutputItemAdded { - output_index: 0, - sequence_number: None, - item: ResponseOutputItem::FunctionCall(ResponseFunctionToolCall { - id: Some("item_fn".to_string()), - status: Some("in_progress".to_string()), - name: Some("get_weather".to_string()), - call_id: Some("call_abc".to_string()), - arguments: String::new(), - }), - }, - ResponsesStreamEvent::FunctionCallArgumentsDelta { - item_id: "item_fn".into(), - output_index: 0, - delta: "{\"city\":\"Bos".into(), - sequence_number: None, - }, - ResponsesStreamEvent::FunctionCallArgumentsDelta { - item_id: "item_fn".into(), - output_index: 0, - delta: "ton\"}".into(), - sequence_number: None, - }, - ResponsesStreamEvent::FunctionCallArgumentsDone { - item_id: "item_fn".into(), - output_index: 0, - arguments: "{\"city\":\"Boston\"}".into(), - sequence_number: None, - }, - ResponsesStreamEvent::Completed { - response: ResponseSummary::default(), - }, - ]; - - let mapped = map_response_events(events); - // Two partial events + one complete event + Stop - assert!(mapped.len() >= 3); - - // The last complete ToolUse event should have is_input_complete: true - let complete_tool_use = mapped.iter().find(|e| { - matches!( - e, - LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { - is_input_complete: true, - .. - }) - ) - }); - assert!( - complete_tool_use.is_some(), - "should have a complete tool use event" - ); - - // All ToolUse events before the final one should have is_input_complete: false - let tool_uses: Vec<_> = mapped - .iter() - .filter(|e| matches!(e, LanguageModelCompletionEvent::ToolUse(_))) - .collect(); - assert!( - tool_uses.len() >= 2, - "should have at least one partial and one complete event" - ); - - let last = tool_uses.last().unwrap(); - assert!(matches!( - last, - LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { - is_input_complete: true, - .. - }) - )); - } - - #[test] - fn responses_stream_maps_reasoning_summary_deltas() { - let events = vec![ - ResponsesStreamEvent::OutputItemAdded { - output_index: 0, - sequence_number: None, - item: ResponseOutputItem::Reasoning(ResponseReasoningItem { - id: Some("rs_123".into()), - summary: vec![], - }), - }, - ResponsesStreamEvent::ReasoningSummaryPartAdded { - item_id: "rs_123".into(), - output_index: 0, - summary_index: 0, - }, - ResponsesStreamEvent::ReasoningSummaryTextDelta { - item_id: "rs_123".into(), - output_index: 0, - delta: "Thinking about".into(), - }, - ResponsesStreamEvent::ReasoningSummaryTextDelta { - item_id: "rs_123".into(), - output_index: 0, - delta: " the answer".into(), - }, - ResponsesStreamEvent::ReasoningSummaryTextDone { - item_id: "rs_123".into(), - output_index: 0, - text: "Thinking about the answer".into(), - }, - ResponsesStreamEvent::ReasoningSummaryPartDone { - item_id: "rs_123".into(), - output_index: 0, - summary_index: 0, - }, - ResponsesStreamEvent::ReasoningSummaryPartAdded { - item_id: "rs_123".into(), - output_index: 0, - summary_index: 1, - }, - ResponsesStreamEvent::ReasoningSummaryTextDelta { - item_id: "rs_123".into(), - output_index: 0, - delta: "Second part".into(), - }, - ResponsesStreamEvent::ReasoningSummaryTextDone { - item_id: "rs_123".into(), - output_index: 0, - text: "Second part".into(), - }, - ResponsesStreamEvent::ReasoningSummaryPartDone { - item_id: "rs_123".into(), - output_index: 0, - summary_index: 1, - }, - ResponsesStreamEvent::OutputItemDone { - output_index: 0, - sequence_number: None, - item: ResponseOutputItem::Reasoning(ResponseReasoningItem { - id: Some("rs_123".into()), - summary: vec![ - ReasoningSummaryPart::SummaryText { - text: "Thinking about the answer".into(), - }, - ReasoningSummaryPart::SummaryText { - text: "Second part".into(), - }, - ], - }), - }, - ResponsesStreamEvent::OutputItemAdded { - output_index: 1, - sequence_number: None, - item: response_item_message("msg_456"), - }, - ResponsesStreamEvent::OutputTextDelta { - item_id: "msg_456".into(), - output_index: 1, - content_index: Some(0), - delta: "The answer is 42".into(), - }, - ResponsesStreamEvent::Completed { - response: ResponseSummary::default(), - }, - ]; - - let mapped = map_response_events(events); - - let thinking_events: Vec<_> = mapped - .iter() - .filter(|e| matches!(e, LanguageModelCompletionEvent::Thinking { .. })) - .collect(); - assert_eq!( - thinking_events.len(), - 4, - "expected 4 thinking events (2 deltas + separator + second delta), got {:?}", - thinking_events, - ); - - assert!(matches!( - &thinking_events[0], - LanguageModelCompletionEvent::Thinking { text, .. } if text == "Thinking about" - )); - assert!(matches!( - &thinking_events[1], - LanguageModelCompletionEvent::Thinking { text, .. } if text == " the answer" - )); - assert!( - matches!( - &thinking_events[2], - LanguageModelCompletionEvent::Thinking { text, .. } if text == "\n\n" - ), - "expected separator between summary parts" - ); - assert!(matches!( - &thinking_events[3], - LanguageModelCompletionEvent::Thinking { text, .. } if text == "Second part" - )); - - assert!(mapped.iter().any(|e| matches!( - e, - LanguageModelCompletionEvent::Text(t) if t == "The answer is 42" - ))); - } - - #[test] - fn responses_stream_maps_reasoning_from_done_only() { - let events = vec![ - ResponsesStreamEvent::OutputItemAdded { - output_index: 0, - sequence_number: None, - item: ResponseOutputItem::Reasoning(ResponseReasoningItem { - id: Some("rs_789".into()), - summary: vec![], - }), - }, - ResponsesStreamEvent::OutputItemDone { - output_index: 0, - sequence_number: None, - item: ResponseOutputItem::Reasoning(ResponseReasoningItem { - id: Some("rs_789".into()), - summary: vec![ReasoningSummaryPart::SummaryText { - text: "Summary without deltas".into(), - }], - }), - }, - ResponsesStreamEvent::Completed { - response: ResponseSummary::default(), - }, - ]; - - let mapped = map_response_events(events); - - assert!( - !mapped - .iter() - .any(|e| matches!(e, LanguageModelCompletionEvent::Thinking { .. })), - "OutputItemDone reasoning should not produce Thinking events (no delta/done text events)" - ); - } -} diff --git a/crates/language_models/src/provider/open_ai_compatible.rs b/crates/language_models/src/provider/open_ai_compatible.rs index e418d08ac63b985606926355e0503e56539f028a..9b0edf3040ef935e6ae08047e0ea9746b96e86ba 100644 --- a/crates/language_models/src/provider/open_ai_compatible.rs +++ b/crates/language_models/src/provider/open_ai_compatible.rs @@ -403,7 +403,7 @@ impl LanguageModel for OpenAiCompatibleLanguageModel { self.model.capabilities.parallel_tool_calls, self.model.capabilities.prompt_cache_key, self.max_output_tokens(), - self.model.reasoning_effort.clone(), + self.model.reasoning_effort, ); let completions = self.stream_completion(request, cx); async move { @@ -418,7 +418,7 @@ impl LanguageModel for OpenAiCompatibleLanguageModel { self.model.capabilities.parallel_tool_calls, self.model.capabilities.prompt_cache_key, self.max_output_tokens(), - self.model.reasoning_effort.clone(), + self.model.reasoning_effort, ); let completions = self.stream_response(request, cx); async move { diff --git a/crates/language_models/src/provider/open_router.rs b/crates/language_models/src/provider/open_router.rs index 09c8eb768d12c61ed1dc86a1251ad52114be6162..fba3a6938aecf1db80680e014e408e4d59c42ff7 100644 --- a/crates/language_models/src/provider/open_router.rs +++ b/crates/language_models/src/provider/open_router.rs @@ -22,7 +22,7 @@ use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*}; use ui_input::InputField; use util::ResultExt; -use crate::provider::util::{fix_streamed_json, parse_tool_arguments}; +use language_model::util::{fix_streamed_json, parse_tool_arguments}; const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openrouter"); const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("OpenRouter"); diff --git a/crates/language_models/src/provider/x_ai.rs b/crates/language_models/src/provider/x_ai.rs index 88189864c7b4b650a24afb2b872c1d6105cf9782..e95bc1ba72fabcf9632b2ed2efd94254fb1313cd 100644 --- a/crates/language_models/src/provider/x_ai.rs +++ b/crates/language_models/src/provider/x_ai.rs @@ -9,7 +9,7 @@ use language_model::{ LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolSchemaFormat, RateLimiter, - Role, env_var, + env_var, }; use open_ai::ResponseStreamEvent; pub use settings::XaiAvailableModel as AvailableModel; @@ -19,7 +19,8 @@ use strum::IntoEnumIterator; use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*}; use ui_input::InputField; use util::ResultExt; -use x_ai::{Model, XAI_API_URL}; +use x_ai::XAI_API_URL; +pub use x_ai::completion::count_xai_tokens; const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("x_ai"); const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("xAI"); @@ -320,7 +321,9 @@ impl LanguageModel for XAiLanguageModel { request: LanguageModelRequest, cx: &App, ) -> BoxFuture<'static, Result> { - count_xai_tokens(request, self.model.clone(), cx) + let model = self.model.clone(); + cx.background_spawn(async move { count_xai_tokens(request, model) }) + .boxed() } fn stream_completion( @@ -354,37 +357,6 @@ impl LanguageModel for XAiLanguageModel { } } -pub fn count_xai_tokens( - request: LanguageModelRequest, - model: Model, - cx: &App, -) -> BoxFuture<'static, Result> { - cx.background_spawn(async move { - let messages = request - .messages - .into_iter() - .map(|message| tiktoken_rs::ChatCompletionRequestMessage { - role: match message.role { - Role::User => "user".into(), - Role::Assistant => "assistant".into(), - Role::System => "system".into(), - }, - content: Some(message.string_contents()), - name: None, - function_call: None, - }) - .collect::>(); - - let model_name = if model.max_token_count() >= 100_000 { - "gpt-4o" - } else { - "gpt-4" - }; - tiktoken_rs::num_tokens_from_messages(model_name, &messages).map(|tokens| tokens as u64) - }) - .boxed() -} - struct ConfigurationView { api_key_editor: Entity, state: Entity, diff --git a/crates/language_models_cloud/Cargo.toml b/crates/language_models_cloud/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..b08acc5ecd5c2a718e936378c2dbfbc3d1c32df0 --- /dev/null +++ b/crates/language_models_cloud/Cargo.toml @@ -0,0 +1,33 @@ +[package] +name = "language_models_cloud" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[lib] +path = "src/language_models_cloud.rs" + +[dependencies] +anthropic = { workspace = true, features = ["schemars"] } +anyhow.workspace = true +cloud_llm_client.workspace = true +futures.workspace = true +google_ai = { workspace = true, features = ["schemars"] } +gpui.workspace = true +http_client.workspace = true +language_model.workspace = true +open_ai = { workspace = true, features = ["schemars"] } +schemars.workspace = true +semver.workspace = true +serde.workspace = true +serde_json.workspace = true +smol.workspace = true +thiserror.workspace = true +x_ai = { workspace = true, features = ["schemars"] } + +[dev-dependencies] +language_model = { workspace = true, features = ["test-support"] } diff --git a/crates/language_models_cloud/LICENSE-GPL b/crates/language_models_cloud/LICENSE-GPL new file mode 120000 index 0000000000000000000000000000000000000000..89e542f750cd3860a0598eff0dc34b56d7336dc4 --- /dev/null +++ b/crates/language_models_cloud/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/language_models_cloud/src/language_models_cloud.rs b/crates/language_models_cloud/src/language_models_cloud.rs new file mode 100644 index 0000000000000000000000000000000000000000..24c8ec87d5c672dbc18b20164f2fe28c9b46b2e1 --- /dev/null +++ b/crates/language_models_cloud/src/language_models_cloud.rs @@ -0,0 +1,1059 @@ +use anthropic::AnthropicModelMode; +use anyhow::{Context as _, Result, anyhow}; +use cloud_llm_client::{ + CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME, + CLIENT_SUPPORTS_X_AI_HEADER_NAME, CompletionBody, CompletionEvent, CompletionRequestStatus, + CountTokensBody, CountTokensResponse, EXPIRED_LLM_TOKEN_HEADER_NAME, ListModelsResponse, + OUTDATED_LLM_TOKEN_HEADER_NAME, SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, + ZED_VERSION_HEADER_NAME, +}; +use futures::{ + AsyncBufReadExt, FutureExt, Stream, StreamExt, + future::BoxFuture, + stream::{self, BoxStream}, +}; +use google_ai::GoogleModelMode; +use gpui::{App, AppContext, AsyncApp, Context, Task}; +use http_client::http::{HeaderMap, HeaderValue}; +use http_client::{ + AsyncBody, HttpClient, HttpClientWithUrl, HttpRequestExt, Method, Response, StatusCode, +}; +use language_model::{ + ANTHROPIC_PROVIDER_ID, ANTHROPIC_PROVIDER_NAME, GOOGLE_PROVIDER_ID, GOOGLE_PROVIDER_NAME, + LanguageModel, LanguageModelCacheConfiguration, LanguageModelCompletionError, + LanguageModelCompletionEvent, LanguageModelEffortLevel, LanguageModelId, LanguageModelName, + LanguageModelProviderId, LanguageModelProviderName, LanguageModelRequest, + LanguageModelToolChoice, LanguageModelToolSchemaFormat, OPEN_AI_PROVIDER_ID, + OPEN_AI_PROVIDER_NAME, PaymentRequiredError, RateLimiter, X_AI_PROVIDER_ID, X_AI_PROVIDER_NAME, + ZED_CLOUD_PROVIDER_ID, ZED_CLOUD_PROVIDER_NAME, +}; + +use schemars::JsonSchema; +use semver::Version; +use serde::{Deserialize, Serialize, de::DeserializeOwned}; +use smol::io::{AsyncReadExt, BufReader}; +use std::collections::VecDeque; +use std::pin::Pin; +use std::str::FromStr; +use std::sync::Arc; +use std::task::Poll; +use std::time::Duration; +use thiserror::Error; + +use anthropic::completion::{ + AnthropicEventMapper, count_anthropic_tokens_with_tiktoken, into_anthropic, +}; +use google_ai::completion::{GoogleEventMapper, into_google}; +use open_ai::completion::{ + OpenAiEventMapper, OpenAiResponseEventMapper, count_open_ai_tokens, into_open_ai, + into_open_ai_response, +}; +use x_ai::completion::count_xai_tokens; + +const PROVIDER_ID: LanguageModelProviderId = ZED_CLOUD_PROVIDER_ID; +const PROVIDER_NAME: LanguageModelProviderName = ZED_CLOUD_PROVIDER_NAME; + +/// Trait for acquiring and refreshing LLM authentication tokens. +pub trait CloudLlmTokenProvider: Send + Sync { + type AuthContext: Clone + Send + 'static; + + fn auth_context(&self, cx: &AsyncApp) -> Self::AuthContext; + fn acquire_token(&self, auth_context: Self::AuthContext) -> BoxFuture<'static, Result>; + fn refresh_token(&self, auth_context: Self::AuthContext) -> BoxFuture<'static, Result>; +} + +#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] +#[serde(tag = "type", rename_all = "lowercase")] +pub enum ModelMode { + #[default] + Default, + Thinking { + /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`. + budget_tokens: Option, + }, +} + +impl From for AnthropicModelMode { + fn from(value: ModelMode) -> Self { + match value { + ModelMode::Default => AnthropicModelMode::Default, + ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens }, + } + } +} + +pub struct CloudLanguageModel { + pub id: LanguageModelId, + pub model: Arc, + pub token_provider: Arc, + pub http_client: Arc, + pub app_version: Option, + pub request_limiter: RateLimiter, +} + +pub struct PerformLlmCompletionResponse { + pub response: Response, + pub includes_status_messages: bool, +} + +impl CloudLanguageModel { + pub async fn perform_llm_completion( + http_client: &HttpClientWithUrl, + token_provider: &TP, + auth_context: TP::AuthContext, + app_version: Option, + body: CompletionBody, + ) -> Result { + let mut token = token_provider.acquire_token(auth_context.clone()).await?; + let mut refreshed_token = false; + + loop { + let request = http_client::Request::builder() + .method(Method::POST) + .uri(http_client.build_zed_llm_url("/completions", &[])?.as_ref()) + .when_some(app_version.as_ref(), |builder, app_version| { + builder.header(ZED_VERSION_HEADER_NAME, app_version.to_string()) + }) + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {token}")) + .header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true") + .header(CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME, "true") + .body(serde_json::to_string(&body)?.into())?; + + let mut response = http_client.send(request).await?; + let status = response.status(); + if status.is_success() { + let includes_status_messages = response + .headers() + .get(SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME) + .is_some(); + + return Ok(PerformLlmCompletionResponse { + response, + includes_status_messages, + }); + } + + if !refreshed_token && needs_llm_token_refresh(&response) { + token = token_provider.refresh_token(auth_context.clone()).await?; + refreshed_token = true; + continue; + } + + if status == StatusCode::PAYMENT_REQUIRED { + return Err(anyhow!(PaymentRequiredError)); + } + + let mut body = String::new(); + let headers = response.headers().clone(); + response.body_mut().read_to_string(&mut body).await?; + return Err(anyhow!(ApiError { + status, + body, + headers + })); + } + } +} + +fn needs_llm_token_refresh(response: &Response) -> bool { + response + .headers() + .get(EXPIRED_LLM_TOKEN_HEADER_NAME) + .is_some() + || response + .headers() + .get(OUTDATED_LLM_TOKEN_HEADER_NAME) + .is_some() +} + +#[derive(Debug, Error)] +#[error("cloud language model request failed with status {status}: {body}")] +struct ApiError { + status: StatusCode, + body: String, + headers: HeaderMap, +} + +/// Represents error responses from Zed's cloud API. +/// +/// Example JSON for an upstream HTTP error: +/// ```json +/// { +/// "code": "upstream_http_error", +/// "message": "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout", +/// "upstream_status": 503 +/// } +/// ``` +#[derive(Debug, serde::Deserialize)] +struct CloudApiError { + code: String, + message: String, + #[serde(default)] + #[serde(deserialize_with = "deserialize_optional_status_code")] + upstream_status: Option, + #[serde(default)] + retry_after: Option, +} + +fn deserialize_optional_status_code<'de, D>(deserializer: D) -> Result, D::Error> +where + D: serde::Deserializer<'de>, +{ + let opt: Option = Option::deserialize(deserializer)?; + Ok(opt.and_then(|code| StatusCode::from_u16(code).ok())) +} + +impl From for LanguageModelCompletionError { + fn from(error: ApiError) -> Self { + if let Ok(cloud_error) = serde_json::from_str::(&error.body) { + if cloud_error.code.starts_with("upstream_http_") { + let status = if let Some(status) = cloud_error.upstream_status { + status + } else if cloud_error.code.ends_with("_error") { + error.status + } else { + // If there's a status code in the code string (e.g. "upstream_http_429") + // then use that; otherwise, see if the JSON contains a status code. + cloud_error + .code + .strip_prefix("upstream_http_") + .and_then(|code_str| code_str.parse::().ok()) + .and_then(|code| StatusCode::from_u16(code).ok()) + .unwrap_or(error.status) + }; + + return LanguageModelCompletionError::UpstreamProviderError { + message: cloud_error.message, + status, + retry_after: cloud_error.retry_after.map(Duration::from_secs_f64), + }; + } + + return LanguageModelCompletionError::from_http_status( + PROVIDER_NAME, + error.status, + cloud_error.message, + None, + ); + } + + let retry_after = None; + LanguageModelCompletionError::from_http_status( + PROVIDER_NAME, + error.status, + error.body, + retry_after, + ) + } +} + +impl LanguageModel for CloudLanguageModel { + fn id(&self) -> LanguageModelId { + self.id.clone() + } + + fn name(&self) -> LanguageModelName { + LanguageModelName::from(self.model.display_name.clone()) + } + + fn provider_id(&self) -> LanguageModelProviderId { + PROVIDER_ID + } + + fn provider_name(&self) -> LanguageModelProviderName { + PROVIDER_NAME + } + + fn upstream_provider_id(&self) -> LanguageModelProviderId { + use cloud_llm_client::LanguageModelProvider::*; + match self.model.provider { + Anthropic => ANTHROPIC_PROVIDER_ID, + OpenAi => OPEN_AI_PROVIDER_ID, + Google => GOOGLE_PROVIDER_ID, + XAi => X_AI_PROVIDER_ID, + } + } + + fn upstream_provider_name(&self) -> LanguageModelProviderName { + use cloud_llm_client::LanguageModelProvider::*; + match self.model.provider { + Anthropic => ANTHROPIC_PROVIDER_NAME, + OpenAi => OPEN_AI_PROVIDER_NAME, + Google => GOOGLE_PROVIDER_NAME, + XAi => X_AI_PROVIDER_NAME, + } + } + + fn is_latest(&self) -> bool { + self.model.is_latest + } + + fn supports_tools(&self) -> bool { + self.model.supports_tools + } + + fn supports_images(&self) -> bool { + self.model.supports_images + } + + fn supports_thinking(&self) -> bool { + self.model.supports_thinking + } + + fn supports_fast_mode(&self) -> bool { + self.model.supports_fast_mode + } + + fn supported_effort_levels(&self) -> Vec { + self.model + .supported_effort_levels + .iter() + .map(|effort_level| LanguageModelEffortLevel { + name: effort_level.name.clone().into(), + value: effort_level.value.clone().into(), + is_default: effort_level.is_default.unwrap_or(false), + }) + .collect() + } + + fn supports_streaming_tools(&self) -> bool { + self.model.supports_streaming_tools + } + + fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { + match choice { + LanguageModelToolChoice::Auto + | LanguageModelToolChoice::Any + | LanguageModelToolChoice::None => true, + } + } + + fn supports_split_token_display(&self) -> bool { + use cloud_llm_client::LanguageModelProvider::*; + matches!(self.model.provider, OpenAi | XAi) + } + + fn telemetry_id(&self) -> String { + format!("zed.dev/{}", self.model.id) + } + + fn tool_input_format(&self) -> LanguageModelToolSchemaFormat { + match self.model.provider { + cloud_llm_client::LanguageModelProvider::Anthropic + | cloud_llm_client::LanguageModelProvider::OpenAi => { + LanguageModelToolSchemaFormat::JsonSchema + } + cloud_llm_client::LanguageModelProvider::Google + | cloud_llm_client::LanguageModelProvider::XAi => { + LanguageModelToolSchemaFormat::JsonSchemaSubset + } + } + } + + fn max_token_count(&self) -> u64 { + self.model.max_token_count as u64 + } + + fn max_output_tokens(&self) -> Option { + Some(self.model.max_output_tokens as u64) + } + + fn cache_configuration(&self) -> Option { + match &self.model.provider { + cloud_llm_client::LanguageModelProvider::Anthropic => { + Some(LanguageModelCacheConfiguration { + min_total_token: 2_048, + should_speculate: true, + max_cache_anchors: 4, + }) + } + cloud_llm_client::LanguageModelProvider::OpenAi + | cloud_llm_client::LanguageModelProvider::XAi + | cloud_llm_client::LanguageModelProvider::Google => None, + } + } + + fn count_tokens( + &self, + request: LanguageModelRequest, + cx: &App, + ) -> BoxFuture<'static, Result> { + match self.model.provider { + cloud_llm_client::LanguageModelProvider::Anthropic => cx + .background_spawn(async move { count_anthropic_tokens_with_tiktoken(request) }) + .boxed(), + cloud_llm_client::LanguageModelProvider::OpenAi => { + let model = match open_ai::Model::from_id(&self.model.id.0) { + Ok(model) => model, + Err(err) => return async move { Err(anyhow!(err)) }.boxed(), + }; + cx.background_spawn(async move { count_open_ai_tokens(request, model) }) + .boxed() + } + cloud_llm_client::LanguageModelProvider::XAi => { + let model = match x_ai::Model::from_id(&self.model.id.0) { + Ok(model) => model, + Err(err) => return async move { Err(anyhow!(err)) }.boxed(), + }; + cx.background_spawn(async move { count_xai_tokens(request, model) }) + .boxed() + } + cloud_llm_client::LanguageModelProvider::Google => { + let http_client = self.http_client.clone(); + let token_provider = self.token_provider.clone(); + let model_id = self.model.id.to_string(); + let generate_content_request = + into_google(request, model_id.clone(), GoogleModelMode::Default); + let auth_context = token_provider.auth_context(&cx.to_async()); + async move { + let token = token_provider.acquire_token(auth_context).await?; + + let request_body = CountTokensBody { + provider: cloud_llm_client::LanguageModelProvider::Google, + model: model_id, + provider_request: serde_json::to_value(&google_ai::CountTokensRequest { + generate_content_request, + })?, + }; + let request = http_client::Request::builder() + .method(Method::POST) + .uri( + http_client + .build_zed_llm_url("/count_tokens", &[])? + .as_ref(), + ) + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {token}")) + .body(serde_json::to_string(&request_body)?.into())?; + let mut response = http_client.send(request).await?; + let status = response.status(); + let headers = response.headers().clone(); + let mut response_body = String::new(); + response + .body_mut() + .read_to_string(&mut response_body) + .await?; + + if status.is_success() { + let response_body: CountTokensResponse = + serde_json::from_str(&response_body)?; + + Ok(response_body.tokens as u64) + } else { + Err(anyhow!(ApiError { + status, + body: response_body, + headers + })) + } + } + .boxed() + } + } + } + + fn stream_completion( + &self, + request: LanguageModelRequest, + cx: &AsyncApp, + ) -> BoxFuture< + 'static, + Result< + BoxStream<'static, Result>, + LanguageModelCompletionError, + >, + > { + let thread_id = request.thread_id.clone(); + let prompt_id = request.prompt_id.clone(); + let app_version = self.app_version.clone(); + let thinking_allowed = request.thinking_allowed; + let enable_thinking = thinking_allowed && self.model.supports_thinking; + let provider_name = provider_name(&self.model.provider); + match self.model.provider { + cloud_llm_client::LanguageModelProvider::Anthropic => { + let effort = request + .thinking_effort + .as_ref() + .and_then(|effort| anthropic::Effort::from_str(effort).ok()); + + let mut request = into_anthropic( + request, + self.model.id.to_string(), + 1.0, + self.model.max_output_tokens as u64, + if enable_thinking { + AnthropicModelMode::Thinking { + budget_tokens: Some(4_096), + } + } else { + AnthropicModelMode::Default + }, + ); + + if enable_thinking && effort.is_some() { + request.thinking = Some(anthropic::Thinking::Adaptive); + request.output_config = Some(anthropic::OutputConfig { effort }); + } + + let http_client = self.http_client.clone(); + let token_provider = self.token_provider.clone(); + let auth_context = token_provider.auth_context(cx); + let future = self.request_limiter.stream(async move { + let PerformLlmCompletionResponse { + response, + includes_status_messages, + } = Self::perform_llm_completion( + &http_client, + &*token_provider, + auth_context, + app_version, + CompletionBody { + thread_id, + prompt_id, + provider: cloud_llm_client::LanguageModelProvider::Anthropic, + model: request.model.clone(), + provider_request: serde_json::to_value(&request) + .map_err(|e| anyhow!(e))?, + }, + ) + .await + .map_err(|err| match err.downcast::() { + Ok(api_err) => anyhow!(LanguageModelCompletionError::from(api_err)), + Err(err) => anyhow!(err), + })?; + + let mut mapper = AnthropicEventMapper::new(); + Ok(map_cloud_completion_events( + Box::pin(response_lines(response, includes_status_messages)), + &provider_name, + move |event| mapper.map_event(event), + )) + }); + async move { Ok(future.await?.boxed()) }.boxed() + } + cloud_llm_client::LanguageModelProvider::OpenAi => { + let http_client = self.http_client.clone(); + let token_provider = self.token_provider.clone(); + let effort = request + .thinking_effort + .as_ref() + .and_then(|effort| open_ai::ReasoningEffort::from_str(effort).ok()); + + let mut request = into_open_ai_response( + request, + &self.model.id.0, + self.model.supports_parallel_tool_calls, + true, + None, + None, + ); + + if enable_thinking && let Some(effort) = effort { + request.reasoning = Some(open_ai::responses::ReasoningConfig { + effort, + summary: Some(open_ai::responses::ReasoningSummaryMode::Auto), + }); + } + + let auth_context = token_provider.auth_context(cx); + let future = self.request_limiter.stream(async move { + let PerformLlmCompletionResponse { + response, + includes_status_messages, + } = Self::perform_llm_completion( + &http_client, + &*token_provider, + auth_context, + app_version, + CompletionBody { + thread_id, + prompt_id, + provider: cloud_llm_client::LanguageModelProvider::OpenAi, + model: request.model.clone(), + provider_request: serde_json::to_value(&request) + .map_err(|e| anyhow!(e))?, + }, + ) + .await?; + + let mut mapper = OpenAiResponseEventMapper::new(); + Ok(map_cloud_completion_events( + Box::pin(response_lines(response, includes_status_messages)), + &provider_name, + move |event| mapper.map_event(event), + )) + }); + async move { Ok(future.await?.boxed()) }.boxed() + } + cloud_llm_client::LanguageModelProvider::XAi => { + let http_client = self.http_client.clone(); + let token_provider = self.token_provider.clone(); + let request = into_open_ai( + request, + &self.model.id.0, + self.model.supports_parallel_tool_calls, + false, + None, + None, + ); + let auth_context = token_provider.auth_context(cx); + let future = self.request_limiter.stream(async move { + let PerformLlmCompletionResponse { + response, + includes_status_messages, + } = Self::perform_llm_completion( + &http_client, + &*token_provider, + auth_context, + app_version, + CompletionBody { + thread_id, + prompt_id, + provider: cloud_llm_client::LanguageModelProvider::XAi, + model: request.model.clone(), + provider_request: serde_json::to_value(&request) + .map_err(|e| anyhow!(e))?, + }, + ) + .await?; + + let mut mapper = OpenAiEventMapper::new(); + Ok(map_cloud_completion_events( + Box::pin(response_lines(response, includes_status_messages)), + &provider_name, + move |event| mapper.map_event(event), + )) + }); + async move { Ok(future.await?.boxed()) }.boxed() + } + cloud_llm_client::LanguageModelProvider::Google => { + let http_client = self.http_client.clone(); + let token_provider = self.token_provider.clone(); + let request = + into_google(request, self.model.id.to_string(), GoogleModelMode::Default); + let auth_context = token_provider.auth_context(cx); + let future = self.request_limiter.stream(async move { + let PerformLlmCompletionResponse { + response, + includes_status_messages, + } = Self::perform_llm_completion( + &http_client, + &*token_provider, + auth_context, + app_version, + CompletionBody { + thread_id, + prompt_id, + provider: cloud_llm_client::LanguageModelProvider::Google, + model: request.model.model_id.clone(), + provider_request: serde_json::to_value(&request) + .map_err(|e| anyhow!(e))?, + }, + ) + .await?; + + let mut mapper = GoogleEventMapper::new(); + Ok(map_cloud_completion_events( + Box::pin(response_lines(response, includes_status_messages)), + &provider_name, + move |event| mapper.map_event(event), + )) + }); + async move { Ok(future.await?.boxed()) }.boxed() + } + } + } +} + +pub struct CloudModelProvider { + token_provider: Arc, + http_client: Arc, + app_version: Option, + models: Vec>, + default_model: Option>, + default_fast_model: Option>, + recommended_models: Vec>, +} + +impl CloudModelProvider { + pub fn new( + token_provider: Arc, + http_client: Arc, + app_version: Option, + ) -> Self { + Self { + token_provider, + http_client, + app_version, + models: Vec::new(), + default_model: None, + default_fast_model: None, + recommended_models: Vec::new(), + } + } + + pub fn refresh_models(&self, cx: &mut Context) -> Task> { + let http_client = self.http_client.clone(); + let token_provider = self.token_provider.clone(); + cx.spawn(async move |this, cx| { + let auth_context = token_provider.auth_context(cx); + let response = + Self::fetch_models_request(&http_client, &*token_provider, auth_context).await?; + this.update(cx, |this, cx| { + this.update_models(response); + cx.notify(); + }) + }) + } + + async fn fetch_models_request( + http_client: &HttpClientWithUrl, + token_provider: &TP, + auth_context: TP::AuthContext, + ) -> Result { + let token = token_provider.acquire_token(auth_context).await?; + + let request = http_client::Request::builder() + .method(Method::GET) + .header(CLIENT_SUPPORTS_X_AI_HEADER_NAME, "true") + .uri(http_client.build_zed_llm_url("/models", &[])?.as_ref()) + .header("Authorization", format!("Bearer {token}")) + .body(AsyncBody::empty())?; + let mut response = http_client + .send(request) + .await + .context("failed to send list models request")?; + + if response.status().is_success() { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + Ok(serde_json::from_str(&body)?) + } else { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + anyhow::bail!( + "error listing models.\nStatus: {:?}\nBody: {body}", + response.status(), + ); + } + } + + pub fn update_models(&mut self, response: ListModelsResponse) { + let models: Vec<_> = response.models.into_iter().map(Arc::new).collect(); + + self.default_model = models + .iter() + .find(|model| { + response + .default_model + .as_ref() + .is_some_and(|default_model_id| &model.id == default_model_id) + }) + .cloned(); + self.default_fast_model = models + .iter() + .find(|model| { + response + .default_fast_model + .as_ref() + .is_some_and(|default_fast_model_id| &model.id == default_fast_model_id) + }) + .cloned(); + self.recommended_models = response + .recommended_models + .iter() + .filter_map(|id| models.iter().find(|model| &model.id == id)) + .cloned() + .collect(); + self.models = models; + } + + pub fn create_model( + &self, + model: &Arc, + ) -> Arc { + Arc::new(CloudLanguageModel:: { + id: LanguageModelId::from(model.id.0.to_string()), + model: model.clone(), + token_provider: self.token_provider.clone(), + http_client: self.http_client.clone(), + app_version: self.app_version.clone(), + request_limiter: RateLimiter::new(4), + }) + } + + pub fn models(&self) -> &[Arc] { + &self.models + } + + pub fn default_model(&self) -> Option<&Arc> { + self.default_model.as_ref() + } + + pub fn default_fast_model(&self) -> Option<&Arc> { + self.default_fast_model.as_ref() + } + + pub fn recommended_models(&self) -> &[Arc] { + &self.recommended_models + } +} + +pub fn map_cloud_completion_events( + stream: Pin>> + Send>>, + provider: &LanguageModelProviderName, + mut map_callback: F, +) -> BoxStream<'static, Result> +where + T: DeserializeOwned + 'static, + F: FnMut(T) -> Vec> + + Send + + 'static, +{ + let provider = provider.clone(); + let mut stream = stream.fuse(); + + let mut saw_stream_ended = false; + + let mut done = false; + let mut pending = VecDeque::new(); + + stream::poll_fn(move |cx| { + loop { + if let Some(item) = pending.pop_front() { + return Poll::Ready(Some(item)); + } + + if done { + return Poll::Ready(None); + } + + match stream.poll_next_unpin(cx) { + Poll::Ready(Some(event)) => { + let items = match event { + Err(error) => { + vec![Err(LanguageModelCompletionError::from(error))] + } + Ok(CompletionEvent::Status(CompletionRequestStatus::StreamEnded)) => { + saw_stream_ended = true; + vec![] + } + Ok(CompletionEvent::Status(status)) => { + LanguageModelCompletionEvent::from_completion_request_status( + status, + provider.clone(), + ) + .transpose() + .map(|event| vec![event]) + .unwrap_or_default() + } + Ok(CompletionEvent::Event(event)) => map_callback(event), + }; + pending.extend(items); + } + Poll::Ready(None) => { + done = true; + + if !saw_stream_ended { + return Poll::Ready(Some(Err( + LanguageModelCompletionError::StreamEndedUnexpectedly { + provider: provider.clone(), + }, + ))); + } + } + Poll::Pending => return Poll::Pending, + } + } + }) + .boxed() +} + +pub fn provider_name( + provider: &cloud_llm_client::LanguageModelProvider, +) -> LanguageModelProviderName { + match provider { + cloud_llm_client::LanguageModelProvider::Anthropic => ANTHROPIC_PROVIDER_NAME, + cloud_llm_client::LanguageModelProvider::OpenAi => OPEN_AI_PROVIDER_NAME, + cloud_llm_client::LanguageModelProvider::Google => GOOGLE_PROVIDER_NAME, + cloud_llm_client::LanguageModelProvider::XAi => X_AI_PROVIDER_NAME, + } +} + +pub fn response_lines( + response: Response, + includes_status_messages: bool, +) -> impl Stream>> { + futures::stream::try_unfold( + (String::new(), BufReader::new(response.into_body())), + move |(mut line, mut body)| async move { + match body.read_line(&mut line).await { + Ok(0) => Ok(None), + Ok(_) => { + let event = if includes_status_messages { + serde_json::from_str::>(&line)? + } else { + CompletionEvent::Event(serde_json::from_str::(&line)?) + }; + + line.clear(); + Ok(Some((event, (line, body)))) + } + Err(e) => Err(e.into()), + } + }, + ) +} + +#[cfg(test)] +mod tests { + use super::*; + use http_client::http::{HeaderMap, StatusCode}; + use language_model::LanguageModelCompletionError; + + #[test] + fn test_api_error_conversion_with_upstream_http_error() { + // upstream_http_error with 503 status should become ServerOverloaded + let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout","upstream_status":503}"#; + + let api_error = ApiError { + status: StatusCode::INTERNAL_SERVER_ERROR, + body: error_body.to_string(), + headers: HeaderMap::new(), + }; + + let completion_error: LanguageModelCompletionError = api_error.into(); + + match completion_error { + LanguageModelCompletionError::UpstreamProviderError { message, .. } => { + assert_eq!( + message, + "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout" + ); + } + _ => panic!( + "Expected UpstreamProviderError for upstream 503, got: {:?}", + completion_error + ), + } + + // upstream_http_error with 500 status should become ApiInternalServerError + let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the OpenAI API: internal server error","upstream_status":500}"#; + + let api_error = ApiError { + status: StatusCode::INTERNAL_SERVER_ERROR, + body: error_body.to_string(), + headers: HeaderMap::new(), + }; + + let completion_error: LanguageModelCompletionError = api_error.into(); + + match completion_error { + LanguageModelCompletionError::UpstreamProviderError { message, .. } => { + assert_eq!( + message, + "Received an error from the OpenAI API: internal server error" + ); + } + _ => panic!( + "Expected UpstreamProviderError for upstream 500, got: {:?}", + completion_error + ), + } + + // upstream_http_error with 429 status should become RateLimitExceeded + let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Google API: rate limit exceeded","upstream_status":429}"#; + + let api_error = ApiError { + status: StatusCode::INTERNAL_SERVER_ERROR, + body: error_body.to_string(), + headers: HeaderMap::new(), + }; + + let completion_error: LanguageModelCompletionError = api_error.into(); + + match completion_error { + LanguageModelCompletionError::UpstreamProviderError { message, .. } => { + assert_eq!( + message, + "Received an error from the Google API: rate limit exceeded" + ); + } + _ => panic!( + "Expected UpstreamProviderError for upstream 429, got: {:?}", + completion_error + ), + } + + // Regular 500 error without upstream_http_error should remain ApiInternalServerError for Zed + let error_body = "Regular internal server error"; + + let api_error = ApiError { + status: StatusCode::INTERNAL_SERVER_ERROR, + body: error_body.to_string(), + headers: HeaderMap::new(), + }; + + let completion_error: LanguageModelCompletionError = api_error.into(); + + match completion_error { + LanguageModelCompletionError::ApiInternalServerError { provider, message } => { + assert_eq!(provider, PROVIDER_NAME); + assert_eq!(message, "Regular internal server error"); + } + _ => panic!( + "Expected ApiInternalServerError for regular 500, got: {:?}", + completion_error + ), + } + + // upstream_http_429 format should be converted to UpstreamProviderError + let error_body = r#"{"code":"upstream_http_429","message":"Upstream Anthropic rate limit exceeded.","retry_after":30.5}"#; + + let api_error = ApiError { + status: StatusCode::INTERNAL_SERVER_ERROR, + body: error_body.to_string(), + headers: HeaderMap::new(), + }; + + let completion_error: LanguageModelCompletionError = api_error.into(); + + match completion_error { + LanguageModelCompletionError::UpstreamProviderError { + message, + status, + retry_after, + } => { + assert_eq!(message, "Upstream Anthropic rate limit exceeded."); + assert_eq!(status, StatusCode::TOO_MANY_REQUESTS); + assert_eq!(retry_after, Some(Duration::from_secs_f64(30.5))); + } + _ => panic!( + "Expected UpstreamProviderError for upstream_http_429, got: {:?}", + completion_error + ), + } + + // Invalid JSON in error body should fall back to regular error handling + let error_body = "Not JSON at all"; + + let api_error = ApiError { + status: StatusCode::INTERNAL_SERVER_ERROR, + body: error_body.to_string(), + headers: HeaderMap::new(), + }; + + let completion_error: LanguageModelCompletionError = api_error.into(); + + match completion_error { + LanguageModelCompletionError::ApiInternalServerError { provider, .. } => { + assert_eq!(provider, PROVIDER_NAME); + } + _ => panic!( + "Expected ApiInternalServerError for invalid JSON, got: {:?}", + completion_error + ), + } + } +} diff --git a/crates/markdown/src/html/html_parser.rs b/crates/markdown/src/html/html_parser.rs index 20338ec2abef2314b7cd6ca91e45ee05be909745..8aa5da0cea7ea160721875fa889a720fe4c8bed1 100644 --- a/crates/markdown/src/html/html_parser.rs +++ b/crates/markdown/src/html/html_parser.rs @@ -1,6 +1,6 @@ use std::{cell::RefCell, collections::HashMap, mem, ops::Range}; -use gpui::{DefiniteLength, FontWeight, SharedString, px, relative}; +use gpui::{DefiniteLength, FontWeight, SharedString, TextAlign, px, relative}; use html5ever::{ Attribute, LocalName, ParseOpts, local_name, parse_document, tendril::TendrilSink, }; @@ -24,10 +24,17 @@ pub(crate) enum ParsedHtmlElement { List(ParsedHtmlList), Table(ParsedHtmlTable), BlockQuote(ParsedHtmlBlockQuote), - Paragraph(HtmlParagraph), + Paragraph(ParsedHtmlParagraph), Image(HtmlImage), } +#[derive(Debug, Clone)] +#[cfg_attr(test, derive(PartialEq))] +pub(crate) struct ParsedHtmlParagraph { + pub text_align: Option, + pub contents: HtmlParagraph, +} + impl ParsedHtmlElement { pub fn source_range(&self) -> Option> { Some(match self { @@ -35,7 +42,7 @@ impl ParsedHtmlElement { Self::List(list) => list.source_range.clone(), Self::Table(table) => table.source_range.clone(), Self::BlockQuote(block_quote) => block_quote.source_range.clone(), - Self::Paragraph(text) => match text.first()? { + Self::Paragraph(paragraph) => match paragraph.contents.first()? { HtmlParagraphChunk::Text(text) => text.source_range.clone(), HtmlParagraphChunk::Image(image) => image.source_range.clone(), }, @@ -83,6 +90,7 @@ pub(crate) struct ParsedHtmlHeading { pub source_range: Range, pub level: HeadingLevel, pub contents: HtmlParagraph, + pub text_align: Option, } #[derive(Debug, Clone)] @@ -236,20 +244,21 @@ fn parse_html_node( consume_children(source_range, node, elements, context); } NodeData::Text { contents } => { - elements.push(ParsedHtmlElement::Paragraph(vec![ - HtmlParagraphChunk::Text(ParsedHtmlText { + elements.push(ParsedHtmlElement::Paragraph(ParsedHtmlParagraph { + text_align: None, + contents: vec![HtmlParagraphChunk::Text(ParsedHtmlText { source_range, highlights: Vec::default(), links: Vec::default(), contents: contents.borrow().to_string().into(), - }), - ])); + })], + })); } NodeData::Comment { .. } => {} NodeData::Element { name, attrs, .. } => { - let mut styles = if let Some(styles) = - html_style_from_html_styles(extract_styles_from_attributes(attrs)) - { + let styles_map = extract_styles_from_attributes(attrs); + let text_align = text_align_from_attributes(attrs, &styles_map); + let mut styles = if let Some(styles) = html_style_from_html_styles(styles_map) { vec![styles] } else { Vec::default() @@ -270,7 +279,10 @@ fn parse_html_node( ); if !paragraph.is_empty() { - elements.push(ParsedHtmlElement::Paragraph(paragraph)); + elements.push(ParsedHtmlElement::Paragraph(ParsedHtmlParagraph { + text_align, + contents: paragraph, + })); } } else if matches!( name.local, @@ -303,6 +315,7 @@ fn parse_html_node( _ => unreachable!(), }, contents: paragraph, + text_align, })); } } else if name.local == local_name!("ul") || name.local == local_name!("ol") { @@ -589,6 +602,30 @@ fn html_style_from_html_styles(styles: HashMap) -> Option Option { + match value.trim().to_ascii_lowercase().as_str() { + "left" => Some(TextAlign::Left), + "center" => Some(TextAlign::Center), + "right" => Some(TextAlign::Right), + _ => None, + } +} + +fn text_align_from_styles(styles: &HashMap) -> Option { + styles + .get("text-align") + .and_then(|value| parse_text_align(value)) +} + +fn text_align_from_attributes( + attrs: &RefCell>, + styles: &HashMap, +) -> Option { + text_align_from_styles(styles).or_else(|| { + attr_value(attrs, local_name!("align")).and_then(|value| parse_text_align(&value)) + }) +} + fn extract_styles_from_attributes(attrs: &RefCell>) -> HashMap { let mut styles = HashMap::new(); @@ -770,6 +807,7 @@ fn extract_html_table(node: &Node, source_range: Range) -> Optionx

", 0..40).unwrap(); + let ParsedHtmlElement::Paragraph(paragraph) = &parsed.children[0] else { + panic!("expected paragraph"); + }; + assert_eq!(paragraph.text_align, Some(TextAlign::Center)); + } + + #[test] + fn parses_heading_text_align_from_style() { + let parsed = parse_html_block("

Title

", 0..45).unwrap(); + let ParsedHtmlElement::Heading(heading) = &parsed.children[0] else { + panic!("expected heading"); + }; + assert_eq!(heading.text_align, Some(TextAlign::Right)); + } + + #[test] + fn parses_paragraph_text_align_from_align_attribute() { + let parsed = parse_html_block("

x

", 0..24).unwrap(); + let ParsedHtmlElement::Paragraph(paragraph) = &parsed.children[0] else { + panic!("expected paragraph"); + }; + assert_eq!(paragraph.text_align, Some(TextAlign::Center)); + } + + #[test] + fn parses_heading_text_align_from_align_attribute() { + let parsed = parse_html_block("

Title

", 0..30).unwrap(); + let ParsedHtmlElement::Heading(heading) = &parsed.children[0] else { + panic!("expected heading"); + }; + assert_eq!(heading.text_align, Some(TextAlign::Right)); + } + + #[test] + fn prefers_style_text_align_over_align_attribute() { + let parsed = parse_html_block( + "

x

", + 0..50, + ) + .unwrap(); + let ParsedHtmlElement::Paragraph(paragraph) = &parsed.children[0] else { + panic!("expected paragraph"); + }; + assert_eq!(paragraph.text_align, Some(TextAlign::Center)); + } } diff --git a/crates/markdown/src/html/html_rendering.rs b/crates/markdown/src/html/html_rendering.rs index 103e2a6accb7dce9bc429419aafd27cbdf5080ce..6ae25eff0b4ba2ec8dedde8118ebd8d60e8fce7d 100644 --- a/crates/markdown/src/html/html_rendering.rs +++ b/crates/markdown/src/html/html_rendering.rs @@ -79,9 +79,20 @@ impl MarkdownElement { match element { ParsedHtmlElement::Paragraph(paragraph) => { - self.push_markdown_paragraph(builder, &source_range, markdown_end); - self.render_html_paragraph(paragraph, source_allocator, builder, cx, markdown_end); - builder.pop_div(); + self.push_markdown_paragraph( + builder, + &source_range, + markdown_end, + paragraph.text_align, + ); + self.render_html_paragraph( + ¶graph.contents, + source_allocator, + builder, + cx, + markdown_end, + ); + self.pop_markdown_paragraph(builder); } ParsedHtmlElement::Heading(heading) => { self.push_markdown_heading( @@ -89,6 +100,7 @@ impl MarkdownElement { heading.level, &heading.source_range, markdown_end, + heading.text_align, ); self.render_html_paragraph( &heading.contents, diff --git a/crates/markdown/src/markdown.rs b/crates/markdown/src/markdown.rs index 247c082d223005a7e0bd6d57696751ce76cc4d86..e6ad1b1f2ac9154eaabc6d18dbcb9c8695ae019d 100644 --- a/crates/markdown/src/markdown.rs +++ b/crates/markdown/src/markdown.rs @@ -36,8 +36,8 @@ use gpui::{ FocusHandle, Focusable, FontStyle, FontWeight, GlobalElementId, Hitbox, Hsla, Image, ImageFormat, ImageSource, KeyContext, Length, MouseButton, MouseDownEvent, MouseEvent, MouseMoveEvent, MouseUpEvent, Point, ScrollHandle, Stateful, StrikethroughStyle, - StyleRefinement, StyledText, Task, TextLayout, TextRun, TextStyle, TextStyleRefinement, - actions, img, point, quad, + StyleRefinement, StyledText, Task, TextAlign, TextLayout, TextRun, TextStyle, + TextStyleRefinement, actions, img, point, quad, }; use language::{CharClassifier, Language, LanguageRegistry, Rope}; use parser::CodeBlockMetadata; @@ -1025,8 +1025,17 @@ impl MarkdownElement { width: Option, height: Option, ) { + let align = builder.text_style().text_align; builder.modify_current_div(|el| { - el.items_center().flex().flex_row().child( + let mut image_container = el.flex().flex_row().items_center(); + + image_container = match align { + TextAlign::Left => image_container.justify_start(), + TextAlign::Center => image_container.justify_center(), + TextAlign::Right => image_container.justify_end(), + }; + + image_container.child( img(source) .max_w_full() .when_some(height, |this, height| this.h(height)) @@ -1041,14 +1050,29 @@ impl MarkdownElement { builder: &mut MarkdownElementBuilder, range: &Range, markdown_end: usize, + text_align_override: Option, ) { - builder.push_div( - div().when(!self.style.height_is_multiple_of_line_height, |el| { - el.mb_2().line_height(rems(1.3)) - }), - range, - markdown_end, - ); + let align = text_align_override.unwrap_or(self.style.base_text_style.text_align); + let mut paragraph = div().when(!self.style.height_is_multiple_of_line_height, |el| { + el.mb_2().line_height(rems(1.3)) + }); + + paragraph = match align { + TextAlign::Center => paragraph.text_center(), + TextAlign::Left => paragraph.text_left(), + TextAlign::Right => paragraph.text_right(), + }; + + builder.push_text_style(TextStyleRefinement { + text_align: Some(align), + ..Default::default() + }); + builder.push_div(paragraph, range, markdown_end); + } + + fn pop_markdown_paragraph(&self, builder: &mut MarkdownElementBuilder) { + builder.pop_div(); + builder.pop_text_style(); } fn push_markdown_heading( @@ -1057,15 +1081,26 @@ impl MarkdownElement { level: pulldown_cmark::HeadingLevel, range: &Range, markdown_end: usize, + text_align_override: Option, ) { + let align = text_align_override.unwrap_or(self.style.base_text_style.text_align); let mut heading = div().mb_2(); heading = apply_heading_style(heading, level, self.style.heading_level_styles.as_ref()); + heading = match align { + TextAlign::Center => heading.text_center(), + TextAlign::Left => heading.text_left(), + TextAlign::Right => heading.text_right(), + }; + let mut heading_style = self.style.heading.clone(); let heading_text_style = heading_style.text_style().clone(); heading.style().refine(&heading_style); - builder.push_text_style(heading_text_style); + builder.push_text_style(TextStyleRefinement { + text_align: Some(align), + ..heading_text_style + }); builder.push_div(heading, range, markdown_end); } @@ -1571,10 +1606,16 @@ impl Element for MarkdownElement { } } MarkdownTag::Paragraph => { - self.push_markdown_paragraph(&mut builder, range, markdown_end); + self.push_markdown_paragraph(&mut builder, range, markdown_end, None); } MarkdownTag::Heading { level, .. } => { - self.push_markdown_heading(&mut builder, *level, range, markdown_end); + self.push_markdown_heading( + &mut builder, + *level, + range, + markdown_end, + None, + ); } MarkdownTag::BlockQuote => { self.push_markdown_block_quote(&mut builder, range, markdown_end); @@ -1826,7 +1867,7 @@ impl Element for MarkdownElement { current_img_block_range.take(); } MarkdownTagEnd::Paragraph => { - builder.pop_div(); + self.pop_markdown_paragraph(&mut builder); } MarkdownTagEnd::Heading(_) => { self.pop_markdown_heading(&mut builder); diff --git a/crates/multi_buffer/src/multi_buffer.rs b/crates/multi_buffer/src/multi_buffer.rs index a54ff64af028f44adced1758933f794e9a002c5a..47c1288c8f9baeebf4afd54dd0597bfe5a41d15f 100644 --- a/crates/multi_buffer/src/multi_buffer.rs +++ b/crates/multi_buffer/src/multi_buffer.rs @@ -21,9 +21,9 @@ use itertools::Itertools; use language::{ AutoindentMode, Buffer, BufferChunks, BufferRow, BufferSnapshot, Capability, CharClassifier, CharKind, CharScopeContext, Chunk, CursorShape, DiagnosticEntryRef, File, IndentGuideSettings, - IndentSize, Language, LanguageScope, OffsetRangeExt, OffsetUtf16, Outline, OutlineItem, Point, - PointUtf16, Selection, TextDimension, TextObject, ToOffset as _, ToPoint as _, TransactionId, - TreeSitterOptions, Unclipped, + IndentSize, Language, LanguageAwareStyling, LanguageScope, OffsetRangeExt, OffsetUtf16, + Outline, OutlineItem, Point, PointUtf16, Selection, TextDimension, TextObject, ToOffset as _, + ToPoint as _, TransactionId, TreeSitterOptions, Unclipped, language_settings::{AllLanguageSettings, LanguageSettings}, }; @@ -1072,7 +1072,7 @@ pub struct MultiBufferChunks<'a> { range: Range, excerpt_offset_range: Range, excerpt_chunks: Option>, - language_aware: bool, + language_aware: LanguageAwareStyling, snapshot: &'a MultiBufferSnapshot, } @@ -3340,9 +3340,15 @@ impl EventEmitter for MultiBuffer {} impl MultiBufferSnapshot { pub fn text(&self) -> String { - self.chunks(MultiBufferOffset::ZERO..self.len(), false) - .map(|chunk| chunk.text) - .collect() + self.chunks( + MultiBufferOffset::ZERO..self.len(), + LanguageAwareStyling { + tree_sitter: false, + diagnostics: false, + }, + ) + .map(|chunk| chunk.text) + .collect() } pub fn reversed_chars_at(&self, position: T) -> impl Iterator + '_ { @@ -3378,7 +3384,14 @@ impl MultiBufferSnapshot { } pub fn text_for_range(&self, range: Range) -> impl Iterator + '_ { - self.chunks(range, false).map(|chunk| chunk.text) + self.chunks( + range, + LanguageAwareStyling { + tree_sitter: false, + diagnostics: false, + }, + ) + .map(|chunk| chunk.text) } pub fn is_line_blank(&self, row: MultiBufferRow) -> bool { @@ -4178,7 +4191,7 @@ impl MultiBufferSnapshot { pub fn chunks( &self, range: Range, - language_aware: bool, + language_aware: LanguageAwareStyling, ) -> MultiBufferChunks<'_> { let mut chunks = MultiBufferChunks { excerpt_offset_range: ExcerptDimension(MultiBufferOffset::ZERO) @@ -7227,7 +7240,7 @@ impl Excerpt { fn chunks_in_range<'a>( &'a self, range: Range, - language_aware: bool, + language_aware: LanguageAwareStyling, snapshot: &'a MultiBufferSnapshot, ) -> ExcerptChunks<'a> { let buffer = self.buffer_snapshot(snapshot); diff --git a/crates/multi_buffer/src/multi_buffer_tests.rs b/crates/multi_buffer/src/multi_buffer_tests.rs index bc904d1a05488ee365ebddf36c3b30accdfb9301..cebc9073e9d87a3c6eaf71d78e181d3e833ad56a 100644 --- a/crates/multi_buffer/src/multi_buffer_tests.rs +++ b/crates/multi_buffer/src/multi_buffer_tests.rs @@ -5039,7 +5039,13 @@ fn check_edits( fn assert_chunks_in_ranges(snapshot: &MultiBufferSnapshot) { let full_text = snapshot.text(); for ix in 0..full_text.len() { - let mut chunks = snapshot.chunks(MultiBufferOffset(0)..snapshot.len(), false); + let mut chunks = snapshot.chunks( + MultiBufferOffset(0)..snapshot.len(), + LanguageAwareStyling { + tree_sitter: false, + diagnostics: false, + }, + ); chunks.seek(MultiBufferOffset(ix)..snapshot.len()); let tail = chunks.map(|chunk| chunk.text).collect::(); assert_eq!(tail, &full_text[ix..], "seek to range: {:?}", ix..); @@ -5300,7 +5306,13 @@ fn test_random_chunk_bitmaps(cx: &mut App, mut rng: StdRng) { let snapshot = multibuffer.read(cx).snapshot(cx); - let chunks = snapshot.chunks(MultiBufferOffset(0)..snapshot.len(), false); + let chunks = snapshot.chunks( + MultiBufferOffset(0)..snapshot.len(), + LanguageAwareStyling { + tree_sitter: false, + diagnostics: false, + }, + ); for chunk in chunks { let chunk_text = chunk.text; @@ -5466,7 +5478,13 @@ fn test_random_chunk_bitmaps_with_diffs(cx: &mut App, mut rng: StdRng) { let snapshot = multibuffer.read(cx).snapshot(cx); - let chunks = snapshot.chunks(MultiBufferOffset(0)..snapshot.len(), false); + let chunks = snapshot.chunks( + MultiBufferOffset(0)..snapshot.len(), + LanguageAwareStyling { + tree_sitter: false, + diagnostics: false, + }, + ); for chunk in chunks { let chunk_text = chunk.text; diff --git a/crates/open_ai/Cargo.toml b/crates/open_ai/Cargo.toml index 3de3a4dc3fcb8c9519f4c67be7cead75401f6281..9a73e73196fa225691fa68e2ca839a19783bc3ca 100644 --- a/crates/open_ai/Cargo.toml +++ b/crates/open_ai/Cargo.toml @@ -17,13 +17,18 @@ schemars = ["dep:schemars"] [dependencies] anyhow.workspace = true +collections.workspace = true futures.workspace = true http_client.workspace = true +language_model_core.workspace = true rand.workspace = true schemars = { workspace = true, optional = true } log.workspace = true serde.workspace = true serde_json.workspace = true -settings.workspace = true strum.workspace = true thiserror.workspace = true +tiktoken-rs.workspace = true + +[dev-dependencies] +pretty_assertions.workspace = true diff --git a/crates/open_ai/src/completion.rs b/crates/open_ai/src/completion.rs new file mode 100644 index 0000000000000000000000000000000000000000..b1b9b58e70a254f5b259cfb0111674fd81d4b82f --- /dev/null +++ b/crates/open_ai/src/completion.rs @@ -0,0 +1,1696 @@ +use anyhow::{Result, anyhow}; +use collections::HashMap; +use futures::{Stream, StreamExt}; +use language_model_core::{ + LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelImage, + LanguageModelRequest, LanguageModelRequestMessage, LanguageModelToolChoice, + LanguageModelToolResultContent, LanguageModelToolUse, LanguageModelToolUseId, MessageContent, + Role, StopReason, TokenUsage, + util::{fix_streamed_json, parse_tool_arguments}, +}; +use std::pin::Pin; +use std::sync::Arc; + +use crate::responses::{ + Request as ResponseRequest, ResponseFunctionCallItem, ResponseFunctionCallOutputContent, + ResponseFunctionCallOutputItem, ResponseInputContent, ResponseInputItem, ResponseMessageItem, + ResponseOutputItem, ResponseSummary as ResponsesSummary, ResponseUsage as ResponsesUsage, + StreamEvent as ResponsesStreamEvent, +}; +use crate::{ + FunctionContent, FunctionDefinition, ImageUrl, MessagePart, Model, ReasoningEffort, + ResponseStreamEvent, ToolCall, ToolCallContent, +}; + +pub fn into_open_ai( + request: LanguageModelRequest, + model_id: &str, + supports_parallel_tool_calls: bool, + supports_prompt_cache_key: bool, + max_output_tokens: Option, + reasoning_effort: Option, +) -> crate::Request { + let stream = !model_id.starts_with("o1-"); + + let mut messages = Vec::new(); + for message in request.messages { + for content in message.content { + match content { + MessageContent::Text(text) | MessageContent::Thinking { text, .. } => { + let should_add = if message.role == Role::User { + // Including whitespace-only user messages can cause error with OpenAI compatible APIs + // See https://github.com/zed-industries/zed/issues/40097 + !text.trim().is_empty() + } else { + !text.is_empty() + }; + if should_add { + add_message_content_part( + MessagePart::Text { text }, + message.role, + &mut messages, + ); + } + } + MessageContent::RedactedThinking(_) => {} + MessageContent::Image(image) => { + add_message_content_part( + MessagePart::Image { + image_url: ImageUrl { + url: image.to_base64_url(), + detail: None, + }, + }, + message.role, + &mut messages, + ); + } + MessageContent::ToolUse(tool_use) => { + let tool_call = ToolCall { + id: tool_use.id.to_string(), + content: ToolCallContent::Function { + function: FunctionContent { + name: tool_use.name.to_string(), + arguments: serde_json::to_string(&tool_use.input) + .unwrap_or_default(), + }, + }, + }; + + if let Some(crate::RequestMessage::Assistant { tool_calls, .. }) = + messages.last_mut() + { + tool_calls.push(tool_call); + } else { + messages.push(crate::RequestMessage::Assistant { + content: None, + tool_calls: vec![tool_call], + }); + } + } + MessageContent::ToolResult(tool_result) => { + let content = match &tool_result.content { + LanguageModelToolResultContent::Text(text) => { + vec![MessagePart::Text { + text: text.to_string(), + }] + } + LanguageModelToolResultContent::Image(image) => { + vec![MessagePart::Image { + image_url: ImageUrl { + url: image.to_base64_url(), + detail: None, + }, + }] + } + }; + + messages.push(crate::RequestMessage::Tool { + content: content.into(), + tool_call_id: tool_result.tool_use_id.to_string(), + }); + } + } + } + } + + crate::Request { + model: model_id.into(), + messages, + stream, + stream_options: if stream { + Some(crate::StreamOptions::default()) + } else { + None + }, + stop: request.stop, + temperature: request.temperature.or(Some(1.0)), + max_completion_tokens: max_output_tokens, + parallel_tool_calls: if supports_parallel_tool_calls && !request.tools.is_empty() { + Some(supports_parallel_tool_calls) + } else { + None + }, + prompt_cache_key: if supports_prompt_cache_key { + request.thread_id + } else { + None + }, + tools: request + .tools + .into_iter() + .map(|tool| crate::ToolDefinition::Function { + function: FunctionDefinition { + name: tool.name, + description: Some(tool.description), + parameters: Some(tool.input_schema), + }, + }) + .collect(), + tool_choice: request.tool_choice.map(|choice| match choice { + LanguageModelToolChoice::Auto => crate::ToolChoice::Auto, + LanguageModelToolChoice::Any => crate::ToolChoice::Required, + LanguageModelToolChoice::None => crate::ToolChoice::None, + }), + reasoning_effort, + } +} + +pub fn into_open_ai_response( + request: LanguageModelRequest, + model_id: &str, + supports_parallel_tool_calls: bool, + supports_prompt_cache_key: bool, + max_output_tokens: Option, + reasoning_effort: Option, +) -> ResponseRequest { + let stream = !model_id.starts_with("o1-"); + + let LanguageModelRequest { + thread_id, + prompt_id: _, + intent: _, + messages, + tools, + tool_choice, + stop: _, + temperature, + thinking_allowed: _, + thinking_effort: _, + speed: _, + } = request; + + let mut input_items = Vec::new(); + for (index, message) in messages.into_iter().enumerate() { + append_message_to_response_items(message, index, &mut input_items); + } + + let tools: Vec<_> = tools + .into_iter() + .map(|tool| crate::responses::ToolDefinition::Function { + name: tool.name, + description: Some(tool.description), + parameters: Some(tool.input_schema), + strict: None, + }) + .collect(); + + ResponseRequest { + model: model_id.into(), + instructions: None, + input: input_items, + stream, + temperature, + top_p: None, + max_output_tokens, + parallel_tool_calls: if tools.is_empty() { + None + } else { + Some(supports_parallel_tool_calls) + }, + tool_choice: tool_choice.map(|choice| match choice { + LanguageModelToolChoice::Auto => crate::ToolChoice::Auto, + LanguageModelToolChoice::Any => crate::ToolChoice::Required, + LanguageModelToolChoice::None => crate::ToolChoice::None, + }), + tools, + prompt_cache_key: if supports_prompt_cache_key { + thread_id + } else { + None + }, + reasoning: reasoning_effort.map(|effort| crate::responses::ReasoningConfig { + effort, + summary: Some(crate::responses::ReasoningSummaryMode::Auto), + }), + store: None, + } +} + +fn append_message_to_response_items( + message: LanguageModelRequestMessage, + index: usize, + input_items: &mut Vec, +) { + let mut content_parts: Vec = Vec::new(); + + for content in message.content { + match content { + MessageContent::Text(text) => { + push_response_text_part(&message.role, text, &mut content_parts); + } + MessageContent::Thinking { text, .. } => { + push_response_text_part(&message.role, text, &mut content_parts); + } + MessageContent::RedactedThinking(_) => {} + MessageContent::Image(image) => { + push_response_image_part(&message.role, image, &mut content_parts); + } + MessageContent::ToolUse(tool_use) => { + flush_response_parts(&message.role, index, &mut content_parts, input_items); + let call_id = tool_use.id.to_string(); + input_items.push(ResponseInputItem::FunctionCall(ResponseFunctionCallItem { + call_id, + name: tool_use.name.to_string(), + arguments: tool_use.raw_input, + })); + } + MessageContent::ToolResult(tool_result) => { + flush_response_parts(&message.role, index, &mut content_parts, input_items); + input_items.push(ResponseInputItem::FunctionCallOutput( + ResponseFunctionCallOutputItem { + call_id: tool_result.tool_use_id.to_string(), + output: match tool_result.content { + LanguageModelToolResultContent::Text(text) => { + ResponseFunctionCallOutputContent::Text(text.to_string()) + } + LanguageModelToolResultContent::Image(image) => { + ResponseFunctionCallOutputContent::List(vec![ + ResponseInputContent::Image { + image_url: image.to_base64_url(), + }, + ]) + } + }, + }, + )); + } + } + } + + flush_response_parts(&message.role, index, &mut content_parts, input_items); +} + +fn push_response_text_part( + role: &Role, + text: impl Into, + parts: &mut Vec, +) { + let text = text.into(); + if text.trim().is_empty() { + return; + } + + match role { + Role::Assistant => parts.push(ResponseInputContent::OutputText { + text, + annotations: Vec::new(), + }), + _ => parts.push(ResponseInputContent::Text { text }), + } +} + +fn push_response_image_part( + role: &Role, + image: LanguageModelImage, + parts: &mut Vec, +) { + match role { + Role::Assistant => parts.push(ResponseInputContent::OutputText { + text: "[image omitted]".to_string(), + annotations: Vec::new(), + }), + _ => parts.push(ResponseInputContent::Image { + image_url: image.to_base64_url(), + }), + } +} + +fn flush_response_parts( + role: &Role, + _index: usize, + parts: &mut Vec, + input_items: &mut Vec, +) { + if parts.is_empty() { + return; + } + + let item = ResponseInputItem::Message(ResponseMessageItem { + role: match role { + Role::User => crate::Role::User, + Role::Assistant => crate::Role::Assistant, + Role::System => crate::Role::System, + }, + content: parts.clone(), + }); + + input_items.push(item); + parts.clear(); +} + +fn add_message_content_part( + new_part: MessagePart, + role: Role, + messages: &mut Vec, +) { + match (role, messages.last_mut()) { + (Role::User, Some(crate::RequestMessage::User { content })) + | ( + Role::Assistant, + Some(crate::RequestMessage::Assistant { + content: Some(content), + .. + }), + ) + | (Role::System, Some(crate::RequestMessage::System { content, .. })) => { + content.push_part(new_part); + } + _ => { + messages.push(match role { + Role::User => crate::RequestMessage::User { + content: crate::MessageContent::from(vec![new_part]), + }, + Role::Assistant => crate::RequestMessage::Assistant { + content: Some(crate::MessageContent::from(vec![new_part])), + tool_calls: Vec::new(), + }, + Role::System => crate::RequestMessage::System { + content: crate::MessageContent::from(vec![new_part]), + }, + }); + } + } +} + +pub struct OpenAiEventMapper { + tool_calls_by_index: HashMap, +} + +impl OpenAiEventMapper { + pub fn new() -> Self { + Self { + tool_calls_by_index: HashMap::default(), + } + } + + pub fn map_stream( + mut self, + events: Pin>>>, + ) -> impl Stream> + { + events.flat_map(move |event| { + futures::stream::iter(match event { + Ok(event) => self.map_event(event), + Err(error) => vec![Err(LanguageModelCompletionError::from(anyhow!(error)))], + }) + }) + } + + pub fn map_event( + &mut self, + event: ResponseStreamEvent, + ) -> Vec> { + let mut events = Vec::new(); + if let Some(usage) = event.usage { + events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage { + input_tokens: usage.prompt_tokens, + output_tokens: usage.completion_tokens, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + }))); + } + + let Some(choice) = event.choices.first() else { + return events; + }; + + if let Some(delta) = choice.delta.as_ref() { + if let Some(reasoning_content) = delta.reasoning_content.clone() { + if !reasoning_content.is_empty() { + events.push(Ok(LanguageModelCompletionEvent::Thinking { + text: reasoning_content, + signature: None, + })); + } + } + if let Some(content) = delta.content.clone() { + if !content.is_empty() { + events.push(Ok(LanguageModelCompletionEvent::Text(content))); + } + } + + if let Some(tool_calls) = delta.tool_calls.as_ref() { + for tool_call in tool_calls { + let entry = self.tool_calls_by_index.entry(tool_call.index).or_default(); + + if let Some(tool_id) = tool_call.id.clone() { + entry.id = tool_id; + } + + if let Some(function) = tool_call.function.as_ref() { + if let Some(name) = function.name.clone() { + entry.name = name; + } + + if let Some(arguments) = function.arguments.clone() { + entry.arguments.push_str(&arguments); + } + } + + if !entry.id.is_empty() && !entry.name.is_empty() { + if let Ok(input) = serde_json::from_str::( + &fix_streamed_json(&entry.arguments), + ) { + events.push(Ok(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: entry.id.clone().into(), + name: entry.name.as_str().into(), + is_input_complete: false, + input, + raw_input: entry.arguments.clone(), + thought_signature: None, + }, + ))); + } + } + } + } + } + + match choice.finish_reason.as_deref() { + Some("stop") => { + events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn))); + } + Some("tool_calls") => { + events.extend(self.tool_calls_by_index.drain().map(|(_, tool_call)| { + match parse_tool_arguments(&tool_call.arguments) { + Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: tool_call.id.clone().into(), + name: tool_call.name.as_str().into(), + is_input_complete: true, + input, + raw_input: tool_call.arguments.clone(), + thought_signature: None, + }, + )), + Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError { + id: tool_call.id.into(), + tool_name: tool_call.name.into(), + raw_input: tool_call.arguments.clone().into(), + json_parse_error: error.to_string(), + }), + } + })); + + events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse))); + } + Some(stop_reason) => { + log::error!("Unexpected OpenAI stop_reason: {stop_reason:?}",); + events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn))); + } + None => {} + } + + events + } +} + +#[derive(Default)] +struct RawToolCall { + id: String, + name: String, + arguments: String, +} + +pub struct OpenAiResponseEventMapper { + function_calls_by_item: HashMap, + pending_stop_reason: Option, +} + +#[derive(Default)] +struct PendingResponseFunctionCall { + call_id: String, + name: Arc, + arguments: String, +} + +impl OpenAiResponseEventMapper { + pub fn new() -> Self { + Self { + function_calls_by_item: HashMap::default(), + pending_stop_reason: None, + } + } + + pub fn map_stream( + mut self, + events: Pin>>>, + ) -> impl Stream> + { + events.flat_map(move |event| { + futures::stream::iter(match event { + Ok(event) => self.map_event(event), + Err(error) => vec![Err(LanguageModelCompletionError::from(anyhow!(error)))], + }) + }) + } + + pub fn map_event( + &mut self, + event: ResponsesStreamEvent, + ) -> Vec> { + match event { + ResponsesStreamEvent::OutputItemAdded { item, .. } => { + let mut events = Vec::new(); + + match &item { + ResponseOutputItem::Message(message) => { + if let Some(id) = &message.id { + events.push(Ok(LanguageModelCompletionEvent::StartMessage { + message_id: id.clone(), + })); + } + } + ResponseOutputItem::FunctionCall(function_call) => { + if let Some(item_id) = function_call.id.clone() { + let call_id = function_call + .call_id + .clone() + .or_else(|| function_call.id.clone()) + .unwrap_or_else(|| item_id.clone()); + let entry = PendingResponseFunctionCall { + call_id, + name: Arc::::from( + function_call.name.clone().unwrap_or_default(), + ), + arguments: function_call.arguments.clone(), + }; + self.function_calls_by_item.insert(item_id, entry); + } + } + ResponseOutputItem::Reasoning(_) | ResponseOutputItem::Unknown => {} + } + events + } + ResponsesStreamEvent::ReasoningSummaryTextDelta { delta, .. } => { + if delta.is_empty() { + Vec::new() + } else { + vec![Ok(LanguageModelCompletionEvent::Thinking { + text: delta, + signature: None, + })] + } + } + ResponsesStreamEvent::OutputTextDelta { delta, .. } => { + if delta.is_empty() { + Vec::new() + } else { + vec![Ok(LanguageModelCompletionEvent::Text(delta))] + } + } + ResponsesStreamEvent::FunctionCallArgumentsDelta { item_id, delta, .. } => { + if let Some(entry) = self.function_calls_by_item.get_mut(&item_id) { + entry.arguments.push_str(&delta); + if let Ok(input) = serde_json::from_str::( + &fix_streamed_json(&entry.arguments), + ) { + return vec![Ok(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: LanguageModelToolUseId::from(entry.call_id.clone()), + name: entry.name.clone(), + is_input_complete: false, + input, + raw_input: entry.arguments.clone(), + thought_signature: None, + }, + ))]; + } + } + Vec::new() + } + ResponsesStreamEvent::FunctionCallArgumentsDone { + item_id, arguments, .. + } => { + if let Some(mut entry) = self.function_calls_by_item.remove(&item_id) { + if !arguments.is_empty() { + entry.arguments = arguments; + } + let raw_input = entry.arguments.clone(); + self.pending_stop_reason = Some(StopReason::ToolUse); + match parse_tool_arguments(&entry.arguments) { + Ok(input) => { + vec![Ok(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: LanguageModelToolUseId::from(entry.call_id.clone()), + name: entry.name.clone(), + is_input_complete: true, + input, + raw_input, + thought_signature: None, + }, + ))] + } + Err(error) => { + vec![Ok(LanguageModelCompletionEvent::ToolUseJsonParseError { + id: LanguageModelToolUseId::from(entry.call_id.clone()), + tool_name: entry.name.clone(), + raw_input: Arc::::from(raw_input), + json_parse_error: error.to_string(), + })] + } + } + } else { + Vec::new() + } + } + ResponsesStreamEvent::Completed { response } => { + self.handle_completion(response, StopReason::EndTurn) + } + ResponsesStreamEvent::Incomplete { response } => { + let reason = response + .status_details + .as_ref() + .and_then(|details| details.reason.as_deref()); + let stop_reason = match reason { + Some("max_output_tokens") => StopReason::MaxTokens, + Some("content_filter") => { + self.pending_stop_reason = Some(StopReason::Refusal); + StopReason::Refusal + } + _ => self + .pending_stop_reason + .take() + .unwrap_or(StopReason::EndTurn), + }; + + let mut events = Vec::new(); + if self.pending_stop_reason.is_none() { + events.extend(self.emit_tool_calls_from_output(&response.output)); + } + if let Some(usage) = response.usage.as_ref() { + events.push(Ok(LanguageModelCompletionEvent::UsageUpdate( + token_usage_from_response_usage(usage), + ))); + } + events.push(Ok(LanguageModelCompletionEvent::Stop(stop_reason))); + events + } + ResponsesStreamEvent::Failed { response } => { + let message = response + .status_details + .and_then(|details| details.error) + .map(|error| error.to_string()) + .unwrap_or_else(|| "response failed".to_string()); + vec![Err(LanguageModelCompletionError::Other(anyhow!(message)))] + } + ResponsesStreamEvent::Error { error } + | ResponsesStreamEvent::GenericError { error } => { + vec![Err(LanguageModelCompletionError::Other(anyhow!( + error.message + )))] + } + ResponsesStreamEvent::ReasoningSummaryPartAdded { summary_index, .. } => { + if summary_index > 0 { + vec![Ok(LanguageModelCompletionEvent::Thinking { + text: "\n\n".to_string(), + signature: None, + })] + } else { + Vec::new() + } + } + ResponsesStreamEvent::OutputTextDone { .. } + | ResponsesStreamEvent::OutputItemDone { .. } + | ResponsesStreamEvent::ContentPartAdded { .. } + | ResponsesStreamEvent::ContentPartDone { .. } + | ResponsesStreamEvent::ReasoningSummaryTextDone { .. } + | ResponsesStreamEvent::ReasoningSummaryPartDone { .. } + | ResponsesStreamEvent::Created { .. } + | ResponsesStreamEvent::InProgress { .. } + | ResponsesStreamEvent::Unknown => Vec::new(), + } + } + + fn handle_completion( + &mut self, + response: ResponsesSummary, + default_reason: StopReason, + ) -> Vec> { + let mut events = Vec::new(); + + if self.pending_stop_reason.is_none() { + events.extend(self.emit_tool_calls_from_output(&response.output)); + } + + if let Some(usage) = response.usage.as_ref() { + events.push(Ok(LanguageModelCompletionEvent::UsageUpdate( + token_usage_from_response_usage(usage), + ))); + } + + let stop_reason = self.pending_stop_reason.take().unwrap_or(default_reason); + events.push(Ok(LanguageModelCompletionEvent::Stop(stop_reason))); + events + } + + fn emit_tool_calls_from_output( + &mut self, + output: &[ResponseOutputItem], + ) -> Vec> { + let mut events = Vec::new(); + for item in output { + if let ResponseOutputItem::FunctionCall(function_call) = item { + let Some(call_id) = function_call + .call_id + .clone() + .or_else(|| function_call.id.clone()) + else { + log::error!( + "Function call item missing both call_id and id: {:?}", + function_call + ); + continue; + }; + let name: Arc = Arc::from(function_call.name.clone().unwrap_or_default()); + let arguments = &function_call.arguments; + self.pending_stop_reason = Some(StopReason::ToolUse); + match parse_tool_arguments(arguments) { + Ok(input) => { + events.push(Ok(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: LanguageModelToolUseId::from(call_id.clone()), + name: name.clone(), + is_input_complete: true, + input, + raw_input: arguments.clone(), + thought_signature: None, + }, + ))); + } + Err(error) => { + events.push(Ok(LanguageModelCompletionEvent::ToolUseJsonParseError { + id: LanguageModelToolUseId::from(call_id.clone()), + tool_name: name.clone(), + raw_input: Arc::::from(arguments.clone()), + json_parse_error: error.to_string(), + })); + } + } + } + } + events + } +} + +fn token_usage_from_response_usage(usage: &ResponsesUsage) -> TokenUsage { + TokenUsage { + input_tokens: usage.input_tokens.unwrap_or_default(), + output_tokens: usage.output_tokens.unwrap_or_default(), + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + } +} + +pub fn collect_tiktoken_messages( + request: LanguageModelRequest, +) -> Vec { + request + .messages + .into_iter() + .map(|message| tiktoken_rs::ChatCompletionRequestMessage { + role: match message.role { + Role::User => "user".into(), + Role::Assistant => "assistant".into(), + Role::System => "system".into(), + }, + content: Some(message.string_contents()), + name: None, + function_call: None, + }) + .collect::>() +} + +/// Count tokens for an OpenAI model. This is synchronous; callers should spawn +/// it on a background thread if needed. +pub fn count_open_ai_tokens(request: LanguageModelRequest, model: Model) -> Result { + let messages = collect_tiktoken_messages(request); + match model { + Model::Custom { max_tokens, .. } => { + let model = if max_tokens >= 100_000 { + // If the max tokens is 100k or more, it likely uses the o200k_base tokenizer + "gpt-4o" + } else { + // Otherwise fallback to gpt-4, since only cl100k_base and o200k_base are + // supported with this tiktoken method + "gpt-4" + }; + tiktoken_rs::num_tokens_from_messages(model, &messages) + } + // Currently supported by tiktoken_rs + // Sometimes tiktoken-rs is behind on model support. If that is the case, make a new branch + // arm with an override. We enumerate all supported models here so that we can check if new + // models are supported yet or not. + Model::ThreePointFiveTurbo + | Model::Four + | Model::FourTurbo + | Model::FourOmniMini + | Model::FourPointOneNano + | Model::O1 + | Model::O3 + | Model::O3Mini + | Model::O4Mini + | Model::Five + | Model::FiveCodex + | Model::FiveMini + | Model::FiveNano => tiktoken_rs::num_tokens_from_messages(model.id(), &messages), + // GPT-5.1, 5.2, 5.2-codex, 5.3-codex, 5.4, and 5.4-pro don't have dedicated tiktoken support; use gpt-5 tokenizer + Model::FivePointOne + | Model::FivePointTwo + | Model::FivePointTwoCodex + | Model::FivePointThreeCodex + | Model::FivePointFour + | Model::FivePointFourPro => tiktoken_rs::num_tokens_from_messages("gpt-5", &messages), + } + .map(|tokens| tokens as u64) +} + +#[cfg(test)] +mod tests { + use crate::responses::{ + ReasoningSummaryPart, ResponseFunctionToolCall, ResponseOutputItem, ResponseOutputMessage, + ResponseReasoningItem, ResponseStatusDetails, ResponseSummary, ResponseUsage, + StreamEvent as ResponsesStreamEvent, + }; + use futures::{StreamExt, executor::block_on}; + use language_model_core::{ + LanguageModelImage, LanguageModelRequestMessage, LanguageModelRequestTool, + LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolUse, + LanguageModelToolUseId, SharedString, + }; + use pretty_assertions::assert_eq; + use serde_json::json; + + use super::*; + + fn map_response_events(events: Vec) -> Vec { + block_on(async { + OpenAiResponseEventMapper::new() + .map_stream(Box::pin(futures::stream::iter(events.into_iter().map(Ok)))) + .collect::>() + .await + .into_iter() + .map(Result::unwrap) + .collect() + }) + } + + fn response_item_message(id: &str) -> ResponseOutputItem { + ResponseOutputItem::Message(ResponseOutputMessage { + id: Some(id.to_string()), + role: Some("assistant".to_string()), + status: Some("in_progress".to_string()), + content: vec![], + }) + } + + fn response_item_function_call(id: &str, args: Option<&str>) -> ResponseOutputItem { + ResponseOutputItem::FunctionCall(ResponseFunctionToolCall { + id: Some(id.to_string()), + status: Some("in_progress".to_string()), + name: Some("get_weather".to_string()), + call_id: Some("call_123".to_string()), + arguments: args.map(|s| s.to_string()).unwrap_or_default(), + }) + } + + #[test] + fn tiktoken_rs_support() { + let request = LanguageModelRequest { + thread_id: None, + prompt_id: None, + intent: None, + messages: vec![LanguageModelRequestMessage { + role: Role::User, + content: vec![MessageContent::Text("message".into())], + cache: false, + reasoning_details: None, + }], + tools: vec![], + tool_choice: None, + stop: vec![], + temperature: None, + thinking_allowed: true, + thinking_effort: None, + speed: None, + }; + + // Validate that all models are supported by tiktoken-rs + for model in ::iter() { + let count = count_open_ai_tokens(request.clone(), model).unwrap(); + assert!(count > 0); + } + } + + #[test] + fn responses_stream_maps_text_and_usage() { + let events = vec![ + ResponsesStreamEvent::OutputItemAdded { + output_index: 0, + sequence_number: None, + item: response_item_message("msg_123"), + }, + ResponsesStreamEvent::OutputTextDelta { + item_id: "msg_123".into(), + output_index: 0, + content_index: Some(0), + delta: "Hello".into(), + }, + ResponsesStreamEvent::Completed { + response: ResponseSummary { + usage: Some(ResponseUsage { + input_tokens: Some(5), + output_tokens: Some(3), + total_tokens: Some(8), + }), + ..Default::default() + }, + }, + ]; + + let mapped = map_response_events(events); + assert!(matches!( + mapped[0], + LanguageModelCompletionEvent::StartMessage { ref message_id } if message_id == "msg_123" + )); + assert!(matches!( + mapped[1], + LanguageModelCompletionEvent::Text(ref text) if text == "Hello" + )); + assert!(matches!( + mapped[2], + LanguageModelCompletionEvent::UsageUpdate(TokenUsage { + input_tokens: 5, + output_tokens: 3, + .. + }) + )); + assert!(matches!( + mapped[3], + LanguageModelCompletionEvent::Stop(StopReason::EndTurn) + )); + } + + #[test] + fn into_open_ai_response_builds_complete_payload() { + let tool_call_id = LanguageModelToolUseId::from("call-42"); + let tool_input = json!({ "city": "Boston" }); + let tool_arguments = serde_json::to_string(&tool_input).unwrap(); + let tool_use = LanguageModelToolUse { + id: tool_call_id.clone(), + name: Arc::from("get_weather"), + raw_input: tool_arguments.clone(), + input: tool_input, + is_input_complete: true, + thought_signature: None, + }; + let tool_result = LanguageModelToolResult { + tool_use_id: tool_call_id, + tool_name: Arc::from("get_weather"), + is_error: false, + content: LanguageModelToolResultContent::Text(Arc::from("Sunny")), + output: Some(json!({ "forecast": "Sunny" })), + }; + let user_image = LanguageModelImage { + source: SharedString::from("aGVsbG8="), + size: None, + }; + let expected_image_url = user_image.to_base64_url(); + + let request = LanguageModelRequest { + thread_id: Some("thread-123".into()), + prompt_id: None, + intent: None, + messages: vec![ + LanguageModelRequestMessage { + role: Role::System, + content: vec![MessageContent::Text("System context".into())], + cache: false, + reasoning_details: None, + }, + LanguageModelRequestMessage { + role: Role::User, + content: vec![ + MessageContent::Text("Please check the weather.".into()), + MessageContent::Image(user_image), + ], + cache: false, + reasoning_details: None, + }, + LanguageModelRequestMessage { + role: Role::Assistant, + content: vec![ + MessageContent::Text("Looking that up.".into()), + MessageContent::ToolUse(tool_use), + ], + cache: false, + reasoning_details: None, + }, + LanguageModelRequestMessage { + role: Role::Assistant, + content: vec![MessageContent::ToolResult(tool_result)], + cache: false, + reasoning_details: None, + }, + ], + tools: vec![LanguageModelRequestTool { + name: "get_weather".into(), + description: "Fetches the weather".into(), + input_schema: json!({ "type": "object" }), + use_input_streaming: false, + }], + tool_choice: Some(LanguageModelToolChoice::Any), + stop: vec!["".into()], + temperature: None, + thinking_allowed: false, + thinking_effort: None, + speed: None, + }; + + let response = into_open_ai_response( + request, + "custom-model", + true, + true, + Some(2048), + Some(ReasoningEffort::Low), + ); + + let serialized = serde_json::to_value(&response).unwrap(); + let expected = json!({ + "model": "custom-model", + "input": [ + { + "type": "message", + "role": "system", + "content": [ + { "type": "input_text", "text": "System context" } + ] + }, + { + "type": "message", + "role": "user", + "content": [ + { "type": "input_text", "text": "Please check the weather." }, + { "type": "input_image", "image_url": expected_image_url } + ] + }, + { + "type": "message", + "role": "assistant", + "content": [ + { "type": "output_text", "text": "Looking that up.", "annotations": [] } + ] + }, + { + "type": "function_call", + "call_id": "call-42", + "name": "get_weather", + "arguments": tool_arguments + }, + { + "type": "function_call_output", + "call_id": "call-42", + "output": "Sunny" + } + ], + "stream": true, + "max_output_tokens": 2048, + "parallel_tool_calls": true, + "tool_choice": "required", + "tools": [ + { + "type": "function", + "name": "get_weather", + "description": "Fetches the weather", + "parameters": { "type": "object" } + } + ], + "prompt_cache_key": "thread-123", + "reasoning": { "effort": "low", "summary": "auto" } + }); + + assert_eq!(serialized, expected); + } + + #[test] + fn responses_stream_maps_tool_calls() { + let events = vec![ + ResponsesStreamEvent::OutputItemAdded { + output_index: 0, + sequence_number: None, + item: response_item_function_call("item_fn", Some("{\"city\":\"Bos")), + }, + ResponsesStreamEvent::FunctionCallArgumentsDelta { + item_id: "item_fn".into(), + output_index: 0, + delta: "ton\"}".into(), + sequence_number: None, + }, + ResponsesStreamEvent::FunctionCallArgumentsDone { + item_id: "item_fn".into(), + output_index: 0, + arguments: "{\"city\":\"Boston\"}".into(), + sequence_number: None, + }, + ResponsesStreamEvent::Completed { + response: ResponseSummary::default(), + }, + ]; + + let mapped = map_response_events(events); + assert_eq!(mapped.len(), 3); + assert!(matches!( + mapped[0], + LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { + is_input_complete: false, + .. + }) + )); + assert!(matches!( + mapped[1], + LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { + ref id, + ref name, + ref raw_input, + is_input_complete: true, + .. + }) if id.to_string() == "call_123" + && name.as_ref() == "get_weather" + && raw_input == "{\"city\":\"Boston\"}" + )); + assert!(matches!( + mapped[2], + LanguageModelCompletionEvent::Stop(StopReason::ToolUse) + )); + } + + #[test] + fn responses_stream_uses_max_tokens_stop_reason() { + let events = vec![ResponsesStreamEvent::Incomplete { + response: ResponseSummary { + status_details: Some(ResponseStatusDetails { + reason: Some("max_output_tokens".into()), + r#type: Some("incomplete".into()), + error: None, + }), + usage: Some(ResponseUsage { + input_tokens: Some(10), + output_tokens: Some(20), + total_tokens: Some(30), + }), + ..Default::default() + }, + }]; + + let mapped = map_response_events(events); + assert!(matches!( + mapped[0], + LanguageModelCompletionEvent::UsageUpdate(TokenUsage { + input_tokens: 10, + output_tokens: 20, + .. + }) + )); + assert!(matches!( + mapped[1], + LanguageModelCompletionEvent::Stop(StopReason::MaxTokens) + )); + } + + #[test] + fn responses_stream_handles_multiple_tool_calls() { + let events = vec![ + ResponsesStreamEvent::OutputItemAdded { + output_index: 0, + sequence_number: None, + item: response_item_function_call("item_fn1", Some("{\"city\":\"NYC\"}")), + }, + ResponsesStreamEvent::FunctionCallArgumentsDone { + item_id: "item_fn1".into(), + output_index: 0, + arguments: "{\"city\":\"NYC\"}".into(), + sequence_number: None, + }, + ResponsesStreamEvent::OutputItemAdded { + output_index: 1, + sequence_number: None, + item: response_item_function_call("item_fn2", Some("{\"city\":\"LA\"}")), + }, + ResponsesStreamEvent::FunctionCallArgumentsDone { + item_id: "item_fn2".into(), + output_index: 1, + arguments: "{\"city\":\"LA\"}".into(), + sequence_number: None, + }, + ResponsesStreamEvent::Completed { + response: ResponseSummary::default(), + }, + ]; + + let mapped = map_response_events(events); + assert_eq!(mapped.len(), 3); + assert!(matches!( + mapped[0], + LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { ref raw_input, .. }) + if raw_input == "{\"city\":\"NYC\"}" + )); + assert!(matches!( + mapped[1], + LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { ref raw_input, .. }) + if raw_input == "{\"city\":\"LA\"}" + )); + assert!(matches!( + mapped[2], + LanguageModelCompletionEvent::Stop(StopReason::ToolUse) + )); + } + + #[test] + fn responses_stream_handles_mixed_text_and_tool_calls() { + let events = vec![ + ResponsesStreamEvent::OutputItemAdded { + output_index: 0, + sequence_number: None, + item: response_item_message("msg_123"), + }, + ResponsesStreamEvent::OutputTextDelta { + item_id: "msg_123".into(), + output_index: 0, + content_index: Some(0), + delta: "Let me check that".into(), + }, + ResponsesStreamEvent::OutputItemAdded { + output_index: 1, + sequence_number: None, + item: response_item_function_call("item_fn", Some("{\"query\":\"test\"}")), + }, + ResponsesStreamEvent::FunctionCallArgumentsDone { + item_id: "item_fn".into(), + output_index: 1, + arguments: "{\"query\":\"test\"}".into(), + sequence_number: None, + }, + ResponsesStreamEvent::Completed { + response: ResponseSummary::default(), + }, + ]; + + let mapped = map_response_events(events); + assert!(matches!( + mapped[0], + LanguageModelCompletionEvent::StartMessage { .. } + )); + assert!( + matches!(mapped[1], LanguageModelCompletionEvent::Text(ref text) if text == "Let me check that") + ); + assert!( + matches!(mapped[2], LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { ref raw_input, .. }) if raw_input == "{\"query\":\"test\"}") + ); + assert!(matches!( + mapped[3], + LanguageModelCompletionEvent::Stop(StopReason::ToolUse) + )); + } + + #[test] + fn responses_stream_handles_json_parse_error() { + let events = vec![ + ResponsesStreamEvent::OutputItemAdded { + output_index: 0, + sequence_number: None, + item: response_item_function_call("item_fn", Some("{invalid json")), + }, + ResponsesStreamEvent::FunctionCallArgumentsDone { + item_id: "item_fn".into(), + output_index: 0, + arguments: "{invalid json".into(), + sequence_number: None, + }, + ResponsesStreamEvent::Completed { + response: ResponseSummary::default(), + }, + ]; + + let mapped = map_response_events(events); + assert!(matches!( + mapped[0], + LanguageModelCompletionEvent::ToolUseJsonParseError { ref raw_input, .. } + if raw_input.as_ref() == "{invalid json" + )); + } + + #[test] + fn responses_stream_handles_incomplete_function_call() { + let events = vec![ + ResponsesStreamEvent::OutputItemAdded { + output_index: 0, + sequence_number: None, + item: response_item_function_call("item_fn", Some("{\"city\":")), + }, + ResponsesStreamEvent::FunctionCallArgumentsDelta { + item_id: "item_fn".into(), + output_index: 0, + delta: "\"Boston\"".into(), + sequence_number: None, + }, + ResponsesStreamEvent::Incomplete { + response: ResponseSummary { + status_details: Some(ResponseStatusDetails { + reason: Some("max_output_tokens".into()), + r#type: Some("incomplete".into()), + error: None, + }), + output: vec![response_item_function_call( + "item_fn", + Some("{\"city\":\"Boston\"}"), + )], + ..Default::default() + }, + }, + ]; + + let mapped = map_response_events(events); + assert_eq!(mapped.len(), 3); + assert!(matches!( + mapped[0], + LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { + is_input_complete: false, + .. + }) + )); + assert!( + matches!(mapped[1], LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { ref raw_input, is_input_complete: true, .. }) if raw_input == "{\"city\":\"Boston\"}") + ); + assert!(matches!( + mapped[2], + LanguageModelCompletionEvent::Stop(StopReason::MaxTokens) + )); + } + + #[test] + fn responses_stream_incomplete_does_not_duplicate_tool_calls() { + let events = vec![ + ResponsesStreamEvent::OutputItemAdded { + output_index: 0, + sequence_number: None, + item: response_item_function_call("item_fn", Some("{\"city\":\"Boston\"}")), + }, + ResponsesStreamEvent::FunctionCallArgumentsDone { + item_id: "item_fn".into(), + output_index: 0, + arguments: "{\"city\":\"Boston\"}".into(), + sequence_number: None, + }, + ResponsesStreamEvent::Incomplete { + response: ResponseSummary { + status_details: Some(ResponseStatusDetails { + reason: Some("max_output_tokens".into()), + r#type: Some("incomplete".into()), + error: None, + }), + output: vec![response_item_function_call( + "item_fn", + Some("{\"city\":\"Boston\"}"), + )], + ..Default::default() + }, + }, + ]; + + let mapped = map_response_events(events); + assert_eq!(mapped.len(), 2); + assert!( + matches!(mapped[0], LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { ref raw_input, .. }) if raw_input == "{\"city\":\"Boston\"}") + ); + assert!(matches!( + mapped[1], + LanguageModelCompletionEvent::Stop(StopReason::MaxTokens) + )); + } + + #[test] + fn responses_stream_handles_empty_tool_arguments() { + let events = vec![ + ResponsesStreamEvent::OutputItemAdded { + output_index: 0, + sequence_number: None, + item: response_item_function_call("item_fn", Some("")), + }, + ResponsesStreamEvent::FunctionCallArgumentsDone { + item_id: "item_fn".into(), + output_index: 0, + arguments: "".into(), + sequence_number: None, + }, + ResponsesStreamEvent::Completed { + response: ResponseSummary::default(), + }, + ]; + + let mapped = map_response_events(events); + assert_eq!(mapped.len(), 2); + assert!(matches!( + &mapped[0], + LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { + id, name, raw_input, input, .. + }) if id.to_string() == "call_123" + && name.as_ref() == "get_weather" + && raw_input == "" + && input.is_object() + && input.as_object().unwrap().is_empty() + )); + assert!(matches!( + mapped[1], + LanguageModelCompletionEvent::Stop(StopReason::ToolUse) + )); + } + + #[test] + fn responses_stream_emits_partial_tool_use_events() { + let events = vec![ + ResponsesStreamEvent::OutputItemAdded { + output_index: 0, + sequence_number: None, + item: ResponseOutputItem::FunctionCall( + crate::responses::ResponseFunctionToolCall { + id: Some("item_fn".to_string()), + status: Some("in_progress".to_string()), + name: Some("get_weather".to_string()), + call_id: Some("call_abc".to_string()), + arguments: String::new(), + }, + ), + }, + ResponsesStreamEvent::FunctionCallArgumentsDelta { + item_id: "item_fn".into(), + output_index: 0, + delta: "{\"city\":\"Bos".into(), + sequence_number: None, + }, + ResponsesStreamEvent::FunctionCallArgumentsDelta { + item_id: "item_fn".into(), + output_index: 0, + delta: "ton\"}".into(), + sequence_number: None, + }, + ResponsesStreamEvent::FunctionCallArgumentsDone { + item_id: "item_fn".into(), + output_index: 0, + arguments: "{\"city\":\"Boston\"}".into(), + sequence_number: None, + }, + ResponsesStreamEvent::Completed { + response: ResponseSummary::default(), + }, + ]; + + let mapped = map_response_events(events); + assert!(mapped.len() >= 3); + + let complete_tool_use = mapped.iter().find(|e| { + matches!( + e, + LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { + is_input_complete: true, + .. + }) + ) + }); + assert!( + complete_tool_use.is_some(), + "should have a complete tool use event" + ); + + let tool_uses: Vec<_> = mapped + .iter() + .filter(|e| matches!(e, LanguageModelCompletionEvent::ToolUse(_))) + .collect(); + assert!( + tool_uses.len() >= 2, + "should have at least one partial and one complete event" + ); + assert!(matches!( + tool_uses.last().unwrap(), + LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { + is_input_complete: true, + .. + }) + )); + } + + #[test] + fn responses_stream_maps_reasoning_summary_deltas() { + let events = vec![ + ResponsesStreamEvent::OutputItemAdded { + output_index: 0, + sequence_number: None, + item: ResponseOutputItem::Reasoning(ResponseReasoningItem { + id: Some("rs_123".into()), + summary: vec![], + }), + }, + ResponsesStreamEvent::ReasoningSummaryPartAdded { + item_id: "rs_123".into(), + output_index: 0, + summary_index: 0, + }, + ResponsesStreamEvent::ReasoningSummaryTextDelta { + item_id: "rs_123".into(), + output_index: 0, + delta: "Thinking about".into(), + }, + ResponsesStreamEvent::ReasoningSummaryTextDelta { + item_id: "rs_123".into(), + output_index: 0, + delta: " the answer".into(), + }, + ResponsesStreamEvent::ReasoningSummaryTextDone { + item_id: "rs_123".into(), + output_index: 0, + text: "Thinking about the answer".into(), + }, + ResponsesStreamEvent::ReasoningSummaryPartDone { + item_id: "rs_123".into(), + output_index: 0, + summary_index: 0, + }, + ResponsesStreamEvent::ReasoningSummaryPartAdded { + item_id: "rs_123".into(), + output_index: 0, + summary_index: 1, + }, + ResponsesStreamEvent::ReasoningSummaryTextDelta { + item_id: "rs_123".into(), + output_index: 0, + delta: "Second part".into(), + }, + ResponsesStreamEvent::ReasoningSummaryTextDone { + item_id: "rs_123".into(), + output_index: 0, + text: "Second part".into(), + }, + ResponsesStreamEvent::ReasoningSummaryPartDone { + item_id: "rs_123".into(), + output_index: 0, + summary_index: 1, + }, + ResponsesStreamEvent::OutputItemDone { + output_index: 0, + sequence_number: None, + item: ResponseOutputItem::Reasoning(ResponseReasoningItem { + id: Some("rs_123".into()), + summary: vec![ + ReasoningSummaryPart::SummaryText { + text: "Thinking about the answer".into(), + }, + ReasoningSummaryPart::SummaryText { + text: "Second part".into(), + }, + ], + }), + }, + ResponsesStreamEvent::OutputItemAdded { + output_index: 1, + sequence_number: None, + item: response_item_message("msg_456"), + }, + ResponsesStreamEvent::OutputTextDelta { + item_id: "msg_456".into(), + output_index: 1, + content_index: Some(0), + delta: "The answer is 42".into(), + }, + ResponsesStreamEvent::Completed { + response: ResponseSummary::default(), + }, + ]; + + let mapped = map_response_events(events); + + let thinking_events: Vec<_> = mapped + .iter() + .filter(|e| matches!(e, LanguageModelCompletionEvent::Thinking { .. })) + .collect(); + assert_eq!( + thinking_events.len(), + 4, + "expected 4 thinking events, got {:?}", + thinking_events + ); + assert!( + matches!(&thinking_events[0], LanguageModelCompletionEvent::Thinking { text, .. } if text == "Thinking about") + ); + assert!( + matches!(&thinking_events[1], LanguageModelCompletionEvent::Thinking { text, .. } if text == " the answer") + ); + assert!( + matches!(&thinking_events[2], LanguageModelCompletionEvent::Thinking { text, .. } if text == "\n\n"), + "expected separator between summary parts" + ); + assert!( + matches!(&thinking_events[3], LanguageModelCompletionEvent::Thinking { text, .. } if text == "Second part") + ); + + assert!(mapped.iter().any( + |e| matches!(e, LanguageModelCompletionEvent::Text(t) if t == "The answer is 42") + )); + } + + #[test] + fn responses_stream_maps_reasoning_from_done_only() { + let events = vec![ + ResponsesStreamEvent::OutputItemAdded { + output_index: 0, + sequence_number: None, + item: ResponseOutputItem::Reasoning(ResponseReasoningItem { + id: Some("rs_789".into()), + summary: vec![], + }), + }, + ResponsesStreamEvent::OutputItemDone { + output_index: 0, + sequence_number: None, + item: ResponseOutputItem::Reasoning(ResponseReasoningItem { + id: Some("rs_789".into()), + summary: vec![ReasoningSummaryPart::SummaryText { + text: "Summary without deltas".into(), + }], + }), + }, + ResponsesStreamEvent::Completed { + response: ResponseSummary::default(), + }, + ]; + + let mapped = map_response_events(events); + assert!( + !mapped + .iter() + .any(|e| matches!(e, LanguageModelCompletionEvent::Thinking { .. })), + "OutputItemDone reasoning should not produce Thinking events" + ); + } +} diff --git a/crates/open_ai/src/open_ai.rs b/crates/open_ai/src/open_ai.rs index ef4a20c9153c7f4ee07160d0dd7c558d9392259e..17343d16d0582be9bc70322b0182037f6e1521e6 100644 --- a/crates/open_ai/src/open_ai.rs +++ b/crates/open_ai/src/open_ai.rs @@ -1,4 +1,5 @@ pub mod batches; +pub mod completion; pub mod responses; use anyhow::{Context as _, Result, anyhow}; @@ -7,9 +8,9 @@ use http_client::{ AsyncBody, HttpClient, Method, Request as HttpRequest, StatusCode, http::{HeaderMap, HeaderValue}, }; +pub use language_model_core::ReasoningEffort; use serde::{Deserialize, Serialize}; use serde_json::Value; -pub use settings::OpenAiReasoningEffort as ReasoningEffort; use std::{convert::TryFrom, future::Future}; use strum::EnumIter; use thiserror::Error; @@ -727,3 +728,26 @@ pub fn embed<'a>( Ok(response) } } + +// -- Conversions to `language_model_core` types -- + +impl From for language_model_core::LanguageModelCompletionError { + fn from(error: RequestError) -> Self { + match error { + RequestError::HttpResponseError { + provider, + status_code, + body, + headers, + } => { + let retry_after = headers + .get(http_client::http::header::RETRY_AFTER) + .and_then(|val| val.to_str().ok()?.parse::().ok()) + .map(std::time::Duration::from_secs); + + Self::from_http_status(provider.into(), status_code, body, retry_after) + } + RequestError::Other(e) => Self::Other(e), + } + } +} diff --git a/crates/open_router/Cargo.toml b/crates/open_router/Cargo.toml index cccb92c33b05b8fff0e5e78277c9f7fa29844ace..2cc5d3d00e2eb5d755cef971be51a315bcdf254f 100644 --- a/crates/open_router/Cargo.toml +++ b/crates/open_router/Cargo.toml @@ -19,6 +19,7 @@ schemars = ["dep:schemars"] anyhow.workspace = true futures.workspace = true http_client.workspace = true +language_model_core.workspace = true schemars = { workspace = true, optional = true } serde.workspace = true serde_json.workspace = true diff --git a/crates/open_router/src/open_router.rs b/crates/open_router/src/open_router.rs index 9841c7b1ae19a57878fd8e84625bc4058b809613..b94631f9a0e6764ab5cfe487e7851a820fa80b1d 100644 --- a/crates/open_router/src/open_router.rs +++ b/crates/open_router/src/open_router.rs @@ -744,3 +744,71 @@ impl ApiErrorCode { } } } + +// -- Conversions to `language_model_core` types -- + +impl From for language_model_core::LanguageModelCompletionError { + fn from(error: OpenRouterError) -> Self { + let provider = language_model_core::LanguageModelProviderName::new("OpenRouter"); + match error { + OpenRouterError::SerializeRequest(error) => Self::SerializeRequest { provider, error }, + OpenRouterError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error }, + OpenRouterError::HttpSend(error) => Self::HttpSend { provider, error }, + OpenRouterError::DeserializeResponse(error) => { + Self::DeserializeResponse { provider, error } + } + OpenRouterError::ReadResponse(error) => Self::ApiReadResponseError { provider, error }, + OpenRouterError::RateLimit { retry_after } => Self::RateLimitExceeded { + provider, + retry_after: Some(retry_after), + }, + OpenRouterError::ServerOverloaded { retry_after } => Self::ServerOverloaded { + provider, + retry_after, + }, + OpenRouterError::ApiError(api_error) => api_error.into(), + } + } +} + +impl From for language_model_core::LanguageModelCompletionError { + fn from(error: ApiError) -> Self { + use ApiErrorCode::*; + let provider = language_model_core::LanguageModelProviderName::new("OpenRouter"); + match error.code { + InvalidRequestError => Self::BadRequestFormat { + provider, + message: error.message, + }, + AuthenticationError => Self::AuthenticationError { + provider, + message: error.message, + }, + PaymentRequiredError => Self::AuthenticationError { + provider, + message: format!("Payment required: {}", error.message), + }, + PermissionError => Self::PermissionError { + provider, + message: error.message, + }, + RequestTimedOut => Self::HttpResponseError { + provider, + status_code: http_client::StatusCode::REQUEST_TIMEOUT, + message: error.message, + }, + RateLimitError => Self::RateLimitExceeded { + provider, + retry_after: None, + }, + ApiError => Self::ApiInternalServerError { + provider, + message: error.message, + }, + OverloadedError => Self::ServerOverloaded { + provider, + retry_after: None, + }, + } + } +} diff --git a/crates/outline_panel/src/outline_panel.rs b/crates/outline_panel/src/outline_panel.rs index b7d5afcb687c017fdf253717a9dae2c95c55b53b..fa23b805cd48461dabaddbb7670155cdfe1ba8b0 100644 --- a/crates/outline_panel/src/outline_panel.rs +++ b/crates/outline_panel/src/outline_panel.rs @@ -23,8 +23,8 @@ use gpui::{ uniform_list, }; use itertools::Itertools; -use language::language_settings::LanguageSettings; use language::{Anchor, BufferId, BufferSnapshot, OffsetRangeExt, OutlineItem}; +use language::{LanguageAwareStyling, language_settings::LanguageSettings}; use menu::{Cancel, SelectFirst, SelectLast, SelectNext, SelectPrevious}; use std::{ @@ -217,10 +217,13 @@ impl SearchState { let mut offset = context_offset_range.start; let mut context_text = String::new(); let mut highlight_ranges = Vec::new(); - for mut chunk in highlight_arguments - .multi_buffer_snapshot - .chunks(context_offset_range.start..context_offset_range.end, true) - { + for mut chunk in highlight_arguments.multi_buffer_snapshot.chunks( + context_offset_range.start..context_offset_range.end, + LanguageAwareStyling { + tree_sitter: true, + diagnostics: true, + }, + ) { if !non_whitespace_symbol_occurred { for c in chunk.text.chars() { if c.is_whitespace() { diff --git a/crates/picker/src/highlighted_match_with_paths.rs b/crates/picker/src/highlighted_match_with_paths.rs index 74271047621b26be573dc2eebfffe9e9e0f1a138..7c88213437feea17e6b431dff9c97b0b8557872a 100644 --- a/crates/picker/src/highlighted_match_with_paths.rs +++ b/crates/picker/src/highlighted_match_with_paths.rs @@ -5,6 +5,7 @@ pub struct HighlightedMatchWithPaths { pub prefix: Option, pub match_label: HighlightedMatch, pub paths: Vec, + pub active: bool, } #[derive(Debug, Clone, IntoElement)] @@ -63,18 +64,30 @@ impl HighlightedMatchWithPaths { .color(Color::Muted) })) } + + pub fn is_active(mut self, active: bool) -> Self { + self.active = active; + self + } } impl RenderOnce for HighlightedMatchWithPaths { fn render(mut self, _window: &mut Window, _: &mut App) -> impl IntoElement { v_flex() .child( - h_flex().gap_1().child(self.match_label.clone()).when_some( - self.prefix.as_ref(), - |this, prefix| { + h_flex() + .gap_1() + .child(self.match_label.clone()) + .when_some(self.prefix.as_ref(), |this, prefix| { this.child(Label::new(format!("({})", prefix)).color(Color::Muted)) - }, - ), + }) + .when(self.active, |this| { + this.child( + Icon::new(IconName::Check) + .size(IconSize::Small) + .color(Color::Accent), + ) + }), ) .when(!self.paths.is_empty(), |this| { self.render_paths_children(this) diff --git a/crates/project/Cargo.toml b/crates/project/Cargo.toml index cd037786a399eb979fd5d9053c57efe3100dd473..628e979aab939a74bb4838477ae3e3657e2c91bc 100644 --- a/crates/project/Cargo.toml +++ b/crates/project/Cargo.toml @@ -52,6 +52,7 @@ fancy-regex.workspace = true fs.workspace = true futures.workspace = true fuzzy.workspace = true +fuzzy_nucleo.workspace = true git.workspace = true git_hosting_providers.workspace = true globset.workspace = true diff --git a/crates/project/src/git_store.rs b/crates/project/src/git_store.rs index e7e84ffe673881d898a56b64892887b9c8d6c809..8da5a14e41d9cb97865d78f4dfc2ed79f76faebd 100644 --- a/crates/project/src/git_store.rs +++ b/crates/project/src/git_store.rs @@ -32,10 +32,10 @@ use git::{ blame::Blame, parse_git_remote_url, repository::{ - Branch, CommitDetails, CommitDiff, CommitFile, CommitOptions, DiffType, FetchOptions, - GitRepository, GitRepositoryCheckpoint, GraphCommitData, InitialGraphCommitData, LogOrder, - LogSource, PushOptions, Remote, RemoteCommandOutput, RepoPath, ResetMode, SearchCommitArgs, - UpstreamTrackingStatus, Worktree as GitWorktree, + Branch, CommitDetails, CommitDiff, CommitFile, CommitOptions, CreateWorktreeTarget, + DiffType, FetchOptions, GitRepository, GitRepositoryCheckpoint, GraphCommitData, + InitialGraphCommitData, LogOrder, LogSource, PushOptions, Remote, RemoteCommandOutput, + RepoPath, ResetMode, SearchCommitArgs, UpstreamTrackingStatus, Worktree as GitWorktree, }, stash::{GitStash, StashEntry}, status::{ @@ -329,12 +329,6 @@ pub struct GraphDataResponse<'a> { pub error: Option, } -#[derive(Clone, Debug)] -enum CreateWorktreeStartPoint { - Detached, - Branched { name: String }, -} - pub struct Repository { this: WeakEntity, snapshot: RepositorySnapshot, @@ -2414,18 +2408,23 @@ impl GitStore { let repository_id = RepositoryId::from_proto(envelope.payload.repository_id); let repository_handle = Self::repository_for_request(&this, repository_id, &mut cx)?; let directory = PathBuf::from(envelope.payload.directory); - let start_point = if envelope.payload.name.is_empty() { - CreateWorktreeStartPoint::Detached + let name = envelope.payload.name; + let commit = envelope.payload.commit; + let use_existing_branch = envelope.payload.use_existing_branch; + let target = if name.is_empty() { + CreateWorktreeTarget::Detached { base_sha: commit } + } else if use_existing_branch { + CreateWorktreeTarget::ExistingBranch { branch_name: name } } else { - CreateWorktreeStartPoint::Branched { - name: envelope.payload.name, + CreateWorktreeTarget::NewBranch { + branch_name: name, + base_sha: commit, } }; - let commit = envelope.payload.commit; repository_handle .update(&mut cx, |repository_handle, _| { - repository_handle.create_worktree_with_start_point(start_point, directory, commit) + repository_handle.create_worktree(target, directory) }) .await??; @@ -6004,50 +6003,43 @@ impl Repository { }) } - fn create_worktree_with_start_point( + pub fn create_worktree( &mut self, - start_point: CreateWorktreeStartPoint, + target: CreateWorktreeTarget, path: PathBuf, - commit: Option, ) -> oneshot::Receiver> { - if matches!( - &start_point, - CreateWorktreeStartPoint::Branched { name } if name.is_empty() - ) { - let (sender, receiver) = oneshot::channel(); - sender - .send(Err(anyhow!("branch name cannot be empty"))) - .ok(); - return receiver; - } - let id = self.id; - let message = match &start_point { - CreateWorktreeStartPoint::Detached => "git worktree add (detached)".into(), - CreateWorktreeStartPoint::Branched { name } => { - format!("git worktree add: {name}").into() - } + let job_description = match target.branch_name() { + Some(branch_name) => format!("git worktree add: {branch_name}"), + None => "git worktree add (detached)".to_string(), }; - - self.send_job(Some(message), move |repo, _cx| async move { - let branch_name = match start_point { - CreateWorktreeStartPoint::Detached => None, - CreateWorktreeStartPoint::Branched { name } => Some(name), - }; - let remote_name = branch_name.clone().unwrap_or_default(); - + self.send_job(Some(job_description.into()), move |repo, _cx| async move { match repo { RepositoryState::Local(LocalRepositoryState { backend, .. }) => { - backend.create_worktree(branch_name, path, commit).await + backend.create_worktree(target, path).await } RepositoryState::Remote(RemoteRepositoryState { project_id, client }) => { + let (name, commit, use_existing_branch) = match target { + CreateWorktreeTarget::ExistingBranch { branch_name } => { + (branch_name, None, true) + } + CreateWorktreeTarget::NewBranch { + branch_name, + base_sha: start_point, + } => (branch_name, start_point, false), + CreateWorktreeTarget::Detached { + base_sha: start_point, + } => (String::new(), start_point, false), + }; + client .request(proto::GitCreateWorktree { project_id: project_id.0, repository_id: id.to_proto(), - name: remote_name, + name, directory: path.to_string_lossy().to_string(), commit, + use_existing_branch, }) .await?; @@ -6057,28 +6049,16 @@ impl Repository { }) } - pub fn create_worktree( - &mut self, - branch_name: String, - path: PathBuf, - commit: Option, - ) -> oneshot::Receiver> { - self.create_worktree_with_start_point( - CreateWorktreeStartPoint::Branched { name: branch_name }, - path, - commit, - ) - } - pub fn create_worktree_detached( &mut self, path: PathBuf, commit: String, ) -> oneshot::Receiver> { - self.create_worktree_with_start_point( - CreateWorktreeStartPoint::Detached, + self.create_worktree( + CreateWorktreeTarget::Detached { + base_sha: Some(commit), + }, path, - Some(commit), ) } diff --git a/crates/project/src/lsp_store.rs b/crates/project/src/lsp_store.rs index 2f579f5a724db143bbd4b0f9853a217bd6b14655..9ea50fdc8f12b68147c1073219625c4fd257afd3 100644 --- a/crates/project/src/lsp_store.rs +++ b/crates/project/src/lsp_store.rs @@ -72,9 +72,10 @@ use itertools::Itertools as _; use language::{ Bias, BinaryStatus, Buffer, BufferRow, BufferSnapshot, CachedLspAdapter, Capability, CodeLabel, CodeLabelExt, Diagnostic, DiagnosticEntry, DiagnosticSet, DiagnosticSourceKind, Diff, - File as _, Language, LanguageName, LanguageRegistry, LocalFile, LspAdapter, LspAdapterDelegate, - LspInstaller, ManifestDelegate, ManifestName, ModelineSettings, OffsetUtf16, Patch, PointUtf16, - TextBufferSnapshot, ToOffset, ToOffsetUtf16, ToPointUtf16, Toolchain, Transaction, Unclipped, + File as _, Language, LanguageAwareStyling, LanguageName, LanguageRegistry, LocalFile, + LspAdapter, LspAdapterDelegate, LspInstaller, ManifestDelegate, ManifestName, ModelineSettings, + OffsetUtf16, Patch, PointUtf16, TextBufferSnapshot, ToOffset, ToOffsetUtf16, ToPointUtf16, + Toolchain, Transaction, Unclipped, language_settings::{ AllLanguageSettings, FormatOnSave, Formatter, LanguageSettings, all_language_settings, }, @@ -13527,7 +13528,13 @@ fn resolve_word_completion(snapshot: &BufferSnapshot, completion: &mut Completio } let mut offset = 0; - for chunk in snapshot.chunks(word_range.clone(), true) { + for chunk in snapshot.chunks( + word_range.clone(), + LanguageAwareStyling { + tree_sitter: true, + diagnostics: true, + }, + ) { let end_offset = offset + chunk.text.len(); if let Some(highlight_id) = chunk.syntax_highlight_id { completion diff --git a/crates/project/src/prettier_store.rs b/crates/project/src/prettier_store.rs index b66f2d5e0c041e104cf109a48b6bad249b492b88..faa2cca79866f31682a497eebab819b75e778ffb 100644 --- a/crates/project/src/prettier_store.rs +++ b/crates/project/src/prettier_store.rs @@ -412,7 +412,7 @@ impl PrettierStore { prettier_store .update(cx, |prettier_store, cx| { let name = if is_default { - LanguageServerName("prettier (default)".to_string().into()) + LanguageServerName("prettier (default)".into()) } else { let worktree_path = worktree_id .and_then(|id| { diff --git a/crates/project/src/project.rs b/crates/project/src/project.rs index 0ec3366ca8f9f6c6e4e3cbd411e1894de4d0f2b8..b90972b3489c25f8a2bf10d7dbdb6d6cfe0c4c6c 100644 --- a/crates/project/src/project.rs +++ b/crates/project/src/project.rs @@ -6186,6 +6186,76 @@ impl<'a> Iterator for PathMatchCandidateSetIter<'a> { } } +impl<'a> fuzzy_nucleo::PathMatchCandidateSet<'a> for PathMatchCandidateSet { + type Candidates = PathMatchCandidateSetNucleoIter<'a>; + fn id(&self) -> usize { + self.snapshot.id().to_usize() + } + fn len(&self) -> usize { + match self.candidates { + Candidates::Files => { + if self.include_ignored { + self.snapshot.file_count() + } else { + self.snapshot.visible_file_count() + } + } + Candidates::Directories => { + if self.include_ignored { + self.snapshot.dir_count() + } else { + self.snapshot.visible_dir_count() + } + } + Candidates::Entries => { + if self.include_ignored { + self.snapshot.entry_count() + } else { + self.snapshot.visible_entry_count() + } + } + } + } + fn prefix(&self) -> Arc { + if self.snapshot.root_entry().is_some_and(|e| e.is_file()) || self.include_root_name { + self.snapshot.root_name().into() + } else { + RelPath::empty().into() + } + } + fn root_is_file(&self) -> bool { + self.snapshot.root_entry().is_some_and(|f| f.is_file()) + } + fn path_style(&self) -> PathStyle { + self.snapshot.path_style() + } + fn candidates(&'a self, start: usize) -> Self::Candidates { + PathMatchCandidateSetNucleoIter { + traversal: match self.candidates { + Candidates::Directories => self.snapshot.directories(self.include_ignored, start), + Candidates::Files => self.snapshot.files(self.include_ignored, start), + Candidates::Entries => self.snapshot.entries(self.include_ignored, start), + }, + } + } +} + +pub struct PathMatchCandidateSetNucleoIter<'a> { + traversal: Traversal<'a>, +} + +impl<'a> Iterator for PathMatchCandidateSetNucleoIter<'a> { + type Item = fuzzy_nucleo::PathMatchCandidate<'a>; + fn next(&mut self) -> Option { + self.traversal + .next() + .map(|entry| fuzzy_nucleo::PathMatchCandidate { + is_dir: entry.kind.is_dir(), + path: &entry.path, + }) + } +} + impl EventEmitter for Project {} impl<'a> From<&'a ProjectPath> for SettingsLocation<'a> { diff --git a/crates/project/tests/integration/git_store.rs b/crates/project/tests/integration/git_store.rs index 02f752b28b24a8135e2cba9307a5eacdc16f0fa3..bbe5c64d7cf7f5b2ffa9160df6130cd88ddc5d69 100644 --- a/crates/project/tests/integration/git_store.rs +++ b/crates/project/tests/integration/git_store.rs @@ -1267,9 +1267,11 @@ mod git_worktrees { cx.update(|cx| { repository.update(cx, |repository, _| { repository.create_worktree( - "feature-branch".to_string(), + git::repository::CreateWorktreeTarget::NewBranch { + branch_name: "feature-branch".to_string(), + base_sha: Some("abc123".to_string()), + }, worktree_1_directory.clone(), - Some("abc123".to_string()), ) }) }) @@ -1297,9 +1299,11 @@ mod git_worktrees { cx.update(|cx| { repository.update(cx, |repository, _| { repository.create_worktree( - "bugfix-branch".to_string(), + git::repository::CreateWorktreeTarget::NewBranch { + branch_name: "bugfix-branch".to_string(), + base_sha: None, + }, worktree_2_directory.clone(), - None, ) }) }) diff --git a/crates/project/tests/integration/project_tests.rs b/crates/project/tests/integration/project_tests.rs index d6c2ce37c9e60e17bd43c3f6c3ad10cde52b4bec..f680ccee78e997064af2647f68d8aa3631fa4bd3 100644 --- a/crates/project/tests/integration/project_tests.rs +++ b/crates/project/tests/integration/project_tests.rs @@ -41,9 +41,10 @@ use gpui::{ use itertools::Itertools; use language::{ Buffer, BufferEvent, Diagnostic, DiagnosticEntry, DiagnosticEntryRef, DiagnosticSet, - DiagnosticSourceKind, DiskState, FakeLspAdapter, Language, LanguageConfig, LanguageMatcher, - LanguageName, LineEnding, ManifestName, ManifestProvider, ManifestQuery, OffsetRangeExt, Point, - ToPoint, Toolchain, ToolchainList, ToolchainLister, ToolchainMetadata, + DiagnosticSourceKind, DiskState, FakeLspAdapter, Language, LanguageAwareStyling, + LanguageConfig, LanguageMatcher, LanguageName, LineEnding, ManifestName, ManifestProvider, + ManifestQuery, OffsetRangeExt, Point, ToPoint, Toolchain, ToolchainList, ToolchainLister, + ToolchainMetadata, language_settings::{LanguageSettings, LanguageSettingsContent}, markdown_lang, rust_lang, tree_sitter_typescript, }; @@ -4382,7 +4383,13 @@ fn chunks_with_diagnostics( range: Range, ) -> Vec<(String, Option)> { let mut chunks: Vec<(String, Option)> = Vec::new(); - for chunk in buffer.snapshot().chunks(range, true) { + for chunk in buffer.snapshot().chunks( + range, + LanguageAwareStyling { + tree_sitter: true, + diagnostics: true, + }, + ) { if chunks .last() .is_some_and(|prev_chunk| prev_chunk.1 == chunk.diagnostic_severity) diff --git a/crates/project_panel/src/project_panel_tests.rs b/crates/project_panel/src/project_panel_tests.rs index 55b53cde8b6252f8b9732cf4effc35ea53c073e0..603cfd892a218d866383f485d058296ad179da05 100644 --- a/crates/project_panel/src/project_panel_tests.rs +++ b/crates/project_panel/src/project_panel_tests.rs @@ -11,7 +11,7 @@ use std::path::{Path, PathBuf}; use util::{path, paths::PathStyle, rel_path::rel_path}; use workspace::{ AppState, ItemHandle, MultiWorkspace, Pane, Workspace, - item::{Item, ProjectItem}, + item::{Item, ProjectItem, test::TestItem}, register_project_item, }; @@ -6015,6 +6015,150 @@ async fn test_explicit_reveal(cx: &mut gpui::TestAppContext) { ); } +#[gpui::test] +async fn test_reveal_in_project_panel_notifications(cx: &mut gpui::TestAppContext) { + init_test_with_editor(cx); + let fs = FakeFs::new(cx.background_executor.clone()); + fs.insert_tree( + "/workspace", + json!({ + "README.md": "" + }), + ) + .await; + + let project = Project::test(fs.clone(), ["/workspace".as_ref()], cx).await; + let window = cx.add_window(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx)); + let workspace = window + .read_with(cx, |mw, _| mw.workspace().clone()) + .unwrap(); + let cx = &mut VisualTestContext::from_window(window.into(), cx); + let panel = workspace.update_in(cx, ProjectPanel::new); + cx.run_until_parked(); + + // Ensure that, attempting to run `pane: reveal in project panel` without + // any active item does nothing, i.e., does not focus the project panel but + // it also does not show a notification. + cx.dispatch_action(workspace::RevealInProjectPanel::default()); + cx.run_until_parked(); + + panel.update_in(cx, |panel, window, cx| { + assert!( + !panel.focus_handle(cx).is_focused(window), + "Project panel should not be focused after attempting to reveal an invisible worktree entry" + ); + + panel.workspace.update(cx, |workspace, cx| { + assert!( + workspace.active_item(cx).is_none(), + "Workspace should not have an active item" + ); + assert_eq!( + workspace.notification_ids(), + vec![], + "No notification should be shown when there's no active item" + ); + }).unwrap(); + }); + + // Create a file in a different folder than the one in the project so we can + // later open it and ensure that, attempting to reveal it in the project + // panel shows a notification and does not focus the project panel. + fs.insert_tree( + "/external", + json!({ + "file.txt": "External File", + }), + ) + .await; + + let (worktree, _) = project + .update(cx, |project, cx| { + project.find_or_create_worktree("/external/file.txt", false, cx) + }) + .await + .unwrap(); + + workspace + .update_in(cx, |workspace, window, cx| { + let worktree_id = worktree.read(cx).id(); + let path = rel_path("").into(); + let project_path = ProjectPath { worktree_id, path }; + + workspace.open_path(project_path, None, true, window, cx) + }) + .await + .unwrap(); + cx.run_until_parked(); + + cx.dispatch_action(workspace::RevealInProjectPanel::default()); + cx.run_until_parked(); + + panel.update_in(cx, |panel, window, cx| { + assert!( + !panel.focus_handle(cx).is_focused(window), + "Project panel should not be focused after attempting to reveal an invisible worktree entry" + ); + + panel.workspace.update(cx, |workspace, cx| { + assert!( + workspace.active_item(cx).is_some(), + "Workspace should have an active item" + ); + + let notification_ids = workspace.notification_ids(); + assert_eq!( + notification_ids.len(), + 1, + "A notification should be shown when trying to reveal an invisible worktree entry" + ); + + workspace.dismiss_notification(¬ification_ids[0], cx); + assert_eq!( + workspace.notification_ids().len(), + 0, + "No notifications should be left after dismissing" + ); + }).unwrap(); + }); + + // Create an empty buffer so we can ensure that, attempting to reveal it in + // the project panel shows a notification and does not focus the project + // panel. + let pane = workspace.update(cx, |workspace, _| workspace.active_pane().clone()); + pane.update_in(cx, |pane, window, cx| { + let item = cx.new(|cx| TestItem::new(cx).with_label("Unsaved buffer")); + pane.add_item(Box::new(item), false, false, None, window, cx); + }); + + cx.dispatch_action(workspace::RevealInProjectPanel::default()); + cx.run_until_parked(); + + panel.update_in(cx, |panel, window, cx| { + assert!( + !panel.focus_handle(cx).is_focused(window), + "Project panel should not be focused after attempting to reveal an unsaved buffer" + ); + + panel + .workspace + .update(cx, |workspace, cx| { + assert!( + workspace.active_item(cx).is_some(), + "Workspace should have an active item" + ); + + let notification_ids = workspace.notification_ids(); + assert_eq!( + notification_ids.len(), + 1, + "A notification should be shown when trying to reveal an unsaved buffer" + ); + }) + .unwrap(); + }); +} + #[gpui::test] async fn test_creating_excluded_entries(cx: &mut gpui::TestAppContext) { init_test(cx); diff --git a/crates/proto/proto/git.proto b/crates/proto/proto/git.proto index 9324feb21b1f50ac1041ed0afc8b59cb9b7fe2c6..d0a594a2817ec50d9d35383587619e311f2950d8 100644 --- a/crates/proto/proto/git.proto +++ b/crates/proto/proto/git.proto @@ -594,6 +594,7 @@ message GitCreateWorktree { string name = 3; string directory = 4; optional string commit = 5; + bool use_existing_branch = 6; } message GitCreateCheckpoint { diff --git a/crates/recent_projects/src/recent_projects.rs b/crates/recent_projects/src/recent_projects.rs index e3bfc0dc08c95c0ce57b818e50965433a6c6bc98..57754dadec20146cb1f21039266de88a0bd5da9f 100644 --- a/crates/recent_projects/src/recent_projects.rs +++ b/crates/recent_projects/src/recent_projects.rs @@ -720,6 +720,9 @@ impl RecentProjects { picker.delegate.workspaces.get(hit.candidate_id) { let workspace_id = *workspace_id; + if picker.delegate.is_current_workspace(workspace_id, cx) { + return; + } picker .delegate .remove_sibling_workspace(workspace_id, window, cx); @@ -939,7 +942,7 @@ impl PickerDelegate for RecentProjectsDelegate { .workspaces .iter() .enumerate() - .filter(|(_, (id, _, _, _))| self.is_sibling_workspace(*id, cx)) + .filter(|(_, (id, _, _, _))| self.sibling_workspace_ids.contains(id)) .map(|(id, (_, _, paths, _))| { let combined_string = paths .ordered_paths() @@ -1028,7 +1031,7 @@ impl PickerDelegate for RecentProjectsDelegate { if is_empty_query { for (id, (workspace_id, _, _, _)) in self.workspaces.iter().enumerate() { - if self.is_sibling_workspace(*workspace_id, cx) { + if self.sibling_workspace_ids.contains(workspace_id) { entries.push(ProjectPickerEntry::OpenProject(StringMatch { candidate_id: id, score: 0.0, @@ -1106,6 +1109,11 @@ impl PickerDelegate for RecentProjectsDelegate { }; let workspace_id = *workspace_id; + if self.is_current_workspace(workspace_id, cx) { + cx.emit(DismissEvent); + return; + } + if let Some(handle) = window.window_handle().downcast::() { cx.defer(move |cx| { handle @@ -1349,6 +1357,7 @@ impl PickerDelegate for RecentProjectsDelegate { ProjectPickerEntry::OpenProject(hit) => { let (workspace_id, location, paths, _) = self.workspaces.get(hit.candidate_id)?; let workspace_id = *workspace_id; + let is_current = self.is_current_workspace(workspace_id, cx); let ordered_paths: Vec<_> = paths .ordered_paths() .map(|p| p.compact().to_string_lossy().to_string()) @@ -1388,6 +1397,7 @@ impl PickerDelegate for RecentProjectsDelegate { prefix, match_label: HighlightedMatch::join(match_labels.into_iter().flatten(), ", "), paths, + active: is_current, }; let icon = icon_for_remote_connection(match location { @@ -1397,20 +1407,24 @@ impl PickerDelegate for RecentProjectsDelegate { let secondary_actions = h_flex() .gap_1() - .child( - IconButton::new("remove_open_project", IconName::Close) - .icon_size(IconSize::Small) - .tooltip(Tooltip::text("Remove Project from Window")) - .on_click(cx.listener(move |picker, _, window, cx| { - cx.stop_propagation(); - window.prevent_default(); - picker - .delegate - .remove_sibling_workspace(workspace_id, window, cx); - let query = picker.query(cx); - picker.update_matches(query, window, cx); - })), - ) + .when(!is_current, |this| { + this.child( + IconButton::new("remove_open_project", IconName::Close) + .icon_size(IconSize::Small) + .tooltip(Tooltip::text("Remove Project from Window")) + .on_click(cx.listener(move |picker, _, window, cx| { + cx.stop_propagation(); + window.prevent_default(); + picker.delegate.remove_sibling_workspace( + workspace_id, + window, + cx, + ); + let query = picker.query(cx); + picker.update_matches(query, window, cx); + })), + ) + }) .into_any_element(); Some( @@ -1483,6 +1497,7 @@ impl PickerDelegate for RecentProjectsDelegate { prefix, match_label: HighlightedMatch::join(match_labels.into_iter().flatten(), ", "), paths, + active: false, }; let focus_handle = self.focus_handle.clone(); @@ -1491,9 +1506,16 @@ impl PickerDelegate for RecentProjectsDelegate { .gap_px() .when(is_local, |this| { this.child( - IconButton::new("add_to_workspace", IconName::FolderPlus) + IconButton::new("add_to_workspace", IconName::FolderOpenAdd) .icon_size(IconSize::Small) - .tooltip(Tooltip::text("Add Project to this Workspace")) + .tooltip(move |_, cx| { + Tooltip::with_meta( + "Add Project to this Workspace", + None, + "As a multi-root folder project", + cx, + ) + }) .on_click({ let paths_to_add = paths_to_add.clone(); cx.listener(move |picker, _event, window, cx| { @@ -1509,8 +1531,8 @@ impl PickerDelegate for RecentProjectsDelegate { ) }) .child( - IconButton::new("open_new_window", IconName::ArrowUpRight) - .icon_size(IconSize::XSmall) + IconButton::new("open_new_window", IconName::OpenNewWindow) + .icon_size(IconSize::Small) .tooltip({ move |_, cx| { Tooltip::for_action_in( @@ -1565,7 +1587,14 @@ impl PickerDelegate for RecentProjectsDelegate { } highlighted.render(window, cx) }) - .tooltip(Tooltip::text(tooltip_path)), + .tooltip(move |_, cx| { + Tooltip::with_meta( + "Open Project in This Window", + None, + tooltip_path.clone(), + cx, + ) + }), ) .end_slot(secondary_actions) .show_end_slot_on_hover() @@ -1625,27 +1654,41 @@ impl PickerDelegate for RecentProjectsDelegate { let selected_entry = self.filtered_entries.get(self.selected_index); + let is_current_workspace_entry = + if let Some(ProjectPickerEntry::OpenProject(hit)) = selected_entry { + self.workspaces + .get(hit.candidate_id) + .map(|(id, ..)| self.is_current_workspace(*id, cx)) + .unwrap_or(false) + } else { + false + }; + let secondary_footer_actions: Option = match selected_entry { - Some(ProjectPickerEntry::OpenFolder { .. } | ProjectPickerEntry::OpenProject(_)) => { - let label = if matches!(selected_entry, Some(ProjectPickerEntry::OpenFolder { .. })) - { - "Remove Folder" - } else { - "Remove from Window" - }; - Some( - Button::new("remove_selected", label) - .key_binding(KeyBinding::for_action_in( - &RemoveSelected, - &focus_handle, - cx, - )) - .on_click(|_, window, cx| { - window.dispatch_action(RemoveSelected.boxed_clone(), cx) - }) - .into_any_element(), - ) - } + Some(ProjectPickerEntry::OpenFolder { .. }) => Some( + Button::new("remove_selected", "Remove Folder") + .key_binding(KeyBinding::for_action_in( + &RemoveSelected, + &focus_handle, + cx, + )) + .on_click(|_, window, cx| { + window.dispatch_action(RemoveSelected.boxed_clone(), cx) + }) + .into_any_element(), + ), + Some(ProjectPickerEntry::OpenProject(_)) if !is_current_workspace_entry => Some( + Button::new("remove_selected", "Remove from Window") + .key_binding(KeyBinding::for_action_in( + &RemoveSelected, + &focus_handle, + cx, + )) + .on_click(|_, window, cx| { + window.dispatch_action(RemoveSelected.boxed_clone(), cx) + }) + .into_any_element(), + ), Some(ProjectPickerEntry::RecentProject(_)) => Some( Button::new("delete_recent", "Delete") .key_binding(KeyBinding::for_action_in( @@ -1748,7 +1791,7 @@ impl PickerDelegate for RecentProjectsDelegate { menu.context(focus_handle) .when(show_add_to_workspace, |menu| { menu.action( - "Add to Workspace", + "Add to this Workspace", AddToWorkspace.boxed_clone(), ) .separator() diff --git a/crates/recent_projects/src/sidebar_recent_projects.rs b/crates/recent_projects/src/sidebar_recent_projects.rs index 1fe0d2ae86aefdad45136c496f8049689d77e048..dec269c07eada3a1d6172482cb886f9ed44d784c 100644 --- a/crates/recent_projects/src/sidebar_recent_projects.rs +++ b/crates/recent_projects/src/sidebar_recent_projects.rs @@ -374,6 +374,7 @@ impl PickerDelegate for SidebarRecentProjectsDelegate { prefix, match_label: HighlightedMatch::join(match_labels.into_iter().flatten(), ", "), paths: Vec::new(), + active: false, }; let icon = icon_for_remote_connection(match location { @@ -395,7 +396,14 @@ impl PickerDelegate for SidebarRecentProjectsDelegate { }) .child(highlighted_match.render(window, cx)), ) - .tooltip(Tooltip::text(tooltip_path)) + .tooltip(move |_, cx| { + Tooltip::with_meta( + "Open Project in This Window", + None, + tooltip_path.clone(), + cx, + ) + }) .into_any_element(), ) } diff --git a/crates/repl/src/kernels/ssh_kernel.rs b/crates/repl/src/kernels/ssh_kernel.rs index 53be6622379cfcbf3ceeb6db425eeede9b226860..797b111a14345267e01c60c6803787c8f1d0f6a2 100644 --- a/crates/repl/src/kernels/ssh_kernel.rs +++ b/crates/repl/src/kernels/ssh_kernel.rs @@ -215,7 +215,7 @@ impl SshRunningKernel { &session_id, ) .await - .context("failed to create iopub connection")?; + .context("Failed to create iopub connection. Is `ipykernel` installed in the remote environment? Try running `pip install ipykernel` on the remote host.")?; let peer_identity = runtimelib::peer_identity_for_session(&session_id)?; let shell_socket = runtimelib::create_client_shell_connection_with_identity( diff --git a/crates/repl/src/kernels/wsl_kernel.rs b/crates/repl/src/kernels/wsl_kernel.rs index d9ac05c5fc8c2cb756898ff449d6714b78cb7997..be76d7ddccb7f199a368b76a1f21bf65fe6f2902 100644 --- a/crates/repl/src/kernels/wsl_kernel.rs +++ b/crates/repl/src/kernels/wsl_kernel.rs @@ -354,7 +354,8 @@ impl WslRunningKernel { "", &session_id, ) - .await?; + .await + .context("Failed to create iopub connection. Is `ipykernel` installed in the WSL environment? Try running `pip install ipykernel` inside your WSL distribution.")?; let peer_identity = runtimelib::peer_identity_for_session(&session_id)?; let shell_socket = runtimelib::create_client_shell_connection_with_identity( diff --git a/crates/search/src/buffer_search.rs b/crates/search/src/buffer_search.rs index 46177c5642a8d05daaf22e9fb24b205cd10ca42b..3a5fbe3fcae6241495deb43930b83bb78ba81968 100644 --- a/crates/search/src/buffer_search.rs +++ b/crates/search/src/buffer_search.rs @@ -849,6 +849,7 @@ impl BufferSearchBar { let query_editor = cx.new(|cx| { let mut editor = Editor::auto_height(1, 4, window, cx); editor.set_use_autoclose(false); + editor.set_use_selection_highlight(false); editor }); cx.subscribe_in(&query_editor, window, Self::on_query_editor_event) diff --git a/crates/search/src/project_search.rs b/crates/search/src/project_search.rs index 1bccf1ae52fb2c52a8d01e53aabb1b3ff5c7c16f..7e7903674e3d883bfb98ac8d57b5f407237f66d1 100644 --- a/crates/search/src/project_search.rs +++ b/crates/search/src/project_search.rs @@ -769,6 +769,17 @@ impl ProjectSearchView { } } + fn set_search_option_enabled( + &mut self, + option: SearchOptions, + enabled: bool, + cx: &mut Context, + ) { + if self.search_options.contains(option) != enabled { + self.toggle_search_option(option, cx); + } + } + fn toggle_search_option(&mut self, option: SearchOptions, cx: &mut Context) { self.search_options.toggle(option); ActiveSettings::update_global(cx, |settings, cx| { @@ -928,6 +939,7 @@ impl ProjectSearchView { let mut editor = Editor::auto_height(1, 4, window, cx); editor.set_placeholder_text("Search all files…", window, cx); editor.set_use_autoclose(false); + editor.set_use_selection_highlight(false); editor.set_text(query_text, window, cx); editor }); @@ -1153,7 +1165,7 @@ impl ProjectSearchView { window: &mut Window, cx: &mut Context, ) { - Self::existing_or_new_search(workspace, None, &DeploySearch::find(), window, cx) + Self::existing_or_new_search(workspace, None, &DeploySearch::default(), window, cx) } fn existing_or_new_search( @@ -1203,8 +1215,29 @@ impl ProjectSearchView { search.update(cx, |search, cx| { search.replace_enabled |= action.replace_enabled; + if let Some(regex) = action.regex { + search.set_search_option_enabled(SearchOptions::REGEX, regex, cx); + } + if let Some(case_sensitive) = action.case_sensitive { + search.set_search_option_enabled(SearchOptions::CASE_SENSITIVE, case_sensitive, cx); + } + if let Some(whole_word) = action.whole_word { + search.set_search_option_enabled(SearchOptions::WHOLE_WORD, whole_word, cx); + } + if let Some(include_ignored) = action.include_ignored { + search.set_search_option_enabled( + SearchOptions::INCLUDE_IGNORED, + include_ignored, + cx, + ); + } + let query = action + .query + .as_deref() + .filter(|q| !q.is_empty()) + .or(query.as_deref()); if let Some(query) = query { - search.set_query(&query, window, cx); + search.set_query(query, window, cx); } if let Some(included_files) = action.included_files.as_deref() { search @@ -3101,7 +3134,7 @@ pub mod tests { ProjectSearchView::deploy_search( workspace, - &workspace::DeploySearch::find(), + &workspace::DeploySearch::default(), window, cx, ) @@ -3252,7 +3285,7 @@ pub mod tests { workspace.update_in(cx, |workspace, window, cx| { ProjectSearchView::deploy_search( workspace, - &workspace::DeploySearch::find(), + &workspace::DeploySearch::default(), window, cx, ) @@ -3325,7 +3358,7 @@ pub mod tests { ProjectSearchView::deploy_search( workspace, - &workspace::DeploySearch::find(), + &workspace::DeploySearch::default(), window, cx, ) @@ -4560,7 +4593,7 @@ pub mod tests { }); // Deploy a new search - cx.dispatch_action(DeploySearch::find()); + cx.dispatch_action(DeploySearch::default()); // Both panes should now have a project search in them workspace.update_in(cx, |workspace, window, cx| { @@ -4585,7 +4618,7 @@ pub mod tests { .unwrap(); // Deploy a new search - cx.dispatch_action(DeploySearch::find()); + cx.dispatch_action(DeploySearch::default()); // The project search view should now be focused in the second pane // And the number of items should be unchanged. @@ -4823,7 +4856,7 @@ pub mod tests { assert!(workspace.has_active_modal(window, cx)); }); - cx.dispatch_action(DeploySearch::find()); + cx.dispatch_action(DeploySearch::default()); workspace.update_in(cx, |workspace, window, cx| { assert!(!workspace.has_active_modal(window, cx)); @@ -5136,6 +5169,271 @@ pub mod tests { .unwrap(); } + #[gpui::test] + async fn test_deploy_search_applies_and_resets_options(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.background_executor.clone()); + fs.insert_tree( + path!("/dir"), + json!({ + "one.rs": "const ONE: usize = 1;", + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await; + let window = cx.add_window(|window, cx| MultiWorkspace::test_new(project, window, cx)); + let workspace = window + .read_with(cx, |mw, _| mw.workspace().clone()) + .unwrap(); + let cx = &mut VisualTestContext::from_window(window.into(), cx); + let search_bar = window.build_entity(cx, |_, _| ProjectSearchBar::new()); + + workspace.update_in(cx, |workspace, window, cx| { + workspace.panes()[0].update(cx, |pane, cx| { + pane.toolbar() + .update(cx, |toolbar, cx| toolbar.add_item(search_bar, window, cx)) + }); + + ProjectSearchView::deploy_search( + workspace, + &workspace::DeploySearch { + regex: Some(true), + case_sensitive: Some(true), + whole_word: Some(true), + include_ignored: Some(true), + query: Some("Test_Query".into()), + ..Default::default() + }, + window, + cx, + ) + }); + + let search_view = cx + .read(|cx| { + workspace + .read(cx) + .active_pane() + .read(cx) + .active_item() + .and_then(|item| item.downcast::()) + }) + .expect("Search view should be active after deploy"); + + search_view.update_in(cx, |search_view, _window, cx| { + assert!( + search_view.search_options.contains(SearchOptions::REGEX), + "Regex option should be enabled" + ); + assert!( + search_view + .search_options + .contains(SearchOptions::CASE_SENSITIVE), + "Case sensitive option should be enabled" + ); + assert!( + search_view + .search_options + .contains(SearchOptions::WHOLE_WORD), + "Whole word option should be enabled" + ); + assert!( + search_view + .search_options + .contains(SearchOptions::INCLUDE_IGNORED), + "Include ignored option should be enabled" + ); + let query_text = search_view.query_editor.read(cx).text(cx); + assert_eq!( + query_text, "Test_Query", + "Query should be set from the action" + ); + }); + + // Redeploy with only regex - unspecified options should be preserved. + cx.dispatch_action(menu::Cancel); + workspace.update_in(cx, |workspace, window, cx| { + ProjectSearchView::deploy_search( + workspace, + &workspace::DeploySearch { + regex: Some(true), + ..Default::default() + }, + window, + cx, + ) + }); + + search_view.update_in(cx, |search_view, _window, _cx| { + assert!( + search_view.search_options.contains(SearchOptions::REGEX), + "Regex should still be enabled" + ); + assert!( + search_view + .search_options + .contains(SearchOptions::CASE_SENSITIVE), + "Case sensitive should be preserved from previous deploy" + ); + assert!( + search_view + .search_options + .contains(SearchOptions::WHOLE_WORD), + "Whole word should be preserved from previous deploy" + ); + assert!( + search_view + .search_options + .contains(SearchOptions::INCLUDE_IGNORED), + "Include ignored should be preserved from previous deploy" + ); + }); + + // Redeploy explicitly turning off options. + cx.dispatch_action(menu::Cancel); + workspace.update_in(cx, |workspace, window, cx| { + ProjectSearchView::deploy_search( + workspace, + &workspace::DeploySearch { + regex: Some(true), + case_sensitive: Some(false), + whole_word: Some(false), + include_ignored: Some(false), + ..Default::default() + }, + window, + cx, + ) + }); + + search_view.update_in(cx, |search_view, _window, _cx| { + assert_eq!( + search_view.search_options, + SearchOptions::REGEX, + "Explicit Some(false) should turn off options" + ); + }); + + // Redeploy with an empty query - should not overwrite the existing query. + cx.dispatch_action(menu::Cancel); + workspace.update_in(cx, |workspace, window, cx| { + ProjectSearchView::deploy_search( + workspace, + &workspace::DeploySearch { + query: Some("".into()), + ..Default::default() + }, + window, + cx, + ) + }); + + search_view.update_in(cx, |search_view, _window, cx| { + let query_text = search_view.query_editor.read(cx).text(cx); + assert_eq!( + query_text, "Test_Query", + "Empty query string should not overwrite the existing query" + ); + }); + } + + #[gpui::test] + async fn test_smartcase_overrides_explicit_case_sensitive(cx: &mut TestAppContext) { + init_test(cx); + + cx.update(|cx| { + cx.update_global::(|store, cx| { + store.update_default_settings(cx, |settings| { + settings.editor.use_smartcase_search = Some(true); + }); + }); + }); + + let fs = FakeFs::new(cx.background_executor.clone()); + fs.insert_tree( + path!("/dir"), + json!({ + "one.rs": "const ONE: usize = 1;", + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await; + let window = cx.add_window(|window, cx| MultiWorkspace::test_new(project, window, cx)); + let workspace = window + .read_with(cx, |mw, _| mw.workspace().clone()) + .unwrap(); + let cx = &mut VisualTestContext::from_window(window.into(), cx); + let search_bar = window.build_entity(cx, |_, _| ProjectSearchBar::new()); + + workspace.update_in(cx, |workspace, window, cx| { + workspace.panes()[0].update(cx, |pane, cx| { + pane.toolbar() + .update(cx, |toolbar, cx| toolbar.add_item(search_bar, window, cx)) + }); + + ProjectSearchView::deploy_search( + workspace, + &workspace::DeploySearch { + case_sensitive: Some(true), + query: Some("lowercase_query".into()), + ..Default::default() + }, + window, + cx, + ) + }); + + let search_view = cx + .read(|cx| { + workspace + .read(cx) + .active_pane() + .read(cx) + .active_item() + .and_then(|item| item.downcast::()) + }) + .expect("Search view should be active after deploy"); + + // Smartcase should override the explicit case_sensitive flag + // because the query is all lowercase. + search_view.update_in(cx, |search_view, _window, cx| { + assert!( + !search_view + .search_options + .contains(SearchOptions::CASE_SENSITIVE), + "Smartcase should disable case sensitivity for a lowercase query, \ + even when case_sensitive was explicitly set in the action" + ); + let query_text = search_view.query_editor.read(cx).text(cx); + assert_eq!(query_text, "lowercase_query"); + }); + + // Now deploy with an uppercase query - smartcase should enable case sensitivity. + workspace.update_in(cx, |workspace, window, cx| { + ProjectSearchView::deploy_search( + workspace, + &workspace::DeploySearch { + query: Some("Uppercase_Query".into()), + ..Default::default() + }, + window, + cx, + ) + }); + + search_view.update_in(cx, |search_view, _window, cx| { + assert!( + search_view + .search_options + .contains(SearchOptions::CASE_SENSITIVE), + "Smartcase should enable case sensitivity for a query containing uppercase" + ); + let query_text = search_view.query_editor.read(cx).text(cx); + assert_eq!(query_text, "Uppercase_Query"); + }); + } + fn init_test(cx: &mut TestAppContext) { cx.update(|cx| { let settings = SettingsStore::test(cx); diff --git a/crates/settings/src/vscode_import.rs b/crates/settings/src/vscode_import.rs index 1211cbd8a4519ea295773eb0d979b48258908311..4c7ce085aed5ad0cf7c48308b4211815cf5aad75 100644 --- a/crates/settings/src/vscode_import.rs +++ b/crates/settings/src/vscode_import.rs @@ -198,7 +198,7 @@ impl VsCodeSettings { log: None, message_editor: None, node: self.node_binary_settings(), - notification_panel: None, + outline_panel: self.outline_panel_settings_content(), preview_tabs: self.preview_tabs_settings_content(), project: self.project_settings_content(), diff --git a/crates/settings_content/Cargo.toml b/crates/settings_content/Cargo.toml index b3599e9eef3b7ac5680f441369a7cbdc98a5d043..59cccb4167ed64a2ece8ae5a73ac570ca7dabd97 100644 --- a/crates/settings_content/Cargo.toml +++ b/crates/settings_content/Cargo.toml @@ -19,6 +19,7 @@ anyhow.workspace = true collections.workspace = true derive_more.workspace = true gpui.workspace = true +language_model_core.workspace = true log.workspace = true schemars.workspace = true serde.workspace = true diff --git a/crates/settings_content/src/agent.rs b/crates/settings_content/src/agent.rs index 5b1b3c014f8c538cb0dff506e05d84a80dc863d1..7a9a1ddb16ac91f90f73e17b3972cd31536d7a66 100644 --- a/crates/settings_content/src/agent.rs +++ b/crates/settings_content/src/agent.rs @@ -128,6 +128,12 @@ pub struct AgentSettingsContent { /// Default: 320 #[serde(serialize_with = "crate::serialize_optional_f32_with_two_decimal_places")] pub default_height: Option, + /// Maximum content width in pixels for the agent panel. Content will be + /// centered when the panel is wider than this value. + /// + /// Default: 850 + #[serde(serialize_with = "crate::serialize_optional_f32_with_two_decimal_places")] + pub max_content_width: Option, /// The default model to use when creating new chats and for other features when a specific model is not specified. pub default_model: Option, /// Favorite models to show at the top of the model selector. diff --git a/crates/settings_content/src/language_model.rs b/crates/settings_content/src/language_model.rs index 4b72c2ad3f47d834dfa38555d80a8646e3940f51..00ecf42537459496102495c51628b54405968214 100644 --- a/crates/settings_content/src/language_model.rs +++ b/crates/settings_content/src/language_model.rs @@ -1,8 +1,8 @@ +use crate::merge_from::MergeFrom; use collections::HashMap; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings_macros::{MergeFrom, with_fallible_options}; -use strum::EnumString; use std::sync::Arc; @@ -237,15 +237,12 @@ pub struct OpenAiAvailableModel { pub capabilities: OpenAiModelCapabilities, } -#[derive(Debug, Serialize, Deserialize, PartialEq, Clone, EnumString, JsonSchema, MergeFrom)] -#[serde(rename_all = "lowercase")] -#[strum(serialize_all = "lowercase")] -pub enum OpenAiReasoningEffort { - Minimal, - Low, - Medium, - High, - XHigh, +pub use language_model_core::ReasoningEffort as OpenAiReasoningEffort; + +impl MergeFrom for OpenAiReasoningEffort { + fn merge_from(&mut self, other: &Self) { + *self = *other; + } } #[with_fallible_options] @@ -479,15 +476,10 @@ pub struct LanguageModelCacheConfiguration { pub min_total_token: u64, } -#[derive( - Copy, Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize, JsonSchema, MergeFrom, -)] -#[serde(tag = "type", rename_all = "lowercase")] -pub enum ModelMode { - #[default] - Default, - Thinking { - /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`. - budget_tokens: Option, - }, +pub use language_model_core::ModelMode; + +impl MergeFrom for ModelMode { + fn merge_from(&mut self, other: &Self) { + *self = *other; + } } diff --git a/crates/settings_content/src/settings_content.rs b/crates/settings_content/src/settings_content.rs index 6c60a7010f7cfc5b4fadf9a8cc386fe6e3267abc..3c3c0f600769b8437dc56016426eee4f84d2fc7a 100644 --- a/crates/settings_content/src/settings_content.rs +++ b/crates/settings_content/src/settings_content.rs @@ -174,9 +174,6 @@ pub struct SettingsContent { /// Configuration for Node-related features pub node: Option, - /// Configuration for the Notification Panel - pub notification_panel: Option, - pub proxy: Option, /// The URL of the Zed server to connect to. @@ -631,28 +628,6 @@ pub struct ScrollbarSettings { pub show: Option, } -#[with_fallible_options] -#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, MergeFrom, Debug, PartialEq)] -pub struct NotificationPanelSettingsContent { - /// Whether to show the panel button in the status bar. - /// - /// Default: true - pub button: Option, - /// Where to dock the panel. - /// - /// Default: right - pub dock: Option, - /// Default width of the panel in pixels. - /// - /// Default: 300 - #[serde(serialize_with = "crate::serialize_optional_f32_with_two_decimal_places")] - pub default_width: Option, - /// Whether to show a badge on the notification panel icon with the count of unread notifications. - /// - /// Default: false - pub show_count_badge: Option, -} - #[with_fallible_options] #[derive(Clone, Default, Serialize, Deserialize, JsonSchema, MergeFrom, Debug, PartialEq)] pub struct PanelSettingsContent { diff --git a/crates/settings_ui/src/components/input_field.rs b/crates/settings_ui/src/components/input_field.rs index 35e63078c154dd324c8dd622b8d98c2de36beb68..e93944cf32cddc02e10a5e4f3251e80563c992b4 100644 --- a/crates/settings_ui/src/components/input_field.rs +++ b/crates/settings_ui/src/components/input_field.rs @@ -109,16 +109,37 @@ impl RenderOnce for SettingsInputField { ..Default::default() }; + let first_render_initial_text = window.use_state(cx, |_, _| self.initial_text.clone()); + let editor = if let Some(id) = self.id { window.use_keyed_state(id, cx, { let initial_text = self.initial_text.clone(); let placeholder = self.placeholder; + let mut confirm = self.confirm.clone(); + move |window, cx| { let mut editor = Editor::single_line(window, cx); + let editor_focus_handle = editor.focus_handle(cx); if let Some(text) = initial_text { editor.set_text(text, window, cx); } + if let Some(confirm) = confirm.take() + && !self.display_confirm_button + && !self.display_clear_button + && !self.clear_on_confirm + { + cx.on_focus_out( + &editor_focus_handle, + window, + move |editor, _, window, cx| { + let text = Some(editor.text(cx)); + confirm(text, window, cx); + }, + ) + .detach(); + } + if let Some(placeholder) = placeholder { editor.set_placeholder_text(placeholder, window, cx); } @@ -130,12 +151,31 @@ impl RenderOnce for SettingsInputField { window.use_state(cx, { let initial_text = self.initial_text.clone(); let placeholder = self.placeholder; + let mut confirm = self.confirm.clone(); + move |window, cx| { let mut editor = Editor::single_line(window, cx); + let editor_focus_handle = editor.focus_handle(cx); if let Some(text) = initial_text { editor.set_text(text, window, cx); } + if let Some(confirm) = confirm.take() + && !self.display_confirm_button + && !self.display_clear_button + && !self.clear_on_confirm + { + cx.on_focus_out( + &editor_focus_handle, + window, + move |editor, _, window, cx| { + let text = Some(editor.text(cx)); + confirm(text, window, cx); + }, + ) + .detach(); + } + if let Some(placeholder) = placeholder { editor.set_placeholder_text(placeholder, window, cx); } @@ -149,11 +189,20 @@ impl RenderOnce for SettingsInputField { // re-renders but use_keyed_state returns the cached editor with stale text. // Reconcile with the expected initial_text when the editor is not focused, // so we don't clobber what the user is actively typing. - if let Some(initial_text) = &self.initial_text { - let current_text = editor.read(cx).text(cx); - if current_text != *initial_text && !editor.read(cx).is_focused(window) { - editor.update(cx, |editor, cx| { - editor.set_text(initial_text.clone(), window, cx); + if let Some(initial_text) = &self.initial_text + && let Some(first_initial) = first_render_initial_text.read(cx) + { + if initial_text != first_initial && !editor.read(cx).is_focused(window) { + *first_render_initial_text.as_mut(cx) = self.initial_text.clone(); + let weak_editor = editor.downgrade(); + let initial_text = initial_text.clone(); + + window.defer(cx, move |window, cx| { + weak_editor + .update(cx, |editor, cx| { + editor.set_text(initial_text, window, cx); + }) + .ok(); }); } } diff --git a/crates/settings_ui/src/page_data.rs b/crates/settings_ui/src/page_data.rs index 9978832c05bb29c97f118fccbe301214d81fa0c6..c77bf5a326c6b48dea2c85f0744de0066d8c0236 100644 --- a/crates/settings_ui/src/page_data.rs +++ b/crates/settings_ui/src/page_data.rs @@ -5579,96 +5579,6 @@ fn panels_page() -> SettingsPage { ] } - fn notification_panel_section() -> [SettingsPageItem; 5] { - [ - SettingsPageItem::SectionHeader("Notification Panel"), - SettingsPageItem::SettingItem(SettingItem { - title: "Notification Panel Button", - description: "Show the notification panel button in the status bar.", - field: Box::new(SettingField { - json_path: Some("notification_panel.button"), - pick: |settings_content| { - settings_content - .notification_panel - .as_ref()? - .button - .as_ref() - }, - write: |settings_content, value| { - settings_content - .notification_panel - .get_or_insert_default() - .button = value; - }, - }), - metadata: None, - files: USER, - }), - SettingsPageItem::SettingItem(SettingItem { - title: "Notification Panel Dock", - description: "Where to dock the notification panel.", - field: Box::new(SettingField { - json_path: Some("notification_panel.dock"), - pick: |settings_content| { - settings_content.notification_panel.as_ref()?.dock.as_ref() - }, - write: |settings_content, value| { - settings_content - .notification_panel - .get_or_insert_default() - .dock = value; - }, - }), - metadata: None, - files: USER, - }), - SettingsPageItem::SettingItem(SettingItem { - title: "Notification Panel Default Width", - description: "Default width of the notification panel in pixels.", - field: Box::new(SettingField { - json_path: Some("notification_panel.default_width"), - pick: |settings_content| { - settings_content - .notification_panel - .as_ref()? - .default_width - .as_ref() - }, - write: |settings_content, value| { - settings_content - .notification_panel - .get_or_insert_default() - .default_width = value; - }, - }), - metadata: None, - files: USER, - }), - SettingsPageItem::SettingItem(SettingItem { - title: "Show Count Badge", - description: "Show a badge on the notification panel icon with the count of unread notifications.", - field: Box::new(SettingField { - json_path: Some("notification_panel.show_count_badge"), - pick: |settings_content| { - settings_content - .notification_panel - .as_ref()? - .show_count_badge - .as_ref() - }, - write: |settings_content, value| { - settings_content - .notification_panel - .get_or_insert_default() - .show_count_badge = value; - }, - }), - metadata: None, - files: USER, - }), - ] - } - fn collaboration_panel_section() -> [SettingsPageItem; 4] { [ SettingsPageItem::SectionHeader("Collaboration Panel"), @@ -5737,7 +5647,7 @@ fn panels_page() -> SettingsPage { ] } - fn agent_panel_section() -> [SettingsPageItem; 6] { + fn agent_panel_section() -> [SettingsPageItem; 7] { [ SettingsPageItem::SectionHeader("Agent Panel"), SettingsPageItem::SettingItem(SettingItem { @@ -5812,6 +5722,24 @@ fn panels_page() -> SettingsPage { metadata: None, files: USER, }), + SettingsPageItem::SettingItem(SettingItem { + title: "Agent Panel Max Content Width", + description: "Maximum content width in pixels. Content will be centered when the panel is wider than this value.", + field: Box::new(SettingField { + json_path: Some("agent.max_content_width"), + pick: |settings_content| { + settings_content.agent.as_ref()?.max_content_width.as_ref() + }, + write: |settings_content, value| { + settings_content + .agent + .get_or_insert_default() + .max_content_width = value; + }, + }), + metadata: None, + files: USER, + }), ] } @@ -5823,7 +5751,6 @@ fn panels_page() -> SettingsPage { outline_panel_section(), git_panel_section(), debugger_panel_section(), - notification_panel_section(), collaboration_panel_section(), agent_panel_section(), ], diff --git a/crates/sidebar/src/sidebar_tests.rs b/crates/sidebar/src/sidebar_tests.rs index 60881acfe9461f7897d6013831970444b7a65544..09fd44af35679a69908e1d86d203ea8c3aa5c545 100644 --- a/crates/sidebar/src/sidebar_tests.rs +++ b/crates/sidebar/src/sidebar_tests.rs @@ -5064,6 +5064,7 @@ async fn test_legacy_thread_with_canonical_path_opens_main_repo_workspace(cx: &m mod property_test { use super::*; + use gpui::proptest::prelude::*; struct UnopenedWorktree { path: String, @@ -5658,7 +5659,10 @@ mod property_test { Ok(()) } - #[gpui::property_test] + #[gpui::property_test(config = ProptestConfig { + cases: 10, + ..Default::default() + })] async fn test_sidebar_invariants( #[strategy = gpui::proptest::collection::vec(0u32..DISTRIBUTION_SLOTS * 10, 1..5)] raw_operations: Vec, diff --git a/crates/tasks_ui/src/modal.rs b/crates/tasks_ui/src/modal.rs index 285a07c9562849b26b4cbba3de3979614384d875..3b7edef415f10f8723ab041e5a81ac672d603371 100644 --- a/crates/tasks_ui/src/modal.rs +++ b/crates/tasks_ui/src/modal.rs @@ -566,9 +566,7 @@ impl PickerDelegate for TasksModalDelegate { .checked_sub(1); picker.refresh(window, cx); })) - .tooltip(|_, cx| { - Tooltip::simple("Delete Previously Scheduled Task", cx) - }), + .tooltip(|_, cx| Tooltip::simple("Delete from Recent Tasks", cx)), ); item.end_slot_on_hover(delete_button) } else { diff --git a/crates/ui/src/components/collab/collab_notification.rs b/crates/ui/src/components/collab/collab_notification.rs index 0c3fca84e9b9fb3246de20b9b1f077202fa3ebdb..28d28b0a292076a575a5443b80eae9b788e2b62e 100644 --- a/crates/ui/src/components/collab/collab_notification.rs +++ b/crates/ui/src/components/collab/collab_notification.rs @@ -67,7 +67,7 @@ impl Component for CollabNotification { let avatar = "https://avatars.githubusercontent.com/u/67129314?v=4"; let container = || div().h(px(72.)).w(px(400.)); // Size of the actual notification window - let examples = vec![ + let call_examples = vec![ single_example( "Incoming Call", container() @@ -129,6 +129,58 @@ impl Component for CollabNotification { ), ]; - Some(example_group(examples).vertical().into_any_element()) + let toast_examples = vec![ + single_example( + "Contact Request", + container() + .child( + CollabNotification::new( + avatar, + Button::new("accept", "Accept"), + Button::new("decline", "Decline"), + ) + .child(Label::new("maxbrunsfeld wants to add you as a contact")), + ) + .into_any_element(), + ), + single_example( + "Contact Request Accepted", + container() + .child( + CollabNotification::new( + avatar, + Button::new("dismiss", "Dismiss"), + Button::new("close", "Close"), + ) + .child(Label::new("maxbrunsfeld accepted your contact request")), + ) + .into_any_element(), + ), + single_example( + "Channel Invitation", + container() + .child( + CollabNotification::new( + avatar, + Button::new("accept", "Accept"), + Button::new("decline", "Decline"), + ) + .child(Label::new( + "maxbrunsfeld invited you to join the #zed channel", + )), + ) + .into_any_element(), + ), + ]; + + Some( + v_flex() + .gap_6() + .child(example_group_with_title("Calls & Projects", call_examples).vertical()) + .child( + example_group_with_title("Contact & Channel Toasts", toast_examples).vertical(), + ) + .into_any_element(), + ) } } diff --git a/crates/ui/src/components/list/list_item.rs b/crates/ui/src/components/list/list_item.rs index 9a764efd58cfd3365d92e534a715a0f23ce46e90..ece1fd3c61ec486c090808891a8eec662138b1b4 100644 --- a/crates/ui/src/components/list/list_item.rs +++ b/crates/ui/src/components/list/list_item.rs @@ -52,7 +52,7 @@ pub struct ListItem { overflow_x: bool, focused: Option, docked_right: bool, - height: Option, + height: Option, } impl ListItem { @@ -207,8 +207,8 @@ impl ListItem { self } - pub fn height(mut self, height: Pixels) -> Self { - self.height = Some(height); + pub fn height(mut self, height: impl Into) -> Self { + self.height = Some(height.into()); self } } diff --git a/crates/vim/src/command.rs b/crates/vim/src/command.rs index fd19a5dc400a24b9f27617c44bd71fe38073c757..06fa6ead775809c3df775d959fb080a93ee84aad 100644 --- a/crates/vim/src/command.rs +++ b/crates/vim/src/command.rs @@ -1782,7 +1782,6 @@ fn generate_commands(_: &App) -> Vec { VimCommand::str(("te", "rm"), "terminal_panel::Toggle"), VimCommand::str(("T", "erm"), "terminal_panel::Toggle"), VimCommand::str(("C", "ollab"), "collab_panel::ToggleFocus"), - VimCommand::str(("No", "tifications"), "notification_panel::ToggleFocus"), VimCommand::str(("A", "I"), "agent::ToggleFocus"), VimCommand::str(("G", "it"), "git_panel::ToggleFocus"), VimCommand::str(("D", "ebug"), "debug_panel::ToggleFocus"), diff --git a/crates/vim/src/state.rs b/crates/vim/src/state.rs index 4dd557199ab9aebe0a2b26438bdaa0e321a956b2..9e9b42d31900e0ceb160df4ad4dd3ce3a530e155 100644 --- a/crates/vim/src/state.rs +++ b/crates/vim/src/state.rs @@ -17,7 +17,7 @@ use gpui::{ Action, App, AppContext, BorrowAppContext, ClipboardEntry, ClipboardItem, DismissEvent, Entity, EntityId, Global, HighlightStyle, StyledText, Subscription, Task, TextStyle, WeakEntity, }; -use language::{Buffer, BufferEvent, BufferId, Chunk, Point}; +use language::{Buffer, BufferEvent, BufferId, Chunk, LanguageAwareStyling, Point}; use multi_buffer::MultiBufferRow; use picker::{Picker, PickerDelegate}; @@ -1504,7 +1504,10 @@ impl PickerDelegate for MarksViewDelegate { position.row, snapshot.line_len(MultiBufferRow(position.row)), ), - true, + LanguageAwareStyling { + tree_sitter: true, + diagnostics: true, + }, ); matches.push(MarksMatch { name: name.clone(), @@ -1530,7 +1533,10 @@ impl PickerDelegate for MarksViewDelegate { let chunks = snapshot.chunks( Point::new(position.row, 0) ..Point::new(position.row, snapshot.line_len(position.row)), - true, + LanguageAwareStyling { + tree_sitter: true, + diagnostics: true, + }, ); matches.push(MarksMatch { diff --git a/crates/web_search_providers/Cargo.toml b/crates/web_search_providers/Cargo.toml index ff264edcb150063237c633de746b2f6b9f6f250c..e2bbc1aeb2dd5718596b905788b4a88826357401 100644 --- a/crates/web_search_providers/Cargo.toml +++ b/crates/web_search_providers/Cargo.toml @@ -14,6 +14,7 @@ path = "src/web_search_providers.rs" [dependencies] anyhow.workspace = true client.workspace = true +cloud_api_client.workspace = true cloud_api_types.workspace = true cloud_llm_client.workspace = true futures.workspace = true diff --git a/crates/web_search_providers/src/cloud.rs b/crates/web_search_providers/src/cloud.rs index 11227d8fb5c7152dc5b7e03b95fadea6cb714717..16707003c49921bce6244b69d0e7387f935ed8e1 100644 --- a/crates/web_search_providers/src/cloud.rs +++ b/crates/web_search_providers/src/cloud.rs @@ -2,12 +2,12 @@ use std::sync::Arc; use anyhow::{Context as _, Result}; use client::{Client, NeedsLlmTokenRefresh, UserStore, global_llm_token}; +use cloud_api_client::LlmApiToken; use cloud_api_types::OrganizationId; use cloud_llm_client::{WebSearchBody, WebSearchResponse}; use futures::AsyncReadExt as _; use gpui::{App, AppContext, Context, Entity, Task}; use http_client::{HttpClient, Method}; -use language_model::LlmApiToken; use web_search::{WebSearchProvider, WebSearchProviderId}; pub struct CloudWebSearchProvider { diff --git a/crates/workspace/src/multi_workspace.rs b/crates/workspace/src/multi_workspace.rs index a61ad3576c57ecd8b1811363d6b5607ead737821..1b057e3fb1e3b5e0639e4a44462fc7528f6db85d 100644 --- a/crates/workspace/src/multi_workspace.rs +++ b/crates/workspace/src/multi_workspace.rs @@ -276,6 +276,7 @@ pub struct MultiWorkspace { pending_removal_tasks: Vec>, _serialize_task: Option>, _subscriptions: Vec, + previous_focus_handle: Option, } impl EventEmitter for MultiWorkspace {} @@ -333,6 +334,7 @@ impl MultiWorkspace { quit_subscription, settings_subscription, ], + previous_focus_handle: None, } } @@ -387,6 +389,7 @@ impl MultiWorkspace { if self.sidebar_open() { self.close_sidebar(window, cx); } else { + self.previous_focus_handle = window.focused(cx); self.open_sidebar(cx); if let Some(sidebar) = &self.sidebar { sidebar.prepare_for_focus(window, cx); @@ -417,14 +420,16 @@ impl MultiWorkspace { .is_some_and(|s| s.focus_handle(cx).contains_focused(window, cx)); if sidebar_is_focused { - let pane = self.workspace().read(cx).active_pane().clone(); - let pane_focus = pane.read(cx).focus_handle(cx); - window.focus(&pane_focus, cx); - } else if let Some(sidebar) = &self.sidebar { - sidebar.prepare_for_focus(window, cx); - sidebar.focus(window, cx); + self.restore_previous_focus(false, window, cx); + } else { + self.previous_focus_handle = window.focused(cx); + if let Some(sidebar) = &self.sidebar { + sidebar.prepare_for_focus(window, cx); + sidebar.focus(window, cx); + } } } else { + self.previous_focus_handle = window.focused(cx); self.open_sidebar(cx); if let Some(sidebar) = &self.sidebar { sidebar.prepare_for_focus(window, cx); @@ -457,13 +462,26 @@ impl MultiWorkspace { workspace.set_sidebar_focus_handle(None); }); } - let pane = self.workspace().read(cx).active_pane().clone(); - let pane_focus = pane.read(cx).focus_handle(cx); - window.focus(&pane_focus, cx); + self.restore_previous_focus(true, window, cx); self.serialize(cx); cx.notify(); } + fn restore_previous_focus(&mut self, clear: bool, window: &mut Window, cx: &mut Context) { + let focus_handle = if clear { + self.previous_focus_handle.take() + } else { + self.previous_focus_handle.clone() + }; + + if let Some(previous_focus) = focus_handle { + previous_focus.focus(window, cx); + } else { + let pane = self.workspace().read(cx).active_pane().clone(); + window.focus(&pane.read(cx).focus_handle(cx), cx); + } + } + pub fn close_window(&mut self, _: &CloseWindow, window: &mut Window, cx: &mut Context) { cx.spawn_in(window, async move |this, cx| { let workspaces = this.update(cx, |multi_workspace, _cx| { diff --git a/crates/workspace/src/pane.rs b/crates/workspace/src/pane.rs index 27cc96ae80a010db2dd5357a9a0bc037ca762875..cbcd60b734644cb61473bef85e27f2403e3c7d3c 100644 --- a/crates/workspace/src/pane.rs +++ b/crates/workspace/src/pane.rs @@ -10,7 +10,10 @@ use crate::{ TabContentParams, TabTooltipContent, WeakItemHandle, }, move_item, - notifications::NotifyResultExt, + notifications::{ + NotificationId, NotifyResultExt, show_app_notification, + simple_message_notification::MessageNotification, + }, toolbar::Toolbar, workspace_settings::{AutosaveSetting, FocusFollowsMouse, TabBarSettings, WorkspaceSettings}, }; @@ -195,6 +198,16 @@ pub struct DeploySearch { pub included_files: Option, #[serde(default)] pub excluded_files: Option, + #[serde(default)] + pub query: Option, + #[serde(default)] + pub regex: Option, + #[serde(default)] + pub case_sensitive: Option, + #[serde(default)] + pub whole_word: Option, + #[serde(default)] + pub include_ignored: Option, } #[derive(Clone, Copy, PartialEq, Debug, Deserialize, JsonSchema, Default)] @@ -306,16 +319,6 @@ actions!( ] ); -impl DeploySearch { - pub fn find() -> Self { - Self { - replace_enabled: false, - included_files: None, - excluded_files: None, - } - } -} - const MAX_NAVIGATION_HISTORY_LEN: usize = 1024; pub enum Event { @@ -4185,15 +4188,7 @@ fn default_render_tab_bar_buttons( menu.action("New File", NewFile.boxed_clone()) .action("Open File", ToggleFileFinder::default().boxed_clone()) .separator() - .action( - "Search Project", - DeploySearch { - replace_enabled: false, - included_files: None, - excluded_files: None, - } - .boxed_clone(), - ) + .action("Search Project", DeploySearch::default().boxed_clone()) .action("Search Symbols", ToggleProjectSymbols.boxed_clone()) .separator() .action("New Terminal", NewTerminal::default().boxed_clone()) @@ -4400,17 +4395,64 @@ impl Render for Pane { )) .on_action( cx.listener(|pane: &mut Self, action: &RevealInProjectPanel, _, cx| { + let Some(active_item) = pane.active_item() else { + return; + }; + let entry_id = action .entry_id .map(ProjectEntryId::from_proto) - .or_else(|| pane.active_item()?.project_entry_ids(cx).first().copied()); - if let Some(entry_id) = entry_id { - pane.project - .update(cx, |_, cx| { - cx.emit(project::Event::RevealInProjectPanel(entry_id)) - }) - .ok(); + .or_else(|| active_item.project_entry_ids(cx).first().copied()); + + let show_reveal_error_toast = |display_name: &str, cx: &mut App| { + let notification_id = NotificationId::unique::(); + let message = SharedString::from(format!( + "\"{display_name}\" is not part of any open projects." + )); + + show_app_notification(notification_id, cx, move |cx| { + let message = message.clone(); + cx.new(|cx| MessageNotification::new(message, cx)) + }); + }; + + let Some(entry_id) = entry_id else { + // When working with an unsaved buffer, display a toast + // informing the user that the buffer is not present in + // any of the open projects and stop execution, as we + // don't want to open the project panel. + let display_name = active_item + .tab_tooltip_text(cx) + .unwrap_or_else(|| active_item.tab_content_text(0, cx)); + + return show_reveal_error_toast(&display_name, cx); + }; + + // We'll now check whether the entry belongs to a visible + // worktree and, if that's not the case, it means the user + // is interacting with a file that does not belong to any of + // the open projects, so we'll show a toast informing them + // of this and stop execution. + let display_name = pane + .project + .read_with(cx, |project, cx| { + project + .worktree_for_entry(entry_id, cx) + .filter(|worktree| !worktree.read(cx).is_visible()) + .map(|worktree| worktree.read(cx).root_name_str().to_string()) + }) + .ok() + .flatten(); + + if let Some(display_name) = display_name { + return show_reveal_error_toast(&display_name, cx); } + + pane.project + .update(cx, |_, cx| { + cx.emit(project::Event::RevealInProjectPanel(entry_id)) + }) + .log_err(); }), ) .on_action(cx.listener(|_, _: &menu::Cancel, window, cx| { diff --git a/crates/x_ai/Cargo.toml b/crates/x_ai/Cargo.toml index 8ff020df8c1ccaf284157d8b46ddaa0e678b3cd7..2d1c9d0ecebeb8a1e0965b0ac914603b41383f00 100644 --- a/crates/x_ai/Cargo.toml +++ b/crates/x_ai/Cargo.toml @@ -17,6 +17,8 @@ schemars = ["dep:schemars"] [dependencies] anyhow.workspace = true +language_model_core.workspace = true schemars = { workspace = true, optional = true } serde.workspace = true strum.workspace = true +tiktoken-rs.workspace = true diff --git a/crates/x_ai/src/completion.rs b/crates/x_ai/src/completion.rs new file mode 100644 index 0000000000000000000000000000000000000000..aad03d227eb82768c972283f7e1617ea7486f22f --- /dev/null +++ b/crates/x_ai/src/completion.rs @@ -0,0 +1,30 @@ +use anyhow::Result; +use language_model_core::{LanguageModelRequest, Role}; + +use crate::Model; + +/// Count tokens for an xAI model using tiktoken. This is synchronous; +/// callers should spawn it on a background thread if needed. +pub fn count_xai_tokens(request: LanguageModelRequest, model: Model) -> Result { + let messages = request + .messages + .into_iter() + .map(|message| tiktoken_rs::ChatCompletionRequestMessage { + role: match message.role { + Role::User => "user".into(), + Role::Assistant => "assistant".into(), + Role::System => "system".into(), + }, + content: Some(message.string_contents()), + name: None, + function_call: None, + }) + .collect::>(); + + let model_name = if model.max_token_count() >= 100_000 { + "gpt-4o" + } else { + "gpt-4" + }; + tiktoken_rs::num_tokens_from_messages(model_name, &messages).map(|tokens| tokens as u64) +} diff --git a/crates/x_ai/src/x_ai.rs b/crates/x_ai/src/x_ai.rs index 1abb2b53771fa1e29e2979560e9f394744b26158..fd141a1723a28d235311d5d875bf4cc0388cab61 100644 --- a/crates/x_ai/src/x_ai.rs +++ b/crates/x_ai/src/x_ai.rs @@ -1,3 +1,5 @@ +pub mod completion; + use anyhow::Result; use serde::{Deserialize, Serialize}; use strum::EnumIter; diff --git a/crates/zed/src/visual_test_runner.rs b/crates/zed/src/visual_test_runner.rs index b59123a1a159487f802210f3916e16856daf8e61..9f69cd3458c194228f37cfdeedcf0c9023b9b7bd 100644 --- a/crates/zed/src/visual_test_runner.rs +++ b/crates/zed/src/visual_test_runner.rs @@ -3080,7 +3080,7 @@ fn run_start_thread_in_selector_visual_tests( cx: &mut VisualTestAppContext, update_baseline: bool, ) -> Result { - use agent_ui::{AgentPanel, StartThreadIn, WorktreeCreationStatus}; + use agent_ui::{AgentPanel, NewWorktreeBranchTarget, StartThreadIn, WorktreeCreationStatus}; // Enable feature flags so the thread target selector renders cx.update(|cx| { @@ -3401,7 +3401,13 @@ edition = "2021" cx.update_window(workspace_window.into(), |_, _window, cx| { panel.update(cx, |panel, cx| { - panel.set_start_thread_in_for_tests(StartThreadIn::NewWorktree, cx); + panel.set_start_thread_in_for_tests( + StartThreadIn::NewWorktree { + worktree_name: None, + branch_target: NewWorktreeBranchTarget::default(), + }, + cx, + ); }); })?; cx.run_until_parked(); @@ -3474,7 +3480,13 @@ edition = "2021" cx.run_until_parked(); cx.update_window(workspace_window.into(), |_, window, cx| { - window.dispatch_action(Box::new(StartThreadIn::NewWorktree), cx); + window.dispatch_action( + Box::new(StartThreadIn::NewWorktree { + worktree_name: None, + branch_target: NewWorktreeBranchTarget::default(), + }), + cx, + ); })?; cx.run_until_parked(); diff --git a/crates/zed/src/zed.rs b/crates/zed/src/zed.rs index 03e128415e1aa8390d1b95816755d3644064dada..293125c0089e0a4315eb9c28f30be5f840bd6052 100644 --- a/crates/zed/src/zed.rs +++ b/crates/zed/src/zed.rs @@ -652,10 +652,6 @@ fn initialize_panels(window: &mut Window, cx: &mut Context) -> Task) -> Task(window, cx); }, ) - .register_action( - |workspace: &mut Workspace, - _: &collab_ui::notification_panel::ToggleFocus, - window: &mut Window, - cx: &mut Context| { - workspace.toggle_panel_focus::( - window, cx, - ); - }, - ) .register_action( |workspace: &mut Workspace, _: &terminal_panel::ToggleFocus, @@ -4962,7 +4947,6 @@ mod tests { "multi_workspace", "new_process_modal", "notebook", - "notification_panel", "onboarding", "outline", "outline_panel", diff --git a/crates/zed/src/zed/app_menus.rs b/crates/zed/src/zed/app_menus.rs index 3edbcad2d81d63b56e777218a3db5e57a42de7bc..f3913a6556626e2919024ca02bcba0f1f41819eb 100644 --- a/crates/zed/src/zed/app_menus.rs +++ b/crates/zed/src/zed/app_menus.rs @@ -165,7 +165,7 @@ pub fn app_menus(cx: &mut App) -> Vec { MenuItem::os_action("Paste", editor::actions::Paste, OsAction::Paste), MenuItem::separator(), MenuItem::action("Find", search::buffer_search::Deploy::find()), - MenuItem::action("Find in Project", workspace::DeploySearch::find()), + MenuItem::action("Find in Project", workspace::DeploySearch::default()), MenuItem::separator(), MenuItem::action( "Toggle Line Comment", diff --git a/docs/src/visual-customization.md b/docs/src/visual-customization.md index 3c285bc3d10fc3bcb5fba6f735304ede438104a3..7597cdac293dd842b6a6a9f5747551a6f172bbf3 100644 --- a/docs/src/visual-customization.md +++ b/docs/src/visual-customization.md @@ -105,7 +105,7 @@ To disable this behavior use: // "outline_panel": {"button": false }, // "collaboration_panel": {"button": false }, // "git_panel": {"button": false }, - // "notification_panel": {"button": false }, + // "agent": {"button": false }, // "debugger": {"button": false }, // "diagnostics": {"button": false }, @@ -588,16 +588,6 @@ See [Terminal settings](./reference/all-settings.md#terminal) for additional non "dock": "left", // Where to dock: left, right "default_width": 240 // Default width of the collaboration panel. }, - "show_call_status_icon": true, // Shown call status in the OS status bar. - - // Notification Panel - "notification_panel": { - // Whether to show the notification panel button in the status bar. - "button": true, - // Where to dock the notification panel. Can be 'left' or 'right'. - "dock": "right", - // Default width of the notification panel. - "default_width": 380 - } + "show_call_status_icon": true // Shown call status in the OS status bar. } ```