Detailed changes
@@ -732,6 +732,10 @@ impl EditAgent {
stop: Vec::new(),
temperature: None,
thinking_allowed: true,
+ // Bypass the rate limiter for nested requests (edit agent requests spawned
+ // from within a tool call) to avoid deadlocks when multiple subagents try
+ // to use edit_file simultaneously.
+ bypass_rate_limit: true,
};
Ok(self.model.stream_completion_text(request, cx).await?.stream)
@@ -0,0 +1,618 @@
+use super::*;
+use acp_thread::UserMessageId;
+use action_log::ActionLog;
+use fs::FakeFs;
+use language_model::{
+ LanguageModelCompletionEvent, LanguageModelToolUse, MessageContent, StopReason,
+ fake_provider::FakeLanguageModel,
+};
+use prompt_store::ProjectContext;
+use serde_json::json;
+use std::{collections::BTreeMap, sync::Arc, time::Duration};
+use util::path;
+
+#[gpui::test]
+async fn test_edit_file_tool_in_thread_context(cx: &mut TestAppContext) {
+ // This test verifies that the edit_file tool works correctly when invoked
+ // through the full thread flow (model sends ToolUse event -> tool runs -> result sent back).
+ // This is different from tests that call tool.run() directly.
+ super::init_test(cx);
+ super::always_allow_tools(cx);
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ path!("/project"),
+ json!({
+ "src": {
+ "main.rs": "fn main() {\n println!(\"Hello, world!\");\n}\n"
+ }
+ }),
+ )
+ .await;
+
+ let project = project::Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
+ let project_context = cx.new(|_cx| ProjectContext::default());
+ let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
+ let context_server_registry =
+ cx.new(|cx| crate::ContextServerRegistry::new(context_server_store.clone(), cx));
+ let model = Arc::new(FakeLanguageModel::default());
+ let fake_model = model.as_fake();
+
+ let thread = cx.new(|cx| {
+ let mut thread = crate::Thread::new(
+ project.clone(),
+ project_context,
+ context_server_registry,
+ crate::Templates::new(),
+ Some(model.clone()),
+ cx,
+ );
+ // Add just the tools we need for this test
+ let language_registry = project.read(cx).languages().clone();
+ thread.add_tool(crate::ReadFileTool::new(
+ cx.weak_entity(),
+ project.clone(),
+ thread.action_log().clone(),
+ ));
+ thread.add_tool(crate::EditFileTool::new(
+ project.clone(),
+ cx.weak_entity(),
+ language_registry,
+ crate::Templates::new(),
+ ));
+ thread
+ });
+
+ // First, read the file so the thread knows about its contents
+ let _events = thread
+ .update(cx, |thread, cx| {
+ thread.send(UserMessageId::new(), ["Read the file src/main.rs"], cx)
+ })
+ .unwrap();
+ cx.run_until_parked();
+
+ // Model calls read_file tool
+ let read_tool_use = LanguageModelToolUse {
+ id: "read_tool_1".into(),
+ name: "read_file".into(),
+ raw_input: json!({"path": "project/src/main.rs"}).to_string(),
+ input: json!({"path": "project/src/main.rs"}),
+ is_input_complete: true,
+ thought_signature: None,
+ };
+ fake_model
+ .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(read_tool_use));
+ fake_model
+ .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse));
+ fake_model.end_last_completion_stream();
+ cx.run_until_parked();
+
+ // Wait for the read tool to complete and model to be called again
+ while fake_model.pending_completions().is_empty() {
+ cx.run_until_parked();
+ }
+
+ // Model responds after seeing the file content, then calls edit_file
+ fake_model.send_last_completion_stream_text_chunk("I'll edit the file now.");
+ let edit_tool_use = LanguageModelToolUse {
+ id: "edit_tool_1".into(),
+ name: "edit_file".into(),
+ raw_input: json!({
+ "display_description": "Change greeting message",
+ "path": "project/src/main.rs",
+ "mode": "edit"
+ })
+ .to_string(),
+ input: json!({
+ "display_description": "Change greeting message",
+ "path": "project/src/main.rs",
+ "mode": "edit"
+ }),
+ is_input_complete: true,
+ thought_signature: None,
+ };
+ fake_model
+ .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(edit_tool_use));
+ fake_model
+ .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse));
+ fake_model.end_last_completion_stream();
+ cx.run_until_parked();
+
+ // The edit_file tool creates an EditAgent which makes its own model request.
+ // We need to respond to that request with the edit instructions.
+ // Wait for the edit agent's completion request
+ let deadline = std::time::Instant::now() + Duration::from_secs(5);
+ while fake_model.pending_completions().is_empty() {
+ if std::time::Instant::now() >= deadline {
+ panic!(
+ "Timed out waiting for edit agent completion request. Pending: {}",
+ fake_model.pending_completions().len()
+ );
+ }
+ cx.run_until_parked();
+ cx.background_executor
+ .timer(Duration::from_millis(10))
+ .await;
+ }
+
+ // Send the edit agent's response with the XML format it expects
+ let edit_response = "<old_text>println!(\"Hello, world!\");</old_text>\n<new_text>println!(\"Hello, Zed!\");</new_text>";
+ fake_model.send_last_completion_stream_text_chunk(edit_response);
+ fake_model.end_last_completion_stream();
+ cx.run_until_parked();
+
+ // Wait for the edit to complete and the thread to call the model again with tool results
+ let deadline = std::time::Instant::now() + Duration::from_secs(5);
+ while fake_model.pending_completions().is_empty() {
+ if std::time::Instant::now() >= deadline {
+ panic!("Timed out waiting for model to be called after edit completion");
+ }
+ cx.run_until_parked();
+ cx.background_executor
+ .timer(Duration::from_millis(10))
+ .await;
+ }
+
+ // Verify the file was edited
+ let file_content = fs
+ .load(path!("/project/src/main.rs").as_ref())
+ .await
+ .expect("file should exist");
+ assert!(
+ file_content.contains("Hello, Zed!"),
+ "File should have been edited. Content: {}",
+ file_content
+ );
+ assert!(
+ !file_content.contains("Hello, world!"),
+ "Old content should be replaced. Content: {}",
+ file_content
+ );
+
+ // Verify the tool result was sent back to the model
+ let pending = fake_model.pending_completions();
+ assert!(
+ !pending.is_empty(),
+ "Model should have been called with tool result"
+ );
+
+ let last_request = pending.last().unwrap();
+ let has_tool_result = last_request.messages.iter().any(|m| {
+ m.content
+ .iter()
+ .any(|c| matches!(c, language_model::MessageContent::ToolResult(_)))
+ });
+ assert!(
+ has_tool_result,
+ "Tool result should be in the messages sent back to the model"
+ );
+
+ // Complete the turn
+ fake_model.send_last_completion_stream_text_chunk("I've updated the greeting message.");
+ fake_model
+ .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
+ fake_model.end_last_completion_stream();
+ cx.run_until_parked();
+
+ // Verify the thread completed successfully
+ thread.update(cx, |thread, _cx| {
+ assert!(
+ thread.is_turn_complete(),
+ "Thread should be complete after the turn ends"
+ );
+ });
+}
+
+#[gpui::test]
+async fn test_subagent_uses_read_file_tool(cx: &mut TestAppContext) {
+ // This test verifies that subagents can successfully use the read_file tool
+ // through the full thread flow, and that tools are properly rebound to use
+ // the subagent's thread ID instead of the parent's.
+ super::init_test(cx);
+ super::always_allow_tools(cx);
+
+ cx.update(|cx| {
+ cx.update_flags(true, vec!["subagents".to_string()]);
+ });
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ path!("/project"),
+ json!({
+ "src": {
+ "lib.rs": "pub fn hello() -> &'static str {\n \"Hello from lib!\"\n}\n"
+ }
+ }),
+ )
+ .await;
+
+ let project = project::Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
+ let project_context = cx.new(|_cx| ProjectContext::default());
+ let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
+ let context_server_registry =
+ cx.new(|cx| crate::ContextServerRegistry::new(context_server_store.clone(), cx));
+ let model = Arc::new(FakeLanguageModel::default());
+ let fake_model = model.as_fake();
+
+ // Create subagent context
+ let subagent_context = crate::SubagentContext {
+ parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
+ tool_use_id: language_model::LanguageModelToolUseId::from("subagent-tool-use-id"),
+ depth: 1,
+ summary_prompt: "Summarize what you found".to_string(),
+ context_low_prompt: "Context low".to_string(),
+ };
+
+ // Create parent tools that will be passed to the subagent
+ // This simulates how the subagent_tool passes tools to new_subagent
+ let parent_tools: BTreeMap<gpui::SharedString, std::sync::Arc<dyn crate::AnyAgentTool>> = {
+ let action_log = cx.new(|_| ActionLog::new(project.clone()));
+ // Create a "fake" parent thread reference - this should get rebound
+ let fake_parent_thread = cx.new(|cx| {
+ crate::Thread::new(
+ project.clone(),
+ cx.new(|_cx| ProjectContext::default()),
+ cx.new(|cx| crate::ContextServerRegistry::new(context_server_store.clone(), cx)),
+ crate::Templates::new(),
+ Some(model.clone()),
+ cx,
+ )
+ });
+ let mut tools: BTreeMap<gpui::SharedString, std::sync::Arc<dyn crate::AnyAgentTool>> =
+ BTreeMap::new();
+ tools.insert(
+ "read_file".into(),
+ crate::ReadFileTool::new(fake_parent_thread.downgrade(), project.clone(), action_log)
+ .erase(),
+ );
+ tools
+ };
+
+ // Create subagent - tools should be rebound to use subagent's thread
+ let subagent = cx.new(|cx| {
+ crate::Thread::new_subagent(
+ project.clone(),
+ project_context,
+ context_server_registry,
+ crate::Templates::new(),
+ model.clone(),
+ subagent_context,
+ parent_tools,
+ cx,
+ )
+ });
+
+ // Get the subagent's thread ID
+ let _subagent_thread_id = subagent.read_with(cx, |thread, _| thread.id().to_string());
+
+ // Verify the subagent has the read_file tool
+ subagent.read_with(cx, |thread, _| {
+ assert!(
+ thread.has_registered_tool("read_file"),
+ "subagent should have read_file tool"
+ );
+ });
+
+ // Submit a user message to the subagent
+ subagent
+ .update(cx, |thread, cx| {
+ thread.submit_user_message("Read the file src/lib.rs", cx)
+ })
+ .unwrap();
+ cx.run_until_parked();
+
+ // Simulate the model calling the read_file tool
+ let read_tool_use = LanguageModelToolUse {
+ id: "read_tool_1".into(),
+ name: "read_file".into(),
+ raw_input: json!({"path": "project/src/lib.rs"}).to_string(),
+ input: json!({"path": "project/src/lib.rs"}),
+ is_input_complete: true,
+ thought_signature: None,
+ };
+ fake_model
+ .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(read_tool_use));
+ fake_model.end_last_completion_stream();
+ cx.run_until_parked();
+
+ // Wait for the tool to complete and the model to be called again with tool results
+ let deadline = std::time::Instant::now() + Duration::from_secs(5);
+ while fake_model.pending_completions().is_empty() {
+ if std::time::Instant::now() >= deadline {
+ panic!("Timed out waiting for model to be called after read_file tool completion");
+ }
+ cx.run_until_parked();
+ cx.background_executor
+ .timer(Duration::from_millis(10))
+ .await;
+ }
+
+ // Verify the tool result was sent back to the model
+ let pending = fake_model.pending_completions();
+ assert!(
+ !pending.is_empty(),
+ "Model should have been called with tool result"
+ );
+
+ let last_request = pending.last().unwrap();
+ let tool_result = last_request.messages.iter().find_map(|m| {
+ m.content.iter().find_map(|c| match c {
+ MessageContent::ToolResult(result) => Some(result),
+ _ => None,
+ })
+ });
+ assert!(
+ tool_result.is_some(),
+ "Tool result should be in the messages sent back to the model"
+ );
+
+ // Verify the tool result contains the file content
+ let result = tool_result.unwrap();
+ let result_text = match &result.content {
+ language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
+ _ => panic!("expected text content in tool result"),
+ };
+ assert!(
+ result_text.contains("Hello from lib!"),
+ "Tool result should contain file content, got: {}",
+ result_text
+ );
+
+ // Verify the subagent is ready for more input (tool completed, model called again)
+ // This test verifies the subagent can successfully use read_file tool.
+ // The summary flow is tested separately in test_subagent_returns_summary_on_completion.
+}
+
+#[gpui::test]
+async fn test_subagent_uses_edit_file_tool(cx: &mut TestAppContext) {
+ // This test verifies that subagents can successfully use the edit_file tool
+ // through the full thread flow, including the edit agent's model request.
+ // It also verifies that the edit agent uses the subagent's thread ID, not the parent's.
+ super::init_test(cx);
+ super::always_allow_tools(cx);
+
+ cx.update(|cx| {
+ cx.update_flags(true, vec!["subagents".to_string()]);
+ });
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ path!("/project"),
+ json!({
+ "src": {
+ "config.rs": "pub const VERSION: &str = \"1.0.0\";\n"
+ }
+ }),
+ )
+ .await;
+
+ let project = project::Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
+ let project_context = cx.new(|_cx| ProjectContext::default());
+ let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
+ let context_server_registry =
+ cx.new(|cx| crate::ContextServerRegistry::new(context_server_store.clone(), cx));
+ let model = Arc::new(FakeLanguageModel::default());
+ let fake_model = model.as_fake();
+
+ // Create a "parent" thread to simulate the real scenario where tools are inherited
+ let parent_thread = cx.new(|cx| {
+ crate::Thread::new(
+ project.clone(),
+ cx.new(|_cx| ProjectContext::default()),
+ cx.new(|cx| crate::ContextServerRegistry::new(context_server_store.clone(), cx)),
+ crate::Templates::new(),
+ Some(model.clone()),
+ cx,
+ )
+ });
+ let parent_thread_id = parent_thread.read_with(cx, |thread, _| thread.id().to_string());
+
+ // Create parent tools that reference the parent thread
+ let parent_tools: BTreeMap<gpui::SharedString, std::sync::Arc<dyn crate::AnyAgentTool>> = {
+ let action_log = cx.new(|_| ActionLog::new(project.clone()));
+ let language_registry = project.read_with(cx, |p, _| p.languages().clone());
+ let mut tools: BTreeMap<gpui::SharedString, std::sync::Arc<dyn crate::AnyAgentTool>> =
+ BTreeMap::new();
+ tools.insert(
+ "read_file".into(),
+ crate::ReadFileTool::new(parent_thread.downgrade(), project.clone(), action_log)
+ .erase(),
+ );
+ tools.insert(
+ "edit_file".into(),
+ crate::EditFileTool::new(
+ project.clone(),
+ parent_thread.downgrade(),
+ language_registry,
+ crate::Templates::new(),
+ )
+ .erase(),
+ );
+ tools
+ };
+
+ // Create subagent context
+ let subagent_context = crate::SubagentContext {
+ parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
+ tool_use_id: language_model::LanguageModelToolUseId::from("subagent-tool-use-id"),
+ depth: 1,
+ summary_prompt: "Summarize what you changed".to_string(),
+ context_low_prompt: "Context low".to_string(),
+ };
+
+ // Create subagent - tools should be rebound to use subagent's thread
+ let subagent = cx.new(|cx| {
+ crate::Thread::new_subagent(
+ project.clone(),
+ project_context,
+ context_server_registry,
+ crate::Templates::new(),
+ model.clone(),
+ subagent_context,
+ parent_tools,
+ cx,
+ )
+ });
+
+ // Get the subagent's thread ID - it should be different from parent
+ let subagent_thread_id = subagent.read_with(cx, |thread, _| thread.id().to_string());
+ assert_ne!(
+ parent_thread_id, subagent_thread_id,
+ "Subagent should have a different thread ID than parent"
+ );
+
+ // Verify the subagent has the tools
+ subagent.read_with(cx, |thread, _| {
+ assert!(
+ thread.has_registered_tool("read_file"),
+ "subagent should have read_file tool"
+ );
+ assert!(
+ thread.has_registered_tool("edit_file"),
+ "subagent should have edit_file tool"
+ );
+ });
+
+ // Submit a user message to the subagent
+ subagent
+ .update(cx, |thread, cx| {
+ thread.submit_user_message("Update the version in config.rs to 2.0.0", cx)
+ })
+ .unwrap();
+ cx.run_until_parked();
+
+ // First, model calls read_file to see the current content
+ let read_tool_use = LanguageModelToolUse {
+ id: "read_tool_1".into(),
+ name: "read_file".into(),
+ raw_input: json!({"path": "project/src/config.rs"}).to_string(),
+ input: json!({"path": "project/src/config.rs"}),
+ is_input_complete: true,
+ thought_signature: None,
+ };
+ fake_model
+ .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(read_tool_use));
+ fake_model.end_last_completion_stream();
+ cx.run_until_parked();
+
+ // Wait for the read tool to complete and model to be called again
+ let deadline = std::time::Instant::now() + Duration::from_secs(5);
+ while fake_model.pending_completions().is_empty() {
+ if std::time::Instant::now() >= deadline {
+ panic!("Timed out waiting for model to be called after read_file tool");
+ }
+ cx.run_until_parked();
+ cx.background_executor
+ .timer(Duration::from_millis(10))
+ .await;
+ }
+
+ // Model responds and calls edit_file
+ fake_model.send_last_completion_stream_text_chunk("I'll update the version now.");
+ let edit_tool_use = LanguageModelToolUse {
+ id: "edit_tool_1".into(),
+ name: "edit_file".into(),
+ raw_input: json!({
+ "display_description": "Update version to 2.0.0",
+ "path": "project/src/config.rs",
+ "mode": "edit"
+ })
+ .to_string(),
+ input: json!({
+ "display_description": "Update version to 2.0.0",
+ "path": "project/src/config.rs",
+ "mode": "edit"
+ }),
+ is_input_complete: true,
+ thought_signature: None,
+ };
+ fake_model
+ .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(edit_tool_use));
+ fake_model.end_last_completion_stream();
+ cx.run_until_parked();
+
+ // The edit_file tool creates an EditAgent which makes its own model request.
+ // Wait for that request.
+ let deadline = std::time::Instant::now() + Duration::from_secs(5);
+ while fake_model.pending_completions().is_empty() {
+ if std::time::Instant::now() >= deadline {
+ panic!(
+ "Timed out waiting for edit agent completion request in subagent. Pending: {}",
+ fake_model.pending_completions().len()
+ );
+ }
+ cx.run_until_parked();
+ cx.background_executor
+ .timer(Duration::from_millis(10))
+ .await;
+ }
+
+ // Verify the edit agent's request uses the SUBAGENT's thread ID, not the parent's
+ let pending = fake_model.pending_completions();
+ let edit_agent_request = pending.last().unwrap();
+ let edit_agent_thread_id = edit_agent_request.thread_id.as_ref().unwrap();
+ std::assert_eq!(
+ edit_agent_thread_id,
+ &subagent_thread_id,
+ "Edit agent should use subagent's thread ID, not parent's. Got: {}, expected: {}",
+ edit_agent_thread_id,
+ subagent_thread_id
+ );
+ std::assert_ne!(
+ edit_agent_thread_id,
+ &parent_thread_id,
+ "Edit agent should NOT use parent's thread ID"
+ );
+
+ // Send the edit agent's response with the XML format it expects
+ let edit_response = "<old_text>pub const VERSION: &str = \"1.0.0\";</old_text>\n<new_text>pub const VERSION: &str = \"2.0.0\";</new_text>";
+ fake_model.send_last_completion_stream_text_chunk(edit_response);
+ fake_model.end_last_completion_stream();
+ cx.run_until_parked();
+
+ // Wait for the edit to complete and the thread to call the model again with tool results
+ let deadline = std::time::Instant::now() + Duration::from_secs(5);
+ while fake_model.pending_completions().is_empty() {
+ if std::time::Instant::now() >= deadline {
+ panic!("Timed out waiting for model to be called after edit completion in subagent");
+ }
+ cx.run_until_parked();
+ cx.background_executor
+ .timer(Duration::from_millis(10))
+ .await;
+ }
+
+ // Verify the file was edited
+ let file_content = fs
+ .load(path!("/project/src/config.rs").as_ref())
+ .await
+ .expect("file should exist");
+ assert!(
+ file_content.contains("2.0.0"),
+ "File should have been edited to contain new version. Content: {}",
+ file_content
+ );
+ assert!(
+ !file_content.contains("1.0.0"),
+ "Old version should be replaced. Content: {}",
+ file_content
+ );
+
+ // Verify the tool result was sent back to the model
+ let pending = fake_model.pending_completions();
+ assert!(
+ !pending.is_empty(),
+ "Model should have been called with tool result"
+ );
+
+ let last_request = pending.last().unwrap();
+ let has_tool_result = last_request.messages.iter().any(|m| {
+ m.content
+ .iter()
+ .any(|c| matches!(c, MessageContent::ToolResult(_)))
+ });
+ assert!(
+ has_tool_result,
+ "Tool result should be in the messages sent back to the model"
+ );
+}
@@ -52,6 +52,7 @@ use std::{
};
use util::path;
+mod edit_file_thread_test;
mod test_tools;
use test_tools::*;
@@ -836,6 +836,19 @@ impl Thread {
let action_log = cx.new(|_cx| ActionLog::new(project.clone()));
let (prompt_capabilities_tx, prompt_capabilities_rx) =
watch::channel(Self::prompt_capabilities(Some(model.as_ref())));
+
+ // Rebind tools that hold thread references to use this subagent's thread
+ // instead of the parent's thread. This is critical for tools like EditFileTool
+ // that make model requests using the thread's ID.
+ let weak_self = cx.weak_entity();
+ let tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>> = parent_tools
+ .into_iter()
+ .map(|(name, tool)| {
+ let rebound = tool.rebind_thread(weak_self.clone()).unwrap_or(tool);
+ (name, rebound)
+ })
+ .collect();
+
Self {
id: acp::SessionId::new(uuid::Uuid::new_v4().to_string()),
prompt_id: PromptId::new(),
@@ -849,7 +862,7 @@ impl Thread {
running_turn: None,
queued_messages: Vec::new(),
pending_message: None,
- tools: parent_tools,
+ tools,
request_token_usage: HashMap::default(),
cumulative_token_usage: TokenUsage::default(),
initial_project_snapshot: Task::ready(None).shared(),
@@ -2274,6 +2287,7 @@ impl Thread {
stop: Vec::new(),
temperature: AgentSettings::temperature_for_model(model, cx),
thinking_allowed: true,
+ bypass_rate_limit: false,
};
log::debug!("Completion request built successfully");
@@ -2690,6 +2704,15 @@ where
fn erase(self) -> Arc<dyn AnyAgentTool> {
Arc::new(Erased(Arc::new(self)))
}
+
+ /// Create a new instance of this tool bound to a different thread.
+ /// This is used when creating subagents, so that tools like EditFileTool
+ /// that hold a thread reference will use the subagent's thread instead
+ /// of the parent's thread.
+ /// Returns None if the tool doesn't need rebinding (most tools).
+ fn rebind_thread(&self, _new_thread: WeakEntity<Thread>) -> Option<Arc<dyn AnyAgentTool>> {
+ None
+ }
}
pub struct Erased<T>(T);
@@ -2721,6 +2744,14 @@ pub trait AnyAgentTool {
event_stream: ToolCallEventStream,
cx: &mut App,
) -> Result<()>;
+ /// Create a new instance of this tool bound to a different thread.
+ /// This is used when creating subagents, so that tools like EditFileTool
+ /// that hold a thread reference will use the subagent's thread instead
+ /// of the parent's thread.
+ /// Returns None if the tool doesn't need rebinding (most tools).
+ fn rebind_thread(&self, _new_thread: WeakEntity<Thread>) -> Option<Arc<dyn AnyAgentTool>> {
+ None
+ }
}
impl<T> AnyAgentTool for Erased<Arc<T>>
@@ -2784,6 +2815,10 @@ where
let output = serde_json::from_value(output)?;
self.0.replay(input, output, event_stream, cx)
}
+
+ fn rebind_thread(&self, new_thread: WeakEntity<Thread>) -> Option<Arc<dyn AnyAgentTool>> {
+ self.0.rebind_thread(new_thread)
+ }
}
#[derive(Clone)]
@@ -144,6 +144,15 @@ impl EditFileTool {
}
}
+ pub fn with_thread(&self, new_thread: WeakEntity<Thread>) -> Self {
+ Self {
+ project: self.project.clone(),
+ thread: new_thread,
+ language_registry: self.language_registry.clone(),
+ templates: self.templates.clone(),
+ }
+ }
+
fn authorize(
&self,
input: &EditFileToolInput,
@@ -398,7 +407,6 @@ impl AgentTool for EditFileTool {
})
.await;
-
let (output, mut events) = if matches!(input.mode, EditFileMode::Edit) {
edit_agent.edit(
buffer.clone(),
@@ -575,6 +583,13 @@ impl AgentTool for EditFileTool {
}));
Ok(())
}
+
+ fn rebind_thread(
+ &self,
+ new_thread: gpui::WeakEntity<crate::Thread>,
+ ) -> Option<std::sync::Arc<dyn crate::AnyAgentTool>> {
+ Some(self.with_thread(new_thread).erase())
+ }
}
/// Validate that the file path is valid, meaning:
@@ -65,6 +65,14 @@ impl ReadFileTool {
action_log,
}
}
+
+ pub fn with_thread(&self, new_thread: WeakEntity<Thread>) -> Self {
+ Self {
+ thread: new_thread,
+ project: self.project.clone(),
+ action_log: self.action_log.clone(),
+ }
+ }
}
impl AgentTool for ReadFileTool {
@@ -308,6 +316,13 @@ impl AgentTool for ReadFileTool {
result
})
}
+
+ fn rebind_thread(
+ &self,
+ new_thread: WeakEntity<Thread>,
+ ) -> Option<std::sync::Arc<dyn crate::AnyAgentTool>> {
+ Some(self.with_thread(new_thread).erase())
+ }
}
#[cfg(test)]
@@ -544,6 +544,7 @@ impl CodegenAlternative {
temperature,
messages,
thinking_allowed: false,
+ bypass_rate_limit: false,
}
}))
}
@@ -622,6 +623,7 @@ impl CodegenAlternative {
temperature,
messages: vec![request_message],
thinking_allowed: false,
+ bypass_rate_limit: false,
}
}))
}
@@ -275,6 +275,7 @@ impl TerminalInlineAssistant {
stop: Vec::new(),
temperature,
thinking_allowed: false,
+ bypass_rate_limit: false,
}
}))
}
@@ -2269,6 +2269,7 @@ impl TextThread {
stop: Vec::new(),
temperature: model.and_then(|model| AgentSettings::temperature_for_model(model, cx)),
thinking_allowed: true,
+ bypass_rate_limit: false,
};
for message in self.messages(cx) {
if message.status != MessageStatus::Done {
@@ -563,6 +563,7 @@ impl ExampleInstance {
tool_choice: None,
stop: Vec::new(),
thinking_allowed: true,
+ bypass_rate_limit: false,
};
let model = model.clone();
@@ -2691,6 +2691,7 @@ impl GitPanel {
stop: Vec::new(),
temperature,
thinking_allowed: false,
+ bypass_rate_limit: false,
};
let stream = model.stream_completion_text(request, cx);
@@ -16,7 +16,7 @@ pub struct RateLimiter {
pub struct RateLimitGuard<T> {
inner: T,
- _guard: SemaphoreGuardArc,
+ _guard: Option<SemaphoreGuardArc>,
}
impl<T> Stream for RateLimitGuard<T>
@@ -68,6 +68,36 @@ impl RateLimiter {
async move {
let guard = guard.await;
let inner = future.await?;
+ Ok(RateLimitGuard {
+ inner,
+ _guard: Some(guard),
+ })
+ }
+ }
+
+ /// Like `stream`, but conditionally bypasses the rate limiter based on the flag.
+ /// Used for nested requests (like edit agent requests) that are already "part of"
+ /// a rate-limited request to avoid deadlocks.
+ pub fn stream_with_bypass<'a, Fut, T>(
+ &self,
+ future: Fut,
+ bypass: bool,
+ ) -> impl 'a
+ + Future<
+ Output = Result<impl Stream<Item = T::Item> + use<Fut, T>, LanguageModelCompletionError>,
+ >
+ where
+ Fut: 'a + Future<Output = Result<T, LanguageModelCompletionError>>,
+ T: Stream,
+ {
+ let semaphore = self.semaphore.clone();
+ async move {
+ let guard = if bypass {
+ None
+ } else {
+ Some(semaphore.acquire_arc().await)
+ };
+ let inner = future.await?;
Ok(RateLimitGuard {
inner,
_guard: guard,
@@ -75,3 +105,190 @@ impl RateLimiter {
}
}
}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use futures::stream;
+ use smol::lock::Barrier;
+ use std::sync::Arc;
+ use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
+ use std::time::{Duration, Instant};
+
+ /// Tests that nested requests without bypass_rate_limit cause deadlock,
+ /// while requests with bypass_rate_limit complete successfully.
+ ///
+ /// This test simulates the scenario where multiple "parent" requests each
+ /// try to spawn a "nested" request (like edit_file tool spawning an edit agent).
+ /// With a rate limit of 2 and 2 parent requests, without bypass the nested
+ /// requests would block forever waiting for permits that the parents hold.
+ #[test]
+ fn test_nested_requests_bypass_prevents_deadlock() {
+ smol::block_on(async {
+ // Use only 2 permits so we can guarantee deadlock conditions
+ let rate_limiter = RateLimiter::new(2);
+ let completed = Arc::new(AtomicUsize::new(0));
+ // Barrier ensures all parents acquire permits before any tries nested request
+ let barrier = Arc::new(Barrier::new(2));
+
+ // Spawn 2 "parent" requests that each try to make a "nested" request
+ let mut handles = Vec::new();
+ for _ in 0..2 {
+ let limiter = rate_limiter.clone();
+ let completed = completed.clone();
+ let barrier = barrier.clone();
+
+ let handle = smol::spawn(async move {
+ // Parent request acquires a permit via stream_with_bypass (bypass=false)
+ let parent_stream = limiter
+ .stream_with_bypass(
+ async {
+ // Wait for all parents to acquire permits
+ barrier.wait().await;
+
+ // While holding the parent permit, make a nested request
+ // WITH bypass=true (simulating EditAgent behavior)
+ let nested_stream = limiter
+ .stream_with_bypass(
+ async { Ok(stream::iter(vec![1, 2, 3])) },
+ true, // bypass - this is the key!
+ )
+ .await?;
+
+ // Consume the nested stream
+ use futures::StreamExt;
+ let _: Vec<_> = nested_stream.collect().await;
+
+ Ok(stream::iter(vec!["done"]))
+ },
+ false, // parent does NOT bypass
+ )
+ .await
+ .unwrap();
+
+ // Consume parent stream
+ use futures::StreamExt;
+ let _: Vec<_> = parent_stream.collect().await;
+
+ completed.fetch_add(1, Ordering::SeqCst);
+ });
+ handles.push(handle);
+ }
+
+ // With bypass=true for nested requests, this should complete quickly
+ let timed_out = Arc::new(AtomicBool::new(false));
+ let timed_out_clone = timed_out.clone();
+
+ // Spawn a watchdog that sets timed_out after 2 seconds
+ let watchdog = smol::spawn(async move {
+ let start = Instant::now();
+ while start.elapsed() < Duration::from_secs(2) {
+ smol::future::yield_now().await;
+ }
+ timed_out_clone.store(true, Ordering::SeqCst);
+ });
+
+ // Wait for all handles to complete
+ for handle in handles {
+ handle.await;
+ }
+
+ // Cancel the watchdog
+ drop(watchdog);
+
+ if timed_out.load(Ordering::SeqCst) {
+ panic!(
+ "Test timed out - deadlock detected! This means bypass_rate_limit is not working."
+ );
+ }
+ assert_eq!(completed.load(Ordering::SeqCst), 2);
+ });
+ }
+
+ /// Tests that without bypass, nested requests DO cause deadlock.
+ /// This test verifies the problem exists when bypass is not used.
+ #[test]
+ fn test_nested_requests_without_bypass_deadlocks() {
+ smol::block_on(async {
+ // Use only 2 permits so we can guarantee deadlock conditions
+ let rate_limiter = RateLimiter::new(2);
+ let completed = Arc::new(AtomicUsize::new(0));
+ // Barrier ensures all parents acquire permits before any tries nested request
+ let barrier = Arc::new(Barrier::new(2));
+
+ // Spawn 2 "parent" requests that each try to make a "nested" request
+ let mut handles = Vec::new();
+ for _ in 0..2 {
+ let limiter = rate_limiter.clone();
+ let completed = completed.clone();
+ let barrier = barrier.clone();
+
+ let handle = smol::spawn(async move {
+ // Parent request acquires a permit
+ let parent_stream = limiter
+ .stream_with_bypass(
+ async {
+ // Wait for all parents to acquire permits - this guarantees
+ // that all 2 permits are held before any nested request starts
+ barrier.wait().await;
+
+ // Nested request WITHOUT bypass - this will deadlock!
+ // Both parents hold permits, so no permits available
+ let nested_stream = limiter
+ .stream_with_bypass(
+ async { Ok(stream::iter(vec![1, 2, 3])) },
+ false, // NO bypass - will try to acquire permit
+ )
+ .await?;
+
+ use futures::StreamExt;
+ let _: Vec<_> = nested_stream.collect().await;
+
+ Ok(stream::iter(vec!["done"]))
+ },
+ false,
+ )
+ .await
+ .unwrap();
+
+ use futures::StreamExt;
+ let _: Vec<_> = parent_stream.collect().await;
+
+ completed.fetch_add(1, Ordering::SeqCst);
+ });
+ handles.push(handle);
+ }
+
+ // This SHOULD timeout because of deadlock (both parents hold permits,
+ // both nested requests wait for permits)
+ let timed_out = Arc::new(AtomicBool::new(false));
+ let timed_out_clone = timed_out.clone();
+
+ // Spawn a watchdog that sets timed_out after 100ms
+ let watchdog = smol::spawn(async move {
+ let start = Instant::now();
+ while start.elapsed() < Duration::from_millis(100) {
+ smol::future::yield_now().await;
+ }
+ timed_out_clone.store(true, Ordering::SeqCst);
+ });
+
+ // Poll briefly to let everything run
+ let start = Instant::now();
+ while start.elapsed() < Duration::from_millis(100) {
+ smol::future::yield_now().await;
+ }
+
+ // Cancel the watchdog
+ drop(watchdog);
+
+ // Expected - deadlock occurred, which proves the bypass is necessary
+ let count = completed.load(Ordering::SeqCst);
+ assert_eq!(
+ count, 0,
+ "Expected complete deadlock (0 completed) but {} requests completed",
+ count
+ );
+ });
+ }
+}
@@ -451,6 +451,11 @@ pub struct LanguageModelRequest {
pub stop: Vec<String>,
pub temperature: Option<f32>,
pub thinking_allowed: bool,
+ /// When true, this request bypasses the rate limiter. Used for nested requests
+ /// (like edit agent requests spawned from within a tool call) that are already
+ /// "part of" a rate-limited request to avoid deadlocks.
+ #[serde(default)]
+ pub bypass_rate_limit: bool,
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
@@ -578,6 +578,7 @@ impl LanguageModel for AnthropicModel {
LanguageModelCompletionError,
>,
> {
+ let bypass_rate_limit = request.bypass_rate_limit;
let request = into_anthropic(
request,
self.model.request_id().into(),
@@ -586,10 +587,13 @@ impl LanguageModel for AnthropicModel {
self.model.mode(),
);
let request = self.stream_completion(request, cx);
- let future = self.request_limiter.stream(async move {
- let response = request.await?;
- Ok(AnthropicEventMapper::new().map_stream(response))
- });
+ let future = self.request_limiter.stream_with_bypass(
+ async move {
+ let response = request.await?;
+ Ok(AnthropicEventMapper::new().map_stream(response))
+ },
+ bypass_rate_limit,
+ );
async move { Ok(future.await?.boxed()) }.boxed()
}
@@ -1164,6 +1168,7 @@ mod tests {
tools: vec![],
tool_choice: None,
thinking_allowed: true,
+ bypass_rate_limit: false,
};
let anthropic_request = into_anthropic(
@@ -679,6 +679,7 @@ impl LanguageModel for BedrockModel {
};
let deny_tool_calls = request.tool_choice == Some(LanguageModelToolChoice::None);
+ let bypass_rate_limit = request.bypass_rate_limit;
let request = match into_bedrock(
request,
@@ -693,16 +694,19 @@ impl LanguageModel for BedrockModel {
};
let request = self.stream_completion(request, cx);
- let future = self.request_limiter.stream(async move {
- let response = request.await.map_err(|err| anyhow!(err))?;
- let events = map_to_language_model_completion_events(response);
+ let future = self.request_limiter.stream_with_bypass(
+ async move {
+ let response = request.await.map_err(|err| anyhow!(err))?;
+ let events = map_to_language_model_completion_events(response);
- if deny_tool_calls {
- Ok(deny_tool_use_events(events).boxed())
- } else {
- Ok(events.boxed())
- }
- });
+ if deny_tool_calls {
+ Ok(deny_tool_use_events(events).boxed())
+ } else {
+ Ok(events.boxed())
+ }
+ },
+ bypass_rate_limit,
+ );
async move { Ok(future.await?.boxed()) }.boxed()
}
@@ -719,6 +719,7 @@ impl LanguageModel for CloudLanguageModel {
let thread_id = request.thread_id.clone();
let prompt_id = request.prompt_id.clone();
let intent = request.intent;
+ let bypass_rate_limit = request.bypass_rate_limit;
let app_version = Some(cx.update(|cx| AppVersion::global(cx)));
let use_responses_api = cx.update(|cx| cx.has_flag::<OpenAiResponsesApiFeatureFlag>());
let thinking_allowed = request.thinking_allowed;
@@ -740,53 +741,8 @@ impl LanguageModel for CloudLanguageModel {
);
let client = self.client.clone();
let llm_api_token = self.llm_api_token.clone();
- let future = self.request_limiter.stream(async move {
- let PerformLlmCompletionResponse {
- response,
- includes_status_messages,
- } = Self::perform_llm_completion(
- client.clone(),
- llm_api_token,
- app_version,
- CompletionBody {
- thread_id,
- prompt_id,
- intent,
- provider: cloud_llm_client::LanguageModelProvider::Anthropic,
- model: request.model.clone(),
- provider_request: serde_json::to_value(&request)
- .map_err(|e| anyhow!(e))?,
- },
- )
- .await
- .map_err(|err| match err.downcast::<ApiError>() {
- Ok(api_err) => anyhow!(LanguageModelCompletionError::from(api_err)),
- Err(err) => anyhow!(err),
- })?;
-
- let mut mapper = AnthropicEventMapper::new();
- Ok(map_cloud_completion_events(
- Box::pin(response_lines(response, includes_status_messages)),
- &provider_name,
- move |event| mapper.map_event(event),
- ))
- });
- async move { Ok(future.await?.boxed()) }.boxed()
- }
- cloud_llm_client::LanguageModelProvider::OpenAi => {
- let client = self.client.clone();
- let llm_api_token = self.llm_api_token.clone();
-
- if use_responses_api {
- let request = into_open_ai_response(
- request,
- &self.model.id.0,
- self.model.supports_parallel_tool_calls,
- true,
- None,
- None,
- );
- let future = self.request_limiter.stream(async move {
+ let future = self.request_limiter.stream_with_bypass(
+ async move {
let PerformLlmCompletionResponse {
response,
includes_status_messages,
@@ -798,21 +754,74 @@ impl LanguageModel for CloudLanguageModel {
thread_id,
prompt_id,
intent,
- provider: cloud_llm_client::LanguageModelProvider::OpenAi,
+ provider: cloud_llm_client::LanguageModelProvider::Anthropic,
model: request.model.clone(),
provider_request: serde_json::to_value(&request)
.map_err(|e| anyhow!(e))?,
},
)
- .await?;
-
- let mut mapper = OpenAiResponseEventMapper::new();
+ .await
+ .map_err(|err| {
+ match err.downcast::<ApiError>() {
+ Ok(api_err) => anyhow!(LanguageModelCompletionError::from(api_err)),
+ Err(err) => anyhow!(err),
+ }
+ })?;
+
+ let mut mapper = AnthropicEventMapper::new();
Ok(map_cloud_completion_events(
Box::pin(response_lines(response, includes_status_messages)),
&provider_name,
move |event| mapper.map_event(event),
))
- });
+ },
+ bypass_rate_limit,
+ );
+ async move { Ok(future.await?.boxed()) }.boxed()
+ }
+ cloud_llm_client::LanguageModelProvider::OpenAi => {
+ let client = self.client.clone();
+ let llm_api_token = self.llm_api_token.clone();
+
+ if use_responses_api {
+ let request = into_open_ai_response(
+ request,
+ &self.model.id.0,
+ self.model.supports_parallel_tool_calls,
+ true,
+ None,
+ None,
+ );
+ let future = self.request_limiter.stream_with_bypass(
+ async move {
+ let PerformLlmCompletionResponse {
+ response,
+ includes_status_messages,
+ } = Self::perform_llm_completion(
+ client.clone(),
+ llm_api_token,
+ app_version,
+ CompletionBody {
+ thread_id,
+ prompt_id,
+ intent,
+ provider: cloud_llm_client::LanguageModelProvider::OpenAi,
+ model: request.model.clone(),
+ provider_request: serde_json::to_value(&request)
+ .map_err(|e| anyhow!(e))?,
+ },
+ )
+ .await?;
+
+ let mut mapper = OpenAiResponseEventMapper::new();
+ Ok(map_cloud_completion_events(
+ Box::pin(response_lines(response, includes_status_messages)),
+ &provider_name,
+ move |event| mapper.map_event(event),
+ ))
+ },
+ bypass_rate_limit,
+ );
async move { Ok(future.await?.boxed()) }.boxed()
} else {
let request = into_open_ai(
@@ -823,7 +832,52 @@ impl LanguageModel for CloudLanguageModel {
None,
None,
);
- let future = self.request_limiter.stream(async move {
+ let future = self.request_limiter.stream_with_bypass(
+ async move {
+ let PerformLlmCompletionResponse {
+ response,
+ includes_status_messages,
+ } = Self::perform_llm_completion(
+ client.clone(),
+ llm_api_token,
+ app_version,
+ CompletionBody {
+ thread_id,
+ prompt_id,
+ intent,
+ provider: cloud_llm_client::LanguageModelProvider::OpenAi,
+ model: request.model.clone(),
+ provider_request: serde_json::to_value(&request)
+ .map_err(|e| anyhow!(e))?,
+ },
+ )
+ .await?;
+
+ let mut mapper = OpenAiEventMapper::new();
+ Ok(map_cloud_completion_events(
+ Box::pin(response_lines(response, includes_status_messages)),
+ &provider_name,
+ move |event| mapper.map_event(event),
+ ))
+ },
+ bypass_rate_limit,
+ );
+ async move { Ok(future.await?.boxed()) }.boxed()
+ }
+ }
+ cloud_llm_client::LanguageModelProvider::XAi => {
+ let client = self.client.clone();
+ let request = into_open_ai(
+ request,
+ &self.model.id.0,
+ self.model.supports_parallel_tool_calls,
+ false,
+ None,
+ None,
+ );
+ let llm_api_token = self.llm_api_token.clone();
+ let future = self.request_limiter.stream_with_bypass(
+ async move {
let PerformLlmCompletionResponse {
response,
includes_status_messages,
@@ -835,7 +889,7 @@ impl LanguageModel for CloudLanguageModel {
thread_id,
prompt_id,
intent,
- provider: cloud_llm_client::LanguageModelProvider::OpenAi,
+ provider: cloud_llm_client::LanguageModelProvider::XAi,
model: request.model.clone(),
provider_request: serde_json::to_value(&request)
.map_err(|e| anyhow!(e))?,
@@ -849,48 +903,9 @@ impl LanguageModel for CloudLanguageModel {
&provider_name,
move |event| mapper.map_event(event),
))
- });
- async move { Ok(future.await?.boxed()) }.boxed()
- }
- }
- cloud_llm_client::LanguageModelProvider::XAi => {
- let client = self.client.clone();
- let request = into_open_ai(
- request,
- &self.model.id.0,
- self.model.supports_parallel_tool_calls,
- false,
- None,
- None,
+ },
+ bypass_rate_limit,
);
- let llm_api_token = self.llm_api_token.clone();
- let future = self.request_limiter.stream(async move {
- let PerformLlmCompletionResponse {
- response,
- includes_status_messages,
- } = Self::perform_llm_completion(
- client.clone(),
- llm_api_token,
- app_version,
- CompletionBody {
- thread_id,
- prompt_id,
- intent,
- provider: cloud_llm_client::LanguageModelProvider::XAi,
- model: request.model.clone(),
- provider_request: serde_json::to_value(&request)
- .map_err(|e| anyhow!(e))?,
- },
- )
- .await?;
-
- let mut mapper = OpenAiEventMapper::new();
- Ok(map_cloud_completion_events(
- Box::pin(response_lines(response, includes_status_messages)),
- &provider_name,
- move |event| mapper.map_event(event),
- ))
- });
async move { Ok(future.await?.boxed()) }.boxed()
}
cloud_llm_client::LanguageModelProvider::Google => {
@@ -898,33 +913,36 @@ impl LanguageModel for CloudLanguageModel {
let request =
into_google(request, self.model.id.to_string(), GoogleModelMode::Default);
let llm_api_token = self.llm_api_token.clone();
- let future = self.request_limiter.stream(async move {
- let PerformLlmCompletionResponse {
- response,
- includes_status_messages,
- } = Self::perform_llm_completion(
- client.clone(),
- llm_api_token,
- app_version,
- CompletionBody {
- thread_id,
- prompt_id,
- intent,
- provider: cloud_llm_client::LanguageModelProvider::Google,
- model: request.model.model_id.clone(),
- provider_request: serde_json::to_value(&request)
- .map_err(|e| anyhow!(e))?,
- },
- )
- .await?;
-
- let mut mapper = GoogleEventMapper::new();
- Ok(map_cloud_completion_events(
- Box::pin(response_lines(response, includes_status_messages)),
- &provider_name,
- move |event| mapper.map_event(event),
- ))
- });
+ let future = self.request_limiter.stream_with_bypass(
+ async move {
+ let PerformLlmCompletionResponse {
+ response,
+ includes_status_messages,
+ } = Self::perform_llm_completion(
+ client.clone(),
+ llm_api_token,
+ app_version,
+ CompletionBody {
+ thread_id,
+ prompt_id,
+ intent,
+ provider: cloud_llm_client::LanguageModelProvider::Google,
+ model: request.model.model_id.clone(),
+ provider_request: serde_json::to_value(&request)
+ .map_err(|e| anyhow!(e))?,
+ },
+ )
+ .await?;
+
+ let mut mapper = GoogleEventMapper::new();
+ Ok(map_cloud_completion_events(
+ Box::pin(response_lines(response, includes_status_messages)),
+ &provider_name,
+ move |event| mapper.map_event(event),
+ ))
+ },
+ bypass_rate_limit,
+ );
async move { Ok(future.await?.boxed()) }.boxed()
}
}
@@ -307,6 +307,7 @@ impl LanguageModel for CopilotChatLanguageModel {
LanguageModelCompletionError,
>,
> {
+ let bypass_rate_limit = request.bypass_rate_limit;
let is_user_initiated = request.intent.is_none_or(|intent| match intent {
CompletionIntent::UserPrompt
| CompletionIntent::ThreadContextSummarization
@@ -327,11 +328,14 @@ impl LanguageModel for CopilotChatLanguageModel {
let request =
CopilotChat::stream_response(responses_request, is_user_initiated, cx.clone());
request_limiter
- .stream(async move {
- let stream = request.await?;
- let mapper = CopilotResponsesEventMapper::new();
- Ok(mapper.map_stream(stream).boxed())
- })
+ .stream_with_bypass(
+ async move {
+ let stream = request.await?;
+ let mapper = CopilotResponsesEventMapper::new();
+ Ok(mapper.map_stream(stream).boxed())
+ },
+ bypass_rate_limit,
+ )
.await
});
return async move { Ok(future.await?.boxed()) }.boxed();
@@ -348,13 +352,16 @@ impl LanguageModel for CopilotChatLanguageModel {
let request =
CopilotChat::stream_completion(copilot_request, is_user_initiated, cx.clone());
request_limiter
- .stream(async move {
- let response = request.await?;
- Ok(map_to_language_model_completion_events(
- response,
- is_streaming,
- ))
- })
+ .stream_with_bypass(
+ async move {
+ let response = request.await?;
+ Ok(map_to_language_model_completion_events(
+ response,
+ is_streaming,
+ ))
+ },
+ bypass_rate_limit,
+ )
.await
});
async move { Ok(future.await?.boxed()) }.boxed()
@@ -929,6 +936,7 @@ fn into_copilot_responses(
stop: _,
temperature,
thinking_allowed: _,
+ bypass_rate_limit: _,
} = request;
let mut input_items: Vec<responses::ResponseInputItem> = Vec::new();
@@ -199,6 +199,7 @@ impl DeepSeekLanguageModel {
fn stream_completion(
&self,
request: deepseek::Request,
+ bypass_rate_limit: bool,
cx: &AsyncApp,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<deepseek::StreamResponse>>>> {
let http_client = self.http_client.clone();
@@ -208,17 +209,20 @@ impl DeepSeekLanguageModel {
(state.api_key_state.key(&api_url), api_url)
});
- let future = self.request_limiter.stream(async move {
- let Some(api_key) = api_key else {
- return Err(LanguageModelCompletionError::NoApiKey {
- provider: PROVIDER_NAME,
- });
- };
- let request =
- deepseek::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
- let response = request.await?;
- Ok(response)
- });
+ let future = self.request_limiter.stream_with_bypass(
+ async move {
+ let Some(api_key) = api_key else {
+ return Err(LanguageModelCompletionError::NoApiKey {
+ provider: PROVIDER_NAME,
+ });
+ };
+ let request =
+ deepseek::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
+ let response = request.await?;
+ Ok(response)
+ },
+ bypass_rate_limit,
+ );
async move { Ok(future.await?.boxed()) }.boxed()
}
@@ -302,8 +306,9 @@ impl LanguageModel for DeepSeekLanguageModel {
LanguageModelCompletionError,
>,
> {
+ let bypass_rate_limit = request.bypass_rate_limit;
let request = into_deepseek(request, &self.model, self.max_output_tokens());
- let stream = self.stream_completion(request, cx);
+ let stream = self.stream_completion(request, bypass_rate_limit, cx);
async move {
let mapper = DeepSeekEventMapper::new();
@@ -370,16 +370,20 @@ impl LanguageModel for GoogleLanguageModel {
LanguageModelCompletionError,
>,
> {
+ let bypass_rate_limit = request.bypass_rate_limit;
let request = into_google(
request,
self.model.request_id().to_string(),
self.model.mode(),
);
let request = self.stream_completion(request, cx);
- let future = self.request_limiter.stream(async move {
- let response = request.await.map_err(LanguageModelCompletionError::from)?;
- Ok(GoogleEventMapper::new().map_stream(response))
- });
+ let future = self.request_limiter.stream_with_bypass(
+ async move {
+ let response = request.await.map_err(LanguageModelCompletionError::from)?;
+ Ok(GoogleEventMapper::new().map_stream(response))
+ },
+ bypass_rate_limit,
+ );
async move { Ok(future.await?.boxed()) }.boxed()
}
}
@@ -370,6 +370,7 @@ impl LmStudioLanguageModel {
fn stream_completion(
&self,
request: lmstudio::ChatCompletionRequest,
+ bypass_rate_limit: bool,
cx: &AsyncApp,
) -> BoxFuture<
'static,
@@ -381,11 +382,15 @@ impl LmStudioLanguageModel {
settings.api_url.clone()
});
- let future = self.request_limiter.stream(async move {
- let request = lmstudio::stream_chat_completion(http_client.as_ref(), &api_url, request);
- let response = request.await?;
- Ok(response)
- });
+ let future = self.request_limiter.stream_with_bypass(
+ async move {
+ let request =
+ lmstudio::stream_chat_completion(http_client.as_ref(), &api_url, request);
+ let response = request.await?;
+ Ok(response)
+ },
+ bypass_rate_limit,
+ );
async move { Ok(future.await?.boxed()) }.boxed()
}
@@ -460,8 +465,9 @@ impl LanguageModel for LmStudioLanguageModel {
LanguageModelCompletionError,
>,
> {
+ let bypass_rate_limit = request.bypass_rate_limit;
let request = self.to_lmstudio_request(request);
- let completions = self.stream_completion(request, cx);
+ let completions = self.stream_completion(request, bypass_rate_limit, cx);
async move {
let mapper = LmStudioEventMapper::new();
Ok(mapper.map_stream(completions.await?).boxed())
@@ -264,6 +264,7 @@ impl MistralLanguageModel {
fn stream_completion(
&self,
request: mistral::Request,
+ bypass_rate_limit: bool,
cx: &AsyncApp,
) -> BoxFuture<
'static,
@@ -276,17 +277,20 @@ impl MistralLanguageModel {
(state.api_key_state.key(&api_url), api_url)
});
- let future = self.request_limiter.stream(async move {
- let Some(api_key) = api_key else {
- return Err(LanguageModelCompletionError::NoApiKey {
- provider: PROVIDER_NAME,
- });
- };
- let request =
- mistral::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
- let response = request.await?;
- Ok(response)
- });
+ let future = self.request_limiter.stream_with_bypass(
+ async move {
+ let Some(api_key) = api_key else {
+ return Err(LanguageModelCompletionError::NoApiKey {
+ provider: PROVIDER_NAME,
+ });
+ };
+ let request =
+ mistral::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
+ let response = request.await?;
+ Ok(response)
+ },
+ bypass_rate_limit,
+ );
async move { Ok(future.await?.boxed()) }.boxed()
}
@@ -370,8 +374,9 @@ impl LanguageModel for MistralLanguageModel {
LanguageModelCompletionError,
>,
> {
+ let bypass_rate_limit = request.bypass_rate_limit;
let request = into_mistral(request, self.model.clone(), self.max_output_tokens());
- let stream = self.stream_completion(request, cx);
+ let stream = self.stream_completion(request, bypass_rate_limit, cx);
async move {
let stream = stream.await?;
@@ -901,6 +906,7 @@ mod tests {
intent: None,
stop: vec![],
thinking_allowed: true,
+ bypass_rate_limit: false,
};
let mistral_request = into_mistral(request, mistral::Model::MistralSmallLatest, None);
@@ -934,6 +940,7 @@ mod tests {
intent: None,
stop: vec![],
thinking_allowed: true,
+ bypass_rate_limit: false,
};
let mistral_request = into_mistral(request, mistral::Model::Pixtral12BLatest, None);
@@ -480,6 +480,7 @@ impl LanguageModel for OllamaLanguageModel {
LanguageModelCompletionError,
>,
> {
+ let bypass_rate_limit = request.bypass_rate_limit;
let request = self.to_ollama_request(request);
let http_client = self.http_client.clone();
@@ -488,13 +489,20 @@ impl LanguageModel for OllamaLanguageModel {
(state.api_key_state.key(&api_url), api_url)
});
- let future = self.request_limiter.stream(async move {
- let stream =
- stream_chat_completion(http_client.as_ref(), &api_url, api_key.as_deref(), request)
- .await?;
- let stream = map_to_language_model_completion_events(stream);
- Ok(stream)
- });
+ let future = self.request_limiter.stream_with_bypass(
+ async move {
+ let stream = stream_chat_completion(
+ http_client.as_ref(),
+ &api_url,
+ api_key.as_deref(),
+ request,
+ )
+ .await?;
+ let stream = map_to_language_model_completion_events(stream);
+ Ok(stream)
+ },
+ bypass_rate_limit,
+ );
future.map_ok(|f| f.boxed()).boxed()
}
@@ -213,6 +213,7 @@ impl OpenAiLanguageModel {
fn stream_completion(
&self,
request: open_ai::Request,
+ bypass_rate_limit: bool,
cx: &AsyncApp,
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>>
{
@@ -223,21 +224,24 @@ impl OpenAiLanguageModel {
(state.api_key_state.key(&api_url), api_url)
});
- let future = self.request_limiter.stream(async move {
- let provider = PROVIDER_NAME;
- let Some(api_key) = api_key else {
- return Err(LanguageModelCompletionError::NoApiKey { provider });
- };
- let request = stream_completion(
- http_client.as_ref(),
- provider.0.as_str(),
- &api_url,
- &api_key,
- request,
- );
- let response = request.await?;
- Ok(response)
- });
+ let future = self.request_limiter.stream_with_bypass(
+ async move {
+ let provider = PROVIDER_NAME;
+ let Some(api_key) = api_key else {
+ return Err(LanguageModelCompletionError::NoApiKey { provider });
+ };
+ let request = stream_completion(
+ http_client.as_ref(),
+ provider.0.as_str(),
+ &api_url,
+ &api_key,
+ request,
+ );
+ let response = request.await?;
+ Ok(response)
+ },
+ bypass_rate_limit,
+ );
async move { Ok(future.await?.boxed()) }.boxed()
}
@@ -245,6 +249,7 @@ impl OpenAiLanguageModel {
fn stream_response(
&self,
request: ResponseRequest,
+ bypass_rate_limit: bool,
cx: &AsyncApp,
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponsesStreamEvent>>>>
{
@@ -256,20 +261,23 @@ impl OpenAiLanguageModel {
});
let provider = PROVIDER_NAME;
- let future = self.request_limiter.stream(async move {
- let Some(api_key) = api_key else {
- return Err(LanguageModelCompletionError::NoApiKey { provider });
- };
- let request = stream_response(
- http_client.as_ref(),
- provider.0.as_str(),
- &api_url,
- &api_key,
- request,
- );
- let response = request.await?;
- Ok(response)
- });
+ let future = self.request_limiter.stream_with_bypass(
+ async move {
+ let Some(api_key) = api_key else {
+ return Err(LanguageModelCompletionError::NoApiKey { provider });
+ };
+ let request = stream_response(
+ http_client.as_ref(),
+ provider.0.as_str(),
+ &api_url,
+ &api_key,
+ request,
+ );
+ let response = request.await?;
+ Ok(response)
+ },
+ bypass_rate_limit,
+ );
async move { Ok(future.await?.boxed()) }.boxed()
}
@@ -368,6 +376,7 @@ impl LanguageModel for OpenAiLanguageModel {
LanguageModelCompletionError,
>,
> {
+ let bypass_rate_limit = request.bypass_rate_limit;
if self.model.supports_chat_completions() {
let request = into_open_ai(
request,
@@ -377,7 +386,7 @@ impl LanguageModel for OpenAiLanguageModel {
self.max_output_tokens(),
self.model.reasoning_effort(),
);
- let completions = self.stream_completion(request, cx);
+ let completions = self.stream_completion(request, bypass_rate_limit, cx);
async move {
let mapper = OpenAiEventMapper::new();
Ok(mapper.map_stream(completions.await?).boxed())
@@ -392,7 +401,7 @@ impl LanguageModel for OpenAiLanguageModel {
self.max_output_tokens(),
self.model.reasoning_effort(),
);
- let completions = self.stream_response(request, cx);
+ let completions = self.stream_response(request, bypass_rate_limit, cx);
async move {
let mapper = OpenAiResponseEventMapper::new();
Ok(mapper.map_stream(completions.await?).boxed())
@@ -545,6 +554,7 @@ pub fn into_open_ai_response(
stop: _,
temperature,
thinking_allowed: _,
+ bypass_rate_limit: _,
} = request;
let mut input_items = Vec::new();
@@ -1417,6 +1427,7 @@ mod tests {
stop: vec![],
temperature: None,
thinking_allowed: true,
+ bypass_rate_limit: false,
};
// Validate that all models are supported by tiktoken-rs
@@ -1553,6 +1564,7 @@ mod tests {
stop: vec!["<STOP>".into()],
temperature: None,
thinking_allowed: false,
+ bypass_rate_limit: false,
};
let response = into_open_ai_response(
@@ -204,6 +204,7 @@ impl OpenAiCompatibleLanguageModel {
fn stream_completion(
&self,
request: open_ai::Request,
+ bypass_rate_limit: bool,
cx: &AsyncApp,
) -> BoxFuture<
'static,
@@ -223,20 +224,23 @@ impl OpenAiCompatibleLanguageModel {
});
let provider = self.provider_name.clone();
- let future = self.request_limiter.stream(async move {
- let Some(api_key) = api_key else {
- return Err(LanguageModelCompletionError::NoApiKey { provider });
- };
- let request = stream_completion(
- http_client.as_ref(),
- provider.0.as_str(),
- &api_url,
- &api_key,
- request,
- );
- let response = request.await?;
- Ok(response)
- });
+ let future = self.request_limiter.stream_with_bypass(
+ async move {
+ let Some(api_key) = api_key else {
+ return Err(LanguageModelCompletionError::NoApiKey { provider });
+ };
+ let request = stream_completion(
+ http_client.as_ref(),
+ provider.0.as_str(),
+ &api_url,
+ &api_key,
+ request,
+ );
+ let response = request.await?;
+ Ok(response)
+ },
+ bypass_rate_limit,
+ );
async move { Ok(future.await?.boxed()) }.boxed()
}
@@ -244,6 +248,7 @@ impl OpenAiCompatibleLanguageModel {
fn stream_response(
&self,
request: ResponseRequest,
+ bypass_rate_limit: bool,
cx: &AsyncApp,
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponsesStreamEvent>>>>
{
@@ -258,20 +263,23 @@ impl OpenAiCompatibleLanguageModel {
});
let provider = self.provider_name.clone();
- let future = self.request_limiter.stream(async move {
- let Some(api_key) = api_key else {
- return Err(LanguageModelCompletionError::NoApiKey { provider });
- };
- let request = stream_response(
- http_client.as_ref(),
- provider.0.as_str(),
- &api_url,
- &api_key,
- request,
- );
- let response = request.await?;
- Ok(response)
- });
+ let future = self.request_limiter.stream_with_bypass(
+ async move {
+ let Some(api_key) = api_key else {
+ return Err(LanguageModelCompletionError::NoApiKey { provider });
+ };
+ let request = stream_response(
+ http_client.as_ref(),
+ provider.0.as_str(),
+ &api_url,
+ &api_key,
+ request,
+ );
+ let response = request.await?;
+ Ok(response)
+ },
+ bypass_rate_limit,
+ );
async move { Ok(future.await?.boxed()) }.boxed()
}
@@ -370,6 +378,7 @@ impl LanguageModel for OpenAiCompatibleLanguageModel {
LanguageModelCompletionError,
>,
> {
+ let bypass_rate_limit = request.bypass_rate_limit;
if self.model.capabilities.chat_completions {
let request = into_open_ai(
request,
@@ -379,7 +388,7 @@ impl LanguageModel for OpenAiCompatibleLanguageModel {
self.max_output_tokens(),
None,
);
- let completions = self.stream_completion(request, cx);
+ let completions = self.stream_completion(request, bypass_rate_limit, cx);
async move {
let mapper = OpenAiEventMapper::new();
Ok(mapper.map_stream(completions.await?).boxed())
@@ -394,7 +403,7 @@ impl LanguageModel for OpenAiCompatibleLanguageModel {
self.max_output_tokens(),
None,
);
- let completions = self.stream_response(request, cx);
+ let completions = self.stream_response(request, bypass_rate_limit, cx);
async move {
let mapper = OpenAiResponseEventMapper::new();
Ok(mapper.map_stream(completions.await?).boxed())
@@ -368,12 +368,16 @@ impl LanguageModel for OpenRouterLanguageModel {
LanguageModelCompletionError,
>,
> {
+ let bypass_rate_limit = request.bypass_rate_limit;
let openrouter_request = into_open_router(request, &self.model, self.max_output_tokens());
let request = self.stream_completion(openrouter_request, cx);
- let future = self.request_limiter.stream(async move {
- let response = request.await?;
- Ok(OpenRouterEventMapper::new().map_stream(response))
- });
+ let future = self.request_limiter.stream_with_bypass(
+ async move {
+ let response = request.await?;
+ Ok(OpenRouterEventMapper::new().map_stream(response))
+ },
+ bypass_rate_limit,
+ );
async move { Ok(future.await?.boxed()) }.boxed()
}
}
@@ -193,6 +193,7 @@ impl VercelLanguageModel {
fn stream_completion(
&self,
request: open_ai::Request,
+ bypass_rate_limit: bool,
cx: &AsyncApp,
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>>
{
@@ -203,21 +204,24 @@ impl VercelLanguageModel {
(state.api_key_state.key(&api_url), api_url)
});
- let future = self.request_limiter.stream(async move {
- let provider = PROVIDER_NAME;
- let Some(api_key) = api_key else {
- return Err(LanguageModelCompletionError::NoApiKey { provider });
- };
- let request = open_ai::stream_completion(
- http_client.as_ref(),
- provider.0.as_str(),
- &api_url,
- &api_key,
- request,
- );
- let response = request.await?;
- Ok(response)
- });
+ let future = self.request_limiter.stream_with_bypass(
+ async move {
+ let provider = PROVIDER_NAME;
+ let Some(api_key) = api_key else {
+ return Err(LanguageModelCompletionError::NoApiKey { provider });
+ };
+ let request = open_ai::stream_completion(
+ http_client.as_ref(),
+ provider.0.as_str(),
+ &api_url,
+ &api_key,
+ request,
+ );
+ let response = request.await?;
+ Ok(response)
+ },
+ bypass_rate_limit,
+ );
async move { Ok(future.await?.boxed()) }.boxed()
}
@@ -290,6 +294,7 @@ impl LanguageModel for VercelLanguageModel {
LanguageModelCompletionError,
>,
> {
+ let bypass_rate_limit = request.bypass_rate_limit;
let request = crate::provider::open_ai::into_open_ai(
request,
self.model.id(),
@@ -298,7 +303,7 @@ impl LanguageModel for VercelLanguageModel {
self.max_output_tokens(),
None,
);
- let completions = self.stream_completion(request, cx);
+ let completions = self.stream_completion(request, bypass_rate_limit, cx);
async move {
let mapper = crate::provider::open_ai::OpenAiEventMapper::new();
Ok(mapper.map_stream(completions.await?).boxed())
@@ -197,6 +197,7 @@ impl XAiLanguageModel {
fn stream_completion(
&self,
request: open_ai::Request,
+ bypass_rate_limit: bool,
cx: &AsyncApp,
) -> BoxFuture<
'static,
@@ -212,21 +213,24 @@ impl XAiLanguageModel {
(state.api_key_state.key(&api_url), api_url)
});
- let future = self.request_limiter.stream(async move {
- let provider = PROVIDER_NAME;
- let Some(api_key) = api_key else {
- return Err(LanguageModelCompletionError::NoApiKey { provider });
- };
- let request = open_ai::stream_completion(
- http_client.as_ref(),
- provider.0.as_str(),
- &api_url,
- &api_key,
- request,
- );
- let response = request.await?;
- Ok(response)
- });
+ let future = self.request_limiter.stream_with_bypass(
+ async move {
+ let provider = PROVIDER_NAME;
+ let Some(api_key) = api_key else {
+ return Err(LanguageModelCompletionError::NoApiKey { provider });
+ };
+ let request = open_ai::stream_completion(
+ http_client.as_ref(),
+ provider.0.as_str(),
+ &api_url,
+ &api_key,
+ request,
+ );
+ let response = request.await?;
+ Ok(response)
+ },
+ bypass_rate_limit,
+ );
async move { Ok(future.await?.boxed()) }.boxed()
}
@@ -307,6 +311,7 @@ impl LanguageModel for XAiLanguageModel {
LanguageModelCompletionError,
>,
> {
+ let bypass_rate_limit = request.bypass_rate_limit;
let request = crate::provider::open_ai::into_open_ai(
request,
self.model.id(),
@@ -315,7 +320,7 @@ impl LanguageModel for XAiLanguageModel {
self.max_output_tokens(),
None,
);
- let completions = self.stream_completion(request, cx);
+ let completions = self.stream_completion(request, bypass_rate_limit, cx);
async move {
let mapper = crate::provider::open_ai::OpenAiEventMapper::new();
Ok(mapper.map_stream(completions.await?).boxed())
@@ -1100,6 +1100,7 @@ impl RulesLibrary {
stop: Vec::new(),
temperature: None,
thinking_allowed: true,
+ bypass_rate_limit: false,
},
cx,
)