diff --git a/crates/agent/src/edit_agent.rs b/crates/agent/src/edit_agent.rs index ecd0b20f674c0a8efdfca3b28cce5780a882cedb..0c234b8cb152e46ad3be29f21639e5a20a1ffb86 100644 --- a/crates/agent/src/edit_agent.rs +++ b/crates/agent/src/edit_agent.rs @@ -732,6 +732,10 @@ impl EditAgent { stop: Vec::new(), temperature: None, thinking_allowed: true, + // Bypass the rate limiter for nested requests (edit agent requests spawned + // from within a tool call) to avoid deadlocks when multiple subagents try + // to use edit_file simultaneously. + bypass_rate_limit: true, }; Ok(self.model.stream_completion_text(request, cx).await?.stream) diff --git a/crates/agent/src/tests/edit_file_thread_test.rs b/crates/agent/src/tests/edit_file_thread_test.rs new file mode 100644 index 0000000000000000000000000000000000000000..67a0aa07255f15445ae236e2042864acba9833c4 --- /dev/null +++ b/crates/agent/src/tests/edit_file_thread_test.rs @@ -0,0 +1,618 @@ +use super::*; +use acp_thread::UserMessageId; +use action_log::ActionLog; +use fs::FakeFs; +use language_model::{ + LanguageModelCompletionEvent, LanguageModelToolUse, MessageContent, StopReason, + fake_provider::FakeLanguageModel, +}; +use prompt_store::ProjectContext; +use serde_json::json; +use std::{collections::BTreeMap, sync::Arc, time::Duration}; +use util::path; + +#[gpui::test] +async fn test_edit_file_tool_in_thread_context(cx: &mut TestAppContext) { + // This test verifies that the edit_file tool works correctly when invoked + // through the full thread flow (model sends ToolUse event -> tool runs -> result sent back). + // This is different from tests that call tool.run() directly. + super::init_test(cx); + super::always_allow_tools(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + path!("/project"), + json!({ + "src": { + "main.rs": "fn main() {\n println!(\"Hello, world!\");\n}\n" + } + }), + ) + .await; + + let project = project::Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + let project_context = cx.new(|_cx| ProjectContext::default()); + let context_server_store = project.read_with(cx, |project, _| project.context_server_store()); + let context_server_registry = + cx.new(|cx| crate::ContextServerRegistry::new(context_server_store.clone(), cx)); + let model = Arc::new(FakeLanguageModel::default()); + let fake_model = model.as_fake(); + + let thread = cx.new(|cx| { + let mut thread = crate::Thread::new( + project.clone(), + project_context, + context_server_registry, + crate::Templates::new(), + Some(model.clone()), + cx, + ); + // Add just the tools we need for this test + let language_registry = project.read(cx).languages().clone(); + thread.add_tool(crate::ReadFileTool::new( + cx.weak_entity(), + project.clone(), + thread.action_log().clone(), + )); + thread.add_tool(crate::EditFileTool::new( + project.clone(), + cx.weak_entity(), + language_registry, + crate::Templates::new(), + )); + thread + }); + + // First, read the file so the thread knows about its contents + let _events = thread + .update(cx, |thread, cx| { + thread.send(UserMessageId::new(), ["Read the file src/main.rs"], cx) + }) + .unwrap(); + cx.run_until_parked(); + + // Model calls read_file tool + let read_tool_use = LanguageModelToolUse { + id: "read_tool_1".into(), + name: "read_file".into(), + raw_input: json!({"path": "project/src/main.rs"}).to_string(), + input: json!({"path": "project/src/main.rs"}), + is_input_complete: true, + thought_signature: None, + }; + fake_model + .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(read_tool_use)); + fake_model + .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + // Wait for the read tool to complete and model to be called again + while fake_model.pending_completions().is_empty() { + cx.run_until_parked(); + } + + // Model responds after seeing the file content, then calls edit_file + fake_model.send_last_completion_stream_text_chunk("I'll edit the file now."); + let edit_tool_use = LanguageModelToolUse { + id: "edit_tool_1".into(), + name: "edit_file".into(), + raw_input: json!({ + "display_description": "Change greeting message", + "path": "project/src/main.rs", + "mode": "edit" + }) + .to_string(), + input: json!({ + "display_description": "Change greeting message", + "path": "project/src/main.rs", + "mode": "edit" + }), + is_input_complete: true, + thought_signature: None, + }; + fake_model + .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(edit_tool_use)); + fake_model + .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + // The edit_file tool creates an EditAgent which makes its own model request. + // We need to respond to that request with the edit instructions. + // Wait for the edit agent's completion request + let deadline = std::time::Instant::now() + Duration::from_secs(5); + while fake_model.pending_completions().is_empty() { + if std::time::Instant::now() >= deadline { + panic!( + "Timed out waiting for edit agent completion request. Pending: {}", + fake_model.pending_completions().len() + ); + } + cx.run_until_parked(); + cx.background_executor + .timer(Duration::from_millis(10)) + .await; + } + + // Send the edit agent's response with the XML format it expects + let edit_response = "println!(\"Hello, world!\");\nprintln!(\"Hello, Zed!\");"; + fake_model.send_last_completion_stream_text_chunk(edit_response); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + // Wait for the edit to complete and the thread to call the model again with tool results + let deadline = std::time::Instant::now() + Duration::from_secs(5); + while fake_model.pending_completions().is_empty() { + if std::time::Instant::now() >= deadline { + panic!("Timed out waiting for model to be called after edit completion"); + } + cx.run_until_parked(); + cx.background_executor + .timer(Duration::from_millis(10)) + .await; + } + + // Verify the file was edited + let file_content = fs + .load(path!("/project/src/main.rs").as_ref()) + .await + .expect("file should exist"); + assert!( + file_content.contains("Hello, Zed!"), + "File should have been edited. Content: {}", + file_content + ); + assert!( + !file_content.contains("Hello, world!"), + "Old content should be replaced. Content: {}", + file_content + ); + + // Verify the tool result was sent back to the model + let pending = fake_model.pending_completions(); + assert!( + !pending.is_empty(), + "Model should have been called with tool result" + ); + + let last_request = pending.last().unwrap(); + let has_tool_result = last_request.messages.iter().any(|m| { + m.content + .iter() + .any(|c| matches!(c, language_model::MessageContent::ToolResult(_))) + }); + assert!( + has_tool_result, + "Tool result should be in the messages sent back to the model" + ); + + // Complete the turn + fake_model.send_last_completion_stream_text_chunk("I've updated the greeting message."); + fake_model + .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + // Verify the thread completed successfully + thread.update(cx, |thread, _cx| { + assert!( + thread.is_turn_complete(), + "Thread should be complete after the turn ends" + ); + }); +} + +#[gpui::test] +async fn test_subagent_uses_read_file_tool(cx: &mut TestAppContext) { + // This test verifies that subagents can successfully use the read_file tool + // through the full thread flow, and that tools are properly rebound to use + // the subagent's thread ID instead of the parent's. + super::init_test(cx); + super::always_allow_tools(cx); + + cx.update(|cx| { + cx.update_flags(true, vec!["subagents".to_string()]); + }); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + path!("/project"), + json!({ + "src": { + "lib.rs": "pub fn hello() -> &'static str {\n \"Hello from lib!\"\n}\n" + } + }), + ) + .await; + + let project = project::Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + let project_context = cx.new(|_cx| ProjectContext::default()); + let context_server_store = project.read_with(cx, |project, _| project.context_server_store()); + let context_server_registry = + cx.new(|cx| crate::ContextServerRegistry::new(context_server_store.clone(), cx)); + let model = Arc::new(FakeLanguageModel::default()); + let fake_model = model.as_fake(); + + // Create subagent context + let subagent_context = crate::SubagentContext { + parent_thread_id: agent_client_protocol::SessionId::new("parent-id"), + tool_use_id: language_model::LanguageModelToolUseId::from("subagent-tool-use-id"), + depth: 1, + summary_prompt: "Summarize what you found".to_string(), + context_low_prompt: "Context low".to_string(), + }; + + // Create parent tools that will be passed to the subagent + // This simulates how the subagent_tool passes tools to new_subagent + let parent_tools: BTreeMap> = { + let action_log = cx.new(|_| ActionLog::new(project.clone())); + // Create a "fake" parent thread reference - this should get rebound + let fake_parent_thread = cx.new(|cx| { + crate::Thread::new( + project.clone(), + cx.new(|_cx| ProjectContext::default()), + cx.new(|cx| crate::ContextServerRegistry::new(context_server_store.clone(), cx)), + crate::Templates::new(), + Some(model.clone()), + cx, + ) + }); + let mut tools: BTreeMap> = + BTreeMap::new(); + tools.insert( + "read_file".into(), + crate::ReadFileTool::new(fake_parent_thread.downgrade(), project.clone(), action_log) + .erase(), + ); + tools + }; + + // Create subagent - tools should be rebound to use subagent's thread + let subagent = cx.new(|cx| { + crate::Thread::new_subagent( + project.clone(), + project_context, + context_server_registry, + crate::Templates::new(), + model.clone(), + subagent_context, + parent_tools, + cx, + ) + }); + + // Get the subagent's thread ID + let _subagent_thread_id = subagent.read_with(cx, |thread, _| thread.id().to_string()); + + // Verify the subagent has the read_file tool + subagent.read_with(cx, |thread, _| { + assert!( + thread.has_registered_tool("read_file"), + "subagent should have read_file tool" + ); + }); + + // Submit a user message to the subagent + subagent + .update(cx, |thread, cx| { + thread.submit_user_message("Read the file src/lib.rs", cx) + }) + .unwrap(); + cx.run_until_parked(); + + // Simulate the model calling the read_file tool + let read_tool_use = LanguageModelToolUse { + id: "read_tool_1".into(), + name: "read_file".into(), + raw_input: json!({"path": "project/src/lib.rs"}).to_string(), + input: json!({"path": "project/src/lib.rs"}), + is_input_complete: true, + thought_signature: None, + }; + fake_model + .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(read_tool_use)); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + // Wait for the tool to complete and the model to be called again with tool results + let deadline = std::time::Instant::now() + Duration::from_secs(5); + while fake_model.pending_completions().is_empty() { + if std::time::Instant::now() >= deadline { + panic!("Timed out waiting for model to be called after read_file tool completion"); + } + cx.run_until_parked(); + cx.background_executor + .timer(Duration::from_millis(10)) + .await; + } + + // Verify the tool result was sent back to the model + let pending = fake_model.pending_completions(); + assert!( + !pending.is_empty(), + "Model should have been called with tool result" + ); + + let last_request = pending.last().unwrap(); + let tool_result = last_request.messages.iter().find_map(|m| { + m.content.iter().find_map(|c| match c { + MessageContent::ToolResult(result) => Some(result), + _ => None, + }) + }); + assert!( + tool_result.is_some(), + "Tool result should be in the messages sent back to the model" + ); + + // Verify the tool result contains the file content + let result = tool_result.unwrap(); + let result_text = match &result.content { + language_model::LanguageModelToolResultContent::Text(text) => text.to_string(), + _ => panic!("expected text content in tool result"), + }; + assert!( + result_text.contains("Hello from lib!"), + "Tool result should contain file content, got: {}", + result_text + ); + + // Verify the subagent is ready for more input (tool completed, model called again) + // This test verifies the subagent can successfully use read_file tool. + // The summary flow is tested separately in test_subagent_returns_summary_on_completion. +} + +#[gpui::test] +async fn test_subagent_uses_edit_file_tool(cx: &mut TestAppContext) { + // This test verifies that subagents can successfully use the edit_file tool + // through the full thread flow, including the edit agent's model request. + // It also verifies that the edit agent uses the subagent's thread ID, not the parent's. + super::init_test(cx); + super::always_allow_tools(cx); + + cx.update(|cx| { + cx.update_flags(true, vec!["subagents".to_string()]); + }); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + path!("/project"), + json!({ + "src": { + "config.rs": "pub const VERSION: &str = \"1.0.0\";\n" + } + }), + ) + .await; + + let project = project::Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + let project_context = cx.new(|_cx| ProjectContext::default()); + let context_server_store = project.read_with(cx, |project, _| project.context_server_store()); + let context_server_registry = + cx.new(|cx| crate::ContextServerRegistry::new(context_server_store.clone(), cx)); + let model = Arc::new(FakeLanguageModel::default()); + let fake_model = model.as_fake(); + + // Create a "parent" thread to simulate the real scenario where tools are inherited + let parent_thread = cx.new(|cx| { + crate::Thread::new( + project.clone(), + cx.new(|_cx| ProjectContext::default()), + cx.new(|cx| crate::ContextServerRegistry::new(context_server_store.clone(), cx)), + crate::Templates::new(), + Some(model.clone()), + cx, + ) + }); + let parent_thread_id = parent_thread.read_with(cx, |thread, _| thread.id().to_string()); + + // Create parent tools that reference the parent thread + let parent_tools: BTreeMap> = { + let action_log = cx.new(|_| ActionLog::new(project.clone())); + let language_registry = project.read_with(cx, |p, _| p.languages().clone()); + let mut tools: BTreeMap> = + BTreeMap::new(); + tools.insert( + "read_file".into(), + crate::ReadFileTool::new(parent_thread.downgrade(), project.clone(), action_log) + .erase(), + ); + tools.insert( + "edit_file".into(), + crate::EditFileTool::new( + project.clone(), + parent_thread.downgrade(), + language_registry, + crate::Templates::new(), + ) + .erase(), + ); + tools + }; + + // Create subagent context + let subagent_context = crate::SubagentContext { + parent_thread_id: agent_client_protocol::SessionId::new("parent-id"), + tool_use_id: language_model::LanguageModelToolUseId::from("subagent-tool-use-id"), + depth: 1, + summary_prompt: "Summarize what you changed".to_string(), + context_low_prompt: "Context low".to_string(), + }; + + // Create subagent - tools should be rebound to use subagent's thread + let subagent = cx.new(|cx| { + crate::Thread::new_subagent( + project.clone(), + project_context, + context_server_registry, + crate::Templates::new(), + model.clone(), + subagent_context, + parent_tools, + cx, + ) + }); + + // Get the subagent's thread ID - it should be different from parent + let subagent_thread_id = subagent.read_with(cx, |thread, _| thread.id().to_string()); + assert_ne!( + parent_thread_id, subagent_thread_id, + "Subagent should have a different thread ID than parent" + ); + + // Verify the subagent has the tools + subagent.read_with(cx, |thread, _| { + assert!( + thread.has_registered_tool("read_file"), + "subagent should have read_file tool" + ); + assert!( + thread.has_registered_tool("edit_file"), + "subagent should have edit_file tool" + ); + }); + + // Submit a user message to the subagent + subagent + .update(cx, |thread, cx| { + thread.submit_user_message("Update the version in config.rs to 2.0.0", cx) + }) + .unwrap(); + cx.run_until_parked(); + + // First, model calls read_file to see the current content + let read_tool_use = LanguageModelToolUse { + id: "read_tool_1".into(), + name: "read_file".into(), + raw_input: json!({"path": "project/src/config.rs"}).to_string(), + input: json!({"path": "project/src/config.rs"}), + is_input_complete: true, + thought_signature: None, + }; + fake_model + .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(read_tool_use)); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + // Wait for the read tool to complete and model to be called again + let deadline = std::time::Instant::now() + Duration::from_secs(5); + while fake_model.pending_completions().is_empty() { + if std::time::Instant::now() >= deadline { + panic!("Timed out waiting for model to be called after read_file tool"); + } + cx.run_until_parked(); + cx.background_executor + .timer(Duration::from_millis(10)) + .await; + } + + // Model responds and calls edit_file + fake_model.send_last_completion_stream_text_chunk("I'll update the version now."); + let edit_tool_use = LanguageModelToolUse { + id: "edit_tool_1".into(), + name: "edit_file".into(), + raw_input: json!({ + "display_description": "Update version to 2.0.0", + "path": "project/src/config.rs", + "mode": "edit" + }) + .to_string(), + input: json!({ + "display_description": "Update version to 2.0.0", + "path": "project/src/config.rs", + "mode": "edit" + }), + is_input_complete: true, + thought_signature: None, + }; + fake_model + .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(edit_tool_use)); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + // The edit_file tool creates an EditAgent which makes its own model request. + // Wait for that request. + let deadline = std::time::Instant::now() + Duration::from_secs(5); + while fake_model.pending_completions().is_empty() { + if std::time::Instant::now() >= deadline { + panic!( + "Timed out waiting for edit agent completion request in subagent. Pending: {}", + fake_model.pending_completions().len() + ); + } + cx.run_until_parked(); + cx.background_executor + .timer(Duration::from_millis(10)) + .await; + } + + // Verify the edit agent's request uses the SUBAGENT's thread ID, not the parent's + let pending = fake_model.pending_completions(); + let edit_agent_request = pending.last().unwrap(); + let edit_agent_thread_id = edit_agent_request.thread_id.as_ref().unwrap(); + std::assert_eq!( + edit_agent_thread_id, + &subagent_thread_id, + "Edit agent should use subagent's thread ID, not parent's. Got: {}, expected: {}", + edit_agent_thread_id, + subagent_thread_id + ); + std::assert_ne!( + edit_agent_thread_id, + &parent_thread_id, + "Edit agent should NOT use parent's thread ID" + ); + + // Send the edit agent's response with the XML format it expects + let edit_response = "pub const VERSION: &str = \"1.0.0\";\npub const VERSION: &str = \"2.0.0\";"; + fake_model.send_last_completion_stream_text_chunk(edit_response); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + // Wait for the edit to complete and the thread to call the model again with tool results + let deadline = std::time::Instant::now() + Duration::from_secs(5); + while fake_model.pending_completions().is_empty() { + if std::time::Instant::now() >= deadline { + panic!("Timed out waiting for model to be called after edit completion in subagent"); + } + cx.run_until_parked(); + cx.background_executor + .timer(Duration::from_millis(10)) + .await; + } + + // Verify the file was edited + let file_content = fs + .load(path!("/project/src/config.rs").as_ref()) + .await + .expect("file should exist"); + assert!( + file_content.contains("2.0.0"), + "File should have been edited to contain new version. Content: {}", + file_content + ); + assert!( + !file_content.contains("1.0.0"), + "Old version should be replaced. Content: {}", + file_content + ); + + // Verify the tool result was sent back to the model + let pending = fake_model.pending_completions(); + assert!( + !pending.is_empty(), + "Model should have been called with tool result" + ); + + let last_request = pending.last().unwrap(); + let has_tool_result = last_request.messages.iter().any(|m| { + m.content + .iter() + .any(|c| matches!(c, MessageContent::ToolResult(_))) + }); + assert!( + has_tool_result, + "Tool result should be in the messages sent back to the model" + ); +} diff --git a/crates/agent/src/tests/mod.rs b/crates/agent/src/tests/mod.rs index 762b4d9a1393f96116e77d3e265b53633f014e7a..a4706f6a752b0ae2fd251320106da998819b0b47 100644 --- a/crates/agent/src/tests/mod.rs +++ b/crates/agent/src/tests/mod.rs @@ -52,6 +52,7 @@ use std::{ }; use util::path; +mod edit_file_thread_test; mod test_tools; use test_tools::*; diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index b1f868a4e42e9e7ddc8ddbf866986f72360e35fc..d58e3cb0c4a14489b8e6f5321e4c7a5178ebe766 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -836,6 +836,19 @@ impl Thread { let action_log = cx.new(|_cx| ActionLog::new(project.clone())); let (prompt_capabilities_tx, prompt_capabilities_rx) = watch::channel(Self::prompt_capabilities(Some(model.as_ref()))); + + // Rebind tools that hold thread references to use this subagent's thread + // instead of the parent's thread. This is critical for tools like EditFileTool + // that make model requests using the thread's ID. + let weak_self = cx.weak_entity(); + let tools: BTreeMap> = parent_tools + .into_iter() + .map(|(name, tool)| { + let rebound = tool.rebind_thread(weak_self.clone()).unwrap_or(tool); + (name, rebound) + }) + .collect(); + Self { id: acp::SessionId::new(uuid::Uuid::new_v4().to_string()), prompt_id: PromptId::new(), @@ -849,7 +862,7 @@ impl Thread { running_turn: None, queued_messages: Vec::new(), pending_message: None, - tools: parent_tools, + tools, request_token_usage: HashMap::default(), cumulative_token_usage: TokenUsage::default(), initial_project_snapshot: Task::ready(None).shared(), @@ -2274,6 +2287,7 @@ impl Thread { stop: Vec::new(), temperature: AgentSettings::temperature_for_model(model, cx), thinking_allowed: true, + bypass_rate_limit: false, }; log::debug!("Completion request built successfully"); @@ -2690,6 +2704,15 @@ where fn erase(self) -> Arc { Arc::new(Erased(Arc::new(self))) } + + /// Create a new instance of this tool bound to a different thread. + /// This is used when creating subagents, so that tools like EditFileTool + /// that hold a thread reference will use the subagent's thread instead + /// of the parent's thread. + /// Returns None if the tool doesn't need rebinding (most tools). + fn rebind_thread(&self, _new_thread: WeakEntity) -> Option> { + None + } } pub struct Erased(T); @@ -2721,6 +2744,14 @@ pub trait AnyAgentTool { event_stream: ToolCallEventStream, cx: &mut App, ) -> Result<()>; + /// Create a new instance of this tool bound to a different thread. + /// This is used when creating subagents, so that tools like EditFileTool + /// that hold a thread reference will use the subagent's thread instead + /// of the parent's thread. + /// Returns None if the tool doesn't need rebinding (most tools). + fn rebind_thread(&self, _new_thread: WeakEntity) -> Option> { + None + } } impl AnyAgentTool for Erased> @@ -2784,6 +2815,10 @@ where let output = serde_json::from_value(output)?; self.0.replay(input, output, event_stream, cx) } + + fn rebind_thread(&self, new_thread: WeakEntity) -> Option> { + self.0.rebind_thread(new_thread) + } } #[derive(Clone)] diff --git a/crates/agent/src/tools/edit_file_tool.rs b/crates/agent/src/tools/edit_file_tool.rs index 4f8288c24382d373a41fd15b779970676ad09fae..bc7e5b5289937d6212c662f97238e43ea185684d 100644 --- a/crates/agent/src/tools/edit_file_tool.rs +++ b/crates/agent/src/tools/edit_file_tool.rs @@ -144,6 +144,15 @@ impl EditFileTool { } } + pub fn with_thread(&self, new_thread: WeakEntity) -> Self { + Self { + project: self.project.clone(), + thread: new_thread, + language_registry: self.language_registry.clone(), + templates: self.templates.clone(), + } + } + fn authorize( &self, input: &EditFileToolInput, @@ -398,7 +407,6 @@ impl AgentTool for EditFileTool { }) .await; - let (output, mut events) = if matches!(input.mode, EditFileMode::Edit) { edit_agent.edit( buffer.clone(), @@ -575,6 +583,13 @@ impl AgentTool for EditFileTool { })); Ok(()) } + + fn rebind_thread( + &self, + new_thread: gpui::WeakEntity, + ) -> Option> { + Some(self.with_thread(new_thread).erase()) + } } /// Validate that the file path is valid, meaning: diff --git a/crates/agent/src/tools/read_file_tool.rs b/crates/agent/src/tools/read_file_tool.rs index bc7647739035a41b91c481d2f25b5fbd0f7856c7..8b13452e9357921a1f7a43a51a3364594b481c42 100644 --- a/crates/agent/src/tools/read_file_tool.rs +++ b/crates/agent/src/tools/read_file_tool.rs @@ -65,6 +65,14 @@ impl ReadFileTool { action_log, } } + + pub fn with_thread(&self, new_thread: WeakEntity) -> Self { + Self { + thread: new_thread, + project: self.project.clone(), + action_log: self.action_log.clone(), + } + } } impl AgentTool for ReadFileTool { @@ -308,6 +316,13 @@ impl AgentTool for ReadFileTool { result }) } + + fn rebind_thread( + &self, + new_thread: WeakEntity, + ) -> Option> { + Some(self.with_thread(new_thread).erase()) + } } #[cfg(test)] diff --git a/crates/agent_ui/src/buffer_codegen.rs b/crates/agent_ui/src/buffer_codegen.rs index 6d860bea9974b209289fd57276351e96ed744b20..5cfd161e9fb01cfde6fefe65ef7b3a5fbd89a6f7 100644 --- a/crates/agent_ui/src/buffer_codegen.rs +++ b/crates/agent_ui/src/buffer_codegen.rs @@ -544,6 +544,7 @@ impl CodegenAlternative { temperature, messages, thinking_allowed: false, + bypass_rate_limit: false, } })) } @@ -622,6 +623,7 @@ impl CodegenAlternative { temperature, messages: vec![request_message], thinking_allowed: false, + bypass_rate_limit: false, } })) } diff --git a/crates/agent_ui/src/terminal_inline_assistant.rs b/crates/agent_ui/src/terminal_inline_assistant.rs index 58a73131e2d20d0776b2cdc49a0b395834b5008f..04dafa588dd3960eb435ef7d9225217d2dfb3354 100644 --- a/crates/agent_ui/src/terminal_inline_assistant.rs +++ b/crates/agent_ui/src/terminal_inline_assistant.rs @@ -275,6 +275,7 @@ impl TerminalInlineAssistant { stop: Vec::new(), temperature, thinking_allowed: false, + bypass_rate_limit: false, } })) } diff --git a/crates/assistant_text_thread/src/text_thread.rs b/crates/assistant_text_thread/src/text_thread.rs index 034314e349a306040fc0cd37dbc3ad9a5ea6e81b..042169eb93b51b681e88091cf994d95fa7b88436 100644 --- a/crates/assistant_text_thread/src/text_thread.rs +++ b/crates/assistant_text_thread/src/text_thread.rs @@ -2269,6 +2269,7 @@ impl TextThread { stop: Vec::new(), temperature: model.and_then(|model| AgentSettings::temperature_for_model(model, cx)), thinking_allowed: true, + bypass_rate_limit: false, }; for message in self.messages(cx) { if message.status != MessageStatus::Done { diff --git a/crates/eval/src/instance.rs b/crates/eval/src/instance.rs index 17b5adfc4aa9621ac4638f873c30e62ab6244107..1cb5e9a10c3c8814f154643b590e280af724188a 100644 --- a/crates/eval/src/instance.rs +++ b/crates/eval/src/instance.rs @@ -563,6 +563,7 @@ impl ExampleInstance { tool_choice: None, stop: Vec::new(), thinking_allowed: true, + bypass_rate_limit: false, }; let model = model.clone(); diff --git a/crates/git_ui/src/git_panel.rs b/crates/git_ui/src/git_panel.rs index 3f2f406136bdc3a8d9a813e37e264e97633bd214..1628358c602fd5a2ff67a178331ec2172f3d7a67 100644 --- a/crates/git_ui/src/git_panel.rs +++ b/crates/git_ui/src/git_panel.rs @@ -2691,6 +2691,7 @@ impl GitPanel { stop: Vec::new(), temperature, thinking_allowed: false, + bypass_rate_limit: false, }; let stream = model.stream_completion_text(request, cx); diff --git a/crates/language_model/src/rate_limiter.rs b/crates/language_model/src/rate_limiter.rs index 790be05ac069b8f394e442cbcb6383f611326a69..f4aea8177666e406f8c22f2c92fd0c6f9b4619a8 100644 --- a/crates/language_model/src/rate_limiter.rs +++ b/crates/language_model/src/rate_limiter.rs @@ -16,7 +16,7 @@ pub struct RateLimiter { pub struct RateLimitGuard { inner: T, - _guard: SemaphoreGuardArc, + _guard: Option, } impl Stream for RateLimitGuard @@ -68,6 +68,36 @@ impl RateLimiter { async move { let guard = guard.await; let inner = future.await?; + Ok(RateLimitGuard { + inner, + _guard: Some(guard), + }) + } + } + + /// Like `stream`, but conditionally bypasses the rate limiter based on the flag. + /// Used for nested requests (like edit agent requests) that are already "part of" + /// a rate-limited request to avoid deadlocks. + pub fn stream_with_bypass<'a, Fut, T>( + &self, + future: Fut, + bypass: bool, + ) -> impl 'a + + Future< + Output = Result + use, LanguageModelCompletionError>, + > + where + Fut: 'a + Future>, + T: Stream, + { + let semaphore = self.semaphore.clone(); + async move { + let guard = if bypass { + None + } else { + Some(semaphore.acquire_arc().await) + }; + let inner = future.await?; Ok(RateLimitGuard { inner, _guard: guard, @@ -75,3 +105,190 @@ impl RateLimiter { } } } + +#[cfg(test)] +mod tests { + use super::*; + use futures::stream; + use smol::lock::Barrier; + use std::sync::Arc; + use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; + use std::time::{Duration, Instant}; + + /// Tests that nested requests without bypass_rate_limit cause deadlock, + /// while requests with bypass_rate_limit complete successfully. + /// + /// This test simulates the scenario where multiple "parent" requests each + /// try to spawn a "nested" request (like edit_file tool spawning an edit agent). + /// With a rate limit of 2 and 2 parent requests, without bypass the nested + /// requests would block forever waiting for permits that the parents hold. + #[test] + fn test_nested_requests_bypass_prevents_deadlock() { + smol::block_on(async { + // Use only 2 permits so we can guarantee deadlock conditions + let rate_limiter = RateLimiter::new(2); + let completed = Arc::new(AtomicUsize::new(0)); + // Barrier ensures all parents acquire permits before any tries nested request + let barrier = Arc::new(Barrier::new(2)); + + // Spawn 2 "parent" requests that each try to make a "nested" request + let mut handles = Vec::new(); + for _ in 0..2 { + let limiter = rate_limiter.clone(); + let completed = completed.clone(); + let barrier = barrier.clone(); + + let handle = smol::spawn(async move { + // Parent request acquires a permit via stream_with_bypass (bypass=false) + let parent_stream = limiter + .stream_with_bypass( + async { + // Wait for all parents to acquire permits + barrier.wait().await; + + // While holding the parent permit, make a nested request + // WITH bypass=true (simulating EditAgent behavior) + let nested_stream = limiter + .stream_with_bypass( + async { Ok(stream::iter(vec![1, 2, 3])) }, + true, // bypass - this is the key! + ) + .await?; + + // Consume the nested stream + use futures::StreamExt; + let _: Vec<_> = nested_stream.collect().await; + + Ok(stream::iter(vec!["done"])) + }, + false, // parent does NOT bypass + ) + .await + .unwrap(); + + // Consume parent stream + use futures::StreamExt; + let _: Vec<_> = parent_stream.collect().await; + + completed.fetch_add(1, Ordering::SeqCst); + }); + handles.push(handle); + } + + // With bypass=true for nested requests, this should complete quickly + let timed_out = Arc::new(AtomicBool::new(false)); + let timed_out_clone = timed_out.clone(); + + // Spawn a watchdog that sets timed_out after 2 seconds + let watchdog = smol::spawn(async move { + let start = Instant::now(); + while start.elapsed() < Duration::from_secs(2) { + smol::future::yield_now().await; + } + timed_out_clone.store(true, Ordering::SeqCst); + }); + + // Wait for all handles to complete + for handle in handles { + handle.await; + } + + // Cancel the watchdog + drop(watchdog); + + if timed_out.load(Ordering::SeqCst) { + panic!( + "Test timed out - deadlock detected! This means bypass_rate_limit is not working." + ); + } + assert_eq!(completed.load(Ordering::SeqCst), 2); + }); + } + + /// Tests that without bypass, nested requests DO cause deadlock. + /// This test verifies the problem exists when bypass is not used. + #[test] + fn test_nested_requests_without_bypass_deadlocks() { + smol::block_on(async { + // Use only 2 permits so we can guarantee deadlock conditions + let rate_limiter = RateLimiter::new(2); + let completed = Arc::new(AtomicUsize::new(0)); + // Barrier ensures all parents acquire permits before any tries nested request + let barrier = Arc::new(Barrier::new(2)); + + // Spawn 2 "parent" requests that each try to make a "nested" request + let mut handles = Vec::new(); + for _ in 0..2 { + let limiter = rate_limiter.clone(); + let completed = completed.clone(); + let barrier = barrier.clone(); + + let handle = smol::spawn(async move { + // Parent request acquires a permit + let parent_stream = limiter + .stream_with_bypass( + async { + // Wait for all parents to acquire permits - this guarantees + // that all 2 permits are held before any nested request starts + barrier.wait().await; + + // Nested request WITHOUT bypass - this will deadlock! + // Both parents hold permits, so no permits available + let nested_stream = limiter + .stream_with_bypass( + async { Ok(stream::iter(vec![1, 2, 3])) }, + false, // NO bypass - will try to acquire permit + ) + .await?; + + use futures::StreamExt; + let _: Vec<_> = nested_stream.collect().await; + + Ok(stream::iter(vec!["done"])) + }, + false, + ) + .await + .unwrap(); + + use futures::StreamExt; + let _: Vec<_> = parent_stream.collect().await; + + completed.fetch_add(1, Ordering::SeqCst); + }); + handles.push(handle); + } + + // This SHOULD timeout because of deadlock (both parents hold permits, + // both nested requests wait for permits) + let timed_out = Arc::new(AtomicBool::new(false)); + let timed_out_clone = timed_out.clone(); + + // Spawn a watchdog that sets timed_out after 100ms + let watchdog = smol::spawn(async move { + let start = Instant::now(); + while start.elapsed() < Duration::from_millis(100) { + smol::future::yield_now().await; + } + timed_out_clone.store(true, Ordering::SeqCst); + }); + + // Poll briefly to let everything run + let start = Instant::now(); + while start.elapsed() < Duration::from_millis(100) { + smol::future::yield_now().await; + } + + // Cancel the watchdog + drop(watchdog); + + // Expected - deadlock occurred, which proves the bypass is necessary + let count = completed.load(Ordering::SeqCst); + assert_eq!( + count, 0, + "Expected complete deadlock (0 completed) but {} requests completed", + count + ); + }); + } +} diff --git a/crates/language_model/src/request.rs b/crates/language_model/src/request.rs index 1d0b536ec43bd9e930c24e3a733448ae12a8d65b..9c3643e3471d52913d1defdb776365750b1870c6 100644 --- a/crates/language_model/src/request.rs +++ b/crates/language_model/src/request.rs @@ -451,6 +451,11 @@ pub struct LanguageModelRequest { pub stop: Vec, pub temperature: Option, pub thinking_allowed: bool, + /// When true, this request bypasses the rate limiter. Used for nested requests + /// (like edit agent requests spawned from within a tool call) that are already + /// "part of" a rate-limited request to avoid deadlocks. + #[serde(default)] + pub bypass_rate_limit: bool, } #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] diff --git a/crates/language_models/src/provider/anthropic.rs b/crates/language_models/src/provider/anthropic.rs index 3ea4a8ee37d76474714d6a4b875982e11f5c691b..d88205b267ada4883bb3e3b568ec9ac7cee57514 100644 --- a/crates/language_models/src/provider/anthropic.rs +++ b/crates/language_models/src/provider/anthropic.rs @@ -578,6 +578,7 @@ impl LanguageModel for AnthropicModel { LanguageModelCompletionError, >, > { + let bypass_rate_limit = request.bypass_rate_limit; let request = into_anthropic( request, self.model.request_id().into(), @@ -586,10 +587,13 @@ impl LanguageModel for AnthropicModel { self.model.mode(), ); let request = self.stream_completion(request, cx); - let future = self.request_limiter.stream(async move { - let response = request.await?; - Ok(AnthropicEventMapper::new().map_stream(response)) - }); + let future = self.request_limiter.stream_with_bypass( + async move { + let response = request.await?; + Ok(AnthropicEventMapper::new().map_stream(response)) + }, + bypass_rate_limit, + ); async move { Ok(future.await?.boxed()) }.boxed() } @@ -1164,6 +1168,7 @@ mod tests { tools: vec![], tool_choice: None, thinking_allowed: true, + bypass_rate_limit: false, }; let anthropic_request = into_anthropic( diff --git a/crates/language_models/src/provider/bedrock.rs b/crates/language_models/src/provider/bedrock.rs index 05c35fb047956740c86f7fc87f6691564360b4d5..d02ed461f8e0cbbefc17b5c44526849724b3a3bf 100644 --- a/crates/language_models/src/provider/bedrock.rs +++ b/crates/language_models/src/provider/bedrock.rs @@ -679,6 +679,7 @@ impl LanguageModel for BedrockModel { }; let deny_tool_calls = request.tool_choice == Some(LanguageModelToolChoice::None); + let bypass_rate_limit = request.bypass_rate_limit; let request = match into_bedrock( request, @@ -693,16 +694,19 @@ impl LanguageModel for BedrockModel { }; let request = self.stream_completion(request, cx); - let future = self.request_limiter.stream(async move { - let response = request.await.map_err(|err| anyhow!(err))?; - let events = map_to_language_model_completion_events(response); + let future = self.request_limiter.stream_with_bypass( + async move { + let response = request.await.map_err(|err| anyhow!(err))?; + let events = map_to_language_model_completion_events(response); - if deny_tool_calls { - Ok(deny_tool_use_events(events).boxed()) - } else { - Ok(events.boxed()) - } - }); + if deny_tool_calls { + Ok(deny_tool_use_events(events).boxed()) + } else { + Ok(events.boxed()) + } + }, + bypass_rate_limit, + ); async move { Ok(future.await?.boxed()) }.boxed() } diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index 4cef74cd325736944fd91612303897f8ea0062c5..a96a0618e15217c648836c89fad51725b16b2b43 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -719,6 +719,7 @@ impl LanguageModel for CloudLanguageModel { let thread_id = request.thread_id.clone(); let prompt_id = request.prompt_id.clone(); let intent = request.intent; + let bypass_rate_limit = request.bypass_rate_limit; let app_version = Some(cx.update(|cx| AppVersion::global(cx))); let use_responses_api = cx.update(|cx| cx.has_flag::()); let thinking_allowed = request.thinking_allowed; @@ -740,53 +741,8 @@ impl LanguageModel for CloudLanguageModel { ); let client = self.client.clone(); let llm_api_token = self.llm_api_token.clone(); - let future = self.request_limiter.stream(async move { - let PerformLlmCompletionResponse { - response, - includes_status_messages, - } = Self::perform_llm_completion( - client.clone(), - llm_api_token, - app_version, - CompletionBody { - thread_id, - prompt_id, - intent, - provider: cloud_llm_client::LanguageModelProvider::Anthropic, - model: request.model.clone(), - provider_request: serde_json::to_value(&request) - .map_err(|e| anyhow!(e))?, - }, - ) - .await - .map_err(|err| match err.downcast::() { - Ok(api_err) => anyhow!(LanguageModelCompletionError::from(api_err)), - Err(err) => anyhow!(err), - })?; - - let mut mapper = AnthropicEventMapper::new(); - Ok(map_cloud_completion_events( - Box::pin(response_lines(response, includes_status_messages)), - &provider_name, - move |event| mapper.map_event(event), - )) - }); - async move { Ok(future.await?.boxed()) }.boxed() - } - cloud_llm_client::LanguageModelProvider::OpenAi => { - let client = self.client.clone(); - let llm_api_token = self.llm_api_token.clone(); - - if use_responses_api { - let request = into_open_ai_response( - request, - &self.model.id.0, - self.model.supports_parallel_tool_calls, - true, - None, - None, - ); - let future = self.request_limiter.stream(async move { + let future = self.request_limiter.stream_with_bypass( + async move { let PerformLlmCompletionResponse { response, includes_status_messages, @@ -798,21 +754,74 @@ impl LanguageModel for CloudLanguageModel { thread_id, prompt_id, intent, - provider: cloud_llm_client::LanguageModelProvider::OpenAi, + provider: cloud_llm_client::LanguageModelProvider::Anthropic, model: request.model.clone(), provider_request: serde_json::to_value(&request) .map_err(|e| anyhow!(e))?, }, ) - .await?; - - let mut mapper = OpenAiResponseEventMapper::new(); + .await + .map_err(|err| { + match err.downcast::() { + Ok(api_err) => anyhow!(LanguageModelCompletionError::from(api_err)), + Err(err) => anyhow!(err), + } + })?; + + let mut mapper = AnthropicEventMapper::new(); Ok(map_cloud_completion_events( Box::pin(response_lines(response, includes_status_messages)), &provider_name, move |event| mapper.map_event(event), )) - }); + }, + bypass_rate_limit, + ); + async move { Ok(future.await?.boxed()) }.boxed() + } + cloud_llm_client::LanguageModelProvider::OpenAi => { + let client = self.client.clone(); + let llm_api_token = self.llm_api_token.clone(); + + if use_responses_api { + let request = into_open_ai_response( + request, + &self.model.id.0, + self.model.supports_parallel_tool_calls, + true, + None, + None, + ); + let future = self.request_limiter.stream_with_bypass( + async move { + let PerformLlmCompletionResponse { + response, + includes_status_messages, + } = Self::perform_llm_completion( + client.clone(), + llm_api_token, + app_version, + CompletionBody { + thread_id, + prompt_id, + intent, + provider: cloud_llm_client::LanguageModelProvider::OpenAi, + model: request.model.clone(), + provider_request: serde_json::to_value(&request) + .map_err(|e| anyhow!(e))?, + }, + ) + .await?; + + let mut mapper = OpenAiResponseEventMapper::new(); + Ok(map_cloud_completion_events( + Box::pin(response_lines(response, includes_status_messages)), + &provider_name, + move |event| mapper.map_event(event), + )) + }, + bypass_rate_limit, + ); async move { Ok(future.await?.boxed()) }.boxed() } else { let request = into_open_ai( @@ -823,7 +832,52 @@ impl LanguageModel for CloudLanguageModel { None, None, ); - let future = self.request_limiter.stream(async move { + let future = self.request_limiter.stream_with_bypass( + async move { + let PerformLlmCompletionResponse { + response, + includes_status_messages, + } = Self::perform_llm_completion( + client.clone(), + llm_api_token, + app_version, + CompletionBody { + thread_id, + prompt_id, + intent, + provider: cloud_llm_client::LanguageModelProvider::OpenAi, + model: request.model.clone(), + provider_request: serde_json::to_value(&request) + .map_err(|e| anyhow!(e))?, + }, + ) + .await?; + + let mut mapper = OpenAiEventMapper::new(); + Ok(map_cloud_completion_events( + Box::pin(response_lines(response, includes_status_messages)), + &provider_name, + move |event| mapper.map_event(event), + )) + }, + bypass_rate_limit, + ); + async move { Ok(future.await?.boxed()) }.boxed() + } + } + cloud_llm_client::LanguageModelProvider::XAi => { + let client = self.client.clone(); + let request = into_open_ai( + request, + &self.model.id.0, + self.model.supports_parallel_tool_calls, + false, + None, + None, + ); + let llm_api_token = self.llm_api_token.clone(); + let future = self.request_limiter.stream_with_bypass( + async move { let PerformLlmCompletionResponse { response, includes_status_messages, @@ -835,7 +889,7 @@ impl LanguageModel for CloudLanguageModel { thread_id, prompt_id, intent, - provider: cloud_llm_client::LanguageModelProvider::OpenAi, + provider: cloud_llm_client::LanguageModelProvider::XAi, model: request.model.clone(), provider_request: serde_json::to_value(&request) .map_err(|e| anyhow!(e))?, @@ -849,48 +903,9 @@ impl LanguageModel for CloudLanguageModel { &provider_name, move |event| mapper.map_event(event), )) - }); - async move { Ok(future.await?.boxed()) }.boxed() - } - } - cloud_llm_client::LanguageModelProvider::XAi => { - let client = self.client.clone(); - let request = into_open_ai( - request, - &self.model.id.0, - self.model.supports_parallel_tool_calls, - false, - None, - None, + }, + bypass_rate_limit, ); - let llm_api_token = self.llm_api_token.clone(); - let future = self.request_limiter.stream(async move { - let PerformLlmCompletionResponse { - response, - includes_status_messages, - } = Self::perform_llm_completion( - client.clone(), - llm_api_token, - app_version, - CompletionBody { - thread_id, - prompt_id, - intent, - provider: cloud_llm_client::LanguageModelProvider::XAi, - model: request.model.clone(), - provider_request: serde_json::to_value(&request) - .map_err(|e| anyhow!(e))?, - }, - ) - .await?; - - let mut mapper = OpenAiEventMapper::new(); - Ok(map_cloud_completion_events( - Box::pin(response_lines(response, includes_status_messages)), - &provider_name, - move |event| mapper.map_event(event), - )) - }); async move { Ok(future.await?.boxed()) }.boxed() } cloud_llm_client::LanguageModelProvider::Google => { @@ -898,33 +913,36 @@ impl LanguageModel for CloudLanguageModel { let request = into_google(request, self.model.id.to_string(), GoogleModelMode::Default); let llm_api_token = self.llm_api_token.clone(); - let future = self.request_limiter.stream(async move { - let PerformLlmCompletionResponse { - response, - includes_status_messages, - } = Self::perform_llm_completion( - client.clone(), - llm_api_token, - app_version, - CompletionBody { - thread_id, - prompt_id, - intent, - provider: cloud_llm_client::LanguageModelProvider::Google, - model: request.model.model_id.clone(), - provider_request: serde_json::to_value(&request) - .map_err(|e| anyhow!(e))?, - }, - ) - .await?; - - let mut mapper = GoogleEventMapper::new(); - Ok(map_cloud_completion_events( - Box::pin(response_lines(response, includes_status_messages)), - &provider_name, - move |event| mapper.map_event(event), - )) - }); + let future = self.request_limiter.stream_with_bypass( + async move { + let PerformLlmCompletionResponse { + response, + includes_status_messages, + } = Self::perform_llm_completion( + client.clone(), + llm_api_token, + app_version, + CompletionBody { + thread_id, + prompt_id, + intent, + provider: cloud_llm_client::LanguageModelProvider::Google, + model: request.model.model_id.clone(), + provider_request: serde_json::to_value(&request) + .map_err(|e| anyhow!(e))?, + }, + ) + .await?; + + let mut mapper = GoogleEventMapper::new(); + Ok(map_cloud_completion_events( + Box::pin(response_lines(response, includes_status_messages)), + &provider_name, + move |event| mapper.map_event(event), + )) + }, + bypass_rate_limit, + ); async move { Ok(future.await?.boxed()) }.boxed() } } diff --git a/crates/language_models/src/provider/copilot_chat.rs b/crates/language_models/src/provider/copilot_chat.rs index 76685e076982f2007d68ac5ae22e469a933cf80e..8c9b84417d33f823da80d221072a766d48bc59ce 100644 --- a/crates/language_models/src/provider/copilot_chat.rs +++ b/crates/language_models/src/provider/copilot_chat.rs @@ -307,6 +307,7 @@ impl LanguageModel for CopilotChatLanguageModel { LanguageModelCompletionError, >, > { + let bypass_rate_limit = request.bypass_rate_limit; let is_user_initiated = request.intent.is_none_or(|intent| match intent { CompletionIntent::UserPrompt | CompletionIntent::ThreadContextSummarization @@ -327,11 +328,14 @@ impl LanguageModel for CopilotChatLanguageModel { let request = CopilotChat::stream_response(responses_request, is_user_initiated, cx.clone()); request_limiter - .stream(async move { - let stream = request.await?; - let mapper = CopilotResponsesEventMapper::new(); - Ok(mapper.map_stream(stream).boxed()) - }) + .stream_with_bypass( + async move { + let stream = request.await?; + let mapper = CopilotResponsesEventMapper::new(); + Ok(mapper.map_stream(stream).boxed()) + }, + bypass_rate_limit, + ) .await }); return async move { Ok(future.await?.boxed()) }.boxed(); @@ -348,13 +352,16 @@ impl LanguageModel for CopilotChatLanguageModel { let request = CopilotChat::stream_completion(copilot_request, is_user_initiated, cx.clone()); request_limiter - .stream(async move { - let response = request.await?; - Ok(map_to_language_model_completion_events( - response, - is_streaming, - )) - }) + .stream_with_bypass( + async move { + let response = request.await?; + Ok(map_to_language_model_completion_events( + response, + is_streaming, + )) + }, + bypass_rate_limit, + ) .await }); async move { Ok(future.await?.boxed()) }.boxed() @@ -929,6 +936,7 @@ fn into_copilot_responses( stop: _, temperature, thinking_allowed: _, + bypass_rate_limit: _, } = request; let mut input_items: Vec = Vec::new(); diff --git a/crates/language_models/src/provider/deepseek.rs b/crates/language_models/src/provider/deepseek.rs index ea623d2cf24f26ce32e8d1fd309ac747e469096e..18f36641d82cd8d993dcb61a9576e089a5a89fda 100644 --- a/crates/language_models/src/provider/deepseek.rs +++ b/crates/language_models/src/provider/deepseek.rs @@ -199,6 +199,7 @@ impl DeepSeekLanguageModel { fn stream_completion( &self, request: deepseek::Request, + bypass_rate_limit: bool, cx: &AsyncApp, ) -> BoxFuture<'static, Result>>> { let http_client = self.http_client.clone(); @@ -208,17 +209,20 @@ impl DeepSeekLanguageModel { (state.api_key_state.key(&api_url), api_url) }); - let future = self.request_limiter.stream(async move { - let Some(api_key) = api_key else { - return Err(LanguageModelCompletionError::NoApiKey { - provider: PROVIDER_NAME, - }); - }; - let request = - deepseek::stream_completion(http_client.as_ref(), &api_url, &api_key, request); - let response = request.await?; - Ok(response) - }); + let future = self.request_limiter.stream_with_bypass( + async move { + let Some(api_key) = api_key else { + return Err(LanguageModelCompletionError::NoApiKey { + provider: PROVIDER_NAME, + }); + }; + let request = + deepseek::stream_completion(http_client.as_ref(), &api_url, &api_key, request); + let response = request.await?; + Ok(response) + }, + bypass_rate_limit, + ); async move { Ok(future.await?.boxed()) }.boxed() } @@ -302,8 +306,9 @@ impl LanguageModel for DeepSeekLanguageModel { LanguageModelCompletionError, >, > { + let bypass_rate_limit = request.bypass_rate_limit; let request = into_deepseek(request, &self.model, self.max_output_tokens()); - let stream = self.stream_completion(request, cx); + let stream = self.stream_completion(request, bypass_rate_limit, cx); async move { let mapper = DeepSeekEventMapper::new(); diff --git a/crates/language_models/src/provider/google.rs b/crates/language_models/src/provider/google.rs index 2f5b3b3701d51e4f4faadae0f8ef83f8bf6b5b2f..5e3700496e23ea8f4d3f9cec1fdda31ea44de0ce 100644 --- a/crates/language_models/src/provider/google.rs +++ b/crates/language_models/src/provider/google.rs @@ -370,16 +370,20 @@ impl LanguageModel for GoogleLanguageModel { LanguageModelCompletionError, >, > { + let bypass_rate_limit = request.bypass_rate_limit; let request = into_google( request, self.model.request_id().to_string(), self.model.mode(), ); let request = self.stream_completion(request, cx); - let future = self.request_limiter.stream(async move { - let response = request.await.map_err(LanguageModelCompletionError::from)?; - Ok(GoogleEventMapper::new().map_stream(response)) - }); + let future = self.request_limiter.stream_with_bypass( + async move { + let response = request.await.map_err(LanguageModelCompletionError::from)?; + Ok(GoogleEventMapper::new().map_stream(response)) + }, + bypass_rate_limit, + ); async move { Ok(future.await?.boxed()) }.boxed() } } diff --git a/crates/language_models/src/provider/lmstudio.rs b/crates/language_models/src/provider/lmstudio.rs index 041dfedf86e4195d98689d4f06031b32fb162e51..f55257d8b22e86968e4b407ede1658d0210e1cf0 100644 --- a/crates/language_models/src/provider/lmstudio.rs +++ b/crates/language_models/src/provider/lmstudio.rs @@ -370,6 +370,7 @@ impl LmStudioLanguageModel { fn stream_completion( &self, request: lmstudio::ChatCompletionRequest, + bypass_rate_limit: bool, cx: &AsyncApp, ) -> BoxFuture< 'static, @@ -381,11 +382,15 @@ impl LmStudioLanguageModel { settings.api_url.clone() }); - let future = self.request_limiter.stream(async move { - let request = lmstudio::stream_chat_completion(http_client.as_ref(), &api_url, request); - let response = request.await?; - Ok(response) - }); + let future = self.request_limiter.stream_with_bypass( + async move { + let request = + lmstudio::stream_chat_completion(http_client.as_ref(), &api_url, request); + let response = request.await?; + Ok(response) + }, + bypass_rate_limit, + ); async move { Ok(future.await?.boxed()) }.boxed() } @@ -460,8 +465,9 @@ impl LanguageModel for LmStudioLanguageModel { LanguageModelCompletionError, >, > { + let bypass_rate_limit = request.bypass_rate_limit; let request = self.to_lmstudio_request(request); - let completions = self.stream_completion(request, cx); + let completions = self.stream_completion(request, bypass_rate_limit, cx); async move { let mapper = LmStudioEventMapper::new(); Ok(mapper.map_stream(completions.await?).boxed()) diff --git a/crates/language_models/src/provider/mistral.rs b/crates/language_models/src/provider/mistral.rs index c8c34d7d2942ca1b42613d8733dc2219800bd66c..1e6cfee116c355875ab97de1096b2aa26cca53d9 100644 --- a/crates/language_models/src/provider/mistral.rs +++ b/crates/language_models/src/provider/mistral.rs @@ -264,6 +264,7 @@ impl MistralLanguageModel { fn stream_completion( &self, request: mistral::Request, + bypass_rate_limit: bool, cx: &AsyncApp, ) -> BoxFuture< 'static, @@ -276,17 +277,20 @@ impl MistralLanguageModel { (state.api_key_state.key(&api_url), api_url) }); - let future = self.request_limiter.stream(async move { - let Some(api_key) = api_key else { - return Err(LanguageModelCompletionError::NoApiKey { - provider: PROVIDER_NAME, - }); - }; - let request = - mistral::stream_completion(http_client.as_ref(), &api_url, &api_key, request); - let response = request.await?; - Ok(response) - }); + let future = self.request_limiter.stream_with_bypass( + async move { + let Some(api_key) = api_key else { + return Err(LanguageModelCompletionError::NoApiKey { + provider: PROVIDER_NAME, + }); + }; + let request = + mistral::stream_completion(http_client.as_ref(), &api_url, &api_key, request); + let response = request.await?; + Ok(response) + }, + bypass_rate_limit, + ); async move { Ok(future.await?.boxed()) }.boxed() } @@ -370,8 +374,9 @@ impl LanguageModel for MistralLanguageModel { LanguageModelCompletionError, >, > { + let bypass_rate_limit = request.bypass_rate_limit; let request = into_mistral(request, self.model.clone(), self.max_output_tokens()); - let stream = self.stream_completion(request, cx); + let stream = self.stream_completion(request, bypass_rate_limit, cx); async move { let stream = stream.await?; @@ -901,6 +906,7 @@ mod tests { intent: None, stop: vec![], thinking_allowed: true, + bypass_rate_limit: false, }; let mistral_request = into_mistral(request, mistral::Model::MistralSmallLatest, None); @@ -934,6 +940,7 @@ mod tests { intent: None, stop: vec![], thinking_allowed: true, + bypass_rate_limit: false, }; let mistral_request = into_mistral(request, mistral::Model::Pixtral12BLatest, None); diff --git a/crates/language_models/src/provider/ollama.rs b/crates/language_models/src/provider/ollama.rs index 79ac51a870782817723c7c32253946068d6570e3..8c9df749e9af17631a85a5934ac9ef3211bff55a 100644 --- a/crates/language_models/src/provider/ollama.rs +++ b/crates/language_models/src/provider/ollama.rs @@ -480,6 +480,7 @@ impl LanguageModel for OllamaLanguageModel { LanguageModelCompletionError, >, > { + let bypass_rate_limit = request.bypass_rate_limit; let request = self.to_ollama_request(request); let http_client = self.http_client.clone(); @@ -488,13 +489,20 @@ impl LanguageModel for OllamaLanguageModel { (state.api_key_state.key(&api_url), api_url) }); - let future = self.request_limiter.stream(async move { - let stream = - stream_chat_completion(http_client.as_ref(), &api_url, api_key.as_deref(), request) - .await?; - let stream = map_to_language_model_completion_events(stream); - Ok(stream) - }); + let future = self.request_limiter.stream_with_bypass( + async move { + let stream = stream_chat_completion( + http_client.as_ref(), + &api_url, + api_key.as_deref(), + request, + ) + .await?; + let stream = map_to_language_model_completion_events(stream); + Ok(stream) + }, + bypass_rate_limit, + ); future.map_ok(|f| f.boxed()).boxed() } diff --git a/crates/language_models/src/provider/open_ai.rs b/crates/language_models/src/provider/open_ai.rs index 915fd325494f523f1a1d64ef17cec63bf29cc44d..fde221f86f270cd5955adccc4588ea5811e02ad8 100644 --- a/crates/language_models/src/provider/open_ai.rs +++ b/crates/language_models/src/provider/open_ai.rs @@ -213,6 +213,7 @@ impl OpenAiLanguageModel { fn stream_completion( &self, request: open_ai::Request, + bypass_rate_limit: bool, cx: &AsyncApp, ) -> BoxFuture<'static, Result>>> { @@ -223,21 +224,24 @@ impl OpenAiLanguageModel { (state.api_key_state.key(&api_url), api_url) }); - let future = self.request_limiter.stream(async move { - let provider = PROVIDER_NAME; - let Some(api_key) = api_key else { - return Err(LanguageModelCompletionError::NoApiKey { provider }); - }; - let request = stream_completion( - http_client.as_ref(), - provider.0.as_str(), - &api_url, - &api_key, - request, - ); - let response = request.await?; - Ok(response) - }); + let future = self.request_limiter.stream_with_bypass( + async move { + let provider = PROVIDER_NAME; + let Some(api_key) = api_key else { + return Err(LanguageModelCompletionError::NoApiKey { provider }); + }; + let request = stream_completion( + http_client.as_ref(), + provider.0.as_str(), + &api_url, + &api_key, + request, + ); + let response = request.await?; + Ok(response) + }, + bypass_rate_limit, + ); async move { Ok(future.await?.boxed()) }.boxed() } @@ -245,6 +249,7 @@ impl OpenAiLanguageModel { fn stream_response( &self, request: ResponseRequest, + bypass_rate_limit: bool, cx: &AsyncApp, ) -> BoxFuture<'static, Result>>> { @@ -256,20 +261,23 @@ impl OpenAiLanguageModel { }); let provider = PROVIDER_NAME; - let future = self.request_limiter.stream(async move { - let Some(api_key) = api_key else { - return Err(LanguageModelCompletionError::NoApiKey { provider }); - }; - let request = stream_response( - http_client.as_ref(), - provider.0.as_str(), - &api_url, - &api_key, - request, - ); - let response = request.await?; - Ok(response) - }); + let future = self.request_limiter.stream_with_bypass( + async move { + let Some(api_key) = api_key else { + return Err(LanguageModelCompletionError::NoApiKey { provider }); + }; + let request = stream_response( + http_client.as_ref(), + provider.0.as_str(), + &api_url, + &api_key, + request, + ); + let response = request.await?; + Ok(response) + }, + bypass_rate_limit, + ); async move { Ok(future.await?.boxed()) }.boxed() } @@ -368,6 +376,7 @@ impl LanguageModel for OpenAiLanguageModel { LanguageModelCompletionError, >, > { + let bypass_rate_limit = request.bypass_rate_limit; if self.model.supports_chat_completions() { let request = into_open_ai( request, @@ -377,7 +386,7 @@ impl LanguageModel for OpenAiLanguageModel { self.max_output_tokens(), self.model.reasoning_effort(), ); - let completions = self.stream_completion(request, cx); + let completions = self.stream_completion(request, bypass_rate_limit, cx); async move { let mapper = OpenAiEventMapper::new(); Ok(mapper.map_stream(completions.await?).boxed()) @@ -392,7 +401,7 @@ impl LanguageModel for OpenAiLanguageModel { self.max_output_tokens(), self.model.reasoning_effort(), ); - let completions = self.stream_response(request, cx); + let completions = self.stream_response(request, bypass_rate_limit, cx); async move { let mapper = OpenAiResponseEventMapper::new(); Ok(mapper.map_stream(completions.await?).boxed()) @@ -545,6 +554,7 @@ pub fn into_open_ai_response( stop: _, temperature, thinking_allowed: _, + bypass_rate_limit: _, } = request; let mut input_items = Vec::new(); @@ -1417,6 +1427,7 @@ mod tests { stop: vec![], temperature: None, thinking_allowed: true, + bypass_rate_limit: false, }; // Validate that all models are supported by tiktoken-rs @@ -1553,6 +1564,7 @@ mod tests { stop: vec!["".into()], temperature: None, thinking_allowed: false, + bypass_rate_limit: false, }; let response = into_open_ai_response( diff --git a/crates/language_models/src/provider/open_ai_compatible.rs b/crates/language_models/src/provider/open_ai_compatible.rs index d47ea26c594ab0abb5c859ed549d43e0ed3f859b..0f949740ff8a6b9b006af6cdccce52613fbeddfa 100644 --- a/crates/language_models/src/provider/open_ai_compatible.rs +++ b/crates/language_models/src/provider/open_ai_compatible.rs @@ -204,6 +204,7 @@ impl OpenAiCompatibleLanguageModel { fn stream_completion( &self, request: open_ai::Request, + bypass_rate_limit: bool, cx: &AsyncApp, ) -> BoxFuture< 'static, @@ -223,20 +224,23 @@ impl OpenAiCompatibleLanguageModel { }); let provider = self.provider_name.clone(); - let future = self.request_limiter.stream(async move { - let Some(api_key) = api_key else { - return Err(LanguageModelCompletionError::NoApiKey { provider }); - }; - let request = stream_completion( - http_client.as_ref(), - provider.0.as_str(), - &api_url, - &api_key, - request, - ); - let response = request.await?; - Ok(response) - }); + let future = self.request_limiter.stream_with_bypass( + async move { + let Some(api_key) = api_key else { + return Err(LanguageModelCompletionError::NoApiKey { provider }); + }; + let request = stream_completion( + http_client.as_ref(), + provider.0.as_str(), + &api_url, + &api_key, + request, + ); + let response = request.await?; + Ok(response) + }, + bypass_rate_limit, + ); async move { Ok(future.await?.boxed()) }.boxed() } @@ -244,6 +248,7 @@ impl OpenAiCompatibleLanguageModel { fn stream_response( &self, request: ResponseRequest, + bypass_rate_limit: bool, cx: &AsyncApp, ) -> BoxFuture<'static, Result>>> { @@ -258,20 +263,23 @@ impl OpenAiCompatibleLanguageModel { }); let provider = self.provider_name.clone(); - let future = self.request_limiter.stream(async move { - let Some(api_key) = api_key else { - return Err(LanguageModelCompletionError::NoApiKey { provider }); - }; - let request = stream_response( - http_client.as_ref(), - provider.0.as_str(), - &api_url, - &api_key, - request, - ); - let response = request.await?; - Ok(response) - }); + let future = self.request_limiter.stream_with_bypass( + async move { + let Some(api_key) = api_key else { + return Err(LanguageModelCompletionError::NoApiKey { provider }); + }; + let request = stream_response( + http_client.as_ref(), + provider.0.as_str(), + &api_url, + &api_key, + request, + ); + let response = request.await?; + Ok(response) + }, + bypass_rate_limit, + ); async move { Ok(future.await?.boxed()) }.boxed() } @@ -370,6 +378,7 @@ impl LanguageModel for OpenAiCompatibleLanguageModel { LanguageModelCompletionError, >, > { + let bypass_rate_limit = request.bypass_rate_limit; if self.model.capabilities.chat_completions { let request = into_open_ai( request, @@ -379,7 +388,7 @@ impl LanguageModel for OpenAiCompatibleLanguageModel { self.max_output_tokens(), None, ); - let completions = self.stream_completion(request, cx); + let completions = self.stream_completion(request, bypass_rate_limit, cx); async move { let mapper = OpenAiEventMapper::new(); Ok(mapper.map_stream(completions.await?).boxed()) @@ -394,7 +403,7 @@ impl LanguageModel for OpenAiCompatibleLanguageModel { self.max_output_tokens(), None, ); - let completions = self.stream_response(request, cx); + let completions = self.stream_response(request, bypass_rate_limit, cx); async move { let mapper = OpenAiResponseEventMapper::new(); Ok(mapper.map_stream(completions.await?).boxed()) diff --git a/crates/language_models/src/provider/open_router.rs b/crates/language_models/src/provider/open_router.rs index 273b45ea23f76936a41584c9c58cd3c73c5c4967..1cdb84437574a82168fea83c46d2c55d33c7c22e 100644 --- a/crates/language_models/src/provider/open_router.rs +++ b/crates/language_models/src/provider/open_router.rs @@ -368,12 +368,16 @@ impl LanguageModel for OpenRouterLanguageModel { LanguageModelCompletionError, >, > { + let bypass_rate_limit = request.bypass_rate_limit; let openrouter_request = into_open_router(request, &self.model, self.max_output_tokens()); let request = self.stream_completion(openrouter_request, cx); - let future = self.request_limiter.stream(async move { - let response = request.await?; - Ok(OpenRouterEventMapper::new().map_stream(response)) - }); + let future = self.request_limiter.stream_with_bypass( + async move { + let response = request.await?; + Ok(OpenRouterEventMapper::new().map_stream(response)) + }, + bypass_rate_limit, + ); async move { Ok(future.await?.boxed()) }.boxed() } } diff --git a/crates/language_models/src/provider/vercel.rs b/crates/language_models/src/provider/vercel.rs index 3b324e46927f5864d83a5e4b74c46f5e39e8ab3a..3b84434ed4aa2431662de176666da7aeaeb50392 100644 --- a/crates/language_models/src/provider/vercel.rs +++ b/crates/language_models/src/provider/vercel.rs @@ -193,6 +193,7 @@ impl VercelLanguageModel { fn stream_completion( &self, request: open_ai::Request, + bypass_rate_limit: bool, cx: &AsyncApp, ) -> BoxFuture<'static, Result>>> { @@ -203,21 +204,24 @@ impl VercelLanguageModel { (state.api_key_state.key(&api_url), api_url) }); - let future = self.request_limiter.stream(async move { - let provider = PROVIDER_NAME; - let Some(api_key) = api_key else { - return Err(LanguageModelCompletionError::NoApiKey { provider }); - }; - let request = open_ai::stream_completion( - http_client.as_ref(), - provider.0.as_str(), - &api_url, - &api_key, - request, - ); - let response = request.await?; - Ok(response) - }); + let future = self.request_limiter.stream_with_bypass( + async move { + let provider = PROVIDER_NAME; + let Some(api_key) = api_key else { + return Err(LanguageModelCompletionError::NoApiKey { provider }); + }; + let request = open_ai::stream_completion( + http_client.as_ref(), + provider.0.as_str(), + &api_url, + &api_key, + request, + ); + let response = request.await?; + Ok(response) + }, + bypass_rate_limit, + ); async move { Ok(future.await?.boxed()) }.boxed() } @@ -290,6 +294,7 @@ impl LanguageModel for VercelLanguageModel { LanguageModelCompletionError, >, > { + let bypass_rate_limit = request.bypass_rate_limit; let request = crate::provider::open_ai::into_open_ai( request, self.model.id(), @@ -298,7 +303,7 @@ impl LanguageModel for VercelLanguageModel { self.max_output_tokens(), None, ); - let completions = self.stream_completion(request, cx); + let completions = self.stream_completion(request, bypass_rate_limit, cx); async move { let mapper = crate::provider::open_ai::OpenAiEventMapper::new(); Ok(mapper.map_stream(completions.await?).boxed()) diff --git a/crates/language_models/src/provider/x_ai.rs b/crates/language_models/src/provider/x_ai.rs index 06564224dea9621d594e5cf3f4a84093f1620446..c77c2a8aca9a7d2d79d92763d5ab31dc7bb3eb1b 100644 --- a/crates/language_models/src/provider/x_ai.rs +++ b/crates/language_models/src/provider/x_ai.rs @@ -197,6 +197,7 @@ impl XAiLanguageModel { fn stream_completion( &self, request: open_ai::Request, + bypass_rate_limit: bool, cx: &AsyncApp, ) -> BoxFuture< 'static, @@ -212,21 +213,24 @@ impl XAiLanguageModel { (state.api_key_state.key(&api_url), api_url) }); - let future = self.request_limiter.stream(async move { - let provider = PROVIDER_NAME; - let Some(api_key) = api_key else { - return Err(LanguageModelCompletionError::NoApiKey { provider }); - }; - let request = open_ai::stream_completion( - http_client.as_ref(), - provider.0.as_str(), - &api_url, - &api_key, - request, - ); - let response = request.await?; - Ok(response) - }); + let future = self.request_limiter.stream_with_bypass( + async move { + let provider = PROVIDER_NAME; + let Some(api_key) = api_key else { + return Err(LanguageModelCompletionError::NoApiKey { provider }); + }; + let request = open_ai::stream_completion( + http_client.as_ref(), + provider.0.as_str(), + &api_url, + &api_key, + request, + ); + let response = request.await?; + Ok(response) + }, + bypass_rate_limit, + ); async move { Ok(future.await?.boxed()) }.boxed() } @@ -307,6 +311,7 @@ impl LanguageModel for XAiLanguageModel { LanguageModelCompletionError, >, > { + let bypass_rate_limit = request.bypass_rate_limit; let request = crate::provider::open_ai::into_open_ai( request, self.model.id(), @@ -315,7 +320,7 @@ impl LanguageModel for XAiLanguageModel { self.max_output_tokens(), None, ); - let completions = self.stream_completion(request, cx); + let completions = self.stream_completion(request, bypass_rate_limit, cx); async move { let mapper = crate::provider::open_ai::OpenAiEventMapper::new(); Ok(mapper.map_stream(completions.await?).boxed()) diff --git a/crates/rules_library/src/rules_library.rs b/crates/rules_library/src/rules_library.rs index 09089a6bcba83b4159b346c0e9da2dfd53289389..320016f200e9af6f04d059b6b8064ea82f204d6d 100644 --- a/crates/rules_library/src/rules_library.rs +++ b/crates/rules_library/src/rules_library.rs @@ -1100,6 +1100,7 @@ impl RulesLibrary { stop: Vec::new(), temperature: None, thinking_allowed: true, + bypass_rate_limit: false, }, cx, )