From fbf7caf93eb18d817bed6bd405d914590ba409d2 Mon Sep 17 00:00:00 2001 From: Michael Sloan Date: Sat, 19 Apr 2025 17:26:29 -0600 Subject: [PATCH] Default to fast model for thread summaries and titles + don't include system prompt / context / thinking segments (#29102) * Adds a fast / cheaper model to providers and defaults thread summarization to this model. Initial motivation for this was that https://github.com/zed-industries/zed/pull/29099 would cause these requests to fail when used with a thinking model. It doesn't seem correct to use a thinking model for summarization. * Skips system prompt, context, and thinking segments. * If tool use is happening, allows 2 tool uses + one more agent response before summarizing. Downside of this is that there was potential for some prefix cache reuse before, especially for title summarization (thread summarization omitted tool results and so would not share a prefix for those). This seems fine as these requests should typically be fairly small. Even for full thread summarization, skipping all tool use / context should greatly reduce the token use. Release Notes: - N/A --- crates/agent/src/active_thread.rs | 6 +- crates/agent/src/assistant.rs | 2 +- crates/agent/src/message_editor.rs | 15 +- crates/agent/src/thread.rs | 169 +++++++++--------- crates/anthropic/src/anthropic.rs | 4 + crates/bedrock/src/models.rs | 4 + crates/copilot/src/copilot_chat.rs | 4 + crates/deepseek/src/deepseek.rs | 4 + crates/eval/src/example.rs | 4 +- crates/google_ai/src/google_ai.rs | 4 + crates/language_model/src/fake_provider.rs | 4 + crates/language_model/src/language_model.rs | 1 + crates/language_model/src/registry.rs | 32 +++- .../language_models/src/provider/anthropic.rs | 15 +- .../language_models/src/provider/bedrock.rs | 39 ++-- crates/language_models/src/provider/cloud.rs | 16 +- .../src/provider/copilot_chat.rs | 24 +-- .../language_models/src/provider/deepseek.rs | 33 ++-- crates/language_models/src/provider/google.rs | 23 ++- .../language_models/src/provider/lmstudio.rs | 4 + .../language_models/src/provider/mistral.rs | 23 ++- crates/language_models/src/provider/ollama.rs | 4 + .../language_models/src/provider/open_ai.rs | 33 ++-- crates/mistral/src/mistral.rs | 4 + crates/open_ai/src/open_ai.rs | 4 + 25 files changed, 270 insertions(+), 205 deletions(-) diff --git a/crates/agent/src/active_thread.rs b/crates/agent/src/active_thread.rs index 4c3fbe878ecaf73b309907e1eaf58f59737799b7..2a4a00cf23e8c9cbef84da8d9d30c521ec4b6279 100644 --- a/crates/agent/src/active_thread.rs +++ b/crates/agent/src/active_thread.rs @@ -1,8 +1,8 @@ use crate::context::{AssistantContext, ContextId, format_context_as_string}; use crate::context_picker::MentionLink; use crate::thread::{ - LastRestoreCheckpoint, MessageId, MessageSegment, RequestKind, Thread, ThreadError, - ThreadEvent, ThreadFeedback, + LastRestoreCheckpoint, MessageId, MessageSegment, Thread, ThreadError, ThreadEvent, + ThreadFeedback, }; use crate::thread_store::{RulesLoadingError, ThreadStore}; use crate::tool_use::{PendingToolUseStatus, ToolUse}; @@ -1285,7 +1285,7 @@ impl ActiveThread { self.thread.update(cx, |thread, cx| { thread.advance_prompt_id(); - thread.send_to_model(model.model, RequestKind::Chat, cx) + thread.send_to_model(model.model, cx) }); cx.notify(); } diff --git a/crates/agent/src/assistant.rs b/crates/agent/src/assistant.rs index 1f067af7343aab5fa79b96dd17ec2e74710abe75..03e13d6f68355ab25e083427509f9baf2c5d3f26 100644 --- a/crates/agent/src/assistant.rs +++ b/crates/agent/src/assistant.rs @@ -40,7 +40,7 @@ pub use crate::active_thread::ActiveThread; use crate::assistant_configuration::{AddContextServerModal, ManageProfilesModal}; pub use crate::assistant_panel::{AssistantPanel, ConcreteAssistantPanelDelegate}; pub use crate::inline_assistant::InlineAssistant; -pub use crate::thread::{Message, RequestKind, Thread, ThreadEvent}; +pub use crate::thread::{Message, Thread, ThreadEvent}; pub use crate::thread_store::ThreadStore; pub use agent_diff::{AgentDiff, AgentDiffToolbar}; diff --git a/crates/agent/src/message_editor.rs b/crates/agent/src/message_editor.rs index d64c95b9b99301da0ed900510d3dc4b0dff819d5..ac16df4c973f8491f3dacb330530968346be57ea 100644 --- a/crates/agent/src/message_editor.rs +++ b/crates/agent/src/message_editor.rs @@ -34,7 +34,7 @@ use crate::context_picker::{ContextPicker, ContextPickerCompletionProvider}; use crate::context_store::{ContextStore, refresh_context_store_text}; use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind}; use crate::profile_selector::ProfileSelector; -use crate::thread::{RequestKind, Thread, TokenUsageRatio}; +use crate::thread::{Thread, TokenUsageRatio}; use crate::thread_store::ThreadStore; use crate::{ AgentDiff, Chat, ChatMode, ExpandMessageEditor, NewThread, OpenAgentDiff, RemoveAllContext, @@ -234,7 +234,7 @@ impl MessageEditor { } self.set_editor_is_expanded(false, cx); - self.send_to_model(RequestKind::Chat, window, cx); + self.send_to_model(window, cx); cx.notify(); } @@ -249,12 +249,7 @@ impl MessageEditor { .is_some() } - fn send_to_model( - &mut self, - request_kind: RequestKind, - window: &mut Window, - cx: &mut Context, - ) { + fn send_to_model(&mut self, window: &mut Window, cx: &mut Context) { let model_registry = LanguageModelRegistry::read_global(cx); let Some(ConfiguredModel { model, provider }) = model_registry.default_model() else { return; @@ -331,7 +326,7 @@ impl MessageEditor { thread .update(cx, |thread, cx| { thread.advance_prompt_id(); - thread.send_to_model(model, request_kind, cx); + thread.send_to_model(model, cx); }) .log_err(); }) @@ -345,7 +340,7 @@ impl MessageEditor { if cancelled { self.set_editor_is_expanded(false, cx); - self.send_to_model(RequestKind::Chat, window, cx); + self.send_to_model(window, cx); } } diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index e8ca584fa0e780479c7094bb133ba9e8c764cb81..17f7e203877a03fc8e4e0e82245c41c96a0e4df3 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -40,13 +40,6 @@ use crate::thread_store::{ }; use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState, USING_TOOL_MARKER}; -#[derive(Debug, Clone, Copy)] -pub enum RequestKind { - Chat, - /// Used when summarizing a thread. - Summarize, -} - #[derive( Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema, )] @@ -949,13 +942,8 @@ impl Thread { }) } - pub fn send_to_model( - &mut self, - model: Arc, - request_kind: RequestKind, - cx: &mut Context, - ) { - let mut request = self.to_completion_request(request_kind, cx); + pub fn send_to_model(&mut self, model: Arc, cx: &mut Context) { + let mut request = self.to_completion_request(cx); if model.supports_tools() { request.tools = { let mut tools = Vec::new(); @@ -994,11 +982,7 @@ impl Thread { false } - pub fn to_completion_request( - &self, - request_kind: RequestKind, - cx: &mut Context, - ) -> LanguageModelRequest { + pub fn to_completion_request(&self, cx: &mut Context) -> LanguageModelRequest { let mut request = LanguageModelRequest { thread_id: Some(self.id.to_string()), prompt_id: Some(self.last_prompt_id.to_string()), @@ -1045,18 +1029,8 @@ impl Thread { cache: false, }; - match request_kind { - RequestKind::Chat => { - self.tool_use - .attach_tool_results(message.id, &mut request_message); - } - RequestKind::Summarize => { - // We don't care about tool use during summarization. - if self.tool_use.message_has_tool_results(message.id) { - continue; - } - } - } + self.tool_use + .attach_tool_results(message.id, &mut request_message); if !message.context.is_empty() { request_message @@ -1089,15 +1063,8 @@ impl Thread { }; } - match request_kind { - RequestKind::Chat => { - self.tool_use - .attach_tool_uses(message.id, &mut request_message); - } - RequestKind::Summarize => { - // We don't care about tool use during summarization. - } - }; + self.tool_use + .attach_tool_uses(message.id, &mut request_message); request.messages.push(request_message); } @@ -1112,6 +1079,54 @@ impl Thread { request } + fn to_summarize_request(&self, added_user_message: String) -> LanguageModelRequest { + let mut request = LanguageModelRequest { + thread_id: None, + prompt_id: None, + messages: vec![], + tools: Vec::new(), + stop: Vec::new(), + temperature: None, + }; + + for message in &self.messages { + let mut request_message = LanguageModelRequestMessage { + role: message.role, + content: Vec::new(), + cache: false, + }; + + // Skip tool results during summarization. + if self.tool_use.message_has_tool_results(message.id) { + continue; + } + + for segment in &message.segments { + match segment { + MessageSegment::Text(text) => request_message + .content + .push(MessageContent::Text(text.clone())), + MessageSegment::Thinking { .. } => {} + MessageSegment::RedactedThinking(_) => {} + } + } + + if request_message.content.is_empty() { + continue; + } + + request.messages.push(request_message); + } + + request.messages.push(LanguageModelRequestMessage { + role: Role::User, + content: vec![MessageContent::Text(added_user_message)], + cache: false, + }); + + request + } + fn attached_tracked_files_state( &self, messages: &mut Vec, @@ -1293,7 +1308,12 @@ impl Thread { .pending_completions .retain(|completion| completion.id != pending_completion_id); - if thread.summary.is_none() && thread.messages.len() >= 2 { + // If there is a response without tool use, summarize the message. Otherwise, + // allow two tool uses before summarizing. + if thread.summary.is_none() + && thread.messages.len() >= 2 + && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6) + { thread.summarize(cx); } })?; @@ -1403,18 +1423,12 @@ impl Thread { return; } - let mut request = self.to_completion_request(RequestKind::Summarize, cx); - request.messages.push(LanguageModelRequestMessage { - role: Role::User, - content: vec![ - "Generate a concise 3-7 word title for this conversation, omitting punctuation. \ - Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \ - If the conversation is about a specific subject, include it in the title. \ - Be descriptive. DO NOT speak in the first person." - .into(), - ], - cache: false, - }); + let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \ + Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \ + If the conversation is about a specific subject, include it in the title. \ + Be descriptive. DO NOT speak in the first person."; + + let request = self.to_summarize_request(added_user_message.into()); self.pending_summary = cx.spawn(async move |this, cx| { async move { @@ -1476,21 +1490,14 @@ impl Thread { return None; } - let mut request = self.to_completion_request(RequestKind::Summarize, cx); + let added_user_message = "Generate a detailed summary of this conversation. Include:\n\ + 1. A brief overview of what was discussed\n\ + 2. Key facts or information discovered\n\ + 3. Outcomes or conclusions reached\n\ + 4. Any action items or next steps if any\n\ + Format it in Markdown with headings and bullet points."; - request.messages.push(LanguageModelRequestMessage { - role: Role::User, - content: vec![ - "Generate a detailed summary of this conversation. Include:\n\ - 1. A brief overview of what was discussed\n\ - 2. Key facts or information discovered\n\ - 3. Outcomes or conclusions reached\n\ - 4. Any action items or next steps if any\n\ - Format it in Markdown with headings and bullet points." - .into(), - ], - cache: false, - }); + let request = self.to_summarize_request(added_user_message.into()); let task = cx.spawn(async move |thread, cx| { let stream = model.stream_completion_text(request, &cx); @@ -1538,7 +1545,7 @@ impl Thread { pub fn use_pending_tools(&mut self, cx: &mut Context) -> Vec { self.auto_capture_telemetry(cx); - let request = self.to_completion_request(RequestKind::Chat, cx); + let request = self.to_completion_request(cx); let messages = Arc::new(request.messages); let pending_tool_uses = self .tool_use @@ -1650,7 +1657,7 @@ impl Thread { if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() { self.attach_tool_results(cx); if !canceled { - self.send_to_model(model, RequestKind::Chat, cx); + self.send_to_model(model, cx); } } } @@ -2275,9 +2282,7 @@ fn main() {{ assert_eq!(message.context, expected_context); // Check message in request - let request = thread.update(cx, |thread, cx| { - thread.to_completion_request(RequestKind::Chat, cx) - }); + let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx)); assert_eq!(request.messages.len(), 2); let expected_full_message = format!("{}Please explain this code", expected_context); @@ -2367,9 +2372,7 @@ fn main() {{ assert!(message3.context.contains("file3.rs")); // Check entire request to make sure all contexts are properly included - let request = thread.update(cx, |thread, cx| { - thread.to_completion_request(RequestKind::Chat, cx) - }); + let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx)); // The request should contain all 3 messages assert_eq!(request.messages.len(), 4); @@ -2419,9 +2422,7 @@ fn main() {{ assert_eq!(message.context, ""); // Check message in request - let request = thread.update(cx, |thread, cx| { - thread.to_completion_request(RequestKind::Chat, cx) - }); + let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx)); assert_eq!(request.messages.len(), 2); assert_eq!( @@ -2439,9 +2440,7 @@ fn main() {{ assert_eq!(message2.context, ""); // Check that both messages appear in the request - let request = thread.update(cx, |thread, cx| { - thread.to_completion_request(RequestKind::Chat, cx) - }); + let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx)); assert_eq!(request.messages.len(), 3); assert_eq!( @@ -2481,9 +2480,7 @@ fn main() {{ }); // Create a request and check that it doesn't have a stale buffer warning yet - let initial_request = thread.update(cx, |thread, cx| { - thread.to_completion_request(RequestKind::Chat, cx) - }); + let initial_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx)); // Make sure we don't have a stale file warning yet let has_stale_warning = initial_request.messages.iter().any(|msg| { @@ -2511,9 +2508,7 @@ fn main() {{ }); // Create a new request and check for the stale buffer warning - let new_request = thread.update(cx, |thread, cx| { - thread.to_completion_request(RequestKind::Chat, cx) - }); + let new_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx)); // We should have a stale file warning as the last message let last_message = new_request diff --git a/crates/anthropic/src/anthropic.rs b/crates/anthropic/src/anthropic.rs index 684feaca3b0d3625aad1d2d5ff0cb1fbdc5f488e..e7edb0e086d1aee0d6bedd9b8e983edf8feb848d 100644 --- a/crates/anthropic/src/anthropic.rs +++ b/crates/anthropic/src/anthropic.rs @@ -74,6 +74,10 @@ pub enum Model { } impl Model { + pub fn default_fast() -> Self { + Self::Claude3_5Haiku + } + pub fn from_id(id: &str) -> Result { if id.starts_with("claude-3-5-sonnet") { Ok(Self::Claude3_5Sonnet) diff --git a/crates/bedrock/src/models.rs b/crates/bedrock/src/models.rs index 052e5c2ca1d8a0678d1fb48e9ddd6663d795342c..8ead77f9c47e7091da4120f8aed3fe8c796247ad 100644 --- a/crates/bedrock/src/models.rs +++ b/crates/bedrock/src/models.rs @@ -84,6 +84,10 @@ pub enum Model { } impl Model { + pub fn default_fast() -> Self { + Self::Claude3_5Haiku + } + pub fn from_id(id: &str) -> anyhow::Result { if id.starts_with("claude-3-5-sonnet-v2") { Ok(Self::Claude3_5SonnetV2) diff --git a/crates/copilot/src/copilot_chat.rs b/crates/copilot/src/copilot_chat.rs index 255c39cb84ec1aa979060e1dd57edf4be2963046..2bcb82c1ee64c9a6b1b617964567d079cd1ce8cf 100644 --- a/crates/copilot/src/copilot_chat.rs +++ b/crates/copilot/src/copilot_chat.rs @@ -61,6 +61,10 @@ pub enum Model { } impl Model { + pub fn default_fast() -> Self { + Self::Claude3_7Sonnet + } + pub fn uses_streaming(&self) -> bool { match self { Self::Gpt4o diff --git a/crates/deepseek/src/deepseek.rs b/crates/deepseek/src/deepseek.rs index 07f6a959e100c9bd544544bd23e08adf097a94e6..9c19f1ae2f43d7d1186620935db0f7890923f81f 100644 --- a/crates/deepseek/src/deepseek.rs +++ b/crates/deepseek/src/deepseek.rs @@ -64,6 +64,10 @@ pub enum Model { } impl Model { + pub fn default_fast() -> Self { + Model::Chat + } + pub fn from_id(id: &str) -> Result { match id { "deepseek-chat" => Ok(Self::Chat), diff --git a/crates/eval/src/example.rs b/crates/eval/src/example.rs index 982daeaed79fc323a8be9966b11b8a500b5ca691..78b7eb7af91e8812ff3cf86f78adb9c6c29fd6bf 100644 --- a/crates/eval/src/example.rs +++ b/crates/eval/src/example.rs @@ -1,4 +1,4 @@ -use agent::{RequestKind, ThreadEvent, ThreadStore}; +use agent::{ThreadEvent, ThreadStore}; use anyhow::{Context as _, Result, anyhow}; use assistant_tool::ToolWorkingSet; use client::proto::LspWorkProgress; @@ -472,7 +472,7 @@ impl Example { thread.update(cx, |thread, cx| { let context = vec![]; thread.insert_user_message(this.prompt.clone(), context, None, cx); - thread.send_to_model(model, RequestKind::Chat, cx); + thread.send_to_model(model, cx); })?; event_handler_task.await?; diff --git a/crates/google_ai/src/google_ai.rs b/crates/google_ai/src/google_ai.rs index 09278d6ed212a95457f8ec3ad32178101368f342..e26750936d46663ea709b5cff778ea69dd2dfcda 100644 --- a/crates/google_ai/src/google_ai.rs +++ b/crates/google_ai/src/google_ai.rs @@ -412,6 +412,10 @@ pub enum Model { } impl Model { + pub fn default_fast() -> Model { + Model::Gemini15Flash + } + pub fn id(&self) -> &str { match self { Model::Gemini15Pro => "gemini-1.5-pro", diff --git a/crates/language_model/src/fake_provider.rs b/crates/language_model/src/fake_provider.rs index 56df184d365ca0d539abcaafcdca421b24a06e1f..25f2a496e7eb7420c845fdf4835eadd4100daad6 100644 --- a/crates/language_model/src/fake_provider.rs +++ b/crates/language_model/src/fake_provider.rs @@ -49,6 +49,10 @@ impl LanguageModelProvider for FakeLanguageModelProvider { Some(Arc::new(FakeLanguageModel::default())) } + fn default_fast_model(&self, _cx: &App) -> Option> { + Some(Arc::new(FakeLanguageModel::default())) + } + fn provided_models(&self, _: &App) -> Vec> { vec![Arc::new(FakeLanguageModel::default())] } diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index 206958e82fafeeed5f22ee2ed5d3119340470595..71d8551bd5b32d6cc76305132a52ca2515331aaf 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -370,6 +370,7 @@ pub trait LanguageModelProvider: 'static { IconName::ZedAssistant } fn default_model(&self, cx: &App) -> Option>; + fn default_fast_model(&self, cx: &App) -> Option>; fn provided_models(&self, cx: &App) -> Vec>; fn recommended_models(&self, _cx: &App) -> Vec> { Vec::new() diff --git a/crates/language_model/src/registry.rs b/crates/language_model/src/registry.rs index 45be22457f0af2b37b1951afc47a5de85237d302..62f216094bde6264b8986a9a7274b9545d277707 100644 --- a/crates/language_model/src/registry.rs +++ b/crates/language_model/src/registry.rs @@ -5,6 +5,7 @@ use crate::{ use collections::BTreeMap; use gpui::{App, Context, Entity, EventEmitter, Global, prelude::*}; use std::sync::Arc; +use util::maybe; pub fn init(cx: &mut App) { let registry = cx.new(|_cx| LanguageModelRegistry::default()); @@ -18,6 +19,7 @@ impl Global for GlobalLanguageModelRegistry {} #[derive(Default)] pub struct LanguageModelRegistry { default_model: Option, + default_fast_model: Option, inline_assistant_model: Option, commit_message_model: Option, thread_summary_model: Option, @@ -202,6 +204,14 @@ impl LanguageModelRegistry { (None, None) => {} _ => cx.emit(Event::DefaultModelChanged), } + self.default_fast_model = maybe!({ + let provider = &model.as_ref()?.provider; + let fast_model = provider.default_fast_model(cx)?; + Some(ConfiguredModel { + provider: provider.clone(), + model: fast_model, + }) + }); self.default_model = model; } @@ -254,21 +264,37 @@ impl LanguageModelRegistry { } pub fn inline_assistant_model(&self) -> Option { + #[cfg(debug_assertions)] + if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() { + return None; + } + self.inline_assistant_model .clone() - .or_else(|| self.default_model()) + .or_else(|| self.default_model.clone()) } pub fn commit_message_model(&self) -> Option { + #[cfg(debug_assertions)] + if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() { + return None; + } + self.commit_message_model .clone() - .or_else(|| self.default_model()) + .or_else(|| self.default_model.clone()) } pub fn thread_summary_model(&self) -> Option { + #[cfg(debug_assertions)] + if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() { + return None; + } + self.thread_summary_model .clone() - .or_else(|| self.default_model()) + .or_else(|| self.default_fast_model.clone()) + .or_else(|| self.default_model.clone()) } /// The models to use for inline assists. Returns the union of the active diff --git a/crates/language_models/src/provider/anthropic.rs b/crates/language_models/src/provider/anthropic.rs index 6a299765046376c9e23e657282edb235f5ceeb7a..f998969bfec3da9a2d0f230fe68b6245ff9eea1d 100644 --- a/crates/language_models/src/provider/anthropic.rs +++ b/crates/language_models/src/provider/anthropic.rs @@ -201,7 +201,7 @@ impl AnthropicLanguageModelProvider { state: self.state.clone(), http_client: self.http_client.clone(), request_limiter: RateLimiter::new(4), - }) as Arc + }) } } @@ -227,14 +227,11 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider { } fn default_model(&self, _cx: &App) -> Option> { - let model = anthropic::Model::default(); - Some(Arc::new(AnthropicModel { - id: LanguageModelId::from(model.id().to_string()), - model, - state: self.state.clone(), - http_client: self.http_client.clone(), - request_limiter: RateLimiter::new(4), - })) + Some(self.create_language_model(anthropic::Model::default())) + } + + fn default_fast_model(&self, _cx: &App) -> Option> { + Some(self.create_language_model(anthropic::Model::default_fast())) } fn recommended_models(&self, _cx: &App) -> Vec> { diff --git a/crates/language_models/src/provider/bedrock.rs b/crates/language_models/src/provider/bedrock.rs index c4ef48404f8f008be8809b8a8bc4bd9b5d9a9467..a2748b45be4aeaeabd5f54a848e2c1243b70f8d1 100644 --- a/crates/language_models/src/provider/bedrock.rs +++ b/crates/language_models/src/provider/bedrock.rs @@ -286,6 +286,18 @@ impl BedrockLanguageModelProvider { state, } } + + fn create_language_model(&self, model: bedrock::Model) -> Arc { + Arc::new(BedrockModel { + id: LanguageModelId::from(model.id().to_string()), + model, + http_client: self.http_client.clone(), + handler: self.handler.clone(), + state: self.state.clone(), + client: OnceCell::new(), + request_limiter: RateLimiter::new(4), + }) + } } impl LanguageModelProvider for BedrockLanguageModelProvider { @@ -302,16 +314,11 @@ impl LanguageModelProvider for BedrockLanguageModelProvider { } fn default_model(&self, _cx: &App) -> Option> { - let model = bedrock::Model::default(); - Some(Arc::new(BedrockModel { - id: LanguageModelId::from(model.id().to_string()), - model, - http_client: self.http_client.clone(), - handler: self.handler.clone(), - state: self.state.clone(), - client: OnceCell::new(), - request_limiter: RateLimiter::new(4), - })) + Some(self.create_language_model(bedrock::Model::default())) + } + + fn default_fast_model(&self, _cx: &App) -> Option> { + Some(self.create_language_model(bedrock::Model::default_fast())) } fn provided_models(&self, cx: &App) -> Vec> { @@ -343,17 +350,7 @@ impl LanguageModelProvider for BedrockLanguageModelProvider { models .into_values() - .map(|model| { - Arc::new(BedrockModel { - id: LanguageModelId::from(model.id().to_string()), - model, - http_client: self.http_client.clone(), - handler: self.handler.clone(), - state: self.state.clone(), - client: OnceCell::new(), - request_limiter: RateLimiter::new(4), - }) as Arc - }) + .map(|model| self.create_language_model(model)) .collect() } diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index 80c8d0dcc383102ed9ff68b738a492f6ad0fcdbb..b9911d5d46a175f38f4956afc667467bea921b9d 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -242,7 +242,7 @@ impl CloudLanguageModelProvider { llm_api_token: llm_api_token.clone(), client: self.client.clone(), request_limiter: RateLimiter::new(4), - }) as Arc + }) } } @@ -270,13 +270,13 @@ impl LanguageModelProvider for CloudLanguageModelProvider { fn default_model(&self, cx: &App) -> Option> { let llm_api_token = self.state.read(cx).llm_api_token.clone(); let model = CloudModel::Anthropic(anthropic::Model::default()); - Some(Arc::new(CloudLanguageModel { - id: LanguageModelId::from(model.id().to_string()), - model, - llm_api_token: llm_api_token.clone(), - client: self.client.clone(), - request_limiter: RateLimiter::new(4), - })) + Some(self.create_language_model(model, llm_api_token)) + } + + fn default_fast_model(&self, cx: &App) -> Option> { + let llm_api_token = self.state.read(cx).llm_api_token.clone(); + let model = CloudModel::Anthropic(anthropic::Model::default_fast()); + Some(self.create_language_model(model, llm_api_token)) } fn recommended_models(&self, cx: &App) -> Vec> { diff --git a/crates/language_models/src/provider/copilot_chat.rs b/crates/language_models/src/provider/copilot_chat.rs index 3d4924b890a1623d26365473f09c0c270ce7691b..255de2d536e56414da07509d58e75192685ec3f7 100644 --- a/crates/language_models/src/provider/copilot_chat.rs +++ b/crates/language_models/src/provider/copilot_chat.rs @@ -70,6 +70,13 @@ impl CopilotChatLanguageModelProvider { Self { state } } + + fn create_language_model(&self, model: CopilotChatModel) -> Arc { + Arc::new(CopilotChatLanguageModel { + model, + request_limiter: RateLimiter::new(4), + }) + } } impl LanguageModelProviderState for CopilotChatLanguageModelProvider { @@ -94,21 +101,16 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider { } fn default_model(&self, _cx: &App) -> Option> { - let model = CopilotChatModel::default(); - Some(Arc::new(CopilotChatLanguageModel { - model, - request_limiter: RateLimiter::new(4), - }) as Arc) + Some(self.create_language_model(CopilotChatModel::default())) + } + + fn default_fast_model(&self, _cx: &App) -> Option> { + Some(self.create_language_model(CopilotChatModel::default_fast())) } fn provided_models(&self, _cx: &App) -> Vec> { CopilotChatModel::iter() - .map(|model| { - Arc::new(CopilotChatLanguageModel { - model, - request_limiter: RateLimiter::new(4), - }) as Arc - }) + .map(|model| self.create_language_model(model)) .collect() } diff --git a/crates/language_models/src/provider/deepseek.rs b/crates/language_models/src/provider/deepseek.rs index e4f1cd830a854b492991109c2a5637df57caab07..9989e4c6b12e8615f39088af7c4de1c274936c5c 100644 --- a/crates/language_models/src/provider/deepseek.rs +++ b/crates/language_models/src/provider/deepseek.rs @@ -140,6 +140,16 @@ impl DeepSeekLanguageModelProvider { Self { http_client, state } } + + fn create_language_model(&self, model: deepseek::Model) -> Arc { + Arc::new(DeepSeekLanguageModel { + id: LanguageModelId::from(model.id().to_string()), + model, + state: self.state.clone(), + http_client: self.http_client.clone(), + request_limiter: RateLimiter::new(4), + }) as Arc + } } impl LanguageModelProviderState for DeepSeekLanguageModelProvider { @@ -164,14 +174,11 @@ impl LanguageModelProvider for DeepSeekLanguageModelProvider { } fn default_model(&self, _cx: &App) -> Option> { - let model = deepseek::Model::Chat; - Some(Arc::new(DeepSeekLanguageModel { - id: LanguageModelId::from(model.id().to_string()), - model, - state: self.state.clone(), - http_client: self.http_client.clone(), - request_limiter: RateLimiter::new(4), - })) + Some(self.create_language_model(deepseek::Model::default())) + } + + fn default_fast_model(&self, _cx: &App) -> Option> { + Some(self.create_language_model(deepseek::Model::default_fast())) } fn provided_models(&self, cx: &App) -> Vec> { @@ -198,15 +205,7 @@ impl LanguageModelProvider for DeepSeekLanguageModelProvider { models .into_values() - .map(|model| { - Arc::new(DeepSeekLanguageModel { - id: LanguageModelId::from(model.id().to_string()), - model, - state: self.state.clone(), - http_client: self.http_client.clone(), - request_limiter: RateLimiter::new(4), - }) as Arc - }) + .map(|model| self.create_language_model(model)) .collect() } diff --git a/crates/language_models/src/provider/google.rs b/crates/language_models/src/provider/google.rs index c754e63bbd5760302ae670436ad82f8fecf5ea08..bbe6c5835352f8e4ff58820800e6f819ee086c87 100644 --- a/crates/language_models/src/provider/google.rs +++ b/crates/language_models/src/provider/google.rs @@ -150,6 +150,16 @@ impl GoogleLanguageModelProvider { Self { http_client, state } } + + fn create_language_model(&self, model: google_ai::Model) -> Arc { + Arc::new(GoogleLanguageModel { + id: LanguageModelId::from(model.id().to_string()), + model, + state: self.state.clone(), + http_client: self.http_client.clone(), + request_limiter: RateLimiter::new(4), + }) + } } impl LanguageModelProviderState for GoogleLanguageModelProvider { @@ -174,14 +184,11 @@ impl LanguageModelProvider for GoogleLanguageModelProvider { } fn default_model(&self, _cx: &App) -> Option> { - let model = google_ai::Model::default(); - Some(Arc::new(GoogleLanguageModel { - id: LanguageModelId::from(model.id().to_string()), - model, - state: self.state.clone(), - http_client: self.http_client.clone(), - request_limiter: RateLimiter::new(4), - })) + Some(self.create_language_model(google_ai::Model::default())) + } + + fn default_fast_model(&self, _cx: &App) -> Option> { + Some(self.create_language_model(google_ai::Model::default_fast())) } fn provided_models(&self, cx: &App) -> Vec> { diff --git a/crates/language_models/src/provider/lmstudio.rs b/crates/language_models/src/provider/lmstudio.rs index bd8b6303f8aa1878738520ed535f20d354ba3030..2f5ae9ebb65ec970c3b392c42814010ebfa21603 100644 --- a/crates/language_models/src/provider/lmstudio.rs +++ b/crates/language_models/src/provider/lmstudio.rs @@ -157,6 +157,10 @@ impl LanguageModelProvider for LmStudioLanguageModelProvider { self.provided_models(cx).into_iter().next() } + fn default_fast_model(&self, cx: &App) -> Option> { + self.default_model(cx) + } + fn provided_models(&self, cx: &App) -> Vec> { let mut models: BTreeMap = BTreeMap::default(); diff --git a/crates/language_models/src/provider/mistral.rs b/crates/language_models/src/provider/mistral.rs index b9017398e69e089e448e1b94fc2065b9217a7090..a5009c76a63829e70d448b94c3f39bd3e826d838 100644 --- a/crates/language_models/src/provider/mistral.rs +++ b/crates/language_models/src/provider/mistral.rs @@ -144,6 +144,16 @@ impl MistralLanguageModelProvider { Self { http_client, state } } + + fn create_language_model(&self, model: mistral::Model) -> Arc { + Arc::new(MistralLanguageModel { + id: LanguageModelId::from(model.id().to_string()), + model, + state: self.state.clone(), + http_client: self.http_client.clone(), + request_limiter: RateLimiter::new(4), + }) + } } impl LanguageModelProviderState for MistralLanguageModelProvider { @@ -168,14 +178,11 @@ impl LanguageModelProvider for MistralLanguageModelProvider { } fn default_model(&self, _cx: &App) -> Option> { - let model = mistral::Model::default(); - Some(Arc::new(MistralLanguageModel { - id: LanguageModelId::from(model.id().to_string()), - model, - state: self.state.clone(), - http_client: self.http_client.clone(), - request_limiter: RateLimiter::new(4), - })) + Some(self.create_language_model(mistral::Model::default())) + } + + fn default_fast_model(&self, _cx: &App) -> Option> { + Some(self.create_language_model(mistral::Model::default_fast())) } fn provided_models(&self, cx: &App) -> Vec> { diff --git a/crates/language_models/src/provider/ollama.rs b/crates/language_models/src/provider/ollama.rs index 465dc0d65998a716a12f9b03237c87cb72cda7b0..17c50c8eaf55b0dc1026d2b778b7ae307e657624 100644 --- a/crates/language_models/src/provider/ollama.rs +++ b/crates/language_models/src/provider/ollama.rs @@ -162,6 +162,10 @@ impl LanguageModelProvider for OllamaLanguageModelProvider { self.provided_models(cx).into_iter().next() } + fn default_fast_model(&self, cx: &App) -> Option> { + self.default_model(cx) + } + fn provided_models(&self, cx: &App) -> Vec> { let mut models: BTreeMap = BTreeMap::default(); diff --git a/crates/language_models/src/provider/open_ai.rs b/crates/language_models/src/provider/open_ai.rs index 020c64252057c7f092480b4fad704db9b9a04a26..188a219e2d8a8430ed06eeb8335fa3fe1e6a2f83 100644 --- a/crates/language_models/src/provider/open_ai.rs +++ b/crates/language_models/src/provider/open_ai.rs @@ -148,6 +148,16 @@ impl OpenAiLanguageModelProvider { Self { http_client, state } } + + fn create_language_model(&self, model: open_ai::Model) -> Arc { + Arc::new(OpenAiLanguageModel { + id: LanguageModelId::from(model.id().to_string()), + model, + state: self.state.clone(), + http_client: self.http_client.clone(), + request_limiter: RateLimiter::new(4), + }) + } } impl LanguageModelProviderState for OpenAiLanguageModelProvider { @@ -172,14 +182,11 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider { } fn default_model(&self, _cx: &App) -> Option> { - let model = open_ai::Model::default(); - Some(Arc::new(OpenAiLanguageModel { - id: LanguageModelId::from(model.id().to_string()), - model, - state: self.state.clone(), - http_client: self.http_client.clone(), - request_limiter: RateLimiter::new(4), - })) + Some(self.create_language_model(open_ai::Model::default())) + } + + fn default_fast_model(&self, _cx: &App) -> Option> { + Some(self.create_language_model(open_ai::Model::default_fast())) } fn provided_models(&self, cx: &App) -> Vec> { @@ -211,15 +218,7 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider { models .into_values() - .map(|model| { - Arc::new(OpenAiLanguageModel { - id: LanguageModelId::from(model.id().to_string()), - model, - state: self.state.clone(), - http_client: self.http_client.clone(), - request_limiter: RateLimiter::new(4), - }) as Arc - }) + .map(|model| self.create_language_model(model)) .collect() } diff --git a/crates/mistral/src/mistral.rs b/crates/mistral/src/mistral.rs index de2457e0bf9730e9ac1822d39e424157df23f9d6..27de5fccc285569a75d9099bcec89771867284a6 100644 --- a/crates/mistral/src/mistral.rs +++ b/crates/mistral/src/mistral.rs @@ -69,6 +69,10 @@ pub enum Model { } impl Model { + pub fn default_fast() -> Self { + Model::MistralSmallLatest + } + pub fn from_id(id: &str) -> Result { match id { "codestral-latest" => Ok(Self::CodestralLatest), diff --git a/crates/open_ai/src/open_ai.rs b/crates/open_ai/src/open_ai.rs index 9284b4a9b2771a8ac49f461009592439530bc96e..f9e0b7d4e372f8dac4adf8bdc78b65310f4d91f6 100644 --- a/crates/open_ai/src/open_ai.rs +++ b/crates/open_ai/src/open_ai.rs @@ -102,6 +102,10 @@ pub enum Model { } impl Model { + pub fn default_fast() -> Self { + Self::FourPointOneMini + } + pub fn from_id(id: &str) -> Result { match id { "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),