Detailed changes
@@ -1,8 +1,8 @@
use crate::context::{AssistantContext, ContextId, format_context_as_string};
use crate::context_picker::MentionLink;
use crate::thread::{
- LastRestoreCheckpoint, MessageId, MessageSegment, RequestKind, Thread, ThreadError,
- ThreadEvent, ThreadFeedback,
+ LastRestoreCheckpoint, MessageId, MessageSegment, Thread, ThreadError, ThreadEvent,
+ ThreadFeedback,
};
use crate::thread_store::{RulesLoadingError, ThreadStore};
use crate::tool_use::{PendingToolUseStatus, ToolUse};
@@ -1285,7 +1285,7 @@ impl ActiveThread {
self.thread.update(cx, |thread, cx| {
thread.advance_prompt_id();
- thread.send_to_model(model.model, RequestKind::Chat, cx)
+ thread.send_to_model(model.model, cx)
});
cx.notify();
}
@@ -40,7 +40,7 @@ pub use crate::active_thread::ActiveThread;
use crate::assistant_configuration::{AddContextServerModal, ManageProfilesModal};
pub use crate::assistant_panel::{AssistantPanel, ConcreteAssistantPanelDelegate};
pub use crate::inline_assistant::InlineAssistant;
-pub use crate::thread::{Message, RequestKind, Thread, ThreadEvent};
+pub use crate::thread::{Message, Thread, ThreadEvent};
pub use crate::thread_store::ThreadStore;
pub use agent_diff::{AgentDiff, AgentDiffToolbar};
@@ -34,7 +34,7 @@ use crate::context_picker::{ContextPicker, ContextPickerCompletionProvider};
use crate::context_store::{ContextStore, refresh_context_store_text};
use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind};
use crate::profile_selector::ProfileSelector;
-use crate::thread::{RequestKind, Thread, TokenUsageRatio};
+use crate::thread::{Thread, TokenUsageRatio};
use crate::thread_store::ThreadStore;
use crate::{
AgentDiff, Chat, ChatMode, ExpandMessageEditor, NewThread, OpenAgentDiff, RemoveAllContext,
@@ -234,7 +234,7 @@ impl MessageEditor {
}
self.set_editor_is_expanded(false, cx);
- self.send_to_model(RequestKind::Chat, window, cx);
+ self.send_to_model(window, cx);
cx.notify();
}
@@ -249,12 +249,7 @@ impl MessageEditor {
.is_some()
}
- fn send_to_model(
- &mut self,
- request_kind: RequestKind,
- window: &mut Window,
- cx: &mut Context<Self>,
- ) {
+ fn send_to_model(&mut self, window: &mut Window, cx: &mut Context<Self>) {
let model_registry = LanguageModelRegistry::read_global(cx);
let Some(ConfiguredModel { model, provider }) = model_registry.default_model() else {
return;
@@ -331,7 +326,7 @@ impl MessageEditor {
thread
.update(cx, |thread, cx| {
thread.advance_prompt_id();
- thread.send_to_model(model, request_kind, cx);
+ thread.send_to_model(model, cx);
})
.log_err();
})
@@ -345,7 +340,7 @@ impl MessageEditor {
if cancelled {
self.set_editor_is_expanded(false, cx);
- self.send_to_model(RequestKind::Chat, window, cx);
+ self.send_to_model(window, cx);
}
}
@@ -40,13 +40,6 @@ use crate::thread_store::{
};
use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState, USING_TOOL_MARKER};
-#[derive(Debug, Clone, Copy)]
-pub enum RequestKind {
- Chat,
- /// Used when summarizing a thread.
- Summarize,
-}
-
#[derive(
Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
)]
@@ -949,13 +942,8 @@ impl Thread {
})
}
- pub fn send_to_model(
- &mut self,
- model: Arc<dyn LanguageModel>,
- request_kind: RequestKind,
- cx: &mut Context<Self>,
- ) {
- let mut request = self.to_completion_request(request_kind, cx);
+ pub fn send_to_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut Context<Self>) {
+ let mut request = self.to_completion_request(cx);
if model.supports_tools() {
request.tools = {
let mut tools = Vec::new();
@@ -994,11 +982,7 @@ impl Thread {
false
}
- pub fn to_completion_request(
- &self,
- request_kind: RequestKind,
- cx: &mut Context<Self>,
- ) -> LanguageModelRequest {
+ pub fn to_completion_request(&self, cx: &mut Context<Self>) -> LanguageModelRequest {
let mut request = LanguageModelRequest {
thread_id: Some(self.id.to_string()),
prompt_id: Some(self.last_prompt_id.to_string()),
@@ -1045,18 +1029,8 @@ impl Thread {
cache: false,
};
- match request_kind {
- RequestKind::Chat => {
- self.tool_use
- .attach_tool_results(message.id, &mut request_message);
- }
- RequestKind::Summarize => {
- // We don't care about tool use during summarization.
- if self.tool_use.message_has_tool_results(message.id) {
- continue;
- }
- }
- }
+ self.tool_use
+ .attach_tool_results(message.id, &mut request_message);
if !message.context.is_empty() {
request_message
@@ -1089,15 +1063,8 @@ impl Thread {
};
}
- match request_kind {
- RequestKind::Chat => {
- self.tool_use
- .attach_tool_uses(message.id, &mut request_message);
- }
- RequestKind::Summarize => {
- // We don't care about tool use during summarization.
- }
- };
+ self.tool_use
+ .attach_tool_uses(message.id, &mut request_message);
request.messages.push(request_message);
}
@@ -1112,6 +1079,54 @@ impl Thread {
request
}
+ fn to_summarize_request(&self, added_user_message: String) -> LanguageModelRequest {
+ let mut request = LanguageModelRequest {
+ thread_id: None,
+ prompt_id: None,
+ messages: vec![],
+ tools: Vec::new(),
+ stop: Vec::new(),
+ temperature: None,
+ };
+
+ for message in &self.messages {
+ let mut request_message = LanguageModelRequestMessage {
+ role: message.role,
+ content: Vec::new(),
+ cache: false,
+ };
+
+ // Skip tool results during summarization.
+ if self.tool_use.message_has_tool_results(message.id) {
+ continue;
+ }
+
+ for segment in &message.segments {
+ match segment {
+ MessageSegment::Text(text) => request_message
+ .content
+ .push(MessageContent::Text(text.clone())),
+ MessageSegment::Thinking { .. } => {}
+ MessageSegment::RedactedThinking(_) => {}
+ }
+ }
+
+ if request_message.content.is_empty() {
+ continue;
+ }
+
+ request.messages.push(request_message);
+ }
+
+ request.messages.push(LanguageModelRequestMessage {
+ role: Role::User,
+ content: vec![MessageContent::Text(added_user_message)],
+ cache: false,
+ });
+
+ request
+ }
+
fn attached_tracked_files_state(
&self,
messages: &mut Vec<LanguageModelRequestMessage>,
@@ -1293,7 +1308,12 @@ impl Thread {
.pending_completions
.retain(|completion| completion.id != pending_completion_id);
- if thread.summary.is_none() && thread.messages.len() >= 2 {
+ // If there is a response without tool use, summarize the message. Otherwise,
+ // allow two tool uses before summarizing.
+ if thread.summary.is_none()
+ && thread.messages.len() >= 2
+ && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
+ {
thread.summarize(cx);
}
})?;
@@ -1403,18 +1423,12 @@ impl Thread {
return;
}
- let mut request = self.to_completion_request(RequestKind::Summarize, cx);
- request.messages.push(LanguageModelRequestMessage {
- role: Role::User,
- content: vec![
- "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
- Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
- If the conversation is about a specific subject, include it in the title. \
- Be descriptive. DO NOT speak in the first person."
- .into(),
- ],
- cache: false,
- });
+ let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
+ Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
+ If the conversation is about a specific subject, include it in the title. \
+ Be descriptive. DO NOT speak in the first person.";
+
+ let request = self.to_summarize_request(added_user_message.into());
self.pending_summary = cx.spawn(async move |this, cx| {
async move {
@@ -1476,21 +1490,14 @@ impl Thread {
return None;
}
- let mut request = self.to_completion_request(RequestKind::Summarize, cx);
+ let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
+ 1. A brief overview of what was discussed\n\
+ 2. Key facts or information discovered\n\
+ 3. Outcomes or conclusions reached\n\
+ 4. Any action items or next steps if any\n\
+ Format it in Markdown with headings and bullet points.";
- request.messages.push(LanguageModelRequestMessage {
- role: Role::User,
- content: vec![
- "Generate a detailed summary of this conversation. Include:\n\
- 1. A brief overview of what was discussed\n\
- 2. Key facts or information discovered\n\
- 3. Outcomes or conclusions reached\n\
- 4. Any action items or next steps if any\n\
- Format it in Markdown with headings and bullet points."
- .into(),
- ],
- cache: false,
- });
+ let request = self.to_summarize_request(added_user_message.into());
let task = cx.spawn(async move |thread, cx| {
let stream = model.stream_completion_text(request, &cx);
@@ -1538,7 +1545,7 @@ impl Thread {
pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) -> Vec<PendingToolUse> {
self.auto_capture_telemetry(cx);
- let request = self.to_completion_request(RequestKind::Chat, cx);
+ let request = self.to_completion_request(cx);
let messages = Arc::new(request.messages);
let pending_tool_uses = self
.tool_use
@@ -1650,7 +1657,7 @@ impl Thread {
if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
self.attach_tool_results(cx);
if !canceled {
- self.send_to_model(model, RequestKind::Chat, cx);
+ self.send_to_model(model, cx);
}
}
}
@@ -2275,9 +2282,7 @@ fn main() {{
assert_eq!(message.context, expected_context);
// Check message in request
- let request = thread.update(cx, |thread, cx| {
- thread.to_completion_request(RequestKind::Chat, cx)
- });
+ let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
assert_eq!(request.messages.len(), 2);
let expected_full_message = format!("{}Please explain this code", expected_context);
@@ -2367,9 +2372,7 @@ fn main() {{
assert!(message3.context.contains("file3.rs"));
// Check entire request to make sure all contexts are properly included
- let request = thread.update(cx, |thread, cx| {
- thread.to_completion_request(RequestKind::Chat, cx)
- });
+ let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
// The request should contain all 3 messages
assert_eq!(request.messages.len(), 4);
@@ -2419,9 +2422,7 @@ fn main() {{
assert_eq!(message.context, "");
// Check message in request
- let request = thread.update(cx, |thread, cx| {
- thread.to_completion_request(RequestKind::Chat, cx)
- });
+ let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
assert_eq!(request.messages.len(), 2);
assert_eq!(
@@ -2439,9 +2440,7 @@ fn main() {{
assert_eq!(message2.context, "");
// Check that both messages appear in the request
- let request = thread.update(cx, |thread, cx| {
- thread.to_completion_request(RequestKind::Chat, cx)
- });
+ let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
assert_eq!(request.messages.len(), 3);
assert_eq!(
@@ -2481,9 +2480,7 @@ fn main() {{
});
// Create a request and check that it doesn't have a stale buffer warning yet
- let initial_request = thread.update(cx, |thread, cx| {
- thread.to_completion_request(RequestKind::Chat, cx)
- });
+ let initial_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
// Make sure we don't have a stale file warning yet
let has_stale_warning = initial_request.messages.iter().any(|msg| {
@@ -2511,9 +2508,7 @@ fn main() {{
});
// Create a new request and check for the stale buffer warning
- let new_request = thread.update(cx, |thread, cx| {
- thread.to_completion_request(RequestKind::Chat, cx)
- });
+ let new_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
// We should have a stale file warning as the last message
let last_message = new_request
@@ -74,6 +74,10 @@ pub enum Model {
}
impl Model {
+ pub fn default_fast() -> Self {
+ Self::Claude3_5Haiku
+ }
+
pub fn from_id(id: &str) -> Result<Self> {
if id.starts_with("claude-3-5-sonnet") {
Ok(Self::Claude3_5Sonnet)
@@ -84,6 +84,10 @@ pub enum Model {
}
impl Model {
+ pub fn default_fast() -> Self {
+ Self::Claude3_5Haiku
+ }
+
pub fn from_id(id: &str) -> anyhow::Result<Self> {
if id.starts_with("claude-3-5-sonnet-v2") {
Ok(Self::Claude3_5SonnetV2)
@@ -61,6 +61,10 @@ pub enum Model {
}
impl Model {
+ pub fn default_fast() -> Self {
+ Self::Claude3_7Sonnet
+ }
+
pub fn uses_streaming(&self) -> bool {
match self {
Self::Gpt4o
@@ -64,6 +64,10 @@ pub enum Model {
}
impl Model {
+ pub fn default_fast() -> Self {
+ Model::Chat
+ }
+
pub fn from_id(id: &str) -> Result<Self> {
match id {
"deepseek-chat" => Ok(Self::Chat),
@@ -1,4 +1,4 @@
-use agent::{RequestKind, ThreadEvent, ThreadStore};
+use agent::{ThreadEvent, ThreadStore};
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::ToolWorkingSet;
use client::proto::LspWorkProgress;
@@ -472,7 +472,7 @@ impl Example {
thread.update(cx, |thread, cx| {
let context = vec![];
thread.insert_user_message(this.prompt.clone(), context, None, cx);
- thread.send_to_model(model, RequestKind::Chat, cx);
+ thread.send_to_model(model, cx);
})?;
event_handler_task.await?;
@@ -412,6 +412,10 @@ pub enum Model {
}
impl Model {
+ pub fn default_fast() -> Model {
+ Model::Gemini15Flash
+ }
+
pub fn id(&self) -> &str {
match self {
Model::Gemini15Pro => "gemini-1.5-pro",
@@ -49,6 +49,10 @@ impl LanguageModelProvider for FakeLanguageModelProvider {
Some(Arc::new(FakeLanguageModel::default()))
}
+ fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
+ Some(Arc::new(FakeLanguageModel::default()))
+ }
+
fn provided_models(&self, _: &App) -> Vec<Arc<dyn LanguageModel>> {
vec![Arc::new(FakeLanguageModel::default())]
}
@@ -370,6 +370,7 @@ pub trait LanguageModelProvider: 'static {
IconName::ZedAssistant
}
fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
+ fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
Vec::new()
@@ -5,6 +5,7 @@ use crate::{
use collections::BTreeMap;
use gpui::{App, Context, Entity, EventEmitter, Global, prelude::*};
use std::sync::Arc;
+use util::maybe;
pub fn init(cx: &mut App) {
let registry = cx.new(|_cx| LanguageModelRegistry::default());
@@ -18,6 +19,7 @@ impl Global for GlobalLanguageModelRegistry {}
#[derive(Default)]
pub struct LanguageModelRegistry {
default_model: Option<ConfiguredModel>,
+ default_fast_model: Option<ConfiguredModel>,
inline_assistant_model: Option<ConfiguredModel>,
commit_message_model: Option<ConfiguredModel>,
thread_summary_model: Option<ConfiguredModel>,
@@ -202,6 +204,14 @@ impl LanguageModelRegistry {
(None, None) => {}
_ => cx.emit(Event::DefaultModelChanged),
}
+ self.default_fast_model = maybe!({
+ let provider = &model.as_ref()?.provider;
+ let fast_model = provider.default_fast_model(cx)?;
+ Some(ConfiguredModel {
+ provider: provider.clone(),
+ model: fast_model,
+ })
+ });
self.default_model = model;
}
@@ -254,21 +264,37 @@ impl LanguageModelRegistry {
}
pub fn inline_assistant_model(&self) -> Option<ConfiguredModel> {
+ #[cfg(debug_assertions)]
+ if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
+ return None;
+ }
+
self.inline_assistant_model
.clone()
- .or_else(|| self.default_model())
+ .or_else(|| self.default_model.clone())
}
pub fn commit_message_model(&self) -> Option<ConfiguredModel> {
+ #[cfg(debug_assertions)]
+ if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
+ return None;
+ }
+
self.commit_message_model
.clone()
- .or_else(|| self.default_model())
+ .or_else(|| self.default_model.clone())
}
pub fn thread_summary_model(&self) -> Option<ConfiguredModel> {
+ #[cfg(debug_assertions)]
+ if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
+ return None;
+ }
+
self.thread_summary_model
.clone()
- .or_else(|| self.default_model())
+ .or_else(|| self.default_fast_model.clone())
+ .or_else(|| self.default_model.clone())
}
/// The models to use for inline assists. Returns the union of the active
@@ -201,7 +201,7 @@ impl AnthropicLanguageModelProvider {
state: self.state.clone(),
http_client: self.http_client.clone(),
request_limiter: RateLimiter::new(4),
- }) as Arc<dyn LanguageModel>
+ })
}
}
@@ -227,14 +227,11 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
}
fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
- let model = anthropic::Model::default();
- Some(Arc::new(AnthropicModel {
- id: LanguageModelId::from(model.id().to_string()),
- model,
- state: self.state.clone(),
- http_client: self.http_client.clone(),
- request_limiter: RateLimiter::new(4),
- }))
+ Some(self.create_language_model(anthropic::Model::default()))
+ }
+
+ fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
+ Some(self.create_language_model(anthropic::Model::default_fast()))
}
fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
@@ -286,6 +286,18 @@ impl BedrockLanguageModelProvider {
state,
}
}
+
+ fn create_language_model(&self, model: bedrock::Model) -> Arc<dyn LanguageModel> {
+ Arc::new(BedrockModel {
+ id: LanguageModelId::from(model.id().to_string()),
+ model,
+ http_client: self.http_client.clone(),
+ handler: self.handler.clone(),
+ state: self.state.clone(),
+ client: OnceCell::new(),
+ request_limiter: RateLimiter::new(4),
+ })
+ }
}
impl LanguageModelProvider for BedrockLanguageModelProvider {
@@ -302,16 +314,11 @@ impl LanguageModelProvider for BedrockLanguageModelProvider {
}
fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
- let model = bedrock::Model::default();
- Some(Arc::new(BedrockModel {
- id: LanguageModelId::from(model.id().to_string()),
- model,
- http_client: self.http_client.clone(),
- handler: self.handler.clone(),
- state: self.state.clone(),
- client: OnceCell::new(),
- request_limiter: RateLimiter::new(4),
- }))
+ Some(self.create_language_model(bedrock::Model::default()))
+ }
+
+ fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
+ Some(self.create_language_model(bedrock::Model::default_fast()))
}
fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
@@ -343,17 +350,7 @@ impl LanguageModelProvider for BedrockLanguageModelProvider {
models
.into_values()
- .map(|model| {
- Arc::new(BedrockModel {
- id: LanguageModelId::from(model.id().to_string()),
- model,
- http_client: self.http_client.clone(),
- handler: self.handler.clone(),
- state: self.state.clone(),
- client: OnceCell::new(),
- request_limiter: RateLimiter::new(4),
- }) as Arc<dyn LanguageModel>
- })
+ .map(|model| self.create_language_model(model))
.collect()
}
@@ -242,7 +242,7 @@ impl CloudLanguageModelProvider {
llm_api_token: llm_api_token.clone(),
client: self.client.clone(),
request_limiter: RateLimiter::new(4),
- }) as Arc<dyn LanguageModel>
+ })
}
}
@@ -270,13 +270,13 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
let llm_api_token = self.state.read(cx).llm_api_token.clone();
let model = CloudModel::Anthropic(anthropic::Model::default());
- Some(Arc::new(CloudLanguageModel {
- id: LanguageModelId::from(model.id().to_string()),
- model,
- llm_api_token: llm_api_token.clone(),
- client: self.client.clone(),
- request_limiter: RateLimiter::new(4),
- }))
+ Some(self.create_language_model(model, llm_api_token))
+ }
+
+ fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
+ let llm_api_token = self.state.read(cx).llm_api_token.clone();
+ let model = CloudModel::Anthropic(anthropic::Model::default_fast());
+ Some(self.create_language_model(model, llm_api_token))
}
fn recommended_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
@@ -70,6 +70,13 @@ impl CopilotChatLanguageModelProvider {
Self { state }
}
+
+ fn create_language_model(&self, model: CopilotChatModel) -> Arc<dyn LanguageModel> {
+ Arc::new(CopilotChatLanguageModel {
+ model,
+ request_limiter: RateLimiter::new(4),
+ })
+ }
}
impl LanguageModelProviderState for CopilotChatLanguageModelProvider {
@@ -94,21 +101,16 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider {
}
fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
- let model = CopilotChatModel::default();
- Some(Arc::new(CopilotChatLanguageModel {
- model,
- request_limiter: RateLimiter::new(4),
- }) as Arc<dyn LanguageModel>)
+ Some(self.create_language_model(CopilotChatModel::default()))
+ }
+
+ fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
+ Some(self.create_language_model(CopilotChatModel::default_fast()))
}
fn provided_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
CopilotChatModel::iter()
- .map(|model| {
- Arc::new(CopilotChatLanguageModel {
- model,
- request_limiter: RateLimiter::new(4),
- }) as Arc<dyn LanguageModel>
- })
+ .map(|model| self.create_language_model(model))
.collect()
}
@@ -140,6 +140,16 @@ impl DeepSeekLanguageModelProvider {
Self { http_client, state }
}
+
+ fn create_language_model(&self, model: deepseek::Model) -> Arc<dyn LanguageModel> {
+ Arc::new(DeepSeekLanguageModel {
+ id: LanguageModelId::from(model.id().to_string()),
+ model,
+ state: self.state.clone(),
+ http_client: self.http_client.clone(),
+ request_limiter: RateLimiter::new(4),
+ }) as Arc<dyn LanguageModel>
+ }
}
impl LanguageModelProviderState for DeepSeekLanguageModelProvider {
@@ -164,14 +174,11 @@ impl LanguageModelProvider for DeepSeekLanguageModelProvider {
}
fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
- let model = deepseek::Model::Chat;
- Some(Arc::new(DeepSeekLanguageModel {
- id: LanguageModelId::from(model.id().to_string()),
- model,
- state: self.state.clone(),
- http_client: self.http_client.clone(),
- request_limiter: RateLimiter::new(4),
- }))
+ Some(self.create_language_model(deepseek::Model::default()))
+ }
+
+ fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
+ Some(self.create_language_model(deepseek::Model::default_fast()))
}
fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
@@ -198,15 +205,7 @@ impl LanguageModelProvider for DeepSeekLanguageModelProvider {
models
.into_values()
- .map(|model| {
- Arc::new(DeepSeekLanguageModel {
- id: LanguageModelId::from(model.id().to_string()),
- model,
- state: self.state.clone(),
- http_client: self.http_client.clone(),
- request_limiter: RateLimiter::new(4),
- }) as Arc<dyn LanguageModel>
- })
+ .map(|model| self.create_language_model(model))
.collect()
}
@@ -150,6 +150,16 @@ impl GoogleLanguageModelProvider {
Self { http_client, state }
}
+
+ fn create_language_model(&self, model: google_ai::Model) -> Arc<dyn LanguageModel> {
+ Arc::new(GoogleLanguageModel {
+ id: LanguageModelId::from(model.id().to_string()),
+ model,
+ state: self.state.clone(),
+ http_client: self.http_client.clone(),
+ request_limiter: RateLimiter::new(4),
+ })
+ }
}
impl LanguageModelProviderState for GoogleLanguageModelProvider {
@@ -174,14 +184,11 @@ impl LanguageModelProvider for GoogleLanguageModelProvider {
}
fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
- let model = google_ai::Model::default();
- Some(Arc::new(GoogleLanguageModel {
- id: LanguageModelId::from(model.id().to_string()),
- model,
- state: self.state.clone(),
- http_client: self.http_client.clone(),
- request_limiter: RateLimiter::new(4),
- }))
+ Some(self.create_language_model(google_ai::Model::default()))
+ }
+
+ fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
+ Some(self.create_language_model(google_ai::Model::default_fast()))
}
fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
@@ -157,6 +157,10 @@ impl LanguageModelProvider for LmStudioLanguageModelProvider {
self.provided_models(cx).into_iter().next()
}
+ fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
+ self.default_model(cx)
+ }
+
fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
let mut models: BTreeMap<String, lmstudio::Model> = BTreeMap::default();
@@ -144,6 +144,16 @@ impl MistralLanguageModelProvider {
Self { http_client, state }
}
+
+ fn create_language_model(&self, model: mistral::Model) -> Arc<dyn LanguageModel> {
+ Arc::new(MistralLanguageModel {
+ id: LanguageModelId::from(model.id().to_string()),
+ model,
+ state: self.state.clone(),
+ http_client: self.http_client.clone(),
+ request_limiter: RateLimiter::new(4),
+ })
+ }
}
impl LanguageModelProviderState for MistralLanguageModelProvider {
@@ -168,14 +178,11 @@ impl LanguageModelProvider for MistralLanguageModelProvider {
}
fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
- let model = mistral::Model::default();
- Some(Arc::new(MistralLanguageModel {
- id: LanguageModelId::from(model.id().to_string()),
- model,
- state: self.state.clone(),
- http_client: self.http_client.clone(),
- request_limiter: RateLimiter::new(4),
- }))
+ Some(self.create_language_model(mistral::Model::default()))
+ }
+
+ fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
+ Some(self.create_language_model(mistral::Model::default_fast()))
}
fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
@@ -162,6 +162,10 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
self.provided_models(cx).into_iter().next()
}
+ fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
+ self.default_model(cx)
+ }
+
fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
let mut models: BTreeMap<String, ollama::Model> = BTreeMap::default();
@@ -148,6 +148,16 @@ impl OpenAiLanguageModelProvider {
Self { http_client, state }
}
+
+ fn create_language_model(&self, model: open_ai::Model) -> Arc<dyn LanguageModel> {
+ Arc::new(OpenAiLanguageModel {
+ id: LanguageModelId::from(model.id().to_string()),
+ model,
+ state: self.state.clone(),
+ http_client: self.http_client.clone(),
+ request_limiter: RateLimiter::new(4),
+ })
+ }
}
impl LanguageModelProviderState for OpenAiLanguageModelProvider {
@@ -172,14 +182,11 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider {
}
fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
- let model = open_ai::Model::default();
- Some(Arc::new(OpenAiLanguageModel {
- id: LanguageModelId::from(model.id().to_string()),
- model,
- state: self.state.clone(),
- http_client: self.http_client.clone(),
- request_limiter: RateLimiter::new(4),
- }))
+ Some(self.create_language_model(open_ai::Model::default()))
+ }
+
+ fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
+ Some(self.create_language_model(open_ai::Model::default_fast()))
}
fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
@@ -211,15 +218,7 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider {
models
.into_values()
- .map(|model| {
- Arc::new(OpenAiLanguageModel {
- id: LanguageModelId::from(model.id().to_string()),
- model,
- state: self.state.clone(),
- http_client: self.http_client.clone(),
- request_limiter: RateLimiter::new(4),
- }) as Arc<dyn LanguageModel>
- })
+ .map(|model| self.create_language_model(model))
.collect()
}
@@ -69,6 +69,10 @@ pub enum Model {
}
impl Model {
+ pub fn default_fast() -> Self {
+ Model::MistralSmallLatest
+ }
+
pub fn from_id(id: &str) -> Result<Self> {
match id {
"codestral-latest" => Ok(Self::CodestralLatest),
@@ -102,6 +102,10 @@ pub enum Model {
}
impl Model {
+ pub fn default_fast() -> Self {
+ Self::FourPointOneMini
+ }
+
pub fn from_id(id: &str) -> Result<Self> {
match id {
"gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),