From 21050e2d372b358357c62ee242c16c454b0a2813 Mon Sep 17 00:00:00 2001 From: Richard Feldman Date: Tue, 20 Jan 2026 21:51:54 -0500 Subject: [PATCH] Fix nested request rate limiting deadlock for subagent edit_file (#47232) ## Problem When subagents use the `edit_file` tool, it creates an `EditAgent` that makes its own model request to get the edit instructions. These "nested" requests compete with the parent subagent conversation requests for rate limiter permits. The rate limiter uses a semaphore with a limit of 4 concurrent requests per model instance. When multiple subagents run in parallel: 1. 3 subagents each hold 1 permit for their ongoing conversation streams (3 permits used) 2. When all 3 try to use `edit_file` simultaneously, their edit agents need permits too 3. Only 1 edit agent can get the 4th permit; the other 2 block waiting 4. The blocked edit agents can't complete, so their parent subagent conversations can't complete 5. The parent conversations hold their permits, so the blocked edit agents stay blocked 6. **Deadlock** ## Solution Added a `bypass_rate_limit` field to `LanguageModelRequest`. When set to `true`, the request skips the rate limiter semaphore entirely. The `EditAgent` sets this flag because its requests are already "part of" a rate-limited parent request. (No release notes because subagents are still feature-flagged.) Release Notes: - N/A --------- Co-authored-by: Zed Zippy <234243425+zed-zippy[bot]@users.noreply.github.com> --- crates/agent/src/edit_agent.rs | 4 + .../agent/src/tests/edit_file_thread_test.rs | 618 ++++++++++++++++++ crates/agent/src/tests/mod.rs | 1 + crates/agent/src/thread.rs | 37 +- crates/agent/src/tools/edit_file_tool.rs | 17 +- crates/agent/src/tools/read_file_tool.rs | 15 + crates/agent_ui/src/buffer_codegen.rs | 2 + .../agent_ui/src/terminal_inline_assistant.rs | 1 + .../assistant_text_thread/src/text_thread.rs | 1 + crates/eval/src/instance.rs | 1 + crates/git_ui/src/git_panel.rs | 1 + crates/language_model/src/rate_limiter.rs | 219 ++++++- crates/language_model/src/request.rs | 5 + .../language_models/src/provider/anthropic.rs | 13 +- .../language_models/src/provider/bedrock.rs | 22 +- crates/language_models/src/provider/cloud.rs | 262 ++++---- .../src/provider/copilot_chat.rs | 32 +- .../language_models/src/provider/deepseek.rs | 29 +- crates/language_models/src/provider/google.rs | 12 +- .../language_models/src/provider/lmstudio.rs | 18 +- .../language_models/src/provider/mistral.rs | 31 +- crates/language_models/src/provider/ollama.rs | 22 +- .../language_models/src/provider/open_ai.rs | 74 ++- .../src/provider/open_ai_compatible.rs | 69 +- .../src/provider/open_router.rs | 12 +- crates/language_models/src/provider/vercel.rs | 37 +- crates/language_models/src/provider/x_ai.rs | 37 +- crates/rules_library/src/rules_library.rs | 1 + 28 files changed, 1305 insertions(+), 288 deletions(-) create mode 100644 crates/agent/src/tests/edit_file_thread_test.rs diff --git a/crates/agent/src/edit_agent.rs b/crates/agent/src/edit_agent.rs index ecd0b20f674c0a8efdfca3b28cce5780a882cedb..0c234b8cb152e46ad3be29f21639e5a20a1ffb86 100644 --- a/crates/agent/src/edit_agent.rs +++ b/crates/agent/src/edit_agent.rs @@ -732,6 +732,10 @@ impl EditAgent { stop: Vec::new(), temperature: None, thinking_allowed: true, + // Bypass the rate limiter for nested requests (edit agent requests spawned + // from within a tool call) to avoid deadlocks when multiple subagents try + // to use edit_file simultaneously. + bypass_rate_limit: true, }; Ok(self.model.stream_completion_text(request, cx).await?.stream) diff --git a/crates/agent/src/tests/edit_file_thread_test.rs b/crates/agent/src/tests/edit_file_thread_test.rs new file mode 100644 index 0000000000000000000000000000000000000000..67a0aa07255f15445ae236e2042864acba9833c4 --- /dev/null +++ b/crates/agent/src/tests/edit_file_thread_test.rs @@ -0,0 +1,618 @@ +use super::*; +use acp_thread::UserMessageId; +use action_log::ActionLog; +use fs::FakeFs; +use language_model::{ + LanguageModelCompletionEvent, LanguageModelToolUse, MessageContent, StopReason, + fake_provider::FakeLanguageModel, +}; +use prompt_store::ProjectContext; +use serde_json::json; +use std::{collections::BTreeMap, sync::Arc, time::Duration}; +use util::path; + +#[gpui::test] +async fn test_edit_file_tool_in_thread_context(cx: &mut TestAppContext) { + // This test verifies that the edit_file tool works correctly when invoked + // through the full thread flow (model sends ToolUse event -> tool runs -> result sent back). + // This is different from tests that call tool.run() directly. + super::init_test(cx); + super::always_allow_tools(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + path!("/project"), + json!({ + "src": { + "main.rs": "fn main() {\n println!(\"Hello, world!\");\n}\n" + } + }), + ) + .await; + + let project = project::Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + let project_context = cx.new(|_cx| ProjectContext::default()); + let context_server_store = project.read_with(cx, |project, _| project.context_server_store()); + let context_server_registry = + cx.new(|cx| crate::ContextServerRegistry::new(context_server_store.clone(), cx)); + let model = Arc::new(FakeLanguageModel::default()); + let fake_model = model.as_fake(); + + let thread = cx.new(|cx| { + let mut thread = crate::Thread::new( + project.clone(), + project_context, + context_server_registry, + crate::Templates::new(), + Some(model.clone()), + cx, + ); + // Add just the tools we need for this test + let language_registry = project.read(cx).languages().clone(); + thread.add_tool(crate::ReadFileTool::new( + cx.weak_entity(), + project.clone(), + thread.action_log().clone(), + )); + thread.add_tool(crate::EditFileTool::new( + project.clone(), + cx.weak_entity(), + language_registry, + crate::Templates::new(), + )); + thread + }); + + // First, read the file so the thread knows about its contents + let _events = thread + .update(cx, |thread, cx| { + thread.send(UserMessageId::new(), ["Read the file src/main.rs"], cx) + }) + .unwrap(); + cx.run_until_parked(); + + // Model calls read_file tool + let read_tool_use = LanguageModelToolUse { + id: "read_tool_1".into(), + name: "read_file".into(), + raw_input: json!({"path": "project/src/main.rs"}).to_string(), + input: json!({"path": "project/src/main.rs"}), + is_input_complete: true, + thought_signature: None, + }; + fake_model + .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(read_tool_use)); + fake_model + .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + // Wait for the read tool to complete and model to be called again + while fake_model.pending_completions().is_empty() { + cx.run_until_parked(); + } + + // Model responds after seeing the file content, then calls edit_file + fake_model.send_last_completion_stream_text_chunk("I'll edit the file now."); + let edit_tool_use = LanguageModelToolUse { + id: "edit_tool_1".into(), + name: "edit_file".into(), + raw_input: json!({ + "display_description": "Change greeting message", + "path": "project/src/main.rs", + "mode": "edit" + }) + .to_string(), + input: json!({ + "display_description": "Change greeting message", + "path": "project/src/main.rs", + "mode": "edit" + }), + is_input_complete: true, + thought_signature: None, + }; + fake_model + .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(edit_tool_use)); + fake_model + .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + // The edit_file tool creates an EditAgent which makes its own model request. + // We need to respond to that request with the edit instructions. + // Wait for the edit agent's completion request + let deadline = std::time::Instant::now() + Duration::from_secs(5); + while fake_model.pending_completions().is_empty() { + if std::time::Instant::now() >= deadline { + panic!( + "Timed out waiting for edit agent completion request. Pending: {}", + fake_model.pending_completions().len() + ); + } + cx.run_until_parked(); + cx.background_executor + .timer(Duration::from_millis(10)) + .await; + } + + // Send the edit agent's response with the XML format it expects + let edit_response = "println!(\"Hello, world!\");\nprintln!(\"Hello, Zed!\");"; + fake_model.send_last_completion_stream_text_chunk(edit_response); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + // Wait for the edit to complete and the thread to call the model again with tool results + let deadline = std::time::Instant::now() + Duration::from_secs(5); + while fake_model.pending_completions().is_empty() { + if std::time::Instant::now() >= deadline { + panic!("Timed out waiting for model to be called after edit completion"); + } + cx.run_until_parked(); + cx.background_executor + .timer(Duration::from_millis(10)) + .await; + } + + // Verify the file was edited + let file_content = fs + .load(path!("/project/src/main.rs").as_ref()) + .await + .expect("file should exist"); + assert!( + file_content.contains("Hello, Zed!"), + "File should have been edited. Content: {}", + file_content + ); + assert!( + !file_content.contains("Hello, world!"), + "Old content should be replaced. Content: {}", + file_content + ); + + // Verify the tool result was sent back to the model + let pending = fake_model.pending_completions(); + assert!( + !pending.is_empty(), + "Model should have been called with tool result" + ); + + let last_request = pending.last().unwrap(); + let has_tool_result = last_request.messages.iter().any(|m| { + m.content + .iter() + .any(|c| matches!(c, language_model::MessageContent::ToolResult(_))) + }); + assert!( + has_tool_result, + "Tool result should be in the messages sent back to the model" + ); + + // Complete the turn + fake_model.send_last_completion_stream_text_chunk("I've updated the greeting message."); + fake_model + .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + // Verify the thread completed successfully + thread.update(cx, |thread, _cx| { + assert!( + thread.is_turn_complete(), + "Thread should be complete after the turn ends" + ); + }); +} + +#[gpui::test] +async fn test_subagent_uses_read_file_tool(cx: &mut TestAppContext) { + // This test verifies that subagents can successfully use the read_file tool + // through the full thread flow, and that tools are properly rebound to use + // the subagent's thread ID instead of the parent's. + super::init_test(cx); + super::always_allow_tools(cx); + + cx.update(|cx| { + cx.update_flags(true, vec!["subagents".to_string()]); + }); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + path!("/project"), + json!({ + "src": { + "lib.rs": "pub fn hello() -> &'static str {\n \"Hello from lib!\"\n}\n" + } + }), + ) + .await; + + let project = project::Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + let project_context = cx.new(|_cx| ProjectContext::default()); + let context_server_store = project.read_with(cx, |project, _| project.context_server_store()); + let context_server_registry = + cx.new(|cx| crate::ContextServerRegistry::new(context_server_store.clone(), cx)); + let model = Arc::new(FakeLanguageModel::default()); + let fake_model = model.as_fake(); + + // Create subagent context + let subagent_context = crate::SubagentContext { + parent_thread_id: agent_client_protocol::SessionId::new("parent-id"), + tool_use_id: language_model::LanguageModelToolUseId::from("subagent-tool-use-id"), + depth: 1, + summary_prompt: "Summarize what you found".to_string(), + context_low_prompt: "Context low".to_string(), + }; + + // Create parent tools that will be passed to the subagent + // This simulates how the subagent_tool passes tools to new_subagent + let parent_tools: BTreeMap> = { + let action_log = cx.new(|_| ActionLog::new(project.clone())); + // Create a "fake" parent thread reference - this should get rebound + let fake_parent_thread = cx.new(|cx| { + crate::Thread::new( + project.clone(), + cx.new(|_cx| ProjectContext::default()), + cx.new(|cx| crate::ContextServerRegistry::new(context_server_store.clone(), cx)), + crate::Templates::new(), + Some(model.clone()), + cx, + ) + }); + let mut tools: BTreeMap> = + BTreeMap::new(); + tools.insert( + "read_file".into(), + crate::ReadFileTool::new(fake_parent_thread.downgrade(), project.clone(), action_log) + .erase(), + ); + tools + }; + + // Create subagent - tools should be rebound to use subagent's thread + let subagent = cx.new(|cx| { + crate::Thread::new_subagent( + project.clone(), + project_context, + context_server_registry, + crate::Templates::new(), + model.clone(), + subagent_context, + parent_tools, + cx, + ) + }); + + // Get the subagent's thread ID + let _subagent_thread_id = subagent.read_with(cx, |thread, _| thread.id().to_string()); + + // Verify the subagent has the read_file tool + subagent.read_with(cx, |thread, _| { + assert!( + thread.has_registered_tool("read_file"), + "subagent should have read_file tool" + ); + }); + + // Submit a user message to the subagent + subagent + .update(cx, |thread, cx| { + thread.submit_user_message("Read the file src/lib.rs", cx) + }) + .unwrap(); + cx.run_until_parked(); + + // Simulate the model calling the read_file tool + let read_tool_use = LanguageModelToolUse { + id: "read_tool_1".into(), + name: "read_file".into(), + raw_input: json!({"path": "project/src/lib.rs"}).to_string(), + input: json!({"path": "project/src/lib.rs"}), + is_input_complete: true, + thought_signature: None, + }; + fake_model + .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(read_tool_use)); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + // Wait for the tool to complete and the model to be called again with tool results + let deadline = std::time::Instant::now() + Duration::from_secs(5); + while fake_model.pending_completions().is_empty() { + if std::time::Instant::now() >= deadline { + panic!("Timed out waiting for model to be called after read_file tool completion"); + } + cx.run_until_parked(); + cx.background_executor + .timer(Duration::from_millis(10)) + .await; + } + + // Verify the tool result was sent back to the model + let pending = fake_model.pending_completions(); + assert!( + !pending.is_empty(), + "Model should have been called with tool result" + ); + + let last_request = pending.last().unwrap(); + let tool_result = last_request.messages.iter().find_map(|m| { + m.content.iter().find_map(|c| match c { + MessageContent::ToolResult(result) => Some(result), + _ => None, + }) + }); + assert!( + tool_result.is_some(), + "Tool result should be in the messages sent back to the model" + ); + + // Verify the tool result contains the file content + let result = tool_result.unwrap(); + let result_text = match &result.content { + language_model::LanguageModelToolResultContent::Text(text) => text.to_string(), + _ => panic!("expected text content in tool result"), + }; + assert!( + result_text.contains("Hello from lib!"), + "Tool result should contain file content, got: {}", + result_text + ); + + // Verify the subagent is ready for more input (tool completed, model called again) + // This test verifies the subagent can successfully use read_file tool. + // The summary flow is tested separately in test_subagent_returns_summary_on_completion. +} + +#[gpui::test] +async fn test_subagent_uses_edit_file_tool(cx: &mut TestAppContext) { + // This test verifies that subagents can successfully use the edit_file tool + // through the full thread flow, including the edit agent's model request. + // It also verifies that the edit agent uses the subagent's thread ID, not the parent's. + super::init_test(cx); + super::always_allow_tools(cx); + + cx.update(|cx| { + cx.update_flags(true, vec!["subagents".to_string()]); + }); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + path!("/project"), + json!({ + "src": { + "config.rs": "pub const VERSION: &str = \"1.0.0\";\n" + } + }), + ) + .await; + + let project = project::Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + let project_context = cx.new(|_cx| ProjectContext::default()); + let context_server_store = project.read_with(cx, |project, _| project.context_server_store()); + let context_server_registry = + cx.new(|cx| crate::ContextServerRegistry::new(context_server_store.clone(), cx)); + let model = Arc::new(FakeLanguageModel::default()); + let fake_model = model.as_fake(); + + // Create a "parent" thread to simulate the real scenario where tools are inherited + let parent_thread = cx.new(|cx| { + crate::Thread::new( + project.clone(), + cx.new(|_cx| ProjectContext::default()), + cx.new(|cx| crate::ContextServerRegistry::new(context_server_store.clone(), cx)), + crate::Templates::new(), + Some(model.clone()), + cx, + ) + }); + let parent_thread_id = parent_thread.read_with(cx, |thread, _| thread.id().to_string()); + + // Create parent tools that reference the parent thread + let parent_tools: BTreeMap> = { + let action_log = cx.new(|_| ActionLog::new(project.clone())); + let language_registry = project.read_with(cx, |p, _| p.languages().clone()); + let mut tools: BTreeMap> = + BTreeMap::new(); + tools.insert( + "read_file".into(), + crate::ReadFileTool::new(parent_thread.downgrade(), project.clone(), action_log) + .erase(), + ); + tools.insert( + "edit_file".into(), + crate::EditFileTool::new( + project.clone(), + parent_thread.downgrade(), + language_registry, + crate::Templates::new(), + ) + .erase(), + ); + tools + }; + + // Create subagent context + let subagent_context = crate::SubagentContext { + parent_thread_id: agent_client_protocol::SessionId::new("parent-id"), + tool_use_id: language_model::LanguageModelToolUseId::from("subagent-tool-use-id"), + depth: 1, + summary_prompt: "Summarize what you changed".to_string(), + context_low_prompt: "Context low".to_string(), + }; + + // Create subagent - tools should be rebound to use subagent's thread + let subagent = cx.new(|cx| { + crate::Thread::new_subagent( + project.clone(), + project_context, + context_server_registry, + crate::Templates::new(), + model.clone(), + subagent_context, + parent_tools, + cx, + ) + }); + + // Get the subagent's thread ID - it should be different from parent + let subagent_thread_id = subagent.read_with(cx, |thread, _| thread.id().to_string()); + assert_ne!( + parent_thread_id, subagent_thread_id, + "Subagent should have a different thread ID than parent" + ); + + // Verify the subagent has the tools + subagent.read_with(cx, |thread, _| { + assert!( + thread.has_registered_tool("read_file"), + "subagent should have read_file tool" + ); + assert!( + thread.has_registered_tool("edit_file"), + "subagent should have edit_file tool" + ); + }); + + // Submit a user message to the subagent + subagent + .update(cx, |thread, cx| { + thread.submit_user_message("Update the version in config.rs to 2.0.0", cx) + }) + .unwrap(); + cx.run_until_parked(); + + // First, model calls read_file to see the current content + let read_tool_use = LanguageModelToolUse { + id: "read_tool_1".into(), + name: "read_file".into(), + raw_input: json!({"path": "project/src/config.rs"}).to_string(), + input: json!({"path": "project/src/config.rs"}), + is_input_complete: true, + thought_signature: None, + }; + fake_model + .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(read_tool_use)); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + // Wait for the read tool to complete and model to be called again + let deadline = std::time::Instant::now() + Duration::from_secs(5); + while fake_model.pending_completions().is_empty() { + if std::time::Instant::now() >= deadline { + panic!("Timed out waiting for model to be called after read_file tool"); + } + cx.run_until_parked(); + cx.background_executor + .timer(Duration::from_millis(10)) + .await; + } + + // Model responds and calls edit_file + fake_model.send_last_completion_stream_text_chunk("I'll update the version now."); + let edit_tool_use = LanguageModelToolUse { + id: "edit_tool_1".into(), + name: "edit_file".into(), + raw_input: json!({ + "display_description": "Update version to 2.0.0", + "path": "project/src/config.rs", + "mode": "edit" + }) + .to_string(), + input: json!({ + "display_description": "Update version to 2.0.0", + "path": "project/src/config.rs", + "mode": "edit" + }), + is_input_complete: true, + thought_signature: None, + }; + fake_model + .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(edit_tool_use)); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + // The edit_file tool creates an EditAgent which makes its own model request. + // Wait for that request. + let deadline = std::time::Instant::now() + Duration::from_secs(5); + while fake_model.pending_completions().is_empty() { + if std::time::Instant::now() >= deadline { + panic!( + "Timed out waiting for edit agent completion request in subagent. Pending: {}", + fake_model.pending_completions().len() + ); + } + cx.run_until_parked(); + cx.background_executor + .timer(Duration::from_millis(10)) + .await; + } + + // Verify the edit agent's request uses the SUBAGENT's thread ID, not the parent's + let pending = fake_model.pending_completions(); + let edit_agent_request = pending.last().unwrap(); + let edit_agent_thread_id = edit_agent_request.thread_id.as_ref().unwrap(); + std::assert_eq!( + edit_agent_thread_id, + &subagent_thread_id, + "Edit agent should use subagent's thread ID, not parent's. Got: {}, expected: {}", + edit_agent_thread_id, + subagent_thread_id + ); + std::assert_ne!( + edit_agent_thread_id, + &parent_thread_id, + "Edit agent should NOT use parent's thread ID" + ); + + // Send the edit agent's response with the XML format it expects + let edit_response = "pub const VERSION: &str = \"1.0.0\";\npub const VERSION: &str = \"2.0.0\";"; + fake_model.send_last_completion_stream_text_chunk(edit_response); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + // Wait for the edit to complete and the thread to call the model again with tool results + let deadline = std::time::Instant::now() + Duration::from_secs(5); + while fake_model.pending_completions().is_empty() { + if std::time::Instant::now() >= deadline { + panic!("Timed out waiting for model to be called after edit completion in subagent"); + } + cx.run_until_parked(); + cx.background_executor + .timer(Duration::from_millis(10)) + .await; + } + + // Verify the file was edited + let file_content = fs + .load(path!("/project/src/config.rs").as_ref()) + .await + .expect("file should exist"); + assert!( + file_content.contains("2.0.0"), + "File should have been edited to contain new version. Content: {}", + file_content + ); + assert!( + !file_content.contains("1.0.0"), + "Old version should be replaced. Content: {}", + file_content + ); + + // Verify the tool result was sent back to the model + let pending = fake_model.pending_completions(); + assert!( + !pending.is_empty(), + "Model should have been called with tool result" + ); + + let last_request = pending.last().unwrap(); + let has_tool_result = last_request.messages.iter().any(|m| { + m.content + .iter() + .any(|c| matches!(c, MessageContent::ToolResult(_))) + }); + assert!( + has_tool_result, + "Tool result should be in the messages sent back to the model" + ); +} diff --git a/crates/agent/src/tests/mod.rs b/crates/agent/src/tests/mod.rs index 762b4d9a1393f96116e77d3e265b53633f014e7a..a4706f6a752b0ae2fd251320106da998819b0b47 100644 --- a/crates/agent/src/tests/mod.rs +++ b/crates/agent/src/tests/mod.rs @@ -52,6 +52,7 @@ use std::{ }; use util::path; +mod edit_file_thread_test; mod test_tools; use test_tools::*; diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index b1f868a4e42e9e7ddc8ddbf866986f72360e35fc..d58e3cb0c4a14489b8e6f5321e4c7a5178ebe766 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -836,6 +836,19 @@ impl Thread { let action_log = cx.new(|_cx| ActionLog::new(project.clone())); let (prompt_capabilities_tx, prompt_capabilities_rx) = watch::channel(Self::prompt_capabilities(Some(model.as_ref()))); + + // Rebind tools that hold thread references to use this subagent's thread + // instead of the parent's thread. This is critical for tools like EditFileTool + // that make model requests using the thread's ID. + let weak_self = cx.weak_entity(); + let tools: BTreeMap> = parent_tools + .into_iter() + .map(|(name, tool)| { + let rebound = tool.rebind_thread(weak_self.clone()).unwrap_or(tool); + (name, rebound) + }) + .collect(); + Self { id: acp::SessionId::new(uuid::Uuid::new_v4().to_string()), prompt_id: PromptId::new(), @@ -849,7 +862,7 @@ impl Thread { running_turn: None, queued_messages: Vec::new(), pending_message: None, - tools: parent_tools, + tools, request_token_usage: HashMap::default(), cumulative_token_usage: TokenUsage::default(), initial_project_snapshot: Task::ready(None).shared(), @@ -2274,6 +2287,7 @@ impl Thread { stop: Vec::new(), temperature: AgentSettings::temperature_for_model(model, cx), thinking_allowed: true, + bypass_rate_limit: false, }; log::debug!("Completion request built successfully"); @@ -2690,6 +2704,15 @@ where fn erase(self) -> Arc { Arc::new(Erased(Arc::new(self))) } + + /// Create a new instance of this tool bound to a different thread. + /// This is used when creating subagents, so that tools like EditFileTool + /// that hold a thread reference will use the subagent's thread instead + /// of the parent's thread. + /// Returns None if the tool doesn't need rebinding (most tools). + fn rebind_thread(&self, _new_thread: WeakEntity) -> Option> { + None + } } pub struct Erased(T); @@ -2721,6 +2744,14 @@ pub trait AnyAgentTool { event_stream: ToolCallEventStream, cx: &mut App, ) -> Result<()>; + /// Create a new instance of this tool bound to a different thread. + /// This is used when creating subagents, so that tools like EditFileTool + /// that hold a thread reference will use the subagent's thread instead + /// of the parent's thread. + /// Returns None if the tool doesn't need rebinding (most tools). + fn rebind_thread(&self, _new_thread: WeakEntity) -> Option> { + None + } } impl AnyAgentTool for Erased> @@ -2784,6 +2815,10 @@ where let output = serde_json::from_value(output)?; self.0.replay(input, output, event_stream, cx) } + + fn rebind_thread(&self, new_thread: WeakEntity) -> Option> { + self.0.rebind_thread(new_thread) + } } #[derive(Clone)] diff --git a/crates/agent/src/tools/edit_file_tool.rs b/crates/agent/src/tools/edit_file_tool.rs index 4f8288c24382d373a41fd15b779970676ad09fae..bc7e5b5289937d6212c662f97238e43ea185684d 100644 --- a/crates/agent/src/tools/edit_file_tool.rs +++ b/crates/agent/src/tools/edit_file_tool.rs @@ -144,6 +144,15 @@ impl EditFileTool { } } + pub fn with_thread(&self, new_thread: WeakEntity) -> Self { + Self { + project: self.project.clone(), + thread: new_thread, + language_registry: self.language_registry.clone(), + templates: self.templates.clone(), + } + } + fn authorize( &self, input: &EditFileToolInput, @@ -398,7 +407,6 @@ impl AgentTool for EditFileTool { }) .await; - let (output, mut events) = if matches!(input.mode, EditFileMode::Edit) { edit_agent.edit( buffer.clone(), @@ -575,6 +583,13 @@ impl AgentTool for EditFileTool { })); Ok(()) } + + fn rebind_thread( + &self, + new_thread: gpui::WeakEntity, + ) -> Option> { + Some(self.with_thread(new_thread).erase()) + } } /// Validate that the file path is valid, meaning: diff --git a/crates/agent/src/tools/read_file_tool.rs b/crates/agent/src/tools/read_file_tool.rs index bc7647739035a41b91c481d2f25b5fbd0f7856c7..8b13452e9357921a1f7a43a51a3364594b481c42 100644 --- a/crates/agent/src/tools/read_file_tool.rs +++ b/crates/agent/src/tools/read_file_tool.rs @@ -65,6 +65,14 @@ impl ReadFileTool { action_log, } } + + pub fn with_thread(&self, new_thread: WeakEntity) -> Self { + Self { + thread: new_thread, + project: self.project.clone(), + action_log: self.action_log.clone(), + } + } } impl AgentTool for ReadFileTool { @@ -308,6 +316,13 @@ impl AgentTool for ReadFileTool { result }) } + + fn rebind_thread( + &self, + new_thread: WeakEntity, + ) -> Option> { + Some(self.with_thread(new_thread).erase()) + } } #[cfg(test)] diff --git a/crates/agent_ui/src/buffer_codegen.rs b/crates/agent_ui/src/buffer_codegen.rs index 6d860bea9974b209289fd57276351e96ed744b20..5cfd161e9fb01cfde6fefe65ef7b3a5fbd89a6f7 100644 --- a/crates/agent_ui/src/buffer_codegen.rs +++ b/crates/agent_ui/src/buffer_codegen.rs @@ -544,6 +544,7 @@ impl CodegenAlternative { temperature, messages, thinking_allowed: false, + bypass_rate_limit: false, } })) } @@ -622,6 +623,7 @@ impl CodegenAlternative { temperature, messages: vec![request_message], thinking_allowed: false, + bypass_rate_limit: false, } })) } diff --git a/crates/agent_ui/src/terminal_inline_assistant.rs b/crates/agent_ui/src/terminal_inline_assistant.rs index 58a73131e2d20d0776b2cdc49a0b395834b5008f..04dafa588dd3960eb435ef7d9225217d2dfb3354 100644 --- a/crates/agent_ui/src/terminal_inline_assistant.rs +++ b/crates/agent_ui/src/terminal_inline_assistant.rs @@ -275,6 +275,7 @@ impl TerminalInlineAssistant { stop: Vec::new(), temperature, thinking_allowed: false, + bypass_rate_limit: false, } })) } diff --git a/crates/assistant_text_thread/src/text_thread.rs b/crates/assistant_text_thread/src/text_thread.rs index 034314e349a306040fc0cd37dbc3ad9a5ea6e81b..042169eb93b51b681e88091cf994d95fa7b88436 100644 --- a/crates/assistant_text_thread/src/text_thread.rs +++ b/crates/assistant_text_thread/src/text_thread.rs @@ -2269,6 +2269,7 @@ impl TextThread { stop: Vec::new(), temperature: model.and_then(|model| AgentSettings::temperature_for_model(model, cx)), thinking_allowed: true, + bypass_rate_limit: false, }; for message in self.messages(cx) { if message.status != MessageStatus::Done { diff --git a/crates/eval/src/instance.rs b/crates/eval/src/instance.rs index 17b5adfc4aa9621ac4638f873c30e62ab6244107..1cb5e9a10c3c8814f154643b590e280af724188a 100644 --- a/crates/eval/src/instance.rs +++ b/crates/eval/src/instance.rs @@ -563,6 +563,7 @@ impl ExampleInstance { tool_choice: None, stop: Vec::new(), thinking_allowed: true, + bypass_rate_limit: false, }; let model = model.clone(); diff --git a/crates/git_ui/src/git_panel.rs b/crates/git_ui/src/git_panel.rs index 3f2f406136bdc3a8d9a813e37e264e97633bd214..1628358c602fd5a2ff67a178331ec2172f3d7a67 100644 --- a/crates/git_ui/src/git_panel.rs +++ b/crates/git_ui/src/git_panel.rs @@ -2691,6 +2691,7 @@ impl GitPanel { stop: Vec::new(), temperature, thinking_allowed: false, + bypass_rate_limit: false, }; let stream = model.stream_completion_text(request, cx); diff --git a/crates/language_model/src/rate_limiter.rs b/crates/language_model/src/rate_limiter.rs index 790be05ac069b8f394e442cbcb6383f611326a69..f4aea8177666e406f8c22f2c92fd0c6f9b4619a8 100644 --- a/crates/language_model/src/rate_limiter.rs +++ b/crates/language_model/src/rate_limiter.rs @@ -16,7 +16,7 @@ pub struct RateLimiter { pub struct RateLimitGuard { inner: T, - _guard: SemaphoreGuardArc, + _guard: Option, } impl Stream for RateLimitGuard @@ -68,6 +68,36 @@ impl RateLimiter { async move { let guard = guard.await; let inner = future.await?; + Ok(RateLimitGuard { + inner, + _guard: Some(guard), + }) + } + } + + /// Like `stream`, but conditionally bypasses the rate limiter based on the flag. + /// Used for nested requests (like edit agent requests) that are already "part of" + /// a rate-limited request to avoid deadlocks. + pub fn stream_with_bypass<'a, Fut, T>( + &self, + future: Fut, + bypass: bool, + ) -> impl 'a + + Future< + Output = Result + use, LanguageModelCompletionError>, + > + where + Fut: 'a + Future>, + T: Stream, + { + let semaphore = self.semaphore.clone(); + async move { + let guard = if bypass { + None + } else { + Some(semaphore.acquire_arc().await) + }; + let inner = future.await?; Ok(RateLimitGuard { inner, _guard: guard, @@ -75,3 +105,190 @@ impl RateLimiter { } } } + +#[cfg(test)] +mod tests { + use super::*; + use futures::stream; + use smol::lock::Barrier; + use std::sync::Arc; + use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; + use std::time::{Duration, Instant}; + + /// Tests that nested requests without bypass_rate_limit cause deadlock, + /// while requests with bypass_rate_limit complete successfully. + /// + /// This test simulates the scenario where multiple "parent" requests each + /// try to spawn a "nested" request (like edit_file tool spawning an edit agent). + /// With a rate limit of 2 and 2 parent requests, without bypass the nested + /// requests would block forever waiting for permits that the parents hold. + #[test] + fn test_nested_requests_bypass_prevents_deadlock() { + smol::block_on(async { + // Use only 2 permits so we can guarantee deadlock conditions + let rate_limiter = RateLimiter::new(2); + let completed = Arc::new(AtomicUsize::new(0)); + // Barrier ensures all parents acquire permits before any tries nested request + let barrier = Arc::new(Barrier::new(2)); + + // Spawn 2 "parent" requests that each try to make a "nested" request + let mut handles = Vec::new(); + for _ in 0..2 { + let limiter = rate_limiter.clone(); + let completed = completed.clone(); + let barrier = barrier.clone(); + + let handle = smol::spawn(async move { + // Parent request acquires a permit via stream_with_bypass (bypass=false) + let parent_stream = limiter + .stream_with_bypass( + async { + // Wait for all parents to acquire permits + barrier.wait().await; + + // While holding the parent permit, make a nested request + // WITH bypass=true (simulating EditAgent behavior) + let nested_stream = limiter + .stream_with_bypass( + async { Ok(stream::iter(vec![1, 2, 3])) }, + true, // bypass - this is the key! + ) + .await?; + + // Consume the nested stream + use futures::StreamExt; + let _: Vec<_> = nested_stream.collect().await; + + Ok(stream::iter(vec!["done"])) + }, + false, // parent does NOT bypass + ) + .await + .unwrap(); + + // Consume parent stream + use futures::StreamExt; + let _: Vec<_> = parent_stream.collect().await; + + completed.fetch_add(1, Ordering::SeqCst); + }); + handles.push(handle); + } + + // With bypass=true for nested requests, this should complete quickly + let timed_out = Arc::new(AtomicBool::new(false)); + let timed_out_clone = timed_out.clone(); + + // Spawn a watchdog that sets timed_out after 2 seconds + let watchdog = smol::spawn(async move { + let start = Instant::now(); + while start.elapsed() < Duration::from_secs(2) { + smol::future::yield_now().await; + } + timed_out_clone.store(true, Ordering::SeqCst); + }); + + // Wait for all handles to complete + for handle in handles { + handle.await; + } + + // Cancel the watchdog + drop(watchdog); + + if timed_out.load(Ordering::SeqCst) { + panic!( + "Test timed out - deadlock detected! This means bypass_rate_limit is not working." + ); + } + assert_eq!(completed.load(Ordering::SeqCst), 2); + }); + } + + /// Tests that without bypass, nested requests DO cause deadlock. + /// This test verifies the problem exists when bypass is not used. + #[test] + fn test_nested_requests_without_bypass_deadlocks() { + smol::block_on(async { + // Use only 2 permits so we can guarantee deadlock conditions + let rate_limiter = RateLimiter::new(2); + let completed = Arc::new(AtomicUsize::new(0)); + // Barrier ensures all parents acquire permits before any tries nested request + let barrier = Arc::new(Barrier::new(2)); + + // Spawn 2 "parent" requests that each try to make a "nested" request + let mut handles = Vec::new(); + for _ in 0..2 { + let limiter = rate_limiter.clone(); + let completed = completed.clone(); + let barrier = barrier.clone(); + + let handle = smol::spawn(async move { + // Parent request acquires a permit + let parent_stream = limiter + .stream_with_bypass( + async { + // Wait for all parents to acquire permits - this guarantees + // that all 2 permits are held before any nested request starts + barrier.wait().await; + + // Nested request WITHOUT bypass - this will deadlock! + // Both parents hold permits, so no permits available + let nested_stream = limiter + .stream_with_bypass( + async { Ok(stream::iter(vec![1, 2, 3])) }, + false, // NO bypass - will try to acquire permit + ) + .await?; + + use futures::StreamExt; + let _: Vec<_> = nested_stream.collect().await; + + Ok(stream::iter(vec!["done"])) + }, + false, + ) + .await + .unwrap(); + + use futures::StreamExt; + let _: Vec<_> = parent_stream.collect().await; + + completed.fetch_add(1, Ordering::SeqCst); + }); + handles.push(handle); + } + + // This SHOULD timeout because of deadlock (both parents hold permits, + // both nested requests wait for permits) + let timed_out = Arc::new(AtomicBool::new(false)); + let timed_out_clone = timed_out.clone(); + + // Spawn a watchdog that sets timed_out after 100ms + let watchdog = smol::spawn(async move { + let start = Instant::now(); + while start.elapsed() < Duration::from_millis(100) { + smol::future::yield_now().await; + } + timed_out_clone.store(true, Ordering::SeqCst); + }); + + // Poll briefly to let everything run + let start = Instant::now(); + while start.elapsed() < Duration::from_millis(100) { + smol::future::yield_now().await; + } + + // Cancel the watchdog + drop(watchdog); + + // Expected - deadlock occurred, which proves the bypass is necessary + let count = completed.load(Ordering::SeqCst); + assert_eq!( + count, 0, + "Expected complete deadlock (0 completed) but {} requests completed", + count + ); + }); + } +} diff --git a/crates/language_model/src/request.rs b/crates/language_model/src/request.rs index 1d0b536ec43bd9e930c24e3a733448ae12a8d65b..9c3643e3471d52913d1defdb776365750b1870c6 100644 --- a/crates/language_model/src/request.rs +++ b/crates/language_model/src/request.rs @@ -451,6 +451,11 @@ pub struct LanguageModelRequest { pub stop: Vec, pub temperature: Option, pub thinking_allowed: bool, + /// When true, this request bypasses the rate limiter. Used for nested requests + /// (like edit agent requests spawned from within a tool call) that are already + /// "part of" a rate-limited request to avoid deadlocks. + #[serde(default)] + pub bypass_rate_limit: bool, } #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] diff --git a/crates/language_models/src/provider/anthropic.rs b/crates/language_models/src/provider/anthropic.rs index 3ea4a8ee37d76474714d6a4b875982e11f5c691b..d88205b267ada4883bb3e3b568ec9ac7cee57514 100644 --- a/crates/language_models/src/provider/anthropic.rs +++ b/crates/language_models/src/provider/anthropic.rs @@ -578,6 +578,7 @@ impl LanguageModel for AnthropicModel { LanguageModelCompletionError, >, > { + let bypass_rate_limit = request.bypass_rate_limit; let request = into_anthropic( request, self.model.request_id().into(), @@ -586,10 +587,13 @@ impl LanguageModel for AnthropicModel { self.model.mode(), ); let request = self.stream_completion(request, cx); - let future = self.request_limiter.stream(async move { - let response = request.await?; - Ok(AnthropicEventMapper::new().map_stream(response)) - }); + let future = self.request_limiter.stream_with_bypass( + async move { + let response = request.await?; + Ok(AnthropicEventMapper::new().map_stream(response)) + }, + bypass_rate_limit, + ); async move { Ok(future.await?.boxed()) }.boxed() } @@ -1164,6 +1168,7 @@ mod tests { tools: vec![], tool_choice: None, thinking_allowed: true, + bypass_rate_limit: false, }; let anthropic_request = into_anthropic( diff --git a/crates/language_models/src/provider/bedrock.rs b/crates/language_models/src/provider/bedrock.rs index 05c35fb047956740c86f7fc87f6691564360b4d5..d02ed461f8e0cbbefc17b5c44526849724b3a3bf 100644 --- a/crates/language_models/src/provider/bedrock.rs +++ b/crates/language_models/src/provider/bedrock.rs @@ -679,6 +679,7 @@ impl LanguageModel for BedrockModel { }; let deny_tool_calls = request.tool_choice == Some(LanguageModelToolChoice::None); + let bypass_rate_limit = request.bypass_rate_limit; let request = match into_bedrock( request, @@ -693,16 +694,19 @@ impl LanguageModel for BedrockModel { }; let request = self.stream_completion(request, cx); - let future = self.request_limiter.stream(async move { - let response = request.await.map_err(|err| anyhow!(err))?; - let events = map_to_language_model_completion_events(response); + let future = self.request_limiter.stream_with_bypass( + async move { + let response = request.await.map_err(|err| anyhow!(err))?; + let events = map_to_language_model_completion_events(response); - if deny_tool_calls { - Ok(deny_tool_use_events(events).boxed()) - } else { - Ok(events.boxed()) - } - }); + if deny_tool_calls { + Ok(deny_tool_use_events(events).boxed()) + } else { + Ok(events.boxed()) + } + }, + bypass_rate_limit, + ); async move { Ok(future.await?.boxed()) }.boxed() } diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index 4cef74cd325736944fd91612303897f8ea0062c5..a96a0618e15217c648836c89fad51725b16b2b43 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -719,6 +719,7 @@ impl LanguageModel for CloudLanguageModel { let thread_id = request.thread_id.clone(); let prompt_id = request.prompt_id.clone(); let intent = request.intent; + let bypass_rate_limit = request.bypass_rate_limit; let app_version = Some(cx.update(|cx| AppVersion::global(cx))); let use_responses_api = cx.update(|cx| cx.has_flag::()); let thinking_allowed = request.thinking_allowed; @@ -740,53 +741,8 @@ impl LanguageModel for CloudLanguageModel { ); let client = self.client.clone(); let llm_api_token = self.llm_api_token.clone(); - let future = self.request_limiter.stream(async move { - let PerformLlmCompletionResponse { - response, - includes_status_messages, - } = Self::perform_llm_completion( - client.clone(), - llm_api_token, - app_version, - CompletionBody { - thread_id, - prompt_id, - intent, - provider: cloud_llm_client::LanguageModelProvider::Anthropic, - model: request.model.clone(), - provider_request: serde_json::to_value(&request) - .map_err(|e| anyhow!(e))?, - }, - ) - .await - .map_err(|err| match err.downcast::() { - Ok(api_err) => anyhow!(LanguageModelCompletionError::from(api_err)), - Err(err) => anyhow!(err), - })?; - - let mut mapper = AnthropicEventMapper::new(); - Ok(map_cloud_completion_events( - Box::pin(response_lines(response, includes_status_messages)), - &provider_name, - move |event| mapper.map_event(event), - )) - }); - async move { Ok(future.await?.boxed()) }.boxed() - } - cloud_llm_client::LanguageModelProvider::OpenAi => { - let client = self.client.clone(); - let llm_api_token = self.llm_api_token.clone(); - - if use_responses_api { - let request = into_open_ai_response( - request, - &self.model.id.0, - self.model.supports_parallel_tool_calls, - true, - None, - None, - ); - let future = self.request_limiter.stream(async move { + let future = self.request_limiter.stream_with_bypass( + async move { let PerformLlmCompletionResponse { response, includes_status_messages, @@ -798,21 +754,74 @@ impl LanguageModel for CloudLanguageModel { thread_id, prompt_id, intent, - provider: cloud_llm_client::LanguageModelProvider::OpenAi, + provider: cloud_llm_client::LanguageModelProvider::Anthropic, model: request.model.clone(), provider_request: serde_json::to_value(&request) .map_err(|e| anyhow!(e))?, }, ) - .await?; - - let mut mapper = OpenAiResponseEventMapper::new(); + .await + .map_err(|err| { + match err.downcast::() { + Ok(api_err) => anyhow!(LanguageModelCompletionError::from(api_err)), + Err(err) => anyhow!(err), + } + })?; + + let mut mapper = AnthropicEventMapper::new(); Ok(map_cloud_completion_events( Box::pin(response_lines(response, includes_status_messages)), &provider_name, move |event| mapper.map_event(event), )) - }); + }, + bypass_rate_limit, + ); + async move { Ok(future.await?.boxed()) }.boxed() + } + cloud_llm_client::LanguageModelProvider::OpenAi => { + let client = self.client.clone(); + let llm_api_token = self.llm_api_token.clone(); + + if use_responses_api { + let request = into_open_ai_response( + request, + &self.model.id.0, + self.model.supports_parallel_tool_calls, + true, + None, + None, + ); + let future = self.request_limiter.stream_with_bypass( + async move { + let PerformLlmCompletionResponse { + response, + includes_status_messages, + } = Self::perform_llm_completion( + client.clone(), + llm_api_token, + app_version, + CompletionBody { + thread_id, + prompt_id, + intent, + provider: cloud_llm_client::LanguageModelProvider::OpenAi, + model: request.model.clone(), + provider_request: serde_json::to_value(&request) + .map_err(|e| anyhow!(e))?, + }, + ) + .await?; + + let mut mapper = OpenAiResponseEventMapper::new(); + Ok(map_cloud_completion_events( + Box::pin(response_lines(response, includes_status_messages)), + &provider_name, + move |event| mapper.map_event(event), + )) + }, + bypass_rate_limit, + ); async move { Ok(future.await?.boxed()) }.boxed() } else { let request = into_open_ai( @@ -823,7 +832,52 @@ impl LanguageModel for CloudLanguageModel { None, None, ); - let future = self.request_limiter.stream(async move { + let future = self.request_limiter.stream_with_bypass( + async move { + let PerformLlmCompletionResponse { + response, + includes_status_messages, + } = Self::perform_llm_completion( + client.clone(), + llm_api_token, + app_version, + CompletionBody { + thread_id, + prompt_id, + intent, + provider: cloud_llm_client::LanguageModelProvider::OpenAi, + model: request.model.clone(), + provider_request: serde_json::to_value(&request) + .map_err(|e| anyhow!(e))?, + }, + ) + .await?; + + let mut mapper = OpenAiEventMapper::new(); + Ok(map_cloud_completion_events( + Box::pin(response_lines(response, includes_status_messages)), + &provider_name, + move |event| mapper.map_event(event), + )) + }, + bypass_rate_limit, + ); + async move { Ok(future.await?.boxed()) }.boxed() + } + } + cloud_llm_client::LanguageModelProvider::XAi => { + let client = self.client.clone(); + let request = into_open_ai( + request, + &self.model.id.0, + self.model.supports_parallel_tool_calls, + false, + None, + None, + ); + let llm_api_token = self.llm_api_token.clone(); + let future = self.request_limiter.stream_with_bypass( + async move { let PerformLlmCompletionResponse { response, includes_status_messages, @@ -835,7 +889,7 @@ impl LanguageModel for CloudLanguageModel { thread_id, prompt_id, intent, - provider: cloud_llm_client::LanguageModelProvider::OpenAi, + provider: cloud_llm_client::LanguageModelProvider::XAi, model: request.model.clone(), provider_request: serde_json::to_value(&request) .map_err(|e| anyhow!(e))?, @@ -849,48 +903,9 @@ impl LanguageModel for CloudLanguageModel { &provider_name, move |event| mapper.map_event(event), )) - }); - async move { Ok(future.await?.boxed()) }.boxed() - } - } - cloud_llm_client::LanguageModelProvider::XAi => { - let client = self.client.clone(); - let request = into_open_ai( - request, - &self.model.id.0, - self.model.supports_parallel_tool_calls, - false, - None, - None, + }, + bypass_rate_limit, ); - let llm_api_token = self.llm_api_token.clone(); - let future = self.request_limiter.stream(async move { - let PerformLlmCompletionResponse { - response, - includes_status_messages, - } = Self::perform_llm_completion( - client.clone(), - llm_api_token, - app_version, - CompletionBody { - thread_id, - prompt_id, - intent, - provider: cloud_llm_client::LanguageModelProvider::XAi, - model: request.model.clone(), - provider_request: serde_json::to_value(&request) - .map_err(|e| anyhow!(e))?, - }, - ) - .await?; - - let mut mapper = OpenAiEventMapper::new(); - Ok(map_cloud_completion_events( - Box::pin(response_lines(response, includes_status_messages)), - &provider_name, - move |event| mapper.map_event(event), - )) - }); async move { Ok(future.await?.boxed()) }.boxed() } cloud_llm_client::LanguageModelProvider::Google => { @@ -898,33 +913,36 @@ impl LanguageModel for CloudLanguageModel { let request = into_google(request, self.model.id.to_string(), GoogleModelMode::Default); let llm_api_token = self.llm_api_token.clone(); - let future = self.request_limiter.stream(async move { - let PerformLlmCompletionResponse { - response, - includes_status_messages, - } = Self::perform_llm_completion( - client.clone(), - llm_api_token, - app_version, - CompletionBody { - thread_id, - prompt_id, - intent, - provider: cloud_llm_client::LanguageModelProvider::Google, - model: request.model.model_id.clone(), - provider_request: serde_json::to_value(&request) - .map_err(|e| anyhow!(e))?, - }, - ) - .await?; - - let mut mapper = GoogleEventMapper::new(); - Ok(map_cloud_completion_events( - Box::pin(response_lines(response, includes_status_messages)), - &provider_name, - move |event| mapper.map_event(event), - )) - }); + let future = self.request_limiter.stream_with_bypass( + async move { + let PerformLlmCompletionResponse { + response, + includes_status_messages, + } = Self::perform_llm_completion( + client.clone(), + llm_api_token, + app_version, + CompletionBody { + thread_id, + prompt_id, + intent, + provider: cloud_llm_client::LanguageModelProvider::Google, + model: request.model.model_id.clone(), + provider_request: serde_json::to_value(&request) + .map_err(|e| anyhow!(e))?, + }, + ) + .await?; + + let mut mapper = GoogleEventMapper::new(); + Ok(map_cloud_completion_events( + Box::pin(response_lines(response, includes_status_messages)), + &provider_name, + move |event| mapper.map_event(event), + )) + }, + bypass_rate_limit, + ); async move { Ok(future.await?.boxed()) }.boxed() } } diff --git a/crates/language_models/src/provider/copilot_chat.rs b/crates/language_models/src/provider/copilot_chat.rs index 76685e076982f2007d68ac5ae22e469a933cf80e..8c9b84417d33f823da80d221072a766d48bc59ce 100644 --- a/crates/language_models/src/provider/copilot_chat.rs +++ b/crates/language_models/src/provider/copilot_chat.rs @@ -307,6 +307,7 @@ impl LanguageModel for CopilotChatLanguageModel { LanguageModelCompletionError, >, > { + let bypass_rate_limit = request.bypass_rate_limit; let is_user_initiated = request.intent.is_none_or(|intent| match intent { CompletionIntent::UserPrompt | CompletionIntent::ThreadContextSummarization @@ -327,11 +328,14 @@ impl LanguageModel for CopilotChatLanguageModel { let request = CopilotChat::stream_response(responses_request, is_user_initiated, cx.clone()); request_limiter - .stream(async move { - let stream = request.await?; - let mapper = CopilotResponsesEventMapper::new(); - Ok(mapper.map_stream(stream).boxed()) - }) + .stream_with_bypass( + async move { + let stream = request.await?; + let mapper = CopilotResponsesEventMapper::new(); + Ok(mapper.map_stream(stream).boxed()) + }, + bypass_rate_limit, + ) .await }); return async move { Ok(future.await?.boxed()) }.boxed(); @@ -348,13 +352,16 @@ impl LanguageModel for CopilotChatLanguageModel { let request = CopilotChat::stream_completion(copilot_request, is_user_initiated, cx.clone()); request_limiter - .stream(async move { - let response = request.await?; - Ok(map_to_language_model_completion_events( - response, - is_streaming, - )) - }) + .stream_with_bypass( + async move { + let response = request.await?; + Ok(map_to_language_model_completion_events( + response, + is_streaming, + )) + }, + bypass_rate_limit, + ) .await }); async move { Ok(future.await?.boxed()) }.boxed() @@ -929,6 +936,7 @@ fn into_copilot_responses( stop: _, temperature, thinking_allowed: _, + bypass_rate_limit: _, } = request; let mut input_items: Vec = Vec::new(); diff --git a/crates/language_models/src/provider/deepseek.rs b/crates/language_models/src/provider/deepseek.rs index ea623d2cf24f26ce32e8d1fd309ac747e469096e..18f36641d82cd8d993dcb61a9576e089a5a89fda 100644 --- a/crates/language_models/src/provider/deepseek.rs +++ b/crates/language_models/src/provider/deepseek.rs @@ -199,6 +199,7 @@ impl DeepSeekLanguageModel { fn stream_completion( &self, request: deepseek::Request, + bypass_rate_limit: bool, cx: &AsyncApp, ) -> BoxFuture<'static, Result>>> { let http_client = self.http_client.clone(); @@ -208,17 +209,20 @@ impl DeepSeekLanguageModel { (state.api_key_state.key(&api_url), api_url) }); - let future = self.request_limiter.stream(async move { - let Some(api_key) = api_key else { - return Err(LanguageModelCompletionError::NoApiKey { - provider: PROVIDER_NAME, - }); - }; - let request = - deepseek::stream_completion(http_client.as_ref(), &api_url, &api_key, request); - let response = request.await?; - Ok(response) - }); + let future = self.request_limiter.stream_with_bypass( + async move { + let Some(api_key) = api_key else { + return Err(LanguageModelCompletionError::NoApiKey { + provider: PROVIDER_NAME, + }); + }; + let request = + deepseek::stream_completion(http_client.as_ref(), &api_url, &api_key, request); + let response = request.await?; + Ok(response) + }, + bypass_rate_limit, + ); async move { Ok(future.await?.boxed()) }.boxed() } @@ -302,8 +306,9 @@ impl LanguageModel for DeepSeekLanguageModel { LanguageModelCompletionError, >, > { + let bypass_rate_limit = request.bypass_rate_limit; let request = into_deepseek(request, &self.model, self.max_output_tokens()); - let stream = self.stream_completion(request, cx); + let stream = self.stream_completion(request, bypass_rate_limit, cx); async move { let mapper = DeepSeekEventMapper::new(); diff --git a/crates/language_models/src/provider/google.rs b/crates/language_models/src/provider/google.rs index 2f5b3b3701d51e4f4faadae0f8ef83f8bf6b5b2f..5e3700496e23ea8f4d3f9cec1fdda31ea44de0ce 100644 --- a/crates/language_models/src/provider/google.rs +++ b/crates/language_models/src/provider/google.rs @@ -370,16 +370,20 @@ impl LanguageModel for GoogleLanguageModel { LanguageModelCompletionError, >, > { + let bypass_rate_limit = request.bypass_rate_limit; let request = into_google( request, self.model.request_id().to_string(), self.model.mode(), ); let request = self.stream_completion(request, cx); - let future = self.request_limiter.stream(async move { - let response = request.await.map_err(LanguageModelCompletionError::from)?; - Ok(GoogleEventMapper::new().map_stream(response)) - }); + let future = self.request_limiter.stream_with_bypass( + async move { + let response = request.await.map_err(LanguageModelCompletionError::from)?; + Ok(GoogleEventMapper::new().map_stream(response)) + }, + bypass_rate_limit, + ); async move { Ok(future.await?.boxed()) }.boxed() } } diff --git a/crates/language_models/src/provider/lmstudio.rs b/crates/language_models/src/provider/lmstudio.rs index 041dfedf86e4195d98689d4f06031b32fb162e51..f55257d8b22e86968e4b407ede1658d0210e1cf0 100644 --- a/crates/language_models/src/provider/lmstudio.rs +++ b/crates/language_models/src/provider/lmstudio.rs @@ -370,6 +370,7 @@ impl LmStudioLanguageModel { fn stream_completion( &self, request: lmstudio::ChatCompletionRequest, + bypass_rate_limit: bool, cx: &AsyncApp, ) -> BoxFuture< 'static, @@ -381,11 +382,15 @@ impl LmStudioLanguageModel { settings.api_url.clone() }); - let future = self.request_limiter.stream(async move { - let request = lmstudio::stream_chat_completion(http_client.as_ref(), &api_url, request); - let response = request.await?; - Ok(response) - }); + let future = self.request_limiter.stream_with_bypass( + async move { + let request = + lmstudio::stream_chat_completion(http_client.as_ref(), &api_url, request); + let response = request.await?; + Ok(response) + }, + bypass_rate_limit, + ); async move { Ok(future.await?.boxed()) }.boxed() } @@ -460,8 +465,9 @@ impl LanguageModel for LmStudioLanguageModel { LanguageModelCompletionError, >, > { + let bypass_rate_limit = request.bypass_rate_limit; let request = self.to_lmstudio_request(request); - let completions = self.stream_completion(request, cx); + let completions = self.stream_completion(request, bypass_rate_limit, cx); async move { let mapper = LmStudioEventMapper::new(); Ok(mapper.map_stream(completions.await?).boxed()) diff --git a/crates/language_models/src/provider/mistral.rs b/crates/language_models/src/provider/mistral.rs index c8c34d7d2942ca1b42613d8733dc2219800bd66c..1e6cfee116c355875ab97de1096b2aa26cca53d9 100644 --- a/crates/language_models/src/provider/mistral.rs +++ b/crates/language_models/src/provider/mistral.rs @@ -264,6 +264,7 @@ impl MistralLanguageModel { fn stream_completion( &self, request: mistral::Request, + bypass_rate_limit: bool, cx: &AsyncApp, ) -> BoxFuture< 'static, @@ -276,17 +277,20 @@ impl MistralLanguageModel { (state.api_key_state.key(&api_url), api_url) }); - let future = self.request_limiter.stream(async move { - let Some(api_key) = api_key else { - return Err(LanguageModelCompletionError::NoApiKey { - provider: PROVIDER_NAME, - }); - }; - let request = - mistral::stream_completion(http_client.as_ref(), &api_url, &api_key, request); - let response = request.await?; - Ok(response) - }); + let future = self.request_limiter.stream_with_bypass( + async move { + let Some(api_key) = api_key else { + return Err(LanguageModelCompletionError::NoApiKey { + provider: PROVIDER_NAME, + }); + }; + let request = + mistral::stream_completion(http_client.as_ref(), &api_url, &api_key, request); + let response = request.await?; + Ok(response) + }, + bypass_rate_limit, + ); async move { Ok(future.await?.boxed()) }.boxed() } @@ -370,8 +374,9 @@ impl LanguageModel for MistralLanguageModel { LanguageModelCompletionError, >, > { + let bypass_rate_limit = request.bypass_rate_limit; let request = into_mistral(request, self.model.clone(), self.max_output_tokens()); - let stream = self.stream_completion(request, cx); + let stream = self.stream_completion(request, bypass_rate_limit, cx); async move { let stream = stream.await?; @@ -901,6 +906,7 @@ mod tests { intent: None, stop: vec![], thinking_allowed: true, + bypass_rate_limit: false, }; let mistral_request = into_mistral(request, mistral::Model::MistralSmallLatest, None); @@ -934,6 +940,7 @@ mod tests { intent: None, stop: vec![], thinking_allowed: true, + bypass_rate_limit: false, }; let mistral_request = into_mistral(request, mistral::Model::Pixtral12BLatest, None); diff --git a/crates/language_models/src/provider/ollama.rs b/crates/language_models/src/provider/ollama.rs index 79ac51a870782817723c7c32253946068d6570e3..8c9df749e9af17631a85a5934ac9ef3211bff55a 100644 --- a/crates/language_models/src/provider/ollama.rs +++ b/crates/language_models/src/provider/ollama.rs @@ -480,6 +480,7 @@ impl LanguageModel for OllamaLanguageModel { LanguageModelCompletionError, >, > { + let bypass_rate_limit = request.bypass_rate_limit; let request = self.to_ollama_request(request); let http_client = self.http_client.clone(); @@ -488,13 +489,20 @@ impl LanguageModel for OllamaLanguageModel { (state.api_key_state.key(&api_url), api_url) }); - let future = self.request_limiter.stream(async move { - let stream = - stream_chat_completion(http_client.as_ref(), &api_url, api_key.as_deref(), request) - .await?; - let stream = map_to_language_model_completion_events(stream); - Ok(stream) - }); + let future = self.request_limiter.stream_with_bypass( + async move { + let stream = stream_chat_completion( + http_client.as_ref(), + &api_url, + api_key.as_deref(), + request, + ) + .await?; + let stream = map_to_language_model_completion_events(stream); + Ok(stream) + }, + bypass_rate_limit, + ); future.map_ok(|f| f.boxed()).boxed() } diff --git a/crates/language_models/src/provider/open_ai.rs b/crates/language_models/src/provider/open_ai.rs index 915fd325494f523f1a1d64ef17cec63bf29cc44d..fde221f86f270cd5955adccc4588ea5811e02ad8 100644 --- a/crates/language_models/src/provider/open_ai.rs +++ b/crates/language_models/src/provider/open_ai.rs @@ -213,6 +213,7 @@ impl OpenAiLanguageModel { fn stream_completion( &self, request: open_ai::Request, + bypass_rate_limit: bool, cx: &AsyncApp, ) -> BoxFuture<'static, Result>>> { @@ -223,21 +224,24 @@ impl OpenAiLanguageModel { (state.api_key_state.key(&api_url), api_url) }); - let future = self.request_limiter.stream(async move { - let provider = PROVIDER_NAME; - let Some(api_key) = api_key else { - return Err(LanguageModelCompletionError::NoApiKey { provider }); - }; - let request = stream_completion( - http_client.as_ref(), - provider.0.as_str(), - &api_url, - &api_key, - request, - ); - let response = request.await?; - Ok(response) - }); + let future = self.request_limiter.stream_with_bypass( + async move { + let provider = PROVIDER_NAME; + let Some(api_key) = api_key else { + return Err(LanguageModelCompletionError::NoApiKey { provider }); + }; + let request = stream_completion( + http_client.as_ref(), + provider.0.as_str(), + &api_url, + &api_key, + request, + ); + let response = request.await?; + Ok(response) + }, + bypass_rate_limit, + ); async move { Ok(future.await?.boxed()) }.boxed() } @@ -245,6 +249,7 @@ impl OpenAiLanguageModel { fn stream_response( &self, request: ResponseRequest, + bypass_rate_limit: bool, cx: &AsyncApp, ) -> BoxFuture<'static, Result>>> { @@ -256,20 +261,23 @@ impl OpenAiLanguageModel { }); let provider = PROVIDER_NAME; - let future = self.request_limiter.stream(async move { - let Some(api_key) = api_key else { - return Err(LanguageModelCompletionError::NoApiKey { provider }); - }; - let request = stream_response( - http_client.as_ref(), - provider.0.as_str(), - &api_url, - &api_key, - request, - ); - let response = request.await?; - Ok(response) - }); + let future = self.request_limiter.stream_with_bypass( + async move { + let Some(api_key) = api_key else { + return Err(LanguageModelCompletionError::NoApiKey { provider }); + }; + let request = stream_response( + http_client.as_ref(), + provider.0.as_str(), + &api_url, + &api_key, + request, + ); + let response = request.await?; + Ok(response) + }, + bypass_rate_limit, + ); async move { Ok(future.await?.boxed()) }.boxed() } @@ -368,6 +376,7 @@ impl LanguageModel for OpenAiLanguageModel { LanguageModelCompletionError, >, > { + let bypass_rate_limit = request.bypass_rate_limit; if self.model.supports_chat_completions() { let request = into_open_ai( request, @@ -377,7 +386,7 @@ impl LanguageModel for OpenAiLanguageModel { self.max_output_tokens(), self.model.reasoning_effort(), ); - let completions = self.stream_completion(request, cx); + let completions = self.stream_completion(request, bypass_rate_limit, cx); async move { let mapper = OpenAiEventMapper::new(); Ok(mapper.map_stream(completions.await?).boxed()) @@ -392,7 +401,7 @@ impl LanguageModel for OpenAiLanguageModel { self.max_output_tokens(), self.model.reasoning_effort(), ); - let completions = self.stream_response(request, cx); + let completions = self.stream_response(request, bypass_rate_limit, cx); async move { let mapper = OpenAiResponseEventMapper::new(); Ok(mapper.map_stream(completions.await?).boxed()) @@ -545,6 +554,7 @@ pub fn into_open_ai_response( stop: _, temperature, thinking_allowed: _, + bypass_rate_limit: _, } = request; let mut input_items = Vec::new(); @@ -1417,6 +1427,7 @@ mod tests { stop: vec![], temperature: None, thinking_allowed: true, + bypass_rate_limit: false, }; // Validate that all models are supported by tiktoken-rs @@ -1553,6 +1564,7 @@ mod tests { stop: vec!["".into()], temperature: None, thinking_allowed: false, + bypass_rate_limit: false, }; let response = into_open_ai_response( diff --git a/crates/language_models/src/provider/open_ai_compatible.rs b/crates/language_models/src/provider/open_ai_compatible.rs index d47ea26c594ab0abb5c859ed549d43e0ed3f859b..0f949740ff8a6b9b006af6cdccce52613fbeddfa 100644 --- a/crates/language_models/src/provider/open_ai_compatible.rs +++ b/crates/language_models/src/provider/open_ai_compatible.rs @@ -204,6 +204,7 @@ impl OpenAiCompatibleLanguageModel { fn stream_completion( &self, request: open_ai::Request, + bypass_rate_limit: bool, cx: &AsyncApp, ) -> BoxFuture< 'static, @@ -223,20 +224,23 @@ impl OpenAiCompatibleLanguageModel { }); let provider = self.provider_name.clone(); - let future = self.request_limiter.stream(async move { - let Some(api_key) = api_key else { - return Err(LanguageModelCompletionError::NoApiKey { provider }); - }; - let request = stream_completion( - http_client.as_ref(), - provider.0.as_str(), - &api_url, - &api_key, - request, - ); - let response = request.await?; - Ok(response) - }); + let future = self.request_limiter.stream_with_bypass( + async move { + let Some(api_key) = api_key else { + return Err(LanguageModelCompletionError::NoApiKey { provider }); + }; + let request = stream_completion( + http_client.as_ref(), + provider.0.as_str(), + &api_url, + &api_key, + request, + ); + let response = request.await?; + Ok(response) + }, + bypass_rate_limit, + ); async move { Ok(future.await?.boxed()) }.boxed() } @@ -244,6 +248,7 @@ impl OpenAiCompatibleLanguageModel { fn stream_response( &self, request: ResponseRequest, + bypass_rate_limit: bool, cx: &AsyncApp, ) -> BoxFuture<'static, Result>>> { @@ -258,20 +263,23 @@ impl OpenAiCompatibleLanguageModel { }); let provider = self.provider_name.clone(); - let future = self.request_limiter.stream(async move { - let Some(api_key) = api_key else { - return Err(LanguageModelCompletionError::NoApiKey { provider }); - }; - let request = stream_response( - http_client.as_ref(), - provider.0.as_str(), - &api_url, - &api_key, - request, - ); - let response = request.await?; - Ok(response) - }); + let future = self.request_limiter.stream_with_bypass( + async move { + let Some(api_key) = api_key else { + return Err(LanguageModelCompletionError::NoApiKey { provider }); + }; + let request = stream_response( + http_client.as_ref(), + provider.0.as_str(), + &api_url, + &api_key, + request, + ); + let response = request.await?; + Ok(response) + }, + bypass_rate_limit, + ); async move { Ok(future.await?.boxed()) }.boxed() } @@ -370,6 +378,7 @@ impl LanguageModel for OpenAiCompatibleLanguageModel { LanguageModelCompletionError, >, > { + let bypass_rate_limit = request.bypass_rate_limit; if self.model.capabilities.chat_completions { let request = into_open_ai( request, @@ -379,7 +388,7 @@ impl LanguageModel for OpenAiCompatibleLanguageModel { self.max_output_tokens(), None, ); - let completions = self.stream_completion(request, cx); + let completions = self.stream_completion(request, bypass_rate_limit, cx); async move { let mapper = OpenAiEventMapper::new(); Ok(mapper.map_stream(completions.await?).boxed()) @@ -394,7 +403,7 @@ impl LanguageModel for OpenAiCompatibleLanguageModel { self.max_output_tokens(), None, ); - let completions = self.stream_response(request, cx); + let completions = self.stream_response(request, bypass_rate_limit, cx); async move { let mapper = OpenAiResponseEventMapper::new(); Ok(mapper.map_stream(completions.await?).boxed()) diff --git a/crates/language_models/src/provider/open_router.rs b/crates/language_models/src/provider/open_router.rs index 273b45ea23f76936a41584c9c58cd3c73c5c4967..1cdb84437574a82168fea83c46d2c55d33c7c22e 100644 --- a/crates/language_models/src/provider/open_router.rs +++ b/crates/language_models/src/provider/open_router.rs @@ -368,12 +368,16 @@ impl LanguageModel for OpenRouterLanguageModel { LanguageModelCompletionError, >, > { + let bypass_rate_limit = request.bypass_rate_limit; let openrouter_request = into_open_router(request, &self.model, self.max_output_tokens()); let request = self.stream_completion(openrouter_request, cx); - let future = self.request_limiter.stream(async move { - let response = request.await?; - Ok(OpenRouterEventMapper::new().map_stream(response)) - }); + let future = self.request_limiter.stream_with_bypass( + async move { + let response = request.await?; + Ok(OpenRouterEventMapper::new().map_stream(response)) + }, + bypass_rate_limit, + ); async move { Ok(future.await?.boxed()) }.boxed() } } diff --git a/crates/language_models/src/provider/vercel.rs b/crates/language_models/src/provider/vercel.rs index 3b324e46927f5864d83a5e4b74c46f5e39e8ab3a..3b84434ed4aa2431662de176666da7aeaeb50392 100644 --- a/crates/language_models/src/provider/vercel.rs +++ b/crates/language_models/src/provider/vercel.rs @@ -193,6 +193,7 @@ impl VercelLanguageModel { fn stream_completion( &self, request: open_ai::Request, + bypass_rate_limit: bool, cx: &AsyncApp, ) -> BoxFuture<'static, Result>>> { @@ -203,21 +204,24 @@ impl VercelLanguageModel { (state.api_key_state.key(&api_url), api_url) }); - let future = self.request_limiter.stream(async move { - let provider = PROVIDER_NAME; - let Some(api_key) = api_key else { - return Err(LanguageModelCompletionError::NoApiKey { provider }); - }; - let request = open_ai::stream_completion( - http_client.as_ref(), - provider.0.as_str(), - &api_url, - &api_key, - request, - ); - let response = request.await?; - Ok(response) - }); + let future = self.request_limiter.stream_with_bypass( + async move { + let provider = PROVIDER_NAME; + let Some(api_key) = api_key else { + return Err(LanguageModelCompletionError::NoApiKey { provider }); + }; + let request = open_ai::stream_completion( + http_client.as_ref(), + provider.0.as_str(), + &api_url, + &api_key, + request, + ); + let response = request.await?; + Ok(response) + }, + bypass_rate_limit, + ); async move { Ok(future.await?.boxed()) }.boxed() } @@ -290,6 +294,7 @@ impl LanguageModel for VercelLanguageModel { LanguageModelCompletionError, >, > { + let bypass_rate_limit = request.bypass_rate_limit; let request = crate::provider::open_ai::into_open_ai( request, self.model.id(), @@ -298,7 +303,7 @@ impl LanguageModel for VercelLanguageModel { self.max_output_tokens(), None, ); - let completions = self.stream_completion(request, cx); + let completions = self.stream_completion(request, bypass_rate_limit, cx); async move { let mapper = crate::provider::open_ai::OpenAiEventMapper::new(); Ok(mapper.map_stream(completions.await?).boxed()) diff --git a/crates/language_models/src/provider/x_ai.rs b/crates/language_models/src/provider/x_ai.rs index 06564224dea9621d594e5cf3f4a84093f1620446..c77c2a8aca9a7d2d79d92763d5ab31dc7bb3eb1b 100644 --- a/crates/language_models/src/provider/x_ai.rs +++ b/crates/language_models/src/provider/x_ai.rs @@ -197,6 +197,7 @@ impl XAiLanguageModel { fn stream_completion( &self, request: open_ai::Request, + bypass_rate_limit: bool, cx: &AsyncApp, ) -> BoxFuture< 'static, @@ -212,21 +213,24 @@ impl XAiLanguageModel { (state.api_key_state.key(&api_url), api_url) }); - let future = self.request_limiter.stream(async move { - let provider = PROVIDER_NAME; - let Some(api_key) = api_key else { - return Err(LanguageModelCompletionError::NoApiKey { provider }); - }; - let request = open_ai::stream_completion( - http_client.as_ref(), - provider.0.as_str(), - &api_url, - &api_key, - request, - ); - let response = request.await?; - Ok(response) - }); + let future = self.request_limiter.stream_with_bypass( + async move { + let provider = PROVIDER_NAME; + let Some(api_key) = api_key else { + return Err(LanguageModelCompletionError::NoApiKey { provider }); + }; + let request = open_ai::stream_completion( + http_client.as_ref(), + provider.0.as_str(), + &api_url, + &api_key, + request, + ); + let response = request.await?; + Ok(response) + }, + bypass_rate_limit, + ); async move { Ok(future.await?.boxed()) }.boxed() } @@ -307,6 +311,7 @@ impl LanguageModel for XAiLanguageModel { LanguageModelCompletionError, >, > { + let bypass_rate_limit = request.bypass_rate_limit; let request = crate::provider::open_ai::into_open_ai( request, self.model.id(), @@ -315,7 +320,7 @@ impl LanguageModel for XAiLanguageModel { self.max_output_tokens(), None, ); - let completions = self.stream_completion(request, cx); + let completions = self.stream_completion(request, bypass_rate_limit, cx); async move { let mapper = crate::provider::open_ai::OpenAiEventMapper::new(); Ok(mapper.map_stream(completions.await?).boxed()) diff --git a/crates/rules_library/src/rules_library.rs b/crates/rules_library/src/rules_library.rs index 09089a6bcba83b4159b346c0e9da2dfd53289389..320016f200e9af6f04d059b6b8064ea82f204d6d 100644 --- a/crates/rules_library/src/rules_library.rs +++ b/crates/rules_library/src/rules_library.rs @@ -1100,6 +1100,7 @@ impl RulesLibrary { stop: Vec::new(), temperature: None, thinking_allowed: true, + bypass_rate_limit: false, }, cx, )